Advertisement
Vageyser

Modified output path and filenames in optimized_img2img.py to match the original img2img.py

Aug 26th, 2022
199
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.21 KB | None | 0 0
  1. import argparse, os, re
  2. import torch
  3. import numpy as np
  4. from random import randint
  5. from omegaconf import OmegaConf
  6. from PIL import Image
  7. from tqdm import tqdm, trange
  8. from itertools import islice
  9. from einops import rearrange
  10. from torchvision.utils import make_grid
  11. import time
  12. from pytorch_lightning import seed_everything
  13. from torch import autocast
  14. from contextlib import contextmanager, nullcontext
  15. from einops import rearrange, repeat
  16. from ldm.util import instantiate_from_config
  17. from split_subprompts import split_weighted_subprompts
  18. from transformers import logging
  19. logging.set_verbosity_error()
  20.  
  21. def chunk(it, size):
  22.     it = iter(it)
  23.     return iter(lambda: tuple(islice(it, size)), ())
  24.  
  25.  
  26. def load_model_from_config(ckpt, verbose=False):
  27.     print(f"Loading model from {ckpt}")
  28.     pl_sd = torch.load(ckpt, map_location="cpu")
  29.     if "global_step" in pl_sd:
  30.         print(f"Global Step: {pl_sd['global_step']}")
  31.     sd = pl_sd["state_dict"]
  32.     return sd
  33.  
  34. def load_img(path, h0, w0):
  35.    
  36.     image = Image.open(path).convert("RGB")
  37.     w, h = image.size
  38.  
  39.     print(f"loaded input image of size ({w}, {h}) from {path}")  
  40.     if(h0 is not None and w0 is not None):
  41.         h, w = h0, w0
  42.    
  43.     w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 32
  44.  
  45.     print(f"New image size ({w}, {h})")
  46.     image = image.resize((w, h), resample = Image.LANCZOS)
  47.     image = np.array(image).astype(np.float32) / 255.0
  48.     image = image[None].transpose(0, 3, 1, 2)
  49.     image = torch.from_numpy(image)
  50.     return 2.*image - 1.
  51.  
  52. config = "optimizedSD/v1-inference.yaml"
  53. ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
  54.  
  55. parser = argparse.ArgumentParser()
  56.  
  57. parser.add_argument(
  58.     "--prompt",
  59.     type=str,
  60.     nargs="?",
  61.     default="a painting of a virus monster playing guitar",
  62.     help="the prompt to render"
  63. )
  64. parser.add_argument(
  65.     "--outdir",
  66.     type=str,
  67.     nargs="?",
  68.     help="dir to write results to",
  69.     default="outputs/img2img-samples"
  70. )
  71.  
  72. parser.add_argument(
  73.     "--init-img",
  74.     type=str,
  75.     nargs="?",
  76.     help="path to the input image"
  77. )
  78.  
  79. parser.add_argument(
  80.     "--skip_grid",
  81.     action='store_true',
  82.     help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
  83. )
  84. parser.add_argument(
  85.     "--skip_save",
  86.     action='store_true',
  87.     help="do not save individual samples. For speed measurements.",
  88. )
  89. parser.add_argument(
  90.     "--ddim_steps",
  91.     type=int,
  92.     default=50,
  93.     help="number of ddim sampling steps",
  94. )
  95.  
  96. parser.add_argument(
  97.     "--ddim_eta",
  98.     type=float,
  99.     default=0.0,
  100.     help="ddim eta (eta=0.0 corresponds to deterministic sampling",
  101. )
  102. parser.add_argument(
  103.     "--n_iter",
  104.     type=int,
  105.     default=1,
  106.     help="sample this often",
  107. )
  108. parser.add_argument(
  109.     "--H",
  110.     type=int,
  111.     default=None,
  112.     help="image height, in pixel space",
  113. )
  114. parser.add_argument(
  115.     "--W",
  116.     type=int,
  117.     default=None,
  118.     help="image width, in pixel space",
  119. )
  120. parser.add_argument(
  121.     "--strength",
  122.     type=float,
  123.     default=0.75,
  124.     help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
  125. )
  126. parser.add_argument(
  127.     "--C",
  128.     type=int,
  129.     default=4,
  130.     help="latent channels",
  131. )
  132. parser.add_argument(
  133.     "--f",
  134.     type=int,
  135.     default=8,
  136.     help="downsampling factor",
  137. )
  138. parser.add_argument(
  139.     "--n_samples",
  140.     type=int,
  141.     default=5,
  142.     help="how many samples to produce for each given prompt. A.k.a. batch size",
  143. )
  144. parser.add_argument(
  145.     "--n_rows",
  146.     type=int,
  147.     default=0,
  148.     help="rows in the grid (default: n_samples)",
  149. )
  150. parser.add_argument(
  151.     "--scale",
  152.     type=float,
  153.     default=7.5,
  154.     help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
  155. )
  156. parser.add_argument(
  157.     "--from-file",
  158.     type=str,
  159.     help="if specified, load prompts from this file",
  160. )
  161. parser.add_argument(
  162.     "--seed",
  163.     type=int,
  164.     default=None,
  165.     help="the seed (for reproducible sampling)",
  166. )
  167. parser.add_argument(
  168.     "--device",
  169.     type=str,
  170.     default="cuda",
  171.     help="CPU or GPU (cuda/cuda:0/cuda:1/...)",
  172. )
  173. parser.add_argument(
  174.     "--small_batch",
  175.     action='store_true',
  176.     help="Reduce inference time when generate a smaller batch of images",
  177. )
  178. parser.add_argument(
  179.     "--precision",
  180.     type=str,
  181.     help="evaluate at this precision",
  182.     choices=["full", "autocast"],
  183.     default="autocast"
  184. )
  185. opt = parser.parse_args()
  186.  
  187. tic = time.time()
  188. os.makedirs(opt.outdir, exist_ok=True)
  189. outpath = opt.outdir
  190.  
  191. sample_path = os.path.join(outpath, "samples")
  192. os.makedirs(sample_path, exist_ok=True)
  193. base_count = len(os.listdir(sample_path))
  194. grid_count = len(os.listdir(outpath)) - 1
  195.  
  196. if opt.seed == None:
  197.     opt.seed = randint(0, 1000000)
  198. seed_everything(opt.seed)
  199.  
  200. sd = load_model_from_config(f"{ckpt}")
  201. li = []
  202. lo = []
  203. for key, value in sd.items():
  204.     sp = key.split('.')
  205.     if(sp[0]) == 'model':
  206.         if('input_blocks' in sp):
  207.             li.append(key)
  208.         elif('middle_block' in sp):
  209.             li.append(key)
  210.         elif('time_embed' in sp):
  211.             li.append(key)
  212.         else:
  213.             lo.append(key)
  214. for key in li:
  215.     sd['model1.' + key[6:]] = sd.pop(key)
  216. for key in lo:
  217.     sd['model2.' + key[6:]] = sd.pop(key)
  218.  
  219. config = OmegaConf.load(f"{config}")
  220.  
  221. if opt.small_batch:
  222.     config.modelUNet.params.small_batch = True
  223. else:
  224.     config.modelUNet.params.small_batch = False
  225. config.modelCondStage.params.cond_stage_config.params.device = opt.device
  226.  
  227. assert os.path.isfile(opt.init_img)
  228. init_image = load_img(opt.init_img, opt.H, opt.W).to(opt.device)
  229.  
  230. model = instantiate_from_config(config.modelUNet)
  231. _, _ = model.load_state_dict(sd, strict=False)
  232. model.eval()
  233. model.cdevice = opt.device
  234.    
  235. modelCS = instantiate_from_config(config.modelCondStage)
  236. _, _ = modelCS.load_state_dict(sd, strict=False)
  237. modelCS.eval()
  238.    
  239. modelFS = instantiate_from_config(config.modelFirstStage)
  240. _, _ = modelFS.load_state_dict(sd, strict=False)
  241. modelFS.eval()
  242. del sd
  243. if opt.device != 'cpu' and opt.precision == "autocast":
  244.     model.half()
  245.     modelCS.half()
  246.     modelFS.half()
  247.     init_image = init_image.half()
  248.  
  249. batch_size = opt.n_samples
  250. n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
  251. if not opt.from_file:
  252.     prompt = opt.prompt
  253.     assert prompt is not None
  254.     data = [batch_size * [prompt]]
  255.  
  256. else:
  257.     print(f"reading prompts from {opt.from_file}")
  258.     with open(opt.from_file, "r") as f:
  259.         data = f.read().splitlines()
  260.         data = batch_size * list(data)
  261.         data = list(chunk(data, batch_size))
  262.  
  263. modelFS.to(opt.device)
  264.  
  265. init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
  266. init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image))  # move to latent space
  267.  
  268. if(opt.device != 'cpu'):
  269.     mem = torch.cuda.memory_allocated()/1e6
  270.     modelFS.to("cpu")
  271.     while(torch.cuda.memory_allocated()/1e6 >= mem):
  272.         time.sleep(1)
  273.  
  274.  
  275. assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
  276. t_enc = int(opt.strength * opt.ddim_steps)
  277. print(f"target t_enc is {t_enc} steps")
  278.  
  279.  
  280. if opt.precision=="autocast" and opt.device != "cpu":
  281.     precision_scope = autocast
  282. else:
  283.     precision_scope = nullcontext
  284.  
  285. with torch.no_grad():
  286.  
  287.     all_samples = list()
  288.     for n in trange(opt.n_iter, desc="Sampling"):
  289.         for prompts in tqdm(data, desc="data"):
  290.              with precision_scope("cuda"):
  291.                 modelCS.to(opt.device)
  292.                 uc = None
  293.                 if opt.scale != 1.0:
  294.                     uc = modelCS.get_learned_conditioning(batch_size * [""])
  295.                 if isinstance(prompts, tuple):
  296.                     prompts = list(prompts)
  297.  
  298.                 subprompts,weights = split_weighted_subprompts(prompts[0])
  299.                 if len(subprompts) > 1:
  300.                     c = torch.zeros_like(uc)
  301.                     totalWeight = sum(weights)
  302.                     # normalize each "sub prompt" and add it
  303.                     for i in range(len(subprompts)):
  304.                         weight = weights[i]
  305.                         # if not skip_normalize:
  306.                         weight = weight / totalWeight
  307.                         c = torch.add(c,modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
  308.                 else:
  309.                     c = modelCS.get_learned_conditioning(prompts)
  310.  
  311.                 if(opt.device != 'cpu'):
  312.                     mem = torch.cuda.memory_allocated()/1e6
  313.                     modelCS.to("cpu")
  314.                     while(torch.cuda.memory_allocated()/1e6 >= mem):
  315.                         time.sleep(1)
  316.  
  317.                 # encode (scaled latent)
  318.                 z_enc = model.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(opt.device), opt.seed,opt.ddim_eta, opt.ddim_steps)
  319.                 # decode it
  320.                 samples_ddim = model.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
  321.                                             unconditional_conditioning=uc,)
  322.  
  323.  
  324.                 modelFS.to(opt.device)
  325.                 print("saving images")
  326.                 for i in range(batch_size):
  327.                    
  328.                     x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
  329.                     x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  330.                     x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c')
  331.                     Image.fromarray(x_sample.astype(np.uint8)).save(
  332.                         os.path.join(sample_path, f"{base_count:05}.png"))
  333.                     opt.seed+=1
  334.                     base_count += 1
  335.  
  336.  
  337.                 if(opt.device != 'cpu'):
  338.                     mem = torch.cuda.memory_allocated()/1e6
  339.                     modelFS.to("cpu")
  340.                     while(torch.cuda.memory_allocated()/1e6 >= mem):
  341.                         time.sleep(1)
  342.  
  343.                 del samples_ddim
  344.                 print("memory_final = ", torch.cuda.memory_allocated()/1e6)
  345.  
  346. toc = time.time()
  347.  
  348. time_taken = (toc-tic)/60.0
  349.  
  350. print(("Your samples are ready in {0:.2f} minutes and waiting for you here \n" + sample_path).format(time_taken))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement