Advertisement
Guest User

SVM_OpenCV_tryout

a guest
Jul 25th, 2014
492
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 15.46 KB | None | 0 0
  1. /***************************************************************************************************
  2. Copyright (c) 2013 EAVISE, KU Leuven, Campus De Nayer
  3. Contact: steven.puttemans[at]kuleuven.be
  4.  
  5. Permission is hereby granted, free of charge, to any person obtaining a copy
  6. of this software and associated documentation files (the "Software"), to deal
  7. in the Software without restriction, including without limitation the rights
  8. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. copies of the Software, and to permit persons to whom the Software is
  10. furnished to do so, subject to the following conditions:
  11.  
  12. The above copyright notice and this permission notice shall be included in
  13. all copies or substantial portions of the Software.
  14.  
  15. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  21. THE SOFTWARE.
  22.  
  23. *****************************************************************************************************
  24.  
  25. Software for creating an SVM model based on training data
  26. INPUT: data_root_folder positive_training_samples negative_training_samples
  27. OUTPUT: SVM_xml_model
  28.  
  29. Extra info
  30. - Cookies average dimensions : w = 122px | h = 117px
  31. - All positive and negative images should have this size for SVM model training to work properly
  32. - If positives have different sizes, resize them using the batch_resize_segmentations utility
  33.  
  34. *****************************************************************************************************/
  35. #include "stdafx.h"
  36.  
  37. // OpenCV include all functionality
  38. #include "opencv2/opencv.hpp";
  39.  
  40. // Extra includes for file processing
  41. #include <vector>
  42. #include <sstream>
  43. #include <fstream>
  44.  
  45. // Open correct namespaces
  46. using namespace std;
  47. using namespace cv;
  48.  
  49. // HOGDescriptor visual_image analyzing
  50. // Adapted from http://www.juergenwiki.de/work/wiki/doku.php?id=public%3ahog_descriptor_computation_and_visualization
  51. // ONLY PRECAUSIONS ARE
  52. // --> Image size width/heigth needs to be a multiple of block width/heigth
  53. // --> Block size width/heigth (multiple cells) needs to be a multiple of cell size (histogram region) width/heigth
  54. // --> Block stride needs to be a multiple of a cell size, however current code only allows to use a block stride = cell size!
  55. // --> ScaleFactor enlarges the image patch to make it visible (e.g. a patch of 50x50 could have a factor 10 to be visible at scale 500x500 for inspection)
  56. // --> viz_factor enlarges the maximum size of the maximal gradient length for normalization. At viz_factor = 1 it results in a length = half the cell width
  57. Mat get_hogdescriptor_visual_image(Mat& origImg, vector<float>& descriptorValues, Size winSize, Size cellSize, int scaleFactor, double viz_factor)
  58. {  
  59.     Mat visual_image;
  60.     resize(origImg, visual_image, Size(origImg.cols*scaleFactor, origImg.rows*scaleFactor));
  61.  
  62.     int gradientBinSize = 9;
  63.     float radRangeForOneBin = 3.14/(float)gradientBinSize; // dividing 180� into 9 bins, how large (in rad) is one bin?
  64.  
  65.     // prepare data structure: 9 orientation / gradient strenghts for each cell
  66.     int cells_in_x_dir = winSize.width / cellSize.width;
  67.     int cells_in_y_dir = winSize.height / cellSize.height;
  68.     int totalnrofcells = cells_in_x_dir * cells_in_y_dir;
  69.     float*** gradientStrengths = new float**[cells_in_y_dir];
  70.     int** cellUpdateCounter   = new int*[cells_in_y_dir];
  71.     for (int y=0; y<cells_in_y_dir; y++)
  72.     {
  73.         gradientStrengths[y] = new float*[cells_in_x_dir];
  74.         cellUpdateCounter[y] = new int[cells_in_x_dir];
  75.         for (int x=0; x<cells_in_x_dir; x++)
  76.         {
  77.             gradientStrengths[y][x] = new float[gradientBinSize];
  78.             cellUpdateCounter[y][x] = 0;
  79.  
  80.             for (int bin=0; bin<gradientBinSize; bin++)
  81.                 gradientStrengths[y][x][bin] = 0.0;
  82.         }
  83.     }
  84.  
  85.     // nr of blocks = nr of cells - 1
  86.     // since there is a new block on each cell (overlapping blocks!) but the last one
  87.     int blocks_in_x_dir = cells_in_x_dir - 1;
  88.     int blocks_in_y_dir = cells_in_y_dir - 1;
  89.  
  90.     // compute gradient strengths per cell
  91.     int descriptorDataIdx = 0;
  92.     int cellx = 0;
  93.     int celly = 0;
  94.  
  95.     for (int blockx=0; blockx<blocks_in_x_dir; blockx++)
  96.     {
  97.         for (int blocky=0; blocky<blocks_in_y_dir; blocky++)            
  98.         {
  99.             // 4 cells per block ...
  100.             for (int cellNr=0; cellNr<4; cellNr++)
  101.             {
  102.                 // compute corresponding cell nr
  103.                 int cellx = blockx;
  104.                 int celly = blocky;
  105.                 if (cellNr==1) celly++;
  106.                 if (cellNr==2) cellx++;
  107.                 if (cellNr==3)
  108.                 {
  109.                     cellx++;
  110.                     celly++;
  111.                 }
  112.  
  113.                 for (int bin=0; bin<gradientBinSize; bin++)
  114.                 {
  115.                     float gradientStrength = descriptorValues[ descriptorDataIdx ];
  116.                     descriptorDataIdx++;
  117.  
  118.                     gradientStrengths[celly][cellx][bin] += gradientStrength;
  119.  
  120.                 } // for (all bins)
  121.                 // note: overlapping blocks lead to multiple updates of this sum!
  122.                 // we therefore keep track how often a cell was updated,
  123.                 // to compute average gradient strengths
  124.                 cellUpdateCounter[celly][cellx]++;
  125.             } // for (all cells)
  126.         } // for (all block x pos)
  127.     } // for (all block y pos)
  128.  
  129.  
  130.     // compute average gradient strengths
  131.     for (int celly=0; celly<cells_in_y_dir; celly++)
  132.     {
  133.         for (int cellx=0; cellx<cells_in_x_dir; cellx++)
  134.         {
  135.             float NrUpdatesForThisCell = (float)cellUpdateCounter[celly][cellx];
  136.             // compute average gradient strenghts for each gradient bin direction
  137.             for (int bin=0; bin<gradientBinSize; bin++)
  138.             {
  139.                 gradientStrengths[celly][cellx][bin] /= NrUpdatesForThisCell;
  140.             }
  141.         }
  142.     }
  143.  
  144.     // draw cells
  145.     for (int celly=0; celly<cells_in_y_dir; celly++)
  146.     {
  147.         for (int cellx=0; cellx<cells_in_x_dir; cellx++)
  148.         {
  149.             int drawX = cellx * cellSize.width;
  150.             int drawY = celly * cellSize.height;
  151.  
  152.             int mx = drawX + cellSize.width/2;
  153.             int my = drawY + cellSize.height/2;
  154.  
  155.             rectangle(visual_image, Point(drawX*scaleFactor,drawY*scaleFactor), Point((drawX+cellSize.width)*scaleFactor,(drawY+cellSize.height)*scaleFactor), CV_RGB(100,100,100), 1);
  156.  
  157.             // draw in each cell all 9 gradient strengths
  158.             for (int bin=0; bin<gradientBinSize; bin++)
  159.             {
  160.                 float currentGradStrength = gradientStrengths[celly][cellx][bin];
  161.  
  162.                 // no line to draw?
  163.                 if (currentGradStrength==0)
  164.                     continue;
  165.  
  166.                 float currRad = bin * radRangeForOneBin + radRangeForOneBin/2;
  167.  
  168.                 float dirVecX = cos( currRad );
  169.                 float dirVecY = sin( currRad );
  170.                 float maxVecLen = cellSize.width/2;
  171.                 float scale = viz_factor; // just a visual_imagealization scale, to see the lines better
  172.  
  173.                 // compute line coordinates
  174.                 float x1 = mx - dirVecX * currentGradStrength * maxVecLen * scale;
  175.                 float y1 = my - dirVecY * currentGradStrength * maxVecLen * scale;
  176.                 float x2 = mx + dirVecX * currentGradStrength * maxVecLen * scale;
  177.                 float y2 = my + dirVecY * currentGradStrength * maxVecLen * scale;
  178.  
  179.                 // draw gradient visual_imagealization
  180.                 line(visual_image, Point(x1*scaleFactor,y1*scaleFactor), Point(x2*scaleFactor,y2*scaleFactor), CV_RGB(0,0,255), 1);
  181.             } // for (all bins)
  182.         } // for (cellx)
  183.     } // for (celly)
  184.  
  185.     // don't forget to free memory allocated by helper data structures!
  186.     for (int y=0; y<cells_in_y_dir; y++)
  187.     {
  188.       for (int x=0; x<cells_in_x_dir; x++)
  189.       {
  190.            delete[] gradientStrengths[y][x];            
  191.       }
  192.       delete[] gradientStrengths[y];
  193.       delete[] cellUpdateCounter[y];
  194.     }
  195.     delete[] gradientStrengths;
  196.     delete[] cellUpdateCounter;
  197.  
  198.     return visual_image;
  199. }
  200.  
  201. int _tmain(int argc, _TCHAR* argv[])
  202. {
  203.     // Check if arguments are given correct
  204.     if( argc == 1 || argc != 3){
  205.         printf( "Usage of SVM training software: \n"  
  206.                 "svm_train.exe <positive_training_samples.txt> <negative_training_samples.txt> <resulting_model.xml>\n");
  207.         return 0;
  208.     }
  209.  
  210.     // ****************************************************************************************************************************************
  211.     // PREPROCESSING
  212.     // ****************************************************************************************************************************************
  213.  
  214.     // Retrieve data from input arguments
  215.     string positive_file = argv[1];
  216.     string negative_file = argv[2];
  217.     string model_file = argv[3];
  218.  
  219.     // Create the HOG descriptor initialisation - configuration
  220.     HOGDescriptor hog;
  221.     Size window_size = Size(48,96); hog.winSize = window_size;
  222.     Size cell_size = Size(8,8); hog.cellSize = cell_size; hog.blockStride = cell_size;
  223.     Size block_size = Size(16,16); hog.blockSize = block_size;
  224.     int scale_factor = 2, viz_factor = 3;
  225.  
  226.     // ****************************************************************************************************************************************
  227.     // POSITIVE DATA - DESCRIPTORS TO RIGHT FORMAT
  228.     // ****************************************************************************************************************************************
  229.  
  230.     // Retrieve a list of positive file names
  231.     ifstream input (positive_file);
  232.     string current_line;
  233.     vector<string> filenames_positive;
  234.     while ( getline(input, current_line) ){
  235.         vector<string> line_elements;
  236.         stringstream temp (current_line);
  237.         string first_element;
  238.         getline(temp, first_element, ' ');
  239.         filenames_positive.push_back(first_element);
  240.     }
  241.     int number_pos_samples = filenames_positive.size();
  242.     input.close();
  243.  
  244.     // For each positive file, compute the descriptor, visualise it and store the descriptor
  245.     vector< vector<float> > all_positive_descriptors;
  246.     for(int i = 0; i < filenames_positive.size(); i++){
  247.         // Read and compute descriptors
  248.         Mat original = imread(filenames_positive[i]);
  249.    
  250.         vector<float> single_image_descriptor;
  251.         hog.compute(original, single_image_descriptor);
  252.  
  253.         // Visualise it
  254.         Mat image_with_descriptors = get_hogdescriptor_visual_image(original, single_image_descriptor, window_size, cell_size, scale_factor, viz_factor);
  255.         imshow("Visualize descriptors", image_with_descriptors);
  256.         int key = waitKey(15);
  257.         if( key == 27 ){
  258.             cout << "Processing aborted by pressing the ESC key!" << endl;
  259.             break;
  260.         }
  261.  
  262.         // Store it
  263.         all_positive_descriptors.push_back(single_image_descriptor);
  264.     }
  265.  
  266.     // Convert the vector of vectors into the correct format for the positive samples
  267.     Mat all_positive_descriptors_matrix (all_positive_descriptors.size(), all_positive_descriptors[0].size(), CV_32FC1);
  268.     for (size_t i = 0; i < all_positive_descriptors.size(); i++) {
  269.         for (size_t j = 0; j < all_positive_descriptors[i].size(); j++) {
  270.             all_positive_descriptors_matrix.at<float>(i, j) = all_positive_descriptors[i][j];
  271.         }
  272.     }
  273.  
  274.     // Output descriptor size for debug purposes
  275.     cout << "Descriptor size using preset parameters on this image = " << all_positive_descriptors[0].size() << endl;
  276.  
  277.     // ****************************************************************************************************************************************
  278.     // NEGATIVE DATA - DESCRIPTORS TO RIGHT FORMAT
  279.     // ****************************************************************************************************************************************
  280.  
  281.     // Retrieve a list of negative file names
  282.     input.open(negative_file);
  283.     vector<string> filenames_negative;
  284.     while ( getline(input, current_line) ){
  285.         vector<string> line_elements;
  286.         stringstream temp (current_line);
  287.         string first_element;
  288.         getline(temp, first_element, ' ');
  289.         filenames_negative.push_back(first_element);
  290.     }
  291.     int number_neg_samples = filenames_negative.size();
  292.     input.close();
  293.  
  294.     // For each negative file, compute the descriptor, visualise it and store the descriptor
  295.     vector< vector<float> > all_negative_descriptors;
  296.     for(int i = 0; i < filenames_negative.size(); i++){
  297.         // Read and compute descriptors
  298.         Mat original = imread(filenames_negative[i]);
  299.        
  300.         vector<float> single_image_descriptor;
  301.         hog.compute(original, single_image_descriptor);
  302.  
  303.         // Visualise it
  304.         Mat image_with_descriptors = get_hogdescriptor_visual_image(original, single_image_descriptor, window_size, cell_size, scale_factor, viz_factor);
  305.         imshow("Visualize descriptors", image_with_descriptors);
  306.         int key = waitKey(15);
  307.         if( key == 27 ){
  308.             cout << "Processing aborted by pressing the ESC key!" << endl;
  309.             break;
  310.         }
  311.  
  312.         // Store it
  313.         all_negative_descriptors.push_back(single_image_descriptor);
  314.     }
  315.  
  316.     // Convert the vector of vectors into the correct format for the negative samples
  317.     Mat all_negative_descriptors_matrix (all_negative_descriptors.size(), all_negative_descriptors[0].size(), CV_32FC1);
  318.     for (size_t i = 0; i < all_negative_descriptors.size(); i++) {
  319.         for (size_t j = 0; j < all_negative_descriptors[i].size(); j++) {
  320.             all_negative_descriptors_matrix.at<float>(i, j) = all_negative_descriptors[i][j];
  321.         }
  322.     }
  323.  
  324.     // ****************************************************************************************************************************************
  325.     // COMBINE BOTH SETS AND PROVIDE CORRECT LABELS + WEIGHTS WHEN PREFERRED
  326.     // Only adapt weights if you know what the influence of this parameter is!
  327.     // ****************************************************************************************************************************************
  328.    
  329.     Mat inputs, labels;
  330.     Mat labels_pos = Mat::ones(number_pos_samples, 1, CV_32FC1);   
  331.     Mat labels_neg = Mat::ones(number_neg_samples, 1, CV_32FC1) * -1.0;
  332.  
  333.     vconcat(all_positive_descriptors_matrix, all_negative_descriptors_matrix, inputs);
  334.     vconcat(labels_pos, labels_neg, labels);
  335.  
  336.     cv::Mat1f weights(1,2); weights(0,0) = 1; weights(0,1) = 1;
  337.  
  338.     // ****************************************************************************************************************************************
  339.     // TRAIN A SVM WITH THIS DATA
  340.     // ****************************************************************************************************************************************
  341.    
  342.     // Configuring the SVM for training purposes
  343.     CvSVMParams params;
  344.     params.svm_type = CvSVM::C_SVC;
  345.     params.kernel_type = CvSVM::LINEAR;
  346.     params.gamma = 20;
  347.     params.degree = 0;
  348.     params.coef0 = 0;
  349.     params.C = 1000; //Take a large punishment for misclassification
  350.     params.nu = 0.0;
  351.     params.p = 0.0;
  352.     params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, 1e-6);
  353.     CvMat old_weights = weights; params.class_weights = &old_weights;
  354.  
  355.     // Train the SVM
  356.     CvSVM SVM_model;
  357.     SVM_model.train(inputs, labels, Mat(), Mat(), params);
  358.  
  359.     cout << "Training done!" << endl;
  360.     cout << "Saving the SVM model!" << endl;
  361.  
  362.     stringstream store_location;
  363.     store_location << model_file;
  364.  
  365.     SVM_model.save(store_location.str().c_str());
  366.  
  367.     return 0;
  368. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement