Hamiltonian Ensemble Sampler¶
A Hamiltonian sampler is a Markov Chain Monte Carlo (MCMC) algorithm that uses concepts from Hamiltonian dynamics to explore probability distributions more efficiently than simple random-walk methods.
- class hemcee.samplers.hamiltonian_ensemble.HamiltonianEnsembleSampler(total_chains: int, dim: int, log_prob: ~typing.Callable, step_size: float = 0.1, L: int = 10, move=<function hmc_walk_move>, backend: ~hemcee.backend.backend.Backend = None, adapt_inital_step_size: bool = True, adapt_step_size: bool | ~hemcee.adaptation.dual_averaging.DAParameters = True, adapt_length: bool | ~hemcee.adaptation.chees.ChEESParameters = True)¶
Bases:
BaseSamplerHamiltonian ensemble sampler with optional dual averaging and ChEES adaptation.
- Attributes:
total_chains (int): Total number of ensemble chains. dim (int): Dimensionality of the target distribution. log_prob (Callable): Vectorized log-probability function. grad_log_prob (Callable): Vectorized gradient of the log probability. step_size (float): Leapfrog step size. L (int): Number of leapfrog steps per move. move (Callable): Proposal function updating each ensemble group. adapter (Adapter): Adapter for step size and integration time adaptation. adapter_state: Current state of the adapter.
- get_acceptance_prob()¶
- get_autocorr(discard, thin)¶
- get_chain(discard: int = 0, thin: int = 1, flat: bool = False) Array¶
- get_logprob(discard: int = 0, thin: int = 1, flat: bool = False) Array¶
- run_mcmc(key: PRNGKey, initial_state: Array, num_samples: int, warmup: int = 1000, thin_by=1, batch_size: int = None, show_progress: bool = False) Tuple[Array, dict]¶
Run the Hamiltonian ensemble sampler.
- Args:
key (jax.random.PRNGKey): Random number generator key. initial_state (jnp.ndarray): Initial ensemble state with shape
(total_chains, dim).num_samples (int): Number of post-warmup samples to retain. warmup (int): Number of warmup iterations. Defaults to
1000. thin_by (int): Keep everythin_bysample. Defaults to
1(no thinning).- show_progress (bool): Whether to display a progress bar. Defaults
to
False.
- Returns:
tuple[jnp.ndarray, dict]: Post-warmup samples and diagnostics containing acceptance rates and dual averaging state.