Skip to content

mccube._solvers

Defines custom solvers for performing MCC in diffrax.

See diffrax.AbstractSolver for further information on the solvers API.

MCCSolver ยค

Bases: AbstractWrappedSolver[_SolverState]

Markov chain cubature solver for diffrax.diffeqsolve.

Composes a diffrax.AbstractSolver with a mccube.AbstractRecombinationKernel, such that a single step of the solver is equivalent to the evaluation of an approximate cubature kernel from \(t_0\) to \(t_1\).

If the recombination kernel is not present, the wrapped solver step is equivalent to an exact cubature kernel from \(t_0\) to \(t_1\). Such a kernel will fail to preserve the shape of \(y(t)\) across steps and, thus, is incompatible with diffrax.

However, one can subdivide the time interval into n_substeps and compute the exact kernel for each sub-step before finally composing the shape preserving recombination kernel. This provides a useful dial for tuning the tradeoff between memory usage and recombination information loss.

Example
import jax.numpy as jnp
import jax.random as jr
from diffrax import diffeqsolve, Euler

key, rng_key = jr.split(jr.key(42))
t0, t1 = 0.0, 1.0
dt0 = 0.001
particles = jnp.ones((32,8))
weights = jr.uniform(rng_key, (32,))
y0 = mccube.pack_particles(particles, weights)
n, d = y0.shape

gaussian_cubature = mccube.Hadamard(mccube.GaussianRegion(d))
cubature_control = mccube.LocalLinearCubaturePath(gaussian_cubature)
ode = ODETerm(lambda t,y,args: -y)
cde = WeaklyDiagonalControlTerm(lambda t,y,args: jnp.sqrt(2), cubature_control)
terms = mccube.MCCTerm(ode, cde)

kernel = mccube.MonteCarloKernel(n, key=key)
solver = mccube.MCCSolver(Euler(), kernel, n_substeps=2, weighted=True)
sol = diffeqsolve(solver, terms, t0, t1, y0)

Attributes:

  • solver (AbstractSolver[_SolverState]) โ€“

    a diffrax.AbstractSolver which, in conjuction with the terms, defines an exact cubature kernel. Note: support is only provided for the diffrax.Euler solver at present.

  • recombination_kernel (AbstractRecombinationKernel) โ€“

    a callable which takes the interval end time, \(t_1\), the potentially expanded state \(y(t_1)\), and any additional \(\text{args}\), and yields a recombined (shape preserved) state \(\hat{y}(t_1)\).

  • n_substeps (int) โ€“

    the number of steps, \(n_s\), to subdivide the interval \([t_0, t_1]\) into. Equivalently can be considered as the number of exact kernel evluations. Note: memory scales with \(\mathcal{O}(k^{n_s})\), where \(k\) is the number of cubature vectors/paths.