Events¤
Events allow for interrupting a differential equation solve, by terminating the solve before t1
is reached.
diffrax.Event
¤
Can be used to terminate the solve early if a condition, or one of multiple conditions, is triggered. It allows for both boolean and continuous condition functions. In the latter case, a root finder can be used to find the exact time of the event. Boolean and continuous conditions can be used together.
Instances of this class should be passed as the event
argument of
diffrax.diffeqsolve
.
__init__(self, cond_fn: PyTree[Callable[..., typing.Union[bool, float, int]]], root_finder: Optional[optimistix._root_find.AbstractRootFinder] = None)
¤
Arguments:
-
cond_fn
: A function or PyTree of functionsf(t, y, args, **kwargs) -> c
each returning either a boolean or a real number. If the return value is a boolean, then the solve will terminate on the first step on whichc
becomesTrue
. If the return value is a real number, then the solve will terminate on the step whenc
changes sign. -
root_finder
: An optional root finder to use for finding the exact time of the event. If the triggered condition function returns a real number, then the final time will be the time at which that real number equals zero. (If the triggered condition function returns a boolean, then the returned time will just be the end of the step on which it becomesTrue
.)optimistix.Newton
would be a typical choice here.
Example
Consider a bouncing ball dropped from some intial height \(x_0\). We can model the ball by a 2-dimensional ODE
\(\frac{dx_t}{dt} = v_t, \quad \frac{dv_t}{dt} = -g,\)
where \(x_t\) represents the height of the ball, \(v_t\) its velocity, and \(g\) is the gravitational constant. With \(g=8\), this corresponds to the vector field:
def vector_field(t, y, args):
_, v = y
return jnp.array([v, -8.0])
Figuring out exactly when the ball hits the ground amounts to solving the ODE until the event \(x_t=0\) is triggered. This can be done by using the real-valued condition function:
def cond_fn(t, y, args, **kwargs):
x, _ = y
return x
With \(x_0=10\), this would yield:
y0 = jnp.array([10.0, 0.0])
t0 = 0
t1 = jnp.inf
dt0 = 0.1
term = diffrax.ODETerm(vector_field)
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = diffrax.Event(cond_fn, root_finder)
solver = diffrax.Tsit5()
sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event)
print(f"Event time: {sol.ts[0]}") # Event time: 1.58...
print(f"Velocity at event time: {sol.ys[0, 1]}") # Velocity at event time: -12.64...
diffrax.steady_state_event(rtol: Optional[float] = None, atol: Optional[float] = None, norm: Optional[Callable[[PyTree[Array]], Union[float, int]]] = None)
¤
Create a condition function that terminates the solve once a steady state is
achieved. The returned function should be passed as the cond_fn
argument of
diffrax.Event
.
Arguments:
rtol
,atol
,norm
: the solve will terminate oncenorm(f) < atol + rtol * norm(y)
, wheref
is the result of evaluating the vector field. Will default to the values used in thestepsize_controller
if they are not specified here.
Returns:
A function f(t, y, args, **kwargs)
, that can be passed to
diffrax.Event(cond_fn=..., ...)
.