Advertisement
Guest User

Untitled

a guest
Jun 26th, 2019
125
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 4.52 KB | None | 0 0
  1. package MOEA_GP;
  2.  
  3. import java.io.File;
  4. import java.io.FileNotFoundException;
  5. import java.io.FileOutputStream;
  6. import java.io.IOException;
  7. import java.io.ObjectOutputStream;
  8. import java.io.OutputStream;
  9. import java.util.Properties;
  10.  
  11. import org.jfree.data.xy.XYSeries;
  12. import org.jfree.data.xy.XYSeriesCollection;
  13. import org.jfree.chart.ChartFactory;
  14. import org.jfree.chart.ChartUtilities;
  15. import org.jfree.chart.JFreeChart;
  16. import org.jfree.chart.plot.PlotOrientation;
  17. import org.moeaframework.core.Solution;
  18. import org.moeaframework.util.tree.Node;
  19.  
  20. import core.ConfusionMatrix;
  21. import core.DataFrame;
  22. import core.GraphViz;
  23.  
  24. public class Test_Single_Model {
  25.  
  26.    
  27.     private DataFrame test_data ;
  28.     Run_MOGP train_mogp ;
  29.     Object[] true_test_labels ;
  30.    
  31.     public Test_Single_Model(DataFrame test_data,Run_MOGP train_mogp)
  32.     {
  33.         this.test_data = test_data ;
  34.         this.train_mogp = train_mogp ;
  35.         this.true_test_labels = test_data.get_column_data(this.train_mogp.topredict) ;
  36.     }
  37.    
  38.     public Boolean [] predict(Solution model,DataFrame data)
  39.     {
  40.        
  41.         Boolean [] test_predictions = Gp_try.compute_predictions(model,data.getData(),train_mogp.features.keySet()) ;
  42.         return test_predictions ;
  43.     }
  44.     public void save(String path)
  45.     {
  46.         int i = 0 ;
  47.         String model_path ;
  48.         OutputStream train_confusion_matrix ;
  49.         OutputStream test_confusion_matrix ;
  50.    
  51.         ConfusionMatrix train_matrix ;
  52.         ConfusionMatrix test_matrix ;
  53.         XYSeries train_series = new XYSeries("model train perforamnce Data");
  54.         XYSeries test_series = new XYSeries("model test perforamnce Data");
  55.         XYSeriesCollection data = new XYSeriesCollection() ;
  56.        
  57.         try
  58.         {
  59.             OutputStream chart_stream = new FileOutputStream(path+"/tpr_tnr.png");
  60.             ObjectOutputStream parameters = new ObjectOutputStream(new FileOutputStream(path+"/parameters.txt")) ;
  61.             parameters.writeChars(train_mogp.run_configs.toString());
  62.             parameters.close();
  63.             //System.out.println(train_mogp.solutions.size());
  64.             for (Solution sol : train_mogp.solutions)
  65.             {
  66.                 Properties Train_props = new Properties() ;
  67.                 Properties Test_props = new Properties() ;
  68.                 model_path = path+"/model" + ++i ;
  69.                 File dir = new File(model_path);
  70.                 dir.mkdir();
  71.                 train_confusion_matrix = new FileOutputStream(model_path+"/train_confusion_matrix.xml") ;
  72.                 test_confusion_matrix = new FileOutputStream(model_path+"/test_confusion_matrix.xml") ;
  73.                 Boolean[] train_predictions =  predict(sol,train_mogp.train_data) ;
  74.                 Boolean[] test_predictions =  predict(sol,this.test_data) ;
  75.                
  76.                 train_matrix = new ConfusionMatrix(train_mogp.true_train_labels,train_predictions) ;
  77.                 test_matrix = new ConfusionMatrix(true_test_labels,test_predictions) ;  
  78.                
  79.                 train_series.add(train_matrix.sensitivity(),train_matrix.specificity());
  80.                 test_series.add(test_matrix.sensitivity(),test_matrix.specificity());
  81.                
  82.                
  83.                 //train_confusion_matrix.writeChars(train_matrix.get_statistics().toString());
  84.                 //test_confusion_matrix.writeChars(test_matrix.get_statistics().toString());
  85.                
  86.                 GraphViz.createDotGraph(((Node)sol.getVariable(0)).getNodeAt(1).todot(1), model_path+"/tree");
  87.                
  88.  
  89.                 Train_props.putAll(train_matrix.get_statistics());
  90.                 Test_props.putAll(test_matrix.get_statistics());
  91.  
  92.                 //System.out.println(Test_props) ;
  93.                 //System.out.println(Train_props) ;
  94.                
  95.                 Train_props.storeToXML(train_confusion_matrix, "train performance");
  96.                 Test_props.storeToXML(test_confusion_matrix, "test performance");
  97.                
  98.                 train_confusion_matrix.close();
  99.                 test_confusion_matrix.close();
  100.                
  101.             }      
  102.             data.addSeries(train_series);
  103.             data.addSeries(test_series);
  104.             JFreeChart chart = ChartFactory.createScatterPlot(
  105.                     "Algorithm : "+this.train_mogp.run_configs.get("algorithm") +  " population size:" + this.train_mogp.run_configs.get("populationsize") + " max generation number: " + this.train_mogp.run_configs.get("maxgeneration") + " crossover rate: "+this.train_mogp.run_configs.get("crossover_rate") + " mutation rate: "+this.train_mogp.run_configs.get("mutation_rate") ,
  106.                     "TPR",
  107.                     "TNR",
  108.                     data,
  109.                     PlotOrientation.VERTICAL,
  110.                     true,
  111.                     true,
  112.                     false
  113.                      );
  114.             ChartUtilities.writeChartAsPNG(chart_stream,
  115.                     chart,
  116.                     500,
  117.                     300);
  118.  
  119.         }
  120.         catch (FileNotFoundException e) {
  121.             // TODO Auto-generated catch block
  122.             e.printStackTrace();
  123.         } catch (IOException e) {
  124.             // TODO Auto-generated catch block
  125.             e.printStackTrace();
  126.         }
  127.     }
  128. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement