Advertisement
NLinker

Monte Carlo Tree Search

Dec 13th, 2017
279
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.25 KB | None | 0 0
  1. from __future__ import division
  2.  
  3. import time
  4. from math import log, sqrt
  5. from random import choice
  6.  
  7.  
  8. class Stat(object):
  9.     __slots__ = ('value', 'visits')
  10.  
  11.     def __init__(self, value=0, visits=0):
  12.         self.value = value
  13.         self.visits = visits
  14.  
  15.  
  16. class UCT(object):
  17.     def __init__(self, board, **kwargs):
  18.         self.board = board
  19.         self.history = []
  20.         self.stats = {}
  21.  
  22.         self.max_depth = 0
  23.         self.data = {}
  24.  
  25.         self.calculation_time = float(kwargs.get('time', 30))
  26.         self.max_actions = int(kwargs.get('max_actions', 1000))
  27.  
  28.         # Exploration constant, increase for more exploratory actions,
  29.         # decrease to prefer actions with known higher win rates.
  30.         self.C = float(kwargs.get('C', 1.4))
  31.  
  32.     def update(self, state):
  33.         self.history.append(self.board.pack_state(state))
  34.  
  35.     def display(self, state, action):
  36.         state = self.board.pack_state(state)
  37.         action = self.board.pack_action(action)
  38.         return self.board.display(state, action)
  39.  
  40.     def winner_message(self, winners):
  41.         return self.board.winner_message(winners)
  42.  
  43.     def get_action(self):
  44.         # Causes the AI to calculate the best action from the
  45.         # current game state and return it.
  46.  
  47.         self.max_depth = 0
  48.         self.data = {}
  49.         self.stats.clear()
  50.  
  51.         state = self.history[-1]
  52.         player = self.board.current_player(state)
  53.         legal = self.board.legal_actions(self.history[:])
  54.  
  55.         # Bail out early if there is no real choice to be made.
  56.         if not legal:
  57.             return
  58.         if len(legal) == 1:
  59.             return self.board.unpack_action(legal[0])
  60.  
  61.         games = 0
  62.         begin = time.time()
  63.         while time.time() - begin < self.calculation_time:
  64.             self.run_simulation()
  65.             games += 1
  66.  
  67.         # Display the number of calls of `run_simulation` and the
  68.         # time elapsed.
  69.         self.data.update(games=games, max_depth=self.max_depth,
  70.                          time=str(time.time() - begin))
  71.         print self.data['games'], self.data['time']
  72.         print "Maximum depth searched:", self.max_depth
  73.  
  74.         # Store and display the stats for each possible action.
  75.         self.data['actions'] = self.calculate_action_values(state, player, legal)
  76.         for m in self.data['actions']:
  77.             print self.action_template.format(**m)
  78.  
  79.         # Pick the action with the highest average value.
  80.         return self.board.unpack_action(self.data['actions'][0]['action'])
  81.  
  82.     def run_simulation(self):
  83.         # Plays out a "random" game from the current position,
  84.         # then updates the statistics tables with the result.
  85.  
  86.         # A bit of an optimization here, so we have a local
  87.         # variable lookup instead of an attribute access each loop.
  88.         stats = self.stats
  89.  
  90.         visited_states = set()
  91.         history_copy = self.history[:]
  92.         state = history_copy[-1]
  93.         player = self.board.current_player(state)
  94.  
  95.         expand = True
  96.         for t in xrange(1, self.max_actions + 1):
  97.             legal = self.board.legal_actions(history_copy)
  98.             actions_states = [(p, self.board.next_state(state, p)) for p in legal]
  99.  
  100.             if all((player, S) in stats for p, S in actions_states):
  101.                 # If we have stats on all of the legal actions here, use UCB1.
  102.                 log_total = log(
  103.                     sum(stats[(player, S)].visits for p, S in actions_states) or 1)
  104.                 value, action, state = max(
  105.                     ((stats[(player, S)].value / (stats[(player, S)].visits or 1)) +
  106.                      self.C * sqrt(log_total / (stats[(player, S)].visits or 1)), p, S)
  107.                     for p, S in actions_states
  108.                 )
  109.             else:
  110.                 # Otherwise, just make an arbitrary decision.
  111.                 action, state = choice(actions_states)
  112.  
  113.             history_copy.append(state)
  114.  
  115.             # `player` here and below refers to the player
  116.             # who moved into that particular state.
  117.             if expand and (player, state) not in stats:
  118.                 expand = False
  119.                 stats[(player, state)] = Stat()
  120.                 if t > self.max_depth:
  121.                     self.max_depth = t
  122.  
  123.             visited_states.add((player, state))
  124.  
  125.             player = self.board.current_player(state)
  126.             if self.board.is_ended(history_copy):
  127.                 break
  128.  
  129.         # Back-propagation
  130.         end_values = self.end_values(history_copy)
  131.         for player, state in visited_states:
  132.             if (player, state) not in stats:
  133.                 continue
  134.             S = stats[(player, state)]
  135.             S.visits += 1
  136.             S.value += end_values[player]
  137.  
  138.  
  139. class UCTWins(UCT):
  140.     action_template = "{action}: {percent:.2f}% ({wins} / {plays})"
  141.  
  142.     def __init__(self, board, **kwargs):
  143.         super(UCTWins, self).__init__(board, **kwargs)
  144.         self.end_values = board.win_values
  145.  
  146.     def calculate_action_values(self, state, player, legal):
  147.         actions_states = ((p, self.board.next_state(state, p)) for p in legal)
  148.         return sorted(
  149.             ({'action': p,
  150.               'percent': 100 * self.stats[(player, S)].value / self.stats[(player, S)].visits,
  151.               'wins': self.stats[(player, S)].value,
  152.               'plays': self.stats[(player, S)].visits}
  153.              for p, S in actions_states),
  154.             key=lambda x: (x['percent'], x['plays']),
  155.             reverse=True
  156.         )
  157.  
  158.  
  159. class UCTValues(UCT):
  160.     action_template = "{action}: {average:.1f} ({sum} / {plays})"
  161.  
  162.     def __init__(self, board, **kwargs):
  163.         super(UCTValues, self).__init__(board, **kwargs)
  164.         self.end_values = board.points_values
  165.  
  166.     def calculate_action_values(self, state, player, legal):
  167.         actions_states = ((p, self.board.next_state(state, p)) for p in legal)
  168.         return sorted(
  169.             ({'action': p,
  170.               'average': self.stats[(player, S)].value / self.stats[(player, S)].visits,
  171.               'sum': self.stats[(player, S)].value,
  172.               'plays': self.stats[(player, S)].visits}
  173.              for p, S in actions_states),
  174.             key=lambda x: (x['average'], x['plays']),
  175.             reverse=True
  176.         )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement