CGC_Codes

NN Function

Jun 4th, 2017
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 8.24 KB | None | 0 0
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Data;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using System.Text.RegularExpressions;
  8.  
  9. namespace NN_SigmoidFunction
  10. {
  11.     class Program
  12.     {
  13.         static void Main(string[] args)
  14.         {
  15.             string TrainfilePath = "", TestFilePath = "";
  16.             double learningRate = 0.0;
  17.             int iterationCount = 0;
  18.             DataTable table = new DataTable("data");
  19.             try
  20.             {
  21.                 TrainfilePath = args[0];
  22.                 TestFilePath = args[1];
  23.                 learningRate = Convert.ToDouble(args[2]);
  24.                 iterationCount = Convert.ToInt32(args[3]);
  25.             }
  26.             catch (Exception ex)
  27.             {
  28.                 Console.WriteLine("Error in the input Please verify your input :\n" + ex.Message);
  29.             }
  30.  
  31.            
  32.             table = ReadFiles(TrainfilePath);
  33.  
  34.             double[] weights = new double[(table.Columns.Count - 1)];
  35.             for (int i = 0; i < weights.Length; i++)
  36.             {
  37.                 weights[i] = 0;
  38.             }
  39.  
  40.             TrainData(table, weights, learningRate, iterationCount);
  41.  
  42.             DataTable testTable = new DataTable("TestData");
  43.             testTable = ReadFiles(TestFilePath);
  44.  
  45.            
  46.             int trainFailures = TestData(table, weights);
  47.            
  48.             int testFailures = TestData(testTable, weights);
  49.            
  50.             int totalRecordsTrain = table.Rows.Count;
  51.            
  52.             int totalRecordsTest = testTable.Rows.Count;
  53.  
  54.            
  55.             Console.WriteLine("\n\nResult");
  56.             Console.WriteLine("\n\nTrain Data:");
  57.             Console.WriteLine("------------------------------------------");
  58.             Console.WriteLine("Total Records of Train: " + totalRecordsTrain);
  59.             Console.WriteLine("Total Failure Records of Train : " + trainFailures);
  60.             double accuracyTrain = (double)(totalRecordsTrain - trainFailures) * 100 / totalRecordsTrain;
  61.             Console.WriteLine("Accuracy of Train :  " + accuracyTrain + "%");
  62.  
  63.             Console.WriteLine("\nTest Data :");
  64.             Console.WriteLine("------------------------------------------");
  65.             Console.WriteLine("Total Records of Test : " + totalRecordsTest.ToString());
  66.             Console.WriteLine("Total Failure Records of Test : " + testFailures);
  67.             double accuracyTest = (double)(totalRecordsTest - testFailures) * 100 / totalRecordsTest;
  68.             Console.Write("Accuracy of Test :" + accuracyTest + "%");
  69.  
  70.             Console.ReadLine();
  71.  
  72.  
  73.         }
  74.  
  75.         private static int TestData(DataTable _table, double[] _weights)
  76.         {
  77.  
  78.             DataTable Table = _table;
  79.  
  80.             double summation = 0.0;
  81.             double sigmoid = 0.0;
  82.             int failCount = 0;
  83.             int result = 0;
  84.             double[] weights = _weights;
  85.             foreach (DataRow row in Table.Rows)
  86.             {
  87.                 result = 0;
  88.                 summation = 0;
  89.                 for (int i = 0; i < Table.Columns.Count - 1; i++)
  90.                 {
  91.                     double rowValue = Convert.ToDouble(row[i]);
  92.                     summation += (weights[i] * rowValue);
  93.                 }
  94.                
  95.                 sigmoid = 1 / (1 + Math.Exp(-(summation)));
  96.                 if (sigmoid >= 0.5)
  97.                 {
  98.                     result = 1;
  99.                 }
  100.  
  101.                 int res = Convert.ToInt32(row["Result"]);
  102.                 if (result != Convert.ToInt32(row["Result"]))
  103.                 {
  104.                     failCount++;
  105.                 }
  106.             }
  107.             return failCount;
  108.         }
  109.        
  110.         private static void TrainData(DataTable _table, double[] _weights, double learnRate, int _iteration)
  111.         {
  112.             int rowCount = _iteration;
  113.             double learningRate = learnRate;
  114.             double summation = 0.0;
  115.             DataTable table = _table;
  116.             double[] weights = _weights;
  117.  
  118.             int nTimes = rowCount / table.Rows.Count;
  119.             int balance = rowCount % table.Rows.Count;
  120.  
  121.             double rowValue = 0.0;
  122.             while (nTimes != 0)
  123.             {
  124.                 for (int i = 0; i < table.Rows.Count; i++)
  125.                 {
  126.                     summation = 0;
  127.                     for (int j = 0; j < table.Columns.Count - 1; j++)
  128.                     {
  129.                         rowValue = Convert.ToDouble(table.Rows[i][j]);
  130.                         summation += (weights[j] * rowValue);
  131.                     }
  132.                    
  133.                     double Sigmoid = 1 / (1 + Math.Exp(-(summation)));
  134.                    
  135.                    
  136.                     double Err = Convert.ToInt32(table.Rows[i][table.Columns["Result"]]) - Sigmoid;
  137.                    
  138.                     double sigMoidPrime = Math.Exp(summation) / Math.Pow((1 + Math.Exp(summation)), 2);
  139.                    
  140.                     for (int k = 0; k < table.Columns.Count - 1; k++)
  141.                     {
  142.                         weights[k] = weights[k] + learningRate * Err * sigMoidPrime * Convert.ToDouble(table.Rows[i][k]);
  143.                     }
  144.                 }
  145.                 nTimes--;
  146.  
  147.             }
  148.             for (int i = 0; i < balance; i++)
  149.             {
  150.                 summation = 0;
  151.                 for (int j = 0; j < table.Columns.Count - 1; j++)
  152.                 {
  153.                     rowValue = Convert.ToDouble(table.Rows[i][j]);
  154.                     summation += (weights[j] * rowValue);
  155.                 }
  156.                
  157.                 double Sigmoid = 1 / (1 + Math.Exp(-(summation)));
  158.                
  159.                 double Err = Convert.ToInt32(table.Rows[i][table.Columns["Result"]]) - Sigmoid;
  160.                
  161.                 double sigMoidPrime = Math.Exp(summation) / Math.Pow((1 + Math.Exp(summation)), 2);
  162.                
  163.                 for (int k = 0; k < table.Columns.Count - 1; k++)
  164.                 {
  165.                     weights[k] = weights[k] + learningRate * Err * sigMoidPrime * Convert.ToDouble(table.Rows[i][k]);
  166.                 }
  167.             }
  168.  
  169.         }
  170.  
  171.        
  172.         private static DataTable ReadFiles(string filePath)
  173.         {
  174.             DataTable table = new DataTable("Data");
  175.             TextReader reader = new StreamReader(filePath);
  176.             bool colAdded = false;
  177.             try
  178.             {
  179.                 while (reader.Peek() != -1)
  180.                 {
  181.                     string[] tokens = Regex.Split(reader.ReadLine(), "[\t\r\n]");
  182.  
  183.                    
  184.                     Array.Resize(ref tokens, tokens.Length + 1);
  185.                     tokens[tokens.Length - 1] = "1";
  186.  
  187.                    
  188.                     if (!colAdded)
  189.                     {
  190.                         foreach (string token in tokens)
  191.                         {
  192.                             if (token == "1")
  193.                             {
  194.                                 table.Columns.Add("X0");
  195.                                 break;
  196.                             }
  197.                             else if (token == "")
  198.                             {
  199.                                 table.Columns.Add("Result");
  200.                             }
  201.                             else
  202.                             {
  203.                                 table.Columns.Add(token);
  204.                             }
  205.  
  206.                         }
  207.                         colAdded = true;
  208.  
  209.                     }
  210.  
  211.                    
  212.                     else
  213.                     {
  214.                         DataRow row = table.NewRow();
  215.                         for (int i = 0; i < table.Columns.Count; i++)
  216.                         {
  217.                             row[i] = tokens[i];
  218.                         }
  219.                         table.Rows.Add(row);
  220.  
  221.                     }
  222.                 }
  223.             }
  224.             catch (IndexOutOfRangeException)
  225.             {
  226.                
  227.             }
  228.             finally
  229.             {
  230.                
  231.                 if (reader != null)
  232.                 {
  233.                     reader.Close();
  234.                 }
  235.             }
  236.            
  237.             table.Columns["X0"].SetOrdinal(0);
  238.             return table;
  239.         }
  240.     }
  241. }
Advertisement
Add Comment
Please, Sign In to add comment