Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "def maha1(L, x):\n",
- " return (torch.trtrs(x.unsqueeze(-1), L, upper=False)[0]).squeeze(-1).pow(2).sum(-1)\n",
- "\n",
- "def maha2(L, x):\n",
- " return torch.inverse(L).matmul(x).pow(2).sum(-1)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### equality"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor(1.00000e-04 *\n",
- " 6.1035)"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "n = 5\n",
- "L = torch.tensor(torch.randn(n, n).exp().tril(), requires_grad=True)\n",
- "x = torch.randn(n, requires_grad=True)\n",
- "(maha1(L, x) - maha2(L, x)).abs().sum()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### forward"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "n = 2000\n",
- "L = torch.tensor(torch.randn(n, n).exp().tril(), requires_grad=True)\n",
- "x = torch.randn(n, requires_grad=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "4.19 ms ± 9.13 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
- ]
- }
- ],
- "source": [
- "%timeit maha1(L, x)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "55.7 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
- ]
- }
- ],
- "source": [
- "%timeit maha2(L, x)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### gpu"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "Lc = L.cuda()\n",
- "xc = x.cuda()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "1.07 ms ± 365 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
- ]
- }
- ],
- "source": [
- "%timeit maha1(Lc, xc)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "22.8 ms ± 206 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
- ]
- }
- ],
- "source": [
- "%timeit maha2(Lc, xc)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### backward"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "n = 1000\n",
- "L = torch.tensor(torch.randn(n, n).exp().tril(), dtype=torch.float64, requires_grad=True)\n",
- "x = torch.randn(n, dtype=torch.float64, requires_grad=True)\n",
- "m1 = maha1(L, x)\n",
- "m2 = maha2(L, x)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "2.94 ms ± 99.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
- ]
- }
- ],
- "source": [
- "%timeit torch.autograd.grad(m1, (L, x), retain_graph=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "24.1 ms ± 434 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
- ]
- }
- ],
- "source": [
- "%timeit torch.autograd.grad(m2, (L, x), retain_graph=True)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.5.5"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
Add Comment
Please, Sign In to add comment