heyyyy
This commit is contained in:
406
old/load_tinystories.py
Normal file
406
old/load_tinystories.py
Normal file
@@ -0,0 +1,406 @@
|
||||
from datasets import load_dataset
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
import re
|
||||
import pickle
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
|
||||
class BiDict:
|
||||
"""
|
||||
Bidirectional dictionary for word-to-vector and vector-to-word mappings
|
||||
"""
|
||||
def __init__(self):
|
||||
self.word_to_vec = {}
|
||||
self.vec_to_word = {}
|
||||
|
||||
def __setitem__(self, word, vector):
|
||||
# Convert numpy array to tuple for hashing
|
||||
if isinstance(vector, np.ndarray):
|
||||
vector_tuple = tuple(vector.flatten())
|
||||
else:
|
||||
vector_tuple = tuple(vector)
|
||||
|
||||
# Convert vector to string of 1s and 0s
|
||||
vector_str = ''.join(str(int(x)) for x in vector_tuple)
|
||||
|
||||
self.word_to_vec[word] = vector_str
|
||||
self.vec_to_word[vector_str] = word
|
||||
|
||||
def __getitem__(self, key):
|
||||
# If key is a numpy array, convert to string
|
||||
if isinstance(key, np.ndarray):
|
||||
key = ''.join(str(int(x)) for x in key.flatten())
|
||||
# Try word_to_vec first, then vec_to_word
|
||||
return self.word_to_vec.get(key) or self.vec_to_word.get(key)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.word_to_vec)
|
||||
|
||||
def items(self):
|
||||
return self.word_to_vec.items()
|
||||
|
||||
def values(self):
|
||||
return self.word_to_vec.values()
|
||||
|
||||
def load_tinystories():
|
||||
"""
|
||||
Load the TinyStories dataset from Hugging Face.
|
||||
Returns the dataset object containing train and validation splits.
|
||||
"""
|
||||
ds = load_dataset("roneneldan/TinyStories")
|
||||
return ds
|
||||
|
||||
def tokenize_with_punctuation(text):
|
||||
"""
|
||||
Split text into words and punctuation marks as separate tokens.
|
||||
Preserves spaces between words but treats punctuation as separate tokens.
|
||||
"""
|
||||
# Define pattern to split on word boundaries but keep punctuation as tokens
|
||||
# Using raw string to properly escape special characters
|
||||
pattern = r'([.,!?;:"\'()\[\]{}]|\s+|[a-zA-Z0-9]+)'
|
||||
tokens = re.findall(pattern, text.lower())
|
||||
# Filter out empty strings and pure whitespace, but keep punctuation
|
||||
return [token for token in tokens if token.strip() or token in '.,!?;:"\'()[]{}']
|
||||
|
||||
def make_binary_tokens(unique_tokens, N=12):
|
||||
"""
|
||||
Create binary vectors for tokens.
|
||||
Each vector is N bits long, containing only 0s and 1s.
|
||||
"""
|
||||
# Generate random binary vectors (0s and 1s only)
|
||||
codes = np.random.randint(0, 2, size=(len(unique_tokens), N))
|
||||
|
||||
token_to_vector = BiDict()
|
||||
for i, w in enumerate(unique_tokens):
|
||||
# Convert to string of 0s and 1s directly
|
||||
binary_str = ''.join(str(int(x)) for x in codes[i])
|
||||
token_to_vector[w] = binary_str
|
||||
return token_to_vector
|
||||
|
||||
|
||||
def get_vocabulary(stories, N=12):
|
||||
"""
|
||||
Create vocabulary from the given stories.
|
||||
Returns a bidirectional dictionary mapping words and vectors.
|
||||
"""
|
||||
# Get all unique tokens across all stories
|
||||
all_tokens = set()
|
||||
for story in stories:
|
||||
tokens = tokenize_with_punctuation(story)
|
||||
all_tokens.update(tokens)
|
||||
# Sort tokens for consistent encoding
|
||||
unique_tokens = sorted(all_tokens)
|
||||
|
||||
# Create unique N-bit vectors
|
||||
num_tokens = len(unique_tokens)
|
||||
if num_tokens > 2**N:
|
||||
raise ValueError(f"Vocabulary size ({num_tokens}) exceeds {N}-bit capacity ({2**N})")
|
||||
|
||||
# Generate all possible N-bit numbers
|
||||
|
||||
token_to_vector = make_binary_tokens(unique_tokens, N=N)
|
||||
return token_to_vector
|
||||
|
||||
def save_encodings(vocab, encoded_stories, stories, filename='encodings.pkl'):
|
||||
"""Save the encodings and vocabulary to a pickle file"""
|
||||
data = {
|
||||
'vocabulary': vocab,
|
||||
'encoded_stories': encoded_stories,
|
||||
'original_stories': stories
|
||||
}
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
def load_encodings(filename='encodings.pkl'):
|
||||
"""Load encodings from pickle file if it exists"""
|
||||
if os.path.exists(filename):
|
||||
with open(filename, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
return data['vocabulary'], data['encoded_stories'], data['original_stories']
|
||||
return None, None, None
|
||||
|
||||
def encode_stories(n_stories=200, force_encode=False, N=12, batch_size=50):
|
||||
"""
|
||||
Encode stories in batches to reduce memory usage.
|
||||
"""
|
||||
if not force_encode:
|
||||
vocab, encoded_stories, stories = load_encodings()
|
||||
if vocab is not None:
|
||||
print("Loaded existing encodings from file")
|
||||
return vocab, encoded_stories, stories
|
||||
|
||||
ds = load_tinystories()
|
||||
|
||||
# Process stories in batches
|
||||
stories = []
|
||||
encoded_stories = []
|
||||
all_tokens = set()
|
||||
|
||||
# First pass: collect vocabulary
|
||||
print("Building vocabulary...")
|
||||
for i in tqdm(range(0, n_stories, batch_size)):
|
||||
batch = [ds['train'][j]['text'] for j in range(i, min(i + batch_size, n_stories))]
|
||||
for story in batch:
|
||||
tokens = tokenize_with_punctuation(story)
|
||||
all_tokens.update(tokens)
|
||||
|
||||
# Create vocabulary
|
||||
unique_tokens = sorted(all_tokens)
|
||||
vocab = make_binary_tokens(unique_tokens, N=N)
|
||||
|
||||
# Second pass: encode stories
|
||||
print("Encoding stories...")
|
||||
for i in tqdm(range(0, n_stories, batch_size)):
|
||||
batch = [ds['train'][j]['text'] for j in range(i, min(i + batch_size, n_stories))]
|
||||
|
||||
batch_stories = []
|
||||
batch_encoded = []
|
||||
|
||||
for story in batch:
|
||||
tokens = tokenize_with_punctuation(story)
|
||||
encoded_tokens = [vocab[token] for token in tokens]
|
||||
batch_stories.append(story)
|
||||
batch_encoded.append(encoded_tokens)
|
||||
|
||||
stories.extend(batch_stories)
|
||||
encoded_stories.extend(batch_encoded)
|
||||
|
||||
# Save intermediate results
|
||||
if (i + batch_size) % (batch_size * 4) == 0:
|
||||
save_encodings(vocab, encoded_stories, stories)
|
||||
print(f"Saved progress: {i + batch_size}/{n_stories} stories")
|
||||
|
||||
# Final save
|
||||
save_encodings(vocab, encoded_stories, stories)
|
||||
print("Created and saved new encodings")
|
||||
|
||||
return vocab, encoded_stories, stories
|
||||
|
||||
def get_word_sequences(encoded_stories, M=100, N=12, batch_size=32):
|
||||
"""
|
||||
Get sequences of M consecutive words from encoded stories.
|
||||
Process in batches to reduce memory usage.
|
||||
"""
|
||||
sequences = []
|
||||
|
||||
# Process stories in batches
|
||||
for i in tqdm(range(0, len(encoded_stories), batch_size), desc="Generating sequences"):
|
||||
batch = encoded_stories[i:i + batch_size]
|
||||
batch_sequences = []
|
||||
|
||||
for story in batch:
|
||||
if len(story) >= M:
|
||||
for j in range(len(story) - M + 1):
|
||||
word_group = story[j:j + M]
|
||||
bits = []
|
||||
for word in word_group:
|
||||
bits.extend([int(bit) for bit in word])
|
||||
vector = np.array(bits).reshape(M * N, 1)
|
||||
batch_sequences.append(vector)
|
||||
|
||||
sequences.extend(batch_sequences)
|
||||
|
||||
# Free memory
|
||||
del batch_sequences
|
||||
|
||||
return np.array(sequences)
|
||||
|
||||
def sequence_to_words(sequence, N=12):
|
||||
"""
|
||||
Convert a sequence vector back into a list of N-bit words
|
||||
"""
|
||||
# Convert sequence to flat list of bits
|
||||
bits = [str(int(bit[0])) for bit in sequence]
|
||||
# Split into N-bit chunks
|
||||
words = [''.join(bits[i:i + N]) for i in range(0, len(bits), N)]
|
||||
return words
|
||||
|
||||
def calculate_energy(sequences, batch_size=32, h=0.1):
|
||||
"""
|
||||
Calculate the energy of sequences using batched processing with magnetic field.
|
||||
Returns energies and weight matrix W.
|
||||
h: magnetic field strength
|
||||
"""
|
||||
num_sequences = len(sequences)
|
||||
seq_length = sequences[0].shape[0]
|
||||
|
||||
# Initialize weight matrix and magnetic field
|
||||
W = np.zeros((seq_length, seq_length))
|
||||
h_field = h * np.ones(seq_length).reshape(-1, 1) # Uniform magnetic field
|
||||
energies = []
|
||||
|
||||
# Process sequences in batches
|
||||
for i in tqdm(range(0, num_sequences, batch_size), desc="Calculating energies"):
|
||||
batch = sequences[i:min(i + batch_size, num_sequences)]
|
||||
batch = np.array(batch) # Convert batch to numpy array
|
||||
|
||||
# Calculate batch contribution to weight matrix (Hebbian learning)
|
||||
for seq in batch:
|
||||
W += np.dot(seq, seq.T)
|
||||
|
||||
# Calculate batch energies including magnetic field
|
||||
batch_energies = []
|
||||
for seq in batch:
|
||||
# E = -1/2 * s^T * W * s - h * sum(s)
|
||||
# Properly extract scalar values from matrix multiplications
|
||||
spin_spin_matrix = seq.T.dot(W).dot(seq)
|
||||
spin_spin = -0.5 * float(spin_spin_matrix[0, 0])
|
||||
|
||||
magnetic_matrix = h_field.T.dot(seq)
|
||||
magnetic = -float(magnetic_matrix[0, 0])
|
||||
|
||||
energy = spin_spin + magnetic
|
||||
batch_energies.append(energy)
|
||||
|
||||
energies.extend(batch_energies)
|
||||
|
||||
# Normalize weight matrix
|
||||
W = W / num_sequences
|
||||
|
||||
return np.array(energies), W, h_field
|
||||
|
||||
def retrieve_sequences(sequences, partial_sequence, vocab, W, M=10, N=12, temperature=1.0, h=0.1):
|
||||
"""
|
||||
Retrieve the most likely next word using Ising Hamiltonian with magnetic field.
|
||||
"""
|
||||
# Get all possible words from vocabulary
|
||||
possible_words = list(vocab.values())
|
||||
|
||||
# Create magnetic field
|
||||
h_field = h * np.ones(M * N).reshape(-1, 1)
|
||||
|
||||
# Calculate energies for all possible completions
|
||||
word_energies = []
|
||||
|
||||
for word in possible_words:
|
||||
# Create complete sequence with this word
|
||||
complete_sequence = partial_sequence + word
|
||||
if len(complete_sequence) == M*N: # Ensure correct length
|
||||
complete_vec = np.array([int(bit) for bit in complete_sequence]).reshape(M * N, 1)
|
||||
|
||||
# Calculate energy with both interaction and magnetic field terms
|
||||
spin_spin = 0
|
||||
for seq in sequences:
|
||||
# Properly extract scalar from matrix multiplication
|
||||
overlap_matrix = complete_vec.T.dot(seq)
|
||||
overlap = overlap_matrix[0, 0] # Extract single scalar value
|
||||
spin_spin -= overlap * overlap
|
||||
|
||||
# Extract scalar from magnetic field contribution
|
||||
magnetic_matrix = h_field.T.dot(complete_vec)
|
||||
magnetic = -float(magnetic_matrix[0, 0])
|
||||
total_energy = spin_spin + magnetic
|
||||
|
||||
word_energies.append((word, total_energy))
|
||||
|
||||
# Sort by energy
|
||||
word_energies.sort(key=lambda x: x[1])
|
||||
|
||||
# Normalize energies
|
||||
energies = np.array([e[1] for e in word_energies])
|
||||
energies = energies - np.min(energies)
|
||||
max_energy = np.max(energies)
|
||||
if max_energy > 0:
|
||||
energies = energies / max_energy
|
||||
|
||||
# Calculate probabilities with Boltzmann distribution
|
||||
probabilities = np.exp(-energies/temperature)
|
||||
probabilities = probabilities / np.sum(probabilities)
|
||||
|
||||
# Sample from distribution
|
||||
selected_idx = np.random.choice(len(word_energies), p=probabilities)
|
||||
best_word, min_energy = word_energies[selected_idx]
|
||||
|
||||
# Find the word corresponding to the binary vector
|
||||
for word, vector in vocab.items():
|
||||
if vector == best_word:
|
||||
return word, best_word, min_energy
|
||||
|
||||
def predict_sequence(initial_sequence, vocab, sequences, W, D=10, M=100, N=12, temperature=1.0):
|
||||
"""
|
||||
Predict D words iteratively by sliding the window.
|
||||
"""
|
||||
current_tokens = initial_sequence.copy()
|
||||
predictions = []
|
||||
energies = []
|
||||
|
||||
# Add progress bar for predictions
|
||||
for _ in tqdm(range(D), desc="Predicting words"):
|
||||
# Convert current tokens to binary sequence
|
||||
partial_sequence = ""
|
||||
for token in current_tokens:
|
||||
partial_sequence += vocab[token]
|
||||
|
||||
# Predict next word
|
||||
predicted_word, _, energy = retrieve_sequences(
|
||||
sequences,
|
||||
partial_sequence,
|
||||
vocab,
|
||||
W=W,
|
||||
M=M,
|
||||
N=N,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
predictions.append(predicted_word)
|
||||
energies.append(energy)
|
||||
|
||||
# Slide window: remove first token and add predicted word
|
||||
current_tokens = current_tokens[1:] + [predicted_word]
|
||||
|
||||
return predictions, energies
|
||||
|
||||
if __name__ == "__main__":
|
||||
N = 20 # Define N as a constant
|
||||
M = 30 # Define M as a constant
|
||||
D = 10 # Number of words to predict
|
||||
temperature = 0.01
|
||||
|
||||
batch_size = 50 # Added batch size parameter
|
||||
|
||||
print("Loading and encoding stories...")
|
||||
vocab, encoded_stories, original_stories = encode_stories(
|
||||
force_encode=True,
|
||||
N=N,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
print("\nGenerating training sequences...")
|
||||
# Get sequences for training
|
||||
sequences = get_word_sequences(encoded_stories=encoded_stories, M=M, N=N)
|
||||
print(f"Number of training sequences: {len(sequences)}")
|
||||
print(f"Sequence shape: {sequences[0].shape if len(sequences) > 0 else 'No sequences found'}")
|
||||
|
||||
# Get initial sequence from first story
|
||||
story_tokens = tokenize_with_punctuation(original_stories[0])
|
||||
_, W, _ = calculate_energy(sequences)
|
||||
|
||||
# Make sure we have enough tokens for M=100
|
||||
if len(story_tokens) >= M-1:
|
||||
initial_tokens = story_tokens[:M-1]
|
||||
|
||||
# Predict next D words
|
||||
predicted_words, energies = predict_sequence(
|
||||
initial_tokens,
|
||||
vocab,
|
||||
sequences,
|
||||
W=W,
|
||||
D=D,
|
||||
M=M,
|
||||
N=N,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
# Print results
|
||||
print("\nOriginal sequence:")
|
||||
print(" ".join(initial_tokens)) # Last 10 tokens of initial sequence
|
||||
print("\nPredicted sequence:")
|
||||
print(" ".join(predicted_words))
|
||||
print("\nEnergies:")
|
||||
print(energies)
|
||||
print("\nActual next words:")
|
||||
print(" ".join(story_tokens[M-1:M-1+D])) # Next D actual words
|
||||
else:
|
||||
print(f"Story too short. Needs at least {M-1} tokens, but has {len(story_tokens)}")
|
||||
Reference in New Issue
Block a user