Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import matplotlib.pyplot as plt
- import cv2
- import torch
- from torch import utils
- from torchvision import datasets, transforms
- from torch.autograd import Variable
- import sys
- sys.path.append('../')
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- import sklearn
- %matplotlib inline
- n = 15
- goal = 5
- let_to_dig = {'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, \
- 'h': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15}
- dig_to_let = {1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', \
- 8: 'h', 9: 'j', 10: 'k', 11: 'l', 12: 'm', 13: 'n', 14: 'o', 15: 'p'}
- def format_move(old_move):
- y, x = let_to_dig[old_move[0]], int(old_move[1:])
- new_move = (x, y)
- return new_move
- def format_from_move(move):
- res = ''
- res += dig_to_let[move[1]]
- res += str(move[0])
- return res
- def get_moves(log):
- if log[-1] == '\n':
- log = log[:-1]
- log = log.split()
- winner = log[0]
- log = log[1:]
- for i in range(len(log)):
- log[i] = format_move(log[i])
- return winner, log
- def get_log(moves, winner):
- log = winner
- for move in moves:
- log += ' '
- log += format_from_move(move)
- log += '\n'
- return log
- def moves_to_grid(moves):
- grid = np.zeros((n + 1, n + 1), dtype=torch.int8)
- player = 1
- for move in moves:
- grid[move] = player
- player *= -1
- return grid
- def process_log(log):
- winner, moves = get_moves(log)
- grid_black = torch.zeros((n + 1, n + 1), dtype=torch.int8)
- grid_white = torch.zeros((n + 1, n + 1), dtype=torch.int8)
- grid_player = torch.ones((n + 1, n + 1), dtype=torch.int8)
- grids = []
- labels = []
- if winner == 'white':
- grid_player *= -1
- for idx in range(1, len(moves), 2):
- grid_black[moves[idx - 1]] = 1
- lbl = (moves[idx][0] - 1) * n + moves[idx][1] - 1
- t = torch.stack([grid_black, grid_white, grid_player])
- grids.append(t)
- labels.append(lbl)
- grid_white[moves[idx]] = -1
- else:
- grid_black[moves[0]] = 1
- for idx in range(2, len(moves), 2):
- grid_white[moves[idx - 1]] = -1
- lbl = (moves[idx][0] - 1) * n + moves[idx][1] - 1
- t = torch.stack([grid_black, grid_white, grid_player])
- grids.append(t)
- labels.append(lbl)
- grid_black[moves[idx]] = 1
- return grids, labels
- def get_data_n_labels(logs):
- data = []
- labels = []
- for log in logs:
- d, l = process_log(log)
- data += d
- labels += l
- return data, labels
- def find_shifts(moves, x, y, dir_x, dir_y, winner):
- new_logs = []
- for y_shift in range(y):
- for x_shift in range(x):
- new_moves = []
- for move in moves:
- new_moves.append((move[0] + dir_x * x_shift, move[1] + dir_y * y_shift))
- new_logs.append(get_log(new_moves, winner))
- return new_logs
- def flip_log(log):
- winner, moves = get_moves(log)
- flipped = []
- for move in moves:
- flipped.append((move[0], n - move[1] + 1))
- flipped_log = get_log(flipped, winner)
- return flipped_log
- def rotate_log(log):
- winner, moves = get_moves(log)
- rotated = []
- for move in moves:
- rotated.append((move[1], n - move[0] + 1))
- rotated_log = get_log(rotated, winner)
- return rotated_log
- def get_rotated_logs(log):
- rotated = log
- rotated_logs = []
- for i in range(3):
- rotated = rotate_log(rotated)
- rotated_logs.append(rotated)
- return rotated_logs
- def get_shifted_logs(log):
- new_logs = []
- min_x, min_y, max_x, max_y = n - 1, n - 1, 0, 0
- winner, moves = get_moves(log)
- for move in moves:
- x, y = move
- max_x = max(max_x, x)
- max_y = max(max_y, y)
- min_x = min(min_x, x)
- min_y = min(min_y, y)
- new_logs += find_shifts(moves, min_x, min_y, -1, -1, winner)
- new_logs += find_shifts(moves, min_x, n + 1 - max_y, -1, 1, winner)
- new_logs += find_shifts(moves, n + 1 - max_x, n + 1- max_y, 1, 1, winner)
- new_logs += find_shifts(moves, n + 1 - max_x, min_y, 1, -1, winner)
- return list(set(new_logs))
- def log_to_data(log):
- shifted_logs = get_shifted_logs(log)
- rotated = get_rotated_logs(log)
- for r in rotated:
- shifted_logs += get_shifted_logs(r)
- flipped = flip_log(log)
- shifted_logs += get_shifted_logs(flipped)
- #shifted_logs = rotated + [log] + [flipped]
- d, l = get_data_n_labels(shifted_logs)
- return d, l
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement