fake_world

ml3

Dec 3rd, 2020
477
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import math
  2. import csv
  3.  
  4. def load_csv(filename):  
  5.     lines=csv.reader(open(filename,"r"));
  6.     dataset = list(lines)
  7.     headers = dataset.pop(0)
  8.     return dataset,headers
  9.  
  10.  
  11. class Node:
  12.     def __init__(self,attribute):
  13.         self.attribute=attribute
  14.         self.children=[]
  15.         self.answer=""
  16.  
  17. def subtables(data,col,delete):  
  18.     dic={}
  19.     coldata=[row[col] for row in data]
  20.     attr=list(set(coldata))
  21.  
  22.     counts=[0]*len(attr)
  23.     r=len(data)
  24.     c=len(data[0])
  25.     for x in range(len(attr)):
  26.         for y in range(r):
  27.             if data[y][col]==attr[x]:
  28.                 counts[x]+=1
  29.  
  30.     for x in range(len(attr)):
  31.         dic[attr[x]]=[[0 for i in range(c)] for j in range(counts[x])]
  32.         pos=0
  33.         for y in range(r):
  34.             if data[y][col]==attr[x]:
  35.                 if delete:
  36.                     del data[y][col]
  37.                 dic[attr[x]][pos]=data[y]
  38.                 pos+=1
  39.     return attr,dic
  40. def entropy(S):  
  41.     attr=list(set(S))
  42.     if len(attr)==1:
  43.         return 0
  44.  
  45.     counts=[0,0]
  46.     for i in range(2):
  47.         counts[i]=sum([1 for x in S if attr[i]==x])/(len(S)*1.0)
  48.  
  49.     sums=0
  50.     for cnt in counts:
  51.         sums+=-1*cnt*math.log(cnt,2)
  52.     return sums
  53.  
  54.  
  55. def compute_gain(data,col):  
  56.     attr,dic = subtables(data,col,delete=False)
  57.     total_size=len(data)
  58.     entropies=[0]*len(attr)
  59.     ratio=[0]*len(attr)
  60.     total_entropy=entropy([row[-1] for row in data])
  61.     for x in range(len(attr)):
  62.         ratio[x]=len(dic[attr[x]])/(total_size*1.0)
  63.         entropies[x]=entropy([row[-1] for row in dic[attr[x]]])
  64.         total_entropy-=ratio[x]*entropies[x]
  65.     return total_entropy
  66.  
  67.  
  68. def build_tree(data,features):   
  69.     lastcol=[row[-1] for row in data]
  70.     if(len(set(lastcol)))==1:
  71.         node=Node("")
  72.         node.answer=lastcol[0]
  73.         return node
  74.  
  75.     n=len(data[0])-1
  76.     gains=[0]*n
  77.     for col in range(n):
  78.         gains[col]=compute_gain(data,col)
  79.     split=gains.index(max(gains))
  80.     node=Node(features[split])
  81.     fea = features[:split]+features[split+1:]
  82.     attr,dic=subtables(data,split,delete=True)
  83.  
  84.     for x in range(len(attr)):
  85.         child=build_tree(dic[attr[x]],fea)
  86.         node.children.append((attr[x],child))
  87.     return node
  88.  
  89.  
  90. def print_tree(node,level):  
  91.     if node.answer!="":
  92.         print(" "*level,node.answer)
  93.         return
  94.  
  95.     print(" "*level,node.attribute)
  96.     for value,n in node.children:
  97.         print(" "*(level+1),value)
  98.         print_tree(n,level+2)
  99.  
  100.  
  101. def classify(node,x_test,features):  
  102.     if node.answer!="":
  103.         print(node.answer)
  104.         return
  105.     pos=features.index(node.attribute)
  106.     for value, n in node.children:
  107.         if x_test[pos]==value:
  108.             classify(n,x_test,features)
  109.  
  110. '''Main program''' 
  111. dataset,features=load_csv("data3.csv")
  112. node1=build_tree(dataset,features)
  113. print("The decision tree for the dataset using ID3 algorithm is")
  114. print_tree(node1,0)
  115. testdata,features=load_csv("data3_test.csv")
  116. for xtest in testdata:
  117.     print("The test instance:",xtest)
  118.     print("The label for test instance:",end="  ")
  119.     classify(node1,xtest,features)
RAW Paste Data