import jax import jax.numpy as jnp from functools import partial from load_tinystories import ( tokenize_with_punctuation, load_encodings, BiDict ) from tqdm import tqdm class SpinGlassJAX: def __init__(self, M=100, N=13, temperature=1.0, batch_size=32): self.M = M # sequence length self.N = N # bits per word self.temperature = temperature self.batch_size = batch_size self.key = jax.random.PRNGKey(0) @partial(jax.jit, static_argnums=(0,)) def _compute_weights(self, sequences): """Compute weight matrix using batched operations""" return jnp.mean( jnp.matmul(sequences, jnp.swapaxes(sequences, 1, 2)), axis=0 ) @partial(jax.jit, static_argnums=(0,)) def _compute_energy(self, sequence, W): """Compute energy for a single sequence""" return -0.5 * jnp.squeeze(sequence.T @ W @ sequence) @partial(jax.jit, static_argnums=(0,)) def _compute_batch_energies(self, sequences, W): """Compute energies for a batch of sequences""" # sequences shape: (batch_size, M*N, 1) energies = jax.vmap(lambda s: self._compute_energy(s, W))(sequences) return self._normalize_energies(energies) @partial(jax.jit, static_argnums=(0,)) def _normalize_energies(self, energies): """Normalize energies and compute probabilities""" energies = energies - jnp.min(energies) energies = energies / (jnp.max(energies) + 1e-10) probs = jnp.exp(-energies / self.temperature) return energies, probs / jnp.sum(probs) def prepare_sequences(self, encoded_stories): """Convert stories to JAX arrays with batching""" sequences = [] for story in tqdm(encoded_stories, desc="Processing stories"): if len(story) >= self.M: for i in range(len(story) - self.M + 1): word_group = story[i:i + self.M] bits = [] for word in word_group: bits.extend([int(bit) for bit in word]) sequences.append(bits) # Convert to JAX array and reshape sequences = jnp.array(sequences) return sequences.reshape(-1, self.M * self.N, 1) def predict_next(self, partial_sequence, vocab, training_sequences): """Predict next word given partial sequence""" # Get all possible words possible_words = list(vocab.values()) # Create complete sequences for all possible words complete_sequences = [] for word in possible_words: complete_sequence = partial_sequence + word if len(complete_sequence) == self.M * self.N: complete_vec = [int(bit) for bit in complete_sequence] complete_sequences.append(complete_vec) # Convert to JAX array complete_sequences = jnp.array(complete_sequences).reshape(-1, self.M * self.N, 1) # Compute weights once W = self._compute_weights(training_sequences) # Process in batches all_energies = [] all_probs = [] for i in range(0, len(complete_sequences), self.batch_size): batch = complete_sequences[i:i + self.batch_size] energies, probs = self._compute_batch_energies(batch, W) all_energies.append(energies) all_probs.append(probs) # Combine results energies = jnp.concatenate(all_energies) probs = jnp.concatenate(all_probs) probs = probs / jnp.sum(probs) # Renormalize # Sample next word self.key, subkey = jax.random.split(self.key) selected_idx = jax.random.choice(subkey, len(possible_words), p=probs) best_word = possible_words[selected_idx] min_energy = float(energies[selected_idx]) # Find corresponding word for word, vector in vocab.items(): if vector == best_word: return word, min_energy def main(): # Load saved encodings vocab, encoded_stories, original_stories = load_encodings() if vocab is None: print("No saved encodings found. Please run load_tinystories.py first.") return # Initialize model model = SpinGlassJAX(M=100, N=13, temperature=1.0, batch_size=32) # Prepare training sequences print("Preparing training sequences...") training_sequences = model.prepare_sequences(encoded_stories) print(f"Prepared {len(training_sequences)} sequences") # Get input from user print("\nEnter your story:") sentence = input("Enter a sentence (at least 99 words): ") initial_tokens = tokenize_with_punctuation(sentence) if len(initial_tokens) < model.M - 1: print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {model.M-1}.") return # Predict sequence print("\nPredicting continuation...") current_tokens = initial_tokens[:model.M-1] predictions = [] energies = [] D = 10 # Number of words to predict for _ in tqdm(range(D), desc="Generating words"): # Convert current tokens to binary sequence partial_sequence = "" for token in current_tokens: partial_sequence += vocab[token] # Predict next word predicted_word, energy = model.predict_next( partial_sequence, vocab, training_sequences ) predictions.append(predicted_word) energies.append(energy) # Update current tokens current_tokens = current_tokens[1:] + [predicted_word] # Print results print("\nYour input ended with:") print(" ".join(initial_tokens[-10:])) print("\nPredicted continuation:") print(" ".join(predictions)) print("\nEnergies of predictions:") for i, (word, energy) in enumerate(zip(predictions, energies)): print(f"Word {i+1} ('{word}'): {energy:.4f}") if __name__ == "__main__": main()