Advertisement
Guest User

Untitled

a guest
Dec 11th, 2019
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.31 KB | None | 0 0
  1. from buffer import Buffer, build_dqn, load_game
  2. import numpy as np
  3.  
  4.  
  5. class Agent:
  6. def __init__(self, alpha, gamma, n_actions, epsilon, batch_size, input_shape, epsilon_dec, epsilon_end,
  7. memory_size,file_name,activations):
  8. print('cacat')
  9.  
  10. self.alpha = alpha
  11. self.gamma = gamma
  12. self.epsilon = epsilon
  13. self.epsilon_dec = epsilon_dec
  14. self.epsilon_end = epsilon_end
  15. self.action_space = [i for i in range(n_actions)]
  16. self.batch_size = batch_size
  17. self.experiences = Buffer(memory_size, input_shape, n_actions)
  18.  
  19. self.q_eval = build_dqn(alpha, n_actions, input_shape, 256, 256,activations)
  20. self.file=file_name
  21.  
  22. def add_experience(self, state, action, reward, new_state):
  23. self.experiences.store_transition(state, new_state, reward, action)
  24.  
  25. def choose_action(self, state):
  26. state = state[np.newaxis, :]
  27. rand = np.random.random()
  28. if rand < self.epsilon:
  29. action = np.random.choice(self.action_space)
  30. else:
  31. # print(state)
  32. actions = self.q_eval.predict(state)
  33. action = np.argmax(actions)
  34.  
  35. return action
  36.  
  37. def learn(self):
  38. if self.experiences.contor > self.batch_size:
  39. state, action, reward, new_state = self.experiences.get_batch(self.batch_size)
  40.  
  41. # action_values = np.array(self.action_space, dtype=np.int8)
  42. # print(action_values, self.action_space)
  43. action_indices = np.dot(action, self.action_space)
  44. # print(action, action_indices)
  45.  
  46. target = self.q_eval.predict(state)
  47. # print(new_state)
  48. new_values = self.q_eval.predict(new_state)
  49.  
  50. batch_index = np.arange(self.batch_size, dtype=np.int32)
  51. # print(batch_index)
  52.  
  53. target[batch_index, action_indices] = reward + self.gamma * np.max(new_values, axis=1)
  54. # print(target[batch_index, action_indices])
  55.  
  56. self.q_eval.fit(state, target, verbose=0)
  57. self.epsilon = self.epsilon * self.epsilon_dec if self.epsilon > self.epsilon_end else self.epsilon_end
  58.  
  59. def save_game(self):
  60. self.q_eval.save(str(self.file))
  61.  
  62. def load_game(self):
  63. self.q_eval= load_game(self.file)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement