Advertisement
Guest User

Untitled

a guest
Feb 13th, 2013
121
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 4.97 KB | None | 0 0
  1. #include <shogun/features/streaming/StreamingDenseFeatures.h>
  2. #include <shogun/io/streaming/StreamingAsciiFile.h>
  3. #include <shogun/labels/MulticlassLabels.h>
  4. #include <shogun/kernel/GaussianKernel.h>
  5. #include <shogun/kernel/LinearKernel.h>
  6. #include <shogun/kernel/PolyKernel.h>
  7. #include <shogun/kernel/CombinedKernel.h>
  8. #include <shogun/classifier/mkl/MKLMulticlass.h>
  9. #include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
  10. #include <shogun/evaluation/CrossValidation.h>
  11. #include <shogun/evaluation/MulticlassAccuracy.h>
  12. #include <shogun/io/SerializableAsciiFile.h>
  13. #include <iostream>
  14.  
  15. using namespace shogun;
  16.  
  17. void test_multiclass_mkl_cv()
  18. {
  19.    /* stream data from a file */
  20.    int32_t num_vectors=50;
  21.    int32_t num_feats=2;
  22.  
  23.    /* file data */
  24.    char fname_feats[]="../data/fm_train_real.dat";
  25.    char fname_labels[]="../data/label_train_multiclass.dat";
  26.    CStreamingAsciiFile* ffeats_train=new CStreamingAsciiFile(fname_feats);
  27.    CStreamingAsciiFile* flabels_train=new CStreamingAsciiFile(fname_labels);
  28.    SG_REF(ffeats_train);
  29.    SG_REF(flabels_train);
  30.  
  31.    /* streaming data */
  32.    CStreamingDenseFeatures<float64_t>* stream_features=
  33.          new CStreamingDenseFeatures<float64_t>(ffeats_train, false, 1024);
  34.    CStreamingDenseFeatures<float64_t>* stream_labels=
  35.          new CStreamingDenseFeatures<float64_t>(flabels_train, true, 1024);
  36.    SG_REF(stream_features);
  37.    SG_REF(stream_labels);
  38.  
  39.    /* matrix data */
  40.    SGMatrix<float64_t> mat=SGMatrix<float64_t>(num_feats, num_vectors);
  41.    SGVector<float64_t> vec;
  42.    stream_features->start_parser();
  43.  
  44.    index_t count=0;
  45.    while (stream_features->get_next_example() && count<num_vectors)
  46.    {
  47.       vec=stream_features->get_vector();
  48.       for (int32_t i=0; i<num_feats; ++i)
  49.          mat(i,count)=vec[i];
  50.  
  51.       stream_features->release_example();
  52.       count++;
  53.    }
  54.    stream_features->end_parser();
  55.    mat.num_cols=num_vectors;
  56.  
  57.    /* dense features from streamed matrix */
  58.    CDenseFeatures<float64_t>* features=new CDenseFeatures<float64_t>(mat);
  59.    CMulticlassLabels* labels=new CMulticlassLabels(num_vectors);
  60.    SG_REF(features);
  61.    SG_REF(labels);
  62.  
  63.    /* read labels from file */
  64.    int32_t idx=0;
  65.    stream_labels->start_parser();
  66.    while (stream_labels->get_next_example())
  67.    {
  68.       labels->set_int_label(idx++, (int32_t)stream_labels->get_label());
  69.       stream_labels->release_example();
  70.    }
  71.    stream_labels->end_parser();
  72.  
  73.    /* combined features and kernel */
  74.    CCombinedFeatures *cfeats=new CCombinedFeatures();
  75.    CCombinedKernel *cker=new CCombinedKernel();
  76.    SG_REF(cfeats);
  77.    SG_REF(cker);
  78.  
  79.    /** 1st kernel: gaussian */
  80.    cfeats->append_feature_obj(features);
  81.    cker->append_kernel(new CGaussianKernel(features, features, 1.2, 10));
  82.  
  83.    /** 2nd kernel: linear */
  84.    cfeats->append_feature_obj(features);
  85.    cker->append_kernel(new CLinearKernel(features, features));
  86.  
  87.    /** 3rd kernel: poly */
  88.    cfeats->append_feature_obj(features);
  89.    cker->append_kernel(new CPolyKernel(features, features, 2, true, 10));
  90.  
  91.    cker->init(cfeats, cfeats);
  92.  
  93.    /* create mkl instance */
  94.    CMKLMulticlass* mkl=new CMKLMulticlass(1.2, cker, labels);
  95.    SG_REF(mkl);
  96.    mkl->set_epsilon(0.00001);
  97.    mkl->parallel->set_num_threads(1);
  98.    mkl->set_mkl_epsilon(0.001);
  99.    mkl->set_mkl_norm(1.5);
  100.  
  101.    /* train to see weights */
  102.    mkl->train();
  103.    cker->get_subkernel_weights().display_vector("weights");
  104.  
  105.    CSerializableAsciiFile *mkl_file_w = new CSerializableAsciiFile("test.mkl",'w');
  106.    mkl->save_serializable(mkl_file_w);
  107.    mkl_file_w->close();
  108.  
  109.    CMKLMulticlass* mkl2=new CMKLMulticlass();
  110.    CCombinedFeatures *cfeats2=new CCombinedFeatures();
  111.    CCombinedKernel *cker2=new CCombinedKernel();
  112.    SG_REF(cfeats2);
  113.    SG_REF(cker2);
  114.  
  115.    CSerializableAsciiFile *mkl_file_r = new CSerializableAsciiFile("test.mkl",'r');
  116.    mkl2->load_serializable(mkl_file_r);
  117.    mkl_file_r->close();
  118.  
  119.    /* print */
  120.    int numweights;
  121.    mkl2->getsubkernelweights(numweights);
  122.    std::cout << numweights << std::endl;
  123.    std::cout << mkl2->get_num_machines() << std::endl;
  124.  
  125.    cfeats2->append_feature_obj(features);
  126.    cfeats2->append_feature_obj(features);
  127.    cfeats2->append_feature_obj(features);
  128.  
  129.    /* test */
  130.    CMulticlassAccuracy* eval_crit = new CMulticlassAccuracy();
  131.    CMulticlassLabels *test_res = new CMulticlassLabels();
  132.    test_res = mkl2->apply_multiclass(cfeats2);
  133.    std::cout << eval_crit->evaluate(labels,test_res) << std::endl;
  134.  
  135.    /* clean up */
  136.    SG_UNREF(ffeats_train);
  137.    SG_UNREF(flabels_train);
  138.    SG_UNREF(stream_features);
  139.    SG_UNREF(stream_labels);
  140.    SG_UNREF(features);
  141.    SG_UNREF(labels);
  142.    SG_UNREF(cfeats);
  143.    SG_UNREF(cker);
  144.    SG_UNREF(mkl);
  145.    SG_UNREF(mkl2);
  146.    SG_UNREF(cfeats2);
  147.  
  148. }
  149.  
  150. int main(int argc, char** argv){
  151.    shogun::init_shogun_with_defaults();
  152.  
  153. //  sg_io->set_loglevel(MSG_DEBUG);
  154.  
  155.    /* performs cross-validation on a multi-class mkl machine */
  156.    test_multiclass_mkl_cv();
  157.  
  158.    exit_shogun();
  159. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement