Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.deeplearning4j.nn.modelimport.keras;
- import org.deeplearning4j.nn.api.Layer;
- import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
- import org.deeplearning4j.nn.graph.ComputationGraph;
- import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
- import org.nd4j.linalg.api.ndarray.INDArray;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
- import java.util.HashMap;
- import java.util.List;
- import java.util.Map;
- public class KerasImportVgg16 {
- private static final Logger log = LoggerFactory.getLogger(KerasImportVgg16.class);
- public static void main(String[] args) throws Exception {
- String modelJsonFilename = "PATH TO EXPORTED JSON FILE";
- String weightsHdf5Filename = "PATH TO EXPORTED WEIGHTS HDF5 ARCHIVE";
- String modelHdf5Filename = "PATH TO EXPORTED FULL MODEL HDF5 ARCHIVE";
- boolean enforceTrainingConfig = false; //Controls whether unsupported training-related configs
- //will throw an exception or just generate a warning.
- /* Import VGG 16 model from separate model config JSON and weights HDF5 files.
- * Will not include loss layer or training configuration.
- */
- // Static helper from KerasModelImport
- ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelJsonFilename, weightsHdf5Filename, enforceTrainingConfig);
- // KerasModel builder pattern
- model = new KerasModel.ModelBuilder()
- .modelJsonFilename(modelJsonFilename)
- .weightsHdf5Filename(weightsHdf5Filename)
- .enforceTrainingConfig(enforceTrainingConfig)
- .buildModel()
- .getComputationGraph();
- /* Import VGG 16 model from full model HDF5 file. Includes loss layer, if any. */
- // Static helper from KerasModelImport
- model = KerasModelImport.importKerasModelAndWeights(modelHdf5Filename, enforceTrainingConfig);
- // KerasModel builder pattern
- model = new KerasModel.ModelBuilder()
- .modelHdf5Filename(modelHdf5Filename)
- .enforceTrainingConfig(enforceTrainingConfig)
- .buildModel()
- .getComputationGraph();
- /* Import VGG 16 model config from model config JSON. Will not include loss
- * layer or training configuration.
- */
- // Static helper from KerasModelImport
- ComputationGraphConfiguration config = KerasModelImport.importKerasModelConfiguration(modelJsonFilename, enforceTrainingConfig);
- // KerasModel builder pattern
- config = new KerasModel.ModelBuilder()
- .modelJsonFilename(modelJsonFilename)
- .enforceTrainingConfig(enforceTrainingConfig)
- .buildModel()
- .getComputationGraphConfiguration();
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement