Advertisement
Guest User

Untitled

a guest
Oct 19th, 2019
108
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.72 KB | None | 0 0
  1. var dataB1 = [1, 1, 0];
  2. var dataB2 = [2, 1, 0];
  3. var dataB3 = [2, .5, 0];
  4. var dataB4 = [3, 1, 0];
  5.  
  6. var dataR1 = [3, 1.5, 1];
  7. var dataR2 = [3.5, .5, 1];
  8. var dataR3 = [4, 1.5, 1];
  9. var dataR4 = [5.5, 1, 1];
  10.  
  11. //unknown type (data we want to find)
  12. var dataU = [4.5, 1, "it should be 1"];
  13.  
  14. var all_points = [dataB1, dataB2, dataB3, dataB4, dataR1, dataR2, dataR3, dataR4];
  15.  
  16. function sigmoid(x) {
  17. return 1/(1+Math.exp(-x));
  18. }
  19.  
  20. // training
  21. function train() {
  22. let w1 = Math.random()*.2-.1;
  23. let w2 = Math.random()*.2-.1;
  24. let b = Math.random()*.2-.1;
  25. let learning_rate = 0.2;
  26. for (let iter = 0; iter < 50000; iter++) {
  27. // pick a random point
  28. let random_idx = Math.floor(Math.random() * all_points.length);
  29. let point = all_points[random_idx];
  30. let target = point[2]; // target stored in 3rd coord of points
  31.  
  32. // feed forward
  33. let z = w1 * point[0] + w2 * point[1] + b;
  34. let pred = sigmoid(z);
  35.  
  36. // now we compare the model prediction with the target
  37. let cost = (pred - target) ** 2;
  38.  
  39. // now we find the slope of the cost w.r.t. each parameter (w1, w2, b)
  40. // bring derivative through square function
  41. let dcost_dpred = 2 * (pred - target);
  42.  
  43. // bring derivative through sigmoid
  44. // derivative of sigmoid can be written using more sigmoids! d/dz sigmoid(z) = sigmoid(z)*(1-sigmoid(z))
  45. let dpred_dz = sigmoid(z) * (1-sigmoid(z));
  46.  
  47. // I think you forgot these in your slope calculation?
  48. let dz_dw1 = point[0];
  49. let dz_dw2 = point[1];
  50. let dz_db = 1;
  51.  
  52. // now we can get the partial derivatives using the chain rule
  53. // notice the pattern? We're bringing how the cost changes through each function, first through the square, then through the sigmoid
  54. // and finally whatever is multiplying our parameter of interest becomes the last part
  55. let dcost_dw1 = dcost_dpred * dpred_dz * dz_dw1;
  56. let dcost_dw2 = dcost_dpred * dpred_dz * dz_dw2;
  57. let dcost_db = dcost_dpred * dpred_dz * dz_db;
  58.  
  59. // now we update our parameters!
  60. w1 -= learning_rate * dcost_dw1;
  61. w2 -= learning_rate * dcost_dw2;
  62. b -= learning_rate * dcost_db;
  63. }
  64.  
  65. return {w1: w1, w2: w2, b: b};
  66. }
  67.  
  68. let canvas = document.createElement("canvas");
  69. canvas.width = 400;
  70. canvas.height = 400;
  71. document.body.appendChild(canvas);
  72. let ctx = canvas.getContext("2d");
  73. ctx.font = "Helvetica";
  74.  
  75. // map points from graph coordinates to the screen
  76. let graph_size = {width: 7, height: 7};
  77. function to_screen(x, y) {
  78. return {x: (x/graph_size.width)*canvas.width, y: -(y/graph_size.height)*canvas.height + canvas.height};
  79. }
  80.  
  81. // map points from screen coordinates to the graph
  82. function to_graph(x, y) {
  83. return {x: x/canvas.width*graph_size.width, y: graph_size.height - y/canvas.height*graph_size.height};
  84. }
  85.  
  86. // draw the graph's grid lines
  87. function draw_grid() {
  88. ctx.strokeStyle = "#AAAAAA";
  89. for (let j = 0; j <= graph_size.width; j++) {
  90.  
  91. // x lines
  92. ctx.beginPath();
  93. let p = to_screen(j, 0);
  94. ctx.moveTo(p.x, p.y);
  95. p = to_screen(j, graph_size.height);
  96. ctx.lineTo(p.x, p.y);
  97. ctx.stroke();
  98.  
  99. // y lines
  100. ctx.beginPath();
  101. p = to_screen(0, j);
  102. ctx.moveTo(p.x, p.y);
  103. p = to_screen(graph_size.width, j);
  104. ctx.lineTo(p.x, p.y);
  105. ctx.stroke();
  106. }
  107. }
  108.  
  109. // draw points
  110. function draw_points() {
  111. // unknown
  112. let p = to_screen(dataU[0], dataU[1]);
  113. ctx.fillStyle = "#555555";
  114. ctx.fillText("???", p.x-8, p.y-5);
  115. ctx.fillRect(p.x-2, p.y-2, 4, 4);
  116.  
  117. // draw points
  118. ctx.fillStyle = "#0000FF";
  119. for (let j = 0; j < all_points.length; j++) {
  120. let point = all_points[j];
  121. if (point[2] == 0) {
  122. ctx.fillStyle = "#0000FF";
  123. } else {
  124. ctx.fillStyle = "#FF0000";
  125. }
  126. p = to_screen(point[0], point[1]);
  127. ctx.fillRect(p.x-2, p.y-2, 4, 4);
  128. }
  129. }
  130.  
  131. // visualize model output on grid of points
  132. function visualize_params(params) {
  133. ctx.save();
  134. ctx.globalAlpha = 0.2;
  135. let step_size = .1;
  136. let box_size = canvas.width/(graph_size.width/step_size);
  137.  
  138. for (let xx = 0; xx < graph_size.width; xx += step_size) {
  139. for (let yy = 0; yy < graph_size.height; yy += step_size) {
  140. let model_out = sigmoid( xx * params.w1 + yy * params.w2 + params.b );
  141. if (model_out < .5) {
  142. // blue
  143. ctx.fillStyle = "#0000FF";
  144. } else {
  145. // red
  146. ctx.fillStyle = "#FF0000";
  147. }
  148. let p = to_screen(xx, yy);
  149. ctx.fillRect(p.x, p.y, box_size, box_size);
  150. }
  151. }
  152. ctx.restore();
  153. }
  154.  
  155. // find parameters
  156. var params = train();
  157.  
  158. // visualize model output
  159. ctx.clearRect(0, 0, canvas.width, canvas.height);
  160. draw_grid();
  161. draw_points();
  162. visualize_params(params);
  163.  
  164. // say what the model would say for a given mouse position
  165. window.onmousemove = function(evt) {
  166. ctx.clearRect(0, 0, 100, 50);
  167.  
  168. let p = {x: 10, y: 20};
  169.  
  170. let mouse = {x: evt.offsetX, y: evt.offsetY};
  171. let mouse_graph = to_graph(mouse.x, mouse.y);
  172.  
  173. ctx.fillText("x: " + Math.round(mouse_graph.x*100)/100, p.x, p.y);
  174. ctx.fillText("y: " + Math.round(mouse_graph.y*100)/100, p.x, p.y + 10);
  175. // model output
  176. let model_out = sigmoid( mouse_graph.x * params.w1 + mouse_graph.y * params.w2 + params.b );
  177. model_out = Math.round(model_out*100)/100;
  178. ctx.fillText("prediction: " + model_out, p.x, p.y + 20);
  179. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement