Advertisement
Guest User

agent

a guest
May 30th, 2025
53
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 12.54 KB | None | 0 0
  1. # small network game that has differnt blobs
  2. # moving around the screen
  3. import contextlib
  4. import sys
  5. import torch
  6. from collections import deque
  7. from torch import nn
  8. import torch.nn.functional as F
  9. import numpy as np
  10. import random
  11. import os
  12. import time
  13.  
  14. with contextlib.redirect_stdout(None):
  15. import pygame
  16. from client import Network
  17.  
  18. # Constants
  19. PLAYER_RADIUS = 10
  20. START_VEL = 9
  21. BALL_RADIUS = 4
  22. TRAP_RADIUS = 10
  23. W, H = 300, 300
  24. SAVE_INTERVAL = 100 # Save weights every 20 episodes
  25. polyak = 0.995
  26.  
  27. #Model
  28. class DQN(nn.Module):
  29. def __init__(self, input_dim, hidden, n_actions, dropout_prob=0.01):
  30. super().__init__()
  31. self.fc1 = nn.Linear(input_dim, hidden)
  32. self.dropout1 = nn.Dropout(dropout_prob)
  33. self.fc2 = nn.Linear(hidden, hidden // 2)
  34. self.dropout2 = nn.Dropout(dropout_prob)
  35. self.fc3 = nn.Linear(hidden // 2, hidden // 4)
  36. self.dropout3 = nn.Dropout(dropout_prob)
  37. self.out = nn.Linear(hidden // 4, n_actions)
  38.  
  39. def forward(self, x):
  40. x = F.relu(self.fc1(x))
  41. x = self.dropout1(x)
  42. x = F.relu(self.fc2(x))
  43. x = self.dropout2(x)
  44. x = F.relu(self.fc3(x))
  45. x = self.dropout3(x)
  46. return self.out(x)
  47.  
  48. # Define memory for Experience Replay
  49. class ReplayMemory():
  50. def __init__(self, maxlen):
  51. self.memory = deque([], maxlen=maxlen)
  52.  
  53. def append(self, transition):
  54. self.memory.append(transition)
  55.  
  56. def sample(self, sample_size):
  57. # Konwertujemy deque na listę przed próbkowaniem
  58. return random.sample(list(self.memory), sample_size)
  59.  
  60. def __len__(self):
  61. return len(self.memory)
  62.  
  63.  
  64. class Agent:
  65. def __init__(self, state_size, action_size):
  66. self.state_size = state_size
  67. self.action_size = action_size
  68. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  69.  
  70. # Hyperparameters
  71. self.learning_rate = 1e-3
  72. self.gamma = 0.99
  73. self.epsilon = 1.0
  74. self.epsilon_min = 0.01
  75. self.epsilon_decay = 0.999998
  76. self.batch_size = 192
  77. self.memory = ReplayMemory(25000)
  78. self.update_target_every = 10
  79.  
  80. # Two networks
  81. self.policy_net = DQN(state_size, 192, action_size).to(self.device)
  82. self.target_net = DQN(state_size, 192, action_size).to(self.device)
  83. self.target_net.load_state_dict(self.policy_net.state_dict())
  84. self.target_net.eval()
  85.  
  86. self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.learning_rate, weight_decay=1e-5)
  87. self.loss_fn = nn.MSELoss()
  88.  
  89. self.steps = 0
  90. self.episode = 0
  91.  
  92. def get_state(self, player, balls, traps, players):
  93.  
  94. is_at_left_wall = 1.0 if player['x'] == 0 else 0.0
  95. is_at_right_wall = 1.0 if player['x'] == W else 0.0
  96. is_at_top_wall = 1.0 if player['y'] == 0 else 0.0
  97. is_at_bottom_wall = 1.0 if player['y'] == H else 0.0
  98.  
  99. state = [
  100. player['x'] / W,
  101. player['y'] / H,
  102. is_at_left_wall,
  103. is_at_right_wall,
  104. is_at_top_wall,
  105. is_at_bottom_wall,
  106. float(player['alive']),
  107. player['score'] / 1000,
  108. (START_VEL - round(player["score"] / 14)) / START_VEL
  109. ]
  110.  
  111. # Add closest balls
  112. sorted_balls = sorted(balls, key=lambda b: (b[0] - player['x']) ** 2 + (b[1] - player['y']) ** 2)
  113. for i in range(min(60, len(sorted_balls))):
  114. state.extend([(sorted_balls[i][0] - player['x'])/ W, (sorted_balls[i][1] - player['y']) / H])
  115. for i in range(60 - min(60, len(sorted_balls))):
  116. state.extend([0, 0])
  117.  
  118. # Add closest traps
  119. sorted_traps = sorted(traps, key=lambda t: (t[0] - player['x']) ** 2 + (t[1] - player['y']) ** 2)
  120. for i in range(min(15, len(sorted_traps))):
  121. state.extend([(sorted_traps[i][0] - player['x']) / W, (sorted_traps[i][1] - player['y'])/ H])
  122. for i in range(15 - min(15, len(sorted_traps))):
  123. state.extend([0, 0])
  124.  
  125. # Add other players
  126. player_list = [p for p in players.values() if p['name'] != player['name']]
  127. if player_list:
  128. sorted_players = sorted(player_list,
  129. key=lambda p: (p['x'] - player['x']) ** 2 + (p['y'] - player['y']) ** 2)
  130. for i in range(min(3, len(sorted_players))):
  131. other_player = sorted_players[i]
  132. state.extend([
  133. (other_player['x'] - player['x']) / W,
  134. (other_player['y'] - player['y']) / H,
  135. other_player['score'] / 1000,
  136. 0 if other_player.get('score', 0) > player.get('score', 0) else 1
  137. ])
  138. for i in range(3 - min(3, len(sorted_players))):
  139. state.extend([0, 0, 0, 0])
  140. else:
  141. for i in range(3):
  142. state.extend([0, 0, 0, 0])
  143.  
  144. return torch.FloatTensor(state).unsqueeze(0).to(self.device)
  145.  
  146. def act(self, state, training=True):
  147. if training and random.random() < self.epsilon:
  148. return random.randint(0, self.action_size - 1)
  149.  
  150. with torch.no_grad():
  151. q_values = self.policy_net(state)
  152. return q_values.argmax().item()
  153.  
  154. def remember(self, state, action, reward, next_state, done):
  155. self.memory.append((state, action, reward, next_state, done))
  156.  
  157. def train(self):
  158. if len(self.memory) < self.batch_size:
  159. return 0.0 # Zwracamy 0 jako wartość loss gdy nie ma wystarczająco próbek
  160.  
  161. # Pobieramy próbkę jako listę
  162. batch = self.memory.sample(self.batch_size)
  163. states, actions, rewards, next_states, dones = zip(*batch)
  164.  
  165. # Konwertujemy na tensory
  166. states = torch.cat(states)
  167. actions = torch.LongTensor(actions).to(self.device)
  168. rewards = torch.FloatTensor(rewards).to(self.device)
  169. next_states = torch.cat(next_states)
  170. dones = torch.FloatTensor(dones).to(self.device)
  171.  
  172. # Current Q values
  173. current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))
  174.  
  175. # Target Q values
  176. with torch.no_grad():
  177. next_action_indices = self.policy_net(next_states).argmax(1)
  178. next_q = self.target_net(next_states).gather(1, next_action_indices.unsqueeze(1)).squeeze(1)
  179. target_q = rewards + (1 - dones) * self.gamma * next_q
  180.  
  181. # Compute loss
  182. loss = self.loss_fn(current_q.squeeze(), target_q)
  183.  
  184. # Optimize
  185. self.optimizer.zero_grad()
  186. loss.backward()
  187. torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
  188. self.optimizer.step()
  189.  
  190. # Decay epsilon
  191. self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
  192.  
  193. # Update target network
  194. self.steps += 1
  195. if self.steps % self.update_target_every == 0:
  196. with torch.no_grad():
  197. for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()):
  198. target_param.data.mul_(polyak).add_(policy_param.data, alpha=1 - polyak)
  199.  
  200. return loss.item()
  201.  
  202. def save(self, filename):
  203. torch.save({
  204. 'policy_state_dict': self.policy_net.state_dict(),
  205. 'target_state_dict': self.target_net.state_dict(),
  206. 'optimizer_state_dict': self.optimizer.state_dict(),
  207. 'epsilon': self.epsilon,
  208. 'episode': self.episode,
  209. }, filename)
  210.  
  211. def load(self, filename):
  212. checkpoint = torch.load(filename)
  213. self.policy_net.load_state_dict(checkpoint['policy_state_dict'])
  214. self.target_net.load_state_dict(checkpoint['target_state_dict'])
  215. self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  216. self.epsilon = checkpoint['epsilon']
  217. self.episode = checkpoint['episode']
  218. print(f"Loaded model from {filename}, epsilon: {self.epsilon}")
  219.  
  220.  
  221. # Game initialization
  222. pygame.font.init()
  223. NAME_FONT = pygame.font.SysFont("comicsans", 20)
  224. TIME_FONT = pygame.font.SysFont("comicsans", 30)
  225. SCORE_FONT = pygame.font.SysFont("comicsans", 26)
  226.  
  227. def main(name, train_mode=True, model_file=None):
  228. # Setup pygame window
  229. WIN = pygame.display.set_mode((W, H))
  230. pygame.display.set_caption("Blobs - DQN Agent")
  231.  
  232. # Connect to server
  233. server = Network()
  234. current_id = server.connect(name)
  235. balls, traps, players, game_time, episodes_count = server.send("get")
  236.  
  237. # Initialize agent
  238. state_size = 9 + 60 * 2 + 15 * 2 + 3 * 4 # player + balls + traps + other players
  239. action_size = 4 # Left, Right, Up, Down
  240. agent = Agent(state_size, action_size)
  241.  
  242. if model_file and os.path.exists(model_file):
  243. agent.load(model_file)
  244.  
  245. clock = pygame.time.Clock()
  246. run = True
  247. total_reward = 0
  248.  
  249. while run:
  250. clock.tick(30)
  251. player = players[current_id]
  252.  
  253. # Get current state
  254. state = agent.get_state(player, balls, traps, players)
  255.  
  256. # Choose action
  257. action = agent.act(state, training=train_mode)
  258.  
  259. # Execute action
  260. vel = START_VEL - round(player["score"] / 14)
  261. if vel <= 1:
  262. vel = 1
  263.  
  264. if action == 0: # Left
  265. player["x"] = max(0, player["x"] - vel)
  266. elif action == 1: # Right
  267. player["x"] = min(W, player["x"] + vel)
  268. elif action == 2: # Up
  269. player["y"] = max(0, player["y"] - vel)
  270. elif action == 3: # Down
  271. player["y"] = min(H, player["y"] + vel)
  272.  
  273. # Send move to server
  274. data = "move " + str(player["x"]) + " " + str(player["y"])
  275. balls, traps, players, game_time, episodes_count = server.send(data)
  276.  
  277. # Get new state and reward
  278. next_state = agent.get_state(player, balls, traps, players)
  279. reward = player["reward"]
  280. done = not player.get("alive", True)
  281.  
  282. if train_mode:
  283. # Store experience and train
  284. agent.remember(state, action, reward, next_state, done)
  285. loss = agent.train()
  286.  
  287. total_reward += reward
  288.  
  289. # Print debug info
  290. print(
  291. f"Ep: {agent.episode} Step: {agent.steps} Act: {action} Done: {done} Reward: {reward:.2f} Eps: {agent.epsilon:.2f} Total: {total_reward:.2f}")
  292.  
  293. if done:
  294. agent.episode += 1
  295.  
  296. total_reward = 0
  297.  
  298. if agent.episode % SAVE_INTERVAL == 0:
  299. agent.save(f"dqn_model_ep{agent.episode}.pt")
  300.  
  301. # Handle events
  302. for event in pygame.event.get():
  303. if event.type == pygame.QUIT:
  304. run = False
  305. if event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
  306. run = False
  307.  
  308. # Render
  309. WIN.fill((255, 255, 255))
  310.  
  311. # Draw game elements
  312. for ball in balls:
  313. pygame.draw.circle(WIN, ball[2], (ball[0], ball[1]), BALL_RADIUS)
  314.  
  315. for trap in traps:
  316. pygame.draw.circle(WIN, trap[2], (trap[0], trap[1]), TRAP_RADIUS)
  317.  
  318. for p in sorted(players.values(), key=lambda x: x["score"]):
  319. pygame.draw.circle(WIN, p["color"], (p["x"], p["y"]), PLAYER_RADIUS)
  320. text = NAME_FONT.render(p["name"], 1, (0, 0, 0))
  321. WIN.blit(text, (p["x"] - text.get_width() / 2, p["y"] - text.get_height() / 2))
  322.  
  323. # Draw UI
  324. text = TIME_FONT.render(f"Score: {player['score']}", 1, (0, 0, 0))
  325. WIN.blit(text, (10, 10))
  326.  
  327. text = TIME_FONT.render(f"Time: {game_time}", 1, (0, 0, 0))
  328. WIN.blit(text, (10, 40))
  329.  
  330. if not train_mode:
  331. text = TIME_FONT.render("EVALUATION MODE", 1, (255, 0, 0))
  332. WIN.blit(text, (W // 2 - text.get_width() // 2, 10))
  333.  
  334. pygame.display.update()
  335.  
  336. # Clean up
  337. if train_mode:
  338. agent.save("dqn_model_final.pt")
  339. server.disconnect()
  340. pygame.quit()
  341.  
  342.  
  343. if __name__ == "__main__":
  344. # To train: python game.py --train
  345. # To evaluate: python game.py --model dqn_model_final.pt
  346.  
  347. import argparse
  348. """
  349. parser = argparse.ArgumentParser()
  350. parser.add_argument("--train", action="store_true", help="Train the DQN agent")
  351. parser.add_argument("--model", type=str, help="Path to model file for evaluation")
  352. args = parser.parse_args()
  353. """
  354. os.environ['SDL_VIDEO_WINDOW_POS'] = "%d,%d" % (0, 30)
  355. main("dqn_agent", train_mode=True, model_file="dqn_model_final.pt")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement