Skip to content

SaveAt¤

diffrax.SaveAt ¤

Determines what to save as output from the differential equation solve.

Instances of this class should be passed as the saveat argument of diffrax.diffeqsolve.

__init__(self, *, t0: bool = False, t1: bool = False, ts: Union[NoneType, Sequence[Union[float, int]], Real[Array, 'times']] = None, steps: bool = False, fn: Callable = <function save_y>, subs: PyTree[SubSaveAt] = None, dense: bool = False, solver_state: bool = False, controller_state: bool = False, made_jump: bool = False) ¤

Main Arguments:

  • t0: If True, save the initial input y0.
  • t1: If True, save the output at t1.
  • ts: Some array of times at which to save the output.
  • steps: If True, save the output at every step of the numerical solver.
  • dense: If True, save dense output, that can later be evaluated at any part of the interval \([t_0, t_1]\) via sol = diffeqsolve(...); sol.evaluate(...).

Other Arguments:

These arguments are used less frequently.

  • fn: A function fn(t, y, args) which specifies what to save into sol.ys when using t0, t1, ts or steps. Defaults to fn(t, y, args) -> y, so that the evolving solution is saved. For example this can be useful to save only statistics of your solution, so as to reduce memory usage.

  • subs: Some PyTree of diffrax.SubSaveAt, which allows for finer-grained control over what is saved. Each SubSaveAt specifies a combination of a function fn and some times t0, t1, ts, steps at which to evaluate it. sol.ts and sol.ys will then be PyTrees of the same structure as subs, with each leaf of the PyTree saving what the corresponding SubSaveAt specifies. The arguments SaveAt(t0=..., t1=..., ts=..., steps=..., fn=...) are actually just a convenience for passing a single SubSaveAt as SaveAt(subs=SubSaveAt(t0=..., t1=..., ts=..., steps=..., fn=...)). This functionality can be useful when you need different functions of the output saved at different times; see the examples below.

  • solver_state: If True, save the internal state of the numerical solver at t1; accessible as sol.solver_state.

  • controller_state: If True, save the internal state of the step size controller at t1; accessible as sol.controller_state.

  • made_jump: If True, save the internal state of the jump tracker at t1; accessible as sol.made_jump.

Example

When solving a large PDE system, it may be the case that saving the full output y at all timesteps is too memory-intensive. Instead, we may prefer to save only the full final value, and only save statistics of the evolving solution. We can do this by:

t0 = 0
t1 = 100
ts = jnp.linspace(t0, t1, 1000)

def statistics(t, y, args):
    return jnp.mean(y), jnp.std(y)

final_subsaveat = diffrax.SubSaveAt(t1=True)
evolving_subsaveat = diffrax.SubSaveAt(ts=ts, fn=statistics)
saveat = diffrax.SaveAt(subs=[final_subsaveat, evolving_subsaveat])

sol = diffrax.diffeqsolve(..., t0=t0, t1=t1, saveat=saveat)
(y1, evolving_stats) = sol.ys  # PyTree of the save structure as `SaveAt(subs=...)`.
evolving_means, evolving_stds = evolving_stats

As another example, it may be the case that you are solving a 2-dimensional ODE, and want to save each component of its solution at different times. (Perhaps because you are comparing your model against data, and each dimension has data observed at different times.) This can be done through:

y0 = (y0_a, y0_b)
ts_a = ...
ts_b = ...
subsaveat_a = diffrax.SubSaveAt(ts=ts_a, fn=lambda t, y, args: y[0])
subsaveat_b = diffrax.SubSaveAt(ts=ts_b, fn=lambda t, y, args: y[1])
saveat = diffrax.SaveAt(subs=[subsaveat_a, subsaveat_b])
sol = diffrax.diffeqsolve(..., y0=y0, saveat=saveat)
y_a, y_b = sol.ys  # PyTree of the same structure as `SaveAt(subs=...)`.
# `sol.ts` will equal `(ts_a, ts_b)`.

diffrax.SubSaveAt ¤

Used for finer-grained control over what is saved. A PyTree of these should be passed to SaveAt(subs=...).

See diffrax.SaveAt for more details on how this is used. (This is a relatively niche feature and most users will probably not need to use SubSaveAt.)

__init__(self, t0: bool = False, t1: bool = False, ts: Union[NoneType, Sequence[Union[float, int]], Real[Array, 'times']] = None, steps: bool = False, fn: Callable = <function save_y>) ¤

Arguments:

  • t0: If True, save the initial input y0.
  • t1: If True, save the output at t1.
  • ts: Some array of times at which to save the output.
  • steps: If True, save the output at every step of the numerical solver.
  • fn: A function fn(t, y, args) which specifies what to save into sol.ys when using t0, t1, ts or steps. Defaults to fn(t, y, args) -> y, so that the evolving solution is saved. This can be useful to save only statistics of your solution, so as to reduce memory usage.