Guest User

Untitled

a guest
May 12th, 2018
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 5 8.36 KB | None | 0 0
  1. package learn;
  2.  
  3. import smcmdp.*;
  4. //import smcmdp.SatResult;
  5. //import smcmdp.Satisfaction$;
  6. //import smcmdp.Trace;
  7. //import smcmdp.Parser;
  8.  
  9. import java.io.*;
  10. import java.util.concurrent.*;
  11. import java.util.concurrent.locks.*;
  12. import java.util.*;
  13.  
  14. import modelchecking.*;
  15.  
  16. import parser.*;
  17. import parser.ast.*;
  18. import prism.*;
  19. import umontreal.iro.lecuyer.probdist.*;
  20.  
  21. public class LearnMDP {
  22.   public static double DEFAULT_ALPHA = 0.5;
  23.   public static int MODELCHECK_BLOCK_SIZE = 100;
  24.  
  25.   private Prism prism; // One Prism to rule them all.
  26.  
  27.   private ModulesFile modulesFile;
  28.   private Formula formula;
  29.  
  30.   private int numJobs;
  31.   public int numWorkersWorking;
  32.   private List<TraceGeneratorThread> workers;
  33.  
  34.   private BlockingQueue<SatResult> results;
  35.  
  36.   private boolean deterministic;
  37.   private Policy policy;
  38.   private DeterministicPolicy deterministicPolicy;
  39.  
  40.   private Rewards rewards;
  41.  
  42.   private boolean done;
  43.  
  44.   private Lock jobLock;
  45.   private Condition jobsAvailable;
  46.   private Condition jobsDone;
  47.  
  48.  
  49.   public LearnMDP(Prism prism, ModulesFile modulesFile, Formula formula) throws Exception {
  50.     this.prism = prism; this.modulesFile = modulesFile; this.formula = formula;
  51.    
  52.     this.workers = new LinkedList<TraceGeneratorThread>();
  53.    
  54.     this.jobLock = new ReentrantLock();
  55.     this.jobsAvailable = this.jobLock.newCondition();
  56.     this.jobsDone      = this.jobLock.newCondition();
  57.    
  58.     this.deterministic = false;
  59.     this.done = false;
  60.   }
  61.    
  62.   public void startThreads(int numThreads, boolean reward) {
  63.     numWorkersWorking = numThreads;
  64.     for(int i = 0; i < numThreads; i++){
  65.       TraceGeneratorThread t = new TraceGeneratorThread(this);
  66.       workers.add(t);
  67.       new Thread(t).start();
  68.     }
  69.   }
  70.  
  71.   public void setThreadMode(boolean reward) {
  72.     if(reward) results = null;
  73.     else       results = new LinkedBlockingDeque<SatResult>();
  74.      
  75.     for(TraceGeneratorThread t: workers)
  76.       t.setResultQueue(results);
  77.   }
  78.  
  79.   public void stopThreads(){
  80.     jobLock.lock();
  81.     done = true;
  82.     jobsAvailable.signalAll();
  83.     workers.clear();
  84.     jobLock.unlock();
  85.   }
  86.  
  87.   public void learn(int numTraces, int numBlocks) throws Exception {
  88.     this.policy = new Policy();
  89.     this.rewards = new Rewards();
  90.  
  91.     for(int i = 0; i < numBlocks; i++){               // Run nBlocks blocks...
  92.       jobLock.lock();
  93.       addJobs(numTraces);
  94.       jobsDone.await();
  95.       while(numWorkersWorking > 0)
  96.         jobsDone.await();
  97.      
  98.       assert(numJobs == 0);
  99.       policy.update(rewards, DEFAULT_ALPHA);
  100.       this.rewards = new Rewards();
  101.      
  102.       jobLock.unlock();
  103.     }
  104.   }
  105.  
  106.   public EstimationResult IntervalEstimation(double alpha, double beta,
  107.       double delta, double coefficient) throws Exception {
  108.     EstimationResult r = new EstimationResult();
  109.    
  110.     int nTraces = 0;
  111.     int nSatisfied = 0;
  112.     double postProb = 0;
  113.     do {
  114.       //System.out.println("!"+numWorkersWorking+"!");
  115.       jobLock.lock();
  116.       if(numWorkersWorking == 0 && results.size() == 0){
  117.         assert(numJobs == 0);
  118.         System.out.print("a");
  119.         addJobs(MODELCHECK_BLOCK_SIZE);
  120.       }
  121.       jobLock.unlock();
  122.      
  123.       SatResult sat = results.take(); // Check whether it is satisfied
  124.       nTraces++;
  125.      
  126.       if(sat.getSat())
  127.         nSatisfied++;
  128.      
  129.       r.p = (nSatisfied + alpha) / (nTraces + alpha + beta);
  130.       r.t0 = r.p - delta;
  131.       r.t1 = r.p + delta;
  132.      
  133.       if (r.t1 > 1) {
  134.         r.t0 = 1 - 2 * delta;
  135.         r.t1 = 1;
  136.       } else if (r.t0 < 0) {
  137.         r.t0 = 0;
  138.         r.t1 = 2 * delta;
  139.       }
  140.      
  141.       BetaDist bd = new BetaDist(nSatisfied + alpha, nTraces - nSatisfied + beta);
  142.      
  143.       postProb = bd.cdf(r.t1) - bd.cdf(r.t0);
  144.      
  145.     } while (postProb < coefficient);
  146.     jobLock.lock();
  147.     numJobs = 0;
  148.     jobLock.unlock();
  149.    
  150.     r.n = nTraces;
  151.     r.nSat = nSatisfied;
  152.    
  153.     return r;
  154.   }
  155.  
  156.   public void lock() { jobLock.lock(); }
  157.  
  158.   public void unlock() { jobLock.unlock(); }
  159.  
  160.   public int requestJobs(int limit) throws Exception {
  161.       jobLock.lock();
  162.       numWorkersWorking--;
  163.       jobsDone.signal();
  164.       while(numJobs == 0 && !done){
  165.         jobsAvailable.await();
  166.         System.out.print("o");
  167.       }
  168.       if(done){
  169.         System.out.print("!");
  170.         jobsAvailable.signalAll();
  171.         jobLock.unlock();
  172.         return -1;
  173.       }
  174.  
  175.       int result = Math.min(numJobs, limit);
  176.       numJobs = numJobs - result;
  177.       numWorkersWorking++;
  178.       //if(numJobs > 0) jobsAvailable.signal(); // TODO Do we need this if we have signalAll in addJobs?
  179.       jobLock.unlock();
  180.       return result;
  181.   }
  182.  
  183.   /**
  184.    * Call only with lock already acquired!
  185.    * @param numJobs
  186.    */
  187.   public void addJobs(int numJobs) {
  188.     jobLock.lock();
  189.     assert(this.numJobs == 0);
  190.     this.numJobs = numJobs;
  191.     jobsAvailable.signalAll();
  192.     jobLock.unlock();
  193.   }
  194.  
  195.   public void calculateDeterministicPolicy() {
  196.     deterministicPolicy = new DeterministicPolicy();
  197.     for(Map.Entry<State, List<Double>> e: policy.getPolicy().entrySet()){
  198.       // Find index of max choice
  199.       int maxIndex = -1;
  200.       double max = Double.MIN_VALUE;
  201.       List<Double> l = e.getValue();
  202.       for(int i = 0; i < l.size(); i++){
  203.         if(l.get(i) < max) continue;
  204.         maxIndex = i;
  205.         max = l.get(i);
  206.       }
  207.      
  208.       // Use found max
  209.       deterministicPolicy.addDeterministicChoice(e.getKey(), maxIndex);
  210.     }
  211.   }
  212.  
  213.   public void setDeterministic(boolean deterministic) {
  214.     this.deterministic = deterministic;
  215.   }
  216.  
  217.   public boolean isDone() { return done; }
  218.  
  219.   public Prism getPrism() { return prism; }
  220.   public ModulesFile getModulesFile() { return modulesFile; }
  221.   public Formula getFormula() { return formula; }
  222.  
  223.   public Rewards getRewards() { return rewards; }
  224.   public Policy getPolicy() { return (deterministic) ? deterministicPolicy : policy; }
  225.  
  226.  
  227.   /**
  228.    * @param path File with the formula in it.
  229.    * @return A string with the actual formula.
  230.    * @throws Exception As usual.... BAD STUFF MAY HAPPEN.
  231.    */
  232.   public static String readFile(File path) throws Exception {
  233.     Scanner s = new Scanner(new FileInputStream(path));
  234.     String result = "";
  235.     while(s.hasNextLine())
  236.       result = result + (s.nextLine().trim());
  237.     return result;
  238.   }
  239.  
  240.   public static void main(String[] args) throws Exception {
  241.     if(args.length < 5) {
  242.       System.err.println("Usage:");
  243.       System.err.println("sh run.sh modules_file formula_file num_threads number_traces_per_block number_of_blocks");
  244.       System.exit(1);
  245.     }
  246.    
  247.     File modulesF = new File(args[0]);
  248.     if(!modulesF.exists()){
  249.       System.err.println("Modules file \""+ args[0] +"\" does not exist.");
  250.       System.exit(1);
  251.     }
  252.    
  253.     File formulaFile = new File(args[1]);
  254.     if(!formulaFile.exists()){
  255.       System.err.println("Formula file \""+ args[0] +"\" does not exist.");
  256.       System.exit(1);
  257.     }
  258.    
  259.     int numThreads = Integer.parseInt(args[2]);
  260.     int numTraces = Integer.parseInt(args[3]);
  261.     int numBlocks = Integer.parseInt(args[4]);
  262.    
  263.     if(numTraces < 1 || numBlocks < 1 || numThreads < 1) {
  264.       System.err.println("The number of threads, number traces per block and number of blocks must be a positive integer.");
  265.       System.exit(1);
  266.     }
  267.    
  268.     // Ignore Prism output
  269.     PrismLog ll = new PrismPrintStreamLog(new PrintStream(new OutputStream(){ public void write(int b) {} }));
  270.  
  271.     // Create and initialise Prism.
  272.     Prism p = new Prism(ll, ll);
  273.     p.initialise();
  274.  
  275.     ModulesFile modulesFile = p.parseModelFile(modulesF);                      // PRISM reads modules file
  276.    
  277.     Formula formula = Parser$.MODULE$.parseFormula(readFile(formulaFile)); // Read and parse formula
  278.     if(formula == null){
  279.       System.err.println("Could not parse formula.");
  280.       System.exit(1);
  281.     }
  282.    
  283.     LearnMDP lmdp = new LearnMDP(p, modulesFile, formula);
  284.    
  285.     lmdp.startThreads(numThreads, true);
  286.     lmdp.learn(numTraces, numBlocks);
  287.     lmdp.calculateDeterministicPolicy();
  288.     lmdp.setThreadMode(false);
  289.     lmdp.setDeterministic(true);
  290.     System.out.println("Interval Estimation!");
  291.     EstimationResult r = lmdp.IntervalEstimation(2, 3, 0.01, 0.95);
  292.     lmdp.stopThreads();
  293.    
  294.     System.out.println(r);
  295.   }
  296. }
Add Comment
Please, Sign In to add comment