Advertisement
Guest User

Untitled

a guest
Jun 30th, 2016
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.34 KB | None | 0 0
  1. package org.deeplearning4j.examples.convolution;
  2.  
  3. import org.canova.api.records.reader.RecordReader;
  4. import org.canova.api.records.reader.impl.CSVRecordReader;
  5. import org.canova.api.split.FileSplit;
  6. import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
  7. import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
  8. import org.deeplearning4j.eval.Evaluation;
  9. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  10. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  11. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  12. import org.deeplearning4j.nn.conf.Updater;
  13. import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
  14. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  15. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  16. import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
  17. import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
  18. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  19. import org.deeplearning4j.nn.weights.WeightInit;
  20. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  21. import org.nd4j.linalg.api.ndarray.INDArray;
  22. import org.nd4j.linalg.dataset.DataSet;
  23. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  24. import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
  25. import org.nd4j.linalg.lossfunctions.LossFunctions;
  26. import org.slf4j.Logger;
  27. import org.slf4j.LoggerFactory;
  28.  
  29. //import org.deeplearning4j.nn.conf.LearningRatePolicy;
  30. import java.io.File;
  31.  
  32. /**
  33. * Created by agibsonccc on 9/16/15.
  34. * modified by fornarat on 06/30/16
  35. */
  36. public class LenetMnistExampleCustom {
  37. private static final Logger log = LoggerFactory.getLogger(LenetMnistExampleCustom.class);
  38.  
  39. public static void main(String[] args) throws Exception {
  40.  
  41. /*
  42. int nChannels = 1;
  43. int outputNum = 10;
  44. int batchSize = 64;
  45. int nEpochs = 10;
  46. int iterations = 1;
  47. int seed = 123;
  48. */
  49.  
  50. int iterations = 1;
  51. int nChannels = 1;
  52.  
  53. int seed = 123;
  54. double learningRate = 0.01;
  55. int batchSize = 3500;
  56. int nEpochs = 30;
  57.  
  58. // int numInputs = 2;
  59. int outputNum = 2;
  60. // int numHiddenNodes = 20;
  61.  
  62.  
  63. log.info("Load data....");
  64.  
  65. RecordReader rrTrain = new CSVRecordReader();
  66. rrTrain.initialize(new FileSplit(new File("src/main/resources/classification/train.csv")));
  67. org.deeplearning4j.datasets.iterator.DataSetIterator dataSetIteratorTrain = new RecordReaderDataSetIterator(rrTrain,batchSize,0,2);
  68.  
  69.  
  70. NormalizerMinMaxScaler preProcessor = new NormalizerMinMaxScaler();
  71. log.info("During 'fit' the preprocessor calculates the metrics (std dev and mean for the standardizer, min and max for minmaxscaler) from the data given");
  72. log.info("Fit can take a dataset or a dataset iterator\n");
  73.  
  74. //Fitting a preprocessor with a dataset
  75. log.info("Fitting with a dataset...............");
  76. // DataSet trainDataset = dataSetIteratorTrain.next();
  77. preProcessor.fit(dataSetIteratorTrain);
  78. preProcessor.transform(dataSetIteratorTrain);
  79.  
  80. //Load the test/evaluation data:
  81. RecordReader rrTest = new CSVRecordReader();
  82. rrTest.initialize(new FileSplit(new File("src/main/resources/classification/test.csv")));
  83. org.deeplearning4j.datasets.iterator.DataSetIterator dataSetIteratorTest = new RecordReaderDataSetIterator(rrTest,batchSize,0,2);
  84.  
  85. preProcessor.fit(dataSetIteratorTest);
  86. preProcessor.transform(dataSetIteratorTest);
  87.  
  88.  
  89. log.info("Build model....");
  90. MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
  91. .seed(seed)
  92. .iterations(iterations)
  93. .regularization(true).l2(0.0005)
  94. .learningRate(learningRate)//.biasLearningRate(0.02)
  95. //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
  96. .weightInit(WeightInit.XAVIER)
  97. .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
  98. .updater(Updater.NESTEROVS).momentum(0.9)
  99. .list()
  100. .layer(0, new ConvolutionLayer.Builder(5, 5)
  101. .nIn(nChannels)
  102. .stride(1, 1)
  103. .nOut(20)
  104. // .nOut(outputNum)
  105. .activation("identity")
  106. .build())
  107. .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
  108. .kernelSize(2,2)
  109. .stride(2,2)
  110. .build())
  111. .layer(2, new ConvolutionLayer.Builder(5, 5)
  112. .nIn(nChannels)
  113. .stride(1, 1)
  114. .nOut(50)
  115. .activation("identity")
  116. .build())
  117. .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
  118. .kernelSize(2,2)
  119. .stride(2,2)
  120. .build())
  121. .layer(4, new DenseLayer.Builder().activation("relu")
  122. .nOut(500).build())
  123. .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
  124. .nOut(outputNum)
  125. .activation("softmax")
  126. .build())
  127. .backprop(true).pretrain(false);
  128. new ConvolutionLayerSetup(builder,100, 90,1);
  129.  
  130. MultiLayerConfiguration conf = builder.build();
  131. MultiLayerNetwork model = new MultiLayerNetwork(conf);
  132. model.init();
  133.  
  134.  
  135. log.info("Train model....");
  136. model.setListeners(new ScoreIterationListener(1));
  137. for( int i=0; i<nEpochs; i++ ) {
  138. model.fit(dataSetIteratorTrain);
  139.  
  140. log.info("*** Completed epoch {} ***", i);
  141.  
  142. log.info("Evaluate model....");
  143. Evaluation eval = new Evaluation(outputNum);
  144. while(dataSetIteratorTest.hasNext()){
  145. DataSet ds = dataSetIteratorTest.next();
  146. INDArray output = model.output(ds.getFeatureMatrix(), false);
  147. eval.eval(ds.getLabels(), output);
  148. }
  149. log.info(eval.stats());
  150. dataSetIteratorTest.reset();
  151. }
  152. log.info("****************Example finished********************");
  153. }
  154. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement