API Reference

Below are the main functions and classes provided by the Spin Sampler library.

Sampler Class

class spin_sampler.Sampler(J: ndarray, T: float, mode: str = 'single_chain', backend: str = 'numpy')[source]

Bases: object

Gibbs Sampler for spin system with different modes.

Attributes:

  • J (np.ndarray): Coupling matrix/matrices.

  • T (float): Temperature.

  • mode (str): Sampling mode (‘single_chain’, ‘multi_chain’, ‘multi_couplings’).

  • backend: Backend to use (‘numpy’, ‘numba’, ‘jax’).

sample(initial_state, N_samples=1, dt_samples=1, rnd_ord=True, seed=None, store=False, progress=False)[source]

Run Gibbs sampling as a generator.

Parameters:

  • initial_state: Initial spin configuration. If None, use the last state in chain

  • N_samples: Total number of steps to perform.

  • dt_samples: Save every ‘dt_samples’ steps to reduce correlation.

  • rnd_ord: If True, update spins in random order.

  • seed: Optional seed for random number generation (int) / mandatory for jax.

  • store: If True, store the sampled states in memory.

  • progress: If True, display a progress bar.

Yields:

  • The sampled state at each step (after thinning).

run_gibbs(initial_state, N_samples=1, dt_samples=1, rnd_ord=True, seed=None, store=False, progress=False)[source]

Iterate function ‘sample’ for ‘N_samples’ iterations and return the final state.

Parameters

  • initial_state: Initial spin configuration. If None, use the last state in chain

  • N_samples: Total number of steps to perform.

  • dt_samples: Save every ‘dt_samples’ steps to reduce correlation.

  • rnd_ord: If True, update spins in random order.

  • seed: Optional seed for random number generation (int) / mandatory for jax.

  • store: If True, store the sampled states in memory.

  • progress: If True, display a progress bar.

Returns

np.ndarray

The final sampled state after nsteps iterations.

step(S, rnd_ord=True, key=None)[source]

Perform one Gibbs sampling step of the state S.

Parameters:

  • S: Current spin configuration.

get_chain()[source]

Get the stored chain of sampled states.

Returns:

  • Array of sampled states.

reset_chain()[source]

Reset the stored chain of sampled states.

Functions

spin_sampler.define_hopfield_model(N, p, N_walkers=1, mode='single_chain', backend='numpy', seed=None)[source]

Defines the Hopfield model patterns and coupling matrix. The patterns are random binary vectors of size N with values in {-1, 1}. The coupling matrix J is constructed using the Hebbian learning rule and the patterns as J = 1/N * patterns @ patterns.T with the diagonal set to zero.

Parameters:

  • N: Number of spins or neurons.

  • p: Number of random patterns.

  • N_walkers: Number of chains to run in parallel.

  • mode: Sampling mode (‘single_chain’, ‘multi_chain’, ‘multi_couplings’).

  • backend: Backend to use (‘numpy’, ‘numba’, ‘jax’).

  • seed: Optional seed for random number generation (int) / mandatory for jax.

Returns:

  • J: Coupling matrix (shape (N, N) or (N_walkers,N,N)).

  • patterns: Generated patterns (shape (N, p) or (N_walkers,N,p)).

spin_sampler.define_SK_model(N, N_walkers=1, mode='single_chain', backend='numpy', seed=None)[source]

Defines the Sherrington-Kirkpatrick model coupling matrix. The coupling matrix J is symmetric and constructed with entries drawn from a normal distribution with mean 0 and variance 1/N, and the diagonal set to zero.

Parameters:

  • N: Number of spins or neurons.

  • N_walkers: Number of chains to run in parallel.

  • mode: Sampling mode (‘single_chain’, ‘multi_chain’, ‘multi_couplings’).

  • backend: Backend to use (‘numpy’, ‘numba’, ‘jax’).

  • seed: Optional seed for random number generation (int) / mandatory for jax.

Returns:

  • J: Coupling matrix (shape (N, N) or (N_walkers,N,N)).

spin_sampler.initialize_spins(N, N_walkers=1, mode='single_chain', backend='numpy', seed=None, config='random', ref_spin=None, m0=None)[source]

Initialize the spin states depending on the configuration.

Parameters:

  • N: Number of spins or neurons.

  • N_walkers: Number of chains to run in parallel.

  • mode: Sampling mode (‘single_chain’, ‘multi_chain’, ‘multi_couplings’).

  • backend: Backend to use (‘numpy’, ‘numba’, ‘jax’).

  • seed: Optional seed for random number generation (int) / mandatory for jax.

  • config: Initialization configuration (‘random’, ‘magnetized’).

  • ref_spin: Reference configuration for magnetized initialization (shape (N,) or (N_walkers,N)).

  • m0: Initial magnetization level for magnetized initialization (float in (0,1) or array of shape (N_walkers,)).

Returns:

  • S0: Initial spin configuration (shape (N) or (N_walkers,N)).

Gibbs steps

spin_sampler.gibbs_steps.gibbs_step_single_chain(S, J, T, rnd_ord=True, key=None)[source]

One full update of spin state using Gibbs sampling for single chain with J.

Parameters:

  • S: Spin configuration shape (N,).

  • J: Coupling matrix (shape (N, N)).

  • T: Temperature.

  • rnd_ord: If True, update spins in random order.

  • key: Dummy variable for compatibility with JAX functions.

Returns:

  • Updated spin configuration and dummy variable.

spin_sampler.gibbs_steps.gibbs_step_multi_chain(S, J, T, rnd_ord=True, key=None)[source]

One full update of spin state using Gibbs sampling for multiple chains with same J.

Parameters:

  • S: Spin configuration shape (N_walkers,N).

  • J: Coupling matrix (shape (N, N)).

  • T: Temperature.

  • rnd_ord: If True, update spins in random order.

  • key: Dummy variable for compatibility with JAX functions.

Returns:

  • Updated spin configuration and dummy variable.

spin_sampler.gibbs_steps.gibbs_step_multi_couplings(S, J, T, rnd_ord=True, key=None)[source]

One full update of spin state using Gibbs sampling for multiple chains with different J.

Parameters:

  • S: Spin configuration (shape (N_walkers, N)).

  • J: Coupling matrices (shape (N_walkers, N, N)).

  • T: Temperature.

  • rnd_ord: If True, update spins in random order.

  • key: Dummy variable for compatibility with JAX functions.

Returns:

  • Updated spin configuration and dummy variable.

spin_sampler.gibbs_steps.gibbs_step_single_chain_jax(S, J, T, rnd_ord=False, key=None)[source]

One full update of spin state using Gibbs sampling for single chain with J.

Parameters:

  • S: Spin configuration shape (N,).

  • J: Coupling matrix (shape (N, N)).

  • T: Temperature.

  • rnd_ord: If True, update spins in random order.

  • key: PRNG key for randomness.

Returns:

  • Updated spin configuration and new PRNG key.

spin_sampler.gibbs_steps.gibbs_step_multi_chain_jax(S, J, T, rnd_ord=False, key=None)[source]

One full update of spin state using Gibbs sampling for multiple chains with same J.

Parameters:

  • S: Spin configuration shape (N_walkers,N).

  • J: Coupling matrix (shape (N, N)).

  • T: Temperature.

  • rnd_ord: If True, update spins in random order.

  • key: PRNG key for randomness.

Returns:

  • Updated spin configuration and new PRNG key.

spin_sampler.gibbs_steps.gibbs_step_multi_couplings_jax(S, J, T, rnd_ord=False, key=None)[source]

One full update of spin state using Gibbs sampling for multiple chains with same J.

Parameters:

  • S: Spin configuration (shape (N_walkers, N)).

  • J: Coupling matrices (shape (N_walkers, N, N)).

  • T: Temperature.

  • rnd_ord: If True, update spins in random order.

  • key: PRNG key for randomness.

Returns:

  • Updated spin configuration and new PRNG key.