daily pastebin goal
23%
SHARE
TWEET

HOG + SVM

TuanAnhVu Jun 14th, 2016 (edited) 230 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #include "stdafx.h"
  2. #include <stdio.h>  
  3. #include <string.h>
  4. #include <fstream>
  5. #include <iterator>
  6. #include <vector>
  7. #include <opencv2/core.hpp>
  8. #include <opencv2/imgproc.hpp>
  9. #include "opencv2/imgcodecs.hpp"
  10. #include <opencv2/highgui.hpp>
  11. #include <opencv2/ml.hpp>
  12. #include "opencv2/objdetect.hpp"
  13.  
  14. using namespace cv;
  15. using namespace cv::ml;
  16. using namespace std;
  17.  
  18.  
  19. void load_image(vector < Mat > & storeInput, vector< int > & labels){
  20.     char filename[256];
  21.    
  22.     for (int i = 1; i <= 6; i++){
  23.         for (int j = 1; j <= 50; j++) {
  24.             Mat img, img_gray;
  25.             sprintf(filename, "./train_data/%02d/%04d.ppm", i, j);
  26.             img = imread(filename);
  27.             resize(img, img, Size(32, 32));
  28.             cvtColor(img, img_gray, CV_BGR2GRAY);
  29.             storeInput.push_back(img_gray);
  30.             labels.push_back(i*1000 + j);
  31.             //show image  
  32.             imshow("load image", img_gray);
  33.  
  34.             waitKey(5);
  35.         }
  36.     }
  37. }
  38.  
  39. void convert_to_ml(const std::vector< cv::Mat > & train_samples, cv::Mat& trainData)
  40. {
  41.     //--Convert data
  42.     const int rows = (int)train_samples.size();
  43.     const int cols = (int)std::max(train_samples[0].cols, train_samples[0].rows);
  44.     cv::Mat tmp(1, cols, CV_32FC1); //< used for transposition if needed
  45.     trainData = cv::Mat(rows, cols, CV_32FC1);
  46.     vector< Mat >::const_iterator itr = train_samples.begin();
  47.     vector< Mat >::const_iterator end = train_samples.end();
  48.     for (int i = 0; itr != end; ++itr, ++i)
  49.     {
  50.         if (itr->cols == 1)
  51.         {
  52.             transpose(*(itr), tmp);
  53.             tmp.copyTo(trainData.row(i));
  54.         }
  55.         else if (itr->rows == 1)
  56.         {
  57.             itr->copyTo(trainData.row(i));
  58.         }
  59.     }
  60. }
  61.  
  62. void compute_hog(const vector< Mat > & img_lst, vector< Mat > & gradient_lst)
  63. {
  64.     cout << "Start hog...";
  65.     Size size = Size(32, 32);
  66.     Size block_size = Size(size.width / 4, size.height / 4);
  67.     Size block_stride = Size(size.width / 8, size.height / 8);
  68.     Size cell_size = block_stride;
  69.     int num_bins = 9;
  70.     HOGDescriptor hog(size, block_size, block_stride, cell_size, num_bins);
  71.     Mat inputHOG;
  72.     vector< Point > location;
  73.     vector< float > descriptors;
  74.  
  75.     for (int i = 0; i < img_lst.size(); i++)
  76.     {
  77.         inputHOG = img_lst[i];
  78.         hog.compute(inputHOG, descriptors, Size(0, 0), Size(0, 0), location);
  79.         //cout << "descriptors size: " << descriptors.size() << endl;
  80.  
  81.         gradient_lst.push_back(Mat(descriptors).clone());
  82.  
  83. }
  84. cout << "...[done]" << endl;
  85. }
  86.  
  87.  
  88. void train_svm(const vector< Mat > & storeInput, const vector< int > & labels){
  89.     Mat train_data;
  90.     convert_to_ml(storeInput, train_data);
  91.     cout << "Start training...";
  92.  
  93.  
  94.     Ptr<SVM> svm = SVM::create();
  95.     svm->setType(SVM::C_SVC);                                          
  96.     svm->setKernel(SVM::LINEAR);                                       
  97.     svm->setC(10);
  98.     svm->setTermCriteria(cvTermCriteria(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS, 1000, 1e-6));
  99.  
  100.     svm->train(train_data, ROW_SAMPLE, Mat(labels));
  101.  
  102.     cout << "...[done]" << endl;
  103.  
  104.     svm->save("./genfiles/trainSVM_Recog_MoreImage.xml");
  105.     cout << "Save XML ..... done!!!" << endl;
  106. }
  107.  
  108. //test trainning data + HOG
  109. float test_svm(Mat img, string nameSVM){
  110.  
  111.     //read image file
  112.     Mat img_gray, feature;
  113.     imshow("input", img);
  114.  
  115.  
  116.     //resizing
  117.     resize(img, img, Size(32, 32));
  118.     //gray
  119.     cvtColor(img, img_gray, COLOR_BGR2GRAY);
  120.  
  121.     //Extract HogFeature
  122.     Size size = Size(32, 32);
  123.     Size block_size = Size(size.width / 4, size.height / 4);
  124.     Size block_stride = Size(size.width / 8, size.height / 8);
  125.     Size cell_size = block_stride;
  126.     int num_bins = 9;
  127.     HOGDescriptor hog(size, block_size, block_stride, cell_size, num_bins);
  128.     vector< float> descriptorsValues;
  129.     vector< Point> locations;
  130.  
  131.     hog.compute(img_gray, descriptorsValues, Size(0, 0), Size(0, 0), locations);
  132.     //vector to Mat
  133.     Mat fm = Mat(descriptorsValues);
  134.  
  135.     //Classification data
  136.     Ptr<SVM> svm = Algorithm::load<SVM>(nameSVM);
  137.     std::cout << "Model Loaded" << std::endl;
  138.  
  139.     Mat image1d(1, fm.rows, CV_32FC1);
  140.     imshow("gradient 1", image1d);
  141.  
  142.     float result = svm->predict(image1d);
  143.     //std::cout << "Predict value: " << result << std::endl;
  144.     return result;
  145. }
  146.  
  147.  
  148. int main(){
  149.  
  150.  
  151.     //trainning data
  152.     vector < Mat > storeInput;
  153.     vector < int > labels;
  154.     vector < Mat > gradient_lst;
  155.  
  156.     load_image(storeInput, labels);
  157.  
  158.     compute_hog(storeInput, gradient_lst);
  159.  
  160.     train_svm(gradient_lst, labels);
  161.  
  162.     Mat img = imread("./train_data/01/0001.ppm");
  163.     string fileSVM = "./genfiles/trainSVM_Recog_MoreImage.xml";
  164.     float result = test_svm(img, fileSVM);
  165.     std::cout << "Predict value: " << result << std::endl;
  166.  
  167.     waitKey(0);
  168.     return 0;
  169. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top