Advertisement
jack06215

[dlib] Linear SVM

Jul 17th, 2020
116
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 4.74 KB | None | 0 0
  1. #include <dlib/svm_threaded.h>
  2. #include <dlib/rand.h>
  3. #include <dlib/time_this.h>
  4. #include <dlib/algs.h>
  5.  
  6. #include <vector>
  7. #include <iostream>
  8. #include <cerrno>
  9.  
  10.  
  11. using namespace dlib;
  12. using namespace std;
  13.  
  14. typedef matrix<double, 0, 1> sample_type;
  15. typedef matrix<double, 0, 1> sample_type;                   // variable column vector
  16.  
  17.  
  18. typedef one_vs_one_trainer<any_trainer<sample_type>, double> ovo_trainer;
  19. typedef radial_basis_kernel<sample_type> rbf_kernel_type;
  20.  
  21. class Learn {
  22. public:
  23.     Learn() { }
  24.     virtual ~Learn() { }
  25.     virtual void train() { }
  26.     virtual void save(std::string pathname) { }
  27.     virtual void load(std::string pathname) { }
  28.     inline sample_type vectorToSample(std::vector<double> sample_);
  29. };
  30.  
  31. class LearnSupervised : public Learn {
  32. public:
  33.     LearnSupervised() : Learn() { }
  34.     void addSample(std::vector<double> sample, double label);
  35.     void addSample(sample_type sample, double label);
  36.     virtual double predict(std::vector<double>& sample) { return 0; }
  37.     virtual double predict(sample_type& sample) { return 0; }
  38.     void clearTrainingInstances();
  39. protected:
  40.     std::vector<sample_type> samples;
  41.     std::vector<double> labels;
  42. };
  43.  
  44. class LearnSVM : public LearnSupervised {
  45. public:
  46.     LearnSVM();
  47.     ~LearnSVM();
  48.     void train();
  49.     void trainWithGridParameterSearch();
  50.     double predict(std::vector<double>& sample);
  51.     double predict(sample_type& sample);
  52.     void save(string path);
  53.     void load(string path);
  54. private:
  55.     ovo_trainer trainer;
  56.     krr_trainer<rbf_kernel_type> rbf_trainer;
  57.     one_vs_one_decision_function<ovo_trainer> df;
  58. };
  59. // END_OF_CLASS_DEFINITION
  60.  
  61. //////////////////////////////////////////////////////////////////////
  62. ////                            Learn                             ////
  63. //////////////////////////////////////////////////////////////////////
  64. inline sample_type Learn::vectorToSample(std::vector<double> sample_) {
  65.     sample_type sample(sample_.size());
  66.     for (int i = 0; i < sample_.size(); i++) {
  67.         sample(i) = sample_.at(i);
  68.     }
  69.     return sample;
  70. }
  71.  
  72. //////////////////////////////////////////////////////////////////////
  73. ////                         Supervised                           ////
  74. //////////////////////////////////////////////////////////////////////
  75. void LearnSupervised::addSample(sample_type sample, double label) {
  76.     if (label < 0.0 || label > 1.0) {
  77.         std::cerr << "label should be between 0.0 and 1.0" << std::endl;
  78.     }
  79.     samples.push_back(sample);
  80.     labels.push_back(label);
  81. }
  82.  
  83. void LearnSupervised::addSample(std::vector<double> sample, double label) {
  84.     if (label < 0.0 || label > 1.0) {
  85.         std::cerr << "label should be between 0.0 and 1.0" << std::endl;
  86.     }
  87.     sample_type tmp(sample.size());
  88.     for (int i = 0; i < sample.size(); i++) {
  89.         tmp(i) = sample.at(i);
  90.     }
  91.     samples.push_back(tmp);
  92.     labels.push_back(label);
  93. }
  94.  
  95. void LearnSupervised::clearTrainingInstances() {
  96.     samples.clear();
  97.     labels.clear();
  98. }
  99.  
  100. LearnSVM::LearnSVM() : LearnSupervised() {
  101.  
  102. }
  103.  
  104. LearnSVM::~LearnSVM() {
  105.  
  106. }
  107.  
  108. void LearnSVM::train() {
  109.     rbf_trainer.set_kernel(rbf_kernel_type(0.1));
  110.     rbf_trainer.set_lambda(0.01);
  111.     trainer.set_trainer(rbf_trainer);
  112.  
  113.     randomize_samples(samples, labels);
  114.     df = trainer.train(samples, labels);
  115. }
  116.  
  117.  
  118. void LearnSVM::trainWithGridParameterSearch() {
  119.  
  120. }
  121.  
  122. double LearnSVM::predict(sample_type& sample) {
  123.     return df(sample);
  124. }
  125.  
  126. double LearnSVM::predict(std::vector<double>& sample) {
  127.     return df(vectorToSample(sample));
  128. }
  129.  
  130. void LearnSVM::save(string path) {
  131.     const char* filepath = path.c_str();
  132.     ofstream fout(filepath, ios::binary);
  133.     one_vs_one_decision_function<ovo_trainer, decision_function<rbf_kernel_type>> df2;
  134.     df2 = df;
  135.     serialize(df2, fout);
  136. }
  137.  
  138. void LearnSVM::load(string path) {
  139.     const char* filepath = path.c_str();
  140.     ifstream fin(filepath, ios::binary);
  141.     one_vs_one_decision_function<ovo_trainer, decision_function<rbf_kernel_type>> df2;
  142.     deserialize(df2, fin);
  143.     df = df2;
  144. }
  145.  
  146. double interpolate(double __inValue, double __minInRange, double __maxInRange, double __minOutRange, double __maxOutRange) {
  147.     double tmp = __inValue / (__maxInRange - __minInRange);
  148.     return __minOutRange + (__maxOutRange - __minOutRange) * tmp;
  149. }
  150.  
  151. double clamp(double value, double lower, double upper) {
  152.     return std::max(lower, std::min(value, upper));
  153. }
  154.  
  155. int main(void) {
  156.     LearnSVM classifier;
  157.     dlib::rand rng;
  158.     double test;
  159.  
  160.     for (int i = 0; i < 3000; i++) {
  161.         double x = rng.get_double_in_range(0, 3000);
  162.         double y = 0.00074 * (x * x) + 0.0095 * x + rng.get_double_in_range(-80, 80);
  163.         x = clamp(interpolate(x, 0, 3000, 0, 1), 0, 1);
  164.         y = clamp(interpolate(y, 0, 3000, 0, 1), 0, 1);
  165.  
  166.         //std::cout << x << "\t" << y << std::endl;
  167.        
  168.         std::vector<double> sample;
  169.         sample.push_back(x);
  170.         classifier.addSample(sample, y);
  171.  
  172.  
  173.     }
  174.  
  175.     TIME_THIS(classifier.train());
  176.     return 0;
  177. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement