Advertisement
Guest User

Untitled

a guest
Mar 24th, 2017
100
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 13.69 KB | None | 0 0
  1. package foo;
  2.  
  3. import org.apache.commons.io.FileUtils;
  4. import org.datavec.api.records.reader.SequenceRecordReader;
  5. import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
  6. import org.datavec.api.split.NumberedFileInputSplit;
  7. import org.deeplearning4j.arbiter.DL4JConfiguration;
  8. import org.deeplearning4j.arbiter.MultiLayerSpace;
  9. import org.deeplearning4j.arbiter.data.DataSetIteratorProvider;
  10. import org.deeplearning4j.arbiter.layers.GravesLSTMLayerSpace;
  11. import org.deeplearning4j.arbiter.layers.RnnOutputLayerSpace;
  12. import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
  13. import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
  14. import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
  15. import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
  16. import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
  17. import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
  18. import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
  19. import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition;
  20. import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
  21. import org.deeplearning4j.arbiter.optimize.candidategenerator.RandomSearchGenerator;
  22. import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
  23. import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
  24. import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
  25. import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
  26. import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
  27. import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
  28. import org.deeplearning4j.arbiter.optimize.runner.listener.runner.LoggingOptimizationRunnerStatusListener;
  29. import org.deeplearning4j.arbiter.saver.local.multilayer.LocalMultiLayerNetworkSaver;
  30. import org.deeplearning4j.arbiter.scoring.multilayer.TestSetAccuracyScoreFunction;
  31. import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator;
  32. import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
  33. import org.deeplearning4j.nn.api.OptimizationAlgorithm;
  34. import org.deeplearning4j.nn.conf.*;
  35. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  36. import org.deeplearning4j.nn.weights.WeightInit;
  37. import org.deeplearning4j.util.ModelSerializer;
  38. import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
  39. import org.nd4j.linalg.lossfunctions.LossFunctions;
  40.  
  41. import java.io.*;
  42. import java.nio.charset.Charset;
  43. import java.util.List;
  44. import java.util.concurrent.TimeUnit;
  45.  
  46. public class ClassificationArbiterCPU {
  47.  
  48. private static final String BASE_DIR = ".. some folder";
  49.  
  50. // data is split for each i.csv file on time axis using a 60/40 split
  51. private static final String DATA_DIR = ".. folder with train and test split";
  52.  
  53. private static final String RESOURCE_DIR = "... somewhere to store folder";
  54.  
  55. public static void main(String[] args) {
  56. String modelName = "category_D1";
  57. String modelPath = BASE_DIR + "/model/";
  58. String resultsPath = BASE_DIR + RESOURCE_DIR + "/results/";
  59.  
  60. String trainFeaturesPath = BASE_DIR + DATA_DIR + "/train/features/";
  61. String trainLabelsPath = BASE_DIR + DATA_DIR + "/train/labels/";
  62. String validationFeaturesPath = BASE_DIR + DATA_DIR + "/validation/features/";
  63. String validationLabelsPath = BASE_DIR + DATA_DIR + "/validation/labels/";
  64.  
  65. int maxFileId = 10;
  66. int numInput = 20;
  67.  
  68. ClassificationArbiterCPU arbiter = new ClassificationArbiterCPU(modelPath, modelName);
  69. arbiter.trainMultiInputNetwork(
  70. numInput,
  71. trainFeaturesPath,
  72. trainLabelsPath,
  73. validationFeaturesPath,
  74. validationLabelsPath,
  75. 0,
  76. maxFileId,
  77. resultsPath,
  78. 20);
  79. }
  80.  
  81.  
  82. private String modelPath;
  83. private String modelName;
  84.  
  85. final int BATCH_SIZE = 100;
  86.  
  87. public ClassificationArbiterCPU(String modelPath, String modelName){
  88. this.modelPath = modelPath;
  89. this.modelName = modelName;
  90. }
  91.  
  92.  
  93. public void trainMultiInputNetwork(
  94. Integer numInput,
  95. String trainFeaturesPath,
  96. String trainLabelsPath,
  97. String testFeaturesPath,
  98. String testLabelsPath,
  99. Integer minFileId,
  100. Integer maxFileId,
  101. String resultsPath,
  102. final int mintutesToTrain) {
  103.  
  104. CandidateGenerator<DL4JConfiguration> generator = createArbiterRNN(numInput);
  105.  
  106. //Load data
  107. DataSetIterator trainingData = getDataIterator(trainFeaturesPath, trainLabelsPath, minFileId, maxFileId);
  108. DataSetIterator testData = getDataIterator(testFeaturesPath, testLabelsPath, minFileId, maxFileId);
  109.  
  110. //data normalization not necessary (ONE-HOT encoding already implemented)
  111.  
  112. DataProvider<DataSetIterator> dataProvider = new DataSetIteratorProvider(trainingData, testData);
  113.  
  114. //This will result in examples being saved to arbiterExample/0/, arbiterExample/1/, arbiterExample/2/, ...
  115. String baseSaveDirectory = clearAndCreateArbiterDir(resultsPath);
  116. ResultSaver<DL4JConfiguration,MultiLayerNetwork,Object> modelSaver = new LocalMultiLayerNetworkSaver<>(baseSaveDirectory);
  117.  
  118. TerminationCondition[] terminationConditions = {new MaxTimeCondition(mintutesToTrain, TimeUnit.MINUTES)};
  119.  
  120. ScoreFunction<MultiLayerNetwork,DataSetIterator> scoreFunction = new TestSetAccuracyScoreFunction();
  121.  
  122. //Given these configuration options, let's put them all together:
  123. OptimizationConfiguration<DL4JConfiguration, MultiLayerNetwork, DataSetIterator, Object> configuration
  124. = new OptimizationConfiguration.Builder<DL4JConfiguration, MultiLayerNetwork, DataSetIterator, Object>()
  125. .candidateGenerator(generator)
  126. .dataProvider(dataProvider)
  127. .modelSaver(modelSaver)
  128. .scoreFunction(scoreFunction)
  129. .terminationConditions(terminationConditions)
  130. .build();
  131.  
  132. //And set up execution locally on this machine:
  133. IOptimizationRunner<DL4JConfiguration,MultiLayerNetwork,Object> runner
  134. = new LocalOptimizationRunner<>(configuration, new MultiLayerNetworkTaskCreator<>());
  135.  
  136.  
  137. runner.addListeners(new LoggingOptimizationRunnerStatusListener());
  138. //Start the hyperparameter optimization
  139. runner.execute();
  140.  
  141.  
  142. StringBuilder sb = new StringBuilder();
  143. sb.append("Best score: ").append(runner.bestScore()).append("\n")
  144. .append("Index of model with best score: ").append(runner.bestScoreCandidateIndex()).append("\n")
  145. .append("Number of configurations evaluated: ").append(runner.numCandidatesCompleted()).append("\n");
  146.  
  147. System.out.println(sb.toString());
  148.  
  149.  
  150. //Get all results, and print out details of the best result:
  151. int indexOfBestResult = runner.bestScoreCandidateIndex();
  152. List<ResultReference<DL4JConfiguration,MultiLayerNetwork,Object>> allResults = runner.getResults();
  153.  
  154. try
  155. {
  156. if(indexOfBestResult != -1) {
  157. OptimizationResult<DL4JConfiguration, MultiLayerNetwork, Object> bestResult = allResults.get(indexOfBestResult).getResult();
  158. MultiLayerNetwork bestModel = bestResult.getResult();
  159.  
  160. storeNetworkModel(bestModel);
  161. storeNetworkConfiguration(bestModel.getLayerWiseConfigurations().toJson());
  162.  
  163. System.out.println("\n\nConfiguration of best model:\n");
  164. System.out.println(bestModel.getLayerWiseConfigurations().toJson());
  165. }else {
  166. System.out.println("\n\nNo configuration for best model found:\n");
  167. }
  168. }
  169. catch (IOException e)
  170. {
  171. System.out.println("Error while getting best result.\n\nError: " + e.toString());
  172. }
  173. }
  174.  
  175. private void storeNetworkConfiguration(String jsonConf) {
  176. if (jsonConf == null)
  177. return;
  178.  
  179. String fileName = modelName + "_conf.json";
  180.  
  181. try {
  182. File storedJsonConfiguration = new File(modelPath, fileName);
  183. // deletes previously created configuration
  184. storedJsonConfiguration.delete();
  185.  
  186. try(OutputStream fos = new FileOutputStream(storedJsonConfiguration);
  187. Writer writer = new OutputStreamWriter(fos, Charset.forName("UTF-8"))){
  188. writer.write(jsonConf);
  189. }
  190. } catch (IOException e) {
  191. System.out.println("Failed storing model [" + modelName + "].\n\nError: " + e);
  192. }
  193. }
  194.  
  195. private void storeNetworkModel(MultiLayerNetwork bestModel) {
  196. if (bestModel == null)
  197. return;
  198.  
  199. try {
  200. File storedModel = new File(modelPath, modelName);
  201. // deletes previously created model
  202. storedModel.delete();
  203. FileOutputStream fos = new FileOutputStream(storedModel);
  204. ModelSerializer.writeModel(bestModel, fos, true);
  205. } catch (IOException e) {
  206. System.out.println("Failed storing model [" + modelName + "].\n\nError: " + e.toString());
  207. }
  208.  
  209. }
  210.  
  211. private String clearAndCreateArbiterDir(String resultsPath) {
  212. //This will result in examples being saved to arbiterExample/0/, arbiterExample/1/, arbiterExample/2/, ...
  213. String baseSaveDirectory = new File(resultsPath, "arbiter-ex-modelName/").toPath().toString();
  214.  
  215. File f = new File(baseSaveDirectory);
  216. if(f.exists()){
  217. try {
  218. FileUtils.deleteDirectory(f);
  219. } catch (IOException e) {
  220. System.out.println("Error deleting " + baseSaveDirectory + " directory.\n\nError:" + e.toString());
  221. }
  222. }
  223.  
  224. f.mkdirs();
  225.  
  226. return baseSaveDirectory;
  227. }
  228.  
  229. private DataSetIterator getDataIterator(String trainFeaturesPath, String trainLabelsPath, Integer minFileId, Integer maxFileId) {
  230. SequenceRecordReader trainFeatures = new CSVSequenceRecordReader();
  231. SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
  232.  
  233. try {
  234. File featuresDirTrain = new File(trainFeaturesPath);
  235. File labelsDirTrain = new File(trainLabelsPath);
  236.  
  237. trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", minFileId, maxFileId));
  238. trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", minFileId, maxFileId));
  239. } catch (Exception e) {
  240. System.out.println("Error when creating training data for RNN.\n\nError: " + e.toString());
  241. }
  242.  
  243. int numClasses = 20;
  244.  
  245. DataSetIterator dataIterator = null;
  246. dataIterator = new SequenceRecordReaderDataSetIterator(
  247. trainFeatures,
  248. trainLabels,
  249. BATCH_SIZE,
  250. numClasses,
  251. false,
  252. SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
  253. return dataIterator;
  254. }
  255.  
  256. private CandidateGenerator<DL4JConfiguration> createArbiterRNN(Integer nIn) {
  257.  
  258. int nOut = 20; // Number of output categories/classes (number of columns)
  259.  
  260. ParameterSpace<Double> learningRateHyperparam = new ContinuousParameterSpace(0.001, 0.1);
  261. ParameterSpace<Double> rmsDecayHyperparam = new ContinuousParameterSpace(0.1, 0.99);
  262. ParameterSpace<Double> dropoutHyperparam = new ContinuousParameterSpace(0.1, 0.9);
  263. ParameterSpace<Double> l2Hyperparameter = new ContinuousParameterSpace(0.0001, 0.1);
  264. ParameterSpace<Double> clipHyperparameter = new ContinuousParameterSpace(0.5, 100);
  265. ParameterSpace<Integer> hwHyperparameter = new IntegerParameterSpace(20, 300);
  266. ParameterSpace<Integer> tbpttHyperparameter = new IntegerParameterSpace(1, 300);
  267. ParameterSpace<Double> fgBiasHyperparameter = new ContinuousParameterSpace(0.5, 5.0);
  268. String[] actFns = new String[]{"tanh","softsign"};
  269.  
  270. MultiLayerSpace hyperparameterSpace = new MultiLayerSpace.Builder()
  271. .numEpochs(50)
  272. //These next few options: fixed values for all models
  273. .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
  274. .iterations(1)
  275. .seed(12345)
  276. .regularization(true)
  277. .l2(l2Hyperparameter)
  278. .learningRate(learningRateHyperparam)
  279. .rmsDecay(rmsDecayHyperparam)
  280. .dropOut(dropoutHyperparam)
  281. .updater(Updater.RMSPROP)
  282. .weightInit(WeightInit.XAVIER)
  283. .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
  284. .gradientNormalizationThreshold(clipHyperparameter)
  285. .addLayer( new GravesLSTMLayerSpace.Builder()
  286. .forgetGateBiasInit(fgBiasHyperparameter)
  287. .nIn(nIn)
  288. .nOut(hwHyperparameter)
  289. .activation(new DiscreteParameterSpace<>(actFns))
  290. .build())
  291. .addLayer( new RnnOutputLayerSpace.Builder()
  292. .activation("softmax")
  293. .lossFunction(LossFunctions.LossFunction.MCXENT)
  294. .nIn(hwHyperparameter)
  295. .nOut(nOut)
  296. .build()
  297. )
  298. .backpropType(BackpropType.TruncatedBPTT).tbpttFwdLength(tbpttHyperparameter).tbpttBwdLength(tbpttHyperparameter)
  299. .pretrain(false).backprop(true).build();
  300.  
  301. CandidateGenerator<DL4JConfiguration> candidateGenerator = new RandomSearchGenerator<>(hyperparameterSpace);
  302. return candidateGenerator;
  303. }
  304. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement