Advertisement
Guest User

Untitled

a guest
Feb 21st, 2020
99
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.58 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from torch.distributions import MultivariateNormal
  4. import gym
  5. import numpy as np
  6. import pybullet_envs
  7.  
  8. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  9.  
  10. class Memory:
  11. def __init__(self):
  12. self.actions = []
  13. self.states = []
  14. self.logprobs = []
  15. self.rewards = []
  16. self.is_terminals = []
  17.  
  18. def clear_memory(self):
  19. del self.actions[:]
  20. del self.states[:]
  21. del self.logprobs[:]
  22. del self.rewards[:]
  23. del self.is_terminals[:]
  24.  
  25. class ActorCritic(nn.Module):
  26. def __init__(self, state_dim, action_dim, action_std):
  27. super(ActorCritic, self).__init__()
  28. # action mean range -1 to 1
  29. self.actor_head = nn.Sequential(
  30. nn.Linear(state_dim, 64),
  31. nn.Tanh(),
  32. nn.Linear(64, 32),
  33. nn.Tanh()
  34. )
  35. self.actor_mu = nn.Sequential(nn.Linear(32, action_dim),
  36. nn.Tanh())
  37. self.actor_sigma = nn.Linear(32, action_dim)
  38. # critic
  39. self.critic = nn.Sequential(
  40. nn.Linear(state_dim, 64),
  41. nn.Tanh(),
  42. nn.Linear(64, 32),
  43. nn.Tanh(),
  44. nn.Linear(32, 1)
  45. )
  46. self.action_var = torch.full((action_dim,), action_std*action_std).to(device)
  47.  
  48. def forward(self):
  49. raise NotImplementedError
  50.  
  51. def act(self, state, memory):
  52. actor_head = self.actor_head(state)
  53.  
  54. action_mean = self.actor_mu(actor_head)
  55.  
  56. action_sigma = torch.exp(self.actor_sigma(actor_head)).squeeze()
  57. #action_mean = self.actor(state)#
  58. cov_mat = torch.diag(action_sigma).to(device)
  59.  
  60. dist = MultivariateNormal(action_mean, cov_mat)
  61. action = dist.sample()
  62. action_logprob = dist.log_prob(action)
  63.  
  64. memory.states.append(state)
  65. memory.actions.append(action)
  66. memory.logprobs.append(action_logprob)
  67.  
  68. return action.detach()
  69.  
  70. def evaluate(self, state, action):
  71. actor_head = self.actor_head(state)
  72. action_mean = self.actor_mu(actor_head)
  73. action_var = torch.exp(self.actor_sigma(actor_head)).expand_as(action_mean)
  74.  
  75. #action_var = self.action_var.expand_as(action_mean)
  76. cov_mat = torch.diag_embed(action_var).to(device)
  77.  
  78. dist = MultivariateNormal(action_mean, cov_mat)
  79.  
  80. action_logprobs = dist.log_prob(action)
  81. dist_entropy = dist.entropy()
  82. state_value = self.critic(state)
  83.  
  84. return action_logprobs, torch.squeeze(state_value), dist_entropy
  85.  
  86. class PPO:
  87. def __init__(self, state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip):
  88. self.lr = lr
  89. self.betas = betas
  90. self.gamma = gamma
  91. self.eps_clip = eps_clip
  92. self.K_epochs = K_epochs
  93.  
  94. self.policy = ActorCritic(state_dim, action_dim, action_std).to(device)
  95. self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
  96.  
  97. self.policy_old = ActorCritic(state_dim, action_dim, action_std).to(device)
  98. self.policy_old.load_state_dict(self.policy.state_dict())
  99.  
  100. self.MseLoss = nn.MSELoss()
  101.  
  102. def select_action(self, state, memory):
  103. state = torch.FloatTensor(state.reshape(1, -1)).to(device)
  104. return self.policy_old.act(state, memory).cpu().data.numpy().flatten()
  105.  
  106. def update(self, memory):
  107. # Monte Carlo estimate of rewards:
  108. rewards = []
  109. discounted_reward = 0
  110. for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
  111. if is_terminal:
  112. discounted_reward = 0
  113. discounted_reward = reward + (self.gamma * discounted_reward)
  114. rewards.insert(0, discounted_reward)
  115.  
  116. # Normalizing the rewards:
  117. rewards = torch.tensor(rewards).to(device)
  118. rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
  119.  
  120. # convert list to tensor
  121. old_states = torch.squeeze(torch.stack(memory.states).to(device), 1).detach()
  122. old_actions = torch.squeeze(torch.stack(memory.actions).to(device), 1).detach()
  123. old_logprobs = torch.squeeze(torch.stack(memory.logprobs), 1).to(device).detach()
  124.  
  125. # Optimize policy for K epochs:
  126. for _ in range(self.K_epochs):
  127. # Evaluating old actions and values :
  128. logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
  129.  
  130. # Finding the ratio (pi_theta / pi_theta__old):
  131. ratios = torch.exp(logprobs - old_logprobs.detach())
  132.  
  133. # Finding Surrogate Loss:
  134. advantages = rewards - state_values.detach()
  135. surr1 = ratios * advantages
  136. surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
  137. loss = -torch.min(surr1, surr2) + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy
  138.  
  139. # take gradient step
  140. self.optimizer.zero_grad()
  141. loss.mean().backward()
  142. self.optimizer.step()
  143.  
  144. # Copy new weights into old policy:
  145. self.policy_old.load_state_dict(self.policy.state_dict())
  146.  
  147. def main():
  148. ############## Hyperparameters ##############
  149. env_name = "HalfCheetahBulletEnv-v0"
  150. render = False
  151. solved_reward = 2200 # stop training if avg_reward > solved_reward
  152. log_interval = 20 # print avg reward in the interval
  153. max_episodes = 10000 # max training episodes
  154. max_timesteps = 1500 # max timesteps in one episode
  155.  
  156. update_timestep = 4000 # update policy every n timesteps
  157. action_std = 0.5 # constant std for action distribution (Multivariate Normal)
  158. K_epochs = 80 # update policy for K epochs
  159. eps_clip = 0.2 # clip parameter for PPO
  160. gamma = 0.99 # discount factor
  161.  
  162. lr = 0.0003 # parameters for Adam optimizer
  163. betas = (0.9, 0.999)
  164.  
  165. random_seed = None
  166. #############################################
  167.  
  168. # creating environment
  169. env = gym.make(env_name)
  170. state_dim = env.observation_space.shape[0]
  171. action_dim = env.action_space.shape[0]
  172.  
  173. if random_seed:
  174. print("Random Seed: {}".format(random_seed))
  175. torch.manual_seed(random_seed)
  176. env.seed(random_seed)
  177. np.random.seed(random_seed)
  178.  
  179. memory = Memory()
  180. ppo = PPO(state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip)
  181. print(lr,betas)
  182.  
  183. # logging variables
  184. running_reward = 0
  185. avg_length = 0
  186. time_step = 0
  187.  
  188. # training loop
  189. for i_episode in range(1, max_episodes+1):
  190. state = env.reset()
  191. for t in range(max_timesteps):
  192. time_step +=1
  193. # Running policy_old:
  194. action = ppo.select_action(state, memory)
  195. state, reward, done, _ = env.step(action)
  196.  
  197. # Saving reward and is_terminals:
  198. memory.rewards.append(reward)
  199. memory.is_terminals.append(done)
  200.  
  201. # update if its time
  202. if time_step % update_timestep == 0:
  203. ppo.update(memory)
  204. memory.clear_memory()
  205. time_step = 0
  206. running_reward += reward
  207. if render:
  208. env.render()
  209. if done:
  210. break
  211.  
  212. avg_length += t
  213.  
  214. # # stop training if avg_reward > solved_reward
  215. # if running_reward > (log_interval*solved_reward):
  216. # print("########## Solved! ##########")
  217. # torch.save(ppo.policy.state_dict(), './PPO_continuous_solved_{}.pth'.format(env_name))
  218. # break
  219.  
  220. # save every 500 episodes
  221. if i_episode % 50 == 0:
  222. torch.save(ppo.policy.state_dict(), './PPO_sigma_{}.pth'.format(env_name))
  223.  
  224. # logging
  225. if i_episode % log_interval == 0:
  226. avg_length = int(avg_length/log_interval)
  227. running_reward = int((running_reward/log_interval))
  228.  
  229. print('Episode {} \t Avg length: {} \t Avg reward: {}'.format(i_episode, avg_length, running_reward))
  230. running_reward = 0
  231. avg_length = 0
  232.  
  233. if __name__ == '__main__':
  234. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement