Guest User

kdiff.py

a guest
Sep 8th, 2022
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 20.91 KB | None | 0 0
  1. #I hate command line so much it's unreal
  2. #hacked together k_diffusion courtesy of /vg/ /g/
  3.  
  4. import PIL
  5. import gradio as gr
  6. import argparse, os, sys, glob
  7. import torch
  8. import torch.nn as nn
  9. import numpy as np
  10. from omegaconf import OmegaConf
  11. from PIL import Image
  12. from tqdm import tqdm, trange
  13. from itertools import islice
  14. from einops import rearrange, repeat
  15. from torchvision.utils import make_grid
  16. import time
  17. from pytorch_lightning import seed_everything
  18. from torch import autocast
  19. from contextlib import contextmanager, nullcontext
  20. import accelerate
  21. import mimetypes
  22. mimetypes.init()
  23. mimetypes.add_type('application/javascript', '.js')
  24.  
  25.  
  26. import k_diffusion as K
  27. from ldm.util import instantiate_from_config
  28. from ldm.models.diffusion.ddim import DDIMSampler
  29. from ldm.models.diffusion.plms import PLMSSampler
  30.  
  31.  
  32. def chunk(it, size):
  33. it = iter(it)
  34. return iter(lambda: tuple(islice(it, size)), ())
  35.  
  36.  
  37. def load_model_from_config(config, ckpt, verbose=False):
  38. print(f"Loading model from {ckpt}")
  39. pl_sd = torch.load(ckpt, map_location="cpu")
  40. if "global_step" in pl_sd:
  41. print(f"Global Step: {pl_sd['global_step']}")
  42. sd = pl_sd["state_dict"]
  43. model = instantiate_from_config(config.model)
  44. m, u = model.load_state_dict(sd, strict=False)
  45. if len(m) > 0 and verbose:
  46. print("missing keys:")
  47. print(m)
  48. if len(u) > 0 and verbose:
  49. print("unexpected keys:")
  50. print(u)
  51.  
  52. model.cuda()
  53. model.eval()
  54. return model
  55.  
  56. def load_img_pil(img_pil):
  57. image = img_pil.convert("RGB")
  58. w, h = image.size
  59. print(f"loaded input image of size ({w}, {h})")
  60. w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
  61. image = image.resize((w, h), resample=PIL.Image.LANCZOS)
  62. print(f"cropped image to size ({w}, {h})")
  63. image = np.array(image).astype(np.float32) / 255.0
  64. image = image[None].transpose(0, 3, 1, 2)
  65. image = torch.from_numpy(image)
  66. return 2.*image - 1.
  67.  
  68. def load_img(path):
  69. return load_img_pil(Image.open(path))
  70.  
  71.  
  72. class CFGDenoiser(nn.Module):
  73. def __init__(self, model):
  74. super().__init__()
  75. self.inner_model = model
  76.  
  77. def forward(self, x, sigma, uncond, cond, cond_scale):
  78. x_in = torch.cat([x] * 2)
  79. sigma_in = torch.cat([sigma] * 2)
  80. cond_in = torch.cat([uncond, cond])
  81. uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
  82. return uncond + (cond - uncond) * cond_scale
  83.  
  84.  
  85. config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
  86. model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")
  87.  
  88. device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
  89. model = model.half().to(device)
  90.  
  91. def dream(prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
  92. torch.cuda.empty_cache()
  93. parser = argparse.ArgumentParser()
  94.  
  95. parser.add_argument(
  96. "--outdir",
  97. type=str,
  98. nargs="?",
  99. help="dir to write results to",
  100. default="outputs/txt2img-samples"
  101. )
  102. parser.add_argument(
  103. "--skip_grid",
  104. action='store_true',
  105. help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
  106. )
  107. parser.add_argument(
  108. "--skip_save",
  109. action='store_true',
  110. help="do not save individual samples. For speed measurements.",
  111. )
  112. parser.add_argument(
  113. "--laion400m",
  114. action='store_true',
  115. help="uses the LAION400M model",
  116. )
  117. parser.add_argument(
  118. "--H",
  119. type=int,
  120. default=height,
  121. help="image height, in pixel space",
  122. )
  123. parser.add_argument(
  124. "--W",
  125. type=int,
  126. default=width,
  127. help="image width, in pixel space",
  128. )
  129. parser.add_argument(
  130. "--C",
  131. type=int,
  132. default=4,
  133. help="latent channels",
  134. )
  135. parser.add_argument(
  136. "--f",
  137. type=int,
  138. default=8,
  139. help="downsampling factor",
  140. )
  141. parser.add_argument(
  142. "--n_rows",
  143. type=int,
  144. default=0,
  145. help="rows in the grid (default: n_samples)",
  146. )
  147. parser.add_argument(
  148. "--from-file",
  149. type=str,
  150. help="if specified, load prompts from this file",
  151. )
  152. parser.add_argument(
  153. "--config",
  154. type=str,
  155. default="configs/stable-diffusion/v1-inference.yaml",
  156. help="path to config which constructs model",
  157. )
  158. parser.add_argument(
  159. "--ckpt",
  160. type=str,
  161. default="models/ldm/stable-diffusion-v1/model.ckpt",
  162. help="path to checkpoint of model",
  163. )
  164. parser.add_argument(
  165. "--precision",
  166. type=str,
  167. help="evaluate at this precision",
  168. choices=["full", "autocast"],
  169. default="autocast"
  170. )
  171. opt = parser.parse_args()
  172.  
  173. if opt.laion400m:
  174. print("Falling back to LAION 400M model...")
  175. opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
  176. opt.ckpt = "models/ldm/text2img-large/model.ckpt"
  177. opt.outdir = "outputs/txt2img-samples-laion400m"
  178.  
  179. accelerator = accelerate.Accelerator()
  180. device = accelerator.device
  181. rng_seed = seed_everything(seed)
  182. seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes])
  183. torch.manual_seed(seeds[accelerator.process_index].item())
  184.  
  185. if plms:
  186. sampler = PLMSSampler(model)
  187. else:
  188. sampler = DDIMSampler(model)
  189.  
  190.  
  191. model_wrap = K.external.CompVisDenoiser(model)
  192. sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
  193.  
  194. os.makedirs(opt.outdir, exist_ok=True)
  195. outpath = opt.outdir
  196.  
  197. batch_size = n_samples
  198. n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
  199. if not opt.from_file:
  200. assert prompt is not None
  201. data = [batch_size * [prompt]]
  202.  
  203. else:
  204. print(f"reading prompts from {opt.from_file}")
  205. with open(opt.from_file, "r") as f:
  206. data = f.read().splitlines()
  207. data = list(chunk(data, batch_size))
  208.  
  209. sample_path = os.path.join(outpath, "samples")
  210. os.makedirs(sample_path, exist_ok=True)
  211. base_count = len(os.listdir(sample_path))
  212. grid_count = len(os.listdir(outpath)) - 1
  213. seedit = 0
  214.  
  215. start_code = None
  216. if fixed_code:
  217. start_code = torch.randn([n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
  218.  
  219. precision_scope = autocast if opt.precision=="autocast" else nullcontext
  220. output_images = []
  221. with torch.no_grad():
  222. with precision_scope("cuda"):
  223. with model.ema_scope():
  224. tic = time.time()
  225. all_samples = list()
  226. for n in trange(n_iter, desc="Sampling", disable =not accelerator.is_main_process):
  227. for prompts in tqdm(data, desc="data", disable=not accelerator.is_main_process):
  228. uc = None
  229. if cfg_scale != 1.0:
  230. uc = model.get_learned_conditioning(batch_size * [""])
  231. if isinstance(prompts, tuple):
  232. prompts = list(prompts)
  233. c = model.get_learned_conditioning(prompts)
  234. shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
  235. # samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
  236. # conditioning=c,
  237. # batch_size=opt.n_samples,
  238. # shape=shape,
  239. # verbose=False,
  240. # unconditional_guidance_scale=opt.scale,
  241. # unconditional_conditioning=uc,
  242. # eta=opt.ddim_eta,
  243. # x_T=start_code)
  244.  
  245. sigmas = model_wrap.get_sigmas(ddim_steps)
  246. torch.manual_seed(rng_seed + seedit) # changes manual seeding procedure
  247. # sigmas = K.sampling.get_sigmas_karras(opt.ddim_steps, sigma_min, sigma_max, device=device)
  248. x = torch.randn([n_samples, *shape], device=device) * sigmas[0] # for GPU draw
  249. # x = torch.randn([opt.n_samples, *shape]).to(device) * sigmas[0] # for CPU draw
  250. model_wrap_cfg = CFGDenoiser(model_wrap)
  251. extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
  252. samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
  253. x_samples_ddim = model.decode_first_stage(samples_ddim)
  254. x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  255. x_samples_ddim = accelerator.gather(x_samples_ddim)
  256.  
  257. if accelerator.is_main_process and not opt.skip_save:
  258. for x_sample in x_samples_ddim:
  259. x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
  260. Image.fromarray(x_sample.astype(np.uint8)).save(
  261. os.path.join(sample_path, f"{base_count:05}-{rng_seed + seedit}_{prompt.replace(' ', '_')[:128]}.png"))
  262. output_images.append(Image.fromarray(x_sample.astype(np.uint8)))
  263. base_count += 1
  264. seedit += 1
  265.  
  266. if accelerator.is_main_process and not opt.skip_grid:
  267. all_samples.append(x_samples_ddim)
  268.  
  269. if accelerator.is_main_process and not opt.skip_grid:
  270. # additionally, save as grid
  271. grid = torch.stack(all_samples, 0)
  272. grid = rearrange(grid, 'n b c h w -> (n b) c h w')
  273. grid = make_grid(grid, nrow=n_rows)
  274.  
  275. # to image
  276. grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
  277. Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
  278. grid_count += 1
  279.  
  280.  
  281. toc = time.time()
  282. del sampler
  283. return output_images, rng_seed
  284.  
  285. def translation(prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int):
  286. torch.cuda.empty_cache()
  287. parser = argparse.ArgumentParser()
  288.  
  289. parser.add_argument(
  290. "--outdir",
  291. type=str,
  292. nargs="?",
  293. help="dir to write results to",
  294. default="outputs/img2img-samples"
  295. )
  296.  
  297. parser.add_argument(
  298. "--skip_grid",
  299. action='store_true',
  300. help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
  301. )
  302.  
  303. parser.add_argument(
  304. "--skip_save",
  305. action='store_true',
  306. help="do not save indiviual samples. For speed measurements.",
  307. )
  308. parser.add_argument(
  309. "--C",
  310. type=int,
  311. default=4,
  312. help="latent channels",
  313. )
  314. parser.add_argument(
  315. "--f",
  316. type=int,
  317. default=8,
  318. help="downsampling factor, most often 8 or 16",
  319. )
  320. parser.add_argument(
  321. "--n_rows",
  322. type=int,
  323. default=0,
  324. help="rows in the grid (default: n_samples)",
  325. )
  326. parser.add_argument(
  327. "--from-file",
  328. type=str,
  329. help="if specified, load prompts from this file",
  330. )
  331. parser.add_argument(
  332. "--config",
  333. type=str,
  334. default="configs/stable-diffusion/v1-inference.yaml",
  335. help="path to config which constructs model",
  336. )
  337. parser.add_argument(
  338. "--ckpt",
  339. type=str,
  340. default="models/ldm/stable-diffusion-v1/model.ckpt",
  341. help="path to checkpoint of model",
  342. )
  343. parser.add_argument(
  344. "--precision",
  345. type=str,
  346. help="evaluate at this precision",
  347. choices=["full", "autocast"],
  348. default="autocast"
  349. )
  350.  
  351. opt = parser.parse_args()
  352.  
  353. accelerator = accelerate.Accelerator()
  354. device = accelerator.device
  355. rng_seed = seed_everything(seed)
  356. seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes])
  357. torch.manual_seed(seeds[accelerator.process_index].item())
  358.  
  359. sampler = DDIMSampler(model)
  360.  
  361. model_wrap = K.external.CompVisDenoiser(model)
  362. sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
  363.  
  364. os.makedirs(opt.outdir, exist_ok=True)
  365. outpath = opt.outdir
  366.  
  367. batch_size = n_samples
  368. n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
  369. if not opt.from_file:
  370. prompt = prompt
  371. assert prompt is not None
  372. data = [batch_size * [prompt]]
  373. else:
  374. print(f"reading prompts from {opt.from_file}")
  375. with open(opt.from_file, "r") as f:
  376. data = f.read().splitlines()
  377. data = list(chunk(data, batch_size))
  378.  
  379. sample_path = os.path.join(outpath, "samples")
  380. os.makedirs(sample_path, exist_ok=True)
  381. base_count = len(os.listdir(sample_path))
  382. grid_count = len(os.listdir(outpath)) - 1
  383. seedit = 0
  384.  
  385. image = init_img.convert("RGB")
  386. w, h = image.size
  387. print(f"loaded input image of size ({w}, {h})")
  388. w, h = map(lambda x: x - x % 32, (width, height)) # resize to integer multiple of 32
  389. image = image.resize((w, h), resample=PIL.Image.LANCZOS)
  390. print(f"cropped image to size ({w}, {h})")
  391. image = np.array(image).astype(np.float32) / 255.0
  392. image = image[None].transpose(0, 3, 1, 2)
  393. image = torch.from_numpy(image)
  394.  
  395. output_images = []
  396. precision_scope = autocast if opt.precision == "autocast" else nullcontext
  397. with torch.no_grad():
  398. with precision_scope("cuda"):
  399. init_image = 2.*image - 1.
  400. init_image = init_image.to(device)
  401. init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
  402. init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
  403. x0 = init_latent
  404.  
  405. sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)
  406.  
  407. assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
  408. t_enc = int(denoising_strength * ddim_steps)
  409. print(f"target t_enc is {t_enc} steps")
  410. with model.ema_scope():
  411. tic = time.time()
  412. all_samples = list()
  413. for n in trange(n_iter, desc="Sampling", disable=not accelerator.is_main_process):
  414. for prompts in tqdm(data, desc="data", disable=not accelerator.is_main_process):
  415. uc = None
  416. if cfg_scale != 1.0:
  417. uc = model.get_learned_conditioning(batch_size * [""])
  418. if isinstance(prompts, tuple):
  419. prompts = list(prompts)
  420. c = model.get_learned_conditioning(prompts)
  421.  
  422. sigmas = model_wrap.get_sigmas(ddim_steps)
  423. torch.manual_seed(rng_seed + seedit) # changes manual seeding procedure
  424. # sigmas = K.sampling.get_sigmas_karras(ddim_steps, sigma_min, sigma_max, device=device)
  425. noise = torch.randn_like(x0) * sigmas[ddim_steps - t_enc - 1] # for GPU draw
  426. xi = x0 + noise
  427. sigma_sched = sigmas[ddim_steps - t_enc - 1:]
  428. # x = torch.randn([n_samples, *shape]).to(device) * sigmas[0] # for CPU draw
  429. model_wrap_cfg = CFGDenoiser(model_wrap)
  430. extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
  431. samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args=extra_args, disable=not accelerator.is_main_process)
  432. x_samples_ddim = model.decode_first_stage(samples_ddim)
  433. x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  434. x_samples_ddim = accelerator.gather(x_samples_ddim)
  435.  
  436. if accelerator.is_main_process and not opt.skip_save:
  437. for x_sample in x_samples_ddim:
  438. x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
  439. Image.fromarray(x_sample.astype(np.uint8)).save(
  440. os.path.join(sample_path, f"{base_count:05}-{rng_seed + seedit}_{prompt.replace(' ', '_')[:128]}.png"))
  441. output_images.append(Image.fromarray(x_sample.astype(np.uint8)))
  442. base_count += 1
  443. seedit += 1
  444.  
  445. if accelerator.is_main_process and not opt.skip_grid:
  446. all_samples.append(x_samples_ddim)
  447.  
  448. if not opt.skip_grid:
  449. # additionally, save as grid
  450. grid = torch.stack(all_samples, 0)
  451. grid = rearrange(grid, 'n b c h w -> (n b) c h w')
  452. grid = make_grid(grid, nrow=n_rows)
  453.  
  454. # to image
  455. grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
  456. Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
  457. Image.fromarray(grid.astype(np.uint8))
  458. grid_count += 1
  459.  
  460. toc = time.time()
  461. del sampler
  462. return output_images, rng_seed
  463.  
  464. dream_interface = gr.Interface(
  465. dream,
  466. inputs=[
  467. gr.Textbox(placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
  468. gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
  469. gr.Checkbox(label='Enable PLMS sampling', value=False),
  470. gr.Checkbox(label='Enable Fixed Code sampling', value=False),
  471. gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
  472. gr.Slider(minimum=1, maximum=50, step=1, label='Sampling iterations', value=2),
  473. gr.Slider(minimum=1, maximum=8, step=1, label='Samples per iteration', value=1),
  474. gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0),
  475. gr.Number(label='Seed', value=-1),
  476. gr.Slider(minimum=32, maximum=2048, step=32, label="Height", value=512),
  477. gr.Slider(minimum=32, maximum=2048, step=32, label="Width", value=512),
  478. ],
  479. outputs=[
  480. gr.Gallery(),
  481. gr.Number(label='Seed')
  482. ],
  483. title="Stable Diffusion Text-to-Image K",
  484. description="Generate images from text with Stable Diffusion (using K-LMS)",
  485. )
  486.  
  487. # prompt, init_img, ddim_steps, plms, ddim_eta, n_iter, n_samples, cfg_scale, denoising_strength, seed
  488.  
  489. img2img_interface = gr.Interface(
  490. translation,
  491. inputs=[
  492. gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
  493. gr.Image(value="https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg", source="upload", interactive=True, type="pil"),
  494. gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
  495. gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
  496. gr.Slider(minimum=1, maximum=50, step=1, label='Sampling iterations', value=2),
  497. gr.Slider(minimum=1, maximum=8, step=1, label='Samples per iteration', value=2),
  498. gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0),
  499. gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
  500. gr.Number(label='Seed', value=-1),
  501. gr.Slider(minimum=64, maximum=2048, step=64, label="Resize Height", value=512),
  502. gr.Slider(minimum=64, maximum=2048, step=64, label="Resize Width", value=512),
  503. ],
  504. outputs=[
  505. gr.Gallery(),
  506. gr.Number(label='Seed')
  507. ],
  508. title="Stable Diffusion Image-to-Image",
  509. description="Generate images from images with Stable Diffusion",
  510. )
  511.  
  512. demo = gr.TabbedInterface(interface_list=[dream_interface, img2img_interface], tab_names=["Dream", "Image Translation"])
  513.  
  514. demo.launch()
Add Comment
Please, Sign In to add comment