Advertisement
Guest User

Untitled

a guest
Oct 20th, 2022
1,777
1
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.67 KB | None | 1 0
  1. import torch
  2.  
  3. import torch
  4. import torch.nn as nn
  5. from torch.utils._pytree import tree_map, tree_flatten
  6. from typing import List, Any
  7. from numbers import Number
  8. from collections import defaultdict
  9. from torch.utils._python_dispatch import TorchDispatchMode
  10.  
  11. aten = torch.ops.aten
  12.  
  13. def get_shape(i):
  14. return i.shape
  15.  
  16. def prod(x):
  17. res = 1
  18. for i in x:
  19. res *= i
  20. return res
  21.  
  22. def matmul_flop(inputs: List[Any], outputs: List[Any]) -> Number:
  23. """
  24. Count flops for matmul.
  25. """
  26. # Inputs should be a list of length 2.
  27. # Inputs contains the shapes of two matrices.
  28. input_shapes = [get_shape(v) for v in inputs]
  29. assert len(input_shapes) == 2, input_shapes
  30. assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
  31. flop = prod(input_shapes[0]) * input_shapes[-1][-1]
  32. return flop
  33.  
  34. def addmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
  35. """
  36. Count flops for fully connected layers.
  37. """
  38. # Count flop for nn.Linear
  39. # inputs is a list of length 3.
  40. input_shapes = [get_shape(v) for v in inputs[1:3]]
  41. # input_shapes[0]: [batch size, input feature dimension]
  42. # input_shapes[1]: [batch size, output feature dimension]
  43. assert len(input_shapes[0]) == 2, input_shapes[0]
  44. assert len(input_shapes[1]) == 2, input_shapes[1]
  45. batch_size, input_dim = input_shapes[0]
  46. output_dim = input_shapes[1][1]
  47. flops = batch_size * input_dim * output_dim
  48. return flops
  49.  
  50. def bmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
  51. """
  52. Count flops for the bmm operation.
  53. """
  54. # Inputs should be a list of length 2.
  55. # Inputs contains the shapes of two tensor.
  56. assert len(inputs) == 2, len(inputs)
  57. input_shapes = [get_shape(v) for v in inputs]
  58. n, c, t = input_shapes[0]
  59. d = input_shapes[-1][-1]
  60. flop = n * c * t * d
  61. return flop
  62.  
  63. def conv_flop_count(
  64. x_shape: List[int],
  65. w_shape: List[int],
  66. out_shape: List[int],
  67. transposed: bool = False,
  68. ) -> Number:
  69. """
  70. Count flops for convolution. Note only multiplication is
  71. counted. Computation for addition and bias is ignored.
  72. Flops for a transposed convolution are calculated as
  73. flops = (x_shape[2:] * prod(w_shape) * batch_size).
  74. Args:
  75. x_shape (list(int)): The input shape before convolution.
  76. w_shape (list(int)): The filter shape.
  77. out_shape (list(int)): The output shape after convolution.
  78. transposed (bool): is the convolution transposed
  79. Returns:
  80. int: the number of flops
  81. """
  82. batch_size = x_shape[0]
  83. conv_shape = (x_shape if transposed else out_shape)[2:]
  84. flop = batch_size * prod(w_shape) * prod(conv_shape)
  85. return flop
  86.  
  87. def conv_flop(inputs: List[Any], outputs: List[Any]):
  88. """
  89. Count flops for convolution.
  90. """
  91. x, w = inputs[:2]
  92. x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
  93. transposed = inputs[6]
  94.  
  95. return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
  96.  
  97. def transpose_shape(shape):
  98. return [shape[1], shape[0]] + list(shape[2:])
  99.  
  100. def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
  101. grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
  102. output_mask = inputs[-1]
  103. fwd_transposed = inputs[7]
  104. flop_count = 0
  105.  
  106. if output_mask[0]:
  107. grad_input_shape = get_shape(outputs[0])
  108. flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
  109. if output_mask[1]:
  110. grad_weight_shape = get_shape(outputs[1])
  111. flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
  112.  
  113. return flop_count
  114.  
  115.  
  116. flop_mapping = {
  117. aten.mm: matmul_flop,
  118. aten.matmul: matmul_flop,
  119. aten.addmm: addmm_flop,
  120. aten.bmm: bmm_flop,
  121. aten.convolution: conv_flop,
  122. aten._convolution: conv_flop,
  123. aten.convolution_backward: conv_backward_flop,
  124. }
  125.  
  126. def normalize_tuple(x):
  127. if not isinstance(x, tuple):
  128. return (x,)
  129. return x
  130.  
  131. class FlopCounterMode(TorchDispatchMode):
  132. def __init__(self, module = None):
  133. self.flop_counts = defaultdict(lambda: defaultdict(int))
  134. self.parents = ['Global']
  135. if module is not None:
  136. for name, module in dict(mod.named_children()).items():
  137. module.register_forward_pre_hook(self.enter_module(name))
  138. module.register_forward_hook(self.exit_module(name))
  139.  
  140. def enter_module(self, name):
  141. def f(module, inputs):
  142. self.parents.append(name)
  143. inputs = normalize_tuple(inputs)
  144. out = self.create_backwards_pop(name)(*inputs)
  145. return out
  146.  
  147. return f
  148.  
  149. def exit_module(self, name):
  150. def f(module, inputs, outputs):
  151. assert(self.parents[-1] == name)
  152. self.parents.pop()
  153. outputs = normalize_tuple(outputs)
  154. return self.create_backwards_push(name)(*outputs)
  155. return f
  156.  
  157. def create_backwards_push(self, name):
  158. class PushState(torch.autograd.Function):
  159. @staticmethod
  160. def forward(ctx, *args):
  161. args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
  162. if len(args) == 1:
  163. return args[0]
  164. return args
  165.  
  166. @staticmethod
  167. def backward(ctx, *grad_outs):
  168. self.parents.append(name)
  169. return grad_outs
  170.  
  171. return PushState.apply
  172.  
  173. def create_backwards_pop(self, name):
  174. class PopState(torch.autograd.Function):
  175. @staticmethod
  176. def forward(ctx, *args):
  177. args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
  178. if len(args) == 1:
  179. return args[0]
  180. return args
  181.  
  182. @staticmethod
  183. def backward(ctx, *grad_outs):
  184. assert(self.parents[-1] == name)
  185. self.parents.pop()
  186. return grad_outs
  187.  
  188. return PopState.apply
  189.  
  190.  
  191. def __enter__(self):
  192. self.flop_counts.clear()
  193. super().__enter__()
  194.  
  195. def __exit__(self, *args):
  196. print(f"Total: {sum(self.flop_counts['Global'].values())/1e9 } GFLOPS")
  197. for mod in self.flop_counts.keys():
  198. print(f"Module: ", mod)
  199. for k,v in self.flop_counts[mod].items():
  200. print(f"{k}: {v/1e9} GFLOPS")
  201. print()
  202. super().__exit__(*args)
  203.  
  204. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  205. kwargs = kwargs if kwargs else {}
  206.  
  207. out = func(*args, **kwargs)
  208. func_packet = func._overloadpacket
  209. if func_packet in flop_mapping:
  210. flop_count = flop_mapping[func_packet](args, normalize_tuple(out))
  211. for par in self.parents:
  212. self.flop_counts[par][func_packet] += flop_count
  213.  
  214. return out
  215.  
  216.  
  217. import torchvision.models as models
  218.  
  219. inp = torch.randn(8, 3, 224, 224, device='cuda')
  220. mod = models.resnet18().cuda()
  221. flop_counter = FlopCounterMode(mod)
  222. with flop_counter:
  223. mod(inp).sum().backward()
  224.  
  225. with flop_counter:
  226. mod(inp).sum().backward()
  227. exit(0)
  228.  
  229. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  230. from torch._subclasses import FakeTensorMode
  231. shape_env = ShapeEnv()
  232. fake_mode = FakeTensorMode(shape_env=shape_env)
  233.  
  234. with fake_mode:
  235. inp = fake_mode.from_tensor(inp)
  236. assert inp.shape[0] == 1
  237. mod = models.resnet18()
  238. flop_counter = FlopCounterMode(mod)
  239. with flop_counter:
  240. with torch.no_grad():
  241. mod(inp)
  242.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement