Source code for spin_sampler.spin_systems

import numpy as np
import jax.numpy as jnp
import jax

type_spins_jax = jnp.float32
type_reals_jax = jnp.float32

type_spins_np = np.float64
type_reals_np = np.float64

# Define different models

[docs] def define_hopfield_model(N , p , N_walkers = 1,mode = 'single_chain', backend = 'numpy', seed=None): """ Defines the Hopfield model patterns and coupling matrix. The patterns are random binary vectors of size N with values in {-1, 1}. The coupling matrix J is constructed using the Hebbian learning rule and the patterns as J = 1/N * patterns @ patterns.T with the diagonal set to zero. Parameters: ----------- - N: Number of spins or neurons. - p: Number of random patterns. - N_walkers: Number of chains to run in parallel. - mode: Sampling mode ('single_chain', 'multi_chain', 'multi_couplings'). - backend: Backend to use ('numpy', 'numba', 'jax'). - seed: Optional seed for random number generation (int) / mandatory for jax. Returns: ---------- - J: Coupling matrix (shape (N, N) or (N_walkers,N,N)). - patterns: Generated patterns (shape (N, p) or (N_walkers,N,p)). """ if mode == 'multi_chain': N_walkers = 1 if mode == 'multi_couplings' and N_walkers == 1: raise ValueError('The number of walkers is 1 only in single chain mode') if mode != 'multi_couplings' and N_walkers != 1: raise ValueError('The number of walkers must be > 1 only in multi couplings mode') if backend == 'jax' and seed == None: raise ValueError('The seed is mandatory for jax') # Initialize random seed/ key if seed is not None: if backend == 'jax': key = jax.random.PRNGKey(seed) key, subkey = jax.random.split(key) else: np.random.seed(seed) # Define shape of patterns shape = (N_walkers,N,p) if mode == 'multi_couplings' else (N,p) if backend == 'jax': patterns = jax.random.choice(subkey, jnp.array([-1, 1]), shape=shape).astype(type_reals_jax) if mode != 'multi_couplings': J = 1/N * (patterns @ patterns.T - p * jnp.eye(N)) else: J = 1/N * jnp.stack([pat @ pat.T - p * jnp.eye(N) for pat in patterns]) else: patterns = np.random.choice([-1, 1], size=shape).astype(type_reals_np) if mode != 'multi_couplings': J = 1/N * (patterns @ patterns.T - p*np.eye(N)) else: J = 1/N * np.array( [pat @ pat.T - p*np.eye(N) for pat in patterns] ) return J, patterns
[docs] def define_SK_model(N,N_walkers = 1,mode = 'single_chain', backend = 'numpy', seed=None): """ Defines the Sherrington-Kirkpatrick model coupling matrix. The coupling matrix J is symmetric and constructed with entries drawn from a normal distribution with mean 0 and variance 1/N, and the diagonal set to zero. Parameters: ----------- - N: Number of spins or neurons. - N_walkers: Number of chains to run in parallel. - mode: Sampling mode ('single_chain', 'multi_chain', 'multi_couplings'). - backend: Backend to use ('numpy', 'numba', 'jax'). - seed: Optional seed for random number generation (int) / mandatory for jax. Returns: ---------- - J: Coupling matrix (shape (N, N) or (N_walkers,N,N)). """ if mode == 'multi_chain': N_walkers = 1 if mode == 'multi_couplings' and N_walkers == 1: raise ValueError('The number of walkers is 1 only in single chain mode') if mode != 'multi_couplings' and N_walkers != 1: raise ValueError('The number of walkers must be > 1 only in multi couplings mode') if backend == 'jax' and seed == None: raise ValueError('The seed is mandatory for jax') # Initialize random seed/ key if seed is not None: if backend == 'jax': key = jax.random.PRNGKey(seed) key, subkey = jax.random.split(key) else: np.random.seed(seed) # Define shape of patterns shape = (N_walkers,N,N) if mode == 'multi_couplings' else (N,N) if backend == 'jax': J = jax.random.normal(subkey, shape, dtype=type_reals_jax) / jnp.sqrt(N) J = (J + jnp.swapaxes(J, -1, -2)) / jnp.sqrt(2) # Make symmetric with same variance if mode != 'multi_couplings': J = J - jnp.diag(jnp.diag(J)) else: J = jnp.array([j - jnp.diag(jnp.diag(j)) for j in J]) else: J = np.random.normal(0, 1/np.sqrt(N), size=shape).astype(type_reals_np) J = (J + np.swapaxes(J, -1, -2)) / np.sqrt(2) if mode != 'multi_couplings': J = J - np.diag(np.diag(J)) else: J = np.array([j - np.diag(np.diag(j)) for j in J]) return J
def define_random_model(N,N_walkers = 1,mode = 'single_chain', backend = 'numpy', seed=None): """ Defines the a random coupling matrix at with zero mean gaussian entries with 1/N variance. Used for testing purposes. Parameters: ----------- - N: Number of spins or neurons. - N_walkers: Number of chains to run in parallel. - mode: Sampling mode ('single_chain', 'multi_chain', 'multi_couplings'). - backend: Backend to use ('numpy', 'numba', 'jax'). - seed: Optional seed for random number generation (int) / mandatory for jax. Returns: ---------- - J: Coupling matrix (shape (N, N) or (N_walkers,N,N)). """ if mode == 'multi_chain': N_walkers = 1 if mode == 'multi_couplings' and N_walkers == 1: raise ValueError('The number of walkers is 1 only in single chain mode') if mode != 'multi_couplings' and N_walkers != 1: raise ValueError('The number of walkers must be > 1 only in multi couplings mode') if backend == 'jax' and seed == None: raise ValueError('The seed is mandatory for jax') # Initialize random seed/ key if seed is not None: if backend == 'jax': key = jax.random.PRNGKey(seed) key, subkey = jax.random.split(key) else: np.random.seed(seed) # Define shape of patterns shape = (N_walkers,N,N) if mode == 'multi_couplings' else (N,N) if backend == 'jax': J = jax.random.normal(subkey, shape, dtype=type_reals_jax) / jnp.sqrt(N) else: J = np.random.normal(0, 1/np.sqrt(N), size=shape).astype(type_reals_np) return J # Define different initializations
[docs] def initialize_spins(N,N_walkers = 1, mode = 'single_chain', backend = 'numpy', seed=None , config = 'random', ref_spin = None, m0 = None): """ Initialize the spin states depending on the configuration. Parameters: ----------- - N: Number of spins or neurons. - N_walkers: Number of chains to run in parallel. - mode: Sampling mode ('single_chain', 'multi_chain', 'multi_couplings'). - backend: Backend to use ('numpy', 'numba', 'jax'). - seed: Optional seed for random number generation (int) / mandatory for jax. - config: Initialization configuration ('random', 'magnetized'). - ref_spin: Reference configuration for magnetized initialization (shape (N,) or (N_walkers,N)). - m0: Initial magnetization level for magnetized initialization (float in (0,1) or array of shape (N_walkers,)). Returns: ---------- - S0: Initial spin configuration (shape (N) or (N_walkers,N)). """ if mode != 'single_chain' and N_walkers == 1: raise ValueError('The number of walkers is 1 only in single chain mode') if mode == 'single_chain' and N_walkers != 1: raise ValueError('The number of walkers must be > 1 only in multi couplings mode') if backend == 'jax' and seed == None: raise ValueError('The seed is mandatory for jax') shape = (N,) if mode == 'single_chain' else (N_walkers,N) if config not in ['random', 'magnetized']: raise ValueError("config must be 'random' or 'magnetized'") if config == 'magnetized': if ref_spin is None or m0 is None: raise ValueError("For 'magnetized' config, ref_spin and m0 must be provided") if ref_spin.shape != shape: raise ValueError(f"ref_spin shape must be {shape}") m0 = m0 if isinstance(m0, (list, np.ndarray)) else [m0]*N_walkers m0 = np.array(m0) if len(m0) != N_walkers: raise ValueError(f"m0 must be a float in (0,1) or an array of shape ({N_walkers},)") if np.any(m0 <= 0) or np.any(m0 >= 1): raise ValueError("All elements of m0 must be in the interval (0,1)") # Initialize random seed/ key if seed is not None: if backend == 'jax': key = jax.random.PRNGKey(seed) key, subkey = jax.random.split(key) else: np.random.seed(seed) shape = (N_walkers,N) if backend == 'jax': S0 = jax.random.choice(subkey,jnp.array([1,-1]),shape=shape).astype(type_spins_jax) else: S0 = np.random.choice([-1, 1], size=shape).astype(type_spins_np) if config == 'magnetized': ref_spins = ref_spin[None,:].copy() if mode == 'single_chain' else ref_spin.copy() if backend == 'jax': indices = [jax.random.choice(subkey, N, shape=[int(m0[w] * N)], replace=False) for w in range(N_walkers)] for w in range(N_walkers): S0 = S0.at[w,indices[w]].set(ref_spins[w,indices[w]]) else: indices = [np.random.choice(N, size=int(m0[w] * N), replace=False) for w in range(N_walkers)] for w in range(N_walkers): S0[w,indices[w]] = ref_spins[w,indices[w]] if mode == 'single_chain': S0 = S0[0] return S0