mccube._solvers
Defines custom solvers for performing MCC in diffrax.
See diffrax.AbstractSolver for further information on the solvers API.
MCCSolver
ยค
Bases: AbstractWrappedSolver[_SolverState]
Markov chain cubature solver for diffrax.diffeqsolve.
Composes a diffrax.AbstractSolver with a mccube.AbstractRecombinationKernel,
such that a single step of the solver is equivalent to the evaluation of an
approximate cubature kernel from \(t_0\) to \(t_1\).
If the recombination kernel is not present, the wrapped solver step is equivalent to an exact cubature kernel from \(t_0\) to \(t_1\). Such a kernel will fail to preserve the shape of \(y(t)\) across steps and, thus, is incompatible with diffrax.
However, one can subdivide the time interval into n_substeps and compute the exact
kernel for each sub-step before finally composing the shape preserving recombination
kernel. This provides a useful dial for tuning the tradeoff between memory usage and
recombination information loss.
Example
import jax.numpy as jnp
import jax.random as jr
from diffrax import diffeqsolve, Euler
key, rng_key = jr.split(jr.key(42))
t0, t1 = 0.0, 1.0
dt0 = 0.001
particles = jnp.ones((32,8))
weights = jr.uniform(rng_key, (32,))
y0 = mccube.pack_particles(particles, weights)
n, d = y0.shape
gaussian_cubature = mccube.Hadamard(mccube.GaussianRegion(d))
cubature_control = mccube.LocalLinearCubaturePath(gaussian_cubature)
ode = ODETerm(lambda t,y,args: -y)
cde = WeaklyDiagonalControlTerm(lambda t,y,args: jnp.sqrt(2), cubature_control)
terms = mccube.MCCTerm(ode, cde)
kernel = mccube.MonteCarloKernel(n, key=key)
solver = mccube.MCCSolver(Euler(), kernel, n_substeps=2, weighted=True)
sol = diffeqsolve(solver, terms, t0, t1, y0)
Attributes:
-
solver(AbstractSolver[_SolverState]) โa
diffrax.AbstractSolverwhich, in conjuction with theterms, defines an exact cubature kernel. Note: support is only provided for thediffrax.Eulersolver at present. -
recombination_kernel(AbstractRecombinationKernel) โa callable which takes the interval end time, \(t_1\), the potentially expanded state \(y(t_1)\), and any additional \(\text{args}\), and yields a recombined (shape preserved) state \(\hat{y}(t_1)\).
-
n_substeps(int) โthe number of steps, \(n_s\), to subdivide the interval \([t_0, t_1]\) into. Equivalently can be considered as the number of exact kernel evluations. Note: memory scales with \(\mathcal{O}(k^{n_s})\), where \(k\) is the number of cubature vectors/paths.