Advertisement
Guest User

Untitled

a guest
Aug 24th, 2016
60
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.96 KB | None | 0 0
  1. package org.deeplearning4j.examples.convolution;
  2.  
  3. import java.awt.Toolkit;
  4. import java.io.File;
  5. import java.util.Arrays;
  6. import java.util.Random;
  7.  
  8. import org.datavec.api.io.filters.BalancedPathFilter;
  9. import org.datavec.api.io.labels.ParentPathLabelGenerator;
  10. import org.datavec.api.split.FileSplit;
  11. import org.datavec.api.split.InputSplit;
  12. import org.datavec.image.loader.BaseImageLoader;
  13. import org.datavec.image.recordreader.ImageRecordReader;
  14. import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
  15. import org.deeplearning4j.eval.Evaluation;
  16. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  17. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  18. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  19. import org.deeplearning4j.nn.conf.Updater;
  20. import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
  21. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  22. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  23. import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
  24. import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
  25. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  26. import org.deeplearning4j.nn.weights.WeightInit;
  27. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  28. import org.deeplearning4j.ui.weights.ConvolutionalIterationListener;
  29. import org.deeplearning4j.ui.weights.HistogramIterationListener;
  30. import org.nd4j.linalg.api.buffer.DataBuffer;
  31. import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
  32. import org.nd4j.linalg.api.ndarray.INDArray;
  33. import org.nd4j.linalg.dataset.api.DataSet;
  34. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  35. import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
  36. import org.nd4j.linalg.lossfunctions.LossFunctions;
  37. import org.slf4j.Logger;
  38. import org.slf4j.LoggerFactory;
  39.  
  40. public class Driver2 {
  41.  
  42. protected static final Logger log = LoggerFactory.getLogger(Driver2.class);
  43.  
  44. protected static final String[] allowedFormats = BaseImageLoader.ALLOWED_FORMATS;
  45. static final long seed = 12345;
  46. static final Random rand = new Random(seed);
  47.  
  48.  
  49. static int height = 128, width = 128, channels = 1, outputNum = 5, epochs = 10;
  50.  
  51. public static void main(String[] args) throws Exception {
  52. File dir = new File("lbl_resizedImg");
  53. //Small training set
  54. // File dir = new File("small");
  55. //Specify how many different labels there is
  56. outputNum = dir.list().length;
  57.  
  58.  
  59. //Specify to use float instead of double
  60. DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);
  61.  
  62. FileSplit filesInDir = new FileSplit(dir, allowedFormats, rand);
  63.  
  64. ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
  65. BalancedPathFilter pathFilter = new BalancedPathFilter(rand, allowedFormats, labelMaker);
  66. // RandomPathFilter pathFilter = new RandomPathFilter(rand, allowedFormats);
  67. /*PathFilter pathFilter = new PathFilter() {
  68. @Override
  69. public URI[] filter(URI[] paths) {
  70. ArrayList<URI> newpaths = new ArrayList<URI>();
  71.  
  72. for(int i = 0; i < 10; i++){
  73. newpaths.add(paths[i]);
  74. }
  75.  
  76. return newpaths.toArray(new URI[newpaths.size()]);
  77. }
  78. };*/
  79.  
  80.  
  81.  
  82. InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter,50,50);
  83. InputSplit trainData = filesInDirSplit[0];
  84. InputSplit testData = filesInDirSplit[1];
  85.  
  86. ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
  87.  
  88. recordReader.initialize(trainData);
  89.  
  90. DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 5, 1, outputNum);
  91.  
  92. //Start Normalization of the data
  93. NormalizerStandardize normalizer = new NormalizerStandardize();
  94. normalizer.fit(iter);
  95. iter.reset();
  96.  
  97. while(iter.hasNext()){
  98. DataSet ds = iter.next();
  99. normalizer.transform(ds);
  100. }
  101. iter.reset();
  102. //End of normalization
  103.  
  104. MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
  105. .seed(seed)
  106. .iterations(1)
  107. .regularization(true).l2(0.0005)
  108. .dropOut(0.5)
  109. .learningRate(0.0001)//.biasLearningRate(0.02)
  110. //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
  111. .weightInit(WeightInit.XAVIER)
  112. .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
  113. .updater(Updater.NESTEROVS).momentum(0.9)
  114. .gradientNormalizationThreshold(0.5)
  115. .list()
  116. .layer(0, new ConvolutionLayer.Builder(5, 5)
  117. //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
  118. .nIn(channels)
  119. .stride(1, 1)
  120. .nOut(8)
  121. .activation("relu")
  122. .build())
  123. .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
  124. .kernelSize(3,3)
  125. .stride(2,2)
  126. .build())
  127. .layer(2, new ConvolutionLayer.Builder(3, 3)
  128. .stride(1, 1)
  129. .nOut(50)
  130. .activation("relu")
  131. .build())
  132. .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
  133. .kernelSize(2,2)
  134. .stride(2,2)
  135. .build())
  136. .layer(4, new DenseLayer.Builder().activation("relu")
  137. .nOut(500).build())
  138. .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
  139. .nOut(outputNum)
  140. .activation("softmax")
  141. .build())
  142. .backprop(true).pretrain(false);
  143.  
  144. // The builder needs the dimensions of the image along with the number of channels. these are 28x28 images in one channel
  145. new ConvolutionLayerSetup(builder,128,128,1);
  146.  
  147. MultiLayerConfiguration conf = builder.build();
  148. MultiLayerNetwork model = new MultiLayerNetwork(conf);
  149. model.init();
  150.  
  151.  
  152. log.info("Train model....");
  153. // model.setListeners(new ScoreIterationListener(10));
  154. // model.setListeners(new HistogramIterationListener(5));
  155. model.setListeners(Arrays.asList(new ScoreIterationListener(1), new ConvolutionalIterationListener(1)));
  156. for( int i=0; i<epochs; i++ ) {
  157. model.fit(iter);
  158.  
  159. System.out.println("Epoc: "+i);
  160.  
  161. iter.reset();
  162. log.info("Evaluate model....");
  163. Evaluation eval = new Evaluation(outputNum);
  164. while(iter.hasNext()){
  165. DataSet ds = iter.next();
  166. INDArray output = model.output(ds.getFeatureMatrix(), false);
  167. eval.eval(ds.getLabels(), output);
  168. }
  169.  
  170. log.info(eval.stats());
  171. iter.reset();
  172. }
  173.  
  174. Toolkit.getDefaultToolkit().beep();
  175.  
  176. }
  177.  
  178. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement