Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import unittest\n",
- "import torch\n",
- "torch.set_default_tensor_type(torch.cuda.float64)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "class TEST(unittest.TestCase):\n",
- " def test_trtrs(self):\n",
- " a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),\n",
- " (-6.05, -3.30, 5.36, -4.44, 1.08),\n",
- " (-0.45, 2.58, -2.70, 0.27, 9.04),\n",
- " (8.32, 2.71, 4.35, -7.17, 2.14),\n",
- " (-9.67, -5.14, -7.26, 6.08, -6.87))).t()\n",
- " b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03),\n",
- " (-1.56, 4.00, -8.67, 1.75, 2.86),\n",
- " (9.81, -4.09, -4.57, -8.61, 8.99))).t()\n",
- " \n",
- " print(a.dtype)\n",
- "\n",
- " U = torch.triu(a)\n",
- " L = torch.tril(a)\n",
- "\n",
- " # solve Ux = b\n",
- " x = torch.trtrs(b, U)[0]\n",
- " self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)\n",
- " x = torch.trtrs(b, U, True, False, False)[0]\n",
- " self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)\n",
- "\n",
- " # solve Lx = b\n",
- " x = torch.trtrs(b, L, False)[0]\n",
- " self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)\n",
- " x = torch.trtrs(b, L, False, False, False)[0]\n",
- " self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)\n",
- "\n",
- " # solve U'x = b\n",
- " x = torch.trtrs(b, U, True, True)[0]\n",
- " self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)\n",
- " x = torch.trtrs(b, U, True, True, False)[0]\n",
- " self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)\n",
- "\n",
- " # solve U'x = b by manual transposition\n",
- " y = torch.trtrs(b, U.t(), False, False)[0]\n",
- " self.assertLessEqual(x.dist(y), 1e-12)\n",
- "\n",
- " # solve L'x = b\n",
- " x = torch.trtrs(b, L, False, True)[0]\n",
- " self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)\n",
- " x = torch.trtrs(b, L, False, True, False)[0]\n",
- " self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)\n",
- "\n",
- " # solve L'x = b by manual transposition\n",
- " y = torch.trtrs(b, L.t(), True, False)[0]\n",
- " self.assertLessEqual(x.dist(y), 1e-12)\n",
- "\n",
- " # test reuse\n",
- " res1 = torch.trtrs(b, a)[0]\n",
- " ta = torch.Tensor()\n",
- " tb = torch.Tensor()\n",
- " torch.trtrs(b, a, out=(tb, ta))\n",
- " self.assertEqual(res1.dist(tb), 0)\n",
- " tb.zero_()\n",
- " torch.trtrs(b, a, out=(tb, ta))\n",
- " self.assertEqual(res1.dist(tb), 0)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "torch.cuda.float64\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- ".\n",
- "----------------------------------------------------------------------\n",
- "Ran 1 test in 1.827s\n",
- "\n",
- "OK\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "<unittest.main.TestProgram at 0x7f230e5e8f98>"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "unittest.main(argv=['first-arg-is-ignored'], exit=False)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python (pyro)",
- "language": "python",
- "name": "pyro"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.5.4"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
Add Comment
Please, Sign In to add comment