Advertisement
Guest User

modelselection fail ?

a guest
Jul 20th, 2012
38
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.05 KB | None | 0 0
  1. /*
  2.  * This program is free software; you can redistribute it and/or modify
  3.  * it under the terms of the GNU General Public License as published by
  4.  * the Free Software Foundation; either version 3 of the License, or
  5.  * (at your option) any later version.
  6.  *
  7.  * Written (W) 2012 yoo, thereisnoknife@gmail.com
  8.  * Written (W) 2012 Heiko Strathmann
  9.  */
  10.  
  11. #include <shogun/features/StreamingDenseFeatures.h>
  12. #include <shogun/io/StreamingAsciiFile.h>
  13. #include <shogun/labels/MulticlassLabels.h>
  14. #include <shogun/kernel/GaussianKernel.h>
  15. #include <shogun/kernel/LinearKernel.h>
  16. #include <shogun/kernel/PolyKernel.h>
  17. #include <shogun/kernel/CombinedKernel.h>
  18. #include <shogun/classifier/mkl/MKLMulticlass.h>
  19. #include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
  20. #include <shogun/evaluation/CrossValidation.h>
  21. #include <shogun/evaluation/MulticlassAccuracy.h>
  22.  
  23. using namespace shogun;
  24.  
  25. void test_multiclass_mkl_cv()
  26. {
  27.    /* stream data from a file */
  28.    int32_t num_vectors=50;
  29.    int32_t num_feats=2;
  30.  
  31.    /* file data */
  32.    char fname_feats[]="/home/eric/buildsys/shogun.data/toy/fm_train_real.dat";
  33.    char fname_labels[]="/home/eric/buildsys/shogun.data/toy/label_train_multiclass.dat";
  34.    CStreamingAsciiFile* ffeats_train=new CStreamingAsciiFile(fname_feats);
  35.    CStreamingAsciiFile* flabels_train=new CStreamingAsciiFile(fname_labels);
  36.    SG_REF(ffeats_train);
  37.    SG_REF(flabels_train);
  38.  
  39.    /* streaming data */
  40.    CStreamingDenseFeatures<float64_t>* stream_features=
  41.          new CStreamingDenseFeatures<float64_t>(ffeats_train, false, 1024);
  42.    CStreamingDenseFeatures<float64_t>* stream_labels=
  43.          new CStreamingDenseFeatures<float64_t>(flabels_train, true, 1024);
  44.    SG_REF(stream_features);
  45.    SG_REF(stream_labels);
  46.  
  47.    /* matrix data */
  48.    SGMatrix<float64_t> mat=SGMatrix<float64_t>(num_feats, num_vectors);
  49.    SGVector<float64_t> vec;
  50.    stream_features->start_parser();
  51.  
  52.    index_t count=0;
  53.    while (stream_features->get_next_example() && count<num_vectors)
  54.    {
  55.       vec=stream_features->get_vector();
  56.       for (int32_t i=0; i<num_feats; ++i)
  57.          mat(i,count)=vec[i];
  58.  
  59.       stream_features->release_example();
  60.       count++;
  61.    }
  62.    stream_features->end_parser();
  63.    mat.num_cols=num_vectors;
  64.  
  65.    /* dense features from streamed matrix */
  66.    CDenseFeatures<float64_t>* features=new CDenseFeatures<float64_t>(mat);
  67.    CMulticlassLabels* labels=new CMulticlassLabels(num_vectors);
  68.    SG_REF(features);
  69.    SG_REF(labels);
  70.  
  71.    /* read labels from file */
  72.    int32_t idx=0;
  73.    stream_labels->start_parser();
  74.    while (stream_labels->get_next_example())
  75.    {
  76.       labels->set_int_label(idx++, (int32_t)stream_labels->get_label());
  77.       stream_labels->release_example();
  78.    }
  79.    stream_labels->end_parser();
  80.  
  81.    /* combined features and kernel */
  82.    CCombinedFeatures *cfeats=new CCombinedFeatures();
  83.    CCombinedKernel *cker=new CCombinedKernel();
  84.    SG_REF(cfeats);
  85.    SG_REF(cker);
  86.  
  87.    /** 1st kernel: gaussian */
  88.    cfeats->append_feature_obj(features);
  89.    cker->append_kernel(new CGaussianKernel(features, features, 1.2, 10));
  90.  
  91.    /** 2nd kernel: linear */
  92.    cfeats->append_feature_obj(features);
  93.    cker->append_kernel(new CLinearKernel(features, features));
  94.  
  95.    /** 3rd kernel: poly */
  96.    cfeats->append_feature_obj(features);
  97.    cker->append_kernel(new CPolyKernel(features, features, 2, true, 10));
  98.  
  99.    cker->init(cfeats, cfeats);
  100.  
  101.    /* create mkl instance */
  102.    CMKLMulticlass* mkl=new CMKLMulticlass(1.2, cker, labels);
  103.    SG_REF(mkl);
  104.    mkl->set_epsilon(0.00001);
  105.    mkl->parallel->set_num_threads(1);
  106.    mkl->set_mkl_epsilon(0.001);
  107.    mkl->set_mkl_norm(1.5);
  108.  
  109.    /* train to see weights */
  110.    mkl->train();
  111.    cker->get_subkernel_weights().display_vector("weights");
  112.  
  113.    /* cross-validation instances */
  114.    index_t n_folds=3;
  115.    index_t n_runs=5;
  116.    CMulticlassAccuracy* eval_crit=new CMulticlassAccuracy();
  117.    CStratifiedCrossValidationSplitting* splitting=
  118.          new CStratifiedCrossValidationSplitting(labels, n_folds);
  119.    CCrossValidation *cross=new CCrossValidation(mkl, cfeats, labels, splitting,
  120.          eval_crit);
  121.    cross->set_autolock(false);
  122.    cross->set_num_runs(n_runs);
  123.    cross->set_conf_int_alpha(0.05);
  124.  
  125.    /* perform x-val and print result */
  126.    CModelSelectionOutput *ms_output = new CModelSelectionOutput();
  127.    CrossValidationResult* result=(CrossValidationResult*)cross->evaluate(ms_output);
  128.    SG_SPRINT("mean of %d %d-fold x-val runs: %f\n", n_runs, n_folds,
  129.          result->mean);
  130.  
  131.    /* assert high accuracy */
  132.    ASSERT(result->mean>0.9);
  133.  
  134.    /* clean up */
  135.    SG_UNREF(ffeats_train);
  136.    SG_UNREF(flabels_train);
  137.    SG_UNREF(stream_features);
  138.    SG_UNREF(stream_labels);
  139.    SG_UNREF(features);
  140.    SG_UNREF(labels);
  141.    SG_UNREF(cfeats);
  142.    SG_UNREF(cker);
  143.    SG_UNREF(mkl);
  144.    SG_UNREF(cross);
  145.    SG_UNREF(result);
  146. }
  147.  
  148. int main(int argc, char** argv){
  149.    shogun::init_shogun_with_defaults();
  150.  
  151. //  sg_io->set_loglevel(MSG_DEBUG);
  152.  
  153.    /* performs cross-validation on a multi-class mkl machine */
  154.    test_multiclass_mkl_cv();
  155.  
  156.    exit_shogun();
  157. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement