daily pastebin goal
41%
SHARE
TWEET

Untitled

a guest Dec 10th, 2018 62 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. package com.ty.test;
  2.  
  3. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  4. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  5. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  6. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  7. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  8. import org.deeplearning4j.nn.weights.WeightInit;
  9. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  10. import org.nd4j.evaluation.classification.Evaluation;
  11. import org.nd4j.linalg.activations.Activation;
  12. import org.nd4j.linalg.api.ndarray.INDArray;
  13. import org.nd4j.linalg.dataset.DataSet;
  14. import org.nd4j.linalg.factory.Nd4j;
  15. import org.nd4j.linalg.learning.config.Nesterovs;
  16. import org.nd4j.linalg.lossfunctions.LossFunctions;
  17.  
  18. import java.io.File;
  19. import java.io.FileNotFoundException;
  20. import java.util.Scanner;
  21.  
  22. public class Test {
  23.     public static void main(String[] args) throws FileNotFoundException {
  24.  
  25.         final int numRows = 38;
  26.         final int numColumns = 50;
  27.         int outputNum = 2; // number of output classes
  28.         int rngSeed = 123; // random number seed for reproducibility
  29.         int numEpochs = 100; // number of epochs to perform
  30.  
  31.         Scanner scanner = new Scanner(new File("out.txt"));
  32.         double[][] vectors = new double[numRows][numColumns];
  33.         double[][] outputs = new double[numRows][2];
  34.         int index = 0;
  35.         while (scanner.hasNextLine()) {
  36.             scanner.nextLine();
  37.             vectors[index][index] = 1;
  38.             index++;
  39.         }
  40.  
  41.         for(int i=0;i<38;i++){
  42.             if(i==7||i==14){
  43.                 outputs[i][1] = 1;
  44.             }else{
  45.                 outputs[i][0] = 1;
  46.             }
  47.         }
  48.  
  49.         INDArray inputs = Nd4j.create(vectors);
  50.         INDArray desiredOutputs = Nd4j.create(outputs);
  51.         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  52.                 .seed(rngSeed) //include a random seed for reproducibility
  53.                 // use stochastic gradient descent as an optimization algorithm
  54.                 .updater(new Nesterovs(0.9, 0.9))
  55.                 .l2(1e-4)
  56.                 .list()
  57.                 .layer(0, new DenseLayer.Builder() //create the first, input layer with xavier initialization
  58.                         .nIn(numColumns)
  59.                         .nOut(2)
  60.                         .activation(Activation.RELU)
  61.                         .weightInit(WeightInit.XAVIER)
  62.                         .build())
  63.                 .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
  64.                         .nIn(2)
  65.                         .nOut(outputNum)
  66.                         .activation(Activation.SOFTMAX)
  67.                         .weightInit(WeightInit.XAVIER)
  68.                         .build())
  69.                 .build();
  70.  
  71.         MultiLayerNetwork model = new MultiLayerNetwork(conf);
  72.         model.init();
  73.         //print the score with every 1 iteration
  74.         model.setListeners(new ScoreIterationListener(1));
  75.  
  76.         System.out.println("train model");
  77.         for (int j = 0; j < numEpochs; j++) {
  78.             model.fit(inputs, desiredOutputs);
  79.         }
  80.  
  81.         Evaluation eval = new Evaluation(outputNum); //create an evaluation object with 10 possible classes
  82.  
  83.         for (int i = 0; i < numRows; i++) {
  84.             INDArray inputs2 = Nd4j.create(vectors[i]);
  85.             INDArray outputs2 = Nd4j.create(new double[]{outputs[i][0], outputs[i][1]});
  86.             DataSet ds = new DataSet(inputs2, outputs2);
  87.             INDArray output = model.output(ds.getFeatures());
  88.             System.out.println(output);
  89.             eval.eval(ds.getLabels(), output);
  90.         }
  91.  
  92.  
  93.         System.out.println(eval.stats());
  94.  
  95.     }
  96. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top