Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 加载包"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Packages loaded\n"
- ]
- }
- ],
- "source": [
- "import scipy.io\n",
- "import numpy as np \n",
- "import os \n",
- "import scipy.misc \n",
- "import matplotlib.pyplot as plt \n",
- "import tensorflow as tf\n",
- "%matplotlib inline \n",
- "print (\"Packages loaded\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 定义网络结构"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "IMAGE_W = 800 \n",
- "IMAGE_H = 600 \n",
- "cwd = os.getcwd()\n",
- "# 内容图片文档\n",
- "CONTENT_IMG = cwd + \"/images/Taipei101.jpg\"\n",
- "# 风格图片文档\n",
- "STYLE_IMG = cwd + \"/images/StarryNight.jpg\"\n",
- "# 输出结果的目录和文档名\n",
- "OUTOUT_DIR = './images'\n",
- "OUTPUT_IMG = 'results.png'\n",
- "# VGG模型文件\n",
- "VGG_MODEL = cwd + \"/data/imagenet-vgg-verydeep-19.mat\"\n",
- "INI_NOISE_RATIO = 0.7\n",
- "STYLE_STRENGTH = 500\n",
- "ITERATION = 5000\n",
- "\n",
- "CONTENT_LAYERS =[('conv4_2',1.)]\n",
- "STYLE_LAYERS=[('conv1_1',1.),('conv2_1',1.5),('conv3_1',2.),('conv4_1',2.5),('conv5_1',3.)]\n",
- "\n",
- "\n",
- "MEAN_VALUES = np.array([123, 117, 104]).reshape((1,1,1,3))\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# 定义前向计算函数,如果是conv层则计算卷积,如果是pool则进行池化\n",
- "def build_net(ntype, nin, nwb=None):\n",
- " if ntype == 'conv':\n",
- " return tf.nn.relu(tf.nn.conv2d(nin, nwb[0], strides=[1, 1, 1, 1], padding='SAME')+ nwb[1])\n",
- " elif ntype == 'pool':\n",
- " return tf.nn.avg_pool(nin, ksize=[1, 2, 2, 1],\n",
- " strides=[1, 2, 2, 1], padding='SAME')\n",
- "\n",
- "# 从VGG模型中提取参数\n",
- "def get_weight_bias(vgg_layers, i,):\n",
- " weights = vgg_layers[i][0][0][0][0][0]\n",
- " weights = tf.constant(weights)\n",
- " bias = vgg_layers[i][0][0][0][0][1]\n",
- " bias = tf.constant(np.reshape(bias, (bias.size)))\n",
- " return weights, bias\n",
- "\n",
- "# 构建VGG模型网络结构,从现成的VGG模型文档中读取参数\n",
- "# 以conv1_1层参数为例,长下面这个样子\n",
- "# (<tf.Tensor 'Const_83:0' shape=(3, 3, 3, 64) dtype=float32>,\n",
- "# <tf.Tensor 'Const_84:0' shape=(64,) dtype=float32>)\n",
- "# conv1_1层输出长下面这个样子\n",
- "# <tf.Tensor 'Relu_32:0' shape=(1, 600, 800, 64) dtype=float32>\n",
- "\n",
- "def build_vgg19(path):\n",
- " net = {}\n",
- " vgg_rawnet = scipy.io.loadmat(path)\n",
- " vgg_layers = vgg_rawnet['layers'][0]\n",
- " net['input'] = tf.Variable(np.zeros((1, IMAGE_H, IMAGE_W, 3)).astype('float32'))\n",
- " net['conv1_1'] = build_net('conv',net['input'],get_weight_bias(vgg_layers,0))\n",
- " net['conv1_2'] = build_net('conv',net['conv1_1'],get_weight_bias(vgg_layers,2))\n",
- " net['pool1'] = build_net('pool',net['conv1_2'])\n",
- " net['conv2_1'] = build_net('conv',net['pool1'],get_weight_bias(vgg_layers,5))\n",
- " net['conv2_2'] = build_net('conv',net['conv2_1'],get_weight_bias(vgg_layers,7))\n",
- " net['pool2'] = build_net('pool',net['conv2_2'])\n",
- " net['conv3_1'] = build_net('conv',net['pool2'],get_weight_bias(vgg_layers,10))\n",
- " net['conv3_2'] = build_net('conv',net['conv3_1'],get_weight_bias(vgg_layers,12))\n",
- " net['conv3_3'] = build_net('conv',net['conv3_2'],get_weight_bias(vgg_layers,14))\n",
- " net['conv3_4'] = build_net('conv',net['conv3_3'],get_weight_bias(vgg_layers,16))\n",
- " net['pool3'] = build_net('pool',net['conv3_4'])\n",
- " net['conv4_1'] = build_net('conv',net['pool3'],get_weight_bias(vgg_layers,19))\n",
- " net['conv4_2'] = build_net('conv',net['conv4_1'],get_weight_bias(vgg_layers,21))\n",
- " net['conv4_3'] = build_net('conv',net['conv4_2'],get_weight_bias(vgg_layers,23))\n",
- " net['conv4_4'] = build_net('conv',net['conv4_3'],get_weight_bias(vgg_layers,25))\n",
- " net['pool4'] = build_net('pool',net['conv4_4'])\n",
- " net['conv5_1'] = build_net('conv',net['pool4'],get_weight_bias(vgg_layers,28))\n",
- " net['conv5_2'] = build_net('conv',net['conv5_1'],get_weight_bias(vgg_layers,30))\n",
- " net['conv5_3'] = build_net('conv',net['conv5_2'],get_weight_bias(vgg_layers,32))\n",
- " net['conv5_4'] = build_net('conv',net['conv5_3'],get_weight_bias(vgg_layers,34))\n",
- " net['pool5'] = build_net('pool',net['conv5_4'])\n",
- " return net\n",
- "\n",
- "# 内容损失函数\n",
- "def build_content_loss(p, x):\n",
- " M = p.shape[1]*p.shape[2]\n",
- " N = p.shape[3]\n",
- " loss = (1./(2* N**0.5 * M**0.5 )) * tf.reduce_sum(tf.pow((x - p),2)) \n",
- " return loss\n",
- "\n",
- "\n",
- "def gram_matrix(x, area, depth):\n",
- " x1 = tf.reshape(x,(area,depth))\n",
- " g = tf.matmul(tf.transpose(x1), x1)\n",
- " return g\n",
- "\n",
- "def gram_matrix_val(x, area, depth):\n",
- " x1 = x.reshape(area,depth)\n",
- " g = np.dot(x1.T, x1)\n",
- " return g\n",
- "\n",
- "# 风格损失函数,A为风格标准图片,G为训练后的结果图片\n",
- "def build_style_loss(a, x):\n",
- " M = a.shape[1]*a.shape[2]\n",
- " N = a.shape[3]\n",
- " A = gram_matrix_val(a, M, N )\n",
- " G = gram_matrix(x, M, N )\n",
- " loss = (1./(4 * N**2 * M**2)) * tf.reduce_sum(tf.pow((G - A),2))\n",
- " return loss\n",
- "\n",
- "\n",
- "# 读取图片函数,同时做白化\n",
- "def read_image(path):\n",
- " image = scipy.misc.imread(path)\n",
- " image = image[np.newaxis,:IMAGE_H,:IMAGE_W,:] \n",
- " image = image - MEAN_VALUES\n",
- " return image\n",
- "\n",
- "# 写图片函数\n",
- "def write_image(path, image):\n",
- " image = image + MEAN_VALUES\n",
- " image = image[0]\n",
- " image = np.clip(image, 0, 255).astype('uint8')\n",
- " scipy.misc.imsave(path, image)\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 定义主函数"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "def main():\n",
- " net = build_vgg19(VGG_MODEL)\n",
- " sess = tf.Session()\n",
- " sess.run(tf.initialize_all_variables())\n",
- "# 建立一个纯噪音图片做为训练参数,使内容符合内容图片,而风格符合风格图片\n",
- " noise_img = np.random.uniform(-20, 20, (1, IMAGE_H, IMAGE_W, 3)).astype('float32')\n",
- " content_img = read_image(CONTENT_IMG)\n",
- " style_img = read_image(STYLE_IMG)\n",
- "# 将内容图片输入到VGG网络中,取出conv4_2层输出结果,计算内容损失\n",
- " sess.run([net['input'].assign(content_img)])\n",
- " cost_content = sum(map(lambda l,: l[1]*build_content_loss(sess.run(net[l[0]]) , net[l[0]])\n",
- " , CONTENT_LAYERS))\n",
- "# 将风格图片输入到VGG网络中,取出conv1_1-conv5_1五个层的输出结果,计算风格损失\n",
- " sess.run([net['input'].assign(style_img)])\n",
- " cost_style = sum(map(lambda l: l[1]*build_style_loss(sess.run(net[l[0]]) , net[l[0]])\n",
- " , STYLE_LAYERS))\n",
- "# 加总两种损失做为最小化训练目标,用cost_style做为调整系数\n",
- " cost_total = cost_content + STYLE_STRENGTH * cost_style\n",
- " optimizer = tf.train.AdamOptimizer(2.0)\n",
- "\n",
- " train = optimizer.minimize(cost_total)\n",
- " sess.run(tf.initialize_all_variables())\n",
- "# 把内容图片加噪音后,做为VGG网络输入层,算法将学习去调整这个输入层,来使得训练目标最小\n",
- " sess.run(net['input'].assign( INI_NOISE_RATIO* noise_img + (1.-INI_NOISE_RATIO) * content_img))\n",
- "\n",
- " if not os.path.exists(OUTOUT_DIR):\n",
- " os.mkdir(OUTOUT_DIR)\n",
- "\n",
- " for i in range(500):\n",
- " sess.run(train)\n",
- " print i\n",
- " if i%100 ==0:\n",
- " result_img = sess.run(net['input'])\n",
- " print sess.run(cost_total)\n",
- " write_image(os.path.join(OUTOUT_DIR,'%s.png'%(str(i).zfill(4))),result_img)\n",
- " \n",
- " write_image(os.path.join(OUTOUT_DIR,OUTPUT_IMG),result_img)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "main()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "language": "python",
- "name": "python2"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.11"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement