Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package com.ukc.deeplearning;
- /**
- * Created by Jon Baker on 20/08/2017. <Part of Socialsense> Copyright University of Kent
- */
- import org.datavec.api.records.reader.RecordReader;
- import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
- import org.datavec.api.split.FileSplit;
- import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
- import org.deeplearning4j.eval.Evaluation;
- import org.deeplearning4j.nn.api.OptimizationAlgorithm;
- import org.deeplearning4j.nn.conf.GradientNormalization;
- import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
- import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
- import org.deeplearning4j.nn.conf.Updater;
- import org.deeplearning4j.nn.conf.layers.OutputLayer;
- import org.deeplearning4j.nn.conf.layers.RBM;
- import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
- import org.deeplearning4j.nn.weights.WeightInit;
- import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
- import org.nd4j.linalg.activations.Activation;
- import org.nd4j.linalg.api.ndarray.INDArray;
- import org.nd4j.linalg.dataset.DataSet;
- import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
- import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
- import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
- import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
- import java.io.File;
- import java.io.IOException;
- public class DeepLearning {
- public static void main(String[] args) throws Exception {
- int labelIndex = 0;
- int numClasses = 28;
- int batchSizeTraining = 2828; //entire training size
- DataSet trainingData = readCSVDataset(
- "csv/train.csv",
- batchSizeTraining, labelIndex, numClasses);
- // this is the data we want to classify
- int batchSizeTest = 11605;
- DataSet testData = readCSVDataset("csv/eval.csv",
- batchSizeTest, labelIndex, numClasses);
- //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
- DataNormalization normalizer = new NormalizerStandardize();
- normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
- normalizer.transform(trainingData); //Apply normalization to the training data
- normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
- trainingData.scale();
- testData.scale();
- //run the model
- MultiLayerNetwork model = buildModel();
- model.fit(trainingData);
- //evaluate the model on the test set
- Evaluation eval = new Evaluation(numClasses);
- INDArray output = model.output(testData.getFeatureMatrix());
- eval.eval(testData.getLabels(), output);
- System.out.println(eval.stats());
- }
- public static MultiLayerNetwork buildModel() {
- MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
- .seed(123)
- .miniBatch(false)
- .weightInit(WeightInit.RELU)
- .iterations(10)
- .learningRate(0.2)
- .updater(Updater.ADAGRAD).gradientNormalization(GradientNormalization.ClipL2PerLayer)
- .regularization(true).l2(1e-1).l1(1e-3)
- .optimizationAlgo(OptimizationAlgorithm.LBFGS)
- .list()
- .layer(0, new RBM.Builder() //RBM is apparently 2-layer (1 visible 1 hidden)
- .nIn(19) // Input nodes
- .nOut(1024) // Output nodes
- //.activation(Activation.RELU) // Activation function type
- .weightInit(WeightInit.RELU) // Weight initialization
- .visibleUnit(RBM.VisibleUnit.GAUSSIAN)
- .hiddenUnit(RBM.HiddenUnit.RECTIFIED)
- .build())
- .layer(1, new RBM.Builder()
- .nIn(1024) // Input nodes
- .nOut(1024) // Output nodes
- .activation(Activation.RELU) // Activation function type
- .weightInit(WeightInit.RELU) // Weight initialization
- .hiddenUnit(RBM.HiddenUnit.RECTIFIED)
- .build())
- .layer(2, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
- .weightInit(WeightInit.RELU)
- .activation(Activation.SOFTMAX).nIn(1024).nOut(28).build())
- .backprop(true).pretrain(false)
- .build();
- MultiLayerNetwork net = new MultiLayerNetwork(conf);
- net.init();
- net.setListeners(new ScoreIterationListener(10));
- return net;
- }
- /**
- * used for testing and training
- *
- * @param csvFileClasspath
- * @param batchSize
- * @param labelIndex
- * @param numClasses
- * @return
- * @throws IOException
- * @throws InterruptedException
- */
- private static DataSet readCSVDataset(
- String csvFileClasspath, int batchSize, int labelIndex, int numClasses)
- throws IOException, InterruptedException {
- RecordReader rr = new CSVRecordReader();
- rr.initialize(new FileSplit(new File(csvFileClasspath)));
- DataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, labelIndex, numClasses);
- return iterator.next();
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement