Guest User

Game of Amazon (NN learning code).

a guest
Jan 16th, 2016
159
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 20.51 KB | None | 0 0
  1. #include <iostream>
  2. #include <fstream>
  3. #include <vector>
  4. #include <deque>
  5. #include <queue>
  6. #include <random>
  7. #include <cassert>
  8. #include <ctime>
  9.  
  10. constexpr int SIZE = 10;
  11. constexpr int AMAZONS_CNT = 4;
  12. constexpr int INF = 1000*1000*1000;
  13. constexpr int WIN = 10000;
  14.  
  15. using namespace std;
  16.  
  17. typedef pair<int, int> pos_t;
  18. typedef pair<pair<pos_t, pos_t>, pos_t> move_t;
  19. typedef double score_t;
  20.  
  21. class GameState {
  22.     int board[SIZE][SIZE];
  23.     int turn;
  24.     pos_t amazons[2][AMAZONS_CNT];
  25. public:
  26.     GameState(const int B[SIZE][SIZE], int t,
  27.               vector<pos_t>& fst, vector<pos_t>& snd) {
  28.         for (int i = 0; i < SIZE; ++i) {
  29.             for (int j = 0; j < SIZE; ++j)
  30.                 board[i][j] = B[i][j];
  31.         }
  32.         turn = t - 1;
  33.         assert(fst.size() == AMAZONS_CNT);
  34.         assert(snd.size() == AMAZONS_CNT);
  35.         for (int i = 0; i < AMAZONS_CNT; ++i) {
  36.             amazons[0][i] = fst[i];
  37.             int x = amazons[0][i].first;
  38.             int y = amazons[0][i].second;
  39.             assert(board[x][y] == 1);
  40.         }
  41.         for (int i = 0; i < AMAZONS_CNT; ++i) {
  42.             amazons[1][i] = snd[i];
  43.             int x = amazons[1][i].first;
  44.             int y = amazons[1][i].second;
  45.             assert(board[x][y] == 2);
  46.         }
  47.     }
  48.    
  49.     GameState(const GameState& state) {
  50.         for (int i = 0; i < SIZE; ++i) {
  51.             for (int j = 0; j < SIZE; ++j)
  52.                 board[i][j] = state.board[i][j];
  53.         }
  54.         turn = state.turn;
  55.         for (int i = 0; i < 2; ++i) {
  56.             for (int j = 0; j < AMAZONS_CNT; ++j) {
  57.                 amazons[i][j] = state.amazons[i][j];
  58.             }
  59.         }
  60.     }
  61.    
  62.     inline const int* operator[](int index) { return board[index]; }
  63.     inline const int get_turn() { return turn; }
  64.     inline const pos_t* get_amazons(int player) { return amazons[player]; }
  65.    
  66.     void make_move(move_t move) {
  67.         int amazon_x = move.first.first.first;
  68.         int amazon_y = move.first.first.second;
  69.         int target_x = move.first.second.first;
  70.         int target_y = move.first.second.second;
  71.         board[target_x][target_y] = board[amazon_x][amazon_y];
  72.         board[amazon_x][amazon_y] = 0;
  73.         for (int i = 0; i < AMAZONS_CNT; ++i) {
  74.             if (amazons[turn][i] == make_pair(amazon_x, amazon_y)) {
  75.                 amazons[turn][i] = make_pair(target_x, target_y);
  76.                 break;
  77.             }
  78.         }
  79.         int arrow_x = move.second.first;
  80.         int arrow_y = move.second.second;
  81.         board[arrow_x][arrow_y] = -1;
  82.         turn = 1 - turn;
  83.     }
  84.    
  85.     deque<move_t> get_moves() {
  86.         constexpr int DIR_CNT = 8;
  87.         int dx[DIR_CNT] = {+1, +1, +1,  0, -1, -1, -1,  0};
  88.         int dy[DIR_CNT] = {+1,  0, -1, -1, -1,  0, +1, +1};
  89.         deque<move_t> res;
  90.         for (int piece = 0; piece < AMAZONS_CNT; ++piece) {
  91.             int x = amazons[turn][piece].first;
  92.             int y = amazons[turn][piece].second;
  93.             pos_t init_pos = make_pair(x, y);
  94.             for (int dir = 0; dir < DIR_CNT; ++dir) {
  95.                 int cur_x = x;
  96.                 int cur_y = y;
  97.                 for (int len = 0; ; ++len) {
  98.                     cur_x += dx[dir];
  99.                     cur_y += dy[dir];
  100.                     if (cur_x < 0 || cur_x >= SIZE) break;
  101.                     else if (cur_y < 0 || cur_y >= SIZE) break;
  102.                     else if (board[cur_x][cur_y] != 0) break;
  103.                     else {
  104.                         pos_t target = make_pair(cur_x, cur_y);
  105.                         int tmp = board[x][y];
  106.                         board[x][y] = 0;
  107.                         for (int arrow_dir = 0;
  108.                              arrow_dir < DIR_CNT; ++arrow_dir) {
  109.                             int cur_arr_x = cur_x;
  110.                             int cur_arr_y = cur_y;
  111.                             for (int arrow_len = 0; ; ++len) {
  112.                                 cur_arr_x += dx[arrow_dir];
  113.                                 cur_arr_y += dy[arrow_dir];
  114.                                 if (cur_arr_x < 0 || cur_arr_x >= SIZE)
  115.                                     break;
  116.                                 else if (cur_arr_y < 0 || cur_arr_y >= SIZE)
  117.                                     break;
  118.                                 else if (board[cur_arr_x][cur_arr_y] != 0)
  119.                                     break;
  120.                                 else {
  121.                                     pos_t arrow =
  122.                                     make_pair(cur_arr_x, cur_arr_y);
  123.                                     res.push_back(
  124.                                                   make_pair(
  125.                                                             make_pair(init_pos, target),
  126.                                                             arrow));
  127.                                 }
  128.                             }
  129.                         }
  130.                         board[x][y] = tmp;
  131.                     }
  132.                 }
  133.             }
  134.         }
  135.         return res;
  136.     }
  137.    
  138.     int get_winner() {
  139.         int result = 1 - turn;
  140.         constexpr int DIR_CNT = 8;
  141.         int dx[DIR_CNT] = {+1, +1, +1,  0, -1, -1, -1,  0};
  142.         int dy[DIR_CNT] = {+1,  0, -1, -1, -1,  0, +1, +1};
  143.         for (int i = 0; i < AMAZONS_CNT; ++i) {
  144.             int x = amazons[turn][i].first;
  145.             int y = amazons[turn][i].second;
  146.             for (int dir = 0; dir < DIR_CNT; ++dir) {
  147.                 int new_x = x + dx[dir];
  148.                 int new_y = y + dy[dir];
  149.                 if (new_x < 0 || new_x >= SIZE) continue;
  150.                 if (new_y < 0 || new_y >= SIZE) continue;
  151.                 if (board[new_x][new_y] == 0) {
  152.                     result = -1;
  153.                     break;
  154.                 }
  155.             }
  156.             if (result == -1) break;
  157.         }
  158.         return result;
  159.     }
  160.    
  161.     void print(ostream& out) {
  162.         out << "Board:\n";
  163.         out << "  | ";
  164.         for (int i = 0; i < SIZE; ++i) {
  165.             out << i << " ";
  166.         }
  167.         out << "\n";
  168.         out << "----";
  169.         for (int i = 0; i < SIZE; ++i) {
  170.             out << "--";
  171.         }
  172.         out << "\n";
  173.         for (int i = 0; i < SIZE; ++i) {
  174.             out << i << " | ";
  175.             for (int j = 0; j < SIZE; ++j) {
  176.                 if (board[i][j] == -1) out << "* ";
  177.                 else out << board[i][j] << " ";
  178.             }
  179.             out << "\n";
  180.         }
  181.         out << "Turn: " << turn+1 << "\n";
  182.         out << "Amazon positions:\n";
  183.         out << "White (1): ";
  184.         for (int i = 0; i < AMAZONS_CNT; ++i) {
  185.             out << "(" << amazons[0][i].first << ", "
  186.             << amazons[0][i].second << ") ";
  187.         }
  188.         out << "\n";
  189.         out << "Black (2): ";
  190.         for (int i = 0; i < AMAZONS_CNT; ++i) {
  191.             out << "(" << amazons[1][i].first << ", "
  192.             << amazons[1][i].second << ") ";
  193.         }
  194.         out << "\n";
  195.     }
  196. };
  197.  
  198. typedef vector<vector<double>> Matrix;
  199.  
  200. class NeuralNetwork {
  201.     vector<Matrix> weights, traces;
  202.     double learning_rate, decay_rate;
  203.     int cnt = 0;
  204. public:
  205.     NeuralNetwork() {
  206.         learning_rate = 0.1;
  207.         decay_rate = 0.0;
  208.     }
  209.     NeuralNetwork(vector<int> input_sizes, double eta, double lambda) {
  210.         assert(input_sizes[input_sizes.size() - 1] == 1);
  211.         learning_rate = eta;
  212.         decay_rate = lambda;
  213.         std::uniform_real_distribution<double> unif{-1, +1};
  214.         std::default_random_engine re(std::random_device{}());
  215.         for (int i = 0; i < input_sizes.size() - 1; ++i) {
  216.             Matrix w; Matrix t;
  217.             int rows = input_sizes[i + 1];
  218.             int cols = input_sizes[i] + 1;
  219.             for (int j = 0; j < rows; ++j) {
  220.                 vector<double> line;
  221.                 t.push_back(line);
  222.                 for (int k = 0; k < cols; ++k) {
  223.                     line.push_back(unif(re));
  224.                     t[j].push_back(0);
  225.                 }
  226.                 w.push_back(line);
  227.             }
  228.             weights.push_back(w);
  229.             traces.push_back(t);
  230.         }
  231.     }
  232.     NeuralNetwork(vector<Matrix> w, double eta, double lambda) {
  233.         weights = w;
  234.         learning_rate = eta;
  235.         decay_rate = lambda;
  236.         for (int i = 0; i < weights.size(); ++i) {
  237.             Matrix t;
  238.             int rows = weights[i].size();
  239.             int cols = weights[i][0].size();
  240.             for (int j = 0; j < rows; ++j) {
  241.                 vector<double> line;
  242.                 t.push_back(line);
  243.                 for (int k = 0; k < cols; ++k) {
  244.                     t[j].push_back(0);
  245.                 }
  246.             }
  247.             traces.push_back(t);
  248.         }
  249.     }
  250.     inline vector<double> feedforward(const vector<double>& inputs) const {
  251.         auto signals = calc_outputs(inputs);
  252.         return signals[signals.size() - 1];
  253.     }
  254.     void train(const vector<double>& inputs,
  255.                const vector<double>& target) {
  256.         vector<vector<double>> outputs = calc_outputs(inputs);
  257.         vector<double> prev_deltas, deltas;
  258.         for (int layer = outputs.size() - 1; layer >= 0; --layer) {
  259.             for (int i = 1; i < outputs[layer].size(); ++i) {
  260.                 double o = outputs[layer][i];
  261.                 double delta = 0;
  262.                 if (layer == outputs.size() - 1) {
  263.                     delta = o * (1 - o) * (target[i-1] - o);
  264.                 }
  265.                 else {
  266.                     for (int k = 0;
  267.                          k < prev_deltas.size(); k++) {
  268.                         delta += prev_deltas[k] * weights[layer+1][k][i];
  269.                     }
  270.                     delta *= o * (1 - o);
  271.                 }
  272.                 deltas.push_back(delta);
  273.                 for (int j = 0; j < weights[layer][i-1].size();++j) {
  274.                     double grad = 0;
  275.                     if (layer > 0) {
  276.                         grad = delta * outputs[layer-1][j];
  277.                     }
  278.                     else {
  279.                         if (j > 0) {
  280.                             grad = delta * inputs[j-1];
  281.                         }
  282.                         else {
  283.                             grad = delta;
  284.                         }
  285.                     }
  286.                     traces[layer][i-1][j] *= decay_rate;
  287.                     traces[layer][i-1][j] += grad;
  288.                     weights[layer][i-1][j] +=
  289.                     learning_rate * traces[layer][i-1][j];
  290.                 }
  291.             }
  292.             prev_deltas = deltas;
  293.             deltas.clear();
  294.         }
  295.         cnt += 1;
  296.         if (cnt % 5000000 == 0) {
  297.             learning_rate *= 0.99;
  298.             learning_rate = max(0.05, learning_rate);
  299.             cnt = 0;
  300.             cerr << "*** ATTENTION ***\n";
  301.             cerr << "New learning rate = " << learning_rate << "\n";
  302.         }
  303.     }
  304.     inline vector<double> operator()(const vector<double>& inputs) const {
  305.         return feedforward(inputs);
  306.     }
  307.     void print(ostream& out) {
  308.         out << "{\n";
  309.         for (int i = 0; i < weights.size(); ++i) {
  310.             out << "{" << "\n";
  311.             for (int j = 0; j < weights[i].size(); ++j) {
  312.                 auto v = weights[i][j];
  313.                 out << "{";
  314.                 for (int k = 0; k < weights[i][j].size(); ++k) {
  315.                     auto l = weights[i][j][k];
  316.                     out << l;
  317.                     if (k < weights[i][j].size() - 1) out << ", ";
  318.                 }
  319.                 out << "}";
  320.                 if (j < weights[i].size() - 1) out << ", ";
  321.                 out << "\n";
  322.             }
  323.             out << "}";
  324.             if (i < weights.size() - 1) out << ", ";
  325.             out << "\n";
  326.         }
  327.         out << "};\n";
  328.     }
  329. private:
  330.     inline double sigmoid(double x) const {
  331.         return 1.0 / (1.0 + exp(-x));
  332.     }
  333.     inline double sigmoid_prime(double x) const {
  334.         return sigmoid(x) * (1 - sigmoid(x));
  335.     }
  336.     vector<vector<double>> calc_outputs(const vector<double>& inputs) const {
  337.         vector<vector<double>> res;
  338.         vector<double> signals{1};
  339.         for (auto i : inputs) {
  340.             signals.push_back(i);
  341.         }
  342.         for (Matrix w : weights) {
  343.             assert(w[0].size() == signals.size());
  344.             vector<double> new_signals(w.size() + 1, 0);
  345.             new_signals[0] = 1;
  346.             for (int i = 1; i <= w.size(); ++i) {
  347.                 double c0, c1, c2, c3;
  348.                 size_t sz = w[i-1].size();
  349.                 for (int j = 0; j < sz; j += 4) {
  350.                     c0 = c1 = c2 = c3 = 0.0;
  351.                     c0 = signals[j] * w[i - 1][j];
  352.                     if (j+1 < sz)
  353.                         c1 = signals[j+1] * w[i - 1][j+1];
  354.                     if (j+2 < sz)
  355.                         c2 = signals[j+2] * w[i - 1][j+2];
  356.                     if (j+3 < sz)
  357.                         c3 = signals[j+3] * w[i - 1][j+3];
  358.                     new_signals[i] += c0+c1+c2+c3;
  359.                 }
  360.                 new_signals[i] = sigmoid(new_signals[i]);
  361.             }
  362.             signals = new_signals;
  363.             res.push_back(signals);
  364.         }
  365.         return res;
  366.     }
  367. };
  368.  
  369.  
  370. inline vector<double> get_vector(GameState& state) {
  371.     constexpr int DIR_CNT = 8;
  372.     constexpr int dx[DIR_CNT] = {+1, +1, +1,  0, -1, -1, -1,  0};
  373.     constexpr int dy[DIR_CNT] = {+1,  0, -1, -1, -1,  0, +1, +1};
  374.     vector<double> inputs{0};
  375.     for (int i = 0; i < SIZE; ++i) {
  376.         for (int j = 0; j < SIZE; ++j) {
  377.             if (state[i][j] == -1)
  378.                 inputs[0]++;
  379.         }
  380.     }
  381.     inputs[0] /= 100;
  382.     int turn = state.get_turn();
  383.     // vector<vector<pos_t>> amazons(2, vector<pos_t>(AMAZONS_CNT));
  384.     pos_t amazons[2][AMAZONS_CNT];
  385.     constexpr int FEATURES_CNT = 13;
  386.     constexpr int X_INDEX = 0, Y_INDEX = 1, DIR_MIN_INDEX = 2;
  387.     constexpr int ONE_PLY_INDEX = DIR_MIN_INDEX + DIR_CNT;
  388.     constexpr int TWO_PLY_INDEX = 1 + ONE_PLY_INDEX;
  389.     constexpr int FAR_INDEX = 1 + TWO_PLY_INDEX;
  390.     static_assert(FAR_INDEX + 1 == FEATURES_CNT, "Indices goin\' berzerk :(");
  391.     double amazons_features_c[2*AMAZONS_CNT][FEATURES_CNT];
  392.     for (int stage = 0; stage < 2; ++stage) {
  393.         for (int i = 0; i < AMAZONS_CNT; ++i) {
  394.             amazons[stage][i] = state.get_amazons(turn)[i];
  395.             int x = amazons[stage][i].first;
  396.             int y = amazons[stage][i].second;
  397.             amazons_features_c[stage*AMAZONS_CNT+i][X_INDEX] = 1.0*x/10;
  398.             amazons_features_c[stage*AMAZONS_CNT+i][Y_INDEX] = 1.0*y/10;
  399.             for (int dir = 0; dir < DIR_CNT; ++dir) {
  400.                 int cur_x = x;
  401.                 int cur_y = y;
  402.                 int cnt = 0;
  403.                 while (true) {
  404.                     cur_x += dx[dir];
  405.                     cur_y += dy[dir];
  406.                     if (cur_x >= SIZE || cur_x < 0) break;
  407.                     if (cur_y >= SIZE || cur_y < 0) break;
  408.                     if (state[cur_x][cur_y] != 0) break;
  409.                     else ++cnt;
  410.                 }
  411.                 amazons_features_c[stage*AMAZONS_CNT+i][DIR_MIN_INDEX+dir] = 1.0*cnt/9;
  412.             }
  413.         }
  414.         turn = 1 - turn;
  415.     }
  416.     int dist_map[2*AMAZONS_CNT][SIZE][SIZE];
  417.     for (int i = 0; i < 2*AMAZONS_CNT; ++i)
  418.         for (int j = 0; j < SIZE; ++j)
  419.             for (int k = 0; k < SIZE; ++k)
  420.                 dist_map[i][j][k] = +INF;
  421.     int control_map[2][SIZE][SIZE];
  422.     for (int i = 0; i < 2; ++i)
  423.         for (int j = 0; j < SIZE; ++j)
  424.             for (int k = 0; k < SIZE; ++k)
  425.                 control_map[i][j][k] = +INF;
  426.     int index = 0;
  427.     for (int turn = 0; turn <= 1; ++turn) {
  428.         queue<pos_t> Q;
  429.         for (pos_t a : amazons[turn]) {
  430.             dist_map[index][a.first][a.second] = 0;
  431.             control_map[turn][a.first][a.second] = 0;
  432.             Q.push(a);
  433.             while (!Q.empty()) {
  434.                 pos_t a = Q.front(); Q.pop();
  435.                 int x = a.first, y = a.second;
  436.                 for (int dir = 0; dir < DIR_CNT; ++dir) {
  437.                     int cur_x = x;
  438.                     int cur_y = y;
  439.                     while (true) {
  440.                         cur_x += dx[dir];
  441.                         cur_y += dy[dir];
  442.                         if (cur_x >= SIZE || cur_x < 0) break;
  443.                         if (cur_y >= SIZE || cur_y < 0) break;
  444.                         if (state[cur_x][cur_y] != 0) break;
  445.                         if (dist_map[index][cur_x][cur_y] == INF) {
  446.                             dist_map[index][cur_x][cur_y] =
  447.                             dist_map[index][x][y] + 1;
  448.                             control_map[turn][cur_x][cur_y] =
  449.                             min(control_map[turn][cur_x][cur_y],
  450.                                 dist_map[index][cur_x][cur_y]);
  451.                             Q.push(make_pair(cur_x, cur_y));
  452.                         }
  453.                     }
  454.                 }
  455.             }
  456.             ++index;
  457.         }
  458.     }
  459.     int w_amazon_1ply[AMAZONS_CNT] = {0};
  460.     int w_amazon_2ply[AMAZONS_CNT] = {0};
  461.     int w_amazon_far[AMAZONS_CNT] = {0};
  462.     for (int num = 0; num < AMAZONS_CNT; ++num) {
  463.         for (int i = 0; i < SIZE; ++i) {
  464.             for (int j = 0; j < SIZE; ++j) {
  465.                 if (dist_map[num][i][j] < control_map[1][i][j] &&
  466.                     dist_map[num][i][j] > 0) {
  467.                     if (dist_map[num][i][j] == 1)
  468.                         w_amazon_1ply[num]++;
  469.                     else if (dist_map[num][i][j] == 2)
  470.                         w_amazon_2ply[num]++;
  471.                     else
  472.                         w_amazon_far[num]++;
  473.                 }
  474.             }
  475.         }
  476.         amazons_features_c[num][ONE_PLY_INDEX] = (w_amazon_1ply[num] * 0.05);
  477.         amazons_features_c[num][TWO_PLY_INDEX] = (w_amazon_2ply[num] * 0.05);
  478.         amazons_features_c[num][FAR_INDEX] = (w_amazon_far[num] * 0.05);
  479.     }
  480.     int b_amazon_1ply[AMAZONS_CNT] = {0};
  481.     int b_amazon_2ply[AMAZONS_CNT] = {0};
  482.     int b_amazon_far[AMAZONS_CNT] = {0};
  483.     for (int num = 0; num < AMAZONS_CNT; ++num) {
  484.         for (int i = 0; i < SIZE; ++i) {
  485.             for (int j = 0; j < SIZE; ++j) {
  486.                 if (dist_map[num+AMAZONS_CNT][i][j] <
  487.                     control_map[0][i][j] &&
  488.                     dist_map[num+AMAZONS_CNT][i][j] > 0) {
  489.                     if (dist_map[num+AMAZONS_CNT][i][j] == 1)
  490.                         b_amazon_1ply[num]++;
  491.                     else if (dist_map[num+AMAZONS_CNT][i][j] == 2)
  492.                         b_amazon_2ply[num]++;
  493.                     else
  494.                         b_amazon_far[num]++;
  495.                 }
  496.             }
  497.         }
  498.         amazons_features_c[num+AMAZONS_CNT][ONE_PLY_INDEX] = (w_amazon_1ply[num] * 0.05);
  499.         amazons_features_c[num+AMAZONS_CNT][TWO_PLY_INDEX] = (w_amazon_2ply[num] * 0.05);
  500.         amazons_features_c[num+AMAZONS_CNT][FAR_INDEX] = (w_amazon_far[num] * 0.05);
  501.     }
  502.     int w_control = 0, b_control = 0;
  503.     for (int i = 0; i < SIZE; ++i) {
  504.         for (int j = 0; j < SIZE; ++j) {
  505.             if (control_map[0][i][j] < control_map[1][i][j] &&
  506.                 control_map[0][i][j] > 0)
  507.                 w_control++;
  508.             if (control_map[0][i][j] > control_map[1][i][j] &&
  509.                 control_map[1][i][j] > 0)
  510.                 b_control++;
  511.         }
  512.     }
  513.     inputs.push_back(w_control * 0.01);
  514.     inputs.push_back(b_control * 0.01);
  515.     for (int i = 0; i < 2*AMAZONS_CNT; ++i) {
  516.         for (int j = 0; j < FEATURES_CNT; ++j) {
  517.             inputs.push_back(amazons_features_c[i][j]);
  518.         }
  519.     }
  520.     return inputs;
  521. }
  522.  
  523.  
  524. class NNGameBot {
  525.     clock_t start_time, max_ticks;
  526.     bool terminating;
  527.     NeuralNetwork predictor;
  528. public:
  529.     NNGameBot(double time, NeuralNetwork& N) {
  530.         start_time = clock();
  531.         max_ticks = clock_t(time * CLOCKS_PER_SEC);
  532.         terminating = false;
  533.         predictor = N;
  534.     }
  535.     move_t make_move(GameState& state) {
  536.         double max_score = +INF;
  537.         move_t res;
  538.         auto moves = state.get_moves();
  539.         for (auto move : moves) {
  540.             GameState new_state(state);
  541.             new_state.make_move(move);
  542.             double est = estimate(new_state);
  543.             if (est < max_score) {
  544.                 res = move;
  545.                 max_score = est;
  546.             }
  547.         }
  548.         return res;
  549.     }
  550.     bool time_check() {
  551.         if (clock() - start_time >= max_ticks) {
  552.             terminating = true;
  553.         }
  554.         return terminating;
  555.     }
  556.     double estimate(GameState& state) {
  557.         if (state.get_winner() != -1) return 0.0;
  558.         auto inputs = get_vector(state);
  559.         return predictor(inputs)[1];
  560.     }
  561.     void train(vector<double>& input, vector<double>& output) {
  562.         predictor.train(input, output);
  563.     }
  564.     void print_net(ostream& out) {
  565.         predictor.print(out);
  566.     }
  567. };
  568.  
  569. int main() {
  570.     constexpr int INPUT_SIZE = 3+8*13;
  571.     vector<int> sizes{INPUT_SIZE, 60, 40, 1};
  572.     vector<Matrix> w(3);
  573.     for (int i = 0; i < sizes[1]; ++i) {
  574.         std::uniform_real_distribution<double> unif{1e-6, 1e-5};
  575.         std::default_random_engine re(std::random_device{}());
  576.         if (i == 0) {
  577.             vector<double> t;
  578.             for (int j = 0; j < sizes[0]+1; ++j) {
  579.                 if (j == 2) t.push_back(+2.5);
  580.                 else if (j == 3) t.push_back(-2.5);
  581.                 else t.push_back(unif(re));
  582.             }
  583.             w[0].push_back(t);
  584.         }
  585.         else {
  586.             vector<double> t(sizes[0]+1, 0);
  587.             for (int i = 0; i < sizes[0]+1; ++i) {
  588.                 t[i] = unif(re);
  589.             }
  590.             w[0].push_back(t);
  591.         }
  592.     }
  593.     {
  594.         std::uniform_real_distribution<double> unif{1e-6, 1e-5};
  595.         std::default_random_engine re(std::random_device{}());
  596.         Matrix t1(sizes[2], vector<double>(sizes[1]+1, 0));
  597.         for (int i = 0; i < sizes[2]; ++i) {
  598.             for (int j = 0; j < sizes[1]+1; ++j) {
  599.                 t1[i][j] = unif(re);
  600.             }
  601.         }
  602.         t1[0][0] = -0.5; t1[0][1] = +1;
  603.         w[1] = t1;
  604.         assert(sizes[3] == 1);
  605.         Matrix t2(sizes[3], vector<double>(sizes[2]+1, 0));
  606.         for (int i = 0; i < sizes[2]+1; ++i) {
  607.             t2[0][i] = unif(re);
  608.         }
  609.         t2[0][0] = -0.5; t2[0][1] = +1;
  610.         w[2] = t2;
  611.     }
  612.     NNGameBot bot(0.96, N);
  613.     bot.print_net(cout);
  614.     for (int epoch = 0; epoch < 1000001; ++epoch) {
  615.         int board[SIZE][SIZE];
  616.         for (int i = 0; i < SIZE; ++i) {
  617.             for (int j = 0; j < SIZE; ++j) {
  618.                 board[i][j] = 0;
  619.             }
  620.         }
  621.         board[0][3] = board[0][6] = 2;
  622.         board[3][0] = board[3][9] = 2;
  623.         board[6][0] = board[6][9] = 1;
  624.         board[9][3] = board[9][6] = 1;
  625.         int turn = 1;
  626.         vector<pos_t> snd {
  627.             make_pair(0, 3),
  628.             make_pair(0, 6),
  629.             make_pair(3, 0),
  630.             make_pair(3, 9)
  631.         };
  632.         vector<pos_t> fst {
  633.             make_pair(6, 0),
  634.             make_pair(6, 9),
  635.             make_pair(9, 3),
  636.             make_pair(9, 6)
  637.         };
  638.         GameState state(board, turn, fst, snd);
  639.         std::uniform_real_distribution<double> unif{0, 1};
  640.         std::default_random_engine re(std::random_device{}());
  641.         double eps = 0.075 - epoch * 0.000001;
  642.         eps = max(0.0, eps);
  643.         if (epoch % 100 == 0) cout << "Epoch " << epoch << "\n";
  644.         while (state.get_winner() == -1) {
  645.             srand(time(NULL));
  646.             if (epoch % 100 == 0) {
  647.                 state.print(cout);
  648.                 cout << bot.estimate(state) << "\n";
  649.             }
  650.             move_t move;
  651.             bool rnd = unif(re) <= eps;
  652.             if (rnd) {
  653.                 auto moves = state.get_moves();
  654.                 move = moves[rand() % moves.size()];
  655.             }
  656.             else {
  657.                 move = bot.make_move(state);
  658.             }
  659.             GameState new_state(state);
  660.             new_state.make_move(move);
  661.             if (!rnd) {
  662.                 vector<double> new_est{1-bot.estimate(new_state)};
  663.                 vector<double> state_vector = get_vector(state);
  664.                 bot.train(state_vector, new_est);
  665.             }
  666.             state.make_move(move);
  667.         }
  668.         cout << "Final " << epoch << ":\n";
  669.         state.print(cout);
  670.         if (epoch % 100 == 0) {
  671.             ofstream fout("nn.txt");
  672.             bot.print_net(fout);
  673.             fout.flush();
  674.         }
  675.     }
  676.     return 0;
  677. }
Add Comment
Please, Sign In to add comment