Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package NeuralNetwork;
- import Miscellaneous.RandomGenerator;
- import java.io.*;
- import java.util.*;
- import java.util.stream.Collectors;
- import java.util.stream.Stream;
- public class Network {
- public final RandomGenerator rg = new RandomGenerator();
- public final ArrayList<Layer> layers = new ArrayList<>();
- public int epoch;
- private boolean debug = false;
- LinkedList<ArrayList<Double>> listOfErrors;
- public Network(ArrayList<Integer> dimensions) {
- listOfErrors = new LinkedList<>();
- for(Integer i : dimensions)
- addLayer(i);
- }
- public Layer getInputLayer() {
- return layers.get(0);
- }
- public Layer getOutputLayer() {
- return layers.get(layers.size() - 1);
- }
- public void batch(Model m, ArrayList<Double> x, ArrayList<Double> y) {
- this.batchFullFeedForward(m, x, y);
- if(m.batchSize == listOfErrors.size()) {
- while(!listOfErrors.isEmpty()) {
- this.batchfullBackWardPropagation(m.functionPair.derivativeFunction, y);
- this.fullUpdateWeights(m.alphaLearningRate);
- }
- }
- }
- private void trainEpoch(Model m, TrainingSet ts) {
- for (var td : ts.set) {
- for (int i = 0; i < m.batchSize; ++i ) {
- if (debug)
- System.out.println("key: " + td.input.toString() + " val: "
- + td.label.toString());
- ArrayList<Double> x = td.input;
- ArrayList<Double> y = td.label;
- //fullFeedForward(m.functionPair.activationFunction, x);
- //var curError = getOutputLayer().computeInitialErrors(m.functionPair.derivativeFunction, y);
- batch(m, x, y);
- if (debug)
- System.out.println(this);
- }
- }
- }
- /* train() uses the model to set up weights and bias */
- public void train(Model m, TrainingSet ts) {
- assert m != null;
- assert ts != null;
- randomizeWeights();
- for (epoch = 0; !m.verifyModel.Verify(this); ++epoch)
- trainEpoch(m, ts);
- }
- public void save(FileWriter fileToSave) throws IOException {
- final BufferedWriter writer = new BufferedWriter(fileToSave);
- // Dimensions
- for (Layer l : layers)
- writer.write(l.nbNeuron + " ");
- writer.newLine();
- // Weights and bias
- for (Layer l : layers) {
- for (Double bia : l.bias)
- writer.write(bia.toString() + " ");
- if (l == getOutputLayer())
- continue;
- writer.newLine();
- for (var ws : l.nextLink.weights)
- for (Double w : ws)
- writer.write(w.toString() + " ");
- writer.newLine();
- }
- writer.close();
- }
- public static Network fromFile(FileReader fr) {
- final Network n;
- final var dim = new ArrayList<Integer>();
- final Scanner sc = new Scanner(fr).useLocale(Locale.US);
- // Get dimensions
- while (sc.hasNextInt())
- dim.add(sc.nextInt());
- n = new Network(dim);
- for (Layer l : n.layers) {
- l.bias = Stream.generate(sc::nextDouble)
- .limit(l.nbNeuron)
- .collect(Collectors.toCollection(ArrayList::new));
- // Get Link with next Layer
- // Does not exist for last Layer
- if (l == n.getOutputLayer())
- continue;
- for (int j = 0; j < l.nextLink.weights.size(); ++j)
- l.nextLink.weights.set(j, Stream.generate(sc::nextDouble)
- .limit(l.nextLayer().nbNeuron)
- .collect(Collectors.toCollection(ArrayList::new)));
- }
- sc.close();
- return n;
- }
- public void addLayer(int size) {
- // Create a layer and add it
- Layer l = new Layer(size);
- layers.add(l);
- // Can we link this one to a previous ?
- if (layers.size() == 1)
- return;
- // Retrieve last one
- Layer prev = layers.get(layers.size() - 2);
- // Create a link
- new Link(prev, l);
- }
- public void randomizeWeights() {
- for(int il = 1; il < layers.size(); ++il)
- {
- Layer l = layers.get(il);
- ArrayList<ArrayList<Double>> weights = l.prevLink.weights;
- for (ArrayList<Double> weight : weights)
- for (int j = 0; j < weight.size(); ++j)
- weight.set(j, rg.generateValue());
- // SHOULD WE DELETE THIS ?
- //for(int i = 0; i < l.bias.size(); ++i)
- // l.bias.set(i, rg.generateValue());
- }
- }
- public void setInput(ArrayList<Double> input) {
- assert (layers.size() != 0);
- Layer inputLayer = getInputLayer();
- assert(input.size() == inputLayer.nbNeuron);
- for(int i = 0; i < inputLayer.nbNeuron; ++i)
- inputLayer.values.set(i,input.get(i));
- }
- public ArrayList<Double> batchFullFeedForward(Model model, ArrayList<Double> input, ArrayList<Double> output)
- {
- var res = fullFeedForward(model.functionPair.activationFunction, input);
- listOfErrors.add(getOutputLayer().computeInitialErrors(model.functionPair.derivativeFunction, output));
- return res;
- }
- public ArrayList<Double> fullFeedForward(ActivationFunction af, ArrayList<Double> input) {
- setInput(input);
- for(int i = 1; i < layers.size(); ++i) {
- Layer l = layers.get(i);
- l.feedForward(af);
- }
- return getOutputLayer().values;
- }
- public void setInitilaErrors()
- {
- ArrayList<Double> errors = this.listOfErrors.pop();
- Layer outputLayer = getOutputLayer();
- for(int i = 0; i < errors.size(); ++i)
- outputLayer.errors.set(i, errors.get(i));
- }
- public void batchfullBackWardPropagation(DerivativeFunction df, ArrayList<Double> expectedValues)
- {
- setInitilaErrors();
- fullBackWardPropagation(df,expectedValues);
- }
- public void fullBackWardPropagation(DerivativeFunction df, ArrayList<Double> expectedValues) {
- Layer l = getOutputLayer();
- l.computeInitialErrors(df, expectedValues);
- // trying to get the last layer which we have not compute the error
- for(l = l.prevLayer(); l.prevLink != null ; l = l.prevLayer()) {
- l.computeErrors(df);
- }
- }
- // the error need to be compute when this function is used
- public void fullUpdateWeights(Double alphaLearningRate) {
- for(Layer l = getOutputLayer(); l.prevLink != null; l = l.prevLayer()) {
- l.prevLink.UpdateWeights(alphaLearningRate);
- }
- }
- public void activateDebug()
- {debug = true;}
- @Override
- public boolean equals(Object other) {
- if (this == other)
- return true;
- if (other == null || other.getClass() != Network.class)
- return false;
- Network otherN = (Network) other;
- if (this.layers.size() != otherN.layers.size())
- return false;
- for (int i = 0; i < this.layers.size(); ++i) {
- Layer curL = layers.get(i);
- Layer othL = otherN.layers.get(i);
- if (!curL.equals(othL))
- return false;
- Link nextLink = curL.nextLink;
- Link nextOthL = othL.nextLink;
- if (nextLink == null) {
- if (nextOthL != null)
- return false;
- continue;
- }
- if (!nextLink.equal(nextOthL))
- return false;
- }
- return true;
- }
- @Override
- public String toString() {
- final var sb = new StringBuilder();
- Layer l = getInputLayer();
- while (l != getOutputLayer()) {
- sb.append(" -BIAS- ").append(l.bias).append("\n");
- for (ArrayList<Double> weight : l.nextLink.weights)
- sb.append(weight).append("\n");
- l = l.nextLayer();
- //System.out.println("_____________________________________________________________");
- }
- sb.append(getOutputLayer().bias);
- sb.append("===================================================");
- sb.append("===================================================");
- return sb.toString();
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement