diff --git a/old/dense_associative_memory.py b/old/dense_associative_memory.py new file mode 100644 index 0000000..4c8aec0 --- /dev/null +++ b/old/dense_associative_memory.py @@ -0,0 +1,260 @@ +import numpy as np +import jax.numpy as jnp +from functools import partial +import jax +from tqdm import tqdm +from load_tinystories import ( + tokenize_with_punctuation, + load_encodings, + BiDict +) + +class DenseAssociativeMemory: + def __init__(self, M=20, N=20, temperature=0.1, batch_size=16, degree=3, h=0.01): + """Reduced default sequence length and batch size""" + self.M = M + self.N = N + self.temperature = temperature + self.batch_size = batch_size + self.degree = degree + self.h = h + self.key = jax.random.PRNGKey(0) + + @partial(jax.jit, static_argnums=(0,)) + def _compute_polynomial_interaction(self, batch): + """Compute interactions for a single batch""" + interactions = [] + mean_pattern = jnp.mean(batch, axis=0, keepdims=True) + centered_batch = batch - mean_pattern + + for d in range(2, self.degree + 1): + # Use smaller scaling factor for higher orders + scale = 1.0 / (d * len(batch)) + interaction = jnp.zeros((self.M * self.N, self.M * self.N)) + + def process_seq(interaction, seq): + term = jnp.outer(seq.flatten(), seq.flatten()) + # Use smaller exponent and clip values + term = jnp.clip(term, -1.0, 1.0) + term = jnp.power(term, d/4) # Reduced power to prevent overflow + return interaction + term * scale + + interaction = jax.lax.fori_loop( + 0, len(centered_batch), + lambda i, acc: process_seq(acc, centered_batch[i]), + interaction + ) + + interactions.append(interaction) + + return interactions, mean_pattern + + @partial(jax.jit, static_argnums=(0,)) + def _compute_energy(self, sequence, interactions, mean_pattern, h_field): + """Compute energy with polynomial interactions""" + # Ensure proper shapes + sequence = sequence.reshape(-1, 1) # Make column vector + mean_pattern = mean_pattern.reshape(-1, 1) # Ensure same shape + + # Center and normalize the sequence + centered_seq = sequence - mean_pattern + norm = jnp.linalg.norm(centered_seq) + 1e-8 + centered_seq = centered_seq / norm + + energy = 0.0 + for d, interaction in enumerate(interactions, 2): + # Reshape for matrix multiplication + seq_flat = centered_seq.reshape(1, -1) # Make row vector for first multiplication + # Compute energy with numerical stability + term = seq_flat @ interaction @ centered_seq + term = jnp.clip(term, -100.0, 100.0) # Prevent overflow + energy += -jnp.sum(term) / (d * d) # Stronger scaling for higher orders + + # Add scaled magnetic field term + field_term = -jnp.sum(h_field * sequence) * 0.1 # Scale down field contribution + energy = energy + field_term + + # Clip final energy to prevent extreme values + return jnp.clip(energy, -100.0, 100.0) + + def prepare_sequences(self, encoded_stories): + """Process stories in larger chunks with more sequences""" + sequences = [] + max_sequences = 20000 # Increased max sequences + + # Process more stories + for story in tqdm(encoded_stories[:1000], desc="Processing stories"): + if len(story) >= self.M: + # Take more sequences from each story + step = max(1, (len(story) - self.M) // 5) # Smaller step size + for i in range(0, len(story) - self.M + 1, step): + if len(sequences) >= max_sequences: + break + + word_group = story[i:i + self.M] + bits = [] + for word in word_group: + bits.extend([int(bit) for bit in word]) + sequences.append(bits) + + if len(sequences) >= max_sequences: + break + + # Convert to JAX array and reshape + print(f"\nCollected {len(sequences)} sequences") + sequences = jnp.array(sequences[:max_sequences]).reshape(-1, self.M * self.N, 1) + return sequences + + def compute_interactions(self, sequences): + """Compute interactions in batches""" + all_interactions = None + mean_pattern = None + num_batches = 0 + + # Process in small batches + for i in tqdm(range(0, len(sequences), self.batch_size), desc="Computing interactions"): + batch = sequences[i:i + self.batch_size] + batch_interactions, batch_mean = self._compute_polynomial_interaction(batch) + + if all_interactions is None: + all_interactions = [jnp.zeros_like(inter) for inter in batch_interactions] + mean_pattern = jnp.zeros_like(batch_mean) + + # Accumulate interactions and mean + for j, inter in enumerate(batch_interactions): + all_interactions[j] += inter + mean_pattern += batch_mean + num_batches += 1 + + # Average the accumulated values + all_interactions = [inter / num_batches for inter in all_interactions] + mean_pattern = mean_pattern / num_batches + + # Create magnetic field + h_field = self.h * jnp.ones((self.M * self.N, 1)) + + return all_interactions, h_field, mean_pattern + + def predict_next(self, partial_sequence, vocab, interactions, h_field, mean_pattern): + """Predict next word using DAM dynamics""" + possible_words = list(vocab.values()) + word_energies = [] + + for word in possible_words: + complete_sequence = partial_sequence + word + if len(complete_sequence) == self.M * self.N: + # Ensure proper shape for the sequence vector + complete_vec = jnp.array([int(bit) for bit in complete_sequence]).reshape(-1, 1) + try: + energy = float(self._compute_energy(complete_vec, interactions, mean_pattern, h_field)) + if not (jnp.isnan(energy) or jnp.isinf(energy)): + word_energies.append((word, energy)) + except: + continue # Skip if there's an error + + if not word_energies: + # Fallback: return random word if all energies are invalid + self.key, subkey = jax.random.split(self.key) + random_idx = jax.random.randint(subkey, (), 0, len(possible_words)) + return vocab[possible_words[random_idx]], 0.0 + + # Sort by energy + word_energies.sort(key=lambda x: x[1]) + + # Take top k candidates + k = min(10, len(word_energies)) + top_k = word_energies[:k] + + # Normalize energies with numerical stability + energies = jnp.array([e[1] for e in top_k]) + energies = energies - jnp.min(energies) + max_energy = jnp.max(energies) + if max_energy > 1e-8: + energies = energies / max_energy + + # Sample using stable softmax + probs = jnp.exp(-energies / self.temperature) + probs = probs / (jnp.sum(probs) + 1e-8) + + self.key, subkey = jax.random.split(self.key) + selected_idx = jax.random.choice(subkey, k, p=probs) + best_word, min_energy = top_k[selected_idx] + + for word, vector in vocab.items(): + if vector == best_word: + return word, min_energy + +def main(): + # Load saved encodings + vocab, encoded_stories, original_stories = load_encodings() + if vocab is None: + print("No saved encodings found. Please run load_tinystories.py first.") + return + + # Initialize model with adjusted parameters + model = DenseAssociativeMemory( + M=32, # Keep sequence length manageable + N=32, # Increased bits per word + temperature=0.1, + batch_size=32, # Increased batch size + degree=10, # Using quartic interactions + h=0.01 + ) + + # Prepare sequences + print("Preparing sequences...") + sequences = model.prepare_sequences(encoded_stories) + print(f"Prepared {len(sequences)} sequences") + + # Compute interactions + print("Computing interactions...") + interactions, h_field, mean_pattern = model.compute_interactions(sequences) + + # Get input from user + print("\nEnter your story:") + sentence = input("Enter a sentence (at least 99 words): ") + initial_tokens = tokenize_with_punctuation(sentence) + + if len(initial_tokens) < model.M - 1: + print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {model.M-1}.") + return + + # Predict sequence + print("\nPredicting continuation...") + current_tokens = initial_tokens[:model.M-1] + predictions = [] + energies = [] + + D = 10 # Number of words to predict + for _ in tqdm(range(D), desc="Generating words"): + # Convert current tokens to binary sequence + partial_sequence = "" + for token in current_tokens: + partial_sequence += vocab[token] + + # Predict next word + predicted_word, energy = model.predict_next( + partial_sequence, + vocab, + interactions, + h_field, + mean_pattern + ) + + predictions.append(predicted_word) + energies.append(energy) + + # Update current tokens + current_tokens = current_tokens[1:] + [predicted_word] + + # Print results + print("\nYour input ended with:") + print(" ".join(initial_tokens[-10:])) + print("\nPredicted continuation:") + print(" ".join(predictions)) + print("\nEnergies of predictions:") + for i, (word, energy) in enumerate(zip(predictions, energies)): + print(f"Word {i+1} ('{word}'): {energy:.4f}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/load_tinystories.py b/old/load_tinystories.py similarity index 99% rename from load_tinystories.py rename to old/load_tinystories.py index 606b42c..498996c 100644 --- a/load_tinystories.py +++ b/old/load_tinystories.py @@ -354,9 +354,9 @@ def predict_sequence(initial_sequence, vocab, sequences, W, D=10, M=100, N=12, t if __name__ == "__main__": N = 20 # Define N as a constant - M = 100 # Define M as a constant + M = 30 # Define M as a constant D = 10 # Number of words to predict - temperature = 1 + temperature = 0.01 batch_size = 50 # Added batch size parameter diff --git a/predict_story.py b/old/predict_story.py similarity index 100% rename from predict_story.py rename to old/predict_story.py diff --git a/sentences.txt b/old/sentences.txt similarity index 100% rename from sentences.txt rename to old/sentences.txt diff --git a/spin_glass.py b/old/spin_glass.py similarity index 100% rename from spin_glass.py rename to old/spin_glass.py diff --git a/spin_glass_gpu.py b/old/spin_glass_gpu.py similarity index 100% rename from spin_glass_gpu.py rename to old/spin_glass_gpu.py diff --git a/spin_glass_jax.py b/old/spin_glass_jax.py similarity index 100% rename from spin_glass_jax.py rename to old/spin_glass_jax.py