Advertisement
Guest User

HousePrices-SimpleNetwork

a guest
Jan 23rd, 2019
859
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 9.35 KB | None | 0 0
  1. namespace HousePrices {
  2.     using System;
  3.     using System.Collections.Generic;
  4.     using System.Data;
  5.     using System.Diagnostics;
  6.     using System.IO;
  7.     using System.Linq;
  8.  
  9.     using CsvHelper;
  10.  
  11.     using numpy;
  12.     using SharPy.Runtime;
  13.     using tensorflow;
  14.     using tensorflow.keras;
  15.     using tensorflow.keras.callbacks;
  16.     using tensorflow.keras.layers;
  17.     using tensorflow.train;
  18.  
  19.     static class HousePricesProgram {
  20.         static void Main() {
  21.             const string TrainFile = "train.csv";
  22.             DataTable trainData = LoadData(TrainFile);
  23.             trainData.Columns.Remove("Id");
  24.             var columns = trainData.Columns.Cast<DataColumn>();
  25.             var types = columns.Select(column => column.DataType).Counts();
  26.             var rows = trainData.Rows.Cast<DataRow>();
  27.             var (trainRows, testRows) = rows.RandomSplit(primaryChance: 1);
  28.             var columnTypes = columns.Select(column => {
  29.                 var values = rows.Select(row => (string)row[column]);
  30.                 var trainValues = trainRows.Select(row => (string)row[column]);
  31.                 var testValues = testRows.Select(row => (string)row[column]);
  32.                 double floats = values.Percentage(v => double.TryParse(v, out _));
  33.                 double ints = values.Percentage(v => int.TryParse(v, out _));
  34.                 int distincts = values.Distinct().Count();
  35.                 var normalizer = ValueNormalizer(floats, values);
  36.                 return new { column, values, distincts, ints, floats, normalizer, trainValues, testValues };
  37.             }).OrderBy(c => c.distincts).ThenBy(c => c.column.ColumnName)
  38.             .ToArray();
  39.  
  40.             const string predict = "SalePrice";
  41.  
  42.             ndarray GetInputs(IEnumerable<DataRow> rowSeq) {
  43.                 return np.array(rowSeq.Select(row => np.array(
  44.                     columnTypes.Where(c => c.column.ColumnName != predict)
  45.                     .SelectMany(column => column.normalizer(row.Table.Columns.Contains(column.column.ColumnName) ? (string)row[column.column.ColumnName] : "-1")).ToArray()))
  46.                 .ToArray());
  47.             }
  48.  
  49.             var predictColumn = columnTypes.Single(c => c.column.ColumnName == predict);
  50.             ndarray trainOutputs = np.array(predictColumn.trainValues.AsDouble().Select(v => v ?? -1).ToArray());
  51.             ndarray trainInputs = GetInputs(trainRows);
  52.  
  53.             //ndarray testOutputs = np.array(predictColumn.testValues.AsDouble().Select(v => v ?? -1).ToArray());
  54.             //ndarray testInputs = GetInputs(testRows);
  55.  
  56.             //Debug.Assert(testOutputs.Length == testInputs.Length);
  57.             //Debug.Assert(testOutputs.Length > 20);
  58.             Debug.Assert(trainOutputs.Length == trainInputs.Length);
  59.             Debug.Assert(trainOutputs.Length > 20);
  60.             //Debug.Assert(trainOutputs.Length != testOutputs.Length);
  61.  
  62.             var model = new Sequential(new Layer[] {
  63.                 new Dense(units: 16, activation: tf.nn.relu_fn),
  64.                 new Dropout(rate: 0.1),
  65.                 new Dense(units: 10, activation: tf.nn.relu_fn),
  66.                 new Dense(units: 1, activation: tf.nn.relu_fn),
  67.             });
  68.  
  69.             model.compile(optimizer: new AdamOptimizer(), loss: "mean_squared_error");
  70.  
  71.             //var tensorboard = new TensorBoard(log_dir: $"./logs/{DateTime.Now.ToString("s").Replace(':','-')}");
  72.  
  73.             model.fit(trainInputs, trainOutputs, epochs: 20000, validation_split: 0.075, verbose: 2);
  74.  
  75.             const string SubmissionInputFile = "test.csv";
  76.             DataTable submissionData = LoadData(SubmissionInputFile);
  77.             var submissionRows = submissionData.Rows.Cast<DataRow>();
  78.             ndarray submissionInputs = GetInputs(submissionRows);
  79.             ndarray sumissionOutputs = model.predict(submissionInputs);
  80.             Console.WriteLine("guesses:");
  81.             var random = new Random();
  82.             using (var writer = new StreamWriter("submit.csv")) {
  83.                 writer.WriteLine("Id,SalePrice");
  84.                 foreach (var (id, prediction) in submissionRows.Select(row => int.Parse((string)row["Id"]))
  85.                                                                .Pair(sumissionOutputs.Cast<ndarray>())) {
  86.                     string guess = $"{id},{prediction[0]}";
  87.                     writer.WriteLine(guess);
  88.                     if (random.Next(100) > 99)
  89.                         Console.WriteLine(guess);
  90.                 }
  91.                 writer.Flush();
  92.             }
  93.  
  94.             //float64 trainLoss = model.evaluate(trainInputs, trainOutputs);
  95.             //float64 testLoss = model.evaluate(testInputs, testOutputs);
  96.             //Console.WriteLine($"Test loss: {(int)Math.Sqrt(testLoss)}; Train loss: {(int)Math.Sqrt(trainLoss)}");
  97.             //Console.WriteLine();
  98.  
  99.             //foreach (var column in columnTypes)
  100.             //    Console.WriteLine($"{column.column.ColumnName}: {column.distincts} values, ints: {column.ints:P2}, floats: {column.floats:P2}");
  101.             //Console.WriteLine();
  102.  
  103.             //Console.WriteLine("Many value columns:");
  104.             //foreach (var column in columnTypes.Where(ct => ct.distincts > 10 && ct.floats < 0.01)) {
  105.             //    Console.Write(column.column.ColumnName + ": ");
  106.             //    Console.WriteLine(string.Join(", ", column.values.Distinct().OrderBy(n => n)));
  107.             //}
  108.             //Console.WriteLine();
  109.  
  110.             //Console.WriteLine("non-parsable floats");
  111.             //foreach (var column in columnTypes.Where(ct => ct.floats > 0 && ct.floats < 1)) {
  112.             //    Console.Write(column.column.ColumnName + ": ");
  113.             //    Console.WriteLine(string.Join(", ", column.values.Where(v => !double.TryParse(v, out _)).Distinct().OrderBy(n => n)));
  114.             //}
  115.             //Console.WriteLine();
  116.  
  117.             //Console.WriteLine("float ranges:");
  118.             //foreach (var column in columnTypes.Where(ct => ct.floats > 0.01)) {
  119.             //    Console.Write(column.column.ColumnName + ": ");
  120.             //    var validValues = column.values.AsDouble().Where(v => v != null).Select(v => v.Value);
  121.             //    Console.WriteLine($"{validValues.Min()}...{validValues.Max()}");
  122.             //}
  123.             //Console.WriteLine();
  124.         }
  125.  
  126.         static IEnumerable<(T1, T2)> Pair<T1, T2>(this IEnumerable<T1> seq1, IEnumerable<T2> seq2)
  127.             => seq1.Zip(seq2, (v1, v2) => (v1, v2));
  128.  
  129.         static (List<T>, List<T>) RandomSplit<T>(this IEnumerable<T> seq, double primaryChance) {
  130.             var random = new Random();
  131.             var primary = new List<T>();
  132.             var secondary = new List<T>();
  133.             foreach(var item in seq) {
  134.                 if (random.NextDouble() < primaryChance)
  135.                     primary.Add(item);
  136.                 else
  137.                     secondary.Add(item);
  138.             }
  139.             return (primary, secondary);
  140.         }
  141.  
  142.         static DataTable LoadData(string csvFilePath) {
  143.             var result = new DataTable();
  144.             using (var reader = new CsvDataReader(new CsvReader(new StreamReader(csvFilePath)))) {
  145.                 result.Load(reader);
  146.             }
  147.             return result;
  148.         }
  149.  
  150.         static IEnumerable<double?> AsDouble(this IEnumerable<string> seq) {
  151.             foreach(var value in seq)
  152.                 yield return double.TryParse(value, out var result) ? result : (double?)null;
  153.         }
  154.  
  155.         static Func<string, double[]> ValueNormalizer(double floats, IEnumerable<string> values) {
  156.             if (floats > 0.01) {
  157.                 double max = values.AsDouble().Max().Value;
  158.                 return s => new[] { double.TryParse(s, out double v) ? v / max : -1 };
  159.             } else {
  160.                 string[] domain = values.Distinct().OrderBy(v => v).ToArray();
  161.                 return s => new double[domain.Length+1].Set(Array.IndexOf(domain, s)+1, 1);
  162.             }
  163.         }
  164.  
  165.         static double Percentage<T>(this IEnumerable<T> seq, Func<T, bool> predicate) {
  166.             int total = 0;
  167.             int matching = 0;
  168.             foreach(var item in seq) {
  169.                 total++;
  170.                 matching += predicate(item) ? 1 : 0;
  171.             }
  172.             return matching * 1.0 / total;
  173.         }
  174.         static T[] Set<T>(this T[] array, int index, T value) {
  175.             array[index] = value;
  176.             return array;
  177.         }
  178.         static void PrettyPrint<TKey, TValue>(TextWriter writer, IReadOnlyDictionary<TKey, TValue> dict) {
  179.             bool multiline = dict.Count > 10;
  180.             string separator = multiline ? "," + Environment.NewLine : ",";
  181.             string prefix = multiline ? "  " : " ";
  182.             writer.Write('{');
  183.             if (multiline)
  184.                 writer.WriteLine();
  185.             foreach(var entry in dict) {
  186.                 writer.Write(prefix);
  187.                 writer.Write(entry.Key);
  188.                 writer.Write(": ");
  189.                 writer.Write(entry.Value);
  190.                 writer.Write(separator);
  191.             }
  192.             writer.Write('}');
  193.         }
  194.  
  195.         static IReadOnlyDictionary<T, int> Counts<T>(this IEnumerable<T> seq) {
  196.             var result = new Dictionary<T, int>();
  197.             foreach(var item in seq) {
  198.                 result.TryGetValue(item, out int count);
  199.                 result[item] = count + 1;
  200.             }
  201.             return result;
  202.         }
  203.     }
  204. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement