Guest User

Untitled

a guest
Nov 19th, 2017
99
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.47 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# DCN\n",
  8. "\n",
  9. "1. Sentinel vector $d_{\\emptyset}$, is not included\n",
  10. "2. The graph isn't tested with real data\n",
  11. "3. Dropout wrapper and Multi-cell are not added\n",
  12. "4. I may have fucked up the affinity matrix calculation - need your input on it\n",
  13. "5. The code is extremely non-pythonic - calculated individual equations in DCN paper in a single line\n",
  14. "6. **Decoder** is not added yet"
  15. ]
  16. },
  17. {
  18. "cell_type": "code",
  19. "execution_count": 21,
  20. "metadata": {},
  21. "outputs": [],
  22. "source": [
  23. "import tensorflow as tf\n",
  24. "import numpy as np"
  25. ]
  26. },
  27. {
  28. "cell_type": "markdown",
  29. "metadata": {},
  30. "source": [
  31. "## Dummy Params"
  32. ]
  33. },
  34. {
  35. "cell_type": "code",
  36. "execution_count": 2,
  37. "metadata": {},
  38. "outputs": [],
  39. "source": [
  40. "vocab_size = 100\n",
  41. "dlen = 30 # m\n",
  42. "qlen = 15 # n\n",
  43. "B = 8 # batch size\n",
  44. "hdim = 32 # l\n",
  45. "emb_dim = 32"
  46. ]
  47. },
  48. {
  49. "cell_type": "markdown",
  50. "metadata": {},
  51. "source": [
  52. "## Helpers"
  53. ]
  54. },
  55. {
  56. "cell_type": "code",
  57. "execution_count": 3,
  58. "metadata": {},
  59. "outputs": [],
  60. "source": [
  61. "def shapes(l):\n",
  62. " if type(l) == type((23, 69)) or type(l) == type([]):\n",
  63. " return [ shapes(t) for t in l ]\n",
  64. " elif type(l) == np.ndarray:\n",
  65. " return l.shape\n",
  66. " \n",
  67. "def execute_graph(t, feed_dict=None):\n",
  68. " feed_dict = {\n",
  69. " _document : np.random.randint(0, vocab_size, [B, dlen]),\n",
  70. " _query : np.random.randint(0, vocab_size, [B, qlen]),\n",
  71. " _answer : np.random.randint(0, dlen, [B, ]),\n",
  72. " } if type(feed_dict) == type(None) else feed_dict\n",
  73. " with tf.Session() as sess:\n",
  74. " sess.run(tf.global_variables_initializer())\n",
  75. " return sess.run(t, feed_dict)"
  76. ]
  77. },
  78. {
  79. "cell_type": "markdown",
  80. "metadata": {},
  81. "source": [
  82. "## Graph"
  83. ]
  84. },
  85. {
  86. "cell_type": "code",
  87. "execution_count": 38,
  88. "metadata": {},
  89. "outputs": [],
  90. "source": [
  91. "tf.reset_default_graph()\n",
  92. "_document = tf.placeholder(tf.int32, [None, None], 'document')\n",
  93. "_query = tf.placeholder(tf.int32, [None, None], 'query')\n",
  94. "_answer = tf.placeholder(tf.int32, [None, ], 'answer')\n",
  95. "batch_size_, dlen_ = tf.unstack(tf.shape(_document))\n",
  96. "qlen_ = tf.shape(_query)[1]\n",
  97. "E = tf.get_variable('embedding', [vocab_size, emb_dim], tf.float32, \n",
  98. " tf.random_uniform_initializer(-0.01, 0.01))\n",
  99. "enc_cell = tf.nn.rnn_cell.LSTMCell(hdim)\n",
  100. "with tf.variable_scope('encoder') as scope:\n",
  101. " document, _ = tf.nn.dynamic_rnn(enc_cell, \n",
  102. " tf.nn.embedding_lookup(E, _document),\n",
  103. " tf.count_nonzero(_document, axis=1), # sequence lengths\n",
  104. " dtype=tf.float32)\n",
  105. " scope.reuse_variables()\n",
  106. " query, _ = tf.nn.dynamic_rnn(enc_cell, \n",
  107. " tf.nn.embedding_lookup(E, _query),\n",
  108. " tf.count_nonzero(_query, axis=1), # sequence lengths\n",
  109. " dtype=tf.float32)\n",
  110. "with tf.variable_scope('query_projection'):\n",
  111. " query = tf.contrib.layers.fully_connected(query, hdim, \n",
  112. " activation_fn=tf.nn.tanh)\n",
  113. "with tf.variable_scope('affinity_matrix'):\n",
  114. " affinity = tf.matmul(document, #[(B,Ld,d]x[B,d,Lq]=[B,Ld,Lq]\n",
  115. " tf.transpose(query, [0, 2, 1]))\n",
  116. "with tf.variable_scope('attention_weights'):\n",
  117. " Ad = tf.nn.softmax(affinity) # normalize along Lq\n",
  118. " Aq = tf.nn.softmax(tf.transpose(affinity, [0, 2, 1])) # normalize along Ld\n",
  119. "with tf.variable_scope('summary'):\n",
  120. " Cq = tf.matmul(Aq, document) # [B,Lq,Ld] x [B,Ld,d] = [B,Lq,d]\n",
  121. " Cd = tf.transpose(tf.matmul( # [B,Lq,d*2] x [B,Ld,Lq] = [B,Ld,2*d]\n",
  122. " tf.transpose(tf.concat([query, Cq], axis=-1), [0, 2, 1]),\n",
  123. " tf.transpose(Ad, [0, 2, 1])\n",
  124. " ), [0, 2, 1])\n",
  125. "\n",
  126. "with tf.variable_scope('temporal_fusion'):\n",
  127. " context_states, _ = tf.nn.bidirectional_dynamic_rnn(\n",
  128. " tf.nn.rnn_cell.LSTMCell(hdim),\n",
  129. " tf.nn.rnn_cell.LSTMCell(hdim),\n",
  130. " tf.concat([document, Cd], axis=-1), # [B, Ld, 2*d + d]\n",
  131. " tf.count_nonzero(_document, axis=1), # sequence lengths\n",
  132. " dtype=tf.float32)\n",
  133. " context = tf.concat(context_states, axis=-1) # final representation"
  134. ]
  135. },
  136. {
  137. "cell_type": "markdown",
  138. "metadata": {},
  139. "source": [
  140. "## Execute"
  141. ]
  142. },
  143. {
  144. "cell_type": "code",
  145. "execution_count": 40,
  146. "metadata": {},
  147. "outputs": [
  148. {
  149. "data": {
  150. "text/plain": [
  151. "[(8, 30, 64)]"
  152. ]
  153. },
  154. "execution_count": 40,
  155. "metadata": {},
  156. "output_type": "execute_result"
  157. }
  158. ],
  159. "source": [
  160. "shapes(execute_graph([context]))"
  161. ]
  162. }
  163. ],
  164. "metadata": {
  165. "kernelspec": {
  166. "display_name": "Python 3",
  167. "language": "python",
  168. "name": "python3"
  169. },
  170. "language_info": {
  171. "codemirror_mode": {
  172. "name": "ipython",
  173. "version": 3
  174. },
  175. "file_extension": ".py",
  176. "mimetype": "text/x-python",
  177. "name": "python",
  178. "nbconvert_exporter": "python",
  179. "pygments_lexer": "ipython3",
  180. "version": "3.5.2"
  181. }
  182. },
  183. "nbformat": 4,
  184. "nbformat_minor": 2
  185. }
Add Comment
Please, Sign In to add comment