Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package MOEA_GP;
- import java.io.File;
- import java.io.FileNotFoundException;
- import java.io.FileOutputStream;
- import java.io.IOException;
- import java.io.ObjectOutputStream;
- import java.io.OutputStream;
- import java.util.Properties;
- import org.jfree.data.xy.XYSeries;
- import org.jfree.data.xy.XYSeriesCollection;
- import org.jfree.chart.ChartFactory;
- import org.jfree.chart.ChartUtilities;
- import org.jfree.chart.JFreeChart;
- import org.jfree.chart.plot.PlotOrientation;
- import org.moeaframework.core.Solution;
- import org.moeaframework.util.tree.Node;
- import core.ConfusionMatrix;
- import core.DataFrame;
- import core.GraphViz;
- public class Test_Single_Model {
- private DataFrame test_data ;
- Run_MOGP train_mogp ;
- Object[] true_test_labels ;
- public Test_Single_Model(DataFrame test_data,Run_MOGP train_mogp)
- {
- this.test_data = test_data ;
- this.train_mogp = train_mogp ;
- this.true_test_labels = test_data.get_column_data(this.train_mogp.topredict) ;
- }
- public Boolean [] predict(Solution model,DataFrame data)
- {
- Boolean [] test_predictions = Gp_try.compute_predictions(model,data.getData(),train_mogp.features.keySet()) ;
- return test_predictions ;
- }
- public void save(String path)
- {
- int i = 0 ;
- String model_path ;
- OutputStream train_confusion_matrix ;
- OutputStream test_confusion_matrix ;
- ConfusionMatrix train_matrix ;
- ConfusionMatrix test_matrix ;
- XYSeries train_series = new XYSeries("model train perforamnce Data");
- XYSeries test_series = new XYSeries("model test perforamnce Data");
- XYSeriesCollection data = new XYSeriesCollection() ;
- try
- {
- OutputStream chart_stream = new FileOutputStream(path+"/tpr_tnr.png");
- ObjectOutputStream parameters = new ObjectOutputStream(new FileOutputStream(path+"/parameters.txt")) ;
- parameters.writeChars(train_mogp.run_configs.toString());
- parameters.close();
- //System.out.println(train_mogp.solutions.size());
- for (Solution sol : train_mogp.solutions)
- {
- Properties Train_props = new Properties() ;
- Properties Test_props = new Properties() ;
- model_path = path+"/model" + ++i ;
- File dir = new File(model_path);
- dir.mkdir();
- train_confusion_matrix = new FileOutputStream(model_path+"/train_confusion_matrix.xml") ;
- test_confusion_matrix = new FileOutputStream(model_path+"/test_confusion_matrix.xml") ;
- Boolean[] train_predictions = predict(sol,train_mogp.train_data) ;
- Boolean[] test_predictions = predict(sol,this.test_data) ;
- train_matrix = new ConfusionMatrix(train_mogp.true_train_labels,train_predictions) ;
- test_matrix = new ConfusionMatrix(true_test_labels,test_predictions) ;
- train_series.add(train_matrix.sensitivity(),train_matrix.specificity());
- test_series.add(test_matrix.sensitivity(),test_matrix.specificity());
- //train_confusion_matrix.writeChars(train_matrix.get_statistics().toString());
- //test_confusion_matrix.writeChars(test_matrix.get_statistics().toString());
- GraphViz.createDotGraph(((Node)sol.getVariable(0)).getNodeAt(1).todot(1), model_path+"/tree");
- Train_props.putAll(train_matrix.get_statistics());
- Test_props.putAll(test_matrix.get_statistics());
- //System.out.println(Test_props) ;
- //System.out.println(Train_props) ;
- Train_props.storeToXML(train_confusion_matrix, "train performance");
- Test_props.storeToXML(test_confusion_matrix, "test performance");
- train_confusion_matrix.close();
- test_confusion_matrix.close();
- }
- data.addSeries(train_series);
- data.addSeries(test_series);
- JFreeChart chart = ChartFactory.createScatterPlot(
- "Algorithm : "+this.train_mogp.run_configs.get("algorithm") + " population size:" + this.train_mogp.run_configs.get("populationsize") + " max generation number: " + this.train_mogp.run_configs.get("maxgeneration") + " crossover rate: "+this.train_mogp.run_configs.get("crossover_rate") + " mutation rate: "+this.train_mogp.run_configs.get("mutation_rate") ,
- "TPR",
- "TNR",
- data,
- PlotOrientation.VERTICAL,
- true,
- true,
- false
- );
- ChartUtilities.writeChartAsPNG(chart_stream,
- chart,
- 500,
- 300);
- }
- catch (FileNotFoundException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- } catch (IOException e) {
- // TODO Auto-generated catch block
- e.printStackTrace();
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement