Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def optimize_model():
- if len(memory) < BATCH_SIZE:
- return
- transitions = memory.sample(BATCH_SIZE)
- # Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
- # detailed explanation).
- batch = Transition(*zip(*transitions))
- # Compute a mask of non-final states and concatenate the batch elements
- non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
- batch.next_state)), device=device, dtype=torch.uint8)
- non_final_next_states = torch.cat([s for s in batch.next_state
- if s is not None])
- state_batch = torch.cat(batch.state)
- action_batch = torch.cat(batch.action)
- reward_batch = torch.cat(batch.reward)
- # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
- # columns of actions taken
- state_action_values = policy_net(state_batch).gather(1, action_batch)
- # Compute V(s_{t+1}) for all next states.
- next_state_values = torch.zeros(BATCH_SIZE, device=device)
- next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
- # Compute the expected Q values
- expected_state_action_values = (next_state_values * GAMMA) + reward_batch
- # Compute Huber loss
- loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
- # Optimize the model
- optimizer.zero_grad()
- loss.backward()
- for param in policy_net.parameters():
- param.grad.data.clamp_(-1, 1)
- optimizer.step()
- num_episodes = 50
- for i_episode in range(num_episodes):
- # Initialize the environment and state
- env.reset()
- last_screen = get_screen()
- current_screen = get_screen()
- state = current_screen - last_screen
- for t in count():
- # Select and perform an action
- action = select_action(state)
- _, reward, done, _ = env.step(action.item())
- reward = torch.tensor([reward], device=device)
- # Observe new state
- last_screen = current_screen
- current_screen = get_screen()
- if not done:
- next_state = current_screen - last_screen
- else:
- next_state = None
- # Store the transition in memory
- memory.push(state, action, next_state, reward)
- # Move to the next state
- state = next_state
- # Perform one step of the optimization (on the target network)
- optimize_model()
- if done:
- episode_durations.append(t + 1)
- plot_durations()
- break
- # Update the target network
- if i_episode % TARGET_UPDATE == 0:
- target_net.load_state_dict(policy_net.state_dict())
- print('Complete')
- env.render()
- env.close()
- plt.ioff()
- plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement