Guest User

Untitled

a guest
Sep 4th, 2012
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 4.15 KB | None | 0 0
  1. /*
  2.  * This program is free software; you can redistribute it and/or modify
  3.  * it under the terms of the GNU General Public License as published by
  4.  * the Free Software Foundation; either version 3 of the License, or
  5.  * (at your option) any later version.
  6.  *
  7.  * Written (W) 2012 [email protected]
  8.  * Written (W) 2012 Heiko Strathmann
  9.  */
  10.  
  11. #include <shogun/features/streaming/StreamingDenseFeatures.h>
  12. #include <shogun/io/streaming/StreamingAsciiFile.h>
  13. #include <shogun/labels/MulticlassLabels.h>
  14. #include <shogun/kernel/GaussianKernel.h>
  15. #include <shogun/kernel/LinearKernel.h>
  16. #include <shogun/kernel/PolyKernel.h>
  17. #include <shogun/kernel/CombinedKernel.h>
  18. #include <shogun/classifier/mkl/MKLMulticlass.h>
  19. #include <shogun/io/SerializableAsciiFile.h>
  20. #include <shogun/kernel/normalizer/SqrtDiagKernelNormalizer.h>
  21. #include <shogun/kernel/normalizer/AvgDiagKernelNormalizer.h>
  22.  
  23. #include <iostream>
  24.  
  25. using namespace shogun;
  26.  
  27. void test_multiclass_mkl()
  28. {
  29.    /* stream data from a file */
  30.    int32_t num_vectors=50;
  31.    int32_t num_feats=2;
  32.  
  33.    /* file data */
  34.    char fname_feats[]="../data/fm_train_real.dat";
  35.    char fname_labels[]="../data/label_train_multiclass.dat";
  36.    CStreamingAsciiFile* ffeats_train=new CStreamingAsciiFile(fname_feats);
  37.    CStreamingAsciiFile* flabels_train=new CStreamingAsciiFile(fname_labels);
  38.    SG_REF(ffeats_train);
  39.    SG_REF(flabels_train);
  40.  
  41.    /* streaming data */
  42.    CStreamingDenseFeatures<float64_t>* stream_features=
  43.          new CStreamingDenseFeatures<float64_t>(ffeats_train, false, 1024);
  44.    CStreamingDenseFeatures<float64_t>* stream_labels=
  45.          new CStreamingDenseFeatures<float64_t>(flabels_train, true, 1024);
  46.    SG_REF(stream_features);
  47.    SG_REF(stream_labels);
  48.  
  49.    /* matrix data */
  50.    SGMatrix<float64_t> mat=SGMatrix<float64_t>(num_feats, num_vectors);
  51.    SGVector<float64_t> vec;
  52.    stream_features->start_parser();
  53.  
  54.    index_t count=0;
  55.    while (stream_features->get_next_example() && count<num_vectors)
  56.    {
  57.       vec=stream_features->get_vector();
  58.       for (int32_t i=0; i<num_feats; ++i)
  59.          mat(i,count)=vec[i];
  60.  
  61.       stream_features->release_example();
  62.       count++;
  63.    }
  64.    stream_features->end_parser();
  65.    mat.num_cols=num_vectors;
  66.  
  67.    /* dense features from streamed matrix */
  68.    CDenseFeatures<float64_t>* features=new CDenseFeatures<float64_t>(mat);
  69.    CMulticlassLabels* labels=new CMulticlassLabels(num_vectors);
  70.    SG_REF(features);
  71.    SG_REF(labels);
  72.  
  73.    /* read labels from file */
  74.    int32_t idx=0;
  75.    stream_labels->start_parser();
  76.    while (stream_labels->get_next_example())
  77.    {
  78.       labels->set_int_label(idx++, (int32_t)stream_labels->get_label());
  79.       stream_labels->release_example();
  80.    }
  81.    stream_labels->end_parser();
  82.  
  83.    /* combined features and kernel */
  84.    CCombinedFeatures *cfeats=new CCombinedFeatures();
  85.    CCombinedKernel *cker=new CCombinedKernel();
  86.    SG_REF(cfeats);
  87.    SG_REF(cker);
  88.  
  89.    /** 1st kernel: gaussian */
  90.    cfeats->append_feature_obj(features);
  91.    cker->append_kernel(new CGaussianKernel(10,1.2));
  92.  
  93.    /** 2nd kernel: linear */
  94.    cfeats->append_feature_obj(features);
  95.    cker->append_kernel(new CLinearKernel());
  96.  
  97.    /** 3rd kernel: poly */
  98.    cfeats->append_feature_obj(features);
  99.    cker->append_kernel(new CPolyKernel(10, 2, true));
  100.  
  101.    cker->init(cfeats, cfeats);
  102.  
  103.    /* create mkl instance */
  104.    CMKLMulticlass* mkl=new CMKLMulticlass(1.2, cker, labels);
  105.    SG_REF(mkl);
  106.    mkl->set_epsilon(0.00001);
  107.    mkl->set_mkl_epsilon(0.001);
  108.    mkl->set_mkl_norm(1.5);
  109.  
  110.    /* train to see weights */
  111.    mkl->train();
  112.    cker->get_subkernel_weights().display_vector("weights");
  113.  
  114.    /* save mkl */
  115.    //CSerializableAsciiFile *asciifile = new CSerializableAsciiFile("../filename.txt",'w');
  116.    //mkl->save_serializable(asciifile);
  117.  
  118.  
  119.    /* clean up */
  120.    SG_UNREF(ffeats_train);
  121.    SG_UNREF(flabels_train);
  122.    SG_UNREF(stream_features);
  123.    SG_UNREF(stream_labels);
  124.    SG_UNREF(features);
  125.    SG_UNREF(labels);
  126.    SG_UNREF(cfeats);
  127.    SG_UNREF(cker);
  128.    SG_UNREF(mkl);
  129. }
  130.  
  131.  
  132. int main(int argc, char** argv){
  133.  
  134.    shogun::init_shogun_with_defaults();
  135.    test_multiclass_mkl();
  136.  
  137.    exit_shogun();
  138. }
Advertisement
Add Comment
Please, Sign In to add comment