Advertisement
Guest User

create new

a guest
Sep 20th, 2019
160
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.48 KB | None | 0 0
  1. package NeuralNetwork;
  2.  
  3. import Miscellaneous.RandomGenerator;
  4.  
  5. import java.io.*;
  6. import java.util.*;
  7. import java.util.stream.Collectors;
  8. import java.util.stream.Stream;
  9.  
  10. public class Network {
  11. public final RandomGenerator rg = new RandomGenerator();
  12. public final ArrayList<Layer> layers = new ArrayList<>();
  13. public int epoch;
  14. private boolean debug = false;
  15. LinkedList<ArrayList<Double>> listOfErrors;
  16.  
  17. public Network(ArrayList<Integer> dimensions) {
  18. listOfErrors = new LinkedList<>();
  19. for(Integer i : dimensions)
  20. addLayer(i);
  21. }
  22.  
  23. public Layer getInputLayer() {
  24. return layers.get(0);
  25. }
  26.  
  27. public Layer getOutputLayer() {
  28. return layers.get(layers.size() - 1);
  29. }
  30.  
  31. public void batch(Model m, ArrayList<Double> x, ArrayList<Double> y) {
  32. this.batchFullFeedForward(m, x, y);
  33. if(m.batchSize == listOfErrors.size()) {
  34. while(!listOfErrors.isEmpty()) {
  35. this.batchfullBackWardPropagation(m.functionPair.derivativeFunction, y);
  36. this.fullUpdateWeights(m.alphaLearningRate);
  37. }
  38. }
  39. }
  40.  
  41. private void trainEpoch(Model m, TrainingSet ts) {
  42. for (var td : ts.set) {
  43. for (int i = 0; i < m.batchSize; ++i ) {
  44. if (debug)
  45. System.out.println("key: " + td.input.toString() + " val: "
  46. + td.label.toString());
  47. ArrayList<Double> x = td.input;
  48. ArrayList<Double> y = td.label;
  49.  
  50. //fullFeedForward(m.functionPair.activationFunction, x);
  51. //var curError = getOutputLayer().computeInitialErrors(m.functionPair.derivativeFunction, y);
  52. batch(m, x, y);
  53. if (debug)
  54. System.out.println(this);
  55. }
  56. }
  57. }
  58.  
  59. /* train() uses the model to set up weights and bias */
  60. public void train(Model m, TrainingSet ts) {
  61. assert m != null;
  62. assert ts != null;
  63. randomizeWeights();
  64. for (epoch = 0; !m.verifyModel.Verify(this); ++epoch)
  65. trainEpoch(m, ts);
  66. }
  67.  
  68. public void save(FileWriter fileToSave) throws IOException {
  69. final BufferedWriter writer = new BufferedWriter(fileToSave);
  70. // Dimensions
  71. for (Layer l : layers)
  72. writer.write(l.nbNeuron + " ");
  73. writer.newLine();
  74. // Weights and bias
  75. for (Layer l : layers) {
  76. for (Double bia : l.bias)
  77. writer.write(bia.toString() + " ");
  78. if (l == getOutputLayer())
  79. continue;
  80. writer.newLine();
  81. for (var ws : l.nextLink.weights)
  82. for (Double w : ws)
  83. writer.write(w.toString() + " ");
  84. writer.newLine();
  85. }
  86. writer.close();
  87. }
  88.  
  89. public static Network fromFile(FileReader fr) {
  90. final Network n;
  91. final var dim = new ArrayList<Integer>();
  92. final Scanner sc = new Scanner(fr).useLocale(Locale.US);
  93.  
  94. // Get dimensions
  95. while (sc.hasNextInt())
  96. dim.add(sc.nextInt());
  97.  
  98. n = new Network(dim);
  99.  
  100. for (Layer l : n.layers) {
  101. l.bias = Stream.generate(sc::nextDouble)
  102. .limit(l.nbNeuron)
  103. .collect(Collectors.toCollection(ArrayList::new));
  104. // Get Link with next Layer
  105. // Does not exist for last Layer
  106. if (l == n.getOutputLayer())
  107. continue;
  108. for (int j = 0; j < l.nextLink.weights.size(); ++j)
  109. l.nextLink.weights.set(j, Stream.generate(sc::nextDouble)
  110. .limit(l.nextLayer().nbNeuron)
  111. .collect(Collectors.toCollection(ArrayList::new)));
  112. }
  113. sc.close();
  114. return n;
  115. }
  116.  
  117. public void addLayer(int size) {
  118. // Create a layer and add it
  119. Layer l = new Layer(size);
  120. layers.add(l);
  121. // Can we link this one to a previous ?
  122. if (layers.size() == 1)
  123. return;
  124. // Retrieve last one
  125. Layer prev = layers.get(layers.size() - 2);
  126. // Create a link
  127. new Link(prev, l);
  128. }
  129.  
  130. public void randomizeWeights() {
  131.  
  132. for(int il = 1; il < layers.size(); ++il)
  133. {
  134. Layer l = layers.get(il);
  135. ArrayList<ArrayList<Double>> weights = l.prevLink.weights;
  136. for (ArrayList<Double> weight : weights)
  137. for (int j = 0; j < weight.size(); ++j)
  138. weight.set(j, rg.generateValue());
  139. // SHOULD WE DELETE THIS ?
  140. //for(int i = 0; i < l.bias.size(); ++i)
  141. // l.bias.set(i, rg.generateValue());
  142. }
  143. }
  144.  
  145. public void setInput(ArrayList<Double> input) {
  146. assert (layers.size() != 0);
  147. Layer inputLayer = getInputLayer();
  148. assert(input.size() == inputLayer.nbNeuron);
  149.  
  150. for(int i = 0; i < inputLayer.nbNeuron; ++i)
  151. inputLayer.values.set(i,input.get(i));
  152. }
  153.  
  154. public ArrayList<Double> batchFullFeedForward(Model model, ArrayList<Double> input, ArrayList<Double> output)
  155. {
  156. var res = fullFeedForward(model.functionPair.activationFunction, input);
  157. listOfErrors.add(getOutputLayer().computeInitialErrors(model.functionPair.derivativeFunction, output));
  158. return res;
  159. }
  160.  
  161. public ArrayList<Double> fullFeedForward(ActivationFunction af, ArrayList<Double> input) {
  162. setInput(input);
  163. for(int i = 1; i < layers.size(); ++i) {
  164. Layer l = layers.get(i);
  165. l.feedForward(af);
  166. }
  167. return getOutputLayer().values;
  168. }
  169.  
  170. public void setInitilaErrors()
  171. {
  172. ArrayList<Double> errors = this.listOfErrors.pop();
  173. Layer outputLayer = getOutputLayer();
  174. for(int i = 0; i < errors.size(); ++i)
  175. outputLayer.errors.set(i, errors.get(i));
  176. }
  177.  
  178. public void batchfullBackWardPropagation(DerivativeFunction df, ArrayList<Double> expectedValues)
  179. {
  180. setInitilaErrors();
  181. fullBackWardPropagation(df,expectedValues);
  182. }
  183.  
  184. public void fullBackWardPropagation(DerivativeFunction df, ArrayList<Double> expectedValues) {
  185. Layer l = getOutputLayer();
  186. l.computeInitialErrors(df, expectedValues);
  187. // trying to get the last layer which we have not compute the error
  188. for(l = l.prevLayer(); l.prevLink != null ; l = l.prevLayer()) {
  189. l.computeErrors(df);
  190. }
  191. }
  192.  
  193. // the error need to be compute when this function is used
  194. public void fullUpdateWeights(Double alphaLearningRate) {
  195. for(Layer l = getOutputLayer(); l.prevLink != null; l = l.prevLayer()) {
  196. l.prevLink.UpdateWeights(alphaLearningRate);
  197. }
  198. }
  199.  
  200. public void activateDebug()
  201. {debug = true;}
  202.  
  203. @Override
  204. public boolean equals(Object other) {
  205. if (this == other)
  206. return true;
  207. if (other == null || other.getClass() != Network.class)
  208. return false;
  209.  
  210. Network otherN = (Network) other;
  211. if (this.layers.size() != otherN.layers.size())
  212. return false;
  213. for (int i = 0; i < this.layers.size(); ++i) {
  214. Layer curL = layers.get(i);
  215. Layer othL = otherN.layers.get(i);
  216. if (!curL.equals(othL))
  217. return false;
  218. Link nextLink = curL.nextLink;
  219. Link nextOthL = othL.nextLink;
  220. if (nextLink == null) {
  221. if (nextOthL != null)
  222. return false;
  223. continue;
  224. }
  225. if (!nextLink.equal(nextOthL))
  226. return false;
  227. }
  228. return true;
  229. }
  230.  
  231. @Override
  232. public String toString() {
  233. final var sb = new StringBuilder();
  234.  
  235. Layer l = getInputLayer();
  236. while (l != getOutputLayer()) {
  237. sb.append(" -BIAS- ").append(l.bias).append("\n");
  238. for (ArrayList<Double> weight : l.nextLink.weights)
  239. sb.append(weight).append("\n");
  240. l = l.nextLayer();
  241. //System.out.println("_____________________________________________________________");
  242. }
  243. sb.append(getOutputLayer().bias);
  244. sb.append("===================================================");
  245. sb.append("===================================================");
  246. return sb.toString();
  247. }
  248. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement