import torch import torch.nn as nn import torch.optim as optim import random import numpy as np from collections import deque from snake_game import GameEnv from utils import DQN # Environment env = GameEnv() # Hyperparameters learning_rate = 0.001 gamma = 0.99 epsilon = 1.0 epsilon_min = 0.01 epsilon_decay = 0.995 batch_size = 64 target_update_freq = 1000 memory_size = 10000 episodes = 1000 # Initialize Q-networks input_dim = env.input_length() output_dim = env.output_length() policy_net = DQN(input_dim, output_dim) target_net = DQN(input_dim, output_dim) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate) memory = deque(maxlen=memory_size) # Function to choose action using epsilon-greedy policy def select_action(state, epsilon): if random.random() < epsilon: return random.randint(0, output_dim - 1) # Explore else: state = torch.FloatTensor(state).unsqueeze(0) q_values = policy_net(state) return torch.argmax(q_values).item() # Exploit # Function to optimize the model using experience replay def optimize_model(): if len(memory) < batch_size: return batch = random.sample(memory, batch_size) state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*batch) state_batch = torch.FloatTensor(np.array(state_batch, dtype=np.float32)) action_batch = torch.LongTensor(action_batch).unsqueeze(1) reward_batch = torch.FloatTensor(reward_batch) next_state_batch = torch.FloatTensor(np.array(next_state_batch, dtype=np.float32)) done_batch = torch.FloatTensor(done_batch) # Compute Q-values for current states q_values = policy_net(state_batch).gather(1, action_batch).squeeze() # Compute target Q-values using Double DQN with torch.no_grad(): next_action_batch = policy_net(next_state_batch).argmax(1) # Get best action from policy net max_next_q_values = target_net(next_state_batch).gather(1, next_action_batch.unsqueeze(1)).squeeze() target_q_values = reward_batch + gamma * max_next_q_values * (1 - done_batch) loss = nn.MSELoss()(q_values, target_q_values) optimizer.zero_grad() loss.backward() optimizer.step() # Main training loop rewards_per_episode = [] steps_done = 0 for episode in range(episodes): state = env.reset() episode_reward = 0 episode_done = False while not episode_done: # Select action action = select_action(state, epsilon) next_state, reward, done = env.step(action) # Store transition in memory memory.append((state, action, reward, next_state, done)) # Visualization env.render() # Update state state = next_state episode_done = done episode_reward += reward # Optimize model optimize_model() # Update target network periodically if steps_done % target_update_freq == 0: target_net.load_state_dict(policy_net.state_dict()) steps_done += 1 progress = (steps_done % target_update_freq) / target_update_freq # Decay epsilon epsilon = max(epsilon_min, epsilon_decay * epsilon) rewards_per_episode.append(episode_reward) print(f"Episode {episode + 1}/{episodes} - Reward: {episode_reward} - Epsilon: {epsilon:.4f}") env.close() # Save the trained model torch.save(policy_net.state_dict(), "snake.pth") print("Model saved successfully!") # Plotting the rewards per episode import matplotlib.pyplot as plt plt.plot(rewards_per_episode) plt.xlabel('Episode') plt.ylabel('Reward') plt.title('DQN') plt.show()