Advertisement
Guest User

kikko StyleGAN Colab

a guest
Feb 20th, 2019
128
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.19 KB | None | 0 0
  1. import os
  2. import pickle
  3. import numpy as np
  4. import PIL.Image
  5. import dnnlib
  6. import dnnlib.tflib as tflib
  7. import config
  8. import scipy
  9.  
  10. def main():
  11.  
  12.     tflib.init_tf()
  13.  
  14.     # Load pre-trained network.
  15.     # url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ'
  16.     # with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
  17.     # _G, _D, Gs = pickle.load(open("results/00015-sgan-holo-2gpu/network-snapshot-011370.pkl", "rb"))
  18.     # _G, _D, Gs = pickle.load(open("results/00035-sgan-faces-2gpu/network-snapshot-009493.pkl", "rb"))
  19.     # _G, _D, Gs = pickle.load(open("results/02029-sgan-ffhq-2gpu/network-snapshot-011098.pkl", "rb"))
  20.     # _G, _D, Gs = pickle.load(open("results/02036-sgan-ffhqdanbooru-2gpu/network-snapshot-012101.pkl", "rb"))
  21.     # _G, _D, Gs = pickle.load(open("results/02037-sgan-danbooru2018-2gpu/network-snapshot-012572.pkl", "rb"))
  22.     _G, _D, Gs = pickle.load(open("results/02043-sgan-faces-2gpu/network-snapshot-011293.pkl", "rb"))
  23.     # _G, _D, Gs = pickle.load(open("results/02021-sgan-faces-2gpu/network-snapshot-010483.pkl", "rb"))
  24.     # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
  25.     # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
  26.     # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
  27.  
  28.     grid_size = [2,2]
  29.     image_shrink = 1
  30.     image_zoom = 1
  31.     duration_sec = 60.0
  32.     smoothing_sec = 1.0
  33.     mp4_fps = 20
  34.     mp4_codec = 'libx264'
  35.     mp4_bitrate = '5M'
  36.     random_seed = 404
  37.     mp4_file = 'results/random_grid_%s.mp4' % random_seed
  38.     minibatch_size = 8
  39.  
  40.     num_frames = int(np.rint(duration_sec * mp4_fps))
  41.     random_state = np.random.RandomState(None)
  42.  
  43.     # Generate latent vectors
  44.     shape = [num_frames, np.prod(grid_size)] + Gs.input_shape[1:] # [frame, image, channel, component]
  45.     all_latents = random_state.randn(*shape).astype(np.float32)
  46.     import scipy
  47.     all_latents = scipy.ndimage.gaussian_filter(all_latents, [smoothing_sec * mp4_fps] + [0] * len(Gs.input_shape), mode='wrap')
  48.     all_latents /= np.sqrt(np.mean(np.square(all_latents)))
  49.  
  50.  
  51.     def create_image_grid(images, grid_size=None):
  52.         assert images.ndim == 3 or images.ndim == 4
  53.         num, img_h, img_w, channels = images.shape
  54.  
  55.         if grid_size is not None:
  56.             grid_w, grid_h = tuple(grid_size)
  57.         else:
  58.             grid_w = max(int(np.ceil(np.sqrt(num))), 1)
  59.             grid_h = max((num - 1) // grid_w + 1, 1)
  60.  
  61.         grid = np.zeros([grid_h * img_h, grid_w * img_w, channels], dtype=images.dtype)
  62.         for idx in range(num):
  63.             x = (idx % grid_w) * img_w
  64.             y = (idx // grid_w) * img_h
  65.             grid[y : y + img_h, x : x + img_w] = images[idx]
  66.         return grid
  67.  
  68.     # Frame generation func for moviepy.
  69.     def make_frame(t):
  70.         frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
  71.         latents = all_latents[frame_idx]
  72.         fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
  73.         images = Gs.run(latents, None, truncation_psi=0.7,
  74.                               randomize_noise=False, output_transform=fmt)
  75.  
  76.         grid = create_image_grid(images, grid_size)
  77.         if image_zoom > 1:
  78.             grid = scipy.ndimage.zoom(grid, [image_zoom, image_zoom, 1], order=0)
  79.         if grid.shape[2] == 1:
  80.             grid = grid.repeat(3, 2) # grayscale => RGB
  81.         return grid
  82.  
  83.     # Generate video.
  84.     import moviepy.editor
  85.     video_clip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
  86.     video_clip.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)
  87.  
  88.     # import scipy
  89.     # coarse
  90.     duration_sec = 60.0
  91.     smoothing_sec = 1.0
  92.     mp4_fps = 20
  93.  
  94.     num_frames = int(np.rint(duration_sec * mp4_fps))
  95.     random_seed = 500
  96.     random_state = np.random.RandomState(random_seed)
  97.  
  98.  
  99.     w = 512
  100.     h = 512
  101.     #src_seeds = [601]
  102.     dst_seeds = [700]
  103.     style_ranges = ([0] * 7 + [range(8,16)]) * len(dst_seeds)
  104.  
  105.     fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
  106.     synthesis_kwargs = dict(output_transform=fmt, truncation_psi=0.7, minibatch_size=8)
  107.  
  108.     shape = [num_frames] + Gs.input_shape[1:] # [frame, image, channel, component]
  109.     src_latents = random_state.randn(*shape).astype(np.float32)
  110.     src_latents = scipy.ndimage.gaussian_filter(src_latents,
  111.                                                 smoothing_sec * mp4_fps,
  112.                                                 mode='wrap')
  113.     src_latents /= np.sqrt(np.mean(np.square(src_latents)))
  114.  
  115.     dst_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds)
  116.  
  117.  
  118.     src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
  119.     dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component]
  120.     src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
  121.     dst_images = Gs.components.synthesis.run(dst_dlatents, randomize_noise=False, **synthesis_kwargs)
  122.  
  123.  
  124.     canvas = PIL.Image.new('RGB', (w * (len(dst_seeds) + 1), h * 2), 'white')
  125.  
  126.     for col, dst_image in enumerate(list(dst_images)):
  127.         canvas.paste(PIL.Image.fromarray(dst_image, 'RGB'), ((col + 1) * h, 0))
  128.  
  129.     def make_frame(t):
  130.         frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
  131.         src_image = src_images[frame_idx]
  132.         canvas.paste(PIL.Image.fromarray(src_image, 'RGB'), (0, h))
  133.  
  134.         for col, dst_image in enumerate(list(dst_images)):
  135.             col_dlatents = np.stack([dst_dlatents[col]])
  136.             col_dlatents[:, style_ranges[col]] = src_dlatents[frame_idx, style_ranges[col]]
  137.             col_images = Gs.components.synthesis.run(col_dlatents, randomize_noise=False, **synthesis_kwargs)
  138.             for row, image in enumerate(list(col_images)):
  139.                 canvas.paste(PIL.Image.fromarray(image, 'RGB'), ((col + 1) * h, (row + 1) * w))
  140.         return np.array(canvas)
  141.  
  142.     # Generate video.
  143.     import moviepy.editor
  144.     mp4_file = 'results/interpolate.mp4'
  145.     mp4_codec = 'libx264'
  146.     mp4_bitrate = '5M'
  147.  
  148.     video_clip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
  149.     video_clip.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)
  150.  
  151.     import scipy
  152.  
  153.     duration_sec = 60.0
  154.     smoothing_sec = 1.0
  155.     mp4_fps = 20
  156.  
  157.     num_frames = int(np.rint(duration_sec * mp4_fps))
  158.     random_seed = 503
  159.     random_state = np.random.RandomState(random_seed)
  160.  
  161.  
  162.     w = 512
  163.     h = 512
  164.     style_ranges = [range(6,16)]
  165.  
  166.     fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
  167.     synthesis_kwargs = dict(output_transform=fmt, truncation_psi=0.7, minibatch_size=8)
  168.  
  169.     shape = [num_frames] + Gs.input_shape[1:] # [frame, image, channel, component]
  170.     src_latents = random_state.randn(*shape).astype(np.float32)
  171.     src_latents = scipy.ndimage.gaussian_filter(src_latents,
  172.                                                 smoothing_sec * mp4_fps,
  173.                                                 mode='wrap')
  174.     src_latents /= np.sqrt(np.mean(np.square(src_latents)))
  175.  
  176.     dst_latents = np.stack([random_state.randn(Gs.input_shape[1])])
  177.  
  178.  
  179.     src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
  180.     dst_dlatents = Gs.components.mapping.run(dst_latents, None) # [seed, layer, component]
  181.  
  182.  
  183.     def make_frame(t):
  184.         frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
  185.         col_dlatents = np.stack([dst_dlatents[0]])
  186.         col_dlatents[:, style_ranges[0]] = src_dlatents[frame_idx, style_ranges[0]]
  187.         col_images = Gs.components.synthesis.run(col_dlatents, randomize_noise=False, **synthesis_kwargs)
  188.         return col_images[0]
  189.  
  190.     # Generate video.
  191.     import moviepy.editor
  192.     mp4_file = 'results/fine_%s.mp4' % (random_seed)
  193.     mp4_codec = 'libx264'
  194.     mp4_bitrate = '5M'
  195.  
  196.     video_clip = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
  197.     video_clip.write_videofile(mp4_file, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)
  198.  
  199. if __name__ == "__main__":
  200.     main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement