Advertisement
Guest User

Untitled

a guest
Jun 22nd, 2017
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 6.70 KB | None | 0 0
  1. package ai;
  2.  
  3. import java.util.ArrayList;
  4. import java.util.Collections;
  5.  
  6. import all.continuous.*;
  7.  
  8. public class MCTS extends ModuleAlgorithm{
  9.  
  10.     //SETTINGS
  11.     private final double GREEDY_SIMULATION_CHANCE = 0.6;
  12.     private final int MAX_ITERATIONS = 10000;
  13.     private final int MINIMUM_VISITS = 20;
  14.     private final double EXPLORATION = Math.sqrt(2);
  15.     private final int SIMULATION_DEPTH = 3;
  16.  
  17.     private final boolean VERBOSE_DEBUG = false;
  18.  
  19.     ArrayList<Action> path = new ArrayList<Action>();
  20.     ArrayList<MCTSNode> nodePath = new ArrayList<MCTSNode>();
  21.  
  22.     private static int turnCounter=0;
  23.     private static int height=0;
  24.  
  25.     private boolean continueLooping = true;
  26.     private int iterationCounter = 0;
  27.  
  28.     private MCTSNode finalNode = null;
  29.  
  30.     public MCTS(Simulation sim) {
  31.         super(sim);
  32.         System.out.println("Initializing MCTS");
  33.     }
  34.  
  35.     public void mainMCTS(Simulation sim){
  36.         MCTSNode root = new MCTSNode(sim.getCurrentConfiguration());
  37.         root.addVisit();
  38.         expand(root);
  39.         for(int i = 0; i < root.getChildren().size(); i++) simulate(root.getChildren().get(i));
  40.  
  41.         //Build tree
  42.         while(continueLooping){
  43.             if(iterationCounter==MAX_ITERATIONS) continueLooping = false;
  44.  
  45.             if(iterationCounter%500==0) System.out.println("MCTS iteration: "+iterationCounter);
  46.             iterationCounter++;
  47.  
  48.             MCTSNode workingNode = root;
  49.  
  50.             while(workingNode.children.size() != 0) workingNode = select(workingNode);
  51.  
  52.             if(workingNode.getVisits() >= MINIMUM_VISITS) {
  53.                 expand(workingNode);
  54.  
  55.                 int childID = (int) (workingNode.getChildren().size() * Math.random());
  56.  
  57.                 workingNode = workingNode.getChildren().get(childID);
  58.             }
  59.  
  60.             double score = simulate(workingNode);
  61.  
  62.             backPropagate(score, workingNode);
  63.         }
  64.  
  65.         //Construct best path
  66.  
  67.         int i = 1;
  68.         if(finalNode!=null){
  69.             System.out.println(" Reconstructing path that leads to goal config");
  70.  
  71.             MCTSNode workingNode = finalNode;
  72.  
  73.             while(workingNode.getParent() != null){
  74.                 nodePath.add(workingNode);
  75.                 workingNode = workingNode.getParent();
  76.             }
  77.  
  78.             Collections.reverse(nodePath);
  79.  
  80.             for (MCTSNode node: nodePath) {
  81.                 path.add(node.getAction());
  82.  
  83.                 if(VERBOSE_DEBUG) System.out.println("  Frame " + i + ": " + estimateScore(node.getConfiguration()));
  84.                 i++;
  85.             }
  86.         } else {
  87.             System.out.println(" Reconstructing best path");
  88.             while(root.getChildren().size()>0){
  89.                 MCTSNode next = bestValueChild(root);
  90.  
  91.                 if(VERBOSE_DEBUG) System.out.println("  Frame " + i + ": " + estimateScore(next.getConfiguration()));
  92.                 i++;
  93.  
  94.                 nodePath.add(next);
  95.  
  96.                 Action a = next.getAction();
  97.                 path.add(a);
  98.  
  99.                 root = next;
  100.             }
  101.         }
  102.     }
  103.  
  104.     private void backPropagate(double score, MCTSNode workingNode) {
  105.         while(true){
  106.             workingNode.addScore(score);
  107.             if(workingNode.getParent() != null) workingNode = workingNode.getParent();
  108.             else break;
  109.         }
  110.     }
  111.  
  112.     public MCTSNode bestValueChild(MCTSNode parent){
  113.         ArrayList<MCTSNode> children = parent.getChildren();
  114.  
  115.         double min = Double.MAX_VALUE;
  116.         MCTSNode best = null;
  117.  
  118.         for (MCTSNode child: children) {
  119.             if(child.getAverageScore()<min){ // && child.getScore()!=Integer.MIN_VALUE
  120.                 min = child.getAverageScore();
  121.                 best = child;
  122.             }
  123.  
  124.         }
  125.  
  126.         return best;
  127.     }
  128.  
  129.     public double selectPolicy(MCTSNode node){
  130.         double selectScore = node.getAverageScore() - EXPLORATION*Math.sqrt(Math.log(node.getParent().getVisits())/node.getVisits());
  131.  
  132.         return selectScore;
  133.     }
  134.  
  135.     public MCTSNode select(MCTSNode origin){
  136.         double min = Double.MAX_VALUE;
  137.         MCTSNode minNode = null;
  138.  
  139.         for (MCTSNode child: origin.getChildren()) {
  140.             if(child.getVisits() < MINIMUM_VISITS) return child;
  141.  
  142.             double selectScore = selectPolicy(child);
  143.  
  144.             if(selectScore < min){
  145.                 min = selectScore;
  146.                 minNode = child;
  147.             }
  148.         }
  149.  
  150.         return minNode;
  151.     }
  152.  
  153.     public void expand(MCTSNode origin){
  154.         ArrayList<Action> validActions = origin.getConfiguration().getAllValidActions();
  155.  
  156.         if (origin.getConfiguration().equals(sim.getGoalConfiguration())) {
  157.             System.out.println("Found goal config!");
  158.             continueLooping = false;
  159.             finalNode = origin;
  160.         } else {
  161.             for (Action action:validActions) {
  162.                 Configuration configCopy = origin.getConfiguration().copy();
  163.                 configCopy.apply(action);
  164.  
  165.                 MCTSNode child = new MCTSNode(configCopy);
  166.                 child.setAction(action);
  167.  
  168.                 origin.addChild(child);
  169.  
  170.                 if(isSameAsAParent(origin)) origin.getParent().getChildren().remove(origin);
  171.             }
  172.         }
  173.     }
  174.  
  175.     public double simulate(MCTSNode origin) {
  176.         Configuration currentConfig = origin.getConfiguration();
  177.         if (currentConfig.equals(sim.getGoalConfiguration())) {
  178.             System.out.println("Found goal config!");
  179.             continueLooping = false;
  180.             finalNode = origin;
  181.         } else {
  182.  
  183.             int moveCounter = 0;
  184.  
  185.             while (moveCounter < SIMULATION_DEPTH) {
  186.                 moveCounter++;
  187.  
  188.                 Configuration nextConfig = currentConfig.copy();
  189.                 currentConfig = nextConfig;
  190.  
  191.                 ArrayList<Action> validActions = currentConfig.getAllValidActions();
  192.  
  193.                 double chance = Math.random();
  194.  
  195.                 if (chance > GREEDY_SIMULATION_CHANCE) { //random
  196.                     int size = validActions.size();
  197.                     int random = (int) (Math.random() * size);
  198.  
  199.                     currentConfig.apply(validActions.get(random));
  200.                 } else { //greedy
  201.                     double bestScore = Integer.MAX_VALUE;
  202.                     Action bestAction = null;
  203.                     for (Action action : validActions) {
  204.                         Configuration testConfig = currentConfig.copy();
  205.                         testConfig.apply(action);
  206.  
  207.                         double currentScore = estimateScore(testConfig);
  208.  
  209.                         if (currentScore < bestScore) {
  210.                             bestAction = action;
  211.                             bestScore = currentScore;
  212.                         }
  213.                     }
  214.  
  215.                     currentConfig.apply(bestAction);
  216.                 }
  217.             }
  218.         }
  219.         return estimateScore(currentConfig);
  220.     }
  221.  
  222.     public double estimateScore(Configuration config){
  223.         ArrayList<Agent> agents = config.getAgents();
  224.         ArrayList<Agent> goals = config.getSimulation().getGoalConfiguration().getAgents();
  225.         double totalManhattanDistance = 0;
  226.  
  227.         for(int i = 0; i < agents.size() ; i++) totalManhattanDistance += Math.pow(agents.get(i).getManhattanDistanceTo(goals.get(i).getLocation()),2);
  228.  
  229.         totalManhattanDistance = totalManhattanDistance/agents.size();
  230.  
  231.         return totalManhattanDistance;
  232.     }
  233.  
  234.     public boolean isSameAsAParent(MCTSNode startingNode){
  235.         if(startingNode.getParent() == null) return false;
  236.  
  237.         MCTSNode workingNode = startingNode.getParent();
  238.  
  239.         while(true){
  240.             if(workingNode.getConfiguration().equals(startingNode.getConfiguration())) return true;
  241.  
  242.             if(workingNode.getParent() == null) return false;
  243.             else workingNode = workingNode.getParent();
  244.         }
  245.     }
  246.  
  247.     @Override
  248.     public void takeTurn() {
  249.         if(turnCounter==0){
  250.             mainMCTS(sim);
  251.         }
  252.  
  253.         if(turnCounter < path.size()){
  254.             sim.apply(path.get(turnCounter));
  255.             turnCounter++;
  256.         } else sim.finish();
  257.     }
  258. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement