Advertisement
paradox64ce

ClusterGAN(OG Imp.)

Sep 22nd, 2020
1,736
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 13.60 KB | None | 0 0
  1. import os
  2. import time
  3. import dateutil.tz
  4. import datetime
  5. import argparse
  6. import importlib
  7. import tensorflow as tf
  8. from scipy.misc import imsave
  9. import numpy as np
  10. from sklearn.cluster import KMeans
  11. from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score
  12.  
  13. import metric
  14. from visualize import *
  15. import util
  16.  
  17. tf.set_random_seed(0)
  18.  
  19. class clusGAN(object):
  20.     def __init__(self, g_net, d_net, enc_net, x_sampler, z_sampler, data, model, sampler,
  21.                  num_classes, dim_gen, n_cat, batch_size, beta_cycle_gen, beta_cycle_label):
  22.         self.model = model
  23.         self.data = data
  24.         self.sampler = sampler
  25.         self.g_net = g_net
  26.         self.d_net = d_net
  27.         self.enc_net = enc_net
  28.         self.x_sampler = x_sampler
  29.         self.z_sampler = z_sampler
  30.         self.num_classes = num_classes
  31.         self.dim_gen = dim_gen
  32.         self.n_cat = n_cat
  33.         self.batch_size = batch_size
  34.         scale = 10.0
  35.         self.beta_cycle_gen = beta_cycle_gen
  36.         self.beta_cycle_label = beta_cycle_label
  37.  
  38.  
  39.         self.x_dim = self.d_net.x_dim
  40.         self.z_dim = self.g_net.z_dim
  41.  
  42.         self.x = tf.placeholder(tf.float32, [None, self.x_dim], name='x')
  43.         self.z = tf.placeholder(tf.float32, [None, self.z_dim], name='z')
  44.  
  45.         self.z_gen = self.z[:,0:self.dim_gen]
  46.         self.z_hot = self.z[:,self.dim_gen:]
  47.  
  48.         self.x_ = self.g_net(self.z)
  49.         self.z_enc_gen, self.z_enc_label, self.z_enc_logits = self.enc_net(self.x_, reuse=False)
  50.         self.z_infer_gen, self.z_infer_label, self.z_infer_logits = self.enc_net(self.x)
  51.  
  52.  
  53.         self.d = self.d_net(self.x, reuse=False)
  54.         self.d_ = self.d_net(self.x_)
  55.  
  56.  
  57.         self.g_loss = tf.reduce_mean(self.d_) + \
  58.                       self.beta_cycle_gen * tf.reduce_mean(tf.square(self.z_gen - self.z_enc_gen)) +\
  59.                       self.beta_cycle_label * tf.reduce_mean(
  60.                           tf.nn.softmax_cross_entropy_with_logits(logits=self.z_enc_logits,labels=self.z_hot))
  61.  
  62.         self.d_loss = tf.reduce_mean(self.d) - tf.reduce_mean(self.d_)
  63.  
  64.         epsilon = tf.random_uniform([], 0.0, 1.0)
  65.         x_hat = epsilon * self.x + (1 - epsilon) * self.x_
  66.         d_hat = self.d_net(x_hat)
  67.  
  68.         ddx = tf.gradients(d_hat, x_hat)[0]
  69.         ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))
  70.         ddx = tf.reduce_mean(tf.square(ddx - 1.0) * scale)
  71.  
  72.         self.d_loss = self.d_loss + ddx
  73.  
  74.         self.d_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \
  75.                 .minimize(self.d_loss, var_list=self.d_net.vars)
  76.         self.g_adam = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9) \
  77.                 .minimize(self.g_loss, var_list=[self.g_net.vars, self.enc_net.vars])
  78.  
  79.         # Reconstruction Nodes
  80.         self.recon_loss = tf.reduce_mean(tf.abs(self.x - self.x_), 1)
  81.         self.compute_grad = tf.gradients(self.recon_loss, self.z)
  82.  
  83.         self.saver = tf.train.Saver()
  84.  
  85.         run_config = tf.ConfigProto()
  86.         run_config.gpu_options.per_process_gpu_memory_fraction = 1.0
  87.         run_config.gpu_options.allow_growth = True
  88.         self.sess = tf.Session(config=run_config)
  89.  
  90.     def train(self, num_batches=500000):
  91.  
  92.         now = datetime.datetime.now(dateutil.tz.tzlocal())
  93.         timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
  94.  
  95.         batch_size = self.batch_size
  96.         plt.ion()
  97.         self.sess.run(tf.global_variables_initializer())
  98.         start_time = time.time()
  99.         print(
  100.         'Training {} on {}, sampler = {}, z = {} dimension, beta_n = {}, beta_c = {}'.
  101.             format(self.model, self.data, self.sampler, self.z_dim, self.beta_cycle_gen, self.beta_cycle_label))
  102.  
  103.  
  104.         im_save_dir = 'logs/{}/{}/{}_z{}_cyc{}_gen{}'.format(self.data, self.model, self.sampler, self.z_dim,
  105.                                                                  self.beta_cycle_label, self.beta_cycle_gen)
  106.         if not os.path.exists(im_save_dir):
  107.             os.makedirs(im_save_dir)
  108.  
  109.         for t in range(0, num_batches):
  110.             d_iters = 5
  111.  
  112.             for _ in range(0, d_iters):
  113.                 bx = self.x_sampler.train(batch_size)
  114.                 bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat)
  115.                 self.sess.run(self.d_adam, feed_dict={self.x: bx, self.z: bz})
  116.  
  117.             bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat)
  118.             self.sess.run(self.g_adam, feed_dict={self.z: bz})
  119.  
  120.             if (t+1) % 100 == 0:
  121.                 bx = self.x_sampler.train(batch_size)
  122.                 bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat)
  123.      
  124.  
  125.                 d_loss = self.sess.run(
  126.                     self.d_loss, feed_dict={self.x: bx, self.z: bz}
  127.                 )
  128.                 g_loss = self.sess.run(
  129.                     self.g_loss, feed_dict={self.z: bz}
  130.                 )
  131.                 print('Iter [%8d] Time [%5.4f] d_loss [%.4f] g_loss [%.4f]' %
  132.                       (t+1, time.time() - start_time, d_loss, g_loss))
  133.  
  134.  
  135.             if (t+1) % 5000 == 0:
  136.                 bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat)
  137.                 bx = self.sess.run(self.x_, feed_dict={self.z: bz})
  138.                 bx = xs.data2img(bx)
  139.                 bx = grid_transform(bx, xs.shape)
  140.  
  141.                 imsave('logs/{}/{}/{}_z{}_cyc{}_gen{}/{}.png'.format(self.data, self.model, self.sampler,
  142.                           self.z_dim, self.beta_cycle_label, self.beta_cycle_gen, (t+1) / 100), bx)
  143.  
  144.         self.recon_enc(timestamp, val = True)
  145.         self.save(timestamp)
  146.  
  147.     def save(self, timestamp):
  148.  
  149.         checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}_cyc{}_gen{}'.format(self.data, timestamp, self.model, self.sampler,
  150.                                                                             self.z_dim, self.beta_cycle_label,
  151.                                                                              self.beta_cycle_gen)
  152.  
  153.         if not os.path.exists(checkpoint_dir):
  154.             os.makedirs(checkpoint_dir)
  155.  
  156.         self.saver.save(self.sess, os.path.join(checkpoint_dir, 'model.ckpt'))
  157.  
  158.     def load(self, pre_trained = False, timestamp = ''):
  159.  
  160.         if pre_trained == True:
  161.             print('Loading Pre-trained Model...')
  162.             checkpoint_dir = 'pre_trained_models/{}/{}_{}_z{}_cyc{}_gen{}'.format(self.data, self.model, self.sampler,
  163.                                                                             self.z_dim, self.beta_cycle_label, self.beta_cycle_gen)
  164.         else:
  165.             if timestamp == '':
  166.                 print('Best Timestamp not provided. Abort !')
  167.                 checkpoint_dir = ''
  168.             else:
  169.                 checkpoint_dir = 'checkpoint_dir/{}/{}_{}_{}_z{}_cyc{}_gen{}'.format(self.data, timestamp, self.model, self.sampler,
  170.                                                                                      self.z_dim, self.beta_cycle_label,
  171.                                                                                      self.beta_cycle_gen)
  172.  
  173.  
  174.         self.saver.restore(self.sess, os.path.join(checkpoint_dir, 'model.ckpt'))
  175.         print('Restored model weights.')
  176.  
  177.  
  178.  
  179.     def _gen_samples(self, num_images):
  180.  
  181.         batch_size = self.batch_size
  182.         bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat)  
  183.         fake_im = self.sess.run(self.x_, feed_dict = {self.z : bz})
  184.         for t in range(num_images // batch_size):
  185.             bz = self.z_sampler(batch_size, self.z_dim, self.sampler, self.num_classes, self.n_cat)
  186.             im = self.sess.run(self.x_, feed_dict = {self.z : bz})
  187.             fake_im = np.vstack((fake_im, im))
  188.  
  189.         print(' Generated {} images .'.format(fake_im.shape[0]))
  190.         np.save('./Image_samples/{}/{}_{}_K_{}_gen_images.npy'.format(self.data, self.model, self.sampler, self.num_classes), fake_im)
  191.  
  192.  
  193.     def gen_from_all_modes(self):
  194.  
  195.         if self.sampler == 'one_hot':
  196.             batch_size = 1000
  197.             label_index = np.tile(np.arange(self.num_classes), int(np.ceil(batch_size * 1.0 / self.num_classes)))
  198.  
  199.             bz = self.z_sampler(batch_size, self.z_dim, self.sampler, num_class=self.num_classes,
  200.                                     n_cat= self.n_cat, label_index=label_index)
  201.             bx = self.sess.run(self.x_, feed_dict={self.z: bz})
  202.  
  203.             for m in range(self.num_classes):
  204.                 print('Generating samples from mode {} ...'.format(m))
  205.                 mode_index = np.where(label_index == m)[0]
  206.                 mode_bx = bx[mode_index, :]
  207.                 mode_bx = xs.data2img(mode_bx)
  208.                 mode_bx = grid_transform(mode_bx, xs.shape)
  209.  
  210.                 imsave('logs/{}/{}/{}_z{}_cyc{}_gen{}/mode{}_samples.png'.format(self.data, self.model, self.sampler,
  211.                         self.z_dim, self.beta_cycle_label, self.beta_cycle_gen, m), mode_bx)
  212.  
  213.     def recon_enc(self, timestamp, val = True):
  214.  
  215.         if val:
  216.             data_recon, label_recon = self.x_sampler.validation()
  217.         else:
  218.             data_recon, label_recon = self.x_sampler.test()
  219.             #data_recon, label_recon = self.x_sampler.load_all()
  220.  
  221.         num_pts_to_plot = data_recon.shape[0]
  222.         recon_batch_size = self.batch_size
  223.         latent = np.zeros(shape=(num_pts_to_plot, self.z_dim))
  224.  
  225.         print('Data Shape = {}, Labels Shape = {}'.format(data_recon.shape, label_recon.shape))
  226.         for b in range(int(np.ceil(num_pts_to_plot * 1.0 / recon_batch_size))):
  227.             if (b+1)*recon_batch_size > num_pts_to_plot:
  228.                pt_indx = np.arange(b*recon_batch_size, num_pts_to_plot)
  229.             else:
  230.                pt_indx = np.arange(b*recon_batch_size, (b+1)*recon_batch_size)
  231.             xtrue = data_recon[pt_indx, :]
  232.  
  233.             zhats_gen, zhats_label = self.sess.run([self.z_infer_gen, self.z_infer_label], feed_dict={self.x : xtrue})
  234.  
  235.             latent[pt_indx, :] = np.concatenate((zhats_gen, zhats_label), axis=1)
  236.  
  237.  
  238.         if self.beta_cycle_gen == 0:
  239.             self._eval_cluster(latent[:, self.dim_gen:], label_recon, timestamp, val)
  240.         else:
  241.             self._eval_cluster(latent, label_recon, timestamp, val)
  242.  
  243.  
  244.     def _eval_cluster(self, latent_rep, labels_true, timestamp, val):
  245.                
  246.         if self.data == 'fashion' and self.num_classes == 5:
  247.              map_labels = {0 : 0, 1 : 1, 2 : 2, 3 : 0, 4 : 2, 5 : 3, 6 : 2, 7 : 3, 8 : 4, 9 : 3}
  248.              labels_true = np.array([map_labels[i] for i in labels_true])
  249.  
  250.         km = KMeans(n_clusters=max(self.num_classes, len(np.unique(labels_true))), random_state=0).fit(latent_rep)
  251.         labels_pred = km.labels_
  252.  
  253.         purity = metric.compute_purity(labels_pred, labels_true)
  254.         ari = adjusted_rand_score(labels_true, labels_pred)
  255.         nmi = normalized_mutual_info_score(labels_true, labels_pred)
  256.  
  257.  
  258.         if val:
  259.             data_split = 'Validation'
  260.         else:
  261.             data_split = 'Test'
  262.             #data_split = 'All'
  263.  
  264.         print('Data = {}, Model = {}, sampler = {}, z_dim = {}, beta_label = {}, beta_gen = {} '
  265.               .format(self.data, self.model, self.sampler, self.z_dim, self.beta_cycle_label, self.beta_cycle_gen))
  266.         print(' #Points = {}, K = {}, Purity = {},  NMI = {}, ARI = {},  '
  267.               .format(latent_rep.shape[0], self.num_classes, purity, nmi, ari))
  268.  
  269.         with open('logs/Res_{}_{}.txt'.format(self.data, self.model), 'a+') as f:
  270.                 f.write('{}, {} : K = {}, z_dim = {}, beta_label = {}, beta_gen = {}, sampler = {}, Purity = {}, NMI = {}, ARI = {}\n'
  271.                         .format(timestamp, data_split, self.num_classes, self.z_dim, self.beta_cycle_label, self.beta_cycle_gen,
  272.                                 self.sampler, purity, nmi, ari))
  273.                 f.flush()
  274.  
  275.  
  276. if __name__ == '__main__':
  277.     parser = argparse.ArgumentParser('')
  278.     parser.add_argument('--data', type=str, default='mnist')
  279.     parser.add_argument('--model', type=str, default='clus_wgan')
  280.     parser.add_argument('--sampler', type=str, default='one_hot')
  281.     parser.add_argument('--K', type=int, default=10)
  282.     parser.add_argument('--dz', type=int, default=30)
  283.     parser.add_argument('--bs', type=int, default=64)
  284.     parser.add_argument('--beta_n', type=float, default=10.0)
  285.     parser.add_argument('--beta_c', type=float, default=10.0)
  286.     parser.add_argument('--timestamp', type=str, default='')
  287.     parser.add_argument('--train', type=str, default='False')
  288.  
  289.     args = parser.parse_args()
  290.     data = importlib.import_module(args.data)
  291.     model = importlib.import_module(args.data + '.' + args.model)
  292.  
  293.     num_classes = args.K
  294.     dim_gen = args.dz
  295.     n_cat = 1
  296.     batch_size = args.bs
  297.     beta_cycle_gen = args.beta_n
  298.     beta_cycle_label = args.beta_c
  299.     timestamp = args.timestamp
  300.  
  301.     z_dim = dim_gen + num_classes * n_cat
  302.     d_net = model.Discriminator()
  303.     g_net = model.Generator(z_dim=z_dim)
  304.     enc_net = model.Encoder(z_dim=z_dim, dim_gen = dim_gen)
  305.     xs = data.DataSampler()
  306.     zs = util.sample_Z
  307.  
  308.  
  309.     cl_gan = clusGAN(g_net, d_net, enc_net, xs, zs, args.data, args.model, args.sampler,
  310.                      num_classes, dim_gen, n_cat, batch_size, beta_cycle_gen, beta_cycle_label)
  311.     if args.train == 'True':
  312.         cl_gan.train()
  313.     else:
  314.  
  315.         print('Attempting to Restore Model ...')
  316.         if timestamp == '':
  317.             cl_gan.load(pre_trained=True)
  318.             timestamp = 'pre-trained'
  319.         else:
  320.             cl_gan.load(pre_trained=False, timestamp = timestamp)
  321.  
  322.         cl_gan.recon_enc(timestamp, val=False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement