Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- .ShapeInferenceFunction([](GraphProto* g, NodeProto* n) {
- std::vector<int64_t> perm;
- for (int i = 0; i < n->attribute_size(); i++) {
- auto attr = n->attribute(i);
- if (attr.has_name() && attr.name() == "perm") {
- for (int j = 0; j < attr.ints_size(); j++) {
- perm.push_back(attr.ints(j));
- }
- }
- }
- TensorProto::DataType dataType;
- std::vector<int64_t> dims;
- {
- auto inp = n->input(0);
- for (int j = 0; j < g->value_info_size(); j++) {
- auto vi = g->value_info(j);
- if (vi.has_name() && vi.name() == inp) {
- auto tt = vi.type().tensor_type();
- dataType = tt.elem_type();
- auto shape = tt.shape();
- for (int k = 0; k < shape.dim_size(); k++) {
- auto dim = shape.dim(k);
- if (dim.has_dim_value()) {
- dims.push_back(dim.dim_value());
- } else {
- return;
- }
- }
- }
- }
- for (int j = 0; j < g->input_size(); j++) {
- auto vi = g->input(j);
- if (vi.has_name() && vi.name() == inp) {
- auto tt = vi.type().tensor_type();
- dataType = tt.elem_type();
- auto shape = tt.shape();
- for (int k = 0; k < shape.dim_size(); k++) {
- auto dim = shape.dim(k);
- if (dim.has_dim_value()) {
- dims.push_back(dim.dim_value());
- } else {
- return;
- }
- }
- }
- }
- for (int j = 0; j < g->output_size(); j++) {
- auto vi = g->output(j);
- if (vi.has_name() && vi.name() == inp) {
- auto tt = vi.type().tensor_type();
- dataType = tt.elem_type();
- auto shape = tt.shape();
- for (int k = 0; k < shape.dim_size(); k++) {
- auto dim = shape.dim(k);
- if (dim.has_dim_value()) {
- dims.push_back(dim.dim_value());
- } else {
- return;
- }
- }
- }
- }
- }
- for (int i = 0; i < n->output_size(); i++) {
- auto out = n->output(i);
- auto vi = g->add_value_info();
- vi->set_name(out);
- auto tt = vi->mutable_type()->mutable_tensor_type();
- tt->set_elem_type(dataType);
- auto shape = tt->mutable_shape();
- for (int j = 0; j < perm.size(); j++) {
- auto dim = shape->add_dim();
- dim->set_dim_value(dims[perm[j]]);
- }
- }
- })
Add Comment
Please, Sign In to add comment