Advertisement
Guest User

Untitled

a guest
Jan 3rd, 2021
271
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 5.15 KB | None | 0 0
  1.  
  2. import org.bytedeco.opencv.opencv_face.FacemarkAAM;
  3. import org.datavec.api.records.reader.SequenceRecordReader;
  4. import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
  5. import org.datavec.api.split.NumberedFileInputSplit;
  6. import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
  7. import org.deeplearning4j.eval.ROC;
  8. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  9. import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
  10. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  11. import org.deeplearning4j.nn.conf.Updater;
  12. import org.deeplearning4j.nn.conf.inputs.InputType;
  13. import org.deeplearning4j.nn.conf.layers.LSTM;
  14. import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
  15. import org.deeplearning4j.nn.graph.ComputationGraph;
  16. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  17. import org.nd4j.linalg.api.ndarray.INDArray;
  18. import org.deeplearning4j.nn.weights.WeightInit;
  19. import org.nd4j.linalg.activations.Activation;
  20. import org.nd4j.linalg.dataset.api.DataSet;
  21. import org.nd4j.linalg.lossfunctions.LossFunctions;
  22. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  23. import org.nd4j.linalg.learning.config.Adam;
  24.  
  25. import org.slf4j.Logger;
  26. import org.slf4j.LoggerFactory;
  27.  
  28. import java.io.File;
  29. import org.apache.commons.io.FileUtils;
  30. import org.apache.commons.io.FilenameUtils;
  31. import java.io.IOException;
  32. import java.util.HashMap;
  33. import java.util.Arrays;
  34. import java.net.URL;
  35. import java.io.BufferedInputStream;
  36. import java.io.FileInputStream;
  37. import java.io.FileNotFoundException;
  38. import java.io.BufferedOutputStream;
  39. import java.io.FileOutputStream;
  40. import java.lang.Byte;
  41.  
  42.  
  43.  
  44.  
  45. import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
  46. import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
  47. import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
  48. import lombok.val;
  49. //import lombok.var;
  50.  
  51. public class lstmTest {
  52.  
  53.     void doLstm() throws FileNotFoundException, IOException, InterruptedException {
  54.  
  55.  
  56.  
  57.         val NB_TRAIN_EXAMPLES = 9500 ;// number of training examples
  58.         val NB_TEST_EXAMPLES = 499 ;// number of testing examples
  59.  
  60.         var featureBaseDir = "featureDir";
  61.         var targetBaseDir = "targetDir";
  62.         //Generate Data
  63.         var gen = new RandomGenerator();
  64.         gen.createCSV(featureBaseDir,targetBaseDir,10000,40);
  65.  
  66.         // Load training data
  67.  
  68.  
  69.         val trainFeatures = new CSVSequenceRecordReader(0, ",");
  70.         trainFeatures.initialize( new NumberedFileInputSplit(featureBaseDir + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
  71.  
  72.         val trainLabels = new CSVSequenceRecordReader();
  73.         trainLabels.initialize(new NumberedFileInputSplit(targetBaseDir + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
  74.  
  75.         val trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels,
  76.                 32, 1, true, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
  77.  
  78.  
  79.         // Load testing data
  80.         val testFeatures = new CSVSequenceRecordReader(0, ",");
  81.         testFeatures.initialize(new NumberedFileInputSplit(featureBaseDir + "/%d.csv", NB_TRAIN_EXAMPLES, NB_TRAIN_EXAMPLES + NB_TEST_EXAMPLES - 1));
  82.  
  83.         val testLabels = new CSVSequenceRecordReader();
  84.         testLabels.initialize(new NumberedFileInputSplit(targetBaseDir + "/%d.csv", NB_TRAIN_EXAMPLES, NB_TRAIN_EXAMPLES  + NB_TEST_EXAMPLES - 1));
  85.  
  86.         val testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels,
  87.                 32, 1, true,SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
  88.  
  89.  
  90.         // Set neural network parameters
  91.  
  92.         val NB_EPOCHS = 10;
  93.         val RANDOM_SEED = 1234;
  94.         val LEARNING_RATE = 0.005;
  95.         val BATCH_SIZE = 32;
  96.         val LSTM_LAYER_SIZE = 200;
  97.         val NUM_LABEL_CLASSES = 1;
  98.  
  99.         val conf = new NeuralNetConfiguration.Builder()
  100.                 .seed(RANDOM_SEED)
  101.                 .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
  102.                 .updater(new Adam(LEARNING_RATE))
  103.                 .weightInit(WeightInit.XAVIER)
  104.                 .dropOut(0.25)
  105.                 .graphBuilder()
  106.                 .addInputs("randomSequence")
  107.                 .setInputTypes(InputType.recurrent(38))
  108.                 .setOutputs("newRandom")
  109.                 .addLayer("L1", new LSTM.Builder()
  110.                                 //.nIn(NB_INPUTS)
  111.                                 .nOut(LSTM_LAYER_SIZE)
  112.                                 .forgetGateBiasInit(1)
  113.                                 .activation(Activation.TANH)
  114.                                 .build(),
  115.                         "randomSequence")
  116.  
  117.  
  118.                 .addLayer("newRandom", new RnnOutputLayer.Builder(LossFunctions.LossFunction.XENT)
  119.                         .activation(Activation.SIGMOID)
  120.                         .nIn(LSTM_LAYER_SIZE).nOut(NUM_LABEL_CLASSES).build(),"L1")
  121.                 .build();
  122.  
  123.  
  124.         val model = new ComputationGraph(conf)  ;
  125.  
  126.         model.fit(trainData, 100);
  127.  
  128.  
  129.  
  130.  
  131.         val evaluation = model.evaluateRegression(testData);
  132.  
  133.  
  134.  
  135.  
  136.  
  137.         System.out.println(evaluation.stats());
  138.     }
  139.  
  140. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement