Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- """PGGAN based on ResNet."""
- import numpy as np
- import tensorflow as tf
- import functools
- import locale
- import os
- import sys
- import math
- from misc import custom_ops
- sys.path.append(os.getcwd())
- locale.setlocale(locale.LC_ALL, '')
- def optimistic_restore(session, save_file):
- """
- Args:
- session:
- save_file:
- Returns:
- """
- reader = tf.train.NewCheckpointReader(save_file)
- saved_shapes = reader.get_variable_to_shape_map()
- var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()
- if var.name.split(':')[0] in saved_shapes])
- restore_vars = []
- name2var = dict(zip(map(lambda x: x.name.split(':')[0], tf.global_variables()), tf.global_variables()))
- with tf.variable_scope('', reuse=True):
- for var_name, saved_var_name in var_names:
- curr_var = name2var[saved_var_name]
- var_shape = curr_var.get_shape().as_list()
- if var_shape == saved_shapes[saved_var_name]:
- restore_vars.append(curr_var)
- saver = tf.train.Saver(restore_vars)
- saver.restore(session, save_file)
- # print('\n--------variables stored:--------')
- # for var_name, saved_var_name in var_names:
- # print(var_name)
- print('\n--------variables to restore:--------')
- for var in restore_vars:
- print(var)
- def ConvMeanPool(inputs, output_dim, spectral_normed=False, update_collection=None, reuse=False,
- k_h=5, k_w=5, d_h=1, d_w=1, in_dim=None, name=None):
- output = custom_ops.custom_conv2d(inputs=inputs, output_dim=output_dim,
- spectral_normed=spectral_normed,
- update_collection=update_collection,
- reuse=reuse,
- k_h=k_h, k_w=k_w, d_h=d_h, d_w=d_w, name=name)
- output = tf.transpose(output, [0, 3, 1, 2], name='NHWC_to_NCHW')
- output = tf.add_n(
- [output[:, :, ::2, ::2], output[:, :, 1::2, ::2], output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
- output = tf.transpose(output, [0, 2, 3, 1], name='NCHW_to_NHWC')
- return output
- def MeanPoolConv(inputs, output_dim, spectral_normed=False, update_collection=None, reuse=False,
- k_h=5, k_w=5, d_h=1, d_w=1, in_dim=None, name=None):
- output = inputs
- output = tf.transpose(output, [0, 3, 1, 2], name='NHWC_to_NCHW')
- output = tf.add_n(
- [output[:, :, ::2, ::2], output[:, :, 1::2, ::2], output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
- output = tf.transpose(output, [0, 2, 3, 1], name='NCHW_to_NHWC')
- output = custom_ops.custom_conv2d(inputs=output, output_dim=output_dim,
- spectral_normed=spectral_normed,
- update_collection=update_collection,
- reuse=reuse,
- k_h=k_h, k_w=k_w, d_h=d_h, d_w=d_w, name=name)
- return output
- def UpsampleConv(inputs, output_dim, spectral_normed=False, update_collection=None, reuse=False,
- k_h=5, k_w=5, d_h=1, d_w=1, in_dim=None, name=None):
- output = inputs
- output = tf.concat([output, output, output, output], axis=3)
- output = tf.depth_to_space(output, 2)
- output = custom_ops.custom_conv2d(inputs=output, output_dim=output_dim,
- spectral_normed=spectral_normed,
- update_collection=update_collection,
- reuse=reuse,
- k_h=k_h, k_w=k_w, d_h=d_h, d_w=d_w, name=name)
- return output
- def ResidualBlock(name, input_dim, output_dim, filter_size, inputs, pixel_norm=False,
- spectral_normed=False, update_collection=None, reuse=False,
- resample=None, labels=None):
- """resample: None, 'down', or 'up'.
- """
- with tf.variable_scope(name):
- if resample == 'down':
- conv_1 = functools.partial(custom_ops.custom_conv2d, output_dim=input_dim, d_h=1, d_w=1)
- conv_2 = functools.partial(ConvMeanPool, output_dim=output_dim, d_h=1, d_w=1)
- conv_shortcut = ConvMeanPool
- elif resample == 'up':
- conv_1 = functools.partial(UpsampleConv, output_dim=output_dim, d_h=1, d_w=1)
- conv_shortcut = UpsampleConv
- conv_2 = functools.partial(custom_ops.custom_conv2d, output_dim=output_dim, d_h=1, d_w=1)
- elif resample is None:
- conv_shortcut = custom_ops.custom_conv2d
- conv_1 = functools.partial(custom_ops.custom_conv2d, output_dim=output_dim, d_h=1, d_w=1)
- conv_2 = functools.partial(custom_ops.custom_conv2d, output_dim=output_dim, d_h=1, d_w=1)
- else:
- raise Exception('invalid resample value')
- if output_dim == input_dim and resample is None:
- shortcut = inputs # Identity skip-connection
- else:
- shortcut = conv_shortcut(inputs=inputs, output_dim=output_dim,
- spectral_normed=spectral_normed,
- update_collection=update_collection,
- reuse=reuse,
- k_h=1, k_w=1, d_h=1, d_w=1, name='.Shortcut')
- output = inputs
- if 'D' not in name:
- output = custom_ops.Normalize('.N1', output, pixel_norm=pixel_norm, labels=labels, training=True)
- output = tf.nn.relu(output)
- output = conv_1(inputs=output, k_h=filter_size, k_w=filter_size, name='.Conv1',
- spectral_normed=spectral_normed,
- update_collection=update_collection,
- reuse=reuse)
- if 'D' not in name:
- output = custom_ops.Normalize('.N2', output, pixel_norm=pixel_norm, labels=labels, training=True)
- output = tf.nn.relu(output)
- output = conv_2(inputs=output, k_h=filter_size, k_w=filter_size, name='.Conv2',
- spectral_normed=spectral_normed,
- update_collection=update_collection,
- reuse=reuse)
- return shortcut + output
- def OptimizedResBlockDisc1(inputs, DIM_D=128, spectral_normed=False, update_collection=None, reuse=False):
- with tf.variable_scope("D.1"):
- conv_1 = functools.partial(custom_ops.custom_conv2d, output_dim=DIM_D, d_h=1, d_w=1)
- conv_2 = functools.partial(ConvMeanPool, output_dim=DIM_D, d_h=1, d_w=1)
- conv_shortcut = MeanPoolConv
- shortcut = conv_shortcut(inputs=inputs, output_dim=DIM_D,
- spectral_normed=spectral_normed,
- update_collection=update_collection,
- reuse=reuse,
- k_h=1, k_w=1, d_h=1, d_w=1, name='.Shortcut')
- output = inputs
- output = conv_1(inputs=output, k_h=3, k_w=3, name='.Conv1',
- spectral_normed=spectral_normed,
- update_collection=update_collection,
- reuse=reuse)
- output = tf.nn.relu(output)
- output = conv_2(inputs=output, k_h=3, k_w=3, name='.Conv2',
- spectral_normed=spectral_normed,
- update_collection=update_collection,
- reuse=reuse)
- return shortcut + output
- def Generator(noise, labels=None, imsize=128, training=True):
- output = custom_ops.custom_fully_connected(noise, 4 * 4 * 1024, scope='G.Input')
- output = tf.reshape(output, [-1, 4, 4, 1024])
- output = ResidualBlock('G.1', 1024, 1024, 3, output, pixel_norm=False, resample='up', labels=labels)
- output = ResidualBlock('G.2', 1024, 512, 3, output, pixel_norm=False, resample='up', labels=labels)
- output = ResidualBlock('G.3', 512, 256, 3, output, pixel_norm=False, resample='up', labels=labels)
- output = ResidualBlock('G.4', 256, 128, 3, output, pixel_norm=False, resample='up', labels=labels)
- output = ResidualBlock('G.5', 128, 64, 3, output, pixel_norm=False, resample='up', labels=labels)
- output = custom_ops.Normalize('G.Output_N', output, pixel_norm=False)
- output = tf.nn.relu(output)
- output = custom_ops.custom_conv2d(inputs=output, output_dim=3, k_h=3, k_w=3, d_h=1, d_w=1,
- name='G.Output')
- output = tf.tanh(output)
- return output
- def Discriminator(x_var, c_var, labels=None, imsize=128, update_collection=None, reuse=False):
- x_code = OptimizedResBlockDisc1(x_var, DIM_D=64,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse)
- x_code = ResidualBlock('D.2', 64, 128, 3, x_code,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample='down', labels=labels)
- x_code = ResidualBlock('D.3', 128, 256, 3, x_code,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample='down', labels=labels)
- c_code = tf.expand_dims(tf.expand_dims(c_var, 1), 1)
- c_code = tf.tile(c_code, [1, imsize // 8, imsize // 8, 1])
- x_c_code = tf.concat(axis=3, values=[x_code, c_code])
- output = ResidualBlock('D.4', 256, 512, 3, x_c_code,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample='down', labels=labels)
- output = ResidualBlock('D.5', 512, 1024, 3, output,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample='down', labels=labels)
- output = ResidualBlock('D.6', 1024, 1024, 3, output,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample=None, labels=labels)
- output = tf.nn.relu(output)
- output = tf.reduce_mean(output, axis=[1, 2])
- output_wgan = custom_ops.custom_fully_connected(output, 1,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- scope='D.Output')
- output_wgan = tf.reshape(output_wgan, [-1])
- return output_wgan
- # ######## ######## PGGAN ######## ######## #
- def get_dim(stage):
- return min(1024 / (2 ** stage), 512)
- def Generator_PGGAN(noise, pg, trans=False, alpha=0.01, pixel_norm=True, labels=None, training=True):
- """
- Args:
- noise:
- pg: Count of ResidualBlock.
- trans:
- alpha:
- pixel_norm:
- labels:
- training:
- Return:
- """
- # pg_ = pg
- # 4 * 4 * 1024
- output = custom_ops.custom_fully_connected(noise, 4 * 4 * 1024, scope='G.Input')
- output = tf.reshape(output, [-1, 4, 4, 1024])
- # 8 * 8 * 1024
- output = ResidualBlock('G.Block.1', 1024, 1024, 3, output, pixel_norm=pixel_norm, resample='up', labels=labels)
- print('G.Block.1: {}'.format(output.shape.as_list()))
- for i in range(pg - 2):
- output = ResidualBlock('G.Block.{}'.format(i + 2), output.shape.as_list()[-1],
- get_dim(i), 3, output, pixel_norm=pixel_norm, resample='up', labels=labels)
- print('G.Block.{}: {}'.format(i + 2, output.shape.as_list()))
- if trans:
- toRGB1 = ResidualBlock('G.Block.{}'.format(pg), output.shape.as_list()[-1],
- get_dim(pg - 2), 3, output, pixel_norm=pixel_norm, resample='up', labels=labels)
- print('G.Block.{}: {}'.format(pg, toRGB1.shape.as_list()))
- toRGB2 = \
- tf.image.resize_nearest_neighbor(output, [toRGB1.shape.as_list()[1], toRGB1.shape.as_list()[2]])
- toRGB2 = ResidualBlock('G.{}_toRGB'.format(pg), toRGB2.shape.as_list()[-1],
- get_dim(pg - 2), 1, toRGB2, pixel_norm=pixel_norm, resample=None, labels=labels)
- output = (1.0 - alpha) * toRGB2 + alpha * toRGB1
- print('G.{}_toRGB: {}'.format(pg, toRGB2.shape.as_list()))
- else:
- output = ResidualBlock('G.Block.{}'.format(pg), output.shape.as_list()[-1],
- get_dim(pg - 2), 3, output, pixel_norm=pixel_norm, resample='up', labels=labels)
- print('G.Block.{}: {}'.format(pg, output.shape.as_list()))
- output = custom_ops.Normalize('G.Output_Normalize', output, pixel_norm=pixel_norm)
- output = tf.nn.relu(output)
- output = custom_ops.custom_conv2d(inputs=output, output_dim=3, k_h=3, k_w=3, d_h=1, d_w=1,
- name='G.Output')
- print('G.Output: {}'.format(output.shape.as_list()))
- output = tf.tanh(output)
- return output
- def Discriminator_PGGAN(x_var, c_var, pg, trans=False, alpha=0.01, labels=None,
- update_collection=None, reuse=False):
- """
- Args:
- x_var:
- c_var:
- pg:
- trans:
- alpha:
- labels:
- reuse:
- update_collection:
- Return:
- """
- # imsize = 4 * pow(2, pg)
- if trans:
- x_code = ResidualBlock('D.Block.{}'.format(pg), 3, get_dim(pg - 2), 3, x_var,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample='down',
- labels=labels)
- print('D.Block.{}: {}'.format(pg, x_code.shape.as_list()))
- fromRGB = ResidualBlock('D.{}_fromRGB'.format(pg), 3, get_dim(pg - 2), 1, x_var,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample=None,
- labels=labels)
- print('D.{}_fromRGB: {}'.format(pg, fromRGB.shape.as_list()))
- fromRGB = \
- tf.image.resize_nearest_neighbor(fromRGB, [x_code.shape.as_list()[1], x_code.shape.as_list()[2]])
- x_code = (1.0 - alpha) * fromRGB + alpha * x_code
- else:
- x_code = ResidualBlock('D.Block.{}'.format(pg), 3, get_dim(pg - 2), 3, x_var,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample='down',
- labels=labels)
- print('D.Block.{}: {}'.format(pg, x_code.shape.as_list()))
- step = int(math.ceil((pg - 2) / 2.))
- print('----setp----: {}'.format(step))
- for i in range(1, step + 1):
- x_code = ResidualBlock('D.Block.{}'.format(pg - i), x_code.shape.as_list()[-1],
- get_dim(pg - 2 - i), 3, x_code,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample='down',
- labels=labels)
- print('D.Block.{}: {}'.format(pg - i, x_code.shape.as_list()))
- print('---- concat ----')
- c_code = tf.expand_dims(tf.expand_dims(c_var, 1), 1)
- c_code = tf.tile(c_code, [1, x_code.shape.as_list()[1], x_code.shape.as_list()[2], 1])
- x_c_code = tf.concat(axis=3, values=[x_code, c_code])
- output = x_c_code
- for i in range(step + 1, pg - 1):
- output = ResidualBlock('D.Block.{}'.format(pg - i), output.shape.as_list()[-1],
- get_dim(pg - 2 - i), 3, output,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample='down',
- labels=labels)
- print('D.Block.{}: {}'.format(pg - i, output.shape.as_list()))
- output = ResidualBlock('D.Block.1', output.shape.as_list()[-1], 1024, 3, output,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample='down',
- labels=labels)
- print('D.Block.1: {}'.format(output.shape.as_list()))
- output = ResidualBlock('D.0', 1024, 1024, 3, output,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- resample=None,
- labels=labels)
- print('D.0: {}'.format(output.shape.as_list()))
- output = tf.nn.relu(output)
- output = tf.reduce_mean(output, axis=[1, 2])
- logits = custom_ops.custom_fully_connected(output, 1,
- spectral_normed=True,
- update_collection=update_collection,
- reuse=reuse,
- scope='D.Output')
- output_wgan = tf.reshape(logits, [-1])
- return output_wgan
Add Comment
Please, Sign In to add comment