daily pastebin goal
70%
SHARE
TWEET

Untitled

a guest Mar 24th, 2019 73 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. {
  2.  "cells": [
  3.   {
  4.    "cell_type": "code",
  5.    "execution_count": 1,
  6.    "metadata": {},
  7.    "outputs": [],
  8.    "source": [
  9.     "import numpy as np"
  10.    ]
  11.   },
  12.   {
  13.    "cell_type": "markdown",
  14.    "metadata": {},
  15.    "source": [
  16.     "### Helper functions:"
  17.    ]
  18.   },
  19.   {
  20.    "cell_type": "code",
  21.    "execution_count": 2,
  22.    "metadata": {},
  23.    "outputs": [],
  24.    "source": [
  25.     "def ohe(X,idx2word):\n",
  26.     "    ncol = len(idx2word.keys())\n",
  27.     "    nrow = len(X)\n",
  28.     "    OHE_X = np.zeros((nrow,ncol))\n",
  29.     "    for r in range(len(X)):\n",
  30.     "        if not isinstance(X[r],list):\n",
  31.     "            OHE_X[r,X[r]] = 1\n",
  32.     "        else:\n",
  33.     "            row_val = X[r]\n",
  34.     "            for c in row_val:\n",
  35.     "                OHE_X[r,c] = 1\n",
  36.     "                \n",
  37.     "    return OHE_X\n",
  38.     "        \n",
  39.     "\n",
  40.     "    \n",
  41.     "def tokenize(x_list):\n",
  42.     "    #unique tokens:\n",
  43.     "    unique_x = list(set([j for i in data for j in i]))\n",
  44.     "    idx2word = dict(enumerate(unique_x))\n",
  45.     "    word2idx = {i[1]:i[0] for i in idx2word.items()}\n",
  46.     "    # Encode:\n",
  47.     "    tokened_x_list = []\n",
  48.     "    for sentence in x_list:\n",
  49.     "        temp_sent = []\n",
  50.     "        for word in sentence:\n",
  51.     "            token = word2idx.get(word,-1)\n",
  52.     "            temp_sent.append(token)\n",
  53.     "        \n",
  54.     "        tokened_x_list.append(temp_sent)\n",
  55.     "    return tokened_x_list,idx2word,word2idx\n",
  56.     "            \n",
  57.     "\n",
  58.     "    \n",
  59.     "def skipgram_prep(x_list,context_window=2):\n",
  60.     "    \"\"\"\n",
  61.     "    Use Skipgram method to prepare the data.\n",
  62.     "    \n",
  63.     "    Arguments:\n",
  64.     "        x_list(list): tokenized training data\n",
  65.     "        \n",
  66.     "        context_window: the context window on each side. \n",
  67.     "        For example, if context_window=2, we will be looking at 2 tokens on the left and \n",
  68.     "        2 tokens on the right\n",
  69.     "    \n",
  70.     "    Returns:\n",
  71.     "        processd_data(list):  a list of tuples represents the processed data. Each pair of tuple is a (x,y) pair\n",
  72.     "        \n",
  73.     "    \"\"\"\n",
  74.     "    processed_data = []\n",
  75.     "    \n",
  76.     "    for row in x_list:\n",
  77.     "        row_len = len(row)\n",
  78.     "        for i in range(row_len):\n",
  79.     "            x = row[i]\n",
  80.     "            start_idx = max(i-context_window,0)\n",
  81.     "            end_idx = min(row_len,i+context_window+1)\n",
  82.     "            y = row[start_idx:i] + row[i+1:end_idx] # skip the self\n",
  83.     "            \n",
  84.     "            \n",
  85.     "            temp_xy_pair = zip([x]*len(y),y)\n",
  86.     "            processed_data.extend(temp_xy_pair)\n",
  87.     "    \n",
  88.     "    return processed_data\n",
  89.     "            \n",
  90.     "            \n",
  91.     "\n",
  92.     "        \n",
  93.     "def softmax(x):\n",
  94.     "    e_x = np.exp(x)\n",
  95.     "    return np.divide(e_x,e_x.sum(axis=1).reshape(-1,1))\n",
  96.     "            \n",
  97.     "            \n",
  98.     "    \n",
  99.     "    \n",
  100.     "            \n",
  101.     "    \n",
  102.     "    "
  103.    ]
  104.   },
  105.   {
  106.    "cell_type": "code",
  107.    "execution_count": 3,
  108.    "metadata": {},
  109.    "outputs": [],
  110.    "source": [
  111.     "\n",
  112.     "data = [\n",
  113.     "    'apple banana are delicious food',\n",
  114.     "    'video game go play in game studio',\n",
  115.     "    'lunch food is fruit apple banana icecream',\n",
  116.     "    'warcraft or starcraft or overwatch best game',\n",
  117.     "    'chocolate or banana or icecream the most delicious food',\n",
  118.     "    'banana apple smoothie is the best for lunch or dinner',\n",
  119.     "    'video game is good for geeks',\n",
  120.     "    'what to eat for dinner banana or chocolate',\n",
  121.     "    'which game company is better ubisoft or blizzard',\n",
  122.     "    'play game on ps4 or xbox',\n",
  123.     "    'banana is less sweet icecream is more sweet',\n",
  124.     "    'chocolate icecream taste more delicious than banana'\n",
  125.     "    \n",
  126.     "]\n",
  127.     "\n",
  128.     "data = [i.split(\" \") for i in data]\n"
  129.    ]
  130.   },
  131.   {
  132.    "cell_type": "markdown",
  133.    "metadata": {},
  134.    "source": [
  135.     "#### Tokenization:"
  136.    ]
  137.   },
  138.   {
  139.    "cell_type": "code",
  140.    "execution_count": 4,
  141.    "metadata": {},
  142.    "outputs": [],
  143.    "source": [
  144.     "tokenized_data_list,idx2word,word2idx = tokenize(data)"
  145.    ]
  146.   },
  147.   {
  148.    "cell_type": "markdown",
  149.    "metadata": {},
  150.    "source": [
  151.     "#### Use Skipgram to Prepare the Training Data:"
  152.    ]
  153.   },
  154.   {
  155.    "cell_type": "code",
  156.    "execution_count": 5,
  157.    "metadata": {},
  158.    "outputs": [],
  159.    "source": [
  160.     "prep_data = skipgram_prep(tokenized_data_list,context_window=2)"
  161.    ]
  162.   },
  163.   {
  164.    "cell_type": "code",
  165.    "execution_count": 6,
  166.    "metadata": {},
  167.    "outputs": [],
  168.    "source": [
  169.     "X = [i[0] for i in prep_data]\n",
  170.     "Y = [i[1] for i in prep_data]"
  171.    ]
  172.   },
  173.   {
  174.    "cell_type": "markdown",
  175.    "metadata": {},
  176.    "source": [
  177.     "#### OHE:"
  178.    ]
  179.   },
  180.   {
  181.    "cell_type": "code",
  182.    "execution_count": 7,
  183.    "metadata": {},
  184.    "outputs": [],
  185.    "source": [
  186.     "ohe_X = ohe(X,idx2word)\n",
  187.     "ohe_Y = ohe(Y,idx2word)"
  188.    ]
  189.   },
  190.   {
  191.    "cell_type": "markdown",
  192.    "metadata": {},
  193.    "source": [
  194.     "#### Naive Word2vec Model:\n",
  195.     "\n",
  196.     "First, let's build a naive Word2vec model, means we're gonna use softmax across all vocabularies."
  197.    ]
  198.   },
  199.   {
  200.    "cell_type": "code",
  201.    "execution_count": 8,
  202.    "metadata": {},
  203.    "outputs": [],
  204.    "source": [
  205.     "# Hyper Parameters:\n",
  206.     "N_NEGATIVE = 3\n",
  207.     "LEARNING_RATE = 0.01\n",
  208.     "N_VOCAB = len(idx2word)\n",
  209.     "N_DIM = 16\n",
  210.     "BATCH_SIZE = len(ohe_X)\n",
  211.     "\n",
  212.     "# Weights Initialization:\n",
  213.     "embedding_mat = np.random.normal(size=(N_VOCAB,N_DIM)) \n",
  214.     "dense_w = np.random.normal(size=(N_DIM,N_VOCAB))"
  215.    ]
  216.   },
  217.   {
  218.    "cell_type": "code",
  219.    "execution_count": 9,
  220.    "metadata": {},
  221.    "outputs": [
  222.     {
  223.      "name": "stdout",
  224.      "output_type": "stream",
  225.      "text": [
  226.       "Loss: 9.378688108625873\n",
  227.       "Loss: 2.005298530982752\n",
  228.       "Loss: 1.9282251934894516\n",
  229.       "Loss: 1.915202004986637\n",
  230.       "Loss: 1.9103938902718987\n",
  231.       "Loss: 1.9079612745844892\n",
  232.       "Loss: 1.9067267756355368\n",
  233.       "Loss: 1.9116364898998748\n",
  234.       "Loss: 1.9093111376997192\n",
  235.       "Loss: 1.9079385961499742\n",
  236.       "Loss: 1.9070468916300318\n",
  237.       "Loss: 1.9064129969572752\n",
  238.       "Loss: 1.9059291852573528\n",
  239.       "Loss: 1.9055409406255481\n",
  240.       "Loss: 1.9052187142658792\n"
  241.      ]
  242.     }
  243.    ],
  244.    "source": [
  245.     "all_loss = []\n",
  246.     "for i in range(1500):\n",
  247.     "    \n",
  248.     "\n",
  249.     "    # forward pass:\n",
  250.     "    input_x = ohe_X\n",
  251.     "    input_y = ohe_Y\n",
  252.     "    x_embedding_layer = input_x.dot(embedding_mat)# query word embedding X\n",
  253.     "#     print(x_embedding_layer.shape)\n",
  254.     "    dense_layer = x_embedding_layer.dot(dense_w)\n",
  255.     "#     print(dense_layer.shape)\n",
  256.     "    output_layer = softmax(dense_layer)\n",
  257.     "\n",
  258.     "    # cross entropy loss:\n",
  259.     "    loss = -np.sum(input_y*np.log(output_layer+1e-9))/BATCH_SIZE # adding smooth term\n",
  260.     "    if i%100==0:\n",
  261.     "        print(f\"Loss: {loss}\")\n",
  262.     "    all_loss.append(loss)\n",
  263.     "#     print('---')\n",
  264.     "    \n",
  265.     "\n",
  266.     "\n",
  267.     "    # Backward Pass\n",
  268.     "\n",
  269.     "    # d_loss/d_dense_layer = d_loss/d_op_layer * d_op_layer/d_dense_layer\n",
  270.     "    d_dense = output_layer - input_y\n",
  271.     "#     print(d_dense.shape)\n",
  272.     "\n",
  273.     "    # d_loss/d_dense_w = d_loss/d_dense_layer * d_dense_layer/d_dense_w\n",
  274.     "    d_dense_w =  d_dense.T.dot(x_embedding_layer).T\n",
  275.     "#     print(d_dense_w.shape)\n",
  276.     "\n",
  277.     "    # d_loss/x_embedding_layer = d_loss/d_dense_layer * d_dense_layer/x_embedding_layer\n",
  278.     "    d_emb_layer =  d_dense.dot(dense_w.T)\n",
  279.     "#     print(d_emb_layer.shape)\n",
  280.     "    # d_loss/d_embedding_mat = d_loss/x_embedding_layer * x_embedding_layer/d_embedding_mat\n",
  281.     "    d_embedding_mat = d_emb_layer.T.dot(input_x)\n",
  282.     "#     print(d_embedding_mat.shape)\n",
  283.     "#     print('~')\n",
  284.     "\n",
  285.     "\n",
  286.     "    \n",
  287.     "    embedding_mat -= LEARNING_RATE*d_embedding_mat.T\n",
  288.     "    dense_w -= LEARNING_RATE*d_dense_w\n",
  289.     "    "
  290.    ]
  291.   },
  292.   {
  293.    "cell_type": "markdown",
  294.    "metadata": {},
  295.    "source": [
  296.     "Now the mini word2vec model is ready, let build the query function to check:"
  297.    ]
  298.   },
  299.   {
  300.    "cell_type": "code",
  301.    "execution_count": 10,
  302.    "metadata": {},
  303.    "outputs": [],
  304.    "source": [
  305.     "def get_word_vector(word ,embedding = embedding_mat,word2idx=word2idx,vector_dim=N_DIM):\n",
  306.     "    \n",
  307.     "    query_id = word2idx.get(word,-1)\n",
  308.     "    if query_id>=0:\n",
  309.     "        return embedding_mat[query_id,:]\n",
  310.     "    else:\n",
  311.     "        return np.zeros((N_DIM,))-999.\n",
  312.     "        \n",
  313.     "    "
  314.    ]
  315.   },
  316.   {
  317.    "cell_type": "code",
  318.    "execution_count": 11,
  319.    "metadata": {},
  320.    "outputs": [
  321.     {
  322.      "data": {
  323.       "text/plain": [
  324.        "array([-1.10479476,  0.69152043,  0.0222014 ,  0.47398416,  0.85253254,\n",
  325.        "        1.29816081,  0.46473506, -0.17165976,  0.02458933, -0.58116457,\n",
  326.        "       -0.40560783,  2.78396632, -0.96417779,  2.04935327,  0.82896536,\n",
  327.        "       -0.92053599])"
  328.       ]
  329.      },
  330.      "execution_count": 11,
  331.      "metadata": {},
  332.      "output_type": "execute_result"
  333.     }
  334.    ],
  335.    "source": [
  336.     "# Try with in vocab word:\n",
  337.     "query_word = 'xbox'\n",
  338.     "get_word_vector(query_word)"
  339.    ]
  340.   },
  341.   {
  342.    "cell_type": "code",
  343.    "execution_count": 12,
  344.    "metadata": {},
  345.    "outputs": [
  346.     {
  347.      "data": {
  348.       "text/plain": [
  349.        "array([-999., -999., -999., -999., -999., -999., -999., -999., -999.,\n",
  350.        "       -999., -999., -999., -999., -999., -999., -999.])"
  351.       ]
  352.      },
  353.      "execution_count": 12,
  354.      "metadata": {},
  355.      "output_type": "execute_result"
  356.     }
  357.    ],
  358.    "source": [
  359.     "# Try with in Out-of_vocabulary word:\n",
  360.     "query_word = 'lol'\n",
  361.     "get_word_vector(query_word)"
  362.    ]
  363.   },
  364.   {
  365.    "cell_type": "markdown",
  366.    "metadata": {},
  367.    "source": [
  368.     "#### Now find the most similar word to our query word:"
  369.    ]
  370.   },
  371.   {
  372.    "cell_type": "code",
  373.    "execution_count": 13,
  374.    "metadata": {},
  375.    "outputs": [],
  376.    "source": [
  377.     "from numpy import dot\n",
  378.     "from numpy.linalg import norm\n",
  379.     "\n",
  380.     "\n",
  381.     "def cosine_sim(vx,vy):\n",
  382.     "    return dot(vx, vy)/(norm(vx)*norm(vy))"
  383.    ]
  384.   },
  385.   {
  386.    "cell_type": "code",
  387.    "execution_count": 14,
  388.    "metadata": {},
  389.    "outputs": [],
  390.    "source": [
  391.     "def find_similar(query_word,word2idx=word2idx):\n",
  392.     "    query_vector = get_word_vector(query_word)\n",
  393.     "    \n",
  394.     "    result = {}\n",
  395.     "    for word in word2idx:\n",
  396.     "        temp_vector = get_word_vector(word)\n",
  397.     "#         print(word)\n",
  398.     "#         print(temp_vector)\n",
  399.     "        sim = cosine_sim(query_vector,temp_vector)\n",
  400.     "        result[word] = sim\n",
  401.     "    \n",
  402.     "    return result\n",
  403.     "        \n"
  404.    ]
  405.   },
  406.   {
  407.    "cell_type": "code",
  408.    "execution_count": 15,
  409.    "metadata": {},
  410.    "outputs": [
  411.     {
  412.      "data": {
  413.       "text/plain": [
  414.        "[('food', 1.0), ('apple', 0.5053182732410727), ('less', 0.34573572640324446)]"
  415.       ]
  416.      },
  417.      "execution_count": 15,
  418.      "metadata": {},
  419.      "output_type": "execute_result"
  420.     }
  421.    ],
  422.    "source": [
  423.     "result = find_similar('food')\n",
  424.     "\n",
  425.     "\n",
  426.     "sorted(list(result.items()),key=lambda x: x[1],reverse=True)[:3]"
  427.    ]
  428.   },
  429.   {
  430.    "cell_type": "code",
  431.    "execution_count": 16,
  432.    "metadata": {},
  433.    "outputs": [
  434.     {
  435.      "data": {
  436.       "text/plain": [
  437.        "[('warcraft', 0.9999999999999999),\n",
  438.        " ('ubisoft', 0.4935534265673933),\n",
  439.        " ('overwatch', 0.467640114458129)]"
  440.       ]
  441.      },
  442.      "execution_count": 16,
  443.      "metadata": {},
  444.      "output_type": "execute_result"
  445.     }
  446.    ],
  447.    "source": [
  448.     "result = find_similar('warcraft')\n",
  449.     "\n",
  450.     "\n",
  451.     "sorted(list(result.items()),key=lambda x: x[1],reverse=True)[:3]"
  452.    ]
  453.   },
  454.   {
  455.    "cell_type": "markdown",
  456.    "metadata": {},
  457.    "source": [
  458.     "### So what's the problem here -- the above algorithm will never work for real world problem!!!\n",
  459.     "\n",
  460.     "We only have 44 vocabularies in this vanilla example. What if we have millions of vocabs? The softmax operation becomes very expensive. To tackle this  issue, several algorithms are proposed, in order to do the approximation of softmax, such as: Hiearchical Softmax, Negative Sampling or NCE.\n",
  461.     "\n"
  462.    ]
  463.   },
  464.   {
  465.    "cell_type": "markdown",
  466.    "metadata": {},
  467.    "source": [
  468.     " Coming Soon...."
  469.    ]
  470.   },
  471.   {
  472.    "cell_type": "markdown",
  473.    "metadata": {},
  474.    "source": [
  475.     "#### Negative Sampling:"
  476.    ]
  477.   },
  478.   {
  479.    "cell_type": "code",
  480.    "execution_count": 17,
  481.    "metadata": {},
  482.    "outputs": [],
  483.    "source": [
  484.     "# def sample_negative(xy_pairs,n_negative,idx2word):\n",
  485.     "\n",
  486.     "#     pos_context = {}\n",
  487.     "#     grand_negative_samples = []\n",
  488.     "#     for x,y in xy_pairs:\n",
  489.     "\n",
  490.     "#         if x not in pos_context:\n",
  491.     "#             good_pair = [i[1] for i in xy_pairs if i[0]==x]\n",
  492.     "#             pos_context[x] = good_pair\n",
  493.     "\n",
  494.     "#         ## Sample:\n",
  495.     "#         temp_neg_samples = []\n",
  496.     "#         while len(temp_neg_samples)< n_negative:\n",
  497.     "#             temp_idx = np.random.choice(list(idx2word.keys()))\n",
  498.     "#             if temp_idx!=x and temp_idx not in pos_context[x]:\n",
  499.     "#                 temp_neg_samples.append(temp_idx)\n",
  500.     "\n",
  501.     "#         grand_negative_samples.append(temp_neg_samples)\n",
  502.     "#     return np.array(grand_negative_samples)\n",
  503.     "    "
  504.    ]
  505.   }
  506.  ],
  507.  "metadata": {
  508.   "kernelspec": {
  509.    "display_name": "Python 3",
  510.    "language": "python",
  511.    "name": "python3"
  512.   },
  513.   "language_info": {
  514.    "codemirror_mode": {
  515.     "name": "ipython",
  516.     "version": 3
  517.    },
  518.    "file_extension": ".py",
  519.    "mimetype": "text/x-python",
  520.    "name": "python",
  521.    "nbconvert_exporter": "python",
  522.    "pygments_lexer": "ipython3",
  523.    "version": "3.6.8"
  524.   }
  525.  },
  526.  "nbformat": 4,
  527.  "nbformat_minor": 2
  528. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top