Advertisement
Guest User

Untitled

a guest
Sep 29th, 2016
90
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.26 KB | None | 0 0
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## 加载包"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": 1,
  13. "metadata": {
  14. "collapsed": false
  15. },
  16. "outputs": [
  17. {
  18. "name": "stdout",
  19. "output_type": "stream",
  20. "text": [
  21. "Packages loaded\n"
  22. ]
  23. }
  24. ],
  25. "source": [
  26. "import scipy.io\n",
  27. "import numpy as np \n",
  28. "import os \n",
  29. "import scipy.misc \n",
  30. "import matplotlib.pyplot as plt \n",
  31. "import tensorflow as tf\n",
  32. "%matplotlib inline \n",
  33. "print (\"Packages loaded\")"
  34. ]
  35. },
  36. {
  37. "cell_type": "markdown",
  38. "metadata": {},
  39. "source": [
  40. "## 定义网络结构"
  41. ]
  42. },
  43. {
  44. "cell_type": "code",
  45. "execution_count": 2,
  46. "metadata": {
  47. "collapsed": true
  48. },
  49. "outputs": [],
  50. "source": [
  51. "IMAGE_W = 800 \n",
  52. "IMAGE_H = 600 \n",
  53. "cwd = os.getcwd()\n",
  54. "# 内容图片文档\n",
  55. "CONTENT_IMG = cwd + \"/images/Taipei101.jpg\"\n",
  56. "# 风格图片文档\n",
  57. "STYLE_IMG = cwd + \"/images/StarryNight.jpg\"\n",
  58. "# 输出结果的目录和文档名\n",
  59. "OUTOUT_DIR = './images'\n",
  60. "OUTPUT_IMG = 'results.png'\n",
  61. "# VGG模型文件\n",
  62. "VGG_MODEL = cwd + \"/data/imagenet-vgg-verydeep-19.mat\"\n",
  63. "INI_NOISE_RATIO = 0.7\n",
  64. "STYLE_STRENGTH = 500\n",
  65. "ITERATION = 5000\n",
  66. "\n",
  67. "CONTENT_LAYERS =[('conv4_2',1.)]\n",
  68. "STYLE_LAYERS=[('conv1_1',1.),('conv2_1',1.5),('conv3_1',2.),('conv4_1',2.5),('conv5_1',3.)]\n",
  69. "\n",
  70. "\n",
  71. "MEAN_VALUES = np.array([123, 117, 104]).reshape((1,1,1,3))\n"
  72. ]
  73. },
  74. {
  75. "cell_type": "code",
  76. "execution_count": 3,
  77. "metadata": {
  78. "collapsed": true
  79. },
  80. "outputs": [],
  81. "source": [
  82. "# 定义前向计算函数,如果是conv层则计算卷积,如果是pool则进行池化\n",
  83. "def build_net(ntype, nin, nwb=None):\n",
  84. " if ntype == 'conv':\n",
  85. " return tf.nn.relu(tf.nn.conv2d(nin, nwb[0], strides=[1, 1, 1, 1], padding='SAME')+ nwb[1])\n",
  86. " elif ntype == 'pool':\n",
  87. " return tf.nn.avg_pool(nin, ksize=[1, 2, 2, 1],\n",
  88. " strides=[1, 2, 2, 1], padding='SAME')\n",
  89. "\n",
  90. "# 从VGG模型中提取参数\n",
  91. "def get_weight_bias(vgg_layers, i,):\n",
  92. " weights = vgg_layers[i][0][0][0][0][0]\n",
  93. " weights = tf.constant(weights)\n",
  94. " bias = vgg_layers[i][0][0][0][0][1]\n",
  95. " bias = tf.constant(np.reshape(bias, (bias.size)))\n",
  96. " return weights, bias\n",
  97. "\n",
  98. "# 构建VGG模型网络结构,从现成的VGG模型文档中读取参数\n",
  99. "# 以conv1_1层参数为例,长下面这个样子\n",
  100. "# (<tf.Tensor 'Const_83:0' shape=(3, 3, 3, 64) dtype=float32>,\n",
  101. "# <tf.Tensor 'Const_84:0' shape=(64,) dtype=float32>)\n",
  102. "# conv1_1层输出长下面这个样子\n",
  103. "# <tf.Tensor 'Relu_32:0' shape=(1, 600, 800, 64) dtype=float32>\n",
  104. "\n",
  105. "def build_vgg19(path):\n",
  106. " net = {}\n",
  107. " vgg_rawnet = scipy.io.loadmat(path)\n",
  108. " vgg_layers = vgg_rawnet['layers'][0]\n",
  109. " net['input'] = tf.Variable(np.zeros((1, IMAGE_H, IMAGE_W, 3)).astype('float32'))\n",
  110. " net['conv1_1'] = build_net('conv',net['input'],get_weight_bias(vgg_layers,0))\n",
  111. " net['conv1_2'] = build_net('conv',net['conv1_1'],get_weight_bias(vgg_layers,2))\n",
  112. " net['pool1'] = build_net('pool',net['conv1_2'])\n",
  113. " net['conv2_1'] = build_net('conv',net['pool1'],get_weight_bias(vgg_layers,5))\n",
  114. " net['conv2_2'] = build_net('conv',net['conv2_1'],get_weight_bias(vgg_layers,7))\n",
  115. " net['pool2'] = build_net('pool',net['conv2_2'])\n",
  116. " net['conv3_1'] = build_net('conv',net['pool2'],get_weight_bias(vgg_layers,10))\n",
  117. " net['conv3_2'] = build_net('conv',net['conv3_1'],get_weight_bias(vgg_layers,12))\n",
  118. " net['conv3_3'] = build_net('conv',net['conv3_2'],get_weight_bias(vgg_layers,14))\n",
  119. " net['conv3_4'] = build_net('conv',net['conv3_3'],get_weight_bias(vgg_layers,16))\n",
  120. " net['pool3'] = build_net('pool',net['conv3_4'])\n",
  121. " net['conv4_1'] = build_net('conv',net['pool3'],get_weight_bias(vgg_layers,19))\n",
  122. " net['conv4_2'] = build_net('conv',net['conv4_1'],get_weight_bias(vgg_layers,21))\n",
  123. " net['conv4_3'] = build_net('conv',net['conv4_2'],get_weight_bias(vgg_layers,23))\n",
  124. " net['conv4_4'] = build_net('conv',net['conv4_3'],get_weight_bias(vgg_layers,25))\n",
  125. " net['pool4'] = build_net('pool',net['conv4_4'])\n",
  126. " net['conv5_1'] = build_net('conv',net['pool4'],get_weight_bias(vgg_layers,28))\n",
  127. " net['conv5_2'] = build_net('conv',net['conv5_1'],get_weight_bias(vgg_layers,30))\n",
  128. " net['conv5_3'] = build_net('conv',net['conv5_2'],get_weight_bias(vgg_layers,32))\n",
  129. " net['conv5_4'] = build_net('conv',net['conv5_3'],get_weight_bias(vgg_layers,34))\n",
  130. " net['pool5'] = build_net('pool',net['conv5_4'])\n",
  131. " return net\n",
  132. "\n",
  133. "# 内容损失函数\n",
  134. "def build_content_loss(p, x):\n",
  135. " M = p.shape[1]*p.shape[2]\n",
  136. " N = p.shape[3]\n",
  137. " loss = (1./(2* N**0.5 * M**0.5 )) * tf.reduce_sum(tf.pow((x - p),2)) \n",
  138. " return loss\n",
  139. "\n",
  140. "\n",
  141. "def gram_matrix(x, area, depth):\n",
  142. " x1 = tf.reshape(x,(area,depth))\n",
  143. " g = tf.matmul(tf.transpose(x1), x1)\n",
  144. " return g\n",
  145. "\n",
  146. "def gram_matrix_val(x, area, depth):\n",
  147. " x1 = x.reshape(area,depth)\n",
  148. " g = np.dot(x1.T, x1)\n",
  149. " return g\n",
  150. "\n",
  151. "# 风格损失函数,A为风格标准图片,G为训练后的结果图片\n",
  152. "def build_style_loss(a, x):\n",
  153. " M = a.shape[1]*a.shape[2]\n",
  154. " N = a.shape[3]\n",
  155. " A = gram_matrix_val(a, M, N )\n",
  156. " G = gram_matrix(x, M, N )\n",
  157. " loss = (1./(4 * N**2 * M**2)) * tf.reduce_sum(tf.pow((G - A),2))\n",
  158. " return loss\n",
  159. "\n",
  160. "\n",
  161. "# 读取图片函数,同时做白化\n",
  162. "def read_image(path):\n",
  163. " image = scipy.misc.imread(path)\n",
  164. " image = image[np.newaxis,:IMAGE_H,:IMAGE_W,:] \n",
  165. " image = image - MEAN_VALUES\n",
  166. " return image\n",
  167. "\n",
  168. "# 写图片函数\n",
  169. "def write_image(path, image):\n",
  170. " image = image + MEAN_VALUES\n",
  171. " image = image[0]\n",
  172. " image = np.clip(image, 0, 255).astype('uint8')\n",
  173. " scipy.misc.imsave(path, image)\n"
  174. ]
  175. },
  176. {
  177. "cell_type": "markdown",
  178. "metadata": {},
  179. "source": [
  180. "## 定义主函数"
  181. ]
  182. },
  183. {
  184. "cell_type": "code",
  185. "execution_count": 18,
  186. "metadata": {
  187. "collapsed": true
  188. },
  189. "outputs": [],
  190. "source": [
  191. "def main():\n",
  192. " net = build_vgg19(VGG_MODEL)\n",
  193. " sess = tf.Session()\n",
  194. " sess.run(tf.initialize_all_variables())\n",
  195. "# 建立一个纯噪音图片做为训练参数,使内容符合内容图片,而风格符合风格图片\n",
  196. " noise_img = np.random.uniform(-20, 20, (1, IMAGE_H, IMAGE_W, 3)).astype('float32')\n",
  197. " content_img = read_image(CONTENT_IMG)\n",
  198. " style_img = read_image(STYLE_IMG)\n",
  199. "# 将内容图片输入到VGG网络中,取出conv4_2层输出结果,计算内容损失\n",
  200. " sess.run([net['input'].assign(content_img)])\n",
  201. " cost_content = sum(map(lambda l,: l[1]*build_content_loss(sess.run(net[l[0]]) , net[l[0]])\n",
  202. " , CONTENT_LAYERS))\n",
  203. "# 将风格图片输入到VGG网络中,取出conv1_1-conv5_1五个层的输出结果,计算风格损失\n",
  204. " sess.run([net['input'].assign(style_img)])\n",
  205. " cost_style = sum(map(lambda l: l[1]*build_style_loss(sess.run(net[l[0]]) , net[l[0]])\n",
  206. " , STYLE_LAYERS))\n",
  207. "# 加总两种损失做为最小化训练目标,用cost_style做为调整系数\n",
  208. " cost_total = cost_content + STYLE_STRENGTH * cost_style\n",
  209. " optimizer = tf.train.AdamOptimizer(2.0)\n",
  210. "\n",
  211. " train = optimizer.minimize(cost_total)\n",
  212. " sess.run(tf.initialize_all_variables())\n",
  213. "# 把内容图片加噪音后,做为VGG网络输入层,算法将学习去调整这个输入层,来使得训练目标最小\n",
  214. " sess.run(net['input'].assign( INI_NOISE_RATIO* noise_img + (1.-INI_NOISE_RATIO) * content_img))\n",
  215. "\n",
  216. " if not os.path.exists(OUTOUT_DIR):\n",
  217. " os.mkdir(OUTOUT_DIR)\n",
  218. "\n",
  219. " for i in range(500):\n",
  220. " sess.run(train)\n",
  221. " print i\n",
  222. " if i%100 ==0:\n",
  223. " result_img = sess.run(net['input'])\n",
  224. " print sess.run(cost_total)\n",
  225. " write_image(os.path.join(OUTOUT_DIR,'%s.png'%(str(i).zfill(4))),result_img)\n",
  226. " \n",
  227. " write_image(os.path.join(OUTOUT_DIR,OUTPUT_IMG),result_img)\n"
  228. ]
  229. },
  230. {
  231. "cell_type": "code",
  232. "execution_count": null,
  233. "metadata": {
  234. "collapsed": true
  235. },
  236. "outputs": [],
  237. "source": [
  238. "main()"
  239. ]
  240. },
  241. {
  242. "cell_type": "code",
  243. "execution_count": null,
  244. "metadata": {
  245. "collapsed": true
  246. },
  247. "outputs": [],
  248. "source": []
  249. }
  250. ],
  251. "metadata": {
  252. "kernelspec": {
  253. "display_name": "Python 2",
  254. "language": "python",
  255. "name": "python2"
  256. },
  257. "language_info": {
  258. "codemirror_mode": {
  259. "name": "ipython",
  260. "version": 2
  261. },
  262. "file_extension": ".py",
  263. "mimetype": "text/x-python",
  264. "name": "python",
  265. "nbconvert_exporter": "python",
  266. "pygments_lexer": "ipython2",
  267. "version": "2.7.11"
  268. }
  269. },
  270. "nbformat": 4,
  271. "nbformat_minor": 0
  272. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement