Advertisement
Arham-4

Naive Bayes

Nov 27th, 2021
948
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.14 KB | None | 0 0
  1. import math
  2. import sys
  3.  
  4. def load_features(train_f):
  5.     features = {}
  6.     feature_for_index = {}
  7.     index = 0
  8.     for feature in train_f.readline().replace('\n', '').split('\t'):
  9.         features[feature] = []
  10.         feature_for_index[index] = feature
  11.         index += 1
  12.  
  13.     for line in train_f:
  14.         split = line.replace('\n', '').split('\t')
  15.         for i in range(len(split)):
  16.             features[feature_for_index[i]].append(int(split[i]))
  17.  
  18.     return features
  19.  
  20. def frequency_table(values):
  21.     unique_value_count = {}
  22.     for value in values:
  23.         if value not in unique_value_count:
  24.             unique_value_count[value] = 1
  25.         else:
  26.             unique_value_count[value] += 1
  27.     return unique_value_count
  28.  
  29. def filter_feature_for_class(features, clazz, feature):
  30.     filtered_values = []
  31.     for i in range(len(features['class'])):
  32.         if features['class'][i] == clazz:
  33.             filtered_values.append(features[feature][i])
  34.     return filtered_values
  35.  
  36. def print_learning(features):
  37.     class_freq_table = frequency_table(features['class'])
  38.     total = sum(class_freq_table.values())
  39.     list = []
  40.     for clazz in sorted(class_freq_table.keys()):
  41.         class_total = class_freq_table[clazz]
  42.         class_p = class_total / total
  43.         print('P(class=' + str(clazz) + ')=%.2f ' % class_p, end='')
  44.         x =  class_p
  45.         list.append(x)
  46.         for feature in features.keys():
  47.             if feature == 'class':
  48.                 continue
  49.             values_for_feature = filter_feature_for_class(features, clazz, feature)
  50.             feature_freq_table = frequency_table(values_for_feature)
  51.            
  52.             for feature_value in sorted(feature_freq_table.keys()):
  53.                 count = feature_freq_table[feature_value]
  54.                 feature_value_p = count / class_total
  55.                 print('P(' + feature + '=' + str(feature_value) + '|' + str(clazz) + ')=%.2f ' % feature_value_p, end='')
  56.                 x =  feature_value_p
  57.                 list.append(x)
  58.        
  59.         print()
  60.     return list
  61.  
  62. def accuracy(list, file):
  63.     right = 0
  64.     wrong = 0
  65.     train_f = open(file, 'r')
  66.     train_f.readline()
  67.     list_of_lists = []
  68.     for line in train_f:
  69.         stripped_line = line.strip()
  70.         line_list = [int(x) for x in stripped_line.split()]
  71.         if len(line_list) > 0:
  72.             list_of_lists.append(line_list)
  73.            
  74.    
  75.     for lists in list_of_lists:
  76.         classzero = math.log(float(list[0]),2)    
  77.         classone = math.log(float(list[len(lists)*2-1]),2)
  78.         for val in range(len(lists)-1):
  79.            
  80.            
  81.  
  82.             classzero += math.log(float(list[val*2 + lists[val]+1]),2)
  83.             classone += math.log(float(list[len(lists)*2+ val*2 + lists[val]]),2)
  84.    
  85.         if(classzero > classone):
  86.             if(lists[len(lists)-1] == 0):
  87.                 right +=1
  88.             else:
  89.                 wrong +=1
  90.            
  91.         elif(classone > classzero):
  92.             if(lists[len(lists)-1] == 1):
  93.                 right +=1
  94.             else:
  95.                 wrong +=1
  96.         else:
  97.             if(list[0]>=.5):
  98.                 if(lists[len(lists)-1] == 0):
  99.                     right +=1
  100.                 else:
  101.                     wrong +=1
  102.             else:
  103.                 if(lists[len(lists)-1] == 1):
  104.                     right +=1
  105.                 else:
  106.                     wrong +=1
  107.     return (round(float(right) / float(right + wrong),4),(right + wrong))
  108.  
  109.  
  110. if len(sys.argv) != 3:
  111.     print('You must specify only a training data file and test data file in the program parameters; nothing more or less.')
  112. else:
  113.     train_file = sys.argv[1]
  114.     test_file = sys.argv[2]
  115.  
  116.     train_f = open(train_file, 'r')
  117.     features = load_features(train_f)
  118.    
  119.     list = print_learning(features)
  120.     print()
  121.     acc = accuracy(list,train_file,features)
  122.     print("Accuracy on training set (" + str(acc[1]) + " instances): "+ str(100 * acc[0]) + "%")
  123.     print()
  124.     acc = accuracy(list,test_file,features)
  125.     print("Accuracy on training set (" + str(acc[1]) + " instances): "+ str(100 * acc[0]) + "%")
  126.    
  127.    
  128.  
  129.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement