from load_tinystories import ( tokenize_with_punctuation, load_encodings, predict_sequence, get_word_sequences, BiDict ) def get_user_sentence(): """Get a sentence from user and validate it has enough words""" while True: sentence = input("Enter a sentence (at least 99 words): ") tokens = tokenize_with_punctuation(sentence) if len(tokens) >= 99: return tokens print(f"Sentence too short. Got {len(tokens)} tokens, need 99.") def main(): # Load saved encodings vocab, encoded_stories, original_stories = load_encodings() if vocab is None: print("No saved encodings found. Please run load_tinystories.py first.") return # Constants M = 100 # Sequence length N = 13 # Bits per word D = 10 # Number of words to predict temperature = 1.0 print("Loading training sequences...") sequences = get_word_sequences(encoded_stories=encoded_stories, M=M, N=N) print(f"Loaded {len(sequences)} training sequences") # Get sentence from user print("\nI'll help you continue your story!") initial_tokens = get_user_sentence() # Predict next words print("\nPredicting next words...") predicted_words, energies = predict_sequence( initial_tokens[:M-1], # Use first M-1 tokens vocab, sequences, D=D, M=M, N=N, temperature=temperature ) # Print results print("\nYour input ended with:") print(" ".join(initial_tokens[-10:])) # Last 10 tokens print("\nPredicted continuation:") print(" ".join(predicted_words)) print("\nEnergies of predictions:") for i, (word, energy) in enumerate(zip(predicted_words, energies)): print(f"Word {i+1} ('{word}'): {energy:.4f}") if __name__ == "__main__": main()