Advertisement
Guest User

Untitled

a guest
Jan 18th, 2020
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.58 KB | None | 0 0
  1. var _ = require('lodash');
  2.  
  3. /**
  4. * ID3 Decision Tree Algorithm
  5. * @module DecisionTreeID3
  6. */
  7.  
  8. module.exports = (function() {
  9.  
  10. /**
  11. * Map of valid tree node types
  12. * @constant
  13. * @static
  14. */
  15. const NODE_TYPES = DecisionTreeID3.NODE_TYPES = {
  16. RESULT: 'result',
  17. FEATURE: 'feature',
  18. FEATURE_VALUE: 'feature_value'
  19. };
  20.  
  21. /**
  22. * Underlying model
  23. * @private
  24. */
  25. var model;
  26.  
  27. /**
  28. * @constructor
  29. * @return {DecisionTreeID3}
  30. */
  31. function DecisionTreeID3(data, target, features) {
  32. this.data = data;
  33. this.target = target;
  34. this.features = features;
  35. model = createTree(data, target, features);
  36. }
  37.  
  38. /**
  39. * @public API
  40. */
  41. DecisionTreeID3.prototype = {
  42.  
  43. /**
  44. * Predicts class for sample
  45. */
  46. predict: function(sample) {
  47. var root = model;
  48. while (root.type !== NODE_TYPES.RESULT) {
  49. var attr = root.name;
  50. var sampleVal = sample[attr];
  51. var childNode = _.detect(root.vals, function(node) {
  52. return node.name == sampleVal
  53. });
  54. if (childNode){
  55. root = childNode.child;
  56. } else {
  57. root = root.vals[0].child;
  58. }
  59. }
  60.  
  61. return root.val;
  62. },
  63.  
  64. /**
  65. * Evalutes prediction accuracy on samples
  66. */
  67. evaluate: function(samples) {
  68. var instance = this;
  69. var target = this.target;
  70.  
  71. var total = 0;
  72. var correct = 0;
  73.  
  74. _.each(samples, function(s) {
  75. total++;
  76. var pred = instance.predict(s);
  77. var actual = s[target];
  78. if (pred == actual) {
  79. correct++;
  80. }
  81. });
  82.  
  83. return correct / total;
  84. },
  85.  
  86. /**
  87. * Returns JSON representation of trained model
  88. */
  89. toJSON: function() {
  90. return model;
  91. }
  92. };
  93.  
  94. /**
  95. * Creates a new tree
  96. * @private
  97. */
  98. function createTree(data, target, features) {
  99. var targets = _.unique(_.pluck(data, target));
  100. if (targets.length == 1) {
  101. return {
  102. type: NODE_TYPES.RESULT,
  103. val: targets[0],
  104. name: targets[0],
  105. alias: targets[0] + randomUUID()
  106. };
  107. }
  108.  
  109. if (features.length == 0) {
  110. var topTarget = mostCommon(targets);
  111. var x = {
  112. type: NODE_TYPES.RESULT,
  113. val: topTarget,
  114. name: topTarget,
  115. alias: topTarget + randomUUID()
  116. };
  117. return x;
  118. }
  119.  
  120. var bestFeature = maxGain(data, target, features);
  121. var remainingFeatures = _.without(features, bestFeature);
  122. var possibleValues = _.unique(_.pluck(data, bestFeature));
  123.  
  124. var node = {
  125. name: bestFeature,
  126. alias: bestFeature + randomUUID()
  127. };
  128.  
  129. node.type = NODE_TYPES.FEATURE;
  130. node.vals = _.map(possibleValues, function(v) {
  131. var _newS = data.filter(function(x) {
  132. return x[bestFeature] == v
  133. });
  134.  
  135. var child_node = {
  136. name: v,
  137. alias: v + randomUUID(),
  138. type: NODE_TYPES.FEATURE_VALUE
  139. };
  140.  
  141. child_node.child = createTree(_newS, target, remainingFeatures);
  142. return child_node;
  143. });
  144.  
  145. return node;
  146. }
  147.  
  148. /**
  149. * Computes entropy of a list
  150. * @private
  151. */
  152. function entropy(vals) {
  153. var uniqueVals = _.unique(vals);
  154. var probs = uniqueVals.map(function(x) {
  155. return prob(x, vals)
  156. });
  157.  
  158. var logVals = probs.map(function(p) {
  159. return -p * log2(p)
  160. });
  161.  
  162. return logVals.reduce(function(a, b) {
  163. return a + b
  164. }, 0);
  165. }
  166.  
  167. /**
  168. * Computes gain
  169. * @private
  170. */
  171. function gain(data, target, feature) {
  172. var attrVals = _.unique(_.pluck(data, feature));
  173. var setEntropy = entropy(_.pluck(data, target));
  174. var setSize = _.size(data);
  175.  
  176. var entropies = attrVals.map(function(n) {
  177. var subset = data.filter(function(x) {
  178. return x[feature] === n
  179. });
  180.  
  181. return (subset.length / setSize) * entropy(_.pluck(subset, target));
  182. });
  183.  
  184. var sumOfEntropies = entropies.reduce(function(a, b) {
  185. return a + b
  186. }, 0);
  187.  
  188. return setEntropy - sumOfEntropies;
  189. }
  190.  
  191. /**
  192. * Computes Max gain across features to determine best split
  193. * @private
  194. */
  195. function maxGain(data, target, features) {
  196. return _.max(features, function(element) {
  197. return gain(data, target, element)
  198. });
  199. }
  200.  
  201. /**
  202. * Computes probability of of a given value existing in a given list
  203. * @private
  204. */
  205. function prob(value, list) {
  206. var occurrences = _.filter(list, function(element) {
  207. return element === value
  208. });
  209.  
  210. var numOccurrences = occurrences.length;
  211. var numElements = list.length;
  212. return numOccurrences / numElements;
  213. }
  214.  
  215. /**
  216. * Computes Log with base-2
  217. * @private
  218. */
  219. function log2(n) {
  220. return Math.log(n) / Math.log(2);
  221. }
  222.  
  223. /**
  224. * Finds element with highest occurrence in a list
  225. * @private
  226. */
  227. function mostCommon(list) {
  228. var elementFrequencyMap = {};
  229. var largestFrequency = -1;
  230. var mostCommonElement = null;
  231.  
  232. list.forEach(function(element) {
  233. var elementFrequency = (elementFrequencyMap[element] || 0) + 1;
  234. elementFrequencyMap[element] = elementFrequency;
  235.  
  236. if (largestFrequency < elementFrequency) {
  237. mostCommonElement = element;
  238. largestFrequency = elementFrequency;
  239. }
  240. });
  241.  
  242. return mostCommonElement;
  243. }
  244.  
  245. /**
  246. * Generates random UUID
  247. * @private
  248. */
  249. function randomUUID() {
  250. return "_r" + Math.random().toString(32).slice(2);
  251. }
  252.  
  253. /**
  254. * @class DecisionTreeID3
  255. */
  256. return DecisionTreeID3;
  257. })();
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement