Advertisement
Guest User

SVM training

a guest
Apr 29th, 2014
271
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 4.83 KB | None | 0 0
  1. //Start the contour detection
  2. void contourDetection(){
  3.    
  4.     //Vector that represents the hierarchy of the contours
  5.     vector< Vec4i > hierarchy;
  6.    
  7.     ///Find contours and store in a contours, then store the contour hierarchies in hierarchy
  8.     cv::findContours(dst, contours, hierarchy, CV_RETR_TREE, CV_CHAIN_APPROX_NONE);
  9.  
  10.     double area;
  11.     int count = 0, all = 0;
  12.     CvSVM detector;
  13.     bool exit = false;
  14.     vector<vector<double>> allData;
  15.  
  16.     vector<float> labels;
  17.  
  18.     //Load the training data
  19.     detector.load("Muskrat_Counting_SVM.xml");
  20.  
  21.     //Size of the whole picture
  22.     int totSize = dst.total();
  23.  
  24.     //Iterate through all parent contours
  25.     for (int i = 0; i < contours.size() && exit == false; i = hierarchy[i][0])
  26.     {
  27.  
  28.         //Find the size of the outer contours (lakes)
  29.         area = contourArea(contours[i]);
  30.  
  31.         //Find the biggest contours (presumed to be the lakes)**Needs optimization
  32.         if (area > totSize*.20)// && area < totSize*.5)
  33.         {
  34.             //Draw the lake
  35.             draw(contours[i], "red");                                              
  36.  
  37.             //Using hierarchy array to determine hierarchy, this loop is for the child contours
  38.             for (int j = hierarchy[i][2]; hierarchy[j][0] > 0 && exit == false; j = hierarchy[j][0]){
  39.  
  40.                 //Find contour size
  41.                 double contArea = contourArea(contours[j]);
  42.  
  43.                 if (contArea > 20 && contArea < 300){                               //When size of the contour is inbetween those area
  44.  
  45.  
  46.                     //Get the location of the center of the contour for zooming
  47.                     Moments mu = moments(contours[j], false);
  48.                     pCen = Point2f(mu.m10 / mu.m00, mu.m01 / mu.m00);
  49.  
  50.                     all++;
  51.  
  52.                     //Draw current contour
  53.                     draw(contours[j], "orange");
  54.  
  55.                     //Evaluate next contour
  56.                     zoom(Mat(frame));
  57.  
  58.                     //Output question
  59.                     cout << "Is this a contour? (Y or N) ";
  60.                     string input;
  61.  
  62.                     while (exit == false){
  63.  
  64.                         //Retrieve user input
  65.                         getline(cin, input);
  66.                         cout << endl;
  67.                         cout << "\n\nGetting next contour...";
  68.                         //If user enters Y
  69.                         if (strUpper(input) == "Y" || strUpper(input) == "N")
  70.                         {
  71.                             if (strUpper(input) == "Y"){
  72.                                 //Draw the correct contour green
  73.                                 draw(contours[j], "green");
  74.                                 //Count the drawn contours
  75.                                 count++;
  76.  
  77.                                 labels.push_back(1);
  78.                             }
  79.                             else{
  80.                                 draw(contours[j], "red");
  81.                                 labels.push_back(-1);
  82.                             }
  83.  
  84.                             ///////////////////////###################Get training data############################////////////////////////////
  85.                             //Get aspect ratio
  86.                             Rect rect = boundingRect(contours[j]);
  87.                             double aspect_ratio = float(rect.width) / rect.height;
  88.  
  89.                             //Get extent
  90.                             double rect_area = float(rect.width)*rect.height;
  91.                             double extent = float(contArea) / rect_area;
  92.  
  93.                             //Get solidity
  94.                             vector<Point> hull;
  95.                             convexHull(contours[j], hull);
  96.                             double hull_area = contourArea(hull);
  97.                             double solidity = float(contArea) / hull_area;
  98.  
  99.                             //Get equivalent diameter
  100.                             double equi_diameter = sqrt(4 * contArea / 3.14159265);
  101.  
  102.                             //Perimeter
  103.                             double perimeter = arcLength(contours[j], true);
  104.  
  105.                             //Compile data
  106.                             vector<double> data = { aspect_ratio, extent, solidity, equi_diameter, perimeter, contArea };
  107.  
  108.                             allData.push_back(data);
  109.  
  110.                             //For prediction
  111.                             /*float response = detector.predict(trainingData);
  112.  
  113.                             if (response == 1)
  114.                             {
  115.                                 draw(contours[j], "green");
  116.                                 count++;
  117.                                 all++;
  118.                             }
  119.                             else if (response == -1)
  120.                             {
  121.                                 draw(contours[j], "red");
  122.                                 all++;
  123.                             }*/
  124.                             ////////////////////////////////////////////.........
  125.  
  126.                             break;
  127.                         }
  128.  
  129.                         //Identified as not a push-up (False detection)
  130.                         /*else if (strUpper(input) == "N")
  131.                         {
  132.                             draw(contours[j], "red");
  133.                             break;
  134.                         }*/
  135.                         //Exits training process
  136.                         else if (strUpper(input) == "QUIT")
  137.                             exit = true;
  138.                         //Incorrect entry
  139.                         else
  140.                             cout << endl << "Enter Y or N..." << endl;
  141.                        
  142.                     }
  143.                 }
  144.             }
  145.         }
  146.     }
  147.     int amount = allData.size();
  148.  
  149.     //Produce matrix containing training data
  150.     Mat trainingData(amount, 6, CV_32FC1, &allData);
  151.  
  152.     //Array to label data
  153.     Mat labelsMat(amount, 1, CV_32FC1, &labels);
  154.  
  155.     CvSVMParams params;                                                 //
  156.     params.svm_type = CvSVM::C_SVC;                                     //Required data for training
  157.     params.kernel_type = CvSVM::LINEAR;                                 //
  158.     //params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6);       //
  159.  
  160.     //Append training data to SVM
  161.     detector.train(trainingData, labelsMat, Mat(), Mat(), params);
  162.     //detector.train_auto(trainingData, labelsMat, Mat(), Mat(), params);
  163.  
  164.     //Save the machine learning data
  165.     detector.save("Muskrat_Counting_SVM.xml");
  166.     system("cls");
  167.     info();
  168.     cout << "\n\nTotal Circled: " << count << endl;     //Output how many contours that were detected
  169.     cout << "Total: " << all << endl;               //Total amount of contours before refinement
  170. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement