Advertisement
Guest User

Untitled

a guest
Oct 17th, 2019
136
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.15 KB | None | 0 0
  1. class Transpose:
  2. """Transpose image tensor.
  3.  
  4. Args:
  5. input_order (str): The order of the axes of the input image tensor.
  6. Describe by a combination of batch (B), height (H), width (W), and
  7. channel (C) like ``'BCHW'``.
  8. output_order (str): The order of the axes of the output image tensor.
  9. """
  10. def __init__(self, input_order, output_order):
  11. self.input_order = input_order
  12. self.output_order = output_order
  13.  
  14. axis_add = ''.join(set(output_order) - set(input_order))
  15. axis_remove = ''.join(set(input_order) - set(output_order))
  16.  
  17. self._axes = [(axis_add + input_order).index(a)
  18. for a in (axis_remove + output_order)]
  19. self._axis_add = tuple([None for i in range(len(axis_add))])
  20. self._axis_remove = tuple([0 for i in range(len(axis_remove))])
  21.  
  22. def __call__(self, a):
  23. if self._axis_add:
  24. a = a[self._axis_add]
  25.  
  26. out = a.transpose(self._axes)
  27.  
  28. if self._axis_remove:
  29. out = out[self._axis_remove]
  30. return out
  31.  
  32.  
  33. def transpose(a, input_order, output_order):
  34. return Transpose(input_order, output_order)(a)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement