Advertisement
Guest User

Untitled

a guest
Jun 21st, 2018
56
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.77 KB | None | 0 0
  1. def optimize_model():
  2. if len(memory) < BATCH_SIZE:
  3. return
  4. transitions = memory.sample(BATCH_SIZE)
  5. # Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
  6. # detailed explanation).
  7. batch = Transition(*zip(*transitions))
  8.  
  9. # Compute a mask of non-final states and concatenate the batch elements
  10. non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
  11. batch.next_state)), device=device, dtype=torch.uint8)
  12. non_final_next_states = torch.cat([s for s in batch.next_state
  13. if s is not None])
  14. state_batch = torch.cat(batch.state)
  15. action_batch = torch.cat(batch.action)
  16. reward_batch = torch.cat(batch.reward)
  17.  
  18. # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
  19. # columns of actions taken
  20. state_action_values = policy_net(state_batch).gather(1, action_batch)
  21.  
  22. # Compute V(s_{t+1}) for all next states.
  23. next_state_values = torch.zeros(BATCH_SIZE, device=device)
  24. next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
  25. # Compute the expected Q values
  26. expected_state_action_values = (next_state_values * GAMMA) + reward_batch
  27.  
  28. # Compute Huber loss
  29. loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
  30.  
  31. # Optimize the model
  32. optimizer.zero_grad()
  33. loss.backward()
  34. for param in policy_net.parameters():
  35. param.grad.data.clamp_(-1, 1)
  36. optimizer.step()
  37.  
  38. num_episodes = 50
  39. for i_episode in range(num_episodes):
  40. # Initialize the environment and state
  41. env.reset()
  42. last_screen = get_screen()
  43. current_screen = get_screen()
  44. state = current_screen - last_screen
  45. for t in count():
  46. # Select and perform an action
  47. action = select_action(state)
  48. _, reward, done, _ = env.step(action.item())
  49. reward = torch.tensor([reward], device=device)
  50.  
  51. # Observe new state
  52. last_screen = current_screen
  53. current_screen = get_screen()
  54. if not done:
  55. next_state = current_screen - last_screen
  56. else:
  57. next_state = None
  58.  
  59. # Store the transition in memory
  60. memory.push(state, action, next_state, reward)
  61.  
  62. # Move to the next state
  63. state = next_state
  64.  
  65. # Perform one step of the optimization (on the target network)
  66. optimize_model()
  67. if done:
  68. episode_durations.append(t + 1)
  69. plot_durations()
  70. break
  71. # Update the target network
  72. if i_episode % TARGET_UPDATE == 0:
  73. target_net.load_state_dict(policy_net.state_dict())
  74.  
  75. print('Complete')
  76. env.render()
  77. env.close()
  78. plt.ioff()
  79. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement