Guest User

Untitled

a guest
Dec 10th, 2018
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.64 KB | None | 0 0
  1. package com.ty.test;
  2.  
  3. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  4. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  5. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  6. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  7. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  8. import org.deeplearning4j.nn.weights.WeightInit;
  9. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  10. import org.nd4j.evaluation.classification.Evaluation;
  11. import org.nd4j.linalg.activations.Activation;
  12. import org.nd4j.linalg.api.ndarray.INDArray;
  13. import org.nd4j.linalg.dataset.DataSet;
  14. import org.nd4j.linalg.factory.Nd4j;
  15. import org.nd4j.linalg.learning.config.Nesterovs;
  16. import org.nd4j.linalg.lossfunctions.LossFunctions;
  17.  
  18. import java.io.File;
  19. import java.io.FileNotFoundException;
  20. import java.util.Scanner;
  21.  
  22. public class Test {
  23. public static void main(String[] args) throws FileNotFoundException {
  24.  
  25. final int numRows = 38;
  26. final int numColumns = 50;
  27. int outputNum = 2; // number of output classes
  28. int rngSeed = 123; // random number seed for reproducibility
  29. int numEpochs = 100; // number of epochs to perform
  30.  
  31. Scanner scanner = new Scanner(new File("out.txt"));
  32. double[][] vectors = new double[numRows][numColumns];
  33. double[][] outputs = new double[numRows][2];
  34. int index = 0;
  35. while (scanner.hasNextLine()) {
  36. scanner.nextLine();
  37. vectors[index][index] = 1;
  38. index++;
  39. }
  40.  
  41. for(int i=0;i<38;i++){
  42. if(i==7||i==14){
  43. outputs[i][1] = 1;
  44. }else{
  45. outputs[i][0] = 1;
  46. }
  47. }
  48.  
  49. INDArray inputs = Nd4j.create(vectors);
  50. INDArray desiredOutputs = Nd4j.create(outputs);
  51. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  52. .seed(rngSeed) //include a random seed for reproducibility
  53. // use stochastic gradient descent as an optimization algorithm
  54. .updater(new Nesterovs(0.9, 0.9))
  55. .l2(1e-4)
  56. .list()
  57. .layer(0, new DenseLayer.Builder() //create the first, input layer with xavier initialization
  58. .nIn(numColumns)
  59. .nOut(2)
  60. .activation(Activation.RELU)
  61. .weightInit(WeightInit.XAVIER)
  62. .build())
  63. .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
  64. .nIn(2)
  65. .nOut(outputNum)
  66. .activation(Activation.SOFTMAX)
  67. .weightInit(WeightInit.XAVIER)
  68. .build())
  69. .build();
  70.  
  71. MultiLayerNetwork model = new MultiLayerNetwork(conf);
  72. model.init();
  73. //print the score with every 1 iteration
  74. model.setListeners(new ScoreIterationListener(1));
  75.  
  76. System.out.println("train model");
  77. for (int j = 0; j < numEpochs; j++) {
  78. model.fit(inputs, desiredOutputs);
  79. }
  80.  
  81. Evaluation eval = new Evaluation(outputNum); //create an evaluation object with 10 possible classes
  82.  
  83. for (int i = 0; i < numRows; i++) {
  84. INDArray inputs2 = Nd4j.create(vectors[i]);
  85. INDArray outputs2 = Nd4j.create(new double[]{outputs[i][0], outputs[i][1]});
  86. DataSet ds = new DataSet(inputs2, outputs2);
  87. INDArray output = model.output(ds.getFeatures());
  88. System.out.println(output);
  89. eval.eval(ds.getLabels(), output);
  90. }
  91.  
  92.  
  93. System.out.println(eval.stats());
  94.  
  95. }
  96. }
Add Comment
Please, Sign In to add comment