Source code for spin_sampler.gibbs_steps

import numpy as np
from spin_sampler.utils import prob_plus , prob_plus_jax
import jax.numpy as jnp
import jax.lax as lax
import jax
from jax import jit
from line_profiler import profile

# Data types
type_spins_np = np.int8
type_reals_np = np.float32



# NUMPY FUNCTIONS
[docs] @profile def gibbs_step_single_chain(S, J, T, rnd_ord = True , key = None): """ One full update of spin state using Gibbs sampling for single chain with J. Parameters: ----------- - S: Spin configuration shape (N,). - J: Coupling matrix (shape (N, N)). - T: Temperature. - rnd_ord: If True, update spins in random order. - key: Dummy variable for compatibility with JAX functions. Returns: -------- - Updated spin configuration and dummy variable. """ beta = 1 / T N = len(S) # Define the update order idx = np.random.permutation(N) if rnd_ord else np.arange(N) # Update sequentially the spins for i in idx: h_i = S @ J[i] p_plus = prob_plus(beta * h_i) rand_vals = np.random.rand() S[i] = 1 if rand_vals < p_plus else -1 return S , key
[docs] @profile def gibbs_step_multi_chain(S, J, T, rnd_ord = True, key = None): """ One full update of spin state using Gibbs sampling for multiple chains with same J. Parameters: ----------- - S: Spin configuration shape (N_walkers,N). - J: Coupling matrix (shape (N, N)). - T: Temperature. - rnd_ord: If True, update spins in random order. - key: Dummy variable for compatibility with JAX functions. Returns: -------- - Updated spin configuration and dummy variable. """ beta = 1 / T N_walkers , N = S.shape # Define the update order idx = np.random.permutation(N) if rnd_ord else np.arange(N) # Update sequentially the spins for i in idx: h_i = S @ J[i] p_plus = prob_plus(beta * h_i) rand_vals = np.random.rand(N_walkers) S[:, i] = np.where(rand_vals < p_plus, 1, -1) return S , key
[docs] @profile def gibbs_step_multi_couplings(S, J, T, rnd_ord = True , key = None): """ One full update of spin state using Gibbs sampling for multiple chains with different J. Parameters: ----------- - S: Spin configuration (shape (N_walkers, N)). - J: Coupling matrices (shape (N_walkers, N, N)). - T: Temperature. - rnd_ord: If True, update spins in random order. - key: Dummy variable for compatibility with JAX functions. Returns: -------- - Updated spin configuration and dummy variable. """ beta = 1 / T N_walkers, N = S.shape # Define the update order idx = np.random.permutation(N) if rnd_ord else np.arange(N) # Update sequentially the spins for i in idx: h_i = np.sum(J[:, i] * S, axis=1) p_plus = prob_plus(beta * h_i) rand_vals = np.random.rand(N_walkers) S[:, i] = np.where(rand_vals < p_plus, 1, -1) return S , key
# JAX FUNCTIONS
[docs] @jit def gibbs_step_single_chain_jax(S , J , T , rnd_ord = False, key = None): """ One full update of spin state using Gibbs sampling for single chain with J. Parameters: ----------- - S: Spin configuration shape (N,). - J: Coupling matrix (shape (N, N)). - T: Temperature. - rnd_ord: If True, update spins in random order. - key: PRNG key for randomness. Returns: -------- - Updated spin configuration and new PRNG key. """ N = S.shape[0] beta = 1 / T key, subkey = jax.random.split(key) idx = lax.cond( rnd_ord, lambda _: jax.random.permutation(subkey, N), # If rnd_ord is True lambda _: jnp.arange(N), # If rnd_ord is False operand=None ) def update_spin(i, val): S , key = val id = idx[i] key, subkey = jax.random.split(key) # Compute local field for spin i # h_i = S@J[id] h_i=jnp.dot(J[id],S) # Compute probability P(si = +1) p_plus = prob_plus_jax(beta * h_i) #1 / (1 + jnp.exp(-beta_2 * h_i)) # Update spin based on probabilities new_spin = jax.random.choice(subkey,jnp.array([1, -1]), p = jnp.array([p_plus, 1-p_plus])) S = S.at[id].set(new_spin) return S , key S , key = lax.fori_loop(0, N, update_spin, (S , key)) return S, key
[docs] @jit def gibbs_step_multi_chain_jax(S , J , T , rnd_ord = False, key = None): """ One full update of spin state using Gibbs sampling for multiple chains with same J. Parameters: ----------- - S: Spin configuration shape (N_walkers,N). - J: Coupling matrix (shape (N, N)). - T: Temperature. - rnd_ord: If True, update spins in random order. - key: PRNG key for randomness. Returns: -------- - Updated spin configuration and new PRNG key. """ N_walkers , N = S.shape beta = 1 / T key, subkey = jax.random.split(key) idx = lax.cond( rnd_ord, lambda _: jax.random.permutation(subkey, N), # If rnd_ord is True lambda _: jnp.arange(N), # If rnd_ord is False operand=None ) def update_spin(i, val): S , key = val id = idx[i] key, subkey = jax.random.split(key) # Compute local field for spin i h_i = S @ J[id] # shape = (N_walkers) # Compute probability P(si = +1) p_plus = prob_plus_jax(beta * h_i) #1 / (1 + jnp.exp(-beta_2 * h_i)) # Generate random numbers for all samples rand_vals = jax.random.uniform(subkey, shape=(N_walkers,)) # Update spins based on probabilities new_spin = jnp.where(rand_vals < p_plus, 1, -1) S = S.at[:,id].set(new_spin) return S , key S , key = lax.fori_loop(0, N, update_spin, (S , key)) return S, key
# @profile
[docs] @jit def gibbs_step_multi_couplings_jax(S , J , T , rnd_ord = False, key = None): """ One full update of spin state using Gibbs sampling for multiple chains with same J. Parameters: ----------- - S: Spin configuration (shape (N_walkers, N)). - J: Coupling matrices (shape (N_walkers, N, N)). - T: Temperature. - rnd_ord: If True, update spins in random order. - key: PRNG key for randomness. Returns: -------- - Updated spin configuration and new PRNG key. """ N_walkers , N = S.shape beta = 1 / T key, subkey = jax.random.split(key) idx = lax.cond( rnd_ord, lambda _: jax.random.permutation(subkey, N), # If rnd_ord is True lambda _: jnp.arange(N), # If rnd_ord is False operand=None ) def update_spin(i, val): S , key = val id = idx[i] key, subkey = jax.random.split(key) # Compute local field for spin i h_i = jnp.sum(J[:,:,id] * S, axis=1) # shape = (N_walkers) # Compute probability P(si = +1) p_plus = prob_plus_jax(beta * h_i) #1 / (1 + jnp.exp(-beta_2 * h_i)) # Generate random numbers for all samples rand_vals = jax.random.uniform(subkey, shape=(N_walkers,)) # Update spins based on probabilities new_spin = jnp.where(rand_vals < p_plus, 1, -1) S = S.at[:,id].set(new_spin) return S , key S , key = lax.fori_loop(0, N, update_spin, (S , key)) return S, key