Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.deeplearning4j.examples.convolution;
- import java.awt.Toolkit;
- import java.io.File;
- import java.util.Arrays;
- import java.util.Random;
- import org.datavec.api.io.filters.BalancedPathFilter;
- import org.datavec.api.io.labels.ParentPathLabelGenerator;
- import org.datavec.api.split.FileSplit;
- import org.datavec.api.split.InputSplit;
- import org.datavec.image.loader.BaseImageLoader;
- import org.datavec.image.recordreader.ImageRecordReader;
- import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
- import org.deeplearning4j.eval.Evaluation;
- import org.deeplearning4j.nn.api.OptimizationAlgorithm;
- import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
- import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
- import org.deeplearning4j.nn.conf.Updater;
- import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
- import org.deeplearning4j.nn.conf.layers.DenseLayer;
- import org.deeplearning4j.nn.conf.layers.OutputLayer;
- import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
- import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
- import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
- import org.deeplearning4j.nn.weights.WeightInit;
- import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
- import org.deeplearning4j.ui.weights.ConvolutionalIterationListener;
- import org.deeplearning4j.ui.weights.HistogramIterationListener;
- import org.nd4j.linalg.api.buffer.DataBuffer;
- import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
- import org.nd4j.linalg.api.ndarray.INDArray;
- import org.nd4j.linalg.dataset.api.DataSet;
- import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
- import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
- import org.nd4j.linalg.lossfunctions.LossFunctions;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
- public class Driver2 {
- protected static final Logger log = LoggerFactory.getLogger(Driver2.class);
- protected static final String[] allowedFormats = BaseImageLoader.ALLOWED_FORMATS;
- static final long seed = 12345;
- static final Random rand = new Random(seed);
- static int height = 128, width = 128, channels = 1, outputNum = 5, epochs = 10;
- public static void main(String[] args) throws Exception {
- File dir = new File("lbl_resizedImg");
- //Small training set
- // File dir = new File("small");
- //Specify how many different labels there is
- outputNum = dir.list().length;
- //Specify to use float instead of double
- DataTypeUtil.setDTypeForContext(DataBuffer.Type.FLOAT);
- FileSplit filesInDir = new FileSplit(dir, allowedFormats, rand);
- ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
- BalancedPathFilter pathFilter = new BalancedPathFilter(rand, allowedFormats, labelMaker);
- // RandomPathFilter pathFilter = new RandomPathFilter(rand, allowedFormats);
- /*PathFilter pathFilter = new PathFilter() {
- @Override
- public URI[] filter(URI[] paths) {
- ArrayList<URI> newpaths = new ArrayList<URI>();
- for(int i = 0; i < 10; i++){
- newpaths.add(paths[i]);
- }
- return newpaths.toArray(new URI[newpaths.size()]);
- }
- };*/
- InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter,50,50);
- InputSplit trainData = filesInDirSplit[0];
- InputSplit testData = filesInDirSplit[1];
- ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
- recordReader.initialize(trainData);
- DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 5, 1, outputNum);
- //Start Normalization of the data
- NormalizerStandardize normalizer = new NormalizerStandardize();
- normalizer.fit(iter);
- iter.reset();
- while(iter.hasNext()){
- DataSet ds = iter.next();
- normalizer.transform(ds);
- }
- iter.reset();
- //End of normalization
- MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
- .seed(seed)
- .iterations(1)
- .regularization(true).l2(0.0005)
- .dropOut(0.5)
- .learningRate(0.0001)//.biasLearningRate(0.02)
- //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
- .weightInit(WeightInit.XAVIER)
- .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- .updater(Updater.NESTEROVS).momentum(0.9)
- .gradientNormalizationThreshold(0.5)
- .list()
- .layer(0, new ConvolutionLayer.Builder(5, 5)
- //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
- .nIn(channels)
- .stride(1, 1)
- .nOut(8)
- .activation("relu")
- .build())
- .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
- .kernelSize(3,3)
- .stride(2,2)
- .build())
- .layer(2, new ConvolutionLayer.Builder(3, 3)
- .stride(1, 1)
- .nOut(50)
- .activation("relu")
- .build())
- .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
- .kernelSize(2,2)
- .stride(2,2)
- .build())
- .layer(4, new DenseLayer.Builder().activation("relu")
- .nOut(500).build())
- .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
- .nOut(outputNum)
- .activation("softmax")
- .build())
- .backprop(true).pretrain(false);
- // The builder needs the dimensions of the image along with the number of channels. these are 28x28 images in one channel
- new ConvolutionLayerSetup(builder,128,128,1);
- MultiLayerConfiguration conf = builder.build();
- MultiLayerNetwork model = new MultiLayerNetwork(conf);
- model.init();
- log.info("Train model....");
- // model.setListeners(new ScoreIterationListener(10));
- // model.setListeners(new HistogramIterationListener(5));
- model.setListeners(Arrays.asList(new ScoreIterationListener(1), new ConvolutionalIterationListener(1)));
- for( int i=0; i<epochs; i++ ) {
- model.fit(iter);
- System.out.println("Epoc: "+i);
- iter.reset();
- log.info("Evaluate model....");
- Evaluation eval = new Evaluation(outputNum);
- while(iter.hasNext()){
- DataSet ds = iter.next();
- INDArray output = model.output(ds.getFeatureMatrix(), false);
- eval.eval(ds.getLabels(), output);
- }
- log.info(eval.stats());
- iter.reset();
- }
- Toolkit.getDefaultToolkit().beep();
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement