Advertisement
Guest User

Untitled

a guest
Jul 20th, 2019
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.87 KB | None | 0 0
  1. const csv = require('csv/lib/sync');
  2. const fs = require('fs');
  3.  
  4. // Naive single decision tree implementation in javascript
  5. // Train.csv can be acquired in kaggle's blue book for bulldozers competition
  6. class DecisionTree {
  7. constructor(csvData, independentVariables, dependentVariable, options) {
  8. this.options = options;
  9. this.depth = 0;
  10. this.columns = csvData.shift();
  11. this.data = csvData;
  12. this.dependentVariableIndex = DecisionTree.getVariablesIndexes(this.columns, [dependentVariable]);
  13. this.independentVariablesIndexes = DecisionTree.getVariablesIndexes(this.columns, independentVariables);
  14.  
  15. if (options.sampleSize > -1 && this.data.length > options.sampleSize) {
  16. this.data = this.data.slice(0, options.sampleSize);
  17. }
  18. }
  19.  
  20. fit() {
  21. let depth = 0;
  22.  
  23. const getBestSplitForBranch = (subtreeIndexes) => {
  24. let bestSplit = {branchDiffusion: Infinity};
  25.  
  26. depth += 1;
  27.  
  28. this.independentVariablesIndexes.forEach(independentVariableIndex => {
  29. const split = this.getBestSplitForVariable(subtreeIndexes, independentVariableIndex);
  30. if (split.branchDiffusion < bestSplit.branchDiffusion) {
  31. bestSplit = split;
  32. bestSplit.splitVariable = this.columns[independentVariableIndex];
  33.  
  34. if (this.options.depth && depth <= this.options.depth) {
  35. bestSplit.leftBranch = getBestSplitForBranch(split.leftBranchIndexes);
  36. bestSplit.rightBranch = getBestSplitForBranch(split.rightBranchIndexes);
  37. }
  38.  
  39. bestSplit.sampleSize = subtreeIndexes.length;
  40.  
  41. delete bestSplit.leftBranchIndexes;
  42. delete bestSplit.rightBranchIndexes;
  43. }
  44. });
  45.  
  46. return bestSplit;
  47. };
  48.  
  49. return getBestSplitForBranch(Array(this.data.length).fill().map((_, i) => i));
  50. };
  51.  
  52. getBestSplitForVariable(subtreeIndexes, independentVariableIndex) {
  53. const possibleValues = this.getValuesFromColumn(subtreeIndexes, independentVariableIndex);
  54.  
  55. let split = {branchDiffusion: Infinity, splitValue: null, leftBranchIndexes: [], rightBranchIndexes: []};
  56.  
  57. possibleValues.forEach(splitValue => {
  58. const leftBranchIndexes = [];
  59. const rightBranchIndexes = [];
  60.  
  61. subtreeIndexes.forEach(index => {
  62. parseFloat(this.data[index][independentVariableIndex]) <= parseFloat(splitValue) ? leftBranchIndexes.push(index) : rightBranchIndexes.push(index);
  63. });
  64.  
  65. let branchDiffusion = leftBranchIndexes.length * DecisionTree.standardDeviation(
  66. leftBranchIndexes.map(index => Math.log(this.data[index][this.dependentVariableIndex])));
  67.  
  68.  
  69. branchDiffusion += rightBranchIndexes.length * DecisionTree.standardDeviation(
  70. rightBranchIndexes.map(index => Math.log(this.data[index][this.dependentVariableIndex])));
  71.  
  72.  
  73. if (branchDiffusion < split.branchDiffusion) {
  74. split = {branchDiffusion, leftBranchIndexes, rightBranchIndexes, splitValue};
  75. }
  76. });
  77.  
  78. return split;
  79. };
  80.  
  81. static getVariablesIndexes(columns, variables) {
  82. const indexes = [];
  83.  
  84. variables.forEach(column => {
  85. const columnIndex = columns.indexOf(column);
  86. if (columnIndex !== -1) {
  87. indexes.push(columnIndex)
  88. }
  89. });
  90.  
  91. return indexes;
  92. };
  93.  
  94. getValuesFromColumn(subtreeIndexes, columnIndex) {
  95. return subtreeIndexes.map(index => this.data[index][columnIndex]);
  96. };
  97.  
  98. static standardDeviation(values) {
  99. const mean = values.reduce((sum, value) => sum + parseFloat(value), 0) / values.length;
  100. const squaredMean = values.reduce((sum, value) => sum + Math.pow(value - mean, 2), 0) / values.length;
  101. return Math.sqrt(squaredMean);
  102. };
  103. }
  104.  
  105. const csvData = csv.parse(fs.readFileSync('Train.csv'));
  106.  
  107. const INDEPENDENT_VARIABLES = ['MachineHoursCurrentMeter', 'YearMade'];
  108. const DEPENDENT_VARIABLE = 'SalePrice';
  109.  
  110. console.log((new DecisionTree(csvData, INDEPENDENT_VARIABLES, DEPENDENT_VARIABLE, {sampleSize: 1000, depth: 1})).fit());
  111.  
  112. // Result:
  113. //
  114. // { branchDiffusion: 672.0239238184622,
  115. // splitValue: '2178',
  116. // splitVariable: 'MachineHoursCurrentMeter',
  117. // leftBranch:
  118. // { branchDiffusion: 299.86590963795163,
  119. // splitValue: '2003',
  120. // splitVariable: 'YearMade',
  121. // sampleSize: 470 },
  122. // rightBranch:
  123. // { branchDiffusion: 349.137024345939,
  124. // splitValue: '1997',
  125. // splitVariable: 'YearMade',
  126. // sampleSize: 530 },
  127. // sampleSize: 1000 }
  128.  
  129. // Jeremy's (fast.ai) Python implementation's result:
  130. // ml1 course – lesson3 notebook
  131.  
  132. // n: 1000; val:10.160352993311724; score:672.0239238184623; split:2178.0; var:MachineHoursCurrentMeter
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement