Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.deeplearning4j.examples.quickstart.modeling.feedforward.classification;
- import org.datavec.api.records.reader.RecordReader;
- import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
- import org.datavec.api.split.FileSplit;
- import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
- import org.deeplearning4j.datasets.iterator.AsyncShieldDataSetIterator;
- import org.deeplearning4j.examples.utils.DownloaderUtility;
- import org.deeplearning4j.examples.utils.PlotUtil;
- import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
- import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
- import org.deeplearning4j.nn.conf.layers.DenseLayer;
- import org.deeplearning4j.nn.conf.layers.DropoutLayer;
- import org.deeplearning4j.nn.conf.layers.OutputLayer;
- import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
- import org.deeplearning4j.nn.weights.WeightInit;
- import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
- import org.nd4j.evaluation.classification.Evaluation;
- import org.nd4j.linalg.activations.Activation;
- import org.nd4j.linalg.api.ndarray.INDArray;
- import org.nd4j.linalg.dataset.DataSet;
- import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
- import org.nd4j.linalg.learning.config.Nesterovs;
- import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
- import java.io.File;
- import java.util.concurrent.TimeUnit;
- /**
- * "Linear" Data Classification Example
- * <p>
- * Based on the data from Jason Baldridge:
- * https://github.com/jasonbaldridge/try-tf/tree/master/simdata
- *
- * @author Josh Patterson
- * @author Alex Black (added plots)
- */
- @SuppressWarnings("DuplicatedCode")
- public class scs {
- public static boolean visualize = true;
- public static String dataLocalPath;
- public static void main(String[] args) throws Exception {
- int seed = 123;
- double learningRate = 0.01;
- int batchSize = 20;
- int nEpochs = 7;
- int numInputs = 7;
- int numOutputs = 5;
- int numHiddenNodes = 20;
- //dataLocalPath = ''
- //Load the training data:
- RecordReader rr = new CSVRecordReader();
- rr.initialize(new FileSplit(new File( "scs_TRAIN.csv")));
- DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 7,5);
- DataSetIterator trainI = new AsyncShieldDataSetIterator(trainIter);
- //Load the test/evaluation data:
- RecordReader rrTest = new CSVRecordReader();
- rrTest.initialize(new FileSplit(new File("scs_TEST.csv")));
- DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 7,5);
- DataSetIterator testI = new AsyncShieldDataSetIterator(testIter);
- MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
- .seed(seed)
- .weightInit(WeightInit.XAVIER)
- .updater(new Nesterovs(learningRate, 0.9))
- .list()
- .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
- .activation(Activation.RELU)
- .dropOut(0.75)
- .build())
- //.layer(new DropoutLayer(0.5))
- .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
- .activation(Activation.SOFTMAX)
- .nIn(numHiddenNodes).nOut(numOutputs).build())
- .build();
- MultiLayerNetwork model = new MultiLayerNetwork(conf);
- model.init();
- model.setListeners(new ScoreIterationListener(100)); //Print score every 100 parameter updates
- model.fit(trainI, nEpochs);
- System.out.println("Evaluate model....");
- Evaluation eval = new Evaluation(numOutputs);
- while (testI.hasNext()) {
- DataSet t = testI.next();
- INDArray features = t.getFeatures();
- INDArray labels = t.getLabels();
- INDArray predicted = model.output(features, false);
- eval.eval(labels, predicted);
- }
- //An alternate way to do the above loop
- //Evaluation evalResults = model.evaluate(testIter);
- //Print the evaluation statistics
- System.out.println(eval.stats());
- System.out.println("\n****************Example finished********************");
- //Training is complete. Code that follows is for plotting the data & predictions only
- }
- }
Add Comment
Please, Sign In to add comment