Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class Transpose:
- """Transpose image tensor.
- Args:
- input_order (str): The order of the axes of the input image tensor.
- Describe by a combination of batch (B), height (H), width (W), and
- channel (C) like ``'BCHW'``.
- output_order (str): The order of the axes of the output image tensor.
- """
- def __init__(self, input_order, output_order):
- self.input_order = input_order
- self.output_order = output_order
- axis_add = ''.join(set(output_order) - set(input_order))
- axis_remove = ''.join(set(input_order) - set(output_order))
- self._axes = [(axis_add + input_order).index(a)
- for a in (axis_remove + output_order)]
- self._axis_add = tuple([None for i in range(len(axis_add))])
- self._axis_remove = tuple([0 for i in range(len(axis_remove))])
- def __call__(self, a):
- if self._axis_add:
- a = a[self._axis_add]
- out = a.transpose(self._axes)
- if self._axis_remove:
- out = out[self._axis_remove]
- return out
- def transpose(a, input_order, output_order):
- return Transpose(input_order, output_order)(a)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement