Guest User

Untitled

a guest
Jul 21st, 2018
73
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.38 KB | None | 0 0
  1. #!/usr/bin/env python
  2. import os
  3. import sys
  4. import numpy as np
  5.  
  6. verbose = False
  7.  
  8. def newBoard():
  9. return np.zeros((3,3), np.uint8)
  10.  
  11. # helper functions
  12. def getState(board):
  13. board = board.reshape((9))
  14. ret = 0
  15. for i in range(9):
  16. ret *= 3
  17. ret += board[i]
  18. return ret
  19.  
  20. # value = p(x winning | state,action)
  21. value = np.zeros(((3**9), 9))
  22. value += 0.5
  23.  
  24. if os.path.isfile("values.npy"):
  25. value = np.load("values.npy")
  26.  
  27. def isGameOver(board):
  28. for x in range(3):
  29. if np.all(board[x] == 1) or np.all(board[x] == 2):
  30. return board[x, 0]
  31. if np.all(board[:, x] == 1) or np.all(board[:, x] == 2):
  32. return board[0, x]
  33. d1 = np.array([board[0,0], board[1,1], board[2,2]])
  34. d2 = np.array([board[0,2], board[1,1], board[2,0]])
  35. if np.all(d1 == 1) or np.all(d1 == 2):
  36. return board[1,1]
  37. if np.all(d2 == 1) or np.all(d2 == 2):
  38. return board[1,1]
  39. if np.all(board != 0):
  40. return -1
  41. return 0
  42.  
  43. def sample(a, temperature=1.0):
  44. a = np.array(a)**(1/temperature)
  45. p_sum = a.sum()
  46. sample_temp = a/p_sum
  47. return np.argmax(np.random.multinomial(1, sample_temp, 1))
  48.  
  49. temp = 10.0
  50. def makeMove(board, turn, argmax=False):
  51. state = getState(board)
  52.  
  53. # remove illegal moves
  54. mask = np.zeros(9) + 1.0
  55. board9 = board.reshape(9)
  56. for j in range(9):
  57. if board9[j] != 0:
  58. mask[j] = 0
  59.  
  60. value_norm = np.copy(value[state])
  61. if turn == 2:
  62. value_norm = 1.0 - value_norm
  63. else:
  64. value_norm = value_norm
  65. value_norm *= mask
  66. value_norm /= np.sum(value_norm)
  67. if argmax: #or random.randint(0,4) != 0:
  68. move_choice = np.argmax(value_norm)
  69. else:
  70. move_choice = sample(value_norm, temp)
  71. return move_choice
  72.  
  73. import random
  74. def agentPlay():
  75. board = newBoard()
  76. turn = 1
  77.  
  78. moves = []
  79. while isGameOver(board) == 0:
  80. state = getState(board)
  81. move_choice = makeMove(board, turn)
  82. board.reshape(9)[move_choice] = turn
  83. moves.append((state, move_choice, turn))
  84.  
  85. turn = 2 if turn == 1 else 1
  86.  
  87. gg = isGameOver(board)
  88.  
  89. adj = 0.01
  90. #for move in moves[::-1]:
  91. for move in [random.choice(moves)]:
  92. if gg == -1:
  93. if value[move[0], move[1]] > 0.5:
  94. value[move[0], move[1]] -= adj
  95. elif value[move[0], move[1]] < 0.5:
  96. value[move[0], move[1]] += adj
  97. if gg == 1:
  98. value[move[0], move[1]] += adj
  99. if gg == 2:
  100. value[move[0], move[1]] -= adj
  101. value[move[0], move[1]] = np.clip(value[move[0], move[1]], 0.0001, 0.9999)
  102. #adj *= 0.9
  103. return gg
  104.  
  105.  
  106. def player():
  107. board = newBoard()
  108. while isGameOver(board) == 0:
  109. print board
  110. try:
  111. x,y = raw_input("Move? ").split(",")
  112. x,y = int(x), int(y)
  113. except Exception:
  114. continue
  115.  
  116. if board[y,x] == 0:
  117. board[y,x] = 1
  118. else:
  119. print "illegal move"
  120. continue
  121.  
  122. # computer is o
  123. move_choice = makeMove(board, 2, argmax=True)
  124. board.reshape(9)[move_choice] = 2
  125. print board
  126. print "gg", isGameOver(board)
  127.  
  128. test = [None] * (3**9)
  129. possible_boards = []
  130. def testValue(board, turn):
  131. global test, possible_boards
  132. state = getState(board)
  133.  
  134. # memoize
  135. if test[state] is not None:
  136. return test[state]
  137.  
  138. possible_boards.append(board)
  139.  
  140. if isGameOver(board) != 0:
  141. # no more moves
  142. #print board
  143. gg = isGameOver(board)
  144. test[state] = gg
  145. return test[state]
  146.  
  147. # moves
  148. next_turn = 2 if turn == 1 else 1
  149. possible = []
  150. for move_choice in range(9):
  151. # if legal move
  152. if board.reshape(9)[move_choice] == 0:
  153. tboard = board.copy()
  154. tboard.reshape(9)[move_choice] = turn
  155. possible.append(testValue(tboard, next_turn))
  156.  
  157. if state == 0:
  158. print possible
  159.  
  160. if turn == 1:
  161. if 1 in possible:
  162. test[state] = 1
  163. elif -1 in possible:
  164. test[state] = -1
  165. else:
  166. test[state] = 2
  167.  
  168. if turn == 2:
  169. if 2 in possible:
  170. test[state] = 2
  171. elif -1 in possible:
  172. test[state] = -1
  173. else:
  174. test[state] = 1
  175.  
  176. return test[state]
  177.  
  178. def pval(xx):
  179. ret = []
  180. for x in xx:
  181. ret.append("%.2f" % x)
  182. return ' '.join(ret)
  183.  
  184. from tqdm import tqdm
  185. def runTest():
  186. if test[0] is None:
  187. testValue(newBoard(), 1)
  188. wrong = 0
  189. for nn, board in tqdm(enumerate(possible_boards)):
  190. tboard = np.copy(board)
  191. #print getState(tboard)
  192. turn = 2
  193. if np.sum(tboard == 1) == np.sum(tboard == 2):
  194. turn = 1
  195. while isGameOver(tboard) == 0:
  196. move_choice = makeMove(tboard, turn, argmax=True)
  197. tboard.reshape(9)[move_choice] = turn
  198. turn = 2 if turn == 1 else 1
  199. if test[getState(board)] != isGameOver(tboard):
  200. wrong += 1
  201. if verbose:
  202. print "BAD STATE at", nn
  203. print board
  204. print getState(board)
  205. print value[getState(board)]
  206. print "it should be:", test[getState(board)]
  207. print tboard
  208. print "it is w argmax policy:", isGameOver(tboard)
  209. print "wrong: %d/%d" % (wrong, len(possible_boards))
  210.  
  211. def train():
  212. global temp
  213. try:
  214. games = []
  215. while 1:
  216. games.append(agentPlay())
  217. tg = games[-1000:]
  218. if len(games) % 100 == 0:
  219. state = getState(np.array(([[1,2,0],[0,0,0],[0,0,1]])))
  220. print "running: %d/%d played %d with temp %f" % (np.sum(np.array(tg) == -1), len(tg), len(games), temp), pval(value[state])
  221. temp *= 0.995
  222. if len(games) % 10000 == 0:
  223. runTest()
  224. except KeyboardInterrupt:
  225. print "saving"
  226. np.save("values.npy", value)
  227.  
  228. if len(sys.argv) > 1 and sys.argv[1] == "play":
  229. player()
  230. elif len(sys.argv) > 1 and sys.argv[1] == "test":
  231. verbose = True
  232. runTest()
  233. else:
  234. train()
Add Comment
Please, Sign In to add comment