Advertisement
lamiastella

basic DCGAN on CelebA

Nov 16th, 2018
495
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 19.08 KB | None | 0 0
  1. from __future__ import print_function
  2. import random
  3. import os
  4. import glob
  5. import scipy
  6.  
  7. import tensorflow as tf
  8. import numpy as np
  9. from PIL import Image
  10. import skimage.io as io
  11. import matplotlib.pyplot as plt
  12.  
  13.  
  14. class Arguments(object):
  15.  
  16.     data_path = 'results_celebA/preprocessed/'
  17.     save_path = 'results_celebA'                           #path to save preprocessed image folder
  18.     preproc_foldername = 'preprocessed'      #folder name for preprocessed images
  19.     image_size = 64                          #images are resized to image_size value
  20.     num_images = 202590                      #the number of training images
  21.     batch_size = 64                          #batch size
  22.     dim_z = 100                              #the dimension of z variable (the generator input dimension)        
  23.     n_g_filters = 64                         #the number of the generator filters (gets multiplied between layers)
  24.     n_f_filters = 64                         #the number of the discriminator filters (gets multiplied between layers)          
  25.     n_epoch = 25                             #the number of epochs
  26.     lr = 0.0002                              #learning rate
  27.     beta1 = 0.5                              #beta_1 parameter of Adam optimizer
  28.     beta2 = 0.99                             #beta_2 parameter of Adam optimizer
  29.  
  30. args = Arguments()
  31.  
  32.  
  33. #contains functions that load, preprocess and visualize images.
  34.  
  35.  
  36. class Dataset(object):    
  37.     def __init__(self, data_path, num_imgs, target_imgsize):
  38.         self.data_path = data_path
  39.         self.num_imgs = num_imgs
  40.         self.target_imgsize = target_imgsize
  41.    
  42.     def normalize_np_image(self, image):
  43.         return (image / 255.0 - 0.5) / 0.5
  44.    
  45.     def denormalize_np_image(self, image):
  46.         return (image * 0.5 + 0.5) * 255
  47.    
  48.     def get_input(self, image_path):
  49.         image = np.array(Image.open(image_path)).astype(np.float32)
  50.         return self.normalize_np_image(image)
  51.    
  52.     def get_imagelist(self, data_path, celebA=False):
  53.         if celebA == True:
  54.             imgs_path = os.path.join(data_path, 'img_align_celeba/*.jpg')
  55.         else:
  56.             imgs_path = os.path.join(data_path, '*.jpg')
  57.         all_namelist = glob.glob(imgs_path, recursive=True)
  58.         return all_namelist[:self.num_imgs]
  59.    
  60.     def load_and_preprocess_image(self, image_path):
  61.         image = Image.open(image_path)
  62.         j = (image.size[0] - 100) // 2
  63.         i = (image.size[1] - 100) // 2
  64.         image = image.crop([j, i, j + 100, i + 100])    
  65.         image = image.resize([self.target_imgsize, self.target_imgsize], Image.BILINEAR)
  66.         image = np.array(image.convert('RGB')).astype(np.float32)
  67.         image = self.normalize_np_image(image)
  68.         return image    
  69.    
  70.     #reads data, preprocesses and saves to another folder with the given path.
  71.     def preprocess_and_save_images(self, dir_name, save_path=''):
  72.         preproc_folder_path = os.path.join(save_path, dir_name)
  73.         if not os.path.exists(preproc_folder_path):
  74.             os.makedirs(preproc_folder_path)  
  75.             imgs_path = os.path.join(self.data_path, 'img_align_celeba/*.jpg')
  76.             print('Saving and preprocessing images ...')
  77.             for num, imgname in enumerate(glob.iglob(imgs_path, recursive=True)):
  78.                 cur_image = self.load_and_preprocess_image(imgname)
  79.                 cur_image = Image.fromarray(np.uint8(self.denormalize_np_image(cur_image)))
  80.                 cur_image.save(preproc_folder_path + '/preprocessed_image_%d.jpg' %(num))
  81.         self.data_path= preproc_folder_path
  82.            
  83.     def get_nextbatch(self, batch_size):
  84.         print("nextbatch batchsize is: ", batch_size)
  85.         assert (batch_size > 0),"Give a valid batch size"
  86.         cur_idx = 0
  87.         image_namelist = self.get_imagelist(self.data_path)
  88.         while cur_idx + batch_size <= self.num_imgs:
  89.             cur_namelist = image_namelist[cur_idx:cur_idx + batch_size]
  90.             cur_batch = [self.get_input(image_path) for image_path in cur_namelist]
  91.             cur_batch = np.array(cur_batch).astype(np.float32)
  92.             cur_idx += batch_size
  93.             yield cur_batch
  94.      
  95.     def show_image(self, image, normalized=True):
  96.         if not type(image).__module__ == np.__name__:
  97.             image = image.numpy()
  98.         if normalized:
  99.             npimg = (image * 0.5) + 0.5
  100.         npimg.astype(np.uint8)
  101.         plt.imshow(npimg, interpolation='nearest')
  102.  
  103.  
  104. #contains functions that load, preprocess and visualize images.
  105.  
  106. class Dataset(object):    
  107.     def __init__(self, data_path, num_imgs, target_imgsize):
  108.         self.data_path = data_path
  109.         self.num_imgs = num_imgs
  110.         self.target_imgsize = target_imgsize
  111.    
  112.     def normalize_np_image(self, image):
  113.         return (image / 255.0 - 0.5) / 0.5
  114.    
  115.     def denormalize_np_image(self, image):
  116.         return (image * 0.5 + 0.5) * 255
  117.    
  118.     def get_input(self, image_path):
  119.         image = np.array(Image.open(image_path)).astype(np.float32)
  120.         return self.normalize_np_image(image)
  121.    
  122.     def get_imagelist(self, data_path, celebA=False):
  123.         if celebA == True:
  124.             imgs_path = os.path.join(data_path, 'img_align_celeba/*.jpg')
  125.         else:
  126.             imgs_path = os.path.join(data_path, '*.jpg')
  127.  
  128.         all_namelist = glob.glob(imgs_path, recursive=True)
  129.         return all_namelist[:self.num_imgs]
  130.    
  131.     def load_and_preprocess_image(self, image_path):
  132.         image = Image.open(image_path)
  133.         j = (image.size[0] - 100) // 2
  134.         i = (image.size[1] - 100) // 2
  135.         image = image.crop([j, i, j + 100, i + 100])    
  136.         image = image.resize([self.target_imgsize, self.target_imgsize], Image.BILINEAR)
  137.         image = np.array(image.convert('RGB')).astype(np.float32)
  138.         image = self.normalize_np_image(image)
  139.         return image    
  140.    
  141.     #reads data, preprocesses and saves to another folder with the given path.
  142.     def preprocess_and_save_images(self, dir_name, save_path=''):
  143.         preproc_folder_path = os.path.join(save_path, dir_name)
  144.         if not os.path.exists(preproc_folder_path):
  145.             os.makedirs(preproc_folder_path)  
  146.             imgs_path = os.path.join(self.data_path, 'img_align_celeba/*.jpg')
  147.             print('Saving and preprocessing images ...')
  148.             for num, imgname in enumerate(glob.iglob(imgs_path, recursive=True)):
  149.                 cur_image = self.load_and_preprocess_image(imgname)
  150.                 cur_image = Image.fromarray(np.uint8(self.denormalize_np_image(cur_image)))
  151.                 cur_image.save(preproc_folder_path + '/preprocessed_image_%d.jpg' %(num))
  152.         self.data_path= preproc_folder_path
  153.            
  154.     def get_nextbatch(self, batch_size):
  155.         assert (batch_size > 0),"Give a valid batch size"
  156.         cur_idx = 0
  157.         image_namelist = self.get_imagelist(self.data_path)
  158.         while cur_idx + batch_size <= self.num_imgs:
  159.             cur_namelist = image_namelist[cur_idx:cur_idx + batch_size]
  160.             cur_batch = [self.get_input(image_path) for image_path in cur_namelist]
  161.             cur_batch = np.array(cur_batch).astype(np.float32)
  162.             cur_idx += batch_size
  163.             yield cur_batch
  164.      
  165.     def show_image(self, image, normalized=True):
  166.         if not type(image).__module__ == np.__name__:
  167.             image = image.numpy()
  168.         if normalized:
  169.             npimg = (image * 0.5) + 0.5
  170.         npimg.astype(np.uint8)
  171.         plt.imshow(npimg, interpolation='nearest')
  172.  
  173.  
  174.  
  175. def generator(x, args, reuse=False):
  176.     with tf.device('/gpu:0'):
  177.         with tf.variable_scope("generator", reuse=reuse):
  178.             #Layer Block 1
  179.             with tf.variable_scope("layer1"):
  180.                 deconv1 = tf.layers.conv2d_transpose(inputs=x,
  181.                                              filters= args.n_g_filters*8,
  182.                                              kernel_size=4,
  183.                                              strides=1,
  184.                                              padding='valid',
  185.                                              use_bias=False,
  186.                                              name='deconv')
  187.                 batch_norm1=tf.layers.batch_normalization(deconv1,
  188.                                              name = 'batch_norm')
  189.                 relu1 = tf.nn.relu(batch_norm1, name='relu')
  190.             #Layer Block 2
  191.             with tf.variable_scope("layer2"):
  192.                 deconv2 = tf.layers.conv2d_transpose(inputs=relu1,
  193.                                              filters=args.n_g_filters*4,
  194.                                              kernel_size=4,
  195.                                              strides=2,
  196.                                              padding='same',
  197.                                              use_bias=False,
  198.                                              name='deconv')
  199.                 batch_norm2 = tf.layers.batch_normalization(deconv2,
  200.                                              name = 'batch_norm')
  201.                 relu2 = tf.nn.relu(batch_norm2, name='relu')
  202.             #Layer Block 3
  203.             with tf.variable_scope("layer3"):
  204.                 deconv3 = tf.layers.conv2d_transpose(inputs=relu2,
  205.                                              filters=args.n_g_filters*2,
  206.                                              kernel_size=4,
  207.                                              strides=2,
  208.                                              padding='same',
  209.                                              use_bias = False,
  210.                                              name='deconv')
  211.                 batch_norm3 = tf.layers.batch_normalization(deconv3,
  212.                                              name = 'batch_norm')
  213.                 relu3 = tf.nn.relu(batch_norm3, name='relu')
  214.             #Layer Block 4
  215.             with tf.variable_scope("layer4"):
  216.                 deconv4 = tf.layers.conv2d_transpose(inputs=relu3,
  217.                                              filters=args.n_g_filters,
  218.                                              kernel_size=4,
  219.                                              strides=2,
  220.                                              padding='same',
  221.                                              use_bias=False,
  222.                                              name='deconv')
  223.                 batch_norm4 = tf.layers.batch_normalization(deconv4,
  224.                                              name = 'batch_norm')
  225.                 relu4 = tf.nn.relu(batch_norm4, name='relu')
  226.             #Output Layer
  227.             with tf.variable_scope("last_layer"):
  228.                 logit = tf.layers.conv2d_transpose(inputs=relu4,
  229.                                              filters=3,
  230.                                              kernel_size=4,
  231.                                              strides=2,
  232.                                              padding='same',
  233.                                              use_bias=False,
  234.                                              name='logit')
  235.                 output = tf.nn.tanh(logit)
  236.     return output, logit
  237.  
  238.  
  239.  
  240. def discriminator(x, args, reuse=False):
  241.     with tf.device('/gpu:0'):
  242.         with tf.variable_scope("discriminator", reuse=reuse):
  243.             with tf.variable_scope("layer1"):
  244.                 conv1 = tf.layers.conv2d(inputs=x,
  245.                                          filters=args.n_f_filters,
  246.                                          kernel_size=4,
  247.                                          strides=2,
  248.                                          padding='same',
  249.                                          use_bias=False,
  250.                                          name='conv')
  251.                 relu1 = tf.nn.leaky_relu(conv1, alpha=0.2, name='relu')
  252.             with tf.variable_scope("layer2"):
  253.                 conv2 = tf.layers.conv2d(inputs=relu1,
  254.                                          filters=args.n_f_filters*2,
  255.                                          kernel_size=4,
  256.                                          strides=2,
  257.                                          padding='same',
  258.                                          use_bias=False,
  259.                                          name='conv')
  260.                 batch_norm2 = tf.layers.batch_normalization(conv2,name='batch_norm')
  261.                 relu2 = tf.nn.leaky_relu(batch_norm2, alpha=0.2, name='relu')
  262.             with tf.variable_scope("layer3"):
  263.                 conv3 = tf.layers.conv2d(inputs=relu2,
  264.                                          filters=args.n_f_filters*4,
  265.                                          kernel_size=4,
  266.                                          strides=2,
  267.                                          padding='same',
  268.                                          use_bias=False,
  269.                                          name='conv')
  270.                 batch_norm3 = tf.layers.batch_normalization(conv3, name='batch_norm')
  271.                 relu3 = tf.nn.leaky_relu(batch_norm3, name='relu')
  272.             with tf.variable_scope("layer4"):
  273.                 conv4 = tf.layers.conv2d(inputs=relu3,
  274.                                          filters=args.n_f_filters*8,
  275.                                          kernel_size=4,
  276.                                          strides=2,
  277.                                          padding='same',
  278.                                          use_bias=False,
  279.                                          name='conv')
  280.                 batch_norm4 = tf.layers.batch_normalization(conv4, name='batch_norm')
  281.                 relu4 = tf.nn.leaky_relu(batch_norm4, alpha=0.2, name='relu')
  282.             with tf.variable_scope("last_layer"):
  283.                 logit = tf.layers.conv2d(inputs=relu4,
  284.                                          filters=1,
  285.                                          kernel_size=4,
  286.                                          strides=1,
  287.                                          padding='valid',
  288.                                          use_bias=False,
  289.                                          name='conv')
  290.                 output = tf.nn.sigmoid(logit)
  291.     return output, logit
  292.  
  293.  
  294.  
  295. def sample_z(dim_z, num_batch):
  296.     mu = 0
  297.     sigma = 1
  298.     s = np.random.normal(mu, sigma, num_batch*dim_z)
  299.     samples = s.reshape(num_batch, 1, 1, dim_z)
  300.     ##dist = tf.distributions.Normal(0.0, 1.0)
  301.     ##samples = dist.sample([num_batch, 1, 1, dim_z])
  302.     return samples
  303. #64,1,1,100  6400
  304. sample_z(100, 64)
  305.  
  306.  
  307.  
  308. def get_losses(d_real_logits, d_fake_logits):
  309.     #add new loss function here  
  310.     ###d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_real_logits, labels=tf.ones_like(d_real_logits)))    
  311.     ###d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits, labels=tf.zeros_like(d_fake_logits)))
  312.     ###d_loss = d_loss_real + d_loss_fake
  313.     ###g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits, labels=tf.ones_like(d_fake_logits)))
  314.     ###return d_loss, g_loss
  315.     d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_real_logits,labels=tf.ones_like(d_real_logits)) + tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits,labels=tf.zeros_like(d_fake_logits)))
  316.    
  317.     g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits,labels=tf.ones_like(d_fake_logits)))
  318.     return d_loss, g_loss
  319.  
  320.  
  321.  
  322. def get_optimizers(learning_rate, beta1, beta2):
  323.     d_optimizer = tf.train.AdamOptimizer(learning_rate, beta1, beta2)
  324.     g_optimizer = tf.train.AdamOptimizer(learning_rate, beta1, beta2)
  325.     return d_optimizer, g_optimizer
  326.  
  327.  
  328. def optimize(d_optimizer, g_optimizer, d_loss, g_loss):
  329.     d_step = d_optimizer.minimize(d_loss)
  330.     g_step = g_optimizer.minimize(g_loss)
  331.     return d_step, g_step
  332.  
  333.  
  334. LOGDIR = "logs_basic_dcgan"
  335.  
  336. def merge_images(image_batch, size):
  337.     h,w = image_batch.shape[1], image_batch.shape[2]
  338.     c = image_batch.shape[3]
  339.     img = np.zeros((int(h*size[0]), w*size[1], c))
  340.     for idx, im in enumerate(image_batch):
  341.         i = idx % size[1]
  342.         j = idx // size[1]
  343.         img[j*h:j*h+h, i*w:i*w+w,:] = im
  344.     return img
  345. itr_fh = open('basic_gan_itr.txt', 'a+')
  346.  
  347. def train(args):
  348.     tf.reset_default_graph()
  349.     data_loader = Dataset(args.data_path, args.num_images, args.image_size)
  350.     #data_loader.preprocess_and_save_images('preprocessed', 'results_celebA') #preprocess the images once
  351.     X = tf.placeholder(tf.float32, shape=[args.batch_size, args.image_size , args.image_size, 3])
  352.     Z = tf.placeholder(tf.float32, shape=[args.batch_size, 1, 1, args.dim_z])
  353.    
  354.     G_sample, _ = generator(Z, args)
  355.     D_real, D_real_logits = discriminator(X, args)
  356.     D_fake, D_fake_logits = discriminator(G_sample, args, reuse=True)
  357.     d_loss, g_loss = get_losses(D_real_logits, D_fake_logits)
  358.     d_optimizer, g_optimizer = get_optimizers(args.lr, args.beta1, args.beta2)
  359.     d_step, g_step = optimize(d_optimizer, g_optimizer, d_loss, g_loss)
  360.     ###z_sum = tf.summary.histogram('z', Z)
  361.     ###d_sum = tf.summary.histogram('d', D_real)
  362.     ###G_sum = tf.summary.histogram('g', G_sample)
  363.     ###d_loss_sum = tf.summary.scalar('d_loss', d_loss)
  364.     ###g_loss_sum = tf.summary.scalar('g_loss', g_loss)
  365.     ###d_sum = tf.summary.merge([z_sum, d_sum, d_loss_sum])
  366.     ###g_sum = tf.summary.merge([z_sum, G_sum, g_loss_sum])
  367.     ###saver = tf.train.Saver()
  368.     ###merged_summary = tf.summary.merge_all()
  369.  
  370.     ###d_loss_summary = tf.summary.scalar("Discriminator_Total_Loss", d_loss)
  371.     ###g_loss_summary = tf.summary.scalar("Generator_Total_Loss", g_loss)
  372.     ###merged_summary = tf.summary.merge_all()
  373.    
  374.     with tf.Session() as sess:
  375.  
  376.         sess.run(tf.global_variables_initializer())
  377.         for epoch in range(args.n_epoch):
  378.             for itr, real_batch in enumerate(data_loader.get_nextbatch(args.batch_size)):
  379.                 print('itr is %d, and epoch is %d' %(itr, epoch))
  380.                 itr_fh.write("epoch: " +  str(epoch) + " itr: " + str(itr) + "\n")
  381.        
  382.                 Z_sample = sample_z(args.dim_z, args.batch_size)
  383.                
  384.                 _, _ = sess.run([d_step, g_step], feed_dict={X:real_batch , Z:Z_sample})
  385.                 sample = sess.run(G_sample, feed_dict={Z:Z_sample})
  386.                 print("sample size is: ", sample.shape)
  387.                 if itr==3164: #num_images/batch_size
  388.                     im_merged = merge_images(sample[:16], [4,4])
  389.                     plt.imsave('sample_gan_images/im_merged_epoch_%d.png' %(epoch), im_merged )
  390.                     scipy.misc.imsave('sample_gan_images/im_epoch_%d_itr_%d.png' %(epoch,itr), sample[1])
  391.                     ##merged_summary = sess.run(merged_summary, feed_dict={X:real_batch , Z:Z_sample})
  392.                     ###writer = tf.summary.FileWriter(LOGDIR)  
  393.                     ###writer.add_summary(merged_summary, itr)
  394.                     ###d_loss_summary = tf.summary.scalar("Discriminator_Total_Loss", d_loss)
  395.                     ###g_loss_summary = tf.summary.scalar("Generator_Total_Loss", g_loss)
  396.                     ###merged_summary = tf.summary.merge_all()
  397.                     ###writer.add_graph(sess.graph)
  398.                     ###saver.save(sess, save_path='logs_basic_dcgan/gan.ckpt')
  399.  
  400.            
  401. train(args)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement