lucasamparo

ANN_MLP with OpenCV

May 12th, 2016
164
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #include <opencv2/highgui.hpp>
  2. #include <opencv2/ml.hpp>
  3.  
  4. using namespace cv;
  5. using namespace cv::ml;
  6. using namespace std;
  7.  
  8. int main( int argc, char** argv ){
  9.     Mat matTrainFeatures(100,1,CV_32F);
  10.     randu(matTrainFeatures,0,100);
  11.    
  12.     Mat matTrainLabels(100,1,CV_32F);
  13.     randu(matTrainLabels,0,100);
  14.    
  15.     Mat matSample(4,1,CV_32F);
  16.     //randu(matSample,0,100);
  17.     matSample.at<float>(0,0) = matTrainFeatures.at<float>(0,0);
  18.     matSample.at<float>(1,0) = matTrainFeatures.at<float>(1,0);
  19.     matSample.at<float>(2,0) = matTrainFeatures.at<float>(2,0);
  20.     matSample.at<float>(3,0) = matTrainFeatures.at<float>(3,0);
  21.    
  22.     Mat matSampleLabels(1,1,CV_32F);
  23.  
  24.     Mat matResults(4,1,CV_32F);
  25.     //Mat matResults(5,1,CV_32F);
  26.  
  27.     Ptr<TrainData> trainingData;
  28.     trainingData=TrainData::create(matTrainFeatures,ROW_SAMPLE,matTrainLabels);
  29.  
  30.     Ptr<ANN_MLP> lr = ANN_MLP::create();
  31.     Mat layers(3,1,CV_32FC1);
  32.     layers.row(0) = 1;
  33.     layers.row(1) = 100;
  34.     layers.row(2) = 1;
  35.     lr->setBackpropMomentumScale(0.05f);
  36.     lr->setLayerSizes(layers);
  37.     lr->setBackpropWeightScale(0.05f);
  38.     lr->setActivationFunction(ANN_MLP::SIGMOID_SYM, 1, 1);
  39.     lr->setTrainMethod(ANN_MLP::BACKPROP, 0.001);
  40.     lr->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 300, FLT_EPSILON));
  41.    
  42.     //lr->train(trainingData);
  43.     for(int j = 0; j < 12; j++){
  44.         for(int i = 0; i < matTrainFeatures.rows; i++){
  45.             lr->train(matTrainFeatures.row(i),ROW_SAMPLE,matTrainLabels.row(i));   
  46.         }  
  47.     }  
  48.     lr->predict(matSample,matResults);
  49.    
  50.     //Just checking the settings
  51.     cout<<"Training data: "<<endl
  52.         <<"getNSample\t"<<trainingData->getNSamples()<<endl
  53.         <<"getSamples\n"<<trainingData->getSamples()<<endl
  54.         <<"getResponses\n"<<trainingData->getTrainResponses()<<endl
  55.         <<endl;
  56.  
  57.     //confirming sample order
  58.     cout<<"matSample: "<<endl
  59.         <<matSample<<endl
  60.         <<endl;
  61.  
  62.     //displaying the results
  63.     cout<<"matResults: "<<endl
  64.         <<matResults<<endl
  65.         <<endl;
  66.    
  67.     return 0;
  68. }
RAW Paste Data