Advertisement
Guest User

Untitled

a guest
Nov 25th, 2015
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.59 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "First we load the network and check that we have separated fc7 into separate blobs so the ReLU pass does not override our input."
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": 5,
  13. "metadata": {
  14. "cellView": "both",
  15. "colab_type": "code",
  16. "collapsed": false,
  17. "id": "i9hkSm1IOZNR"
  18. },
  19. "outputs": [
  20. {
  21. "name": "stdout",
  22. "output_type": "stream",
  23. "text": [
  24. "data: 1505280 = (10, 3, 224, 224)\n",
  25. "conv1_1: 32112640 = (10, 64, 224, 224)\n",
  26. "conv1_2: 32112640 = (10, 64, 224, 224)\n",
  27. "pool1: 8028160 = (10, 64, 112, 112)\n",
  28. "conv2_1: 16056320 = (10, 128, 112, 112)\n",
  29. "conv2_2: 16056320 = (10, 128, 112, 112)\n",
  30. "pool2: 4014080 = (10, 128, 56, 56)\n",
  31. "conv3_1: 8028160 = (10, 256, 56, 56)\n",
  32. "conv3_2: 8028160 = (10, 256, 56, 56)\n",
  33. "conv3_3: 8028160 = (10, 256, 56, 56)\n",
  34. "conv3_4: 8028160 = (10, 256, 56, 56)\n",
  35. "pool3: 2007040 = (10, 256, 28, 28)\n",
  36. "conv4_1: 4014080 = (10, 512, 28, 28)\n",
  37. "conv4_2: 4014080 = (10, 512, 28, 28)\n",
  38. "conv4_3: 4014080 = (10, 512, 28, 28)\n",
  39. "conv4_4: 4014080 = (10, 512, 28, 28)\n",
  40. "pool4: 1003520 = (10, 512, 14, 14)\n",
  41. "conv5_1: 1003520 = (10, 512, 14, 14)\n",
  42. "conv5_2: 1003520 = (10, 512, 14, 14)\n",
  43. "conv5_3: 1003520 = (10, 512, 14, 14)\n",
  44. "conv5_4: 1003520 = (10, 512, 14, 14)\n",
  45. "pool5: 250880 = (10, 512, 7, 7)\n",
  46. "fc6: 40960 = (10, 4096)\n",
  47. "fc6_fc6_0_split_0: 40960 = (10, 4096)\n",
  48. "fc6_fc6_0_split_1: 40960 = (10, 4096)\n",
  49. "fc6_fc6_0_split_2: 40960 = (10, 4096)\n",
  50. "relu6: 40960 = (10, 4096)\n",
  51. "drop6: 40960 = (10, 4096)\n",
  52. "fc7: 40960 = (10, 4096)\n",
  53. "relu7: 40960 = (10, 4096)\n",
  54. "drop7: 40960 = (10, 4096)\n",
  55. "fc8: 10000 = (10, 1000)\n",
  56. "prob: 10000 = (10, 1000)\n"
  57. ]
  58. }
  59. ],
  60. "source": [
  61. "import numpy as np\n",
  62. "from google.protobuf import text_format\n",
  63. "import caffe\n",
  64. "\n",
  65. "# load network\n",
  66. "# the prototxt has force_backward: true, and fc7 is separated into multiple blobs\n",
  67. "model_name = 'vgg_ilsvrc_19'\n",
  68. "model_path = '../caffe/models/' + model_name + '/'\n",
  69. "net_fn = model_path + 'deploy-expanded.prototxt'\n",
  70. "param_fn = model_path + 'net.caffemodel'\n",
  71. "net = caffe.Classifier(net_fn, param_fn)\n",
  72. "\n",
  73. "# print blob names and sizes\n",
  74. "for end in net.blobs.keys():\n",
  75. " cur = net.blobs[end]\n",
  76. " print end + ': {} = {}'.format(cur.count, cur.data.shape)"
  77. ]
  78. },
  79. {
  80. "cell_type": "markdown",
  81. "metadata": {},
  82. "source": [
  83. "Then define a function that optimizes one layer (fc7) to produce a one-hot vector on the output (prob), starting from the layer immediately after fc7 (relu7). We use Nesterov momentum but it's probably overkill as this generally converges quickly even without it."
  84. ]
  85. },
  86. {
  87. "cell_type": "code",
  88. "execution_count": 6,
  89. "metadata": {
  90. "collapsed": false
  91. },
  92. "outputs": [],
  93. "source": [
  94. "def optimize(net,\n",
  95. " hot=0,\n",
  96. " step_size=.01,\n",
  97. " iter_n=100,\n",
  98. " mu=.9,\n",
  99. " basename='fc7',\n",
  100. " start='relu7',\n",
  101. " end='prob'):\n",
  102. " base = net.blobs[basename]\n",
  103. " first = net.blobs[start]\n",
  104. " last = net.blobs[end]\n",
  105. " base.data[0] = np.random.normal(.5, .1, base.data[0].shape)\n",
  106. " base.diff[0] = 0.\n",
  107. " velocity = np.zeros_like(base.data[0])\n",
  108. " velocity_previous = np.zeros_like(base.data[0])\n",
  109. " for i in range(iter_n):\n",
  110. " net.forward(start=start, end=end)\n",
  111. " target = np.zeros_like(last.data[0])\n",
  112. " target.flat[hot] = 1.\n",
  113. " error = target - last.data[0]\n",
  114. " last.diff[0] = error\n",
  115. " net.backward(start=end, end=start)\n",
  116. " grad = base.diff[0]\n",
  117. " learning_rate = (step_size / np.abs(grad).mean())\n",
  118. " velocity_previous = velocity\n",
  119. " velocity = mu * velocity + learning_rate * grad\n",
  120. " base.data[0] += -mu * velocity_previous + (1 + mu) * velocity\n",
  121. " base.data[0] = np.clip(base.data[0], 0, +1)\n",
  122. " return base.data[0]"
  123. ]
  124. },
  125. {
  126. "cell_type": "markdown",
  127. "metadata": {},
  128. "source": [
  129. "Checking that we get different vectors for different \"hot\" choices, and that the `optimize()` function is actually doing what we expect to the net."
  130. ]
  131. },
  132. {
  133. "cell_type": "code",
  134. "execution_count": 19,
  135. "metadata": {
  136. "collapsed": false
  137. },
  138. "outputs": [
  139. {
  140. "name": "stdout",
  141. "output_type": "stream",
  142. "text": [
  143. "in: [ 1. 0. 1. 0. 0. 0. 0. 0.]\n",
  144. "out: [ 9.91814017e-01 3.98620468e-05 9.90041463e-06 9.02536340e-06\n",
  145. " 1.13274482e-05 1.72698910e-05 1.00160096e-05 3.25881274e-06]\n",
  146. "in: [ 0. 0. 1. 0. 1. 1. 0. 1.]\n",
  147. "out: [ 2.45179963e-05 9.93825078e-01 4.71601061e-06 8.41968267e-06\n",
  148. " 8.20327932e-06 1.34181846e-05 5.93410823e-06 7.48563616e-06]\n"
  149. ]
  150. }
  151. ],
  152. "source": [
  153. "print 'in:', optimize(net, hot=0)[0:8]\n",
  154. "print 'out:', net.blobs['prob'].data[0,0:8]\n",
  155. "print 'in:', optimize(net, hot=1)[0:8]\n",
  156. "print 'out:', net.blobs['prob'].data[0,0:8]"
  157. ]
  158. },
  159. {
  160. "cell_type": "markdown",
  161. "metadata": {},
  162. "source": [
  163. "Run `optimize()` for every classification and save to disk in a format `bh_tsne` will be able to parse."
  164. ]
  165. },
  166. {
  167. "cell_type": "code",
  168. "execution_count": 9,
  169. "metadata": {
  170. "collapsed": false
  171. },
  172. "outputs": [
  173. {
  174. "name": "stderr",
  175. "output_type": "stream",
  176. "text": [
  177. "100% (1000 of 1000) |#####################| Elapsed Time: 0:07:12 Time: 0:07:12\n"
  178. ]
  179. }
  180. ],
  181. "source": [
  182. "from progressbar import ProgressBar\n",
  183. "vectors = []\n",
  184. "pbar = ProgressBar()\n",
  185. "for i in pbar(range(1000)):\n",
  186. " vectors.append(optimize(net, hot=i).copy())\n",
  187. "np.savetxt('vectors', vectors, fmt='%.2f', delimiter='\\t')"
  188. ]
  189. },
  190. {
  191. "cell_type": "markdown",
  192. "metadata": {},
  193. "source": [
  194. "Load a list of labels and print them with their associated vectors."
  195. ]
  196. },
  197. {
  198. "cell_type": "code",
  199. "execution_count": 20,
  200. "metadata": {
  201. "collapsed": false
  202. },
  203. "outputs": [
  204. {
  205. "name": "stdout",
  206. "output_type": "stream",
  207. "text": [
  208. "[ 1. 0. 1. ..., 1. 0. 1.] tench\n",
  209. "[ 0. 0. 1. ..., 1. 1. 1.] goldfish\n",
  210. "[ 0. 0. 0. ..., 0. 0. 0.] great white shark\n",
  211. "[ 1. 1. 0. ..., 0. 0. 0.] tiger shark\n",
  212. "[ 0. 0. 0. ..., 1. 0. 0.] hammerhead\n",
  213. "[ 1. 1. 1. ..., 1. 0. 0.] electric ray\n",
  214. "[ 0. 0. 0. ..., 0. 0. 1.] stingray\n",
  215. "[ 0. 0. 0. ..., 1. 1. 0.] cock\n",
  216. "[ 0. 0. 1. ..., 0. 1. 0.] hen\n",
  217. "[ 0. 1. 0. ..., 1. 0. 0.] ostrich\n"
  218. ]
  219. }
  220. ],
  221. "source": [
  222. "labels = []\n",
  223. "with open('words') as f:\n",
  224. " for line in f:\n",
  225. " labels.append(line.strip())\n",
  226. "for i in range(10):\n",
  227. " print vectors[i], labels[i]"
  228. ]
  229. },
  230. {
  231. "cell_type": "markdown",
  232. "metadata": {},
  233. "source": [
  234. "To double-check that the vectors are representative of some similarity, we set up a nearest neighbor search."
  235. ]
  236. },
  237. {
  238. "cell_type": "code",
  239. "execution_count": 17,
  240. "metadata": {
  241. "collapsed": false
  242. },
  243. "outputs": [
  244. {
  245. "data": {
  246. "text/plain": [
  247. "NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',\n",
  248. " metric_params=None, n_neighbors=10, p=2, radius=100)"
  249. ]
  250. },
  251. "execution_count": 17,
  252. "metadata": {},
  253. "output_type": "execute_result"
  254. }
  255. ],
  256. "source": [
  257. "from sklearn.neighbors import NearestNeighbors\n",
  258. "neigh = NearestNeighbors(n_neighbors=10, radius=100)\n",
  259. "neigh.fit(vectors)"
  260. ]
  261. },
  262. {
  263. "cell_type": "markdown",
  264. "metadata": {},
  265. "source": [
  266. "And run it on one category to get similar results, and indeed \"electric ray\" is similar to \"stingray\" and the others."
  267. ]
  268. },
  269. {
  270. "cell_type": "code",
  271. "execution_count": 21,
  272. "metadata": {
  273. "collapsed": false
  274. },
  275. "outputs": [
  276. {
  277. "name": "stdout",
  278. "output_type": "stream",
  279. "text": [
  280. "0.0 electric ray\n",
  281. "39.6456756784 stingray\n",
  282. "40.1148800322 dugong\n",
  283. "40.6637639674 jellyfish\n",
  284. "40.8964570593 tiger shark\n",
  285. "41.2111950809 hammerhead\n",
  286. "41.2288285063 flatworm\n",
  287. "41.3787372934 sea slug\n",
  288. "41.5759257263 loggerhead\n",
  289. "41.6082491821 grey whale\n"
  290. ]
  291. }
  292. ],
  293. "source": [
  294. "neighbors = neigh.kneighbors([vectors[5]], n_neighbors=10, return_distance=True)\n",
  295. "for distance, i in zip(neighbors[0][0], neighbors[1][0]):\n",
  296. " print distance, labels[i]"
  297. ]
  298. },
  299. {
  300. "cell_type": "code",
  301. "execution_count": null,
  302. "metadata": {
  303. "collapsed": true
  304. },
  305. "outputs": [],
  306. "source": []
  307. }
  308. ],
  309. "metadata": {
  310. "colabVersion": "0.3.1",
  311. "default_view": {},
  312. "kernelspec": {
  313. "display_name": "Python 2",
  314. "language": "python",
  315. "name": "python2"
  316. },
  317. "language_info": {
  318. "codemirror_mode": {
  319. "name": "ipython",
  320. "version": 2
  321. },
  322. "file_extension": ".py",
  323. "mimetype": "text/x-python",
  324. "name": "python",
  325. "nbconvert_exporter": "python",
  326. "pygments_lexer": "ipython2",
  327. "version": "2.7.10"
  328. },
  329. "views": {}
  330. },
  331. "nbformat": 4,
  332. "nbformat_minor": 0
  333. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement