hey
This commit is contained in:
@@ -217,16 +217,18 @@ def sequence_to_words(sequence, N=12):
|
|||||||
words = [''.join(bits[i:i + N]) for i in range(0, len(bits), N)]
|
words = [''.join(bits[i:i + N]) for i in range(0, len(bits), N)]
|
||||||
return words
|
return words
|
||||||
|
|
||||||
def calculate_energy(sequences, batch_size=32):
|
def calculate_energy(sequences, batch_size=32, h=0.1):
|
||||||
"""
|
"""
|
||||||
Calculate the energy of sequences using batched processing.
|
Calculate the energy of sequences using batched processing with magnetic field.
|
||||||
Returns energies and Hamiltonian matrix.
|
Returns energies and weight matrix W.
|
||||||
|
h: magnetic field strength
|
||||||
"""
|
"""
|
||||||
num_sequences = len(sequences)
|
num_sequences = len(sequences)
|
||||||
seq_length = sequences[0].shape[0]
|
seq_length = sequences[0].shape[0]
|
||||||
|
|
||||||
# Initialize Hamiltonian matrix
|
# Initialize weight matrix and magnetic field
|
||||||
hamiltonian = np.zeros((seq_length, seq_length))
|
W = np.zeros((seq_length, seq_length))
|
||||||
|
h_field = h * np.ones(seq_length).reshape(-1, 1) # Uniform magnetic field
|
||||||
energies = []
|
energies = []
|
||||||
|
|
||||||
# Process sequences in batches
|
# Process sequences in batches
|
||||||
@@ -234,37 +236,42 @@ def calculate_energy(sequences, batch_size=32):
|
|||||||
batch = sequences[i:min(i + batch_size, num_sequences)]
|
batch = sequences[i:min(i + batch_size, num_sequences)]
|
||||||
batch = np.array(batch) # Convert batch to numpy array
|
batch = np.array(batch) # Convert batch to numpy array
|
||||||
|
|
||||||
# Calculate batch energies
|
# Calculate batch contribution to weight matrix (Hebbian learning)
|
||||||
batch_energies = np.sum(batch * batch.transpose(0, 2, 1), axis=(1, 2)) / -2
|
for seq in batch:
|
||||||
|
W += np.dot(seq, seq.T)
|
||||||
|
|
||||||
|
# Calculate batch energies including magnetic field
|
||||||
|
batch_energies = []
|
||||||
|
for seq in batch:
|
||||||
|
# E = -1/2 * s^T * W * s - h * sum(s)
|
||||||
|
# Properly extract scalar values from matrix multiplications
|
||||||
|
spin_spin_matrix = seq.T.dot(W).dot(seq)
|
||||||
|
spin_spin = -0.5 * float(spin_spin_matrix[0, 0])
|
||||||
|
|
||||||
|
magnetic_matrix = h_field.T.dot(seq)
|
||||||
|
magnetic = -float(magnetic_matrix[0, 0])
|
||||||
|
|
||||||
|
energy = spin_spin + magnetic
|
||||||
|
batch_energies.append(energy)
|
||||||
|
|
||||||
energies.extend(batch_energies)
|
energies.extend(batch_energies)
|
||||||
|
|
||||||
# Update Hamiltonian
|
|
||||||
batch_hamiltonian = np.sum(np.matmul(batch, batch.transpose(0, 2, 1)), axis=0)
|
|
||||||
hamiltonian += batch_hamiltonian
|
|
||||||
|
|
||||||
# Free memory
|
|
||||||
del batch
|
|
||||||
del batch_energies
|
|
||||||
del batch_hamiltonian
|
|
||||||
|
|
||||||
# Normalize Hamiltonian
|
# Normalize weight matrix
|
||||||
hamiltonian = hamiltonian / num_sequences
|
W = W / num_sequences
|
||||||
|
|
||||||
return np.array(energies), hamiltonian
|
return np.array(energies), W, h_field
|
||||||
|
|
||||||
def retrieve_sequences(sequences, partial_sequence, vocab, W, M=10, N=12, temperature=1.0):
|
def retrieve_sequences(sequences, partial_sequence, vocab, W, M=10, N=12, temperature=1.0, h=0.1):
|
||||||
"""
|
"""
|
||||||
Retrieve the most likely next word using Ising Hamiltonian with temperature.
|
Retrieve the most likely next word using Ising Hamiltonian with magnetic field.
|
||||||
Uses associative memory to retrieve the last word of the sequence.
|
|
||||||
"""
|
"""
|
||||||
# Convert partial sequence to vector
|
|
||||||
partial_vec = np.array([int(bit) for bit in partial_sequence]).reshape(-1, 1)
|
|
||||||
|
|
||||||
# Get all possible words from vocabulary
|
# Get all possible words from vocabulary
|
||||||
possible_words = list(vocab.values())
|
possible_words = list(vocab.values())
|
||||||
|
|
||||||
# Calculate weights matrix (Hebbian learning)
|
# Create magnetic field
|
||||||
# Calculate energies for all possible words
|
h_field = h * np.ones(M * N).reshape(-1, 1)
|
||||||
|
|
||||||
|
# Calculate energies for all possible completions
|
||||||
word_energies = []
|
word_energies = []
|
||||||
|
|
||||||
for word in possible_words:
|
for word in possible_words:
|
||||||
@@ -273,29 +280,36 @@ def retrieve_sequences(sequences, partial_sequence, vocab, W, M=10, N=12, temper
|
|||||||
if len(complete_sequence) == M*N: # Ensure correct length
|
if len(complete_sequence) == M*N: # Ensure correct length
|
||||||
complete_vec = np.array([int(bit) for bit in complete_sequence]).reshape(M * N, 1)
|
complete_vec = np.array([int(bit) for bit in complete_sequence]).reshape(M * N, 1)
|
||||||
|
|
||||||
# Calculate energy using Ising Hamiltonian
|
# Calculate energy with both interaction and magnetic field terms
|
||||||
energy_matrix = complete_vec.T.dot(W).dot(complete_vec)
|
spin_spin = 0
|
||||||
energy = float(energy_matrix[0, 0])
|
for seq in sequences:
|
||||||
|
# Properly extract scalar from matrix multiplication
|
||||||
|
overlap_matrix = complete_vec.T.dot(seq)
|
||||||
|
overlap = overlap_matrix[0, 0] # Extract single scalar value
|
||||||
|
spin_spin -= overlap * overlap
|
||||||
|
|
||||||
word_energies.append((word, energy))
|
# Extract scalar from magnetic field contribution
|
||||||
|
magnetic_matrix = h_field.T.dot(complete_vec)
|
||||||
|
magnetic = -float(magnetic_matrix[0, 0])
|
||||||
|
total_energy = spin_spin + magnetic
|
||||||
|
|
||||||
|
word_energies.append((word, total_energy))
|
||||||
|
|
||||||
# Sort by energy
|
# Sort by energy
|
||||||
word_energies.sort(key=lambda x: x[1])
|
word_energies.sort(key=lambda x: x[1])
|
||||||
|
|
||||||
# Normalize energies to prevent overflow
|
# Normalize energies
|
||||||
energies = np.array([e[1] for e in word_energies])
|
energies = np.array([e[1] for e in word_energies])
|
||||||
energies = energies - np.min(energies) # Shift to make minimum energy 0
|
energies = energies - np.min(energies)
|
||||||
energies = energies / np.max(energies) if np.max(energies) > 0 else energies # Scale to [0,1]
|
max_energy = np.max(energies)
|
||||||
|
if max_energy > 0:
|
||||||
|
energies = energies / max_energy
|
||||||
|
|
||||||
# Calculate probabilities with normalized energies
|
# Calculate probabilities with Boltzmann distribution
|
||||||
probabilities = np.exp(-energies/temperature)
|
probabilities = np.exp(-energies/temperature)
|
||||||
probabilities = probabilities / np.sum(probabilities)
|
probabilities = probabilities / np.sum(probabilities)
|
||||||
|
|
||||||
# Check for valid probabilities
|
# Sample from distribution
|
||||||
if np.any(np.isnan(probabilities)):
|
|
||||||
# Fallback to uniform distribution if numerical issues occur
|
|
||||||
probabilities = np.ones(len(word_energies)) / len(word_energies)
|
|
||||||
|
|
||||||
selected_idx = np.random.choice(len(word_energies), p=probabilities)
|
selected_idx = np.random.choice(len(word_energies), p=probabilities)
|
||||||
best_word, min_energy = word_energies[selected_idx]
|
best_word, min_energy = word_energies[selected_idx]
|
||||||
|
|
||||||
@@ -340,9 +354,10 @@ def predict_sequence(initial_sequence, vocab, sequences, W, D=10, M=100, N=12, t
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
N = 20 # Define N as a constant
|
N = 20 # Define N as a constant
|
||||||
M = 20 # Define M as a constant
|
M = 100 # Define M as a constant
|
||||||
D = 10 # Number of words to predict
|
D = 10 # Number of words to predict
|
||||||
temperature = 0.10
|
temperature = 1
|
||||||
|
|
||||||
batch_size = 50 # Added batch size parameter
|
batch_size = 50 # Added batch size parameter
|
||||||
|
|
||||||
print("Loading and encoding stories...")
|
print("Loading and encoding stories...")
|
||||||
@@ -360,7 +375,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Get initial sequence from first story
|
# Get initial sequence from first story
|
||||||
story_tokens = tokenize_with_punctuation(original_stories[0])
|
story_tokens = tokenize_with_punctuation(original_stories[0])
|
||||||
_, W = calculate_energy(sequences)
|
_, W, _ = calculate_energy(sequences)
|
||||||
|
|
||||||
# Make sure we have enough tokens for M=100
|
# Make sure we have enough tokens for M=100
|
||||||
if len(story_tokens) >= M-1:
|
if len(story_tokens) >= M-1:
|
||||||
|
|||||||
Reference in New Issue
Block a user