Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import tensorflow as tf
- from tensorflow.contrib import layers
- from tensorflow.examples.tutorials.mnist import input_data
- batch_size = 64
- height = 28
- width = 28
- channel = 1
- f_channel = 100
- num_ephocs = 5
- learning_rate = 0.003
- g_summary = []
- d_summary = []
- # ---- Placeholder
- inputs = tf.placeholder(
- tf.float32, [batch_size, height, width, channel], name='inputs')
- z = tf.placeholder(tf.float32, [None, f_channel], name='z')
- # ---- Network (Generator)
- def generator(target):
- with tf.variable_scope('generator') as sc:
- h1, w1 = height // 4, width // 4
- h2, w2 = height // 2, width // 2
- d1, d2 = 256, 128
- proj = layers.linear(target, h1*w1*d1, scope='projection')
- reshape_proj = tf.reshape(proj, [batch_size, h1, w1, d1])
- deconv1 = layers.conv2d_transpose(reshape_proj, d2, [3, 3], [2, 2], 'SAME')
- deconv2 = layers.conv2d_transpose(deconv1, 1, [5, 5], [2, 2], 'SAME')
- g_summary.append(tf.summary.image('genearated', deconv2))
- return deconv2
- # ---- Network (Discriminator)
- def discriminator(inputs, reuse):
- with tf.variable_scope('discriminator') as sc:
- if reuse:
- sc.reuse_variables()
- conv1 = layers.conv2d(inputs, 64, [3, 3], 2, scope='conv1')
- conv2 = layers.conv2d(conv1, 128, [3, 3], 2, scope='conv2')
- conv3 = layers.conv2d(conv2, 256, [3, 3], 2, scope='conv3')
- proj = layers.linear(
- tf.reshape(conv3, [batch_size, -1]), 1, scope='proj')
- final = tf.sigmoid(proj, name='final')
- # fake-likely --------------- real-likely
- return final
- # ---- Network
- G = generator(z)
- D1 = discriminator(G, reuse=False)
- D2 = discriminator(inputs, reuse=True)
- # ---- Loss
- with tf.name_scope('discriminator'):
- d_loss = - tf.reduce_mean(
- tf.log(1 - D1) + tf.log(D2), name='d_loss')
- d_summary.append(tf.summary.scalar('d_loss', d_loss))
- with tf.name_scope('generator'):
- g_loss = - tf.reduce_mean(tf.log(D1), name='g_loss')
- g_summary.append(tf.summary.scalar('g_loss', g_loss))
- # ---- Optimization
- with tf.variable_scope('optimizer'):
- d_step = tf.Variable(0, name='d_step', trainable=False)
- g_step = tf.Variable(0, name='g_step', trainable=False)
- d_vars = list(filter(lambda x: x.name[0] == 'd', tf.trainable_variables()))
- g_vars = list(filter(lambda x: x.name[0] == 'g', tf.trainable_variables()))
- d_opt = tf.train.AdamOptimizer(learning_rate, name='d_opt').minimize(d_loss, global_step=d_step, var_list=d_vars)
- g_opt = tf.train.AdamOptimizer(learning_rate, name='g_opt').minimize(g_loss, global_step=g_step, var_list=g_vars)
- # ---- Summary
- d_train_summary = tf.summary.merge(d_summary)
- g_train_summary = tf.summary.merge(g_summary)
- def get_log_directory_name():
- from datetime import datetime
- return './logs/%s_%s/' % (
- int(datetime.now().timestamp()), learning_rate)
- # ---- Start!
- session_config = tf.ConfigProto()
- session_config.gpu_options.allow_growth = True
- with tf.Session(config=session_config) as sess:
- saver = tf.train.Saver()
- sess.run(tf.global_variables_initializer())
- summ_writer = tf.summary.FileWriter(get_log_directory_name(), sess.graph)
- mnist = input_data.read_data_sets("MNIST_data/")
- imgs = np.reshape(mnist.train.images, [-1, 28, 28, 1]).astype(np.float32) / 255.0
- for epoch in range(num_ephocs):
- batch_counts = len(imgs) // batch_size
- np.random.shuffle(imgs)
- for idx in range(batch_counts):
- feed_dict = {
- inputs: imgs[idx*batch_size:(idx+1)*batch_size],
- z: np.random.uniform(0, 1, size=(batch_size, f_channel))
- }
- # ---- Discriminator Training
- fetch = {
- 'loss': d_loss,
- 'optim': d_opt,
- 'step': d_step,
- 'summary': d_train_summary
- }
- d_result = sess.run(fetch, feed_dict=feed_dict)
- summ_writer.add_summary(d_result['summary'], d_result['step'])
- summ_writer.flush()
- print('[%02d _ %4d/%04d] d_loss: %.8f' % (epoch, idx + 1, batch_counts, d_result['loss']))
- # ---- Generator Training
- fetch = {
- 'loss': g_loss,
- 'optim': g_opt,
- 'step': g_step,
- 'summary': g_train_summary
- }
- for _ in range(num_ephocs - epoch):
- g_result = sess.run(fetch, feed_dict=feed_dict)
- summ_writer.add_summary(g_result['summary'], g_result['step'])
- summ_writer.flush()
- print('[%02d _ %4d/%04d] g_loss: %.8f' % (epoch, idx + 1, batch_counts, g_result['loss']))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement