Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <shogun/base/init.h>
- #include <shogun/features/StringFeatures.h>
- #include <shogun/evaluation/CrossValidation.h>
- #include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
- #include <shogun/evaluation/CrossValidationMulticlassStorage.h>
- #include <shogun/evaluation/MulticlassAccuracy.h>
- #include <shogun/multiclass/GMNPSVM.h>
- #include <shogun/preprocessor/SortWordString.h>
- #include <shogun/kernel/string/WeightedCommWordStringKernel.h>
- using namespace shogun;
- void test()
- {
- /** generate random strings list and mc labels */
- index_t num_strings=50;
- index_t max_string_length=20;
- index_t min_string_length=max_string_length/2;
- SGStringList<uint8_t> strings(num_strings, max_string_length);
- SGVector<float64_t> lab(num_strings);
- SG_SPRINT("original string data:\n");
- for (index_t i=0; i<num_strings; ++i)
- {
- index_t len=CMath::random(min_string_length, max_string_length);
- SGString<uint8_t> current(len);
- // labels
- if (i<num_strings/3)
- lab[i]=0;
- else if (i<2*num_strings/3)
- lab[i]=1;
- else
- lab[i]=2;
- SG_SPRINT("[%i]: \"", i);
- for (index_t j=0; j<len; ++j)
- {
- current.string[j]=(uint8_t)CMath::random('A', 'Z');
- char* string=SG_MALLOC(char, 2);
- string[0]=current.string[j];
- string[1]='\0';
- SG_SPRINT("%s", string);
- SG_FREE(string);
- }
- SG_SPRINT("\"\n");
- strings.strings[i]=current;
- }
- /** word string features */
- CStringFeatures<uint8_t>* feat=new CStringFeatures<uint8_t>(strings, RAWBYTE);
- SG_REF(feat);
- CStringFeatures<uint16_t>* wfeat=new CStringFeatures<uint16_t>(RAWBYTE);
- SG_REF(wfeat);
- wfeat->obtain_from_char_features<uint8_t>(feat, 0, 1, 0, false);
- CSortWordString *preproc=new CSortWordString();
- SG_REF(preproc);
- preproc->init(wfeat);
- wfeat->add_preprocessor(preproc);
- wfeat->apply_preprocessor();
- /** string kernel */
- //CCommWordStringKernel *kmer_ker=new CCommWordStringKernel(10, false);
- CWeightedCommWordStringKernel *kmer_ker = new CWeightedCommWordStringKernel(10, false);
- SG_REF(kmer_ker);
- kmer_ker->init(wfeat, wfeat);
- /** mc svm */
- CMulticlassLabels *labels=new CMulticlassLabels(lab);
- SG_REF(labels);
- CGMNPSVM* svm=new CGMNPSVM(10, kmer_ker, labels);
- SG_REF(svm);
- /** cross-validation */
- int32_t n_folds=2, n_runs=1;
- CMulticlassAccuracy* eval_crit=new CMulticlassAccuracy();
- SG_REF(eval_crit);
- CStratifiedCrossValidationSplitting* splitting= new CStratifiedCrossValidationSplitting(labels, n_folds);
- SG_REF(splitting);
- CCrossValidation* cross=new CCrossValidation(svm, wfeat, labels, splitting, eval_crit);
- SG_REF(cross);
- cross->set_num_runs(n_runs);
- cross->set_autolock(false);
- CEvaluationResult* result=cross->evaluate();
- SG_UNREF(result);
- CCrossValidationMulticlassStorage *mc_storage= new CCrossValidationMulticlassStorage();
- SG_REF(mc_storage);
- cross->add_cross_validation_output(mc_storage);
- SG_UNREF(feat);
- SG_UNREF(wfeat);
- SG_UNREF(preproc);
- SG_UNREF(kmer_ker);
- SG_UNREF(labels);
- SG_UNREF(svm);
- SG_UNREF(eval_crit);
- SG_UNREF(splitting);
- SG_UNREF(cross);
- SG_UNREF(mc_storage);
- }
- /** main */
- int main()
- {
- init_shogun_with_defaults();
- sg_io->set_loglevel(MSG_DEBUG);
- test();
- exit_shogun();
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement