Advertisement
Guest User

Untitled

a guest
Jun 1st, 2025
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.95 KB | None | 0 0
  1. import sys
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import numpy as np
  6. import random
  7. import pygame
  8. from collections import deque
  9. from game_environment import SnakeEnv
  10. import imageio
  11. import matplotlib.pyplot as plt
  12. import time
  13.  
  14. # Hyperparameters
  15. GAMMA = 0.99
  16. LR = 1e-3
  17. BATCH_SIZE = 64
  18. MEMORY_SIZE = 10000
  19. EPSILON_START = 1.0
  20. EPSILON_END = 0.05
  21. EPSILON_DECAY = 0.995
  22. TARGET_UPDATE_FREQ = 10
  23. EPISODES = 500
  24.  
  25. RENDER_EVERY = 20
  26.  
  27. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  28. print(f"Using device: {device}\n")
  29.  
  30. # CNN
  31. class DQN(nn.Module):
  32.     def __init__(self, grid_h, grid_w, n_actions):
  33.         super().__init__()
  34.        
  35.         # Our observation is now 4 channels (body, head, food, direction).
  36.         # So Conv2d must expect in_channels=4.
  37.         self.conv = nn.Sequential(
  38.             nn.Conv2d(in_channels=4, out_channels=32, kernel_size=3, padding=1),  # -> (32, grid_h, grid_w)
  39.             nn.ReLU(),
  40.             nn.Conv2d(32, 64, kernel_size=3, padding=1),                            # -> (64, grid_h, grid_w)
  41.             nn.ReLU(),
  42.             nn.Conv2d(64, 64, kernel_size=3, padding=1),                            # -> (64, grid_h, grid_w)
  43.             nn.ReLU()
  44.         )
  45.        
  46.         # AFTER these three conv layers (with padding=1), the spatial dimensions remain (grid_h × grid_w).
  47.         # Therefore the flattened size is:
  48.         conv_output_size = 64 * grid_h * grid_w
  49.  
  50.         self.fc = nn.Sequential(
  51.             nn.Flatten(),                      # Flattens (batch, 64, grid_h, grid_w) -> (batch, 64*grid_h*grid_w)
  52.             nn.Linear(conv_output_size, 512),  # Now uses the *actual* conv_output_size, not a hard-coded 25600
  53.             nn.ReLU(),
  54.             nn.Linear(512, n_actions)          # One Q-value per action
  55.         )
  56.  
  57.     def forward(self, x):
  58.         # Expect x of shape (batch, 4, grid_h, grid_w)
  59.         x = self.conv(x)   # -> (batch, 64, grid_h, grid_w)
  60.         return self.fc(x)  # -> (batch, n_actions)
  61.  
  62. class ReplayMemory:
  63.     def __init__(self, capacity):
  64.         self.buffer = deque(maxlen=capacity)
  65.  
  66.     def push(self, transition):
  67.         # transition = (state, action, reward, next_state, done)
  68.         self.buffer.append(transition)
  69.  
  70.     def sample(self, batch_size):
  71.         return random.sample(self.buffer, batch_size)
  72.  
  73.     def __len__(self):
  74.         return len(self.buffer)
  75.    
  76. def train(model_path=None):
  77.  
  78.     start_time = time.time()
  79.  
  80.     env = SnakeEnv()
  81.     grid_h, grid_w = env.grid_height, env.grid_width
  82.     n_actions = 4  # right, left, up, down
  83.  
  84.     policy_net = DQN(grid_h, grid_w, n_actions).to(device)
  85.     target_net = DQN(grid_h, grid_w, n_actions).to(device)
  86.  
  87.     if model_path is not None:
  88.         print(f"Loading model from: {model_path}")
  89.         policy_net.load_state_dict(torch.load(model_path, map_location=device))
  90.  
  91.     target_net.load_state_dict(policy_net.state_dict())
  92.     target_net.eval()
  93.  
  94.     optimizer = optim.Adam(policy_net.parameters(), lr=LR)
  95.     memory = ReplayMemory(MEMORY_SIZE)
  96.  
  97.     epsilon = 0 if model_path is not None else EPSILON_START # can replace this line with epsilon = EPSILON_START to continue training a pre-trained model rather than testing with eps = 0
  98.  
  99.     best_rendered_score = -float('inf')
  100.     best_score = -float('inf')
  101.     video_filename = 'best_snake_episode.mp4'
  102.     model_save_path = "best_dqn_model.pth"
  103.  
  104.     all_scores = []
  105.     all_losses = []
  106.  
  107.     for episode in range(EPISODES):
  108.         obs = env.reset()
  109.         total_reward = 0
  110.         done = False
  111.  
  112.         render = (episode % RENDER_EVERY == 0)
  113.         frames = []
  114.         episode_losses = []
  115.  
  116.         while not done:
  117.  
  118.             if render:
  119.                 pygame.event.pump()
  120.                 env.window.fill('black')
  121.                 env.draw_snake()
  122.                 env.draw_food()
  123.                 env.display_score()
  124.                 pygame.display.flip()
  125.                 env.clock.tick(30)
  126.  
  127.                 frame = pygame.surfarray.array3d(env.window)
  128.                 frame = frame.transpose([1, 0, 2]) # transpose (Width, Height, Channel) to (Height, Width, Channel)
  129.                 frames.append(frame)
  130.  
  131.             # choose between random action and policy net action
  132.             if random.random() < epsilon:
  133.                 action = random.randint(0, n_actions - 1)
  134.             else:
  135.                 with torch.no_grad():
  136.                     obs_tensor = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
  137.                     action = policy_net(obs_tensor).argmax().item()
  138.            
  139.             # Step through environment
  140.             next_obs, reward, done = env.step(action)
  141.             memory.push((obs, action, reward, next_obs, done))
  142.             obs = next_obs
  143.             total_reward += reward
  144.  
  145.             # Train the network
  146.             if len(memory) >= BATCH_SIZE:
  147.                 batch = memory.sample(BATCH_SIZE)
  148.                 obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = zip(*batch)
  149.  
  150.                 # Convert to tensors:
  151.                 obs_batch = torch.tensor(np.stack(obs_batch), dtype=torch.float32, device=device)
  152.                 action_batch = torch.tensor(action_batch, dtype=torch.int64, device=device).unsqueeze(1)
  153.                 reward_batch = torch.tensor(reward_batch, dtype=torch.float32, device=device).unsqueeze(1)
  154.                 next_obs_batch = torch.tensor(np.stack(next_obs_batch), dtype=torch.float32, device=device)
  155.                 done_batch = torch.tensor(done_batch, dtype=torch.float32, device=device).unsqueeze(1)
  156.  
  157.                 #1 Forward pass
  158.                 q_values = policy_net(obs_batch).gather(1, action_batch)
  159.                 with torch.no_grad():
  160.                     max_next_q = target_net(next_obs_batch).max(1)[0].unsqueeze(1)
  161.                     target_q = reward_batch + GAMMA * max_next_q * (1 - done_batch)
  162.  
  163.                 #2 Calculate loss
  164.                 loss = nn.MSELoss()(q_values, target_q)
  165.  
  166.                 #3 Optimizer zero grad zeros out gradients since gradients can accumulate into next iteration even though they have already been applied
  167.                 optimizer.zero_grad()
  168.  
  169.                 #4 Backprop
  170.                 loss.backward()
  171.  
  172.                 #5 Gradient Descent
  173.                 optimizer.step()
  174.  
  175.                 # visualization
  176.                 episode_losses.append(loss.item())
  177.            
  178.         score = env.snake_size - env.STARTING_SIZE
  179.  
  180.         # visualization
  181.         avg_loss = np.mean(episode_losses) if episode_losses else 0
  182.         all_losses.append(avg_loss)
  183.         all_scores.append(score)
  184.  
  185.         # Update target network
  186.         if episode % TARGET_UPDATE_FREQ == 0:
  187.             target_net.load_state_dict(policy_net.state_dict())
  188.  
  189.         epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
  190.         print(f"Episode {episode}, Total reward: {round(total_reward, 2)}, Score: {score}, Epsilon: {epsilon:.3f}")
  191.  
  192.         # NOTE: only saves best score episode AMONG the few ones recorded...
  193.         if render and score > best_rendered_score:
  194.             best_rendered_score = score
  195.             print(f"New best score {best_rendered_score} at episode {episode}. Saving video...")
  196.             imageio.mimwrite(video_filename, frames, fps=15, codec='libx264', quality=8)
  197.        
  198.         if score > best_score: # best total score among all episodes
  199.             best_score = score
  200.             print(f"New best model score: {best_score}. Saving model weights...")
  201.             torch.save(policy_net.state_dict(), model_save_path)
  202.  
  203.    
  204.     # After training complete:
  205.     window = 50
  206.     episodes = np.arange(len(all_scores))
  207.     rolling_scores = np.convolve(all_scores, np.ones(window)/window, mode='valid')
  208.     rolling_losses = np.convolve(all_losses, np.ones(window)/window, mode='valid')
  209.  
  210.     fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
  211.  
  212.     ax1.plot(episodes[len(episodes) - len(rolling_scores):], rolling_scores, label=f"Avg Score ({window}-episode window)")
  213.     ax1.set_title("Smoothed Score per Episode")
  214.     ax1.set_xlabel("Episode")
  215.     ax1.set_ylabel("Score")
  216.     ax1.grid(True)
  217.  
  218.     ax2.plot(episodes[len(episodes) - len(rolling_losses):], rolling_losses, label=f"Avg Loss ({window}-episode window)", color="orange")
  219.     ax2.set_title("Smoothed Loss per Episode")
  220.     ax2.set_xlabel("Episode")
  221.     ax2.set_ylabel("Loss")
  222.     ax2.grid(True)
  223.  
  224.     plt.tight_layout()
  225.     plt.savefig("training_metrics.png")
  226.  
  227.     elapsed = time.time() - start_time
  228.     minutes, seconds = divmod(int(elapsed), 60)
  229.     print(f"\nTotal training time: {minutes} minutes and {seconds} seconds")
  230.  
  231.     print(f"\nBest score from a single episode in training: {best_score}")
  232.        
  233. if __name__ == "__main__":
  234.     try:
  235.         model_path = sys.argv[1] if len(sys.argv) > 1 else None # assign model path if given or set to None if not
  236.         train(model_path)
  237.     except KeyboardInterrupt:
  238.         print("\nTraining interrupted by user.")
  239.         pygame.quit()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement