Advertisement
Guest User

Untitled

a guest
Jul 14th, 2014
191
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 6.26 KB | None | 0 0
  1. public class CrossValidation
  2.     {
  3.         MulticlassSupportVectorMachine classifierMC;
  4.         SupportVectorMachine classifier;
  5.         public CrossValidation(double[][] inputs, int[] outputs, out CrossValidationStatistics trainingErrors, out CrossValidationStatistics validationErrors, string path, double tolerance, int folds, int kernelN)
  6.         {
  7.             int[] randGroupsResult = RandomGroups(outputs, outputs.Distinct().Length, folds);
  8.             var crossvalidation = new CrossValidation<MulticlassSupportVectorMachine>(randGroupsResult, folds: folds);
  9.             crossvalidation.Fitting = delegate(int index, int[] indicesTrain, int[] indicesValidation)
  10.             {
  11.                 double trainingError = 0, validationError = 0;
  12.  
  13.                 // Lets now grab the training data:
  14.                 double[][] trainingInputs = inputs.Submatrix(indicesTrain);
  15.                 int[] trainingOutputs = outputs.Submatrix(indicesTrain);
  16.  
  17.                 // And now the validation data:
  18.                 double[][] validationInputs = inputs.Submatrix(indicesValidation);
  19.                 int[] validationOutputs = outputs.Submatrix(indicesValidation);
  20.                 IKernel kernel;
  21.                 switch(kernelN){
  22.                     case 1: kernel = Gaussian.Estimate(trainingInputs);
  23.                         break;
  24.                     case 0: kernel = new Linear();
  25.                         break;
  26.                     case 2: kernel = new Quadratic();
  27.                         break;
  28.                     default: kernel = new Linear();
  29.                         break;
  30.                 }
  31.                 // Complexity
  32.                 var complexity = SequentialMinimalOptimization.EstimateComplexity(kernel, trainingInputs);
  33.  
  34.                 // Create a new Multi-class Support Vector Machine
  35.                 var model = new MulticlassSupportVectorMachine(trainingInputs[0].Length, kernel, trainingOutputs.Distinct().Length);
  36.                 //var model = new MulticlassSupportVectorMachine(trainingInputs[0].Length, kernel, outputs.Distinct().Length);
  37.  
  38.                 // Create the Multi-class learning algorithm for the machine
  39.                 var teacher = new MulticlassSupportVectorLearning(model, trainingInputs, trainingOutputs);
  40.  
  41.                 // Configure the learning algorithm to use SMO to train the
  42.                 // underlying SVMs in each of the binary class subproblems.
  43.                 teacher.Algorithm = (svm, classInputs, classOutputs, i, j) =>
  44.                     new SequentialMinimalOptimization(svm, classInputs, classOutputs)
  45.                     {
  46.                         Tolerance = tolerance,
  47.                         Complexity = complexity
  48.                     };
  49.  
  50.                 // Run the learning algorithm
  51.                 trainingError = teacher.Run();
  52.  
  53.                 // Get the validation errors
  54.                 validationError = teacher.ComputeError(validationInputs, validationOutputs);
  55.  
  56.                 // Return a new information structure containing the model and the errors achieved.
  57.                 return new CrossValidationValues<MulticlassSupportVectorMachine>(model, trainingError, validationError);
  58.             };
  59.  
  60.             // Compute the cross-validation
  61.             var result = crossvalidation.Compute();
  62.  
  63.             // Finally, access the measured performance.
  64.             result.Training.Tag = "Training results";
  65.             result.Validation.Tag = "Validation results";
  66.             trainingErrors = result.Training;
  67.             validationErrors = result.Validation;
  68.            
  69.             result.Save(path + "_"+validationErrors.Mean.ToString("0.000") + ".cvc");
  70.             var minIndex = result.Validation.Values.Find(v => v == result.Training.Values.Min()).FirstOrDefault();
  71.             var minModelValues = result.Models[minIndex];
  72.             this.classifierMC = minModelValues.Model;
  73.         }
  74.        
  75.         public static int[] RandomGroups(int[] labels, int classes, int groups)
  76.         {
  77.             int size = labels.Length;
  78.  
  79.             var buckets = new List<Tuple<int, int>>[classes];
  80.             for (int i = 0; i < buckets.Length; i++)
  81.                 buckets[i] = new List<Tuple<int, int>>();
  82.  
  83.             for (int i = 0; i < labels.Length; i++)
  84.                 buckets[labels[i]].Add(Tuple.Create(i, labels[i]));
  85.  
  86.  
  87.             for (int i = 0; i < buckets.Length; i++)
  88.                 Accord.Statistics.Tools.Shuffle(buckets);
  89.  
  90.             var partitions = new List<Tuple<int, int>>[groups];
  91.             for (int i = 0; i < partitions.Length; i++)
  92.                 partitions[i] = new List<Tuple<int, int>>();
  93.  
  94.             // We are going to take samples from the buckets and assign to
  95.             // groups. For this, we will be following the buckets in order,
  96.             // such that new samples are drawn equally from each bucket.
  97.  
  98.             bool allEmpty = true;
  99.             int bucketIndex = 0;
  100.             int partitionIndex = 0;
  101.  
  102.             do
  103.             {
  104.                 for (int i = 0; i < partitions.Length; i++)
  105.                 {
  106.                     allEmpty = true;
  107.  
  108.  
  109.                     var currentPartition = partitions[partitionIndex];
  110.                     partitionIndex = (partitionIndex + 1) % partitions.Length;
  111.  
  112.  
  113.                     for (int j = 0; j < buckets.Length; j++)
  114.                     {
  115.                         var currentBucket = buckets[bucketIndex];
  116.                         bucketIndex = (bucketIndex + 1) % buckets.Length;
  117.  
  118.  
  119.                         if (currentBucket.Count == 0)
  120.                             continue;
  121.  
  122.  
  123.                         allEmpty = false;
  124.  
  125.  
  126.                         var next = currentBucket[currentBucket.Count - 1];
  127.                         currentBucket.RemoveAt(currentBucket.Count - 1);
  128.                         currentPartition.Add(next);
  129.                     }
  130.                 }
  131.  
  132.             } while (!allEmpty);
  133.  
  134.  
  135.             for (int i = 0; i < partitions.Length; i++)
  136.                 Accord.Statistics.Tools.Shuffle(partitions[i].ToArray());
  137.  
  138.             int[] splittings = new int[labels.Length];
  139.             for (int i = 0; i < partitions.Length; i++)
  140.                 foreach (var index in partitions[i])
  141.                     splittings[index.Item1] = i;
  142.  
  143.             return splittings;
  144.         }
  145.     }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement