Advertisement
Guest User

Untitled

a guest
Mar 24th, 2017
64
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.50 KB | None | 0 0
  1. package com.nikhilgopal.spark;
  2.  
  3. /**
  4. * Created by nikhilgopal on 3/24/17.
  5. */
  6. public class SGDLinReg {
  7. public static void main(String[] args) {
  8. double[] coefficients = {0.4, 0.8};
  9. double[][] dataset = {
  10. {1.0, 1.0},
  11. {2.0, 3.0},
  12. {4.0, 3.0},
  13. {3.0, 2.0},
  14. {5.0, 5.0}
  15. };
  16. double[] newcoef = coef_sgd(dataset, 0.001, 500);
  17. System.out.println(newcoef[0] + " " + newcoef[1]);
  18. }
  19.  
  20. private static double[] coef_sgd(double[][] train, double l_rate, double n_epoch) {
  21. double[] coefficients = {0.0, 0.0};
  22. for (int e = 0; e < n_epoch; e++) {
  23. double sum_error = 0.0;
  24. for (int w = 0; w < train.length; w++) {
  25. double yhat = predict(train[w], coefficients);
  26. double error = yhat - train[w][1];
  27. sum_error += error*error;
  28. coefficients[0] = coefficients[0] = l_rate*error;
  29. for (int k = 0; k < (train[w].length-1); k++) {
  30. coefficients[k+1] = coefficients[k+1] - l_rate * error * train[w][k];
  31. }
  32. }
  33. System.out.println("EPOCH " + e + " LRATE " + l_rate + " ERR " + sum_error);
  34. }
  35. return coefficients;
  36. }
  37.  
  38. private static double predict(double[] row, double[] coef) {
  39. double yhat = coef[0];
  40. for (int i = 0; i < (coef.length-1); i++) {
  41. yhat += coef[i+1] * row[i];
  42. }
  43. return yhat;
  44. }
  45. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement