Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package foo;
- import org.apache.commons.io.FileUtils;
- import org.datavec.api.records.reader.SequenceRecordReader;
- import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
- import org.datavec.api.split.NumberedFileInputSplit;
- import org.deeplearning4j.arbiter.DL4JConfiguration;
- import org.deeplearning4j.arbiter.MultiLayerSpace;
- import org.deeplearning4j.arbiter.data.DataSetIteratorProvider;
- import org.deeplearning4j.arbiter.layers.GravesLSTMLayerSpace;
- import org.deeplearning4j.arbiter.layers.RnnOutputLayerSpace;
- import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
- import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
- import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
- import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
- import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
- import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
- import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
- import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition;
- import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
- import org.deeplearning4j.arbiter.optimize.candidategenerator.RandomSearchGenerator;
- import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
- import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
- import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
- import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
- import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
- import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
- import org.deeplearning4j.arbiter.optimize.runner.listener.runner.LoggingOptimizationRunnerStatusListener;
- import org.deeplearning4j.arbiter.saver.local.multilayer.LocalMultiLayerNetworkSaver;
- import org.deeplearning4j.arbiter.scoring.multilayer.TestSetAccuracyScoreFunction;
- import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator;
- import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
- import org.deeplearning4j.nn.api.OptimizationAlgorithm;
- import org.deeplearning4j.nn.conf.*;
- import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
- import org.deeplearning4j.nn.weights.WeightInit;
- import org.deeplearning4j.util.ModelSerializer;
- import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
- import org.nd4j.linalg.lossfunctions.LossFunctions;
- import java.io.*;
- import java.nio.charset.Charset;
- import java.util.List;
- import java.util.concurrent.TimeUnit;
- public class ClassificationArbiterCPU {
- private static final String BASE_DIR = ".. some folder";
- // data is split for each i.csv file on time axis using a 60/40 split
- private static final String DATA_DIR = ".. folder with train and test split";
- private static final String RESOURCE_DIR = "... somewhere to store folder";
- public static void main(String[] args) {
- String modelName = "category_D1";
- String modelPath = BASE_DIR + "/model/";
- String resultsPath = BASE_DIR + RESOURCE_DIR + "/results/";
- String trainFeaturesPath = BASE_DIR + DATA_DIR + "/train/features/";
- String trainLabelsPath = BASE_DIR + DATA_DIR + "/train/labels/";
- String validationFeaturesPath = BASE_DIR + DATA_DIR + "/validation/features/";
- String validationLabelsPath = BASE_DIR + DATA_DIR + "/validation/labels/";
- int maxFileId = 10;
- int numInput = 20;
- ClassificationArbiterCPU arbiter = new ClassificationArbiterCPU(modelPath, modelName);
- arbiter.trainMultiInputNetwork(
- numInput,
- trainFeaturesPath,
- trainLabelsPath,
- validationFeaturesPath,
- validationLabelsPath,
- 0,
- maxFileId,
- resultsPath,
- 20);
- }
- private String modelPath;
- private String modelName;
- final int BATCH_SIZE = 100;
- public ClassificationArbiterCPU(String modelPath, String modelName){
- this.modelPath = modelPath;
- this.modelName = modelName;
- }
- public void trainMultiInputNetwork(
- Integer numInput,
- String trainFeaturesPath,
- String trainLabelsPath,
- String testFeaturesPath,
- String testLabelsPath,
- Integer minFileId,
- Integer maxFileId,
- String resultsPath,
- final int mintutesToTrain) {
- CandidateGenerator<DL4JConfiguration> generator = createArbiterRNN(numInput);
- //Load data
- DataSetIterator trainingData = getDataIterator(trainFeaturesPath, trainLabelsPath, minFileId, maxFileId);
- DataSetIterator testData = getDataIterator(testFeaturesPath, testLabelsPath, minFileId, maxFileId);
- //data normalization not necessary (ONE-HOT encoding already implemented)
- DataProvider<DataSetIterator> dataProvider = new DataSetIteratorProvider(trainingData, testData);
- //This will result in examples being saved to arbiterExample/0/, arbiterExample/1/, arbiterExample/2/, ...
- String baseSaveDirectory = clearAndCreateArbiterDir(resultsPath);
- ResultSaver<DL4JConfiguration,MultiLayerNetwork,Object> modelSaver = new LocalMultiLayerNetworkSaver<>(baseSaveDirectory);
- TerminationCondition[] terminationConditions = {new MaxTimeCondition(mintutesToTrain, TimeUnit.MINUTES)};
- ScoreFunction<MultiLayerNetwork,DataSetIterator> scoreFunction = new TestSetAccuracyScoreFunction();
- //Given these configuration options, let's put them all together:
- OptimizationConfiguration<DL4JConfiguration, MultiLayerNetwork, DataSetIterator, Object> configuration
- = new OptimizationConfiguration.Builder<DL4JConfiguration, MultiLayerNetwork, DataSetIterator, Object>()
- .candidateGenerator(generator)
- .dataProvider(dataProvider)
- .modelSaver(modelSaver)
- .scoreFunction(scoreFunction)
- .terminationConditions(terminationConditions)
- .build();
- //And set up execution locally on this machine:
- IOptimizationRunner<DL4JConfiguration,MultiLayerNetwork,Object> runner
- = new LocalOptimizationRunner<>(configuration, new MultiLayerNetworkTaskCreator<>());
- runner.addListeners(new LoggingOptimizationRunnerStatusListener());
- //Start the hyperparameter optimization
- runner.execute();
- StringBuilder sb = new StringBuilder();
- sb.append("Best score: ").append(runner.bestScore()).append("\n")
- .append("Index of model with best score: ").append(runner.bestScoreCandidateIndex()).append("\n")
- .append("Number of configurations evaluated: ").append(runner.numCandidatesCompleted()).append("\n");
- System.out.println(sb.toString());
- //Get all results, and print out details of the best result:
- int indexOfBestResult = runner.bestScoreCandidateIndex();
- List<ResultReference<DL4JConfiguration,MultiLayerNetwork,Object>> allResults = runner.getResults();
- try
- {
- if(indexOfBestResult != -1) {
- OptimizationResult<DL4JConfiguration, MultiLayerNetwork, Object> bestResult = allResults.get(indexOfBestResult).getResult();
- MultiLayerNetwork bestModel = bestResult.getResult();
- storeNetworkModel(bestModel);
- storeNetworkConfiguration(bestModel.getLayerWiseConfigurations().toJson());
- System.out.println("\n\nConfiguration of best model:\n");
- System.out.println(bestModel.getLayerWiseConfigurations().toJson());
- }else {
- System.out.println("\n\nNo configuration for best model found:\n");
- }
- }
- catch (IOException e)
- {
- System.out.println("Error while getting best result.\n\nError: " + e.toString());
- }
- }
- private void storeNetworkConfiguration(String jsonConf) {
- if (jsonConf == null)
- return;
- String fileName = modelName + "_conf.json";
- try {
- File storedJsonConfiguration = new File(modelPath, fileName);
- // deletes previously created configuration
- storedJsonConfiguration.delete();
- try(OutputStream fos = new FileOutputStream(storedJsonConfiguration);
- Writer writer = new OutputStreamWriter(fos, Charset.forName("UTF-8"))){
- writer.write(jsonConf);
- }
- } catch (IOException e) {
- System.out.println("Failed storing model [" + modelName + "].\n\nError: " + e);
- }
- }
- private void storeNetworkModel(MultiLayerNetwork bestModel) {
- if (bestModel == null)
- return;
- try {
- File storedModel = new File(modelPath, modelName);
- // deletes previously created model
- storedModel.delete();
- FileOutputStream fos = new FileOutputStream(storedModel);
- ModelSerializer.writeModel(bestModel, fos, true);
- } catch (IOException e) {
- System.out.println("Failed storing model [" + modelName + "].\n\nError: " + e.toString());
- }
- }
- private String clearAndCreateArbiterDir(String resultsPath) {
- //This will result in examples being saved to arbiterExample/0/, arbiterExample/1/, arbiterExample/2/, ...
- String baseSaveDirectory = new File(resultsPath, "arbiter-ex-modelName/").toPath().toString();
- File f = new File(baseSaveDirectory);
- if(f.exists()){
- try {
- FileUtils.deleteDirectory(f);
- } catch (IOException e) {
- System.out.println("Error deleting " + baseSaveDirectory + " directory.\n\nError:" + e.toString());
- }
- }
- f.mkdirs();
- return baseSaveDirectory;
- }
- private DataSetIterator getDataIterator(String trainFeaturesPath, String trainLabelsPath, Integer minFileId, Integer maxFileId) {
- SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
- SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
- try {
- File featuresDirTrain = new File(trainFeaturesPath);
- File labelsDirTrain = new File(trainLabelsPath);
- trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", minFileId, maxFileId));
- trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", minFileId, maxFileId));
- } catch (Exception e) {
- System.out.println("Error when creating training data for RNN.\n\nError: " + e.toString());
- }
- int numClasses = 20;
- DataSetIterator dataIterator = null;
- dataIterator = new SequenceRecordReaderDataSetIterator(
- trainFeatures,
- trainLabels,
- BATCH_SIZE,
- numClasses,
- false,
- SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
- return dataIterator;
- }
- private CandidateGenerator<DL4JConfiguration> createArbiterRNN(Integer nIn) {
- int nOut = 20; // Number of output categories/classes (number of columns)
- ParameterSpace<Double> learningRateHyperparam = new ContinuousParameterSpace(0.001, 0.1);
- ParameterSpace<Double> rmsDecayHyperparam = new ContinuousParameterSpace(0.1, 0.99);
- ParameterSpace<Double> dropoutHyperparam = new ContinuousParameterSpace(0.1, 0.9);
- ParameterSpace<Double> l2Hyperparameter = new ContinuousParameterSpace(0.0001, 0.1);
- ParameterSpace<Double> clipHyperparameter = new ContinuousParameterSpace(0.5, 100);
- ParameterSpace<Integer> hwHyperparameter = new IntegerParameterSpace(20, 300);
- ParameterSpace<Integer> tbpttHyperparameter = new IntegerParameterSpace(1, 300);
- ParameterSpace<Double> fgBiasHyperparameter = new ContinuousParameterSpace(0.5, 5.0);
- String[] actFns = new String[]{"tanh","softsign"};
- MultiLayerSpace hyperparameterSpace = new MultiLayerSpace.Builder()
- .numEpochs(50)
- //These next few options: fixed values for all models
- .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- .iterations(1)
- .seed(12345)
- .regularization(true)
- .l2(l2Hyperparameter)
- .learningRate(learningRateHyperparam)
- .rmsDecay(rmsDecayHyperparam)
- .dropOut(dropoutHyperparam)
- .updater(Updater.RMSPROP)
- .weightInit(WeightInit.XAVIER)
- .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
- .gradientNormalizationThreshold(clipHyperparameter)
- .addLayer( new GravesLSTMLayerSpace.Builder()
- .forgetGateBiasInit(fgBiasHyperparameter)
- .nIn(nIn)
- .nOut(hwHyperparameter)
- .activation(new DiscreteParameterSpace<>(actFns))
- .build())
- .addLayer( new RnnOutputLayerSpace.Builder()
- .activation("softmax")
- .lossFunction(LossFunctions.LossFunction.MCXENT)
- .nIn(hwHyperparameter)
- .nOut(nOut)
- .build()
- )
- .backpropType(BackpropType.TruncatedBPTT).tbpttFwdLength(tbpttHyperparameter).tbpttBwdLength(tbpttHyperparameter)
- .pretrain(false).backprop(true).build();
- CandidateGenerator<DL4JConfiguration> candidateGenerator = new RandomSearchGenerator<>(hyperparameterSpace);
- return candidateGenerator;
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement