SHARE
TWEET

Untitled

a guest Oct 17th, 2019 93 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. class Transpose:
  2.     """Transpose image tensor.
  3.  
  4.     Args:
  5.         input_order (str): The order of the axes of the input image tensor.
  6.             Describe by a combination of batch (B), height (H), width (W), and
  7.             channel (C) like ``'BCHW'``.
  8.         output_order (str): The order of the axes of the output image tensor.
  9.     """
  10.     def __init__(self, input_order, output_order):
  11.         self.input_order = input_order
  12.         self.output_order = output_order
  13.  
  14.         axis_add = ''.join(set(output_order) - set(input_order))
  15.         axis_remove = ''.join(set(input_order) - set(output_order))
  16.  
  17.         self._axes = [(axis_add + input_order).index(a)
  18.                       for a in (axis_remove + output_order)]
  19.         self._axis_add = tuple([None for i in range(len(axis_add))])
  20.         self._axis_remove = tuple([0 for i in range(len(axis_remove))])
  21.  
  22.     def __call__(self, a):
  23.         if self._axis_add:
  24.             a = a[self._axis_add]
  25.  
  26.         out = a.transpose(self._axes)
  27.  
  28.         if self._axis_remove:
  29.             out = out[self._axis_remove]
  30.         return out
  31.  
  32.  
  33. def transpose(a, input_order, output_order):
  34.     return Transpose(input_order, output_order)(a)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top