diff --git a/load_tinystories.py b/load_tinystories.py index 5b11797..56e3a88 100644 --- a/load_tinystories.py +++ b/load_tinystories.py @@ -16,10 +16,22 @@ class BiDict: self.vec_to_word = {} def __setitem__(self, word, vector): - self.word_to_vec[word] = vector - self.vec_to_word[vector] = word + # Convert numpy array to tuple for hashing + if isinstance(vector, np.ndarray): + vector_tuple = tuple(vector.flatten()) + else: + vector_tuple = tuple(vector) + + # Convert vector to string of 1s and 0s + vector_str = ''.join(str(int(x)) for x in vector_tuple) + + self.word_to_vec[word] = vector_str + self.vec_to_word[vector_str] = word def __getitem__(self, key): + # If key is a numpy array, convert to string + if isinstance(key, np.ndarray): + key = ''.join(str(int(x)) for x in key.flatten()) # Try word_to_vec first, then vec_to_word return self.word_to_vec.get(key) or self.vec_to_word.get(key) @@ -52,6 +64,22 @@ def tokenize_with_punctuation(text): # Filter out empty strings and pure whitespace, but keep punctuation return [token for token in tokens if token.strip() or token in '.,!?;:"\'()[]{}'] +def make_binary_tokens(unique_tokens, N=12): + """ + Create binary vectors for tokens. + Each vector is N bits long, containing only 0s and 1s. + """ + # Generate random binary vectors (0s and 1s only) + codes = np.random.randint(0, 2, size=(len(unique_tokens), N)) + + token_to_vector = BiDict() + for i, w in enumerate(unique_tokens): + # Convert to string of 0s and 1s directly + binary_str = ''.join(str(int(x)) for x in codes[i]) + token_to_vector[w] = binary_str + return token_to_vector + + def get_vocabulary(stories, N=12): """ Create vocabulary from the given stories. @@ -62,7 +90,6 @@ def get_vocabulary(stories, N=12): for story in stories: tokens = tokenize_with_punctuation(story) all_tokens.update(tokens) - # Sort tokens for consistent encoding unique_tokens = sorted(all_tokens) @@ -72,15 +99,8 @@ def get_vocabulary(stories, N=12): raise ValueError(f"Vocabulary size ({num_tokens}) exceeds {N}-bit capacity ({2**N})") # Generate all possible N-bit numbers - all_possible = list(range(2**N)) - np.random.shuffle(all_possible) - - # Create unique random binary numbers for each token - token_to_vector = BiDict() - for i, token in enumerate(unique_tokens): - binary = format(all_possible[i], f'0{N}b') - token_to_vector[token] = binary - + + token_to_vector = make_binary_tokens(unique_tokens, N=N) return token_to_vector def save_encodings(vocab, encoded_stories, stories, filename='encodings.pkl'): @@ -101,11 +121,9 @@ def load_encodings(filename='encodings.pkl'): return data['vocabulary'], data['encoded_stories'], data['original_stories'] return None, None, None -def encode_stories(n_stories=30, force_encode=False, N=12): +def encode_stories(n_stories=200, force_encode=False, N=12, batch_size=50): """ - Encode the first n stories into N-bit vectors. - If encodings exist and force_encode is False, load from file. - Otherwise, create new encodings and save them. + Encode stories in batches to reduce memory usage. """ if not force_encode: vocab, encoded_stories, stories = load_encodings() @@ -114,46 +132,80 @@ def encode_stories(n_stories=30, force_encode=False, N=12): return vocab, encoded_stories, stories ds = load_tinystories() - stories = [ds['train'][i]['text'] for i in range(n_stories)] - print(stories) - # Get vocabulary mapping with specified N - vocab = get_vocabulary(stories, N=N) - # Encode stories + # Process stories in batches + stories = [] encoded_stories = [] - for story in stories: - tokens = tokenize_with_punctuation(story) - encoded_tokens = [vocab[token] for token in tokens] - encoded_stories.append(encoded_tokens) + all_tokens = set() - # Save the encodings + # First pass: collect vocabulary + print("Building vocabulary...") + for i in tqdm(range(0, n_stories, batch_size)): + batch = [ds['train'][j]['text'] for j in range(i, min(i + batch_size, n_stories))] + for story in batch: + tokens = tokenize_with_punctuation(story) + all_tokens.update(tokens) + + # Create vocabulary + unique_tokens = sorted(all_tokens) + vocab = make_binary_tokens(unique_tokens, N=N) + + # Second pass: encode stories + print("Encoding stories...") + for i in tqdm(range(0, n_stories, batch_size)): + batch = [ds['train'][j]['text'] for j in range(i, min(i + batch_size, n_stories))] + + batch_stories = [] + batch_encoded = [] + + for story in batch: + tokens = tokenize_with_punctuation(story) + encoded_tokens = [vocab[token] for token in tokens] + batch_stories.append(story) + batch_encoded.append(encoded_tokens) + + stories.extend(batch_stories) + encoded_stories.extend(batch_encoded) + + # Save intermediate results + if (i + batch_size) % (batch_size * 4) == 0: + save_encodings(vocab, encoded_stories, stories) + print(f"Saved progress: {i + batch_size}/{n_stories} stories") + + # Final save save_encodings(vocab, encoded_stories, stories) print("Created and saved new encodings") return vocab, encoded_stories, stories -def get_word_sequences(encoded_stories, M=100, N=12): +def get_word_sequences(encoded_stories, M=100, N=12, batch_size=32): """ Get sequences of M consecutive words from encoded stories. - Each word is N bits long. + Process in batches to reduce memory usage. """ - M_N_sequences = [] + sequences = [] - # Process each story with progress bar - for story in tqdm(encoded_stories, desc="Generating sequences"): - # Only process if story has enough words - if len(story) >= M: - # Get groups of M words, shifting by 1 word each time - for i in range(len(story) - M + 1): - word_group = story[i:i + M] - # Convert words to bit array - bits = [] - for word in word_group: - bits.extend([int(bit) for bit in word]) - vector = np.array(bits).reshape(M * N, 1) - M_N_sequences.append(vector) + # Process stories in batches + for i in tqdm(range(0, len(encoded_stories), batch_size), desc="Generating sequences"): + batch = encoded_stories[i:i + batch_size] + batch_sequences = [] + + for story in batch: + if len(story) >= M: + for j in range(len(story) - M + 1): + word_group = story[j:j + M] + bits = [] + for word in word_group: + bits.extend([int(bit) for bit in word]) + vector = np.array(bits).reshape(M * N, 1) + batch_sequences.append(vector) + + sequences.extend(batch_sequences) + + # Free memory + del batch_sequences - return np.array(M_N_sequences) + return np.array(sequences) def sequence_to_words(sequence, N=12): """ @@ -165,19 +217,40 @@ def sequence_to_words(sequence, N=12): words = [''.join(bits[i:i + N]) for i in range(0, len(bits), N)] return words -def calculate_energy(sequences): +def calculate_energy(sequences, batch_size=32): """ - Calculate the energy of each sequence. + Calculate the energy of sequences using batched processing. + Returns energies and Hamiltonian matrix. """ + num_sequences = len(sequences) + seq_length = sequences[0].shape[0] + + # Initialize Hamiltonian matrix + hamiltonian = np.zeros((seq_length, seq_length)) energies = [] - hamiltonian = 0 - for seq in sequences: - energy = -seq.dot(seq.T)/2 - hamiltonian += energy - energies.append(energy) - plt.semilogy(-np.linalg.eigvals(hamiltonian), ".") - plt.show() - return energies, hamiltonian + + # Process sequences in batches + for i in tqdm(range(0, num_sequences, batch_size), desc="Calculating energies"): + batch = sequences[i:min(i + batch_size, num_sequences)] + batch = np.array(batch) # Convert batch to numpy array + + # Calculate batch energies + batch_energies = np.sum(batch * batch.transpose(0, 2, 1), axis=(1, 2)) / -2 + energies.extend(batch_energies) + + # Update Hamiltonian + batch_hamiltonian = np.sum(np.matmul(batch, batch.transpose(0, 2, 1)), axis=0) + hamiltonian += batch_hamiltonian + + # Free memory + del batch + del batch_energies + del batch_hamiltonian + + # Normalize Hamiltonian + hamiltonian = hamiltonian / num_sequences + + return np.array(energies), hamiltonian def retrieve_sequences(sequences, partial_sequence, vocab, W, M=10, N=12, temperature=1.0): """ @@ -202,7 +275,7 @@ def retrieve_sequences(sequences, partial_sequence, vocab, W, M=10, N=12, temper # Calculate energy using Ising Hamiltonian energy_matrix = complete_vec.T.dot(W).dot(complete_vec) - energy = -0.5 * float(energy_matrix[0, 0]) + energy = float(energy_matrix[0, 0]) word_energies.append((word, energy)) @@ -266,14 +339,18 @@ def predict_sequence(initial_sequence, vocab, sequences, W, D=10, M=100, N=12, t return predictions, energies if __name__ == "__main__": - N = 13 # Define N as a constant - M = 10 # Define M as a constant - D = 3 # Number of words to predict - temperature = 1.0 # Increased temperature for more diversity + N = 20 # Define N as a constant + M = 20 # Define M as a constant + D = 10 # Number of words to predict + temperature = 0.10 + batch_size = 50 # Added batch size parameter print("Loading and encoding stories...") - # Force new encoding to ensure consistency - vocab, encoded_stories, original_stories = encode_stories(force_encode=True, N=N) + vocab, encoded_stories, original_stories = encode_stories( + force_encode=True, + N=N, + batch_size=batch_size + ) print("\nGenerating training sequences...") # Get sequences for training @@ -303,7 +380,7 @@ if __name__ == "__main__": # Print results print("\nOriginal sequence:") - print(" ".join(initial_tokens[-10:])) # Last 10 tokens of initial sequence + print(" ".join(initial_tokens)) # Last 10 tokens of initial sequence print("\nPredicted sequence:") print(" ".join(predicted_words)) print("\nEnergies:") @@ -312,26 +389,3 @@ if __name__ == "__main__": print(" ".join(story_tokens[M-1:M-1+D])) # Next D actual words else: print(f"Story too short. Needs at least {M-1} tokens, but has {len(story_tokens)}") - - # # Print example - # print(f"Total vocabulary size: {len(vocab)}") - # print("\nExample encoding for first story:") - # print("Original:", original_stories[0]) - # print("First few tokens and their encodings:") - # tokens = tokenize_with_punctuation(original_stories[0]) - # for token, encoding in zip(tokens[:10], encoded_stories[0][:10]): - # print(f"'{token}' -> {encoding}") - - # # Get statistics about vector usage - # total_unique_in_vocab = len(vocab) - # total_unique_used = len(set([vec for story in encoded_stories for vec in story])) - # total_vectors = sum(len(story) for story in encoded_stories) - - # print(f"\nTotal unique vectors in vocabulary: {total_unique_in_vocab}") - # print(f"Total unique vectors used in stories: {total_unique_used}") - # print(f"Total word occurrences: {total_vectors}") - # print(encoded_stories[0]) - - # print(sequences) - # plt.imshow(energies[0]) - # plt.show()