hey
This commit is contained in:
@@ -16,10 +16,22 @@ class BiDict:
|
||||
self.vec_to_word = {}
|
||||
|
||||
def __setitem__(self, word, vector):
|
||||
self.word_to_vec[word] = vector
|
||||
self.vec_to_word[vector] = word
|
||||
# Convert numpy array to tuple for hashing
|
||||
if isinstance(vector, np.ndarray):
|
||||
vector_tuple = tuple(vector.flatten())
|
||||
else:
|
||||
vector_tuple = tuple(vector)
|
||||
|
||||
# Convert vector to string of 1s and 0s
|
||||
vector_str = ''.join(str(int(x)) for x in vector_tuple)
|
||||
|
||||
self.word_to_vec[word] = vector_str
|
||||
self.vec_to_word[vector_str] = word
|
||||
|
||||
def __getitem__(self, key):
|
||||
# If key is a numpy array, convert to string
|
||||
if isinstance(key, np.ndarray):
|
||||
key = ''.join(str(int(x)) for x in key.flatten())
|
||||
# Try word_to_vec first, then vec_to_word
|
||||
return self.word_to_vec.get(key) or self.vec_to_word.get(key)
|
||||
|
||||
@@ -52,6 +64,22 @@ def tokenize_with_punctuation(text):
|
||||
# Filter out empty strings and pure whitespace, but keep punctuation
|
||||
return [token for token in tokens if token.strip() or token in '.,!?;:"\'()[]{}']
|
||||
|
||||
def make_binary_tokens(unique_tokens, N=12):
|
||||
"""
|
||||
Create binary vectors for tokens.
|
||||
Each vector is N bits long, containing only 0s and 1s.
|
||||
"""
|
||||
# Generate random binary vectors (0s and 1s only)
|
||||
codes = np.random.randint(0, 2, size=(len(unique_tokens), N))
|
||||
|
||||
token_to_vector = BiDict()
|
||||
for i, w in enumerate(unique_tokens):
|
||||
# Convert to string of 0s and 1s directly
|
||||
binary_str = ''.join(str(int(x)) for x in codes[i])
|
||||
token_to_vector[w] = binary_str
|
||||
return token_to_vector
|
||||
|
||||
|
||||
def get_vocabulary(stories, N=12):
|
||||
"""
|
||||
Create vocabulary from the given stories.
|
||||
@@ -62,7 +90,6 @@ def get_vocabulary(stories, N=12):
|
||||
for story in stories:
|
||||
tokens = tokenize_with_punctuation(story)
|
||||
all_tokens.update(tokens)
|
||||
|
||||
# Sort tokens for consistent encoding
|
||||
unique_tokens = sorted(all_tokens)
|
||||
|
||||
@@ -72,15 +99,8 @@ def get_vocabulary(stories, N=12):
|
||||
raise ValueError(f"Vocabulary size ({num_tokens}) exceeds {N}-bit capacity ({2**N})")
|
||||
|
||||
# Generate all possible N-bit numbers
|
||||
all_possible = list(range(2**N))
|
||||
np.random.shuffle(all_possible)
|
||||
|
||||
# Create unique random binary numbers for each token
|
||||
token_to_vector = BiDict()
|
||||
for i, token in enumerate(unique_tokens):
|
||||
binary = format(all_possible[i], f'0{N}b')
|
||||
token_to_vector[token] = binary
|
||||
|
||||
|
||||
token_to_vector = make_binary_tokens(unique_tokens, N=N)
|
||||
return token_to_vector
|
||||
|
||||
def save_encodings(vocab, encoded_stories, stories, filename='encodings.pkl'):
|
||||
@@ -101,11 +121,9 @@ def load_encodings(filename='encodings.pkl'):
|
||||
return data['vocabulary'], data['encoded_stories'], data['original_stories']
|
||||
return None, None, None
|
||||
|
||||
def encode_stories(n_stories=30, force_encode=False, N=12):
|
||||
def encode_stories(n_stories=200, force_encode=False, N=12, batch_size=50):
|
||||
"""
|
||||
Encode the first n stories into N-bit vectors.
|
||||
If encodings exist and force_encode is False, load from file.
|
||||
Otherwise, create new encodings and save them.
|
||||
Encode stories in batches to reduce memory usage.
|
||||
"""
|
||||
if not force_encode:
|
||||
vocab, encoded_stories, stories = load_encodings()
|
||||
@@ -114,46 +132,80 @@ def encode_stories(n_stories=30, force_encode=False, N=12):
|
||||
return vocab, encoded_stories, stories
|
||||
|
||||
ds = load_tinystories()
|
||||
stories = [ds['train'][i]['text'] for i in range(n_stories)]
|
||||
print(stories)
|
||||
# Get vocabulary mapping with specified N
|
||||
vocab = get_vocabulary(stories, N=N)
|
||||
|
||||
# Encode stories
|
||||
# Process stories in batches
|
||||
stories = []
|
||||
encoded_stories = []
|
||||
for story in stories:
|
||||
tokens = tokenize_with_punctuation(story)
|
||||
encoded_tokens = [vocab[token] for token in tokens]
|
||||
encoded_stories.append(encoded_tokens)
|
||||
all_tokens = set()
|
||||
|
||||
# Save the encodings
|
||||
# First pass: collect vocabulary
|
||||
print("Building vocabulary...")
|
||||
for i in tqdm(range(0, n_stories, batch_size)):
|
||||
batch = [ds['train'][j]['text'] for j in range(i, min(i + batch_size, n_stories))]
|
||||
for story in batch:
|
||||
tokens = tokenize_with_punctuation(story)
|
||||
all_tokens.update(tokens)
|
||||
|
||||
# Create vocabulary
|
||||
unique_tokens = sorted(all_tokens)
|
||||
vocab = make_binary_tokens(unique_tokens, N=N)
|
||||
|
||||
# Second pass: encode stories
|
||||
print("Encoding stories...")
|
||||
for i in tqdm(range(0, n_stories, batch_size)):
|
||||
batch = [ds['train'][j]['text'] for j in range(i, min(i + batch_size, n_stories))]
|
||||
|
||||
batch_stories = []
|
||||
batch_encoded = []
|
||||
|
||||
for story in batch:
|
||||
tokens = tokenize_with_punctuation(story)
|
||||
encoded_tokens = [vocab[token] for token in tokens]
|
||||
batch_stories.append(story)
|
||||
batch_encoded.append(encoded_tokens)
|
||||
|
||||
stories.extend(batch_stories)
|
||||
encoded_stories.extend(batch_encoded)
|
||||
|
||||
# Save intermediate results
|
||||
if (i + batch_size) % (batch_size * 4) == 0:
|
||||
save_encodings(vocab, encoded_stories, stories)
|
||||
print(f"Saved progress: {i + batch_size}/{n_stories} stories")
|
||||
|
||||
# Final save
|
||||
save_encodings(vocab, encoded_stories, stories)
|
||||
print("Created and saved new encodings")
|
||||
|
||||
return vocab, encoded_stories, stories
|
||||
|
||||
def get_word_sequences(encoded_stories, M=100, N=12):
|
||||
def get_word_sequences(encoded_stories, M=100, N=12, batch_size=32):
|
||||
"""
|
||||
Get sequences of M consecutive words from encoded stories.
|
||||
Each word is N bits long.
|
||||
Process in batches to reduce memory usage.
|
||||
"""
|
||||
M_N_sequences = []
|
||||
sequences = []
|
||||
|
||||
# Process each story with progress bar
|
||||
for story in tqdm(encoded_stories, desc="Generating sequences"):
|
||||
# Only process if story has enough words
|
||||
if len(story) >= M:
|
||||
# Get groups of M words, shifting by 1 word each time
|
||||
for i in range(len(story) - M + 1):
|
||||
word_group = story[i:i + M]
|
||||
# Convert words to bit array
|
||||
bits = []
|
||||
for word in word_group:
|
||||
bits.extend([int(bit) for bit in word])
|
||||
vector = np.array(bits).reshape(M * N, 1)
|
||||
M_N_sequences.append(vector)
|
||||
# Process stories in batches
|
||||
for i in tqdm(range(0, len(encoded_stories), batch_size), desc="Generating sequences"):
|
||||
batch = encoded_stories[i:i + batch_size]
|
||||
batch_sequences = []
|
||||
|
||||
for story in batch:
|
||||
if len(story) >= M:
|
||||
for j in range(len(story) - M + 1):
|
||||
word_group = story[j:j + M]
|
||||
bits = []
|
||||
for word in word_group:
|
||||
bits.extend([int(bit) for bit in word])
|
||||
vector = np.array(bits).reshape(M * N, 1)
|
||||
batch_sequences.append(vector)
|
||||
|
||||
sequences.extend(batch_sequences)
|
||||
|
||||
# Free memory
|
||||
del batch_sequences
|
||||
|
||||
return np.array(M_N_sequences)
|
||||
return np.array(sequences)
|
||||
|
||||
def sequence_to_words(sequence, N=12):
|
||||
"""
|
||||
@@ -165,19 +217,40 @@ def sequence_to_words(sequence, N=12):
|
||||
words = [''.join(bits[i:i + N]) for i in range(0, len(bits), N)]
|
||||
return words
|
||||
|
||||
def calculate_energy(sequences):
|
||||
def calculate_energy(sequences, batch_size=32):
|
||||
"""
|
||||
Calculate the energy of each sequence.
|
||||
Calculate the energy of sequences using batched processing.
|
||||
Returns energies and Hamiltonian matrix.
|
||||
"""
|
||||
num_sequences = len(sequences)
|
||||
seq_length = sequences[0].shape[0]
|
||||
|
||||
# Initialize Hamiltonian matrix
|
||||
hamiltonian = np.zeros((seq_length, seq_length))
|
||||
energies = []
|
||||
hamiltonian = 0
|
||||
for seq in sequences:
|
||||
energy = -seq.dot(seq.T)/2
|
||||
hamiltonian += energy
|
||||
energies.append(energy)
|
||||
plt.semilogy(-np.linalg.eigvals(hamiltonian), ".")
|
||||
plt.show()
|
||||
return energies, hamiltonian
|
||||
|
||||
# Process sequences in batches
|
||||
for i in tqdm(range(0, num_sequences, batch_size), desc="Calculating energies"):
|
||||
batch = sequences[i:min(i + batch_size, num_sequences)]
|
||||
batch = np.array(batch) # Convert batch to numpy array
|
||||
|
||||
# Calculate batch energies
|
||||
batch_energies = np.sum(batch * batch.transpose(0, 2, 1), axis=(1, 2)) / -2
|
||||
energies.extend(batch_energies)
|
||||
|
||||
# Update Hamiltonian
|
||||
batch_hamiltonian = np.sum(np.matmul(batch, batch.transpose(0, 2, 1)), axis=0)
|
||||
hamiltonian += batch_hamiltonian
|
||||
|
||||
# Free memory
|
||||
del batch
|
||||
del batch_energies
|
||||
del batch_hamiltonian
|
||||
|
||||
# Normalize Hamiltonian
|
||||
hamiltonian = hamiltonian / num_sequences
|
||||
|
||||
return np.array(energies), hamiltonian
|
||||
|
||||
def retrieve_sequences(sequences, partial_sequence, vocab, W, M=10, N=12, temperature=1.0):
|
||||
"""
|
||||
@@ -202,7 +275,7 @@ def retrieve_sequences(sequences, partial_sequence, vocab, W, M=10, N=12, temper
|
||||
|
||||
# Calculate energy using Ising Hamiltonian
|
||||
energy_matrix = complete_vec.T.dot(W).dot(complete_vec)
|
||||
energy = -0.5 * float(energy_matrix[0, 0])
|
||||
energy = float(energy_matrix[0, 0])
|
||||
|
||||
word_energies.append((word, energy))
|
||||
|
||||
@@ -266,14 +339,18 @@ def predict_sequence(initial_sequence, vocab, sequences, W, D=10, M=100, N=12, t
|
||||
return predictions, energies
|
||||
|
||||
if __name__ == "__main__":
|
||||
N = 13 # Define N as a constant
|
||||
M = 10 # Define M as a constant
|
||||
D = 3 # Number of words to predict
|
||||
temperature = 1.0 # Increased temperature for more diversity
|
||||
N = 20 # Define N as a constant
|
||||
M = 20 # Define M as a constant
|
||||
D = 10 # Number of words to predict
|
||||
temperature = 0.10
|
||||
batch_size = 50 # Added batch size parameter
|
||||
|
||||
print("Loading and encoding stories...")
|
||||
# Force new encoding to ensure consistency
|
||||
vocab, encoded_stories, original_stories = encode_stories(force_encode=True, N=N)
|
||||
vocab, encoded_stories, original_stories = encode_stories(
|
||||
force_encode=True,
|
||||
N=N,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
print("\nGenerating training sequences...")
|
||||
# Get sequences for training
|
||||
@@ -303,7 +380,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Print results
|
||||
print("\nOriginal sequence:")
|
||||
print(" ".join(initial_tokens[-10:])) # Last 10 tokens of initial sequence
|
||||
print(" ".join(initial_tokens)) # Last 10 tokens of initial sequence
|
||||
print("\nPredicted sequence:")
|
||||
print(" ".join(predicted_words))
|
||||
print("\nEnergies:")
|
||||
@@ -312,26 +389,3 @@ if __name__ == "__main__":
|
||||
print(" ".join(story_tokens[M-1:M-1+D])) # Next D actual words
|
||||
else:
|
||||
print(f"Story too short. Needs at least {M-1} tokens, but has {len(story_tokens)}")
|
||||
|
||||
# # Print example
|
||||
# print(f"Total vocabulary size: {len(vocab)}")
|
||||
# print("\nExample encoding for first story:")
|
||||
# print("Original:", original_stories[0])
|
||||
# print("First few tokens and their encodings:")
|
||||
# tokens = tokenize_with_punctuation(original_stories[0])
|
||||
# for token, encoding in zip(tokens[:10], encoded_stories[0][:10]):
|
||||
# print(f"'{token}' -> {encoding}")
|
||||
|
||||
# # Get statistics about vector usage
|
||||
# total_unique_in_vocab = len(vocab)
|
||||
# total_unique_used = len(set([vec for story in encoded_stories for vec in story]))
|
||||
# total_vectors = sum(len(story) for story in encoded_stories)
|
||||
|
||||
# print(f"\nTotal unique vectors in vocabulary: {total_unique_in_vocab}")
|
||||
# print(f"Total unique vectors used in stories: {total_unique_used}")
|
||||
# print(f"Total word occurrences: {total_vectors}")
|
||||
# print(encoded_stories[0])
|
||||
|
||||
# print(sequences)
|
||||
# plt.imshow(energies[0])
|
||||
# plt.show()
|
||||
|
||||
Reference in New Issue
Block a user