SaveAt¤
diffrax.SaveAt
¤
Determines what to save as output from the differential equation solve.
Instances of this class should be passed as the saveat
argument of
diffrax.diffeqsolve
.
__init__(self, *, t0: bool = False, t1: bool = False, ts: Union[NoneType, Sequence[Union[float, int]], Real[Array, 'times']] = None, steps: bool = False, fn: Callable = <function save_y>, subs: PyTree[SubSaveAt] = None, dense: bool = False, solver_state: bool = False, controller_state: bool = False, made_jump: bool = False)
¤
Main Arguments:
t0
: IfTrue
, save the initial inputy0
.t1
: IfTrue
, save the output att1
.ts
: Some array of times at which to save the output.steps
: IfTrue
, save the output at every step of the numerical solver.dense
: IfTrue
, save dense output, that can later be evaluated at any part of the interval \([t_0, t_1]\) viasol = diffeqsolve(...); sol.evaluate(...)
.
Other Arguments:
These arguments are used less frequently.
-
fn
: A functionfn(t, y, args)
which specifies what to save intosol.ys
when usingt0
,t1
,ts
orsteps
. Defaults tofn(t, y, args) -> y
, so that the evolving solution is saved. For example this can be useful to save only statistics of your solution, so as to reduce memory usage. -
subs
: Some PyTree ofdiffrax.SubSaveAt
, which allows for finer-grained control over what is saved. EachSubSaveAt
specifies a combination of a functionfn
and some timest0
,t1
,ts
,steps
at which to evaluate it.sol.ts
andsol.ys
will then be PyTrees of the same structure assubs
, with each leaf of the PyTree saving what the correspondingSubSaveAt
specifies. The argumentsSaveAt(t0=..., t1=..., ts=..., steps=..., fn=...)
are actually just a convenience for passing a singleSubSaveAt
asSaveAt(subs=SubSaveAt(t0=..., t1=..., ts=..., steps=..., fn=...))
. This functionality can be useful when you need different functions of the output saved at different times; see the examples below. -
solver_state
: IfTrue
, save the internal state of the numerical solver att1
; accessible assol.solver_state
. -
controller_state
: IfTrue
, save the internal state of the step size controller att1
; accessible assol.controller_state
. -
made_jump
: IfTrue
, save the internal state of the jump tracker att1
; accessible assol.made_jump
.
Example
When solving a large PDE system, it may be the case that saving the full output
y
at all timesteps is too memory-intensive. Instead, we may prefer to save only
the full final value, and only save statistics of the evolving solution. We can do
this by:
t0 = 0
t1 = 100
ts = jnp.linspace(t0, t1, 1000)
def statistics(t, y, args):
return jnp.mean(y), jnp.std(y)
final_subsaveat = diffrax.SubSaveAt(t1=True)
evolving_subsaveat = diffrax.SubSaveAt(ts=ts, fn=statistics)
saveat = diffrax.SaveAt(subs=[final_subsaveat, evolving_subsaveat])
sol = diffrax.diffeqsolve(..., t0=t0, t1=t1, saveat=saveat)
(y1, evolving_stats) = sol.ys # PyTree of the save structure as `SaveAt(subs=...)`.
evolving_means, evolving_stds = evolving_stats
As another example, it may be the case that you are solving a 2-dimensional ODE, and want to save each component of its solution at different times. (Perhaps because you are comparing your model against data, and each dimension has data observed at different times.) This can be done through:
y0 = (y0_a, y0_b)
ts_a = ...
ts_b = ...
subsaveat_a = diffrax.SubSaveAt(ts=ts_a, fn=lambda t, y, args: y[0])
subsaveat_b = diffrax.SubSaveAt(ts=ts_b, fn=lambda t, y, args: y[1])
saveat = diffrax.SaveAt(subs=[subsaveat_a, subsaveat_b])
sol = diffrax.diffeqsolve(..., y0=y0, saveat=saveat)
y_a, y_b = sol.ys # PyTree of the same structure as `SaveAt(subs=...)`.
# `sol.ts` will equal `(ts_a, ts_b)`.
diffrax.SubSaveAt
¤
Used for finer-grained control over what is saved. A PyTree of these should be
passed to SaveAt(subs=...)
.
See diffrax.SaveAt
for more details on how this is used. (This is a
relatively niche feature and most users will probably not need to use SubSaveAt
.)
__init__(self, t0: bool = False, t1: bool = False, ts: Union[NoneType, Sequence[Union[float, int]], Real[Array, 'times']] = None, steps: bool = False, fn: Callable = <function save_y>)
¤
Arguments:
t0
: IfTrue
, save the initial inputy0
.t1
: IfTrue
, save the output att1
.ts
: Some array of times at which to save the output.steps
: IfTrue
, save the output at every step of the numerical solver.fn
: A functionfn(t, y, args)
which specifies what to save intosol.ys
when usingt0
,t1
,ts
orsteps
. Defaults tofn(t, y, args) -> y
, so that the evolving solution is saved. This can be useful to save only statistics of your solution, so as to reduce memory usage.