Advertisement
Guest User

thereisnoknife

a guest
Jul 18th, 2012
60
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 4.15 KB | None | 0 0
  1. /** @author thereisnoknife@gmail.com */
  2.  
  3. // stl
  4. #include <iostream>
  5. // shogun init
  6. #include <shogun/io/SGIO.h>
  7. #include <shogun/lib/ShogunException.h>
  8. // features
  9. #include <shogun/features/StreamingDenseFeatures.h>
  10. #include <shogun/io/StreamingAsciiFile.h>
  11. #include <shogun/labels/MulticlassLabels.h>
  12. // kernel
  13. #include <shogun/kernel/GaussianKernel.h>
  14. #include <shogun/kernel/LinearKernel.h>
  15. #include <shogun/kernel/PolyKernel.h>
  16. // mkl
  17. #include <shogun/kernel/CustomKernel.h>
  18. #include <shogun/kernel/CombinedKernel.h>
  19. #include <shogun/classifier/mkl/MKLMulticlass.h>
  20. // evaluation
  21. #include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
  22. #include <shogun/evaluation/CrossValidation.h>
  23. #include <shogun/evaluation/MulticlassAccuracy.h>
  24.  
  25. using namespace shogun;
  26.  
  27. int main(int argc, char** argv){
  28.    /** init */
  29.    shogun::init_shogun_with_defaults();
  30.  
  31.    int32_t num_vectors = 0;
  32.    int32_t num_feats   = 2;
  33.  
  34.    /** file data */
  35.    char fname_feats[]  = "/home/eric/buildsys/shogun.data/toy/fm_train_real.dat";
  36.    char fname_labels[] = "/home/eric/buildsys/shogun.data/toy/label_train_multiclass.dat";
  37.    CStreamingAsciiFile* ffeats_train  = new CStreamingAsciiFile(fname_feats);
  38.    CStreamingAsciiFile* flabels_train = new CStreamingAsciiFile(fname_labels);
  39.    SG_REF(ffeats_train);
  40.    SG_REF(flabels_train);
  41.  
  42.    /** streaming data */
  43.    CStreamingDenseFeatures< float64_t >* stream_features =
  44.          new CStreamingDenseFeatures< float64_t >(ffeats_train, false, 1024);
  45.    CStreamingDenseFeatures< float64_t >* stream_labels =
  46.          new CStreamingDenseFeatures< float64_t >(flabels_train, true, 1024);
  47.    SG_REF(stream_features);
  48.    SG_REF(stream_labels);
  49.  
  50.    /** matrix data */
  51.    SGMatrix< float64_t > mat = SGMatrix< float64_t >(num_feats, 1000);
  52.    SGVector< float64_t > vec;
  53.    stream_features->start_parser();
  54.    while ( stream_features->get_next_example() ){
  55.       vec = stream_features->get_vector();
  56.       for ( int32_t i = 0 ; i < num_feats ; ++i )
  57.          mat.matrix[num_vectors*num_feats + i] = vec[i];
  58.       num_vectors++;
  59.       stream_features->release_example();
  60.    }
  61.    stream_features->end_parser();
  62.    mat.num_cols = num_vectors;
  63.  
  64.    /** dense features */
  65.    CDenseFeatures< float64_t >* features = new CDenseFeatures<float64_t>(mat);
  66.    CMulticlassLabels* labels = new CMulticlassLabels(num_vectors);
  67.    SG_REF(features);
  68.    SG_REF(labels);
  69.  
  70.    // Read the labels from the file
  71.    int32_t idx = 0;
  72.    stream_labels->start_parser();
  73.    while ( stream_labels->get_next_example() ){
  74.       labels->set_int_label( idx++, (int32_t)stream_labels->get_label() );
  75.       stream_labels->release_example();
  76.    }
  77.    stream_labels->end_parser();
  78.  
  79.    /** combined features */
  80.    CCombinedFeatures *cfeats = new CCombinedFeatures();
  81.    CCombinedKernel *cker = new CCombinedKernel();
  82.  
  83.    /** 1st kernel: gaussian */
  84.    CGaussianKernel *gker = new CGaussianKernel(features,features,1.2,10);
  85.    cfeats->append_feature_obj(features);
  86.    cker->append_kernel(gker);
  87.  
  88.    /** 2nd kernel: linear */
  89.    CLinearKernel *lker = new CLinearKernel(features,features);
  90.    cfeats->append_feature_obj(features);
  91.    cker->append_kernel(lker);
  92.  
  93.    /** 3rd kernel: poly */
  94.    CPolyKernel *pker = new CPolyKernel(features, features, 2, true, 10);
  95.    cfeats->append_feature_obj(features);
  96.    cker->append_kernel(pker);
  97.  
  98.    cker->init(cfeats,cfeats);
  99.  
  100.    CMKLMulticlass *mkl = new CMKLMulticlass(1.2,cker,labels);
  101.    mkl->set_epsilon(0.00001);
  102.    mkl->parallel->set_num_threads(1);
  103.    mkl->set_mkl_epsilon(0.001);
  104.    mkl->set_mkl_norm(1.5);
  105.  
  106.    mkl->train();
  107.    cker->get_subkernel_weights().display_vector("weights");
  108.  
  109.    index_t n_folds=2;
  110.    CMulticlassAccuracy* eval_crit = new CMulticlassAccuracy();
  111.    CStratifiedCrossValidationSplitting* splitting = new CStratifiedCrossValidationSplitting(labels, n_folds);
  112.    CCrossValidation *cross= new CCrossValidation(mkl,cfeats,labels,splitting,eval_crit);
  113.    cross->set_autolock(false);
  114.    cross->set_num_runs(5);
  115.    cross->set_conf_int_alpha(0.05);
  116.  
  117.    CrossValidationResult* result=(CrossValidationResult*)cross->evaluate();
  118.    std::cout << "Mean= " << result->mean << std::endl;
  119. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement