Advertisement
Guest User

Untitled

a guest
Aug 18th, 2017
112
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 123.45 KB | None | 0 0
  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16.  
  17. /*
  18.  *    Evaluation.java
  19.  *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
  20.  *
  21.  */
  22.  
  23. package weka.classifiers;
  24.  
  25. import weka.classifiers.evaluation.NominalPrediction;
  26. import weka.classifiers.evaluation.ThresholdCurve;
  27. import weka.classifiers.pmml.consumer.PMMLClassifier;
  28. import weka.classifiers.xml.XMLClassifier;
  29. import weka.core.Drawable;
  30. import weka.core.FastVector;
  31. import weka.core.Instance;
  32. import weka.core.Instances;
  33. import weka.core.Option;
  34. import weka.core.OptionHandler;
  35. import weka.core.Range;
  36. import weka.core.RevisionHandler;
  37. import weka.core.RevisionUtils;
  38. import weka.core.Summarizable;
  39. import weka.core.Utils;
  40. import weka.core.Version;
  41. import weka.core.converters.ConverterUtils.DataSink;
  42. import weka.core.converters.ConverterUtils.DataSource;
  43. import weka.core.pmml.PMMLFactory;
  44. import weka.core.pmml.PMMLModel;
  45. import weka.core.xml.KOML;
  46. import weka.core.xml.XMLOptions;
  47. import weka.core.xml.XMLSerialization;
  48. import weka.estimators.Estimator;
  49. import weka.estimators.KernelEstimator;
  50.  
  51. import java.beans.BeanInfo;
  52. import java.beans.Introspector;
  53. import java.beans.MethodDescriptor;
  54. import java.io.BufferedInputStream;
  55. import java.io.BufferedOutputStream;
  56. import java.io.BufferedReader;
  57. import java.io.FileInputStream;
  58. import java.io.FileOutputStream;
  59. import java.io.FileReader;
  60. import java.io.InputStream;
  61. import java.io.ObjectInputStream;
  62. import java.io.ObjectOutputStream;
  63. import java.io.OutputStream;
  64. import java.io.Reader;
  65. import java.lang.reflect.Method;
  66. import java.util.Date;
  67. import java.util.Enumeration;
  68. import java.util.Random;
  69. import java.util.zip.GZIPInputStream;
  70. import java.util.zip.GZIPOutputStream;
  71.  
  72. /**
  73.  * Class for evaluating machine learning models. <p/>
  74.  *
  75.  * ------------------------------------------------------------------- <p/>
  76.  *
  77.  * General options when evaluating a learning scheme from the command-line: <p/>
  78.  *
  79.  * -t filename <br/>
  80.  * Name of the file with the training data. (required) <p/>
  81.  *
  82.  * -T filename <br/>
  83.  * Name of the file with the test data. If missing a cross-validation
  84.  * is performed. <p/>
  85.  *
  86.  * -c index <br/>
  87.  * Index of the class attribute (1, 2, ...; default: last). <p/>
  88.  *
  89.  * -x number <br/>
  90.  * The number of folds for the cross-validation (default: 10). <p/>
  91.  *
  92.  * -no-cv <br/>
  93.  * No cross validation.  If no test file is provided, no evaluation
  94.  * is done. <p/>
  95.  *
  96.  * -split-percentage percentage <br/>
  97.  * Sets the percentage for the train/test set split, e.g., 66. <p/>
  98.  *
  99.  * -preserve-order <br/>
  100.  * Preserves the order in the percentage split instead of randomizing
  101.  * the data first with the seed value ('-s'). <p/>
  102.  *
  103.  * -s seed <br/>
  104.  * Random number seed for the cross-validation and percentage split
  105.  * (default: 1). <p/>
  106.  *
  107.  * -m filename <br/>
  108.  * The name of a file containing a cost matrix. <p/>
  109.  *
  110.  * -l filename <br/>
  111.  * Loads classifier from the given file. In case the filename ends with ".xml",
  112.  * a PMML file is loaded or, if that fails, options are loaded from XML. <p/>
  113.  *
  114.  * -d filename <br/>
  115.  * Saves classifier built from the training data into the given file. In case
  116.  * the filename ends with ".xml" the options are saved XML, not the model. <p/>
  117.  *
  118.  * -v <br/>
  119.  * Outputs no statistics for the training data. <p/>
  120.  *
  121.  * -o <br/>
  122.  * Outputs statistics only, not the classifier. <p/>
  123.  *
  124.  * -i <br/>
  125.  * Outputs information-retrieval statistics per class. <p/>
  126.  *
  127.  * -k <br/>
  128.  * Outputs information-theoretic statistics. <p/>
  129.  *
  130.  * -p range <br/>
  131.  * Outputs predictions for test instances (or the train instances if no test
  132.  * instances provided and -no-cv is used), along with the attributes in the specified range
  133.  * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
  134.  *
  135.  * -distribution <br/>
  136.  * Outputs the distribution instead of only the prediction
  137.  * in conjunction with the '-p' option (only nominal classes). <p/>
  138.  *
  139.  * -r <br/>
  140.  * Outputs cumulative margin distribution (and nothing else). <p/>
  141.  *
  142.  * -g <br/>
  143.  * Only for classifiers that implement "Graphable." Outputs
  144.  * the graph representation of the classifier (and nothing
  145.  * else). <p/>
  146.  *
  147.  * -xml filename | xml-string <br/>
  148.  * Retrieves the options from the XML-data instead of the command line. <p/>
  149.  *
  150.  * -threshold-file file <br/>
  151.  * The file to save the threshold data to.
  152.  * The format is determined by the extensions, e.g., '.arff' for ARFF
  153.  * format or '.csv' for CSV. <p/>
  154.  *        
  155.  * -threshold-label label <br/>
  156.  * The class label to determine the threshold data for
  157.  * (default is the first label) <p/>
  158.  *        
  159.  * ------------------------------------------------------------------- <p/>
  160.  *
  161.  * Example usage as the main of a classifier (called FunkyClassifier):
  162.  * <code> <pre>
  163.  * public static void main(String [] args) {
  164.  *   runClassifier(new FunkyClassifier(), args);
  165.  * }
  166.  * </pre> </code>
  167.  * <p/>
  168.  *
  169.  * ------------------------------------------------------------------ <p/>
  170.  *
  171.  * Example usage from within an application:
  172.  * <code> <pre>
  173.  * Instances trainInstances = ... instances got from somewhere
  174.  * Instances testInstances = ... instances got from somewhere
  175.  * Classifier scheme = ... scheme got from somewhere
  176.  *
  177.  * Evaluation evaluation = new Evaluation(trainInstances);
  178.  * evaluation.evaluateModel(scheme, testInstances);
  179.  * System.out.println(evaluation.toSummaryString());
  180.  * </pre> </code>
  181.  *
  182.  *
  183.  * @author   Eibe Frank (eibe@cs.waikato.ac.nz)
  184.  * @author   Len Trigg (trigg@cs.waikato.ac.nz)
  185.  * @version  $Revision: 6346 $
  186.  */
  187. public class Evaluation
  188.   implements Summarizable, RevisionHandler {
  189.  
  190.   /** The number of classes. */
  191.   protected int m_NumClasses;
  192.  
  193.   /** The number of folds for a cross-validation. */
  194.   protected int m_NumFolds;
  195.  
  196.   /** The weight of all incorrectly classified instances. */
  197.   protected double m_Incorrect;
  198.  
  199.   /** The weight of all correctly classified instances. */
  200.   protected double m_Correct;
  201.  
  202.   /** The weight of all unclassified instances. */
  203.   protected double m_Unclassified;
  204.  
  205.   /*** The weight of all instances that had no class assigned to them. */
  206.   protected double m_MissingClass;
  207.  
  208.   /** The weight of all instances that had a class assigned to them. */
  209.   protected double m_WithClass;
  210.  
  211.   /** Array for storing the confusion matrix. */
  212.   protected double [][] m_ConfusionMatrix;
  213.  
  214.   /** The names of the classes. */
  215.   protected String [] m_ClassNames;
  216.  
  217.   /** Is the class nominal or numeric? */
  218.   protected boolean m_ClassIsNominal;
  219.  
  220.   /** The prior probabilities of the classes */
  221.   protected double [] m_ClassPriors;
  222.  
  223.   /** The sum of counts for priors */
  224.   protected double m_ClassPriorsSum;
  225.  
  226.   /** The cost matrix (if given). */
  227.   protected CostMatrix m_CostMatrix;
  228.  
  229.   /** The total cost of predictions (includes instance weights) */
  230.   protected double m_TotalCost;
  231.  
  232.   /** Sum of errors. */
  233.   protected double m_SumErr;
  234.  
  235.   /** Sum of absolute errors. */
  236.   protected double m_SumAbsErr;
  237.  
  238.   /** Sum of squared errors. */
  239.   protected double m_SumSqrErr;
  240.  
  241.   /** Sum of class values. */
  242.   protected double m_SumClass;
  243.  
  244.   /** Sum of squared class values. */
  245.   protected double m_SumSqrClass;
  246.  
  247.   /*** Sum of predicted values. */
  248.   protected double m_SumPredicted;
  249.  
  250.   /** Sum of squared predicted values. */
  251.   protected double m_SumSqrPredicted;
  252.  
  253.   /** Sum of predicted * class values. */
  254.   protected double m_SumClassPredicted;
  255.  
  256.   /** Sum of absolute errors of the prior */
  257.   protected double m_SumPriorAbsErr;
  258.  
  259.   /** Sum of absolute errors of the prior */
  260.   protected double m_SumPriorSqrErr;
  261.  
  262.   /** Total Kononenko & Bratko Information */
  263.   protected double m_SumKBInfo;
  264.  
  265.   /*** Resolution of the margin histogram */
  266.   protected static int k_MarginResolution = 500;
  267.  
  268.   /** Cumulative margin distribution */
  269.   protected double m_MarginCounts [];
  270.  
  271.   /** Number of non-missing class training instances seen */
  272.   protected int m_NumTrainClassVals;
  273.  
  274.   /** Array containing all numeric training class values seen */
  275.   protected double [] m_TrainClassVals;
  276.  
  277.   /** Array containing all numeric training class weights */
  278.   protected double [] m_TrainClassWeights;
  279.  
  280.   /** Numeric class error estimator for prior */
  281.   protected Estimator m_PriorErrorEstimator;
  282.  
  283.   /** Numeric class error estimator for scheme */
  284.   protected Estimator m_ErrorEstimator;
  285.  
  286.   /**
  287.    * The minimum probablility accepted from an estimator to avoid
  288.    * taking log(0) in Sf calculations.
  289.    */
  290.   protected static final double MIN_SF_PROB = Double.MIN_VALUE;
  291.  
  292.   /** Total entropy of prior predictions */
  293.   protected double m_SumPriorEntropy;
  294.  
  295.   /** Total entropy of scheme predictions */
  296.   protected double m_SumSchemeEntropy;
  297.  
  298.   /** The list of predictions that have been generated (for computing AUC) */
  299.   private FastVector m_Predictions;
  300.  
  301.   /** enables/disables the use of priors, e.g., if no training set is
  302.    * present in case of de-serialized schemes */
  303.   protected boolean m_NoPriors = false;
  304.  
  305.   /**
  306.    * Initializes all the counters for the evaluation.
  307.    * Use <code>useNoPriors()</code> if the dataset is the test set and you
  308.    * can't initialize with the priors from the training set via
  309.    * <code>setPriors(Instances)</code>.
  310.    *
  311.    * @param data    set of training instances, to get some header
  312.    *            information and prior class distribution information
  313.    * @throws Exception  if the class is not defined
  314.    * @see       #useNoPriors()
  315.    * @see       #setPriors(Instances)
  316.    */
  317.   public Evaluation(Instances data) throws Exception {
  318.  
  319.     this(data, null);
  320.   }
  321.  
  322.   /**
  323.    * Initializes all the counters for the evaluation and also takes a
  324.    * cost matrix as parameter.
  325.    * Use <code>useNoPriors()</code> if the dataset is the test set and you
  326.    * can't initialize with the priors from the training set via
  327.    * <code>setPriors(Instances)</code>.
  328.    *
  329.    * @param data    set of training instances, to get some header
  330.    *            information and prior class distribution information
  331.    * @param costMatrix  the cost matrix---if null, default costs will be used
  332.    * @throws Exception  if cost matrix is not compatible with
  333.    *            data, the class is not defined or the class is numeric
  334.    * @see       #useNoPriors()
  335.    * @see       #setPriors(Instances)
  336.    */
  337.   public Evaluation(Instances data, CostMatrix costMatrix)
  338.   throws Exception {
  339.  
  340.     m_NumClasses = data.numClasses();
  341.     m_NumFolds = 1;
  342.     m_ClassIsNominal = data.classAttribute().isNominal();
  343.  
  344.     if (m_ClassIsNominal) {
  345.       m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses];
  346.       m_ClassNames = new String [m_NumClasses];
  347.       for(int i = 0; i < m_NumClasses; i++) {
  348.     m_ClassNames[i] = data.classAttribute().value(i);
  349.       }
  350.     }
  351.     m_CostMatrix = costMatrix;
  352.     if (m_CostMatrix != null) {
  353.       if (!m_ClassIsNominal) {
  354.     throw new Exception("Class has to be nominal if cost matrix " +
  355.     "given!");
  356.       }
  357.       if (m_CostMatrix.size() != m_NumClasses) {
  358.     throw new Exception("Cost matrix not compatible with data!");
  359.       }
  360.     }
  361.     m_ClassPriors = new double [m_NumClasses];
  362.     setPriors(data);
  363.     m_MarginCounts = new double [k_MarginResolution + 1];
  364.   }
  365.  
  366.   /**
  367.    * Returns the area under ROC for those predictions that have been collected
  368.    * in the evaluateClassifier(Classifier, Instances) method. Returns
  369.    * Instance.missingValue() if the area is not available.
  370.    *
  371.    * @param classIndex the index of the class to consider as "positive"
  372.    * @return the area under the ROC curve or not a number
  373.    */
  374.   public double areaUnderROC(int classIndex) {
  375.  
  376.     // Check if any predictions have been collected
  377.     if (m_Predictions == null) {
  378.       return Instance.missingValue();
  379.     } else {
  380.       ThresholdCurve tc = new ThresholdCurve();
  381.       Instances result = tc.getCurve(m_Predictions, classIndex);
  382.       return ThresholdCurve.getROCArea(result);
  383.     }
  384.   }
  385.  
  386.   /**
  387.    * Calculates the weighted (by class size) AUC.
  388.    *
  389.    * @return the weighted AUC.
  390.    */
  391.   public double weightedAreaUnderROC() {
  392.     double[] classCounts = new double[m_NumClasses];
  393.     double classCountSum = 0;
  394.    
  395.     for (int i = 0; i < m_NumClasses; i++) {
  396.       for (int j = 0; j < m_NumClasses; j++) {
  397.         classCounts[i] += m_ConfusionMatrix[i][j];
  398.       }
  399.       classCountSum += classCounts[i];
  400.     }
  401.  
  402.     double aucTotal = 0;
  403.     for(int i = 0; i < m_NumClasses; i++) {
  404.       double temp = areaUnderROC(i);
  405.       if (!Instance.isMissingValue(temp)) {
  406.         aucTotal += (temp * classCounts[i]);
  407.       }
  408.     }
  409.  
  410.     return aucTotal / classCountSum;
  411.   }
  412.  
  413.   /**
  414.    * Returns a copy of the confusion matrix.
  415.    *
  416.    * @return a copy of the confusion matrix as a two-dimensional array
  417.    */
  418.   public double[][] confusionMatrix() {
  419.  
  420.     double[][] newMatrix = new double[m_ConfusionMatrix.length][0];
  421.  
  422.     for (int i = 0; i < m_ConfusionMatrix.length; i++) {
  423.       newMatrix[i] = new double[m_ConfusionMatrix[i].length];
  424.       System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0,
  425.       m_ConfusionMatrix[i].length);
  426.     }
  427.     return newMatrix;
  428.   }
  429.  
  430.   /**
  431.    * Performs a (stratified if class is nominal) cross-validation
  432.    * for a classifier on a set of instances. Now performs
  433.    * a deep copy of the classifier before each call to
  434.    * buildClassifier() (just in case the classifier is not
  435.    * initialized properly).
  436.    *
  437.    * @param classifier the classifier with any options set.
  438.    * @param data the data on which the cross-validation is to be
  439.    * performed
  440.    * @param numFolds the number of folds for the cross-validation
  441.    * @param random random number generator for randomization
  442.    * @param forPredictionsString varargs parameter that, if supplied, is
  443.    * expected to hold a StringBuffer to print predictions to,
  444.    * a Range of attributes to output and a Boolean (true if the distribution
  445.    * is to be printed)
  446.    * @throws Exception if a classifier could not be generated
  447.    * successfully or the class is not defined
  448.    */
  449.   public void crossValidateModel(Classifier classifier,
  450.                                  Instances data, int numFolds, Random random,
  451.                                  Object... forPredictionsPrinting)
  452.   throws Exception {
  453.  
  454.     // Make a copy of the data we can reorder
  455.     data = new Instances(data);
  456.     data.randomize(random);
  457.     if (data.classAttribute().isNominal()) {
  458.       data.stratify(numFolds);
  459.     }
  460.  
  461.     // We assume that the first element is a StringBuffer, the second a Range (attributes
  462.     // to output) and the third a Boolean (whether or not to output a distribution instead
  463.     // of just a classification)
  464.     if (forPredictionsPrinting.length > 0) {
  465.       // print the header first
  466.       StringBuffer buff = (StringBuffer)forPredictionsPrinting[0];
  467.       Range attsToOutput = (Range)forPredictionsPrinting[1];
  468.       boolean printDist = ((Boolean)forPredictionsPrinting[2]).booleanValue();
  469.       printClassificationsHeader(data, attsToOutput, printDist, buff);
  470.     }
  471.  
  472.     // Do the folds
  473.     for (int i = 0; i < numFolds; i++) {
  474.       Instances train = data.trainCV(numFolds, i, random);
  475.       setPriors(train);
  476.       Classifier copiedClassifier = Classifier.makeCopy(classifier);
  477.       copiedClassifier.buildClassifier(train);
  478.       Instances test = data.testCV(numFolds, i);
  479.       evaluateModel(copiedClassifier, test, forPredictionsPrinting);
  480.     }
  481.     m_NumFolds = numFolds;
  482.   }
  483.  
  484.   /**
  485.    * Performs a (stratified if class is nominal) cross-validation
  486.    * for a classifier on a set of instances.
  487.    *
  488.    * @param classifierString a string naming the class of the classifier
  489.    * @param data the data on which the cross-validation is to be
  490.    * performed
  491.    * @param numFolds the number of folds for the cross-validation
  492.    * @param options the options to the classifier. Any options
  493.    * @param random the random number generator for randomizing the data
  494.    * accepted by the classifier will be removed from this array.
  495.    * @throws Exception if a classifier could not be generated
  496.    * successfully or the class is not defined
  497.    */
  498.   public void crossValidateModel(String classifierString,
  499.       Instances data, int numFolds,
  500.       String[] options, Random random)
  501.   throws Exception {
  502.  
  503.     crossValidateModel(Classifier.forName(classifierString, options),
  504.     data, numFolds, random);
  505.   }
  506.  
  507.   /**
  508.    * Evaluates a classifier with the options given in an array of
  509.    * strings. <p/>
  510.    *
  511.    * Valid options are: <p/>
  512.    *
  513.    * -t filename <br/>
  514.    * Name of the file with the training data. (required) <p/>
  515.    *
  516.    * -T filename <br/>
  517.    * Name of the file with the test data. If missing a cross-validation
  518.    * is performed. <p/>
  519.    *
  520.    * -c index <br/>
  521.    * Index of the class attribute (1, 2, ...; default: last). <p/>
  522.    *
  523.    * -x number <br/>
  524.    * The number of folds for the cross-validation (default: 10). <p/>
  525.    *
  526.    * -no-cv <br/>
  527.    * No cross validation.  If no test file is provided, no evaluation
  528.    * is done. <p/>
  529.    *
  530.    * -split-percentage percentage <br/>
  531.    * Sets the percentage for the train/test set split, e.g., 66. <p/>
  532.    *
  533.    * -preserve-order <br/>
  534.    * Preserves the order in the percentage split instead of randomizing
  535.    * the data first with the seed value ('-s'). <p/>
  536.    *
  537.    * -s seed <br/>
  538.    * Random number seed for the cross-validation and percentage split
  539.    * (default: 1). <p/>
  540.    *
  541.    * -m filename <br/>
  542.    * The name of a file containing a cost matrix. <p/>
  543.    *
  544.    * -l filename <br/>
  545.    * Loads classifier from the given file. In case the filename ends with
  546.    * ".xml",a PMML file is loaded or, if that fails, options are loaded from XML. <p/>
  547.    *
  548.    * -d filename <br/>
  549.    * Saves classifier built from the training data into the given file. In case
  550.    * the filename ends with ".xml" the options are saved XML, not the model. <p/>
  551.    *
  552.    * -v <br/>
  553.    * Outputs no statistics for the training data. <p/>
  554.    *
  555.    * -o <br/>
  556.    * Outputs statistics only, not the classifier. <p/>
  557.    *
  558.    * -i <br/>
  559.    * Outputs detailed information-retrieval statistics per class. <p/>
  560.    *
  561.    * -k <br/>
  562.    * Outputs information-theoretic statistics. <p/>
  563.    *
  564.    * -p range <br/>
  565.    * Outputs predictions for test instances (or the train instances if no test
  566.    * instances provided  and -no-cv is used), along with the attributes in the specified range (and
  567.    *  nothing else). Use '-p 0' if no attributes are desired. <p/>
  568.    *
  569.    * -distribution <br/>
  570.    * Outputs the distribution instead of only the prediction
  571.    * in conjunction with the '-p' option (only nominal classes). <p/>
  572.    *
  573.    * -r <br/>
  574.    * Outputs cumulative margin distribution (and nothing else). <p/>
  575.    *
  576.    * -g <br/>
  577.    * Only for classifiers that implement "Graphable." Outputs
  578.    * the graph representation of the classifier (and nothing
  579.    * else). <p/>
  580.    *
  581.    * -xml filename | xml-string <br/>
  582.    * Retrieves the options from the XML-data instead of the command line. <p/>
  583.    *
  584.    * -threshold-file file <br/>
  585.    * The file to save the threshold data to.
  586.    * The format is determined by the extensions, e.g., '.arff' for ARFF
  587.    * format or '.csv' for CSV. <p/>
  588.    *        
  589.    * -threshold-label label <br/>
  590.    * The class label to determine the threshold data for
  591.    * (default is the first label) <p/>
  592.    *
  593.    * @param classifierString class of machine learning classifier as a string
  594.    * @param options the array of string containing the options
  595.    * @throws Exception if model could not be evaluated successfully
  596.    * @return a string describing the results
  597.    */
  598.   public static String evaluateModel(String classifierString,
  599.       String [] options) throws Exception {
  600.  
  601.     Classifier classifier;   
  602.  
  603.     // Create classifier
  604.     try {
  605.       classifier =
  606.     (Classifier)Class.forName(classifierString).newInstance();
  607.     } catch (Exception e) {
  608.       throw new Exception("Can't find class with name "
  609.       + classifierString + '.');
  610.     }
  611.     return evaluateModel(classifier, options);
  612.   }
  613.  
  614.   /**
  615.    * A test method for this class. Just extracts the first command line
  616.    * argument as a classifier class name and calls evaluateModel.
  617.    * @param args an array of command line arguments, the first of which
  618.    * must be the class name of a classifier.
  619.    */
  620.   public static void main(String [] args) {
  621.  
  622.     try {
  623.       if (args.length == 0) {
  624.     throw new Exception("The first argument must be the class name"
  625.         + " of a classifier");
  626.       }
  627.       String classifier = args[0];
  628.       args[0] = "";
  629.       System.out.println(evaluateModel(classifier, args));
  630.     } catch (Exception ex) {
  631.       ex.printStackTrace();
  632.       System.err.println(ex.getMessage());
  633.     }
  634.   }
  635.  
  636.   /**
  637.    * Evaluates a classifier with the options given in an array of
  638.    * strings. <p/>
  639.    *
  640.    * Valid options are: <p/>
  641.    *
  642.    * -t name of training file <br/>
  643.    * Name of the file with the training data. (required) <p/>
  644.    *
  645.    * -T name of test file <br/>
  646.    * Name of the file with the test data. If missing a cross-validation
  647.    * is performed. <p/>
  648.    *
  649.    * -c class index <br/>
  650.    * Index of the class attribute (1, 2, ...; default: last). <p/>
  651.    *
  652.    * -x number of folds <br/>
  653.    * The number of folds for the cross-validation (default: 10). <p/>
  654.    *
  655.    * -no-cv <br/>
  656.    * No cross validation.  If no test file is provided, no evaluation
  657.    * is done. <p/>
  658.    *
  659.    * -split-percentage percentage <br/>
  660.    * Sets the percentage for the train/test set split, e.g., 66. <p/>
  661.    *
  662.    * -preserve-order <br/>
  663.    * Preserves the order in the percentage split instead of randomizing
  664.    * the data first with the seed value ('-s'). <p/>
  665.    *
  666.    * -s seed <br/>
  667.    * Random number seed for the cross-validation and percentage split
  668.    * (default: 1). <p/>
  669.    *
  670.    * -m file with cost matrix <br/>
  671.    * The name of a file containing a cost matrix. <p/>
  672.    *
  673.    * -l filename <br/>
  674.    * Loads classifier from the given file. In case the filename ends with
  675.    * ".xml",a PMML file is loaded or, if that fails, options are loaded from XML. <p/>
  676.    *
  677.    * -d filename <br/>
  678.    * Saves classifier built from the training data into the given file. In case
  679.    * the filename ends with ".xml" the options are saved XML, not the model. <p/>
  680.    *
  681.    * -v <br/>
  682.    * Outputs no statistics for the training data. <p/>
  683.    *
  684.    * -o <br/>
  685.    * Outputs statistics only, not the classifier. <p/>
  686.    *
  687.    * -i <br/>
  688.    * Outputs detailed information-retrieval statistics per class. <p/>
  689.    *
  690.    * -k <br/>
  691.    * Outputs information-theoretic statistics. <p/>
  692.    *
  693.    * -p range <br/>
  694.    * Outputs predictions for test instances (or the train instances if no test
  695.    * instances provided and -no-cv is used), along with the attributes in the specified range
  696.    * (and nothing else). Use '-p 0' if no attributes are desired. <p/>
  697.    *
  698.    * -distribution <br/>
  699.    * Outputs the distribution instead of only the prediction
  700.    * in conjunction with the '-p' option (only nominal classes). <p/>
  701.    *
  702.    * -r <br/>
  703.    * Outputs cumulative margin distribution (and nothing else). <p/>
  704.    *
  705.    * -g <br/>
  706.    * Only for classifiers that implement "Graphable." Outputs
  707.    * the graph representation of the classifier (and nothing
  708.    * else). <p/>
  709.    *
  710.    * -xml filename | xml-string <br/>
  711.    * Retrieves the options from the XML-data instead of the command line. <p/>
  712.    *
  713.    * @param classifier machine learning classifier
  714.    * @param options the array of string containing the options
  715.    * @throws Exception if model could not be evaluated successfully
  716.    * @return a string describing the results
  717.    */
  718.   public static String evaluateModel(Classifier classifier,
  719.       String [] options) throws Exception {
  720.  
  721.     Instances train = null, tempTrain, test = null, template = null;
  722.     int seed = 1, folds = 10, classIndex = -1;
  723.     boolean noCrossValidation = false;
  724.     String trainFileName, testFileName, sourceClass,
  725.     classIndexString, seedString, foldsString, objectInputFileName,
  726.     objectOutputFileName, attributeRangeString;
  727.     boolean noOutput = false,
  728.     printClassifications = false, trainStatistics = true,
  729.     printMargins = false, printComplexityStatistics = false,
  730.     printGraph = false, classStatistics = false, printSource = false;
  731.     StringBuffer text = new StringBuffer();
  732.     DataSource trainSource = null, testSource = null;
  733.     ObjectInputStream objectInputStream = null;
  734.     BufferedInputStream xmlInputStream = null;
  735.     CostMatrix costMatrix = null;
  736.     StringBuffer schemeOptionsText = null;
  737.     Range attributesToOutput = null;
  738.     long trainTimeStart = 0, trainTimeElapsed = 0,
  739.     testTimeStart = 0, testTimeElapsed = 0;
  740.     String xml = "";
  741.     String[] optionsTmp = null;
  742.     Classifier classifierBackup;
  743.     Classifier classifierClassifications = null;
  744.     boolean printDistribution = false;
  745.     int actualClassIndex = -1;  // 0-based class index
  746.     String splitPercentageString = "";
  747.     double splitPercentage = -1;
  748.     boolean preserveOrder = false;
  749.     boolean trainSetPresent = false;
  750.     boolean testSetPresent = false;
  751.     String thresholdFile;
  752.     String thresholdLabel;
  753.     StringBuffer predsBuff = null; // predictions from cross-validation
  754.  
  755.     // help requested?
  756.     if (Utils.getFlag("h", options) || Utils.getFlag("help", options)) {
  757.      
  758.       // global info requested as well?
  759.       boolean globalInfo = Utils.getFlag("synopsis", options) ||
  760.         Utils.getFlag("info", options);
  761.      
  762.       throw new Exception("\nHelp requested."
  763.           + makeOptionString(classifier, globalInfo));
  764.     }
  765.    
  766.     try {
  767.       // do we get the input from XML instead of normal parameters?
  768.       xml = Utils.getOption("xml", options);
  769.       if (!xml.equals(""))
  770.     options = new XMLOptions(xml).toArray();
  771.  
  772.       // is the input model only the XML-Options, i.e. w/o built model?
  773.       optionsTmp = new String[options.length];
  774.       for (int i = 0; i < options.length; i++)
  775.     optionsTmp[i] = options[i];
  776.  
  777.       String tmpO = Utils.getOption('l', optionsTmp);
  778.       //if (Utils.getOption('l', optionsTmp).toLowerCase().endsWith(".xml")) {
  779.       if (tmpO.endsWith(".xml")) {
  780.     // try to load file as PMML first
  781.     boolean success = false;
  782.     try {
  783.       PMMLModel pmmlModel = PMMLFactory.getPMMLModel(tmpO);
  784.       if (pmmlModel instanceof PMMLClassifier) {
  785.         classifier = ((PMMLClassifier)pmmlModel);
  786.         success = true;
  787.       }
  788.     } catch (IllegalArgumentException ex) {
  789.       success = false;
  790.     }
  791.     if (!success) {
  792.       // load options from serialized data  ('-l' is automatically erased!)
  793.       XMLClassifier xmlserial = new XMLClassifier();
  794.       Classifier cl = (Classifier) xmlserial.read(Utils.getOption('l', options));
  795.      
  796.       // merge options
  797.       optionsTmp = new String[options.length + cl.getOptions().length];
  798.       System.arraycopy(cl.getOptions(), 0, optionsTmp, 0, cl.getOptions().length);
  799.       System.arraycopy(options, 0, optionsTmp, cl.getOptions().length, options.length);
  800.       options = optionsTmp;
  801.     }
  802.       }
  803.  
  804.       noCrossValidation = Utils.getFlag("no-cv", options);
  805.       // Get basic options (options the same for all schemes)
  806.       classIndexString = Utils.getOption('c', options);
  807.       if (classIndexString.length() != 0) {
  808.     if (classIndexString.equals("first"))
  809.       classIndex = 1;
  810.     else if (classIndexString.equals("last"))
  811.       classIndex = -1;
  812.     else
  813.       classIndex = Integer.parseInt(classIndexString);
  814.       }
  815.       trainFileName = Utils.getOption('t', options);
  816.       objectInputFileName = Utils.getOption('l', options);
  817.       objectOutputFileName = Utils.getOption('d', options);
  818.       testFileName = Utils.getOption('T', options);
  819.       foldsString = Utils.getOption('x', options);
  820.       if (foldsString.length() != 0) {
  821.     folds = Integer.parseInt(foldsString);
  822.       }
  823.       seedString = Utils.getOption('s', options);
  824.       if (seedString.length() != 0) {
  825.     seed = Integer.parseInt(seedString);
  826.       }
  827.       if (trainFileName.length() == 0) {
  828.     if (objectInputFileName.length() == 0) {
  829.       throw new Exception("No training file and no object "+
  830.       "input file given.");
  831.     }
  832.     if (testFileName.length() == 0) {
  833.       throw new Exception("No training file and no test "+
  834.       "file given.");
  835.     }
  836.       } else if ((objectInputFileName.length() != 0) &&
  837.       ((!(classifier instanceof UpdateableClassifier)) ||
  838.           (testFileName.length() == 0))) {
  839.     throw new Exception("Classifier not incremental, or no " +
  840.         "test file provided: can't "+
  841.     "use both train and model file.");
  842.       }
  843.       try {
  844.     if (trainFileName.length() != 0) {
  845.       trainSetPresent = true;
  846.       trainSource = new DataSource(trainFileName);
  847.     }
  848.     if (testFileName.length() != 0) {
  849.       testSetPresent = true;
  850.       testSource = new DataSource(testFileName);
  851.     }
  852.     if (objectInputFileName.length() != 0) {
  853.       if (objectInputFileName.endsWith(".xml")) {
  854.         // if this is the case then it means that a PMML classifier was
  855.         // successfully loaded earlier in the code
  856.         objectInputStream = null;
  857.         xmlInputStream = null;
  858.       } else {
  859.         InputStream is = new FileInputStream(objectInputFileName);
  860.         if (objectInputFileName.endsWith(".gz")) {
  861.           is = new GZIPInputStream(is);
  862.         }
  863.         // load from KOML?
  864.         if (!(objectInputFileName.endsWith(".koml") && KOML.isPresent()) ) {
  865.           objectInputStream = new ObjectInputStream(is);
  866.           xmlInputStream    = null;
  867.         }
  868.         else {
  869.           objectInputStream = null;
  870.           xmlInputStream    = new BufferedInputStream(is);
  871.         }
  872.       }
  873.     }
  874.       } catch (Exception e) {
  875.     throw new Exception("Can't open file " + e.getMessage() + '.');
  876.       }
  877.       if (testSetPresent) {
  878.     template = test = testSource.getStructure();
  879.     if (classIndex != -1) {
  880.       test.setClassIndex(classIndex - 1);
  881.     } else {
  882.       if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
  883.         test.setClassIndex(test.numAttributes() - 1);
  884.     }
  885.     actualClassIndex = test.classIndex();
  886.       }
  887.       else {
  888.     // percentage split
  889.     splitPercentageString = Utils.getOption("split-percentage", options);
  890.     if (splitPercentageString.length() != 0) {
  891.       if (foldsString.length() != 0)
  892.         throw new Exception(
  893.         "Percentage split cannot be used in conjunction with "
  894.         + "cross-validation ('-x').");
  895.       splitPercentage = Double.parseDouble(splitPercentageString);
  896.       if ((splitPercentage <= 0) || (splitPercentage >= 100))
  897.         throw new Exception("Percentage split value needs be >0 and <100.");
  898.     }
  899.     else {
  900.       splitPercentage = -1;
  901.     }
  902.     preserveOrder = Utils.getFlag("preserve-order", options);
  903.     if (preserveOrder) {
  904.       if (splitPercentage == -1)
  905.         throw new Exception("Percentage split ('-percentage-split') is missing.");
  906.     }
  907.     // create new train/test sources
  908.     if (splitPercentage > 0) {
  909.       testSetPresent = true;
  910.       Instances tmpInst = trainSource.getDataSet(actualClassIndex);
  911.       if (!preserveOrder)
  912.         tmpInst.randomize(new Random(seed));
  913.       int trainSize =
  914.             (int) Math.round(tmpInst.numInstances() * splitPercentage / 100);
  915.       int testSize  = tmpInst.numInstances() - trainSize;
  916.       Instances trainInst = new Instances(tmpInst, 0, trainSize);
  917.       Instances testInst  = new Instances(tmpInst, trainSize, testSize);
  918.       trainSource = new DataSource(trainInst);
  919.       testSource  = new DataSource(testInst);
  920.       template = test = testSource.getStructure();
  921.       if (classIndex != -1) {
  922.         test.setClassIndex(classIndex - 1);
  923.       } else {
  924.         if ( (test.classIndex() == -1) || (classIndexString.length() != 0) )
  925.           test.setClassIndex(test.numAttributes() - 1);
  926.       }
  927.       actualClassIndex = test.classIndex();
  928.     }
  929.       }
  930.       if (trainSetPresent) {
  931.     template = train = trainSource.getStructure();
  932.     if (classIndex != -1) {
  933.       train.setClassIndex(classIndex - 1);
  934.     } else {
  935.       if ( (train.classIndex() == -1) || (classIndexString.length() != 0) )
  936.         train.setClassIndex(train.numAttributes() - 1);
  937.     }
  938.     actualClassIndex = train.classIndex();
  939.     if ((testSetPresent) && !test.equalHeaders(train)) {
  940.       throw new IllegalArgumentException("Train and test file not compatible!");
  941.     }
  942.       }
  943.       if (template == null) {
  944.     throw new Exception("No actual dataset provided to use as template");
  945.       }
  946.       costMatrix = handleCostOption(
  947.       Utils.getOption('m', options), template.numClasses());
  948.  
  949.       classStatistics = Utils.getFlag('i', options);
  950.       noOutput = Utils.getFlag('o', options);
  951.       trainStatistics = !Utils.getFlag('v', options);
  952.       printComplexityStatistics = Utils.getFlag('k', options);
  953.       printMargins = Utils.getFlag('r', options);
  954.       printGraph = Utils.getFlag('g', options);
  955.       sourceClass = Utils.getOption('z', options);
  956.       printSource = (sourceClass.length() != 0);
  957.       printDistribution = Utils.getFlag("distribution", options);
  958.       thresholdFile = Utils.getOption("threshold-file", options);
  959.       thresholdLabel = Utils.getOption("threshold-label", options);
  960.  
  961.       // Check -p option
  962.       try {
  963.     attributeRangeString = Utils.getOption('p', options);
  964.       }
  965.       catch (Exception e) {
  966.     throw new Exception(e.getMessage() + "\nNOTE: the -p option has changed. " +
  967.         "It now expects a parameter specifying a range of attributes " +
  968.     "to list with the predictions. Use '-p 0' for none.");
  969.       }
  970.       if (attributeRangeString.length() != 0) {
  971.     printClassifications = true;
  972.     noOutput = true;
  973.     if (!attributeRangeString.equals("0"))
  974.       attributesToOutput = new Range(attributeRangeString);
  975.       }
  976.  
  977.       if (!printClassifications && printDistribution)
  978.     throw new Exception("Cannot print distribution without '-p' option!");
  979.  
  980.       // if no training file given, we don't have any priors
  981.       if ( (!trainSetPresent) && (printComplexityStatistics) )
  982.     throw new Exception("Cannot print complexity statistics ('-k') without training file ('-t')!");
  983.  
  984.       // If a model file is given, we can't process
  985.       // scheme-specific options
  986.       if (objectInputFileName.length() != 0) {
  987.     Utils.checkForRemainingOptions(options);
  988.       } else {
  989.  
  990.     // Set options for classifier
  991.     if (classifier instanceof OptionHandler) {
  992.       for (int i = 0; i < options.length; i++) {
  993.         if (options[i].length() != 0) {
  994.           if (schemeOptionsText == null) {
  995.         schemeOptionsText = new StringBuffer();
  996.           }
  997.           if (options[i].indexOf(' ') != -1) {
  998.         schemeOptionsText.append('"' + options[i] + "\" ");
  999.           } else {
  1000.         schemeOptionsText.append(options[i] + " ");
  1001.           }
  1002.         }
  1003.       }
  1004.       ((OptionHandler)classifier).setOptions(options);
  1005.     }
  1006.       }
  1007.       Utils.checkForRemainingOptions(options);
  1008.     } catch (Exception e) {
  1009.       throw new Exception("\nWeka exception: " + e.getMessage()
  1010.       + makeOptionString(classifier, false));
  1011.     }
  1012.  
  1013.     // Setup up evaluation objects
  1014.     Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
  1015.     Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
  1016.  
  1017.     // disable use of priors if no training file given
  1018.     if (!trainSetPresent)
  1019.       testingEvaluation.useNoPriors();
  1020.  
  1021.     if (objectInputFileName.length() != 0) {
  1022.       // Load classifier from file
  1023.       if (objectInputStream != null) {
  1024.     classifier = (Classifier) objectInputStream.readObject();
  1025.         // try and read a header (if present)
  1026.         Instances savedStructure = null;
  1027.         try {
  1028.           savedStructure = (Instances) objectInputStream.readObject();
  1029.         } catch (Exception ex) {
  1030.           // don't make a fuss
  1031.         }
  1032.         if (savedStructure != null) {
  1033.           // test for compatibility with template
  1034.           if (!template.equalHeaders(savedStructure)) {
  1035.             throw new Exception("training and test set are not compatible");
  1036.           }
  1037.         }
  1038.     objectInputStream.close();
  1039.       }
  1040.       else if (xmlInputStream != null) {
  1041.     // whether KOML is available has already been checked (objectInputStream would null otherwise)!
  1042.     classifier = (Classifier) KOML.read(xmlInputStream);
  1043.     xmlInputStream.close();
  1044.       }
  1045.     }
  1046.  
  1047.     // backup of fully setup classifier for cross-validation
  1048.     classifierBackup = Classifier.makeCopy(classifier);
  1049.  
  1050.     // Build the classifier if no object file provided
  1051.     if ((classifier instanceof UpdateableClassifier) &&
  1052.     (testSetPresent || noCrossValidation) &&
  1053.     (costMatrix == null) &&
  1054.     (trainSetPresent)) {
  1055.       // Build classifier incrementally
  1056.       trainingEvaluation.setPriors(train);
  1057.       testingEvaluation.setPriors(train);
  1058.       trainTimeStart = System.currentTimeMillis();
  1059.       if (objectInputFileName.length() == 0) {
  1060.     classifier.buildClassifier(train);
  1061.       }
  1062.       Instance trainInst;
  1063.       while (trainSource.hasMoreElements(train)) {
  1064.     trainInst = trainSource.nextElement(train);
  1065.     trainingEvaluation.updatePriors(trainInst);
  1066.     testingEvaluation.updatePriors(trainInst);
  1067.     ((UpdateableClassifier)classifier).updateClassifier(trainInst);
  1068.       }
  1069.       trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
  1070.     } else if (objectInputFileName.length() == 0) {
  1071.       // Build classifier in one go
  1072.       tempTrain = trainSource.getDataSet(actualClassIndex);
  1073.       trainingEvaluation.setPriors(tempTrain);
  1074.       testingEvaluation.setPriors(tempTrain);
  1075.       trainTimeStart = System.currentTimeMillis();
  1076.       classifier.buildClassifier(tempTrain);
  1077.       trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
  1078.     }
  1079.  
  1080.     // backup of fully trained classifier for printing the classifications
  1081.     if (printClassifications)
  1082.       classifierClassifications = Classifier.makeCopy(classifier);
  1083.  
  1084.     // Save the classifier if an object output file is provided
  1085.     if (objectOutputFileName.length() != 0) {
  1086.       OutputStream os = new FileOutputStream(objectOutputFileName);
  1087.       // binary
  1088.       if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName.endsWith(".koml") && KOML.isPresent()))) {
  1089.     if (objectOutputFileName.endsWith(".gz")) {
  1090.       os = new GZIPOutputStream(os);
  1091.     }
  1092.     ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
  1093.     objectOutputStream.writeObject(classifier);
  1094.         if (template != null) {
  1095.           objectOutputStream.writeObject(template);
  1096.         }
  1097.     objectOutputStream.flush();
  1098.     objectOutputStream.close();
  1099.       }
  1100.       // KOML/XML
  1101.       else {
  1102.     BufferedOutputStream xmlOutputStream = new BufferedOutputStream(os);
  1103.     if (objectOutputFileName.endsWith(".xml")) {
  1104.       XMLSerialization xmlSerial = new XMLClassifier();
  1105.       xmlSerial.write(xmlOutputStream, classifier);
  1106.     }
  1107.     else
  1108.       // whether KOML is present has already been checked
  1109.       // if not present -> ".koml" is interpreted as binary - see above
  1110.       if (objectOutputFileName.endsWith(".koml")) {
  1111.         KOML.write(xmlOutputStream, classifier);
  1112.       }
  1113.     xmlOutputStream.close();
  1114.       }
  1115.     }
  1116.  
  1117.     // If classifier is drawable output string describing graph
  1118.     if ((classifier instanceof Drawable) && (printGraph)){
  1119.       return ((Drawable)classifier).graph();
  1120.     }
  1121.  
  1122.     // Output the classifier as equivalent source
  1123.     if ((classifier instanceof Sourcable) && (printSource)){
  1124.       return wekaStaticWrapper((Sourcable) classifier, sourceClass);
  1125.     }
  1126.  
  1127.     // Output model
  1128.     if (!(noOutput || printMargins)) {
  1129.       if (classifier instanceof OptionHandler) {
  1130.     if (schemeOptionsText != null) {
  1131.       text.append("\nOptions: "+schemeOptionsText);
  1132.       text.append("\n");
  1133.     }
  1134.       }
  1135.       text.append("\n" + classifier.toString() + "\n");
  1136.     }
  1137.  
  1138.     if (!printMargins && (costMatrix != null)) {
  1139.       text.append("\n=== Evaluation Cost Matrix ===\n\n");
  1140.       text.append(costMatrix.toString());
  1141.     }
  1142.  
  1143.     // Output test instance predictions only
  1144.     if (printClassifications) {
  1145.       DataSource source = testSource;
  1146.       predsBuff = new StringBuffer();
  1147.       // no test set -> use train set
  1148.       if (source == null && noCrossValidation) {
  1149.     source = trainSource;
  1150.         predsBuff.append("\n=== Predictions on training data ===\n\n");
  1151.       } else {
  1152.         predsBuff.append("\n=== Predictions on test data ===\n\n");
  1153.       }
  1154.       if (source != null) {
  1155.         /*      return printClassifications(classifierClassifications, new Instances(template, 0),
  1156.                 source, actualClassIndex + 1, attributesToOutput,
  1157.                 printDistribution); */
  1158.         printClassifications(classifierClassifications, new Instances(template, 0),
  1159.                              source, actualClassIndex + 1, attributesToOutput,
  1160.                              printDistribution, predsBuff);
  1161.         //        return predsText.toString();
  1162.       }
  1163.     }
  1164.  
  1165.     // Compute error estimate from training data
  1166.     if ((trainStatistics) && (trainSetPresent)) {
  1167.  
  1168.       if ((classifier instanceof UpdateableClassifier) &&
  1169.       (testSetPresent) &&
  1170.       (costMatrix == null)) {
  1171.  
  1172.     // Classifier was trained incrementally, so we have to
  1173.     // reset the source.
  1174.     trainSource.reset();
  1175.  
  1176.     // Incremental testing
  1177.     train = trainSource.getStructure(actualClassIndex);
  1178.     testTimeStart = System.currentTimeMillis();
  1179.     Instance trainInst;
  1180.     while (trainSource.hasMoreElements(train)) {
  1181.       trainInst = trainSource.nextElement(train);
  1182.       trainingEvaluation.evaluateModelOnce((Classifier)classifier, trainInst);
  1183.     }
  1184.     testTimeElapsed = System.currentTimeMillis() - testTimeStart;
  1185.       } else {
  1186.     testTimeStart = System.currentTimeMillis();
  1187.     trainingEvaluation.evaluateModel(
  1188.         classifier, trainSource.getDataSet(actualClassIndex));
  1189.     testTimeElapsed = System.currentTimeMillis() - testTimeStart;
  1190.       }
  1191.  
  1192.       // Print the results of the training evaluation
  1193.       if (printMargins) {
  1194.     return trainingEvaluation.toCumulativeMarginDistributionString();
  1195.       } else {
  1196.         if (!printClassifications) {
  1197.           text.append("\nTime taken to build model: "
  1198.               + Utils.doubleToString(trainTimeElapsed / 1000.0,2)
  1199.               + " seconds");
  1200.  
  1201.           if (splitPercentage > 0)
  1202.             text.append("\nTime taken to test model on training split: ");
  1203.           else
  1204.             text.append("\nTime taken to test model on training data: ");
  1205.           text.append(Utils.doubleToString(testTimeElapsed / 1000.0,2) + " seconds");
  1206.  
  1207.           if (splitPercentage > 0)
  1208.             text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training"
  1209.                 + " split ===\n", printComplexityStatistics));
  1210.           else
  1211.             text.append(trainingEvaluation.toSummaryString("\n\n=== Error on training"
  1212.                 + " data ===\n", printComplexityStatistics));
  1213.  
  1214.           if (template.classAttribute().isNominal()) {
  1215.             if (classStatistics) {
  1216.               text.append("\n\n" + trainingEvaluation.toClassDetailsString());
  1217.             }
  1218.             if (!noCrossValidation)
  1219.               text.append("\n\n" + trainingEvaluation.toMatrixString());
  1220.           }
  1221.         }
  1222.       }
  1223.     }
  1224.  
  1225.     // Compute proper error estimates
  1226.     if (testSource != null) {
  1227.       // Testing is on the supplied test data
  1228.       testSource.reset();
  1229.       test = testSource.getStructure(test.classIndex());
  1230.       Instance testInst;
  1231.       while (testSource.hasMoreElements(test)) {
  1232.     testInst = testSource.nextElement(test);
  1233.     testingEvaluation.evaluateModelOnceAndRecordPrediction(
  1234.             (Classifier)classifier, testInst);
  1235.       }
  1236.  
  1237.       if (splitPercentage > 0) {
  1238.         if (!printClassifications) {
  1239.           text.append("\n\n" + testingEvaluation.
  1240.               toSummaryString("=== Error on test split ===\n",
  1241.                   printComplexityStatistics));
  1242.         }
  1243.       } else {
  1244.         if (!printClassifications) {
  1245.           text.append("\n\n" + testingEvaluation.
  1246.               toSummaryString("=== Error on test data ===\n",
  1247.                   printComplexityStatistics));
  1248.         }
  1249.       }
  1250.  
  1251.     } else if (trainSource != null) {
  1252.       if (!noCrossValidation) {
  1253.     // Testing is via cross-validation on training data
  1254.     Random random = new Random(seed);
  1255.     // use untrained (!) classifier for cross-validation
  1256.     classifier = Classifier.makeCopy(classifierBackup);
  1257.         if (!printClassifications) {
  1258.           testingEvaluation.crossValidateModel(classifier,
  1259.                                                trainSource.getDataSet(actualClassIndex),
  1260.                                                folds, random);
  1261.           if (template.classAttribute().isNumeric()) {
  1262.             text.append("\n\n\n" + testingEvaluation.
  1263.                         toSummaryString("=== Cross-validation ===\n",
  1264.                                         printComplexityStatistics));
  1265.           } else {
  1266.             text.append("\n\n\n" + testingEvaluation.
  1267.                         toSummaryString("=== Stratified " +
  1268.                                         "cross-validation ===\n",
  1269.                                         printComplexityStatistics));
  1270.           }
  1271.         } else {
  1272.           predsBuff = new StringBuffer();
  1273.           predsBuff.append("\n=== Predictions under cross-validation ===\n\n");
  1274.           testingEvaluation.crossValidateModel(classifier,
  1275.                                                trainSource.getDataSet(actualClassIndex),
  1276.                                                folds, random, predsBuff, attributesToOutput,
  1277.                                                new Boolean(printDistribution));
  1278. /*          if (template.classAttribute().isNumeric()) {
  1279.             text.append("\n\n\n" + testingEvaluation.
  1280.                         toSummaryString("=== Cross-validation ===\n",
  1281.                                         printComplexityStatistics));
  1282.           } else {
  1283.             text.append("\n\n\n" + testingEvaluation.
  1284.                         toSummaryString("=== Stratified " +
  1285.                                         "cross-validation ===\n",
  1286.                                         printComplexityStatistics));
  1287.           } */
  1288.         }
  1289.       }
  1290.     }
  1291.     if (template.classAttribute().isNominal()) {
  1292.       if (classStatistics && !noCrossValidation && !printClassifications) {
  1293.     text.append("\n\n" + testingEvaluation.toClassDetailsString());
  1294.       }
  1295.       if (!noCrossValidation && !printClassifications)
  1296.         text.append("\n\n" + testingEvaluation.toMatrixString());
  1297.      
  1298.     }
  1299.    
  1300.     // predictions from cross-validation?
  1301.     if (predsBuff != null) {
  1302.       text.append("\n" + predsBuff);
  1303.     }
  1304.  
  1305.     if ((thresholdFile.length() != 0) && template.classAttribute().isNominal()) {
  1306.       int labelIndex = 0;
  1307.       if (thresholdLabel.length() != 0)
  1308.     labelIndex = template.classAttribute().indexOfValue(thresholdLabel);
  1309.       if (labelIndex == -1)
  1310.     throw new IllegalArgumentException(
  1311.         "Class label '" + thresholdLabel + "' is unknown!");
  1312.       ThresholdCurve tc = new ThresholdCurve();
  1313.       Instances result = tc.getCurve(testingEvaluation.predictions(), labelIndex);
  1314.       DataSink.write(thresholdFile, result);
  1315.     }
  1316.    
  1317.     return text.toString();
  1318.   }
  1319.  
  1320.   /**
  1321.    * Attempts to load a cost matrix.
  1322.    *
  1323.    * @param costFileName the filename of the cost matrix
  1324.    * @param numClasses the number of classes that should be in the cost matrix
  1325.    * (only used if the cost file is in old format).
  1326.    * @return a <code>CostMatrix</code> value, or null if costFileName is empty
  1327.    * @throws Exception if an error occurs.
  1328.    */
  1329.   protected static CostMatrix handleCostOption(String costFileName,
  1330.       int numClasses)
  1331.   throws Exception {
  1332.  
  1333.     if ((costFileName != null) && (costFileName.length() != 0)) {
  1334.       System.out.println(
  1335.       "NOTE: The behaviour of the -m option has changed between WEKA 3.0"
  1336.       +" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*"
  1337.       +" only. For cost-sensitive *prediction*, use one of the"
  1338.       +" cost-sensitive metaschemes such as"
  1339.       +" weka.classifiers.meta.CostSensitiveClassifier or"
  1340.       +" weka.classifiers.meta.MetaCost");
  1341.  
  1342.       Reader costReader = null;
  1343.       try {
  1344.     costReader = new BufferedReader(new FileReader(costFileName));
  1345.       } catch (Exception e) {
  1346.     throw new Exception("Can't open file " + e.getMessage() + '.');
  1347.       }
  1348.       try {
  1349.     // First try as a proper cost matrix format
  1350.     return new CostMatrix(costReader);
  1351.       } catch (Exception ex) {
  1352.     try {
  1353.       // Now try as the poxy old format :-)
  1354.       //System.err.println("Attempting to read old format cost file");
  1355.       try {
  1356.         costReader.close(); // Close the old one
  1357.         costReader = new BufferedReader(new FileReader(costFileName));
  1358.       } catch (Exception e) {
  1359.         throw new Exception("Can't open file " + e.getMessage() + '.');
  1360.       }
  1361.       CostMatrix costMatrix = new CostMatrix(numClasses);
  1362.       //System.err.println("Created default cost matrix");
  1363.       costMatrix.readOldFormat(costReader);
  1364.       return costMatrix;
  1365.       //System.err.println("Read old format");
  1366.     } catch (Exception e2) {
  1367.       // re-throw the original exception
  1368.       //System.err.println("Re-throwing original exception");
  1369.       throw ex;
  1370.     }
  1371.       }
  1372.     } else {
  1373.       return null;
  1374.     }
  1375.   }
  1376.      
  1377.   /**
  1378.    * Evaluates the classifier on a given set of instances. Note that
  1379.    * the data must have exactly the same format (e.g. order of
  1380.    * attributes) as the data used to train the classifier! Otherwise
  1381.    * the results will generally be meaningless.
  1382.    *
  1383.    * @param classifier machine learning classifier
  1384.    * @param data set of test instances for evaluation
  1385.    * @param forPredictionsString varargs parameter that, if supplied, is
  1386.    * expected to hold a StringBuffer to print predictions to,
  1387.    * a Range of attributes to output and a Boolean (true if the distribution
  1388.    * is to be printed)
  1389.    * @return the predictions
  1390.    * @throws Exception if model could not be evaluated
  1391.    * successfully
  1392.    */
  1393.   public double[] evaluateModel(Classifier classifier,
  1394.                                 Instances data,
  1395.                                 Object... forPredictionsPrinting) throws Exception {
  1396.     // for predictions printing
  1397.     StringBuffer buff = null;
  1398.     Range attsToOutput = null;
  1399.     boolean printDist = false;
  1400.  
  1401.     double predictions[] = new double[data.numInstances()];
  1402.  
  1403.     if (forPredictionsPrinting.length > 0) {
  1404.       buff = (StringBuffer)forPredictionsPrinting[0];
  1405.       attsToOutput = (Range)forPredictionsPrinting[1];
  1406.       printDist = ((Boolean)forPredictionsPrinting[2]).booleanValue();
  1407.     }
  1408.  
  1409.     // Need to be able to collect predictions if appropriate (for AUC)
  1410.  
  1411.     for (int i = 0; i < data.numInstances(); i++) {
  1412.       predictions[i] = evaluateModelOnceAndRecordPrediction((Classifier)classifier,
  1413.       data.instance(i));
  1414.       if (buff != null) {
  1415.         buff.append(predictionText(classifier, data.instance(i), i,
  1416.                                    attsToOutput, printDist));
  1417.       }
  1418.     }
  1419.  
  1420.     return predictions;
  1421.   }
  1422.  
  1423.   /**
  1424.    * Evaluates the classifier on a single instance and records the
  1425.    * prediction (if the class is nominal).
  1426.    *
  1427.    * @param classifier machine learning classifier
  1428.    * @param instance the test instance to be classified
  1429.    * @return the prediction made by the clasifier
  1430.    * @throws Exception if model could not be evaluated
  1431.    * successfully or the data contains string attributes
  1432.    */
  1433.   public double evaluateModelOnceAndRecordPrediction(Classifier classifier,
  1434.       Instance instance) throws Exception {
  1435.  
  1436.     Instance classMissing = (Instance)instance.copy();
  1437.     double pred = 0;
  1438.     classMissing.setDataset(instance.dataset());
  1439.     classMissing.setClassMissing();
  1440.     if (m_ClassIsNominal) {
  1441.       if (m_Predictions == null) {
  1442.     m_Predictions = new FastVector();
  1443.       }
  1444.       double [] dist = classifier.distributionForInstance(classMissing);
  1445.       pred = Utils.maxIndex(dist);
  1446.       if (dist[(int)pred] <= 0) {
  1447.     pred = Instance.missingValue();
  1448.       }
  1449.       updateStatsForClassifier(dist, instance);
  1450.       m_Predictions.addElement(new NominalPrediction(instance.classValue(), dist,
  1451.       instance.weight()));
  1452.     } else {
  1453.       pred = classifier.classifyInstance(classMissing);
  1454.       updateStatsForPredictor(pred, instance);
  1455.     }
  1456.     return pred;
  1457.   }
  1458.  
  1459.   /**
  1460.    * Evaluates the classifier on a single instance.
  1461.    *
  1462.    * @param classifier machine learning classifier
  1463.    * @param instance the test instance to be classified
  1464.    * @return the prediction made by the clasifier
  1465.    * @throws Exception if model could not be evaluated
  1466.    * successfully or the data contains string attributes
  1467.    */
  1468.   public double evaluateModelOnce(Classifier classifier,
  1469.       Instance instance) throws Exception {
  1470.  
  1471.     Instance classMissing = (Instance)instance.copy();
  1472.     double pred = 0;
  1473.     classMissing.setDataset(instance.dataset());
  1474.     classMissing.setClassMissing();
  1475.     if (m_ClassIsNominal) {
  1476.       double [] dist = classifier.distributionForInstance(classMissing);
  1477.       pred = Utils.maxIndex(dist);
  1478.       if (dist[(int)pred] <= 0) {
  1479.     pred = Instance.missingValue();
  1480.       }
  1481.       updateStatsForClassifier(dist, instance);
  1482.     } else {
  1483.       pred = classifier.classifyInstance(classMissing);
  1484.       updateStatsForPredictor(pred, instance);
  1485.     }
  1486.     return pred;
  1487.   }
  1488.  
  1489.   /**
  1490.    * Evaluates the supplied distribution on a single instance.
  1491.    *
  1492.    * @param dist the supplied distribution
  1493.    * @param instance the test instance to be classified
  1494.    * @return the prediction
  1495.    * @throws Exception if model could not be evaluated
  1496.    * successfully
  1497.    */
  1498.   public double evaluateModelOnce(double [] dist,
  1499.       Instance instance) throws Exception {
  1500.     double pred;
  1501.     if (m_ClassIsNominal) {
  1502.       pred = Utils.maxIndex(dist);
  1503.       if (dist[(int)pred] <= 0) {
  1504.     pred = Instance.missingValue();
  1505.       }
  1506.       updateStatsForClassifier(dist, instance);
  1507.     } else {
  1508.       pred = dist[0];
  1509.       updateStatsForPredictor(pred, instance);
  1510.     }
  1511.     return pred;
  1512.   }
  1513.  
  1514.   /**
  1515.    * Evaluates the supplied distribution on a single instance.
  1516.    *
  1517.    * @param dist the supplied distribution
  1518.    * @param instance the test instance to be classified
  1519.    * @return the prediction
  1520.    * @throws Exception if model could not be evaluated
  1521.    * successfully
  1522.    */
  1523.   public double evaluateModelOnceAndRecordPrediction(double [] dist,
  1524.       Instance instance) throws Exception {
  1525.     double pred;
  1526.     if (m_ClassIsNominal) {
  1527.       if (m_Predictions == null) {
  1528.     m_Predictions = new FastVector();
  1529.       }
  1530.       pred = Utils.maxIndex(dist);
  1531.       if (dist[(int)pred] <= 0) {
  1532.     pred = Instance.missingValue();
  1533.       }
  1534.       updateStatsForClassifier(dist, instance);
  1535.       m_Predictions.addElement(new NominalPrediction(instance.classValue(), dist,
  1536.       instance.weight()));
  1537.     } else {
  1538.       pred = dist[0];
  1539.       updateStatsForPredictor(pred, instance);
  1540.     }
  1541.     return pred;
  1542.   }
  1543.  
  1544.   /**
  1545.    * Evaluates the supplied prediction on a single instance.
  1546.    *
  1547.    * @param prediction the supplied prediction
  1548.    * @param instance the test instance to be classified
  1549.    * @throws Exception if model could not be evaluated
  1550.    * successfully
  1551.    */
  1552.   public void evaluateModelOnce(double prediction,
  1553.       Instance instance) throws Exception {
  1554.  
  1555.     if (m_ClassIsNominal) {
  1556.       updateStatsForClassifier(makeDistribution(prediction),
  1557.       instance);
  1558.     } else {
  1559.       updateStatsForPredictor(prediction, instance);
  1560.     }
  1561.   }
  1562.  
  1563.   /**
  1564.    * Returns the predictions that have been collected.
  1565.    *
  1566.    * @return a reference to the FastVector containing the predictions
  1567.    * that have been collected. This should be null if no predictions
  1568.    * have been collected (e.g. if the class is numeric).
  1569.    */
  1570.   public FastVector predictions() {
  1571.  
  1572.     return m_Predictions;
  1573.   }
  1574.  
  1575.   /**
  1576.    * Wraps a static classifier in enough source to test using the weka
  1577.    * class libraries.
  1578.    *
  1579.    * @param classifier a Sourcable Classifier
  1580.    * @param className the name to give to the source code class
  1581.    * @return the source for a static classifier that can be tested with
  1582.    * weka libraries.
  1583.    * @throws Exception if code-generation fails
  1584.    */
  1585.   public static String wekaStaticWrapper(Sourcable classifier, String className)    
  1586.     throws Exception {
  1587.  
  1588.     StringBuffer result = new StringBuffer();
  1589.     String staticClassifier = classifier.toSource(className);
  1590.    
  1591.     result.append("// Generated with Weka " + Version.VERSION + "\n");
  1592.     result.append("//\n");
  1593.     result.append("// This code is public domain and comes with no warranty.\n");
  1594.     result.append("//\n");
  1595.     result.append("// Timestamp: " + new Date() + "\n");
  1596.     result.append("\n");
  1597.     result.append("package weka.classifiers;\n");
  1598.     result.append("\n");
  1599.     result.append("import weka.core.Attribute;\n");
  1600.     result.append("import weka.core.Capabilities;\n");
  1601.     result.append("import weka.core.Capabilities.Capability;\n");
  1602.     result.append("import weka.core.Instance;\n");
  1603.     result.append("import weka.core.Instances;\n");
  1604.     result.append("import weka.core.RevisionUtils;\n");
  1605.     result.append("import weka.classifiers.Classifier;\n");
  1606.     result.append("\n");
  1607.     result.append("public class WekaWrapper\n");
  1608.     result.append("  extends Classifier {\n");
  1609.    
  1610.     // globalInfo
  1611.     result.append("\n");
  1612.     result.append("  /**\n");
  1613.     result.append("   * Returns only the toString() method.\n");
  1614.     result.append("   *\n");
  1615.     result.append("   * @return a string describing the classifier\n");
  1616.     result.append("   */\n");
  1617.     result.append("  public String globalInfo() {\n");
  1618.     result.append("    return toString();\n");
  1619.     result.append("  }\n");
  1620.    
  1621.     // getCapabilities
  1622.     result.append("\n");
  1623.     result.append("  /**\n");
  1624.     result.append("   * Returns the capabilities of this classifier.\n");
  1625.     result.append("   *\n");
  1626.     result.append("   * @return the capabilities\n");
  1627.     result.append("   */\n");
  1628.     result.append("  public Capabilities getCapabilities() {\n");
  1629.     result.append(((Classifier) classifier).getCapabilities().toSource("result", 4));
  1630.     result.append("    return result;\n");
  1631.     result.append("  }\n");
  1632.    
  1633.     // buildClassifier
  1634.     result.append("\n");
  1635.     result.append("  /**\n");
  1636.     result.append("   * only checks the data against its capabilities.\n");
  1637.     result.append("   *\n");
  1638.     result.append("   * @param i the training data\n");
  1639.     result.append("   */\n");
  1640.     result.append("  public void buildClassifier(Instances i) throws Exception {\n");
  1641.     result.append("    // can classifier handle the data?\n");
  1642.     result.append("    getCapabilities().testWithFail(i);\n");
  1643.     result.append("  }\n");
  1644.    
  1645.     // classifyInstance
  1646.     result.append("\n");
  1647.     result.append("  /**\n");
  1648.     result.append("   * Classifies the given instance.\n");
  1649.     result.append("   *\n");
  1650.     result.append("   * @param i the instance to classify\n");
  1651.     result.append("   * @return the classification result\n");
  1652.     result.append("   */\n");
  1653.     result.append("  public double classifyInstance(Instance i) throws Exception {\n");
  1654.     result.append("    Object[] s = new Object[i.numAttributes()];\n");
  1655.     result.append("    \n");
  1656.     result.append("    for (int j = 0; j < s.length; j++) {\n");
  1657.     result.append("      if (!i.isMissing(j)) {\n");
  1658.     result.append("        if (i.attribute(j).isNominal())\n");
  1659.     result.append("          s[j] = new String(i.stringValue(j));\n");
  1660.     result.append("        else if (i.attribute(j).isNumeric())\n");
  1661.     result.append("          s[j] = new Double(i.value(j));\n");
  1662.     result.append("      }\n");
  1663.     result.append("    }\n");
  1664.     result.append("    \n");
  1665.     result.append("    // set class value to missing\n");
  1666.     result.append("    s[i.classIndex()] = null;\n");
  1667.     result.append("    \n");
  1668.     result.append("    return " + className + ".classify(s);\n");
  1669.     result.append("  }\n");
  1670.  
  1671.     // getRevision
  1672.     result.append("\n");
  1673.     result.append("  /**\n");
  1674.     result.append("   * Returns the revision string.\n");
  1675.     result.append("   * \n");
  1676.     result.append("   * @return        the revision\n");
  1677.     result.append("   */\n");
  1678.     result.append("  public String getRevision() {\n");
  1679.     result.append("    return RevisionUtils.extract(\"1.0\");\n");
  1680.     result.append("  }\n");
  1681.  
  1682.     // toString
  1683.     result.append("\n");
  1684.     result.append("  /**\n");
  1685.     result.append("   * Returns only the classnames and what classifier it is based on.\n");
  1686.     result.append("   *\n");
  1687.     result.append("   * @return a short description\n");
  1688.     result.append("   */\n");
  1689.     result.append("  public String toString() {\n");
  1690.     result.append("    return \"Auto-generated classifier wrapper, based on "
  1691.     + classifier.getClass().getName() + " (generated with Weka " + Version.VERSION + ").\\n"
  1692.     + "\" + this.getClass().getName() + \"/" + className + "\";\n");
  1693.     result.append("  }\n");
  1694.    
  1695.     // main
  1696.     result.append("\n");
  1697.     result.append("  /**\n");
  1698.     result.append("   * Runs the classfier from commandline.\n");
  1699.     result.append("   *\n");
  1700.     result.append("   * @param args the commandline arguments\n");
  1701.     result.append("   */\n");
  1702.     result.append("  public static void main(String args[]) {\n");
  1703.     result.append("    runClassifier(new WekaWrapper(), args);\n");
  1704.     result.append("  }\n");
  1705.     result.append("}\n");
  1706.    
  1707.     // actual classifier code
  1708.     result.append("\n");
  1709.     result.append(staticClassifier);
  1710.    
  1711.     return result.toString();
  1712.   }
  1713.  
  1714.   /**
  1715.    * Gets the number of test instances that had a known class value
  1716.    * (actually the sum of the weights of test instances with known
  1717.    * class value).
  1718.    *
  1719.    * @return the number of test instances with known class
  1720.    */
  1721.   public final double numInstances() {
  1722.  
  1723.     return m_WithClass;
  1724.   }
  1725.  
  1726.   /**
  1727.    * Gets the number of instances incorrectly classified (that is, for
  1728.    * which an incorrect prediction was made). (Actually the sum of the weights
  1729.    * of these instances)
  1730.    *
  1731.    * @return the number of incorrectly classified instances
  1732.    */
  1733.   public final double incorrect() {
  1734.  
  1735.     return m_Incorrect;
  1736.   }
  1737.  
  1738.   /**
  1739.    * Gets the percentage of instances incorrectly classified (that is, for
  1740.    * which an incorrect prediction was made).
  1741.    *
  1742.    * @return the percent of incorrectly classified instances
  1743.    * (between 0 and 100)
  1744.    */
  1745.   public final double pctIncorrect() {
  1746.  
  1747.     return 100 * m_Incorrect / m_WithClass;
  1748.   }
  1749.  
  1750.   /**
  1751.    * Gets the total cost, that is, the cost of each prediction times the
  1752.    * weight of the instance, summed over all instances.
  1753.    *
  1754.    * @return the total cost
  1755.    */
  1756.   public final double totalCost() {
  1757.  
  1758.     return m_TotalCost;
  1759.   }
  1760.  
  1761.   /**
  1762.    * Gets the average cost, that is, total cost of misclassifications
  1763.    * (incorrect plus unclassified) over the total number of instances.
  1764.    *
  1765.    * @return the average cost.  
  1766.    */
  1767.   public final double avgCost() {
  1768.  
  1769.     return m_TotalCost / m_WithClass;
  1770.   }
  1771.  
  1772.   /**
  1773.    * Gets the number of instances correctly classified (that is, for
  1774.    * which a correct prediction was made). (Actually the sum of the weights
  1775.    * of these instances)
  1776.    *
  1777.    * @return the number of correctly classified instances
  1778.    */
  1779.   public final double correct() {
  1780.  
  1781.     return m_Correct;
  1782.   }
  1783.  
  1784.   /**
  1785.    * Gets the percentage of instances correctly classified (that is, for
  1786.    * which a correct prediction was made).
  1787.    *
  1788.    * @return the percent of correctly classified instances (between 0 and 100)
  1789.    */
  1790.   public final double pctCorrect() {
  1791.  
  1792.     return 100 * m_Correct / m_WithClass;
  1793.   }
  1794.  
  1795.   /**
  1796.    * Gets the number of instances not classified (that is, for
  1797.    * which no prediction was made by the classifier). (Actually the sum
  1798.    * of the weights of these instances)
  1799.    *
  1800.    * @return the number of unclassified instances
  1801.    */
  1802.   public final double unclassified() {
  1803.  
  1804.     return m_Unclassified;
  1805.   }
  1806.  
  1807.   /**
  1808.    * Gets the percentage of instances not classified (that is, for
  1809.    * which no prediction was made by the classifier).
  1810.    *
  1811.    * @return the percent of unclassified instances (between 0 and 100)
  1812.    */
  1813.   public final double pctUnclassified() {
  1814.  
  1815.     return 100 * m_Unclassified / m_WithClass;
  1816.   }
  1817.  
  1818.   /**
  1819.    * Returns the estimated error rate or the root mean squared error
  1820.    * (if the class is numeric). If a cost matrix was given this
  1821.    * error rate gives the average cost.
  1822.    *
  1823.    * @return the estimated error rate (between 0 and 1, or between 0 and
  1824.    * maximum cost)
  1825.    */
  1826.   public final double errorRate() {
  1827.  
  1828.     if (!m_ClassIsNominal) {
  1829.       return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified));
  1830.     }
  1831.     if (m_CostMatrix == null) {
  1832.       return m_Incorrect / m_WithClass;
  1833.     } else {
  1834.       return avgCost();
  1835.     }
  1836.   }
  1837.  
  1838.   /**
  1839.    * Returns value of kappa statistic if class is nominal.
  1840.    *
  1841.    * @return the value of the kappa statistic
  1842.    */
  1843.   public final double kappa() {
  1844.  
  1845.  
  1846.     double[] sumRows = new double[m_ConfusionMatrix.length];
  1847.     double[] sumColumns = new double[m_ConfusionMatrix.length];
  1848.     double sumOfWeights = 0;
  1849.     for (int i = 0; i < m_ConfusionMatrix.length; i++) {
  1850.       for (int j = 0; j < m_ConfusionMatrix.length; j++) {
  1851.     sumRows[i] += m_ConfusionMatrix[i][j];
  1852.     sumColumns[j] += m_ConfusionMatrix[i][j];
  1853.     sumOfWeights += m_ConfusionMatrix[i][j];
  1854.       }
  1855.     }
  1856.     double correct = 0, chanceAgreement = 0;
  1857.     for (int i = 0; i < m_ConfusionMatrix.length; i++) {
  1858.       chanceAgreement += (sumRows[i] * sumColumns[i]);
  1859.       correct += m_ConfusionMatrix[i][i];
  1860.     }
  1861.     chanceAgreement /= (sumOfWeights * sumOfWeights);
  1862.     correct /= sumOfWeights;
  1863.  
  1864.     if (chanceAgreement < 1) {
  1865.       return (correct - chanceAgreement) / (1 - chanceAgreement);
  1866.     } else {
  1867.       return 1;
  1868.     }
  1869.   }
  1870.  
  1871.   /**
  1872.    * Returns the correlation coefficient if the class is numeric.
  1873.    *
  1874.    * @return the correlation coefficient
  1875.    * @throws Exception if class is not numeric
  1876.    */
  1877.   public final double correlationCoefficient() throws Exception {
  1878.  
  1879.     if (m_ClassIsNominal) {
  1880.       throw
  1881.       new Exception("Can't compute correlation coefficient: " +
  1882.       "class is nominal!");
  1883.     }
  1884.  
  1885.     double correlation = 0;
  1886.     double varActual =
  1887.       m_SumSqrClass - m_SumClass * m_SumClass /
  1888.       (m_WithClass - m_Unclassified);
  1889.     double varPredicted =
  1890.       m_SumSqrPredicted - m_SumPredicted * m_SumPredicted /
  1891.       (m_WithClass - m_Unclassified);
  1892.     double varProd =
  1893.       m_SumClassPredicted - m_SumClass * m_SumPredicted /
  1894.       (m_WithClass - m_Unclassified);
  1895.  
  1896.     if (varActual * varPredicted <= 0) {
  1897.       correlation = 0.0;
  1898.     } else {
  1899.       correlation = varProd / Math.sqrt(varActual * varPredicted);
  1900.     }
  1901.  
  1902.     return correlation;
  1903.   }
  1904.  
  1905.   /**
  1906.    * Returns the mean absolute error. Refers to the error of the
  1907.    * predicted values for numeric classes, and the error of the
  1908.    * predicted probability distribution for nominal classes.
  1909.    *
  1910.    * @return the mean absolute error
  1911.    */
  1912.   public final double meanAbsoluteError() {
  1913.  
  1914.     return m_SumAbsErr / (m_WithClass - m_Unclassified);
  1915.   }
  1916.  
  1917.   /**
  1918.    * Returns the mean absolute error of the prior.
  1919.    *
  1920.    * @return the mean absolute error
  1921.    */
  1922.   public final double meanPriorAbsoluteError() {
  1923.  
  1924.     if (m_NoPriors)
  1925.       return Double.NaN;
  1926.  
  1927.     return m_SumPriorAbsErr / m_WithClass;
  1928.   }
  1929.  
  1930.   /**
  1931.    * Returns the relative absolute error.
  1932.    *
  1933.    * @return the relative absolute error
  1934.    * @throws Exception if it can't be computed
  1935.    */
  1936.   public final double relativeAbsoluteError() throws Exception {
  1937.  
  1938.     if (m_NoPriors)
  1939.       return Double.NaN;
  1940.  
  1941.     return 100 * meanAbsoluteError() / meanPriorAbsoluteError();
  1942.   }
  1943.  
  1944.   /**
  1945.    * Returns the root mean squared error.
  1946.    *
  1947.    * @return the root mean squared error
  1948.    */
  1949.   public final double rootMeanSquaredError() {
  1950.  
  1951.     return Math.sqrt(m_SumSqrErr / (m_WithClass - m_Unclassified));
  1952.   }
  1953.  
  1954.   /**
  1955.    * Returns the root mean prior squared error.
  1956.    *
  1957.    * @return the root mean prior squared error
  1958.    */
  1959.   public final double rootMeanPriorSquaredError() {
  1960.  
  1961.     if (m_NoPriors)
  1962.       return Double.NaN;
  1963.  
  1964.     return Math.sqrt(m_SumPriorSqrErr / m_WithClass);
  1965.   }
  1966.  
  1967.   /**
  1968.    * Returns the root relative squared error if the class is numeric.
  1969.    *
  1970.    * @return the root relative squared error
  1971.    */
  1972.   public final double rootRelativeSquaredError() {
  1973.  
  1974.     if (m_NoPriors)
  1975.       return Double.NaN;
  1976.  
  1977.     return 100.0 * rootMeanSquaredError() /
  1978.     rootMeanPriorSquaredError();
  1979.   }
  1980.  
  1981.   /**
  1982.    * Calculate the entropy of the prior distribution
  1983.    *
  1984.    * @return the entropy of the prior distribution
  1985.    * @throws Exception if the class is not nominal
  1986.    */
  1987.   public final double priorEntropy() throws Exception {
  1988.  
  1989.     if (!m_ClassIsNominal) {
  1990.       throw
  1991.       new Exception("Can't compute entropy of class prior: " +
  1992.       "class numeric!");
  1993.     }
  1994.  
  1995.     if (m_NoPriors)
  1996.       return Double.NaN;
  1997.  
  1998.     double entropy = 0;
  1999.     for(int i = 0; i < m_NumClasses; i++) {
  2000.       entropy -= m_ClassPriors[i] / m_ClassPriorsSum
  2001.       * Utils.log2(m_ClassPriors[i] / m_ClassPriorsSum);
  2002.     }
  2003.     return entropy;
  2004.   }
  2005.  
  2006.   /**
  2007.    * Return the total Kononenko & Bratko Information score in bits
  2008.    *
  2009.    * @return the K&B information score
  2010.    * @throws Exception if the class is not nominal
  2011.    */
  2012.   public final double KBInformation() throws Exception {
  2013.  
  2014.     if (!m_ClassIsNominal) {
  2015.       throw
  2016.       new Exception("Can't compute K&B Info score: " +
  2017.       "class numeric!");
  2018.     }
  2019.  
  2020.     if (m_NoPriors)
  2021.       return Double.NaN;
  2022.  
  2023.     return m_SumKBInfo;
  2024.   }
  2025.  
  2026.   /**
  2027.    * Return the Kononenko & Bratko Information score in bits per
  2028.    * instance.
  2029.    *
  2030.    * @return the K&B information score
  2031.    * @throws Exception if the class is not nominal
  2032.    */
  2033.   public final double KBMeanInformation() throws Exception {
  2034.  
  2035.     if (!m_ClassIsNominal) {
  2036.       throw
  2037.       new Exception("Can't compute K&B Info score: "
  2038.       + "class numeric!");
  2039.     }
  2040.  
  2041.     if (m_NoPriors)
  2042.       return Double.NaN;
  2043.  
  2044.     return m_SumKBInfo / (m_WithClass - m_Unclassified);
  2045.   }
  2046.  
  2047.   /**
  2048.    * Return the Kononenko & Bratko Relative Information score
  2049.    *
  2050.    * @return the K&B relative information score
  2051.    * @throws Exception if the class is not nominal
  2052.    */
  2053.   public final double KBRelativeInformation() throws Exception {
  2054.  
  2055.     if (!m_ClassIsNominal) {
  2056.       throw
  2057.       new Exception("Can't compute K&B Info score: " +
  2058.       "class numeric!");
  2059.     }
  2060.  
  2061.     if (m_NoPriors)
  2062.       return Double.NaN;
  2063.  
  2064.     return 100.0 * KBInformation() / priorEntropy();
  2065.   }
  2066.  
  2067.   /**
  2068.    * Returns the total entropy for the null model
  2069.    *
  2070.    * @return the total null model entropy
  2071.    */
  2072.   public final double SFPriorEntropy() {
  2073.  
  2074.     if (m_NoPriors)
  2075.       return Double.NaN;
  2076.  
  2077.     return m_SumPriorEntropy;
  2078.   }
  2079.  
  2080.   /**
  2081.    * Returns the entropy per instance for the null model
  2082.    *
  2083.    * @return the null model entropy per instance
  2084.    */
  2085.   public final double SFMeanPriorEntropy() {
  2086.  
  2087.     if (m_NoPriors)
  2088.       return Double.NaN;
  2089.  
  2090.     return m_SumPriorEntropy / m_WithClass;
  2091.   }
  2092.  
  2093.   /**
  2094.    * Returns the total entropy for the scheme
  2095.    *
  2096.    * @return the total scheme entropy
  2097.    */
  2098.   public final double SFSchemeEntropy() {
  2099.  
  2100.     if (m_NoPriors)
  2101.       return Double.NaN;
  2102.  
  2103.     return m_SumSchemeEntropy;
  2104.   }
  2105.  
  2106.   /**
  2107.    * Returns the entropy per instance for the scheme
  2108.    *
  2109.    * @return the scheme entropy per instance
  2110.    */
  2111.   public final double SFMeanSchemeEntropy() {
  2112.  
  2113.     if (m_NoPriors)
  2114.       return Double.NaN;
  2115.  
  2116.     return m_SumSchemeEntropy / (m_WithClass - m_Unclassified);
  2117.   }
  2118.  
  2119.   /**
  2120.    * Returns the total SF, which is the null model entropy minus
  2121.    * the scheme entropy.
  2122.    *
  2123.    * @return the total SF
  2124.    */
  2125.   public final double SFEntropyGain() {
  2126.  
  2127.     if (m_NoPriors)
  2128.       return Double.NaN;
  2129.  
  2130.     return m_SumPriorEntropy - m_SumSchemeEntropy;
  2131.   }
  2132.  
  2133.   /**
  2134.    * Returns the SF per instance, which is the null model entropy
  2135.    * minus the scheme entropy, per instance.
  2136.    *
  2137.    * @return the SF per instance
  2138.    */
  2139.   public final double SFMeanEntropyGain() {
  2140.  
  2141.     if (m_NoPriors)
  2142.       return Double.NaN;
  2143.  
  2144.     return (m_SumPriorEntropy - m_SumSchemeEntropy) /
  2145.       (m_WithClass - m_Unclassified);
  2146.   }
  2147.  
  2148.   /**
  2149.    * Output the cumulative margin distribution as a string suitable
  2150.    * for input for gnuplot or similar package.
  2151.    *
  2152.    * @return the cumulative margin distribution
  2153.    * @throws Exception if the class attribute is nominal
  2154.    */
  2155.   public String toCumulativeMarginDistributionString() throws Exception {
  2156.  
  2157.     if (!m_ClassIsNominal) {
  2158.       throw new Exception("Class must be nominal for margin distributions");
  2159.     }
  2160.     String result = "";
  2161.     double cumulativeCount = 0;
  2162.     double margin;
  2163.     for(int i = 0; i <= k_MarginResolution; i++) {
  2164.       if (m_MarginCounts[i] != 0) {
  2165.     cumulativeCount += m_MarginCounts[i];
  2166.     margin = (double)i * 2.0 / k_MarginResolution - 1.0;
  2167.     result = result + Utils.doubleToString(margin, 7, 3) + ' '
  2168.     + Utils.doubleToString(cumulativeCount * 100
  2169.         / m_WithClass, 7, 3) + '\n';
  2170.       } else if (i == 0) {
  2171.     result = Utils.doubleToString(-1.0, 7, 3) + ' '
  2172.     + Utils.doubleToString(0, 7, 3) + '\n';
  2173.       }
  2174.     }
  2175.     return result;
  2176.   }
  2177.  
  2178.  
  2179.   /**
  2180.    * Calls toSummaryString() with no title and no complexity stats
  2181.    *
  2182.    * @return a summary description of the classifier evaluation
  2183.    */
  2184.   public String toSummaryString() {
  2185.  
  2186.     return toSummaryString("", false);
  2187.   }
  2188.  
  2189.   /**
  2190.    * Calls toSummaryString() with a default title.
  2191.    *
  2192.    * @param printComplexityStatistics if true, complexity statistics are
  2193.    * returned as well
  2194.    * @return the summary string
  2195.    */
  2196.   public String toSummaryString(boolean printComplexityStatistics) {
  2197.  
  2198.     return toSummaryString("=== Summary ===\n", printComplexityStatistics);
  2199.   }
  2200.  
  2201.   /**
  2202.    * Outputs the performance statistics in summary form. Lists
  2203.    * number (and percentage) of instances classified correctly,
  2204.    * incorrectly and unclassified. Outputs the total number of
  2205.    * instances classified, and the number of instances (if any)
  2206.    * that had no class value provided.
  2207.    *
  2208.    * @param title the title for the statistics
  2209.    * @param printComplexityStatistics if true, complexity statistics are
  2210.    * returned as well
  2211.    * @return the summary as a String
  2212.    */
  2213.   public String toSummaryString(String title,
  2214.       boolean printComplexityStatistics) {
  2215.  
  2216.     StringBuffer text = new StringBuffer();
  2217.  
  2218.     if (printComplexityStatistics && m_NoPriors) {
  2219.       printComplexityStatistics = false;
  2220.       System.err.println("Priors disabled, cannot print complexity statistics!");
  2221.     }
  2222.  
  2223.     text.append(title + "\n");
  2224.     try {
  2225.       if (m_WithClass > 0) {
  2226.     if (m_ClassIsNominal) {
  2227.  
  2228.       text.append("Correctly Classified Instances     ");
  2229.       text.append(Utils.doubleToString(correct(), 12, 4) + "     " +
  2230.           Utils.doubleToString(pctCorrect(),
  2231.           12, 4) + " %\n");
  2232.       text.append("Incorrectly Classified Instances   ");
  2233.       text.append(Utils.doubleToString(incorrect(), 12, 4) + "     " +
  2234.           Utils.doubleToString(pctIncorrect(),
  2235.           12, 4) + " %\n");
  2236.       text.append("Kappa statistic                    ");
  2237.       text.append(Utils.doubleToString(kappa(), 12, 4) + "\n");
  2238.  
  2239.       if (m_CostMatrix != null) {
  2240.         text.append("Total Cost                         ");
  2241.         text.append(Utils.doubleToString(totalCost(), 12, 4) + "\n");
  2242.         text.append("Average Cost                       ");
  2243.         text.append(Utils.doubleToString(avgCost(), 12, 4) + "\n");
  2244.       }
  2245.       if (printComplexityStatistics) {
  2246.         text.append("K&B Relative Info Score            ");
  2247.         text.append(Utils.doubleToString(KBRelativeInformation(), 12, 4)
  2248.         + " %\n");
  2249.         text.append("K&B Information Score              ");
  2250.         text.append(Utils.doubleToString(KBInformation(), 12, 4)
  2251.         + " bits");
  2252.         text.append(Utils.doubleToString(KBMeanInformation(), 12, 4)
  2253.         + " bits/instance\n");
  2254.       }
  2255.     } else {        
  2256.       text.append("Correlation coefficient            ");
  2257.       text.append(Utils.doubleToString(correlationCoefficient(), 12 , 4) +
  2258.       "\n");
  2259.     }
  2260.     if (printComplexityStatistics) {
  2261.       text.append("Class complexity | order 0         ");
  2262.       text.append(Utils.doubleToString(SFPriorEntropy(), 12, 4)
  2263.           + " bits");
  2264.       text.append(Utils.doubleToString(SFMeanPriorEntropy(), 12, 4)
  2265.           + " bits/instance\n");
  2266.       text.append("Class complexity | scheme          ");
  2267.       text.append(Utils.doubleToString(SFSchemeEntropy(), 12, 4)
  2268.           + " bits");
  2269.       text.append(Utils.doubleToString(SFMeanSchemeEntropy(), 12, 4)
  2270.           + " bits/instance\n");
  2271.       text.append("Complexity improvement     (Sf)    ");
  2272.       text.append(Utils.doubleToString(SFEntropyGain(), 12, 4) + " bits");
  2273.       text.append(Utils.doubleToString(SFMeanEntropyGain(), 12, 4)
  2274.           + " bits/instance\n");
  2275.     }
  2276.  
  2277.     text.append("Mean absolute error                ");
  2278.     text.append(Utils.doubleToString(meanAbsoluteError(), 12, 4)
  2279.         + "\n");
  2280.     text.append("Root mean squared error            ");
  2281.     text.append(Utils.
  2282.         doubleToString(rootMeanSquaredError(), 12, 4)
  2283.         + "\n");
  2284.     if (!m_NoPriors) {
  2285.       text.append("Relative absolute error            ");
  2286.       text.append(Utils.doubleToString(relativeAbsoluteError(),
  2287.           12, 4) + " %\n");
  2288.       text.append("Root relative squared error        ");
  2289.       text.append(Utils.doubleToString(rootRelativeSquaredError(),
  2290.           12, 4) + " %\n");
  2291.     }
  2292.       }
  2293.       if (Utils.gr(unclassified(), 0)) {
  2294.     text.append("UnClassified Instances             ");
  2295.     text.append(Utils.doubleToString(unclassified(), 12,4) +  "     " +
  2296.         Utils.doubleToString(pctUnclassified(),
  2297.         12, 4) + " %\n");
  2298.       }
  2299.       text.append("Total Number of Instances          ");
  2300.       text.append(Utils.doubleToString(m_WithClass, 12, 4) + "\n");
  2301.       if (m_MissingClass > 0) {
  2302.     text.append("Ignored Class Unknown Instances            ");
  2303.     text.append(Utils.doubleToString(m_MissingClass, 12, 4) + "\n");
  2304.       }
  2305.     } catch (Exception ex) {
  2306.       // Should never occur since the class is known to be nominal
  2307.       // here
  2308.       System.err.println("Arggh - Must be a bug in Evaluation class");
  2309.     }
  2310.  
  2311.     return text.toString();
  2312.   }
  2313.  
  2314.   /**
  2315.    * Calls toMatrixString() with a default title.
  2316.    *
  2317.    * @return the confusion matrix as a string
  2318.    * @throws Exception if the class is numeric
  2319.    */
  2320.   public String toMatrixString() throws Exception {
  2321.  
  2322.     return toMatrixString("=== Confusion Matrix ===\n");
  2323.   }
  2324.  
  2325.   /**
  2326.    * Outputs the performance statistics as a classification confusion
  2327.    * matrix. For each class value, shows the distribution of
  2328.    * predicted class values.
  2329.    *
  2330.    * @param title the title for the confusion matrix
  2331.    * @return the confusion matrix as a String
  2332.    * @throws Exception if the class is numeric
  2333.    */
  2334.   public String toMatrixString(String title) throws Exception {
  2335.  
  2336.     StringBuffer text = new StringBuffer();
  2337.     char [] IDChars = {'a','b','c','d','e','f','g','h','i','j',
  2338.     'k','l','m','n','o','p','q','r','s','t',
  2339.     'u','v','w','x','y','z'};
  2340.     int IDWidth;
  2341.     boolean fractional = false;
  2342.  
  2343.     if (!m_ClassIsNominal) {
  2344.       throw new Exception("Evaluation: No confusion matrix possible!");
  2345.     }
  2346.  
  2347.     // Find the maximum value in the matrix
  2348.     // and check for fractional display requirement
  2349.     double maxval = 0;
  2350.     for(int i = 0; i < m_NumClasses; i++) {
  2351.       for(int j = 0; j < m_NumClasses; j++) {
  2352.     double current = m_ConfusionMatrix[i][j];
  2353.     if (current < 0) {
  2354.       current *= -10;
  2355.     }
  2356.     if (current > maxval) {
  2357.       maxval = current;
  2358.     }
  2359.     double fract = current - Math.rint(current);
  2360.     if (!fractional
  2361.         && ((Math.log(fract) / Math.log(10)) >= -2)) {
  2362.       fractional = true;
  2363.     }
  2364.       }
  2365.     }
  2366.  
  2367.     IDWidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10)
  2368.     + (fractional ? 3 : 0)),
  2369.     (int)(Math.log(m_NumClasses) /
  2370.         Math.log(IDChars.length)));
  2371.     text.append(title).append("\n");
  2372.     for(int i = 0; i < m_NumClasses; i++) {
  2373.       if (fractional) {
  2374.     text.append(" ").append(num2ShortID(i,IDChars,IDWidth - 3))
  2375.     .append("   ");
  2376.       } else {
  2377.     text.append(" ").append(num2ShortID(i,IDChars,IDWidth));
  2378.       }
  2379.     }
  2380.     text.append("   <-- classified as\n");
  2381.     for(int i = 0; i< m_NumClasses; i++) {
  2382.       for(int j = 0; j < m_NumClasses; j++) {
  2383.     text.append(" ").append(
  2384.         Utils.doubleToString(m_ConfusionMatrix[i][j],
  2385.         IDWidth,
  2386.         (fractional ? 2 : 0)));
  2387.       }
  2388.       text.append(" | ").append(num2ShortID(i,IDChars,IDWidth))
  2389.       .append(" = ").append(m_ClassNames[i]).append("\n");
  2390.     }
  2391.     return text.toString();
  2392.   }
  2393.  
  2394.   /**
  2395.    * Generates a breakdown of the accuracy for each class (with default title),
  2396.    * incorporating various information-retrieval statistics, such as
  2397.    * true/false positive rate, precision/recall/F-Measure.  Should be
  2398.    * useful for ROC curves, recall/precision curves.  
  2399.    *
  2400.    * @return the statistics presented as a string
  2401.    * @throws Exception if class is not nominal
  2402.    */
  2403.   public String toClassDetailsString() throws Exception {
  2404.  
  2405.     return toClassDetailsString("=== Detailed Accuracy By Class ===\n");
  2406.   }
  2407.  
  2408.   /**
  2409.    * Generates a breakdown of the accuracy for each class,
  2410.    * incorporating various information-retrieval statistics, such as
  2411.    * true/false positive rate, precision/recall/F-Measure.  Should be
  2412.    * useful for ROC curves, recall/precision curves.  
  2413.    *
  2414.    * @param title the title to prepend the stats string with
  2415.    * @return the statistics presented as a string
  2416.    * @throws Exception if class is not nominal
  2417.    */
  2418.   public String toClassDetailsString(String title) throws Exception {
  2419.  
  2420.     if (!m_ClassIsNominal) {
  2421.       throw new Exception("Evaluation: No confusion matrix possible!");
  2422.     }
  2423.  
  2424.     StringBuffer text = new StringBuffer(title
  2425.     + "\n               TP Rate   FP Rate"
  2426.     + "   Precision   Recall"
  2427.     + "  F-Measure   ROC Area  Class\n");
  2428.     for(int i = 0; i < m_NumClasses; i++) {
  2429.       text.append("               " + Utils.doubleToString(truePositiveRate(i), 7, 3))
  2430.       .append("   ");      
  2431.       text.append(Utils.doubleToString(falsePositiveRate(i), 7, 3))
  2432.       .append("    ");
  2433.       text.append(Utils.doubleToString(precision(i), 7, 3))
  2434.       .append("   ");
  2435.       text.append(Utils.doubleToString(recall(i), 7, 3))
  2436.       .append("   ");
  2437.       text.append(Utils.doubleToString(fMeasure(i), 7, 3))
  2438.       .append("    ");
  2439.  
  2440.       double rocVal = areaUnderROC(i);
  2441.       if (Instance.isMissingValue(rocVal)) {
  2442.     text.append("  ?    ")
  2443.     .append("    ");
  2444.       } else {
  2445.     text.append(Utils.doubleToString(rocVal, 7, 3))
  2446.     .append("    ");
  2447.       }
  2448.       text.append(m_ClassNames[i]).append('\n');
  2449.     }
  2450.  
  2451.     text.append("Weighted Avg.  " + Utils.doubleToString(weightedTruePositiveRate(), 7, 3));
  2452.     text.append("   " + Utils.doubleToString(weightedFalsePositiveRate(), 7 ,3));
  2453.     text.append("    " + Utils.doubleToString(weightedPrecision(), 7 ,3));
  2454.     text.append("   " + Utils.doubleToString(weightedRecall(), 7 ,3));
  2455.     text.append("   " + Utils.doubleToString(weightedFMeasure(), 7 ,3));
  2456.     text.append("    " + Utils.doubleToString(weightedAreaUnderROC(), 7 ,3));
  2457.     text.append("\n");
  2458.    
  2459.     return text.toString();
  2460.   }
  2461.  
  2462.   /**
  2463.    * Calculate the number of true positives with respect to a particular class.
  2464.    * This is defined as<p/>
  2465.    * <pre>
  2466.    * correctly classified positives
  2467.    * </pre>
  2468.    *
  2469.    * @param classIndex the index of the class to consider as "positive"
  2470.    * @return the true positive rate
  2471.    */
  2472.   public double numTruePositives(int classIndex) {
  2473.  
  2474.     double correct = 0;
  2475.     for (int j = 0; j < m_NumClasses; j++) {
  2476.       if (j == classIndex) {
  2477.     correct += m_ConfusionMatrix[classIndex][j];
  2478.       }
  2479.     }
  2480.     return correct;
  2481.   }
  2482.  
  2483.   /**
  2484.    * Calculate the true positive rate with respect to a particular class.
  2485.    * This is defined as<p/>
  2486.    * <pre>
  2487.    * correctly classified positives
  2488.    * ------------------------------
  2489.    *       total positives
  2490.    * </pre>
  2491.    *
  2492.    * @param classIndex the index of the class to consider as "positive"
  2493.    * @return the true positive rate
  2494.    */
  2495.   public double truePositiveRate(int classIndex) {
  2496.  
  2497.     double correct = 0, total = 0;
  2498.     for (int j = 0; j < m_NumClasses; j++) {
  2499.       if (j == classIndex) {
  2500.     correct += m_ConfusionMatrix[classIndex][j];
  2501.       }
  2502.       total += m_ConfusionMatrix[classIndex][j];
  2503.     }
  2504.     if (total == 0) {
  2505.       return 0;
  2506.     }
  2507.     return correct / total;
  2508.   }
  2509.  
  2510.   /**
  2511.    * Calculates the weighted (by class size) true positive rate.
  2512.    *
  2513.    * @return the weighted true positive rate.
  2514.    */
  2515.   public double weightedTruePositiveRate() {
  2516.     double[] classCounts = new double[m_NumClasses];
  2517.     double classCountSum = 0;
  2518.    
  2519.     for (int i = 0; i < m_NumClasses; i++) {
  2520.       for (int j = 0; j < m_NumClasses; j++) {
  2521.         classCounts[i] += m_ConfusionMatrix[i][j];
  2522.       }
  2523.       classCountSum += classCounts[i];
  2524.     }
  2525.  
  2526.     double truePosTotal = 0;
  2527.     for(int i = 0; i < m_NumClasses; i++) {
  2528.       double temp = truePositiveRate(i);
  2529.       truePosTotal += (temp * classCounts[i]);
  2530.     }
  2531.  
  2532.     return truePosTotal / classCountSum;
  2533.   }
  2534.  
  2535.   /**
  2536.    * Calculate the number of true negatives with respect to a particular class.
  2537.    * This is defined as<p/>
  2538.    * <pre>
  2539.    * correctly classified negatives
  2540.    * </pre>
  2541.    *
  2542.    * @param classIndex the index of the class to consider as "positive"
  2543.    * @return the true positive rate
  2544.    */
  2545.   public double numTrueNegatives(int classIndex) {
  2546.  
  2547.     double correct = 0;
  2548.     for (int i = 0; i < m_NumClasses; i++) {
  2549.       if (i != classIndex) {
  2550.     for (int j = 0; j < m_NumClasses; j++) {
  2551.       if (j != classIndex) {
  2552.         correct += m_ConfusionMatrix[i][j];
  2553.       }
  2554.     }
  2555.       }
  2556.     }
  2557.     return correct;
  2558.   }
  2559.  
  2560.   /**
  2561.    * Calculate the true negative rate with respect to a particular class.
  2562.    * This is defined as<p/>
  2563.    * <pre>
  2564.    * correctly classified negatives
  2565.    * ------------------------------
  2566.    *       total negatives
  2567.    * </pre>
  2568.    *
  2569.    * @param classIndex the index of the class to consider as "positive"
  2570.    * @return the true positive rate
  2571.    */
  2572.   public double trueNegativeRate(int classIndex) {
  2573.  
  2574.     double correct = 0, total = 0;
  2575.     for (int i = 0; i < m_NumClasses; i++) {
  2576.       if (i != classIndex) {
  2577.     for (int j = 0; j < m_NumClasses; j++) {
  2578.       if (j != classIndex) {
  2579.         correct += m_ConfusionMatrix[i][j];
  2580.       }
  2581.       total += m_ConfusionMatrix[i][j];
  2582.     }
  2583.       }
  2584.     }
  2585.     if (total == 0) {
  2586.       return 0;
  2587.     }
  2588.     return correct / total;
  2589.   }
  2590.  
  2591.   /**
  2592.    * Calculates the weighted (by class size) true negative rate.
  2593.    *
  2594.    * @return the weighted true negative rate.
  2595.    */
  2596.   public double weightedTrueNegativeRate() {
  2597.     double[] classCounts = new double[m_NumClasses];
  2598.     double classCountSum = 0;
  2599.    
  2600.     for (int i = 0; i < m_NumClasses; i++) {
  2601.       for (int j = 0; j < m_NumClasses; j++) {
  2602.         classCounts[i] += m_ConfusionMatrix[i][j];
  2603.       }
  2604.       classCountSum += classCounts[i];
  2605.     }
  2606.  
  2607.     double trueNegTotal = 0;
  2608.     for(int i = 0; i < m_NumClasses; i++) {
  2609.       double temp = trueNegativeRate(i);
  2610.       trueNegTotal += (temp * classCounts[i]);
  2611.     }
  2612.  
  2613.     return trueNegTotal / classCountSum;
  2614.   }
  2615.  
  2616.   /**
  2617.    * Calculate number of false positives with respect to a particular class.
  2618.    * This is defined as<p/>
  2619.    * <pre>
  2620.    * incorrectly classified negatives
  2621.    * </pre>
  2622.    *
  2623.    * @param classIndex the index of the class to consider as "positive"
  2624.    * @return the false positive rate
  2625.    */
  2626.   public double numFalsePositives(int classIndex) {
  2627.  
  2628.     double incorrect = 0;
  2629.     for (int i = 0; i < m_NumClasses; i++) {
  2630.       if (i != classIndex) {
  2631.     for (int j = 0; j < m_NumClasses; j++) {
  2632.       if (j == classIndex) {
  2633.         incorrect += m_ConfusionMatrix[i][j];
  2634.       }
  2635.     }
  2636.       }
  2637.     }
  2638.     return incorrect;
  2639.   }
  2640.  
  2641.   /**
  2642.    * Calculate the false positive rate with respect to a particular class.
  2643.    * This is defined as<p/>
  2644.    * <pre>
  2645.    * incorrectly classified negatives
  2646.    * --------------------------------
  2647.    *        total negatives
  2648.    * </pre>
  2649.    *
  2650.    * @param classIndex the index of the class to consider as "positive"
  2651.    * @return the false positive rate
  2652.    */
  2653.   public double falsePositiveRate(int classIndex) {
  2654.  
  2655.     double incorrect = 0, total = 0;
  2656.     for (int i = 0; i < m_NumClasses; i++) {
  2657.       if (i != classIndex) {
  2658.     for (int j = 0; j < m_NumClasses; j++) {
  2659.       if (j == classIndex) {
  2660.         incorrect += m_ConfusionMatrix[i][j];
  2661.       }
  2662.       total += m_ConfusionMatrix[i][j];
  2663.     }
  2664.       }
  2665.     }
  2666.     if (total == 0) {
  2667.       return 0;
  2668.     }
  2669.     return incorrect / total;
  2670.   }
  2671.  
  2672.   /**
  2673.    * Calculates the weighted (by class size) false positive rate.
  2674.    *
  2675.    * @return the weighted false positive rate.
  2676.    */
  2677.   public double weightedFalsePositiveRate() {
  2678.     double[] classCounts = new double[m_NumClasses];
  2679.     double classCountSum = 0;
  2680.    
  2681.     for (int i = 0; i < m_NumClasses; i++) {
  2682.       for (int j = 0; j < m_NumClasses; j++) {
  2683.         classCounts[i] += m_ConfusionMatrix[i][j];
  2684.       }
  2685.       classCountSum += classCounts[i];
  2686.     }
  2687.  
  2688.     double falsePosTotal = 0;
  2689.     for(int i = 0; i < m_NumClasses; i++) {
  2690.       double temp = falsePositiveRate(i);
  2691.       falsePosTotal += (temp * classCounts[i]);
  2692.     }
  2693.  
  2694.     return falsePosTotal / classCountSum;
  2695.   }
  2696.  
  2697.  
  2698.  
  2699.   /**
  2700.    * Calculate number of false negatives with respect to a particular class.
  2701.    * This is defined as<p/>
  2702.    * <pre>
  2703.    * incorrectly classified positives
  2704.    * </pre>
  2705.    *
  2706.    * @param classIndex the index of the class to consider as "positive"
  2707.    * @return the false positive rate
  2708.    */
  2709.   public double numFalseNegatives(int classIndex) {
  2710.  
  2711.     double incorrect = 0;
  2712.     for (int i = 0; i < m_NumClasses; i++) {
  2713.       if (i == classIndex) {
  2714.     for (int j = 0; j < m_NumClasses; j++) {
  2715.       if (j != classIndex) {
  2716.         incorrect += m_ConfusionMatrix[i][j];
  2717.       }
  2718.     }
  2719.       }
  2720.     }
  2721.     return incorrect;
  2722.   }
  2723.  
  2724.   /**
  2725.    * Calculate the false negative rate with respect to a particular class.
  2726.    * This is defined as<p/>
  2727.    * <pre>
  2728.    * incorrectly classified positives
  2729.    * --------------------------------
  2730.    *        total positives
  2731.    * </pre>
  2732.    *
  2733.    * @param classIndex the index of the class to consider as "positive"
  2734.    * @return the false positive rate
  2735.    */
  2736.   public double falseNegativeRate(int classIndex) {
  2737.  
  2738.     double incorrect = 0, total = 0;
  2739.     for (int i = 0; i < m_NumClasses; i++) {
  2740.       if (i == classIndex) {
  2741.     for (int j = 0; j < m_NumClasses; j++) {
  2742.       if (j != classIndex) {
  2743.         incorrect += m_ConfusionMatrix[i][j];
  2744.       }
  2745.       total += m_ConfusionMatrix[i][j];
  2746.     }
  2747.       }
  2748.     }
  2749.     if (total == 0) {
  2750.       return 0;
  2751.     }
  2752.     return incorrect / total;
  2753.   }
  2754.  
  2755.   /**
  2756.    * Calculates the weighted (by class size) false negative rate.
  2757.    *
  2758.    * @return the weighted false negative rate.
  2759.    */
  2760.   public double weightedFalseNegativeRate() {
  2761.     double[] classCounts = new double[m_NumClasses];
  2762.     double classCountSum = 0;
  2763.    
  2764.     for (int i = 0; i < m_NumClasses; i++) {
  2765.       for (int j = 0; j < m_NumClasses; j++) {
  2766.         classCounts[i] += m_ConfusionMatrix[i][j];
  2767.       }
  2768.       classCountSum += classCounts[i];
  2769.     }
  2770.  
  2771.     double falseNegTotal = 0;
  2772.     for(int i = 0; i < m_NumClasses; i++) {
  2773.       double temp = falseNegativeRate(i);
  2774.       falseNegTotal += (temp * classCounts[i]);
  2775.     }
  2776.  
  2777.     return falseNegTotal / classCountSum;
  2778.   }
  2779.  
  2780.   /**
  2781.    * Calculate the recall with respect to a particular class.
  2782.    * This is defined as<p/>
  2783.    * <pre>
  2784.    * correctly classified positives
  2785.    * ------------------------------
  2786.    *       total positives
  2787.    * </pre><p/>
  2788.    * (Which is also the same as the truePositiveRate.)
  2789.    *
  2790.    * @param classIndex the index of the class to consider as "positive"
  2791.    * @return the recall
  2792.    */
  2793.   public double recall(int classIndex) {
  2794.  
  2795.     return truePositiveRate(classIndex);
  2796.   }
  2797.  
  2798.   /**
  2799.    * Calculates the weighted (by class size) recall.
  2800.    *
  2801.    * @return the weighted recall.
  2802.    */
  2803.   public double weightedRecall() {
  2804.     return weightedTruePositiveRate();
  2805.   }
  2806.  
  2807.   /**
  2808.    * Calculate the precision with respect to a particular class.
  2809.    * This is defined as<p/>
  2810.    * <pre>
  2811.    * correctly classified positives
  2812.    * ------------------------------
  2813.    *  total predicted as positive
  2814.    * </pre>
  2815.    *
  2816.    * @param classIndex the index of the class to consider as "positive"
  2817.    * @return the precision
  2818.    */
  2819.   public double precision(int classIndex) {
  2820.  
  2821.     double correct = 0, total = 0;
  2822.     for (int i = 0; i < m_NumClasses; i++) {
  2823.       if (i == classIndex) {
  2824.     correct += m_ConfusionMatrix[i][classIndex];
  2825.       }
  2826.       total += m_ConfusionMatrix[i][classIndex];
  2827.     }
  2828.     if (total == 0) {
  2829.       return 0;
  2830.     }
  2831.     return correct / total;
  2832.   }
  2833.  
  2834.   /**
  2835.    * Calculates the weighted (by class size) false precision.
  2836.    *
  2837.    * @return the weighted precision.
  2838.    */
  2839.   public double weightedPrecision() {
  2840.     double[] classCounts = new double[m_NumClasses];
  2841.     double classCountSum = 0;
  2842.    
  2843.     for (int i = 0; i < m_NumClasses; i++) {
  2844.       for (int j = 0; j < m_NumClasses; j++) {
  2845.         classCounts[i] += m_ConfusionMatrix[i][j];
  2846.       }
  2847.       classCountSum += classCounts[i];
  2848.     }
  2849.  
  2850.     double precisionTotal = 0;
  2851.     for(int i = 0; i < m_NumClasses; i++) {
  2852.       double temp = precision(i);
  2853.       precisionTotal += (temp * classCounts[i]);
  2854.     }
  2855.  
  2856.     return precisionTotal / classCountSum;
  2857.   }
  2858.  
  2859.   /**
  2860.    * Calculate the F-Measure with respect to a particular class.
  2861.    * This is defined as<p/>
  2862.    * <pre>
  2863.    * 2 * recall * precision
  2864.    * ----------------------
  2865.    *   recall + precision
  2866.    * </pre>
  2867.    *
  2868.    * @param classIndex the index of the class to consider as "positive"
  2869.    * @return the F-Measure
  2870.    */
  2871.   public double fMeasure(int classIndex) {
  2872.  
  2873.     double precision = precision(classIndex);
  2874.     double recall = recall(classIndex);
  2875.     if ((precision + recall) == 0) {
  2876.       return 0;
  2877.     }
  2878.     return 2 * precision * recall / (precision + recall);
  2879.   }
  2880.  
  2881.   /**
  2882.    * Calculates the weighted (by class size) F-Measure.
  2883.    *
  2884.    * @return the weighted F-Measure.
  2885.    */
  2886.   public double weightedFMeasure() {
  2887.     double[] classCounts = new double[m_NumClasses];
  2888.     double classCountSum = 0;
  2889.    
  2890.     for (int i = 0; i < m_NumClasses; i++) {
  2891.       for (int j = 0; j < m_NumClasses; j++) {
  2892.         classCounts[i] += m_ConfusionMatrix[i][j];
  2893.       }
  2894.       classCountSum += classCounts[i];
  2895.     }
  2896.  
  2897.     double fMeasureTotal = 0;
  2898.     for(int i = 0; i < m_NumClasses; i++) {
  2899.       double temp = fMeasure(i);
  2900.       fMeasureTotal += (temp * classCounts[i]);
  2901.     }
  2902.  
  2903.     return fMeasureTotal / classCountSum;
  2904.   }
  2905.  
  2906.   /**
  2907.    * Sets the class prior probabilities
  2908.    *
  2909.    * @param train the training instances used to determine
  2910.    * the prior probabilities
  2911.    * @throws Exception if the class attribute of the instances is not
  2912.    * set
  2913.    */
  2914.   public void setPriors(Instances train) throws Exception {
  2915.     m_NoPriors = false;
  2916.  
  2917.     if (!m_ClassIsNominal) {
  2918.  
  2919.       m_NumTrainClassVals = 0;
  2920.       m_TrainClassVals = null;
  2921.       m_TrainClassWeights = null;
  2922.       m_PriorErrorEstimator = null;
  2923.       m_ErrorEstimator = null;
  2924.  
  2925.       for (int i = 0; i < train.numInstances(); i++) {
  2926.     Instance currentInst = train.instance(i);
  2927.     if (!currentInst.classIsMissing()) {
  2928.       addNumericTrainClass(currentInst.classValue(),
  2929.           currentInst.weight());
  2930.     }
  2931.       }
  2932.  
  2933.     } else {
  2934.       for (int i = 0; i < m_NumClasses; i++) {
  2935.     m_ClassPriors[i] = 1;
  2936.       }
  2937.       m_ClassPriorsSum = m_NumClasses;
  2938.       for (int i = 0; i < train.numInstances(); i++) {
  2939.     if (!train.instance(i).classIsMissing()) {
  2940.       m_ClassPriors[(int)train.instance(i).classValue()] +=
  2941.         train.instance(i).weight();
  2942.       m_ClassPriorsSum += train.instance(i).weight();
  2943.     }
  2944.       }
  2945.     }
  2946.   }
  2947.  
  2948.   /**
  2949.    * Get the current weighted class counts
  2950.    *
  2951.    * @return the weighted class counts
  2952.    */
  2953.   public double [] getClassPriors() {
  2954.     return m_ClassPriors;
  2955.   }
  2956.  
  2957.   /**
  2958.    * Updates the class prior probabilities (when incrementally
  2959.    * training)
  2960.    *
  2961.    * @param instance the new training instance seen
  2962.    * @throws Exception if the class of the instance is not
  2963.    * set
  2964.    */
  2965.   public void updatePriors(Instance instance) throws Exception {
  2966.     if (!instance.classIsMissing()) {
  2967.       if (!m_ClassIsNominal) {
  2968.     if (!instance.classIsMissing()) {
  2969.       addNumericTrainClass(instance.classValue(),
  2970.           instance.weight());
  2971.     }
  2972.       } else {
  2973.     m_ClassPriors[(int)instance.classValue()] +=
  2974.       instance.weight();
  2975.     m_ClassPriorsSum += instance.weight();
  2976.       }
  2977.     }    
  2978.   }
  2979.  
  2980.   /**
  2981.    * disables the use of priors, e.g., in case of de-serialized schemes
  2982.    * that have no access to the original training set, but are evaluated
  2983.    * on a set set.
  2984.    */
  2985.   public void useNoPriors() {
  2986.     m_NoPriors = true;
  2987.   }
  2988.  
  2989.   /**
  2990.    * Tests whether the current evaluation object is equal to another
  2991.    * evaluation object
  2992.    *
  2993.    * @param obj the object to compare against
  2994.    * @return true if the two objects are equal
  2995.    */
  2996.   public boolean equals(Object obj) {
  2997.  
  2998.     if ((obj == null) || !(obj.getClass().equals(this.getClass()))) {
  2999.       return false;
  3000.     }
  3001.     Evaluation cmp = (Evaluation) obj;
  3002.     if (m_ClassIsNominal != cmp.m_ClassIsNominal) return false;
  3003.     if (m_NumClasses != cmp.m_NumClasses) return false;
  3004.  
  3005.     if (m_Incorrect != cmp.m_Incorrect) return false;
  3006.     if (m_Correct != cmp.m_Correct) return false;
  3007.     if (m_Unclassified != cmp.m_Unclassified) return false;
  3008.     if (m_MissingClass != cmp.m_MissingClass) return false;
  3009.     if (m_WithClass != cmp.m_WithClass) return false;
  3010.  
  3011.     if (m_SumErr != cmp.m_SumErr) return false;
  3012.     if (m_SumAbsErr != cmp.m_SumAbsErr) return false;
  3013.     if (m_SumSqrErr != cmp.m_SumSqrErr) return false;
  3014.     if (m_SumClass != cmp.m_SumClass) return false;
  3015.     if (m_SumSqrClass != cmp.m_SumSqrClass) return false;
  3016.     if (m_SumPredicted != cmp.m_SumPredicted) return false;
  3017.     if (m_SumSqrPredicted != cmp.m_SumSqrPredicted) return false;
  3018.     if (m_SumClassPredicted != cmp.m_SumClassPredicted) return false;
  3019.  
  3020.     if (m_ClassIsNominal) {
  3021.       for (int i = 0; i < m_NumClasses; i++) {
  3022.     for (int j = 0; j < m_NumClasses; j++) {
  3023.       if (m_ConfusionMatrix[i][j] != cmp.m_ConfusionMatrix[i][j]) {
  3024.         return false;
  3025.       }
  3026.     }
  3027.       }
  3028.     }
  3029.  
  3030.     return true;
  3031.   }
  3032.  
  3033.   /**
  3034.    * Prints the predictions for the given dataset into a String variable.
  3035.    *
  3036.    * @param classifier      the classifier to use
  3037.    * @param train       the training data
  3038.    * @param testSource      the test set
  3039.    * @param classIndex      the class index (1-based), if -1 ot does not
  3040.    *                override the class index is stored in the data
  3041.    *                file (by using the last attribute)
  3042.    * @param attributesToOutput  the indices of the attributes to output
  3043.    * @return            the generated predictions for the attribute range
  3044.    * @throws Exception      if test file cannot be opened
  3045.    */
  3046.   public static void printClassifications(Classifier classifier,
  3047.                                             Instances train,
  3048.                                             DataSource testSource,
  3049.                                             int classIndex,
  3050.                                             Range attributesToOutput,
  3051.                                             StringBuffer predsText) throws Exception {
  3052.    
  3053.     printClassifications(classifier, train,
  3054.                          testSource, classIndex,
  3055.                          attributesToOutput, false, predsText);
  3056.   }
  3057.  
  3058.   /**
  3059.    * Prints the header for the predictions output into a supplied StringBuffer
  3060.    *
  3061.    * @param test structure of the test set to print predictions for
  3062.    * @param attributesToOutput indices of the attributes to output
  3063.    * @param printDistribution prints the complete distribution for nominal
  3064.    * attributes, not just the predicted value
  3065.    * @param text the StringBuffer to print to
  3066.    */
  3067.   protected static void printClassificationsHeader(Instances test,
  3068.                                                    Range attributesToOutput,
  3069.                                                    boolean printDistribution,
  3070.                                                    StringBuffer text) {
  3071.     // print header
  3072.     if (test.classAttribute().isNominal())
  3073.       if (printDistribution)
  3074.         text.append(" inst#     actual  predicted error distribution");
  3075.       else
  3076.         text.append(" inst#     actual  predicted error prediction");
  3077.     else
  3078.       text.append(" inst#     actual  predicted      error");
  3079.     if (attributesToOutput != null) {
  3080.       attributesToOutput.setUpper(test.numAttributes() - 1);
  3081.       text.append(" (");
  3082.       boolean first = true;
  3083.       for (int i = 0; i < test.numAttributes(); i++) {
  3084.         if (i == test.classIndex())
  3085.           continue;
  3086.  
  3087.         if (attributesToOutput.isInRange(i)) {
  3088.           if (!first)
  3089.             text.append(",");
  3090.           text.append(test.attribute(i).name());
  3091.           first = false;
  3092.         }
  3093.       }
  3094.       text.append(")");
  3095.     }
  3096.     text.append("\n");
  3097.   }
  3098.  
  3099.   /**
  3100.    * Prints the predictions for the given dataset into a supplied StringBuffer
  3101.    *
  3102.    * @param classifier      the classifier to use
  3103.    * @param train       the training data
  3104.    * @param testSource      the test set
  3105.    * @param classIndex      the class index (1-based), if -1 ot does not
  3106.    *                override the class index is stored in the data
  3107.    *                file (by using the last attribute)
  3108.    * @param attributesToOutput  the indices of the attributes to output
  3109.    * @param printDistribution   prints the complete distribution for nominal
  3110.    *                classes, not just the predicted value
  3111.    * @param text                StringBuffer to hold the printed predictions
  3112.    * @throws Exception      if test file cannot be opened
  3113.    */
  3114.   public static void printClassifications(Classifier classifier,
  3115.                                           Instances train,
  3116.                                           DataSource testSource,
  3117.                                           int classIndex,
  3118.                                           Range attributesToOutput,
  3119.                                           boolean printDistribution,
  3120.                                           StringBuffer text) throws Exception {
  3121.  
  3122.     if (testSource != null) {
  3123.       Instances test = testSource.getStructure();
  3124.       if (classIndex != -1) {
  3125.     test.setClassIndex(classIndex - 1);
  3126.       } else {
  3127.     if (test.classIndex() == -1)
  3128.       test.setClassIndex(test.numAttributes() - 1);
  3129.       }
  3130.  
  3131.       // print the header
  3132.       printClassificationsHeader(test, attributesToOutput, printDistribution, text);
  3133.  
  3134.       // print predictions
  3135.       int i = 0;
  3136.       testSource.reset();
  3137.       test = testSource.getStructure(test.classIndex());
  3138.       while (testSource.hasMoreElements(test)) {
  3139.     Instance inst = testSource.nextElement(test);
  3140.         text.append(predictionText(classifier, inst, i,
  3141.                                    attributesToOutput, printDistribution));
  3142.     i++;
  3143.       }
  3144.     }
  3145.     //    return text.toString();
  3146.   }
  3147.  
  3148.   /**
  3149.    * store the prediction made by the classifier as a string
  3150.    *
  3151.    * @param classifier      the classifier to use
  3152.    * @param inst        the instance to generate text from
  3153.    * @param instNum     the index in the dataset
  3154.    * @param attributesToOutput  the indices of the attributes to output
  3155.    * @param printDistribution   prints the complete distribution for nominal
  3156.    *                classes, not just the predicted value
  3157.    * @return                    the prediction as a String
  3158.    * @throws Exception      if something goes wrong
  3159.    * @see           #printClassifications(Classifier, Instances, String, int, Range, boolean)
  3160.    */
  3161.   protected static String predictionText(Classifier classifier,
  3162.                                          Instance inst,
  3163.                                          int instNum,
  3164.                                          Range attributesToOutput,
  3165.                                          boolean printDistribution)
  3166.    
  3167.     throws Exception {
  3168.  
  3169.     StringBuffer result = new StringBuffer();
  3170.     int width = 10;
  3171.     int prec = 3;
  3172.  
  3173.     Instance withMissing = (Instance)inst.copy();
  3174.     withMissing.setDataset(inst.dataset());
  3175.     withMissing.setMissing(withMissing.classIndex());
  3176.     double predValue = classifier.classifyInstance(withMissing);
  3177.  
  3178.     // index
  3179.     result.append(Utils.padLeft("" + (instNum+1), 6));
  3180.  
  3181.     if (inst.dataset().classAttribute().isNumeric()) {
  3182.       // actual
  3183.       if (inst.classIsMissing())
  3184.     result.append(" " + Utils.padLeft("?", width));
  3185.       else
  3186.     result.append(" " + Utils.doubleToString(inst.classValue(), width, prec));
  3187.       // predicted
  3188.       if (Instance.isMissingValue(predValue))
  3189.     result.append(" " + Utils.padLeft("?", width));
  3190.       else
  3191.     result.append(" " + Utils.doubleToString(predValue, width, prec));
  3192.       // error
  3193.       if (Instance.isMissingValue(predValue) || inst.classIsMissing())
  3194.     result.append(" " + Utils.padLeft("?", width));
  3195.       else
  3196.     result.append(" " + Utils.doubleToString(predValue - inst.classValue(), width, prec));
  3197.     } else {
  3198.       // actual
  3199.       result.append(" " + Utils.padLeft(((int) inst.classValue()+1) + ":" + inst.toString(inst.classIndex()), width));
  3200.       // predicted
  3201.       if (Instance.isMissingValue(predValue))
  3202.     result.append(" " + Utils.padLeft("?", width));
  3203.       else
  3204.     result.append(" " + Utils.padLeft(((int) predValue+1) + ":" + inst.dataset().classAttribute().value((int)predValue), width));
  3205.       // error?
  3206.       if (!Instance.isMissingValue(predValue) && !inst.classIsMissing() && ((int) predValue+1 != (int) inst.classValue()+1))
  3207.     result.append(" " + "  +  ");
  3208.       else
  3209.     result.append(" " + "     ");
  3210.       // prediction/distribution
  3211.       if (printDistribution) {
  3212.     if (Instance.isMissingValue(predValue)) {
  3213.       result.append(" " + "?");
  3214.     }
  3215.     else {
  3216.       result.append(" ");
  3217.       double[] dist = classifier.distributionForInstance(withMissing);
  3218.       for (int n = 0; n < dist.length; n++) {
  3219.         if (n > 0)
  3220.           result.append(",");
  3221.         if (n == (int) predValue)
  3222.           result.append("*");
  3223.             result.append(Utils.doubleToString(dist[n], prec));
  3224.       }
  3225.     }
  3226.       }
  3227.       else {
  3228.     if (Instance.isMissingValue(predValue))
  3229.       result.append(" " + "?");
  3230.     else
  3231.       result.append(" " + Utils.doubleToString(classifier.distributionForInstance(withMissing) [(int)predValue], prec));
  3232.       }
  3233.     }
  3234.  
  3235.     // attributes
  3236.     result.append(" " + attributeValuesString(withMissing, attributesToOutput) + "\n");
  3237.  
  3238.     return result.toString();
  3239.   }
  3240.  
  3241.   /**
  3242.    * Builds a string listing the attribute values in a specified range of indices,
  3243.    * separated by commas and enclosed in brackets.
  3244.    *
  3245.    * @param instance the instance to print the values from
  3246.    * @param attRange the range of the attributes to list
  3247.    * @return a string listing values of the attributes in the range
  3248.    */
  3249.   protected static String attributeValuesString(Instance instance, Range attRange) {
  3250.     StringBuffer text = new StringBuffer();
  3251.     if (attRange != null) {
  3252.       boolean firstOutput = true;
  3253.       attRange.setUpper(instance.numAttributes() - 1);
  3254.       for (int i=0; i<instance.numAttributes(); i++)
  3255.     if (attRange.isInRange(i) && i != instance.classIndex()) {
  3256.       if (firstOutput) text.append("(");
  3257.       else text.append(",");
  3258.       text.append(instance.toString(i));
  3259.       firstOutput = false;
  3260.     }
  3261.       if (!firstOutput) text.append(")");
  3262.     }
  3263.     return text.toString();
  3264.   }
  3265.  
  3266.   /**
  3267.    * Make up the help string giving all the command line options
  3268.    *
  3269.    * @param classifier the classifier to include options for
  3270.    * @param globalInfo include the global information string
  3271.    * for the classifier (if available).
  3272.    * @return a string detailing the valid command line options
  3273.    */
  3274.   protected static String makeOptionString(Classifier classifier,
  3275.                                            boolean globalInfo) {
  3276.  
  3277.     StringBuffer optionsText = new StringBuffer("");
  3278.  
  3279.     // General options
  3280.     optionsText.append("\n\nGeneral options:\n\n");
  3281.     optionsText.append("-h or -help\n");
  3282.     optionsText.append("\tOutput help information.\n");
  3283.     optionsText.append("-synopsis or -info\n");
  3284.     optionsText.append("\tOutput synopsis for classifier (use in conjunction "
  3285.         + " with -h)\n");
  3286.     optionsText.append("-t <name of training file>\n");
  3287.     optionsText.append("\tSets training file.\n");
  3288.     optionsText.append("-T <name of test file>\n");
  3289.     optionsText.append("\tSets test file. If missing, a cross-validation will be performed\n");
  3290.     optionsText.append("\ton the training data.\n");
  3291.     optionsText.append("-c <class index>\n");
  3292.     optionsText.append("\tSets index of class attribute (default: last).\n");
  3293.     optionsText.append("-x <number of folds>\n");
  3294.     optionsText.append("\tSets number of folds for cross-validation (default: 10).\n");
  3295.     optionsText.append("-no-cv\n");
  3296.     optionsText.append("\tDo not perform any cross validation.\n");
  3297.     optionsText.append("-split-percentage <percentage>\n");
  3298.     optionsText.append("\tSets the percentage for the train/test set split, e.g., 66.\n");
  3299.     optionsText.append("-preserve-order\n");
  3300.     optionsText.append("\tPreserves the order in the percentage split.\n");
  3301.     optionsText.append("-s <random number seed>\n");
  3302.     optionsText.append("\tSets random number seed for cross-validation or percentage split\n");
  3303.     optionsText.append("\t(default: 1).\n");
  3304.     optionsText.append("-m <name of file with cost matrix>\n");
  3305.     optionsText.append("\tSets file with cost matrix.\n");
  3306.     optionsText.append("-l <name of input file>\n");
  3307.     optionsText.append("\tSets model input file. In case the filename ends with '.xml',\n");
  3308.     optionsText.append("\ta PMML file is loaded or, if that fails, options are loaded\n");
  3309.     optionsText.append("\tfrom the XML file.\n");
  3310.     optionsText.append("-d <name of output file>\n");
  3311.     optionsText.append("\tSets model output file. In case the filename ends with '.xml',\n");
  3312.     optionsText.append("\tonly the options are saved to the XML file, not the model.\n");
  3313.     optionsText.append("-v\n");
  3314.     optionsText.append("\tOutputs no statistics for training data.\n");
  3315.     optionsText.append("-o\n");
  3316.     optionsText.append("\tOutputs statistics only, not the classifier.\n");
  3317.     optionsText.append("-i\n");
  3318.     optionsText.append("\tOutputs detailed information-retrieval");
  3319.     optionsText.append(" statistics for each class.\n");
  3320.     optionsText.append("-k\n");
  3321.     optionsText.append("\tOutputs information-theoretic statistics.\n");
  3322.     optionsText.append("-p <attribute range>\n");
  3323.     optionsText.append("\tOnly outputs predictions for test instances (or the train\n"
  3324.     + "\tinstances if no test instances provided and -no-cv is used),\n"
  3325.     + "\talong with attributes (0 for none).\n");
  3326.     optionsText.append("-distribution\n");
  3327.     optionsText.append("\tOutputs the distribution instead of only the prediction\n");
  3328.     optionsText.append("\tin conjunction with the '-p' option (only nominal classes).\n");
  3329.     optionsText.append("-r\n");
  3330.     optionsText.append("\tOnly outputs cumulative margin distribution.\n");
  3331.     if (classifier instanceof Sourcable) {
  3332.       optionsText.append("-z <class name>\n");
  3333.       optionsText.append("\tOnly outputs the source representation"
  3334.       + " of the classifier,\n\tgiving it the supplied"
  3335.       + " name.\n");
  3336.     }
  3337.     if (classifier instanceof Drawable) {
  3338.       optionsText.append("-g\n");
  3339.       optionsText.append("\tOnly outputs the graph representation"
  3340.       + " of the classifier.\n");
  3341.     }
  3342.     optionsText.append("-xml filename | xml-string\n");
  3343.     optionsText.append("\tRetrieves the options from the XML-data instead of the "
  3344.     + "command line.\n");
  3345.     optionsText.append("-threshold-file <file>\n");
  3346.     optionsText.append("\tThe file to save the threshold data to.\n"
  3347.     + "\tThe format is determined by the extensions, e.g., '.arff' for ARFF \n"
  3348.     + "\tformat or '.csv' for CSV.\n");
  3349.     optionsText.append("-threshold-label <label>\n");
  3350.     optionsText.append("\tThe class label to determine the threshold data for\n"
  3351.     + "\t(default is the first label)\n");
  3352.  
  3353.     // Get scheme-specific options
  3354.     if (classifier instanceof OptionHandler) {
  3355.       optionsText.append("\nOptions specific to "
  3356.       + classifier.getClass().getName()
  3357.       + ":\n\n");
  3358.       Enumeration enu = ((OptionHandler)classifier).listOptions();
  3359.       while (enu.hasMoreElements()) {
  3360.     Option option = (Option) enu.nextElement();
  3361.     optionsText.append(option.synopsis() + '\n');
  3362.     optionsText.append(option.description() + "\n");
  3363.       }
  3364.     }
  3365.    
  3366.     // Get global information (if available)
  3367.     if (globalInfo) {
  3368.       try {
  3369.         String gi = getGlobalInfo(classifier);
  3370.         optionsText.append(gi);
  3371.       } catch (Exception ex) {
  3372.         // quietly ignore
  3373.       }
  3374.     }
  3375.     return optionsText.toString();
  3376.   }
  3377.  
  3378.   /**
  3379.    * Return the global info (if it exists) for the supplied classifier
  3380.    *
  3381.    * @param classifier the classifier to get the global info for
  3382.    * @return the global info (synopsis) for the classifier
  3383.    * @throws Exception if there is a problem reflecting on the classifier
  3384.    */
  3385.   protected static String getGlobalInfo(Classifier classifier) throws Exception {
  3386.     BeanInfo bi = Introspector.getBeanInfo(classifier.getClass());
  3387.     MethodDescriptor[] methods;
  3388.     methods = bi.getMethodDescriptors();
  3389.     Object[] args = {};
  3390.     String result = "\nSynopsis for " + classifier.getClass().getName()
  3391.       + ":\n\n";
  3392.    
  3393.     for (int i = 0; i < methods.length; i++) {
  3394.       String name = methods[i].getDisplayName();
  3395.       Method meth = methods[i].getMethod();
  3396.       if (name.equals("globalInfo")) {
  3397.         String globalInfo = (String)(meth.invoke(classifier, args));
  3398.         result += globalInfo;
  3399.         break;
  3400.       }
  3401.     }
  3402.    
  3403.     return result;
  3404.   }
  3405.  
  3406.   /**
  3407.    * Method for generating indices for the confusion matrix.
  3408.    *
  3409.    * @param num     integer to format
  3410.    * @param IDChars the characters to use
  3411.    * @param IDWidth the width of the entry
  3412.    * @return        the formatted integer as a string
  3413.    */
  3414.   protected String num2ShortID(int num, char[] IDChars, int IDWidth) {
  3415.  
  3416.     char ID [] = new char [IDWidth];
  3417.     int i;
  3418.  
  3419.     for(i = IDWidth - 1; i >=0; i--) {
  3420.       ID[i] = IDChars[num % IDChars.length];
  3421.       num = num / IDChars.length - 1;
  3422.       if (num < 0) {
  3423.     break;
  3424.       }
  3425.     }
  3426.     for(i--; i >= 0; i--) {
  3427.       ID[i] = ' ';
  3428.     }
  3429.  
  3430.     return new String(ID);
  3431.   }
  3432.  
  3433.   /**
  3434.    * Convert a single prediction into a probability distribution
  3435.    * with all zero probabilities except the predicted value which
  3436.    * has probability 1.0;
  3437.    *
  3438.    * @param predictedClass the index of the predicted class
  3439.    * @return the probability distribution
  3440.    */
  3441.   protected double [] makeDistribution(double predictedClass) {
  3442.  
  3443.     double [] result = new double [m_NumClasses];
  3444.     if (Instance.isMissingValue(predictedClass)) {
  3445.       return result;
  3446.     }
  3447.     if (m_ClassIsNominal) {
  3448.       result[(int)predictedClass] = 1.0;
  3449.     } else {
  3450.       result[0] = predictedClass;
  3451.     }
  3452.     return result;
  3453.   }
  3454.  
  3455.   /**
  3456.    * Updates all the statistics about a classifiers performance for
  3457.    * the current test instance.
  3458.    *
  3459.    * @param predictedDistribution the probabilities assigned to
  3460.    * each class
  3461.    * @param instance the instance to be classified
  3462.    * @throws Exception if the class of the instance is not
  3463.    * set
  3464.    */
  3465.   protected void updateStatsForClassifier(double [] predictedDistribution,
  3466.       Instance instance)
  3467.   throws Exception {
  3468.  
  3469.     int actualClass = (int)instance.classValue();
  3470.  
  3471.     if (!instance.classIsMissing()) {
  3472.       updateMargins(predictedDistribution, actualClass, instance.weight());
  3473.  
  3474.       // Determine the predicted class (doesn't detect multiple
  3475.       // classifications)
  3476.       int predictedClass = -1;
  3477.       double bestProb = 0.0;
  3478.       for(int i = 0; i < m_NumClasses; i++) {
  3479.     if (predictedDistribution[i] > bestProb) {
  3480.       predictedClass = i;
  3481.       bestProb = predictedDistribution[i];
  3482.     }
  3483.       }
  3484.  
  3485.       m_WithClass += instance.weight();
  3486.  
  3487.       // Determine misclassification cost
  3488.       if (m_CostMatrix != null) {
  3489.     if (predictedClass < 0) {
  3490.       // For missing predictions, we assume the worst possible cost.
  3491.       // This is pretty harsh.
  3492.       // Perhaps we could take the negative of the cost of a correct
  3493.       // prediction (-m_CostMatrix.getElement(actualClass,actualClass)),
  3494.       // although often this will be zero
  3495.       m_TotalCost += instance.weight()
  3496.       * m_CostMatrix.getMaxCost(actualClass, instance);
  3497.     } else {
  3498.       m_TotalCost += instance.weight()
  3499.       * m_CostMatrix.getElement(actualClass, predictedClass,
  3500.           instance);
  3501.     }
  3502.       }
  3503.  
  3504.       // Update counts when no class was predicted
  3505.       if (predictedClass < 0) {
  3506.     m_Unclassified += instance.weight();
  3507.     return;
  3508.       }
  3509.  
  3510.       double predictedProb = Math.max(MIN_SF_PROB,
  3511.       predictedDistribution[actualClass]);
  3512.       double priorProb = Math.max(MIN_SF_PROB,
  3513.       m_ClassPriors[actualClass]
  3514.                     / m_ClassPriorsSum);
  3515.       if (predictedProb >= priorProb) {
  3516.     m_SumKBInfo += (Utils.log2(predictedProb) -
  3517.         Utils.log2(priorProb))
  3518.         * instance.weight();
  3519.       } else {
  3520.     m_SumKBInfo -= (Utils.log2(1.0-predictedProb) -
  3521.         Utils.log2(1.0-priorProb))
  3522.         * instance.weight();
  3523.       }
  3524.  
  3525.       m_SumSchemeEntropy -= Utils.log2(predictedProb) * instance.weight();
  3526.       m_SumPriorEntropy -= Utils.log2(priorProb) * instance.weight();
  3527.  
  3528.       updateNumericScores(predictedDistribution,
  3529.       makeDistribution(instance.classValue()),
  3530.       instance.weight());
  3531.  
  3532.       // Update other stats
  3533.       m_ConfusionMatrix[actualClass][predictedClass] += instance.weight();
  3534.       if (predictedClass != actualClass) {
  3535.     m_Incorrect += instance.weight();
  3536.       } else {
  3537.     m_Correct += instance.weight();
  3538.       }
  3539.     } else {
  3540.       m_MissingClass += instance.weight();
  3541.     }
  3542.   }
  3543.  
  3544.   /**
  3545.    * Updates all the statistics about a predictors performance for
  3546.    * the current test instance.
  3547.    *
  3548.    * @param predictedValue the numeric value the classifier predicts
  3549.    * @param instance the instance to be classified
  3550.    * @throws Exception if the class of the instance is not
  3551.    * set
  3552.    */
  3553.   protected void updateStatsForPredictor(double predictedValue,
  3554.       Instance instance)
  3555.   throws Exception {
  3556.  
  3557.     if (!instance.classIsMissing()){
  3558.  
  3559.       // Update stats
  3560.       m_WithClass += instance.weight();
  3561.       if (Instance.isMissingValue(predictedValue)) {
  3562.     m_Unclassified += instance.weight();
  3563.     return;
  3564.       }
  3565.       m_SumClass += instance.weight() * instance.classValue();
  3566.       m_SumSqrClass += instance.weight() * instance.classValue()
  3567.       * instance.classValue();
  3568.       m_SumClassPredicted += instance.weight()
  3569.       * instance.classValue() * predictedValue;
  3570.       m_SumPredicted += instance.weight() * predictedValue;
  3571.       m_SumSqrPredicted += instance.weight() * predictedValue * predictedValue;
  3572.  
  3573.       if (m_ErrorEstimator == null) {
  3574.     setNumericPriorsFromBuffer();
  3575.       }
  3576.       double predictedProb = Math.max(m_ErrorEstimator.getProbability(
  3577.       predictedValue
  3578.       - instance.classValue()),
  3579.       MIN_SF_PROB);
  3580.       double priorProb = Math.max(m_PriorErrorEstimator.getProbability(
  3581.       instance.classValue()),
  3582.       MIN_SF_PROB);
  3583.  
  3584.       m_SumSchemeEntropy -= Utils.log2(predictedProb) * instance.weight();
  3585.       m_SumPriorEntropy -= Utils.log2(priorProb) * instance.weight();
  3586.       m_ErrorEstimator.addValue(predictedValue - instance.classValue(),
  3587.       instance.weight());
  3588.  
  3589.       updateNumericScores(makeDistribution(predictedValue),
  3590.       makeDistribution(instance.classValue()),
  3591.       instance.weight());
  3592.  
  3593.     } else
  3594.       m_MissingClass += instance.weight();
  3595.   }
  3596.  
  3597.   /**
  3598.    * Update the cumulative record of classification margins
  3599.    *
  3600.    * @param predictedDistribution the probability distribution predicted for
  3601.    * the current instance
  3602.    * @param actualClass the index of the actual instance class
  3603.    * @param weight the weight assigned to the instance
  3604.    */
  3605.   protected void updateMargins(double [] predictedDistribution,
  3606.       int actualClass, double weight) {
  3607.  
  3608.     double probActual = predictedDistribution[actualClass];
  3609.     double probNext = 0;
  3610.  
  3611.     for(int i = 0; i < m_NumClasses; i++)
  3612.       if ((i != actualClass) &&
  3613.       (predictedDistribution[i] > probNext))
  3614.     probNext = predictedDistribution[i];
  3615.  
  3616.     double margin = probActual - probNext;
  3617.     int bin = (int)((margin + 1.0) / 2.0 * k_MarginResolution);
  3618.     m_MarginCounts[bin] += weight;
  3619.   }
  3620.  
  3621.   /**
  3622.    * Update the numeric accuracy measures. For numeric classes, the
  3623.    * accuracy is between the actual and predicted class values. For
  3624.    * nominal classes, the accuracy is between the actual and
  3625.    * predicted class probabilities.
  3626.    *
  3627.    * @param predicted the predicted values
  3628.    * @param actual the actual value
  3629.    * @param weight the weight associated with this prediction
  3630.    */
  3631.   protected void updateNumericScores(double [] predicted,
  3632.       double [] actual, double weight) {
  3633.  
  3634.     double diff;
  3635.     double sumErr = 0, sumAbsErr = 0, sumSqrErr = 0;
  3636.     double sumPriorAbsErr = 0, sumPriorSqrErr = 0;
  3637.     for(int i = 0; i < m_NumClasses; i++) {
  3638.       diff = predicted[i] - actual[i];
  3639.       sumErr += diff;
  3640.       sumAbsErr += Math.abs(diff);
  3641.       sumSqrErr += diff * diff;
  3642.       diff = (m_ClassPriors[i] / m_ClassPriorsSum) - actual[i];
  3643.       sumPriorAbsErr += Math.abs(diff);
  3644.       sumPriorSqrErr += diff * diff;
  3645.     }
  3646.     m_SumErr += weight * sumErr / m_NumClasses;
  3647.     m_SumAbsErr += weight * sumAbsErr / m_NumClasses;
  3648.     m_SumSqrErr += weight * sumSqrErr / m_NumClasses;
  3649.     m_SumPriorAbsErr += weight * sumPriorAbsErr / m_NumClasses;
  3650.     m_SumPriorSqrErr += weight * sumPriorSqrErr / m_NumClasses;
  3651.   }
  3652.  
  3653.   /**
  3654.    * Adds a numeric (non-missing) training class value and weight to
  3655.    * the buffer of stored values.
  3656.    *
  3657.    * @param classValue the class value
  3658.    * @param weight the instance weight
  3659.    */
  3660.   protected void addNumericTrainClass(double classValue, double weight) {
  3661.  
  3662.     if (m_TrainClassVals == null) {
  3663.       m_TrainClassVals = new double [100];
  3664.       m_TrainClassWeights = new double [100];
  3665.     }
  3666.     if (m_NumTrainClassVals == m_TrainClassVals.length) {
  3667.       double [] temp = new double [m_TrainClassVals.length * 2];
  3668.       System.arraycopy(m_TrainClassVals, 0,
  3669.       temp, 0, m_TrainClassVals.length);
  3670.       m_TrainClassVals = temp;
  3671.  
  3672.       temp = new double [m_TrainClassWeights.length * 2];
  3673.       System.arraycopy(m_TrainClassWeights, 0,
  3674.       temp, 0, m_TrainClassWeights.length);
  3675.       m_TrainClassWeights = temp;
  3676.     }
  3677.     m_TrainClassVals[m_NumTrainClassVals] = classValue;
  3678.     m_TrainClassWeights[m_NumTrainClassVals] = weight;
  3679.     m_NumTrainClassVals++;
  3680.   }
  3681.  
  3682.   /**
  3683.    * Sets up the priors for numeric class attributes from the
  3684.    * training class values that have been seen so far.
  3685.    */
  3686.   protected void setNumericPriorsFromBuffer() {
  3687.  
  3688.     double numPrecision = 0.01; // Default value
  3689.     if (m_NumTrainClassVals > 1) {
  3690.       double [] temp = new double [m_NumTrainClassVals];
  3691.       System.arraycopy(m_TrainClassVals, 0, temp, 0, m_NumTrainClassVals);
  3692.       int [] index = Utils.sort(temp);
  3693.       double lastVal = temp[index[0]];
  3694.       double deltaSum = 0;
  3695.       int distinct = 0;
  3696.       for (int i = 1; i < temp.length; i++) {
  3697.     double current = temp[index[i]];
  3698.     if (current != lastVal) {
  3699.       deltaSum += current - lastVal;
  3700.       lastVal = current;
  3701.       distinct++;
  3702.     }
  3703.       }
  3704.       if (distinct > 0) {
  3705.     numPrecision = deltaSum / distinct;
  3706.       }
  3707.     }
  3708.     m_PriorErrorEstimator = new KernelEstimator(numPrecision);
  3709.     m_ErrorEstimator = new KernelEstimator(numPrecision);
  3710.     m_ClassPriors[0] = m_ClassPriorsSum = 0;
  3711.     for (int i = 0; i < m_NumTrainClassVals; i++) {
  3712.       m_ClassPriors[0] += m_TrainClassVals[i] * m_TrainClassWeights[i];
  3713.       m_ClassPriorsSum += m_TrainClassWeights[i];
  3714.       m_PriorErrorEstimator.addValue(m_TrainClassVals[i],
  3715.       m_TrainClassWeights[i]);
  3716.     }
  3717.   }
  3718.  
  3719.   /**
  3720.    * Returns the revision string.
  3721.    *
  3722.    * @return        the revision
  3723.    */
  3724.   public String getRevision() {
  3725.     return RevisionUtils.extract("$Revision: 6346 $");
  3726.   }
  3727. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement