Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package learn;
- import smcmdp.*;
- //import smcmdp.SatResult;
- //import smcmdp.Satisfaction$;
- //import smcmdp.Trace;
- //import smcmdp.Parser;
- import java.io.*;
- import java.util.concurrent.*;
- import java.util.concurrent.locks.*;
- import java.util.*;
- import modelchecking.*;
- import parser.*;
- import parser.ast.*;
- import prism.*;
- import umontreal.iro.lecuyer.probdist.*;
- public class LearnMDP {
- public static double DEFAULT_ALPHA = 0.5;
- public static int MODELCHECK_BLOCK_SIZE = 100;
- private Prism prism; // One Prism to rule them all.
- private ModulesFile modulesFile;
- private Formula formula;
- private int numJobs;
- public int numWorkersWorking;
- private List<TraceGeneratorThread> workers;
- private BlockingQueue<SatResult> results;
- private boolean deterministic;
- private Policy policy;
- private DeterministicPolicy deterministicPolicy;
- private Rewards rewards;
- private boolean done;
- private Lock jobLock;
- private Condition jobsAvailable;
- private Condition jobsDone;
- public LearnMDP(Prism prism, ModulesFile modulesFile, Formula formula) throws Exception {
- this.prism = prism; this.modulesFile = modulesFile; this.formula = formula;
- this.workers = new LinkedList<TraceGeneratorThread>();
- this.jobLock = new ReentrantLock();
- this.jobsAvailable = this.jobLock.newCondition();
- this.jobsDone = this.jobLock.newCondition();
- this.deterministic = false;
- this.done = false;
- }
- public void startThreads(int numThreads, boolean reward) {
- numWorkersWorking = numThreads;
- for(int i = 0; i < numThreads; i++){
- TraceGeneratorThread t = new TraceGeneratorThread(this);
- workers.add(t);
- new Thread(t).start();
- }
- }
- public void setThreadMode(boolean reward) {
- if(reward) results = null;
- else results = new LinkedBlockingDeque<SatResult>();
- for(TraceGeneratorThread t: workers)
- t.setResultQueue(results);
- }
- public void stopThreads(){
- jobLock.lock();
- done = true;
- jobsAvailable.signalAll();
- workers.clear();
- jobLock.unlock();
- }
- public void learn(int numTraces, int numBlocks) throws Exception {
- this.policy = new Policy();
- this.rewards = new Rewards();
- for(int i = 0; i < numBlocks; i++){ // Run nBlocks blocks...
- jobLock.lock();
- addJobs(numTraces);
- jobsDone.await();
- while(numWorkersWorking > 0)
- jobsDone.await();
- assert(numJobs == 0);
- policy.update(rewards, DEFAULT_ALPHA);
- this.rewards = new Rewards();
- jobLock.unlock();
- }
- }
- public EstimationResult IntervalEstimation(double alpha, double beta,
- double delta, double coefficient) throws Exception {
- EstimationResult r = new EstimationResult();
- int nTraces = 0;
- int nSatisfied = 0;
- double postProb = 0;
- do {
- //System.out.println("!"+numWorkersWorking+"!");
- jobLock.lock();
- if(numWorkersWorking == 0 && results.size() == 0){
- assert(numJobs == 0);
- System.out.print("a");
- addJobs(MODELCHECK_BLOCK_SIZE);
- }
- jobLock.unlock();
- SatResult sat = results.take(); // Check whether it is satisfied
- nTraces++;
- if(sat.getSat())
- nSatisfied++;
- r.p = (nSatisfied + alpha) / (nTraces + alpha + beta);
- r.t0 = r.p - delta;
- r.t1 = r.p + delta;
- if (r.t1 > 1) {
- r.t0 = 1 - 2 * delta;
- r.t1 = 1;
- } else if (r.t0 < 0) {
- r.t0 = 0;
- r.t1 = 2 * delta;
- }
- BetaDist bd = new BetaDist(nSatisfied + alpha, nTraces - nSatisfied + beta);
- postProb = bd.cdf(r.t1) - bd.cdf(r.t0);
- } while (postProb < coefficient);
- jobLock.lock();
- numJobs = 0;
- jobLock.unlock();
- r.n = nTraces;
- r.nSat = nSatisfied;
- return r;
- }
- public void lock() { jobLock.lock(); }
- public void unlock() { jobLock.unlock(); }
- public int requestJobs(int limit) throws Exception {
- jobLock.lock();
- numWorkersWorking--;
- jobsDone.signal();
- while(numJobs == 0 && !done){
- jobsAvailable.await();
- System.out.print("o");
- }
- if(done){
- System.out.print("!");
- jobsAvailable.signalAll();
- jobLock.unlock();
- return -1;
- }
- int result = Math.min(numJobs, limit);
- numJobs = numJobs - result;
- numWorkersWorking++;
- //if(numJobs > 0) jobsAvailable.signal(); // TODO Do we need this if we have signalAll in addJobs?
- jobLock.unlock();
- return result;
- }
- /**
- * Call only with lock already acquired!
- * @param numJobs
- */
- public void addJobs(int numJobs) {
- jobLock.lock();
- assert(this.numJobs == 0);
- this.numJobs = numJobs;
- jobsAvailable.signalAll();
- jobLock.unlock();
- }
- public void calculateDeterministicPolicy() {
- deterministicPolicy = new DeterministicPolicy();
- for(Map.Entry<State, List<Double>> e: policy.getPolicy().entrySet()){
- // Find index of max choice
- int maxIndex = -1;
- double max = Double.MIN_VALUE;
- List<Double> l = e.getValue();
- for(int i = 0; i < l.size(); i++){
- if(l.get(i) < max) continue;
- maxIndex = i;
- max = l.get(i);
- }
- // Use found max
- deterministicPolicy.addDeterministicChoice(e.getKey(), maxIndex);
- }
- }
- public void setDeterministic(boolean deterministic) {
- this.deterministic = deterministic;
- }
- public boolean isDone() { return done; }
- public Prism getPrism() { return prism; }
- public ModulesFile getModulesFile() { return modulesFile; }
- public Formula getFormula() { return formula; }
- public Rewards getRewards() { return rewards; }
- public Policy getPolicy() { return (deterministic) ? deterministicPolicy : policy; }
- /**
- * @param path File with the formula in it.
- * @return A string with the actual formula.
- * @throws Exception As usual.... BAD STUFF MAY HAPPEN.
- */
- public static String readFile(File path) throws Exception {
- Scanner s = new Scanner(new FileInputStream(path));
- String result = "";
- while(s.hasNextLine())
- result = result + (s.nextLine().trim());
- return result;
- }
- public static void main(String[] args) throws Exception {
- if(args.length < 5) {
- System.err.println("Usage:");
- System.err.println("sh run.sh modules_file formula_file num_threads number_traces_per_block number_of_blocks");
- System.exit(1);
- }
- File modulesF = new File(args[0]);
- if(!modulesF.exists()){
- System.err.println("Modules file \""+ args[0] +"\" does not exist.");
- System.exit(1);
- }
- File formulaFile = new File(args[1]);
- if(!formulaFile.exists()){
- System.err.println("Formula file \""+ args[0] +"\" does not exist.");
- System.exit(1);
- }
- int numThreads = Integer.parseInt(args[2]);
- int numTraces = Integer.parseInt(args[3]);
- int numBlocks = Integer.parseInt(args[4]);
- if(numTraces < 1 || numBlocks < 1 || numThreads < 1) {
- System.err.println("The number of threads, number traces per block and number of blocks must be a positive integer.");
- System.exit(1);
- }
- // Ignore Prism output
- PrismLog ll = new PrismPrintStreamLog(new PrintStream(new OutputStream(){ public void write(int b) {} }));
- // Create and initialise Prism.
- Prism p = new Prism(ll, ll);
- p.initialise();
- ModulesFile modulesFile = p.parseModelFile(modulesF); // PRISM reads modules file
- Formula formula = Parser$.MODULE$.parseFormula(readFile(formulaFile)); // Read and parse formula
- if(formula == null){
- System.err.println("Could not parse formula.");
- System.exit(1);
- }
- LearnMDP lmdp = new LearnMDP(p, modulesFile, formula);
- lmdp.startThreads(numThreads, true);
- lmdp.learn(numTraces, numBlocks);
- lmdp.calculateDeterministicPolicy();
- lmdp.setThreadMode(false);
- lmdp.setDeterministic(true);
- System.out.println("Interval Estimation!");
- EstimationResult r = lmdp.IntervalEstimation(2, 3, 0.01, 0.95);
- lmdp.stopThreads();
- System.out.println(r);
- }
- }
Add Comment
Please, Sign In to add comment