Advertisement
Guest User

Untitled

a guest
Aug 22nd, 2012
23
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.42 KB | None | 0 0
  1. #include <shogun/classifier/svm/LibLinear.h>
  2. #include <shogun/evaluation/MulticlassAccuracy.h>
  3. #include <shogun/evaluation/StructuredAccuracy.h>
  4. #include <shogun/features/DenseFeatures.h>
  5. #include <shogun/io/SGIO.h>
  6. #include <shogun/labels/MulticlassLabels.h>
  7. #include <shogun/labels/StructuredLabels.h>
  8. #include <shogun/lib/common.h>
  9. #include <shogun/loss/HingeLoss.h>
  10. #include <shogun/machine/LinearMulticlassMachine.h>
  11. #include <shogun/mathematics/Math.h>
  12. #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
  13. #include <shogun/structure/MulticlassSOLabels.h>
  14. #include <shogun/structure/MulticlassModel.h>
  15. #include <shogun/structure/PrimalMosekSOSVM.h>
  16. #include <shogun/structure/DualLibQPBMSOSVM.h>
  17.  
  18. using namespace shogun;
  19.  
  20. #define DIMS 2
  21. #define EPSILON 10e-5
  22. #define NUM_SAMPLES 100
  23. #define NUM_CLASSES 10
  24.  
  25. char FNAME[] = "data.out";
  26.  
  27. void gen_rand_data(SGVector< float64_t > labs, SGMatrix< float64_t > feats)
  28. {
  29. float64_t means[DIMS];
  30. float64_t stds[DIMS];
  31.  
  32. FILE* pfile = fopen(FNAME, "w");
  33.  
  34. for ( int32_t c = 0 ; c < NUM_CLASSES ; ++c )
  35. {
  36. for ( int32_t j = 0 ; j < DIMS ; ++j )
  37. {
  38. means[j] = CMath::random(-100, 100);
  39. stds[j] = CMath::random( 1, 5);
  40. }
  41.  
  42. for ( int32_t i = 0 ; i < NUM_SAMPLES ; ++i )
  43. {
  44. labs[c*NUM_SAMPLES+i] = c;
  45.  
  46. fprintf(pfile, "%d", c);
  47.  
  48. for ( int32_t j = 0 ; j < DIMS ; ++j )
  49. {
  50. feats[(c*NUM_SAMPLES+i)*DIMS + j] =
  51. CMath::normal_random(means[j], stds[j]);
  52.  
  53. fprintf(pfile, " %f", feats[(c*NUM_SAMPLES+i)*DIMS + j]);
  54. }
  55.  
  56. fprintf(pfile, "\n");
  57. }
  58. }
  59.  
  60. fclose(pfile);
  61. }
  62.  
  63. int main(int argc, char ** argv)
  64. {
  65. init_shogun_with_defaults();
  66.  
  67. SGVector< float64_t > labs(NUM_CLASSES*NUM_SAMPLES);
  68. SGMatrix< float64_t > feats(DIMS, NUM_CLASSES*NUM_SAMPLES);
  69.  
  70. gen_rand_data(labs, feats);
  71.  
  72. // Create train labels
  73. CMulticlassSOLabels* labels = new CMulticlassSOLabels(labs);
  74. CMulticlassLabels* mlabels = new CMulticlassLabels(labs);
  75.  
  76. // Create train features
  77. CDenseFeatures< float64_t >* features = new CDenseFeatures< float64_t >(feats);
  78.  
  79. // Create structured model
  80. CMulticlassModel* model = new CMulticlassModel(features, labels);
  81.  
  82. // Create loss function
  83. CHingeLoss* loss = new CHingeLoss();
  84.  
  85. // Create SO-SVM
  86. CPrimalMosekSOSVM* sosvm = new CPrimalMosekSOSVM(model, loss, labels);
  87. CDualLibQPBMSOSVM* bundle = new CDualLibQPBMSOSVM(model, loss, labels, 0.01);
  88. SG_REF(sosvm);
  89. SG_REF(bundle);
  90.  
  91. sosvm->train();
  92. bundle->train();
  93. CStructuredLabels* out = CStructuredLabels::obtain_from_generic(sosvm->apply());
  94. CStructuredLabels* bout = CStructuredLabels::obtain_from_generic(bundle->apply());
  95.  
  96. // Create liblinear svm classifier with L2-regularized L2-loss
  97. CLibLinear* svm = new CLibLinear(L2R_L2LOSS_SVC);
  98.  
  99. // Add some configuration to the svm
  100. svm->set_epsilon(EPSILON);
  101. svm->set_bias_enabled(false);
  102.  
  103. // Create a multiclass svm classifier that consists of several of the previous one
  104. CLinearMulticlassMachine* mc_svm =
  105. new CLinearMulticlassMachine( new CMulticlassOneVsRestStrategy(),
  106. (CDotFeatures*) features, svm, mlabels);
  107. SG_REF(mc_svm);
  108.  
  109. // Train the multiclass machine using the data passed in the constructor
  110. mc_svm->train();
  111. CMulticlassLabels* mout = CMulticlassLabels::obtain_from_generic(mc_svm->apply());
  112.  
  113. SGVector< float64_t > w = sosvm->get_w();
  114. for ( int32_t i = 0 ; i < w.vlen ; ++i )
  115. SG_SPRINT("%10f ", w[i]);
  116. SG_SPRINT("\n\n");
  117.  
  118. for ( int32_t i = 0 ; i < NUM_CLASSES ; ++i )
  119. {
  120. CLinearMachine* lm = (CLinearMachine*) mc_svm->get_machine(i);
  121. SGVector< float64_t > mw = lm->get_w();
  122. for ( int32_t j = 0 ; j < mw.vlen ; ++j )
  123. SG_SPRINT("%10f ", mw[j]);
  124.  
  125. SG_UNREF(lm); // because of CLinearMulticlassMachine::get_machine()
  126. }
  127. SG_SPRINT("\n");
  128.  
  129. CStructuredAccuracy* structured_evaluator = new CStructuredAccuracy();
  130. CMulticlassAccuracy* multiclass_evaluator = new CMulticlassAccuracy();
  131. SG_REF(structured_evaluator);
  132. SG_REF(multiclass_evaluator);
  133.  
  134. SG_SPRINT("SO-SVM: %5.2f%\n", 100.0*structured_evaluator->evaluate(out, labels));
  135. SG_SPRINT("BMRM: %5.2f%\n", 100.0*structured_evaluator->evaluate(bout, labels));
  136. SG_SPRINT("MC: %5.2f%\n", 100.0*multiclass_evaluator->evaluate(mout, mlabels));
  137.  
  138. // Free memory
  139. SG_UNREF(multiclass_evaluator);
  140. SG_UNREF(structured_evaluator);
  141. SG_UNREF(mout);
  142. SG_UNREF(mc_svm);
  143. SG_UNREF(bundle);
  144. SG_UNREF(sosvm);
  145. SG_UNREF(bout);
  146. SG_UNREF(out);
  147.  
  148. exit_shogun();
  149.  
  150. return 0;
  151. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement