Knaapje

Deep Q-Learning with Experience Replay of Go 9x9

Jan 19th, 2018
44
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.72 KB | None | 0 0
  1. # main.py
  2. import gym, time;
  3. import numpy as np;
  4. from keras.models import Sequential;
  5. from keras.optimizers import Adam;
  6. from keras.layers import Input, Convolution2D, Dense, Dropout, Activation, Flatten;
  7. from collections import deque;
  8. from random import shuffle, sample;
  9. from utils import is_valid_move;
  10.  
  11. from tensorflow.python.client import device_lib
  12. print(device_lib.list_local_devices())
  13.  
  14. episodes = 1000;
  15. batch_size = 32;
  16. epsilon = 1.0;
  17. epsilon_decay = 0.995;
  18. epsilon_min = 0.01;
  19. gamma = 0.995;
  20. learning_rate = 1e-6;
  21. render = False;
  22.  
  23. env = gym.make('Go9x9-v0');
  24.  
  25. model = Sequential();
  26. model.add(Convolution2D(32, (3, 3), input_shape=(3,9,9), data_format='channels_first'));
  27. model.add(Convolution2D(32, (3, 3), activation='relu'));
  28. model.add(Dropout(0.2));
  29. model.add(Flatten());
  30. model.add(Dense(512, activation='relu'));
  31. model.add(Dropout(0.5));
  32. model.add(Dense(env.action_space.n, activation='linear'));
  33. model.compile(loss='mse', optimizer=Adam(lr=learning_rate), metrics=['accuracy']);
  34.  
  35. memory = deque(maxlen=2000);
  36.  
  37. def learnt_valid_move(state, player):
  38.   best_action = None;
  39.   best_value  = -np.inf;
  40.   for a,v in zip(range(env.action_space.n), model.predict(state.reshape(1, 3, 9, 9))[0]):
  41.     if v > best_value and is_valid_move(state, player, a):
  42.       best_action = a;
  43.       best_value  = v;
  44.   return best_action;
  45.  
  46. def random_valid_move(state, player):
  47.   actions = list(range(env.action_space.n));
  48.   shuffle(actions);
  49.   for a in actions:
  50.     if is_valid_move(state, player, a):
  51.       return a;
  52.   raise ValueError;
  53.  
  54. for e in range(episodes):
  55.   state = env.reset();
  56.   while True:
  57.     if render:
  58.       env.render();
  59.       time.sleep(0.2);
  60.    
  61.     if np.random.rand() <= epsilon:
  62.       action = random_valid_move(state, 0);
  63.     else:
  64.       action = learnt_valid_move(state, 0);
  65.  
  66.     next_state, reward, done, _ = env.step(action);
  67.     memory.append((state, action, reward, next_state, done));
  68.     state = next_state;
  69.  
  70.     if done:
  71.       print("episode: {}/{}, reward: {}".format(e+1, episodes, reward));
  72.       break;
  73.   if len(memory) > batch_size:
  74.     minibatch = sample(memory, batch_size);
  75.     for state, action, reward, next_state, done in minibatch:
  76.       target = reward;
  77.       if not done:
  78.         target = (reward + gamma * np.amax(model.predict(next_state.reshape(1, 3, 9, 9))[0]));
  79.       target_f = model.predict(state.reshape(1, 3, 9, 9));
  80.       target_f[0][action] = target;
  81.       model.fit(state.reshape(1, 3, 9, 9), target_f, batch_size=batch_size, epochs=1, verbose=0);
  82.     if epsilon > epsilon_min:
  83.       epsilon *= epsilon_decay;
  84.       epsilon = max(epsilon, epsilon_min);
  85. env.close();
  86.  
  87.  
  88. # utils.py
  89. import numpy as np;
  90.  
  91. def find_group(state, player, action):
  92.   x = action %  9;
  93.   y = action // 9;
  94.   return _find_group(state.copy(), player, (x,y));
  95.  
  96. def _find_group(c_state, player, position):
  97.   x = position[0];
  98.   y = position[1];
  99.   if c_state[player][y][x] == 1:
  100.     c_state[player][y][x] = 0;
  101.     group = [position];
  102.     if x > 0:
  103.       group += _find_group(c_state, player, (x-1, y));
  104.     if x < 8:
  105.       group += _find_group(c_state, player, (x+1, y));
  106.     if y > 0:
  107.       group += _find_group(c_state, player, (x, y-1));
  108.     if y < 8:
  109.       group += _find_group(c_state, player, (x, y+1));
  110.     return group;
  111.   else:
  112.     return [];
  113.  
  114. def find_liberties(state, group):
  115.   freedoms = set();
  116.   for x,y in group:
  117.     if x > 0 and state[2][y][x-1] == 1:
  118.       freedoms.add((x-1,y));
  119.     if x < 8 and state[2][y][x+1] == 1:
  120.       freedoms.add((x+1,y));
  121.     if y > 0 and state[2][y-1][x] == 1:
  122.       freedoms.add((x,y-1));
  123.     if y < 8 and state[2][y+1][x] == 1:
  124.       freedoms.add((x,y+1));
  125.   return list(freedoms);
  126.  
  127. # NOTE: DOES NOT TAKE KO INTO ACCOUNT
  128. def is_valid_move(state, player, action):
  129.   if action > 80:
  130.     return True;
  131.  
  132.   # If the position is already occupied, then the move is invalid
  133.   x = action %  9;
  134.   y = action // 9;
  135.   if state[2][y][x] == 0:
  136.     return False;
  137.  
  138.   # Moves that do not leave you without liberties are allowed
  139.   c_state = state.copy();
  140.   c_state[player][y][x] = 1;
  141.   c_state[2][y][x]      = 0;
  142.   liber = find_liberties(c_state, find_group(c_state, player, action));
  143.   if liber != []:
  144.     return True;
  145.  
  146.   # When you are left without liberties by your move, it is only valid if you hit with your move
  147.   if x > 0 and find_liberties(c_state, find_group(c_state, 1-player, action-1)) == []:
  148.     return True;
  149.   if x < 8 and find_liberties(c_state, find_group(c_state, 1-player, action+1)) == []:
  150.     return True;
  151.   if y > 0 and find_liberties(c_state, find_group(c_state, 1-player, action-9)) == []:
  152.     return True;
  153.   if y < 8 and find_liberties(c_state, find_group(c_state, 1-player, action+9)) == []:
  154.     return True;
  155.   return False;
Add Comment
Please, Sign In to add comment