Advertisement
Guest User

Untitled

a guest
Oct 29th, 2012
59
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.96 KB | None | 0 0
  1. #include <shogun/base/init.h>
  2. #include <shogun/features/StringFeatures.h>
  3. #include <shogun/evaluation/CrossValidation.h>
  4. #include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
  5. #include <shogun/evaluation/CrossValidationMulticlassStorage.h>
  6. #include <shogun/evaluation/MulticlassAccuracy.h>
  7. #include <shogun/multiclass/GMNPSVM.h>
  8. #include <shogun/preprocessor/SortWordString.h>
  9. #include <shogun/kernel/string/WeightedCommWordStringKernel.h>
  10.  
  11. using namespace shogun;
  12.  
  13. void test()
  14. {
  15.  
  16.         /** generate random strings list and mc labels */
  17.         index_t num_strings=50;
  18.         index_t max_string_length=20;
  19.         index_t min_string_length=max_string_length/2;
  20.         SGStringList<uint8_t> strings(num_strings, max_string_length);
  21.         SGVector<float64_t> lab(num_strings);
  22.  
  23.         SG_SPRINT("original string data:\n");
  24.         for (index_t i=0; i<num_strings; ++i)
  25.         {
  26.                 index_t len=CMath::random(min_string_length, max_string_length);
  27.                 SGString<uint8_t> current(len);
  28.                 // labels
  29.                 if (i<num_strings/3)
  30.                         lab[i]=0;
  31.                 else if (i<2*num_strings/3)
  32.                         lab[i]=1;
  33.                 else
  34.                         lab[i]=2;
  35.  
  36.                 SG_SPRINT("[%i]: \"", i);
  37.                 for (index_t j=0; j<len; ++j)
  38.                 {
  39.                         current.string[j]=(uint8_t)CMath::random('A', 'Z');
  40.                         char* string=SG_MALLOC(char, 2);
  41.                         string[0]=current.string[j];
  42.                         string[1]='\0';
  43.                         SG_SPRINT("%s", string);
  44.                         SG_FREE(string);
  45.                 }
  46.                 SG_SPRINT("\"\n");
  47.                 strings.strings[i]=current;
  48.         }
  49.  
  50.         /** word string features */
  51.         CStringFeatures<uint8_t>* feat=new CStringFeatures<uint8_t>(strings, RAWBYTE);
  52.         SG_REF(feat);
  53.         CStringFeatures<uint16_t>* wfeat=new CStringFeatures<uint16_t>(RAWBYTE);
  54.         SG_REF(wfeat);
  55.         wfeat->obtain_from_char_features<uint8_t>(feat, 0, 1, 0, false);
  56.         CSortWordString *preproc=new CSortWordString();
  57.         SG_REF(preproc);
  58.         preproc->init(wfeat);
  59.         wfeat->add_preprocessor(preproc);
  60.         wfeat->apply_preprocessor();
  61.  
  62.         /** string kernel */
  63.         //CCommWordStringKernel *kmer_ker=new CCommWordStringKernel(10, false);
  64.         CWeightedCommWordStringKernel *kmer_ker = new CWeightedCommWordStringKernel(10, false);
  65.         SG_REF(kmer_ker);
  66.         kmer_ker->init(wfeat, wfeat);
  67.  
  68.         /** mc svm */
  69.         CMulticlassLabels *labels=new CMulticlassLabels(lab);
  70.         SG_REF(labels);
  71.         CGMNPSVM* svm=new CGMNPSVM(10, kmer_ker, labels);
  72.         SG_REF(svm);
  73.  
  74.         /** cross-validation */
  75.         int32_t n_folds=2, n_runs=1;
  76.         CMulticlassAccuracy* eval_crit=new CMulticlassAccuracy();
  77.         SG_REF(eval_crit);
  78.         CStratifiedCrossValidationSplitting* splitting= new CStratifiedCrossValidationSplitting(labels, n_folds);
  79.         SG_REF(splitting);
  80.         CCrossValidation* cross=new CCrossValidation(svm, wfeat, labels, splitting, eval_crit);
  81.         SG_REF(cross);
  82.         cross->set_num_runs(n_runs);
  83.         cross->set_autolock(false);
  84.         CEvaluationResult* result=cross->evaluate();
  85.         SG_UNREF(result);
  86.  
  87.         CCrossValidationMulticlassStorage *mc_storage= new CCrossValidationMulticlassStorage();
  88.         SG_REF(mc_storage);
  89.         cross->add_cross_validation_output(mc_storage);
  90.  
  91.         SG_UNREF(feat);
  92.         SG_UNREF(wfeat);
  93.         SG_UNREF(preproc);
  94.         SG_UNREF(kmer_ker);
  95.         SG_UNREF(labels);
  96.         SG_UNREF(svm);
  97.         SG_UNREF(eval_crit);
  98.         SG_UNREF(splitting);
  99.         SG_UNREF(cross);
  100.         SG_UNREF(mc_storage);
  101. }
  102.  
  103.  
  104. /** main */
  105. int main()
  106. {
  107.         init_shogun_with_defaults();
  108.         sg_io->set_loglevel(MSG_DEBUG);
  109.         test();
  110.         exit_shogun();
  111.         return 0;
  112. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement