Guest User

Untitled

a guest
Jan 23rd, 2019
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.51 KB | None | 0 0
  1. /** decision_tree_train - a stored procedure that trains a decision tree, and stores the model in the 'ml_model_runs' table
  2. * Parameters:
  3. * TABLE_NAME - the name of the table containing the training data
  4. * TARGET - the name of the column containing the target variable (the one to predict)
  5. * COLS - a comma separated list of the table columns to include as variables in the model
  6. * TRAINING_PARAMS - an object containing training parameters, which can be:
  7. * cv_limit (default 10) - Coefficient of Deviation limit, used to stop branching
  8. * total_count_limit (default 1) - Total record count limit, used to stop branching
  9. * cv_decimal_places (default 5) - The number of decimal places to round the Coefficient of Deviation calculation to
  10. * average_decimal_places (default 2) - The number of decimal places to round the average calculation to (where multiple records exist at a leaf)
  11. * maxDepth (default 15) - the maximum depth of the tree
  12. * maxFeatures (default 8) - the maximum number of features to evaluate at a time
  13. * debugMessages (default false) - set to true to include extra information in the output model about the state of each node
  14. **/
  15. create or replace procedure decision_tree_train(TABLE_NAME VARCHAR, TARGET VARCHAR, COLS VARCHAR,TRAINING_PARAMS VARIANT)
  16. returns string not null
  17. language javascript
  18. as
  19. $$
  20. function leafCalc(tableName,whereClause,whereClauseBindings,target,remainingCols,depth,trainingParameters){
  21. var return_object={};
  22. if (training_parameters.debugMessages){
  23. return_object.cumulative_where_clause=whereClause;
  24. return_object.cumulative_where_clause_bindings=whereClauseBindings;
  25. }
  26. var results;
  27. results = snowflake.execute({
  28. sqlText: "select stddev("+target+") as target_stddev,"+
  29. "avg("+target+") as target_avg,"+
  30. "case when target_avg is not null and target_avg!=0 then target_stddev/target_avg*100 else 0 end as coef_of_variation,"+
  31. "count(*) as target_count "+
  32. "from "+tableName+" where "+whereClause,
  33. binds: whereClauseBindings
  34. });
  35. results.next();
  36. var averageBelow=results.getColumnValue(2);
  37. if (averageBelow==null){
  38. return null; // if there are no results below this value, return null so that this node can be removed
  39. }
  40. else{
  41. averageBelow=averageBelow.toFixed(trainingParameters.average_decimal_places);
  42. }
  43. if (depth >= trainingParameters.maxDepth){
  44. return_object.prediction=averageBelow;
  45. if (training_parameters.debugMessages){
  46. return_object.stopped_on="max_depth_reached (limit "+trainingParameters.maxDepth+", value "+depth+")";
  47. }
  48. return return_object;
  49. }
  50. if (remainingCols.length<1){
  51. return_object.prediction=averageBelow;
  52. if (training_parameters.debugMessages){
  53. return_object.stopped_on="last_attribute";
  54. }
  55. return return_object;
  56. }
  57. var target_count=results.getColumnValue(4);
  58. if (target_count <= trainingParameters.total_count_limit){
  59. return_object.prediction=averageBelow;
  60. if (training_parameters.debugMessages){
  61. return_object.stopped_on="below_child_record_count_limit (limit "+trainingParameters.total_count_limit+", value "+target_count+")";
  62. }
  63. return return_object;
  64. }
  65. var coefficientOfVariation=results.getColumnValue(3).toFixed(trainingParameters.cv_decimal_places);
  66. if (coefficientOfVariation < trainingParameters.cv_limit){
  67. return_object.prediction=averageBelow;
  68. if (training_parameters.debugMessages){
  69. return_object.stopped_on="below_cv_threshold (limit "+trainingParameters.cv_limit+", value "+coefficientOfVariation+")";
  70. }
  71. return return_object;
  72. }
  73. var stddevBeforeSplit = results.getColumnValue(1);
  74. var countBeforeSplit = results.getColumnValue(4);
  75. if (countBeforeSplit==0){
  76. throw "The number of records during leaf node calculation was zero, this should not happen and means there's a bug in the stored proc";
  77. }
  78. if (stddevBeforeSplit==0){
  79. throw "The standard deviation during leaf node calculation was zero, this should not happen and means there's a bug in the stored proc";
  80. }
  81. var columnQueries=[];
  82. for (var i=0;i<remainingCols.length && i < training_parameters.maxFeatures;i++){
  83. var col=remainingCols[i];
  84. columnQueries.push("select '"+col+"' as col,"+
  85. col+" as column_value,"+
  86. "stddev("+target+") as sd_branch, "+
  87. "count("+col+") as count_branch, "+
  88. "count("+col+")/"+countBeforeSplit+"*stddev("+target+") as p_times_s "+
  89. "from "+tableName+" where "+whereClause+" group by "+col);
  90. }
  91. if (columnQueries.length==0){
  92. throw "No subqueries were generated, this should not happen and means there's a bug in the stored proc";
  93. }
  94. var query="select col,"+stddevBeforeSplit+"-sum(p_times_s) as sdr from (";
  95. query=query+columnQueries.join(" union ");
  96. query=query+") group by col order by sdr desc";
  97. results = snowflake.execute({
  98. sqlText: query,
  99. binds: whereClauseBindings
  100. });
  101. results.next();
  102. var selectedCol=results.getColumnValue(1);
  103. var withSelectedColRemoved=remainingCols.filter(function(value, index, arr){return value != selectedCol;});
  104. var results = snowflake.execute({
  105. sqlText: "select distinct("+selectedCol+") from "+TABLE_NAME
  106. });
  107. var thisNode={};
  108. if (training_parameters.debugMessages){
  109. thisNode.nextAttribute=selectedCol;
  110. thisNode.coefficientOfVariation=coefficientOfVariation;
  111. }
  112. thisNode.children=[]
  113. while(results.next()){
  114. var child={};
  115. child.columnValue=results.getColumnValue(1);
  116. var childWhereClause=whereClause+" and "+selectedCol+"= :"+(whereClauseBindings.length+1);
  117. whereClauseBindings.push(child.columnValue);
  118. var branchesBelow=leafCalc(tableName,childWhereClause,whereClauseBindings,target,withSelectedColRemoved,depth+1,trainingParameters);
  119. if (branchesBelow!=null){
  120. branchesBelow.selectionCriteriaAttribute=selectedCol;
  121. branchesBelow.selectionCriteriaPredicate='=';
  122. branchesBelow.selectionCriteriaValue=child.columnValue;
  123. thisNode.children.push(branchesBelow);
  124. }
  125. }
  126. return thisNode;
  127. }
  128. var columns=COLS.split(',');
  129. var results = snowflake.execute({
  130. sqlText: "select ml_model_runs_sequence.nextval"
  131. });
  132. results.next();
  133. var default_training_parameters={};
  134. default_training_parameters.cv_limit=10;
  135. default_training_parameters.total_count_limit=1;
  136. default_training_parameters.cv_decimal_places=5;
  137. default_training_parameters.average_decimal_places=2;
  138. default_training_parameters.maxDepth=15;
  139. default_training_parameters.maxFeatures=8;
  140. default_training_parameters.debugMessages=false;
  141.  
  142. var training_parameters={...default_training_parameters,...TRAINING_PARAMS};
  143.  
  144. var runId=results.getColumnValue(1);
  145. results = snowflake.execute({
  146. sqlText: "insert into ml_model_runs(run_id,table_name,algorithm,training_parameters,start_time) select :1,:2,:3,parse_json('"+JSON.stringify(training_parameters)+"'),current_timestamp::TIMESTAMP_NTZ",
  147. binds: [runId, TABLE_NAME,'decision_tree']
  148. });
  149.  
  150. var model=leafCalc(TABLE_NAME,'1=1',[],TARGET,columns,0,training_parameters);
  151. results = snowflake.execute({
  152. sqlText: "update ml_model_runs set end_time=current_timestamp::TIMESTAMP_NTZ, model_object=parse_json('"+JSON.stringify(model)+"') where run_id=?",
  153. binds: [runId]
  154. });
  155. return runId;
  156. $$
  157. ;
Add Comment
Please, Sign In to add comment