Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- public class CrossValidation
- {
- MulticlassSupportVectorMachine classifierMC;
- SupportVectorMachine classifier;
- public CrossValidation(double[][] inputs, int[] outputs, out CrossValidationStatistics trainingErrors, out CrossValidationStatistics validationErrors, string path, double tolerance, int folds, int kernelN)
- {
- int[] randGroupsResult = RandomGroups(outputs, outputs.Distinct().Length, folds);
- var crossvalidation = new CrossValidation<MulticlassSupportVectorMachine>(randGroupsResult, folds: folds);
- crossvalidation.Fitting = delegate(int index, int[] indicesTrain, int[] indicesValidation)
- {
- double trainingError = 0, validationError = 0;
- // Lets now grab the training data:
- double[][] trainingInputs = inputs.Submatrix(indicesTrain);
- int[] trainingOutputs = outputs.Submatrix(indicesTrain);
- // And now the validation data:
- double[][] validationInputs = inputs.Submatrix(indicesValidation);
- int[] validationOutputs = outputs.Submatrix(indicesValidation);
- IKernel kernel;
- switch(kernelN){
- case 1: kernel = Gaussian.Estimate(trainingInputs);
- break;
- case 0: kernel = new Linear();
- break;
- case 2: kernel = new Quadratic();
- break;
- default: kernel = new Linear();
- break;
- }
- // Complexity
- var complexity = SequentialMinimalOptimization.EstimateComplexity(kernel, trainingInputs);
- // Create a new Multi-class Support Vector Machine
- var model = new MulticlassSupportVectorMachine(trainingInputs[0].Length, kernel, trainingOutputs.Distinct().Length);
- //var model = new MulticlassSupportVectorMachine(trainingInputs[0].Length, kernel, outputs.Distinct().Length);
- // Create the Multi-class learning algorithm for the machine
- var teacher = new MulticlassSupportVectorLearning(model, trainingInputs, trainingOutputs);
- // Configure the learning algorithm to use SMO to train the
- // underlying SVMs in each of the binary class subproblems.
- teacher.Algorithm = (svm, classInputs, classOutputs, i, j) =>
- new SequentialMinimalOptimization(svm, classInputs, classOutputs)
- {
- Tolerance = tolerance,
- Complexity = complexity
- };
- // Run the learning algorithm
- trainingError = teacher.Run();
- // Get the validation errors
- validationError = teacher.ComputeError(validationInputs, validationOutputs);
- // Return a new information structure containing the model and the errors achieved.
- return new CrossValidationValues<MulticlassSupportVectorMachine>(model, trainingError, validationError);
- };
- // Compute the cross-validation
- var result = crossvalidation.Compute();
- // Finally, access the measured performance.
- result.Training.Tag = "Training results";
- result.Validation.Tag = "Validation results";
- trainingErrors = result.Training;
- validationErrors = result.Validation;
- result.Save(path + "_"+validationErrors.Mean.ToString("0.000") + ".cvc");
- var minIndex = result.Validation.Values.Find(v => v == result.Training.Values.Min()).FirstOrDefault();
- var minModelValues = result.Models[minIndex];
- this.classifierMC = minModelValues.Model;
- }
- public static int[] RandomGroups(int[] labels, int classes, int groups)
- {
- int size = labels.Length;
- var buckets = new List<Tuple<int, int>>[classes];
- for (int i = 0; i < buckets.Length; i++)
- buckets[i] = new List<Tuple<int, int>>();
- for (int i = 0; i < labels.Length; i++)
- buckets[labels[i]].Add(Tuple.Create(i, labels[i]));
- for (int i = 0; i < buckets.Length; i++)
- Accord.Statistics.Tools.Shuffle(buckets);
- var partitions = new List<Tuple<int, int>>[groups];
- for (int i = 0; i < partitions.Length; i++)
- partitions[i] = new List<Tuple<int, int>>();
- // We are going to take samples from the buckets and assign to
- // groups. For this, we will be following the buckets in order,
- // such that new samples are drawn equally from each bucket.
- bool allEmpty = true;
- int bucketIndex = 0;
- int partitionIndex = 0;
- do
- {
- for (int i = 0; i < partitions.Length; i++)
- {
- allEmpty = true;
- var currentPartition = partitions[partitionIndex];
- partitionIndex = (partitionIndex + 1) % partitions.Length;
- for (int j = 0; j < buckets.Length; j++)
- {
- var currentBucket = buckets[bucketIndex];
- bucketIndex = (bucketIndex + 1) % buckets.Length;
- if (currentBucket.Count == 0)
- continue;
- allEmpty = false;
- var next = currentBucket[currentBucket.Count - 1];
- currentBucket.RemoveAt(currentBucket.Count - 1);
- currentPartition.Add(next);
- }
- }
- } while (!allEmpty);
- for (int i = 0; i < partitions.Length; i++)
- Accord.Statistics.Tools.Shuffle(partitions[i].ToArray());
- int[] splittings = new int[labels.Length];
- for (int i = 0; i < partitions.Length; i++)
- foreach (var index in partitions[i])
- splittings[index.Item1] = i;
- return splittings;
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement