This commit is contained in:
alireza
2025-02-23 15:15:17 +03:30
parent f3f7fcda8a
commit df03b985ed
7 changed files with 262 additions and 2 deletions

View File

@@ -0,0 +1,260 @@
import numpy as np
import jax.numpy as jnp
from functools import partial
import jax
from tqdm import tqdm
from load_tinystories import (
tokenize_with_punctuation,
load_encodings,
BiDict
)
class DenseAssociativeMemory:
def __init__(self, M=20, N=20, temperature=0.1, batch_size=16, degree=3, h=0.01):
"""Reduced default sequence length and batch size"""
self.M = M
self.N = N
self.temperature = temperature
self.batch_size = batch_size
self.degree = degree
self.h = h
self.key = jax.random.PRNGKey(0)
@partial(jax.jit, static_argnums=(0,))
def _compute_polynomial_interaction(self, batch):
"""Compute interactions for a single batch"""
interactions = []
mean_pattern = jnp.mean(batch, axis=0, keepdims=True)
centered_batch = batch - mean_pattern
for d in range(2, self.degree + 1):
# Use smaller scaling factor for higher orders
scale = 1.0 / (d * len(batch))
interaction = jnp.zeros((self.M * self.N, self.M * self.N))
def process_seq(interaction, seq):
term = jnp.outer(seq.flatten(), seq.flatten())
# Use smaller exponent and clip values
term = jnp.clip(term, -1.0, 1.0)
term = jnp.power(term, d/4) # Reduced power to prevent overflow
return interaction + term * scale
interaction = jax.lax.fori_loop(
0, len(centered_batch),
lambda i, acc: process_seq(acc, centered_batch[i]),
interaction
)
interactions.append(interaction)
return interactions, mean_pattern
@partial(jax.jit, static_argnums=(0,))
def _compute_energy(self, sequence, interactions, mean_pattern, h_field):
"""Compute energy with polynomial interactions"""
# Ensure proper shapes
sequence = sequence.reshape(-1, 1) # Make column vector
mean_pattern = mean_pattern.reshape(-1, 1) # Ensure same shape
# Center and normalize the sequence
centered_seq = sequence - mean_pattern
norm = jnp.linalg.norm(centered_seq) + 1e-8
centered_seq = centered_seq / norm
energy = 0.0
for d, interaction in enumerate(interactions, 2):
# Reshape for matrix multiplication
seq_flat = centered_seq.reshape(1, -1) # Make row vector for first multiplication
# Compute energy with numerical stability
term = seq_flat @ interaction @ centered_seq
term = jnp.clip(term, -100.0, 100.0) # Prevent overflow
energy += -jnp.sum(term) / (d * d) # Stronger scaling for higher orders
# Add scaled magnetic field term
field_term = -jnp.sum(h_field * sequence) * 0.1 # Scale down field contribution
energy = energy + field_term
# Clip final energy to prevent extreme values
return jnp.clip(energy, -100.0, 100.0)
def prepare_sequences(self, encoded_stories):
"""Process stories in larger chunks with more sequences"""
sequences = []
max_sequences = 20000 # Increased max sequences
# Process more stories
for story in tqdm(encoded_stories[:1000], desc="Processing stories"):
if len(story) >= self.M:
# Take more sequences from each story
step = max(1, (len(story) - self.M) // 5) # Smaller step size
for i in range(0, len(story) - self.M + 1, step):
if len(sequences) >= max_sequences:
break
word_group = story[i:i + self.M]
bits = []
for word in word_group:
bits.extend([int(bit) for bit in word])
sequences.append(bits)
if len(sequences) >= max_sequences:
break
# Convert to JAX array and reshape
print(f"\nCollected {len(sequences)} sequences")
sequences = jnp.array(sequences[:max_sequences]).reshape(-1, self.M * self.N, 1)
return sequences
def compute_interactions(self, sequences):
"""Compute interactions in batches"""
all_interactions = None
mean_pattern = None
num_batches = 0
# Process in small batches
for i in tqdm(range(0, len(sequences), self.batch_size), desc="Computing interactions"):
batch = sequences[i:i + self.batch_size]
batch_interactions, batch_mean = self._compute_polynomial_interaction(batch)
if all_interactions is None:
all_interactions = [jnp.zeros_like(inter) for inter in batch_interactions]
mean_pattern = jnp.zeros_like(batch_mean)
# Accumulate interactions and mean
for j, inter in enumerate(batch_interactions):
all_interactions[j] += inter
mean_pattern += batch_mean
num_batches += 1
# Average the accumulated values
all_interactions = [inter / num_batches for inter in all_interactions]
mean_pattern = mean_pattern / num_batches
# Create magnetic field
h_field = self.h * jnp.ones((self.M * self.N, 1))
return all_interactions, h_field, mean_pattern
def predict_next(self, partial_sequence, vocab, interactions, h_field, mean_pattern):
"""Predict next word using DAM dynamics"""
possible_words = list(vocab.values())
word_energies = []
for word in possible_words:
complete_sequence = partial_sequence + word
if len(complete_sequence) == self.M * self.N:
# Ensure proper shape for the sequence vector
complete_vec = jnp.array([int(bit) for bit in complete_sequence]).reshape(-1, 1)
try:
energy = float(self._compute_energy(complete_vec, interactions, mean_pattern, h_field))
if not (jnp.isnan(energy) or jnp.isinf(energy)):
word_energies.append((word, energy))
except:
continue # Skip if there's an error
if not word_energies:
# Fallback: return random word if all energies are invalid
self.key, subkey = jax.random.split(self.key)
random_idx = jax.random.randint(subkey, (), 0, len(possible_words))
return vocab[possible_words[random_idx]], 0.0
# Sort by energy
word_energies.sort(key=lambda x: x[1])
# Take top k candidates
k = min(10, len(word_energies))
top_k = word_energies[:k]
# Normalize energies with numerical stability
energies = jnp.array([e[1] for e in top_k])
energies = energies - jnp.min(energies)
max_energy = jnp.max(energies)
if max_energy > 1e-8:
energies = energies / max_energy
# Sample using stable softmax
probs = jnp.exp(-energies / self.temperature)
probs = probs / (jnp.sum(probs) + 1e-8)
self.key, subkey = jax.random.split(self.key)
selected_idx = jax.random.choice(subkey, k, p=probs)
best_word, min_energy = top_k[selected_idx]
for word, vector in vocab.items():
if vector == best_word:
return word, min_energy
def main():
# Load saved encodings
vocab, encoded_stories, original_stories = load_encodings()
if vocab is None:
print("No saved encodings found. Please run load_tinystories.py first.")
return
# Initialize model with adjusted parameters
model = DenseAssociativeMemory(
M=32, # Keep sequence length manageable
N=32, # Increased bits per word
temperature=0.1,
batch_size=32, # Increased batch size
degree=10, # Using quartic interactions
h=0.01
)
# Prepare sequences
print("Preparing sequences...")
sequences = model.prepare_sequences(encoded_stories)
print(f"Prepared {len(sequences)} sequences")
# Compute interactions
print("Computing interactions...")
interactions, h_field, mean_pattern = model.compute_interactions(sequences)
# Get input from user
print("\nEnter your story:")
sentence = input("Enter a sentence (at least 99 words): ")
initial_tokens = tokenize_with_punctuation(sentence)
if len(initial_tokens) < model.M - 1:
print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {model.M-1}.")
return
# Predict sequence
print("\nPredicting continuation...")
current_tokens = initial_tokens[:model.M-1]
predictions = []
energies = []
D = 10 # Number of words to predict
for _ in tqdm(range(D), desc="Generating words"):
# Convert current tokens to binary sequence
partial_sequence = ""
for token in current_tokens:
partial_sequence += vocab[token]
# Predict next word
predicted_word, energy = model.predict_next(
partial_sequence,
vocab,
interactions,
h_field,
mean_pattern
)
predictions.append(predicted_word)
energies.append(energy)
# Update current tokens
current_tokens = current_tokens[1:] + [predicted_word]
# Print results
print("\nYour input ended with:")
print(" ".join(initial_tokens[-10:]))
print("\nPredicted continuation:")
print(" ".join(predictions))
print("\nEnergies of predictions:")
for i, (word, energy) in enumerate(zip(predictions, energies)):
print(f"Word {i+1} ('{word}'): {energy:.4f}")
if __name__ == "__main__":
main()

View File

@@ -354,9 +354,9 @@ def predict_sequence(initial_sequence, vocab, sequences, W, D=10, M=100, N=12, t
if __name__ == "__main__":
N = 20 # Define N as a constant
M = 100 # Define M as a constant
M = 30 # Define M as a constant
D = 10 # Number of words to predict
temperature = 1
temperature = 0.01
batch_size = 50 # Added batch size parameter