Advertisement
Guest User

Untitled

a guest
Oct 17th, 2019
115
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.14 KB | None | 0 0
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # This work is licensed under the Creative Commons Attribution-NonCommercial
  4. # 4.0 International License. To view a copy of this license, visit
  5. # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
  6. # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
  7.  
  8. """Minimal script for generating an image using pre-trained StyleGAN generator."""
  9.  
  10. import os
  11. import pickle
  12. import numpy as np
  13. import PIL.Image
  14. import dnnlib
  15. import dnnlib.tflib as tflib
  16. import config
  17. import tensorflow as tf
  18.  
  19. def main():
  20. # Initialize TensorFlow.
  21. tflib.init_tf()
  22.  
  23. # Load pre-trained network.
  24. url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
  25. with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
  26. _G, _D, Gs = pickle.load(f)
  27. # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
  28. # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
  29. # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
  30.  
  31. print(type(Gs))
  32. # Print network details.
  33. Gs.print_layers()
  34.  
  35. # Pick latent vector.
  36. rnd = np.random.RandomState(6)
  37. latents = rnd.randn(1, Gs.input_shape[1])
  38. print("input shape", Gs.input_shape)
  39.  
  40. # Generate image.
  41. fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
  42. images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=False, output_transform=fmt)
  43.  
  44. # Save image.
  45. os.makedirs(config.result_dir, exist_ok=True)
  46. png_filename = os.path.join(config.result_dir, 'example.png')
  47. PIL.Image.fromarray(images[0], 'RGB').save(png_filename)
  48.  
  49. print("\n\ncache:", Gs._run_cache)
  50. keys = list(Gs._run_cache.keys())
  51.  
  52. in_expr, out_expr = Gs._run_cache[keys[0]]
  53. print(in_expr, out_expr)
  54.  
  55. target = PIL.Image.open("results/base.png")
  56. target_expr = tf.constant(np.float32(target))
  57.  
  58. learning_rate = .1
  59. opt = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
  60. #train = opt.minimize(total_loss, var_list=[in_expr[0]])
  61. out_gpu = Gs.get_output_for(*in_expr, return_as_list=True)
  62. print("\n\n\n$$$$$$$$$$$$$$$$\n>", out_gpu)
  63.  
  64. #grads,_ = tf.gradients(pixelloss, in_expr)
  65. #print("grads", grads)
  66. #exit()
  67. #writer = tf.summary.FileWriter("output", sess.graph)
  68. #rnd = np.random.RandomState(5)
  69. #latents2 = rnd.randn(1, Gs.input_shape[1])
  70. #dx2 = latents2-latents
  71. tf.keras.backend.set_image_data_format('channels_first')
  72.  
  73. #pretrained_resnet = tf.keras.applications.MobileNet(
  74. # input_tensor = tf.stack([tf.reshape(out_gpu,(3,1024,1024)),target_expr],0),
  75. #input_tensor = tf.reshape(target_expr,(1,3,1024,1024)),
  76. # weights="imagenet",
  77. # include_top=False,
  78. #input_shape=(1024,1024,3)
  79. #)
  80.  
  81. #l1,l2 = tf.split(pretrained_resnet.output, 2)
  82. #contx_loss = tf.reduce_mean(l1-l2)
  83. #loss,resout = sess.run([contx_loss, pretrained_resnet.output],
  84. # {in_expr[0]: latents, in_expr[1] : np.zeros((1,0))})
  85. #print(loss)
  86. #grads,_ = tf.gradients(0.01*pixelloss+0.1*contx_loss, in_expr)
  87. out_gpu_im = tflib.convert_images_to_uint8(out_gpu)
  88. target_expr_im = target_expr
  89. #target_expr = tflib.convert_images_from_uint8(tf.expand_dims(target_expr,0), nhwc_to_nchw=True)
  90. target_expr = tflib.convert_images_from_uint8(tf.expand_dims(target_expr,0), nhwc_to_nchw=True)
  91. #target_expr = tflib.convert_images_from_uint8(out_gpu_im)
  92.  
  93. print("\n\n\n", out_gpu, target_expr, "\n\n\n")
  94. pixelloss = tf.nn.l2_loss(out_gpu - target_expr)
  95. grads,_ = tf.gradients(pixelloss, in_expr)
  96. #grads,_ = tf.gradients(contx_loss, in_expr)
  97. #exit()
  98. print("output tensors", out_gpu, target_expr)
  99. sess = tf.get_default_session()
  100. for i in range(1000):
  101. print("gradient step",i)
  102. dx, ploss,out,t,og,te = sess.run([grads,pixelloss,out_gpu_im,target_expr_im,out_gpu, target_expr],
  103. {in_expr[0]: latents, in_expr[1] : np.zeros((1,0))})
  104.  
  105. print("losses", ploss,ploss)
  106. #print(dx[0].shape, np.max(dx[0]))
  107. #latents = latents - ((0.00001/(1+0.1*i))*+dx[0])
  108. latents = latents - 0.00001* dx
  109. #np.save("og.np", og[0])
  110. #np.save("te.np", te[0])
  111.  
  112. #writer.close()
  113. if i % 100 == 0:
  114. images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=False, output_transform=fmt)
  115.  
  116. print("output tensors",len(out), out[0].shape, t.shape)
  117. PIL.Image.fromarray(images[0], 'RGB').save("results/{}_out.png".format(i))
  118. PIL.Image.fromarray(np.uint8(t), 'RGB').save("results/ground_target.png".format(i))
  119. PIL.Image.fromarray(out[0][0].transpose((1,2,0)), 'RGB').save("results/ground_out.png".format(i))
  120.  
  121. print(out[0][0].shape)
  122. print(out[0][0].transpose((2,1,0)).shape)
  123. print(out[0][0].transpose((1,2,0)).shape)
  124.  
  125. pass
  126.  
  127. if __name__ == "__main__":
  128. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement