Advertisement
Guest User

Untitled

a guest
Aug 22nd, 2017
71
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.42 KB | None | 0 0
  1. package com.ukc.deeplearning;
  2.  
  3. /**
  4. * Created by Jon Baker on 20/08/2017. <Part of Socialsense> Copyright University of Kent
  5. */
  6.  
  7. import org.datavec.api.records.reader.RecordReader;
  8. import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
  9. import org.datavec.api.split.FileSplit;
  10. import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
  11. import org.deeplearning4j.eval.Evaluation;
  12. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  13. import org.deeplearning4j.nn.conf.GradientNormalization;
  14. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  15. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  16. import org.deeplearning4j.nn.conf.Updater;
  17. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  18. import org.deeplearning4j.nn.conf.layers.RBM;
  19. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  20. import org.deeplearning4j.nn.weights.WeightInit;
  21. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  22. import org.nd4j.linalg.activations.Activation;
  23. import org.nd4j.linalg.api.ndarray.INDArray;
  24. import org.nd4j.linalg.dataset.DataSet;
  25. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  26. import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
  27. import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
  28. import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
  29.  
  30. import java.io.File;
  31. import java.io.IOException;
  32.  
  33. public class DeepLearning {
  34.  
  35. public static void main(String[] args) throws Exception {
  36. int labelIndex = 0;
  37. int numClasses = 28;
  38.  
  39. int batchSizeTraining = 2828; //entire training size
  40. DataSet trainingData = readCSVDataset(
  41. "csv/train.csv",
  42. batchSizeTraining, labelIndex, numClasses);
  43.  
  44. // this is the data we want to classify
  45. int batchSizeTest = 11605;
  46. DataSet testData = readCSVDataset("csv/eval.csv",
  47. batchSizeTest, labelIndex, numClasses);
  48.  
  49.  
  50. //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
  51. DataNormalization normalizer = new NormalizerStandardize();
  52. normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
  53. normalizer.transform(trainingData); //Apply normalization to the training data
  54. normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
  55.  
  56. trainingData.scale();
  57. testData.scale();
  58.  
  59. //run the model
  60. MultiLayerNetwork model = buildModel();
  61.  
  62. model.fit(trainingData);
  63. //evaluate the model on the test set
  64. Evaluation eval = new Evaluation(numClasses);
  65. INDArray output = model.output(testData.getFeatureMatrix());
  66.  
  67. eval.eval(testData.getLabels(), output);
  68. System.out.println(eval.stats());
  69. }
  70.  
  71. public static MultiLayerNetwork buildModel() {
  72. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  73. .seed(123)
  74. .miniBatch(false)
  75. .weightInit(WeightInit.RELU)
  76. .iterations(10)
  77. .learningRate(0.2)
  78. .updater(Updater.ADAGRAD).gradientNormalization(GradientNormalization.ClipL2PerLayer)
  79. .regularization(true).l2(1e-1).l1(1e-3)
  80. .optimizationAlgo(OptimizationAlgorithm.LBFGS)
  81. .list()
  82. .layer(0, new RBM.Builder() //RBM is apparently 2-layer (1 visible 1 hidden)
  83. .nIn(19) // Input nodes
  84. .nOut(1024) // Output nodes
  85. //.activation(Activation.RELU) // Activation function type
  86. .weightInit(WeightInit.RELU) // Weight initialization
  87. .visibleUnit(RBM.VisibleUnit.GAUSSIAN)
  88. .hiddenUnit(RBM.HiddenUnit.RECTIFIED)
  89. .build())
  90. .layer(1, new RBM.Builder()
  91. .nIn(1024) // Input nodes
  92. .nOut(1024) // Output nodes
  93. .activation(Activation.RELU) // Activation function type
  94. .weightInit(WeightInit.RELU) // Weight initialization
  95. .hiddenUnit(RBM.HiddenUnit.RECTIFIED)
  96. .build())
  97. .layer(2, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
  98. .weightInit(WeightInit.RELU)
  99. .activation(Activation.SOFTMAX).nIn(1024).nOut(28).build())
  100. .backprop(true).pretrain(false)
  101. .build();
  102. MultiLayerNetwork net = new MultiLayerNetwork(conf);
  103. net.init();
  104. net.setListeners(new ScoreIterationListener(10));
  105. return net;
  106. }
  107.  
  108. /**
  109. * used for testing and training
  110. *
  111. * @param csvFileClasspath
  112. * @param batchSize
  113. * @param labelIndex
  114. * @param numClasses
  115. * @return
  116. * @throws IOException
  117. * @throws InterruptedException
  118. */
  119. private static DataSet readCSVDataset(
  120. String csvFileClasspath, int batchSize, int labelIndex, int numClasses)
  121. throws IOException, InterruptedException {
  122.  
  123. RecordReader rr = new CSVRecordReader();
  124. rr.initialize(new FileSplit(new File(csvFileClasspath)));
  125. DataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, labelIndex, numClasses);
  126. return iterator.next();
  127. }
  128. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement