first
This commit is contained in:
61
predict_story.py
Normal file
61
predict_story.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from load_tinystories import (
|
||||
tokenize_with_punctuation,
|
||||
load_encodings,
|
||||
predict_sequence,
|
||||
get_word_sequences,
|
||||
BiDict
|
||||
)
|
||||
|
||||
def get_user_sentence():
|
||||
"""Get a sentence from user and validate it has enough words"""
|
||||
while True:
|
||||
sentence = input("Enter a sentence (at least 99 words): ")
|
||||
tokens = tokenize_with_punctuation(sentence)
|
||||
if len(tokens) >= 99:
|
||||
return tokens
|
||||
print(f"Sentence too short. Got {len(tokens)} tokens, need 99.")
|
||||
|
||||
def main():
|
||||
# Load saved encodings
|
||||
vocab, encoded_stories, original_stories = load_encodings()
|
||||
if vocab is None:
|
||||
print("No saved encodings found. Please run load_tinystories.py first.")
|
||||
return
|
||||
|
||||
# Constants
|
||||
M = 100 # Sequence length
|
||||
N = 13 # Bits per word
|
||||
D = 10 # Number of words to predict
|
||||
temperature = 1.0
|
||||
|
||||
print("Loading training sequences...")
|
||||
sequences = get_word_sequences(encoded_stories=encoded_stories, M=M, N=N)
|
||||
print(f"Loaded {len(sequences)} training sequences")
|
||||
|
||||
# Get sentence from user
|
||||
print("\nI'll help you continue your story!")
|
||||
initial_tokens = get_user_sentence()
|
||||
|
||||
# Predict next words
|
||||
print("\nPredicting next words...")
|
||||
predicted_words, energies = predict_sequence(
|
||||
initial_tokens[:M-1], # Use first M-1 tokens
|
||||
vocab,
|
||||
sequences,
|
||||
D=D,
|
||||
M=M,
|
||||
N=N,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
# Print results
|
||||
print("\nYour input ended with:")
|
||||
print(" ".join(initial_tokens[-10:])) # Last 10 tokens
|
||||
print("\nPredicted continuation:")
|
||||
print(" ".join(predicted_words))
|
||||
print("\nEnergies of predictions:")
|
||||
for i, (word, energy) in enumerate(zip(predicted_words, energies)):
|
||||
print(f"Word {i+1} ('{word}'): {energy:.4f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user