Advertisement
Guest User

Untitled

a guest
Jun 18th, 2018
102
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.62 KB | None | 0 0
  1. public class Network {
  2.  
  3. /**
  4. * Глобальная ошибка для обучения.
  5. */
  6. protected double globalError;
  7.  
  8. /**
  9. * количество нейронов ввода
  10. */
  11. protected int inputCount;
  12.  
  13. /**
  14. * количество скрытых нейронов
  15. */
  16. protected int hiddenCount;
  17.  
  18. /**
  19. * количество нейронов выхода
  20. */
  21. protected int outputCount;
  22.  
  23. /**
  24. * количество нейронов
  25. */
  26. protected int neuronCount;
  27.  
  28. /**
  29. * The number of weights in the network.
  30. */
  31. protected int weightCount;
  32.  
  33. /**
  34. * The learning rate.
  35. */
  36. protected double learnRate;
  37.  
  38. /**
  39. * массив сигмоид нейронов.
  40. */
  41. protected double fire[];
  42.  
  43. /**
  44. * Весовая матрица, считается «памятью» нейронной сети.
  45. */
  46. protected double matrixWeights[];
  47.  
  48. /**
  49. * The errors from the last calculation.
  50. */
  51. protected double error[];
  52.  
  53. /**
  54. * Accumulates matrixWeights delta's for training.
  55. */
  56. protected double accMatrixDelta[];
  57.  
  58. /**
  59. * The thresholds, this value, along with the weight matrixWeights
  60. * can be thought of as the memory of the neural network.
  61. */
  62. protected double thresholds[];
  63.  
  64. /**
  65. * The changes that should be applied to the weight
  66. * matrixWeights.
  67. */
  68. protected double matrixDelta[];
  69.  
  70. /**
  71. * The accumulation of the threshold deltas.
  72. */
  73. protected double accThresholdDelta[];
  74.  
  75. /**
  76. * The threshold deltas.
  77. */
  78. protected double thresholdDelta[];
  79.  
  80. /**
  81. * The momentum for training.
  82. */
  83. protected double momentum;
  84.  
  85. /**
  86. * The changes in the errors.
  87. */
  88. protected double errorDelta[];
  89.  
  90.  
  91. /**
  92. * @param inputCount Количество входных нейронов.
  93. * @param hiddenCount Количество скрытых нейронов
  94. * @param outputCount Количество выходных нейронов
  95. * @param learnRate Уровень обучения, который будет использоваться при обучении.
  96. * @param momentum Импульс, который будет использоваться при обучении.
  97. */
  98. public Network(int inputCount,
  99. int hiddenCount,
  100. int outputCount,
  101. double learnRate,
  102. double momentum) {
  103.  
  104. this.learnRate = learnRate;
  105. this.momentum = momentum;
  106.  
  107. this.inputCount = inputCount;
  108. this.hiddenCount = hiddenCount;
  109. this.outputCount = outputCount;
  110. neuronCount = inputCount + hiddenCount + outputCount;
  111. weightCount = (inputCount * hiddenCount) + (hiddenCount * outputCount);
  112.  
  113. fire = new double[neuronCount];
  114. matrixWeights = new double[weightCount];
  115. matrixDelta = new double[weightCount];
  116. thresholds = new double[neuronCount];
  117. errorDelta = new double[neuronCount];
  118. error = new double[neuronCount];
  119. accThresholdDelta = new double[neuronCount];
  120. accMatrixDelta = new double[weightCount];
  121. thresholdDelta = new double[neuronCount];
  122.  
  123. reset();
  124. }
  125.  
  126.  
  127. /**
  128. * Сигмоид, вычисляем экспоненту
  129. */
  130. public double threshold(double sum) {
  131. return 1.0 / (1 + Math.exp(-1.0 * sum));
  132. }
  133.  
  134. /**
  135. * Вычислить результат
  136. * @param input обеспечивает нейронную сеть входными данными.
  137. * @return Результаты выходных нейронов.
  138. */
  139. public double[] computeOutputs(double input[]) {
  140. final int hiddenIndex = inputCount;
  141. final int outIndex = inputCount + hiddenCount;
  142.  
  143. for (int i = 0; i < inputCount; i++) {// вытаскиваем входные данные и записываем в fire
  144. fire[i] = input[i];
  145. }
  146.  
  147. //Из формулы видно, что входная информация — это сумма всех входных данных, умноженных на соответствующие им веса.
  148. // Тогда дадим на вход 1 и 0. Пусть w1=0.4 и w2 = 0.7 Входные данные нейрона Н1 будут следующими: 1*0.4+0*0.7=0.4.
  149. // Теперь когда у нас есть входные данные, мы можем получить выходные данные,
  150. // подставив входное значение в функцию активации
  151.  
  152. // first layer
  153. for (int i = hiddenIndex; i < outIndex; i++) {
  154. double sum = thresholds[i];
  155.  
  156. for (int j = 0; j < inputCount; j++) {
  157. sum += fire[j] * matrixWeights[j];
  158. }
  159. fire[i] = threshold(sum);//функция активации
  160. }
  161.  
  162. // hidden layer
  163.  
  164. double result[] = new double[outputCount];// сигмоид выходного нейрона
  165.  
  166. for (int i = outIndex; i < neuronCount; i++) {
  167. double sum = thresholds[i];
  168.  
  169. for (int j = hiddenIndex; j < outIndex; j++) {
  170. sum += fire[j] * matrixWeights[j];
  171. }
  172. fire[i] = threshold(sum);
  173. result[i - outIndex] = fire[i];
  174. }
  175.  
  176. return result;
  177. }
  178.  
  179.  
  180. /**
  181. * Вычислите ошибку для только что достигнутого распознавания.
  182. *
  183. * @param ideal Что должны были дать выходные нейроны.
  184. */
  185. public void calcError(double ideal[]) {
  186. int i, j;
  187. final int hiddenIndex = inputCount;
  188. final int outputIndex = inputCount + hiddenCount;
  189.  
  190. // очистить ошибки скрытого слоя
  191. for (i = inputCount; i < neuronCount; i++) {
  192. error[i] = 0;
  193. }
  194.  
  195. // ошибки слоя и дельта для уровня вывода
  196. for (i = outputIndex; i < neuronCount; i++) {
  197. error[i] = ideal[i - outputIndex] - fire[i];
  198. globalError += error[i] * error[i];
  199. errorDelta[i] = error[i] * fire[i] * (1 - fire[i]);
  200. }
  201.  
  202. // ошибки скрытого слоя
  203. int winx = inputCount * hiddenCount;
  204.  
  205. for (i = outputIndex; i < neuronCount; i++) {
  206. for (j = hiddenIndex; j < outputIndex; j++) {
  207. accMatrixDelta[winx] += errorDelta[i] * fire[j];
  208. error[j] += matrixWeights[winx] * errorDelta[i];
  209. winx++;
  210. }
  211. accThresholdDelta[i] += errorDelta[i];
  212. }
  213.  
  214. // скрытый слой дельты
  215. for (i = hiddenIndex; i < outputIndex; i++) {
  216. errorDelta[i] = error[i] * fire[i] * (1 - fire[i]);
  217. }
  218.  
  219. // ошибки входного слоя
  220. winx = 0; // смещение массива весов
  221. for (i = hiddenIndex; i < outputIndex; i++) {
  222. for (j = 0; j < hiddenIndex; j++) {
  223. accMatrixDelta[winx] += errorDelta[i] * fire[j];
  224. error[j] += matrixWeights[winx] * errorDelta[i];
  225. winx++;
  226. }
  227. accThresholdDelta[i] += errorDelta[i];
  228. }
  229. }
  230.  
  231. /**
  232. * Modify the weight matrixWeights and thresholds based on the last call to
  233. * calcError.
  234. */
  235. public void learn() {
  236. int i;
  237.  
  238. // process the matrixWeights
  239. for (i = 0; i < matrixWeights.length; i++) {
  240. matrixDelta[i] = (learnRate * accMatrixDelta[i]) + (momentum * matrixDelta[i]);
  241. matrixWeights[i] += matrixDelta[i];
  242. accMatrixDelta[i] = 0;
  243. }
  244.  
  245. // process the thresholds
  246. for (i = inputCount; i < neuronCount; i++) {
  247. thresholdDelta[i] = learnRate * accThresholdDelta[i] + (momentum * thresholdDelta[i]);
  248. thresholds[i] += thresholdDelta[i];
  249. accThresholdDelta[i] = 0;
  250. }
  251. }
  252.  
  253. /**
  254. * Сброс весов
  255. */
  256. public void reset() {
  257. int i;
  258.  
  259. for (i = 0; i < neuronCount; i++) {
  260. thresholds[i] = 0.5 - (Math.random());
  261. thresholdDelta[i] = 0;
  262. accThresholdDelta[i] = 0;
  263. }
  264. for (i = 0; i < matrixWeights.length; i++) {
  265. matrixWeights[i] = 0.5 - (Math.random());
  266. matrixDelta[i] = 0;
  267. accMatrixDelta[i] = 0;
  268. }
  269. }
  270. }
  271.  
  272. public class TestNeuralNetwork {
  273. public static void main(String args[])
  274. {
  275. double xorInput[][] = {
  276. {1.0,0.0},
  277. {0.0,0.0},
  278. {0.0,1.0},
  279. {1.0,1.0}
  280. };
  281.  
  282. double xorIdeal[][] = { {1.0},{0.0},{0.0},{1.0}};
  283. Network network = new Network(2,2,1,0.7,0.9);
  284.  
  285.  
  286. for (int i=0;i<10000;i++) {
  287. for (int j=0;j<xorInput.length;j++) {
  288. network.computeOutputs(xorInput[j]);
  289. network.calcError(xorIdeal[j]);
  290. network.learn();
  291. }
  292. }
  293.  
  294. for (int i=0;i<xorInput.length;i++) {
  295. for (int j=0;j<xorInput[0].length;j++) {
  296. System.out.print( xorInput[i][j] +":" );
  297. }
  298. double out[] = network.computeOutputs(xorInput[i]);
  299. System.out.println("="+out[0]);
  300. }
  301. }
  302. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement