commit 729ae0b8c78f9bc2b566eea292f34de906f3ce04 Author: alireza Date: Wed Feb 19 13:44:33 2025 +0330 Add JAX implementation of spin glass model with batched processing diff --git a/spin_glass_jax.py b/spin_glass_jax.py new file mode 100644 index 0000000..e5bcd27 --- /dev/null +++ b/spin_glass_jax.py @@ -0,0 +1,170 @@ +import jax +import jax.numpy as jnp +from functools import partial +from load_tinystories import ( + tokenize_with_punctuation, + load_encodings, + BiDict +) +from tqdm import tqdm + +class SpinGlassJAX: + def __init__(self, M=100, N=13, temperature=1.0, batch_size=32): + self.M = M # sequence length + self.N = N # bits per word + self.temperature = temperature + self.batch_size = batch_size + self.key = jax.random.PRNGKey(0) + + @partial(jax.jit, static_argnums=(0,)) + def _compute_weights(self, sequences): + """Compute weight matrix using batched operations""" + return jnp.mean( + jnp.matmul(sequences, jnp.swapaxes(sequences, 1, 2)), + axis=0 + ) + + @partial(jax.jit, static_argnums=(0,)) + def _compute_energy(self, sequence, W): + """Compute energy for a single sequence""" + return -0.5 * jnp.squeeze(sequence.T @ W @ sequence) + + @partial(jax.jit, static_argnums=(0,)) + def _compute_batch_energies(self, sequences, W): + """Compute energies for a batch of sequences""" + # sequences shape: (batch_size, M*N, 1) + energies = jax.vmap(lambda s: self._compute_energy(s, W))(sequences) + return self._normalize_energies(energies) + + @partial(jax.jit, static_argnums=(0,)) + def _normalize_energies(self, energies): + """Normalize energies and compute probabilities""" + energies = energies - jnp.min(energies) + energies = energies / (jnp.max(energies) + 1e-10) + probs = jnp.exp(-energies / self.temperature) + return energies, probs / jnp.sum(probs) + + def prepare_sequences(self, encoded_stories): + """Convert stories to JAX arrays with batching""" + sequences = [] + + for story in tqdm(encoded_stories, desc="Processing stories"): + if len(story) >= self.M: + for i in range(len(story) - self.M + 1): + word_group = story[i:i + self.M] + bits = [] + for word in word_group: + bits.extend([int(bit) for bit in word]) + sequences.append(bits) + + # Convert to JAX array and reshape + sequences = jnp.array(sequences) + return sequences.reshape(-1, self.M * self.N, 1) + + def predict_next(self, partial_sequence, vocab, training_sequences): + """Predict next word given partial sequence""" + # Get all possible words + possible_words = list(vocab.values()) + + # Create complete sequences for all possible words + complete_sequences = [] + for word in possible_words: + complete_sequence = partial_sequence + word + if len(complete_sequence) == self.M * self.N: + complete_vec = [int(bit) for bit in complete_sequence] + complete_sequences.append(complete_vec) + + # Convert to JAX array + complete_sequences = jnp.array(complete_sequences).reshape(-1, self.M * self.N, 1) + + # Compute weights once + W = self._compute_weights(training_sequences) + + # Process in batches + all_energies = [] + all_probs = [] + + for i in range(0, len(complete_sequences), self.batch_size): + batch = complete_sequences[i:i + self.batch_size] + energies, probs = self._compute_batch_energies(batch, W) + all_energies.append(energies) + all_probs.append(probs) + + # Combine results + energies = jnp.concatenate(all_energies) + probs = jnp.concatenate(all_probs) + probs = probs / jnp.sum(probs) # Renormalize + + # Sample next word + self.key, subkey = jax.random.split(self.key) + selected_idx = jax.random.choice(subkey, len(possible_words), p=probs) + + best_word = possible_words[selected_idx] + min_energy = float(energies[selected_idx]) + + # Find corresponding word + for word, vector in vocab.items(): + if vector == best_word: + return word, min_energy + +def main(): + # Load saved encodings + vocab, encoded_stories, original_stories = load_encodings() + if vocab is None: + print("No saved encodings found. Please run load_tinystories.py first.") + return + + # Initialize model + model = SpinGlassJAX(M=100, N=13, temperature=1.0, batch_size=32) + + # Prepare training sequences + print("Preparing training sequences...") + training_sequences = model.prepare_sequences(encoded_stories) + print(f"Prepared {len(training_sequences)} sequences") + + # Get input from user + print("\nEnter your story:") + sentence = input("Enter a sentence (at least 99 words): ") + initial_tokens = tokenize_with_punctuation(sentence) + + if len(initial_tokens) < model.M - 1: + print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {model.M-1}.") + return + + # Predict sequence + print("\nPredicting continuation...") + current_tokens = initial_tokens[:model.M-1] + predictions = [] + energies = [] + + D = 10 # Number of words to predict + for _ in tqdm(range(D), desc="Generating words"): + # Convert current tokens to binary sequence + partial_sequence = "" + for token in current_tokens: + partial_sequence += vocab[token] + + # Predict next word + predicted_word, energy = model.predict_next( + partial_sequence, + vocab, + training_sequences + ) + + predictions.append(predicted_word) + energies.append(energy) + + # Update current tokens + current_tokens = current_tokens[1:] + [predicted_word] + + # Print results + print("\nYour input ended with:") + print(" ".join(initial_tokens[-10:])) + print("\nPredicted continuation:") + print(" ".join(predictions)) + print("\nEnergies of predictions:") + for i, (word, energy) in enumerate(zip(predictions, energies)): + print(f"Word {i+1} ('{word}'): {energy:.4f}") + +if __name__ == "__main__": + main() \ No newline at end of file