Advertisement
Guest User

Untitled

a guest
Oct 19th, 2018
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.81 KB | None | 0 0
  1. static void printNode(bmnet_onnx::Node *n) {
  2.   if (n->has_name())
  3.     std::cout << "Node's name: " << n->name() << std::endl;
  4.   else
  5.     std::cout << "Node has got no name" << std::endl;
  6.  
  7.   std::cout << "Node's type: " << n->kind() << std::endl;
  8.   std::cout << "input number: " << n->inputs().size() << std::endl;
  9.   std::cout << "output number: " << n->outputs().size() << std::endl;
  10.   std::cout << "input dim size: " << n->inputs()[0]->sizes().size() << std::endl;
  11.   for (int i = 0; i < n->inputs()[0]->sizes().size(); i++) {
  12.       std::cout << n->inputs()[0]->sizes()[i].dim << ", ";
  13.   }
  14.   std::cout << std::endl;
  15.   std::cout << "output dim size: " << n->outputs()[0]->sizes().size() << std::endl;
  16.   for (int i = 0; i < n->outputs()[0]->sizes().size(); i++) {
  17.       std::cout << n->outputs()[0]->sizes()[i].dim << ", ";
  18.   }
  19.   std::cout << std::endl;
  20. }
  21.  
  22. void OnnxNetParser::shuffleChannelPass(Module* module) {
  23.   ONNX_NAMESPACE::Graph* graph = module->getGraphIR().get();
  24.   for (auto it = graph->begin(), ie = graph->end(); it != ie; ++it) {
  25.     auto it_temp = it;
  26.     auto it_reshape1 = it;
  27.     auto* n = *it;
  28.     ONNX_NAMESPACE::Node* curNode = n;
  29.     /* the first node is reshape, from 4D to 5D */
  30.     if (  n->kind() != ONNX_NAMESPACE::kReshape
  31.        || n->inputs()[0]->sizes().size() != 4
  32.        || n->outputs()[0]->sizes().size() != 5)
  33.        continue;
  34. printNode(n);
  35.     /* the second node is transpose */
  36.     ++it_temp;
  37.     auto it_transpose = it_temp;
  38.     n = *it_temp;
  39.     if (n->kind() != ONNX_NAMESPACE::kTranspose)
  40.       continue;
  41. printNode(n);
  42.  
  43.     /* the third node is reshape, from 5D to 4D */
  44.     ++it_temp;
  45.     auto it_reshape2 = it_temp;
  46.     n = *it_temp;
  47.     if (  n->kind() != ONNX_NAMESPACE::kReshape
  48.       || n->inputs()[0]->sizes().size() != 5
  49.       || n->outputs()[0]->sizes().size() != 4)
  50.       continue;
  51. printNode(n);
  52.  
  53.     /* here is one shuffle channel needs to be handled */
  54.     ONNX_NAMESPACE::Value *inputData = curNode->inputs()[0];
  55.     int group = curNode->outputs()[0]->sizes()[1].dim;
  56.  
  57.     ONNX_NAMESPACE::Node *shuffle = graph->create(ONNX_NAMESPACE::Symbol("ShuffleChannel"));
  58.     shuffle->setName("OC2_DUMMY_" + curNode->name());
  59.     shuffle->addInput(inputData);
  60.     shuffle->output()->setElemType(inputData->elemType());
  61.     shuffle->output()->setUniqueName("OC2_DUMMY_" + n->outputs()[0]->uniqueName());
  62.     //shuffle->insertBefore(curNode);
  63. #if 0
  64.     /* replace users' input */
  65.     std::vector<ONNX_NAMESPACE::Node *> user_nodes;
  66.     for (int j = 0; j < n->outputs()[0]->uses().size(); j++) {
  67.       auto user_node = n->outputs()[0]->uses()[j].user;
  68.       user_nodes.push_back(user_node);
  69.     }
  70.     for (int j = 0; j < user_nodes.size(); j++) {
  71.       auto user_node = user_nodes[j];
  72.       user_node->replaceInput(0, shuffle->outputs()[0]);
  73.     }
  74. #endif
  75.     std::cout << "Here we start to delete three nodes in a row" << std::endl;
  76. std::cout << "===============================================" << std::endl;
  77.     /* destroy 3 nodes */
  78. n = *it_reshape2;
  79. printNode(n);
  80. n->output()->replaceAllUsesWith(n->inputs()[0]);
  81.     it_reshape2.destroyCurrent();
  82.  
  83.     n = *it_transpose;
  84. printNode(n);
  85.     n->output()->replaceAllUsesWith(n->inputs()[0]);
  86.     it_transpose.destroyCurrent();
  87.  
  88.     n = *it_reshape1;
  89. printNode(n);
  90.     n->output()->replaceAllUsesWith(n->inputs()[0]);
  91.     it_reshape1.destroyCurrent();
  92. std::cout << "Here we finished delete three node in a row" << std::endl;
  93. std::cout << "===============================================" << std::endl;
  94.  
  95.   }
  96. }
  97.  
  98.  
  99.  
  100. #if 0
  101.     assert(n->inputs().size() == 5);  // TODO: support size smaller than 5
  102.  
  103.     auto end_iter = graph->initializers().end();
  104.     auto scale_iter = graph->getInitializer(n->inputs()[1]->uniqueName());
  105.     auto bias_iter = graph->getInitializer(n->inputs()[2]->uniqueName());
  106.     auto mean_iter = graph->getInitializer(n->inputs()[3]->uniqueName());
  107.     auto var_iter = graph->getInitializer(n->inputs()[4]->uniqueName());
  108.  
  109.     assert(scale_iter != end_iter);
  110.     assert(bias_iter != end_iter);
  111.     assert(mean_iter != end_iter);
  112.     assert(var_iter != end_iter);
  113.     assert(scale_iter->sizes().size() == 1);
  114.     assert(bias_iter->sizes().size() == 1 && bias_iter->sizes()[0] == scale_iter->sizes()[0]);
  115.     assert(mean_iter->sizes().size() == 1 && mean_iter->sizes()[0] == scale_iter->sizes()[0]);
  116.     assert(var_iter->sizes().size() == 1 && var_iter->sizes()[0] == scale_iter->sizes()[0]);
  117.     // TODO(wwcai): other datafmt
  118.     assert(scale_iter->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
  119.     assert(bias_iter->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
  120.     assert(mean_iter->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
  121.     assert(var_iter->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
  122.  
  123.     ONNX_NAMESPACE::Node* bn = n;
  124.     ONNX_NAMESPACE::Node* conv = n->inputs()[0]->node();
  125.  
  126.     // fuse_bn_into_conv will change the yolov2 result, disable it temporary
  127. #define ENABLE_FUSE_BN_INTO_CONV 0
  128.     if (ENABLE_FUSE_BN_INTO_CONV && conv->kind() == ONNX_NAMESPACE::kConv) {
  129.       // Before:
  130.       //     conv = Conv()
  131.       //   bn = BatchNormalization()
  132.       //
  133.       // After:
  134.       //     bn is deleted
  135.       //   new inputs/initializers to conv are added to graph
  136.       //   any no longer used inputs/initializers are erased from graph
  137.       //
  138.       //     this pass can handle the case satisfy all following conditions:
  139.       //       condition 1: Run in testing mode
  140.       //     condition 2: Inputs 1 - 4 of bn are all initializer_size
  141.       //     condition 3: Output of initial conv has no other uses
  142.       //     condition 3: Currently works for only DOUBLE, FLOAT32 tensor types
  143.       //
  144.       // Formula for transformation
  145.       // $$ X_{bn} = \frac{s(X - m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
  146.       // $$ X_{conv} = X * W + b_{conv} $$
  147.       // thus, substituting $X$ with $X_{conv}$ in the BN equation we get:
  148.       // $$X_{bn} = X * \frac{sW}{\sqrt{\sigma + \epsilon}} + \frac{s(b_{conv} - m)}{\sqrt{\sigma +
  149.       // \epsilon}} + b_{bn}$$ or
  150.       // $$ W' = W\frac{s}{\sqrt{\sigma + \epsilon}}$$
  151.       // $$ b' = (b_{conv} - m)\frac{s}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
  152.  
  153.       auto origInput = bn->inputs()[0];
  154.       if (origInput->uses().size() > 1 || bn->outputs().size() > 1 ||
  155.           !modify_conv(conv, bn, graph)) {
  156.         continue;
  157.       }
  158.       bn->output()->replaceAllUsesWith(origInput);
  159.       for (int i = 4; i >= 1; --i) {
  160.         if (bn->inputs()[i]->uses().size() == 1) {
  161.           auto input = bn->inputs()[i];
  162.           bn->removeInput(i);
  163.           graph->eraseInitializerAndInput(input);
  164.         }
  165.       }
  166.       it.destroyCurrent();
  167.     } else {
  168.       bool isInt8Model = ctx_->getDataTypeSize() == 1;
  169.       // pre-calculate bn param as scale
  170.       bn2scale(bn, graph, isInt8Model);
  171.     }
  172.   }
  173. }
  174. #endif
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement