Hamiltonian Ensemble Sampler

A Hamiltonian sampler is a Markov Chain Monte Carlo (MCMC) algorithm that uses concepts from Hamiltonian dynamics to explore probability distributions more efficiently than simple random-walk methods.

class hemcee.samplers.hamiltonian_ensemble.HamiltonianEnsembleSampler(total_chains: int, dim: int, log_prob: ~typing.Callable, step_size: float = 0.1, L: int = 10, move=<function hmc_walk_move>, backend: ~hemcee.backend.backend.Backend = None, adapt_inital_step_size: bool = True, adapt_step_size: bool | ~hemcee.adaptation.dual_averaging.DAParameters = True, adapt_length: bool | ~hemcee.adaptation.chees.ChEESParameters = True)

Bases: BaseSampler

Hamiltonian ensemble sampler with optional dual averaging and ChEES adaptation.

Attributes:

total_chains (int): Total number of ensemble chains. dim (int): Dimensionality of the target distribution. log_prob (Callable): Vectorized log-probability function. grad_log_prob (Callable): Vectorized gradient of the log probability. step_size (float): Leapfrog step size. L (int): Number of leapfrog steps per move. move (Callable): Proposal function updating each ensemble group. adapter (Adapter): Adapter for step size and integration time adaptation. adapter_state: Current state of the adapter.

get_acceptance_prob()
get_autocorr(discard, thin)
get_chain(discard: int = 0, thin: int = 1, flat: bool = False) Array
get_logprob(discard: int = 0, thin: int = 1, flat: bool = False) Array
run_mcmc(key: PRNGKey, initial_state: Array, num_samples: int, warmup: int = 1000, thin_by=1, batch_size: int = None, show_progress: bool = False) Tuple[Array, dict]

Run the Hamiltonian ensemble sampler.

Args:

key (jax.random.PRNGKey): Random number generator key. initial_state (jnp.ndarray): Initial ensemble state with shape

(total_chains, dim).

num_samples (int): Number of post-warmup samples to retain. warmup (int): Number of warmup iterations. Defaults to 1000. thin_by (int): Keep every thin_by

sample. Defaults to 1 (no thinning).

show_progress (bool): Whether to display a progress bar. Defaults

to False.

Returns:

tuple[jnp.ndarray, dict]: Post-warmup samples and diagnostics containing acceptance rates and dual averaging state.