Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- const csv = require('csv/lib/sync');
- const fs = require('fs');
- // Naive single decision tree implementation in javascript
- // Train.csv can be acquired in kaggle's blue book for bulldozers competition
- class DecisionTree {
- constructor(csvData, independentVariables, dependentVariable, options) {
- this.options = options;
- this.depth = 0;
- this.columns = csvData.shift();
- this.data = csvData;
- this.dependentVariableIndex = DecisionTree.getVariablesIndexes(this.columns, [dependentVariable]);
- this.independentVariablesIndexes = DecisionTree.getVariablesIndexes(this.columns, independentVariables);
- if (options.sampleSize > -1 && this.data.length > options.sampleSize) {
- this.data = this.data.slice(0, options.sampleSize);
- }
- }
- fit() {
- let depth = 0;
- const getBestSplitForBranch = (subtreeIndexes) => {
- let bestSplit = {branchDiffusion: Infinity};
- depth += 1;
- this.independentVariablesIndexes.forEach(independentVariableIndex => {
- const split = this.getBestSplitForVariable(subtreeIndexes, independentVariableIndex);
- if (split.branchDiffusion < bestSplit.branchDiffusion) {
- bestSplit = split;
- bestSplit.splitVariable = this.columns[independentVariableIndex];
- if (this.options.depth && depth <= this.options.depth) {
- bestSplit.leftBranch = getBestSplitForBranch(split.leftBranchIndexes);
- bestSplit.rightBranch = getBestSplitForBranch(split.rightBranchIndexes);
- }
- bestSplit.sampleSize = subtreeIndexes.length;
- delete bestSplit.leftBranchIndexes;
- delete bestSplit.rightBranchIndexes;
- }
- });
- return bestSplit;
- };
- return getBestSplitForBranch(Array(this.data.length).fill().map((_, i) => i));
- };
- getBestSplitForVariable(subtreeIndexes, independentVariableIndex) {
- const possibleValues = this.getValuesFromColumn(subtreeIndexes, independentVariableIndex);
- let split = {branchDiffusion: Infinity, splitValue: null, leftBranchIndexes: [], rightBranchIndexes: []};
- possibleValues.forEach(splitValue => {
- const leftBranchIndexes = [];
- const rightBranchIndexes = [];
- subtreeIndexes.forEach(index => {
- parseFloat(this.data[index][independentVariableIndex]) <= parseFloat(splitValue) ? leftBranchIndexes.push(index) : rightBranchIndexes.push(index);
- });
- let branchDiffusion = leftBranchIndexes.length * DecisionTree.standardDeviation(
- leftBranchIndexes.map(index => Math.log(this.data[index][this.dependentVariableIndex])));
- branchDiffusion += rightBranchIndexes.length * DecisionTree.standardDeviation(
- rightBranchIndexes.map(index => Math.log(this.data[index][this.dependentVariableIndex])));
- if (branchDiffusion < split.branchDiffusion) {
- split = {branchDiffusion, leftBranchIndexes, rightBranchIndexes, splitValue};
- }
- });
- return split;
- };
- static getVariablesIndexes(columns, variables) {
- const indexes = [];
- variables.forEach(column => {
- const columnIndex = columns.indexOf(column);
- if (columnIndex !== -1) {
- indexes.push(columnIndex)
- }
- });
- return indexes;
- };
- getValuesFromColumn(subtreeIndexes, columnIndex) {
- return subtreeIndexes.map(index => this.data[index][columnIndex]);
- };
- static standardDeviation(values) {
- const mean = values.reduce((sum, value) => sum + parseFloat(value), 0) / values.length;
- const squaredMean = values.reduce((sum, value) => sum + Math.pow(value - mean, 2), 0) / values.length;
- return Math.sqrt(squaredMean);
- };
- }
- const csvData = csv.parse(fs.readFileSync('Train.csv'));
- const INDEPENDENT_VARIABLES = ['MachineHoursCurrentMeter', 'YearMade'];
- const DEPENDENT_VARIABLE = 'SalePrice';
- console.log((new DecisionTree(csvData, INDEPENDENT_VARIABLES, DEPENDENT_VARIABLE, {sampleSize: 1000, depth: 1})).fit());
- // Result:
- //
- // { branchDiffusion: 672.0239238184622,
- // splitValue: '2178',
- // splitVariable: 'MachineHoursCurrentMeter',
- // leftBranch:
- // { branchDiffusion: 299.86590963795163,
- // splitValue: '2003',
- // splitVariable: 'YearMade',
- // sampleSize: 470 },
- // rightBranch:
- // { branchDiffusion: 349.137024345939,
- // splitValue: '1997',
- // splitVariable: 'YearMade',
- // sampleSize: 530 },
- // sampleSize: 1000 }
- // Jeremy's (fast.ai) Python implementation's result:
- // ml1 course – lesson3 notebook
- // n: 1000; val:10.160352993311724; score:672.0239238184623; split:2178.0; var:MachineHoursCurrentMeter
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement