SHARE
TWEET

Untitled

a guest Jun 27th, 2019 60 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. class DQNAgent:
  2.         def __init__(self, state_size, action_size):
  3.             self.state_size = state_size
  4.             self.action_size = action_size
  5.             self.memory = deque(maxlen=2000)
  6.             self.gamma = GAMMA  
  7.             self.epsilon = START_EPSILON  
  8.             self.epsilon_min = EPSILON_MIN
  9.             self.epsilon_decay = EPSILON_DECAY
  10.             self.learning_rate = LEARNING_RATE
  11.             self.model = self.build_model()
  12.             self.target_model = self.build_model()
  13.  
  14.         def build_model(self):
  15.             model = Sequential()
  16.             model.add(Dense(16, input_dim=self.state_size, activation='relu'))
  17.             model.add(Dense(32, activation='relu'))
  18.             model.add(Dense(self.action_size, activation='linear'))
  19.             model.compile(loss='mse',optimizer=Adam(lr=self.learning_rate))
  20.             return model
  21.  
  22.         def remember(self, state, action, reward, next_state, done):
  23.             self.memory.append((state, action, reward, next_state, done))
  24.  
  25.         def act(self, state):
  26.             if np.random.rand() <= self.epsilon:
  27.                 return random.randrange(self.action_size)
  28.             self.model.set_weights(self.target_model.get_weights())
  29.             act_values = self.model.predict(state)
  30.             return np.argmax(act_values[0])  
  31.  
  32.         def replay(self, batch_size):
  33.             minibatch = random.sample(self.memory, batch_size)
  34.             for state, action, reward, next_state, done in minibatch:
  35.                 action_t = np.argmax(self.model.predict(state)[0])
  36.                 target = reward
  37.                 if not done:
  38.                     target = (reward + self.gamma *
  39.                               (self.target_model.predict(next_state)[0][action_t]))
  40.                 target_f = self.target_model.predict(state)
  41.                 target_f[0][action] = target
  42.                 self.target_model.fit(state, target_f, epochs=1, verbose=0)
  43.             if self.epsilon > self.epsilon_min:
  44.                 self.epsilon *= self.epsilon_decay
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top