Advertisement
Guest User

Untitled

a guest
Apr 26th, 2018
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.47 KB | None | 0 0
  1. package MLP;
  2.  
  3. import java.io.File;
  4. import java.io.FileNotFoundException;
  5. import java.text.DecimalFormat;
  6. import java.util.ArrayList;
  7. import java.util.Random;
  8. import java.util.Scanner;
  9. import java.util.concurrent.ThreadLocalRandom;
  10. import java.util.stream.Stream;
  11.  
  12. import org.jfree.chart.*;
  13. import org.jfree.chart.plot.XYPlot;
  14. import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
  15. import org.jfree.data.xy.XYSeries;
  16. import org.jfree.data.xy.XYSeriesCollection;
  17.  
  18.  
  19. import static MLP.NetworkTools.createRandomArray;
  20.  
  21. public class Network {
  22.  
  23. private double[][] wyjscie;
  24. private double[][][] wagi;
  25. private double[][] bias;
  26. private double[][] vectorMomentum;
  27. private double[][] errorSignal;
  28. private double[][] pochodnaPoWyjsciu;
  29.  
  30. ArrayList<double[]> list = new ArrayList<>();
  31.  
  32. private final int[] rozmiarWarstwySieci;
  33. private final int rozmiarWejscia;
  34. private final int rozmiarWyjscia;
  35. private final int rozmiarSieci;
  36.  
  37. public Network(int... tablica) {
  38. rozmiarWarstwySieci = tablica;
  39. rozmiarWejscia = rozmiarWarstwySieci[0];
  40. rozmiarSieci = rozmiarWarstwySieci.length;
  41. rozmiarWyjscia = rozmiarWarstwySieci[rozmiarSieci -1];
  42.  
  43. this.wyjscie = new double[rozmiarSieci][];
  44. this.wagi = new double[rozmiarSieci][][];
  45. this.bias = new double[rozmiarSieci][];
  46. this.vectorMomentum = new double[rozmiarSieci][];
  47. this.errorSignal = new double[rozmiarSieci][];
  48. this.pochodnaPoWyjsciu = new double[rozmiarSieci][];
  49.  
  50. for(int i = 0; i < rozmiarSieci; i++)
  51. {
  52. this.wyjscie[i] = new double[tablica[i]];
  53. this.errorSignal[i] = new double[tablica[i]];
  54. this.pochodnaPoWyjsciu[i] = new double[tablica[i]];
  55. this.bias[i] = createRandomArray(rozmiarWarstwySieci[i],0.2,0.7);
  56. this.vectorMomentum[i] = new double[tablica[i]];
  57. if(i>0)
  58. {
  59. wagi[i] = createRandomArray(rozmiarWarstwySieci[i],tablica[i-1],-0.3,0.5);
  60. }
  61. }
  62. }
  63.  
  64. private void trening(double[] wejscie, double[] cel, double eta, double momentum, boolean ifBias)
  65. {
  66. if (wejscie.length != rozmiarWejscia || cel.length != rozmiarWyjscia) return;
  67. oblicz(ifBias, wejscie);
  68. bladWstecznejPropragacji(cel);
  69. zaktualizujWagi(eta, momentum);
  70. }
  71. public double[] oblicz(boolean ifBias, double... wejscie)
  72. {
  73. if (wejscie.length != this.rozmiarWejscia) return null;
  74. this.wyjscie[0] = wejscie;
  75. for (int warstwa = 1; warstwa < rozmiarSieci; warstwa++)
  76. for (int neuron = 0; neuron < rozmiarWarstwySieci[warstwa]; neuron++) {
  77. double sum;
  78. if (ifBias)
  79. sum = bias[warstwa][neuron];
  80. else
  81. sum = 0;
  82. for (int poprzedniNeuron = 0; poprzedniNeuron < rozmiarWarstwySieci[warstwa - 1]; poprzedniNeuron++) {
  83. sum += wyjscie[warstwa - 1][poprzedniNeuron] * wagi[warstwa][neuron][poprzedniNeuron];
  84. }
  85. wyjscie[warstwa][neuron] = sigmoid(sum);
  86. pochodnaPoWyjsciu[warstwa][neuron] = wyjscie[warstwa][neuron] * (1 - wyjscie[warstwa][neuron]);
  87. }
  88. return wyjscie[rozmiarSieci - 1];
  89. }
  90.  
  91. private void zaktualizujWagi(double eta, double momentum)
  92. {
  93. double delta = 0.0;
  94. for (int warstwa = 1; warstwa < rozmiarSieci; warstwa++)
  95. {
  96. //delta = -eta * errorSignal[layer][neuron];
  97. //bias[layer][neuron] += delta;
  98. for (int neuron = 0; neuron < rozmiarWarstwySieci[warstwa]; neuron++)
  99. for (int poprzedniNeuron = 0; poprzedniNeuron < rozmiarWarstwySieci[warstwa - 1]; poprzedniNeuron++) {
  100. //delta = -eta * output[layer - 1][prevNeuron] * errorSignal[layer][neuron] + momentum * prevWeight;
  101. delta = -eta * wyjscie[warstwa - 1][poprzedniNeuron] * errorSignal[warstwa][neuron] + momentum * vectorMomentum[warstwa][neuron];
  102. //prevWeight = delta;
  103. vectorMomentum[warstwa][neuron] = delta;
  104. wagi[warstwa][neuron][poprzedniNeuron] += delta;
  105. bias[warstwa][neuron] += -eta * errorSignal[warstwa][neuron];
  106. }
  107. }
  108. }
  109.  
  110. private void bladWstecznejPropragacji(double[] cel)
  111. {
  112. for (int neuron = 0; neuron < rozmiarWarstwySieci[rozmiarSieci - 1]; neuron++)
  113. errorSignal[rozmiarSieci - 1][neuron] = (wyjscie[rozmiarSieci - 1][neuron] - cel[neuron])
  114. * pochodnaPoWyjsciu[rozmiarSieci - 1][neuron];
  115. for (int warstwa = rozmiarSieci - 2; warstwa > 0; warstwa--)
  116. for (int neuron = 0; neuron < rozmiarWarstwySieci[warstwa]; neuron++) {
  117. double sum = 0;
  118. for (int nastepnyNeuron = 0; nastepnyNeuron < rozmiarWarstwySieci[warstwa + 1]; nastepnyNeuron++) {
  119. sum += wagi[warstwa + 1][nastepnyNeuron][neuron] * errorSignal[warstwa + 1][nastepnyNeuron];
  120. }
  121. errorSignal[warstwa][neuron] = sum * pochodnaPoWyjsciu[warstwa][neuron];
  122. }
  123. }
  124.  
  125. private double sigmoid(double x)
  126. {
  127. return 1d / (1 + Math.exp(-x));
  128. }
  129.  
  130. private ArrayList<double[]> trenowac(ArrayList<double[]> wejscie, ArrayList<double[]> wyjscie, double eta, double momentum, boolean ifBias, int iteracje, int rozmiarTablicy)
  131. {
  132. double[] tablicaBledu = new double[iteracje];
  133. double obliczonyBlad;
  134. ArrayList<double[]> temp = new ArrayList<>();
  135. for (int i = 0; i < iteracje; i++)
  136. {
  137. int[] tablica = new int[rozmiarTablicy];
  138. for (int r = 0; r < rozmiarTablicy; r++)
  139. tablica[r] = r;
  140. szufladkujTablice(tablica);
  141. obliczonyBlad = 0;
  142. for (int k = 0; k < rozmiarTablicy; k++)
  143. {
  144. trening(wejscie.get(tablica[k]), wyjscie.get(tablica[k]), eta, momentum, ifBias);
  145. double sum = 0;
  146. for (int t = 0; t < rozmiarTablicy; t++)
  147. {
  148. sum += (wyjscie.get(tablica[k])[t] - this.wyjscie[rozmiarSieci - 1][t]) * ((wyjscie.get(tablica[k])[t]) - this.wyjscie[rozmiarSieci - 1][t]);
  149. }
  150. obliczonyBlad += sum;
  151.  
  152. }
  153. tablicaBledu[i] = 1d / 4d * obliczonyBlad;
  154. if (i == iteracje - 1)
  155. {
  156. for (int k = 0; k < rozmiarTablicy; k++)
  157. {
  158. list.add(oblicz(ifBias, wejscie.get(k)));
  159. temp.add(list.get(k).clone());
  160. }
  161. }
  162. }
  163. rysujBlad(tablicaBledu);
  164. return temp;
  165. }
  166.  
  167. private static void szufladkujTablice(int[] tablica)
  168. {
  169. Random rnd = ThreadLocalRandom.current();
  170. for (int i = tablica.length - 1; i > 0; i--)
  171. {
  172. int index = rnd.nextInt(i + 1);
  173. int a = tablica[index];
  174. tablica[index] = tablica[i];
  175. tablica[i] = a;
  176. }
  177. }
  178. private int czytajZPliku(ArrayList<double[]> UI, ArrayList<double[]> IO)
  179. {
  180. int rozmiarTablicy = 0;
  181. try{
  182. Scanner scanner = new Scanner(new File("dane.txt"));
  183. while (scanner.hasNext())
  184. {
  185. rozmiarTablicy++;
  186. double[] arr = Stream.of(scanner.nextLine().split(","))
  187. .mapToDouble (Double::parseDouble)
  188. .toArray();
  189.  
  190. UI.add(arr);
  191. IO.add(arr);
  192. }
  193. }catch(FileNotFoundException e ) {
  194. System.out.println(e);
  195. }
  196. return rozmiarTablicy;
  197. }
  198.  
  199. private void rysujBlad(double[] error)
  200. {
  201. XYPlot plot = new XYPlot();
  202. XYSeriesCollection dataset = new XYSeriesCollection();
  203.  
  204. XYSeries series = new XYSeries("Funkcja błędu") ;
  205. for(int i = 0 ; i < error.length ; i++)
  206. {
  207. series.add(i,error[i]);
  208. }
  209. dataset.addSeries(series);
  210. plot.setDataset(dataset);
  211.  
  212. XYLineAndShapeRenderer renderer = new XYLineAndShapeRenderer(true, false);
  213. plot.setRenderer(renderer);
  214. plot.setDomainAxis(new org.jfree.chart.axis.NumberAxis("Iteracje"));
  215. plot.setRangeAxis(new org.jfree.chart.axis.NumberAxis("Błąd"));
  216.  
  217. JFreeChart chart = new JFreeChart("Perceptron wielowarstwowy", JFreeChart.DEFAULT_TITLE_FONT, plot, true);
  218.  
  219. ChartFrame frame1 = new ChartFrame("Perceptron wielowarstwowy",chart);
  220. frame1.setVisible(true);
  221.  
  222.  
  223. frame1.setSize(1150,650);
  224.  
  225. }
  226. public static void main(String[] args) {
  227.  
  228. Scanner scan = new Scanner(System.in);
  229. double eta = 0.0;
  230. double momentum = 0.0;
  231. int iteracje = 0;
  232. boolean czyBiasWlaczony = false;
  233. int liczbaNeuronowUkrytych = 0;
  234.  
  235. DecimalFormat df = new DecimalFormat("####.###########");
  236. ArrayList<double[]> wejscia = new ArrayList<double[]>();
  237. ArrayList<double[]> wyjscia= new ArrayList<double[]>();
  238. ArrayList<double[]> wyniki;
  239. /*
  240.  
  241. System.out.println("Podaj ilość neuronów w warstwie ukrytej:");
  242. liczbaNeuronowUkrytych=scan.nextInt();
  243. System.out.println("Podaj etę:");
  244. eta=scan.nextDouble();
  245. System.out.println("Podaj momentum:");
  246. momentum=scan.nextDouble();
  247. System.out.println("Podaj liczbę iteracji:");
  248. iteracje=scan.nextInt();
  249. System.out.println("Użyć biasu? T by uzyc, cokolwiek by nie uzyc:");
  250. char ans=scan.next().charAt(0);
  251. if(ans=='T') {czyBiasWlaczony = true; System.out.println("Bias aktywny.");}
  252. else {System.out.println("Bias nieaktywny.");}
  253. scan.close();
  254. */
  255. //Network net = new Network(4, numOfHiddenNeurons, 4);
  256. Network net = new Network(4, 2, 4);
  257.  
  258. // wyniki = net.Training(wejscia,wyjscia,eta ,momentum,biasEnabled,iteracje,net.readFile(wejscia,wyjscia));
  259. wyniki = net.trenowac(wejscia,wyjscia,0.1 ,0,false,10000,net.czytajZPliku(wejscia,wyjscia));
  260.  
  261. System.out.println();
  262.  
  263. for(int i = 0; i < 4; i++)
  264. {
  265. for(int j = 0; j < 4; j++)
  266. {
  267. System.out.print(df.format(wyniki.get(i)[j]) + " ");
  268. }
  269. System.out.println();
  270. }
  271. System.out.println();
  272.  
  273. }}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement