Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /*******************************************************
- KFoldCV.java
- */
- import java.util.ArrayList;
- import java.util.Collections;
- import java.util.List;
- import java.util.Random;
- import java.util.stream.IntStream;
- import weka.classifiers.Classifier;
- import weka.classifiers.Evaluation;
- import weka.core.Instances;
- public class KFoldCV {
- private final Instances instances;
- private final int folds;
- /**
- * Default constructor for specifying instances and number of folds.
- *
- * @param instances Instances to cross-validate using k-fold
- * @param folds Number of folds (k)
- */
- public KFoldCV(Instances instances, int folds) {
- this.instances = instances;
- this.folds = folds;
- }
- /**
- * Perform k-fold CV and retrieve average error rate through all k folds.
- *
- * @param machineLearningAlgorithm Specify which ML algorithm to use (enum)
- * @param seed Random seed to be used for random instance reordering
- * @return Average error rate for all k-fold iterations
- * @throws Exception
- */
- public double getErrorRate(MachineLearningAlgorithm machineLearningAlgorithm, Random seed) throws Exception {
- // Error rates
- List<Double> errorRates = Collections.synchronizedList(new ArrayList<>());
- // Prepare instances
- final Instances preparedInstances = new Instances(instances);
- preparedInstances.randomize(seed);
- preparedInstances.stratify(folds);
- // k-fold iterations
- IntStream.range(0, folds).parallel().forEach(fold -> {
- errorRates.add(
- foldIteration(
- machineLearningAlgorithm,
- folds,
- fold,
- preparedInstances
- )
- );
- });
- // Average folds
- double avgError = errorRates.stream().mapToDouble(i -> i).average().orElse(0);
- return avgError;
- }
- /**
- * Fold iteration, for internal use only.
- *
- * @param machineLearningAlgorithm Specify which ML algorithm to use
- * @param folds Number of total folds
- * @param fold Current fold iteration
- * @param preparedInstances Stratified instances
- * @return Error rate for current fold iteration
- */
- private double foldIteration(final MachineLearningAlgorithm machineLearningAlgorithm, final int folds, final int fold, final Instances preparedInstances) {
- try {
- // Evaluation
- Evaluation evaluation = new Evaluation(instances);
- // Train
- Classifier classifier = ClassifierFactory.instantiate(machineLearningAlgorithm);
- classifier.buildClassifier(preparedInstances.trainCV(folds, fold));
- evaluation.evaluateModel(classifier, preparedInstances.testCV(folds, fold));
- // Return error rate
- return evaluation.errorRate();
- } catch(Exception e) {
- throw new RuntimeException(e);
- }
- }
- }
- /*******************************************************
- MachineLearningAlgorithm.java
- */
- public enum MachineLearningAlgorithm {
- J48, MULTILAYER_PERCEPTRON, NAIVE_BAYES, SUPPORT_VECTOR_MACHINES;
- }
- /*******************************************************
- ClassifierFactory.java
- */
- import weka.classifiers.Classifier;
- import weka.classifiers.bayes.NaiveBayes;
- import weka.classifiers.functions.MultilayerPerceptron;
- import weka.classifiers.functions.SMO;
- import weka.classifiers.trees.J48;
- public class ClassifierFactory {
- public static Classifier instantiate(MachineLearningAlgorithm algorithm) {
- switch (algorithm) {
- case J48:
- return new J48();
- case MULTILAYER_PERCEPTRON:
- return new MultilayerPerceptron();
- case NAIVE_BAYES:
- return new NaiveBayes();
- case SUPPORT_VECTOR_MACHINES:
- return new SMO();
- default:
- return null;
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement