61 lines
1.8 KiB
Python
61 lines
1.8 KiB
Python
from load_tinystories import (
|
|
tokenize_with_punctuation,
|
|
load_encodings,
|
|
predict_sequence,
|
|
get_word_sequences,
|
|
BiDict
|
|
)
|
|
|
|
def get_user_sentence():
|
|
"""Get a sentence from user and validate it has enough words"""
|
|
while True:
|
|
sentence = input("Enter a sentence (at least 99 words): ")
|
|
tokens = tokenize_with_punctuation(sentence)
|
|
if len(tokens) >= 99:
|
|
return tokens
|
|
print(f"Sentence too short. Got {len(tokens)} tokens, need 99.")
|
|
|
|
def main():
|
|
# Load saved encodings
|
|
vocab, encoded_stories, original_stories = load_encodings()
|
|
if vocab is None:
|
|
print("No saved encodings found. Please run load_tinystories.py first.")
|
|
return
|
|
|
|
# Constants
|
|
M = 100 # Sequence length
|
|
N = 13 # Bits per word
|
|
D = 10 # Number of words to predict
|
|
temperature = 1.0
|
|
|
|
print("Loading training sequences...")
|
|
sequences = get_word_sequences(encoded_stories=encoded_stories, M=M, N=N)
|
|
print(f"Loaded {len(sequences)} training sequences")
|
|
|
|
# Get sentence from user
|
|
print("\nI'll help you continue your story!")
|
|
initial_tokens = get_user_sentence()
|
|
|
|
# Predict next words
|
|
print("\nPredicting next words...")
|
|
predicted_words, energies = predict_sequence(
|
|
initial_tokens[:M-1], # Use first M-1 tokens
|
|
vocab,
|
|
sequences,
|
|
D=D,
|
|
M=M,
|
|
N=N,
|
|
temperature=temperature
|
|
)
|
|
|
|
# Print results
|
|
print("\nYour input ended with:")
|
|
print(" ".join(initial_tokens[-10:])) # Last 10 tokens
|
|
print("\nPredicted continuation:")
|
|
print(" ".join(predicted_words))
|
|
print("\nEnergies of predictions:")
|
|
for i, (word, energy) in enumerate(zip(predicted_words, energies)):
|
|
print(f"Word {i+1} ('{word}'): {energy:.4f}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |