Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- from transformers import get_cosine_schedule_with_warmup
- import random
- import re
- import time
- import os
- def save_checkpoint(model, optimizer, scheduler, epoch, loss, filepath):
- checkpoint = {
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'scheduler_state_dict': scheduler.state_dict(),
- 'loss': loss,
- 'lr': optimizer.param_groups[0]['lr']
- }
- torch.save(checkpoint, filepath)
- print(f"Checkpoint saved at epoch {epoch}")
- EMBED_DIM = 32 # vocab size / 10–20
- LAYERS = 8 # Rule of thumb: you want roughly 10-100 tokens per parameter for decent training "Chinchilla" research paper suggests 20 tokens per parameter
- HEADS = 2 # embed_dim / 32-64
- SEQ_LEN = 64
- BATCH_SIZE = 64 # shrink this if you run out of memory, but bigger runs faster
- EPOCHS = 50
- DIM_FEEDFORWARD = EMBED_DIM * 4 # * 4 is more common, 2 might be nice to keep model small
- LEARNING_RATE = 0.005 # start lower if removing scheduler
- TRAINING_TEXT = "trainingData/tinyStories/tiny_stories_10_shrunk200x.txt"
- def get_training_text(shrink_factor = 1):
- with open(TRAINING_TEXT) as f:
- text = f.read()
- text = text[:len(text) // shrink_factor]
- return clean_text(text)
- def clean_text(s):
- s = s.replace('.', ' .').replace(',', ' ,').replace('\n', ' ').replace('!', ' !').replace('?', ' ?')
- s = s.lower()
- return re.sub(r'[^a-zA-Z,\.\!\?\ ]', ' ', s)
- class SimpleTokenizer:
- def __init__(self, text):
- seen = {}
- words = text.split()
- for word in words:
- if word not in seen:
- seen[word] = 1
- else:
- seen[word] += 1
- self.words = ['<pad>', '<unk>'] + [word for word, count in seen.items()]# if count >= 5]
- self.word_to_id = {word: i for i, word in enumerate(self.words)}
- self.pad_id = 0
- self.unk_id = 1
- def encode(self, text):
- words = text.lower().split()
- return [self.word_to_id.get(word, self.unk_id) for word in words] # Use <unk> for unknown words
- def decode(self, ids):
- words = [self.words[i] for i in ids]
- return ' '.join(words)
- class SimpleLLM(nn.Module):
- def __init__(self, vocab_size):
- super().__init__()
- self.vocab_size = vocab_size
- # Convert words to numbers (embeddings)
- self.word_embeddings = nn.Embedding(vocab_size, EMBED_DIM)
- nn.init.normal_(self.word_embeddings.weight, mean=0, std=0.02)
- # Learn position in sentence
- self.position_embeddings = nn.Embedding(SEQ_LEN, EMBED_DIM)
- # The transformer layers
- encoder_layer = nn.TransformerEncoderLayer(
- d_model=EMBED_DIM,
- nhead=HEADS,
- dim_feedforward=DIM_FEEDFORWARD,
- batch_first=True,
- norm_first=True,
- dropout=0.1 # REMOVE (set to 0)
- )
- self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=LAYERS)
- # Convert back to word predictions
- self.output_layer = nn.Linear(EMBED_DIM, vocab_size, bias=False) # todo confirm bias false makes sense
- self.output_layer.weight = self.word_embeddings.weight # todo confirm tying weights makes sense
- # REMOVE
- self.dropout = nn.Dropout(0.1)
- def forward(self, x):
- # x is a batch of sequences of word IDs (2d array)
- batch_size, seq_len = x.shape
- # Create positions: [0, 1, 2, 3, ...]
- positions = torch.arange(seq_len, device=x.device).expand(batch_size, -1)
- # Convert words and positions to embeddings
- word_embeds = self.word_embeddings(x)
- pos_embeds = self.position_embeddings(positions)
- # Add them together (REMOVE dropout wrapper before giving to students)
- embeddings = self.dropout(word_embeds + pos_embeds)
- # Create a mask so the model can't cheat by looking at future words during training
- attention_mask = torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1)
- # Pass through transformer
- output = self.transformer(embeddings, mask=attention_mask)
- # Convert back to word predictions
- return self.output_layer(output)
- def create_training_data(text, tokenizer):
- words = tokenizer.encode(text)
- data = []
- # Create sequences of SEQ_LEN words
- for i in range(len(words) - SEQ_LEN):
- input_seq = words[i:i + SEQ_LEN]
- target_seq = words[i + 1:i + SEQ_LEN + 1] # Next word for each position
- data.append((input_seq, target_seq))
- return data
- def train_model(model, training_data, tokenizer, validation_data):
- """Train the model to predict the next word"""
- optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
- loss_function = nn.CrossEntropyLoss()
- # REMOVE for students
- total_batches = len(training_data) // BATCH_SIZE
- total_steps = total_batches * EPOCHS
- scheduler = get_cosine_schedule_with_warmup(
- optimizer,
- num_warmup_steps=0,
- num_training_steps=total_steps
- )
- best_val_loss = 100
- for epoch in range(EPOCHS):
- total_loss = 0
- random.shuffle(training_data) # REMOVE
- # REMOVE. Show loss progress at intervals in the epoch
- interval = total_batches // 100
- # Process in small batches (training_data is an array of strings that are each seq_length long)
- batch_count = 0
- for i in range(0, len(training_data), BATCH_SIZE):
- batch = training_data[i:i + BATCH_SIZE]
- if len(batch) == 0:
- continue
- batch_count += 1
- # Convert to tensors
- inputs = torch.tensor([item[0] for item in batch])
- targets = torch.tensor([item[1] for item in batch])
- # Make predictions
- predictions = model(inputs) # output = 3d tensor: [batch_size, sequence_length, vocab_size]
- # Calculate loss
- loss = loss_function(predictions.reshape(-1, model.vocab_size), targets.reshape(-1))
- # Backpropagation
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- scheduler.step() # REMOVE
- total_loss += loss.item()
- # REMOVE: more instrumentation
- if batch_count % interval == 0 and interval > 0:
- current_avg = total_loss / batch_count
- print(f"{current_avg:.3f}")
- current_lr = optimizer.param_groups[0]['lr']
- avg_loss = total_loss / batch_count
- print(f"Epoch {epoch + 1}: Average loss = {avg_loss:.3f}, LR = {current_lr:.4f}, elapsed = {(time.time() - start) / 60} minutes")
- simple_eval(model, tokenizer, validation_data)
- # Save checkpoint every few epochs
- if (epoch + 1) % 5 == 0: # Save every 5 epochs
- save_checkpoint(model, optimizer, scheduler, epoch, avg_loss,
- f'checkpoint_epoch_{epoch+1}.pt')
- # Also save best model
- if avg_loss < best_val_loss:
- best_val_loss = avg_loss
- save_checkpoint(model, optimizer, scheduler, epoch, avg_loss,
- 'best_model.pt')
- def generate_text(model, tokenizer, prompt, length=20):
- """Generate new text starting with a prompt"""
- model.eval() # Switch to evaluation mode
- words = tokenizer.encode(prompt)
- for _ in range(length):
- # Use the last SEQ_LEN words as input
- if len(words) >= SEQ_LEN:
- input_seq = words[-SEQ_LEN:]
- else:
- input_seq = [tokenizer.pad_id] * (SEQ_LEN - len(words)) + words
- # Get entire prediction (more data than we need)
- input_tensor = torch.tensor([input_seq])
- with torch.no_grad():
- output = model(input_tensor)
- # Get the prediction for the last position [batch 0, last word, all probabilities]
- last_word_predictions = output[0, -1, :]
- # Don't let the model generate pad, unk, or the same token repeatedly
- # REMOVE
- last_word_predictions[tokenizer.pad_id] = -float('inf')
- last_word_predictions[tokenizer.unk_id] = -float('inf')
- # Discourage repeating the last few tokens
- # REMOVE
- if len(words) >= 2 and words[-1] == words[-2]:
- last_word_predictions[words[-1]] = -float('inf')
- # Convert to raw numbers to probabilities
- probabilities = nn.functional.softmax(last_word_predictions / .8, dim=0)
- # option 1: typical way to select the next word randomly
- # next_word_id = torch.multinomial(probabilities, 1).item()
- # option 2: more explicit and students can figure it out on their own (works the same as how we assign snacks in advisory)
- next_word_id = 0
- random_val = torch.rand(1)
- cumulative = 0
- for i in range(len(probabilities)):
- cumulative += probabilities[i]
- if cumulative >= random_val:
- next_word_id = i
- break
- words.append(next_word_id)
- return tokenizer.decode(words)
- # todo: simplify for student use
- def evaluate_loss(model, data):
- model.eval()
- loss_function = nn.CrossEntropyLoss()
- total_loss = 0
- batch_count = 0
- with torch.no_grad():
- for i in range(0, len(data), BATCH_SIZE):
- batch = data[i:i + BATCH_SIZE]
- if len(batch) < BATCH_SIZE:
- continue
- inputs = torch.tensor([x[0] for x in batch])
- targets = torch.tensor([x[1] for x in batch])
- predictions = model(inputs)
- # print(predictions.shape) # should be [BATCH_SIZE, SEQ_LEN, vocab_size]
- # print(predictions[0,0,:10])
- loss = loss_function(predictions.reshape(-1, model.vocab_size), targets.reshape(-1))
- total_loss += loss.item()
- batch_count += 1
- return total_loss / batch_count
- def top_k_similar_tokens(token_id, embedding_matrix, k=5):
- # Normalize embeddings
- normed = nn.functional.normalize(embedding_matrix, dim=1)
- # Get embedding for the given token
- query = normed[token_id].unsqueeze(0) # shape (1, dim)
- # Compute cosine similarities to all tokens
- similarities = torch.matmul(query, normed.T).squeeze() # shape (vocab_size,)
- # Get top k (excluding the token itself)
- topk = torch.topk(similarities, k + 1)
- topk_indices = topk.indices.tolist()
- topk_scores = topk.values.tolist()
- return list(zip(topk_indices, topk_scores))[:k]
- def simple_eval(model, tokenizer, validation_data):
- for word in ["dog", "the", "good", "king", "tim"]:
- token = tokenizer.encode(word)[0]
- similar = top_k_similar_tokens(token, model.word_embeddings.weight)
- print([(tokenizer.decode([token])) for (token, similarity) in similar])
- for prompt in ["children play", "my favorite animal is", "tom and tim were a little"]:
- print(generate_text(model, tokenizer, prompt, length=40))
- print("")
- val_loss = evaluate_loss(model, validation_data)
- print(f"Validation loss: {val_loss:.3f}")
- #
- # MAIN
- #
- start = time.time()
- training_text = get_training_text()
- tokenizer = SimpleTokenizer(training_text)
- vocab_size = len(tokenizer.words)
- print(f"Vocab: {vocab_size}")
- model = SimpleLLM(vocab_size)
- print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
- # Create training data
- all_data = create_training_data(training_text, tokenizer)
- # shrink data in order to overfit
- #all_data = all_data[:1000]
- random.shuffle(all_data)
- print(tokenizer.decode(all_data[0][0]))
- split = int(0.95 * len(all_data)) # 95% is for training, 5% for validation
- training_data = all_data[:split]
- validation_data = all_data[split:]
- print(f"Created {len(training_data)} training examples")
- train_model(model, training_data, tokenizer, validation_data)
- train_loss_eval = evaluate_loss(model, training_data)
- print(f"Training loss (eval mode): {train_loss_eval:.3f}")
- val_loss = evaluate_loss(model, validation_data)
- print(f"Validation loss: {val_loss:.3f}")
- # demo: examine which words have closest embeddings
- for i in range(20):
- word = input("what word would you like to examine? ")
- try:
- token = tokenizer.encode(word)[0]
- similar = top_k_similar_tokens(token, model.word_embeddings.weight)
- print([(tokenizer.decode([token]), similarity) for (token, similarity) in similar])
- except Exception:
- print("try again")
- test_prompts = ["the cat", "the sun", "children play"]
- for prompt in test_prompts:
- print(generate_text(model, tokenizer, prompt, length=50))
- while True:
- user_prompt = input("\nEnter a prompt (or 'quit' to exit): ")
- try:
- print(generate_text(model, tokenizer, user_prompt, length=50))
- except Exception:
- print("try again")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement