Advertisement
Guest User

Untitled

a guest
Apr 23rd, 2019
115
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.29 KB | None | 0 0
  1. class QNET():
  2. def batch_train(self, batch_size=64):
  3. """Implement Double DQN Algorithm, batch training"""
  4. if self.exp.get_num() < self.exp.get_min():
  5. #The number of experiences is not enough for batch training
  6. return
  7.  
  8. # get a batch of experiences
  9. state, action, reward, next_state, done = self.exp.get_batch(batch_size)
  10. state = state.reshape(batch_size, self.in_units)
  11. next_state = next_state.reshape(batch_size, self.in_units)
  12.  
  13. # get actions by Q-network
  14. qnet_q_values = self.session.run(self.q, feed_dict={self.x:next_state})
  15. qnet_actions = np.argmax(qnet_q_values, axis=1)
  16.  
  17. # calculate estimated Q-values with qnet_actions by using Target-network
  18. tnet_q_values = self.session.run(self.tnet.q, feed_dict={self.tnet.x:next_state})
  19. tnet_q = [np.take(tnet_q_values[i], qnet_actions[i]) for i in range(batch_size)]
  20.  
  21. # Update Q-values of Q-network
  22. qnet_update_q = [r+0.95*q if not d else r for r, q, d in zip(reward, tnet_q, done)]
  23.  
  24. # optimization
  25. indices=[[i,action[i]] for i in range(batch_size)]
  26. feed_dict={self.x:state, self.target:qnet_update_q, self.selected_idx:indices}
  27. self.session.run(self.train_opt, feed_dict)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement