Advertisement
Guest User

Untitled

a guest
Mar 24th, 2017
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.61 KB | None | 0 0
  1. import numpy as np
  2. import tensorflow as tf
  3. from tensorflow.contrib import layers
  4. from tensorflow.examples.tutorials.mnist import input_data
  5.  
  6.  
  7. batch_size = 64
  8. height = 28
  9. width = 28
  10. channel = 1
  11. f_channel = 100
  12.  
  13. num_ephocs = 5
  14. learning_rate = 0.003
  15.  
  16. g_summary = []
  17. d_summary = []
  18.  
  19.  
  20. # ---- Placeholder
  21. inputs = tf.placeholder(
  22. tf.float32, [batch_size, height, width, channel], name='inputs')
  23. z = tf.placeholder(tf.float32, [None, f_channel], name='z')
  24.  
  25.  
  26. # ---- Network (Generator)
  27. def generator(target):
  28. with tf.variable_scope('generator') as sc:
  29. h1, w1 = height // 4, width // 4
  30. h2, w2 = height // 2, width // 2
  31. d1, d2 = 256, 128
  32.  
  33. proj = layers.linear(target, h1*w1*d1, scope='projection')
  34. reshape_proj = tf.reshape(proj, [batch_size, h1, w1, d1])
  35.  
  36. deconv1 = layers.conv2d_transpose(reshape_proj, d2, [3, 3], [2, 2], 'SAME')
  37. deconv2 = layers.conv2d_transpose(deconv1, 1, [5, 5], [2, 2], 'SAME')
  38. g_summary.append(tf.summary.image('genearated', deconv2))
  39. return deconv2
  40.  
  41.  
  42. # ---- Network (Discriminator)
  43. def discriminator(inputs, reuse):
  44. with tf.variable_scope('discriminator') as sc:
  45. if reuse:
  46. sc.reuse_variables()
  47.  
  48. conv1 = layers.conv2d(inputs, 64, [3, 3], 2, scope='conv1')
  49. conv2 = layers.conv2d(conv1, 128, [3, 3], 2, scope='conv2')
  50. conv3 = layers.conv2d(conv2, 256, [3, 3], 2, scope='conv3')
  51. proj = layers.linear(
  52. tf.reshape(conv3, [batch_size, -1]), 1, scope='proj')
  53. final = tf.sigmoid(proj, name='final')
  54. # fake-likely --------------- real-likely
  55. return final
  56.  
  57.  
  58. # ---- Network
  59. G = generator(z)
  60.  
  61. D1 = discriminator(G, reuse=False)
  62. D2 = discriminator(inputs, reuse=True)
  63.  
  64.  
  65. # ---- Loss
  66. with tf.name_scope('discriminator'):
  67. d_loss = - tf.reduce_mean(
  68. tf.log(1 - D1) + tf.log(D2), name='d_loss')
  69. d_summary.append(tf.summary.scalar('d_loss', d_loss))
  70.  
  71. with tf.name_scope('generator'):
  72. g_loss = - tf.reduce_mean(tf.log(D1), name='g_loss')
  73. g_summary.append(tf.summary.scalar('g_loss', g_loss))
  74.  
  75.  
  76. # ---- Optimization
  77. with tf.variable_scope('optimizer'):
  78. d_step = tf.Variable(0, name='d_step', trainable=False)
  79. g_step = tf.Variable(0, name='g_step', trainable=False)
  80.  
  81. d_vars = list(filter(lambda x: x.name[0] == 'd', tf.trainable_variables()))
  82. g_vars = list(filter(lambda x: x.name[0] == 'g', tf.trainable_variables()))
  83.  
  84. d_opt = tf.train.AdamOptimizer(learning_rate, name='d_opt').minimize(d_loss, global_step=d_step, var_list=d_vars)
  85. g_opt = tf.train.AdamOptimizer(learning_rate, name='g_opt').minimize(g_loss, global_step=g_step, var_list=g_vars)
  86.  
  87.  
  88. # ---- Summary
  89. d_train_summary = tf.summary.merge(d_summary)
  90. g_train_summary = tf.summary.merge(g_summary)
  91.  
  92.  
  93. def get_log_directory_name():
  94. from datetime import datetime
  95. return './logs/%s_%s/' % (
  96. int(datetime.now().timestamp()), learning_rate)
  97.  
  98.  
  99. # ---- Start!
  100. session_config = tf.ConfigProto()
  101. session_config.gpu_options.allow_growth = True
  102.  
  103. with tf.Session(config=session_config) as sess:
  104. saver = tf.train.Saver()
  105. sess.run(tf.global_variables_initializer())
  106. summ_writer = tf.summary.FileWriter(get_log_directory_name(), sess.graph)
  107.  
  108. mnist = input_data.read_data_sets("MNIST_data/")
  109. imgs = np.reshape(mnist.train.images, [-1, 28, 28, 1]).astype(np.float32) / 255.0
  110. for epoch in range(num_ephocs):
  111. batch_counts = len(imgs) // batch_size
  112. np.random.shuffle(imgs)
  113.  
  114. for idx in range(batch_counts):
  115. feed_dict = {
  116. inputs: imgs[idx*batch_size:(idx+1)*batch_size],
  117. z: np.random.uniform(0, 1, size=(batch_size, f_channel))
  118. }
  119.  
  120. # ---- Discriminator Training
  121. fetch = {
  122. 'loss': d_loss,
  123. 'optim': d_opt,
  124. 'step': d_step,
  125. 'summary': d_train_summary
  126. }
  127.  
  128. d_result = sess.run(fetch, feed_dict=feed_dict)
  129. summ_writer.add_summary(d_result['summary'], d_result['step'])
  130. summ_writer.flush()
  131.  
  132. print('[%02d _ %4d/%04d] d_loss: %.8f' % (epoch, idx + 1, batch_counts, d_result['loss']))
  133.  
  134. # ---- Generator Training
  135. fetch = {
  136. 'loss': g_loss,
  137. 'optim': g_opt,
  138. 'step': g_step,
  139. 'summary': g_train_summary
  140. }
  141.  
  142. for _ in range(num_ephocs - epoch):
  143. g_result = sess.run(fetch, feed_dict=feed_dict)
  144. summ_writer.add_summary(g_result['summary'], g_result['step'])
  145. summ_writer.flush()
  146.  
  147. print('[%02d _ %4d/%04d] g_loss: %.8f' % (epoch, idx + 1, batch_counts, g_result['loss']))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement