Advertisement
Guest User

Untitled

a guest
Jul 17th, 2019
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.67 KB | None | 0 0
  1. Point gradient_descent(double dx, double dy, double error, double gamma, unsigned int max_iters, double moment) {
  2.  
  3. double cur_error_x = error;
  4. double cur_error_y = error;
  5. double cur_error_moment_x = 0;
  6. double cur_error_moment_y = 0;
  7.  
  8. unsigned int iters = 0;
  9. double p_error_dx;
  10. double p_error_dy;
  11. double v_x = 0, v_y = 0;
  12. double v_x_pre = 0, v_y_pre = 0;
  13.  
  14. std::string savePath = "D:\GradientDescent.csv";
  15.  
  16. std::ofstream writeFile(savePath);
  17. writeFile << "dx,dy,E(dx;dy)n";
  18.  
  19. double dxMin = 1000, dyMin = 1000;
  20. do {
  21. p_error_dx = dx;
  22. p_error_dy = dy;
  23.  
  24. //NAG
  25. v_x = moment * v_x_pre + gamma * dfx(p_error_dx - moment * v_x_pre , p_error_dy - moment * v_x_pre);
  26. v_y = moment * v_y_pre + gamma * dfy(p_error_dx - moment * v_x_pre , p_error_dy - moment * v_x_pre);
  27. dx -= v_x;
  28. dy -= v_y;
  29.  
  30. cur_error_moment_x = (p_error_dx - dx);
  31. cur_error_moment_y = (p_error_dy - dy);
  32. cur_error_x = abs_val(p_error_dx - dx);
  33. cur_error_y = abs_val(p_error_dy - dy);
  34.  
  35. printf("ni= %i n", iters);
  36. printf("nc_error x= %fn", cur_error_x);
  37. printf("nc_error y= %fn", cur_error_y);
  38. printf("n==================================n");
  39. printf("ndx = %f , dy= %f n", dx, dy);
  40. printf("n E(dx,dy)= %fn", cost_func(dx, dy));
  41. printf("n==================================n");
  42. writeFile << dx << "," << dy << "," << cost_func(dx, dy) << std::endl;
  43. iters++;
  44. v_x_pre = v_x;
  45. v_y_pre = v_y;
  46. if (cost_func(dx, dy) < cost_func(dxMin, dyMin)) {
  47. dxMin = dx;
  48. dyMin = dy;
  49. }
  50. } while ((error < cur_error_x || error < cur_error_y) && iters < max_iters);
  51.  
  52. v_x = 0; v_y = 0;
  53. v_x_pre = 0; v_y_pre = 0;
  54.  
  55. writeFile.close();
  56. return Point(dxMin, dyMin);
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement