SHARE
TWEET

Untitled

a guest Feb 21st, 2020 65 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torch.nn as nn
  3. from torch.distributions import MultivariateNormal
  4. import gym
  5. import numpy as np
  6. import pybullet_envs
  7.  
  8. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  9.  
  10. class Memory:
  11.     def __init__(self):
  12.         self.actions = []
  13.         self.states = []
  14.         self.logprobs = []
  15.         self.rewards = []
  16.         self.is_terminals = []
  17.    
  18.     def clear_memory(self):
  19.         del self.actions[:]
  20.         del self.states[:]
  21.         del self.logprobs[:]
  22.         del self.rewards[:]
  23.         del self.is_terminals[:]
  24.  
  25. class ActorCritic(nn.Module):
  26.     def __init__(self, state_dim, action_dim, action_std):
  27.         super(ActorCritic, self).__init__()
  28.         # action mean range -1 to 1
  29.         self.actor_head =  nn.Sequential(
  30.                 nn.Linear(state_dim, 64),
  31.                 nn.Tanh(),
  32.                 nn.Linear(64, 32),
  33.                 nn.Tanh()
  34.                 )
  35.         self.actor_mu = nn.Sequential(nn.Linear(32, action_dim),
  36.                 nn.Tanh())
  37.         self.actor_sigma = nn.Linear(32, action_dim)
  38.         # critic
  39.         self.critic = nn.Sequential(
  40.                 nn.Linear(state_dim, 64),
  41.                 nn.Tanh(),
  42.                 nn.Linear(64, 32),
  43.                 nn.Tanh(),
  44.                 nn.Linear(32, 1)
  45.                 )
  46.         self.action_var = torch.full((action_dim,), action_std*action_std).to(device)
  47.        
  48.     def forward(self):
  49.         raise NotImplementedError
  50.    
  51.     def act(self, state, memory):
  52.         actor_head = self.actor_head(state)
  53.  
  54.         action_mean = self.actor_mu(actor_head)
  55.  
  56.         action_sigma = torch.exp(self.actor_sigma(actor_head)).squeeze()
  57.         #action_mean = self.actor(state)#
  58.         cov_mat = torch.diag(action_sigma).to(device)
  59.        
  60.         dist = MultivariateNormal(action_mean, cov_mat)
  61.         action = dist.sample()
  62.         action_logprob = dist.log_prob(action)
  63.        
  64.         memory.states.append(state)
  65.         memory.actions.append(action)
  66.         memory.logprobs.append(action_logprob)
  67.        
  68.         return action.detach()
  69.    
  70.     def evaluate(self, state, action):  
  71.         actor_head = self.actor_head(state)
  72.         action_mean = self.actor_mu(actor_head)
  73.         action_var = torch.exp(self.actor_sigma(actor_head)).expand_as(action_mean)
  74.        
  75.         #action_var = self.action_var.expand_as(action_mean)
  76.         cov_mat = torch.diag_embed(action_var).to(device)
  77.        
  78.         dist = MultivariateNormal(action_mean, cov_mat)
  79.        
  80.         action_logprobs = dist.log_prob(action)
  81.         dist_entropy = dist.entropy()
  82.         state_value = self.critic(state)
  83.        
  84.         return action_logprobs, torch.squeeze(state_value), dist_entropy
  85.  
  86. class PPO:
  87.     def __init__(self, state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip):
  88.         self.lr = lr
  89.         self.betas = betas
  90.         self.gamma = gamma
  91.         self.eps_clip = eps_clip
  92.         self.K_epochs = K_epochs
  93.        
  94.         self.policy = ActorCritic(state_dim, action_dim, action_std).to(device)
  95.         self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
  96.        
  97.         self.policy_old = ActorCritic(state_dim, action_dim, action_std).to(device)
  98.         self.policy_old.load_state_dict(self.policy.state_dict())
  99.        
  100.         self.MseLoss = nn.MSELoss()
  101.    
  102.     def select_action(self, state, memory):
  103.         state = torch.FloatTensor(state.reshape(1, -1)).to(device)
  104.         return self.policy_old.act(state, memory).cpu().data.numpy().flatten()
  105.    
  106.     def update(self, memory):
  107.         # Monte Carlo estimate of rewards:
  108.         rewards = []
  109.         discounted_reward = 0
  110.         for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
  111.             if is_terminal:
  112.                 discounted_reward = 0
  113.             discounted_reward = reward + (self.gamma * discounted_reward)
  114.             rewards.insert(0, discounted_reward)
  115.        
  116.         # Normalizing the rewards:
  117.         rewards = torch.tensor(rewards).to(device)
  118.         rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
  119.        
  120.         # convert list to tensor
  121.         old_states = torch.squeeze(torch.stack(memory.states).to(device), 1).detach()
  122.         old_actions = torch.squeeze(torch.stack(memory.actions).to(device), 1).detach()
  123.         old_logprobs = torch.squeeze(torch.stack(memory.logprobs), 1).to(device).detach()
  124.        
  125.         # Optimize policy for K epochs:
  126.         for _ in range(self.K_epochs):
  127.             # Evaluating old actions and values :
  128.             logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
  129.            
  130.             # Finding the ratio (pi_theta / pi_theta__old):
  131.             ratios = torch.exp(logprobs - old_logprobs.detach())
  132.  
  133.             # Finding Surrogate Loss:
  134.             advantages = rewards - state_values.detach()  
  135.             surr1 = ratios * advantages
  136.             surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
  137.             loss = -torch.min(surr1, surr2) + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy
  138.            
  139.             # take gradient step
  140.             self.optimizer.zero_grad()
  141.             loss.mean().backward()
  142.             self.optimizer.step()
  143.            
  144.         # Copy new weights into old policy:
  145.         self.policy_old.load_state_dict(self.policy.state_dict())
  146.        
  147. def main():
  148.     ############## Hyperparameters ##############
  149.     env_name = "HalfCheetahBulletEnv-v0"
  150.     render = False
  151.     solved_reward = 2200         # stop training if avg_reward > solved_reward
  152.     log_interval = 20           # print avg reward in the interval
  153.     max_episodes = 10000        # max training episodes
  154.     max_timesteps = 1500        # max timesteps in one episode
  155.    
  156.     update_timestep = 4000      # update policy every n timesteps
  157.     action_std = 0.5            # constant std for action distribution (Multivariate Normal)
  158.     K_epochs = 80               # update policy for K epochs
  159.     eps_clip = 0.2           # clip parameter for PPO
  160.     gamma = 0.99                # discount factor
  161.    
  162.     lr = 0.0003                 # parameters for Adam optimizer
  163.     betas = (0.9, 0.999)
  164.    
  165.     random_seed = None
  166.     #############################################
  167.    
  168.     # creating environment
  169.     env = gym.make(env_name)
  170.     state_dim = env.observation_space.shape[0]
  171.     action_dim = env.action_space.shape[0]
  172.    
  173.     if random_seed:
  174.         print("Random Seed: {}".format(random_seed))
  175.         torch.manual_seed(random_seed)
  176.         env.seed(random_seed)
  177.         np.random.seed(random_seed)
  178.    
  179.     memory = Memory()
  180.     ppo = PPO(state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip)
  181.     print(lr,betas)
  182.    
  183.     # logging variables
  184.     running_reward = 0
  185.     avg_length = 0
  186.     time_step = 0
  187.    
  188.     # training loop
  189.     for i_episode in range(1, max_episodes+1):
  190.         state = env.reset()
  191.         for t in range(max_timesteps):
  192.             time_step +=1
  193.             # Running policy_old:
  194.             action = ppo.select_action(state, memory)
  195.             state, reward, done, _ = env.step(action)
  196.            
  197.             # Saving reward and is_terminals:
  198.             memory.rewards.append(reward)
  199.             memory.is_terminals.append(done)
  200.            
  201.             # update if its time
  202.             if time_step % update_timestep == 0:
  203.                 ppo.update(memory)
  204.                 memory.clear_memory()
  205.                 time_step = 0
  206.             running_reward += reward
  207.             if render:
  208.                 env.render()
  209.             if done:
  210.                 break
  211.        
  212.         avg_length += t
  213.        
  214.         # # stop training if avg_reward > solved_reward
  215.         # if running_reward > (log_interval*solved_reward):
  216.         #     print("########## Solved! ##########")
  217.         #     torch.save(ppo.policy.state_dict(), './PPO_continuous_solved_{}.pth'.format(env_name))
  218.         #     break
  219.        
  220.         # save every 500 episodes
  221.         if i_episode % 50 == 0:
  222.             torch.save(ppo.policy.state_dict(), './PPO_sigma_{}.pth'.format(env_name))
  223.            
  224.         # logging
  225.         if i_episode % log_interval == 0:
  226.             avg_length = int(avg_length/log_interval)
  227.             running_reward = int((running_reward/log_interval))
  228.            
  229.             print('Episode {} \t Avg length: {} \t Avg reward: {}'.format(i_episode, avg_length, running_reward))
  230.             running_reward = 0
  231.             avg_length = 0
  232.            
  233. if __name__ == '__main__':
  234.     main()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Top