Advertisement
Fabio_LaF

Arquivo principal

Aug 17th, 2022 (edited)
525
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.52 KB | None | 0 0
  1. # file main.py
  2.  
  3. import model
  4. import pgConnector
  5. import dataParser
  6. import attribute
  7. import math
  8.  
  9. # This function assumes that there are "enough" objects on the data set to build a training set with '#training_size' elements
  10. def buid_training_table(dbc, training_table_name, data_table_name, class_list, training_size, class_field = "class"):
  11.   class_amts = {}                                                            # How many of each class will be put on the training table              
  12.   row_amt    = dbc.do_query("select count(*) from " + data_table_name)[0][0] # How many rows on data table
  13.  
  14.   for c in class_list:
  15.     # How many objects of class 'c' are there in the data table
  16.     class_amt_data = dbc.do_query("select count(*) from " + data_table_name + " where " + class_field + " = " + str(c))[0][0]
  17.  
  18.     class_pct       = class_amt_data/row_amt                # What percentage of objects are of class 'c'
  19.     class_amts[c]   = math.floor(class_pct * training_size) # How many objs of class 'c' will be put on the training table
  20.  
  21.   # Creating the training table and adding the correct amt of objects of class 'class_list[0]' on it
  22.   create_query_aux = "select * from " + data_table_name + " where " + class_field + " = " + str(class_list[0]) + " limit " + str(class_amts[class_list[0]])
  23.   dbc.fetchless_query("create table " + training_table_name + " as (" + create_query_aux + ")")
  24.  
  25.   # Adding objs of the other classes on the table
  26.   for c_i in class_list[1:]:
  27.     insert_query = " select * from " + data_table_name + " where " + class_field + " = " + str(c_i) + " limit " + str(class_amts[c_i])
  28.     dbc.fetchless_query("insert into " + training_table_name + insert_query)
  29.  
  30. def run_tests(classify_func, data, verbose_print):
  31.   # Side artifacts of the program
  32.   true_positives  = 0
  33.   true_negatives  = 0
  34.   false_positives = 0
  35.   false_negatives = 0
  36.  
  37.   for obj in data:
  38.     # Classifying each obj
  39.     obj_probs = classify_func(obj[:-1])
  40.  
  41.     c_1_prob = obj_probs[True]
  42.     c_2_prob = obj_probs[False]
  43.  
  44.     real_class       = obj[len(obj)-1].value
  45.     calculated_class = True if c_1_prob > c_2_prob else False
  46.  
  47.     # Updating the appropriate count variable
  48.     if(calculated_class and real_class):
  49.       true_positives += 1
  50.     elif((not calculated_class) and (not real_class)):
  51.       true_negatives += 1
  52.     elif((not calculated_class) and real_class):
  53.       false_negatives += 1
  54.     elif(calculated_class and (not real_class)):
  55.       false_positives += 1
  56.  
  57.   # Main artifacts of the program
  58.   recall    = true_positives/(true_positives + false_negatives)
  59.   precision = true_positives/(true_positives + false_positives)
  60.   f_measure = (2*precision*recall)/(precision+recall)
  61.  
  62.   if verbose_print:
  63.     print("True Positives: " + str(true_positives))
  64.     print("True Negatives: " + str(true_negatives))
  65.     print("False Positives: " + str(false_positives))
  66.     print("False Negatives: " + str(false_negatives))
  67.  
  68.   print("Recall: " + str(recall))
  69.   print("Precision: " + str(precision))
  70.   print("F-Measure: " + str(f_measure))
  71.  
  72. ####################################################################################################################################################
  73.  
  74. ############ Consts
  75.  
  76. DBC                 = pgConnector.PgConnector("postgres", "BatatinhaFrita123", "PGC-II", 'n')
  77. DP                  = dataParser.DataParser(DBC)
  78. DATA_TABLE_NAME     = "teste1_hom_full"
  79. TRAINING_TABLE_NAME = "training_data_for_" + DATA_TABLE_NAME
  80. CLASS_LIST          = [True, False]
  81. VERBOSE             = False
  82. TRAINING_SIZE       = 5000
  83.  
  84. ############ Main
  85.  
  86. buid_training_table(DBC, TRAINING_TABLE_NAME, DATA_TABLE_NAME, CLASS_LIST, TRAINING_SIZE)
  87.  
  88. # The 'parse_objects' function returns a tuple with the field names and the parsed data, but since we're not interested
  89. # on the field names, we discard the first element
  90. _, parsed_data = DP.parse_objects(DATA_TABLE_NAME, excluded_fields=["order_item_seq_id", "id"])
  91. modelo         = model.TreeAugmentedNB[bool](CLASS_LIST, DBC, TRAINING_TABLE_NAME, excluded_fields=["order_item_seq_id", "id"])
  92.  
  93. modelo.train()
  94.  
  95. print("-----------------------------------------------------------------------------")
  96. print("Naive Bayes:")
  97. run_tests(modelo.classify_super, parsed_data, VERBOSE)
  98.  
  99. print("-----------------------------------------------------------------------------")
  100. print("Tree Augmented Naive Bayes:")
  101. run_tests(modelo.classify, parsed_data, VERBOSE)
  102.  
  103. if VERBOSE:
  104.   modelo.print_temp_pairs()
  105.  
  106. print("")
  107.  
  108. DBC.fetchless_query("drop table " + TRAINING_TABLE_NAME)
  109. DBC.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement