• API
• FAQ
• Tools
• Archive
daily pastebin goal
9%
SHARE
TWEET

# Untitled

a guest Mar 22nd, 2019 69 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import numpy as np
2. import matplotlib.pyplot as plt
3. import cv2
4. import torch
5. from torch import utils
6. from torchvision import datasets, transforms
8. import sys
9. sys.path.append('../')
10. import torch.nn as nn
11. import torch.nn.functional as F
12. import torch.optim as optim
13. import sklearn
14.
15.
16. %matplotlib inline
17.
18. n = 15
19. goal = 5
20.
21.
22. let_to_dig = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, \
23.               'h': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15}
24.
25. dig_to_let = {1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', \
26.               8: 'h', 9: 'j', 10: 'k', 11: 'l', 12: 'm', 13: 'n', 14: 'o', 15: 'p'}
27.
28.
29. def format_move(old_move):
30.     y, x = let_to_dig[old_move[0]], int(old_move[1:])
31.     new_move = (x, y)
32.     return new_move
33.
34.
35. def format_from_move(move):
36.     res = ''
37.     res += dig_to_let[move[1]]
38.     res += str(move[0])
39.     return res
40.
41.
42. def get_moves(log):
43.     if log[-1] == '\n':
44.         log = log[:-1]
45.     log = log.split()
46.     winner = log[0]
47.     log = log[1:]
48.     for i in range(len(log)):
49.         log[i] = format_move(log[i])
50.     return winner, log
51.
52.
53. def get_log(moves, winner):
54.     log = winner
55.     for move in moves:
56.         log += ' '
57.         log += format_from_move(move)
58.     log += '\n'
59.     return log
60.
61.
62. def moves_to_grid(moves):
63.     grid = np.zeros((n + 1, n + 1), dtype=torch.int8)
64.     player = 1
65.     for move in moves:
66.         grid[move] = player
67.         player *= -1
68.     return grid
69.
70.
71. def process_log(log):
72.     winner, moves = get_moves(log)
73.     grid_black = torch.zeros((n + 1, n + 1), dtype=torch.int8)
74.     grid_white = torch.zeros((n + 1, n + 1), dtype=torch.int8)
75.     grid_player = torch.ones((n + 1, n + 1), dtype=torch.int8)
76.     grids = []
77.     labels = []
78.     if winner == 'white':
79.         grid_player *= -1
80.         for idx in range(1, len(moves), 2):
81.             grid_black[moves[idx - 1]] = 1
82.             lbl = (moves[idx][0] - 1) * n + moves[idx][1] - 1
83.             t = torch.stack([grid_black, grid_white, grid_player])
84.             grids.append(t)
85.             labels.append(lbl)
86.             grid_white[moves[idx]] = -1
87.     else:
88.         grid_black[moves[0]] = 1
89.         for idx in range(2, len(moves), 2):
90.             grid_white[moves[idx - 1]] = -1
91.             lbl = (moves[idx][0] - 1) * n + moves[idx][1] - 1
92.             t = torch.stack([grid_black, grid_white, grid_player])
93.             grids.append(t)
94.             labels.append(lbl)
95.             grid_black[moves[idx]] = 1
96.     return grids, labels
97.
98. def get_data_n_labels(logs):
99.     data = []
100.     labels = []
102.         d, l = process_log(log)
103.         data += d
104.         labels += l
105.     return data, labels
106.
107. def find_shifts(moves, x, y, dir_x, dir_y, winner):
108.     new_logs = []
109.     for y_shift in range(y):
110.         for x_shift in range(x):
111.             new_moves = []
112.             for move in moves:
113.                 new_moves.append((move[0] + dir_x * x_shift, move[1] + dir_y * y_shift))
114.             new_logs.append(get_log(new_moves, winner))
115.     return new_logs
116.
117. def flip_log(log):
118.     winner, moves = get_moves(log)
119.     flipped = []
120.     for move in moves:
121.         flipped.append((move[0], n - move[1] + 1))
122.     flipped_log = get_log(flipped, winner)
123.     return flipped_log
124.
125. def rotate_log(log):
126.     winner, moves = get_moves(log)
127.     rotated = []
128.     for move in moves:
129.         rotated.append((move[1], n - move[0] + 1))
130.     rotated_log = get_log(rotated, winner)
131.     return rotated_log
132.
133. def get_rotated_logs(log):
134.     rotated = log
135.     rotated_logs = []
136.     for i in range(3):
137.         rotated = rotate_log(rotated)
138.         rotated_logs.append(rotated)
139.     return rotated_logs
140.
141. def get_shifted_logs(log):
142.     new_logs = []
143.     min_x, min_y, max_x, max_y = n - 1, n - 1, 0, 0
144.     winner, moves = get_moves(log)
145.     for move in moves:
146.         x, y = move
147.         max_x = max(max_x, x)
148.         max_y = max(max_y, y)
149.         min_x = min(min_x, x)
150.         min_y = min(min_y, y)
151.     new_logs += find_shifts(moves, min_x, min_y, -1, -1, winner)
152.     new_logs += find_shifts(moves, min_x, n + 1 - max_y, -1, 1, winner)
153.     new_logs += find_shifts(moves, n + 1 - max_x, n + 1- max_y, 1, 1, winner)
154.     new_logs += find_shifts(moves, n + 1 - max_x, min_y, 1, -1, winner)
155.     return list(set(new_logs))
156.
157. def log_to_data(log):
158.     shifted_logs = get_shifted_logs(log)
159.     rotated = get_rotated_logs(log)
160.     for r in rotated:
161.         shifted_logs += get_shifted_logs(r)
162.     flipped = flip_log(log)
163.     shifted_logs += get_shifted_logs(flipped)
164.     #shifted_logs = rotated + [log] + [flipped]
165.     d, l = get_data_n_labels(shifted_logs)
166.     return d, l
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.

Top