Advertisement
Guest User

Untitled

a guest
Jul 18th, 2019
89
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.09 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torch.nn.functional as F
  5. import random
  6. import math
  7. import numpy as np
  8. import gym
  9. import matplotlib.pyplot as plt
  10.  
  11.  
  12.  
  13. class DQN(nn.Module):
  14. def __init__(self, input_dim, output_dim):
  15. super(DQN, self).__init__()
  16. self.linear1 = nn.Linear(input_dim, 16)
  17. self.linear2 = nn.Linear(16, 32)
  18. self.linear3 = nn.Linear(32, 32)
  19. self.linear4 = nn.Linear(32, output_dim)
  20.  
  21.  
  22. def forward(self, x):
  23. x = F.relu(self.linear1(x))
  24. x = F.relu(self.linear2(x))
  25. x = F.relu(self.linear3(x))
  26. return self.linear4(x)
  27.  
  28.  
  29. final_epsilon = 0.05
  30. initial_epsilon = 1
  31. epsilon_decay = 5000
  32. global steps_done
  33. steps_done = 0
  34.  
  35.  
  36. def select_action(state):
  37. global steps_done
  38. sample = random.random()
  39. eps_threshold = final_epsilon + (initial_epsilon - final_epsilon) *
  40. math.exp(-1. * steps_done / epsilon_decay)
  41. if sample > eps_threshold:
  42. with torch.no_grad():
  43. state = torch.Tensor(state)
  44. steps_done += 1
  45. q_calc = model(state)
  46. node_activated = int(torch.argmax(q_calc))
  47. return node_activated
  48. else:
  49. node_activated = random.randint(0,1)
  50. steps_done += 1
  51. return node_activated
  52.  
  53.  
  54. class ReplayMemory(object): # Stores [state, reward, action, next_state, done]
  55.  
  56. def __init__(self, capacity):
  57. self.capacity = capacity
  58. self.memory = [[],[],[],[],[]]
  59.  
  60. def push(self, data):
  61. """Saves a transition."""
  62. for idx, point in enumerate(data):
  63. #print("Col {} appended {}".format(idx, point))
  64. self.memory[idx].append(point)
  65.  
  66. def sample(self, batch_size):
  67. rows = random.sample(range(0, len(self.memory[0])), batch_size)
  68. experiences = [[],[],[],[],[]]
  69. for row in rows:
  70. for col in range(5):
  71. experiences[col].append(self.memory[col][row])
  72. return experiences
  73.  
  74. def __len__(self):
  75. return len(self.memory[0])
  76.  
  77.  
  78. input_dim, output_dim = 4, 2
  79. model = DQN(input_dim, output_dim)
  80. target_net = DQN(input_dim, output_dim)
  81. target_net.load_state_dict(model.state_dict())
  82. target_net.eval()
  83. tau = 1
  84. discount = 0.99
  85.  
  86. learning_rate = 1e-4
  87. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  88.  
  89. memory = ReplayMemory(65536)
  90. BATCH_SIZE = 128
  91.  
  92.  
  93. def optimize_model():
  94. if len(memory) < BATCH_SIZE:
  95. return 0
  96. experiences = memory.sample(BATCH_SIZE)
  97. state_batch = torch.Tensor(experiences[0])
  98. action_batch = torch.LongTensor(experiences[1]).unsqueeze(1)
  99. reward_batch = torch.Tensor(experiences[2])
  100. next_state_batch = torch.Tensor(experiences[3])
  101. done_batch = experiences[4]
  102.  
  103. pred_q = model(state_batch).gather(1, action_batch)
  104.  
  105. next_state_q_vals = torch.zeros(BATCH_SIZE)
  106.  
  107. for idx, next_state in enumerate(next_state_batch):
  108. if done_batch[idx] == True:
  109. next_state_q_vals[idx] = -1
  110. else:
  111. # .max in pytorch returns (values, idx), we only want vals
  112. next_state_q_vals[idx] = ((target_net(next_state_batch[idx]).max(0)[0]).detach())
  113.  
  114. better_pred = (reward_batch + next_state_q_vals).unsqueeze(1)
  115.  
  116. loss = F.smooth_l1_loss(pred_q, better_pred)
  117. optimizer.zero_grad()
  118. loss.backward()
  119. for param in model.parameters():
  120. param.grad.data.clamp_(-1, 1)
  121. optimizer.step()
  122. return loss
  123.  
  124.  
  125. env = gym.make('CartPole-v0')
  126. for i_episode in range(300):
  127. model.train()
  128. target_net.eval()
  129. observation = env.reset()
  130. episode_loss = 0
  131. if i_episode % tau == 0:
  132. target_net.load_state_dict(model.state_dict())
  133. for t in range(200):
  134. #env.render()
  135. state = observation
  136. action = select_action(observation)
  137. observation, reward, done, _ = env.step(action)
  138.  
  139. if done:
  140. next_state = [0,0,0,0]
  141. else:
  142. next_state = observation
  143.  
  144. memory.push([state, action, reward, next_state, done])
  145. optimize_model()
  146. if done:
  147. print("Episode {} finished after {} timesteps".format(i_episode, t+1))
  148. break
  149. env.close()
  150.  
  151. ```
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement