185 lines
5.6 KiB
Python
185 lines
5.6 KiB
Python
import jax
|
|
import jax.numpy as jnp
|
|
from functools import partial
|
|
from load_tinystories import (
|
|
tokenize_with_punctuation,
|
|
load_encodings,
|
|
BiDict
|
|
)
|
|
from tqdm import tqdm
|
|
|
|
@partial(jax.jit, static_argnums=(3, 4))
|
|
def calculate_weights(sequences, M, N):
|
|
"""
|
|
Calculate weight matrix using JAX for GPU acceleration.
|
|
"""
|
|
# Convert sequences to JAX array
|
|
sequences_jax = jnp.array(sequences)
|
|
|
|
# Calculate weights using matrix multiplication
|
|
W = jnp.sum(jnp.matmul(sequences_jax, sequences_jax.transpose(0, 2, 1)), axis=0)
|
|
|
|
# Normalize weights
|
|
W = W / len(sequences)
|
|
return W
|
|
|
|
@partial(jax.jit, static_argnums=(3, 4, 5))
|
|
def calculate_energies(complete_sequences, W, temperature):
|
|
"""
|
|
Calculate energies for all possible completions using JAX.
|
|
"""
|
|
# Calculate energies using matrix operations
|
|
energies = -0.5 * jnp.diagonal(
|
|
jnp.matmul(
|
|
jnp.matmul(complete_sequences, W),
|
|
complete_sequences.transpose(0, 2, 1)
|
|
)
|
|
).reshape(-1)
|
|
|
|
# Normalize energies
|
|
energies = energies - jnp.min(energies)
|
|
energies = energies / (jnp.max(energies) + 1e-10)
|
|
|
|
# Calculate probabilities
|
|
probabilities = jnp.exp(-energies/temperature)
|
|
probabilities = probabilities / jnp.sum(probabilities)
|
|
|
|
return energies, probabilities
|
|
|
|
def get_word_sequences_gpu(encoded_stories, M=100, N=12):
|
|
"""
|
|
Get sequences of M consecutive words, optimized for GPU.
|
|
"""
|
|
sequences = []
|
|
|
|
for story in tqdm(encoded_stories, desc="Generating sequences"):
|
|
if len(story) >= M:
|
|
for i in range(len(story) - M + 1):
|
|
word_group = story[i:i + M]
|
|
bits = []
|
|
for word in word_group:
|
|
bits.extend([int(bit) for bit in word])
|
|
vector = jnp.array(bits).reshape(M * N, 1)
|
|
sequences.append(vector)
|
|
|
|
return jnp.array(sequences)
|
|
|
|
def retrieve_sequences_gpu(sequences, partial_sequence, vocab, M=100, N=12, temperature=1.0):
|
|
"""
|
|
GPU-accelerated version of sequence retrieval using JAX.
|
|
"""
|
|
# Convert partial sequence to vector
|
|
partial_vec = jnp.array([int(bit) for bit in partial_sequence]).reshape(-1, 1)
|
|
|
|
# Get all possible words
|
|
possible_words = list(vocab.values())
|
|
|
|
# Calculate weight matrix
|
|
W = calculate_weights(sequences, M, N)
|
|
|
|
# Create complete sequences for all possible words
|
|
complete_sequences = []
|
|
for word in possible_words:
|
|
complete_sequence = partial_sequence + word
|
|
if len(complete_sequence) == M*N:
|
|
complete_vec = jnp.array([int(bit) for bit in complete_sequence]).reshape(M * N, 1)
|
|
complete_sequences.append(complete_vec)
|
|
|
|
complete_sequences = jnp.array(complete_sequences)
|
|
|
|
# Calculate energies and probabilities
|
|
energies, probabilities = calculate_energies(complete_sequences, W, temperature)
|
|
|
|
# Select word based on probabilities
|
|
selected_idx = jax.random.choice(
|
|
jax.random.PRNGKey(0),
|
|
len(possible_words),
|
|
p=probabilities
|
|
)
|
|
|
|
best_word = possible_words[selected_idx]
|
|
min_energy = float(energies[selected_idx])
|
|
|
|
# Find corresponding word
|
|
for word, vector in vocab.items():
|
|
if vector == best_word:
|
|
return word, best_word, min_energy
|
|
|
|
def predict_sequence_gpu(initial_sequence, vocab, sequences, D=10, M=100, N=12, temperature=1.0):
|
|
"""
|
|
GPU-accelerated version of sequence prediction.
|
|
"""
|
|
current_tokens = initial_sequence.copy()
|
|
predictions = []
|
|
energies = []
|
|
|
|
for _ in tqdm(range(D), desc="Predicting words"):
|
|
partial_sequence = ""
|
|
for token in current_tokens:
|
|
partial_sequence += vocab[token]
|
|
|
|
predicted_word, _, energy = retrieve_sequences_gpu(
|
|
sequences,
|
|
partial_sequence,
|
|
vocab,
|
|
M=M,
|
|
N=N,
|
|
temperature=temperature
|
|
)
|
|
|
|
predictions.append(predicted_word)
|
|
energies.append(energy)
|
|
|
|
current_tokens = current_tokens[1:] + [predicted_word]
|
|
|
|
return predictions, energies
|
|
|
|
def main():
|
|
# Load saved encodings
|
|
vocab, encoded_stories, original_stories = load_encodings()
|
|
if vocab is None:
|
|
print("No saved encodings found. Please run load_tinystories.py first.")
|
|
return
|
|
|
|
# Constants
|
|
M = 100
|
|
N = 13
|
|
D = 10
|
|
temperature = 1.0
|
|
|
|
print("Loading training sequences...")
|
|
sequences = get_word_sequences_gpu(encoded_stories=encoded_stories, M=M, N=N)
|
|
print(f"Loaded {len(sequences)} training sequences")
|
|
|
|
# Get sentence from user
|
|
print("\nEnter your story:")
|
|
sentence = input("Enter a sentence (at least 99 words): ")
|
|
initial_tokens = tokenize_with_punctuation(sentence)
|
|
|
|
if len(initial_tokens) < M-1:
|
|
print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {M-1}.")
|
|
return
|
|
|
|
# Predict next words
|
|
print("\nPredicting next words...")
|
|
predicted_words, energies = predict_sequence_gpu(
|
|
initial_tokens[:M-1],
|
|
vocab,
|
|
sequences,
|
|
D=D,
|
|
M=M,
|
|
N=N,
|
|
temperature=temperature
|
|
)
|
|
|
|
# Print results
|
|
print("\nYour input ended with:")
|
|
print(" ".join(initial_tokens[-10:]))
|
|
print("\nPredicted continuation:")
|
|
print(" ".join(predicted_words))
|
|
print("\nEnergies of predictions:")
|
|
for i, (word, energy) in enumerate(zip(predicted_words, energies)):
|
|
print(f"Word {i+1} ('{word}'): {energy:.4f}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |