Guest User

Untitled

a guest
Jan 20th, 2018
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 17.21 KB | None | 0 0
  1. """PGGAN based on ResNet."""
  2.  
  3. import numpy as np
  4. import tensorflow as tf
  5. import functools
  6. import locale
  7. import os
  8. import sys
  9. import math
  10.  
  11. from misc import custom_ops
  12.  
  13. sys.path.append(os.getcwd())
  14. locale.setlocale(locale.LC_ALL, '')
  15.  
  16.  
  17. def optimistic_restore(session, save_file):
  18. """
  19. Args:
  20. session:
  21. save_file:
  22.  
  23. Returns:
  24. """
  25. reader = tf.train.NewCheckpointReader(save_file)
  26. saved_shapes = reader.get_variable_to_shape_map()
  27. var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()
  28. if var.name.split(':')[0] in saved_shapes])
  29. restore_vars = []
  30.  
  31. name2var = dict(zip(map(lambda x: x.name.split(':')[0], tf.global_variables()), tf.global_variables()))
  32.  
  33. with tf.variable_scope('', reuse=True):
  34. for var_name, saved_var_name in var_names:
  35. curr_var = name2var[saved_var_name]
  36. var_shape = curr_var.get_shape().as_list()
  37. if var_shape == saved_shapes[saved_var_name]:
  38. restore_vars.append(curr_var)
  39. saver = tf.train.Saver(restore_vars)
  40. saver.restore(session, save_file)
  41.  
  42. # print('\n--------variables stored:--------')
  43. # for var_name, saved_var_name in var_names:
  44. # print(var_name)
  45.  
  46. print('\n--------variables to restore:--------')
  47. for var in restore_vars:
  48. print(var)
  49.  
  50.  
  51. def ConvMeanPool(inputs, output_dim, spectral_normed=False, update_collection=None, reuse=False,
  52. k_h=5, k_w=5, d_h=1, d_w=1, in_dim=None, name=None):
  53. output = custom_ops.custom_conv2d(inputs=inputs, output_dim=output_dim,
  54. spectral_normed=spectral_normed,
  55. update_collection=update_collection,
  56. reuse=reuse,
  57. k_h=k_h, k_w=k_w, d_h=d_h, d_w=d_w, name=name)
  58. output = tf.transpose(output, [0, 3, 1, 2], name='NHWC_to_NCHW')
  59. output = tf.add_n(
  60. [output[:, :, ::2, ::2], output[:, :, 1::2, ::2], output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
  61. output = tf.transpose(output, [0, 2, 3, 1], name='NCHW_to_NHWC')
  62.  
  63. return output
  64.  
  65.  
  66. def MeanPoolConv(inputs, output_dim, spectral_normed=False, update_collection=None, reuse=False,
  67. k_h=5, k_w=5, d_h=1, d_w=1, in_dim=None, name=None):
  68. output = inputs
  69. output = tf.transpose(output, [0, 3, 1, 2], name='NHWC_to_NCHW')
  70. output = tf.add_n(
  71. [output[:, :, ::2, ::2], output[:, :, 1::2, ::2], output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
  72. output = tf.transpose(output, [0, 2, 3, 1], name='NCHW_to_NHWC')
  73. output = custom_ops.custom_conv2d(inputs=output, output_dim=output_dim,
  74. spectral_normed=spectral_normed,
  75. update_collection=update_collection,
  76. reuse=reuse,
  77. k_h=k_h, k_w=k_w, d_h=d_h, d_w=d_w, name=name)
  78. return output
  79.  
  80.  
  81. def UpsampleConv(inputs, output_dim, spectral_normed=False, update_collection=None, reuse=False,
  82. k_h=5, k_w=5, d_h=1, d_w=1, in_dim=None, name=None):
  83. output = inputs
  84. output = tf.concat([output, output, output, output], axis=3)
  85. output = tf.depth_to_space(output, 2)
  86. output = custom_ops.custom_conv2d(inputs=output, output_dim=output_dim,
  87. spectral_normed=spectral_normed,
  88. update_collection=update_collection,
  89. reuse=reuse,
  90. k_h=k_h, k_w=k_w, d_h=d_h, d_w=d_w, name=name)
  91. return output
  92.  
  93.  
  94. def ResidualBlock(name, input_dim, output_dim, filter_size, inputs, pixel_norm=False,
  95. spectral_normed=False, update_collection=None, reuse=False,
  96. resample=None, labels=None):
  97. """resample: None, 'down', or 'up'.
  98. """
  99. with tf.variable_scope(name):
  100. if resample == 'down':
  101. conv_1 = functools.partial(custom_ops.custom_conv2d, output_dim=input_dim, d_h=1, d_w=1)
  102. conv_2 = functools.partial(ConvMeanPool, output_dim=output_dim, d_h=1, d_w=1)
  103. conv_shortcut = ConvMeanPool
  104. elif resample == 'up':
  105. conv_1 = functools.partial(UpsampleConv, output_dim=output_dim, d_h=1, d_w=1)
  106. conv_shortcut = UpsampleConv
  107. conv_2 = functools.partial(custom_ops.custom_conv2d, output_dim=output_dim, d_h=1, d_w=1)
  108. elif resample is None:
  109. conv_shortcut = custom_ops.custom_conv2d
  110. conv_1 = functools.partial(custom_ops.custom_conv2d, output_dim=output_dim, d_h=1, d_w=1)
  111. conv_2 = functools.partial(custom_ops.custom_conv2d, output_dim=output_dim, d_h=1, d_w=1)
  112. else:
  113. raise Exception('invalid resample value')
  114.  
  115. if output_dim == input_dim and resample is None:
  116. shortcut = inputs # Identity skip-connection
  117. else:
  118. shortcut = conv_shortcut(inputs=inputs, output_dim=output_dim,
  119. spectral_normed=spectral_normed,
  120. update_collection=update_collection,
  121. reuse=reuse,
  122. k_h=1, k_w=1, d_h=1, d_w=1, name='.Shortcut')
  123.  
  124. output = inputs
  125. if 'D' not in name:
  126. output = custom_ops.Normalize('.N1', output, pixel_norm=pixel_norm, labels=labels, training=True)
  127. output = tf.nn.relu(output)
  128. output = conv_1(inputs=output, k_h=filter_size, k_w=filter_size, name='.Conv1',
  129. spectral_normed=spectral_normed,
  130. update_collection=update_collection,
  131. reuse=reuse)
  132. if 'D' not in name:
  133. output = custom_ops.Normalize('.N2', output, pixel_norm=pixel_norm, labels=labels, training=True)
  134. output = tf.nn.relu(output)
  135. output = conv_2(inputs=output, k_h=filter_size, k_w=filter_size, name='.Conv2',
  136. spectral_normed=spectral_normed,
  137. update_collection=update_collection,
  138. reuse=reuse)
  139.  
  140. return shortcut + output
  141.  
  142.  
  143. def OptimizedResBlockDisc1(inputs, DIM_D=128, spectral_normed=False, update_collection=None, reuse=False):
  144. with tf.variable_scope("D.1"):
  145. conv_1 = functools.partial(custom_ops.custom_conv2d, output_dim=DIM_D, d_h=1, d_w=1)
  146. conv_2 = functools.partial(ConvMeanPool, output_dim=DIM_D, d_h=1, d_w=1)
  147. conv_shortcut = MeanPoolConv
  148. shortcut = conv_shortcut(inputs=inputs, output_dim=DIM_D,
  149. spectral_normed=spectral_normed,
  150. update_collection=update_collection,
  151. reuse=reuse,
  152. k_h=1, k_w=1, d_h=1, d_w=1, name='.Shortcut')
  153.  
  154. output = inputs
  155. output = conv_1(inputs=output, k_h=3, k_w=3, name='.Conv1',
  156. spectral_normed=spectral_normed,
  157. update_collection=update_collection,
  158. reuse=reuse)
  159. output = tf.nn.relu(output)
  160. output = conv_2(inputs=output, k_h=3, k_w=3, name='.Conv2',
  161. spectral_normed=spectral_normed,
  162. update_collection=update_collection,
  163. reuse=reuse)
  164. return shortcut + output
  165.  
  166.  
  167. def Generator(noise, labels=None, imsize=128, training=True):
  168. output = custom_ops.custom_fully_connected(noise, 4 * 4 * 1024, scope='G.Input')
  169. output = tf.reshape(output, [-1, 4, 4, 1024])
  170.  
  171. output = ResidualBlock('G.1', 1024, 1024, 3, output, pixel_norm=False, resample='up', labels=labels)
  172. output = ResidualBlock('G.2', 1024, 512, 3, output, pixel_norm=False, resample='up', labels=labels)
  173. output = ResidualBlock('G.3', 512, 256, 3, output, pixel_norm=False, resample='up', labels=labels)
  174. output = ResidualBlock('G.4', 256, 128, 3, output, pixel_norm=False, resample='up', labels=labels)
  175. output = ResidualBlock('G.5', 128, 64, 3, output, pixel_norm=False, resample='up', labels=labels)
  176. output = custom_ops.Normalize('G.Output_N', output, pixel_norm=False)
  177. output = tf.nn.relu(output)
  178.  
  179. output = custom_ops.custom_conv2d(inputs=output, output_dim=3, k_h=3, k_w=3, d_h=1, d_w=1,
  180. name='G.Output')
  181. output = tf.tanh(output)
  182.  
  183. return output
  184.  
  185.  
  186. def Discriminator(x_var, c_var, labels=None, imsize=128, update_collection=None, reuse=False):
  187. x_code = OptimizedResBlockDisc1(x_var, DIM_D=64,
  188. spectral_normed=True,
  189. update_collection=update_collection,
  190. reuse=reuse)
  191. x_code = ResidualBlock('D.2', 64, 128, 3, x_code,
  192. spectral_normed=True,
  193. update_collection=update_collection,
  194. reuse=reuse,
  195. resample='down', labels=labels)
  196. x_code = ResidualBlock('D.3', 128, 256, 3, x_code,
  197. spectral_normed=True,
  198. update_collection=update_collection,
  199. reuse=reuse,
  200. resample='down', labels=labels)
  201.  
  202. c_code = tf.expand_dims(tf.expand_dims(c_var, 1), 1)
  203. c_code = tf.tile(c_code, [1, imsize // 8, imsize // 8, 1])
  204. x_c_code = tf.concat(axis=3, values=[x_code, c_code])
  205.  
  206. output = ResidualBlock('D.4', 256, 512, 3, x_c_code,
  207. spectral_normed=True,
  208. update_collection=update_collection,
  209. reuse=reuse,
  210. resample='down', labels=labels)
  211.  
  212. output = ResidualBlock('D.5', 512, 1024, 3, output,
  213. spectral_normed=True,
  214. update_collection=update_collection,
  215. reuse=reuse,
  216. resample='down', labels=labels)
  217. output = ResidualBlock('D.6', 1024, 1024, 3, output,
  218. spectral_normed=True,
  219. update_collection=update_collection,
  220. reuse=reuse,
  221. resample=None, labels=labels)
  222.  
  223. output = tf.nn.relu(output)
  224. output = tf.reduce_mean(output, axis=[1, 2])
  225. output_wgan = custom_ops.custom_fully_connected(output, 1,
  226. spectral_normed=True,
  227. update_collection=update_collection,
  228. reuse=reuse,
  229. scope='D.Output')
  230. output_wgan = tf.reshape(output_wgan, [-1])
  231.  
  232. return output_wgan
  233.  
  234.  
  235. # ######## ######## PGGAN ######## ######## #
  236. def get_dim(stage):
  237. return min(1024 / (2 ** stage), 512)
  238.  
  239.  
  240. def Generator_PGGAN(noise, pg, trans=False, alpha=0.01, pixel_norm=True, labels=None, training=True):
  241. """
  242. Args:
  243. noise:
  244. pg: Count of ResidualBlock.
  245. trans:
  246. alpha:
  247. pixel_norm:
  248. labels:
  249. training:
  250. Return:
  251. """
  252. # pg_ = pg
  253.  
  254. # 4 * 4 * 1024
  255. output = custom_ops.custom_fully_connected(noise, 4 * 4 * 1024, scope='G.Input')
  256. output = tf.reshape(output, [-1, 4, 4, 1024])
  257. # 8 * 8 * 1024
  258. output = ResidualBlock('G.Block.1', 1024, 1024, 3, output, pixel_norm=pixel_norm, resample='up', labels=labels)
  259. print('G.Block.1: {}'.format(output.shape.as_list()))
  260.  
  261. for i in range(pg - 2):
  262. output = ResidualBlock('G.Block.{}'.format(i + 2), output.shape.as_list()[-1],
  263. get_dim(i), 3, output, pixel_norm=pixel_norm, resample='up', labels=labels)
  264. print('G.Block.{}: {}'.format(i + 2, output.shape.as_list()))
  265.  
  266. if trans:
  267. toRGB1 = ResidualBlock('G.Block.{}'.format(pg), output.shape.as_list()[-1],
  268. get_dim(pg - 2), 3, output, pixel_norm=pixel_norm, resample='up', labels=labels)
  269. print('G.Block.{}: {}'.format(pg, toRGB1.shape.as_list()))
  270.  
  271. toRGB2 = \
  272. tf.image.resize_nearest_neighbor(output, [toRGB1.shape.as_list()[1], toRGB1.shape.as_list()[2]])
  273. toRGB2 = ResidualBlock('G.{}_toRGB'.format(pg), toRGB2.shape.as_list()[-1],
  274. get_dim(pg - 2), 1, toRGB2, pixel_norm=pixel_norm, resample=None, labels=labels)
  275. output = (1.0 - alpha) * toRGB2 + alpha * toRGB1
  276. print('G.{}_toRGB: {}'.format(pg, toRGB2.shape.as_list()))
  277. else:
  278. output = ResidualBlock('G.Block.{}'.format(pg), output.shape.as_list()[-1],
  279. get_dim(pg - 2), 3, output, pixel_norm=pixel_norm, resample='up', labels=labels)
  280. print('G.Block.{}: {}'.format(pg, output.shape.as_list()))
  281. output = custom_ops.Normalize('G.Output_Normalize', output, pixel_norm=pixel_norm)
  282. output = tf.nn.relu(output)
  283.  
  284. output = custom_ops.custom_conv2d(inputs=output, output_dim=3, k_h=3, k_w=3, d_h=1, d_w=1,
  285. name='G.Output')
  286. print('G.Output: {}'.format(output.shape.as_list()))
  287.  
  288. output = tf.tanh(output)
  289.  
  290. return output
  291.  
  292.  
  293. def Discriminator_PGGAN(x_var, c_var, pg, trans=False, alpha=0.01, labels=None,
  294. update_collection=None, reuse=False):
  295. """
  296. Args:
  297. x_var:
  298. c_var:
  299. pg:
  300. trans:
  301. alpha:
  302. labels:
  303. reuse:
  304. update_collection:
  305. Return:
  306. """
  307. # imsize = 4 * pow(2, pg)
  308.  
  309. if trans:
  310. x_code = ResidualBlock('D.Block.{}'.format(pg), 3, get_dim(pg - 2), 3, x_var,
  311. spectral_normed=True,
  312. update_collection=update_collection,
  313. reuse=reuse,
  314. resample='down',
  315. labels=labels)
  316. print('D.Block.{}: {}'.format(pg, x_code.shape.as_list()))
  317.  
  318. fromRGB = ResidualBlock('D.{}_fromRGB'.format(pg), 3, get_dim(pg - 2), 1, x_var,
  319. spectral_normed=True,
  320. update_collection=update_collection,
  321. reuse=reuse,
  322. resample=None,
  323. labels=labels)
  324. print('D.{}_fromRGB: {}'.format(pg, fromRGB.shape.as_list()))
  325. fromRGB = \
  326. tf.image.resize_nearest_neighbor(fromRGB, [x_code.shape.as_list()[1], x_code.shape.as_list()[2]])
  327. x_code = (1.0 - alpha) * fromRGB + alpha * x_code
  328. else:
  329. x_code = ResidualBlock('D.Block.{}'.format(pg), 3, get_dim(pg - 2), 3, x_var,
  330. spectral_normed=True,
  331. update_collection=update_collection,
  332. reuse=reuse,
  333. resample='down',
  334. labels=labels)
  335. print('D.Block.{}: {}'.format(pg, x_code.shape.as_list()))
  336.  
  337. step = int(math.ceil((pg - 2) / 2.))
  338. print('----setp----: {}'.format(step))
  339. for i in range(1, step + 1):
  340. x_code = ResidualBlock('D.Block.{}'.format(pg - i), x_code.shape.as_list()[-1],
  341. get_dim(pg - 2 - i), 3, x_code,
  342. spectral_normed=True,
  343. update_collection=update_collection,
  344. reuse=reuse,
  345. resample='down',
  346. labels=labels)
  347. print('D.Block.{}: {}'.format(pg - i, x_code.shape.as_list()))
  348.  
  349. print('---- concat ----')
  350. c_code = tf.expand_dims(tf.expand_dims(c_var, 1), 1)
  351. c_code = tf.tile(c_code, [1, x_code.shape.as_list()[1], x_code.shape.as_list()[2], 1])
  352. x_c_code = tf.concat(axis=3, values=[x_code, c_code])
  353. output = x_c_code
  354.  
  355. for i in range(step + 1, pg - 1):
  356. output = ResidualBlock('D.Block.{}'.format(pg - i), output.shape.as_list()[-1],
  357. get_dim(pg - 2 - i), 3, output,
  358. spectral_normed=True,
  359. update_collection=update_collection,
  360. reuse=reuse,
  361. resample='down',
  362. labels=labels)
  363. print('D.Block.{}: {}'.format(pg - i, output.shape.as_list()))
  364.  
  365. output = ResidualBlock('D.Block.1', output.shape.as_list()[-1], 1024, 3, output,
  366. spectral_normed=True,
  367. update_collection=update_collection,
  368. reuse=reuse,
  369. resample='down',
  370. labels=labels)
  371. print('D.Block.1: {}'.format(output.shape.as_list()))
  372.  
  373. output = ResidualBlock('D.0', 1024, 1024, 3, output,
  374. spectral_normed=True,
  375. update_collection=update_collection,
  376. reuse=reuse,
  377. resample=None,
  378. labels=labels)
  379. print('D.0: {}'.format(output.shape.as_list()))
  380.  
  381. output = tf.nn.relu(output)
  382. output = tf.reduce_mean(output, axis=[1, 2])
  383. logits = custom_ops.custom_fully_connected(output, 1,
  384. spectral_normed=True,
  385. update_collection=update_collection,
  386. reuse=reuse,
  387. scope='D.Output')
  388.  
  389. output_wgan = tf.reshape(logits, [-1])
  390.  
  391. return output_wgan
Add Comment
Please, Sign In to add comment