Advertisement
Guest User

Untitled

a guest
Apr 10th, 2020
200
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.46 KB | None | 0 0
  1. import tensorflow as tf
  2. import tensorrt as trt
  3. import graphsurgeon as gs
  4.  
  5. path = '/mnt/Data2/reps/tf_models/object_detection/checkpoints/autov/frozen_inference_graph.pb'
  6. TRTbin = '/home/undead/inception_autov.bin'
  7. output_name = ['NMS']
  8. dims = [3, 480, 480]
  9. layout = 7
  10.  
  11. def add_plugin(graph):
  12.     all_assert_nodes = graph.find_nodes_by_op("Assert")
  13.     graph.remove(all_assert_nodes, remove_exclusive_dependencies=True)
  14.  
  15.     all_identity_nodes = graph.find_nodes_by_op("Identity")
  16.     graph.forward_inputs(all_identity_nodes)
  17.  
  18.     Input = gs.create_plugin_node(
  19.         name="Input",
  20.         dtype=tf.float32,
  21.         op="Placeholder",
  22.         shape=[1, 3, 480, 480]
  23.     )
  24.  
  25.     PriorBox = gs.create_plugin_node(
  26.         name="GridAnchor",
  27.         op="GridAnchor_TRT",
  28.         minSize=0.1,
  29.         maxSize=0.95,
  30.         aspectRatios=[0.1, 0.3, 0.6, 0.8, 1.0, 1.5, 2.0, 3.0, 5.0],
  31.         variance=[0.1, 0.1, 0.2, 0.2],
  32.         featureMapShapes=[19, 10, 5, 3, 2, 1],
  33.         numLayers=6
  34.     )
  35.  
  36.     NMS = gs.create_plugin_node(
  37.         name="NMS",
  38.         op="NMS_TRT",
  39.         shareLocation=1,
  40.         varianceEncodedInTarget=0,
  41.         backgroundLabelId=0,
  42.         confidenceThreshold=1e-8,
  43.         nmsThreshold=0.6,
  44.         topK=100,
  45.         keepTopK=100,
  46.         numClasses=24,
  47.         inputOrder=[0, 2, 1],
  48.         confSigmoid=1,
  49.         isNormalized=1,
  50.         scoreConverter="SIGMOID"
  51.     )
  52.  
  53.     concat_priorbox = gs.create_node(
  54.         "concat_priorbox",
  55.         op="ConcatV2",
  56.         dtype=tf.float32,
  57.         axis=2
  58.     )
  59.  
  60.     concat_box_loc = gs.create_plugin_node(
  61.         "concat_box_loc",
  62.         op="FlattenConcat_TRT",
  63.         dtype=tf.float32,
  64.         axis=1,
  65.         ignoreBatch=0
  66.     )
  67.  
  68.     concat_box_conf = gs.create_plugin_node(
  69.         "concat_box_conf",
  70.         op="FlattenConcat_TRT",
  71.         dtype=tf.float32,
  72.         axis=1,
  73.         ignoreBatch=0
  74.     )
  75.  
  76.     namespace_plugin_map = {
  77.         "MultipleGridAnchorGenerator": PriorBox,
  78.         "Postprocessor": NMS,
  79.         "Preprocessor": Input,
  80.         "ToFloat": Input,
  81.         "image_tensor": Input,
  82.         "MultipleGridAnchorGenerator/Concatenate": concat_priorbox,
  83.         "MultipleGridAnchorGenerator/Identity": concat_priorbox,
  84.         "concat": concat_box_loc,
  85.         "concat_1": concat_box_conf
  86.     }
  87.  
  88.     graph.collapse_namespaces(namespace_plugin_map)
  89.     graph.remove(graph.graph_outputs, remove_exclusive_dependencies=False)
  90.  
  91.     return graph
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement