Guest User

Multi-thread k-fold implementation using Weka

a guest
Apr 26th, 2015
86
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. /*******************************************************
  2.     KFoldCV.java
  3. */
  4. import java.util.ArrayList;
  5. import java.util.Collections;
  6. import java.util.List;
  7. import java.util.Random;
  8. import java.util.stream.IntStream;
  9.  
  10. import weka.classifiers.Classifier;
  11. import weka.classifiers.Evaluation;
  12. import weka.core.Instances;
  13.  
  14. public class KFoldCV {
  15.  
  16.     private final Instances instances;
  17.     private final int folds;
  18.    
  19.     /**
  20.      * Default constructor for specifying instances and number of folds.
  21.      *
  22.      * @param instances Instances to cross-validate using k-fold
  23.      * @param folds Number of folds (k)
  24.      */
  25.     public KFoldCV(Instances instances, int folds) {
  26.         this.instances = instances;
  27.         this.folds = folds;
  28.     }
  29.  
  30.     /**
  31.      * Perform k-fold CV and retrieve average error rate through all k folds.
  32.      *
  33.      * @param machineLearningAlgorithm Specify which ML algorithm to use (enum)
  34.      * @param seed Random seed to be used for random instance reordering
  35.      * @return Average error rate for all k-fold iterations
  36.      * @throws Exception
  37.      */
  38.     public double getErrorRate(MachineLearningAlgorithm machineLearningAlgorithm, Random seed) throws Exception {
  39.        
  40.         // Error rates
  41.         List<Double> errorRates = Collections.synchronizedList(new ArrayList<>());
  42.        
  43.         // Prepare instances
  44.         final Instances preparedInstances = new Instances(instances);
  45.         preparedInstances.randomize(seed);
  46.         preparedInstances.stratify(folds);
  47.        
  48.         // k-fold iterations
  49.         IntStream.range(0, folds).parallel().forEach(fold -> {
  50.             errorRates.add(
  51.                     foldIteration(
  52.                             machineLearningAlgorithm,
  53.                             folds,
  54.                             fold,
  55.                             preparedInstances
  56.                     )
  57.                 );
  58.         });
  59.        
  60.         // Average folds
  61.         double avgError = errorRates.stream().mapToDouble(i -> i).average().orElse(0);
  62.        
  63.         return avgError;
  64.        
  65.     }
  66.    
  67.     /**
  68.      * Fold iteration, for internal use only.
  69.      *
  70.      * @param machineLearningAlgorithm Specify which ML algorithm to use
  71.      * @param folds Number of total folds
  72.      * @param fold  Current fold iteration
  73.      * @param preparedInstances Stratified instances
  74.      * @return Error rate for current fold iteration
  75.      */
  76.     private double foldIteration(final MachineLearningAlgorithm machineLearningAlgorithm, final int folds, final int fold, final Instances preparedInstances) {
  77.    
  78.         try {
  79.        
  80.             // Evaluation
  81.             Evaluation evaluation = new Evaluation(instances);
  82.            
  83.             // Train
  84.             Classifier classifier = ClassifierFactory.instantiate(machineLearningAlgorithm);
  85.             classifier.buildClassifier(preparedInstances.trainCV(folds, fold));
  86.             evaluation.evaluateModel(classifier, preparedInstances.testCV(folds, fold));
  87.            
  88.             // Return error rate
  89.             return evaluation.errorRate();
  90.            
  91.         } catch(Exception e) {
  92.            
  93.             throw new RuntimeException(e);
  94.            
  95.         }
  96.        
  97.     }
  98.  
  99. }
  100.  
  101. /*******************************************************
  102.     MachineLearningAlgorithm.java
  103. */
  104. public enum MachineLearningAlgorithm {
  105.     J48, MULTILAYER_PERCEPTRON, NAIVE_BAYES, SUPPORT_VECTOR_MACHINES;
  106. }
  107.  
  108. /*******************************************************
  109.     ClassifierFactory.java
  110. */
  111. import weka.classifiers.Classifier;
  112. import weka.classifiers.bayes.NaiveBayes;
  113. import weka.classifiers.functions.MultilayerPerceptron;
  114. import weka.classifiers.functions.SMO;
  115. import weka.classifiers.trees.J48;
  116.  
  117. public class ClassifierFactory {
  118.  
  119.     public static Classifier instantiate(MachineLearningAlgorithm algorithm) {
  120.        
  121.         switch (algorithm) {
  122.        
  123.             case J48:
  124.                 return new J48();
  125.             case MULTILAYER_PERCEPTRON:
  126.                 return new MultilayerPerceptron();
  127.             case NAIVE_BAYES:
  128.                 return new NaiveBayes();
  129.             case SUPPORT_VECTOR_MACHINES:
  130.                 return new SMO();
  131.             default:
  132.                 return null;
  133.                
  134.         }
  135.        
  136.     }
  137.  
  138. }
RAW Paste Data