Guest User

HousePrices-SimpleNetwork

a guest
Jan 23rd, 2019
326
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. namespace HousePrices {
  2.     using System;
  3.     using System.Collections.Generic;
  4.     using System.Data;
  5.     using System.Diagnostics;
  6.     using System.IO;
  7.     using System.Linq;
  8.  
  9.     using CsvHelper;
  10.  
  11.     using numpy;
  12.     using SharPy.Runtime;
  13.     using tensorflow;
  14.     using tensorflow.keras;
  15.     using tensorflow.keras.callbacks;
  16.     using tensorflow.keras.layers;
  17.     using tensorflow.train;
  18.  
  19.     static class HousePricesProgram {
  20.         static void Main() {
  21.             const string TrainFile = "train.csv";
  22.             DataTable trainData = LoadData(TrainFile);
  23.             trainData.Columns.Remove("Id");
  24.             var columns = trainData.Columns.Cast<DataColumn>();
  25.             var types = columns.Select(column => column.DataType).Counts();
  26.             var rows = trainData.Rows.Cast<DataRow>();
  27.             var (trainRows, testRows) = rows.RandomSplit(primaryChance: 1);
  28.             var columnTypes = columns.Select(column => {
  29.                 var values = rows.Select(row => (string)row[column]);
  30.                 var trainValues = trainRows.Select(row => (string)row[column]);
  31.                 var testValues = testRows.Select(row => (string)row[column]);
  32.                 double floats = values.Percentage(v => double.TryParse(v, out _));
  33.                 double ints = values.Percentage(v => int.TryParse(v, out _));
  34.                 int distincts = values.Distinct().Count();
  35.                 var normalizer = ValueNormalizer(floats, values);
  36.                 return new { column, values, distincts, ints, floats, normalizer, trainValues, testValues };
  37.             }).OrderBy(c => c.distincts).ThenBy(c => c.column.ColumnName)
  38.             .ToArray();
  39.  
  40.             const string predict = "SalePrice";
  41.  
  42.             ndarray GetInputs(IEnumerable<DataRow> rowSeq) {
  43.                 return np.array(rowSeq.Select(row => np.array(
  44.                     columnTypes.Where(c => c.column.ColumnName != predict)
  45.                     .SelectMany(column => column.normalizer(row.Table.Columns.Contains(column.column.ColumnName) ? (string)row[column.column.ColumnName] : "-1")).ToArray()))
  46.                 .ToArray());
  47.             }
  48.  
  49.             var predictColumn = columnTypes.Single(c => c.column.ColumnName == predict);
  50.             ndarray trainOutputs = np.array(predictColumn.trainValues.AsDouble().Select(v => v ?? -1).ToArray());
  51.             ndarray trainInputs = GetInputs(trainRows);
  52.  
  53.             //ndarray testOutputs = np.array(predictColumn.testValues.AsDouble().Select(v => v ?? -1).ToArray());
  54.             //ndarray testInputs = GetInputs(testRows);
  55.  
  56.             //Debug.Assert(testOutputs.Length == testInputs.Length);
  57.             //Debug.Assert(testOutputs.Length > 20);
  58.             Debug.Assert(trainOutputs.Length == trainInputs.Length);
  59.             Debug.Assert(trainOutputs.Length > 20);
  60.             //Debug.Assert(trainOutputs.Length != testOutputs.Length);
  61.  
  62.             var model = new Sequential(new Layer[] {
  63.                 new Dense(units: 16, activation: tf.nn.relu_fn),
  64.                 new Dropout(rate: 0.1),
  65.                 new Dense(units: 10, activation: tf.nn.relu_fn),
  66.                 new Dense(units: 1, activation: tf.nn.relu_fn),
  67.             });
  68.  
  69.             model.compile(optimizer: new AdamOptimizer(), loss: "mean_squared_error");
  70.  
  71.             //var tensorboard = new TensorBoard(log_dir: $"./logs/{DateTime.Now.ToString("s").Replace(':','-')}");
  72.  
  73.             model.fit(trainInputs, trainOutputs, epochs: 20000, validation_split: 0.075, verbose: 2);
  74.  
  75.             const string SubmissionInputFile = "test.csv";
  76.             DataTable submissionData = LoadData(SubmissionInputFile);
  77.             var submissionRows = submissionData.Rows.Cast<DataRow>();
  78.             ndarray submissionInputs = GetInputs(submissionRows);
  79.             ndarray sumissionOutputs = model.predict(submissionInputs);
  80.             Console.WriteLine("guesses:");
  81.             var random = new Random();
  82.             using (var writer = new StreamWriter("submit.csv")) {
  83.                 writer.WriteLine("Id,SalePrice");
  84.                 foreach (var (id, prediction) in submissionRows.Select(row => int.Parse((string)row["Id"]))
  85.                                                                .Pair(sumissionOutputs.Cast<ndarray>())) {
  86.                     string guess = $"{id},{prediction[0]}";
  87.                     writer.WriteLine(guess);
  88.                     if (random.Next(100) > 99)
  89.                         Console.WriteLine(guess);
  90.                 }
  91.                 writer.Flush();
  92.             }
  93.  
  94.             //float64 trainLoss = model.evaluate(trainInputs, trainOutputs);
  95.             //float64 testLoss = model.evaluate(testInputs, testOutputs);
  96.             //Console.WriteLine($"Test loss: {(int)Math.Sqrt(testLoss)}; Train loss: {(int)Math.Sqrt(trainLoss)}");
  97.             //Console.WriteLine();
  98.  
  99.             //foreach (var column in columnTypes)
  100.             //    Console.WriteLine($"{column.column.ColumnName}: {column.distincts} values, ints: {column.ints:P2}, floats: {column.floats:P2}");
  101.             //Console.WriteLine();
  102.  
  103.             //Console.WriteLine("Many value columns:");
  104.             //foreach (var column in columnTypes.Where(ct => ct.distincts > 10 && ct.floats < 0.01)) {
  105.             //    Console.Write(column.column.ColumnName + ": ");
  106.             //    Console.WriteLine(string.Join(", ", column.values.Distinct().OrderBy(n => n)));
  107.             //}
  108.             //Console.WriteLine();
  109.  
  110.             //Console.WriteLine("non-parsable floats");
  111.             //foreach (var column in columnTypes.Where(ct => ct.floats > 0 && ct.floats < 1)) {
  112.             //    Console.Write(column.column.ColumnName + ": ");
  113.             //    Console.WriteLine(string.Join(", ", column.values.Where(v => !double.TryParse(v, out _)).Distinct().OrderBy(n => n)));
  114.             //}
  115.             //Console.WriteLine();
  116.  
  117.             //Console.WriteLine("float ranges:");
  118.             //foreach (var column in columnTypes.Where(ct => ct.floats > 0.01)) {
  119.             //    Console.Write(column.column.ColumnName + ": ");
  120.             //    var validValues = column.values.AsDouble().Where(v => v != null).Select(v => v.Value);
  121.             //    Console.WriteLine($"{validValues.Min()}...{validValues.Max()}");
  122.             //}
  123.             //Console.WriteLine();
  124.         }
  125.  
  126.         static IEnumerable<(T1, T2)> Pair<T1, T2>(this IEnumerable<T1> seq1, IEnumerable<T2> seq2)
  127.             => seq1.Zip(seq2, (v1, v2) => (v1, v2));
  128.  
  129.         static (List<T>, List<T>) RandomSplit<T>(this IEnumerable<T> seq, double primaryChance) {
  130.             var random = new Random();
  131.             var primary = new List<T>();
  132.             var secondary = new List<T>();
  133.             foreach(var item in seq) {
  134.                 if (random.NextDouble() < primaryChance)
  135.                     primary.Add(item);
  136.                 else
  137.                     secondary.Add(item);
  138.             }
  139.             return (primary, secondary);
  140.         }
  141.  
  142.         static DataTable LoadData(string csvFilePath) {
  143.             var result = new DataTable();
  144.             using (var reader = new CsvDataReader(new CsvReader(new StreamReader(csvFilePath)))) {
  145.                 result.Load(reader);
  146.             }
  147.             return result;
  148.         }
  149.  
  150.         static IEnumerable<double?> AsDouble(this IEnumerable<string> seq) {
  151.             foreach(var value in seq)
  152.                 yield return double.TryParse(value, out var result) ? result : (double?)null;
  153.         }
  154.  
  155.         static Func<string, double[]> ValueNormalizer(double floats, IEnumerable<string> values) {
  156.             if (floats > 0.01) {
  157.                 double max = values.AsDouble().Max().Value;
  158.                 return s => new[] { double.TryParse(s, out double v) ? v / max : -1 };
  159.             } else {
  160.                 string[] domain = values.Distinct().OrderBy(v => v).ToArray();
  161.                 return s => new double[domain.Length+1].Set(Array.IndexOf(domain, s)+1, 1);
  162.             }
  163.         }
  164.  
  165.         static double Percentage<T>(this IEnumerable<T> seq, Func<T, bool> predicate) {
  166.             int total = 0;
  167.             int matching = 0;
  168.             foreach(var item in seq) {
  169.                 total++;
  170.                 matching += predicate(item) ? 1 : 0;
  171.             }
  172.             return matching * 1.0 / total;
  173.         }
  174.         static T[] Set<T>(this T[] array, int index, T value) {
  175.             array[index] = value;
  176.             return array;
  177.         }
  178.         static void PrettyPrint<TKey, TValue>(TextWriter writer, IReadOnlyDictionary<TKey, TValue> dict) {
  179.             bool multiline = dict.Count > 10;
  180.             string separator = multiline ? "," + Environment.NewLine : ",";
  181.             string prefix = multiline ? "  " : " ";
  182.             writer.Write('{');
  183.             if (multiline)
  184.                 writer.WriteLine();
  185.             foreach(var entry in dict) {
  186.                 writer.Write(prefix);
  187.                 writer.Write(entry.Key);
  188.                 writer.Write(": ");
  189.                 writer.Write(entry.Value);
  190.                 writer.Write(separator);
  191.             }
  192.             writer.Write('}');
  193.         }
  194.  
  195.         static IReadOnlyDictionary<T, int> Counts<T>(this IEnumerable<T> seq) {
  196.             var result = new Dictionary<T, int>();
  197.             foreach(var item in seq) {
  198.                 result.TryGetValue(item, out int count);
  199.                 result[item] = count + 1;
  200.             }
  201.             return result;
  202.         }
  203.     }
  204. }
RAW Paste Data

Adblocker detected! Please consider disabling it...

We've detected AdBlock Plus or some other adblocking software preventing Pastebin.com from fully loading.

We don't have any obnoxious sound, or popup ads, we actively block these annoying types of ads!

Please add Pastebin.com to your ad blocker whitelist or disable your adblocking software.

×