Guest User

Untitled

a guest
Sep 22nd, 2020
78
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.35 KB | None | 0 0
  1. package org.deeplearning4j.examples.quickstart.modeling.feedforward.classification;
  2.  
  3. import org.datavec.api.records.reader.RecordReader;
  4. import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
  5. import org.datavec.api.split.FileSplit;
  6. import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
  7. import org.deeplearning4j.datasets.iterator.AsyncShieldDataSetIterator;
  8. import org.deeplearning4j.examples.utils.DownloaderUtility;
  9. import org.deeplearning4j.examples.utils.PlotUtil;
  10. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  11. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  12. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  13. import org.deeplearning4j.nn.conf.layers.DropoutLayer;
  14. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  15. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  16. import org.deeplearning4j.nn.weights.WeightInit;
  17. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  18. import org.nd4j.evaluation.classification.Evaluation;
  19. import org.nd4j.linalg.activations.Activation;
  20. import org.nd4j.linalg.api.ndarray.INDArray;
  21. import org.nd4j.linalg.dataset.DataSet;
  22. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  23. import org.nd4j.linalg.learning.config.Nesterovs;
  24. import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
  25.  
  26. import java.io.File;
  27. import java.util.concurrent.TimeUnit;
  28.  
  29. /**
  30. * "Linear" Data Classification Example
  31. * <p>
  32. * Based on the data from Jason Baldridge:
  33. * https://github.com/jasonbaldridge/try-tf/tree/master/simdata
  34. *
  35. * @author Josh Patterson
  36. * @author Alex Black (added plots)
  37. */
  38. @SuppressWarnings("DuplicatedCode")
  39. public class scs {
  40.  
  41. public static boolean visualize = true;
  42. public static String dataLocalPath;
  43.  
  44. public static void main(String[] args) throws Exception {
  45. int seed = 123;
  46. double learningRate = 0.01;
  47. int batchSize = 20;
  48. int nEpochs = 7;
  49.  
  50. int numInputs = 7;
  51. int numOutputs = 5;
  52. int numHiddenNodes = 20;
  53.  
  54. //dataLocalPath = ''
  55. //Load the training data:
  56. RecordReader rr = new CSVRecordReader();
  57. rr.initialize(new FileSplit(new File( "scs_TRAIN.csv")));
  58. DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 7,5);
  59. DataSetIterator trainI = new AsyncShieldDataSetIterator(trainIter);
  60.  
  61. //Load the test/evaluation data:
  62. RecordReader rrTest = new CSVRecordReader();
  63. rrTest.initialize(new FileSplit(new File("scs_TEST.csv")));
  64. DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 7,5);
  65. DataSetIterator testI = new AsyncShieldDataSetIterator(testIter);
  66.  
  67. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  68. .seed(seed)
  69. .weightInit(WeightInit.XAVIER)
  70. .updater(new Nesterovs(learningRate, 0.9))
  71. .list()
  72. .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
  73. .activation(Activation.RELU)
  74. .dropOut(0.75)
  75. .build())
  76. //.layer(new DropoutLayer(0.5))
  77. .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
  78. .activation(Activation.SOFTMAX)
  79. .nIn(numHiddenNodes).nOut(numOutputs).build())
  80. .build();
  81.  
  82.  
  83. MultiLayerNetwork model = new MultiLayerNetwork(conf);
  84. model.init();
  85. model.setListeners(new ScoreIterationListener(100)); //Print score every 100 parameter updates
  86.  
  87. model.fit(trainI, nEpochs);
  88.  
  89. System.out.println("Evaluate model....");
  90. Evaluation eval = new Evaluation(numOutputs);
  91. while (testI.hasNext()) {
  92. DataSet t = testI.next();
  93. INDArray features = t.getFeatures();
  94. INDArray labels = t.getLabels();
  95. INDArray predicted = model.output(features, false);
  96. eval.eval(labels, predicted);
  97. }
  98. //An alternate way to do the above loop
  99. //Evaluation evalResults = model.evaluate(testIter);
  100.  
  101. //Print the evaluation statistics
  102. System.out.println(eval.stats());
  103.  
  104. System.out.println("\n****************Example finished********************");
  105. //Training is complete. Code that follows is for plotting the data & predictions only
  106.  
  107. }
  108.  
  109. }
  110.  
Add Comment
Please, Sign In to add comment