Stochastic Runge-Kutta (SRK) demonstration¤
The AbstractSRK
class takes a StochasticButcherTableau
and implements the corresponding SRK method.
Depending on the tableau, the resulting method can either be used for general SDEs, or just for ones with additive noise.
The additive-noise-only methods are somewhat faster, but will fail if the noise is not additive.
Nevertheless, even in the additive noise case, the diffusion vector field can depend on time (just not on the state
To account for time-dependent noise, the SRK adds a term to the output of each step, which allows it to still maintain its usual strong order of convergence.
The SRK is capable of utilising various types of time Lévy area, depending on the tableau provided. It can use:
- just the Brownian motion
, without any Lévy area and the space-time Lévy area , and the space-time-time Lévy area .
For more information see the documentation of the StochasticButcherTableau
class.
First we will demonstrate an additive-noise-only SRK method, the ShARK method, on an SDE with additive, time-dependent noise.
We will compare various additive-noise-only SRK methods as well as some general SRK methods proposed by Foster.
%env JAX_PLATFORM_NAME=cuda
from test.helpers import (
get_mlp_sde,
get_time_sde,
simple_sde_order,
)
from warnings import simplefilter
import diffrax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from diffrax import (
diffeqsolve,
GeneralShARK,
ShARK,
SlowRK,
SpaceTimeLevyArea,
SPaRK,
SRA1,
)
from jax import config
simplefilter("ignore", category=FutureWarning)
config.update("jax_enable_x64", True)
jnp.set_printoptions(precision=4, suppress=True)
# Plotting
def draw_order(results):
steps, errs, order = results
plt.plot(steps, errs)
plt.yscale("log")
plt.xscale("log")
pretty_steps = [int(step) for step in steps]
plt.xticks(ticks=pretty_steps, labels=pretty_steps)
plt.ylabel("RMS error")
plt.xlabel("average number of steps")
plt.show()
print(f"Order of convergence: {order:.4f}")
def plot_sol_general(sol):
plt.plot(sol.ts, sol.ys)
plt.show()
def draw_order_multiple(results_list, names_list, title=None):
plt.figure(dpi=200)
if title is not None:
plt.title(title)
orders = "Orders of convergence:\n"
for results, name in zip(results_list, names_list):
steps, errs, order = results
plt.plot(steps, errs, label=name)
orders += f"{name}: {order:.4f}\n"
plt.yscale("log")
plt.xscale("log")
plt.ylabel("RMS error")
plt.xlabel("average number of steps")
plt.legend()
# Write the orders in the corner of the plot
plt.text(
0.05,
0.05,
orders,
transform=plt.gca().transAxes,
verticalalignment="bottom",
fontsize=10,
)
plt.show()
dtype = jnp.float64
key = jr.PRNGKey(2)
sde_key = jr.PRNGKey(11)
num_samples = 1000
keys = jr.split(jr.PRNGKey(5678), num=num_samples)
t0, t1 = 0.0, 16.0
t_short = 4.0
t_long = 32.0
save_at_solver_steps = diffrax.SaveAt(steps=True)
levy_area = SpaceTimeLevyArea
def constant_step_strong_order(keys, sde, solver, levels, bm_tol=None):
def _step_ts(level):
return jnp.linspace(sde.t0, sde.t1, 2**level + 1, endpoint=True)
def get_controller(level):
return None, diffrax.StepTo(ts=_step_ts(level))
_saveat = diffrax.SaveAt(ts=_step_ts(levels[0]))
if bm_tol is None:
bm_tol = (sde.t1 - sde.t0) * (2 ** -(levels[1] + 5))
return simple_sde_order(
keys,
sde,
solver,
solver,
levels,
get_controller,
_saveat,
bm_tol,
levy_area,
ref_solution=None,
)
def pid_strong_order(keys, sde, solver, levels, bm_tol=2**-18):
save_ts_pid = jnp.linspace(sde.t0, sde.t1, 65, endpoint=True)
def get_pid(level):
return diffrax.PIDController(
pcoeff=0.1,
icoeff=0.3,
rtol=2 ** -(level - 1),
atol=2 ** -(level + 3),
dtmin=2**-14,
step_ts=save_ts_pid,
)
saveat_pid = diffrax.SaveAt(ts=save_ts_pid)
return simple_sde_order(
keys,
sde,
solver,
solver,
levels,
get_pid,
saveat_pid,
bm_tol,
levy_area,
ref_solution=None,
)
time_sde = get_time_sde(t0, t1, dtype=dtype, noise_dim=7, key=sde_key)
terms_time_sde = time_sde.get_terms(
time_sde.get_bm(key, levy_area=SpaceTimeLevyArea, tol=2**-10)
)
time_sde_short = get_time_sde(t0, t_short, dtype=dtype, noise_dim=7, key=sde_key)
mlp_sde = get_mlp_sde(t0, t1, dtype=dtype, key=sde_key, noise_dim=3)
terms_mlp_sde = mlp_sde.get_terms(
mlp_sde.get_bm(key, levy_area=SpaceTimeLevyArea, tol=2**-10)
)
mlp_sde_short = get_mlp_sde(t0, t_short, dtype=dtype, key=sde_key, noise_dim=3)
commutative_sde = get_mlp_sde(t0, t1, dtype=dtype, key=sde_key, noise_dim=1)
terms_commutative_sde = commutative_sde.get_terms(
commutative_sde.get_bm(key, levy_area=SpaceTimeLevyArea, tol=2**-10)
)
commutative_sde_short = get_mlp_sde(t0, t_short, dtype=dtype, key=sde_key, noise_dim=1)
# A plot of the solution of the SDE used to compare the methods
sol_general = diffeqsolve(
terms_mlp_sde,
GeneralShARK(),
t0,
t1,
dt0=0.02,
y0=mlp_sde.y0,
args=mlp_sde.args,
saveat=diffrax.SaveAt(steps=True),
)
plot_sol_general(sol_general)
ShARK¤
ShARK
is an SRK method for additive-noise SDEs. It uses two vector-field evaluations per step and has strong order 1.5, but applied to an Underdamped Langevin SDE it has order 2.
While it has the same order as SRA1
, it has a better proportionality constant.
Based on equation (6.1) in Foster, J., dos Reis, G., & Strange, C. (2023). High order splitting methods for SDEs satisfying a commutativity condition. arXiv [Math.NA] http://arxiv.org/abs/2210.17543
General ShARK¤
GeneralShARK
is a generalisation of the ShARK method which now works for any SDE, not only those with additive noise. It uses three evaluations of the vector field per step and has the following strong orders of convergence:
- 2 for the Underdamped Langevin SDEs
- 1.5 for additive noise SDEs
- 1 for commutative noise SDEs
- 0.5 for general SDEs.
SRA1¤
Another method for additive-noise SDEs.
SRA1
normally has strong order 1.5, but when applied to an Underdamped Langevin SDE it has order 2. It natively supports adaptive-stepping via an embedded method for error estimation. Uses two evaluations of the vector-field per step.
Based on the SRA1 method from A. Rößler, Runge–Kutta methods for the strong approximation of solutions of stochastic differential equations, SIAM Journal on Numerical Analysis, 8 (2010), pp. 922–952.
Shifted Additive-noise Euler (SEA)¤
This variant of the Euler-Maruyama makes use of the space-time Lévy area, which improves its local error to
### The "Splitting Path Runge-Kutta" (SPaRK) method This is a general Stochastic Runge-Kutta method with 3 evaluations of the vector field per step, based on Definition 1.6 from Foster, J. (2023). On the convergence of adaptive approximations for stochastic differential equations. arXiv [Math.NA]. Retrieved from http://arxiv.org/abs/2311.14201
For general SDEs, this has order 0.5. When the noise is commutative it has order 1. When the noise is additive it has order 1.5. For the Underdamped Langevin SDE it has order 2. Requires the space-time Lévy area H. It also natively supports adaptive time-stepping.
SLOW-RK¤
This is a general Stochastic Runge-Kutta method with 7 stages (2 evaluations of the drift vector field and 5 evaluations of the diffusion vector field) per step. Remarkably, it has order 1.5 for commutative noise SDEs and order 0.5 for general SDEs. Devised by James Foster.
# A comparison of SlowRK, SPaRK and GeneralShARK for general SDEs
# We compute their orders and plot them on the same graph
out_SLOWRK_mlp_sde = constant_step_strong_order(
keys, mlp_sde_short, SlowRK(), levels=(4, 10)
)
out_SPaRK_mlp_sde = constant_step_strong_order(
keys, mlp_sde_short, SPaRK(), levels=(4, 10)
)
out_GenShARK_mlp_sde = constant_step_strong_order(
keys, mlp_sde_short, GeneralShARK(), levels=(4, 10)
)
draw_order_multiple(
[out_SLOWRK_mlp_sde, out_SPaRK_mlp_sde, out_GenShARK_mlp_sde],
["SlowRK", "SPaRK", "GeneralShARK"],
title="Order of convergence on a general SDE",
)
Commutative noise SDEs¤
# A plot of the solution of the commutative-noise SDE used to compare the methods
# A plot of the solution of the SDE
# We will use this to compare the methods
sol_commutative = diffeqsolve(
terms_commutative_sde,
GeneralShARK(),
t0,
t1,
dt0=0.02,
y0=commutative_sde.y0,
args=commutative_sde.args,
saveat=diffrax.SaveAt(steps=True),
)
plot_sol_general(sol_commutative)
# A comparison of SlowRK, SPaRK and GeneralShARK for commutative noise SDEs
# We compute their orders and plot them on the same graph
out_SLOWRK_commutative_sde = constant_step_strong_order(
keys, commutative_sde_short, SlowRK(), levels=(4, 10)
)
out_SPaRK_commutive_sde = constant_step_strong_order(
keys, commutative_sde_short, SPaRK(), levels=(4, 10)
)
out_GenShARK_commutative_sde = constant_step_strong_order(
keys, commutative_sde_short, GeneralShARK(), levels=(4, 10)
)
draw_order_multiple(
[
out_SLOWRK_commutative_sde,
out_SPaRK_commutive_sde,
out_GenShARK_commutative_sde,
],
["SlowRK", "SPaRK", "GeneralShARK"],
title="Order of convergence on a commutative noise SDE",
)
Additive noise SDEs¤
# A plot of the solution of the additive-noise SDE used to compare the methods
# A plot of the solution of the SDE
# We will use this to compare the methods
sol_additive = diffeqsolve(
terms_time_sde,
ShARK(),
t0,
t1,
dt0=0.02,
y0=time_sde.y0,
args=time_sde.args,
saveat=diffrax.SaveAt(steps=True),
)
plot_sol_general(sol_additive)
# A comparison of SRKs for additive noise SDEs
# We compute their orders and plot them on the same graph
out_SLOWRK_time_sde = constant_step_strong_order(
keys, time_sde_short, SlowRK(), levels=(7, 12)
)
out_SPaRK_time_sde = constant_step_strong_order(
keys, time_sde_short, SPaRK(), levels=(7, 12)
)
out_GenShARK_time_sde = constant_step_strong_order(
keys, time_sde_short, GeneralShARK(), levels=(7, 12)
)
out_ShARK_time_sde = constant_step_strong_order(
keys, time_sde_short, ShARK(), levels=(7, 12)
)
out_SRA1_time_sde = constant_step_strong_order(
keys, time_sde_short, SRA1(), levels=(7, 12)
)
draw_order_multiple(
[
out_SLOWRK_time_sde,
out_SPaRK_time_sde,
out_GenShARK_time_sde,
out_ShARK_time_sde,
out_SRA1_time_sde,
],
["SlowRK", "SPaRK", "GeneralShARK", "ShARK", "SRA1"],
title="Order of convergence on an additive noise SDE",
)