Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import tensorflow as tf
- import tensorrt as trt
- import graphsurgeon as gs
- path = '/mnt/Data2/reps/tf_models/object_detection/checkpoints/autov/frozen_inference_graph.pb'
- TRTbin = '/home/undead/inception_autov.bin'
- output_name = ['NMS']
- dims = [3, 480, 480]
- layout = 7
- def add_plugin(graph):
- all_assert_nodes = graph.find_nodes_by_op("Assert")
- graph.remove(all_assert_nodes, remove_exclusive_dependencies=True)
- all_identity_nodes = graph.find_nodes_by_op("Identity")
- graph.forward_inputs(all_identity_nodes)
- Input = gs.create_plugin_node(
- name="Input",
- dtype=tf.float32,
- op="Placeholder",
- shape=[1, 3, 480, 480]
- )
- PriorBox = gs.create_plugin_node(
- name="GridAnchor",
- op="GridAnchor_TRT",
- minSize=0.1,
- maxSize=0.95,
- aspectRatios=[0.1, 0.3, 0.6, 0.8, 1.0, 1.5, 2.0, 3.0, 5.0],
- variance=[0.1, 0.1, 0.2, 0.2],
- featureMapShapes=[19, 10, 5, 3, 2, 1],
- numLayers=6
- )
- NMS = gs.create_plugin_node(
- name="NMS",
- op="NMS_TRT",
- shareLocation=1,
- varianceEncodedInTarget=0,
- backgroundLabelId=0,
- confidenceThreshold=1e-8,
- nmsThreshold=0.6,
- topK=100,
- keepTopK=100,
- numClasses=24,
- inputOrder=[0, 2, 1],
- confSigmoid=1,
- isNormalized=1,
- scoreConverter="SIGMOID"
- )
- concat_priorbox = gs.create_node(
- "concat_priorbox",
- op="ConcatV2",
- dtype=tf.float32,
- axis=2
- )
- concat_box_loc = gs.create_plugin_node(
- "concat_box_loc",
- op="FlattenConcat_TRT",
- dtype=tf.float32,
- axis=1,
- ignoreBatch=0
- )
- concat_box_conf = gs.create_plugin_node(
- "concat_box_conf",
- op="FlattenConcat_TRT",
- dtype=tf.float32,
- axis=1,
- ignoreBatch=0
- )
- namespace_plugin_map = {
- "MultipleGridAnchorGenerator": PriorBox,
- "Postprocessor": NMS,
- "Preprocessor": Input,
- "ToFloat": Input,
- "image_tensor": Input,
- "MultipleGridAnchorGenerator/Concatenate": concat_priorbox,
- "MultipleGridAnchorGenerator/Identity": concat_priorbox,
- "concat": concat_box_loc,
- "concat_1": concat_box_conf
- }
- graph.collapse_namespaces(namespace_plugin_map)
- graph.remove(graph.graph_outputs, remove_exclusive_dependencies=False)
- return graph
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement