Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import org.bytedeco.opencv.opencv_face.FacemarkAAM;
- import org.datavec.api.records.reader.SequenceRecordReader;
- import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
- import org.datavec.api.split.NumberedFileInputSplit;
- import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
- import org.deeplearning4j.eval.ROC;
- import org.deeplearning4j.nn.api.OptimizationAlgorithm;
- import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
- import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
- import org.deeplearning4j.nn.conf.Updater;
- import org.deeplearning4j.nn.conf.inputs.InputType;
- import org.deeplearning4j.nn.conf.layers.LSTM;
- import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
- import org.deeplearning4j.nn.graph.ComputationGraph;
- import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
- import org.nd4j.linalg.api.ndarray.INDArray;
- import org.deeplearning4j.nn.weights.WeightInit;
- import org.nd4j.linalg.activations.Activation;
- import org.nd4j.linalg.dataset.api.DataSet;
- import org.nd4j.linalg.lossfunctions.LossFunctions;
- import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
- import org.nd4j.linalg.learning.config.Adam;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
- import java.io.File;
- import org.apache.commons.io.FileUtils;
- import org.apache.commons.io.FilenameUtils;
- import java.io.IOException;
- import java.util.HashMap;
- import java.util.Arrays;
- import java.net.URL;
- import java.io.BufferedInputStream;
- import java.io.FileInputStream;
- import java.io.FileNotFoundException;
- import java.io.BufferedOutputStream;
- import java.io.FileOutputStream;
- import java.lang.Byte;
- import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
- import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
- import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
- import lombok.val;
- //import lombok.var;
- public class lstmTest {
- void doLstm() throws FileNotFoundException, IOException, InterruptedException {
- val NB_TRAIN_EXAMPLES = 9500 ;// number of training examples
- val NB_TEST_EXAMPLES = 499 ;// number of testing examples
- var featureBaseDir = "featureDir";
- var targetBaseDir = "targetDir";
- //Generate Data
- var gen = new RandomGenerator();
- gen.createCSV(featureBaseDir,targetBaseDir,10000,40);
- // Load training data
- val trainFeatures = new CSVSequenceRecordReader(0, ",");
- trainFeatures.initialize( new NumberedFileInputSplit(featureBaseDir + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
- val trainLabels = new CSVSequenceRecordReader();
- trainLabels.initialize(new NumberedFileInputSplit(targetBaseDir + "/%d.csv", 0, NB_TRAIN_EXAMPLES - 1));
- val trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels,
- 32, 1, true, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
- // Load testing data
- val testFeatures = new CSVSequenceRecordReader(0, ",");
- testFeatures.initialize(new NumberedFileInputSplit(featureBaseDir + "/%d.csv", NB_TRAIN_EXAMPLES, NB_TRAIN_EXAMPLES + NB_TEST_EXAMPLES - 1));
- val testLabels = new CSVSequenceRecordReader();
- testLabels.initialize(new NumberedFileInputSplit(targetBaseDir + "/%d.csv", NB_TRAIN_EXAMPLES, NB_TRAIN_EXAMPLES + NB_TEST_EXAMPLES - 1));
- val testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels,
- 32, 1, true,SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
- // Set neural network parameters
- val NB_EPOCHS = 10;
- val RANDOM_SEED = 1234;
- val LEARNING_RATE = 0.005;
- val BATCH_SIZE = 32;
- val LSTM_LAYER_SIZE = 200;
- val NUM_LABEL_CLASSES = 1;
- val conf = new NeuralNetConfiguration.Builder()
- .seed(RANDOM_SEED)
- .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- .updater(new Adam(LEARNING_RATE))
- .weightInit(WeightInit.XAVIER)
- .dropOut(0.25)
- .graphBuilder()
- .addInputs("randomSequence")
- .setInputTypes(InputType.recurrent(38))
- .setOutputs("newRandom")
- .addLayer("L1", new LSTM.Builder()
- //.nIn(NB_INPUTS)
- .nOut(LSTM_LAYER_SIZE)
- .forgetGateBiasInit(1)
- .activation(Activation.TANH)
- .build(),
- "randomSequence")
- .addLayer("newRandom", new RnnOutputLayer.Builder(LossFunctions.LossFunction.XENT)
- .activation(Activation.SIGMOID)
- .nIn(LSTM_LAYER_SIZE).nOut(NUM_LABEL_CLASSES).build(),"L1")
- .build();
- val model = new ComputationGraph(conf) ;
- model.fit(trainData, 100);
- val evaluation = model.evaluateRegression(testData);
- System.out.println(evaluation.stats());
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement