StoneHaos

myneuro2

Feb 8th, 2022 (edited)
241
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 3.73 KB | None | 0 0
  1. using System;
  2.  
  3. namespace neuro {
  4.     class NeuralNetwork {
  5.         private double?[][] network;
  6.         private double?[] results;
  7.         private int inputs;
  8.         private int[] hides;
  9.         private int cnthides;
  10.         private int outputs;
  11.         private int cnt;
  12.  
  13.         private Random rnd = new Random();
  14.         public NeuralNetwork(int inputs, int outputs, params int[] hides) {
  15.             this.inputs = inputs;
  16.             this.hides = hides;
  17.             this.outputs = outputs;
  18.             int cnthides = 0;
  19.             for (int i = 0; i < hides.Length; ++ i) {
  20.                 cnthides += hides[i];
  21.             }
  22.             int n = inputs + cnthides + outputs;
  23.             this.cnt = n;
  24.             this.cnthides = cnthides;
  25.             network = new double?[n][];
  26.             results = new double?[n];
  27.             for (int i = 0; i < n; ++ i) {
  28.                 network[i] = new double?[n];
  29.                 for (int j = 0; j < n; ++ j) {
  30.                     network[i][j] = null;
  31.                 }
  32.             }
  33.             if (hides.Length == 0) {
  34.                 for (int i = 0; i < inputs; ++ i) {
  35.                     for (int j = inputs; j < cnt; ++ j) {
  36.                         network[i][j] = random();
  37.                     }
  38.                 }
  39.             }
  40.             else {
  41.                 for (int i = 0; i < inputs; ++ i) {
  42.                     for (int j = 0; j < hides[0]; ++ j) {
  43.                         network[i][inputs + j] = random();
  44.                     }
  45.                 }
  46.                 int a = inputs;
  47.                 int b = inputs + hides[0];
  48.                 int t = inputs + hides[0];
  49.                 for (int i = 1; i < hides.Length; ++ i) {
  50.                     for (int k = 0; k < hides[i]; ++ k) {
  51.                         for (int j = a; j < b; ++ j) {
  52.                             network[j][t] = random();
  53.                         }
  54.                         ++t;
  55.                     }
  56.                     a = b;
  57.                     b += hides[i];
  58.                 }
  59.                 for (int k = 0; k < outputs; ++ k) {
  60.                     for (int j = a; j < b; ++ j) {
  61.                         network[j][t] = random();
  62.                     }
  63.                     ++t;
  64.                 }
  65.             }
  66.         }
  67.         public double random() {
  68.             return 1.0 / rnd.Next(1, 101);
  69.         }
  70.         public double F(double x) {
  71.             return 1.0 / (1.0 + Math.Exp(-x));
  72.         }
  73.         public void Init(double[] inputValues) {
  74.             if (inputValues.Length != inputs) throw new Exception();
  75.             for (int i = 0; i < cnt; ++ i) {
  76.                 results[i] = null;
  77.             }
  78.             for (int i = 0; i < inputs; ++ i) {
  79.                 results[i] = inputValues[i];
  80.             }
  81.         }
  82.         public double[] Count() {
  83.             for (int i = inputs; i < cnt; ++ i) {
  84.                 double s = 0;
  85.                 for (int j = 0; j < cnt; ++ j) {
  86.                     if (network[j][i] != null)
  87.                         s += (double)results[j] * (double)network[j][i];
  88.                 }
  89.                 results[i] = F(s);
  90.             }
  91.             double[] answers = new double[outputs];
  92.             int t = inputs + cnthides;
  93.             for (int i = 0; i < outputs; ++ i) {
  94.                 answers[i] = (double)results[t];
  95.                 ++ t;
  96.             }
  97.             return answers;
  98.         }
  99.         public double[] Count(double[] inputValues) {
  100.             Init(inputValues);
  101.             return Count();
  102.         }
  103.         private bool isOutputNeuron(int n) {
  104.             for (int i = 0; i < cnt; ++ i) {
  105.                 if (network[n][i] != null)
  106.                     return false;
  107.             }
  108.             return true;
  109.         }
  110.         private bool isInputNeuron(int n) {
  111.             for (int i = 0; i < cnt; ++ i) {
  112.                 if (network[i][n] != null)
  113.                     return false;
  114.             }
  115.             return true;
  116.         }
  117.  
  118.         public void Learn(double[] test, double[] testAnswer) {
  119.             double[] errors = new double[cnt];
  120.             Init(test);
  121.             for (int i = cnt - 1; i >= 0; -- i) {
  122.                 if (isOutputNeuron(i)) {
  123.                     double[] res = Count();
  124.                     errors[i] = testAnswer[i - (inputs + cnthides)] - res[i - (inputs + cnthides)];
  125.                     continue;
  126.                 }
  127.                 else if (isInputNeuron(i)) continue;
  128.  
  129.                 double s = 0;
  130.                 for (int j = 0; j < cnt; ++ j) {
  131.                     if (network[i][j] != null) {
  132.                         s += errors[j] * (double)network[i][j];
  133.                     }
  134.                 }
  135.                 errors[i] = (double)results[i] * (1 - (double)results[i]) * s;
  136.             }
  137.             for (int i = 0; i < cnt; ++ i) {
  138.                 if (isOutputNeuron(i)) continue;
  139.                 for (int j = 0; j < cnt; ++ j) {
  140.                     if (network[i][j] != null) {
  141.                         network[i][j] += (double)results[i] * errors[j] * (double)results[j] * (1 - (double)results[j]);
  142.                         //network[i][j] += (double)results[i] * errors[j] * 0.85;
  143.                     }
  144.                 }
  145.             }
  146.         }
  147.     }
  148. }
Add Comment
Please, Sign In to add comment