# Decision Trees

Nov 27th, 2021
806
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import math
2. import sys
3. #the two system arguments
4. train_file = sys.argv[1]
5. test_file = sys.argv[2]
6.
7. #Nodes used to the create the tree
8. class Node:
9.     def __init__(self, entropy, unqualified_indices, identifier, lists, most):
10.         self.entropy = entropy
11.         self.unqualified_indices = unqualified_indices
12.         self.lists = lists
13.         self.children = []
14.         self.identifier = identifier
15.         self.value = -1
16.         self.child_index = -1
17.         self.most = most
18.
19. #used for the entropy of the current node
20. def entropy(list_of_lists, index):
21.     unique_value_count = [0,0,0]
22.     total = 0
23.     for list in list_of_lists:
24.         unique_value_count[list[index]] += 1
25.         total += 1
26.     entropy = 0
27.     if total == 0:
28.         return entropy
29.     for count in unique_value_count:
30.         pi = float(count) / float(total)
31.         if pi != 0:
32.             entropy += -pi * math.log(pi, 2)
33.
34.
35.     return entropy
36.
37. # split the list and find the individual entropies
38. def child_entropy(list_of_lists, index,comp):
39.     entropy_list = [[],[],[]]
40.     entropy_values = [0,0,0]
41.     temp = list_of_lists
42.     ig = entropy(temp,len(temp[0])-1)
43.     for i in temp:
44.         if entropy_list[i[index]] != []:
45.             entropy_list[i[index]].append(i)
46.         else:
47.             entropy_list[i[index]].extend([i])
48.     count = 0
49.
50.     for lists in entropy_list:
51.         entropy_values[count] = entropy(lists, comp)
52.         ig -= float(len(entropy_list[count]))/len(temp) * entropy_values[count]
53.         count +=1
54.     return (ig, entropy_list)
55.
56. #information gain
57. def information_gain(parent_entropy, list_of_lists, unqualified_indices,comp):
58.     entropy_values = [None] * (comp +1)
59.     entropy_list = [None] * (comp + 1)
60.     best_index = -1
61.     ig = 0
62.     for i in range(comp):
63.
64.         entropy_values[i] = 0
65.         if i in unqualified_indices:
66.             continue
67.         entropy_values[i] = child_entropy(list_of_lists, i,comp)[0]
68.         if best_index == -1 or entropy_values[i] > entropy_values[best_index]:
69.             best_index = i
70.
71.     return best_index
72. #creates the tree of Nodes
73. def tree(root, titles):
74.     if root.entropy == 0:
75.         root.value = get_class(root,len(titles)-1)
76.         return
77.     index =  information_gain(root.entropy, root.lists ,root.unqualified_indices,len(titles)-1)
78.     if index == -1:
79.         root.value = get_class(root,len(titles)-1)
80.         return
81.     root.unqualified_indices.append(index)
82.     temp = child_entropy(root.lists, index, len(titles)-1)[1]
83.     root.children = []
84.     x = 0
85.     for i in temp:
86.         c_entropy = entropy(i,len(titles)-1)
87.         id = (titles[index] + " = " + str(x))
88.         root.child_index = index
89.         if id == "class = " + str(x):
90.             id = None
91.         root.children.append(Node(c_entropy, root.unqualified_indices, id, i,root.most))
92.         tree(root.children[x],titles)
93.         x += 1
94.     root.unqualified_indices.remove(index)
95.
96. #The class of a leaf node
97. def get_class(root, index):
98.     unique_value_count = [0,0,0]
99.     for list in root.lists:
100.         unique_value_count[list[index]] += 1
101.     max_index = -1
102.     max_value = -1
103.     for i in range(3):
104.         if unique_value_count[i] > max_value and unique_value_count[i] != 0:
105.             max_index = i
106.             max_value = unique_value_count[i]
107.         elif unique_value_count[i] == max_value and root.most == i:
108.             max_index = i
109.             max_value = unique_value_count[i]
110.
111.     if max_index == -1:
112.         return root.most
113.
114.     return max_index
115.
116. #print the tree of the function
117. def showTree(root,tab):
118.     if root_entropy == 0:
119.         if root.identifier != None:
120.             print(tab + root.identifier + " :")
121.
122.     for i in root.children:
123.         if i.value != -1:
124.             print(tab + i.identifier + " : " + str(i.value))
125.         elif i.identifier != None:
126.             print(tab + i.identifier + " :")
127.         showTree(i,tab + "| ")
128.
129. #accuracy function
130. def accuracy(root,titles, listoflists):
131.     right = 0
132.     wrong = 0
133.     index = -1
134.     temp = root
135.     for i in list_of_lists:
136.         value = root.value
137.         temp = root
138.         while value == -1:
139.             if temp.child_index != -1:
140.                 temp = temp.children[i[temp.child_index]]
141.                 value = temp.value
142.
143.         if value == i[len(root.lists[0])-1]:
144.             right +=1
145.             value = -2
146.         else:
147.             wrong += 1
148.             value = -2
149.
150.     return round(float(right) / float(right + wrong),3)
151.
152. #load the file into a list of a list
154.
155.     list_of_lists = []
156.     for line in train_f:
157.         stripped_line = line.strip()
158.         line_list = [int(x) for x in stripped_line.split()]
159.         if len(line_list) > 0:
160.             list_of_lists.append(line_list)
161.     return list_of_lists
162.
165.     titles = []
167.     stripped_line = line.strip()
168.     titles = stripped_line.split()
169.     return titles
170.
171. #make sure their are two arguments
172. if len(sys.argv) != 3:
173.     print('You must specify only a training data file and test data file in the program parameters; nothing more or less.')
174. else:
175.     #inputting training data
176.     train_f = open(train_file, 'r')
179.     train_f.close()
180.
181.     #create root node
182.     root_entropy = entropy(list_of_lists, len(list_of_lists[0]) - 1)
183.     lister = [len(list_of_lists[0]) - 1]
184.     root = Node (root_entropy, lister, "", list_of_lists,-1)
185.     root.most = get_class(root,len(list_of_lists[0]) - 1)
186.     #create the decision tree
187.     tree(root,titles)
188.     #show decision tree
189.     showTree(root , '')
190.     #show accuracy of training data
191.     print
192.     print("Accuracy on training set (" + str(len(list_of_lists)) + " instances): "+ str(100 * accuracy(root,titles, list_of_lists)) + "%")
193.     #inputting test data
194.     test_f = open(test_file, 'r')