Advertisement
Guest User

Untitled

a guest
Mar 24th, 2017
59
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.06 KB | None | 0 0
  1. def radixsort(values):
  2. radix = 2
  3. maxLength = False
  4. temp = -1
  5. placement = 1
  6.  
  7. indices = np.array([i for i in range(len(values))])
  8.  
  9. while not maxLength:
  10. maxLength = True
  11. buckets = [[] for i in range(radix)]
  12.  
  13. for i in indices:
  14. temp = int(values[i] / placement)
  15. buckets[temp % radix] += [i]
  16. if maxLength and temp > 0:
  17. maxLength = False
  18.  
  19. a = 0
  20. for b in range(radix):
  21. bucket = buckets[b]
  22. for i in bucket:
  23. indices[a] = i
  24. a += 1
  25.  
  26. placement *= radix
  27. return indices
  28.  
  29. class DecisionTree:
  30.  
  31. def __init__(self, attrBagging=1, prune=False, overSample=False, splitFine=False):
  32. self.attrBagging = attrBagging
  33. self.prune = prune
  34. self.overSample = overSample
  35. self.splitFine = splitFine
  36.  
  37. def train(self, data, labels):
  38. # balancing
  39. if self.overSample:
  40. ys, counts = np.unique(labels, return_counts=True)
  41. extraData = []
  42. extraLabels = []
  43. maxCounts = max(counts)
  44. for i in range(len(ys)):
  45. if counts[i] < maxCounts:
  46. added = 0
  47. index = 0
  48. while(counts[i] + added < maxCounts):
  49. if(labels[index] == ys[i]):
  50. extraData += [data[index]]
  51. extraLabels += [labels[index]]
  52. added += 1
  53. index = (index + 1) % len(labels)
  54. data = np.array(list(data) + extraData)
  55. labels = np.array(list(labels) + extraLabels)
  56.  
  57. self.numSamples = data.shape[0]
  58. if self.prune:
  59. indices = [i for i in range(data.shape[0])]
  60. random.shuffle(indices)
  61. xtrain = data[indices[:int(.7 * data.shape[0])]]
  62. ytrain = labels[indices[:int(.7 * data.shape[0])]]
  63. xvalid = data[indices[int(.7 * data.shape[0]):]]
  64. yvalid = labels[indices[int(.7 * data.shape[0]):]]
  65. self.root = self.growTree(xtrain, ytrain, xvalid, yvalid, 0)
  66. else:
  67. self.root = self.growTree(data, labels, None, None, 0)
  68.  
  69. def predict(self, data):
  70. return self.root.predict(data)
  71.  
  72. def predictPath(self, sample, features):
  73. self.root.predictPath(sample, features)
  74.  
  75. def growTree(self, xtrain, ytrain, xvalid, yvalid, depth):
  76. ys, yCounts = np.unique(ytrain, return_counts=True)
  77. if self.prune:
  78. tooDeep = depth > 2 * np.log2(self.numSamples)
  79. else:
  80. tooDeep = depth > np.log2(self.numSamples)
  81. notEnoughSamples = ytrain.shape[0] < self.numSamples / 100
  82. allSame = max(yCounts) == ytrain.shape[0]
  83. if tooDeep or notEnoughSamples or allSame:
  84. return Node(None, None, None, ys[np.argmax(yCounts)])
  85.  
  86. rule = self.segmenter(xtrain, ytrain)
  87. if rule == None:
  88. return Node(None, None, None, ys[np.argmax(yCounts)])
  89.  
  90. # print(rule)
  91.  
  92. leftIndicesTrain = []
  93. rightIndicesTrain = []
  94. for i in range(xtrain.shape[0]):
  95. if xtrain[i, rule[0]] > rule[1]:
  96. rightIndicesTrain += [i]
  97. else:
  98. leftIndicesTrain += [i]
  99.  
  100. leftIndicesValid = []
  101. rightIndicesValid = []
  102.  
  103. if self.prune:
  104. for i in range(xvalid.shape[0]):
  105. if xvalid[i, rule[0]] > rule[1]:
  106. rightIndicesValid += [i]
  107. else:
  108. leftIndicesValid += [i]
  109.  
  110. if self.prune:
  111. leftNode = self.growTree(xtrain[leftIndicesTrain], ytrain[leftIndicesTrain], xvalid[leftIndicesValid], yvalid[leftIndicesValid], depth + 1)
  112. rightNode = self.growTree(xtrain[rightIndicesTrain], ytrain[rightIndicesTrain], xvalid[rightIndicesValid], yvalid[rightIndicesValid], depth + 1)
  113. else:
  114. leftNode = self.growTree(xtrain[leftIndicesTrain], ytrain[leftIndicesTrain], None, None, depth + 1)
  115. rightNode = self.growTree(xtrain[rightIndicesTrain], ytrain[rightIndicesTrain], None, None, depth + 1)
  116.  
  117. split = Node(rule, leftNode, rightNode, None)
  118.  
  119. if not self.prune:
  120. return split
  121.  
  122. noSplit = Node(None, None, None, ys[np.argmax(yCounts)])
  123.  
  124. def errs(pred, actual):
  125. errs = 0
  126. for i in range(len(pred)):
  127. if pred[i] != actual[i]:
  128. errs += 1
  129. return errs
  130.  
  131. # pruning
  132. if leftNode.label != None or rightNode.label != None:
  133. noSplitErr = errs(noSplit.predict(xvalid), yvalid)
  134. splitErr = errs(split.predict(xvalid), yvalid)
  135. if(splitErr >= noSplitErr):
  136. return noSplit
  137.  
  138. return split
  139.  
  140. def segmenter(self, data, labels):
  141. maxEntropyDec = 0
  142. rule = None
  143. # check split for some or all feature depending on attribute bagging
  144. features = [i for i in range(data.shape[1])]
  145. if self.attrBagging != 1:
  146. random.shuffle(features)
  147. features = features[:int(data.shape[1] * self.attrBagging)]
  148. for i in features:
  149. sortedIndices = radixsort(data[:, i])
  150.  
  151. ys, rightCounts = np.unique(labels, return_counts=True)
  152. labelToIndex = {}
  153. for j in range(len(ys)):
  154. labelToIndex[ys[j]] = j
  155. rightCounts = np.append(rightCounts, [sum(rightCounts)]) + 0.0
  156. leftCounts = np.zeros(rightCounts.shape) + 0.0
  157.  
  158. baseEntropy = self.entropy(rightCounts)
  159. curEntropy = baseEntropy
  160. lowestEntropy = baseEntropy
  161. split = None
  162. j = 0
  163.  
  164. end = False
  165.  
  166. if not self.splitFine:
  167. fVals = np.unique(data[:, i])
  168. numSplitLocs = min(7, len(fVals))
  169. splitLoc = 0
  170. while j < len(sortedIndices) - 1:
  171. changed = set()
  172. changes = np.zeros(rightCounts.shape) + 0.0
  173. if self.splitFine:
  174. fVal = data[sortedIndices[j], i]
  175. else:
  176. fVal = min(fVals) + ((max(fVals) - min(fVals)) * splitLoc / numSplitLocs)
  177. while data[sortedIndices[j], i] <= fVal:
  178. changed.add(labelToIndex[labels[sortedIndices[j]]])
  179. changes[labelToIndex[labels[sortedIndices[j]]]] += 1
  180. changes[-1] += 1
  181. j += 1
  182. if j >= len(sortedIndices):
  183. end = True
  184. break
  185. if end:
  186. break
  187.  
  188. if not self.splitFine:
  189. splitLoc += 1
  190.  
  191. # update entropy
  192. curEntropy *= (leftCounts[-1] + rightCounts[-1])
  193. changed = np.array(list(changed) + [len(ys)])
  194. curEntropy -= leftCounts[-1] * self.entropy(leftCounts[changed])
  195. curEntropy -= rightCounts[-1] * self.entropy(rightCounts[changed])
  196. leftCounts += changes
  197. rightCounts -= changes
  198. curEntropy += leftCounts[-1] * self.entropy(leftCounts[changed])
  199. curEntropy += rightCounts[-1] * self.entropy(rightCounts[changed])
  200. curEntropy /= (leftCounts[-1] + rightCounts[-1])
  201. if curEntropy < lowestEntropy:
  202. lowestEntropy = curEntropy
  203. split = (i, data[sortedIndices[j - 1], i])
  204.  
  205. if baseEntropy - lowestEntropy > maxEntropyDec:
  206. maxEntropyDec = baseEntropy - lowestEntropy
  207. rule = split
  208.  
  209. return rule
  210.  
  211.  
  212.  
  213. def impurity(self, leftLabelHist, rightLabelHist):
  214. leftEntropy = leftLabelHist[-1] * self.entropy(leftLabelHist)
  215. rightEntropy = rightLabelHist[-1] * self.entropy(rightLabelHist)
  216. return (leftEntropy + rightEntropy) / (leftLabelHist[-1] + rightLabelHist[-1])
  217.  
  218.  
  219. def entropy(self, counts):
  220. if counts[-1] == 0:
  221. return 0
  222. p = counts[:-1] / counts[-1]
  223. return np.nan_to_num(-np.dot(p, np.log2(p)))
  224.  
  225. def printNodes(self):
  226. self.root.print(0)
  227.  
  228. class Node:
  229.  
  230. def __init__(self, splitRule, left, right, label):
  231. self.splitRule = splitRule
  232. self.left = left
  233. self.right = right
  234. self.label = label
  235.  
  236. def predict(self, X):
  237. pred = np.zeros(X.shape[0])
  238. for i in range(X.shape[0]):
  239. pred[i] = self.predictSingle(X[i])
  240. return pred
  241.  
  242. def predictSingle(self, x):
  243. if self.label != None:
  244. return self.label
  245. if x[self.splitRule[0]] > self.splitRule[1]:
  246. return self.right.predictSingle(x)
  247. else:
  248. return self.left.predictSingle(x)
  249.  
  250. def predictPath(self, x, features):
  251. if self.label != None:
  252. print("final label:", self.label)
  253. return
  254. if x[self.splitRule[0]] > self.splitRule[0]:
  255. print(features[self.splitRule[0]], ">", self.splitRule[1])
  256. self.right.predictPath(x, features)
  257. else:
  258. print(features[self.splitRule[0]], "<=", self.splitRule[1])
  259. self.left.predictPath(x, features)
  260.  
  261. def print(self, depth):
  262. if self.splitRule == None:
  263. print("depth:", depth, "label:", self.label)
  264. else:
  265. print("depth:", depth, "rule:", self.splitRule)
  266. if self.left != None:
  267. self.left.print(depth + 1)
  268. if self.right != None:
  269. self.right.print(depth + 1)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement