SHARE
TWEET

Untitled

a guest Oct 17th, 2019 88 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. {
  2.  "cells": [
  3.   {
  4.    "cell_type": "code",
  5.    "execution_count": 2,
  6.    "metadata": {},
  7.    "outputs": [],
  8.    "source": [
  9.     "import numpy as np"
  10.    ]
  11.   },
  12.   {
  13.    "cell_type": "markdown",
  14.    "metadata": {},
  15.    "source": [
  16.     "## theory"
  17.    ]
  18.   },
  19.   {
  20.    "cell_type": "markdown",
  21.    "metadata": {},
  22.    "source": [
  23.     "Let's get the result in this easy case."
  24.    ]
  25.   },
  26.   {
  27.    "cell_type": "markdown",
  28.    "metadata": {},
  29.    "source": [
  30.     "Suppose $z = xW = [W_{s1} ... W_{sD}]$ where $x$ is one-hot encoded and $x_s = 1$. Then we have $\\frac{\\partial J}{\\partial W_{ij}} = \\sum_k{\\frac{\\partial J}{\\partial z_k} \\frac{\\partial z_k}{\\partial W_{ij}}}$."
  31.    ]
  32.   },
  33.   {
  34.    "cell_type": "markdown",
  35.    "metadata": {},
  36.    "source": [
  37.     "This is equal to $0$ if $s \\neq i$. Also $\\frac{\\partial z_k}{\\partial W_{sj}} = 0$ if $k \\neq j$ and $1$ otherwise. So $\\frac{\\partial J}{\\partial W_{sj}} = \\frac{\\partial J}{\\partial z_j}$."
  38.    ]
  39.   },
  40.   {
  41.    "cell_type": "markdown",
  42.    "metadata": {},
  43.    "source": [
  44.     "In other words: $\\frac{\\partial J}{\\partial W} = \\begin{bmatrix}0 & ... & 0\\\\ \\frac{\\partial J}{\\partial z_1} & ... & \\frac{\\partial J}{\\partial z_D} \\\\ 0 & ... & 0 \\end{bmatrix}$."
  45.    ]
  46.   },
  47.   {
  48.    "cell_type": "markdown",
  49.    "metadata": {},
  50.    "source": [
  51.     "So we just need to fill in $s^{th}$ row of our gradient with the upstream gradient. "
  52.    ]
  53.   },
  54.   {
  55.    "cell_type": "markdown",
  56.    "metadata": {},
  57.    "source": [
  58.     "## implementation"
  59.    ]
  60.   },
  61.   {
  62.    "cell_type": "code",
  63.    "execution_count": 3,
  64.    "metadata": {},
  65.    "outputs": [],
  66.    "source": [
  67.     "np.random.seed(42)\n",
  68.     "V, D = 10, 2\n",
  69.     "x = [3]\n",
  70.     "W = np.random.randn(V, D)\n",
  71.     "dW = np.zeros_like(W)\n",
  72.     "dout = np.random.randn(1, D)"
  73.    ]
  74.   },
  75.   {
  76.    "cell_type": "code",
  77.    "execution_count": 4,
  78.    "metadata": {},
  79.    "outputs": [
  80.     {
  81.      "data": {
  82.       "text/plain": [
  83.        "array([[ 1.46564877, -0.2257763 ]])"
  84.       ]
  85.      },
  86.      "execution_count": 4,
  87.      "metadata": {},
  88.      "output_type": "execute_result"
  89.     }
  90.    ],
  91.    "source": [
  92.     "dout"
  93.    ]
  94.   },
  95.   {
  96.    "cell_type": "code",
  97.    "execution_count": 5,
  98.    "metadata": {},
  99.    "outputs": [],
  100.    "source": [
  101.     "np.add.at(dW, x, dout)"
  102.    ]
  103.   },
  104.   {
  105.    "cell_type": "code",
  106.    "execution_count": 6,
  107.    "metadata": {},
  108.    "outputs": [
  109.     {
  110.      "data": {
  111.       "text/plain": [
  112.        "array([[ 0.        ,  0.        ],\n",
  113.        "       [ 0.        ,  0.        ],\n",
  114.        "       [ 0.        ,  0.        ],\n",
  115.        "       [ 1.46564877, -0.2257763 ],\n",
  116.        "       [ 0.        ,  0.        ],\n",
  117.        "       [ 0.        ,  0.        ],\n",
  118.        "       [ 0.        ,  0.        ],\n",
  119.        "       [ 0.        ,  0.        ],\n",
  120.        "       [ 0.        ,  0.        ],\n",
  121.        "       [ 0.        ,  0.        ]])"
  122.       ]
  123.      },
  124.      "execution_count": 6,
  125.      "metadata": {},
  126.      "output_type": "execute_result"
  127.     }
  128.    ],
  129.    "source": [
  130.     "dW"
  131.    ]
  132.   },
  133.   {
  134.    "cell_type": "markdown",
  135.    "metadata": {},
  136.    "source": [
  137.     "This concludes our short analysis in this simple case."
  138.    ]
  139.   }
  140.  ],
  141.  "metadata": {
  142.   "kernelspec": {
  143.    "display_name": "Python 3",
  144.    "language": "python",
  145.    "name": "python3"
  146.   },
  147.   "language_info": {
  148.    "codemirror_mode": {
  149.     "name": "ipython",
  150.     "version": 3
  151.    },
  152.    "file_extension": ".py",
  153.    "mimetype": "text/x-python",
  154.    "name": "python",
  155.    "nbconvert_exporter": "python",
  156.    "pygments_lexer": "ipython3",
  157.    "version": "3.6.9"
  158.   }
  159.  },
  160.  "nbformat": 4,
  161.  "nbformat_minor": 2
  162. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top