heyyyy
This commit is contained in:
260
old/dense_associative_memory.py
Normal file
260
old/dense_associative_memory.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from functools import partial
|
||||
import jax
|
||||
from tqdm import tqdm
|
||||
from load_tinystories import (
|
||||
tokenize_with_punctuation,
|
||||
load_encodings,
|
||||
BiDict
|
||||
)
|
||||
|
||||
class DenseAssociativeMemory:
|
||||
def __init__(self, M=20, N=20, temperature=0.1, batch_size=16, degree=3, h=0.01):
|
||||
"""Reduced default sequence length and batch size"""
|
||||
self.M = M
|
||||
self.N = N
|
||||
self.temperature = temperature
|
||||
self.batch_size = batch_size
|
||||
self.degree = degree
|
||||
self.h = h
|
||||
self.key = jax.random.PRNGKey(0)
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def _compute_polynomial_interaction(self, batch):
|
||||
"""Compute interactions for a single batch"""
|
||||
interactions = []
|
||||
mean_pattern = jnp.mean(batch, axis=0, keepdims=True)
|
||||
centered_batch = batch - mean_pattern
|
||||
|
||||
for d in range(2, self.degree + 1):
|
||||
# Use smaller scaling factor for higher orders
|
||||
scale = 1.0 / (d * len(batch))
|
||||
interaction = jnp.zeros((self.M * self.N, self.M * self.N))
|
||||
|
||||
def process_seq(interaction, seq):
|
||||
term = jnp.outer(seq.flatten(), seq.flatten())
|
||||
# Use smaller exponent and clip values
|
||||
term = jnp.clip(term, -1.0, 1.0)
|
||||
term = jnp.power(term, d/4) # Reduced power to prevent overflow
|
||||
return interaction + term * scale
|
||||
|
||||
interaction = jax.lax.fori_loop(
|
||||
0, len(centered_batch),
|
||||
lambda i, acc: process_seq(acc, centered_batch[i]),
|
||||
interaction
|
||||
)
|
||||
|
||||
interactions.append(interaction)
|
||||
|
||||
return interactions, mean_pattern
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def _compute_energy(self, sequence, interactions, mean_pattern, h_field):
|
||||
"""Compute energy with polynomial interactions"""
|
||||
# Ensure proper shapes
|
||||
sequence = sequence.reshape(-1, 1) # Make column vector
|
||||
mean_pattern = mean_pattern.reshape(-1, 1) # Ensure same shape
|
||||
|
||||
# Center and normalize the sequence
|
||||
centered_seq = sequence - mean_pattern
|
||||
norm = jnp.linalg.norm(centered_seq) + 1e-8
|
||||
centered_seq = centered_seq / norm
|
||||
|
||||
energy = 0.0
|
||||
for d, interaction in enumerate(interactions, 2):
|
||||
# Reshape for matrix multiplication
|
||||
seq_flat = centered_seq.reshape(1, -1) # Make row vector for first multiplication
|
||||
# Compute energy with numerical stability
|
||||
term = seq_flat @ interaction @ centered_seq
|
||||
term = jnp.clip(term, -100.0, 100.0) # Prevent overflow
|
||||
energy += -jnp.sum(term) / (d * d) # Stronger scaling for higher orders
|
||||
|
||||
# Add scaled magnetic field term
|
||||
field_term = -jnp.sum(h_field * sequence) * 0.1 # Scale down field contribution
|
||||
energy = energy + field_term
|
||||
|
||||
# Clip final energy to prevent extreme values
|
||||
return jnp.clip(energy, -100.0, 100.0)
|
||||
|
||||
def prepare_sequences(self, encoded_stories):
|
||||
"""Process stories in larger chunks with more sequences"""
|
||||
sequences = []
|
||||
max_sequences = 20000 # Increased max sequences
|
||||
|
||||
# Process more stories
|
||||
for story in tqdm(encoded_stories[:1000], desc="Processing stories"):
|
||||
if len(story) >= self.M:
|
||||
# Take more sequences from each story
|
||||
step = max(1, (len(story) - self.M) // 5) # Smaller step size
|
||||
for i in range(0, len(story) - self.M + 1, step):
|
||||
if len(sequences) >= max_sequences:
|
||||
break
|
||||
|
||||
word_group = story[i:i + self.M]
|
||||
bits = []
|
||||
for word in word_group:
|
||||
bits.extend([int(bit) for bit in word])
|
||||
sequences.append(bits)
|
||||
|
||||
if len(sequences) >= max_sequences:
|
||||
break
|
||||
|
||||
# Convert to JAX array and reshape
|
||||
print(f"\nCollected {len(sequences)} sequences")
|
||||
sequences = jnp.array(sequences[:max_sequences]).reshape(-1, self.M * self.N, 1)
|
||||
return sequences
|
||||
|
||||
def compute_interactions(self, sequences):
|
||||
"""Compute interactions in batches"""
|
||||
all_interactions = None
|
||||
mean_pattern = None
|
||||
num_batches = 0
|
||||
|
||||
# Process in small batches
|
||||
for i in tqdm(range(0, len(sequences), self.batch_size), desc="Computing interactions"):
|
||||
batch = sequences[i:i + self.batch_size]
|
||||
batch_interactions, batch_mean = self._compute_polynomial_interaction(batch)
|
||||
|
||||
if all_interactions is None:
|
||||
all_interactions = [jnp.zeros_like(inter) for inter in batch_interactions]
|
||||
mean_pattern = jnp.zeros_like(batch_mean)
|
||||
|
||||
# Accumulate interactions and mean
|
||||
for j, inter in enumerate(batch_interactions):
|
||||
all_interactions[j] += inter
|
||||
mean_pattern += batch_mean
|
||||
num_batches += 1
|
||||
|
||||
# Average the accumulated values
|
||||
all_interactions = [inter / num_batches for inter in all_interactions]
|
||||
mean_pattern = mean_pattern / num_batches
|
||||
|
||||
# Create magnetic field
|
||||
h_field = self.h * jnp.ones((self.M * self.N, 1))
|
||||
|
||||
return all_interactions, h_field, mean_pattern
|
||||
|
||||
def predict_next(self, partial_sequence, vocab, interactions, h_field, mean_pattern):
|
||||
"""Predict next word using DAM dynamics"""
|
||||
possible_words = list(vocab.values())
|
||||
word_energies = []
|
||||
|
||||
for word in possible_words:
|
||||
complete_sequence = partial_sequence + word
|
||||
if len(complete_sequence) == self.M * self.N:
|
||||
# Ensure proper shape for the sequence vector
|
||||
complete_vec = jnp.array([int(bit) for bit in complete_sequence]).reshape(-1, 1)
|
||||
try:
|
||||
energy = float(self._compute_energy(complete_vec, interactions, mean_pattern, h_field))
|
||||
if not (jnp.isnan(energy) or jnp.isinf(energy)):
|
||||
word_energies.append((word, energy))
|
||||
except:
|
||||
continue # Skip if there's an error
|
||||
|
||||
if not word_energies:
|
||||
# Fallback: return random word if all energies are invalid
|
||||
self.key, subkey = jax.random.split(self.key)
|
||||
random_idx = jax.random.randint(subkey, (), 0, len(possible_words))
|
||||
return vocab[possible_words[random_idx]], 0.0
|
||||
|
||||
# Sort by energy
|
||||
word_energies.sort(key=lambda x: x[1])
|
||||
|
||||
# Take top k candidates
|
||||
k = min(10, len(word_energies))
|
||||
top_k = word_energies[:k]
|
||||
|
||||
# Normalize energies with numerical stability
|
||||
energies = jnp.array([e[1] for e in top_k])
|
||||
energies = energies - jnp.min(energies)
|
||||
max_energy = jnp.max(energies)
|
||||
if max_energy > 1e-8:
|
||||
energies = energies / max_energy
|
||||
|
||||
# Sample using stable softmax
|
||||
probs = jnp.exp(-energies / self.temperature)
|
||||
probs = probs / (jnp.sum(probs) + 1e-8)
|
||||
|
||||
self.key, subkey = jax.random.split(self.key)
|
||||
selected_idx = jax.random.choice(subkey, k, p=probs)
|
||||
best_word, min_energy = top_k[selected_idx]
|
||||
|
||||
for word, vector in vocab.items():
|
||||
if vector == best_word:
|
||||
return word, min_energy
|
||||
|
||||
def main():
|
||||
# Load saved encodings
|
||||
vocab, encoded_stories, original_stories = load_encodings()
|
||||
if vocab is None:
|
||||
print("No saved encodings found. Please run load_tinystories.py first.")
|
||||
return
|
||||
|
||||
# Initialize model with adjusted parameters
|
||||
model = DenseAssociativeMemory(
|
||||
M=32, # Keep sequence length manageable
|
||||
N=32, # Increased bits per word
|
||||
temperature=0.1,
|
||||
batch_size=32, # Increased batch size
|
||||
degree=10, # Using quartic interactions
|
||||
h=0.01
|
||||
)
|
||||
|
||||
# Prepare sequences
|
||||
print("Preparing sequences...")
|
||||
sequences = model.prepare_sequences(encoded_stories)
|
||||
print(f"Prepared {len(sequences)} sequences")
|
||||
|
||||
# Compute interactions
|
||||
print("Computing interactions...")
|
||||
interactions, h_field, mean_pattern = model.compute_interactions(sequences)
|
||||
|
||||
# Get input from user
|
||||
print("\nEnter your story:")
|
||||
sentence = input("Enter a sentence (at least 99 words): ")
|
||||
initial_tokens = tokenize_with_punctuation(sentence)
|
||||
|
||||
if len(initial_tokens) < model.M - 1:
|
||||
print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {model.M-1}.")
|
||||
return
|
||||
|
||||
# Predict sequence
|
||||
print("\nPredicting continuation...")
|
||||
current_tokens = initial_tokens[:model.M-1]
|
||||
predictions = []
|
||||
energies = []
|
||||
|
||||
D = 10 # Number of words to predict
|
||||
for _ in tqdm(range(D), desc="Generating words"):
|
||||
# Convert current tokens to binary sequence
|
||||
partial_sequence = ""
|
||||
for token in current_tokens:
|
||||
partial_sequence += vocab[token]
|
||||
|
||||
# Predict next word
|
||||
predicted_word, energy = model.predict_next(
|
||||
partial_sequence,
|
||||
vocab,
|
||||
interactions,
|
||||
h_field,
|
||||
mean_pattern
|
||||
)
|
||||
|
||||
predictions.append(predicted_word)
|
||||
energies.append(energy)
|
||||
|
||||
# Update current tokens
|
||||
current_tokens = current_tokens[1:] + [predicted_word]
|
||||
|
||||
# Print results
|
||||
print("\nYour input ended with:")
|
||||
print(" ".join(initial_tokens[-10:]))
|
||||
print("\nPredicted continuation:")
|
||||
print(" ".join(predictions))
|
||||
print("\nEnergies of predictions:")
|
||||
for i, (word, energy) in enumerate(zip(predictions, energies)):
|
||||
print(f"Word {i+1} ('{word}'): {energy:.4f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user