heyyyy
This commit is contained in:
185
old/spin_glass_gpu.py
Normal file
185
old/spin_glass_gpu.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from functools import partial
|
||||
from load_tinystories import (
|
||||
tokenize_with_punctuation,
|
||||
load_encodings,
|
||||
BiDict
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
@partial(jax.jit, static_argnums=(3, 4))
|
||||
def calculate_weights(sequences, M, N):
|
||||
"""
|
||||
Calculate weight matrix using JAX for GPU acceleration.
|
||||
"""
|
||||
# Convert sequences to JAX array
|
||||
sequences_jax = jnp.array(sequences)
|
||||
|
||||
# Calculate weights using matrix multiplication
|
||||
W = jnp.sum(jnp.matmul(sequences_jax, sequences_jax.transpose(0, 2, 1)), axis=0)
|
||||
|
||||
# Normalize weights
|
||||
W = W / len(sequences)
|
||||
return W
|
||||
|
||||
@partial(jax.jit, static_argnums=(3, 4, 5))
|
||||
def calculate_energies(complete_sequences, W, temperature):
|
||||
"""
|
||||
Calculate energies for all possible completions using JAX.
|
||||
"""
|
||||
# Calculate energies using matrix operations
|
||||
energies = -0.5 * jnp.diagonal(
|
||||
jnp.matmul(
|
||||
jnp.matmul(complete_sequences, W),
|
||||
complete_sequences.transpose(0, 2, 1)
|
||||
)
|
||||
).reshape(-1)
|
||||
|
||||
# Normalize energies
|
||||
energies = energies - jnp.min(energies)
|
||||
energies = energies / (jnp.max(energies) + 1e-10)
|
||||
|
||||
# Calculate probabilities
|
||||
probabilities = jnp.exp(-energies/temperature)
|
||||
probabilities = probabilities / jnp.sum(probabilities)
|
||||
|
||||
return energies, probabilities
|
||||
|
||||
def get_word_sequences_gpu(encoded_stories, M=100, N=12):
|
||||
"""
|
||||
Get sequences of M consecutive words, optimized for GPU.
|
||||
"""
|
||||
sequences = []
|
||||
|
||||
for story in tqdm(encoded_stories, desc="Generating sequences"):
|
||||
if len(story) >= M:
|
||||
for i in range(len(story) - M + 1):
|
||||
word_group = story[i:i + M]
|
||||
bits = []
|
||||
for word in word_group:
|
||||
bits.extend([int(bit) for bit in word])
|
||||
vector = jnp.array(bits).reshape(M * N, 1)
|
||||
sequences.append(vector)
|
||||
|
||||
return jnp.array(sequences)
|
||||
|
||||
def retrieve_sequences_gpu(sequences, partial_sequence, vocab, M=100, N=12, temperature=1.0):
|
||||
"""
|
||||
GPU-accelerated version of sequence retrieval using JAX.
|
||||
"""
|
||||
# Convert partial sequence to vector
|
||||
partial_vec = jnp.array([int(bit) for bit in partial_sequence]).reshape(-1, 1)
|
||||
|
||||
# Get all possible words
|
||||
possible_words = list(vocab.values())
|
||||
|
||||
# Calculate weight matrix
|
||||
W = calculate_weights(sequences, M, N)
|
||||
|
||||
# Create complete sequences for all possible words
|
||||
complete_sequences = []
|
||||
for word in possible_words:
|
||||
complete_sequence = partial_sequence + word
|
||||
if len(complete_sequence) == M*N:
|
||||
complete_vec = jnp.array([int(bit) for bit in complete_sequence]).reshape(M * N, 1)
|
||||
complete_sequences.append(complete_vec)
|
||||
|
||||
complete_sequences = jnp.array(complete_sequences)
|
||||
|
||||
# Calculate energies and probabilities
|
||||
energies, probabilities = calculate_energies(complete_sequences, W, temperature)
|
||||
|
||||
# Select word based on probabilities
|
||||
selected_idx = jax.random.choice(
|
||||
jax.random.PRNGKey(0),
|
||||
len(possible_words),
|
||||
p=probabilities
|
||||
)
|
||||
|
||||
best_word = possible_words[selected_idx]
|
||||
min_energy = float(energies[selected_idx])
|
||||
|
||||
# Find corresponding word
|
||||
for word, vector in vocab.items():
|
||||
if vector == best_word:
|
||||
return word, best_word, min_energy
|
||||
|
||||
def predict_sequence_gpu(initial_sequence, vocab, sequences, D=10, M=100, N=12, temperature=1.0):
|
||||
"""
|
||||
GPU-accelerated version of sequence prediction.
|
||||
"""
|
||||
current_tokens = initial_sequence.copy()
|
||||
predictions = []
|
||||
energies = []
|
||||
|
||||
for _ in tqdm(range(D), desc="Predicting words"):
|
||||
partial_sequence = ""
|
||||
for token in current_tokens:
|
||||
partial_sequence += vocab[token]
|
||||
|
||||
predicted_word, _, energy = retrieve_sequences_gpu(
|
||||
sequences,
|
||||
partial_sequence,
|
||||
vocab,
|
||||
M=M,
|
||||
N=N,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
predictions.append(predicted_word)
|
||||
energies.append(energy)
|
||||
|
||||
current_tokens = current_tokens[1:] + [predicted_word]
|
||||
|
||||
return predictions, energies
|
||||
|
||||
def main():
|
||||
# Load saved encodings
|
||||
vocab, encoded_stories, original_stories = load_encodings()
|
||||
if vocab is None:
|
||||
print("No saved encodings found. Please run load_tinystories.py first.")
|
||||
return
|
||||
|
||||
# Constants
|
||||
M = 100
|
||||
N = 13
|
||||
D = 10
|
||||
temperature = 1.0
|
||||
|
||||
print("Loading training sequences...")
|
||||
sequences = get_word_sequences_gpu(encoded_stories=encoded_stories, M=M, N=N)
|
||||
print(f"Loaded {len(sequences)} training sequences")
|
||||
|
||||
# Get sentence from user
|
||||
print("\nEnter your story:")
|
||||
sentence = input("Enter a sentence (at least 99 words): ")
|
||||
initial_tokens = tokenize_with_punctuation(sentence)
|
||||
|
||||
if len(initial_tokens) < M-1:
|
||||
print(f"Sentence too short. Got {len(initial_tokens)} tokens, need {M-1}.")
|
||||
return
|
||||
|
||||
# Predict next words
|
||||
print("\nPredicting next words...")
|
||||
predicted_words, energies = predict_sequence_gpu(
|
||||
initial_tokens[:M-1],
|
||||
vocab,
|
||||
sequences,
|
||||
D=D,
|
||||
M=M,
|
||||
N=N,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
# Print results
|
||||
print("\nYour input ended with:")
|
||||
print(" ".join(initial_tokens[-10:]))
|
||||
print("\nPredicted continuation:")
|
||||
print(" ".join(predicted_words))
|
||||
print("\nEnergies of predictions:")
|
||||
for i, (word, energy) in enumerate(zip(predicted_words, energies)):
|
||||
print(f"Word {i+1} ('{word}'): {energy:.4f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user