Advertisement
Guest User

Untitled

a guest
Dec 11th, 2019
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.71 KB | None | 0 0
  1. void bayesian_classifier() {
  2. Mat img;
  3. const int featureSize = 28 * 28;
  4. const int noOfInstances = 60000;
  5. const int noOfClasses = 10;
  6. const int noOfTestInstances = 10000;
  7. Mat X = Mat(noOfInstances, featureSize, CV_8UC1);
  8. Mat XTest = Mat(noOfTestInstances, featureSize, CV_8UC1);
  9. int rowX = 0;
  10. int rowXTest = 0;
  11. Mat y(noOfInstances, 1, CV_8UC1);
  12. Mat yTest(noOfTestInstances, 1, CV_8UC1);
  13.  
  14. char classes[noOfClasses][10] =
  15. { "0", "1", "2", "3", "4", "5", "6", "7", "8", "9" };
  16.  
  17. //to hold a-priori probability
  18. double priors[5];
  19. int elementsOfClassTrain[5];
  20. int priorNr;
  21. //load the train instances
  22. char fname[256];
  23. for (int i = 0; i < noOfClasses; i++) {
  24. priorNr = 0;
  25. for(;;)
  26. {
  27. sprintf(fname, "images_bayes/train/%s/%06d.png", classes[i], priorNr);
  28. priorNr++;
  29. img = imread(fname, CV_8UC1);
  30. if (0 == img.cols )
  31. {
  32. std::cout << "\n All photos " << i << " read from train folder " << priorNr;
  33. break;
  34. }
  35. Mat binary = grayToBinary(img, 128);
  36.  
  37. int d = 0;
  38. for (int r = 0; r < binary.rows; r++) {
  39. for (int c = 0; c < binary.cols; c++) {
  40. X.at<uchar>(rowX, d) = binary.at<uchar>(r, c);
  41. d++;
  42. }
  43. }
  44. y.at<uchar>(rowX) = i;
  45. rowX++;
  46. }
  47. priors[i] = priorNr / (double)noOfInstances;
  48. elementsOfClassTrain[i] = priorNr;
  49. }
  50. char classesT[noOfClasses][10] =
  51. { "0", "1", "2", "3", "4", "5", "6", "7", "8", "9" };
  52. //loatd the test instances
  53. char fnameTest[256];
  54. for (int i = 0; i < noOfClasses; i++) {
  55. priorNr = 0;
  56. while (1) {
  57. sprintf(fnameTest, "images_bayes/test/%s/%06d.png", classesT[i], priorNr);
  58. priorNr++;
  59. Mat img = imread(fnameTest, CV_8UC1);
  60. if (img.cols == 0) {
  61. std::cout << "\n All photos from class " << i << " read from test folder : " << priorNr;
  62. break;
  63. }
  64. Mat binary = grayToBinary(img, 128);
  65.  
  66. int d = 0;
  67. for (int r = 0; r < binary.rows; r++) {
  68. for (int c = 0; c < binary.cols; c++) {
  69. XTest.at<uchar>(rowXTest, d) = binary.at<uchar>(r, c);
  70. d++;
  71. }
  72. }
  73. yTest.at<uchar>(rowXTest) = i;
  74. rowXTest++;
  75. }
  76. }
  77.  
  78. //compute likelihood
  79. //w/ laplace smoothing
  80. Mat likelihood = Mat::zeros(noOfClasses, featureSize, CV_64FC1);
  81. for (int k = 0; k < noOfInstances; k++) {
  82. for (int d = 0; d < featureSize; d++) {
  83. if (X.at<uchar>(k, d) == 255) {
  84. likelihood.at<double>((int)y.at<uchar>(k), d) += 1.0;
  85. }
  86. }
  87. }
  88. //laplace smoothing
  89. for (int r = 0; r < likelihood.rows; r++) {
  90. for (int c = 0; c < likelihood.cols; c++) {
  91. double value = likelihood.at<double>(r, c) + 1.0;
  92. likelihood.at<double>(r, c) = (value / (double)(noOfClasses + elementsOfClassTrain[r]));
  93. }
  94. }
  95.  
  96. //classify test images
  97. Mat C = Mat::zeros(noOfClasses, noOfClasses, CV_32F);
  98. for (int count = 0; count < noOfTestInstances; count++) {
  99. Mat randImg = XTest.row(count);
  100. double classProbs[noOfClasses];
  101. for (int c = 0; c < noOfClasses; c++) {
  102. classProbs[c] = log(priors[c]);
  103. for (int j = 0; j < featureSize; j++) {
  104. if (randImg.at<uchar>(0, j) == 255) {
  105. classProbs[c] += log(likelihood.at<double>(c, j));
  106. }
  107. else {
  108. classProbs[c] += log(1.0f - likelihood.at<double>(c, j));
  109. }
  110. }
  111. }
  112. double max = *std::max_element(classProbs, classProbs + noOfClasses);
  113. int predictedClass = -1;
  114. for (int i = 0; i < noOfClasses; i++) {
  115. if (max == classProbs[i]) {
  116. predictedClass = i;
  117. }
  118. }
  119. C.at<float>(predictedClass, yTest.at<uchar>(count)) += 1.0;
  120. }
  121. double accuracy = getAccuracyFromConfusionMatrix(C);
  122. std::cout << "\nAccuracy: " << accuracy << std::endl;
  123. cout << endl;
  124. for (int i = 0; i < C.rows; i++) {
  125. for (int j = 0; j < C.cols; j++)
  126. cout << C.at<float>(i, j) << " ";
  127.  
  128. cout << endl;
  129. }
  130. waitKey(0);
  131. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement