Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package com.rbennett;
- import smile.data.DataFrame;
- import smile.data.formula.Formula;
- import smile.data.type.StructType;
- import smile.data.vector.DoubleVector;
- import smile.regression.LinearModel;
- import smile.regression.OLS;
- import tech.tablesaw.api.NumericColumn;
- import tech.tablesaw.api.Table;
- import tech.tablesaw.selection.Selection;
- import java.util.Arrays;
- public class HouseLinearRegression {
- public static void main(String[] args) throws Exception {
- Table allHousingPrices = Table.read().csv("datasets/HousePricesAll.csv").dropRowsWithMissingValues();
- // Split dependent and independent variables
- //NumericColumn<?> y = allHousingPrices.numberColumn("SalePrices"); // dependent
- //Table X = (Table)allHousingPrices.removeColumns("SalePrices"); // independent
- // Reload original table ( because the above appears to have a side effect)
- allHousingPrices = Table.read().csv("datasets/HousePricesAll.csv").dropRowsWithMissingValues();
- //Table X_train = X.where(Selection.withRange(1, 727));
- //Table X_test = X.where(Selection.withRange(728, 1127));
- //NumericColumn<?> y_train = y.where(Selection.withRange(1, 727));
- //NumericColumn<?> y_test = y.where(Selection.withRange(728, 1127));
- Table allHousingPrices_train = allHousingPrices.where(Selection.withRange(1, 727));
- Table allHousingPrices_test = allHousingPrices.where(Selection.withRange(728, 1127));
- //double[][] X_train_arr = X_train.as().doubleMatrix();
- //double[] y_train_arr = y_train.asDoubleArray();
- DataFrame dataframe = DataFrame.of(allHousingPrices_train.as().doubleMatrix(), allHousingPrices_train.columnNames().toArray(new String[0]));
- Formula formula = Formula.lhs("SalePrices");
- LinearModel model = OLS.fit(formula, dataframe); // new OLS(X_train_arr, y_train_arr);
- System.out.println(model);
- // Use the trained regression model to predict housing prices
- double[] values = {20, 64, 7406, 7, 5, 2006, 2006, 684, 515, 1199, 1220, 1220, 2, 2, 6, 2006, 2, 632, 54, 2006};
- double predictedValue = model.predict(values);
- System.out.println("The predicted value is " + predictedValue);
- System.out.println("The actual value is 194000");
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement