Advertisement
ptrelford

Decision Trees

Jul 7th, 2013
318
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.70 KB | None | 0 0
  1. from math import log
  2.  
  3. def calcShannonEnt(dataSet):
  4.     numEntries = len(dataSet)
  5.     labelCounts = {}
  6.     for featVec in dataSet:
  7.         currentLabel = featVec[-1]
  8.         if currentLabel not in labelCounts.keys():
  9.             labelCounts[currentLabel] = 0
  10.         labelCounts[currentLabel] += 1
  11.     shannonEnt = 0.0
  12.     for key in labelCounts:
  13.         prob = float(labelCounts[key])/numEntries
  14.         shannonEnt -= prob * log(prob,2)
  15.     return shannonEnt
  16.  
  17. def splitDataSet(dataSet, axis, value):
  18.     retDataSet = []
  19.     for featVec in dataSet:
  20.         if featVec[axis] == value:
  21.             reducedFeatVec = featVec[:axis]
  22.             reducedFeatVec.extend(featVec[axis+1:])
  23.             retDataSet.append(reducedFeatVec)
  24.     return retDataSet
  25.  
  26.  
  27. def chooseBestFeatureToSplit(dataSet):
  28.     numFeatures = len(dataSet[0]) - 1
  29.     baseEntropy = calcShannonEnt(dataSet)
  30.     bestInfoGain = 0.0; bestFeature = -1
  31.     for i in range(numFeatures):
  32.         featList = [example[i] for example in dataSet]
  33.         uniqueVals = set(featList)
  34.         newEntropy = 0.0
  35.         for value in uniqueVals:
  36.             subDataSet = splitDataSet(dataSet, i, value)
  37.             prob = len(subDataSet)/float(len(dataSet))
  38.             newEntropy += prob * calcShannonEnt(subDataSet)
  39.         infoGain = baseEntropy - newEntropy
  40.         if (infoGain > bestInfoGain):
  41.             bestInfoGain = infoGain
  42.             bestFeature = i
  43.     return bestFeature
  44.  
  45. def majorityCnt(classList):
  46.     classCount={}
  47.     for vote in classList:
  48.         if vote not in classCount.keys(): classCount[vote] = 0
  49.         classCount[vote] += 1
  50.     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
  51.     return sortedClassCount[0][0]
  52.  
  53. def createTree(dataSet,labels):
  54.     classList = [example[-1] for example in dataSet]
  55.     if classList.count(classList[0]) == len(classList):
  56.         return classList[0]
  57.     if len(dataSet[0]) == 1:
  58.         return majorityCnt(classList)
  59.     bestFeat = chooseBestFeatureToSplit(dataSet)
  60.     bestFeatLabel = labels[bestFeat]
  61.     myTree = {bestFeatLabel:{}}
  62.     del(labels[bestFeat])
  63.     featValues = [example[bestFeat] for example in dataSet]
  64.     uniqueVals = set(featValues)
  65.     for value in uniqueVals:
  66.         subLabels = labels[:]
  67.         myTree[bestFeatLabel][value] = createTree(splitDataSet\
  68.             (dataSet, bestFeat, value),subLabels)
  69.     return myTree
  70.  
  71. labels = ['no surfacing','flippers']
  72. myDat = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
  73. print(calcShannonEnt(myDat))
  74. feature = chooseBestFeatureToSplit(myDat)
  75. print(feature)
  76. myTree = createTree(myDat,labels)
  77. print(myTree)
  78. print('Done')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement