Advertisement
Guest User

Untitled

a guest
Jul 19th, 2018
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 13.51 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Tue Jul  3 07:10:41 2018
  4.  
  5. @author: Mario
  6. """
  7. import os
  8.  
  9. import tensorflow as tf
  10. import numpy as np
  11. import cv2
  12. import random
  13. import scipy.misc
  14. from utils import *
  15.  
  16. slim = tf.contrib.slim #interface to contrib models
  17.  
  18. HEIGHT, WIDTH, CHANNEL = 128, 128, 3
  19. BATCH_SIZE = 64
  20. EPOCH = 5000
  21. version = 'newPokemon'
  22. newPoke_path = './' + version
  23.  
  24. def lrelu(x, n, leak=0.2): #leaky Relu Activation function
  25.     return tf.maximum(x, leak * x, name=n)
  26.  
  27. def process_data():  
  28.     current_dir = os.getcwd()
  29.     # parent = os.path.dirname(current_dir)
  30.     pokemon_dir = os.path.join(current_dir, 'data')
  31.     images = []
  32.     for each in os.listdir(pokemon_dir):
  33.         images.append(os.path.join(pokemon_dir,each))
  34.     # print images    
  35.     all_images = tf.convert_to_tensor(images, dtype = tf.string)
  36.    
  37.     images_queue = tf.train.slice_input_producer(
  38.                                         [all_images])#slices tensor and returns list of tensor
  39.                                        
  40.     content = tf.read_file(images_queue[0]) #read imgs
  41.     image = tf.image.decode_jpeg(content, channels = CHANNEL) #decode
  42.     # sess1 = tf.Session()
  43.     # print sess1.run(image)
  44.     image = tf.image.random_flip_left_right(image) #for better training
  45.     image = tf.image.random_brightness(image, max_delta = 0.1)
  46.     image = tf.image.random_contrast(image, lower = 0.9, upper = 1.1)
  47.     # noise = tf.Variable(tf.truncated_normal(shape = [HEIGHT,WIDTH,CHANNEL], dtype = tf.float32, stddev = 1e-3, name = 'noise'))
  48.     # print image.get_shape()
  49.     size = [HEIGHT, WIDTH]
  50.     image = tf.image.resize_images(image, size)
  51.     image.set_shape([HEIGHT,WIDTH,CHANNEL])
  52.     # image = image + noise
  53.     # image = tf.transpose(image, perm=[2, 0, 1])
  54.     # print image.get_shape()
  55.    
  56.     image = tf.cast(image, tf.float32) #make image floats
  57.     image = image / 255.0 #normalize float vals from 0-1
  58.    
  59.     iamges_batch = tf.train.shuffle_batch(
  60.                                     [image], batch_size = BATCH_SIZE,
  61.                                     num_threads = 4, capacity = 200 + 3* BATCH_SIZE,
  62.                                     min_after_dequeue = 200)
  63.     num_images = len(images)
  64.  
  65.     return iamges_batch, num_images
  66.  
  67. def generator(input_, random_dim, is_train, reuse=False):
  68.     c4, c8, c16, c32, c64 = 512, 256, 128, 64, 32 # channel num
  69.     s4 = 4
  70.     output_dim = CHANNEL  # RGB image
  71.     with tf.variable_scope('gen') as scope:
  72.         if reuse:
  73.             scope.reuse_variables()
  74.         w1 = tf.get_variable('w1', shape=[random_dim, s4 * s4 * c4], dtype=tf.float32,
  75.                              initializer=tf.truncated_normal_initializer(stddev=0.02))
  76.         b1 = tf.get_variable('b1', shape=[c4 * s4 * s4], dtype=tf.float32,
  77.                              initializer=tf.constant_initializer(0.0))
  78.         flat_conv1 = tf.add(tf.matmul(input_, w1), b1, name='flat_conv1')
  79.          #Convolution, bias, activation, repeat!
  80.         conv1 = tf.reshape(flat_conv1, shape=[-1, s4, s4, c4], name='conv1')
  81.         bn1 = tf.contrib.layers.batch_norm(conv1, is_training=is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope='bn1')
  82.         act1 = tf.nn.relu(bn1, name='act1')
  83.         # 8*8*256
  84.         #Convolution, bias, activation, repeat! pass pevrios activation
  85.         conv2 = tf.layers.conv2d_transpose(act1, c8, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
  86.                                            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
  87.                                            name='conv2')
  88.         bn2 = tf.contrib.layers.batch_norm(conv2, is_training=is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope='bn2')
  89.         act2 = tf.nn.relu(bn2, name='act2')
  90.         # 16*16*128
  91.         conv3 = tf.layers.conv2d_transpose(act2, c16, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
  92.                                            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
  93.                                            name='conv3')
  94.         bn3 = tf.contrib.layers.batch_norm(conv3, is_training=is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope='bn3')
  95.         act3 = tf.nn.relu(bn3, name='act3')
  96.         # 32*32*64
  97.         conv4 = tf.layers.conv2d_transpose(act3, c32, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
  98.                                            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
  99.                                            name='conv4')
  100.         bn4 = tf.contrib.layers.batch_norm(conv4, is_training=is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope='bn4')
  101.         act4 = tf.nn.relu(bn4, name='act4')
  102.         # 64*64*32
  103.         conv5 = tf.layers.conv2d_transpose(act4, c64, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
  104.                                            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
  105.                                            name='conv5')
  106.         bn5 = tf.contrib.layers.batch_norm(conv5, is_training=is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope='bn5')
  107.         act5 = tf.nn.relu(bn5, name='act5')
  108.        
  109.         #128*128*3
  110.         conv6 = tf.layers.conv2d_transpose(act5, output_dim, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
  111.                                            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
  112.                                            name='conv6')
  113.         # bn6 = tf.contrib.layers.batch_norm(conv6, is_training=is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope='bn6')
  114.         act6 = tf.nn.tanh(conv6, name='act6')
  115.         return act6
  116. def discriminator(input, is_train, reuse=False):
  117.     c2, c4, c8, c16 = 64, 128, 256, 512  # channel num: 64, 128, 256, 512
  118.     with tf.variable_scope('dis') as scope:
  119.         if reuse:
  120.             scope.reuse_variables()
  121.  
  122.         #Convolution, activation, bias, repeat!
  123.         conv1 = tf.layers.conv2d(input, c2, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
  124.                                  kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
  125.                                  name='conv1')
  126.         bn1 = tf.contrib.layers.batch_norm(conv1, is_training = is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope = 'bn1')
  127.         act1 = lrelu(conv1, n='act1')
  128.          #Convolution, activation, bias, repeat!
  129.         conv2 = tf.layers.conv2d(act1, c4, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
  130.                                  kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
  131.                                  name='conv2')
  132.         bn2 = tf.contrib.layers.batch_norm(conv2, is_training=is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope='bn2')
  133.         act2 = lrelu(bn2, n='act2')
  134.         #Convolution, activation, bias, repeat!
  135.         conv3 = tf.layers.conv2d(act2, c8, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
  136.                                  kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
  137.                                  name='conv3')
  138.         bn3 = tf.contrib.layers.batch_norm(conv3, is_training=is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope='bn3')
  139.         act3 = lrelu(bn3, n='act3')
  140.          #Convolution, activation, bias, repeat!
  141.         conv4 = tf.layers.conv2d(act3, c16, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
  142.                                  kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
  143.                                  name='conv4')
  144.         bn4 = tf.contrib.layers.batch_norm(conv4, is_training=is_train, epsilon=1e-5, decay = 0.9,  updates_collections=None, scope='bn4')
  145.         act4 = lrelu(bn4, n='act4')
  146.        
  147.         # start from act4
  148.         dim = int(np.prod(act4.get_shape()[1:]))
  149.         fc1 = tf.reshape(act4, shape=[-1, dim], name='fc1')
  150.      
  151.        
  152.         w2 = tf.get_variable('w2', shape=[fc1.shape[-1], 1], dtype=tf.float32,
  153.                              initializer=tf.truncated_normal_initializer(stddev=0.02))
  154.         b2 = tf.get_variable('b2', shape=[1], dtype=tf.float32,
  155.                              initializer=tf.constant_initializer(0.0))
  156.  
  157.         # wgan just get rid of the sigmoid
  158.         logits = tf.add(tf.matmul(fc1, w2), b2, name='logits')
  159.         # dcgan
  160.         acted_out = tf.nn.sigmoid(logits)
  161.         return logits #, acted_out
  162.  
  163. def train():
  164.     random_dim = 100 #input to generator
  165.    
  166.     with tf.variable_scope('input'):
  167.         #real and fake image placholders
  168.         real_image = tf.placeholder(tf.float32, shape = [None, HEIGHT, WIDTH, CHANNEL], name='real_image')
  169.         random_input = tf.placeholder(tf.float32, shape=[None, random_dim], name='rand_input')
  170.         is_train = tf.placeholder(tf.bool, name='is_train')
  171.    
  172.    
  173.     fake_image = generator(random_input, random_dim, is_train)#returns fake img
  174.    
  175.     real_result = discriminator(real_image, is_train) #result of real img
  176.     fake_result = discriminator(fake_image, is_train, reuse=True)#result of fake img
  177.    
  178.     d_loss = tf.reduce_mean(fake_result) - tf.reduce_mean(real_result)  # This optimizes the discriminator.
  179.     g_loss = -tf.reduce_mean(fake_result)  # This optimizes the generator.
  180.            
  181.  
  182.     t_vars = tf.trainable_variables() #collect all training variables
  183.     d_vars = [var for var in t_vars if 'dis' in var.name]#disc vars
  184.     g_vars = [var for var in t_vars if 'gen' in var.name]#gen vars
  185.     #lets do some learning
  186.     trainer_d = tf.train.RMSPropOptimizer(learning_rate=2e-4).minimize(d_loss, var_list=d_vars)
  187.     trainer_g = tf.train.RMSPropOptimizer(learning_rate=2e-4).minimize(g_loss, var_list=g_vars)
  188.     # clip discriminator weights
  189.     d_clip = [v.assign(tf.clip_by_value(v, -0.01, 0.01)) for v in d_vars]
  190.  
  191.    
  192.     batch_size = BATCH_SIZE
  193.     image_batch, samples_num = process_data()
  194.    
  195.     batch_num = int(samples_num / batch_size)
  196.     total_batch = 0
  197.     sess = tf.Session()
  198.     saver = tf.train.Saver()
  199.     sess.run(tf.global_variables_initializer())
  200.     sess.run(tf.local_variables_initializer())
  201.     # continue training
  202.     save_path = saver.save(sess, "/tmp/model.ckpt")
  203.     ckpt = tf.train.latest_checkpoint('./model/' + version)
  204.     saver.restore(sess, save_path)
  205.     coord = tf.train.Coordinator()
  206.     threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  207.  
  208.     print('total training sample num:%d' % samples_num)
  209.     print('batch size: %d, batch num per epoch: %d, epoch num: %d' % (batch_size, batch_num, EPOCH))
  210.     print('start training...')
  211.     for i in range(EPOCH):
  212.         print("Running epoch {}/{}...".format(i, EPOCH))
  213.         for j in range(batch_num):
  214.             print(j)
  215.             d_iters = 5
  216.             g_iters = 1
  217.  
  218.             train_noise = np.random.uniform(-1.0, 1.0, size=[batch_size, random_dim]).astype(np.float32)
  219.             for k in range(d_iters):
  220.                 print(k)
  221.                 train_image = sess.run(image_batch)
  222.                 #wgan clip weights
  223.                 sess.run(d_clip)
  224.                
  225.                 # Update the discriminator
  226.                 _, dLoss = sess.run([trainer_d, d_loss],
  227.                                     feed_dict={random_input: train_noise, real_image: train_image, is_train: True})
  228.  
  229.             # Update the generator
  230.             for k in range(g_iters):
  231.                 # train_noise = np.random.uniform(-1.0, 1.0, size=[batch_size, random_dim]).astype(np.float32)
  232.                 _, gLoss = sess.run([trainer_g, g_loss],
  233.                                     feed_dict={random_input: train_noise, is_train: True})
  234.  
  235.             # print 'train:[%d/%d],d_loss:%f,g_loss:%f' % (i, j, dLoss, gLoss)
  236.            
  237.         # save check point every 500 epoch
  238.         if i%500 == 0:
  239.             if not os.path.exists('./model/' + version):
  240.                 os.makedirs('./model/' + version)
  241.             saver.save(sess, './model/' +version + '/' + str(i))  
  242.         if i%50 == 0:
  243.             # save images
  244.             if not os.path.exists(newPoke_path):
  245.                 os.makedirs(newPoke_path)
  246.             sample_noise = np.random.uniform(-1.0, 1.0, size=[batch_size, random_dim]).astype(np.float32)
  247.             imgtest = sess.run(fake_image, feed_dict={random_input: sample_noise, is_train: False})
  248.             # imgtest = imgtest * 255.0
  249.             # imgtest.astype(np.uint8)
  250.             save_images(imgtest, [8,8] ,newPoke_path + '/epoch' + str(i) + '.jpg')
  251.            
  252.             print('train:[%d],d_loss:%f,g_loss:%f' % (i, dLoss, gLoss))
  253.     coord.request_stop()
  254.     coord.join(threads)
  255.  
  256.  
  257. # def test():
  258.     # random_dim = 100
  259.     # with tf.variable_scope('input'):
  260.         # real_image = tf.placeholder(tf.float32, shape = [None, HEIGHT, WIDTH, CHANNEL], name='real_image')
  261.         # random_input = tf.placeholder(tf.float32, shape=[None, random_dim], name='rand_input')
  262.         # is_train = tf.placeholder(tf.bool, name='is_train')
  263.    
  264.     # # wgan
  265.     # fake_image = generator(random_input, random_dim, is_train)
  266.     # real_result = discriminator(real_image, is_train)
  267.     # fake_result = discriminator(fake_image, is_train, reuse=True)
  268.     # sess = tf.InteractiveSession()
  269.     # sess.run(tf.global_variables_initializer())
  270.     # variables_to_restore = slim.get_variables_to_restore(include=['gen'])
  271.     # print(variables_to_restore)
  272.     # saver = tf.train.Saver(variables_to_restore)
  273.     # ckpt = tf.train.latest_checkpoint('./model/' + version)
  274.     # saver.restore(sess, ckpt)
  275.  
  276.  
  277. if __name__ == "__main__":
  278.     train()
  279.     # test()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement