Guest User

Untitled

a guest
Jan 18th, 2018
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.72 KB | None | 0 0
  1. --- Runs Counterfactual Regret Minimization (CFR) to approximately
  2. -- solve a game represented by a complete game tree.
  3. --
  4. -- As this class does full solving from the root of the game with no
  5. -- limited lookahead, it is not used in continual re-solving. It is provided
  6. -- simply for convenience.
  7. -- @classmod tree_cfr
  8.  
  9. local arguments = require 'Settings.arguments'
  10. local constants = require 'Settings.constants'
  11. local game_settings = require 'Settings.game_settings'
  12. local card_tools = require 'Game.card_tools'
  13. require 'TerminalEquity.terminal_equity'
  14.  
  15. local TreeCFR = torch.class('TreeCFR')
  16.  
  17. --- Constructor
  18. function TreeCFR:__init()
  19. --for ease of implementation, we use small epsilon rather than zero when working with regrets
  20. self.regret_epsilon = 1/1000000000
  21. self._cached_terminal_equities = {}
  22. end
  23.  
  24. --- Gets an evaluator for player equities at a terminal node.
  25. --
  26. -- Caches the result to minimize creation of @{terminal_equity|TerminalEquity}
  27. -- objects.
  28. -- @param node the terminal node to evaluate
  29. -- @return a @{terminal_equity|TerminalEquity} evaluator for the node
  30. -- @local
  31. function TreeCFR:_get_terminal_equity(node)
  32. local cached = self._cached_terminal_equities[node.board]
  33. if cached == nil then
  34. cached = TerminalEquity()
  35. cached:set_board(node.board)
  36. self._cached_terminal_equities[node.board] = cached
  37. end
  38.  
  39. return cached
  40. end
  41.  
  42. --- Recursively walks the tree, applying the CFR algorithm.
  43. -- @param node the current node in the tree
  44. -- @param iter the current iteration number
  45. -- @local
  46. function TreeCFR:cfrs_iter_dfs( node, iter )
  47.  
  48. assert(node.current_player == constants.players.P1 or node.current_player == constants.players.P2 or node.current_player == constants.players.chance)
  49.  
  50. local opponent_index = 3 - node.current_player
  51.  
  52. --dimensions in tensor
  53. local action_dimension = 1
  54. local card_dimension = 2
  55.  
  56. --compute values using terminal_equity in terminal nodes
  57. if(node.terminal) then
  58.  
  59. local terminal_equity = self:_get_terminal_equity(node)
  60.  
  61. local values = node.ranges_absolute:clone():fill(0)
  62.  
  63. if(node.type == constants.node_types.terminal_fold) then
  64. terminal_equity:tree_node_fold_value(node.ranges_absolute, values, opponent_index)
  65. else
  66. terminal_equity:tree_node_call_value(node.ranges_absolute, values)
  67. end
  68.  
  69. --multiply by the pot
  70. values = values * node.pot
  71. node.cf_values = values:viewAs(node.ranges_absolute)
  72. else
  73.  
  74. local actions_count = #node.children
  75. local current_strategy = nil
  76.  
  77. if node.current_player == constants.players.chance then
  78. current_strategy = node.strategy
  79. else
  80. --we have to compute current strategy at the beginning of each iteraton
  81.  
  82. --initialize regrets in the first iteration
  83. node.regrets = node.regrets or arguments.Tensor(actions_count, game_settings.card_count):fill(self.regret_epsilon) --[[actions_count x card_count]]
  84. node.possitive_regrets = node.possitive_regrets or arguments.Tensor(actions_count, game_settings.card_count):fill(self.regret_epsilon)
  85.  
  86. --compute positive regrets so that we can compute the current strategy fromm them
  87. node.possitive_regrets:copy(node.regrets)
  88. node.possitive_regrets[torch.le(node.possitive_regrets, self.regret_epsilon)] = self.regret_epsilon
  89.  
  90. --compute the current strategy
  91. local regrets_sum = node.possitive_regrets:sum(action_dimension)
  92. current_strategy = node.possitive_regrets:clone()
  93. current_strategy:cdiv(regrets_sum:expandAs(current_strategy))
  94. end
  95.  
  96. --current cfv [[actions, players, ranges]]
  97. local cf_values_allactions = arguments.Tensor(actions_count, constants.players_count, game_settings.card_count):fill(0)
  98.  
  99. local children_ranges_absolute = {}
  100.  
  101. if node.current_player == constants.players.chance then
  102. local ranges_mul_matrix = node.ranges_absolute[1]:repeatTensor(actions_count, 1)
  103. children_ranges_absolute[1] = torch.cmul(current_strategy, ranges_mul_matrix)
  104.  
  105. ranges_mul_matrix = node.ranges_absolute[2]:repeatTensor(actions_count, 1)
  106. children_ranges_absolute[2] = torch.cmul(current_strategy, ranges_mul_matrix)
  107. else
  108. local ranges_mul_matrix = node.ranges_absolute[node.current_player]:repeatTensor(actions_count, 1)
  109. children_ranges_absolute[node.current_player] = torch.cmul(current_strategy, ranges_mul_matrix)
  110.  
  111. children_ranges_absolute[opponent_index] = node.ranges_absolute[opponent_index]:repeatTensor(actions_count, 1):clone()
  112. end
  113.  
  114. for i = 1,#node.children do
  115. local child_node = node.children[i]
  116. --set new absolute ranges (after the action) for the child
  117. child_node.ranges_absolute = node.ranges_absolute:clone()
  118.  
  119. child_node.ranges_absolute[1]:copy(children_ranges_absolute[1][{i}])
  120. child_node.ranges_absolute[2]:copy(children_ranges_absolute[2][{i}])
  121. self:cfrs_iter_dfs(child_node, iter, card_count)
  122. cf_values_allactions[i] = child_node.cf_values
  123. end
  124.  
  125. -- TODO: What is the exact role of cf_values?
  126. -- -- Is this cfvalue (in the on-going iteration) per each possible state in the range?
  127. node.cf_values = arguments.Tensor(constants.players_count, game_settings.card_count):fill(0)
  128.  
  129. if node.current_player ~= constants.players.chance then
  130. local strategy_mul_matrix = current_strategy:viewAs(arguments.Tensor(actions_count, game_settings.card_count))
  131.  
  132. node.cf_values[node.current_player] = torch.cmul(strategy_mul_matrix, cf_values_allactions[{{}, node.current_player, {}}]):sum(1)
  133.  
  134. -- TODO: Why sum? Why not take the max cfvalue (worst-case opponent)?
  135. node.cf_values[opponent_index] = (cf_values_allactions[{{}, opponent_index, {}}]):sum(1)
  136. else
  137. node.cf_values[1] = (cf_values_allactions[{{}, 1, {}}]):sum(1)
  138. node.cf_values[2] = (cf_values_allactions[{{}, 2, {}}]):sum(1)
  139. end
  140.  
  141. if node.current_player ~= constants.players.chance then
  142. --computing regrets
  143.  
  144. -- NOTE: Why reshape?
  145. -- -- To squeeze() the dimension of players.
  146. local current_regrets = cf_values_allactions[{{}, {node.current_player}, {}}]:reshape(actions_count, game_settings.card_count):clone()
  147.  
  148. -- TODO: Why subtract?
  149. -- -- Because regret (current_regrets) means the difference in each pure action's cfvalue (cf_values_allactions)
  150. -- -- against the cfvalue of current_player's mixed strategy (node.cf_values)?
  151. -- NOTE: Why view()?
  152. -- -- To squeeze() the dimension of players.
  153. -- NOTE: Why expandAs(), why to replicate cfvalues actions_count-times?
  154. -- -- Because we replicate the "node.cf_values" tensor "action_count-times" to compare with cf_values_allactions
  155. -- -- (see "Why subtract?" above)
  156. current_regrets:csub(node.cf_values[node.current_player]:view(1, game_settings.card_count):expandAs(current_regrets))
  157.  
  158. self:update_regrets(node, current_regrets)
  159.  
  160. --accumulating average strategy
  161. self:update_average_strategy(node, current_strategy, iter)
  162. end
  163. end
  164. end
  165.  
  166. --- Update a node's total regrets with the current iteration regrets.
  167. -- @param node the node to update
  168. -- @param current_regrets the regrets from the current iteration of CFR
  169. -- @local
  170. function TreeCFR:update_regrets(node, current_regrets)
  171. --node.regrets:add(current_regrets)
  172. --local negative_regrets = node.regrets[node.regrets:lt(0)]
  173. --node.regrets[node.regrets:lt(0)] = negative_regrets
  174. node.regrets:add(current_regrets)
  175. node.regrets[torch.le(node.regrets, self.regret_epsilon)] = self.regret_epsilon
  176. end
  177.  
  178. --- Update a node's average strategy with the current iteration strategy.
  179. -- @param node the node to update
  180. -- @param current_strategy the CFR strategy for the current iteration
  181. -- @param iter the iteration number of the current CFR iteration
  182. function TreeCFR:update_average_strategy(node, current_strategy, iter)
  183. if iter > arguments.cfr_skip_iters then
  184. node.strategy = node.strategy or arguments.Tensor(actions_count, game_settings.card_count):fill(0)
  185.  
  186. -- TODO: What is the exact role of iter_weight_contribution?
  187. local iter_weight_contribution = node.ranges_absolute[node.current_player]:clone()
  188. iter_weight_contribution[torch.le(iter_weight_contribution, 0)] = self.regret_epsilon
  189.  
  190. node.iter_weight_sum = node.iter_weight_sum or arguments.Tensor(game_settings.card_count):fill(0)
  191. -- TODO: Does iter_weight_sum keep track of sum of all weight
  192. node.iter_weight_sum:add(iter_weight_contribution)
  193. local iter_weight = torch.cdiv(iter_weight_contribution, node.iter_weight_sum)
  194.  
  195. -- TODO:
  196. -- -- new_strategy <- ( old_strategy_scale .* old_strategy) + strategy_addition
  197. -- == ( old_strategy_scale .* old_strategy) + (expanded_weight .* current_strategy)
  198. -- == ((-expanded_weight + 1) .* old_strategy) + (expanded_weight .* current_strategy)
  199. -- == old_strategy + (expanded_weight .* (current_strategy - old_strategy))
  200. local expanded_weight = iter_weight:view(1, game_settings.card_count):expandAs(node.strategy)
  201. local old_strategy_scale = expanded_weight * (-1) + 1 --same as 1 - expanded weight
  202. node.strategy:cmul(old_strategy_scale)
  203. local strategy_addition = current_strategy:cmul(expanded_weight)
  204. node.strategy:add(strategy_addition)
  205. end
  206. end
  207.  
  208. --- Run CFR to solve the given game tree.
  209. -- @param root the root node of the tree to solve.
  210. -- @param[opt] starting_ranges probability vectors over player private hands
  211. -- at the root node (default uniform)
  212. -- @param[opt] iter_count the number of iterations to run CFR for
  213. -- (default @{arguments.cfr_iters})
  214. function TreeCFR:run_cfr( root, starting_ranges, iter_count )
  215.  
  216. assert(starting_ranges)
  217. local iter_count = iter_count or arguments.cfr_iters
  218.  
  219. root.ranges_absolute = starting_ranges
  220.  
  221. for iter = 1,iter_count do
  222. self:cfrs_iter_dfs(root, iter)
  223. end
  224. end
Add Comment
Please, Sign In to add comment