Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using System.Collections;
- using System.Collections.Generic;
- using UnityEngine;
- using System.Linq;
- public class MCTS : MonoBehaviour {
- private static System.Random rnd = new System.Random();
- const float WIN_POINTS = 2.0f;
- int t = 0;
- // Use this for initialization
- void Start () {
- }
- // Update is called once per frame
- void Update () {
- }
- Section[,] input(Section[,] original, Vector2[] move, Turn turn) {
- Section[,] newBoard = new Section[3, 3];
- for(int i = 0; i < 3; i++) {
- for(int j = 0; j < 3; j++) {
- newBoard[i, j] = GameManager.cloneSection(original[i, j]);
- }
- }
- newBoard[0, 0].metaData.isBeingChallenged = false;
- newBoard[0, 0].metaData.lastMove = move;
- switch(getSection(original, move).state) {
- case States.Empty:
- getSection(newBoard, move).state = (turn == Turn.X) ? States.X : States.O;
- if(move.Length > 1) {
- t++;
- switch((int)analyzeState(getSection(newBoard, GameManager.dropASection(move)).sectionsInsideSection, turn)) {
- case 0:
- break;
- case (int)WIN_POINTS:
- Section sect = getSection(newBoard, GameManager.dropASection(move));
- sect.state = (turn == Turn.X) ? States.X : States.O;
- sect.sectionsInsideSection = null;
- break;
- case (int)-WIN_POINTS:
- sect = getSection(newBoard, GameManager.dropASection(move));
- sect.state = (turn == Turn.X) ? States.O : States.X;
- sect.sectionsInsideSection = null;
- break;
- case -1:
- sect = getSection(newBoard, GameManager.dropASection(move));
- sect.state = States.Empty;
- sect.sectionsInsideSection = null;
- break;
- }
- }
- break;
- case States.X:
- Section section = getSection(newBoard, move);
- section.state = States.Board;
- section.sectionsInsideSection = new Section[3, 3];
- for(int i = 0; i < 3; i++) {
- for(int j = 0; j < 3; j++) {
- section.sectionsInsideSection[i, j] = new Section();
- }
- }
- newBoard[0, 0].metaData.isBeingChallenged = true;
- break;
- case States.O:
- Section section2 = getSection(newBoard, move);
- section2.state = States.Board;
- section2.sectionsInsideSection = new Section[3, 3];
- for(int i = 0; i < 3; i++) {
- for(int j = 0; j < 3; j++) {
- section2.sectionsInsideSection[i, j] = new Section();
- }
- }
- newBoard[0, 0].metaData.isBeingChallenged = true;
- break;
- }
- return newBoard;
- }
- public List<Section[,]> generateMoves(Section[,] board, Vector2[] baseVector, Turn turn, bool recursive = false) {
- if(board[0, 0].metaData == null)
- board[0, 0].metaData = new MetaData();
- if(!board[0,0].metaData.isBeingChallenged || recursive) {
- List<Section[,]> moves = new List<Section[,]>();
- Section baseSection = getSection(board, baseVector);
- for(int i = 0; i < 3; i++) {
- for(int j = 0; j < 3; j++) {
- States state = baseSection.sectionsInsideSection[i, j].state;
- if(state == States.Empty || ((state == States.O || state == States.X) && baseVector.Length + 1 < GameManager.MAXLAYER && board[0, 0].metaData.lastMove.SequenceEqual(GameManager.addASection(baseVector, new Vector2(i, j))))) {
- Section[,] move = input(board, GameManager.addASection(baseVector, new Vector2(i, j)), turn);
- //If there is a move that will win the player the game, only return that move
- if(analyzeState(move, turn) == WIN_POINTS) {
- moves.Clear();
- moves.Add(move);
- return moves;
- }
- moves.Add(move);
- } else if(state == States.Board) {
- moves.AddRange(generateMoves(board, GameManager.addASection(baseVector, new Vector2(i, j)), turn));
- }
- }
- }
- return moves;
- }else {
- return generateMoves(board, board[0, 0].metaData.lastMove, turn, true);
- }
- }
- public float analyzeState(Section[,] board, Turn perspective) {
- //If is win State, return points based on perspective
- for(int i = 0; i < 3; i++) {
- if(GameManager.threequal(board[i, 0].state, board[i, 1].state, board[i, 2].state, States.X))
- return (perspective == Turn.X) ? WIN_POINTS : -WIN_POINTS;
- if(GameManager.threequal(board[i, 0].state, board[i, 1].state, board[i, 2].state, States.O))
- return (perspective == Turn.O) ? WIN_POINTS : -WIN_POINTS;
- if(GameManager.threequal(board[0, i].state, board[1, i].state, board[2, i].state, States.X))
- return (perspective == Turn.X) ? WIN_POINTS : -WIN_POINTS;
- if(GameManager.threequal(board[0, i].state, board[1, i].state, board[2, i].state, States.O))
- return (perspective == Turn.O) ? WIN_POINTS : -WIN_POINTS;
- }
- if(GameManager.threequal(board[0, 0].state, board[1, 1].state, board[2, 2].state, States.X))
- return (perspective == Turn.X) ? WIN_POINTS : -WIN_POINTS;
- if(GameManager.threequal(board[0, 0].state, board[1, 1].state, board[2, 2].state, States.O))
- return (perspective == Turn.O) ? WIN_POINTS : -WIN_POINTS;
- if(GameManager.threequal(board[0, 2].state, board[1, 1].state, board[2, 0].state, States.X))
- return (perspective == Turn.X) ? WIN_POINTS : -WIN_POINTS;
- if(GameManager.threequal(board[0, 2].state, board[1, 1].state, board[2, 0].state, States.O))
- return (perspective == Turn.O) ? WIN_POINTS : -WIN_POINTS;
- //If not tie, return 0 points
- for(int i = 0; i < 3; i++)
- for(int j = 0; j < 3; j++)
- if(board[i, j].state == States.Empty || board[i, j].state == States.Board)
- return 0;
- //If tie, return -1 points
- return -1;
- }
- public Section[,] mcts(State tree, int iterations) {
- for(int i = 0; i < iterations; i++) {
- //Selection
- State selection = select(tree);
- //Expansion
- State node = expand(selection);
- //Simulation
- float value = simulate(node, tree.turn);
- //Back-propagation
- backPropogate(node, value);
- }
- //Find move with highest score
- float best = float.MinValue;
- Section[,] bestMove = null;
- for(int i = 0; i < tree.getChildren().Count; i++) {
- if(tree.getChild(i).totalVal / tree.getChild(i).visits > best) {
- best = tree.getChild(i).totalVal / tree.getChild(i).visits;
- bestMove = tree.getChild(i).state;
- }
- }
- return bestMove;
- }
- State select(State tree) {
- int numOfChildren = tree.getChildren().Count;
- //If State is leaf node, select it
- if(numOfChildren == 0)
- return tree;
- //Otherwise find State with highest UCB1
- float best = float.MinValue;
- State selection = null;
- for(int i = 0; i < numOfChildren; i++) {
- if(tree.getChild(i).getUCB1() > best) {
- best = tree.getChild(i).getUCB1();
- selection = tree.getChild(i);
- }
- }
- //and then select from that State's children
- selection = select(selection);
- return selection;
- }
- State expand(State node) {
- //If is an end state just return it
- if(analyzeState(node.state, node.turn) != 0)
- return node;
- //Add every possible move as a child to the node
- List<Section[,]> moves = generateMoves(node.state, new Vector2[0], node.turn);
- for(int i = 0; i < moves.Count; i++) {
- node.addChild(new State(moves[i], node, (node.turn == Turn.X) ? Turn.O : Turn.X));
- }
- //Pick a random child and select it
- return node.getChild((Mathf.FloorToInt((float)rnd.NextDouble() * node.getChildren().Count)));
- //return node.getChild((Mathf.FloorToInt(Random.Range(0, node.getChildren().Count))));
- }
- float simulate(State node, Turn turnMax) {
- Section[,] currentState = new Section[3, 3];
- for(int i = 0; i < 3; i++) {
- for(int j = 0; j < 3; j++) {
- currentState[i, j] = GameManager.cloneSection(node.state[i, j]);
- }
- }
- Turn currentTurn = node.turn;
- //While the simulation is not in an end state
- while(analyzeState(currentState, turnMax) == 0) {
- //Pick a random move from possible moves and go from there
- List<Section[,]> moves = generateMoves(currentState, new Vector2[0], currentTurn);
- currentState = moves[Mathf.FloorToInt((float)rnd.NextDouble() * moves.Count)];
- //currentState = (Section[,])moves[Mathf.FloorToInt(Random.Range(0, moves.Count))];
- currentTurn = (currentTurn == Turn.X) ? Turn.O : Turn.X;
- }
- float value = analyzeState(currentState, turnMax);
- return value;
- }
- void backPropogate(State node, float value) {
- bool reachedRoot = false;
- //For ever node until the root add the value to the node and increment visits;
- while(!reachedRoot) {
- node.totalVal += value;
- node.visits++;
- if(node.parent != null) {
- node = node.parent;
- }else {
- reachedRoot = true;
- }
- }
- }
- public Section getSection(Section[,] sections, Vector2[] sects) {
- if(sects.Length == 0) {
- Section section = new Section();
- section.sectionsInsideSection = sections;
- return section;
- } else {
- Section section = sections[(int)sects[0].x, (int)sects[0].y];
- for(int i = 1; i < sects.Length; i++) {
- section = section.sectionsInsideSection[(int)sects[i].x, (int)sects[i].y];
- }
- return section;
- }
- }
- }
- public class State {
- public Section[,] state;
- public float totalVal = 0;
- public float visits = 0.0000001f;
- public State parent;
- List<State> children = new List<State>();
- public Turn turn;
- public State(Section[,] state, State parent, Turn turn) {
- this.state = state;
- this.parent = parent;
- this.turn = turn;
- }
- public void addChild(State child) {
- children.Add(child);
- }
- public List<State> getChildren() {
- return children;
- }
- public State getChild(int index) {
- return children[index];
- }
- public float getUCB1() {
- return totalVal / visits + Mathf.Sqrt((2 * Mathf.Log(parent.visits)) / visits);
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement