Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def UCT_search(game_state, num_reads,net,temp):
- root = UCTNode(game_state, move=None, parent=DummyNode())
- for i in range(num_reads):
- leaf = root.select_leaf()
- encoded_s = ed.encode_board(leaf.game); encoded_s = encoded_s.transpose(2,0,1)
- encoded_s = torch.from_numpy(encoded_s).float().cuda()
- child_priors, value_estimate = net(encoded_s)
- child_priors = child_priors.detach().cpu().numpy().reshape(-1); value_estimate = value_estimate.item()
- if leaf.game.check_winner() == True or leaf.game.actions() == []: # if somebody won or draw
- leaf.backup(value_estimate); continue
- leaf.expand(child_priors) # need to make sure valid moves
- leaf.backup(value_estimate)
- return root
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement