Advertisement
makispaiktis

ML - Naive Bayes Classifier

Oct 10th, 2022 (edited)
1,081
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
MatLab 5.66 KB | None | 0 0
  1. clear all
  2. clc
  3.  
  4.  
  5. % *************************************************************************
  6. % *************************************************************************
  7. % Theory
  8. % *************************************************************************
  9. % *************************************************************************
  10. % max{ P(C | A1, A2, A3, A4) }      <-- Bayes -->
  11. % max{ P(A1, A2, A3, A4 | C) * P(C) }
  12. % P(A1, A2, A3, A4 | C) = P(A1|C) * P(A2|C) * P(A3|C) * P(A4|C)
  13. % Y = Yes, N = No, S = Sometimes, M = Mammals, NM = Non-Mammals
  14.  
  15.  
  16.  
  17. % *************************************************************************
  18. % *************************************************************************
  19. % Database
  20. % *************************************************************************
  21. % *************************************************************************
  22. chars = ["Name", "Give Birth", "Can Fly", "Live in water", "Have legs", "Class"];
  23. names = ["Human", "Python", "Salmon", "Whale", "Frog", "Komodo", "Bat", "Pigeon", "Cat", "Leopard Shark", "Turtle", "Penguin", "Porcupine", "Eel", "Salamander", "Gila Monster", "Platypus", "Owl", "Dolphin", "Eagle"]';
  24. births = ["Y", "N", "N", "Y", "N", "N", "Y", "N", "Y", "Y", "N", "N", "Y", "N", "N", "N", "N", "N", "Y", "N"]';
  25. flys = ["N", "N", "N", "N", "N", "N", "Y", "Y", "N", "N", "N", "N", "N", "N", "N", "N", "N", "Y", "N", "Y"]';
  26. waters = ["N", "N", "Y", "Y", "S", "N", "N", "N", "N", "Y", "S", "S", "N", "Y", "S", "N", "N", "N", "Y", "N"]';
  27. legss = ["Y", "N", "N", "N", "Y", "Y", "Y", "Y", "Y", "N", "Y", "Y", "Y", "N", "Y", "Y", "Y", "Y", "N", "Y"]';
  28. classes = ["M", "NM", "NM", "M", "NM", "NM", "M", "NM", "M", "NM", "NM", "NM", "M", "NM", "NM", "NM", "M", "NM", "M", "NM"]';
  29. matrix = [names births flys waters legss classes]
  30. ROWS = size(matrix, 1);
  31. COLS = size(matrix, 2);
  32.  
  33. un_births = unique(births);
  34. un_flys = unique(flys);
  35. un_waters = unique(waters);
  36. un_legss = unique(legss);
  37.  
  38. % Random input
  39. name = "Random Animal";
  40. birth = un_births(randi([1 length(un_births)]));
  41. fly = un_flys(randi([1 length(un_flys)]));
  42. water = un_waters(randi([1 length(un_waters)]));
  43. legs = un_legss(randi([1 length(un_legss)]));
  44. input = [name birth fly water legs];
  45. % input = [name "Y" "N" "Y" "N"]
  46. disp(mat2str(input));
  47.  
  48. % First, I have to count how many mammals and non-mammals are there
  49. MAM_VEC = matrix(:, COLS);
  50. indices_M = find(MAM_VEC == "M");
  51. indices_NM = find(MAM_VEC == "NM");
  52. num_M = length(indices_M);
  53. num_NM = length(indices_NM);
  54. prob_M = num_M / ROWS;
  55. prob_NM = num_NM / ROWS;
  56.  
  57.  
  58.  
  59.  
  60.  
  61.  
  62. % *************************************************************************
  63. % *************************************************************************
  64. % Algorithm
  65. % *************************************************************************
  66. % *************************************************************************
  67.  
  68.  
  69. % Let class 1 be the correct class ("M" = "Mammals")
  70. % I will check in the indices of "indices_M"
  71. for i = 1 : num_M
  72.     index = indices_M(i);
  73.     MAM_BIRTHS_VEC(i) = births(index);
  74. end
  75. num1_M = length(find(MAM_BIRTHS_VEC == input(2)));
  76.  
  77. for i = 1 : num_M
  78.     index = indices_M(i);
  79.     MAM_FLYS_VEC(i) = flys(index);
  80. end
  81. num2_M = length(find(MAM_FLYS_VEC == input(3)));
  82.  
  83. for i = 1 : num_M
  84.     index = indices_M(i);
  85.     MAM_WATERS_VEC(i) = waters(index);
  86. end
  87. num3_M = length(find(MAM_WATERS_VEC == input(4)));
  88.  
  89. for i = 1 : num_M
  90.     index = indices_M(i);
  91.     MAM_LEGSS_VEC(i) =legss(index);
  92. end
  93. num4_M = length(find(MAM_LEGSS_VEC == input(5)));
  94. P1 = PROBABILITY(ROWS, prob_M, num_M, num1_M, num2_M, num3_M, num4_M, "Mammal");
  95.  
  96.  
  97. % Let class 2 be the correct class ("NM" = "Non-Mammals")
  98. for i = 1 : num_NM
  99.     index = indices_NM(i);
  100.     NMAM_BIRTHS_VEC(i) = births(index);
  101. end
  102. num1_NM = length(find(NMAM_BIRTHS_VEC == input(2)));
  103.  
  104. for i = 1 : num_NM
  105.     index = indices_NM(i);
  106.     NMAM_FLYS_VEC(i) = flys(index);
  107. end
  108. num2_NM = length(find(NMAM_FLYS_VEC == input(3)));
  109.  
  110. for i = 1 : num_NM
  111.     index = indices_NM(i);
  112.     NMAM_WATERS_VEC(i) = waters(index);
  113. end
  114. num3_NM = length(find(NMAM_WATERS_VEC == input(4)));
  115.  
  116. for i = 1 : num_NM
  117.     index = indices_NM(i);
  118.     NMAM_LEGSS_VEC(i) =legss(index);
  119. end
  120. num4_NM = length(find(NMAM_LEGSS_VEC == input(5)));
  121. P2 = PROBABILITY(ROWS, prob_NM, num_NM, num1_NM, num2_NM, num3_NM, num4_NM, "Non-Mammal");
  122. compare(P1, P2, input);
  123.  
  124.  
  125. % AUXILIARY FUNCTIONS
  126. function prob = PROBABILITY(ROWS, prob_M, num_M, num1_M, num2_M, num3_M, num4_M, STRING)
  127.     if(STRING == "Mammal")
  128.         abbr = " M";
  129.     else
  130.         abbr = "NM";
  131.     end
  132.     prob1 = num1_M / num_M;
  133.     prob2 = num2_M / num_M;
  134.     prob3 = num3_M / num_M;
  135.     prob4 = num4_M / num_M;
  136.     disp("**************** " + STRING + " ****************");
  137.     P_A_M = prob1 * prob2 * prob3 * prob4;
  138.     disp("Pr(A|" + abbr + ") = " + num2str(num1_M) + "/" + num2str(num_M) + " * " + num2str(num2_M) + "/" + num2str(num_M) + " * " + num2str(num3_M) + "/" + num2str(num_M) + " * " + num2str(num4_M) + "/" + num2str(num_M) + " = ");
  139.     disp("         = " + num2str(prob1) + " * " + num2str(prob2) + " * " + num2str(prob3) + " * " + num2str(prob4) + " = " + P_A_M);
  140.     disp("Pr(" + abbr + ") = " + num2str(num_M) + "/" + num2str(ROWS) + " = " + num2str(prob_M));
  141.     prob = P_A_M * prob_M;
  142.     display(' ');
  143.     disp("Pr(A|" + abbr + ")Pr(" + abbr + ") = " + num2str(prob));
  144.     display(' ');
  145.     display(' ');
  146. end
  147.  
  148. function compare(P1, P2, input)
  149.     if P1 >= P2
  150.         disp(mat2str(input) + " = Mammal");
  151.         disp("because: " + num2str(P1) + " >= " + num2str(P2));
  152.     else
  153.         disp(mat2str(input) + " = Non-Mammal");
  154.         disp("because: " + num2str(P1) + " < " + num2str(P2));
  155.     end
  156. end
  157.  
  158.  
  159.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement