Advertisement
ec1117

Untitled

Jul 7th, 2021
817
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.26 KB | None | 0 0
  1. from keras.models import Sequential, save_model, load_model
  2. from keras.layers import Dense
  3. from collections import deque
  4. import numpy as np
  5. import random
  6.  
  7. # Deep Q Learning Agent + Maximin
  8. #
  9. # This version only provides only value per input,
  10. # that indicates the score expected in that state.
  11. # This is because the algorithm will try to find the
  12. # best final state for the combinations of possible states,
  13. # in constrast to the traditional way of finding the best
  14. # action for a particular state.
  15. class DQNAgent:
  16.  
  17.     '''Deep Q Learning Agent + Maximin
  18.  
  19.    Args:
  20.        state_size (int): Size of the input domain
  21.        mem_size (int): Size of the replay buffer
  22.        discount (float): How important is the future rewards compared to the immediate ones [0,1]
  23.        epsilon (float): Exploration (probability of random values given) value at the start
  24.        epsilon_min (float): At what epsilon value the agent stops decrementing it
  25.        epsilon_stop_episode (int): At what episode the agent stops decreasing the exploration variable
  26.        n_neurons (list(int)): List with the number of neurons in each inner layer
  27.        activations (list): List with the activations used in each inner layer, as well as the output
  28.        loss (obj): Loss function
  29.        optimizer (obj): Otimizer used
  30.        replay_start_size: Minimum size needed to train
  31.    '''
  32.  
  33.     def __init__(self, state_size, mem_size=10000, discount=0.95,
  34.                  epsilon=1, epsilon_min=0, epsilon_stop_episode=500,
  35.                  n_neurons=[32,32], activations=['relu', 'relu', 'linear'],
  36.                  loss='mse', optimizer='adam', replay_start_size=None):
  37.  
  38.         assert len(activations) == len(n_neurons) + 1
  39.  
  40.         self.state_size = state_size
  41.         self.memory = deque(maxlen=mem_size)
  42.         self.discount = discount
  43.         self.epsilon = epsilon
  44.         self.epsilon_min = epsilon_min
  45.         self.epsilon_decay = (self.epsilon - self.epsilon_min) / (epsilon_stop_episode)
  46.         self.n_neurons = n_neurons
  47.         self.activations = activations
  48.         self.loss = loss
  49.         self.optimizer = optimizer
  50.         if not replay_start_size:
  51.             replay_start_size = mem_size / 2
  52.         self.replay_start_size = replay_start_size
  53.         self.model = self._build_model()
  54.  
  55.  
  56.     def _build_model(self):
  57.         '''Builds a Keras deep neural network model'''
  58.         model = Sequential()
  59.         model.add(Dense(self.n_neurons[0], input_dim=self.state_size, activation=self.activations[0]))
  60.  
  61.         for i in range(1, len(self.n_neurons)):
  62.             model.add(Dense(self.n_neurons[i], activation=self.activations[i]))
  63.  
  64.         model.add(Dense(1, activation=self.activations[-1]))
  65.  
  66.         model.compile(loss=self.loss, optimizer=self.optimizer)
  67.        
  68.         return model
  69.  
  70.  
  71.     def add_to_memory(self, current_state, next_state, reward, done):
  72.         '''Adds a play to the replay memory buffer'''
  73.         self.memory.append((current_state, next_state, reward, done))
  74.  
  75.  
  76.     def random_value(self):
  77.         '''Random score for a certain action'''
  78.         return random.random()
  79.  
  80.  
  81.     def predict_value(self, state):
  82.         '''Predicts the score for a certain state'''
  83.         return self.model.predict(state)[0]
  84.  
  85.  
  86.     def act(self, state):
  87.         '''Returns the expected score of a certain state'''
  88.         state = np.reshape(state, [1, self.state_size])
  89.         if random.random() <= self.epsilon:
  90.             return self.random_value()
  91.         else:
  92.             return self.predict_value(state)
  93.  
  94.  
  95.     def best_state(self, states):
  96.         '''Returns the best state for a given collection of states'''
  97.         max_value = None
  98.         best_state = None
  99.  
  100.         if random.random() <= self.epsilon:
  101.             return random.choice(list(states))
  102.  
  103.         else:
  104.             for state in states:
  105.                 value = self.predict_value(np.reshape(state, [1, self.state_size]))
  106.                 if not max_value or value > max_value:
  107.                     max_value = value
  108.                     best_state = state
  109.  
  110.         return best_state
  111.  
  112.  
  113.     def train(self, batch_size=32, epochs=3):
  114.         '''Trains the agent'''
  115.         n = len(self.memory)
  116.    
  117.         if n >= self.replay_start_size and n >= batch_size:
  118.  
  119.             batch = random.sample(self.memory, batch_size)
  120.  
  121.             # Get the expected score for the next states, in batch (better performance)
  122.             next_states = np.array([x[1] for x in batch])
  123.             next_qs = [x[0] for x in self.model.predict(next_states)]
  124.  
  125.             x = []
  126.             y = []
  127.  
  128.             # Build xy structure to fit the model in batch (better performance)
  129.             for i, (state, _, reward, done) in enumerate(batch):
  130.                 if not done:
  131.                     # Partial Q formula
  132.                     new_q = reward + self.discount * next_qs[i]
  133.                 else:
  134.                     new_q = reward
  135.  
  136.                 x.append(state)
  137.                 y.append(new_q)
  138.  
  139.             # Fit the model to the given values
  140.             self.model.fit(np.array(x), np.array(y), batch_size=batch_size, epochs=epochs, verbose=0)
  141.  
  142.             # Update the exploration variable
  143.             if self.epsilon > self.epsilon_min:
  144.                 self.epsilon -= self.epsilon_decay
  145.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement