Advertisement
rajath_pai

Entropy

Mar 15th, 2021 (edited)
181
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.96 KB | None | 0 0
  1. import pandas as pd
  2. import math
  3. import copy
  4.  
  5. dataset = pd.read_csv('tennis.csv')
  6. X = dataset.iloc[:, 1:].values
  7. print(X)
  8. attribute = ['outlook', 'temp', 'humidity', 'wind']
  9.  
  10.  
  11. class Node(object):
  12.     def __init__(self):
  13.         self.value = None
  14.         self.decision = None
  15.         self.childs = None
  16.  
  17.  
  18. def findEntropy(data, rows):
  19.     yes = 0
  20.     no = 0
  21.     ans = -1
  22.     idx = len(data[0]) - 1
  23.     entropy = 0
  24.     for i in rows:
  25.         if data[i][idx] == 'Yes':
  26.             yes = yes + 1
  27.         else:
  28.             no = no + 1
  29.  
  30.     x = yes/(yes+no)
  31.     y = no/(yes+no)
  32.     if x != 0 and y != 0:
  33.         entropy = -1 * (x*math.log2(x) + y*math.log2(y))
  34.     if x == 1:
  35.         ans = 1
  36.     if y == 1:
  37.         ans = 0
  38.     return entropy, ans
  39.  
  40.  
  41. def findMaxGain(data, rows, columns):
  42.     maxGain = 0
  43.     retidx = -1
  44.     entropy, ans = findEntropy(data, rows)
  45.     if entropy == 0:
  46.         """if ans == 1:
  47.            print("Yes")
  48.        else:
  49.            print("No")"""
  50.         return maxGain, retidx, ans
  51.  
  52.     for j in columns:
  53.         mydict = {}
  54.         idx = j
  55.         for i in rows:
  56.             key = data[i][idx]
  57.             if key not in mydict:
  58.                 mydict[key] = 1
  59.             else:
  60.                 mydict[key] = mydict[key] + 1
  61.         gain = entropy
  62.  
  63.         # print(mydict)
  64.         for key in mydict:
  65.             yes = 0
  66.             no = 0
  67.             for k in rows:
  68.                 if data[k][j] == key:
  69.                     if data[k][-1] == 'Yes':
  70.                         yes = yes + 1
  71.                     else:
  72.                         no = no + 1
  73.             # print(yes, no)
  74.             x = yes/(yes+no)
  75.             y = no/(yes+no)
  76.             # print(x, y)
  77.             if x != 0 and y != 0:
  78.                 gain += (mydict[key] * (x*math.log2(x) + y*math.log2(y)))/14
  79.         # print(gain)
  80.         if gain > maxGain:
  81.             # print("hello")
  82.             maxGain = gain
  83.             retidx = j
  84.  
  85.     return maxGain, retidx, ans
  86.  
  87.  
  88. def buildTree(data, rows, columns):
  89.  
  90.     maxGain, idx, ans = findMaxGain(X, rows, columns)
  91.     root = Node()
  92.     root.childs = []
  93.     # print(maxGain
  94.     #
  95.     # )
  96.     if maxGain == 0:
  97.         if ans == 1:
  98.             root.value = 'Yes'
  99.         else:
  100.             root.value = 'No'
  101.         return root
  102.  
  103.     root.value = attribute[idx]
  104.     mydict = {}
  105.     for i in rows:
  106.         key = data[i][idx]
  107.         if key not in mydict:
  108.             mydict[key] = 1
  109.         else:
  110.             mydict[key] += 1
  111.  
  112.     newcolumns = copy.deepcopy(columns)
  113.     newcolumns.remove(idx)
  114.     for key in mydict:
  115.         newrows = []
  116.         for i in rows:
  117.             if data[i][idx] == key:
  118.                 newrows.append(i)
  119.         # print(newrows)
  120.         temp = buildTree(data, newrows, newcolumns)
  121.         temp.decision = key
  122.         root.childs.append(temp)
  123.     return root
  124.  
  125.  
  126. def traverse(root):
  127.     print(root.decision)
  128.     print(root.value)
  129.  
  130.     n = len(root.childs)
  131.     if n > 0:
  132.         for i in range(0, n):
  133.             traverse(root.childs[i])
  134.  
  135.                  
  136. def calculate():
  137.     rows = [i for i in range(0, 14)]
  138.     columns = [i for i in range(0, 4)]
  139.     root = buildTree(X, rows, columns)
  140.     root.decision = 'Start'
  141.     traverse(root)
  142.  
  143.  
  144. calculate()                  
  145.  
  146. """
  147. OUTPUT
  148.  
  149. [['Sunny' 'Hot' 'High' 'Weak' 'No']
  150. ['Sunny' 'Hot' 'High' 'Strong' 'No']
  151. ['Overcast' 'Hot' 'High' 'Weak' 'Yes']
  152. ['Rain' 'Mild' 'High' 'Weak' 'Yes']
  153. ['Rain' 'Cool' 'Normal' 'Weak' 'Yes']
  154. ['Rain' 'Cool' 'Normal' 'Strong' 'No']
  155. ['Overcast' 'Cool' 'Normal' 'Strong' 'Yes']
  156. ['Sunny' 'Mild' 'High' 'Weak' 'No']
  157. ['Sunny' 'Cool' 'Normal' 'Weak' 'Yes']
  158. ['Rain' 'Mild' 'Normal' 'Weak' 'Yes']
  159. ['Sunny' 'Mild' 'Normal' 'Strong' 'Yes']
  160. ['Overcast' 'Mild' 'High' 'Strong' 'Yes']
  161. ['Overcast' 'Hot' 'Normal' 'Weak' 'Yes']
  162. ['Rain' 'Mild' 'High' 'Strong' 'No']]
  163. Start
  164. outlook
  165. Sunny
  166. humidity
  167. High
  168. No
  169. Normal
  170. Yes
  171. Overcast
  172. Yes
  173. Rain
  174. wind
  175. Weak
  176. Yes
  177. Strong
  178. No
  179.  
  180. """
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement