SHARE
TWEET

Untitled

a guest Jul 19th, 2019 87 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import gym
  2. import torch
  3. import random
  4. import collections
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. import torch.optim as optim
  8. import time
  9.  
  10.  
  11. #Hyperparameters
  12. learning_rate = 0.0005
  13. gamma         = 0.999
  14. buffer_limit  = 50000
  15. batch_size    = 32
  16.  
  17.  
  18. class ReplayBuffer():
  19.     def __init__(self):
  20.         self.buffer = collections.deque(maxlen=buffer_limit)
  21.  
  22.     def put(self, transition):
  23.         self.buffer.append(transition)
  24.  
  25.     def sample(self, n):
  26.         mini_batch = random.sample(self.buffer, n)
  27.         s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []
  28.  
  29.         for transition in mini_batch:
  30.             s, a, r, s_prime, done_mask = transition
  31.             s_lst.append(s)
  32.             a_lst.append([a])
  33.             r_lst.append([r])
  34.             s_prime_lst.append(s_prime)
  35.             done_mask_lst.append([done_mask])
  36.  
  37.         return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
  38.                torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
  39.                torch.tensor(done_mask_lst)
  40.  
  41.     def size(self):
  42.         return len(self.buffer)
  43.  
  44.  
  45. class Qnet(nn.Module):
  46.     def __init__(self):
  47.         super(Qnet, self).__init__()
  48.         self.fc1 = nn.Linear(2, 200)
  49.         self.fc2 = nn.Linear(200, 3)
  50.  
  51.     def forward(self, x):
  52.         x = F.relu(self.fc1(x))
  53.         x = self.fc2(x)
  54.         return x
  55.  
  56.  
  57. q = Qnet()
  58. q.load_state_dict(torch.load("model"))
  59. q_target = Qnet()
  60. q_target.load_state_dict(q.state_dict())
  61.  
  62. env = gym.make('MountainCar-v0')
  63. memory = ReplayBuffer()
  64.  
  65. print_interval = 10
  66. score = 0.0
  67. optimizer = optim.Adam(q.parameters(), lr=learning_rate)
  68. loss = 0
  69.  
  70. for n_epi in range(10000):
  71.     exploration = 0.05 + 0.5/((1+n_epi)**0.5)
  72.  
  73.     s = env.reset()
  74.  
  75.     for t in range(600):
  76.         if random.random() < exploration:
  77.             a = torch.tensor(random.randint(0,2))
  78.         else:
  79.             a = torch.argmax(q.forward(torch.from_numpy(s).float()))
  80.  
  81.         s_prime, r, done, info = env.step(a.item())
  82.         done_mask = 0.0 if done else 1.0
  83.         memory.put((s, a.item(), r/100, s_prime, done_mask))
  84.         s = s_prime
  85.         score += r
  86.  
  87.         if done:
  88.             break
  89.  
  90.     if memory.size() >= 2000:
  91.         s,a,r,s_prime,done_mask = memory.sample(batch_size)
  92.         q_out = q(s)
  93.         q_a = q_out.gather(1,a)
  94.         q_target_out = q_target(s_prime)
  95.         # max_q_prime = q_target_out.gather(1,a)
  96.         max_q_prime = q_target_out.max(1)[0].unsqueeze(1)
  97.         target = r + gamma*max_q_prime*done_mask
  98.  
  99.         loss = F.smooth_l1_loss(q_a, target)
  100.  
  101.         optimizer.zero_grad()
  102.         loss.backward()
  103.         optimizer.step()
  104.  
  105.     if n_epi%print_interval == 0 and n_epi != 0:
  106.         q_target.load_state_dict(q.state_dict())
  107.         print("n_epi = ", n_epi, " loss = ", loss, " score = ", score/print_interval)
  108.         score = 0
  109.  
  110.     if n_epi%1000 == 0 and n_epi != 0:
  111.         s = env.reset()
  112.  
  113.         for step_index in range(200):
  114.             env.render()
  115.             action = torch.argmax(q.forward(torch.from_numpy(s).float()))
  116.             s, reward, done, info = env.step(action.item())
  117.             time.sleep(0.0005)
  118.             if done:
  119.                 break
  120.  
  121.         env.close()
  122.  
  123.  
  124. torch.save(q.state_dict(), "model")
  125. env.close()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top