338 lines
12 KiB
Python
338 lines
12 KiB
Python
from datasets import load_dataset
|
|
import numpy as np
|
|
from collections import Counter
|
|
import re
|
|
import pickle
|
|
import os
|
|
import matplotlib.pyplot as plt
|
|
from tqdm import tqdm
|
|
|
|
class BiDict:
|
|
"""
|
|
Bidirectional dictionary for word-to-vector and vector-to-word mappings
|
|
"""
|
|
def __init__(self):
|
|
self.word_to_vec = {}
|
|
self.vec_to_word = {}
|
|
|
|
def __setitem__(self, word, vector):
|
|
self.word_to_vec[word] = vector
|
|
self.vec_to_word[vector] = word
|
|
|
|
def __getitem__(self, key):
|
|
# Try word_to_vec first, then vec_to_word
|
|
return self.word_to_vec.get(key) or self.vec_to_word.get(key)
|
|
|
|
def __len__(self):
|
|
return len(self.word_to_vec)
|
|
|
|
def items(self):
|
|
return self.word_to_vec.items()
|
|
|
|
def values(self):
|
|
return self.word_to_vec.values()
|
|
|
|
def load_tinystories():
|
|
"""
|
|
Load the TinyStories dataset from Hugging Face.
|
|
Returns the dataset object containing train and validation splits.
|
|
"""
|
|
ds = load_dataset("roneneldan/TinyStories")
|
|
return ds
|
|
|
|
def tokenize_with_punctuation(text):
|
|
"""
|
|
Split text into words and punctuation marks as separate tokens.
|
|
Preserves spaces between words but treats punctuation as separate tokens.
|
|
"""
|
|
# Define pattern to split on word boundaries but keep punctuation as tokens
|
|
# Using raw string to properly escape special characters
|
|
pattern = r'([.,!?;:"\'()\[\]{}]|\s+|[a-zA-Z0-9]+)'
|
|
tokens = re.findall(pattern, text.lower())
|
|
# Filter out empty strings and pure whitespace, but keep punctuation
|
|
return [token for token in tokens if token.strip() or token in '.,!?;:"\'()[]{}']
|
|
|
|
def get_vocabulary(stories, N=12):
|
|
"""
|
|
Create vocabulary from the given stories.
|
|
Returns a bidirectional dictionary mapping words and vectors.
|
|
"""
|
|
# Get all unique tokens across all stories
|
|
all_tokens = set()
|
|
for story in stories:
|
|
tokens = tokenize_with_punctuation(story)
|
|
all_tokens.update(tokens)
|
|
|
|
# Sort tokens for consistent encoding
|
|
unique_tokens = sorted(all_tokens)
|
|
|
|
# Create unique N-bit vectors
|
|
num_tokens = len(unique_tokens)
|
|
if num_tokens > 2**N:
|
|
raise ValueError(f"Vocabulary size ({num_tokens}) exceeds {N}-bit capacity ({2**N})")
|
|
|
|
# Generate all possible N-bit numbers
|
|
all_possible = list(range(2**N))
|
|
np.random.shuffle(all_possible)
|
|
|
|
# Create unique random binary numbers for each token
|
|
token_to_vector = BiDict()
|
|
for i, token in enumerate(unique_tokens):
|
|
binary = format(all_possible[i], f'0{N}b')
|
|
token_to_vector[token] = binary
|
|
|
|
return token_to_vector
|
|
|
|
def save_encodings(vocab, encoded_stories, stories, filename='encodings.pkl'):
|
|
"""Save the encodings and vocabulary to a pickle file"""
|
|
data = {
|
|
'vocabulary': vocab,
|
|
'encoded_stories': encoded_stories,
|
|
'original_stories': stories
|
|
}
|
|
with open(filename, 'wb') as f:
|
|
pickle.dump(data, f)
|
|
|
|
def load_encodings(filename='encodings.pkl'):
|
|
"""Load encodings from pickle file if it exists"""
|
|
if os.path.exists(filename):
|
|
with open(filename, 'rb') as f:
|
|
data = pickle.load(f)
|
|
return data['vocabulary'], data['encoded_stories'], data['original_stories']
|
|
return None, None, None
|
|
|
|
def encode_stories(n_stories=30, force_encode=False, N=12):
|
|
"""
|
|
Encode the first n stories into N-bit vectors.
|
|
If encodings exist and force_encode is False, load from file.
|
|
Otherwise, create new encodings and save them.
|
|
"""
|
|
if not force_encode:
|
|
vocab, encoded_stories, stories = load_encodings()
|
|
if vocab is not None:
|
|
print("Loaded existing encodings from file")
|
|
return vocab, encoded_stories, stories
|
|
|
|
ds = load_tinystories()
|
|
stories = [ds['train'][i]['text'] for i in range(n_stories)]
|
|
print(stories)
|
|
# Get vocabulary mapping with specified N
|
|
vocab = get_vocabulary(stories, N=N)
|
|
|
|
# Encode stories
|
|
encoded_stories = []
|
|
for story in stories:
|
|
tokens = tokenize_with_punctuation(story)
|
|
encoded_tokens = [vocab[token] for token in tokens]
|
|
encoded_stories.append(encoded_tokens)
|
|
|
|
# Save the encodings
|
|
save_encodings(vocab, encoded_stories, stories)
|
|
print("Created and saved new encodings")
|
|
|
|
return vocab, encoded_stories, stories
|
|
|
|
def get_word_sequences(encoded_stories, M=100, N=12):
|
|
"""
|
|
Get sequences of M consecutive words from encoded stories.
|
|
Each word is N bits long.
|
|
"""
|
|
M_N_sequences = []
|
|
|
|
# Process each story with progress bar
|
|
for story in tqdm(encoded_stories, desc="Generating sequences"):
|
|
# Only process if story has enough words
|
|
if len(story) >= M:
|
|
# Get groups of M words, shifting by 1 word each time
|
|
for i in range(len(story) - M + 1):
|
|
word_group = story[i:i + M]
|
|
# Convert words to bit array
|
|
bits = []
|
|
for word in word_group:
|
|
bits.extend([int(bit) for bit in word])
|
|
vector = np.array(bits).reshape(M * N, 1)
|
|
M_N_sequences.append(vector)
|
|
|
|
return np.array(M_N_sequences)
|
|
|
|
def sequence_to_words(sequence, N=12):
|
|
"""
|
|
Convert a sequence vector back into a list of N-bit words
|
|
"""
|
|
# Convert sequence to flat list of bits
|
|
bits = [str(int(bit[0])) for bit in sequence]
|
|
# Split into N-bit chunks
|
|
words = [''.join(bits[i:i + N]) for i in range(0, len(bits), N)]
|
|
return words
|
|
|
|
def calculate_energy(sequences):
|
|
"""
|
|
Calculate the energy of each sequence.
|
|
"""
|
|
energies = []
|
|
hamiltonian = 0
|
|
for seq in sequences:
|
|
energy = -seq.dot(seq.T)/2
|
|
hamiltonian += energy
|
|
energies.append(energy)
|
|
plt.semilogy(-np.linalg.eigvals(hamiltonian), ".")
|
|
plt.show()
|
|
return energies, hamiltonian
|
|
|
|
def retrieve_sequences(sequences, partial_sequence, vocab, W, M=10, N=12, temperature=1.0):
|
|
"""
|
|
Retrieve the most likely next word using Ising Hamiltonian with temperature.
|
|
Uses associative memory to retrieve the last word of the sequence.
|
|
"""
|
|
# Convert partial sequence to vector
|
|
partial_vec = np.array([int(bit) for bit in partial_sequence]).reshape(-1, 1)
|
|
|
|
# Get all possible words from vocabulary
|
|
possible_words = list(vocab.values())
|
|
|
|
# Calculate weights matrix (Hebbian learning)
|
|
# Calculate energies for all possible words
|
|
word_energies = []
|
|
|
|
for word in possible_words:
|
|
# Create complete sequence with this word
|
|
complete_sequence = partial_sequence + word
|
|
if len(complete_sequence) == M*N: # Ensure correct length
|
|
complete_vec = np.array([int(bit) for bit in complete_sequence]).reshape(M * N, 1)
|
|
|
|
# Calculate energy using Ising Hamiltonian
|
|
energy_matrix = complete_vec.T.dot(W).dot(complete_vec)
|
|
energy = -0.5 * float(energy_matrix[0, 0])
|
|
|
|
word_energies.append((word, energy))
|
|
|
|
# Sort by energy
|
|
word_energies.sort(key=lambda x: x[1])
|
|
|
|
# Normalize energies to prevent overflow
|
|
energies = np.array([e[1] for e in word_energies])
|
|
energies = energies - np.min(energies) # Shift to make minimum energy 0
|
|
energies = energies / np.max(energies) if np.max(energies) > 0 else energies # Scale to [0,1]
|
|
|
|
# Calculate probabilities with normalized energies
|
|
probabilities = np.exp(-energies/temperature)
|
|
probabilities = probabilities / np.sum(probabilities)
|
|
|
|
# Check for valid probabilities
|
|
if np.any(np.isnan(probabilities)):
|
|
# Fallback to uniform distribution if numerical issues occur
|
|
probabilities = np.ones(len(word_energies)) / len(word_energies)
|
|
|
|
selected_idx = np.random.choice(len(word_energies), p=probabilities)
|
|
best_word, min_energy = word_energies[selected_idx]
|
|
|
|
# Find the word corresponding to the binary vector
|
|
for word, vector in vocab.items():
|
|
if vector == best_word:
|
|
return word, best_word, min_energy
|
|
|
|
def predict_sequence(initial_sequence, vocab, sequences, W, D=10, M=100, N=12, temperature=1.0):
|
|
"""
|
|
Predict D words iteratively by sliding the window.
|
|
"""
|
|
current_tokens = initial_sequence.copy()
|
|
predictions = []
|
|
energies = []
|
|
|
|
# Add progress bar for predictions
|
|
for _ in tqdm(range(D), desc="Predicting words"):
|
|
# Convert current tokens to binary sequence
|
|
partial_sequence = ""
|
|
for token in current_tokens:
|
|
partial_sequence += vocab[token]
|
|
|
|
# Predict next word
|
|
predicted_word, _, energy = retrieve_sequences(
|
|
sequences,
|
|
partial_sequence,
|
|
vocab,
|
|
W=W,
|
|
M=M,
|
|
N=N,
|
|
temperature=temperature
|
|
)
|
|
|
|
predictions.append(predicted_word)
|
|
energies.append(energy)
|
|
|
|
# Slide window: remove first token and add predicted word
|
|
current_tokens = current_tokens[1:] + [predicted_word]
|
|
|
|
return predictions, energies
|
|
|
|
if __name__ == "__main__":
|
|
N = 13 # Define N as a constant
|
|
M = 10 # Define M as a constant
|
|
D = 3 # Number of words to predict
|
|
temperature = 1.0 # Increased temperature for more diversity
|
|
|
|
print("Loading and encoding stories...")
|
|
# Force new encoding to ensure consistency
|
|
vocab, encoded_stories, original_stories = encode_stories(force_encode=True, N=N)
|
|
|
|
print("\nGenerating training sequences...")
|
|
# Get sequences for training
|
|
sequences = get_word_sequences(encoded_stories=encoded_stories, M=M, N=N)
|
|
print(f"Number of training sequences: {len(sequences)}")
|
|
print(f"Sequence shape: {sequences[0].shape if len(sequences) > 0 else 'No sequences found'}")
|
|
|
|
# Get initial sequence from first story
|
|
story_tokens = tokenize_with_punctuation(original_stories[0])
|
|
_, W = calculate_energy(sequences)
|
|
|
|
# Make sure we have enough tokens for M=100
|
|
if len(story_tokens) >= M-1:
|
|
initial_tokens = story_tokens[:M-1]
|
|
|
|
# Predict next D words
|
|
predicted_words, energies = predict_sequence(
|
|
initial_tokens,
|
|
vocab,
|
|
sequences,
|
|
W=W,
|
|
D=D,
|
|
M=M,
|
|
N=N,
|
|
temperature=temperature
|
|
)
|
|
|
|
# Print results
|
|
print("\nOriginal sequence:")
|
|
print(" ".join(initial_tokens[-10:])) # Last 10 tokens of initial sequence
|
|
print("\nPredicted sequence:")
|
|
print(" ".join(predicted_words))
|
|
print("\nEnergies:")
|
|
print(energies)
|
|
print("\nActual next words:")
|
|
print(" ".join(story_tokens[M-1:M-1+D])) # Next D actual words
|
|
else:
|
|
print(f"Story too short. Needs at least {M-1} tokens, but has {len(story_tokens)}")
|
|
|
|
# # Print example
|
|
# print(f"Total vocabulary size: {len(vocab)}")
|
|
# print("\nExample encoding for first story:")
|
|
# print("Original:", original_stories[0])
|
|
# print("First few tokens and their encodings:")
|
|
# tokens = tokenize_with_punctuation(original_stories[0])
|
|
# for token, encoding in zip(tokens[:10], encoded_stories[0][:10]):
|
|
# print(f"'{token}' -> {encoding}")
|
|
|
|
# # Get statistics about vector usage
|
|
# total_unique_in_vocab = len(vocab)
|
|
# total_unique_used = len(set([vec for story in encoded_stories for vec in story]))
|
|
# total_vectors = sum(len(story) for story in encoded_stories)
|
|
|
|
# print(f"\nTotal unique vectors in vocabulary: {total_unique_in_vocab}")
|
|
# print(f"Total unique vectors used in stories: {total_unique_used}")
|
|
# print(f"Total word occurrences: {total_vectors}")
|
|
# print(encoded_stories[0])
|
|
|
|
# print(sequences)
|
|
# plt.imshow(energies[0])
|
|
# plt.show()
|