MCCube Markov chain cubature via JAX
MCCube provides the tools for performing Markov chain cubature in diffrax.
Key features:
- Custom terms, paths, and solvers that provide a painless means to perform MCC in diffrax.
- A small library of recombination kernels, conventional cubature formulae, and metrics.
Installation¤
To install the base pacakge:
Requires Python 3.9+, Diffrax 0.5.0+, and Equinox 0.11.3+.By default, a CPU only version of JAX will be installed. To make use of other JAX/XLA compatible accelerators (GPUs/TPUs) please follow these installation instructions. Windows support for JAX is currently experimental; WSL2 is the recommended approach for using JAX on Windows.
Documentation¤
Available at https://mccube.readthedocs.io/.
What is Markov chain cubature?¤
MCC is an approach to constructing a Cubature on Wiener Space which does not suffer from exponential scaling in time (particle count explosion), thanks to the utilization of (partitioned) recombination in the (approximate) cubature kernel.
Example¤
import diffrax
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.scipy.stats import multivariate_normal
from mccube import (
GaussianRegion,
Hadamard,
LocalLinearCubaturePath,
MCCSolver,
MCCTerm,
MonteCarloKernel,
gaussian_wasserstein_metric,
)
key = jr.key(42)
n, d = 512, 10
t0 = 0.0
epochs = 512
dt0 = 0.05
t1 = t0 + dt0 * epochs
y0 = jnp.ones((n, d))
target_mean = 2 * jnp.ones(d)
target_cov = 3 * jnp.eye(d)
def logdensity(p):
return multivariate_normal.logpdf(p, mean=target_mean, cov=target_cov)
ode = diffrax.ODETerm(lambda t, p, args: jax.vmap(jax.grad(logdensity))(p))
cde = diffrax.WeaklyDiagonalControlTerm(
lambda t, p, args: jnp.sqrt(2.0),
LocalLinearCubaturePath(Hadamard(GaussianRegion(d))),
)
terms = MCCTerm(ode, cde)
solver = MCCSolver(diffrax.Euler(), MonteCarloKernel(n, key=key))
sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0, y0)
res_mean = jnp.mean(sol.ys[-1], axis=0)
res_cov = jnp.cov(sol.ys[-1], rowvar=False)
metric = gaussian_wasserstein_metric((target_mean, res_mean), (target_cov, res_cov))
print(f"Result 2-Wasserstein distance: {metric}")
Citation¤
Please cite this repository if it has been useful in your work:
@software{mccube2023github,
author={},
title={{MCC}ube: Markov chain cubature via {JAX}},
url={},
version={<insert current release tag>},
year={2023},
}
See Also¤
Some other Python/JAX packages that you may find interesting:
- Markov-Chain-Cubature A PyTorch implementation of Markov Chain Cubature.
- PySR High-Performance Symbolic Regression in Python and Julia.
- Equinox A JAX library for parameterised functions.
- Diffrax A JAX library providing numerical differential equation solvers.
- Lineax A JAX library for linear solves and linear least squares.
- OTT-JAX A JAX library for optimal transport.