Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # main.py
- import gym, time;
- import numpy as np;
- from keras.models import Sequential;
- from keras.optimizers import Adam;
- from keras.layers import Input, Convolution2D, Dense, Dropout, Activation, Flatten;
- from collections import deque;
- from random import shuffle, sample;
- from utils import is_valid_move;
- from tensorflow.python.client import device_lib
- print(device_lib.list_local_devices())
- episodes = 1000;
- batch_size = 32;
- epsilon = 1.0;
- epsilon_decay = 0.995;
- epsilon_min = 0.01;
- gamma = 0.995;
- learning_rate = 1e-6;
- render = False;
- env = gym.make('Go9x9-v0');
- model = Sequential();
- model.add(Convolution2D(32, (3, 3), input_shape=(3,9,9), data_format='channels_first'));
- model.add(Convolution2D(32, (3, 3), activation='relu'));
- model.add(Dropout(0.2));
- model.add(Flatten());
- model.add(Dense(512, activation='relu'));
- model.add(Dropout(0.5));
- model.add(Dense(env.action_space.n, activation='linear'));
- model.compile(loss='mse', optimizer=Adam(lr=learning_rate), metrics=['accuracy']);
- memory = deque(maxlen=2000);
- def learnt_valid_move(state, player):
- best_action = None;
- best_value = -np.inf;
- for a,v in zip(range(env.action_space.n), model.predict(state.reshape(1, 3, 9, 9))[0]):
- if v > best_value and is_valid_move(state, player, a):
- best_action = a;
- best_value = v;
- return best_action;
- def random_valid_move(state, player):
- actions = list(range(env.action_space.n));
- shuffle(actions);
- for a in actions:
- if is_valid_move(state, player, a):
- return a;
- raise ValueError;
- for e in range(episodes):
- state = env.reset();
- while True:
- if render:
- env.render();
- time.sleep(0.2);
- if np.random.rand() <= epsilon:
- action = random_valid_move(state, 0);
- else:
- action = learnt_valid_move(state, 0);
- next_state, reward, done, _ = env.step(action);
- memory.append((state, action, reward, next_state, done));
- state = next_state;
- if done:
- print("episode: {}/{}, reward: {}".format(e+1, episodes, reward));
- break;
- if len(memory) > batch_size:
- minibatch = sample(memory, batch_size);
- for state, action, reward, next_state, done in minibatch:
- target = reward;
- if not done:
- target = (reward + gamma * np.amax(model.predict(next_state.reshape(1, 3, 9, 9))[0]));
- target_f = model.predict(state.reshape(1, 3, 9, 9));
- target_f[0][action] = target;
- model.fit(state.reshape(1, 3, 9, 9), target_f, batch_size=batch_size, epochs=1, verbose=0);
- if epsilon > epsilon_min:
- epsilon *= epsilon_decay;
- epsilon = max(epsilon, epsilon_min);
- env.close();
- # utils.py
- import numpy as np;
- def find_group(state, player, action):
- x = action % 9;
- y = action // 9;
- return _find_group(state.copy(), player, (x,y));
- def _find_group(c_state, player, position):
- x = position[0];
- y = position[1];
- if c_state[player][y][x] == 1:
- c_state[player][y][x] = 0;
- group = [position];
- if x > 0:
- group += _find_group(c_state, player, (x-1, y));
- if x < 8:
- group += _find_group(c_state, player, (x+1, y));
- if y > 0:
- group += _find_group(c_state, player, (x, y-1));
- if y < 8:
- group += _find_group(c_state, player, (x, y+1));
- return group;
- else:
- return [];
- def find_liberties(state, group):
- freedoms = set();
- for x,y in group:
- if x > 0 and state[2][y][x-1] == 1:
- freedoms.add((x-1,y));
- if x < 8 and state[2][y][x+1] == 1:
- freedoms.add((x+1,y));
- if y > 0 and state[2][y-1][x] == 1:
- freedoms.add((x,y-1));
- if y < 8 and state[2][y+1][x] == 1:
- freedoms.add((x,y+1));
- return list(freedoms);
- # NOTE: DOES NOT TAKE KO INTO ACCOUNT
- def is_valid_move(state, player, action):
- if action > 80:
- return True;
- # If the position is already occupied, then the move is invalid
- x = action % 9;
- y = action // 9;
- if state[2][y][x] == 0:
- return False;
- # Moves that do not leave you without liberties are allowed
- c_state = state.copy();
- c_state[player][y][x] = 1;
- c_state[2][y][x] = 0;
- liber = find_liberties(c_state, find_group(c_state, player, action));
- if liber != []:
- return True;
- # When you are left without liberties by your move, it is only valid if you hit with your move
- if x > 0 and find_liberties(c_state, find_group(c_state, 1-player, action-1)) == []:
- return True;
- if x < 8 and find_liberties(c_state, find_group(c_state, 1-player, action+1)) == []:
- return True;
- if y > 0 and find_liberties(c_state, find_group(c_state, 1-player, action-9)) == []:
- return True;
- if y < 8 and find_liberties(c_state, find_group(c_state, 1-player, action+9)) == []:
- return True;
- return False;
Add Comment
Please, Sign In to add comment