TaylorModeInitialization

class probnum.diffeq.odefiltsmooth.initialization_routines.TaylorModeInitialization

Bases: probnum.diffeq.odefiltsmooth.initialization_routines.InitializationRoutine

Initialize a probabilistic ODE solver with Taylor-mode automatic differentiation.

This requires JAX. For an explanation of what happens under the hood, see 1.

The implementation is inspired by the implementation in https://github.com/jacobjinkelly/easy-neural-ode/blob/master/latent_ode.py See also 2.

References

1

Krämer, N. and Hennig, P., Stable implementation of probabilistic ODE solvers, arXiv:2012.10106, 2020.

2

Kelly, J. and Bettencourt, J. and Johnson, M. and Duvenaud, D., Learning differential equations that are easy to solve, Neurips 2020.

Examples

>>> import sys, pytest
>>> if not sys.platform.startswith('linux'):
...     pytest.skip()
>>> import numpy as np
>>> from probnum.randvars import Normal
>>> from probnum.problems.zoo.diffeq import threebody_jax, vanderpol_jax
>>> from probnum.statespace import IBM
>>> from probnum.randprocs import MarkovProcess

Compute the initial values of the restricted three-body problem as follows

>>> ivp = threebody_jax()
>>> print(ivp.y0)
[ 0.994       0.          0.         -2.00158511]

Construct the prior process.

>>> prior = IBM(ordint=3, spatialdim=4)
>>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension))
>>> prior_process = MarkovProcess(transition=prior, initrv=initrv, initarg=ivp.t0)

Initialize with Taylor-mode autodiff.

>>> taylor_init = TaylorModeInitialization()
>>> improved_initrv = taylor_init(ivp=ivp, prior_process=prior_process)

Print the results.

>>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean)
[ 0.994       0.          0.         -2.00158511]
>>> print(improved_initrv.mean)
[ 9.94000000e-01  0.00000000e+00 -3.15543023e+02  0.00000000e+00
  0.00000000e+00 -2.00158511e+00  0.00000000e+00  9.99720945e+04
  0.00000000e+00 -3.15543023e+02  0.00000000e+00  6.39028111e+07
 -2.00158511e+00  0.00000000e+00  9.99720945e+04  0.00000000e+00]

Compute the initial values of the van-der-Pol oscillator as follows. First, set up the IVP and prior process.

>>> ivp = vanderpol_jax()
>>> print(ivp.y0)
[2. 0.]
>>> prior = IBM(ordint=3, spatialdim=2)
>>> initrv = Normal(mean=np.zeros(prior.dimension), cov=np.eye(prior.dimension))
>>> prior_process = MarkovProcess(transition=prior, initrv=initrv, initarg=ivp.t0)
>>> taylor_init = TaylorModeInitialization()
>>> improved_initrv = taylor_init(ivp=ivp, prior_process=prior_process)

Print the results.

>>> print(prior_process.transition.proj2coord(0) @ improved_initrv.mean)
[2. 0.]
>>> print(improved_initrv.mean)
[    2.     0.    -2.    60.     0.    -2.    60. -1798.]
>>> print(improved_initrv.std)
[0. 0. 0. 0. 0. 0. 0. 0.]

Attributes Summary

is_exact

Exactness of the computed initial values.

requires_jax

Whether the implementation of the routine relies on JAX.

Methods Summary

__call__(ivp, prior_process)

Call self as a function.

Attributes Documentation

is_exact

Exactness of the computed initial values.

Some initialization routines yield the exact initial derivatives, some others only yield approximations.

Return type

bool

requires_jax

Whether the implementation of the routine relies on JAX.

Return type

bool

Methods Documentation

__call__(ivp, prior_process)[source]

Call self as a function.

Return type

RandomVariable