Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class QNET():
- def batch_train(self, batch_size=64):
- """Implement Double DQN Algorithm, batch training"""
- if self.exp.get_num() < self.exp.get_min():
- #The number of experiences is not enough for batch training
- return
- # get a batch of experiences
- state, action, reward, next_state, done = self.exp.get_batch(batch_size)
- state = state.reshape(batch_size, self.in_units)
- next_state = next_state.reshape(batch_size, self.in_units)
- # get actions by Q-network
- qnet_q_values = self.session.run(self.q, feed_dict={self.x:next_state})
- qnet_actions = np.argmax(qnet_q_values, axis=1)
- # calculate estimated Q-values with qnet_actions by using Target-network
- tnet_q_values = self.session.run(self.tnet.q, feed_dict={self.tnet.x:next_state})
- tnet_q = [np.take(tnet_q_values[i], qnet_actions[i]) for i in range(batch_size)]
- # Update Q-values of Q-network
- qnet_update_q = [r+0.95*q if not d else r for r, q, d in zip(reward, tnet_q, done)]
- # optimization
- indices=[[i,action[i]] for i in range(batch_size)]
- feed_dict={self.x:state, self.target:qnet_update_q, self.selected_idx:indices}
- self.session.run(self.train_opt, feed_dict)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement