Advertisement
Guest User

Untitled

a guest
Jul 20th, 2019
202
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.04 KB | None | 0 0
  1. def _build_model(self):
  2. model = Sequential()
  3. model.add(Dense(24, input_dim=self.state_size, activation="relu"))
  4. model.add(Dense(24, activation="relu"))
  5. model.add(Dense(self.action_size, activation="linear"))
  6. model.compile(optimizer=Adam(lr=self.learning_rate), loss="mse")
  7. return model
  8.  
  9. def get_action(self, state):
  10. # Use random exploration for the current rate.
  11. if np.random.rand() < self.epsilon:
  12. return random.randrange(self.action_size)
  13.  
  14. # Otherwise use the model to predict the rewards and select the max.
  15. q_values = self.model.predict(state)
  16. return np.argmax(q_values[0])
  17.  
  18. def replay(self, batch_size):
  19. if len(agent.memory) < minibatch_size:
  20. return
  21.  
  22. # Decay the exploration rate.
  23. self.epsilon *= self.epsilon_decay
  24. self.epsilon = max(self.epsilon_min, self.epsilon)
  25.  
  26. minibatch = random.sample(self.memory, minibatch_size)
  27.  
  28. state_batch, q_values_batch = [], []
  29. for state, action, reward, next_state, done in minibatch:
  30. # Get predictions for all actions for the current state.
  31. q_values = self.model.predict(state)
  32.  
  33. # If we're not done, add on the future predicted reward at the discounted rate.
  34. if done:
  35. q_values[0][action] = reward
  36. else:
  37. f = self.target_model.predict(next_state)
  38. future_reward = max(self.target_model.predict(next_state)[0])
  39. q_values[0][action] = reward + self.gamma * future_reward
  40.  
  41. state_batch.append(state[0])
  42. q_values_batch.append(q_values[0])
  43.  
  44. # Re-fit the model to move it closer to this newly calculated reward.
  45. self.model.fit(np.array(state_batch), np.array(q_values_batch), batch_size=batch_size, epochs=1, verbose=0)
  46.  
  47. self.update_weights()
  48.  
  49. def update_weights(self):
  50. weights = self.model.get_weights()
  51. target_weights = self.target_model.get_weights()
  52.  
  53. for i in range(len(target_weights)):
  54. target_weights[i] = weights[i] * self.tau + target_weights[i] * (1 - self.tau)
  55.  
  56. self.target_model.set_weights(target_weights)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement