Advertisement
Guest User

dl4j addition failure

a guest
Jan 27th, 2017
242
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 5.27 KB | None | 0 0
  1. import java.util.Random;
  2.  
  3. import org.deeplearning4j.nn.api.Layer;
  4. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  5. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  6. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  7. import org.deeplearning4j.nn.conf.NeuralNetConfiguration.ListBuilder;
  8. import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
  9. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  10. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  11. import org.deeplearning4j.nn.conf.layers.OutputLayer.Builder;
  12. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  13. import org.deeplearning4j.nn.weights.WeightInit;
  14. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  15. import org.nd4j.linalg.activations.Activation;
  16. import org.nd4j.linalg.api.ndarray.INDArray;
  17. import org.nd4j.linalg.dataset.DataSet;
  18. import org.nd4j.linalg.factory.Nd4j;
  19. import org.nd4j.linalg.lossfunctions.LossFunctions;
  20.  
  21. /**
  22.  * This basic example shows how to manually create a DataSet and train it to an
  23.  * basic Network.
  24.  * <p>
  25.  * The network consists in 2 input-neurons, 1 hidden-layer with 4
  26.  * hidden-neurons, and 2 output-neurons.
  27.  * <p>
  28.  * I choose 2 output neurons, (the first fires for false, the second fires for
  29.  * true) because the Evaluation class needs one neuron per classification.
  30.  *
  31.  * @author Peter Großmann
  32.  */
  33. public class FFMultiply
  34. {
  35.    private static final int SAMPLES = 500;
  36.    private static final int LAYER_WIDTH = 2;
  37.    private static final int HIDDEN_LAYERS = 0;
  38.  
  39.    private static final Random r = new Random(9385);
  40.  
  41.    public static void main( String[] args )
  42.    {
  43.       INDArray inputData = Nd4j.zeros(SAMPLES, 2);
  44.       INDArray outputData = Nd4j.zeros(SAMPLES, 1);
  45.  
  46.       for (int i = 0; i < SAMPLES; i++)
  47.       {
  48.          double a = r.nextDouble() * 0.5;
  49.          double b = r.nextDouble() * 0.5;
  50.          double c = a + b;
  51.          inputData.putScalar(new int[] { i, 0 }, a);
  52.          inputData.putScalar(new int[] { i, 1 }, b);
  53.          outputData.putScalar(new int[] { i, 0 }, c);
  54.       }
  55.       DataSet ds = new DataSet(inputData, outputData);
  56.  
  57.       NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
  58.       builder.iterations(100);
  59.       builder.learningRate(0.01);
  60.       builder.seed(123);
  61.       builder.useDropConnect(false);
  62.       builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
  63.       builder.biasInit(0);
  64.       // TODO Enable?
  65.       builder.miniBatch(true);
  66.  
  67.       ListBuilder listBuilder = builder.list();
  68.       listBuilder.layer(0, new DenseLayer.Builder().nIn(2).nOut(LAYER_WIDTH).activation(Activation.IDENTITY).weightInit(WeightInit.DISTRIBUTION)
  69.          .dist(new UniformDistribution(0, 1)).build());
  70.       DenseLayer.Builder hiddenLayerBuilder = new DenseLayer.Builder().nIn(LAYER_WIDTH).nOut(LAYER_WIDTH).activation(Activation.IDENTITY)
  71.          .weightInit(WeightInit.DISTRIBUTION).dist(new UniformDistribution(0, 1));
  72.       for (int i = 0; i < HIDDEN_LAYERS; i++)
  73.       {
  74.          listBuilder.layer(i + 1, hiddenLayerBuilder.build());
  75.       }
  76.  
  77.       Builder outputLayerBuilder = new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD);
  78.       outputLayerBuilder.nIn(LAYER_WIDTH);
  79.       outputLayerBuilder.nOut(1);
  80.       outputLayerBuilder.activation(Activation.IDENTITY);
  81.       outputLayerBuilder.weightInit(WeightInit.DISTRIBUTION);
  82.       outputLayerBuilder.dist(new UniformDistribution(0, 1));
  83.       listBuilder.layer(HIDDEN_LAYERS + 1, outputLayerBuilder.build());
  84.  
  85.       // TODO What is this?
  86.       listBuilder.pretrain(false);
  87.  
  88.       // seems to be mandatory
  89.       // according to agibsonccc: You typically only use that with
  90.       // pretrain(true) when you want to do pretrain/finetune without changing
  91.       // the previous layers finetuned weights that's for autoencoders and
  92.       // rbms
  93.       // TODO Huh?
  94.       listBuilder.backprop(true);
  95.  
  96.       // build and init the network, will check if everything is configured
  97.       // correct
  98.       MultiLayerConfiguration conf = listBuilder.build();
  99.       MultiLayerNetwork net = new MultiLayerNetwork(conf);
  100.       net.init();
  101.  
  102.       // add an listener which outputs the error every 100 parameter updates
  103.       net.setListeners(new ScoreIterationListener(1));
  104.  
  105.       // C&P from GravesLSTMCharModellingExample
  106.       // Print the number of parameters in the network (and for each layer)
  107.       Layer[] layers = net.getLayers();
  108.       int totalNumParams = 0;
  109.       for (int i = 0; i < layers.length; i++)
  110.       {
  111.          int nParams = layers[i].numParams();
  112.          System.out.println("Number of parameters in layer " + i + ": " + nParams);
  113.          totalNumParams += nParams;
  114.       }
  115.       System.out.println("Total number of network parameters: " + totalNumParams);
  116.  
  117.       // here the actual learning takes place
  118.  
  119.       // net.fit(ds);
  120.  
  121.       // create output for every training sample
  122.       INDArray output = net.output(ds.getFeatureMatrix());
  123.       System.out.println(output);
  124.       System.out.println(outputData);
  125.  
  126.       // let Evaluation prints stats how often the right output had the
  127.       // highest value
  128.       // Evaluation eval = new Evaluation(2);
  129.       // eval.eval(ds.getLabels(), output);
  130.       // System.out.println(eval.stats());
  131.  
  132.    }
  133. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement