Advertisement
Guest User

Untitled

a guest
Aug 25th, 2019
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.37 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import torch\n",
  10. "import torch.nn.functional as F"
  11. ]
  12. },
  13. {
  14. "cell_type": "code",
  15. "execution_count": 2,
  16. "metadata": {},
  17. "outputs": [],
  18. "source": [
  19. "# Input data, in NCHW layout\n",
  20. "\n",
  21. "# batch size\n",
  22. "n = 32\n",
  23. "# input channels\n",
  24. "c = 8\n",
  25. "# height\n",
  26. "h = 224\n",
  27. "# width\n",
  28. "w = 224\n",
  29. "\n",
  30. "data = torch.randn(n, c, h, w)"
  31. ]
  32. },
  33. {
  34. "cell_type": "code",
  35. "execution_count": 3,
  36. "metadata": {},
  37. "outputs": [],
  38. "source": [
  39. "# filter, r=s=1 since it's a 1x1 convolution\n",
  40. "\n",
  41. "# output channels\n",
  42. "k = 16\n",
  43. "# filter height\n",
  44. "r = 1\n",
  45. "# filter width\n",
  46. "s = 1\n",
  47. "\n",
  48. "weights = torch.randn(k, c, r, s)"
  49. ]
  50. },
  51. {
  52. "cell_type": "code",
  53. "execution_count": 4,
  54. "metadata": {},
  55. "outputs": [],
  56. "source": [
  57. "out_conv2d = F.conv2d(data, weights)"
  58. ]
  59. },
  60. {
  61. "cell_type": "code",
  62. "execution_count": 5,
  63. "metadata": {},
  64. "outputs": [
  65. {
  66. "data": {
  67. "text/plain": [
  68. "torch.Size([32, 16, 224, 224])"
  69. ]
  70. },
  71. "execution_count": 5,
  72. "metadata": {},
  73. "output_type": "execute_result"
  74. }
  75. ],
  76. "source": [
  77. "out_conv2d.shape"
  78. ]
  79. },
  80. {
  81. "cell_type": "code",
  82. "execution_count": 6,
  83. "metadata": {},
  84. "outputs": [
  85. {
  86. "data": {
  87. "text/plain": [
  88. "torch.Size([16, 8, 1, 1])"
  89. ]
  90. },
  91. "execution_count": 6,
  92. "metadata": {},
  93. "output_type": "execute_result"
  94. }
  95. ],
  96. "source": [
  97. "weights.shape"
  98. ]
  99. },
  100. {
  101. "cell_type": "code",
  102. "execution_count": 7,
  103. "metadata": {},
  104. "outputs": [],
  105. "source": [
  106. "# change layout from NCHW to NHWC\n",
  107. "# (we'll be multiplying channels)\n",
  108. "data2 = data.permute(0, 2, 3, 1)"
  109. ]
  110. },
  111. {
  112. "cell_type": "code",
  113. "execution_count": 8,
  114. "metadata": {},
  115. "outputs": [
  116. {
  117. "data": {
  118. "text/plain": [
  119. "torch.Size([32, 224, 224, 8])"
  120. ]
  121. },
  122. "execution_count": 8,
  123. "metadata": {},
  124. "output_type": "execute_result"
  125. }
  126. ],
  127. "source": [
  128. "data2.shape"
  129. ]
  130. },
  131. {
  132. "cell_type": "code",
  133. "execution_count": 9,
  134. "metadata": {},
  135. "outputs": [
  136. {
  137. "data": {
  138. "text/plain": [
  139. "torch.Size([8, 16])"
  140. ]
  141. },
  142. "execution_count": 9,
  143. "metadata": {},
  144. "output_type": "execute_result"
  145. }
  146. ],
  147. "source": [
  148. "matmul_weights = weights.squeeze().t()\n",
  149. "matmul_weights.shape"
  150. ]
  151. },
  152. {
  153. "cell_type": "code",
  154. "execution_count": 10,
  155. "metadata": {},
  156. "outputs": [],
  157. "source": [
  158. "out_matmul = torch.matmul(data2, matmul_weights)"
  159. ]
  160. },
  161. {
  162. "cell_type": "code",
  163. "execution_count": 11,
  164. "metadata": {},
  165. "outputs": [
  166. {
  167. "data": {
  168. "text/plain": [
  169. "torch.Size([32, 224, 224, 16])"
  170. ]
  171. },
  172. "execution_count": 11,
  173. "metadata": {},
  174. "output_type": "execute_result"
  175. }
  176. ],
  177. "source": [
  178. "out_matmul.shape"
  179. ]
  180. },
  181. {
  182. "cell_type": "code",
  183. "execution_count": 12,
  184. "metadata": {},
  185. "outputs": [],
  186. "source": [
  187. "# change layout from NHWC to NCHW\n",
  188. "out_matmul = out_matmul.permute(0, 3, 1, 2)"
  189. ]
  190. },
  191. {
  192. "cell_type": "code",
  193. "execution_count": 13,
  194. "metadata": {},
  195. "outputs": [
  196. {
  197. "data": {
  198. "text/plain": [
  199. "torch.Size([32, 16, 224, 224])"
  200. ]
  201. },
  202. "execution_count": 13,
  203. "metadata": {},
  204. "output_type": "execute_result"
  205. }
  206. ],
  207. "source": [
  208. "out_matmul.shape"
  209. ]
  210. },
  211. {
  212. "cell_type": "code",
  213. "execution_count": 14,
  214. "metadata": {},
  215. "outputs": [
  216. {
  217. "data": {
  218. "text/plain": [
  219. "tensor(3.8147e-06)"
  220. ]
  221. },
  222. "execution_count": 14,
  223. "metadata": {},
  224. "output_type": "execute_result"
  225. }
  226. ],
  227. "source": [
  228. "torch.abs(out_matmul - out_conv2d).max()"
  229. ]
  230. }
  231. ],
  232. "metadata": {
  233. "kernelspec": {
  234. "display_name": "Python 3",
  235. "language": "python",
  236. "name": "python3"
  237. },
  238. "language_info": {
  239. "codemirror_mode": {
  240. "name": "ipython",
  241. "version": 3
  242. },
  243. "file_extension": ".py",
  244. "mimetype": "text/x-python",
  245. "name": "python",
  246. "nbconvert_exporter": "python",
  247. "pygments_lexer": "ipython3",
  248. "version": "3.7.0"
  249. }
  250. },
  251. "nbformat": 4,
  252. "nbformat_minor": 2
  253. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement