Guest User

Untitled

a guest
Nov 22nd, 2017
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 17.04 KB | None | 0 0
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torchvision
  5.  
  6. REGISTERED_OUTPUT_SHAPE_TYPES = []
  7.  
  8.  
  9. def compute_type(type):
  10. def _wrap(func):
  11. REGISTERED_OUTPUT_SHAPE_TYPES.append((type, func))
  12. return func
  13. return _wrap
  14.  
  15.  
  16. class OutputShapeFor(object):
  17.  
  18. math = math # for hacking in sympy
  19.  
  20. def __init__(self, module):
  21. self.module = module
  22. self._func = getattr(module, 'output_shape_for', None)
  23. if self._func is None:
  24. # Lookup shape func if we can't find it
  25. for type, _func in REGISTERED_OUTPUT_SHAPE_TYPES:
  26. try:
  27. if module is type or isinstance(module, type):
  28. self._func = _func
  29. except TypeError:
  30. pass
  31. if not self._func:
  32. raise TypeError('Unknown module type {}'.format(module))
  33.  
  34. def __call__(self, *args, **kwargs):
  35. if isinstance(self.module, nn.Module):
  36. # bound methods dont need module
  37. is_bound = hasattr(self._func, '__func__') and getattr(self._func, '__func__', None) is not None
  38. is_bound |= hasattr(self._func, 'im_func') and getattr(self._func, 'im_func', None) is not None
  39. if is_bound:
  40. output_shape = self._func(*args, **kwargs)
  41. else:
  42. # nn.Module with state
  43. output_shape = self._func(self.module, *args, **kwargs)
  44. else:
  45. # a simple pytorch func
  46. output_shape = self._func(*args, **kwargs)
  47. return output_shape
  48.  
  49. @staticmethod
  50. @compute_type(nn.UpsamplingBilinear2d)
  51. def UpsamplingBilinear2d(module, input_shape):
  52. """
  53. - Input: :math:`(N, C, H_{in}, W_{in})`
  54. - Output: :math:`(N, C, H_{out}, W_{out})` where
  55. :math:`H_{out} = floor(H_{in} * scale\_factor)`
  56. :math:`W_{out} = floor(W_{in} * scale\_factor)`
  57.  
  58. Example:
  59. >>> from pysseg.torch.models.output_shape_for import *
  60. >>> input_shape = (1, 3, 256, 256)
  61. >>> module = nn.UpsamplingBilinear2d(scale_factor=2)
  62. >>> output_shape = OutputShapeFor(module)(input_shape)
  63. >>> print('output_shape = {!r}'.format(output_shape))
  64. output_shape = (1, 3, 512, 512)
  65. """
  66. math = OutputShapeFor.math
  67. (N, C, H_in, W_in) = input_shape
  68. H_out = math.floor(H_in * module.scale_factor)
  69. W_out = math.floor(W_in * module.scale_factor)
  70. output_shape = (N, C, H_out, W_out)
  71. return output_shape
  72.  
  73. @staticmethod
  74. @compute_type(nn.ConvTranspose2d)
  75. def conv2dT(module, input_shape):
  76. """
  77. - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
  78. - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
  79. :math:`H_{out} = (H_{in} - 1) * stride[0] - 2 * padding[0] + kernel\_size[0] + output\_padding[0]`
  80. :math:`W_{out} = (W_{in} - 1) * stride[1] - 2 * padding[1] + kernel\_size[1] + output\_padding[1]`
  81.  
  82. Example:
  83. >>> from pysseg.torch.models.output_shape_for import *
  84. >>> input_shape = (1, 3, 256, 256)
  85. >>> module = nn.ConvTranspose2d(input_shape[1], 11, kernel_size=2, stride=2)
  86. >>> output_shape = OutputShapeFor(module)(input_shape)
  87. >>> print('output_shape = {!r}'.format(output_shape))
  88. output_shape = (1, 11, 512, 512)
  89. """
  90. (N, C_in, H_in, W_in) = input_shape
  91. C_out = module.out_channels
  92. stride = module.stride
  93. kernel_size = module.kernel_size
  94. output_padding = module.output_padding
  95. padding = module.padding
  96. H_out = (H_in - 1) * stride[0] - 2 * padding[0] + kernel_size[0] + output_padding[0]
  97. W_out = (W_in - 1) * stride[1] - 2 * padding[1] + kernel_size[1] + output_padding[1]
  98. output_shape = (N, C_out, H_out, W_out)
  99. return output_shape
  100.  
  101. @staticmethod
  102. @compute_type(nn.Conv2d)
  103. def conv2d(module, input_shape):
  104. """
  105. Notes:
  106. - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
  107. - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
  108. :math:`H_{out} = floor((H_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
  109. :math:`W_{out} = floor((W_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
  110.  
  111. Example:
  112. >>> from pysseg.torch.models.output_shape_for import *
  113. >>> input_shape = (1, 3, 256, 256)
  114. >>> module = nn.Conv2d(input_shape[1], 11, 3, 1, 0)
  115. >>> output_shape = OutputShapeFor(module)(input_shape)
  116. >>> print('output_shape = {!r}'.format(output_shape))
  117. output_shape = (1, 11, 254, 254)
  118. """
  119. math = OutputShapeFor.math
  120. N, C_in, H_in, W_in = input_shape
  121. C_out = module.out_channels
  122. padding = module.padding
  123. stride = module.stride
  124. dilation = module.dilation
  125. kernel_size = module.kernel_size
  126. H_out = math.floor((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
  127. W_out = math.floor((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
  128. output_shape = (N, C_out, H_out, W_out)
  129. return output_shape
  130.  
  131. @staticmethod
  132. @compute_type(nn.Conv3d)
  133. def conv3d(module, input_shape):
  134. """
  135. Notes:
  136. - Input: :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`
  137. - Output: :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})` where
  138. :math:`D_{out} = floor((D_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
  139. :math:`H_{out} = floor((H_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
  140. :math:`W_{out} = floor((W_{in} + 2 * padding[2] - dilation[2] * (kernel\_size[2] - 1) - 1) / stride[2] + 1)`
  141.  
  142. Example:
  143. >>> from pysseg.torch.models.output_shape_for import *
  144. >>> input_shape = (1, 3, 25, 32, 32)
  145. >>> module = nn.Conv3d(in_channels=input_shape[1], out_channels=11,
  146. >>> kernel_size=(3, 3, 3), stride=1, padding=0,
  147. >>> dilation=1, groups=1, bias=True)
  148. >>> output_shape = OutputShapeFor(module)(input_shape)
  149. >>> print('output_shape = {!r}'.format(output_shape))
  150. output_shape = (1, 11, 23, 30, 30)
  151. """
  152. math = OutputShapeFor.math
  153. N, C_in, D_in, H_in, W_in = input_shape
  154. C_out = module.out_channels
  155. padding = module.padding
  156. stride = module.stride
  157. dilation = module.dilation
  158. kernel_size = module.kernel_size
  159.  
  160. D_out = math.floor((D_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
  161. H_out = math.floor((H_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
  162. W_out = math.floor((W_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1)
  163.  
  164. output_shape = (N, C_out, D_out, H_out, W_out)
  165. return output_shape
  166.  
  167. @staticmethod
  168. @compute_type(nn.Conv3d)
  169. def max_pool3d(module, input_shape):
  170. """
  171. - Input: :math:`(N, C, D_{in}, H_{in}, W_{in})`
  172. - Output: :math:`(N, C, D_{out}, H_{out}, W_{out})` where
  173. :math:`D_{out} = floor((D_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
  174. :math:`H_{out} = floor((H_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
  175. :math:`W_{out} = floor((W_{in} + 2 * padding[2] - dilation[2] * (kernel\_size[2] - 1) - 1) / stride[2] + 1)`
  176. """
  177.  
  178. math = OutputShapeFor.math
  179. N, C_in, D_in, H_in, W_in = input_shape
  180. padding = module.padding
  181. stride = module.stride
  182. dilation = module.dilation
  183. kernel_size = module.kernel_size
  184.  
  185. D_out = math.floor((D_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
  186. H_out = math.floor((H_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
  187. W_out = math.floor((W_in + 2 * padding[2] - dilation[2] * (kernel_size[2] - 1) - 1) / stride[2] + 1)
  188.  
  189. output_shape = (N, C_in, D_out, H_out, W_out)
  190. return output_shape
  191.  
  192. @staticmethod
  193. @compute_type(torch.cat)
  194. def cat(input_shapes, dim=0):
  195. """
  196.  
  197. Example:
  198. >>> from pysseg.torch.models.output_shape_for import *
  199. >>> input_shape1 = (1, 3, 256, 256)
  200. >>> input_shape2 = (1, 4, 256, 256)
  201. >>> input_shapes = [input_shape1, input_shape2]
  202. >>> output_shape = OutputShapeFor(torch.cat)(input_shapes, dim=1)
  203. >>> print('output_shape = {!r}'.format(output_shape))
  204. output_shape = [1, 7, 256, 256]
  205. """
  206. n_dims = max(map(len, input_shapes))
  207. assert n_dims == min(map(len, input_shapes))
  208. output_shape = [None] * n_dims
  209. for shape in input_shapes:
  210. for i, v in enumerate(shape):
  211. if output_shape[i] is None:
  212. output_shape[i] = v
  213. else:
  214. if i == dim:
  215. output_shape[i] += v
  216. else:
  217. assert output_shape[i] == v, 'inconsistent dims'
  218. return output_shape
  219.  
  220. @staticmethod
  221. @compute_type(nn.MaxPool2d)
  222. def maxpool2(module, input_shape):
  223. """
  224.  
  225. Example:
  226. >>> from pysseg.torch.models.output_shape_for import *
  227. >>> input_shape = (1, 3, 256, 256)
  228. >>> module = nn.MaxPool2d(kernel_size=2)
  229. >>> output_shape = OutputShapeFor(module)(input_shape)
  230. >>> print('output_shape = {!r}'.format(output_shape))
  231. output_shape = [1, 7, 256, 256]
  232.  
  233. Shape:
  234. Same as conv2 forumla except C2 = C1
  235. - Input: :math:`(N, C, H_{in}, W_{in})`
  236. - Output: :math:`(N, C, H_{out}, W_{out})` where
  237. :math:`H_{out} = floor((H_{in} + 2 * padding[0] - dilation[0] * (kernel\_size[0] - 1) - 1) / stride[0] + 1)`
  238. :math:`W_{out} = floor((W_{in} + 2 * padding[1] - dilation[1] * (kernel\_size[1] - 1) - 1) / stride[1] + 1)`
  239. """
  240. math = OutputShapeFor.math
  241. N, C, H_in, W_in = input_shape
  242.  
  243. def ensure_iterable2(scalar):
  244. try:
  245. iter(scalar)
  246. except TypeError:
  247. return [scalar] * 2
  248. return scalar
  249.  
  250. padding = ensure_iterable2(module.padding)
  251. stride = ensure_iterable2(module.stride)
  252. dilation = ensure_iterable2(module.dilation)
  253. kernel_size = ensure_iterable2(module.kernel_size)
  254.  
  255. H_out = math.floor((H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
  256. W_out = math.floor((W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1)
  257. output_shape = (N, C, H_out, W_out)
  258. return output_shape
  259.  
  260. @staticmethod
  261. @compute_type(nn.AvgPool2d)
  262. def avepool2d(module, input_shape):
  263. """
  264. Shape:
  265. - Input: :math:`(N, C, H_{in}, W_{in})`
  266. - Output: :math:`(N, C, H_{out}, W_{out})` where
  267. :math:`H_{out} = floor((H_{in} + 2 * padding[0] - kernel\_size[0]) / stride[0] + 1)`
  268. :math:`W_{out} = floor((W_{in} + 2 * padding[1] - kernel\_size[1]) / stride[1] + 1)`
  269. """
  270. math = OutputShapeFor.math
  271. N, C, H_in, W_in = input_shape
  272.  
  273. def ensure_iterable2(scalar):
  274. try:
  275. iter(scalar)
  276. except TypeError:
  277. return [scalar] * 2
  278. return scalar
  279.  
  280. padding = ensure_iterable2(module.padding)
  281. stride = ensure_iterable2(module.stride)
  282. kernel_size = ensure_iterable2(module.kernel_size)
  283.  
  284. H_out = math.floor((H_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
  285. W_out = math.floor((W_in + 2 * padding[1] - kernel_size[1]) / stride[1] + 1)
  286. output_shape = (N, C, H_out, W_out)
  287. return output_shape
  288.  
  289. @staticmethod
  290. @compute_type(nn.Linear)
  291. def linear(module, input_shape):
  292. """
  293. Shape:
  294. - Input: :math:`(N, *, in\_features)` where `*` means any number of
  295. additional dimensions
  296. - Output: :math:`(N, *, out\_features)` where all but the last dimension
  297. are the same shape as the input.
  298. """
  299. N, *other, in_feat = input_shape
  300. output_shape = [N] + other + [module.out_features]
  301. return output_shape
  302.  
  303. @staticmethod
  304. @compute_type(nn.BatchNorm2d)
  305. def batchnorm(module, input_shape):
  306. return input_shape
  307.  
  308. @staticmethod
  309. @compute_type(nn.ReLU)
  310. def relu(module, input_shape):
  311. return input_shape
  312.  
  313. @staticmethod
  314. @compute_type(nn.LeakyReLU)
  315. def leaky_relu(module, input_shape):
  316. return input_shape
  317.  
  318. @staticmethod
  319. @compute_type(nn.Sequential)
  320. def sequential(module, input_shape):
  321. shape = input_shape
  322. for child in module._modules.values():
  323. shape = OutputShapeFor(child)(shape)
  324. return shape
  325.  
  326. @staticmethod
  327. @compute_type(torchvision.models.resnet.BasicBlock)
  328. def resent_basic_block(module, input_shape):
  329. residual_shape = input_shape
  330. shape = input_shape
  331.  
  332. shape = OutputShapeFor(module.conv1)(shape)
  333. shape = OutputShapeFor(module.bn1)(shape)
  334. shape = OutputShapeFor(module.relu)(shape)
  335.  
  336. shape = OutputShapeFor(module.conv2)(shape)
  337. shape = OutputShapeFor(module.bn2)(shape)
  338. shape = OutputShapeFor(module.relu)(shape)
  339.  
  340. if module.downsample is not None:
  341. residual_shape = OutputShapeFor(module.downsample)(residual_shape)
  342.  
  343. # assert residual_shape[-2:] == shape[-2:], 'cannot add residual {} {}'.format(residual_shape, shape)
  344. # out += residual
  345. shape = OutputShapeFor(module.relu)(shape)
  346. # print('BASIC residual_shape = {!r}'.format(residual_shape[-2:]))
  347. # print('BASIC shape = {!r}'.format(shape[-2:]))
  348. # print('---')
  349. return shape
  350.  
  351. @staticmethod
  352. @compute_type(torchvision.models.resnet.Bottleneck)
  353. def resent_bottleneck(module, input_shape):
  354. residual_shape = input_shape
  355. shape = input_shape
  356.  
  357. shape = OutputShapeFor(module.conv1)(shape)
  358. shape = OutputShapeFor(module.bn1)(shape)
  359. shape = OutputShapeFor(module.relu)(shape)
  360.  
  361. shape = OutputShapeFor(module.conv2)(shape)
  362. shape = OutputShapeFor(module.bn2)(shape)
  363. shape = OutputShapeFor(module.relu)(shape)
  364.  
  365. shape = OutputShapeFor(module.conv3)(shape)
  366. shape = OutputShapeFor(module.bn3)(shape)
  367.  
  368. if module.downsample is not None:
  369. residual_shape = OutputShapeFor(module.downsample)(input_shape)
  370.  
  371. assert residual_shape[-2:] == shape[-2:], 'cannot add residual {} {}'.format(residual_shape, shape)
  372. # out += residual
  373. shape = OutputShapeFor(module.relu)(shape)
  374. # print('bottle downsample = {!r}'.format(module.downsample))
  375. # print('bottle input_shape = {!r}'.format(input_shape[-2:]))
  376. # print('bottle residual_shape = {!r}'.format(residual_shape[-2:]))
  377. # print('bottle shape = {!r}'.format(shape[-2:]))
  378. # print('---')
  379. return shape
  380.  
  381. @staticmethod
  382. @compute_type(torchvision.models.resnet.ResNet)
  383. def resnet_model(module, input_shape):
  384. shape = input_shape
  385. shape = OutputShapeFor(module.conv1)(shape)
  386. shape = OutputShapeFor(module.bn1)(shape)
  387. shape = OutputShapeFor(module.relu)(shape)
  388. shape = OutputShapeFor(module.maxpool)(shape)
  389.  
  390. shape = OutputShapeFor(module.layer1)(shape)
  391. shape = OutputShapeFor(module.layer2)(shape)
  392. shape = OutputShapeFor(module.layer3)(shape)
  393. shape = OutputShapeFor(module.layer4)(shape)
  394.  
  395. shape = OutputShapeFor(module.avgpool)(shape)
  396. print('pre-flatten-shape = {!r}'.format(shape))
  397.  
  398. def prod(args):
  399. result = args[0]
  400. for arg in args[1:]:
  401. result = result * arg
  402. return result
  403. shape = (shape[0], prod(shape[1:]))
  404. # shape = shape.view(shape.size(0), -1)
  405.  
  406. shape = OutputShapeFor(module.fc)(shape)
  407.  
  408. @staticmethod
  409. def resnet_conv_part(module, input_shape):
  410. shape = input_shape
  411. shape = OutputShapeFor(module.conv1)(shape)
  412. shape = OutputShapeFor(module.bn1)(shape)
  413. shape = OutputShapeFor(module.relu)(shape)
  414. shape = OutputShapeFor(module.maxpool)(shape)
  415.  
  416. shape = OutputShapeFor(module.layer1)(shape)
  417. shape = OutputShapeFor(module.layer2)(shape)
  418. shape = OutputShapeFor(module.layer3)(shape)
  419. shape = OutputShapeFor(module.layer4)(shape)
  420.  
  421. shape = OutputShapeFor(module.avgpool)(shape)
  422. # print('pre-flatten-shape = {!r}'.format(shape))
  423.  
  424. def prod(args):
  425. result = args[0]
  426. for arg in args[1:]:
  427. result = result * arg
  428. return result
  429. shape = (shape[0], prod(shape[1:]))
  430. # shape = shape.view(shape.size(0), -1)
  431. return shape
Add Comment
Please, Sign In to add comment