import jax import jax.numpy as jnp from functools import partial from load_tinystories import ( tokenize_with_punctuation, load_encodings, BiDict ) from tqdm import tqdm @partial(jax.jit, static_argnums=(3, 4)) def calculate_weights(sequences, M, N): """ Calculate weight matrix using JAX for GPU acceleration. """ # Convert sequences to JAX array sequences_jax = jnp.array(sequences) # Calculate weights using matrix multiplication W = jnp.sum(jnp.matmul(sequences_jax, sequences_jax.transpose(0, 2, 1)), axis=0) # Normalize weights W = W / len(sequences) return W @partial(jax.jit, static_argnums=(3, 4, 5)) def calculate_energies(complete_sequences, W, temperature): """ Calculate energies for all possible completions using JAX. """ # Calculate energies using matrix operations energies = -0.5 * jnp.diagonal( jnp.matmul( jnp.matmul(complete_sequences, W), complete_sequences.transpose(0, 2, 1) ) ).reshape(-1) # Normalize energies energies = energies - jnp.min(energies) energies = energies / (jnp.max(energies) + 1e-10) # Calculate probabilities probabilities = jnp.exp(-energies/temperature) probabilities = probabilities / jnp.sum(probabilities) return energies, probabilities def get_word_sequences_gpu(encoded_stories, M=100, N=12): """ Get sequences of M consecutive words, optimized for GPU. """ sequences = [] for story in tqdm(encoded_stories, desc="Generating sequences"): if len(story) >= M: for i in range(len(story) - M + 1): word_group = story[i:i + M] bits = [] for word in word_group: bits.extend([int(bit) for bit in word]) vector = jnp.array(bits).reshape(M * N, 1) sequences.append(vector) return jnp.array(sequences) def retrieve_sequences_gpu(sequences, partial_sequence, vocab, M=100, N=12, temperature=1.0): """ GPU-accelerated version of sequence retrieval using JAX. """ # Convert partial sequence to vector partial_vec = jnp.array([int(bit) for bit in partial_sequence]).reshape(-1, 1) # Get all possible words possible_words = list(vocab.values()) # Calculate weight matrix W = calculate_weights(sequences, M, N) # Create complete sequences for all possible words complete_sequences = [] for word in possible_words: complete_sequence = partial_sequence + word if len(complete_sequence) == M*N: complete_vec = jnp.array([int(bit) for bit in complete_sequence]).reshape(M * N, 1) complete_sequences.append(complete_vec) complete_sequences = jnp.array(complete_sequences) # Calculate energies and probabilities energies, probabilities = calculate_energies(complete_sequences, W, temperature) # Select word based on probabilities selected_idx = jax.random.choice( jax.random.PRNGKey(0), len(possible_words), p=probabilities ) best_word = possible_words[selected_idx] min_energy = float(energies[selected_idx]) # Find corresponding word for word, vector in vocab.items(): if vector == best_word: return word, best_word, min_energy def predict_sequence_gpu(initial_sequence, vocab, sequences, D=10, M=100, N=12, temperature=1.0): """ GPU-accelerated version of sequence prediction. """ current_tokens = initial_sequence.copy() predictions = [] energies = [] for _ in tqdm(range(D), desc="Predicting words"): partial_sequence = "" for token in current_tokens: partial_sequence += vocab[token] predicted_word, _, energy = retrieve_sequences_gpu( sequences, partial_sequence, vocab, M=M, N=N, temperature=temperature ) predictions.append(predicted_word) energies.append(energy) current_tokens = current_tokens[1:] + [predicted_word] return predictions, energies def main(): # Load saved encodings vocab, encoded_stories, original_stories = load_encodings() if vocab is None: print("No saved encodings found. Please run load_tinystories.py first.") return # Constants M = 100 N = 13 D = 10 temperature = 1.0 print("Loading training sequences...") sequences = get_word_sequences_gpu(encoded_stories=encoded_stories, M=M, N=N) print(f"Loaded {len(sequences)} training sequences") # Get sentence from user print("\nEnter your story:") sentence = input("Enter a sentence (at least 99 words): ") initial_tokens = tokenize_with_punctuation(sentence) if len(initial_tokens) < M-1: print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {M-1}.") return # Predict next words print("\nPredicting next words...") predicted_words, energies = predict_sequence_gpu( initial_tokens[:M-1], vocab, sequences, D=D, M=M, N=N, temperature=temperature ) # Print results print("\nYour input ended with:") print(" ".join(initial_tokens[-10:])) print("\nPredicted continuation:") print(" ".join(predicted_words)) print("\nEnergies of predictions:") for i, (word, energy) in enumerate(zip(predicted_words, energies)): print(f"Word {i+1} ('{word}'): {energy:.4f}") if __name__ == "__main__": main()