SHARE
TWEET

create new

a guest Sep 20th, 2019 110 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. package NeuralNetwork;
  2.  
  3. import Miscellaneous.RandomGenerator;
  4.  
  5. import java.io.*;
  6. import java.util.*;
  7. import java.util.stream.Collectors;
  8. import java.util.stream.Stream;
  9.  
  10. public class Network {
  11.     public final RandomGenerator rg = new RandomGenerator();
  12.     public final ArrayList<Layer> layers = new ArrayList<>();
  13.     public int epoch;
  14.     private boolean debug = false;
  15.     LinkedList<ArrayList<Double>> listOfErrors;
  16.  
  17.     public Network(ArrayList<Integer> dimensions) {
  18.         listOfErrors = new LinkedList<>();
  19.         for(Integer i : dimensions)
  20.             addLayer(i);
  21.     }
  22.  
  23.     public Layer getInputLayer() {
  24.         return layers.get(0);
  25.     }
  26.  
  27.     public Layer getOutputLayer() {
  28.         return layers.get(layers.size() - 1);
  29.     }
  30.  
  31.     public void batch(Model m, ArrayList<Double> x, ArrayList<Double> y) {
  32.         this.batchFullFeedForward(m, x, y);
  33.         if(m.batchSize == listOfErrors.size()) {
  34.             while(!listOfErrors.isEmpty()) {
  35.                 this.batchfullBackWardPropagation(m.functionPair.derivativeFunction, y);
  36.                 this.fullUpdateWeights(m.alphaLearningRate);
  37.             }
  38.         }
  39.     }
  40.  
  41.     private void trainEpoch(Model m, TrainingSet ts) {
  42.         for (var td : ts.set) {
  43.             for (int i = 0; i < m.batchSize; ++i ) {
  44.                 if (debug)
  45.                     System.out.println("key: " + td.input.toString() + " val: "
  46.                             + td.label.toString());
  47.                 ArrayList<Double> x = td.input;
  48.                 ArrayList<Double> y = td.label;
  49.  
  50.                 //fullFeedForward(m.functionPair.activationFunction, x);
  51.                 //var curError = getOutputLayer().computeInitialErrors(m.functionPair.derivativeFunction, y);
  52.                 batch(m, x, y);
  53.                 if (debug)
  54.                     System.out.println(this);
  55.             }
  56.         }
  57.     }
  58.  
  59.     /* train() uses the model to set up weights and bias */
  60.     public void train(Model m, TrainingSet ts) {
  61.         assert m != null;
  62.         assert ts != null;
  63.         randomizeWeights();
  64.         for (epoch = 0; !m.verifyModel.Verify(this); ++epoch)
  65.             trainEpoch(m, ts);
  66.     }
  67.  
  68.     public void save(FileWriter fileToSave) throws IOException {
  69.         final BufferedWriter writer = new BufferedWriter(fileToSave);
  70.         // Dimensions
  71.         for (Layer l : layers)
  72.             writer.write(l.nbNeuron + " ");
  73.         writer.newLine();
  74.         // Weights and bias
  75.         for (Layer l : layers) {
  76.             for (Double bia : l.bias)
  77.                 writer.write(bia.toString() + " ");
  78.             if (l == getOutputLayer())
  79.                 continue;
  80.             writer.newLine();
  81.             for (var ws : l.nextLink.weights)
  82.                 for (Double w : ws)
  83.                     writer.write(w.toString() + " ");
  84.             writer.newLine();
  85.         }
  86.         writer.close();
  87.     }
  88.  
  89.     public static Network fromFile(FileReader fr) {
  90.         final Network n;
  91.         final var dim = new ArrayList<Integer>();
  92.         final Scanner sc = new Scanner(fr).useLocale(Locale.US);
  93.  
  94.         // Get dimensions
  95.         while (sc.hasNextInt())
  96.             dim.add(sc.nextInt());
  97.  
  98.         n = new Network(dim);
  99.  
  100.         for (Layer l : n.layers) {
  101.             l.bias = Stream.generate(sc::nextDouble)
  102.                     .limit(l.nbNeuron)
  103.                     .collect(Collectors.toCollection(ArrayList::new));
  104.             // Get Link with next Layer
  105.             // Does not exist for last Layer
  106.             if (l == n.getOutputLayer())
  107.                 continue;
  108.             for (int j = 0; j < l.nextLink.weights.size(); ++j)
  109.                 l.nextLink.weights.set(j, Stream.generate(sc::nextDouble)
  110.                         .limit(l.nextLayer().nbNeuron)
  111.                         .collect(Collectors.toCollection(ArrayList::new)));
  112.         }
  113.         sc.close();
  114.         return n;
  115.     }
  116.  
  117.     public void addLayer(int size) {
  118.         // Create a layer and add it
  119.         Layer l = new Layer(size);
  120.         layers.add(l);
  121.         // Can we link this one to a previous ?
  122.         if (layers.size() == 1)
  123.             return;
  124.         // Retrieve last one
  125.         Layer prev = layers.get(layers.size() - 2);
  126.         // Create a link
  127.         new Link(prev, l);
  128.     }
  129.  
  130.     public void randomizeWeights() {
  131.  
  132.         for(int il = 1; il < layers.size(); ++il)
  133.         {
  134.             Layer l = layers.get(il);
  135.             ArrayList<ArrayList<Double>> weights = l.prevLink.weights;
  136.             for (ArrayList<Double> weight : weights)
  137.                 for (int j = 0; j < weight.size(); ++j)
  138.                     weight.set(j, rg.generateValue());
  139.             // SHOULD WE DELETE THIS ?
  140.             //for(int i = 0; i < l.bias.size(); ++i)
  141.             //    l.bias.set(i, rg.generateValue());
  142.         }
  143.     }
  144.  
  145.     public void setInput(ArrayList<Double> input) {
  146.         assert (layers.size() != 0);
  147.         Layer inputLayer = getInputLayer();
  148.         assert(input.size() == inputLayer.nbNeuron);
  149.  
  150.         for(int i = 0; i < inputLayer.nbNeuron; ++i)
  151.             inputLayer.values.set(i,input.get(i));
  152.     }
  153.  
  154.     public ArrayList<Double> batchFullFeedForward(Model model, ArrayList<Double> input, ArrayList<Double> output)
  155.     {
  156.         var res = fullFeedForward(model.functionPair.activationFunction, input);
  157.         listOfErrors.add(getOutputLayer().computeInitialErrors(model.functionPair.derivativeFunction, output));
  158.         return res;
  159.     }
  160.  
  161.     public ArrayList<Double> fullFeedForward(ActivationFunction af, ArrayList<Double> input) {
  162.         setInput(input);
  163.         for(int i = 1; i < layers.size(); ++i) {
  164.             Layer l = layers.get(i);
  165.             l.feedForward(af);
  166.         }
  167.         return getOutputLayer().values;
  168.     }
  169.  
  170.     public void setInitilaErrors()
  171.     {
  172.         ArrayList<Double> errors = this.listOfErrors.pop();
  173.         Layer outputLayer = getOutputLayer();
  174.         for(int i = 0; i < errors.size(); ++i)
  175.             outputLayer.errors.set(i, errors.get(i));
  176.     }
  177.  
  178.     public void batchfullBackWardPropagation(DerivativeFunction df, ArrayList<Double> expectedValues)
  179.     {
  180.         setInitilaErrors();
  181.         fullBackWardPropagation(df,expectedValues);
  182.     }
  183.  
  184.     public void fullBackWardPropagation(DerivativeFunction df, ArrayList<Double> expectedValues) {
  185.         Layer l = getOutputLayer();
  186.         l.computeInitialErrors(df, expectedValues);
  187.         // trying to get the last layer which we have not compute the error
  188.         for(l = l.prevLayer(); l.prevLink != null ; l = l.prevLayer()) {
  189.             l.computeErrors(df);
  190.         }
  191.     }
  192.  
  193.     // the error need to be compute when this function is used
  194.     public void fullUpdateWeights(Double alphaLearningRate) {
  195.       for(Layer l = getOutputLayer(); l.prevLink != null; l = l.prevLayer()) {
  196.           l.prevLink.UpdateWeights(alphaLearningRate);
  197.       }
  198.     }
  199.  
  200.     public void activateDebug()
  201.     {debug = true;}
  202.  
  203.     @Override
  204.     public boolean equals(Object other) {
  205.         if (this == other)
  206.             return true;
  207.         if (other == null || other.getClass() != Network.class)
  208.             return false;
  209.  
  210.         Network otherN = (Network) other;
  211.         if (this.layers.size() != otherN.layers.size())
  212.             return false;
  213.         for (int i = 0; i < this.layers.size(); ++i) {
  214.             Layer curL = layers.get(i);
  215.             Layer othL = otherN.layers.get(i);
  216.             if (!curL.equals(othL))
  217.                 return false;
  218.             Link nextLink = curL.nextLink;
  219.             Link nextOthL = othL.nextLink;
  220.             if (nextLink == null) {
  221.                 if (nextOthL != null)
  222.                     return false;
  223.                 continue;
  224.             }
  225.             if (!nextLink.equal(nextOthL))
  226.                 return false;
  227.         }
  228.         return true;
  229.     }
  230.  
  231.     @Override
  232.     public String toString() {
  233.         final var sb = new StringBuilder();
  234.  
  235.         Layer l = getInputLayer();
  236.         while (l != getOutputLayer()) {
  237.             sb.append(" -BIAS- ").append(l.bias).append("\n");
  238.             for (ArrayList<Double> weight : l.nextLink.weights)
  239.                 sb.append(weight).append("\n");
  240.             l = l.nextLayer();
  241.             //System.out.println("_____________________________________________________________");
  242.         }
  243.         sb.append(getOutputLayer().bias);
  244.         sb.append("===================================================");
  245.         sb.append("===================================================");
  246.         return sb.toString();
  247.     }
  248. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top