Advertisement
Guest User

Untitled

a guest
May 8th, 2012
53
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.66 KB | None | 0 0
  1. void test_cross_validation ()
  2. {
  3. int32_t num_vectors = 0;
  4. int32_t num_feats = 2;
  5.  
  6. init_shogun_with_defaults();
  7.  
  8. // Prepare to read a file for the training data
  9. char fname_feats[] = "../data/fm_train_real.dat";
  10. char fname_labels[] = "../data/label_train_multiclass.dat";
  11. CStreamingAsciiFile* ffeats_train = new CStreamingAsciiFile(fname_feats);
  12. CStreamingAsciiFile* flabels_train = new CStreamingAsciiFile(fname_labels);
  13. SG_REF(ffeats_train);
  14. SG_REF(flabels_train);
  15.  
  16. CStreamingDenseFeatures< float64_t >* stream_features =
  17. new CStreamingDenseFeatures< float64_t >(ffeats_train, false, 1024);
  18.  
  19. CStreamingDenseFeatures< float64_t >* stream_labels =
  20. new CStreamingDenseFeatures< float64_t >(flabels_train, true, 1024);
  21.  
  22. SG_REF(stream_features);
  23. SG_REF(stream_labels);
  24.  
  25. // Create a matrix with enough space to read all the feature vectors
  26. SGMatrix< float64_t > mat = SGMatrix< float64_t >(num_feats, 1000);
  27.  
  28. // Read the values from the file and store them in mat
  29. SGVector< float64_t > vec;
  30. stream_features->start_parser();
  31. while (stream_features->get_next_example())
  32. {
  33. vec = stream_features->get_vector();
  34.  
  35. for ( int32_t i = 0 ; i < num_feats ; ++i )
  36. mat[num_vectors*num_feats + i] = vec[i];
  37.  
  38. num_vectors++;
  39. stream_features->release_example();
  40. }
  41. stream_features->end_parser();
  42.  
  43. // Create features with the useful values from mat
  44. CDenseFeatures< float64_t >* features = new CDenseFeatures< float64_t >(mat.matrix, num_feats, num_vectors);
  45.  
  46. CLabels* labels = new CLabels(num_vectors);
  47. SG_REF(features);
  48. SG_REF(labels);
  49.  
  50. // Read the labels from the file
  51. int32_t idx = 0;
  52. stream_labels->start_parser();
  53. while (stream_labels->get_next_example())
  54. {
  55. labels->set_int_label( idx++, (int32_t)stream_labels->get_label() );
  56. stream_labels->release_example();
  57. }
  58. stream_labels->end_parser();
  59.  
  60. /* gaussian kernel */
  61. int32_t kernel_cache=100;
  62. int32_t width=10;
  63. CGaussianKernel* kernel=new CGaussianKernel(kernel_cache, width);
  64. kernel->init(features, features);
  65.  
  66. /* create svm via libsvm */
  67. float64_t svm_C=10;
  68. float64_t svm_eps=0.0001;
  69. CMulticlassLibLinear* svm=new CMulticlassLibLinear(svm_C, features, labels);
  70. svm->set_epsilon(svm_eps);
  71.  
  72. /* train and output */
  73. svm->train(features);
  74. CLabels* output=svm->apply(features);
  75. for (index_t i=0; i<num_vectors; ++i)
  76. SG_SPRINT("i=%d, class=%f,\n", i, output->get_label(i));
  77.  
  78. /* evaluation criterion */
  79. CMulticlassAccuracy* eval_crit = new CMulticlassAccuracy();
  80.  
  81. /* evaluate training error */
  82. float64_t eval_result=eval_crit->evaluate(output, labels);
  83. SG_SPRINT("training accuracy: %f\n", eval_result);
  84. SG_UNREF(output);
  85.  
  86. /* assert that regression "works". this is not guaranteed to always work
  87. * but should be a really coarse check to see if everything is going
  88. * approx. right */
  89. ASSERT(eval_result<2);
  90.  
  91. /* splitting strategy */
  92. index_t n_folds=5;
  93. CStratifiedCrossValidationSplitting* splitting=
  94. new CStratifiedCrossValidationSplitting(labels, n_folds);
  95.  
  96. /* cross validation instance, 10 runs, 95% confidence interval */
  97. CCrossValidation* cross=new CCrossValidation(svm, features, labels, splitting, eval_crit);
  98.  
  99. cross->set_num_runs(10);
  100. cross->set_conf_int_alpha(0.05);
  101. cross->set_autolock (false);
  102.  
  103. /* actual evaluation */
  104. CrossValidationResult result=cross->evaluate();
  105. result.print_result();
  106.  
  107. /* clean up */
  108. SG_UNREF(stream_features);
  109. SG_UNREF(stream_labels);
  110. SG_UNREF(cross);
  111. SG_UNREF(features);
  112. SG_UNREF(labels);
  113. SG_UNREF(kernel);
  114. SG_UNREF(flabels_train);
  115. SG_UNREF(ffeats_train);
  116. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement