# Event handling and callbacks in ODE solvers¶

The differential equation solvers in ProbNum are able to handle events. At the moment, an event can either be a set of grid-points that must be included in the posterior, or a state for which a condition-function

\begin{align}\begin{aligned} \text{condition}: \mathbb{R}^d \rightarrow \{0, 1\},\\evaluates to True. This notebook explains how this can be used with ProbNum (some examples are taken from https://diffeq.sciml.ai/stable/features/callback\_functions/)\end{aligned}\end{align}

## Quickstart¶

What is the easiest way to force events into your ODE solution? Let us define a simple, linear ODE that describes exponential decay.

[1]:

# Make inline plots vector graphics instead of raster graphics
%matplotlib inline
from IPython.display import set_matplotlib_formats

set_matplotlib_formats("pdf", "svg")

# Plotting
import matplotlib.pyplot as plt

plt.style.use("../../probnum.mplstyle")

[2]:

from probnum import diffeq, randvars, randprocs, problems
import numpy as np

# For easy modification of the states in the callbacks
import dataclasses

[3]:

def f(t, y):
return -y

def df(t, y):
return -1.0 * np.eye(len(y))  # np.ones((len(y), len(y)))

t0 = 0.0
tmax = 5.0
y0 = np.array([4])


To show off the ability to include a set number of grid-points, let us define a dense grid in a subset of the integration domain.

[4]:

time_stops = np.linspace(3.5, 4.0, 50)


To force the ODE solver to include these time-stamps, just pass them to probsolve_ivp. Here, we pick a large relative tolerance because we want to see a range of samples (the ODE is so simple, it is solved very accurately on large steps).

[5]:

probsol = diffeq.probsolve_ivp(
f,
t0,
tmax,
y0,
time_stops=time_stops,
rtol=0.8,
)

# Draw 10 samples from the posterior and plot.
rng = np.random.default_rng(seed=2)
samples = probsol.sample(size=10, rng=rng)
for sample in samples:
plt.plot(probsol.locations, sample, "o-", color="C0")
plt.show()


Observe how there is a dense gathering of grid-points between 3.5 and 4.0. These are our events!

The same works for e.g. perturbsolve_ivp. Let us compute 10 perturbed solutions, so the plots look similar to the samples from the posterior of the probabilistic solver.

[6]:

# every solve is random
rng = np.random.default_rng()

time_stops = np.linspace(3.5, 4.0, 100)
perturbsols = [
diffeq.perturbsolve_ivp(
f=f,
t0=t0,
tmax=tmax,
y0=y0,
rng=rng,
noise_scale=0.05,
time_stops=time_stops,
)
for _ in range(10)
]

for perturbsol in perturbsols:
plt.plot(perturbsol.locations, perturbsol.states.mean, "o-", color="C1")
plt.show()


Again, observe how there are many locations between 3.5 and 4.0.

## Discrete callback events¶

It is also possible to modify the solver states whenever an event happens. This is not possible via the top-level interface functions (e.g. probsolve_ivp) - we have to build an ODE solver from scratch (see the respective notebook for an explanation thereof).

[7]:

# Construct IVP, prior, linearization, diffusion, and initialization
ivp = problems.InitialValueProblem(t0=t0, tmax=tmax, y0=y0, f=f, df=df)
prior_process = randprocs.markov.integrator.IntegratedWienerProcess(
initarg=ivp.t0,
num_derivatives=1,
wiener_process_dimension=ivp.dimension,
forward_implementation="sqrt",
backward_implementation="sqrt",
)
diffmodel = randprocs.markov.continuous.PiecewiseConstantDiffusion(t0=t0)
rk_init = diffeq.odefilter.initialization_routines.RungeKuttaInitialization()
ode_residual = diffeq.odefilter.information_operators.ODEResidual(1, ivp.dimension)
ek1 = diffeq.odefilter.approx_strategies.EK1()
firststep = diffeq.stepsize.propose_firststep(ivp)
steprule = diffeq.stepsize.AdaptiveSteps(firststep=firststep, atol=1e-1, rtol=1e-1)

solver = diffeq.odefilter.ODEFilter(
steprule,
prior_process=prior_process,
information_operator=ode_residual,
approx_strategy=ek1,
initialization_routine=rk_init,
diffusion_model=diffmodel,
with_smoothing=False,
)


To describe a discrete event, we define a condition function that checks whether the current time-point is either 2.0 or 4.0. At both locations, we reset the current state to $$y=6.$$ (careful! The state of a filter-based solver consists of $$[y, \dot y, \ddot y, ...]$$).

Let us construct both functions and pass them to a DiscreteEventHandler. Since the solver is unlikely to stop at exactly 2.0 or 4.0, let us force these locations into the ODE solver posterior.

[8]:

def condition(state: diffeq.ODESolverState) -> bool:
return state.t in [2.0, 4.0]

def replace(state: diffeq.ODESolverState) -> diffeq.ODESolverState:
"""Replace an ODE solver state whenever a condition is True."""
new_mean = np.array([6.0, -6])
new_rv = randvars.Normal(
new_mean, cov=0 * state.rv.cov, cov_cholesky=0 * state.rv.cov_cholesky
)
return dataclasses.replace(state, rv=new_rv)

callback = diffeq.callbacks.DiscreteCallback(condition=condition, replace=replace)

odesol = solver.solve(ivp=ivp, stop_at=[2.0, 4.0], callbacks=callback)

[9]:

plt.plot(odesol.locations, odesol.states.mean, "o-")
plt.show()