Advertisement
Guest User

Untitled

a guest
Jun 27th, 2019
155
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.00 KB | None | 0 0
  1. class DQNAgent:
  2. def __init__(self, state_size, action_size):
  3. self.state_size = state_size
  4. self.action_size = action_size
  5. self.memory = deque(maxlen=2000)
  6. self.gamma = GAMMA
  7. self.epsilon = START_EPSILON
  8. self.epsilon_min = EPSILON_MIN
  9. self.epsilon_decay = EPSILON_DECAY
  10. self.learning_rate = LEARNING_RATE
  11. self.model = self.build_model()
  12. self.target_model = self.build_model()
  13.  
  14. def build_model(self):
  15. model = Sequential()
  16. model.add(Dense(16, input_dim=self.state_size, activation='relu'))
  17. model.add(Dense(32, activation='relu'))
  18. model.add(Dense(self.action_size, activation='linear'))
  19. model.compile(loss='mse',optimizer=Adam(lr=self.learning_rate))
  20. return model
  21.  
  22. def remember(self, state, action, reward, next_state, done):
  23. self.memory.append((state, action, reward, next_state, done))
  24.  
  25. def act(self, state):
  26. if np.random.rand() <= self.epsilon:
  27. return random.randrange(self.action_size)
  28. self.model.set_weights(self.target_model.get_weights())
  29. act_values = self.model.predict(state)
  30. return np.argmax(act_values[0])
  31.  
  32. def replay(self, batch_size):
  33. minibatch = random.sample(self.memory, batch_size)
  34. for state, action, reward, next_state, done in minibatch:
  35. action_t = np.argmax(self.model.predict(state)[0])
  36. target = reward
  37. if not done:
  38. target = (reward + self.gamma *
  39. (self.target_model.predict(next_state)[0][action_t]))
  40. target_f = self.target_model.predict(state)
  41. target_f[0][action] = target
  42. self.target_model.fit(state, target_f, epochs=1, verbose=0)
  43. if self.epsilon > self.epsilon_min:
  44. self.epsilon *= self.epsilon_decay
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement