API Reference
Below are the main functions and classes provided by the Spin Sampler library.
Sampler Class
- class spin_sampler.Sampler(J: ndarray, T: float, mode: str = 'single_chain', backend: str = 'numpy')[source]
Bases:
objectGibbs Sampler for spin system with different modes.
Attributes:
J (np.ndarray): Coupling matrix/matrices.
T (float): Temperature.
mode (str): Sampling mode (‘single_chain’, ‘multi_chain’, ‘multi_couplings’).
backend: Backend to use (‘numpy’, ‘numba’, ‘jax’).
- sample(initial_state, N_samples=1, dt_samples=1, rnd_ord=True, seed=None, store=False, progress=False)[source]
Run Gibbs sampling as a generator.
Parameters:
initial_state: Initial spin configuration. If None, use the last state in chain
N_samples: Total number of steps to perform.
dt_samples: Save every ‘dt_samples’ steps to reduce correlation.
rnd_ord: If True, update spins in random order.
seed: Optional seed for random number generation (int) / mandatory for jax.
store: If True, store the sampled states in memory.
progress: If True, display a progress bar.
Yields:
The sampled state at each step (after thinning).
- run_gibbs(initial_state, N_samples=1, dt_samples=1, rnd_ord=True, seed=None, store=False, progress=False)[source]
Iterate function ‘sample’ for ‘N_samples’ iterations and return the final state.
Parameters
initial_state: Initial spin configuration. If None, use the last state in chain
N_samples: Total number of steps to perform.
dt_samples: Save every ‘dt_samples’ steps to reduce correlation.
rnd_ord: If True, update spins in random order.
seed: Optional seed for random number generation (int) / mandatory for jax.
store: If True, store the sampled states in memory.
progress: If True, display a progress bar.
Returns
- np.ndarray
The final sampled state after
nstepsiterations.
Functions
- spin_sampler.define_hopfield_model(N, p, N_walkers=1, mode='single_chain', backend='numpy', seed=None)[source]
Defines the Hopfield model patterns and coupling matrix. The patterns are random binary vectors of size N with values in {-1, 1}. The coupling matrix J is constructed using the Hebbian learning rule and the patterns as J = 1/N * patterns @ patterns.T with the diagonal set to zero.
Parameters:
N: Number of spins or neurons.
p: Number of random patterns.
N_walkers: Number of chains to run in parallel.
mode: Sampling mode (‘single_chain’, ‘multi_chain’, ‘multi_couplings’).
backend: Backend to use (‘numpy’, ‘numba’, ‘jax’).
seed: Optional seed for random number generation (int) / mandatory for jax.
Returns:
J: Coupling matrix (shape (N, N) or (N_walkers,N,N)).
patterns: Generated patterns (shape (N, p) or (N_walkers,N,p)).
- spin_sampler.define_SK_model(N, N_walkers=1, mode='single_chain', backend='numpy', seed=None)[source]
Defines the Sherrington-Kirkpatrick model coupling matrix. The coupling matrix J is symmetric and constructed with entries drawn from a normal distribution with mean 0 and variance 1/N, and the diagonal set to zero.
Parameters:
N: Number of spins or neurons.
N_walkers: Number of chains to run in parallel.
mode: Sampling mode (‘single_chain’, ‘multi_chain’, ‘multi_couplings’).
backend: Backend to use (‘numpy’, ‘numba’, ‘jax’).
seed: Optional seed for random number generation (int) / mandatory for jax.
Returns:
J: Coupling matrix (shape (N, N) or (N_walkers,N,N)).
- spin_sampler.initialize_spins(N, N_walkers=1, mode='single_chain', backend='numpy', seed=None, config='random', ref_spin=None, m0=None)[source]
Initialize the spin states depending on the configuration.
Parameters:
N: Number of spins or neurons.
N_walkers: Number of chains to run in parallel.
mode: Sampling mode (‘single_chain’, ‘multi_chain’, ‘multi_couplings’).
backend: Backend to use (‘numpy’, ‘numba’, ‘jax’).
seed: Optional seed for random number generation (int) / mandatory for jax.
config: Initialization configuration (‘random’, ‘magnetized’).
ref_spin: Reference configuration for magnetized initialization (shape (N,) or (N_walkers,N)).
m0: Initial magnetization level for magnetized initialization (float in (0,1) or array of shape (N_walkers,)).
Returns:
S0: Initial spin configuration (shape (N) or (N_walkers,N)).
Gibbs steps
- spin_sampler.gibbs_steps.gibbs_step_single_chain(S, J, T, rnd_ord=True, key=None)[source]
One full update of spin state using Gibbs sampling for single chain with J.
Parameters:
S: Spin configuration shape (N,).
J: Coupling matrix (shape (N, N)).
T: Temperature.
rnd_ord: If True, update spins in random order.
key: Dummy variable for compatibility with JAX functions.
Returns:
Updated spin configuration and dummy variable.
- spin_sampler.gibbs_steps.gibbs_step_multi_chain(S, J, T, rnd_ord=True, key=None)[source]
One full update of spin state using Gibbs sampling for multiple chains with same J.
Parameters:
S: Spin configuration shape (N_walkers,N).
J: Coupling matrix (shape (N, N)).
T: Temperature.
rnd_ord: If True, update spins in random order.
key: Dummy variable for compatibility with JAX functions.
Returns:
Updated spin configuration and dummy variable.
- spin_sampler.gibbs_steps.gibbs_step_multi_couplings(S, J, T, rnd_ord=True, key=None)[source]
One full update of spin state using Gibbs sampling for multiple chains with different J.
Parameters:
S: Spin configuration (shape (N_walkers, N)).
J: Coupling matrices (shape (N_walkers, N, N)).
T: Temperature.
rnd_ord: If True, update spins in random order.
key: Dummy variable for compatibility with JAX functions.
Returns:
Updated spin configuration and dummy variable.
- spin_sampler.gibbs_steps.gibbs_step_single_chain_jax(S, J, T, rnd_ord=False, key=None)[source]
One full update of spin state using Gibbs sampling for single chain with J.
Parameters:
S: Spin configuration shape (N,).
J: Coupling matrix (shape (N, N)).
T: Temperature.
rnd_ord: If True, update spins in random order.
key: PRNG key for randomness.
Returns:
Updated spin configuration and new PRNG key.
- spin_sampler.gibbs_steps.gibbs_step_multi_chain_jax(S, J, T, rnd_ord=False, key=None)[source]
One full update of spin state using Gibbs sampling for multiple chains with same J.
Parameters:
S: Spin configuration shape (N_walkers,N).
J: Coupling matrix (shape (N, N)).
T: Temperature.
rnd_ord: If True, update spins in random order.
key: PRNG key for randomness.
Returns:
Updated spin configuration and new PRNG key.
- spin_sampler.gibbs_steps.gibbs_step_multi_couplings_jax(S, J, T, rnd_ord=False, key=None)[source]
One full update of spin state using Gibbs sampling for multiple chains with same J.
Parameters:
S: Spin configuration (shape (N_walkers, N)).
J: Coupling matrices (shape (N_walkers, N, N)).
T: Temperature.
rnd_ord: If True, update spins in random order.
key: PRNG key for randomness.
Returns:
Updated spin configuration and new PRNG key.