Guest User

Untitled

a guest
May 27th, 2025
191
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 29.28 KB | None | 0 0
  1. """
  2. PPO MLPlay Module
  3. This module implements a Proximal Policy Optimization (PPO) agent for Unity games
  4. using the MLGame3D framework and Unity ML-Agents PPO implementation.
  5. """
  6. import csv
  7. import os
  8. import time
  9. import numpy as np
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from typing import Dict, Any, List
  14.  
  15. # Check if CUDA is available
  16. device = (
  17. torch.device("cuda") if torch.cuda.is_available()
  18. else torch.device("cpu")
  19. )
  20.  
  21. # Resume training configuration
  22. RESUME_TRAINING = False
  23. MODEL_SAVE_DIR = "./models/20250527_162128" if RESUME_TRAINING else f"./models/{time.strftime('%Y%m%d_%H%M%S')}"
  24. MODEL_LOAD_PATH = f"{MODEL_SAVE_DIR}/model_latest.pt"
  25. OPTIMIZER_LOAD_PATH = f"{MODEL_SAVE_DIR}/optimizer_latest.pt"
  26.  
  27. # Built-in configuration
  28. TRAINING_MODE = True # Set to False to use pre-trained model
  29.  
  30. # PPO hyperparameters
  31. LEARNING_RATE = 3e-4
  32. GAMMA = 0.99 # Discount factor
  33. GAE_LAMBDA = 0.95 # GAE parameter
  34. CLIP_RATIO = 0.2 # Lower clip ratio to encourage safer early updates
  35. VALUE_COEF = 0.5 # Value loss coefficient
  36. ENTROPY_COEF = 0.022 # Best: 0.03
  37. FINAL_ENTROPY_COEF = 0.003
  38. DECAY_RATE = 0.97
  39. MAX_GRAD_NORM = 0.5 # Gradient clipping
  40. UPDATE_EPOCHS = 6 # Number of iterations per update
  41. BUFFER_SIZE = 2048 # Experience buffer size
  42. BATCH_SIZE = 64 # Batch size
  43. UPDATE_FREQUENCY = 2048 # Update frequency
  44. SAVE_FREQUENCY = 5 # Save frequency (episodes)
  45. DECAY_STEPS = 100000
  46.  
  47. # Model parameters
  48. HIDDEN_SIZE = 128
  49.  
  50. # Reward weights
  51. REWARD_WEIGHTS = {
  52. "checkpoint": 1.5, # Passing checkpoint
  53. "progress": 2, # Moving toward goal
  54. "health": 0.5, # Health change
  55. "item_pickup": 0, # Picking up items
  56. "item_use": 0, # Using items
  57. "completion": 2 # Completing level
  58. }
  59.  
  60.  
  61. class ObservationProcessor:
  62. """Class for processing game observations"""
  63.  
  64. def __init__(self):
  65. # 設定 dummy observation 用於初始化時計算大小
  66. dummy_obs = {
  67. "agent_position": [0.0, 0.0, 0.0],
  68. "target_position": [0.0, 0.0, 0.0],
  69. "agent_forward_direction": [0.0, 1.0],
  70. "terrain_grid": [[{"terrain_type": 0.0} for _ in range(5)] for _ in range(5)],
  71. "agent_health_normalized": 1.0,
  72. "last_checkpoint_index": -1.0,
  73. "agent_velocity": [0.0, 0.0],
  74. "other_players": [
  75. {"relative_position": [0.0, 0.0, 0.0]},
  76. {"relative_position": [0.0, 0.0, 0.0]},
  77. {"relative_position": [0.0, 0.0, 0.0]}
  78. ]
  79. }
  80. self.observation_size = self._calculate_observation_size(dummy_obs)
  81. print(f"Observation size calculated: {self.observation_size}")
  82.  
  83. def process(self, observations: Dict[str, Any]) -> torch.Tensor:
  84. """
  85. Process observation data
  86. Args:
  87. observations: Game observation dictionary
  88. Returns:
  89. Processed observation tensor (dx, dz, forward_x, forward_z, 5x5 terrain)
  90. """
  91. # Processed observation includes: dx, dz, normalized distance, 5x5 terrain types, normalized mud positions
  92. flattened = self._flatten_observations(observations)
  93. return torch.tensor(flattened, dtype=torch.float32)
  94.  
  95. def get_size(self) -> int:
  96. """Return the size of processed observations"""
  97. return self.observation_size
  98.  
  99. def _flatten_observations(self, observations: Dict[str, Any]) -> List[float]:
  100. agent = observations["agent_position"]
  101. target = observations["target_position"]
  102. dx = target[0] - agent[0]
  103. dz = target[2] - agent[2]
  104. # Replace normalization with scaled dx, dz
  105. flattened = [dx / 30.0, dz / 30.0]
  106. # flattened.append(distance_to_target)
  107.  
  108. agent_forward = np.array(observations["agent_forward_direction"])
  109. agent_forward_normalized = agent_forward / (np.linalg.norm(agent_forward) + 1e-6)
  110. flattened.extend(agent_forward_normalized.tolist())
  111.  
  112. # Add normalized agent health
  113. # flattened.append(observations.get("agent_health_normalized", 1.0))
  114.  
  115. # Add last checkpoint index
  116. # checkpoint_index = observations.get("last_checkpoint_index", -1.0)
  117. # normalized_checkpoint = checkpoint_index / 10.0
  118. # flattened.append(normalized_checkpoint)
  119.  
  120. terrain_types = []
  121. for row in observations.get("terrain_grid", []):
  122. for cell in row:
  123. terrain_types.append(cell["terrain_type"])
  124. normalized_terrain = [terrain_type + 1 for terrain_type in terrain_types]
  125. flattened.extend(normalized_terrain)
  126.  
  127. mud_positions = []
  128. for obj in observations.get("nearby_map_objects", []):
  129. if obj["object_type"] == 1.0: # mud
  130. dx, dz = obj["relative_position"]
  131. mud_positions.append(dx / 10.0)
  132. mud_positions.append(dz / 10.0)
  133. while len(mud_positions) < 10:
  134. mud_positions.append(1.1)
  135. flattened.extend(mud_positions)
  136.  
  137. other_players_positions = []
  138. for player in observations.get("other_players", [])[:3]:
  139. dx = player["relative_position"][0] / 30.0
  140. dz = player["relative_position"][2] / 30.0
  141. other_players_positions.append(dx)
  142. other_players_positions.append(dz)
  143. flattened.extend(other_players_positions)
  144.  
  145. # Add agent_velocity (normalized)
  146. # velocity = np.array(observations.get("agent_velocity", [0.0, 0.0]))
  147. # velocity_norm = np.linalg.norm(velocity)
  148. # normalized_velocity = (velocity / velocity_norm) if velocity_norm > 0 else np.array([0.0, 0.0])
  149. # flattened.extend(normalized_velocity.tolist())
  150.  
  151. return flattened
  152.  
  153. def _calculate_observation_size(self, sample_obs: Dict[str, Any]) -> int:
  154. """
  155. Calculate the size of the observation space
  156. Returns:
  157. Size of the flattened observation vector
  158. """
  159. obs = dict(sample_obs)
  160. obs.setdefault("agent_health_normalized", 1.0)
  161. obs.setdefault("last_checkpoint_index", -1.0)
  162. return len(self._flatten_observations(obs))
  163.  
  164.  
  165. class ActionProcessor:
  166. """Class for processing actions"""
  167.  
  168. def __init__(self, action_space_info=None):
  169. self.action_space_info = action_space_info
  170. self.action_size = 2 # Output: 2 continuous actions (ax, az)
  171.  
  172. def create_action(self, network_output):
  173. """
  174. Convert network output to game action
  175.  
  176. Args:
  177. network_output: Neural network output tensor of shape (2,)
  178. network_output[:2] = mean of (ax, az)
  179.  
  180. Returns:
  181. Tuple[List[float], List[int]]: continuous actions and fixed discrete actions
  182. """
  183. if network_output.dim() > 1:
  184. network_output = network_output.squeeze(0)
  185. continuous_action = network_output.cpu().numpy().astype(np.float32)
  186.  
  187. # Discrete actions: fixed [0, 0]
  188. discrete_action = np.array([0, 0], dtype=np.int32)
  189.  
  190. return continuous_action, discrete_action
  191.  
  192. def get_size(self):
  193. """Return the size of the action space output"""
  194. return self.action_size # only 2 continuous outputs
  195.  
  196. def _process_discrete_action(self, network_output):
  197. """Process discrete action from logits"""
  198. discrete_logits = network_output[-2:]
  199. discrete_probs = torch.sigmoid(discrete_logits)
  200. return [int(prob > 0.5) for prob in discrete_probs]
  201.  
  202.  
  203. class RewardCalculator:
  204. """Class for calculating rewards"""
  205.  
  206. def __init__(self):
  207. self.reward_weights = REWARD_WEIGHTS
  208. self.prev_checkpoint_index = -1
  209. self.prev_distance_to_target = float('inf')
  210. self.prev_health = 0
  211. self.prev_inventory_count = 0
  212.  
  213. def calculate(self, observations, reward, done, info, prev_observations):
  214. additional_reward = 0.0
  215. agent_pos = np.array(observations["agent_position"])
  216. target_pos = np.array(observations["target_position"])
  217. curr_dist = np.linalg.norm(agent_pos[[0, 2]] - target_pos[[0, 2]])
  218. prev_agent_pos = np.array(prev_observations["agent_position"])
  219. prev_dist = np.linalg.norm(prev_agent_pos[[0, 2]] - target_pos[[0, 2]])
  220. current_health = observations["agent_health_normalized"]
  221. prev_health = prev_observations["agent_health_normalized"]
  222.  
  223. current_checkpoint = observations.get("last_checkpoint_index", -1)
  224. prev_checkpoint = prev_observations.get("last_checkpoint_index", -1)
  225.  
  226. # Checkpoint reward
  227. if current_checkpoint > prev_checkpoint:
  228. additional_reward += self.reward_weights["checkpoint"]
  229.  
  230. # Progress reward
  231. if prev_health > 0.0:
  232. progress = prev_dist - curr_dist
  233. additional_reward += self.reward_weights["progress"] * progress
  234.  
  235. # Health reward
  236.  
  237.  
  238. # Death penalty
  239. if prev_health > 0.0 and current_health == 0.0:
  240. additional_reward += -2.0 * self.reward_weights["health"]
  241.  
  242. # Complete reward
  243. if current_checkpoint == 1.0 and prev_checkpoint == 0.0:
  244. additional_reward += self.reward_weights["completion"]
  245.  
  246. return additional_reward
  247.  
  248.  
  249. class ExperienceBuffer:
  250. """Experience buffer class"""
  251.  
  252. def __init__(self, capacity, model):
  253. self.observations = []
  254. self.actions = []
  255. self.rewards = []
  256. self.dones = []
  257. self.next_observations = []
  258. self.action_log_probs = []
  259. self.values = []
  260. self.capacity = capacity
  261. self.model = model
  262.  
  263. def add(self, observation, action, reward, done, next_observation, action_log_prob, value):
  264. """Add experience"""
  265. self.observations.append(observation)
  266. self.actions.append(action)
  267. self.rewards.append(reward)
  268. self.dones.append(done)
  269. self.next_observations.append(next_observation)
  270. self.action_log_probs.append(action_log_prob)
  271. self.values.append(value)
  272.  
  273. def clear(self):
  274. """Clear buffer"""
  275. self.observations.clear()
  276. self.actions.clear()
  277. self.rewards.clear()
  278. self.dones.clear()
  279. self.next_observations.clear()
  280. self.action_log_probs.clear()
  281. self.values.clear()
  282.  
  283. def get_batches(self, batch_size):
  284. """Get batch data"""
  285. indices = np.arange(len(self.observations))
  286. np.random.shuffle(indices)
  287.  
  288. for start in range(0, len(indices), batch_size):
  289. end = start + batch_size
  290. batch_indices = indices[start:end]
  291. yield (
  292. [self.observations[i] for i in batch_indices],
  293. [self.actions[i] for i in batch_indices],
  294. [self.rewards[i] for i in batch_indices],
  295. [self.dones[i] for i in batch_indices],
  296. [self.next_observations[i] for i in batch_indices],
  297. [self.action_log_probs[i] for i in batch_indices],
  298. [self.values[i] for i in batch_indices],
  299. )
  300.  
  301. def compute_advantages(self, gamma, lam):
  302. """Compute advantage function, returns (advantages, returns) (returns are not standardized)"""
  303. advantages = []
  304. returns = []
  305. gae = 0
  306. next_value = 0
  307. for t in reversed(range(len(self.rewards))):
  308. next_value = self.values[t + 1] if t + 1 < len(self.values) else 0
  309. delta = self.rewards[t] + gamma * next_value * (1 - self.dones[t]) - self.values[t]
  310. gae = delta + gamma * lam * (1 - self.dones[t]) * gae
  311. advantages.insert(0, gae)
  312. returns.insert(0, gae + self.values[t])
  313. return advantages, returns
  314.  
  315. def __len__(self):
  316. return len(self.observations)
  317.  
  318. class PPOModel(nn.Module):
  319. def __init__(self, observation_size, action_size):
  320. super(PPOModel, self).__init__()
  321.  
  322. # Shared feature extractor
  323. self.feature_extractor = nn.Sequential(
  324. nn.Linear(observation_size, HIDDEN_SIZE),
  325. nn.Tanh(),
  326. nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
  327. nn.Tanh()
  328. )
  329.  
  330. # if isinstance(action_size, tuple):
  331. # continuous_size, discrete_size = action_size
  332. # self.continuous_policy = nn.Sequential(
  333. # nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
  334. # nn.Tanh(),
  335. # nn.Linear(HIDDEN_SIZE, continuous_size * 2) # mean and log_std
  336. # )
  337. # self.discrete_policy = nn.Sequential(
  338. # nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
  339. # nn.Tanh(),
  340. # nn.Linear(HIDDEN_SIZE, discrete_size)
  341. # )
  342. # self.action_type = "hybrid"
  343. # elif isinstance(action_size, int):
  344. # if action_size > 10:
  345. # self.policy = nn.Sequential(
  346. # nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
  347. # nn.Tanh(),
  348. # nn.Linear(HIDDEN_SIZE, action_size * 2)
  349. # )
  350. # self.action_type = "continuous"
  351. # else:
  352. # self.policy = nn.Sequential(
  353. # nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
  354. # nn.Tanh(),
  355. # nn.Linear(HIDDEN_SIZE, action_size)
  356. # )
  357. # self.action_type = "discrete"
  358.  
  359. self.policy = nn.Sequential(
  360. nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
  361. nn.Tanh(),
  362. nn.Linear(HIDDEN_SIZE, action_size * 2)
  363. )
  364. self.action_type = "continuous"
  365.  
  366. self.value = nn.Sequential(
  367. nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
  368. nn.Tanh(),
  369. nn.Linear(HIDDEN_SIZE, 1)
  370. )
  371.  
  372. def forward(self, x):
  373. features = self.feature_extractor(x)
  374. if self.action_type == "hybrid":
  375. continuous_params = self.continuous_policy(features)
  376. discrete_logits = self.discrete_policy(features)
  377. value = self.value(features)
  378. return continuous_params, discrete_logits, value
  379. else:
  380. policy_output = self.policy(features)
  381. value = self.value(features)
  382. return policy_output, value
  383.  
  384. def act(self, x):
  385. output, value = self.forward(x)
  386. if self.action_type == "continuous":
  387. mean, log_std = torch.chunk(output, 2, dim=-1)
  388. std = torch.exp(log_std.clamp(min=-2, max=0.7))
  389. dist = torch.distributions.Normal(mean, std)
  390. action = dist.sample()
  391. log_prob = dist.log_prob(action).sum(dim=-1)
  392.  
  393. mean_list = mean.squeeze(0).tolist() if mean.dim() > 1 else mean.tolist()
  394. log_std_list = log_std.squeeze(0).tolist() if log_std.dim() > 1 else log_std.tolist()
  395. value_item = value.item() if isinstance(value, torch.Tensor) else value
  396.  
  397. dx_dz_str = ""
  398. if x is not None and x.size(-1) >= 2:
  399. dx = x[0].item() if x.dim() == 1 else x[0, 0].item()
  400. dz = x[1].item() if x.dim() == 1 else x[0, 1].item()
  401. dx_dz_str = f",{dx:.6f},{dz:.6f}"
  402. with open("act_debug.csv", "a") as f:
  403. f.write(",".join([f"{x:.6f}" for x in mean_list + log_std_list]) + f",{value_item:.6f}{dx_dz_str}\n")
  404.  
  405. return action, log_prob, value
  406. else:
  407. dist = torch.distributions.Categorical(logits=output)
  408. action = dist.sample()
  409. log_prob = dist.log_prob(action)
  410. return action, log_prob, value.squeeze(-1)
  411.  
  412. def evaluate_actions(self, x, actions):
  413. output, value = self.forward(x)
  414. if self.action_type == "continuous":
  415. mean, log_std = torch.chunk(output, 2, dim=-1)
  416. std = torch.exp(log_std.clamp(min=-2, max=0.7))
  417. dist = torch.distributions.Normal(mean, std)
  418. log_probs = dist.log_prob(actions).sum(dim=-1)
  419. entropy = dist.entropy().sum(-1)
  420. elif self.action_type == "discrete":
  421. dist = torch.distributions.Categorical(logits=output)
  422. log_probs = dist.log_prob(actions)
  423. entropy = dist.entropy()
  424. else: # hybrid
  425. continuous_params, discrete_logits, _ = output
  426. mean, log_std = torch.chunk(continuous_params, 2, dim=-1)
  427. std = torch.exp(log_std.clamp(-2, 0.7))
  428. continuous_dist = torch.distributions.Normal(mean, std)
  429. continuous_actions, discrete_actions = actions
  430. continuous_log_probs = continuous_dist.log_prob(continuous_actions).sum(-1)
  431. continuous_entropy = continuous_dist.entropy().sum(-1)
  432.  
  433. discrete_dist = torch.distributions.Categorical(logits=discrete_logits)
  434. discrete_log_probs = discrete_dist.log_prob(discrete_actions)
  435. discrete_entropy = discrete_dist.entropy()
  436.  
  437. log_probs = continuous_log_probs + discrete_log_probs
  438. entropy = continuous_entropy + discrete_entropy
  439. return log_probs, entropy, value
  440.  
  441. def save(self, path):
  442. os.makedirs(os.path.dirname(path), exist_ok=True)
  443. torch.save(self.state_dict(), path)
  444.  
  445. def load(self, path):
  446. if os.path.exists(path):
  447. try:
  448. self.load_state_dict(torch.load(path, map_location=device))
  449. print(f"Model loaded from {path}")
  450. return True
  451. except RuntimeError as e:
  452. print(f"Error loading model from {path}: {e}")
  453. return False
  454. else:
  455. print(f"Model file not found: {path}")
  456. return False
  457.  
  458. class MLPlay:
  459. """
  460. MLPlay class using PPO algorithm
  461.  
  462. This class implements the PPO algorithm, which can train models during gameplay or use pre-trained models.
  463. """
  464.  
  465. def __init__(self, action_space_info=None):
  466. """
  467. Initialize MLPlay instance
  468.  
  469. Args:
  470. action_space_info: Action space information
  471. """
  472. # Set name
  473. self.name = "PPO_MLPlay"
  474.  
  475. # Initialize components
  476. self.observation_processor = ObservationProcessor()
  477. self.action_processor = ActionProcessor(action_space_info)
  478. self.reward_calculator = RewardCalculator()
  479.  
  480. # Training mode
  481. self.training_mode = TRAINING_MODE
  482.  
  483. # Initialize state
  484. self.real_prev_observations = None
  485. self.prev_observations = None
  486. self.prev_action = None
  487. self.prev_action_log_prob = None
  488. self.episode_rewards = []
  489. self.total_steps = 0
  490. self.accumulate_steps = 0
  491. self.episode_count = 0
  492.  
  493. # Create model directory
  494. os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
  495.  
  496. # Wait for first observation to initialize model
  497. self.model = None
  498. self.optimizer = None
  499. self.experience_buffer = None
  500.  
  501. if not RESUME_TRAINING:
  502. for fname, header in [
  503. ("reward_log.csv", "episode,reward,policy_loss,value_loss,entropy,advantage,advantage_std\n"),
  504. ("log_std_monitor.csv", "policy_loss,value_loss,entrophy_loss,entrophy_coef\n"),
  505. ("value_debug.csv", "step,value_mean,target_mean,value_loss,grad_norm\n"),
  506. ("act_debug.csv", "episode,mean_x,mean_z,log_std_x,log_std_z,value,dx,dz\n")
  507. ]:
  508. with open(fname, "w") as f:
  509. f.write(header)
  510.  
  511. print(f"PPO MLPlay initialized (Training mode: {self.training_mode})")
  512.  
  513. def reset(self):
  514. """Reset MLPlay instance"""
  515. self.prev_observations = None
  516. self.prev_action = None
  517. self.prev_action_log_prob = None
  518. self.episode_rewards = []
  519.  
  520. def update(self,
  521. observations: Dict[str, np.ndarray],
  522. done: bool = False,
  523. info: Dict[str, Any] = None) -> np.ndarray:
  524. """
  525. Process observations and return actions
  526. """
  527. try:
  528. raw_obs = observations
  529.  
  530. if observations["current_time_normalized"] >= 1.0:
  531. self.real_prev_observations = observations
  532. self.prev_observations = None
  533. return np.array([0.0, 0.0], dtype=np.float32), np.array([0, 0], dtype=np.int32)
  534.  
  535. if observations["last_checkpoint_index"] == 1.0:
  536. if self.real_prev_observations is not None and self.real_prev_observations["last_checkpoint_index"] == 1.0:
  537. self.real_prev_observations = observations
  538. self.prev_observations = None
  539. return np.array([0.0, 0.0], dtype=np.float32), np.array([0, 0], dtype=np.int32)
  540.  
  541. if self.real_prev_observations is not None:
  542. prev_pos = self.real_prev_observations["agent_position"]
  543. prev_hp = self.real_prev_observations["agent_health_normalized"]
  544. curr_pos = raw_obs["agent_position"]
  545. curr_hp = raw_obs["agent_health_normalized"]
  546. if (prev_hp == 0.0 and curr_hp == 0.0) or (curr_pos[1] > 4) or (prev_pos[0] == curr_pos[0] and prev_pos[2] == curr_pos[2] and prev_pos[1] > 1.4):
  547. self.real_prev_observations = raw_obs
  548. return np.array([0.0, 0.0], dtype=np.float32), np.array([0, 0], dtype=np.int32)
  549.  
  550. current_obs = self.observation_processor.process(raw_obs)
  551.  
  552. if self.model is None:
  553. # Initialize model once we know input/output sizes
  554. self._initialize_model(self.observation_processor.get_size(), self.action_processor.get_size())
  555.  
  556. with torch.no_grad():
  557. action_tensor, log_prob, value = self.model.act(current_obs)
  558. action, discrete_action = self.action_processor.create_action(action_tensor)
  559.  
  560. if self.prev_observations is not None and self.training_mode:
  561. shaped_reward = self.reward_calculator.calculate(
  562. raw_obs, reward=0.0, done=done, info=info, prev_observations=self.prev_observations
  563. )
  564. self.experience_buffer.add(
  565. observation=self.observation_processor.process(self.prev_observations),
  566. action=self.prev_action,
  567. reward=shaped_reward,
  568. done=done,
  569. next_observation=current_obs,
  570. action_log_prob=self.prev_action_log_prob.detach() if self.prev_action_log_prob is not None else None,
  571. value=value.item()
  572. )
  573. self.episode_rewards.append(shaped_reward)
  574.  
  575. self.total_steps += 1
  576. self.accumulate_steps += 1
  577. print(f"[Step {self.total_steps}] reward: {shaped_reward}, action_tensor: {action_tensor.numpy()}, log_prob: {log_prob.item():.4f}")
  578. if self.total_steps % UPDATE_FREQUENCY == 0:
  579. self._update_policy()
  580. self._save_model()
  581. self.experience_buffer.clear()
  582.  
  583. self.real_prev_observations = raw_obs
  584. self.prev_observations = raw_obs
  585. self.prev_action = action_tensor
  586. self.prev_action_log_prob = log_prob
  587.  
  588. # print(f"[Step {self.total_steps}] direction: {action}")
  589. return action, discrete_action
  590. except Exception as e:
  591. import traceback
  592. traceback.print_exc()
  593. return np.array([0.0, 0.0], dtype=np.float32), np.array([0, 0], dtype=np.int32)
  594.  
  595. def _initialize_model(self, observation_size, action_size):
  596. """Initialize model and related components"""
  597. self.model = PPOModel(observation_size, action_size).to(device)
  598. self.base_learning_rate = LEARNING_RATE
  599. self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.base_learning_rate)
  600. self.experience_buffer = ExperienceBuffer(BUFFER_SIZE, self.model)
  601.  
  602. if RESUME_TRAINING and os.path.exists(MODEL_LOAD_PATH) and os.path.exists(OPTIMIZER_LOAD_PATH):
  603. loaded = self.model.load(MODEL_LOAD_PATH)
  604. checkpoint = torch.load(OPTIMIZER_LOAD_PATH, map_location=device)
  605. self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  606. self.total_steps = (checkpoint.get('episode', 0) + 1) * UPDATE_FREQUENCY
  607. self.episode_count = checkpoint.get('episode', 0)
  608. elif not self.training_mode:
  609. loaded = self.model.load(MODEL_LOAD_PATH)
  610. if not loaded:
  611. print(f"Failed to load model from {MODEL_LOAD_PATH}. Will run with untrained model.")
  612.  
  613. def _update_policy(self):
  614. """Update PPO policy"""
  615. decay_factor = DECAY_RATE ** (self.total_steps / UPDATE_FREQUENCY)
  616. entropy_coef = max(FINAL_ENTROPY_COEF, ENTROPY_COEF * decay_factor)
  617. # decay_ratio = self.total_steps / DECAY_STEPS
  618. # entropy_coef = ENTROPY_COEF - (ENTROPY_COEF - FINAL_ENTROPY_COEF) * decay_ratio
  619.  
  620. advantages, returns = self.experience_buffer.compute_advantages(GAMMA, GAE_LAMBDA)
  621. advantages = torch.tensor(advantages, dtype=torch.float32).to(device)
  622. returns = torch.tensor(returns, dtype=torch.float32).to(device)
  623. # Initialize accumulators for logging
  624. total_policy_loss = 0
  625. total_value_loss = 0
  626. total_entropy = 0
  627. batch_count = 0
  628.  
  629. for _ in range(UPDATE_EPOCHS):
  630. for batch in self.experience_buffer.get_batches(BATCH_SIZE):
  631. obs_b, act_b, rew_b, done_b, next_obs_b, logp_b, val_b = batch
  632. obs_b = torch.stack(obs_b).to(device)
  633. act_b = torch.stack(act_b).to(device)
  634. logp_b = torch.stack(logp_b).to(device)
  635. val_b = torch.tensor(val_b, dtype=torch.float32).to(device)
  636.  
  637. adv_b = advantages[:len(obs_b)]
  638. adv_b = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
  639. vt_b = returns[:len(obs_b)]
  640.  
  641. new_log_probs, entropy, values = self.model.evaluate_actions(obs_b, act_b)
  642. ratio = torch.exp(new_log_probs - logp_b)
  643.  
  644. surr1 = ratio * adv_b
  645. surr2 = torch.clamp(ratio, 1.0 - CLIP_RATIO, 1.0 + CLIP_RATIO) * adv_b
  646. policy_loss = -torch.min(surr1, surr2).mean()
  647.  
  648. value_loss = F.mse_loss(values.squeeze(-1), vt_b.squeeze(-1))
  649. entropy_loss = entropy.mean()
  650.  
  651. loss = policy_loss + VALUE_COEF * value_loss - entropy_coef * entropy_loss
  652. with open("log_std_monitor.csv", "a") as f:
  653. f.write(
  654. f"{self.episode_count},{policy_loss.item():.7f},{value_loss.item():.7f},{entropy_loss.item():.7f},{entropy_coef:.7f}\n"
  655. )
  656. self.optimizer.zero_grad()
  657. torch.autograd.set_detect_anomaly(True)
  658. loss.backward()
  659. critic_last_layer = self.model.value[-1]
  660. grad_norm = critic_last_layer.weight.grad.norm().item() if critic_last_layer.weight.grad is not None else 0.0
  661. nn.utils.clip_grad_norm_(self.model.parameters(), MAX_GRAD_NORM)
  662. self.optimizer.step()
  663.  
  664. total_policy_loss += policy_loss.item()
  665. total_value_loss += value_loss.item()
  666. total_entropy += entropy_loss.item()
  667. batch_count += 1
  668.  
  669. with open("value_debug.csv", "a") as f:
  670. f.write(f"{self.total_steps},{values.mean().item():.4f},{vt_b.mean().item():.4f},{value_loss.item():.4f},{grad_norm:.6f}\n")
  671.  
  672. # Log statistics and total reward of latest episode to reward_log.csv
  673. total_reward = sum(self.episode_rewards)
  674. if batch_count > 0:
  675. avg_policy_loss = total_policy_loss / batch_count
  676. avg_value_loss = total_value_loss / batch_count
  677. avg_entropy = total_entropy / batch_count
  678. else:
  679. avg_policy_loss = 0.0
  680. avg_value_loss = 0.0
  681. avg_entropy = 0.0
  682. avg_advantage = advantages.mean().item()
  683. advantage_std = advantages.std().item()
  684. with open("reward_log.csv", "a") as f:
  685. f.write(
  686. f"{self.episode_count},{total_reward:.2f},{avg_policy_loss:.4f},{avg_value_loss:.4f},{avg_entropy:.6f},{avg_advantage:.4f},{advantage_std:.4f}\n"
  687. )
  688. self.episode_rewards = []
  689. self.episode_count += 1
  690.  
  691. def _save_model(self):
  692. """Save model"""
  693. if self.model is not None:
  694. # Save latest model
  695. self.model.save(f"{MODEL_SAVE_DIR}/model_latest.pt")
  696. self.model.save(f"{MODEL_SAVE_DIR}/model_{self.episode_count}.pt")
  697.  
  698. torch.save({
  699. 'optimizer_state_dict': self.optimizer.state_dict(),
  700. 'step': self.total_steps,
  701. 'episode': self.episode_count
  702. }, f"{MODEL_SAVE_DIR}/optimizer_latest.pt")
  703. torch.save({
  704. 'optimizer_state_dict': self.optimizer.state_dict(),
  705. 'step': self.total_steps,
  706. 'episode': self.episode_count
  707. }, f"{MODEL_SAVE_DIR}/optimizer_{self.episode_count}.pt")
Add Comment
Please, Sign In to add comment