Advertisement
Fabio_LaF

Classes dos classificadores

Aug 17th, 2022
619
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.39 KB | None | 0 0
  1. # file name: model.py
  2.  
  3. import pgConnector
  4. import math
  5. import dataParser
  6. import attribute
  7. import sys
  8. import prim
  9. from typing import TypeVar, Generic, List
  10.  
  11. T = TypeVar('T') # Generics
  12.  
  13. # ---------------------------------------------------------------------------------------------------------
  14. # ---------------------------------------------------------------------------------------------------------
  15. # NB
  16. class NaiveBayes(Generic[T]):
  17.   # constructor
  18.   def __init__(self, classes: List[T], db_connector: pgConnector.PgConnector, table_name: str, class_field:str = "class", excluded_fields:List[str] = []):
  19.     self.classes         = classes                          # List of all classes
  20.     self.class_probs     = {}                               # Dictionary whose elements are as such: [class_name: 0]
  21.     self.attr_probs      = {}                               # Dictionary containing each p(x_j | C_i), one of the main outputs of the model
  22.     self.db_connector    = db_connector                     # Object that creates the interface between the model and the database
  23.     self.table_name      = table_name                       # Name of the table where data is located
  24.     self.excluded_fields = [class_field] + excluded_fields  # Those fields will not be considered by the model in the training process
  25.     self.attr_amt        = {}                               # Dict containing amount of each attribute
  26.     self.c_i_attr_amt    = {}                               # Dict containing amount of each attribute classified as c_i on the training data
  27.  
  28.     self.table_len = self.db_connector.do_query("select count(*) from " + self.table_name)[0][0]
  29.  
  30.     for c_i in self.classes:
  31.       self.class_probs[c_i] = 0
  32.  
  33.   def _calculate_class_probs(self):
  34.     for c_i in self.classes:
  35.       c_i_amt = self.db_connector.do_query("select count(*) from " + self.table_name + " where " + self.excluded_fields[0] + " = " + str(c_i))[0][0]
  36.       self.class_probs[c_i] = c_i_amt/self.table_len
  37.  
  38.       if (self.class_probs[c_i] == 0):
  39.         raise Exception("Bad training sample, no instances of class \"" + str(c_i) + "\" found!")
  40.  
  41.   # Using the Laplace Smoothing to estimate both probabilities
  42.   def _smoothed_p_c_i_attr(self, c_i_attr_amt: int, attr_amt: int):
  43.     pseudocount = 1
  44.     amt_classes = len(self.attr_probs.keys())
  45.    
  46.     return (c_i_attr_amt + pseudocount)/(attr_amt + amt_classes * pseudocount)
  47.  
  48.   def _smoothed_p_attr(self, attr_amt: int, amt_rows: int):
  49.     pseudocount = 1
  50.     amt_classes = len(self.attr_probs.keys())
  51.  
  52.     return (attr_amt + pseudocount * amt_classes)/(amt_rows + pseudocount * amt_classes)
  53.  
  54.   # calculates the probability p(x_j | C_i) for a single attribute x_j and a single class C_i
  55.   def _attr_given_c(self, attr: attribute.Attribute, c_i:T, is_training:bool):
  56.     if (attr in self.attr_probs and c_i in self.attr_probs[attr]):
  57.       return self.attr_probs[attr][c_i]
  58.  
  59.     if is_training and attr not in self.attr_probs:
  60.       self.attr_probs[attr] = {}
  61.  
  62.     c_i_attr_amt = self.c_i_attr_amt[attr, c_i] if (attr, c_i) in self.c_i_attr_amt else 0
  63.     attr_amt     = self.attr_amt[attr] if attr in self.attr_amt else 0
  64.  
  65.     # p(c_i | x_j)
  66.     p_c_i_attr = self._smoothed_p_c_i_attr(c_i_attr_amt, attr_amt)
  67.     # p(x_j)
  68.     p_attr     = self._smoothed_p_attr(attr_amt, self.table_len)
  69.  
  70.     # calculating p(x_j | C_i) itself
  71.     prob = (p_c_i_attr * p_attr)/self.class_probs[c_i]
  72.  
  73.     if is_training:
  74.       self.attr_probs[attr][c_i] = prob
  75.    
  76.     return prob
  77.  
  78.   # Here, @self.table_name must be a table containing labeled data.
  79.   # The labels must coincide in name with the ones in @self.class_probs
  80.   def train(self):
  81.     print("Naive Bayes Classifier training started...")
  82.  
  83.     self._calculate_class_probs()
  84.  
  85.     dp = dataParser.DataParser(self.db_connector)
  86.     _, parsed_data = dp.parse_count(self.table_name, self.excluded_fields[0], self.excluded_fields)
  87.  
  88.     amt_attrs = 1
  89.  
  90.     for (attr, c_i, c_i_attr_amt, attr_amt) in parsed_data:
  91.       self.attr_amt[attr]          = attr_amt
  92.       self.c_i_attr_amt[attr, c_i] = c_i_attr_amt
  93.       self._attr_given_c(attr, c_i, True)
  94.  
  95.       print_str = "\r Currently working on attribute number " + str(amt_attrs)
  96.       print_str += " of " + str(len(parsed_data))
  97.  
  98.       print(print_str, end='')
  99.       sys.stdout.flush()
  100.      
  101.       amt_attrs+=1
  102.  
  103.     print("")
  104.     print("Naive Bayes Classifier training finished!")
  105.  
  106.   def classify(self, obj:List[attribute.Attribute]):
  107.     probs = {}
  108.     for c_i in self.classes:
  109.       probs[c_i] = self.class_probs[c_i]
  110.      
  111.       for attr in obj:
  112.         probs[c_i] *= self._attr_given_c(attr, c_i, False)
  113.  
  114.     return probs
  115.  
  116. # ---------------------------------------------------------------------------------------------------------
  117. # ---------------------------------------------------------------------------------------------------------
  118. # TAN
  119. class TreeAugmentedNB(NaiveBayes, Generic[T]):
  120.   def __init__(self, classes: List[T], db_connector: pgConnector.PgConnector, table_name: str, class_field:str = "class", excluded_fields:List[str] = []):
  121.     super().__init__(classes, db_connector, table_name, class_field, excluded_fields)
  122.  
  123.     self.temp_pairs  = {}   # Dict containing how pairs appear how many times, e.g. self.temp_pairs[n] = x means that x pairs appear n times
  124.     self.pair_amt    = {}   # Dict containing the amount of each pair
  125.     self.pair_ck_amt = {}   # Dict containing the amount of each pair on each class
  126.     self.mutInfo     = {}   # Dict containing each I(X; Y | Z)
  127.     self.trees       = {}   # Dict containing the TAN for each class
  128.  
  129.   def __update_multual_info(self, pair_data: tuple[attribute.Attribute, attribute.Attribute, T, int, int], dict_attr_ck: dict):
  130.     xi          = pair_data[0]
  131.     yj          = pair_data[1]
  132.     ck          = pair_data[2]
  133.     pair_ck_amt = pair_data[3]
  134.     pair_amt    = pair_data[4]
  135.  
  136.     xi_field = xi.field_name
  137.     yj_field = yj.field_name
  138.  
  139.     xi_amt = dict_attr_ck[(xi, ck)][1] if (xi, ck) in dict_attr_ck.keys() else 0
  140.     yj_amt = dict_attr_ck[(yj, ck)][1] if (yj, ck) in dict_attr_ck.keys() else 0
  141.  
  142.     if (xi_field, yj_field) not in self.mutInfo.keys():
  143.       self.mutInfo[(yj_field, xi_field)] = {}
  144.       self.mutInfo[(xi_field, yj_field)] = {}
  145.     if ck not in self.mutInfo[(xi_field, yj_field)].keys():
  146.       self.mutInfo[(xi_field, yj_field)][ck] = 0
  147.       self.mutInfo[(yj_field, xi_field)][ck] = 0
  148.  
  149.     # P(x_i, y_j | c_k)
  150.     prob_xi_yj_given_ck = (self._smoothed_p_c_i_attr(pair_ck_amt, pair_amt) * self._smoothed_p_attr(xi_amt, self.table_len) * self._smoothed_p_attr(yj_amt, self.table_len))/self.class_probs[ck]
  151.  
  152.     # P(x_i, y_j, c_k)
  153.     p_x_y_z             = prob_xi_yj_given_ck * self.class_probs[ck]
  154.  
  155.     # P(x_i | c_k)
  156.     prob_xi_given_ck    = self._attr_given_c(xi, ck, False)
  157.  
  158.     # P(y_j | c_k)
  159.     prob_yj_given_ck    = self._attr_given_c(yj, ck, False)
  160.  
  161.     self.mutInfo[(xi_field, yj_field)][ck] += p_x_y_z * math.log(prob_xi_yj_given_ck/(prob_xi_given_ck*prob_yj_given_ck))
  162.     self.mutInfo[(yj_field, xi_field)][ck] += p_x_y_z * math.log(prob_xi_yj_given_ck/(prob_xi_given_ck*prob_yj_given_ck))
  163.  
  164.   def train(self):
  165.     print("Tree Augmented Naive Bayes Classifier training started...")
  166.     super()._calculate_class_probs()
  167.     super().train()
  168.  
  169.     dp = dataParser.DataParser(self.db_connector)
  170.  
  171.     # (xi, yj, ck, #(xi, yj) that are ck, #(xi, yj))
  172.     pair_parsed_data = dp.parse_count_tan(self.table_name, self.excluded_fields[0], self.excluded_fields)
  173.  
  174.     dict_attr_ck = dp.parse_count(self.table_name, self.excluded_fields[0], self.excluded_fields, True)[1]
  175.  
  176.     # Calculating Mutual infor between each pair of attributes
  177.     for (i, pair) in enumerate(pair_parsed_data):
  178.       xi          = pair[0]
  179.       yj          = pair[1]
  180.       ck          = pair[2]
  181.       ck_pair_amt = pair[3]
  182.       pair_amt    = pair[4]
  183.  
  184.       self.pair_amt[(xi, yj)] = pair_amt
  185.       self.pair_amt[(yj, xi)] = pair_amt
  186.  
  187.       if (xi, yj) not in self.pair_ck_amt:
  188.         self.pair_ck_amt[(xi, yj)] = {}
  189.         self.pair_ck_amt[(yj, xi)] = {}
  190.  
  191.       self.pair_ck_amt[(xi, yj)][ck] = ck_pair_amt
  192.       self.pair_ck_amt[(yj, xi)][ck] = ck_pair_amt
  193.  
  194.       self.__update_multual_info(pair, dict_attr_ck)
  195.  
  196.       print_str = "\r Currently working on pair #{}".format(i+1)
  197.       print_str += " of " + str(len(pair_parsed_data))
  198.  
  199.       print(print_str, end='')
  200.       sys.stdout.flush()
  201.  
  202.     print("")
  203.  
  204.     # Building a maximum weight tree for each class
  205.     nodes = list(set([key for sublist in self.mutInfo.keys() for key in sublist]))
  206.  
  207.     weights_by_class = {}
  208.     for c in self.classes:
  209.       weights_by_class[c] = {}
  210.  
  211.       for key in self.mutInfo.keys():
  212.         weights_by_class[c][key] = self.mutInfo[key][c]
  213.  
  214.     for c in self.classes:
  215.       self.trees[c] = prim.MaxTree(nodes, weights_by_class[c])
  216.    
  217.     print("Tree Augmented Naive Bayes Classifier training finished!\n")
  218.  
  219.   def classify(self, obj: List[attribute.Attribute]):
  220.     probs = {}
  221.  
  222.     for c_k in self.classes:
  223.       probs[c_k] = self.class_probs[c_k]
  224.  
  225.       for attr in obj:
  226.         if attr.field_name == self.trees[c_k].tree_root:
  227.           probs[c_k] *= self._attr_given_c(attr, c_k, False)
  228.         else:
  229.           parent = self.trees[c_k].attr_parent_of(attr, obj)
  230.  
  231.           attr_amt    = self.attr_amt[attr] if attr in self.attr_amt else 0
  232.           parent_amt  = self.attr_amt[parent] if parent in self.attr_amt else 0
  233.           pair_amt    = self.pair_amt[(attr, parent)] if (attr, parent) in self.pair_amt else 0
  234.           pair_ck_amt = self.pair_ck_amt[(attr, parent)][c_k] if (attr, parent) in self.pair_ck_amt and c_k in self.pair_ck_amt[(attr, parent)] else 0
  235.  
  236.           if pair_amt not in self.temp_pairs:
  237.             self.temp_pairs[pair_amt] = 0
  238.  
  239.           self.temp_pairs[pair_amt] += 1
  240.  
  241.           prob_c_given_pair = self._smoothed_p_c_i_attr(pair_ck_amt, pair_amt)
  242.           prob_pair         = self._smoothed_p_attr(attr_amt, self.table_len) * self._smoothed_p_attr(parent_amt, self.table_len)
  243.  
  244.           numerator   = prob_c_given_pair * prob_pair
  245.           denominator = self.class_probs[c_k] * self._attr_given_c(parent, c_k, False)
  246.  
  247.           probs[c_k] *= numerator/denominator
  248.  
  249.     return probs
  250.  
  251.   # returns the NB output for the given obj
  252.   def classify_super(self, obj: List[attribute.Attribute]):
  253.     return super().classify(obj)
  254.  
  255.   def print_temp_pairs(self):
  256.     print(self.temp_pairs)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement