heyyyy
This commit is contained in:
260
old/dense_associative_memory.py
Normal file
260
old/dense_associative_memory.py
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from functools import partial
|
||||||
|
import jax
|
||||||
|
from tqdm import tqdm
|
||||||
|
from load_tinystories import (
|
||||||
|
tokenize_with_punctuation,
|
||||||
|
load_encodings,
|
||||||
|
BiDict
|
||||||
|
)
|
||||||
|
|
||||||
|
class DenseAssociativeMemory:
|
||||||
|
def __init__(self, M=20, N=20, temperature=0.1, batch_size=16, degree=3, h=0.01):
|
||||||
|
"""Reduced default sequence length and batch size"""
|
||||||
|
self.M = M
|
||||||
|
self.N = N
|
||||||
|
self.temperature = temperature
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.degree = degree
|
||||||
|
self.h = h
|
||||||
|
self.key = jax.random.PRNGKey(0)
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
|
def _compute_polynomial_interaction(self, batch):
|
||||||
|
"""Compute interactions for a single batch"""
|
||||||
|
interactions = []
|
||||||
|
mean_pattern = jnp.mean(batch, axis=0, keepdims=True)
|
||||||
|
centered_batch = batch - mean_pattern
|
||||||
|
|
||||||
|
for d in range(2, self.degree + 1):
|
||||||
|
# Use smaller scaling factor for higher orders
|
||||||
|
scale = 1.0 / (d * len(batch))
|
||||||
|
interaction = jnp.zeros((self.M * self.N, self.M * self.N))
|
||||||
|
|
||||||
|
def process_seq(interaction, seq):
|
||||||
|
term = jnp.outer(seq.flatten(), seq.flatten())
|
||||||
|
# Use smaller exponent and clip values
|
||||||
|
term = jnp.clip(term, -1.0, 1.0)
|
||||||
|
term = jnp.power(term, d/4) # Reduced power to prevent overflow
|
||||||
|
return interaction + term * scale
|
||||||
|
|
||||||
|
interaction = jax.lax.fori_loop(
|
||||||
|
0, len(centered_batch),
|
||||||
|
lambda i, acc: process_seq(acc, centered_batch[i]),
|
||||||
|
interaction
|
||||||
|
)
|
||||||
|
|
||||||
|
interactions.append(interaction)
|
||||||
|
|
||||||
|
return interactions, mean_pattern
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
|
def _compute_energy(self, sequence, interactions, mean_pattern, h_field):
|
||||||
|
"""Compute energy with polynomial interactions"""
|
||||||
|
# Ensure proper shapes
|
||||||
|
sequence = sequence.reshape(-1, 1) # Make column vector
|
||||||
|
mean_pattern = mean_pattern.reshape(-1, 1) # Ensure same shape
|
||||||
|
|
||||||
|
# Center and normalize the sequence
|
||||||
|
centered_seq = sequence - mean_pattern
|
||||||
|
norm = jnp.linalg.norm(centered_seq) + 1e-8
|
||||||
|
centered_seq = centered_seq / norm
|
||||||
|
|
||||||
|
energy = 0.0
|
||||||
|
for d, interaction in enumerate(interactions, 2):
|
||||||
|
# Reshape for matrix multiplication
|
||||||
|
seq_flat = centered_seq.reshape(1, -1) # Make row vector for first multiplication
|
||||||
|
# Compute energy with numerical stability
|
||||||
|
term = seq_flat @ interaction @ centered_seq
|
||||||
|
term = jnp.clip(term, -100.0, 100.0) # Prevent overflow
|
||||||
|
energy += -jnp.sum(term) / (d * d) # Stronger scaling for higher orders
|
||||||
|
|
||||||
|
# Add scaled magnetic field term
|
||||||
|
field_term = -jnp.sum(h_field * sequence) * 0.1 # Scale down field contribution
|
||||||
|
energy = energy + field_term
|
||||||
|
|
||||||
|
# Clip final energy to prevent extreme values
|
||||||
|
return jnp.clip(energy, -100.0, 100.0)
|
||||||
|
|
||||||
|
def prepare_sequences(self, encoded_stories):
|
||||||
|
"""Process stories in larger chunks with more sequences"""
|
||||||
|
sequences = []
|
||||||
|
max_sequences = 20000 # Increased max sequences
|
||||||
|
|
||||||
|
# Process more stories
|
||||||
|
for story in tqdm(encoded_stories[:1000], desc="Processing stories"):
|
||||||
|
if len(story) >= self.M:
|
||||||
|
# Take more sequences from each story
|
||||||
|
step = max(1, (len(story) - self.M) // 5) # Smaller step size
|
||||||
|
for i in range(0, len(story) - self.M + 1, step):
|
||||||
|
if len(sequences) >= max_sequences:
|
||||||
|
break
|
||||||
|
|
||||||
|
word_group = story[i:i + self.M]
|
||||||
|
bits = []
|
||||||
|
for word in word_group:
|
||||||
|
bits.extend([int(bit) for bit in word])
|
||||||
|
sequences.append(bits)
|
||||||
|
|
||||||
|
if len(sequences) >= max_sequences:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Convert to JAX array and reshape
|
||||||
|
print(f"\nCollected {len(sequences)} sequences")
|
||||||
|
sequences = jnp.array(sequences[:max_sequences]).reshape(-1, self.M * self.N, 1)
|
||||||
|
return sequences
|
||||||
|
|
||||||
|
def compute_interactions(self, sequences):
|
||||||
|
"""Compute interactions in batches"""
|
||||||
|
all_interactions = None
|
||||||
|
mean_pattern = None
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
# Process in small batches
|
||||||
|
for i in tqdm(range(0, len(sequences), self.batch_size), desc="Computing interactions"):
|
||||||
|
batch = sequences[i:i + self.batch_size]
|
||||||
|
batch_interactions, batch_mean = self._compute_polynomial_interaction(batch)
|
||||||
|
|
||||||
|
if all_interactions is None:
|
||||||
|
all_interactions = [jnp.zeros_like(inter) for inter in batch_interactions]
|
||||||
|
mean_pattern = jnp.zeros_like(batch_mean)
|
||||||
|
|
||||||
|
# Accumulate interactions and mean
|
||||||
|
for j, inter in enumerate(batch_interactions):
|
||||||
|
all_interactions[j] += inter
|
||||||
|
mean_pattern += batch_mean
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
# Average the accumulated values
|
||||||
|
all_interactions = [inter / num_batches for inter in all_interactions]
|
||||||
|
mean_pattern = mean_pattern / num_batches
|
||||||
|
|
||||||
|
# Create magnetic field
|
||||||
|
h_field = self.h * jnp.ones((self.M * self.N, 1))
|
||||||
|
|
||||||
|
return all_interactions, h_field, mean_pattern
|
||||||
|
|
||||||
|
def predict_next(self, partial_sequence, vocab, interactions, h_field, mean_pattern):
|
||||||
|
"""Predict next word using DAM dynamics"""
|
||||||
|
possible_words = list(vocab.values())
|
||||||
|
word_energies = []
|
||||||
|
|
||||||
|
for word in possible_words:
|
||||||
|
complete_sequence = partial_sequence + word
|
||||||
|
if len(complete_sequence) == self.M * self.N:
|
||||||
|
# Ensure proper shape for the sequence vector
|
||||||
|
complete_vec = jnp.array([int(bit) for bit in complete_sequence]).reshape(-1, 1)
|
||||||
|
try:
|
||||||
|
energy = float(self._compute_energy(complete_vec, interactions, mean_pattern, h_field))
|
||||||
|
if not (jnp.isnan(energy) or jnp.isinf(energy)):
|
||||||
|
word_energies.append((word, energy))
|
||||||
|
except:
|
||||||
|
continue # Skip if there's an error
|
||||||
|
|
||||||
|
if not word_energies:
|
||||||
|
# Fallback: return random word if all energies are invalid
|
||||||
|
self.key, subkey = jax.random.split(self.key)
|
||||||
|
random_idx = jax.random.randint(subkey, (), 0, len(possible_words))
|
||||||
|
return vocab[possible_words[random_idx]], 0.0
|
||||||
|
|
||||||
|
# Sort by energy
|
||||||
|
word_energies.sort(key=lambda x: x[1])
|
||||||
|
|
||||||
|
# Take top k candidates
|
||||||
|
k = min(10, len(word_energies))
|
||||||
|
top_k = word_energies[:k]
|
||||||
|
|
||||||
|
# Normalize energies with numerical stability
|
||||||
|
energies = jnp.array([e[1] for e in top_k])
|
||||||
|
energies = energies - jnp.min(energies)
|
||||||
|
max_energy = jnp.max(energies)
|
||||||
|
if max_energy > 1e-8:
|
||||||
|
energies = energies / max_energy
|
||||||
|
|
||||||
|
# Sample using stable softmax
|
||||||
|
probs = jnp.exp(-energies / self.temperature)
|
||||||
|
probs = probs / (jnp.sum(probs) + 1e-8)
|
||||||
|
|
||||||
|
self.key, subkey = jax.random.split(self.key)
|
||||||
|
selected_idx = jax.random.choice(subkey, k, p=probs)
|
||||||
|
best_word, min_energy = top_k[selected_idx]
|
||||||
|
|
||||||
|
for word, vector in vocab.items():
|
||||||
|
if vector == best_word:
|
||||||
|
return word, min_energy
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Load saved encodings
|
||||||
|
vocab, encoded_stories, original_stories = load_encodings()
|
||||||
|
if vocab is None:
|
||||||
|
print("No saved encodings found. Please run load_tinystories.py first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize model with adjusted parameters
|
||||||
|
model = DenseAssociativeMemory(
|
||||||
|
M=32, # Keep sequence length manageable
|
||||||
|
N=32, # Increased bits per word
|
||||||
|
temperature=0.1,
|
||||||
|
batch_size=32, # Increased batch size
|
||||||
|
degree=10, # Using quartic interactions
|
||||||
|
h=0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare sequences
|
||||||
|
print("Preparing sequences...")
|
||||||
|
sequences = model.prepare_sequences(encoded_stories)
|
||||||
|
print(f"Prepared {len(sequences)} sequences")
|
||||||
|
|
||||||
|
# Compute interactions
|
||||||
|
print("Computing interactions...")
|
||||||
|
interactions, h_field, mean_pattern = model.compute_interactions(sequences)
|
||||||
|
|
||||||
|
# Get input from user
|
||||||
|
print("\nEnter your story:")
|
||||||
|
sentence = input("Enter a sentence (at least 99 words): ")
|
||||||
|
initial_tokens = tokenize_with_punctuation(sentence)
|
||||||
|
|
||||||
|
if len(initial_tokens) < model.M - 1:
|
||||||
|
print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {model.M-1}.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Predict sequence
|
||||||
|
print("\nPredicting continuation...")
|
||||||
|
current_tokens = initial_tokens[:model.M-1]
|
||||||
|
predictions = []
|
||||||
|
energies = []
|
||||||
|
|
||||||
|
D = 10 # Number of words to predict
|
||||||
|
for _ in tqdm(range(D), desc="Generating words"):
|
||||||
|
# Convert current tokens to binary sequence
|
||||||
|
partial_sequence = ""
|
||||||
|
for token in current_tokens:
|
||||||
|
partial_sequence += vocab[token]
|
||||||
|
|
||||||
|
# Predict next word
|
||||||
|
predicted_word, energy = model.predict_next(
|
||||||
|
partial_sequence,
|
||||||
|
vocab,
|
||||||
|
interactions,
|
||||||
|
h_field,
|
||||||
|
mean_pattern
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions.append(predicted_word)
|
||||||
|
energies.append(energy)
|
||||||
|
|
||||||
|
# Update current tokens
|
||||||
|
current_tokens = current_tokens[1:] + [predicted_word]
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\nYour input ended with:")
|
||||||
|
print(" ".join(initial_tokens[-10:]))
|
||||||
|
print("\nPredicted continuation:")
|
||||||
|
print(" ".join(predictions))
|
||||||
|
print("\nEnergies of predictions:")
|
||||||
|
for i, (word, energy) in enumerate(zip(predictions, energies)):
|
||||||
|
print(f"Word {i+1} ('{word}'): {energy:.4f}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -354,9 +354,9 @@ def predict_sequence(initial_sequence, vocab, sequences, W, D=10, M=100, N=12, t
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
N = 20 # Define N as a constant
|
N = 20 # Define N as a constant
|
||||||
M = 100 # Define M as a constant
|
M = 30 # Define M as a constant
|
||||||
D = 10 # Number of words to predict
|
D = 10 # Number of words to predict
|
||||||
temperature = 1
|
temperature = 0.01
|
||||||
|
|
||||||
batch_size = 50 # Added batch size parameter
|
batch_size = 50 # Added batch size parameter
|
||||||
|
|
||||||
Reference in New Issue
Block a user