Advertisement
StoneHaos

neuro3

Feb 15th, 2022
577
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. using System;
  2.  
  3. namespace neuro {
  4.     class NeuralNetwork {
  5.         private double?[][] network;
  6.         private double?[] results;
  7.         private int inputs;
  8.         private int[] hides;
  9.         private int cnthides;
  10.         private int outputs;
  11.         private int cnt;
  12.         private Random rnd = new Random();
  13.        
  14.         public NeuralNetwork(int inputs, int outputs, params int[] hides) {
  15.             this.inputs = inputs;
  16.             this.hides = hides;
  17.             this.outputs = outputs;
  18.             int cnthides = 0;
  19.             for (int i = 0; i < hides.Length; ++ i) {
  20.                 cnthides += hides[i];
  21.             }
  22.             int n = inputs + cnthides + outputs;
  23.             this.cnt = n;
  24.             this.cnthides = cnthides;
  25.             network = new double?[n][];
  26.             results = new double?[n];
  27.             for (int i = 0; i < n; ++ i) {
  28.                 network[i] = new double?[n];
  29.                 for (int j = 0; j < n; ++ j) {
  30.                     network[i][j] = null;
  31.                 }
  32.             }
  33.             if (hides.Length == 0) {
  34.                 for (int i = 0; i < inputs; ++ i) {
  35.                     for (int j = inputs; j < cnt; ++ j) {
  36.                         network[i][j] = random();
  37.                     }
  38.                 }
  39.             }
  40.             else {
  41.                 for (int i = 0; i < inputs; ++ i) {
  42.                     for (int j = 0; j < hides[0]; ++ j) {
  43.                         network[i][inputs + j] = random();
  44.                     }
  45.                 }
  46.                 int a = inputs;
  47.                 int b = inputs + hides[0];
  48.                 int t = inputs + hides[0];
  49.                 for (int i = 1; i < hides.Length; ++ i) {
  50.                     for (int k = 0; k < hides[i]; ++ k) {
  51.                         for (int j = a; j < b; ++ j) {
  52.                             network[j][t] = random();
  53.                         }
  54.                         ++t;
  55.                     }
  56.                     a = b;
  57.                     b += hides[i];
  58.                 }
  59.                 for (int k = 0; k < outputs; ++ k) {
  60.                     for (int j = a; j < b; ++ j) {
  61.                         network[j][t] = random();
  62.                     }
  63.                     ++t;
  64.                 }
  65.             }
  66.         }
  67.         public double random1() {
  68.             return 1.0 / rnd.Next(1, 101);
  69.             //return 0.5;
  70.         }
  71.         public double random() {
  72.             return rnd.NextDouble();
  73.         }
  74.         public double F(double x) {
  75.             return 1.0 / (1.0 + Math.Exp(-x));
  76.         }
  77.         public void Init(double[] inputValues) {
  78.             if (inputValues.Length != inputs) throw new Exception();
  79.             for (int i = 0; i < cnt; ++ i) {
  80.                 results[i] = null;
  81.             }
  82.             for (int i = 0; i < inputs; ++ i) {
  83.                 results[i] = inputValues[i];
  84.             }
  85.         }
  86.         public double[] Count() {
  87.             for (int i = inputs; i < cnt; ++ i) {
  88.                 double s = 0;
  89.                 for (int j = 0; j < cnt; ++ j) {
  90.                     if (network[j][i] != null)
  91.                         s += (double)results[j] * (double)network[j][i];
  92.                 }
  93.                 results[i] = F(s);
  94.             }
  95.             double[] answers = new double[outputs];
  96.             int t = inputs + cnthides;
  97.             for (int i = 0; i < outputs; ++ i) {
  98.                 answers[i] = (double)results[t];
  99.                 ++ t;
  100.             }
  101.             return answers;
  102.         }
  103.         public double[] Count(double[] inputValues) {
  104.             Init(inputValues);
  105.             return Count();
  106.         }
  107.         private bool isOutputNeuron(int n) {
  108.             for (int i = 0; i < cnt; ++ i) {
  109.                 if (network[n][i] != null)
  110.                     return false;
  111.             }
  112.             return true;
  113.         }
  114.         private bool isInputNeuron(int n) {
  115.             for (int i = 0; i < cnt; ++ i) {
  116.                 if (network[i][n] != null)
  117.                     return false;
  118.             }
  119.             return true;
  120.         }
  121.  
  122.         public void Learn(double[] test, double[] testAnswer) {
  123.             double[] errors = new double[cnt];
  124.             Init(test);
  125.             for (int i = cnt - 1; i >= 0; -- i) {
  126.                 if (isOutputNeuron(i)) {
  127.                     double[] res = Count();
  128.                     errors[i] = testAnswer[i - (inputs + cnthides)] - res[i - (inputs + cnthides)];
  129.                     continue;
  130.                 }
  131.                 else if (isInputNeuron(i)) continue;
  132.  
  133.                 double s = 0;
  134.                 for (int j = 0; j < cnt; ++ j) {
  135.                     if (network[i][j] != null) {
  136.                         s += errors[j] * (double)network[i][j];
  137.                     }
  138.                 }
  139.                 errors[i] = (double)results[i] * (1 - (double)results[i]) * s;
  140.             }
  141.             for (int i = 0; i < cnt; ++ i) {
  142.                 if (isOutputNeuron(i)) continue;
  143.                 for (int j = 0; j < cnt; ++ j) {
  144.                     if (network[i][j] != null) {
  145.                         network[i][j] += (double)results[i] * errors[j] * (double)results[j] * (1 - (double)results[j]);
  146.                         //network[i][j] += (double)results[i] * errors[j] * 0.85;
  147.                     }
  148.                 }
  149.             }
  150.         }
  151.     }
  152. }
Advertisement
RAW Paste Data Copied
Advertisement