Advertisement
cesarsouza

Cross-Validation for Hidden Markov Model classifiers

Apr 11th, 2014
314
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 4.00 KB | None | 0 0
  1.     // This is a sample code on how to use Cross-Validation
  2.     // to assess the performance of Hidden Markov Models.
  3.  
  4.     // Declare some testing data
  5.     int[][] inputs = new int[][]
  6.     {
  7.         new int[] { 0,1,1,0 },   // Class 0
  8.         new int[] { 0,0,1,0 },   // Class 0
  9.         new int[] { 0,1,1,1,0 }, // Class 0
  10.         new int[] { 0,1,1,1,0 }, // Class 0
  11.         new int[] { 0,1,1,0 },   // Class 0
  12.         new int[] { 0,1,1,1,0 }, // Class 0
  13.         new int[] { 0,1,1,1,0 }, // Class 0
  14.         new int[] { 0,1,0,1,0 }, // Class 0
  15.         new int[] { 0,1,0 },     // Class 0
  16.         new int[] { 0,1,1,0 },   // Class 0
  17.  
  18.         new int[] { 1,0,0,1 },   // Class 1
  19.         new int[] { 1,1,0,1 },   // Class 1
  20.         new int[] { 1,0,0,0,1 }, // Class 1
  21.         new int[] { 1,0,1 },     // Class 1
  22.         new int[] { 1,1,0,1 },   // Class 1
  23.         new int[] { 1,0,1 },     // Class 1
  24.         new int[] { 1,0,0,1 },   // Class 1
  25.         new int[] { 1,0,0,0,1 }, // Class 1
  26.         new int[] { 1,0,1 },     // Class 1
  27.         new int[] { 1,0,0,0,1 }, // Class 1
  28.     };
  29.  
  30.     int[] outputs = new int[]
  31.     {
  32.         0,0,0,0,0,0,0,0,0,0, // First 10 sequences are of class 0
  33.         1,1,1,1,1,1,1,1,1,1, // Last 10 sequences are of class 1
  34.     };
  35.  
  36.  
  37.  
  38.     // Create a new Cross-validation algorithm passing the data set size and the number of folds
  39.     var crossvalidation = new CrossValidation<HiddenMarkovClassifier>(size: inputs.Length, folds: 3);
  40.  
  41.     // Define a fitting function using Support Vector Machines. The objective of this
  42.     // function is to learn a SVM in the subset of the data indicated by cross-validation.
  43.  
  44.     crossvalidation.Fitting = delegate(int k, int[] indicesTrain, int[] indicesValidation)
  45.     {
  46.         // The fitting function is passing the indices of the original set which
  47.         // should be considered training data and the indices of the original set
  48.         // which should be considered validation data.
  49.  
  50.         // Lets now grab the training data:
  51.         var trainingInputs = inputs.Submatrix(indicesTrain);
  52.         var trainingOutputs = outputs.Submatrix(indicesTrain);
  53.  
  54.         // And now the validation data:
  55.         var validationInputs = inputs.Submatrix(indicesValidation);
  56.         var validationOutputs = outputs.Submatrix(indicesValidation);
  57.  
  58.  
  59.         // We are trying to predict two different classes
  60.         int classes = 2;
  61.  
  62.         // Each sequence may have up to two symbols (0 or 1)
  63.         int symbols = 2;
  64.  
  65.         // Nested models will have two states each
  66.         int[] states = new int[] { 2, 2 };
  67.  
  68.         // Creates a new Hidden Markov Model Classifier with the given parameters
  69.         HiddenMarkovClassifier classifier = new HiddenMarkovClassifier(classes, states, symbols);
  70.  
  71.  
  72.         // Create a new learning algorithm to train the sequence classifier
  73.         var teacher = new HiddenMarkovClassifierLearning(classifier,
  74.  
  75.             // Train each model until the log-likelihood changes less than 0.001
  76.             modelIndex => new BaumWelchLearning(classifier.Models[modelIndex])
  77.             {
  78.                 Tolerance = 0.001,
  79.                 Iterations = 0
  80.             }
  81.         );
  82.  
  83.         // Train the sequence classifier using the algorithm
  84.         double likelihood = teacher.Run(trainingInputs, trainingOutputs);
  85.  
  86.         double trainingError = teacher.ComputeError(trainingInputs, trainingOutputs);
  87.  
  88.         // Now we can compute the validation error on the validation data:
  89.         double validationError = teacher.ComputeError(validationInputs, validationOutputs);
  90.  
  91.         // Return a new information structure containing the model and the errors achieved.
  92.         return new CrossValidationValues<HiddenMarkovClassifier>(classifier, trainingError, validationError);
  93.     };
  94.  
  95.  
  96.     // Compute the cross-validation
  97.     var result = crossvalidation.Compute();
  98.  
  99.     // Finally, access the measured performance.
  100.     double trainingErrors = result.Training.Mean;
  101.     double validationErrors = result.Validation.Mean;
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement