Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.deeplearning4j.examples.convolution;
- import org.canova.api.records.reader.RecordReader;
- import org.canova.api.records.reader.impl.CSVRecordReader;
- import org.canova.api.split.FileSplit;
- import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
- import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
- import org.deeplearning4j.eval.Evaluation;
- import org.deeplearning4j.nn.api.OptimizationAlgorithm;
- import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
- import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
- import org.deeplearning4j.nn.conf.Updater;
- import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
- import org.deeplearning4j.nn.conf.layers.DenseLayer;
- import org.deeplearning4j.nn.conf.layers.OutputLayer;
- import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
- import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
- import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
- import org.deeplearning4j.nn.weights.WeightInit;
- import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
- import org.nd4j.linalg.api.ndarray.INDArray;
- import org.nd4j.linalg.dataset.DataSet;
- import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
- import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
- import org.nd4j.linalg.lossfunctions.LossFunctions;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
- //import org.deeplearning4j.nn.conf.LearningRatePolicy;
- import java.io.File;
- /**
- * Created by agibsonccc on 9/16/15.
- * modified by fornarat on 06/30/16
- */
- public class LenetMnistExampleCustom {
- private static final Logger log = LoggerFactory.getLogger(LenetMnistExampleCustom.class);
- public static void main(String[] args) throws Exception {
- /*
- int nChannels = 1;
- int outputNum = 10;
- int batchSize = 64;
- int nEpochs = 10;
- int iterations = 1;
- int seed = 123;
- */
- int iterations = 1;
- int nChannels = 1;
- int seed = 123;
- double learningRate = 0.01;
- int batchSize = 3500;
- int nEpochs = 30;
- // int numInputs = 2;
- int outputNum = 2;
- // int numHiddenNodes = 20;
- log.info("Load data....");
- RecordReader rrTrain = new CSVRecordReader();
- rrTrain.initialize(new FileSplit(new File("src/main/resources/classification/train.csv")));
- org.deeplearning4j.datasets.iterator.DataSetIterator dataSetIteratorTrain = new RecordReaderDataSetIterator(rrTrain,batchSize,0,2);
- NormalizerMinMaxScaler preProcessor = new NormalizerMinMaxScaler();
- log.info("During 'fit' the preprocessor calculates the metrics (std dev and mean for the standardizer, min and max for minmaxscaler) from the data given");
- log.info("Fit can take a dataset or a dataset iterator\n");
- //Fitting a preprocessor with a dataset
- log.info("Fitting with a dataset...............");
- // DataSet trainDataset = dataSetIteratorTrain.next();
- preProcessor.fit(dataSetIteratorTrain);
- preProcessor.transform(dataSetIteratorTrain);
- //Load the test/evaluation data:
- RecordReader rrTest = new CSVRecordReader();
- rrTest.initialize(new FileSplit(new File("src/main/resources/classification/test.csv")));
- org.deeplearning4j.datasets.iterator.DataSetIterator dataSetIteratorTest = new RecordReaderDataSetIterator(rrTest,batchSize,0,2);
- preProcessor.fit(dataSetIteratorTest);
- preProcessor.transform(dataSetIteratorTest);
- log.info("Build model....");
- MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
- .seed(seed)
- .iterations(iterations)
- .regularization(true).l2(0.0005)
- .learningRate(learningRate)//.biasLearningRate(0.02)
- //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
- .weightInit(WeightInit.XAVIER)
- .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- .updater(Updater.NESTEROVS).momentum(0.9)
- .list()
- .layer(0, new ConvolutionLayer.Builder(5, 5)
- .nIn(nChannels)
- .stride(1, 1)
- .nOut(20)
- // .nOut(outputNum)
- .activation("identity")
- .build())
- .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
- .kernelSize(2,2)
- .stride(2,2)
- .build())
- .layer(2, new ConvolutionLayer.Builder(5, 5)
- .nIn(nChannels)
- .stride(1, 1)
- .nOut(50)
- .activation("identity")
- .build())
- .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
- .kernelSize(2,2)
- .stride(2,2)
- .build())
- .layer(4, new DenseLayer.Builder().activation("relu")
- .nOut(500).build())
- .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
- .nOut(outputNum)
- .activation("softmax")
- .build())
- .backprop(true).pretrain(false);
- new ConvolutionLayerSetup(builder,100, 90,1);
- MultiLayerConfiguration conf = builder.build();
- MultiLayerNetwork model = new MultiLayerNetwork(conf);
- model.init();
- log.info("Train model....");
- model.setListeners(new ScoreIterationListener(1));
- for( int i=0; i<nEpochs; i++ ) {
- model.fit(dataSetIteratorTrain);
- log.info("*** Completed epoch {} ***", i);
- log.info("Evaluate model....");
- Evaluation eval = new Evaluation(outputNum);
- while(dataSetIteratorTest.hasNext()){
- DataSet ds = dataSetIteratorTest.next();
- INDArray output = model.output(ds.getFeatureMatrix(), false);
- eval.eval(ds.getLabels(), output);
- }
- log.info(eval.stats());
- dataSetIteratorTest.reset();
- }
- log.info("****************Example finished********************");
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement