Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- static void printNode(bmnet_onnx::Node *n) {
- if (n->has_name())
- std::cout << "Node's name: " << n->name() << std::endl;
- else
- std::cout << "Node has got no name" << std::endl;
- std::cout << "Node's type: " << n->kind() << std::endl;
- std::cout << "input number: " << n->inputs().size() << std::endl;
- std::cout << "output number: " << n->outputs().size() << std::endl;
- std::cout << "input dim size: " << n->inputs()[0]->sizes().size() << std::endl;
- for (int i = 0; i < n->inputs()[0]->sizes().size(); i++) {
- std::cout << n->inputs()[0]->sizes()[i].dim << ", ";
- }
- std::cout << std::endl;
- std::cout << "output dim size: " << n->outputs()[0]->sizes().size() << std::endl;
- for (int i = 0; i < n->outputs()[0]->sizes().size(); i++) {
- std::cout << n->outputs()[0]->sizes()[i].dim << ", ";
- }
- std::cout << std::endl;
- }
- void OnnxNetParser::shuffleChannelPass(Module* module) {
- ONNX_NAMESPACE::Graph* graph = module->getGraphIR().get();
- for (auto it = graph->begin(), ie = graph->end(); it != ie; ++it) {
- auto it_temp = it;
- auto it_reshape1 = it;
- auto* n = *it;
- ONNX_NAMESPACE::Node* curNode = n;
- /* the first node is reshape, from 4D to 5D */
- if ( n->kind() != ONNX_NAMESPACE::kReshape
- || n->inputs()[0]->sizes().size() != 4
- || n->outputs()[0]->sizes().size() != 5)
- continue;
- printNode(n);
- /* the second node is transpose */
- ++it_temp;
- auto it_transpose = it_temp;
- n = *it_temp;
- if (n->kind() != ONNX_NAMESPACE::kTranspose)
- continue;
- printNode(n);
- /* the third node is reshape, from 5D to 4D */
- ++it_temp;
- auto it_reshape2 = it_temp;
- n = *it_temp;
- if ( n->kind() != ONNX_NAMESPACE::kReshape
- || n->inputs()[0]->sizes().size() != 5
- || n->outputs()[0]->sizes().size() != 4)
- continue;
- printNode(n);
- /* here is one shuffle channel needs to be handled */
- ONNX_NAMESPACE::Value *inputData = curNode->inputs()[0];
- int group = curNode->outputs()[0]->sizes()[1].dim;
- ONNX_NAMESPACE::Node *shuffle = graph->create(ONNX_NAMESPACE::Symbol("ShuffleChannel"));
- shuffle->setName("OC2_DUMMY_" + curNode->name());
- shuffle->addInput(inputData);
- shuffle->output()->setElemType(inputData->elemType());
- shuffle->output()->setUniqueName("OC2_DUMMY_" + n->outputs()[0]->uniqueName());
- //shuffle->insertBefore(curNode);
- #if 0
- /* replace users' input */
- std::vector<ONNX_NAMESPACE::Node *> user_nodes;
- for (int j = 0; j < n->outputs()[0]->uses().size(); j++) {
- auto user_node = n->outputs()[0]->uses()[j].user;
- user_nodes.push_back(user_node);
- }
- for (int j = 0; j < user_nodes.size(); j++) {
- auto user_node = user_nodes[j];
- user_node->replaceInput(0, shuffle->outputs()[0]);
- }
- #endif
- std::cout << "Here we start to delete three nodes in a row" << std::endl;
- std::cout << "===============================================" << std::endl;
- /* destroy 3 nodes */
- n = *it_reshape2;
- printNode(n);
- n->output()->replaceAllUsesWith(n->inputs()[0]);
- it_reshape2.destroyCurrent();
- n = *it_transpose;
- printNode(n);
- n->output()->replaceAllUsesWith(n->inputs()[0]);
- it_transpose.destroyCurrent();
- n = *it_reshape1;
- printNode(n);
- n->output()->replaceAllUsesWith(n->inputs()[0]);
- it_reshape1.destroyCurrent();
- std::cout << "Here we finished delete three node in a row" << std::endl;
- std::cout << "===============================================" << std::endl;
- }
- }
- #if 0
- assert(n->inputs().size() == 5); // TODO: support size smaller than 5
- auto end_iter = graph->initializers().end();
- auto scale_iter = graph->getInitializer(n->inputs()[1]->uniqueName());
- auto bias_iter = graph->getInitializer(n->inputs()[2]->uniqueName());
- auto mean_iter = graph->getInitializer(n->inputs()[3]->uniqueName());
- auto var_iter = graph->getInitializer(n->inputs()[4]->uniqueName());
- assert(scale_iter != end_iter);
- assert(bias_iter != end_iter);
- assert(mean_iter != end_iter);
- assert(var_iter != end_iter);
- assert(scale_iter->sizes().size() == 1);
- assert(bias_iter->sizes().size() == 1 && bias_iter->sizes()[0] == scale_iter->sizes()[0]);
- assert(mean_iter->sizes().size() == 1 && mean_iter->sizes()[0] == scale_iter->sizes()[0]);
- assert(var_iter->sizes().size() == 1 && var_iter->sizes()[0] == scale_iter->sizes()[0]);
- // TODO(wwcai): other datafmt
- assert(scale_iter->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
- assert(bias_iter->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
- assert(mean_iter->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
- assert(var_iter->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
- ONNX_NAMESPACE::Node* bn = n;
- ONNX_NAMESPACE::Node* conv = n->inputs()[0]->node();
- // fuse_bn_into_conv will change the yolov2 result, disable it temporary
- #define ENABLE_FUSE_BN_INTO_CONV 0
- if (ENABLE_FUSE_BN_INTO_CONV && conv->kind() == ONNX_NAMESPACE::kConv) {
- // Before:
- // conv = Conv()
- // bn = BatchNormalization()
- //
- // After:
- // bn is deleted
- // new inputs/initializers to conv are added to graph
- // any no longer used inputs/initializers are erased from graph
- //
- // this pass can handle the case satisfy all following conditions:
- // condition 1: Run in testing mode
- // condition 2: Inputs 1 - 4 of bn are all initializer_size
- // condition 3: Output of initial conv has no other uses
- // condition 3: Currently works for only DOUBLE, FLOAT32 tensor types
- //
- // Formula for transformation
- // $$ X_{bn} = \frac{s(X - m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
- // $$ X_{conv} = X * W + b_{conv} $$
- // thus, substituting $X$ with $X_{conv}$ in the BN equation we get:
- // $$X_{bn} = X * \frac{sW}{\sqrt{\sigma + \epsilon}} + \frac{s(b_{conv} - m)}{\sqrt{\sigma +
- // \epsilon}} + b_{bn}$$ or
- // $$ W' = W\frac{s}{\sqrt{\sigma + \epsilon}}$$
- // $$ b' = (b_{conv} - m)\frac{s}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
- auto origInput = bn->inputs()[0];
- if (origInput->uses().size() > 1 || bn->outputs().size() > 1 ||
- !modify_conv(conv, bn, graph)) {
- continue;
- }
- bn->output()->replaceAllUsesWith(origInput);
- for (int i = 4; i >= 1; --i) {
- if (bn->inputs()[i]->uses().size() == 1) {
- auto input = bn->inputs()[i];
- bn->removeInput(i);
- graph->eraseInitializerAndInput(input);
- }
- }
- it.destroyCurrent();
- } else {
- bool isInt8Model = ctx_->getDataTypeSize() == 1;
- // pre-calculate bn param as scale
- bn2scale(bn, graph, isInt8Model);
- }
- }
- }
- #endif
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement