import numpy as np import jax.numpy as jnp from functools import partial import jax from tqdm import tqdm from load_tinystories import ( tokenize_with_punctuation, load_encodings, BiDict ) class DenseAssociativeMemory: def __init__(self, M=20, N=20, temperature=0.1, batch_size=16, degree=3, h=0.01): """Reduced default sequence length and batch size""" self.M = M self.N = N self.temperature = temperature self.batch_size = batch_size self.degree = degree self.h = h self.key = jax.random.PRNGKey(0) @partial(jax.jit, static_argnums=(0,)) def _compute_polynomial_interaction(self, batch): """Compute interactions for a single batch""" interactions = [] mean_pattern = jnp.mean(batch, axis=0, keepdims=True) centered_batch = batch - mean_pattern for d in range(2, self.degree + 1): # Use smaller scaling factor for higher orders scale = 1.0 / (d * len(batch)) interaction = jnp.zeros((self.M * self.N, self.M * self.N)) def process_seq(interaction, seq): term = jnp.outer(seq.flatten(), seq.flatten()) # Use smaller exponent and clip values term = jnp.clip(term, -1.0, 1.0) term = jnp.power(term, d/4) # Reduced power to prevent overflow return interaction + term * scale interaction = jax.lax.fori_loop( 0, len(centered_batch), lambda i, acc: process_seq(acc, centered_batch[i]), interaction ) interactions.append(interaction) return interactions, mean_pattern @partial(jax.jit, static_argnums=(0,)) def _compute_energy(self, sequence, interactions, mean_pattern, h_field): """Compute energy with polynomial interactions""" # Ensure proper shapes sequence = sequence.reshape(-1, 1) # Make column vector mean_pattern = mean_pattern.reshape(-1, 1) # Ensure same shape # Center and normalize the sequence centered_seq = sequence - mean_pattern norm = jnp.linalg.norm(centered_seq) + 1e-8 centered_seq = centered_seq / norm energy = 0.0 for d, interaction in enumerate(interactions, 2): # Reshape for matrix multiplication seq_flat = centered_seq.reshape(1, -1) # Make row vector for first multiplication # Compute energy with numerical stability term = seq_flat @ interaction @ centered_seq term = jnp.clip(term, -100.0, 100.0) # Prevent overflow energy += -jnp.sum(term) / (d * d) # Stronger scaling for higher orders # Add scaled magnetic field term field_term = -jnp.sum(h_field * sequence) * 0.1 # Scale down field contribution energy = energy + field_term # Clip final energy to prevent extreme values return jnp.clip(energy, -100.0, 100.0) def prepare_sequences(self, encoded_stories): """Process stories in larger chunks with more sequences""" sequences = [] max_sequences = 20000 # Increased max sequences # Process more stories for story in tqdm(encoded_stories[:1000], desc="Processing stories"): if len(story) >= self.M: # Take more sequences from each story step = max(1, (len(story) - self.M) // 5) # Smaller step size for i in range(0, len(story) - self.M + 1, step): if len(sequences) >= max_sequences: break word_group = story[i:i + self.M] bits = [] for word in word_group: bits.extend([int(bit) for bit in word]) sequences.append(bits) if len(sequences) >= max_sequences: break # Convert to JAX array and reshape print(f"\nCollected {len(sequences)} sequences") sequences = jnp.array(sequences[:max_sequences]).reshape(-1, self.M * self.N, 1) return sequences def compute_interactions(self, sequences): """Compute interactions in batches""" all_interactions = None mean_pattern = None num_batches = 0 # Process in small batches for i in tqdm(range(0, len(sequences), self.batch_size), desc="Computing interactions"): batch = sequences[i:i + self.batch_size] batch_interactions, batch_mean = self._compute_polynomial_interaction(batch) if all_interactions is None: all_interactions = [jnp.zeros_like(inter) for inter in batch_interactions] mean_pattern = jnp.zeros_like(batch_mean) # Accumulate interactions and mean for j, inter in enumerate(batch_interactions): all_interactions[j] += inter mean_pattern += batch_mean num_batches += 1 # Average the accumulated values all_interactions = [inter / num_batches for inter in all_interactions] mean_pattern = mean_pattern / num_batches # Create magnetic field h_field = self.h * jnp.ones((self.M * self.N, 1)) return all_interactions, h_field, mean_pattern def predict_next(self, partial_sequence, vocab, interactions, h_field, mean_pattern): """Predict next word using DAM dynamics""" possible_words = list(vocab.values()) word_energies = [] for word in possible_words: complete_sequence = partial_sequence + word if len(complete_sequence) == self.M * self.N: # Ensure proper shape for the sequence vector complete_vec = jnp.array([int(bit) for bit in complete_sequence]).reshape(-1, 1) try: energy = float(self._compute_energy(complete_vec, interactions, mean_pattern, h_field)) if not (jnp.isnan(energy) or jnp.isinf(energy)): word_energies.append((word, energy)) except: continue # Skip if there's an error if not word_energies: # Fallback: return random word if all energies are invalid self.key, subkey = jax.random.split(self.key) random_idx = jax.random.randint(subkey, (), 0, len(possible_words)) return vocab[possible_words[random_idx]], 0.0 # Sort by energy word_energies.sort(key=lambda x: x[1]) # Take top k candidates k = min(10, len(word_energies)) top_k = word_energies[:k] # Normalize energies with numerical stability energies = jnp.array([e[1] for e in top_k]) energies = energies - jnp.min(energies) max_energy = jnp.max(energies) if max_energy > 1e-8: energies = energies / max_energy # Sample using stable softmax probs = jnp.exp(-energies / self.temperature) probs = probs / (jnp.sum(probs) + 1e-8) self.key, subkey = jax.random.split(self.key) selected_idx = jax.random.choice(subkey, k, p=probs) best_word, min_energy = top_k[selected_idx] for word, vector in vocab.items(): if vector == best_word: return word, min_energy def main(): # Load saved encodings vocab, encoded_stories, original_stories = load_encodings() if vocab is None: print("No saved encodings found. Please run load_tinystories.py first.") return # Initialize model with adjusted parameters model = DenseAssociativeMemory( M=32, # Keep sequence length manageable N=32, # Increased bits per word temperature=0.1, batch_size=32, # Increased batch size degree=10, # Using quartic interactions h=0.01 ) # Prepare sequences print("Preparing sequences...") sequences = model.prepare_sequences(encoded_stories) print(f"Prepared {len(sequences)} sequences") # Compute interactions print("Computing interactions...") interactions, h_field, mean_pattern = model.compute_interactions(sequences) # Get input from user print("\nEnter your story:") sentence = input("Enter a sentence (at least 99 words): ") initial_tokens = tokenize_with_punctuation(sentence) if len(initial_tokens) < model.M - 1: print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {model.M-1}.") return # Predict sequence print("\nPredicting continuation...") current_tokens = initial_tokens[:model.M-1] predictions = [] energies = [] D = 10 # Number of words to predict for _ in tqdm(range(D), desc="Generating words"): # Convert current tokens to binary sequence partial_sequence = "" for token in current_tokens: partial_sequence += vocab[token] # Predict next word predicted_word, energy = model.predict_next( partial_sequence, vocab, interactions, h_field, mean_pattern ) predictions.append(predicted_word) energies.append(energy) # Update current tokens current_tokens = current_tokens[1:] + [predicted_word] # Print results print("\nYour input ended with:") print(" ".join(initial_tokens[-10:])) print("\nPredicted continuation:") print(" ".join(predictions)) print("\nEnergies of predictions:") for i, (word, energy) in enumerate(zip(predictions, energies)): print(f"Word {i+1} ('{word}'): {energy:.4f}") if __name__ == "__main__": main()