Guest User

Untitled

a guest
Mar 18th, 2018
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.37 KB | None | 0 0
  1. .ShapeInferenceFunction([](GraphProto* g, NodeProto* n) {
  2. std::vector<int64_t> perm;
  3. for (int i = 0; i < n->attribute_size(); i++) {
  4. auto attr = n->attribute(i);
  5. if (attr.has_name() && attr.name() == "perm") {
  6. for (int j = 0; j < attr.ints_size(); j++) {
  7. perm.push_back(attr.ints(j));
  8. }
  9. }
  10. }
  11.  
  12. TensorProto::DataType dataType;
  13. std::vector<int64_t> dims;
  14.  
  15. {
  16. auto inp = n->input(0);
  17. for (int j = 0; j < g->value_info_size(); j++) {
  18. auto vi = g->value_info(j);
  19. if (vi.has_name() && vi.name() == inp) {
  20. auto tt = vi.type().tensor_type();
  21. dataType = tt.elem_type();
  22. auto shape = tt.shape();
  23. for (int k = 0; k < shape.dim_size(); k++) {
  24. auto dim = shape.dim(k);
  25. if (dim.has_dim_value()) {
  26. dims.push_back(dim.dim_value());
  27. } else {
  28. return;
  29. }
  30. }
  31. }
  32. }
  33. for (int j = 0; j < g->input_size(); j++) {
  34. auto vi = g->input(j);
  35. if (vi.has_name() && vi.name() == inp) {
  36. auto tt = vi.type().tensor_type();
  37. dataType = tt.elem_type();
  38. auto shape = tt.shape();
  39. for (int k = 0; k < shape.dim_size(); k++) {
  40. auto dim = shape.dim(k);
  41. if (dim.has_dim_value()) {
  42. dims.push_back(dim.dim_value());
  43. } else {
  44. return;
  45. }
  46. }
  47. }
  48. }
  49. for (int j = 0; j < g->output_size(); j++) {
  50. auto vi = g->output(j);
  51. if (vi.has_name() && vi.name() == inp) {
  52. auto tt = vi.type().tensor_type();
  53. dataType = tt.elem_type();
  54. auto shape = tt.shape();
  55. for (int k = 0; k < shape.dim_size(); k++) {
  56. auto dim = shape.dim(k);
  57. if (dim.has_dim_value()) {
  58. dims.push_back(dim.dim_value());
  59. } else {
  60. return;
  61. }
  62. }
  63. }
  64. }
  65. }
  66.  
  67. for (int i = 0; i < n->output_size(); i++) {
  68. auto out = n->output(i);
  69. auto vi = g->add_value_info();
  70. vi->set_name(out);
  71. auto tt = vi->mutable_type()->mutable_tensor_type();
  72. tt->set_elem_type(dataType);
  73. auto shape = tt->mutable_shape();
  74. for (int j = 0; j < perm.size(); j++) {
  75. auto dim = shape->add_dim();
  76. dim->set_dim_value(dims[perm[j]]);
  77. }
  78. }
  79.  
  80. })
Add Comment
Please, Sign In to add comment