Guest User

Untitled

a guest
Jun 21st, 2018
64
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.43 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import torch"
  10. ]
  11. },
  12. {
  13. "cell_type": "code",
  14. "execution_count": 2,
  15. "metadata": {},
  16. "outputs": [],
  17. "source": [
  18. "def maha1(L, x):\n",
  19. " return (torch.trtrs(x.unsqueeze(-1), L, upper=False)[0]).squeeze(-1).pow(2).sum(-1)\n",
  20. "\n",
  21. "def maha2(L, x):\n",
  22. " return torch.inverse(L).matmul(x).pow(2).sum(-1)"
  23. ]
  24. },
  25. {
  26. "cell_type": "markdown",
  27. "metadata": {},
  28. "source": [
  29. "### equality"
  30. ]
  31. },
  32. {
  33. "cell_type": "code",
  34. "execution_count": 3,
  35. "metadata": {},
  36. "outputs": [
  37. {
  38. "data": {
  39. "text/plain": [
  40. "tensor(1.00000e-04 *\n",
  41. " 6.1035)"
  42. ]
  43. },
  44. "execution_count": 3,
  45. "metadata": {},
  46. "output_type": "execute_result"
  47. }
  48. ],
  49. "source": [
  50. "n = 5\n",
  51. "L = torch.tensor(torch.randn(n, n).exp().tril(), requires_grad=True)\n",
  52. "x = torch.randn(n, requires_grad=True)\n",
  53. "(maha1(L, x) - maha2(L, x)).abs().sum()"
  54. ]
  55. },
  56. {
  57. "cell_type": "markdown",
  58. "metadata": {},
  59. "source": [
  60. "### forward"
  61. ]
  62. },
  63. {
  64. "cell_type": "code",
  65. "execution_count": 4,
  66. "metadata": {},
  67. "outputs": [],
  68. "source": [
  69. "n = 2000\n",
  70. "L = torch.tensor(torch.randn(n, n).exp().tril(), requires_grad=True)\n",
  71. "x = torch.randn(n, requires_grad=True)"
  72. ]
  73. },
  74. {
  75. "cell_type": "code",
  76. "execution_count": 5,
  77. "metadata": {},
  78. "outputs": [
  79. {
  80. "name": "stdout",
  81. "output_type": "stream",
  82. "text": [
  83. "4.19 ms ± 9.13 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
  84. ]
  85. }
  86. ],
  87. "source": [
  88. "%timeit maha1(L, x)"
  89. ]
  90. },
  91. {
  92. "cell_type": "code",
  93. "execution_count": 6,
  94. "metadata": {},
  95. "outputs": [
  96. {
  97. "name": "stdout",
  98. "output_type": "stream",
  99. "text": [
  100. "55.7 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
  101. ]
  102. }
  103. ],
  104. "source": [
  105. "%timeit maha2(L, x)"
  106. ]
  107. },
  108. {
  109. "cell_type": "markdown",
  110. "metadata": {},
  111. "source": [
  112. "### gpu"
  113. ]
  114. },
  115. {
  116. "cell_type": "code",
  117. "execution_count": 7,
  118. "metadata": {},
  119. "outputs": [],
  120. "source": [
  121. "Lc = L.cuda()\n",
  122. "xc = x.cuda()"
  123. ]
  124. },
  125. {
  126. "cell_type": "code",
  127. "execution_count": 8,
  128. "metadata": {},
  129. "outputs": [
  130. {
  131. "name": "stdout",
  132. "output_type": "stream",
  133. "text": [
  134. "1.07 ms ± 365 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
  135. ]
  136. }
  137. ],
  138. "source": [
  139. "%timeit maha1(Lc, xc)"
  140. ]
  141. },
  142. {
  143. "cell_type": "code",
  144. "execution_count": 9,
  145. "metadata": {},
  146. "outputs": [
  147. {
  148. "name": "stdout",
  149. "output_type": "stream",
  150. "text": [
  151. "22.8 ms ± 206 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
  152. ]
  153. }
  154. ],
  155. "source": [
  156. "%timeit maha2(Lc, xc)"
  157. ]
  158. },
  159. {
  160. "cell_type": "markdown",
  161. "metadata": {},
  162. "source": [
  163. "### backward"
  164. ]
  165. },
  166. {
  167. "cell_type": "code",
  168. "execution_count": 10,
  169. "metadata": {},
  170. "outputs": [],
  171. "source": [
  172. "n = 1000\n",
  173. "L = torch.tensor(torch.randn(n, n).exp().tril(), dtype=torch.float64, requires_grad=True)\n",
  174. "x = torch.randn(n, dtype=torch.float64, requires_grad=True)\n",
  175. "m1 = maha1(L, x)\n",
  176. "m2 = maha2(L, x)"
  177. ]
  178. },
  179. {
  180. "cell_type": "code",
  181. "execution_count": 11,
  182. "metadata": {},
  183. "outputs": [
  184. {
  185. "name": "stdout",
  186. "output_type": "stream",
  187. "text": [
  188. "2.94 ms ± 99.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
  189. ]
  190. }
  191. ],
  192. "source": [
  193. "%timeit torch.autograd.grad(m1, (L, x), retain_graph=True)"
  194. ]
  195. },
  196. {
  197. "cell_type": "code",
  198. "execution_count": 12,
  199. "metadata": {},
  200. "outputs": [
  201. {
  202. "name": "stdout",
  203. "output_type": "stream",
  204. "text": [
  205. "24.1 ms ± 434 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
  206. ]
  207. }
  208. ],
  209. "source": [
  210. "%timeit torch.autograd.grad(m2, (L, x), retain_graph=True)"
  211. ]
  212. }
  213. ],
  214. "metadata": {
  215. "kernelspec": {
  216. "display_name": "Python 3",
  217. "language": "python",
  218. "name": "python3"
  219. },
  220. "language_info": {
  221. "codemirror_mode": {
  222. "name": "ipython",
  223. "version": 3
  224. },
  225. "file_extension": ".py",
  226. "mimetype": "text/x-python",
  227. "name": "python",
  228. "nbconvert_exporter": "python",
  229. "pygments_lexer": "ipython3",
  230. "version": "3.5.5"
  231. }
  232. },
  233. "nbformat": 4,
  234. "nbformat_minor": 2
  235. }
Add Comment
Please, Sign In to add comment