Guest User

Untitled

a guest
Jan 5th, 2021
72
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. package com.rbennett;
  2.  
  3. import smile.data.DataFrame;
  4. import smile.data.formula.Formula;
  5. import smile.data.type.StructType;
  6. import smile.data.vector.DoubleVector;
  7. import smile.regression.LinearModel;
  8. import smile.regression.OLS;
  9. import tech.tablesaw.api.NumericColumn;
  10. import tech.tablesaw.api.Table;
  11. import tech.tablesaw.selection.Selection;
  12.  
  13. import java.util.Arrays;
  14.  
  15. public class HouseLinearRegression {
  16.  
  17. public static void main(String[] args) throws Exception {
  18. Table allHousingPrices = Table.read().csv("datasets/HousePricesAll.csv").dropRowsWithMissingValues();
  19.  
  20. // Split dependent and independent variables
  21. //NumericColumn<?> y = allHousingPrices.numberColumn("SalePrices"); // dependent
  22. //Table X = (Table)allHousingPrices.removeColumns("SalePrices"); // independent
  23.  
  24. // Reload original table ( because the above appears to have a side effect)
  25. allHousingPrices = Table.read().csv("datasets/HousePricesAll.csv").dropRowsWithMissingValues();
  26.  
  27. //Table X_train = X.where(Selection.withRange(1, 727));
  28. //Table X_test = X.where(Selection.withRange(728, 1127));
  29.  
  30. //NumericColumn<?> y_train = y.where(Selection.withRange(1, 727));
  31. //NumericColumn<?> y_test = y.where(Selection.withRange(728, 1127));
  32.  
  33. Table allHousingPrices_train = allHousingPrices.where(Selection.withRange(1, 727));
  34. Table allHousingPrices_test = allHousingPrices.where(Selection.withRange(728, 1127));
  35.  
  36. //double[][] X_train_arr = X_train.as().doubleMatrix();
  37. //double[] y_train_arr = y_train.asDoubleArray();
  38.  
  39. DataFrame dataframe = DataFrame.of(allHousingPrices_train.as().doubleMatrix(), allHousingPrices_train.columnNames().toArray(new String[0]));
  40.  
  41. Formula formula = Formula.lhs("SalePrices");
  42.  
  43. LinearModel model = OLS.fit(formula, dataframe); // new OLS(X_train_arr, y_train_arr);
  44.  
  45. System.out.println(model);
  46.  
  47. // Use the trained regression model to predict housing prices
  48. double[] values = {20, 64, 7406, 7, 5, 2006, 2006, 684, 515, 1199, 1220, 1220, 2, 2, 6, 2006, 2, 632, 54, 2006};
  49. double predictedValue = model.predict(values);
  50. System.out.println("The predicted value is " + predictedValue);
  51. System.out.println("The actual value is 194000");
  52.  
  53. }
  54. }
  55.  
RAW Paste Data