Guest User

Untitled

a guest
Apr 24th, 2018
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.13 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import unittest\n",
  10. "import torch\n",
  11. "torch.set_default_tensor_type(torch.cuda.float64)"
  12. ]
  13. },
  14. {
  15. "cell_type": "code",
  16. "execution_count": 2,
  17. "metadata": {},
  18. "outputs": [],
  19. "source": [
  20. "class TEST(unittest.TestCase):\n",
  21. " def test_trtrs(self):\n",
  22. " a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),\n",
  23. " (-6.05, -3.30, 5.36, -4.44, 1.08),\n",
  24. " (-0.45, 2.58, -2.70, 0.27, 9.04),\n",
  25. " (8.32, 2.71, 4.35, -7.17, 2.14),\n",
  26. " (-9.67, -5.14, -7.26, 6.08, -6.87))).t()\n",
  27. " b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),\n",
  28. " (-1.56, 4.00, -8.67, 1.75, 2.86),\n",
  29. " (9.81, -4.09, -4.57, -8.61, 8.99))).t()\n",
  30. " \n",
  31. " print(a.dtype)\n",
  32. "\n",
  33. " U = torch.triu(a)\n",
  34. " L = torch.tril(a)\n",
  35. "\n",
  36. " # solve Ux = b\n",
  37. " x = torch.trtrs(b, U)[0]\n",
  38. " self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)\n",
  39. " x = torch.trtrs(b, U, True, False, False)[0]\n",
  40. " self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)\n",
  41. "\n",
  42. " # solve Lx = b\n",
  43. " x = torch.trtrs(b, L, False)[0]\n",
  44. " self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)\n",
  45. " x = torch.trtrs(b, L, False, False, False)[0]\n",
  46. " self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)\n",
  47. "\n",
  48. " # solve U'x = b\n",
  49. " x = torch.trtrs(b, U, True, True)[0]\n",
  50. " self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)\n",
  51. " x = torch.trtrs(b, U, True, True, False)[0]\n",
  52. " self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)\n",
  53. "\n",
  54. " # solve U'x = b by manual transposition\n",
  55. " y = torch.trtrs(b, U.t(), False, False)[0]\n",
  56. " self.assertLessEqual(x.dist(y), 1e-12)\n",
  57. "\n",
  58. " # solve L'x = b\n",
  59. " x = torch.trtrs(b, L, False, True)[0]\n",
  60. " self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)\n",
  61. " x = torch.trtrs(b, L, False, True, False)[0]\n",
  62. " self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)\n",
  63. "\n",
  64. " # solve L'x = b by manual transposition\n",
  65. " y = torch.trtrs(b, L.t(), True, False)[0]\n",
  66. " self.assertLessEqual(x.dist(y), 1e-12)\n",
  67. "\n",
  68. " # test reuse\n",
  69. " res1 = torch.trtrs(b, a)[0]\n",
  70. " ta = torch.Tensor()\n",
  71. " tb = torch.Tensor()\n",
  72. " torch.trtrs(b, a, out=(tb, ta))\n",
  73. " self.assertEqual(res1.dist(tb), 0)\n",
  74. " tb.zero_()\n",
  75. " torch.trtrs(b, a, out=(tb, ta))\n",
  76. " self.assertEqual(res1.dist(tb), 0)"
  77. ]
  78. },
  79. {
  80. "cell_type": "code",
  81. "execution_count": 3,
  82. "metadata": {},
  83. "outputs": [
  84. {
  85. "name": "stdout",
  86. "output_type": "stream",
  87. "text": [
  88. "torch.cuda.float64\n"
  89. ]
  90. },
  91. {
  92. "name": "stderr",
  93. "output_type": "stream",
  94. "text": [
  95. ".\n",
  96. "----------------------------------------------------------------------\n",
  97. "Ran 1 test in 1.827s\n",
  98. "\n",
  99. "OK\n"
  100. ]
  101. },
  102. {
  103. "data": {
  104. "text/plain": [
  105. "<unittest.main.TestProgram at 0x7f230e5e8f98>"
  106. ]
  107. },
  108. "execution_count": 3,
  109. "metadata": {},
  110. "output_type": "execute_result"
  111. }
  112. ],
  113. "source": [
  114. "unittest.main(argv=['first-arg-is-ignored'], exit=False)"
  115. ]
  116. }
  117. ],
  118. "metadata": {
  119. "kernelspec": {
  120. "display_name": "Python (pyro)",
  121. "language": "python",
  122. "name": "pyro"
  123. },
  124. "language_info": {
  125. "codemirror_mode": {
  126. "name": "ipython",
  127. "version": 3
  128. },
  129. "file_extension": ".py",
  130. "mimetype": "text/x-python",
  131. "name": "python",
  132. "nbconvert_exporter": "python",
  133. "pygments_lexer": "ipython3",
  134. "version": "3.5.4"
  135. }
  136. },
  137. "nbformat": 4,
  138. "nbformat_minor": 2
  139. }
Add Comment
Please, Sign In to add comment