Files
cimai/spin_glass_jax.py

170 lines
6.0 KiB
Python

import jax
import jax.numpy as jnp
from functools import partial
from load_tinystories import (
tokenize_with_punctuation,
load_encodings,
BiDict
)
from tqdm import tqdm
class SpinGlassJAX:
def __init__(self, M=100, N=13, temperature=1.0, batch_size=32):
self.M = M # sequence length
self.N = N # bits per word
self.temperature = temperature
self.batch_size = batch_size
self.key = jax.random.PRNGKey(0)
@partial(jax.jit, static_argnums=(0,))
def _compute_weights(self, sequences):
"""Compute weight matrix using batched operations"""
return jnp.mean(
jnp.matmul(sequences, jnp.swapaxes(sequences, 1, 2)),
axis=0
)
@partial(jax.jit, static_argnums=(0,))
def _compute_energy(self, sequence, W):
"""Compute energy for a single sequence"""
return -0.5 * jnp.squeeze(sequence.T @ W @ sequence)
@partial(jax.jit, static_argnums=(0,))
def _compute_batch_energies(self, sequences, W):
"""Compute energies for a batch of sequences"""
# sequences shape: (batch_size, M*N, 1)
energies = jax.vmap(lambda s: self._compute_energy(s, W))(sequences)
return self._normalize_energies(energies)
@partial(jax.jit, static_argnums=(0,))
def _normalize_energies(self, energies):
"""Normalize energies and compute probabilities"""
energies = energies - jnp.min(energies)
energies = energies / (jnp.max(energies) + 1e-10)
probs = jnp.exp(-energies / self.temperature)
return energies, probs / jnp.sum(probs)
def prepare_sequences(self, encoded_stories):
"""Convert stories to JAX arrays with batching"""
sequences = []
for story in tqdm(encoded_stories, desc="Processing stories"):
if len(story) >= self.M:
for i in range(len(story) - self.M + 1):
word_group = story[i:i + self.M]
bits = []
for word in word_group:
bits.extend([int(bit) for bit in word])
sequences.append(bits)
# Convert to JAX array and reshape
sequences = jnp.array(sequences)
return sequences.reshape(-1, self.M * self.N, 1)
def predict_next(self, partial_sequence, vocab, training_sequences):
"""Predict next word given partial sequence"""
# Get all possible words
possible_words = list(vocab.values())
# Create complete sequences for all possible words
complete_sequences = []
for word in possible_words:
complete_sequence = partial_sequence + word
if len(complete_sequence) == self.M * self.N:
complete_vec = [int(bit) for bit in complete_sequence]
complete_sequences.append(complete_vec)
# Convert to JAX array
complete_sequences = jnp.array(complete_sequences).reshape(-1, self.M * self.N, 1)
# Compute weights once
W = self._compute_weights(training_sequences)
# Process in batches
all_energies = []
all_probs = []
for i in range(0, len(complete_sequences), self.batch_size):
batch = complete_sequences[i:i + self.batch_size]
energies, probs = self._compute_batch_energies(batch, W)
all_energies.append(energies)
all_probs.append(probs)
# Combine results
energies = jnp.concatenate(all_energies)
probs = jnp.concatenate(all_probs)
probs = probs / jnp.sum(probs) # Renormalize
# Sample next word
self.key, subkey = jax.random.split(self.key)
selected_idx = jax.random.choice(subkey, len(possible_words), p=probs)
best_word = possible_words[selected_idx]
min_energy = float(energies[selected_idx])
# Find corresponding word
for word, vector in vocab.items():
if vector == best_word:
return word, min_energy
def main():
# Load saved encodings
vocab, encoded_stories, original_stories = load_encodings()
if vocab is None:
print("No saved encodings found. Please run load_tinystories.py first.")
return
# Initialize model
model = SpinGlassJAX(M=100, N=13, temperature=1.0, batch_size=32)
# Prepare training sequences
print("Preparing training sequences...")
training_sequences = model.prepare_sequences(encoded_stories)
print(f"Prepared {len(training_sequences)} sequences")
# Get input from user
print("\nEnter your story:")
sentence = input("Enter a sentence (at least 99 words): ")
initial_tokens = tokenize_with_punctuation(sentence)
if len(initial_tokens) < model.M - 1:
print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {model.M-1}.")
return
# Predict sequence
print("\nPredicting continuation...")
current_tokens = initial_tokens[:model.M-1]
predictions = []
energies = []
D = 10 # Number of words to predict
for _ in tqdm(range(D), desc="Generating words"):
# Convert current tokens to binary sequence
partial_sequence = ""
for token in current_tokens:
partial_sequence += vocab[token]
# Predict next word
predicted_word, energy = model.predict_next(
partial_sequence,
vocab,
training_sequences
)
predictions.append(predicted_word)
energies.append(energy)
# Update current tokens
current_tokens = current_tokens[1:] + [predicted_word]
# Print results
print("\nYour input ended with:")
print(" ".join(initial_tokens[-10:]))
print("\nPredicted continuation:")
print(" ".join(predictions))
print("\nEnergies of predictions:")
for i, (word, energy) in enumerate(zip(predictions, energies)):
print(f"Word {i+1} ('{word}'): {energy:.4f}")
if __name__ == "__main__":
main()