Advertisement
Guest User

Untitled

a guest
Feb 17th, 2022
2,389
1
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.02 KB | None | 1 0
  1. 1
  2. import torch
  3.  
  4. import torch
  5. import torch.nn as nn
  6. from torch.utils._pytree import tree_map, tree_flatten
  7. from typing import List, Any
  8. from numbers import Number
  9. from collections import defaultdict
  10.  
  11. aten = torch.ops.aten
  12.  
  13. def get_shape(i):
  14. return i.shape
  15.  
  16. def prod(x):
  17. res = 1
  18. for i in x:
  19. res *= i
  20. return res
  21.  
  22. def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
  23. """
  24. Count flops for matmul.
  25. """
  26. # Inputs should be a list of length 2.
  27. # Inputs contains the shapes of two matrices.
  28. input_shapes = [get_shape(v) for v in inputs]
  29. assert len(input_shapes) == 2, input_shapes
  30. assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
  31. flop = prod(input_shapes[0]) * input_shapes[-1][-1]
  32. return flop
  33.  
  34. def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
  35. """
  36. Count flops for fully connected layers.
  37. """
  38. # Count flop for nn.Linear
  39. # inputs is a list of length 3.
  40. input_shapes = [get_shape(v) for v in inputs[1:3]]
  41. # input_shapes[0]: [batch size, input feature dimension]
  42. # input_shapes[1]: [batch size, output feature dimension]
  43. assert len(input_shapes[0]) == 2, input_shapes[0]
  44. assert len(input_shapes[1]) == 2, input_shapes[1]
  45. batch_size, input_dim = input_shapes[0]
  46. output_dim = input_shapes[1][1]
  47. flops = batch_size * input_dim * output_dim
  48. return flops
  49.  
  50. def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
  51. """
  52. Count flops for the bmm operation.
  53. """
  54. # Inputs should be a list of length 2.
  55. # Inputs contains the shapes of two tensor.
  56. assert len(inputs) == 2, len(inputs)
  57. input_shapes = [get_shape(v) for v in inputs]
  58. n, c, t = input_shapes[0]
  59. d = input_shapes[-1][-1]
  60. flop = n * c * t * d
  61. return flop
  62.  
  63. def conv_flop_count(
  64. x_shape: List[int],
  65. w_shape: List[int],
  66. out_shape: List[int],
  67. transposed: bool = False,
  68. ) -> Number:
  69. """
  70. Count flops for convolution. Note only multiplication is
  71. counted. Computation for addition and bias is ignored.
  72. Flops for a transposed convolution are calculated as
  73. flops = (x_shape[2:] * prod(w_shape) * batch_size).
  74. Args:
  75. x_shape (list(int)): The input shape before convolution.
  76. w_shape (list(int)): The filter shape.
  77. out_shape (list(int)): The output shape after convolution.
  78. transposed (bool): is the convolution transposed
  79. Returns:
  80. int: the number of flops
  81. """
  82. batch_size = x_shape[0]
  83. conv_shape = (x_shape if transposed else out_shape)[2:]
  84. flop = batch_size * prod(w_shape) * prod(conv_shape)
  85. return flop
  86.  
  87. def conv_flop_jit(inputs: List[Any], outputs: List[Any]):
  88. """
  89. Count flops for convolution.
  90. """
  91. x, w = inputs[:2]
  92. x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
  93. transposed = inputs[6]
  94.  
  95. return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
  96.  
  97. def transpose_shape(shape):
  98. return [shape[1], shape[0]] + list(shape[2:])
  99.  
  100. def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]):
  101. grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
  102. output_mask = inputs[-1]
  103. fwd_transposed = inputs[7]
  104. flop_count = 0
  105.  
  106. if output_mask[0]:
  107. grad_input_shape = get_shape(outputs[0])
  108. flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
  109. if output_mask[1]:
  110. grad_weight_shape = get_shape(outputs[1])
  111. flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
  112.  
  113. return flop_count
  114.  
  115.  
  116. flop_mapping = {
  117. aten.mm: matmul_flop_jit,
  118. aten.matmul: matmul_flop_jit,
  119. aten.addmm: addmm_flop_jit,
  120. aten.bmm: bmm_flop_jit,
  121. aten.convolution: conv_flop_jit,
  122. aten._convolution: conv_flop_jit,
  123. aten.convolution_backward: conv_backward_flop_jit,
  124. }
  125.  
  126. flop_counts = defaultdict(lambda: defaultdict(int))
  127. parents = ['Global']
  128.  
  129. def normalize_tuple(x):
  130. if not isinstance(x, tuple):
  131. return (x,)
  132. return x
  133.  
  134. class FlopTensor(torch.Tensor):
  135. elem: torch.Tensor
  136.  
  137. __slots__ = ['elem']
  138.  
  139. @staticmethod
  140. def __new__(cls, elem):
  141. # The wrapping tensor (FlopTensor) shouldn't hold any
  142. # memory for the class in question, but it should still
  143. # advertise the same device as before
  144. r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
  145. cls, elem.size(),
  146. strides=elem.stride(), storage_offset=elem.storage_offset(),
  147. # TODO: clone storage aliasing
  148. dtype=elem.dtype, layout=elem.layout,
  149. device=elem.device, requires_grad=elem.requires_grad
  150. )
  151. # ...the real tensor is held as an element on the tensor.
  152. r.elem = elem
  153. return r
  154.  
  155. def __repr__(self):
  156. if self.grad_fn:
  157. return f"FlopTensor({self.elem}, grad_fn={self.grad_fn})"
  158. return f"FlopTensor({self.elem})"
  159.  
  160. @classmethod
  161. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  162. def unwrap(e):
  163. return e.elem if isinstance(e, FlopTensor) else e
  164.  
  165. # no_dispatch is only needed if you use enable_python_mode.
  166. # It prevents infinite recursion.
  167. rs = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
  168. outs = normalize_tuple(rs)
  169.  
  170. if func in flop_mapping:
  171. global flop_counts
  172. flop_count = flop_mapping[func](args, outs)
  173. for par in parents:
  174. flop_counts[par][func.__name__] += flop_count
  175.  
  176. def wrap(e):
  177. return FlopTensor(e) if isinstance(e, torch.Tensor) else e
  178.  
  179. rs = tree_map(wrap, rs)
  180. return rs
  181.  
  182.  
  183. def create_backwards_push(name):
  184. class PushState(torch.autograd.Function):
  185. @staticmethod
  186. def forward(ctx, *args):
  187. args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
  188. if len(args) == 1:
  189. return args[0]
  190. return args
  191.  
  192. @staticmethod
  193. def backward(ctx, *grad_outs):
  194. global parents
  195. parents.append(name)
  196. return grad_outs
  197.  
  198. return PushState.apply
  199.  
  200. def create_backwards_pop(name):
  201. class PopState(torch.autograd.Function):
  202. @staticmethod
  203. def forward(ctx, *args):
  204. args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
  205. if len(args) == 1:
  206. return args[0]
  207. return args
  208.  
  209. @staticmethod
  210. def backward(ctx, *grad_outs):
  211. global parents
  212. assert(parents[-1] == name)
  213. parents.pop()
  214. return grad_outs
  215.  
  216. return PopState.apply
  217.  
  218.  
  219.  
  220. def enter_module(name):
  221. def f(module, inputs):
  222. global parents
  223. parents.append(name)
  224. inputs = normalize_tuple(inputs)
  225. out = create_backwards_pop(name)(*inputs)
  226. return out
  227.  
  228. return f
  229.  
  230. def exit_module(name):
  231. def f(module, inputs, outputs):
  232. global parents
  233. assert(parents[-1] == name)
  234. parents.pop()
  235. outputs = normalize_tuple(outputs)
  236. return create_backwards_push(name)(*outputs)
  237. return f
  238.  
  239. def instrument_module(mod):
  240. for name, module in dict(mod.named_children()).items():
  241. module.register_forward_pre_hook(enter_module(name))
  242. module.register_forward_hook(exit_module(name))
  243.  
  244. def start_counting():
  245. global parents, flop_counts
  246. parents = ['Global']
  247. flop_counts.clear()
  248.  
  249. def display_flops():
  250. for mod in flop_counts.keys():
  251. print(f"Module: ", mod)
  252. for k,v in flop_counts[mod].items():
  253. print(k, v/1e9)
  254. print()
  255.  
  256.  
  257. import torchvision.models as models
  258. mod = models.resnet18().cuda()
  259. instrument_module(mod)
  260.  
  261. inp = torch.randn(1, 3, 224, 224, device='cuda')
  262. mod(FlopTensor(inp)).sum().backward()
  263.  
  264. display_flops()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement