Advertisement
Guest User

Untitled

a guest
Apr 26th, 2018
72
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.47 KB | None | 0 0
  1. package MLP;
  2.  
  3. import java.io.File;
  4. import java.io.FileNotFoundException;
  5. import java.text.DecimalFormat;
  6. import java.util.ArrayList;
  7. import java.util.Random;
  8. import java.util.Scanner;
  9. import java.util.concurrent.ThreadLocalRandom;
  10. import java.util.stream.Stream;
  11.  
  12. import org.jfree.chart.*;
  13. import org.jfree.chart.plot.XYPlot;
  14. import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
  15. import org.jfree.data.xy.XYSeries;
  16. import org.jfree.data.xy.XYSeriesCollection;
  17.  
  18.  
  19.  
  20. public class Network {
  21.  
  22. private double[][] wyjscie;
  23. private double[][][] wagi;
  24. private double[][] bias;
  25. private double[][] vectorMomentum;
  26. private double[][] errorSignal;
  27. private double[][] pochodnaPoWyjsciu;
  28.  
  29. ArrayList<double[]> list = new ArrayList<>();
  30.  
  31. private final int[] rozmiarWarstwySieci;
  32. private final int rozmiarWejscia;
  33. private final int rozmiarWyjscia;
  34. private final int rozmiarSieci;
  35.  
  36. public Network(int... tablica) {
  37. rozmiarWarstwySieci = tablica;
  38. rozmiarWejscia = rozmiarWarstwySieci[0];
  39. rozmiarSieci = rozmiarWarstwySieci.length;
  40. rozmiarWyjscia = rozmiarWarstwySieci[rozmiarSieci -1];
  41.  
  42. this.wyjscie = new double[rozmiarSieci][];
  43. this.wagi = new double[rozmiarSieci][][];
  44. this.bias = new double[rozmiarSieci][];
  45. this.vectorMomentum = new double[rozmiarSieci][];
  46. this.errorSignal = new double[rozmiarSieci][];
  47. this.pochodnaPoWyjsciu = new double[rozmiarSieci][];
  48.  
  49. for(int i = 0; i < rozmiarSieci; i++)
  50. {
  51. this.wyjscie[i] = new double[tablica[i]];
  52. this.errorSignal[i] = new double[tablica[i]];
  53. this.pochodnaPoWyjsciu[i] = new double[tablica[i]];
  54. this.bias[i] = NetworkTools.stworzRandomowaTablice(rozmiarWarstwySieci[i],0.2,0.7);
  55. this.vectorMomentum[i] = new double[tablica[i]];
  56. if(i>0)
  57. {
  58. wagi[i] = NetworkTools.stworzRandomowaTablice(rozmiarWarstwySieci[i],tablica[i-1],-0.3,0.5);
  59. }
  60. }
  61. }
  62.  
  63. private void trening(double[] wejscie, double[] cel, double eta, double momentum, boolean ifBias)
  64. {
  65. if (wejscie.length != rozmiarWejscia || cel.length != rozmiarWyjscia) return;
  66. oblicz(ifBias, wejscie);
  67. bladWstecznejPropragacji(cel);
  68. zaktualizujWagi(eta, momentum);
  69. }
  70. public double[] oblicz(boolean ifBias, double... wejscie)
  71. {
  72. if (wejscie.length != this.rozmiarWejscia) return null;
  73. this.wyjscie[0] = wejscie;
  74. for (int warstwa = 1; warstwa < rozmiarSieci; warstwa++)
  75. for (int neuron = 0; neuron < rozmiarWarstwySieci[warstwa]; neuron++) {
  76. double sum;
  77. if (ifBias)
  78. sum = bias[warstwa][neuron];
  79. else
  80. sum = 0;
  81. for (int poprzedniNeuron = 0; poprzedniNeuron < rozmiarWarstwySieci[warstwa - 1]; poprzedniNeuron++) {
  82. sum += wyjscie[warstwa - 1][poprzedniNeuron] * wagi[warstwa][neuron][poprzedniNeuron];
  83. }
  84. wyjscie[warstwa][neuron] = sigmoid(sum);
  85. pochodnaPoWyjsciu[warstwa][neuron] = wyjscie[warstwa][neuron] * (1 - wyjscie[warstwa][neuron]);
  86. }
  87. return wyjscie[rozmiarSieci - 1];
  88. }
  89.  
  90. private void zaktualizujWagi(double eta, double momentum)
  91. {
  92. double delta = 0.0;
  93. for (int warstwa = 1; warstwa < rozmiarSieci; warstwa++)
  94. {
  95. //delta = -eta * errorSignal[layer][neuron];
  96. //bias[layer][neuron] += delta;
  97. for (int neuron = 0; neuron < rozmiarWarstwySieci[warstwa]; neuron++)
  98. for (int poprzedniNeuron = 0; poprzedniNeuron < rozmiarWarstwySieci[warstwa - 1]; poprzedniNeuron++) {
  99. //delta = -eta * output[layer - 1][prevNeuron] * errorSignal[layer][neuron] + momentum * prevWeight;
  100. delta = -eta * wyjscie[warstwa - 1][poprzedniNeuron] * errorSignal[warstwa][neuron] + momentum * vectorMomentum[warstwa][neuron];
  101. //prevWeight = delta;
  102. vectorMomentum[warstwa][neuron] = delta;
  103. wagi[warstwa][neuron][poprzedniNeuron] += delta;
  104. bias[warstwa][neuron] += -eta * errorSignal[warstwa][neuron];
  105. }
  106. }
  107. }
  108.  
  109. private void bladWstecznejPropragacji(double[] cel)
  110. {
  111. for (int neuron = 0; neuron < rozmiarWarstwySieci[rozmiarSieci - 1]; neuron++)
  112. errorSignal[rozmiarSieci - 1][neuron] = (wyjscie[rozmiarSieci - 1][neuron] - cel[neuron])
  113. * pochodnaPoWyjsciu[rozmiarSieci - 1][neuron];
  114. for (int warstwa = rozmiarSieci - 2; warstwa > 0; warstwa--)
  115. for (int neuron = 0; neuron < rozmiarWarstwySieci[warstwa]; neuron++) {
  116. double sum = 0;
  117. for (int nastepnyNeuron = 0; nastepnyNeuron < rozmiarWarstwySieci[warstwa + 1]; nastepnyNeuron++) {
  118. sum += wagi[warstwa + 1][nastepnyNeuron][neuron] * errorSignal[warstwa + 1][nastepnyNeuron];
  119. }
  120. errorSignal[warstwa][neuron] = sum * pochodnaPoWyjsciu[warstwa][neuron];
  121. }
  122. }
  123.  
  124. private double sigmoid(double x)
  125. {
  126. return 1d / (1 + Math.exp(-x));
  127. }
  128.  
  129. private ArrayList<double[]> trenowac(ArrayList<double[]> wejscie, ArrayList<double[]> wyjscie, double eta, double momentum, boolean ifBias, int iteracje, int rozmiarTablicy)
  130. {
  131. double[] tablicaBledu = new double[iteracje];
  132. double obliczonyBlad;
  133. ArrayList<double[]> temp = new ArrayList<>();
  134. for (int i = 0; i < iteracje; i++)
  135. {
  136. int[] tablica = new int[rozmiarTablicy];
  137. for (int r = 0; r < rozmiarTablicy; r++)
  138. tablica[r] = r;
  139. szufladkujTablice(tablica);
  140. obliczonyBlad = 0;
  141. for (int k = 0; k < rozmiarTablicy; k++)
  142. {
  143. trening(wejscie.get(tablica[k]), wyjscie.get(tablica[k]), eta, momentum, ifBias);
  144. double sum = 0;
  145. for (int t = 0; t < rozmiarTablicy; t++)
  146. {
  147. sum += (wyjscie.get(tablica[k])[t] - this.wyjscie[rozmiarSieci - 1][t]) * ((wyjscie.get(tablica[k])[t]) - this.wyjscie[rozmiarSieci - 1][t]);
  148. }
  149. obliczonyBlad += sum;
  150.  
  151. }
  152. tablicaBledu[i] = 1d / 4d * obliczonyBlad;
  153. if (i == iteracje - 1)
  154. {
  155. for (int k = 0; k < rozmiarTablicy; k++)
  156. {
  157. list.add(oblicz(ifBias, wejscie.get(k)));
  158. temp.add(list.get(k).clone());
  159. }
  160. }
  161. }
  162. rysujBlad(tablicaBledu);
  163. return temp;
  164. }
  165.  
  166. private static void szufladkujTablice(int[] tablica)
  167. {
  168. Random rnd = ThreadLocalRandom.current();
  169. for (int i = tablica.length - 1; i > 0; i--)
  170. {
  171. int index = rnd.nextInt(i + 1);
  172. int a = tablica[index];
  173. tablica[index] = tablica[i];
  174. tablica[i] = a;
  175. }
  176. }
  177. private int czytajZPliku(ArrayList<double[]> UI, ArrayList<double[]> IO)
  178. {
  179. int rozmiarTablicy = 0;
  180. try{
  181. Scanner scanner = new Scanner(new File("dane.txt"));
  182. while (scanner.hasNext())
  183. {
  184. rozmiarTablicy++;
  185. double[] tablica = Stream.of(scanner.nextLine().split(","))
  186. .mapToDouble (Double::parseDouble)
  187. .toArray();
  188.  
  189. UI.add(tablica);
  190. IO.add(tablica);
  191. }
  192. }catch(FileNotFoundException e ) {
  193. System.out.println(e);
  194. }
  195. return rozmiarTablicy;
  196. }
  197.  
  198. private void rysujBlad(double[] error)
  199. {
  200. XYPlot plot = new XYPlot();
  201. XYSeriesCollection dataset = new XYSeriesCollection();
  202.  
  203. XYSeries series = new XYSeries("Funkcja błędu") ;
  204. for(int i = 0 ; i < error.length ; i++)
  205. {
  206. series.add(i,error[i]);
  207. }
  208. dataset.addSeries(series);
  209. plot.setDataset(dataset);
  210.  
  211. XYLineAndShapeRenderer renderer = new XYLineAndShapeRenderer(true, false);
  212. plot.setRenderer(renderer);
  213. plot.setDomainAxis(new org.jfree.chart.axis.NumberAxis("Iteracje"));
  214. plot.setRangeAxis(new org.jfree.chart.axis.NumberAxis("Błąd"));
  215.  
  216. JFreeChart chart = new JFreeChart("Perceptron wielowarstwowy", JFreeChart.DEFAULT_TITLE_FONT, plot, true);
  217.  
  218. ChartFrame frame1 = new ChartFrame("Perceptron wielowarstwowy",chart);
  219. frame1.setVisible(true);
  220.  
  221.  
  222. frame1.setSize(1150,650);
  223.  
  224. }
  225. public static void main(String[] args) {
  226.  
  227. Scanner scan = new Scanner(System.in);
  228. double eta = 0.0;
  229. double momentum = 0.0;
  230. int iteracje = 0;
  231. boolean czyBiasWlaczony = false;
  232. int liczbaNeuronowUkrytych = 0;
  233.  
  234. DecimalFormat df = new DecimalFormat("####.###########");
  235. ArrayList<double[]> wejscia = new ArrayList<double[]>();
  236. ArrayList<double[]> wyjscia= new ArrayList<double[]>();
  237. ArrayList<double[]> wyniki;
  238. /*
  239.  
  240. System.out.println("Podaj ilość neuronów w warstwie ukrytej:");
  241. liczbaNeuronowUkrytych=scan.nextInt();
  242. System.out.println("Podaj etę:");
  243. eta=scan.nextDouble();
  244. System.out.println("Podaj momentum:");
  245. momentum=scan.nextDouble();
  246. System.out.println("Podaj liczbę iteracji:");
  247. iteracje=scan.nextInt();
  248. System.out.println("Użyć biasu? T by uzyc, cokolwiek by nie uzyc:");
  249. char ans=scan.next().charAt(0);
  250. if(ans=='T') {czyBiasWlaczony = true; System.out.println("Bias aktywny.");}
  251. else {System.out.println("Bias nieaktywny.");}
  252. scan.close();
  253. */
  254. //Network net = new Network(4, numOfHiddenNeurons, 4);
  255. Network net = new Network(4, 2, 4);
  256.  
  257. // wyniki = net.Training(wejscia,wyjscia,eta ,momentum,biasEnabled,iteracje,net.readFile(wejscia,wyjscia));
  258. wyniki = net.trenowac(wejscia,wyjscia,0.1 ,0,false,10000,net.czytajZPliku(wejscia,wyjscia));
  259.  
  260. System.out.println();
  261.  
  262. for(int i = 0; i < 4; i++)
  263. {
  264. for(int j = 0; j < 4; j++)
  265. {
  266. System.out.print(df.format(wyniki.get(i)[j]) + " ");
  267. }
  268. System.out.println();
  269. }
  270. System.out.println();
  271.  
  272. }}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement