Skip to content

mccube._kernels.base

AbstractKernel ¤

Bases: Module

Abstract base class for all Kernels.

Let a time-evolving PyTree of constant structure \(P\), whose leaves are two-dimensional arrays with constant trailing dimension \(d\) and time-dependant leading dimension \(k\), be denoted \(p(t)\) and referred to as particles.

Given \(p(t_0)\), a (transition) kernel \(f \colon (t_0, p(t_0), \text{args}) \mapsto p(t_1)\) defines the state of the particles at some (future) time \(t_1 \ge t_0\), providing \(t_1 - t_0\) is sufficiently small. An alternative interpretation is to view \(p(t)\) as the solution of some differential equation, and the kernel as a (non-linear) vector-field control product.

__call__ abstractmethod ¤

__call__(t: RealScalarLike, particles: Particles, args: Args, weighted: bool = False) -> Particles

Transform the particles.

Parameters:

  • t (RealScalarLike) –

    current time; particle state observation time, \(t\).

  • particles (Particles) –

    particles to transform, \(p(t_0)\).

  • args (Args) –

    additional static arguments passed to the transform.

Returns:

  • Particles

    A PyTree of transformed particles \(p(t)\) with the same PyTree structure and dimension as the input particles, \(p(t_0)\).

Source code in mccube/_kernels/base.py
@abc.abstractmethod
def __call__(
    self,
    t: RealScalarLike,
    particles: Particles,
    args: Args,
    weighted: bool = False,
) -> Particles:
    r"""Transform the particles.

    Args:
        t: current time; particle state observation time, $t$.
        particles: particles to transform, $p(t_0)$.
        args: additional static arguments passed to the transform.

    Returns:
        A PyTree of transformed particles $p(t)$ with the same PyTree structure and
            dimension as the input particles, $p(t_0)$.
    """
    ...

AbstractPartitioningKernel ¤

Bases: AbstractKernel

Indicates (but does not check) that the kernel performs partitioning.

Partitioning requires the kernel to reshape each leaf of the particles PyTree from the shape [n, d] to the shape [m, n/m, d], where \(m\) is the number of partitions. Implcitly, the partitions must be of equal size.

Attributes:

  • partition_count (PyTree[int, Particles] | None) –

    indicates the requested number of partitions, \(m\).

AbstractRecombinationKernel ¤

Bases: AbstractKernel

Indicates (but does not check) that the kernel performs recombination.

Recombination requires the kernel to strictly reduce the size of the leading dimension \(n\) of each leaf in the particles PyTree. The reduced set of recombined particles \(p(t_1)\) are expected to, in some abstract sense, be as representative as possible of the input particles \(p(t_0)\).

Attributes:

  • recombination_count (PyTree[int, Particles] | None) –

    indicates the requested size of the recombined dimension.

PartitioningRecombinationKernel ¤

PartitioningRecombinationKernel(partitioning_kernel: AbstractPartitioningKernel, recombination_kernel: AbstractRecombinationKernel)

Bases: AbstractRecombinationKernel

Composes a partitioning kernel with a recombination kernel.

The recombination kernel is applied independantly to each partition generated by the partitioning kernel.

Example
import jax.numpy as jnp
import jax.random as jr

key = jr.key(42)

y0 = jnp.ones((64,8))
n, d = y0.shape
n_out = n // 2
n_partitions = 4
partitioning_kernel = mccube.BinaryTreePartitioningKernel(n_partitions)
# The recombination_count is modified to account for the partitioning.
recombination_kernel = mccube.MonteCarloKernel(n_out // n_partitions, key=key)
kernel = mccube.PartitioningRecombinationKernel(
    partitioning_kernel,
    recombination_kernel
)
result = kernel(..., y0, ...)
# jnp.ones((32, 8))

Attributes:

Source code in mccube/_kernels/base.py
def __init__(
    self,
    partitioning_kernel: AbstractPartitioningKernel,
    recombination_kernel: AbstractRecombinationKernel,
):
    self.partitioning_kernel = partitioning_kernel
    self.recombination_kernel = recombination_kernel
    self.recombination_count = jtu.tree_map(
        operator.mul,
        self.recombination_kernel.recombination_count,
        self.partitioning_kernel.partition_count,
    )