Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package com.nikhilgopal.spark;
- /**
- * Created by nikhilgopal on 3/24/17.
- */
- public class SGDLinReg {
- public static void main(String[] args) {
- double[] coefficients = {0.4, 0.8};
- double[][] dataset = {
- {1.0, 1.0},
- {2.0, 3.0},
- {4.0, 3.0},
- {3.0, 2.0},
- {5.0, 5.0}
- };
- double[] newcoef = coef_sgd(dataset, 0.001, 500);
- System.out.println(newcoef[0] + " " + newcoef[1]);
- }
- private static double[] coef_sgd(double[][] train, double l_rate, double n_epoch) {
- double[] coefficients = {0.0, 0.0};
- for (int e = 0; e < n_epoch; e++) {
- double sum_error = 0.0;
- for (int w = 0; w < train.length; w++) {
- double yhat = predict(train[w], coefficients);
- double error = yhat - train[w][1];
- sum_error += error*error;
- coefficients[0] = coefficients[0] = l_rate*error;
- for (int k = 0; k < (train[w].length-1); k++) {
- coefficients[k+1] = coefficients[k+1] - l_rate * error * train[w][k];
- }
- }
- System.out.println("EPOCH " + e + " LRATE " + l_rate + " ERR " + sum_error);
- }
- return coefficients;
- }
- private static double predict(double[] row, double[] coef) {
- double yhat = coef[0];
- for (int i = 0; i < (coef.length-1); i++) {
- yhat += coef[i+1] * row[i];
- }
- return yhat;
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement