Guest User

Untitled

a guest
Aug 15th, 2018
109
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.98 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "I'd like to train VGG network as example, but it just ran too slow on my computer.\n",
  8. "\n",
  9. "I have only mac, so I can only run tensorflow with CPU :(\n",
  10. "\n",
  11. "So, I refer to our good old friend, handwriting recognition with MNIST.\n",
  12. "\n",
  13. "Bear with me, plz. I just want to show you a full process converting a graph to a single reusable file."
  14. ]
  15. },
  16. {
  17. "cell_type": "code",
  18. "execution_count": 1,
  19. "metadata": {
  20. "collapsed": true
  21. },
  22. "outputs": [],
  23. "source": [
  24. "from math import sqrt\n",
  25. "import numpy as np\n",
  26. "import tensorflow as tf\n",
  27. "from tensorflow.examples.tutorials.mnist.input_data import read_data_sets\n",
  28. "\n",
  29. "%matplotlib inline\n",
  30. "import matplotlib.pyplot as plt"
  31. ]
  32. },
  33. {
  34. "cell_type": "code",
  35. "execution_count": 2,
  36. "metadata": {
  37. "collapsed": true
  38. },
  39. "outputs": [],
  40. "source": [
  41. "# helper function\n",
  42. "def accuracy(y, y_):\n",
  43. " return (y == y_).mean()"
  44. ]
  45. },
  46. {
  47. "cell_type": "code",
  48. "execution_count": 3,
  49. "metadata": {},
  50. "outputs": [
  51. {
  52. "name": "stdout",
  53. "output_type": "stream",
  54. "text": [
  55. "Extracting mnist_data/train-images-idx3-ubyte.gz\n",
  56. "Extracting mnist_data/train-labels-idx1-ubyte.gz\n",
  57. "Extracting mnist_data/t10k-images-idx3-ubyte.gz\n",
  58. "Extracting mnist_data/t10k-labels-idx1-ubyte.gz\n"
  59. ]
  60. }
  61. ],
  62. "source": [
  63. "mnist = read_data_sets(\"mnist_data\", one_hot=True)"
  64. ]
  65. },
  66. {
  67. "cell_type": "code",
  68. "execution_count": 4,
  69. "metadata": {},
  70. "outputs": [],
  71. "source": [
  72. "# build a trainable graph\n",
  73. "train_graph = tf.Graph()\n",
  74. "hidden_shape = [100, 100]\n",
  75. "learning_rate = 0.05\n",
  76. "\n",
  77. "with train_graph.as_default():\n",
  78. " with tf.name_scope(\"input\"):\n",
  79. " image_batch = tf.placeholder(tf.float32, \n",
  80. " shape=[None, 784],\n",
  81. " name=\"image_batch\")\n",
  82. " layer_shapes = [784] + hidden_shape + [10]\n",
  83. " last_layer = image_batch\n",
  84. " with tf.name_scope(\"hidden\"):\n",
  85. " for in_shape, out_shape in zip(layer_shapes[:-1], layer_shapes[1:]):\n",
  86. " sig = sqrt(6)/(in_shape + out_shape)\n",
  87. " W_init = 2*sig*np.random.rand(in_shape, out_shape)-sig\n",
  88. " W = tf.Variable(np.random.randn(in_shape, out_shape),\n",
  89. " dtype=tf.float32,\n",
  90. " name=\"weight\")\n",
  91. " bias = tf.Variable(np.random.randn(out_shape),\n",
  92. " dtype=tf.float32,\n",
  93. " name=\"bias\")\n",
  94. " zscore = tf.matmul(last_layer, W) + bias\n",
  95. " last_layer = tf.nn.sigmoid(zscore)\n",
  96. " \n",
  97. " with tf.name_scope(\"output\"):\n",
  98. " prob = tf.nn.softmax(zscore, \n",
  99. " name=\"probability\")\n",
  100. " predict = tf.arg_max(prob, \n",
  101. " 1, \n",
  102. " name=\"prediction\")\n",
  103. " target = tf.placeholder(tf.float32,\n",
  104. " shape=[None, 10],\n",
  105. " name=\"target\")\n",
  106. " loss = -tf.reduce_mean(target*tf.log(prob))\n",
  107. " train_op = tf.train.AdagradOptimizer(learning_rate).minimize(loss)\n",
  108. " saver = tf.train.Saver()"
  109. ]
  110. },
  111. {
  112. "cell_type": "code",
  113. "execution_count": 5,
  114. "metadata": {},
  115. "outputs": [
  116. {
  117. "name": "stdout",
  118. "output_type": "stream",
  119. "text": [
  120. "Initialized\n",
  121. "Iteration 1000: 0.16 47.45%\n",
  122. "Iteration 2000: 0.12 62.55%\n",
  123. "Iteration 3000: 0.10 69.13%\n",
  124. "Iteration 4000: 0.08 73.12%\n",
  125. "Iteration 5000: 0.10 76.01%\n",
  126. "Iteration 6000: 0.06 77.92%\n",
  127. "Iteration 7000: 0.07 79.24%\n",
  128. "Iteration 8000: 0.06 80.53%\n",
  129. "Iteration 9000: 0.06 81.49%\n",
  130. "Iteration 10000: 0.07 82.52%\n",
  131. "Iteration 11000: 0.05 83.25%\n",
  132. "Iteration 12000: 0.06 83.75%\n",
  133. "Iteration 13000: 0.04 84.27%\n",
  134. "Iteration 14000: 0.06 84.71%\n",
  135. "Iteration 15000: 0.06 84.95%\n",
  136. "Iteration 16000: 0.05 85.33%\n",
  137. "Iteration 17000: 0.03 85.62%\n",
  138. "Iteration 18000: 0.05 86.05%\n",
  139. "Iteration 19000: 0.04 86.28%\n",
  140. "Iteration 20000: 0.05 86.50%\n"
  141. ]
  142. }
  143. ],
  144. "source": [
  145. "# Training and save\n",
  146. "num_iters = 20000\n",
  147. "batch_size = 300\n",
  148. "\n",
  149. "with tf.Session(graph=train_graph) as sess:\n",
  150. " tf.global_variables_initializer().run()\n",
  151. " print(\"Initialized\")\n",
  152. " \n",
  153. " for step in range(num_iters):\n",
  154. " images, labels = mnist.train.next_batch(batch_size)\n",
  155. " feed_dict = { image_batch: images,\n",
  156. " target: labels }\n",
  157. " l, _ = sess.run([loss, train_op], feed_dict=feed_dict)\n",
  158. " \n",
  159. " if (step+1) % 1000 == 0:\n",
  160. " pred = sess.run(predict, feed_dict={image_batch:mnist.test.images})\n",
  161. " acc = accuracy(pred, np.argmax(mnist.test.labels, axis=1))\n",
  162. " print(\"Iteration {}: {:.2f} {:.2f}%\".format(step+1, l, acc*100))\n",
  163. " \n",
  164. " saver.save(sess, save_path=\"model/mnist_example.chkp\", global_step=step)"
  165. ]
  166. },
  167. {
  168. "cell_type": "markdown",
  169. "metadata": {},
  170. "source": [
  171. "Looking good enough, ready to save the graph"
  172. ]
  173. },
  174. {
  175. "cell_type": "code",
  176. "execution_count": 6,
  177. "metadata": {},
  178. "outputs": [
  179. {
  180. "name": "stdout",
  181. "output_type": "stream",
  182. "text": [
  183. "checkpoint\r\n",
  184. "mnist_example.chkp-19999.data-00000-of-00001\r\n",
  185. "mnist_example.chkp-19999.index\r\n",
  186. "mnist_example.chkp-19999.meta\r\n"
  187. ]
  188. }
  189. ],
  190. "source": [
  191. "# saved data\n",
  192. "!ls model/"
  193. ]
  194. },
  195. {
  196. "cell_type": "code",
  197. "execution_count": 7,
  198. "metadata": {
  199. "collapsed": true
  200. },
  201. "outputs": [],
  202. "source": [
  203. "# import the graph_util\n",
  204. "# We'll use it to convert all variables in graph to constants\n",
  205. "from tensorflow.python.framework import graph_util"
  206. ]
  207. },
  208. {
  209. "cell_type": "code",
  210. "execution_count": 8,
  211. "metadata": {},
  212. "outputs": [
  213. {
  214. "data": {
  215. "text/plain": [
  216. "'model/mnist_example.chkp-19999'"
  217. ]
  218. },
  219. "execution_count": 8,
  220. "metadata": {},
  221. "output_type": "execute_result"
  222. }
  223. ],
  224. "source": [
  225. "# A CheckpointState object from tensorflow.python.training.checkpoint_state_pb2\n",
  226. "# Yes, it's produced by protobuf.\n",
  227. "checkpoint = tf.train.get_checkpoint_state(\"model/\")\n",
  228. "checkpoint.model_checkpoint_path"
  229. ]
  230. },
  231. {
  232. "cell_type": "code",
  233. "execution_count": 9,
  234. "metadata": {},
  235. "outputs": [
  236. {
  237. "data": {
  238. "text/plain": [
  239. "'model/mnist_example.chkp-19999.meta'"
  240. ]
  241. },
  242. "execution_count": 9,
  243. "metadata": {},
  244. "output_type": "execute_result"
  245. }
  246. ],
  247. "source": [
  248. "# meta file contains the meta data of the saved session\n",
  249. "meta_file_path = checkpoint.model_checkpoint_path+\".meta\"\n",
  250. "meta_file_path"
  251. ]
  252. },
  253. {
  254. "cell_type": "code",
  255. "execution_count": 10,
  256. "metadata": {
  257. "collapsed": true
  258. },
  259. "outputs": [],
  260. "source": [
  261. "# restore the saver from the meta file\n",
  262. "restore_graph = tf.Graph()\n",
  263. "with restore_graph.as_default():\n",
  264. " saver = tf.train.import_meta_graph(meta_file_path, clear_devices=True)\n",
  265. "restore_graph_def = restore_graph.as_graph_def()"
  266. ]
  267. },
  268. {
  269. "cell_type": "code",
  270. "execution_count": 11,
  271. "metadata": {},
  272. "outputs": [
  273. {
  274. "name": "stdout",
  275. "output_type": "stream",
  276. "text": [
  277. "INFO:tensorflow:Restoring parameters from model/mnist_example.chkp-19999\n",
  278. "INFO:tensorflow:Froze 6 variables.\n",
  279. "Converted 6 variables to const ops.\n"
  280. ]
  281. }
  282. ],
  283. "source": [
  284. "# restore the session with restored saver\n",
  285. "# convert graph to \"freezed\" graph\n",
  286. "# Note that you need to specify the layers you want to output, \n",
  287. "# tensorflow will extract a subgraph which involves the layers you want and freeze it.\n",
  288. "with tf.Session(graph=restore_graph) as sess:\n",
  289. " saver.restore(sess, checkpoint.model_checkpoint_path)\n",
  290. " out_graph_def = graph_util.convert_variables_to_constants(sess,\n",
  291. " restore_graph_def,\n",
  292. " [predict.op.name])"
  293. ]
  294. },
  295. {
  296. "cell_type": "code",
  297. "execution_count": 12,
  298. "metadata": {
  299. "collapsed": true
  300. },
  301. "outputs": [],
  302. "source": [
  303. "# write graph to disk just like what you will do with any protobuf object\n",
  304. "with tf.gfile.GFile(\"my_mnist.pb\", \"wb\") as fid:\n",
  305. " fid.write(out_graph_def.SerializeToString())"
  306. ]
  307. },
  308. {
  309. "cell_type": "markdown",
  310. "metadata": {},
  311. "source": [
  312. "------\n",
  313. "\n",
  314. "Testing time!"
  315. ]
  316. },
  317. {
  318. "cell_type": "code",
  319. "execution_count": 13,
  320. "metadata": {
  321. "collapsed": true
  322. },
  323. "outputs": [],
  324. "source": [
  325. "graph = tf.Graph()\n",
  326. "graph_def = graph.as_graph_def()"
  327. ]
  328. },
  329. {
  330. "cell_type": "code",
  331. "execution_count": 14,
  332. "metadata": {
  333. "collapsed": true
  334. },
  335. "outputs": [],
  336. "source": [
  337. "with tf.gfile.GFile(\"my_mnist.pb\", \"rb\") as rf:\n",
  338. " graph_def.ParseFromString(rf.read())\n",
  339. "with graph.as_default():\n",
  340. " tf.import_graph_def(graph_def, name=\"\")"
  341. ]
  342. },
  343. {
  344. "cell_type": "code",
  345. "execution_count": 15,
  346. "metadata": {},
  347. "outputs": [],
  348. "source": [
  349. "images_tensor = graph.get_tensor_by_name(\"input/image_batch:0\")\n",
  350. "predict_tensor = graph.get_tensor_by_name(\"output/prediction:0\")"
  351. ]
  352. },
  353. {
  354. "cell_type": "code",
  355. "execution_count": 16,
  356. "metadata": {},
  357. "outputs": [
  358. {
  359. "name": "stdout",
  360. "output_type": "stream",
  361. "text": [
  362. "86.5\n"
  363. ]
  364. }
  365. ],
  366. "source": [
  367. "with tf.Session(graph=graph) as sess:\n",
  368. " tf.global_variables_initializer().run()\n",
  369. " pred = sess.run(predict_tensor, \n",
  370. " feed_dict={images_tensor:mnist.test.images})\n",
  371. " print(accuracy(np.argmax(mnist.test.labels, axis=1), pred)*100)"
  372. ]
  373. },
  374. {
  375. "cell_type": "code",
  376. "execution_count": null,
  377. "metadata": {
  378. "collapsed": true
  379. },
  380. "outputs": [],
  381. "source": []
  382. }
  383. ],
  384. "metadata": {
  385. "kernelspec": {
  386. "display_name": "Python 3",
  387. "language": "python",
  388. "name": "python3"
  389. },
  390. "language_info": {
  391. "codemirror_mode": {
  392. "name": "ipython",
  393. "version": 3
  394. },
  395. "file_extension": ".py",
  396. "mimetype": "text/x-python",
  397. "name": "python",
  398. "nbconvert_exporter": "python",
  399. "pygments_lexer": "ipython3",
  400. "version": "3.6.1"
  401. }
  402. },
  403. "nbformat": 4,
  404. "nbformat_minor": 2
  405. }
Add Comment
Please, Sign In to add comment