Guest User

Untitled

a guest
Aug 15th, 2018
59
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 22.29 KB | None | 0 0
  1. /*!
  2. * Copyright (c) 2017 by Contributors
  3. * \file graph_fuse.cc
  4. * \brief Fuse the operators together.
  5. */
  6. #include <nnvm/graph.h>
  7. #include <nnvm/node.h>
  8. #include <nnvm/op_attr_types.h>
  9. #include <nnvm/graph_attr_types.h>
  10. #include <nnvm/tuple.h>
  11. #include <nnvm/pass.h>
  12. #include <nnvm/pass_functions.h>
  13. #include <nnvm/compiler/packed_func_ext.h>
  14. #include <tvm/runtime/packed_func.h>
  15. #include <tvm/lowered_func.h>
  16. #include <dmlc/parameter.h>
  17. #include "./compile_engine.h"
  18. #include "./graph_runtime.h"
  19. #include "./pattern_util.h"
  20.  
  21. namespace nnvm {
  22. namespace compiler {
  23. using namespace tvm;
  24.  
  25. // The single fuse rule.
  26. enum class FuseRule {
  27. kUknown,
  28. kFuseToMaster,
  29. kRealize
  30. };
  31.  
  32. /*!
  33. * \brief Get DLDataType from dtype flag.
  34. *
  35. * \param type_flag The data type flag
  36. * \return corresponding DLDataType
  37. */
  38. DLDataType GetDLType(int type_flag) {
  39. return Type2TVMType(GetTVMType(type_flag));
  40. }
  41.  
  42. // Partition the graph into segments
  43. // Each segment will be compiled into one operator.
  44. // Need also mark the property of the segment.
  45. nnvm::Graph GraphFusePartition(nnvm::Graph g) {
  46. // setup ref counter
  47. const IndexedGraph& idx = g.indexed_graph();
  48. int opt_level = 2;
  49. if (g.attrs.count("opt_level") != 0) {
  50. opt_level = g.MoveCopyAttr<int>("opt_level");
  51. }
  52.  
  53. // Get attributes from the graph
  54. const ShapeVector& shape_vec = g.GetAttr<ShapeVector>("shape");
  55.  
  56. // Reference counter of each op node
  57. // For now, always store result when an op is referred more than once.
  58. std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
  59. for (const auto& e : idx.outputs()) {
  60. // this line will realize all the outputs
  61. ref_count[e.node_id] += 1;
  62. }
  63. // Pattern for the subgraph
  64. std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kOpaque);
  65. // Whether node can be fused to parent.
  66. std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown);
  67. // Master node id of fusion segment.
  68. std::vector<int> master_vec(idx.num_nodes(), -1);
  69. // Operator pattern
  70. static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");
  71.  
  72. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
  73. const auto& inode = idx[nid];
  74. if (inode.source->is_variable()) {
  75. fuse_vec[nid] = FuseRule::kRealize; continue;
  76. }
  77. TOpPattern pt = op_pattern.get(inode.source->op(), kOpaque);
  78.  
  79. if (pt <= kBroadcast) {
  80. // Try to check if we can fuse to the master.
  81. int chosen_master = -1;
  82. bool ewise = inode.source->num_outputs() == 1;
  83. LOG(INFO) << "Broadcast: " << inode.source->op()->name << ", inputs.size" << inode.inputs.size();
  84. for (const auto& e : inode.inputs) {
  85. auto p = idx[e.node_id];
  86. if (p.source->op()) {
  87. LOG(INFO) << "input: " << p.source->op()->name;
  88. LOG(INFO) << "fuse_vec: " << (int)fuse_vec[e.node_id];
  89. }
  90. else {
  91. LOG(INFO) << "not op node";
  92. }
  93.  
  94. if (fuse_vec[e.node_id] == FuseRule::kUknown) {
  95. TOpPattern ipt = pattern_vec[e.node_id];
  96. if (ipt != kElemWise) ewise = false;
  97. if (ipt <= kInjective) {
  98. fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
  99. LOG(INFO) << "fuse to master " << p.source->op()->name;
  100. } else if (ipt == kOutEWiseFusable &&
  101. chosen_master == -1 &&
  102. shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
  103. chosen_master = master_vec[e.node_id];
  104. fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
  105. } else {
  106. fuse_vec[e.node_id] = FuseRule::kRealize;
  107. }
  108. }
  109. if (ewise) {
  110. if (shape_vec[idx.entry_id(nid, 0)] != shape_vec[idx.entry_id(e)]) {
  111. ewise = false;
  112. }
  113. }
  114. }
  115. master_vec[nid] = chosen_master;
  116. if (chosen_master != -1) {
  117. pt = kOutEWiseFusable;
  118. } else {
  119. pt = ewise ? kElemWise : kBroadcast;
  120. }
  121. } else if (pt == kInjective || pt == kCommReduce) {
  122. // fuse to the comm reduce or injective
  123. for (const auto& e : inode.inputs) {
  124. if (fuse_vec[e.node_id] == FuseRule::kUknown) {
  125. TOpPattern ipt = pattern_vec[e.node_id];
  126. if (ipt <= kInjective) {
  127. fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
  128. } else {
  129. fuse_vec[e.node_id] = FuseRule::kRealize;
  130. }
  131. }
  132. }
  133. if (pt == kCommReduce) {
  134. master_vec[nid] = nid;
  135. }
  136. } else {
  137. // realize
  138. master_vec[nid] = nid;
  139. LOG(INFO) << "Realize: " << inode.source->op()->name << ", inputs.size" << inode.inputs.size();
  140. for (const auto& e : inode.inputs) {
  141. auto p = idx[e.node_id];
  142. if (p.source->op()) {
  143. LOG(INFO) << "input: " << p.source->op()->name;
  144. LOG(INFO) << "fuse_vec: " << (int)fuse_vec[e.node_id];
  145. }
  146. else {
  147. LOG(INFO) << "not op node";
  148. }
  149. if (fuse_vec[e.node_id] == FuseRule::kUknown) {
  150. fuse_vec[e.node_id] = FuseRule::kRealize;
  151. if (master_vec[e.node_id] == -1) {
  152. master_vec[e.node_id] = e.node_id;
  153. }
  154. }
  155. }
  156. }
  157.  
  158. pattern_vec[nid] = pt;
  159. if (ref_count[nid] > 1 || opt_level < 1) {
  160. fuse_vec[nid] = FuseRule::kRealize;
  161. if (master_vec[nid] == -1) {
  162. master_vec[nid] = nid;
  163. }
  164. }
  165. }
  166.  
  167. // point to the group root id of each node
  168. std::vector<int> group_vec(idx.num_nodes(), -1);
  169. for (uint32_t i = idx.num_nodes(); i != 0; --i) {
  170. uint32_t nid = i - 1;
  171. const auto& inode = idx[nid];
  172. if (group_vec[nid] == -1) {
  173. group_vec[nid] = nid;
  174. }
  175. // propagate the group id.
  176. for (const auto& e : inode.inputs) {
  177. if (fuse_vec[e.node_id] == FuseRule::kFuseToMaster) {
  178. CHECK(group_vec[e.node_id] == -1||
  179. group_vec[e.node_id] == group_vec[nid]);
  180. group_vec[e.node_id] = group_vec[nid];
  181. }
  182. }
  183. }
  184.  
  185. /*
  186. Above algorithm will not fuse a node whose output is fed to more than one
  187. child node. This is because in general, it does not make sense to fuse multiple
  188. children branches with their parent, as in the following example.
  189.  
  190. conv2d
  191. / | \
  192. / | \
  193. op op op
  194. | | |
  195. | | |
  196.  
  197. However, when all children branches meet at a certain node, there is a possibility for
  198. further operator fusion. For example, all nodes in the following subgraph can be fused
  199. into a single node, if three 'in-between' nodes and the bottom node are all element wise
  200. operation.
  201.  
  202. conv2d
  203. / | \
  204. / | \
  205. op op op
  206. \ | /
  207. \ | /
  208. elemwise add
  209. |
  210.  
  211. This pattern is not uncommon. For example, it arises when conv2d op is followed by exponential
  212. linear unit. If bias add and batch normalization are also present, they can be fused as well.
  213.  
  214. In fact, above fusion algorithm already fuses three in-between nodes and the element wise
  215. add node in the figure above. The following code fuses the conv2d node with the already
  216. fused children nodes. The following patterns are supported.
  217.  
  218. * Any number of child nodes from the top node
  219. * The path from the top node to bottom node can contain any number of element wise ops.
  220.  
  221. The only restriction is that in-between nodes cannot have more than one child.
  222.  
  223. The overview of the algorithm below is as follows:
  224.  
  225. 1. Check if all children nodes are fused into a single op by the existing fusion algorithm
  226. 2. Fuse the parent node to children nodes, and update its group id to be the children's group id
  227. 3. If the parent node originally belongs to another group (for example, conv + batch norm),
  228. propagate the new group id to a grand parent and upward
  229. */
  230. if (opt_level >= 1) {
  231. std::vector<std::vector<uint32_t> > children_group_ids(idx.num_nodes());
  232. std::vector<std::vector<uint32_t> > node_ids_per_group(idx.num_nodes());
  233. for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) {
  234. const auto& inode = idx[nid];
  235. if (inode.source->is_variable()) continue;
  236. CHECK_NE(group_vec[nid], -1);
  237. node_ids_per_group[group_vec[nid]].push_back(nid);
  238. const uint32_t parent_nid = inode.inputs[0].node_id;
  239. // if parent node has more than one child, record each child's group id.
  240. if (ref_count[parent_nid] > 1) children_group_ids[parent_nid].push_back(group_vec[nid]);
  241. }
  242.  
  243. std::vector<int> new_group_id(idx.num_nodes(), -1);
  244. for (uint32_t nid = idx.num_nodes() - 1; nid != 0; --nid) {
  245. if (new_group_id[group_vec[nid]] != -1) {
  246. // propagate new group id from child
  247. group_vec[nid] = new_group_id[group_vec[nid]];
  248. }
  249. TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque);
  250. if (pt == kOpaque) continue;
  251. const auto& group_ids = children_group_ids[nid];
  252. if (group_ids.size() <= 1) continue;
  253. const uint32_t child_group_id = group_ids[0];
  254. const auto& children_node_ids = node_ids_per_group[child_group_id];
  255.  
  256. auto is_same_group_id = [child_group_id](uint32_t id) {
  257. return id == child_group_id;
  258. };
  259. auto is_fusible_pattern = [&idx](uint32_t child_nid) {
  260. TOpPattern child_pt = op_pattern.get(idx[child_nid].source->op(), kOpaque);
  261. return child_pt <= kBroadcast;
  262. };
  263. // fuse this node with children if
  264. // all children belong to the same group and
  265. // all nodes in the group are element wise or broadcast op.
  266. const bool can_be_fused = std::all_of(group_ids.begin(), group_ids.end(), is_same_group_id) &&
  267. std::all_of(children_node_ids.begin(), children_node_ids.end(), is_fusible_pattern);
  268.  
  269. if (can_be_fused) {
  270. LOG(INFO) << "fusing";
  271. new_group_id[group_vec[nid]] = child_group_id;
  272. group_vec[nid] = child_group_id;
  273. for (uint32_t nid2 : node_ids_per_group[child_group_id]) {
  274. pattern_vec[nid2] = pattern_vec[nid];
  275. master_vec[nid2] = master_vec[nid];
  276. }
  277. }
  278. }
  279. }
  280.  
  281. g.attrs["group_root"] = std::make_shared<any>(std::move(group_vec));
  282. g.attrs["group_master"] = std::make_shared<any>(std::move(master_vec));
  283. g.attrs["pattern"] = std::make_shared<any>(std::move(pattern_vec));
  284. return g;
  285. }
  286.  
  287.  
  288. NNVM_REGISTER_PASS(GraphFusePartition)
  289. .set_body(GraphFusePartition)
  290. .depend_graph_attr("shape")
  291. .depend_graph_attr("dtype");
  292.  
  293.  
  294. // Decorate the result of PlanMemory
  295. // This function does two things:
  296. // - Give separate memory to each variable
  297. // - Tie the memory of output/lhs in assign node properly
  298. // so the execution of assign can have side effect.
  299. nnvm::Graph DecorateMemoryPlan(
  300. nnvm::Graph g,
  301. const std::vector<int>& assign_flag) {
  302. // setup ref counter
  303. const IndexedGraph& idx = g.indexed_graph();
  304. StorageVector storage_vec = g.MoveCopyAttr<StorageVector>("storage_id");
  305. g.attrs.erase("storage_allocated_bytes");
  306. g.attrs.erase("storage_inplace_index");
  307. size_t num_not_allocated = g.MoveCopyAttr<size_t>(
  308. "storage_num_not_allocated");
  309. CHECK_EQ(num_not_allocated, 0U)
  310. << "Can only build inference graph with all statically allocated memory";
  311.  
  312. // reassign variable id so that they are different.
  313. int max_id = 0;
  314. for (size_t i = 0; i < storage_vec.size(); ++i) {
  315. max_id = std::max(storage_vec[i] + 1, max_id);
  316. }
  317. for (uint32_t nid : idx.input_nodes()) {
  318. storage_vec[idx.entry_id(nid, 0)] = max_id++;
  319. }
  320. // tie up the assign node storage properly
  321. for (uint32_t nid = 0 ; nid < idx.num_nodes(); ++nid) {
  322. if (assign_flag[nid] == 0) continue;
  323. const auto& inode = idx[nid];
  324. int var_storage_id = storage_vec[idx.entry_id(inode.inputs[0])];
  325. storage_vec[idx.entry_id(nid, 0)] = var_storage_id;
  326.  
  327. if (assign_flag[nid] == 2) {
  328. storage_vec[idx.entry_id(inode.inputs[1])] = var_storage_id;
  329. }
  330. }
  331. g.attrs["storage_id"] = std::make_shared<any>(std::move(storage_vec));
  332. return g;
  333. }
  334.  
  335. struct INodeEntryHash {
  336. size_t operator()(const IndexedGraph::NodeEntry& e) const {
  337. return e.node_id;
  338. }
  339. };
  340.  
  341. struct INodeEntryEqual {
  342. size_t operator()(const IndexedGraph::NodeEntry& a,
  343. const IndexedGraph::NodeEntry& b) const {
  344. return a.node_id == b.node_id && a.index == b.index;
  345. }
  346. };
  347.  
  348. // Auxiliary data structure for representing fused op.
  349. struct FuseEntry {
  350. // subgraph of the fragement
  351. Graph subgraph;
  352. // The input map
  353. std::unordered_map<IndexedGraph::NodeEntry, nnvm::NodeEntry,
  354. INodeEntryHash, INodeEntryEqual> imap;
  355. // reverse map to the old input entry
  356. std::unordered_map<const Node*, IndexedGraph::NodeEntry> reverse_imap;
  357. // TVM Placeholder for inputs
  358. std::unordered_map<const Node*, Tensor> input_info;
  359. // Whether we can flatten data
  360. bool flatten_data;
  361. // The corresponding function.
  362. GraphFunc compiled_func;
  363. };
  364.  
  365. // Fuse the partitioned graph into segments.
  366. // Create a new graph with fused noded.
  367. // Also inheritate attribute shape, dltype from previous graph.
  368. nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
  369. // setup ref counter
  370. const IndexedGraph& idx = g.indexed_graph();
  371. // Get attributes from the graph
  372. const ShapeVector& shape_vec = g.GetAttr<ShapeVector>("shape");
  373. const DTypeVector& dtype_vec = g.GetAttr<DTypeVector>("dtype");
  374. const std::vector<int>& group_vec = g.GetAttr<std::vector<int> >("group_root");
  375. const std::vector<int>& master_vec = g.GetAttr<std::vector<int> >("group_master");
  376. const std::vector<TOpPattern>& pattern_vec =
  377. g.GetAttr<std::vector<TOpPattern> >("pattern");
  378. std::string target = g.GetAttr<std::string>("target");
  379. std::string target_host;
  380.  
  381. if (g.HasAttr("target_host")) {
  382. target_host = g.GetAttr<std::string>("target_host");
  383. }
  384. // specially handle assign
  385. const nnvm::Op* assign_op = nnvm::Op::Get("_assign");
  386.  
  387. std::vector<FuseEntry> fuse_vec(idx.num_nodes());
  388. // setup inputs and placeholder.
  389. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
  390. const auto& inode = idx[nid];
  391. if (inode.source->is_variable()) continue;
  392. CHECK_GE(group_vec[nid], 0);
  393. int root_id = group_vec[nid];
  394. FuseEntry& fe = fuse_vec[root_id];
  395. fe.flatten_data = (pattern_vec[root_id] == kElemWise ||
  396. inode.source->op() == assign_op);
  397. for (const auto& e : inode.inputs) {
  398. if (group_vec[e.node_id] != root_id && fe.imap.count(e) == 0) {
  399. Array<Expr> shape;
  400. if (fe.flatten_data) {
  401. // elementwise support flatten
  402. int64_t prod = 1;
  403. for (int64_t x : shape_vec[idx.entry_id(e)]) {
  404. prod *= x;
  405. }
  406. CHECK_LE(prod, static_cast<int64_t>(std::numeric_limits<int>::max()));
  407. shape.push_back(make_const(Int(32), prod));
  408. } else {
  409. for (int64_t x : shape_vec[idx.entry_id(e)]) {
  410. CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
  411. shape.push_back(make_const(Int(32), x));
  412. }
  413. }
  414. std::ostringstream os_name;
  415. os_name << "input" << fe.imap.size();
  416. Tensor data = placeholder(
  417. shape, TVMType2Type(GetDLType(dtype_vec[idx.entry_id(e)])),
  418. os_name.str());
  419. NodeEntry garg = Symbol::CreateVariable(os_name.str()).outputs[0];
  420. fe.imap[e] = garg;
  421. fe.reverse_imap[garg.node.get()] = e;
  422. fe.input_info[garg.node.get()] = std::move(data);
  423. }
  424. }
  425. }
  426. // Setup the Subgraph
  427. std::vector<NodeEntry> subgraph_vec(idx.num_node_entries());
  428. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
  429. const auto& inode = idx[nid];
  430. if (inode.source->is_variable()) continue;
  431. int root_id = group_vec[nid];
  432. FuseEntry& fe = fuse_vec[root_id];
  433. // copy and create subgraph node.
  434. NodePtr gnode = Node::Create();
  435. gnode->attrs = inode.source->attrs;
  436. // input loading
  437. for (const auto& e : inode.inputs) {
  438. if (group_vec[e.node_id] != root_id) {
  439. auto it = fe.imap.find(e);
  440. CHECK(it != fe.imap.end());
  441. gnode->inputs.push_back(it->second);
  442. } else {
  443. const NodeEntry& ne = subgraph_vec[idx.entry_id(e)];
  444. CHECK(!idx[e.node_id].source->is_variable());
  445. CHECK(ne.node != nullptr);
  446. gnode->inputs.push_back(ne);
  447. }
  448. }
  449. // schedule on root node, and use master's schedule
  450. if (static_cast<int>(nid) != root_id) {
  451. for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
  452. uint32_t eid = idx.entry_id(nid, index);
  453. subgraph_vec[eid] = NodeEntry{gnode, index, 0};
  454. }
  455. } else {
  456. for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
  457. fe.subgraph.outputs.push_back(NodeEntry{gnode, index, 0});
  458. }
  459. }
  460. }
  461. // Start lowering
  462. Array<tvm::LoweredFunc> func_list;
  463. std::unordered_set<const tvm::Node*> func_set;
  464.  
  465. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
  466. const auto& inode = idx[nid];
  467. if (inode.source->is_variable()) continue;
  468. int root_id = group_vec[nid];
  469. if (static_cast<int>(nid) != root_id) continue;
  470. int master = master_vec[root_id];
  471. FuseEntry& fe = fuse_vec[root_id];
  472.  
  473. const IndexedGraph& subidx = fe.subgraph.indexed_graph();
  474. CHECK_EQ(subidx.input_nodes().size(), fe.imap.size());
  475. CHECK_EQ(subidx.input_nodes().size(), fe.input_info.size());
  476.  
  477. Array<Tensor> inputs;
  478. for (uint32_t sub_input_id : subidx.input_nodes()) {
  479. auto it = fe.input_info.find(subidx[sub_input_id].source);
  480. inputs.push_back(it->second);
  481. }
  482. // find master idx in subgraph
  483. int sub_master_idx = 0;
  484. for (uint32_t i = 0; i < subidx.num_nodes(); i++) {
  485. if (subidx[i].source->op() == idx[master].source->op()) {
  486. sub_master_idx = i;
  487. break;
  488. }
  489. }
  490. fe.compiled_func = GraphLower(fe.subgraph, inputs, target, sub_master_idx);
  491. for (LoweredFunc f : fe.compiled_func->funcs) {
  492. if (!func_set.count(f.get())) {
  493. func_set.insert(f.get());
  494. func_list.push_back(f);
  495. }
  496. }
  497. }
  498.  
  499. const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op");
  500.  
  501. std::unordered_map<uint32_t, nnvm::NodePtr> old_new;
  502. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
  503. const auto& inode = idx[nid];
  504. if (inode.source->is_variable()) {
  505. // only copy over name since that is sufficient.
  506. nnvm::NodePtr np = nnvm::Node::Create();
  507. np->attrs.name = inode.source->attrs.name;
  508. old_new[nid] = np;
  509. continue;
  510. }
  511. int root_id = group_vec[nid];
  512. if (static_cast<int>(nid) != root_id) continue;
  513.  
  514. // Handle normal op
  515. FuseEntry& fe = fuse_vec[root_id];
  516. const IndexedGraph& subidx = fe.subgraph.indexed_graph();
  517. nnvm::NodePtr np = nnvm::Node::Create();
  518. np->attrs.op = tvm_op;
  519. np->attrs.name = inode.source->attrs.name;
  520. TVMOpParam param;
  521. param.func_name = fe.compiled_func->func_name;
  522. param.num_inputs = static_cast<uint32_t>(fe.imap.size());
  523. param.num_outputs = static_cast<uint32_t>(fe.subgraph.outputs.size());
  524. param.flatten_data = fe.flatten_data;
  525. param.UpdateDict(&(np->attrs.dict));
  526. np->attrs.parsed = std::move(param);
  527.  
  528. for (uint32_t sub_input_id : subidx.input_nodes()) {
  529. // Need to make sure subgraph input order meets order of the graph input
  530. auto rit = fe.reverse_imap.find(subidx[sub_input_id].source);
  531. CHECK(rit != fe.reverse_imap.end());
  532. const IndexedGraph::NodeEntry& e = rit->second;
  533. auto it = old_new.find(e.node_id);
  534. CHECK(it != old_new.end())
  535. << "cannot find node_id=" << e.node_id;
  536. np->inputs.emplace_back(
  537. nnvm::NodeEntry{it->second, e.index, e.version});
  538. }
  539. for (const uint32_t node_id : inode.control_deps) {
  540. auto it = old_new.find(node_id);
  541. CHECK(it != old_new.end());
  542. np->control_deps.emplace_back(it->second);
  543. }
  544. old_new[nid] = np;
  545. }
  546. nnvm::Graph ret;
  547. for (const auto& e : idx.outputs()) {
  548. auto it = old_new.find(group_vec[e.node_id]);
  549. CHECK(it != old_new.end())
  550. << "cannot find node_id=" << e.node_id;
  551. ret.outputs.emplace_back(
  552. nnvm::NodeEntry{it->second, e.index, e.version});
  553. }
  554.  
  555. // Reference counter of each op node
  556. // For now, always store result when an op is referred more than once.
  557. std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
  558. for (const auto& e : idx.outputs()) {
  559. // this line will realize all the outputs
  560. ref_count[e.node_id] += 1;
  561. }
  562.  
  563. const IndexedGraph& new_idx = ret.indexed_graph();
  564.  
  565. // Handling assign:
  566. //
  567. // assign is a special operator that mutates the variable.
  568. // Currently assign is implemented as output = copy(input[1])
  569. // Then we run DecorageMemoryPlan to force
  570. // output.storage = input[0].storage
  571. //
  572. std::vector<int> assign_flag(new_idx.num_nodes(), 0);
  573. ShapeVector new_shape_vec = ShapeVector(new_idx.num_node_entries(), TShape());
  574. DTypeVector new_dtype_vec = DTypeVector(new_idx.num_node_entries());
  575. std::vector<std::string> new_dltype_vec(new_idx.num_node_entries());
  576.  
  577. for (const auto& kv : old_new) {
  578. uint32_t nid = kv.first;
  579. const auto& inode = idx[nid];
  580. uint32_t new_nid = new_idx.node_id(kv.second.get());
  581. if (inode.source->op() == assign_op) {
  582. // Check if rhs of assign can be comute inplace
  583. // If yes, we can simply set that memory to be assign target
  584. // and change assign to nop
  585. const IndexedGraph::NodeEntry& rhs = inode.inputs[1];
  586. if (ref_count[rhs.node_id] <= 1 &&
  587. !(idx[rhs.node_id].source->is_variable()) &&
  588. pattern_vec[group_vec[rhs.node_id]] <= kBroadcast) {
  589. assign_flag[new_nid] = 2;
  590. TVMOpParam& param = dmlc::get<TVMOpParam>(kv.second->attrs.parsed);
  591. param.func_name = "__nop";
  592. param.UpdateDict(&(kv.second->attrs.dict));
  593. } else {
  594. assign_flag[new_nid] = 1;
  595. }
  596. }
  597. for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
  598. uint32_t new_eid = new_idx.entry_id(new_idx.node_id(kv.second.get()), i);
  599. uint32_t old_eid = idx.entry_id(nid, i);
  600. new_shape_vec[new_eid] = shape_vec[old_eid];
  601. new_dtype_vec[new_eid] = dtype_vec[old_eid];
  602. new_dltype_vec[new_eid] = tvm::runtime::TVMType2String(
  603. GetDLType(dtype_vec[old_eid]));
  604. }
  605. }
  606. ret.attrs["shape"] = std::make_shared<any>(std::move(new_shape_vec));
  607. ret.attrs["dtype"] = std::make_shared<any>(std::move(new_dtype_vec));
  608. ret.attrs["dltype"] = std::make_shared<any>(std::move(new_dltype_vec));
  609. // Setup module
  610. static const PackedFunc& fbuild = GetPackedFunc("nnvm.compiler.build_target");
  611. tvm::runtime::Module module = fbuild(func_list, target, target_host);
  612. ret.attrs["module"] = std::make_shared<any>(std::move(module));
  613. ret = nnvm::ApplyPass(ret, "PlanMemory");
  614. ret = DecorateMemoryPlan(ret, assign_flag);
  615. return ret;
  616. }
  617.  
  618. NNVM_REGISTER_PASS(GraphFuseCompile)
  619. .set_body(GraphFuseCompile);
  620.  
  621. } // namespace compiler
  622. } // namespace nnvm
Add Comment
Please, Sign In to add comment