Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package org.nd4j.cli;
- import lombok.val;
- import org.nd4j.autodiff.execution.NativeGraphExecutioner;
- import org.nd4j.autodiff.execution.conf.ExecutionMode;
- import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
- import org.nd4j.autodiff.execution.conf.OutputMode;
- import org.nd4j.autodiff.samediff.SameDiff;
- import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
- import org.nd4j.linalg.api.ndarray.INDArray;
- import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
- import org.nd4j.linalg.factory.Nd4j;
- import org.nd4j.linalg.io.ClassPathResource;
- import org.nd4j.nativeblas.Nd4jCpu;
- import java.io.File;
- import java.io.IOException;
- import java.util.Collections;
- /**
- * Created by Yves Quemener on 12/7/18.
- */
- public class TFtoFlatFileConverter {
- public static void convert(String inFile, String outFile, String inputName, int[] inputShape) throws IOException {
- SameDiff tg = TFGraphMapper.getInstance().importGraph(new File(inFile));
- //tg.asFlatFile(new File(outFile), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build());
- /*INDArray input = Nd4j.ones(1,128,128,3);
- tg.associateArrayWithVariable(input, "input");*/
- //INDArray input = Nd4j.ones(inputShape);
- //tg.associateArrayWithVariable(input, inputName);
- tg.asFlatFile(new File(outFile));
- }
- public static void main(String [] args) throws IOException {
- if(args.length<2){
- //mobileNetFlatBufferSanity();
- mobileNetFlatBufferSanityLibNd4j();
- //testDenseNet();
- //testMobileNet();
- //testMobileNetOutput();
- //testTfFiles();
- }
- else{
- convert(args[0], args[1], "input", new int[]{1,224,224,3});
- }
- }
- public static void testTfFiles() throws IOException {
- String inDir = "//home/yves/dl4j/models/";
- String outDir = "/home/yves/tmp/fbmodels/";
- String[] filenames = {"mobilenet_v1_0.5_128_frozen.pb",
- "mobilenet_v2_1.0_224_frozen.pb",
- "nasnet_mobile.pb",
- "resnetv2_imagenet_frozen_graph.pb",
- "squeezenet.pb"};
- String[] inputNames={"input",
- "input",
- "input",
- "input_tensor",
- "Placeholder"
- };
- int[][] shapes = {new int[]{1,128,128,3},
- new int[]{1,224,224,3},
- new int[]{1,224,224,3},
- new int[]{1,224,224,3},
- new int[]{1,224,224,3}
- };
- //String[] filenames = {"lenet_cnn.pb", "lenet_frozen.pb", "max_lstm.pb", "tensorflow_inception_graph.pb", "train_iris.pb"};
- //String[] filenames = {"lenet_cnn.pb"/*, "lenet_frozen.pb", max_lstm.pb"*/, "tensorflow_inception_graph.pb", /*"train_iris.pb"*/};
- for(int i=0;i<filenames.length;i++){
- System.out.println("Testing model "+filenames[i]);
- String outFilename = filenames[i].substring(0, filenames[i].length()-3) + ".fb";
- convert(inDir+filenames[i], outDir+outFilename, inputNames[i], shapes[i]);
- }
- }
- public static void testMobileNet() throws IOException {
- //convert("/home/yves/tmp/mobilenet_v1_0.5_128_frozen.pb", "/home/yves/tmp/mobilenet.fb");
- }
- public static void testDenseNet() throws IOException {
- //convert("/home/yves/tmp/densenet/densenet.pb", "/home/yves/tmp/densenet.fb");
- }
- public static void testMobileNetOutput(){
- SameDiff tg = TFGraphMapper.getInstance().importGraph(new File("/home/yves/dl4j/models/mobilenet_v1_0.5_128_frozen.pb"));
- INDArray array = Nd4j.ones(1,128,128,3);
- tg.associateArrayWithVariable(array, "input");
- //INDArray result = tg.execAndEndResult();
- INDArray result = tg.execSingle(Collections.singletonMap("input",array), tg.outputs().get(0));
- System.out.println("Result = "+result.toString());
- System.out.println("Result ind max = "+result.argMax().toString());
- System.out.println("Result arg max = "+result.get(result.argMax()).toString());
- }
- public static void mobileNetFlatBufferSanity() throws IOException {
- SameDiff tg = SameDiff.fromFlatFile(new File("/home/yves/dl4j/models/flatBufferModels/master_version/mobilenet_v1_0.5_128_frozen.fb"));
- INDArray array = Nd4j.zeros(1,128,128,3);
- tg.associateArrayWithVariable(array, "input");
- INDArray result = tg.execSingle(Collections.singletonMap("input",array), tg.outputs().get(0));
- System.out.println("Result = "+result.toString());
- System.out.println("Result ind max = "+result.argMax().toString());
- System.out.println("Result arg max = "+result.get(result.argMax()).toString());
- }
- public static void mobileNetFlatBufferSanityLibNd4j() throws IOException {
- //SameDiff tg = SameDiff.fromFlatFile(new File("/home/yves/dl4j/models/flatBufferModels/master_version/mobilenet_v1_0.5_128_frozen.fb"));
- //SameDiff tg = SameDiff.fromFlatFile(new File("/home/yves/tmp/fbmodels/mobilenet_v1_0.5_128_frozen.fb"));
- SameDiff tg = TFGraphMapper.getInstance().importGraph(new File("/home/yves/dl4j/models/mobilenet_v1_0.5_128_frozen.pb"));
- INDArray array = Nd4j.zeros(1,128,128,3);
- tg.associateArrayWithVariable(array, "input");
- //INDArray result = tg.execSingle(Collections.singletonMap("input",array), tg.outputs().get(0));
- val executioner = new NativeGraphExecutioner();
- ExecutorConfiguration configuration = ExecutorConfiguration.builder()
- .executionMode(ExecutionMode.SEQUENTIAL)
- .profilingMode(OpExecutioner.ProfilingMode.DISABLED)
- .gatherTimings(true)
- .outputMode(OutputMode.VARIABLE_SPACE)
- .build();
- executioner.executeGraph(tg, configuration);
- INDArray result = tg.getVariable("MobilenetV1/Predictions/Reshape_1").getArr();
- System.out.println("Result = "+result.toString());
- System.out.println("Result ind max = "+result.argMax().toString());
- System.out.println("Result arg max = "+result.get(result.argMax()).toString());
- }
- }
Add Comment
Please, Sign In to add comment