Add JAX implementation of spin glass model with batched processing
This commit is contained in:
170
spin_glass_jax.py
Normal file
170
spin_glass_jax.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from functools import partial
|
||||
from load_tinystories import (
|
||||
tokenize_with_punctuation,
|
||||
load_encodings,
|
||||
BiDict
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
class SpinGlassJAX:
|
||||
def __init__(self, M=100, N=13, temperature=1.0, batch_size=32):
|
||||
self.M = M # sequence length
|
||||
self.N = N # bits per word
|
||||
self.temperature = temperature
|
||||
self.batch_size = batch_size
|
||||
self.key = jax.random.PRNGKey(0)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def _compute_weights(self, sequences):
|
||||
"""Compute weight matrix using batched operations"""
|
||||
return jnp.mean(
|
||||
jnp.matmul(sequences, jnp.swapaxes(sequences, 1, 2)),
|
||||
axis=0
|
||||
)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def _compute_energy(self, sequence, W):
|
||||
"""Compute energy for a single sequence"""
|
||||
return -0.5 * jnp.squeeze(sequence.T @ W @ sequence)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def _compute_batch_energies(self, sequences, W):
|
||||
"""Compute energies for a batch of sequences"""
|
||||
# sequences shape: (batch_size, M*N, 1)
|
||||
energies = jax.vmap(lambda s: self._compute_energy(s, W))(sequences)
|
||||
return self._normalize_energies(energies)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def _normalize_energies(self, energies):
|
||||
"""Normalize energies and compute probabilities"""
|
||||
energies = energies - jnp.min(energies)
|
||||
energies = energies / (jnp.max(energies) + 1e-10)
|
||||
probs = jnp.exp(-energies / self.temperature)
|
||||
return energies, probs / jnp.sum(probs)
|
||||
|
||||
def prepare_sequences(self, encoded_stories):
|
||||
"""Convert stories to JAX arrays with batching"""
|
||||
sequences = []
|
||||
|
||||
for story in tqdm(encoded_stories, desc="Processing stories"):
|
||||
if len(story) >= self.M:
|
||||
for i in range(len(story) - self.M + 1):
|
||||
word_group = story[i:i + self.M]
|
||||
bits = []
|
||||
for word in word_group:
|
||||
bits.extend([int(bit) for bit in word])
|
||||
sequences.append(bits)
|
||||
|
||||
# Convert to JAX array and reshape
|
||||
sequences = jnp.array(sequences)
|
||||
return sequences.reshape(-1, self.M * self.N, 1)
|
||||
|
||||
def predict_next(self, partial_sequence, vocab, training_sequences):
|
||||
"""Predict next word given partial sequence"""
|
||||
# Get all possible words
|
||||
possible_words = list(vocab.values())
|
||||
|
||||
# Create complete sequences for all possible words
|
||||
complete_sequences = []
|
||||
for word in possible_words:
|
||||
complete_sequence = partial_sequence + word
|
||||
if len(complete_sequence) == self.M * self.N:
|
||||
complete_vec = [int(bit) for bit in complete_sequence]
|
||||
complete_sequences.append(complete_vec)
|
||||
|
||||
# Convert to JAX array
|
||||
complete_sequences = jnp.array(complete_sequences).reshape(-1, self.M * self.N, 1)
|
||||
|
||||
# Compute weights once
|
||||
W = self._compute_weights(training_sequences)
|
||||
|
||||
# Process in batches
|
||||
all_energies = []
|
||||
all_probs = []
|
||||
|
||||
for i in range(0, len(complete_sequences), self.batch_size):
|
||||
batch = complete_sequences[i:i + self.batch_size]
|
||||
energies, probs = self._compute_batch_energies(batch, W)
|
||||
all_energies.append(energies)
|
||||
all_probs.append(probs)
|
||||
|
||||
# Combine results
|
||||
energies = jnp.concatenate(all_energies)
|
||||
probs = jnp.concatenate(all_probs)
|
||||
probs = probs / jnp.sum(probs) # Renormalize
|
||||
|
||||
# Sample next word
|
||||
self.key, subkey = jax.random.split(self.key)
|
||||
selected_idx = jax.random.choice(subkey, len(possible_words), p=probs)
|
||||
|
||||
best_word = possible_words[selected_idx]
|
||||
min_energy = float(energies[selected_idx])
|
||||
|
||||
# Find corresponding word
|
||||
for word, vector in vocab.items():
|
||||
if vector == best_word:
|
||||
return word, min_energy
|
||||
|
||||
def main():
|
||||
# Load saved encodings
|
||||
vocab, encoded_stories, original_stories = load_encodings()
|
||||
if vocab is None:
|
||||
print("No saved encodings found. Please run load_tinystories.py first.")
|
||||
return
|
||||
|
||||
# Initialize model
|
||||
model = SpinGlassJAX(M=100, N=13, temperature=1.0, batch_size=32)
|
||||
|
||||
# Prepare training sequences
|
||||
print("Preparing training sequences...")
|
||||
training_sequences = model.prepare_sequences(encoded_stories)
|
||||
print(f"Prepared {len(training_sequences)} sequences")
|
||||
|
||||
# Get input from user
|
||||
print("\nEnter your story:")
|
||||
sentence = input("Enter a sentence (at least 99 words): ")
|
||||
initial_tokens = tokenize_with_punctuation(sentence)
|
||||
|
||||
if len(initial_tokens) < model.M - 1:
|
||||
print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {model.M-1}.")
|
||||
return
|
||||
|
||||
# Predict sequence
|
||||
print("\nPredicting continuation...")
|
||||
current_tokens = initial_tokens[:model.M-1]
|
||||
predictions = []
|
||||
energies = []
|
||||
|
||||
D = 10 # Number of words to predict
|
||||
for _ in tqdm(range(D), desc="Generating words"):
|
||||
# Convert current tokens to binary sequence
|
||||
partial_sequence = ""
|
||||
for token in current_tokens:
|
||||
partial_sequence += vocab[token]
|
||||
|
||||
# Predict next word
|
||||
predicted_word, energy = model.predict_next(
|
||||
partial_sequence,
|
||||
vocab,
|
||||
training_sequences
|
||||
)
|
||||
|
||||
predictions.append(predicted_word)
|
||||
energies.append(energy)
|
||||
|
||||
# Update current tokens
|
||||
current_tokens = current_tokens[1:] + [predicted_word]
|
||||
|
||||
# Print results
|
||||
print("\nYour input ended with:")
|
||||
print(" ".join(initial_tokens[-10:]))
|
||||
print("\nPredicted continuation:")
|
||||
print(" ".join(predictions))
|
||||
print("\nEnergies of predictions:")
|
||||
for i, (word, energy) in enumerate(zip(predictions, energies)):
|
||||
print(f"Word {i+1} ('{word}'): {energy:.4f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user