Advertisement
Arham-4

Decision Trees

Nov 27th, 2021
917
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.23 KB | None | 0 0
  1. import math
  2. import sys
  3. #the two system arguments
  4. train_file = sys.argv[1]
  5. test_file = sys.argv[2]
  6.  
  7. #Nodes used to the create the tree
  8. class Node:
  9.     def __init__(self, entropy, unqualified_indices, identifier, lists, most):
  10.         self.entropy = entropy
  11.         self.unqualified_indices = unqualified_indices
  12.         self.lists = lists
  13.         self.children = []
  14.         self.identifier = identifier
  15.         self.value = -1
  16.         self.child_index = -1  
  17.         self.most = most
  18.  
  19. #used for the entropy of the current node
  20. def entropy(list_of_lists, index):
  21.     unique_value_count = [0,0,0]
  22.     total = 0
  23.     for list in list_of_lists:
  24.         unique_value_count[list[index]] += 1
  25.         total += 1
  26.     entropy = 0
  27.     if total == 0:
  28.         return entropy
  29.     for count in unique_value_count:
  30.         pi = float(count) / float(total)
  31.         if pi != 0:
  32.             entropy += -pi * math.log(pi, 2)
  33.  
  34.    
  35.     return entropy
  36.  
  37. # split the list and find the individual entropies
  38. def child_entropy(list_of_lists, index,comp):
  39.     entropy_list = [[],[],[]]
  40.     entropy_values = [0,0,0]
  41.     temp = list_of_lists
  42.     ig = entropy(temp,len(temp[0])-1)
  43.     for i in temp:
  44.         if entropy_list[i[index]] != []:
  45.             entropy_list[i[index]].append(i)
  46.         else:
  47.             entropy_list[i[index]].extend([i])
  48.     count = 0
  49.    
  50.     for lists in entropy_list:
  51.         entropy_values[count] = entropy(lists, comp)
  52.         ig -= float(len(entropy_list[count]))/len(temp) * entropy_values[count]
  53.         count +=1
  54.     return (ig, entropy_list)
  55.  
  56. #information gain
  57. def information_gain(parent_entropy, list_of_lists, unqualified_indices,comp):
  58.     entropy_values = [None] * (comp +1)
  59.     entropy_list = [None] * (comp + 1)
  60.     best_index = -1
  61.     ig = 0
  62.     for i in range(comp):
  63.        
  64.         entropy_values[i] = 0
  65.         if i in unqualified_indices:
  66.             continue
  67.         entropy_values[i] = child_entropy(list_of_lists, i,comp)[0]
  68.         if best_index == -1 or entropy_values[i] > entropy_values[best_index]:
  69.             best_index = i
  70.  
  71.     return best_index
  72. #creates the tree of Nodes
  73. def tree(root, titles):
  74.     if root.entropy == 0:
  75.         root.value = get_class(root,len(titles)-1)
  76.         return
  77.     index =  information_gain(root.entropy, root.lists ,root.unqualified_indices,len(titles)-1)
  78.     if index == -1:
  79.         root.value = get_class(root,len(titles)-1)
  80.         return
  81.     root.unqualified_indices.append(index)
  82.     temp = child_entropy(root.lists, index, len(titles)-1)[1]
  83.     root.children = []
  84.     x = 0
  85.     for i in temp:
  86.         c_entropy = entropy(i,len(titles)-1)
  87.         id = (titles[index] + " = " + str(x))
  88.         root.child_index = index
  89.         if id == "class = " + str(x):
  90.             id = None
  91.         root.children.append(Node(c_entropy, root.unqualified_indices, id, i,root.most))
  92.         tree(root.children[x],titles)
  93.         x += 1
  94.     root.unqualified_indices.remove(index)
  95.  
  96. #The class of a leaf node
  97. def get_class(root, index):
  98.     unique_value_count = [0,0,0]
  99.     for list in root.lists:
  100.         unique_value_count[list[index]] += 1
  101.     max_index = -1
  102.     max_value = -1
  103.     for i in range(3):
  104.         if unique_value_count[i] > max_value and unique_value_count[i] != 0:
  105.             max_index = i
  106.             max_value = unique_value_count[i]
  107.         elif unique_value_count[i] == max_value and root.most == i:
  108.             max_index = i
  109.             max_value = unique_value_count[i]
  110.    
  111.     if max_index == -1:
  112.         return root.most
  113.    
  114.     return max_index
  115.  
  116. #print the tree of the function
  117. def showTree(root,tab):
  118.     if root_entropy == 0:
  119.         if root.identifier != None:
  120.             print(tab + root.identifier + " :")
  121.    
  122.     for i in root.children:
  123.         if i.value != -1:
  124.             print(tab + i.identifier + " : " + str(i.value))
  125.         elif i.identifier != None:
  126.             print(tab + i.identifier + " :")
  127.         showTree(i,tab + "| ")
  128.  
  129. #accuracy function
  130. def accuracy(root,titles, listoflists):
  131.     right = 0
  132.     wrong = 0
  133.     index = -1
  134.     temp = root
  135.     for i in list_of_lists:
  136.         value = root.value
  137.         temp = root
  138.         while value == -1:
  139.             if temp.child_index != -1:
  140.                 temp = temp.children[i[temp.child_index]]
  141.                 value = temp.value
  142.    
  143.         if value == i[len(root.lists[0])-1]:
  144.             right +=1
  145.             value = -2
  146.         else:
  147.             wrong += 1
  148.             value = -2
  149.  
  150.     return round(float(right) / float(right + wrong),3)
  151.  
  152. #load the file into a list of a list
  153. def load_file(train_f):
  154.  
  155.     list_of_lists = []
  156.     for line in train_f:
  157.         stripped_line = line.strip()
  158.         line_list = [int(x) for x in stripped_line.split()]
  159.         if len(line_list) > 0:
  160.             list_of_lists.append(line_list)
  161.     return list_of_lists
  162.  
  163. #load the headings
  164. def load_features(train_f):
  165.     titles = []
  166.     line = train_f.readline()
  167.     stripped_line = line.strip()
  168.     titles = stripped_line.split()
  169.     return titles
  170.  
  171. #make sure their are two arguments
  172. if len(sys.argv) != 3:
  173.     print('You must specify only a training data file and test data file in the program parameters; nothing more or less.')
  174. else:
  175.     #inputting training data
  176.     train_f = open(train_file, 'r')
  177.     titles =  load_features(train_f)
  178.     list_of_lists = load_file(train_f)
  179.     train_f.close()
  180.  
  181.     #create root node
  182.     root_entropy = entropy(list_of_lists, len(list_of_lists[0]) - 1)
  183.     lister = [len(list_of_lists[0]) - 1]
  184.     root = Node (root_entropy, lister, "", list_of_lists,-1)
  185.     root.most = get_class(root,len(list_of_lists[0]) - 1)
  186.     #create the decision tree
  187.     tree(root,titles)
  188.     #show decision tree
  189.     showTree(root , '')
  190.     #show accuracy of training data
  191.     print
  192.     print("Accuracy on training set (" + str(len(list_of_lists)) + " instances): "+ str(100 * accuracy(root,titles, list_of_lists)) + "%")
  193.     #inputting test data
  194.     test_f = open(test_file, 'r')
  195.     line = test_f.readline()
  196.     list_of_lists = load_file(test_f)
  197.     test_f.close()
  198.     #accuracy of test data
  199.     print
  200.     print("Accuracy on test set (" + str(len(list_of_lists)) + " instances): "+ str(100 * accuracy(root,titles, list_of_lists)) + "%")
  201.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement