Advertisement
Guest User

Untitled

a guest
Feb 15th, 2021
495
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.19 KB | None | 0 0
  1. import random
  2. import numpy as np
  3.  
  4.  
  5. class Agent():
  6.     def __init__(self, player1=True, epsilon=1, eps_dec=0.000001, eps_min=0.08):
  7.         if player1:
  8.             self.marker = 1
  9.         else:
  10.             self.marker = -1
  11.         self.epsilon = epsilon # takes random move (epsilon%) of the time
  12.         self.epsilon_min = eps_min
  13.         self.epsilon_dec = eps_dec
  14.  
  15.         self.Q_table = {} # tracks (state, action) tuples to track the value of an agent taking x in action in a state
  16.         self.return_value = {} # tracks total return values for calculating the mean
  17.         self.return_number = {} # tracks total returns for calculating the mean
  18.         self.visited = [] # tracks visited states in an episode
  19.  
  20.     # Updates the agent epsilon (called at the end of each episode)
  21.     def update_epsilon(self):
  22.         self.epsilon -= self.epsilon_dec
  23.         self.epsilon = max(self.epsilon_min, self.epsilon)
  24.  
  25.     def select_action(self, state):
  26.         t_state = tuple(state)
  27.         action = 0
  28.         legal_moves = [idx for idx, i in enumerate(t_state) if i == 0]
  29.         if (t_state, 0) not in self.Q_table: # adding states to the table as agents come to them
  30.             for move in range(9):
  31.                 if (t_state, move) not in self.Q_table:
  32.                     self.Q_table[(t_state, move)] = 0.05 # adding possible moves
  33.                     self.return_value[(t_state, move)] = 0.0
  34.                     self.return_number[(t_state, move)] = 0
  35.  
  36.         if random.random() > self.epsilon: # if choosing 'optimal' move
  37.             action_scores = [self.Q_table[t_state, move] for move in legal_moves] # get each legal action
  38.             action_pos = np.argmax(action_scores) # get the best 'legal' action
  39.             action = legal_moves[action_pos] # find which action it relates to
  40.         else: # if random move
  41.             action = random.choice(legal_moves)
  42.         self.visited.append((t_state, action)) # store a list of 'visited' nodes for the episode
  43.         return action
  44.  
  45.  
  46.     def learn(self, winner):
  47.         # self.marker is -1 for O player and 1 for X player
  48.         # .. if O player gets reward of -1, can be inverted based on marker
  49.         # .. if X player gets reward of -1, it will be negative
  50.  
  51.         # Episode reward
  52.         reward = 0
  53.         if self.marker == winner:
  54.             reward = 1 # reward for a win
  55.         elif self.marker == -winner: # episode was lost by current agent
  56.             reward = -1 # reward for a loss
  57.         else:
  58.             reward = 0 # reward for a draw
  59.  
  60.         # Adding reward to return tables
  61.         if reward != 0:
  62.             for idx, (state, action) in enumerate(self.visited):
  63.                 G = 0
  64.                 discount = 1
  65.                 for t in range(idx):
  66.                     G += reward * discount
  67.                     discount *= 0.99
  68.                     self.return_value[(state, action)] += G # add the return 'value'
  69.  
  70.         for idx, (state, action) in enumerate(self.visited):
  71.             self.return_number[(state, action)] += 1 # state has been visited
  72.  
  73.         # Update states before the end
  74.         for idx, (state, action) in enumerate(self.visited[:-1]):
  75.             next_state, _ = self.visited[idx+1]
  76.             max_Q = max([self.Q_table[(next_state, a)] for a in range(9)])
  77.             self.Q_table[(state, action)] = 0.9 * self.Q_table[(state, action)] + 0.1 * (reward + 0.97 * max_Q - self.Q_table[(state, action)])
  78.  
  79.         # Update the last state
  80.         (last_state, last_action) = self.visited[-1]  # check update
  81.         self.Q_table[(last_state, last_action)] = 0.9 * self.Q_table[(last_state, last_action)] + 0.1 * (reward - self.Q_table[(last_state, last_action)])
  82.  
  83.         # clearing states visited for the episode
  84.         self.visited = []
  85.  
  86.  
  87. class Board():
  88.     def __init__(self):
  89.         self.board = [0, 0, 0, 0, 0, 0, 0, 0, 0]
  90.  
  91.     def reset(self):
  92.         self.board = [0, 0, 0, 0, 0, 0, 0, 0, 0]
  93.  
  94.     def placePiece(self, position, player1=True):
  95.         if self.board[position] != 0:
  96.             print("Invalid move made")
  97.  
  98.         if player1:
  99.             self.board[position] = 1
  100.         else:
  101.             self.board[position] = -1
  102.  
  103.     def __str__(self):
  104.         def lamba(x):
  105.             if x == 1: return 'x'
  106.             elif x == -1: return 'o'
  107.             else: return '_'
  108.  
  109.         rep = list(map(lamba, self.board))
  110.         rep.insert(3, '\n')
  111.         rep.insert(7, '\n')
  112.  
  113.         return ''.join(rep)
  114.  
  115.     def gameDone(self):
  116.         if self.board.count(0) == 0:
  117.             return True
  118.         elif self.gameWinner() != 0:
  119.             return True
  120.         else:
  121.             return False
  122.  
  123.     def gameWinner(self): # since these columns need to go through pos 4, can remove != 0 check for each
  124.         if self.board[4] != 0:
  125.             if self.board[0] == self.board[4] == self.board[8]: # Diagonal
  126.                 return self.board[4]
  127.             elif self.board[2] == self.board[4] == self.board[6]: # Diagonal
  128.                 return self.board[4]
  129.             elif self.board[3] == self.board[4] == self.board[5]: # Row
  130.                 return self.board[4]
  131.             elif self.board[1] == self.board[4] == self.board[7]: # Column
  132.                 return self.board[4]
  133.  
  134.         if self.board[0] != 0:
  135.             if self.board[0] == self.board[1] == self.board[2]: # Row
  136.                 return self.board[0]
  137.             elif self.board[0] == self.board[3] == self.board[6]: # Column
  138.                 return self.board[0]
  139.  
  140.         if self.board[8] != 0:
  141.             if self.board[6] == self.board[7] == self.board[8]: # Row
  142.                 return self.board[8]
  143.             elif self.board[2] == self.board[5] == self.board[8]: # Column
  144.                 return self.board[8]
  145.         return 0
  146.  
  147.  
  148. if __name__ == '__main__':
  149.     player1 = Agent() # player 1 plays X's
  150.     player2 = Agent(player1=False) # player 2 plays O's
  151.     players = [player1, player2]
  152.     wins = [] # [player1 wins, player2 wins, draws]
  153.  
  154.     board = Board()
  155.  
  156.     n_games = 1000000
  157.     for game in range(n_games):
  158.         board.reset() # reset board
  159.         while not board.gameDone():
  160.             board.placePiece(player1.select_action(board.board))
  161.             if board.gameDone():
  162.                 break
  163.             board.placePiece(player2.select_action(board.board), player1=False)
  164.         winner = board.gameWinner() # player1 win: 1, draw: 0, player2 win: -1
  165.         [player.learn(winner) for player in players] # learn from experiences for both players
  166.         [player.update_epsilon() for player in players] # update epsilon for both agents
  167.  
  168.         wins.append(winner)
  169.  
  170.         if game % 10000 == 0 and game != 0:
  171.             print("Played " + str(game) + " games")
  172.             print("X wins for last 10000 games are {}".format((wins[-10000:].count(1) / 10000) * 100))
  173.             print("O wins for last 10000 games are {}".format((wins[-10000:].count(-1) / 10000) * 100))
  174.             print("Draws for last 10000 games are {}".format((wins[-10000:].count(0) / 10000) * 100))
  175.             print("Player epsilon is {}".format(player1.epsilon))
  176.             print("")
  177.  
  178.         if game > 990000:
  179.             print(board)
  180.             print("winner is " + str(wins[-1]))
  181.             print("")
  182.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement