Advertisement
Guest User

Untitled

a guest
May 13th, 2024
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 16.36 KB | None | 0 0
  1. # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy
  2. import os
  3. import random
  4. import time
  5. from dataclasses import dataclass
  6.  
  7. import gymnasium as gym
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. import torch.optim as optim
  12. import tyro
  13. from torch.distributions.normal import Normal
  14. from torch.utils.tensorboard import SummaryWriter
  15.  
  16. class DoubleIntegrator(gym.Env):
  17.  
  18. def __init__(self, render_mode=None):
  19. super(DoubleIntegrator, self).__init__()
  20. self.pos = 0
  21. self.vel = 0
  22. self.target = 0
  23. self.curr_step = 0
  24. self.max_steps = 300
  25. self.terminated = False
  26. self.truncated = False
  27. self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
  28. self.observation_space = gym.spaces.Box(low=-5, high=5, shape=(2,))
  29.  
  30. def step(self, action):
  31. action = action[0]
  32. reward = -10 * (self.pos - self.target)
  33. vel = self.vel + 0.1 * action
  34. pos = self.pos + 0.1 * self.vel
  35. self.vel = vel
  36. self.pos = pos
  37. self.curr_step += 1
  38.  
  39. if self.curr_step > self.max_steps:
  40. self.terminated = True
  41. self.truncated = True
  42.  
  43. return self._get_obs(), reward, self.terminated, self.truncated, self._get_info()
  44.  
  45. def reset(self, seed=None, options=None):
  46. self.pos = 0
  47. self.vel = 0
  48. self.target = np.random.uniform() * 10 - 5
  49. self.curr_step = 0
  50. self.terminated = False
  51. self.truncated = False
  52. return self._get_obs(), self._get_info()
  53.  
  54. def _get_obs(self):
  55. return np.array([self.pos - self.target, self.vel], dtype=np.float32)
  56.  
  57. def _get_info(self):
  58. return {'target': self.target, 'pos': self.pos}
  59.  
  60.  
  61.  
  62.  
  63. @dataclass
  64. class Args:
  65. exp_name: str = os.path.basename(__file__)[: -len(".py")]
  66. """the name of this experiment"""
  67. seed: int = 1
  68. """seed of the experiment"""
  69. torch_deterministic: bool = True
  70. """if toggled, `torch.backends.cudnn.deterministic=False`"""
  71. cuda: bool = True
  72. """if toggled, cuda will be enabled by default"""
  73. track: bool = False
  74. """if toggled, this experiment will be tracked with Weights and Biases"""
  75. wandb_project_name: str = "cleanRL"
  76. """the wandb's project name"""
  77. wandb_entity: str = None
  78. """the entity (team) of wandb's project"""
  79. capture_video: bool = False
  80. """whether to capture videos of the agent performances (check out `videos` folder)"""
  81. save_model: bool = False
  82. """whether to save model into the `runs/{run_name}` folder"""
  83. upload_model: bool = False
  84. """whether to upload the saved model to huggingface"""
  85. hf_entity: str = ""
  86. """the user or org name of the model repository from the Hugging Face Hub"""
  87.  
  88. # Algorithm specific arguments
  89. env_id: str = "DoubleIntegrator"
  90. """the id of the environment"""
  91. total_timesteps: int = 1000000
  92. """total timesteps of the experiments"""
  93. learning_rate: float = 3e-4
  94. """the learning rate of the optimizer"""
  95. num_envs: int = 1
  96. """the number of parallel game environments"""
  97. num_steps: int = 2048
  98. """the number of steps to run in each environment per policy rollout"""
  99. anneal_lr: bool = True
  100. """Toggle learning rate annealing for policy and value networks"""
  101. gamma: float = 0.99
  102. """the discount factor gamma"""
  103. gae_lambda: float = 0.95
  104. """the lambda for the general advantage estimation"""
  105. num_minibatches: int = 32
  106. """the number of mini-batches"""
  107. update_epochs: int = 10
  108. """the K epochs to update the policy"""
  109. norm_adv: bool = True
  110. """Toggles advantages normalization"""
  111. clip_coef: float = 0.2
  112. """the surrogate clipping coefficient"""
  113. clip_vloss: bool = True
  114. """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
  115. ent_coef: float = 0.0
  116. """coefficient of the entropy"""
  117. vf_coef: float = 0.5
  118. """coefficient of the value function"""
  119. max_grad_norm: float = 0.5
  120. """the maximum norm for the gradient clipping"""
  121. target_kl: float = None
  122. """the target KL divergence threshold"""
  123.  
  124. # to be filled in runtime
  125. batch_size: int = 0
  126. """the batch size (computed in runtime)"""
  127. minibatch_size: int = 0
  128. """the mini-batch size (computed in runtime)"""
  129. num_iterations: int = 0
  130. """the number of iterations (computed in runtime)"""
  131.  
  132.  
  133. def make_env(env_id, idx, capture_video, run_name, gamma):
  134. def thunk():
  135. if capture_video and idx == 0:
  136. env = gym.make(env_id, render_mode="rgb_array")
  137. env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
  138. else:
  139. #env = gym.make(env_id)
  140. env = DoubleIntegrator()
  141. env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space
  142. env = gym.wrappers.RecordEpisodeStatistics(env)
  143. env = gym.wrappers.ClipAction(env)
  144. env = gym.wrappers.NormalizeObservation(env)
  145. env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
  146. env = gym.wrappers.NormalizeReward(env, gamma=gamma)
  147. env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
  148. return env
  149.  
  150. return thunk
  151.  
  152.  
  153. def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
  154. torch.nn.init.orthogonal_(layer.weight, std)
  155. torch.nn.init.constant_(layer.bias, bias_const)
  156. return layer
  157.  
  158.  
  159. class Agent(nn.Module):
  160. def __init__(self, envs):
  161. super().__init__()
  162. self.critic = nn.Sequential(
  163. layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
  164. nn.Tanh(),
  165. layer_init(nn.Linear(64, 64)),
  166. nn.Tanh(),
  167. layer_init(nn.Linear(64, 1), std=1.0),
  168. )
  169. self.actor_mean = nn.Sequential(
  170. layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
  171. nn.Tanh(),
  172. layer_init(nn.Linear(64, 64)),
  173. nn.Tanh(),
  174. layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
  175. )
  176. self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))
  177.  
  178. def get_value(self, x):
  179. return self.critic(x)
  180.  
  181. def get_action_and_value(self, x, action=None):
  182. action_mean = self.actor_mean(x)
  183. action_logstd = self.actor_logstd.expand_as(action_mean)
  184. action_std = torch.exp(action_logstd)
  185. probs = Normal(action_mean, action_std)
  186. if action is None:
  187. action = probs.sample()
  188. return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
  189.  
  190.  
  191. if __name__ == "__main__":
  192. args = tyro.cli(Args)
  193. args.batch_size = int(args.num_envs * args.num_steps)
  194. args.minibatch_size = int(args.batch_size // args.num_minibatches)
  195. args.num_iterations = args.total_timesteps // args.batch_size
  196. run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
  197. if args.track:
  198. import wandb
  199.  
  200. wandb.init(
  201. project=args.wandb_project_name,
  202. entity=args.wandb_entity,
  203. sync_tensorboard=True,
  204. config=vars(args),
  205. name=run_name,
  206. monitor_gym=True,
  207. save_code=True,
  208. )
  209. writer = SummaryWriter(f"runs/{run_name}")
  210. writer.add_text(
  211. "hyperparameters",
  212. "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
  213. )
  214.  
  215. # TRY NOT TO MODIFY: seeding
  216. random.seed(args.seed)
  217. np.random.seed(args.seed)
  218. torch.manual_seed(args.seed)
  219. torch.backends.cudnn.deterministic = args.torch_deterministic
  220.  
  221. device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
  222.  
  223. # env setup
  224. envs = gym.vector.SyncVectorEnv(
  225. [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)]
  226. )
  227. assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
  228.  
  229. agent = Agent(envs).to(device)
  230. optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
  231.  
  232. # ALGO Logic: Storage setup
  233. obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
  234. actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
  235. logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
  236. rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
  237. dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
  238. values = torch.zeros((args.num_steps, args.num_envs)).to(device)
  239.  
  240. # TRY NOT TO MODIFY: start the game
  241. global_step = 0
  242. start_time = time.time()
  243. next_obs, _ = envs.reset(seed=args.seed)
  244. next_obs = torch.Tensor(next_obs).to(device)
  245. next_done = torch.zeros(args.num_envs).to(device)
  246.  
  247. for iteration in range(1, args.num_iterations + 1):
  248. # Annealing the rate if instructed to do so.
  249. if args.anneal_lr:
  250. frac = 1.0 - (iteration - 1.0) / args.num_iterations
  251. lrnow = frac * args.learning_rate
  252. optimizer.param_groups[0]["lr"] = lrnow
  253.  
  254. for step in range(0, args.num_steps):
  255. global_step += args.num_envs
  256. obs[step] = next_obs
  257. dones[step] = next_done
  258.  
  259. # ALGO LOGIC: action logic
  260. with torch.no_grad():
  261. action, logprob, _, value = agent.get_action_and_value(next_obs)
  262. values[step] = value.flatten()
  263. actions[step] = action
  264. logprobs[step] = logprob
  265.  
  266. # TRY NOT TO MODIFY: execute the game and log data.
  267. next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
  268. next_done = np.logical_or(terminations, truncations)
  269. rewards[step] = torch.tensor(reward).to(device).view(-1)
  270. next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device)
  271.  
  272. if "final_info" in infos:
  273. for info in infos["final_info"]:
  274. if info and "episode" in info:
  275. print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
  276. writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
  277. writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
  278.  
  279. # bootstrap value if not done
  280. with torch.no_grad():
  281. next_value = agent.get_value(next_obs).reshape(1, -1)
  282. advantages = torch.zeros_like(rewards).to(device)
  283. lastgaelam = 0
  284. for t in reversed(range(args.num_steps)):
  285. if t == args.num_steps - 1:
  286. nextnonterminal = 1.0 - next_done
  287. nextvalues = next_value
  288. else:
  289. nextnonterminal = 1.0 - dones[t + 1]
  290. nextvalues = values[t + 1]
  291. delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
  292. advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
  293. returns = advantages + values
  294.  
  295. # flatten the batch
  296. b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
  297. b_logprobs = logprobs.reshape(-1)
  298. b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
  299. b_advantages = advantages.reshape(-1)
  300. b_returns = returns.reshape(-1)
  301. b_values = values.reshape(-1)
  302.  
  303. # Optimizing the policy and value network
  304. b_inds = np.arange(args.batch_size)
  305. clipfracs = []
  306. for epoch in range(args.update_epochs):
  307. np.random.shuffle(b_inds)
  308. for start in range(0, args.batch_size, args.minibatch_size):
  309. end = start + args.minibatch_size
  310. mb_inds = b_inds[start:end]
  311.  
  312. _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
  313. logratio = newlogprob - b_logprobs[mb_inds]
  314. ratio = logratio.exp()
  315.  
  316. with torch.no_grad():
  317. # calculate approx_kl http://joschu.net/blog/kl-approx.html
  318. old_approx_kl = (-logratio).mean()
  319. approx_kl = ((ratio - 1) - logratio).mean()
  320. clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
  321.  
  322. mb_advantages = b_advantages[mb_inds]
  323. if args.norm_adv:
  324. mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
  325.  
  326. # Policy loss
  327. pg_loss1 = -mb_advantages * ratio
  328. pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
  329. pg_loss = torch.max(pg_loss1, pg_loss2).mean()
  330.  
  331. # Value loss
  332. newvalue = newvalue.view(-1)
  333. if args.clip_vloss:
  334. v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
  335. v_clipped = b_values[mb_inds] + torch.clamp(
  336. newvalue - b_values[mb_inds],
  337. -args.clip_coef,
  338. args.clip_coef,
  339. )
  340. v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
  341. v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
  342. v_loss = 0.5 * v_loss_max.mean()
  343. else:
  344. v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
  345.  
  346. entropy_loss = entropy.mean()
  347. loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
  348.  
  349. optimizer.zero_grad()
  350. loss.backward()
  351. nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
  352. optimizer.step()
  353.  
  354. if args.target_kl is not None and approx_kl > args.target_kl:
  355. break
  356.  
  357. y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
  358. var_y = np.var(y_true)
  359. explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
  360.  
  361. # TRY NOT TO MODIFY: record rewards for plotting purposes
  362. writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
  363. writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
  364. writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
  365. writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
  366. writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
  367. writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
  368. writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
  369. writer.add_scalar("losses/explained_variance", explained_var, global_step)
  370. print("SPS:", int(global_step / (time.time() - start_time)))
  371. writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
  372.  
  373. if args.save_model:
  374. model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
  375. torch.save(agent.state_dict(), model_path)
  376. print(f"model saved to {model_path}")
  377. from cleanrl_utils.evals.ppo_eval import evaluate
  378.  
  379. episodic_returns = evaluate(
  380. model_path,
  381. make_env,
  382. args.env_id,
  383. eval_episodes=10,
  384. run_name=f"{run_name}-eval",
  385. Model=Agent,
  386. device=device,
  387. gamma=args.gamma,
  388. )
  389. for idx, episodic_return in enumerate(episodic_returns):
  390. writer.add_scalar("eval/episodic_return", episodic_return, idx)
  391.  
  392. if args.upload_model:
  393. from cleanrl_utils.huggingface import push_to_hub
  394.  
  395. repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
  396. repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
  397. push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")
  398.  
  399. envs.close()
  400. writer.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement