Advertisement
Guest User

Untitled

a guest
Jan 20th, 2017
164
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.80 KB | None | 0 0
  1. package org.deeplearning4j.nn.modelimport.keras;
  2.  
  3. import org.deeplearning4j.nn.api.Layer;
  4. import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
  5. import org.deeplearning4j.nn.graph.ComputationGraph;
  6. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  7. import org.nd4j.linalg.api.ndarray.INDArray;
  8. import org.slf4j.Logger;
  9. import org.slf4j.LoggerFactory;
  10.  
  11. import java.util.HashMap;
  12. import java.util.List;
  13. import java.util.Map;
  14.  
  15. public class KerasImportVgg16 {
  16.  
  17. private static final Logger log = LoggerFactory.getLogger(KerasImportVgg16.class);
  18.  
  19. public static void main(String[] args) throws Exception {
  20. String modelJsonFilename = "PATH TO EXPORTED JSON FILE";
  21. String weightsHdf5Filename = "PATH TO EXPORTED WEIGHTS HDF5 ARCHIVE";
  22. String modelHdf5Filename = "PATH TO EXPORTED FULL MODEL HDF5 ARCHIVE";
  23. boolean enforceTrainingConfig = false; //Controls whether unsupported training-related configs
  24. //will throw an exception or just generate a warning.
  25.  
  26. /* Import VGG 16 model from separate model config JSON and weights HDF5 files.
  27. * Will not include loss layer or training configuration.
  28. */
  29. // Static helper from KerasModelImport
  30. ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelJsonFilename, weightsHdf5Filename, enforceTrainingConfig);
  31.  
  32. // KerasModel builder pattern
  33. model = new KerasModel.ModelBuilder()
  34. .modelJsonFilename(modelJsonFilename)
  35. .weightsHdf5Filename(weightsHdf5Filename)
  36. .enforceTrainingConfig(enforceTrainingConfig)
  37. .buildModel()
  38. .getComputationGraph();
  39.  
  40. /* Import VGG 16 model from full model HDF5 file. Includes loss layer, if any. */
  41. // Static helper from KerasModelImport
  42. model = KerasModelImport.importKerasModelAndWeights(modelHdf5Filename, enforceTrainingConfig);
  43.  
  44. // KerasModel builder pattern
  45. model = new KerasModel.ModelBuilder()
  46. .modelHdf5Filename(modelHdf5Filename)
  47. .enforceTrainingConfig(enforceTrainingConfig)
  48. .buildModel()
  49. .getComputationGraph();
  50.  
  51. /* Import VGG 16 model config from model config JSON. Will not include loss
  52. * layer or training configuration.
  53. */
  54. // Static helper from KerasModelImport
  55. ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration(modelJsonFilename, enforceTrainingConfig);
  56.  
  57. // KerasModel builder pattern
  58. config = new KerasModel.ModelBuilder()
  59. .modelJsonFilename(modelJsonFilename)
  60. .enforceTrainingConfig(enforceTrainingConfig)
  61. .buildModel()
  62. .getComputationGraphConfiguration();
  63. }
  64. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement