Files
cimai/predict_story.py
alireza 8940a7d7f9 first
2025-02-19 13:45:07 +03:30

61 lines
1.8 KiB
Python

from load_tinystories import (
tokenize_with_punctuation,
load_encodings,
predict_sequence,
get_word_sequences,
BiDict
)
def get_user_sentence():
"""Get a sentence from user and validate it has enough words"""
while True:
sentence = input("Enter a sentence (at least 99 words): ")
tokens = tokenize_with_punctuation(sentence)
if len(tokens) >= 99:
return tokens
print(f"Sentence too short. Got {len(tokens)} tokens, need 99.")
def main():
# Load saved encodings
vocab, encoded_stories, original_stories = load_encodings()
if vocab is None:
print("No saved encodings found. Please run load_tinystories.py first.")
return
# Constants
M = 100 # Sequence length
N = 13 # Bits per word
D = 10 # Number of words to predict
temperature = 1.0
print("Loading training sequences...")
sequences = get_word_sequences(encoded_stories=encoded_stories, M=M, N=N)
print(f"Loaded {len(sequences)} training sequences")
# Get sentence from user
print("\nI'll help you continue your story!")
initial_tokens = get_user_sentence()
# Predict next words
print("\nPredicting next words...")
predicted_words, energies = predict_sequence(
initial_tokens[:M-1], # Use first M-1 tokens
vocab,
sequences,
D=D,
M=M,
N=N,
temperature=temperature
)
# Print results
print("\nYour input ended with:")
print(" ".join(initial_tokens[-10:])) # Last 10 tokens
print("\nPredicted continuation:")
print(" ".join(predicted_words))
print("\nEnergies of predictions:")
for i, (word, energy) in enumerate(zip(predicted_words, energies)):
print(f"Word {i+1} ('{word}'): {energy:.4f}")
if __name__ == "__main__":
main()