Advertisement
ondross

Toy LLM Code

Jul 4th, 2025
34
0
6 days
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 13.28 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from transformers import get_cosine_schedule_with_warmup
  4. import random
  5. import re
  6. import time
  7. import os
  8.  
  9.  
  10. def save_checkpoint(model, optimizer, scheduler, epoch, loss, filepath):
  11.     checkpoint = {
  12.         'epoch': epoch,
  13.         'model_state_dict': model.state_dict(),
  14.         'optimizer_state_dict': optimizer.state_dict(),
  15.         'scheduler_state_dict': scheduler.state_dict(),
  16.         'loss': loss,
  17.         'lr': optimizer.param_groups[0]['lr']
  18.     }
  19.     torch.save(checkpoint, filepath)
  20.     print(f"Checkpoint saved at epoch {epoch}")
  21.  
  22.  
  23. EMBED_DIM = 32   # vocab size / 10–20
  24. LAYERS = 8        # Rule of thumb: you want roughly 10-100 tokens per parameter for decent training "Chinchilla" research paper suggests 20 tokens per parameter
  25. HEADS = 2         # embed_dim / 32-64
  26. SEQ_LEN = 64
  27. BATCH_SIZE = 64    # shrink this if you run out of memory, but bigger runs faster
  28. EPOCHS = 50
  29. DIM_FEEDFORWARD = EMBED_DIM * 4  # * 4 is more common, 2 might be nice to keep model small
  30. LEARNING_RATE = 0.005 # start lower if removing scheduler
  31. TRAINING_TEXT = "trainingData/tinyStories/tiny_stories_10_shrunk200x.txt"
  32.  
  33.  
  34. def get_training_text(shrink_factor = 1):
  35.     with open(TRAINING_TEXT) as f:
  36.         text = f.read()
  37.         text = text[:len(text) // shrink_factor]
  38.         return clean_text(text)
  39.  
  40. def clean_text(s):
  41.     s = s.replace('.', ' .').replace(',', ' ,').replace('\n', ' ').replace('!', ' !').replace('?', ' ?')
  42.     s = s.lower()
  43.     return re.sub(r'[^a-zA-Z,\.\!\?\ ]', ' ', s)
  44.  
  45.  
  46. class SimpleTokenizer:
  47.     def __init__(self, text):
  48.         seen = {}
  49.         words = text.split()
  50.         for word in words:
  51.             if word not in seen:
  52.                 seen[word] = 1
  53.             else:
  54.                 seen[word] += 1
  55.    
  56.  
  57.         self.words = ['<pad>', '<unk>'] + [word for word, count in seen.items()]# if count >= 5]
  58.  
  59.         self.word_to_id = {word: i for i, word in enumerate(self.words)}
  60.         self.pad_id = 0
  61.         self.unk_id = 1
  62.  
  63.     def encode(self, text):
  64.         words = text.lower().split()
  65.         return [self.word_to_id.get(word, self.unk_id) for word in words]  # Use <unk> for unknown words
  66.  
  67.     def decode(self, ids):
  68.         words = [self.words[i] for i in ids]
  69.         return ' '.join(words)
  70.  
  71. class SimpleLLM(nn.Module):
  72.     def __init__(self, vocab_size):
  73.         super().__init__()
  74.         self.vocab_size = vocab_size
  75.  
  76.         # Convert words to numbers (embeddings)
  77.         self.word_embeddings = nn.Embedding(vocab_size, EMBED_DIM)
  78.         nn.init.normal_(self.word_embeddings.weight, mean=0, std=0.02)
  79.        
  80.         # Learn position in sentence
  81.         self.position_embeddings = nn.Embedding(SEQ_LEN, EMBED_DIM)
  82.        
  83.         # The transformer layers
  84.         encoder_layer = nn.TransformerEncoderLayer(
  85.             d_model=EMBED_DIM,
  86.             nhead=HEADS,
  87.             dim_feedforward=DIM_FEEDFORWARD,
  88.             batch_first=True,
  89.             norm_first=True,
  90.             dropout=0.1  # REMOVE (set to 0)
  91.         )
  92.         self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=LAYERS)
  93.        
  94.         # Convert back to word predictions
  95.         self.output_layer = nn.Linear(EMBED_DIM, vocab_size, bias=False)  # todo confirm bias false makes sense
  96.         self.output_layer.weight = self.word_embeddings.weight  # todo confirm tying weights makes sense
  97.  
  98.         # REMOVE
  99.         self.dropout = nn.Dropout(0.1)
  100.  
  101.    
  102.     def forward(self, x):
  103.         # x is a batch of sequences of word IDs (2d array)
  104.         batch_size, seq_len = x.shape
  105.        
  106.         # Create positions: [0, 1, 2, 3, ...]
  107.         positions = torch.arange(seq_len, device=x.device).expand(batch_size, -1)
  108.        
  109.         # Convert words and positions to embeddings
  110.         word_embeds = self.word_embeddings(x)
  111.         pos_embeds = self.position_embeddings(positions)
  112.        
  113.         # Add them together (REMOVE dropout wrapper before giving to students)
  114.         embeddings = self.dropout(word_embeds + pos_embeds)
  115.        
  116.         # Create a mask so the model can't cheat by looking at future words during training
  117.         attention_mask = torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1)
  118.        
  119.         # Pass through transformer
  120.         output = self.transformer(embeddings, mask=attention_mask)
  121.        
  122.         # Convert back to word predictions
  123.         return self.output_layer(output)
  124.  
  125. def create_training_data(text, tokenizer):
  126.     words = tokenizer.encode(text)
  127.     data = []
  128.    
  129.     # Create sequences of SEQ_LEN words
  130.     for i in range(len(words) - SEQ_LEN):
  131.         input_seq = words[i:i + SEQ_LEN]
  132.         target_seq = words[i + 1:i + SEQ_LEN + 1]  # Next word for each position
  133.         data.append((input_seq, target_seq))
  134.    
  135.     return data
  136.  
  137. def train_model(model, training_data, tokenizer, validation_data):
  138.     """Train the model to predict the next word"""
  139.     optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
  140.     loss_function = nn.CrossEntropyLoss()
  141.  
  142.     # REMOVE for students
  143.     total_batches = len(training_data) // BATCH_SIZE
  144.     total_steps = total_batches * EPOCHS
  145.     scheduler = get_cosine_schedule_with_warmup(
  146.         optimizer,
  147.         num_warmup_steps=0,
  148.         num_training_steps=total_steps
  149.     )
  150.    
  151.     best_val_loss = 100
  152.     for epoch in range(EPOCHS):
  153.         total_loss = 0
  154.         random.shuffle(training_data)  # REMOVE
  155.  
  156.         # REMOVE. Show loss progress at intervals in the epoch
  157.         interval = total_batches // 100
  158.  
  159.         # Process in small batches (training_data is an array of strings that are each seq_length long)
  160.         batch_count = 0
  161.         for i in range(0, len(training_data), BATCH_SIZE):
  162.             batch = training_data[i:i + BATCH_SIZE]
  163.             if len(batch) == 0:
  164.                 continue
  165.             batch_count += 1
  166.            
  167.             # Convert to tensors
  168.             inputs = torch.tensor([item[0] for item in batch])
  169.             targets = torch.tensor([item[1] for item in batch])
  170.            
  171.             # Make predictions
  172.             predictions = model(inputs)  # output = 3d tensor: [batch_size, sequence_length, vocab_size]
  173.            
  174.             # Calculate loss
  175.             loss = loss_function(predictions.reshape(-1, model.vocab_size), targets.reshape(-1))
  176.            
  177.             # Backpropagation
  178.             optimizer.zero_grad()
  179.             loss.backward()
  180.             optimizer.step()
  181.             scheduler.step() # REMOVE
  182.            
  183.             total_loss += loss.item()
  184.  
  185.             # REMOVE: more instrumentation
  186.             if batch_count % interval == 0 and interval > 0:
  187.                 current_avg = total_loss / batch_count
  188.                 print(f"{current_avg:.3f}")
  189.  
  190.         current_lr = optimizer.param_groups[0]['lr']
  191.         avg_loss = total_loss / batch_count
  192.         print(f"Epoch {epoch + 1}: Average loss = {avg_loss:.3f}, LR = {current_lr:.4f}, elapsed = {(time.time() - start) / 60} minutes")
  193.         simple_eval(model, tokenizer, validation_data)
  194.  
  195.         # Save checkpoint every few epochs
  196.         if (epoch + 1) % 5 == 0:  # Save every 5 epochs
  197.             save_checkpoint(model, optimizer, scheduler, epoch, avg_loss,
  198.                         f'checkpoint_epoch_{epoch+1}.pt')
  199.  
  200.         # Also save best model
  201.         if avg_loss < best_val_loss:
  202.             best_val_loss = avg_loss
  203.             save_checkpoint(model, optimizer, scheduler, epoch, avg_loss,
  204.                         'best_model.pt')
  205.  
  206.  
  207.  
  208.  
  209. def generate_text(model, tokenizer, prompt, length=20):
  210.     """Generate new text starting with a prompt"""
  211.     model.eval()  # Switch to evaluation mode
  212.    
  213.     words = tokenizer.encode(prompt)
  214.    
  215.     for _ in range(length):
  216.         # Use the last SEQ_LEN words as input
  217.         if len(words) >= SEQ_LEN:
  218.             input_seq = words[-SEQ_LEN:]
  219.         else:
  220.             input_seq = [tokenizer.pad_id] * (SEQ_LEN - len(words)) + words
  221.        
  222.         # Get entire prediction (more data than we need)
  223.         input_tensor = torch.tensor([input_seq])
  224.         with torch.no_grad():
  225.             output = model(input_tensor)
  226.            
  227.         # Get the prediction for the last position [batch 0, last word, all probabilities]
  228.         last_word_predictions = output[0, -1, :]
  229.  
  230.         # Don't let the model generate pad, unk, or the same token repeatedly
  231.         # REMOVE
  232.         last_word_predictions[tokenizer.pad_id] = -float('inf')
  233.         last_word_predictions[tokenizer.unk_id] = -float('inf')
  234.        
  235.         # Discourage repeating the last few tokens
  236.         # REMOVE
  237.         if len(words) >= 2 and words[-1] == words[-2]:
  238.             last_word_predictions[words[-1]] = -float('inf')
  239.        
  240.         # Convert to raw numbers to probabilities
  241.         probabilities = nn.functional.softmax(last_word_predictions / .8, dim=0)
  242.        
  243.         # option 1: typical way to select the next word randomly
  244.         # next_word_id = torch.multinomial(probabilities, 1).item()
  245.  
  246.         # option 2: more explicit and students can figure it out on their own (works the same as how we assign snacks in advisory)
  247.         next_word_id = 0
  248.         random_val = torch.rand(1)
  249.         cumulative = 0
  250.         for i in range(len(probabilities)):
  251.             cumulative += probabilities[i]
  252.             if cumulative >= random_val:
  253.                 next_word_id = i
  254.                 break
  255.        
  256.         words.append(next_word_id)
  257.    
  258.     return tokenizer.decode(words)
  259.  
  260.  
  261. # todo: simplify for student use
  262. def evaluate_loss(model, data):
  263.     model.eval()
  264.     loss_function = nn.CrossEntropyLoss()
  265.     total_loss = 0
  266.     batch_count = 0
  267.  
  268.     with torch.no_grad():
  269.         for i in range(0, len(data), BATCH_SIZE):
  270.             batch = data[i:i + BATCH_SIZE]
  271.             if len(batch) < BATCH_SIZE:
  272.                 continue
  273.             inputs = torch.tensor([x[0] for x in batch])
  274.             targets = torch.tensor([x[1] for x in batch])
  275.             predictions = model(inputs)
  276.  
  277.             # print(predictions.shape)           # should be [BATCH_SIZE, SEQ_LEN, vocab_size]
  278.             # print(predictions[0,0,:10])  
  279.  
  280.             loss = loss_function(predictions.reshape(-1, model.vocab_size), targets.reshape(-1))
  281.             total_loss += loss.item()
  282.             batch_count += 1
  283.     return total_loss / batch_count
  284.  
  285.  
  286. def top_k_similar_tokens(token_id, embedding_matrix, k=5):
  287.     # Normalize embeddings
  288.     normed = nn.functional.normalize(embedding_matrix, dim=1)
  289.    
  290.     # Get embedding for the given token
  291.     query = normed[token_id].unsqueeze(0)  # shape (1, dim)
  292.    
  293.     # Compute cosine similarities to all tokens
  294.     similarities = torch.matmul(query, normed.T).squeeze()  # shape (vocab_size,)
  295.    
  296.     # Get top k (excluding the token itself)
  297.     topk = torch.topk(similarities, k + 1)
  298.    
  299.     topk_indices = topk.indices.tolist()
  300.     topk_scores = topk.values.tolist()
  301.    
  302.     return list(zip(topk_indices, topk_scores))[:k]
  303.  
  304. def simple_eval(model, tokenizer, validation_data):
  305.     for word in ["dog", "the", "good", "king", "tim"]:
  306.         token = tokenizer.encode(word)[0]
  307.         similar = top_k_similar_tokens(token, model.word_embeddings.weight)
  308.         print([(tokenizer.decode([token])) for (token, similarity) in similar])
  309.    
  310.     for prompt in ["children play", "my favorite animal is", "tom and tim were a little"]:
  311.         print(generate_text(model, tokenizer, prompt, length=40))
  312.         print("")
  313.    
  314.     val_loss = evaluate_loss(model, validation_data)
  315.     print(f"Validation loss: {val_loss:.3f}")
  316.    
  317.  
  318. #
  319. # MAIN
  320. #
  321.  
  322.  
  323. start = time.time()
  324.  
  325. training_text = get_training_text()
  326. tokenizer = SimpleTokenizer(training_text)
  327. vocab_size = len(tokenizer.words)
  328.  
  329.  
  330. print(f"Vocab: {vocab_size}")
  331. model = SimpleLLM(vocab_size)
  332. print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")
  333.  
  334. # Create training data
  335. all_data = create_training_data(training_text, tokenizer)
  336.  
  337. # shrink data in order to overfit
  338. #all_data = all_data[:1000]
  339.  
  340.  
  341. random.shuffle(all_data)
  342. print(tokenizer.decode(all_data[0][0]))
  343. split = int(0.95 * len(all_data))  # 95% is for training, 5% for validation
  344. training_data = all_data[:split]
  345. validation_data = all_data[split:]
  346.  
  347. print(f"Created {len(training_data)} training examples")
  348.  
  349.  
  350. train_model(model, training_data, tokenizer, validation_data)
  351.  
  352. train_loss_eval = evaluate_loss(model, training_data)
  353. print(f"Training loss (eval mode): {train_loss_eval:.3f}")
  354.  
  355. val_loss = evaluate_loss(model, validation_data)
  356. print(f"Validation loss: {val_loss:.3f}")
  357.  
  358. # demo: examine which words have closest embeddings
  359. for i in range(20):
  360.     word = input("what word would you like to examine? ")
  361.     try:
  362.         token = tokenizer.encode(word)[0]
  363.         similar = top_k_similar_tokens(token, model.word_embeddings.weight)
  364.         print([(tokenizer.decode([token]), similarity) for (token, similarity) in similar])
  365.     except Exception:
  366.         print("try again")
  367.  
  368.  
  369. test_prompts = ["the cat", "the sun", "children play"]
  370. for prompt in test_prompts:
  371.     print(generate_text(model, tokenizer, prompt, length=50))
  372.  
  373. while True:
  374.     user_prompt = input("\nEnter a prompt (or 'quit' to exit): ")
  375.     try:
  376.         print(generate_text(model, tokenizer, user_prompt, length=50))
  377.     except Exception:
  378.         print("try again")
  379.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement