Guest User

Untitled

a guest
Jan 17th, 2019
108
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.06 KB | None | 0 0
  1. package org.nd4j.cli;
  2.  
  3. import lombok.val;
  4. import org.nd4j.autodiff.execution.NativeGraphExecutioner;
  5. import org.nd4j.autodiff.execution.conf.ExecutionMode;
  6. import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
  7. import org.nd4j.autodiff.execution.conf.OutputMode;
  8. import org.nd4j.autodiff.samediff.SameDiff;
  9. import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
  10. import org.nd4j.linalg.api.ndarray.INDArray;
  11. import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
  12. import org.nd4j.linalg.factory.Nd4j;
  13. import org.nd4j.linalg.io.ClassPathResource;
  14. import org.nd4j.nativeblas.Nd4jCpu;
  15.  
  16. import java.io.File;
  17. import java.io.IOException;
  18. import java.util.Collections;
  19.  
  20. /**
  21. * Created by Yves Quemener on 12/7/18.
  22. */
  23. public class TFtoFlatFileConverter {
  24.  
  25. public static void convert(String inFile, String outFile, String inputName, int[] inputShape) throws IOException {
  26. SameDiff tg = TFGraphMapper.getInstance().importGraph(new File(inFile));
  27. //tg.asFlatFile(new File(outFile), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build());
  28. /*INDArray input = Nd4j.ones(1,128,128,3);
  29. tg.associateArrayWithVariable(input, "input");*/
  30. //INDArray input = Nd4j.ones(inputShape);
  31. //tg.associateArrayWithVariable(input, inputName);
  32. tg.asFlatFile(new File(outFile));
  33. }
  34.  
  35. public static void main(String [] args) throws IOException {
  36. if(args.length<2){
  37. //mobileNetFlatBufferSanity();
  38. mobileNetFlatBufferSanityLibNd4j();
  39. //testDenseNet();
  40. //testMobileNet();
  41. //testMobileNetOutput();
  42. //testTfFiles();
  43. }
  44. else{
  45. convert(args[0], args[1], "input", new int[]{1,224,224,3});
  46. }
  47. }
  48.  
  49. public static void testTfFiles() throws IOException {
  50. String inDir = "//home/yves/dl4j/models/";
  51. String outDir = "/home/yves/tmp/fbmodels/";
  52. String[] filenames = {"mobilenet_v1_0.5_128_frozen.pb",
  53. "mobilenet_v2_1.0_224_frozen.pb",
  54. "nasnet_mobile.pb",
  55. "resnetv2_imagenet_frozen_graph.pb",
  56. "squeezenet.pb"};
  57.  
  58. String[] inputNames={"input",
  59. "input",
  60. "input",
  61. "input_tensor",
  62. "Placeholder"
  63. };
  64.  
  65. int[][] shapes = {new int[]{1,128,128,3},
  66. new int[]{1,224,224,3},
  67. new int[]{1,224,224,3},
  68. new int[]{1,224,224,3},
  69. new int[]{1,224,224,3}
  70. };
  71. //String[] filenames = {"lenet_cnn.pb", "lenet_frozen.pb", "max_lstm.pb", "tensorflow_inception_graph.pb", "train_iris.pb"};
  72. //String[] filenames = {"lenet_cnn.pb"/*, "lenet_frozen.pb", max_lstm.pb"*/, "tensorflow_inception_graph.pb", /*"train_iris.pb"*/};
  73.  
  74. for(int i=0;i<filenames.length;i++){
  75. System.out.println("Testing model "+filenames[i]);
  76. String outFilename = filenames[i].substring(0, filenames[i].length()-3) + ".fb";
  77. convert(inDir+filenames[i], outDir+outFilename, inputNames[i], shapes[i]);
  78. }
  79. }
  80.  
  81. public static void testMobileNet() throws IOException {
  82. //convert("/home/yves/tmp/mobilenet_v1_0.5_128_frozen.pb", "/home/yves/tmp/mobilenet.fb");
  83. }
  84.  
  85. public static void testDenseNet() throws IOException {
  86. //convert("/home/yves/tmp/densenet/densenet.pb", "/home/yves/tmp/densenet.fb");
  87. }
  88.  
  89. public static void testMobileNetOutput(){
  90. SameDiff tg = TFGraphMapper.getInstance().importGraph(new File("/home/yves/dl4j/models/mobilenet_v1_0.5_128_frozen.pb"));
  91. INDArray array = Nd4j.ones(1,128,128,3);
  92. tg.associateArrayWithVariable(array, "input");
  93. //INDArray result = tg.execAndEndResult();
  94. INDArray result = tg.execSingle(Collections.singletonMap("input",array), tg.outputs().get(0));
  95.  
  96. System.out.println("Result = "+result.toString());
  97. System.out.println("Result ind max = "+result.argMax().toString());
  98. System.out.println("Result arg max = "+result.get(result.argMax()).toString());
  99. }
  100.  
  101. public static void mobileNetFlatBufferSanity() throws IOException {
  102. SameDiff tg = SameDiff.fromFlatFile(new File("/home/yves/dl4j/models/flatBufferModels/master_version/mobilenet_v1_0.5_128_frozen.fb"));
  103. INDArray array = Nd4j.zeros(1,128,128,3);
  104. tg.associateArrayWithVariable(array, "input");
  105.  
  106. INDArray result = tg.execSingle(Collections.singletonMap("input",array), tg.outputs().get(0));
  107.  
  108.  
  109. System.out.println("Result = "+result.toString());
  110. System.out.println("Result ind max = "+result.argMax().toString());
  111. System.out.println("Result arg max = "+result.get(result.argMax()).toString());
  112. }
  113.  
  114. public static void mobileNetFlatBufferSanityLibNd4j() throws IOException {
  115. //SameDiff tg = SameDiff.fromFlatFile(new File("/home/yves/dl4j/models/flatBufferModels/master_version/mobilenet_v1_0.5_128_frozen.fb"));
  116. //SameDiff tg = SameDiff.fromFlatFile(new File("/home/yves/tmp/fbmodels/mobilenet_v1_0.5_128_frozen.fb"));
  117. SameDiff tg = TFGraphMapper.getInstance().importGraph(new File("/home/yves/dl4j/models/mobilenet_v1_0.5_128_frozen.pb"));
  118. INDArray array = Nd4j.zeros(1,128,128,3);
  119. tg.associateArrayWithVariable(array, "input");
  120.  
  121. //INDArray result = tg.execSingle(Collections.singletonMap("input",array), tg.outputs().get(0));
  122.  
  123. val executioner = new NativeGraphExecutioner();
  124. ExecutorConfiguration configuration = ExecutorConfiguration.builder()
  125. .executionMode(ExecutionMode.SEQUENTIAL)
  126. .profilingMode(OpExecutioner.ProfilingMode.DISABLED)
  127. .gatherTimings(true)
  128. .outputMode(OutputMode.VARIABLE_SPACE)
  129. .build();
  130.  
  131.  
  132. executioner.executeGraph(tg, configuration);
  133. INDArray result = tg.getVariable("MobilenetV1/Predictions/Reshape_1").getArr();
  134.  
  135. System.out.println("Result = "+result.toString());
  136. System.out.println("Result ind max = "+result.argMax().toString());
  137. System.out.println("Result arg max = "+result.get(result.argMax()).toString());
  138. }
  139.  
  140.  
  141. }
Add Comment
Please, Sign In to add comment