Files
cimai/spin_glass_gpu.py
alireza 8940a7d7f9 first
2025-02-19 13:45:07 +03:30

185 lines
5.6 KiB
Python

import jax
import jax.numpy as jnp
from functools import partial
from load_tinystories import (
tokenize_with_punctuation,
load_encodings,
BiDict
)
from tqdm import tqdm
@partial(jax.jit, static_argnums=(3, 4))
def calculate_weights(sequences, M, N):
"""
Calculate weight matrix using JAX for GPU acceleration.
"""
# Convert sequences to JAX array
sequences_jax = jnp.array(sequences)
# Calculate weights using matrix multiplication
W = jnp.sum(jnp.matmul(sequences_jax, sequences_jax.transpose(0, 2, 1)), axis=0)
# Normalize weights
W = W / len(sequences)
return W
@partial(jax.jit, static_argnums=(3, 4, 5))
def calculate_energies(complete_sequences, W, temperature):
"""
Calculate energies for all possible completions using JAX.
"""
# Calculate energies using matrix operations
energies = -0.5 * jnp.diagonal(
jnp.matmul(
jnp.matmul(complete_sequences, W),
complete_sequences.transpose(0, 2, 1)
)
).reshape(-1)
# Normalize energies
energies = energies - jnp.min(energies)
energies = energies / (jnp.max(energies) + 1e-10)
# Calculate probabilities
probabilities = jnp.exp(-energies/temperature)
probabilities = probabilities / jnp.sum(probabilities)
return energies, probabilities
def get_word_sequences_gpu(encoded_stories, M=100, N=12):
"""
Get sequences of M consecutive words, optimized for GPU.
"""
sequences = []
for story in tqdm(encoded_stories, desc="Generating sequences"):
if len(story) >= M:
for i in range(len(story) - M + 1):
word_group = story[i:i + M]
bits = []
for word in word_group:
bits.extend([int(bit) for bit in word])
vector = jnp.array(bits).reshape(M * N, 1)
sequences.append(vector)
return jnp.array(sequences)
def retrieve_sequences_gpu(sequences, partial_sequence, vocab, M=100, N=12, temperature=1.0):
"""
GPU-accelerated version of sequence retrieval using JAX.
"""
# Convert partial sequence to vector
partial_vec = jnp.array([int(bit) for bit in partial_sequence]).reshape(-1, 1)
# Get all possible words
possible_words = list(vocab.values())
# Calculate weight matrix
W = calculate_weights(sequences, M, N)
# Create complete sequences for all possible words
complete_sequences = []
for word in possible_words:
complete_sequence = partial_sequence + word
if len(complete_sequence) == M*N:
complete_vec = jnp.array([int(bit) for bit in complete_sequence]).reshape(M * N, 1)
complete_sequences.append(complete_vec)
complete_sequences = jnp.array(complete_sequences)
# Calculate energies and probabilities
energies, probabilities = calculate_energies(complete_sequences, W, temperature)
# Select word based on probabilities
selected_idx = jax.random.choice(
jax.random.PRNGKey(0),
len(possible_words),
p=probabilities
)
best_word = possible_words[selected_idx]
min_energy = float(energies[selected_idx])
# Find corresponding word
for word, vector in vocab.items():
if vector == best_word:
return word, best_word, min_energy
def predict_sequence_gpu(initial_sequence, vocab, sequences, D=10, M=100, N=12, temperature=1.0):
"""
GPU-accelerated version of sequence prediction.
"""
current_tokens = initial_sequence.copy()
predictions = []
energies = []
for _ in tqdm(range(D), desc="Predicting words"):
partial_sequence = ""
for token in current_tokens:
partial_sequence += vocab[token]
predicted_word, _, energy = retrieve_sequences_gpu(
sequences,
partial_sequence,
vocab,
M=M,
N=N,
temperature=temperature
)
predictions.append(predicted_word)
energies.append(energy)
current_tokens = current_tokens[1:] + [predicted_word]
return predictions, energies
def main():
# Load saved encodings
vocab, encoded_stories, original_stories = load_encodings()
if vocab is None:
print("No saved encodings found. Please run load_tinystories.py first.")
return
# Constants
M = 100
N = 13
D = 10
temperature = 1.0
print("Loading training sequences...")
sequences = get_word_sequences_gpu(encoded_stories=encoded_stories, M=M, N=N)
print(f"Loaded {len(sequences)} training sequences")
# Get sentence from user
print("\nEnter your story:")
sentence = input("Enter a sentence (at least 99 words): ")
initial_tokens = tokenize_with_punctuation(sentence)
if len(initial_tokens) < M-1:
print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {M-1}.")
return
# Predict next words
print("\nPredicting next words...")
predicted_words, energies = predict_sequence_gpu(
initial_tokens[:M-1],
vocab,
sequences,
D=D,
M=M,
N=N,
temperature=temperature
)
# Print results
print("\nYour input ended with:")
print(" ".join(initial_tokens[-10:]))
print("\nPredicted continuation:")
print(" ".join(predicted_words))
print("\nEnergies of predictions:")
for i, (word, energy) in enumerate(zip(predicted_words, energies)):
print(f"Word {i+1} ('{word}'): {energy:.4f}")
if __name__ == "__main__":
main()