Advertisement
Guest User

Untitled

a guest
Mar 22nd, 2019
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.66 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import cv2
  4. import torch
  5. from torch import utils
  6. from torchvision import datasets, transforms
  7. from torch.autograd import Variable
  8. import sys
  9. sys.path.append('../')
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. import torch.optim as optim
  13. import sklearn
  14.  
  15.  
  16. %matplotlib inline
  17.  
  18. n = 15
  19. goal = 5
  20.  
  21.  
  22. let_to_dig = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, \
  23. 'h': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15}
  24.  
  25. dig_to_let = {1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', \
  26. 8: 'h', 9: 'j', 10: 'k', 11: 'l', 12: 'm', 13: 'n', 14: 'o', 15: 'p'}
  27.  
  28.  
  29. def format_move(old_move):
  30. y, x = let_to_dig[old_move[0]], int(old_move[1:])
  31. new_move = (x, y)
  32. return new_move
  33.  
  34.  
  35. def format_from_move(move):
  36. res = ''
  37. res += dig_to_let[move[1]]
  38. res += str(move[0])
  39. return res
  40.  
  41.  
  42. def get_moves(log):
  43. if log[-1] == '\n':
  44. log = log[:-1]
  45. log = log.split()
  46. winner = log[0]
  47. log = log[1:]
  48. for i in range(len(log)):
  49. log[i] = format_move(log[i])
  50. return winner, log
  51.  
  52.  
  53. def get_log(moves, winner):
  54. log = winner
  55. for move in moves:
  56. log += ' '
  57. log += format_from_move(move)
  58. log += '\n'
  59. return log
  60.  
  61.  
  62. def moves_to_grid(moves):
  63. grid = np.zeros((n + 1, n + 1), dtype=torch.int8)
  64. player = 1
  65. for move in moves:
  66. grid[move] = player
  67. player *= -1
  68. return grid
  69.  
  70.  
  71. def process_log(log):
  72. winner, moves = get_moves(log)
  73. grid_black = torch.zeros((n + 1, n + 1), dtype=torch.int8)
  74. grid_white = torch.zeros((n + 1, n + 1), dtype=torch.int8)
  75. grid_player = torch.ones((n + 1, n + 1), dtype=torch.int8)
  76. grids = []
  77. labels = []
  78. if winner == 'white':
  79. grid_player *= -1
  80. for idx in range(1, len(moves), 2):
  81. grid_black[moves[idx - 1]] = 1
  82. lbl = (moves[idx][0] - 1) * n + moves[idx][1] - 1
  83. t = torch.stack([grid_black, grid_white, grid_player])
  84. grids.append(t)
  85. labels.append(lbl)
  86. grid_white[moves[idx]] = -1
  87. else:
  88. grid_black[moves[0]] = 1
  89. for idx in range(2, len(moves), 2):
  90. grid_white[moves[idx - 1]] = -1
  91. lbl = (moves[idx][0] - 1) * n + moves[idx][1] - 1
  92. t = torch.stack([grid_black, grid_white, grid_player])
  93. grids.append(t)
  94. labels.append(lbl)
  95. grid_black[moves[idx]] = 1
  96. return grids, labels
  97.  
  98. def get_data_n_labels(logs):
  99. data = []
  100. labels = []
  101. for log in logs:
  102. d, l = process_log(log)
  103. data += d
  104. labels += l
  105. return data, labels
  106.  
  107. def find_shifts(moves, x, y, dir_x, dir_y, winner):
  108. new_logs = []
  109. for y_shift in range(y):
  110. for x_shift in range(x):
  111. new_moves = []
  112. for move in moves:
  113. new_moves.append((move[0] + dir_x * x_shift, move[1] + dir_y * y_shift))
  114. new_logs.append(get_log(new_moves, winner))
  115. return new_logs
  116.  
  117. def flip_log(log):
  118. winner, moves = get_moves(log)
  119. flipped = []
  120. for move in moves:
  121. flipped.append((move[0], n - move[1] + 1))
  122. flipped_log = get_log(flipped, winner)
  123. return flipped_log
  124.  
  125. def rotate_log(log):
  126. winner, moves = get_moves(log)
  127. rotated = []
  128. for move in moves:
  129. rotated.append((move[1], n - move[0] + 1))
  130. rotated_log = get_log(rotated, winner)
  131. return rotated_log
  132.  
  133. def get_rotated_logs(log):
  134. rotated = log
  135. rotated_logs = []
  136. for i in range(3):
  137. rotated = rotate_log(rotated)
  138. rotated_logs.append(rotated)
  139. return rotated_logs
  140.  
  141. def get_shifted_logs(log):
  142. new_logs = []
  143. min_x, min_y, max_x, max_y = n - 1, n - 1, 0, 0
  144. winner, moves = get_moves(log)
  145. for move in moves:
  146. x, y = move
  147. max_x = max(max_x, x)
  148. max_y = max(max_y, y)
  149. min_x = min(min_x, x)
  150. min_y = min(min_y, y)
  151. new_logs += find_shifts(moves, min_x, min_y, -1, -1, winner)
  152. new_logs += find_shifts(moves, min_x, n + 1 - max_y, -1, 1, winner)
  153. new_logs += find_shifts(moves, n + 1 - max_x, n + 1- max_y, 1, 1, winner)
  154. new_logs += find_shifts(moves, n + 1 - max_x, min_y, 1, -1, winner)
  155. return list(set(new_logs))
  156.  
  157. def log_to_data(log):
  158. shifted_logs = get_shifted_logs(log)
  159. rotated = get_rotated_logs(log)
  160. for r in rotated:
  161. shifted_logs += get_shifted_logs(r)
  162. flipped = flip_log(log)
  163. shifted_logs += get_shifted_logs(flipped)
  164. #shifted_logs = rotated + [log] + [flipped]
  165. d, l = get_data_n_labels(shifted_logs)
  166. return d, l
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement