Advertisement
Dvadch

Untitled

Aug 22nd, 2022 (edited)
131
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 23.42 KB | None | 0 0
  1. import PIL
  2. import gradio as gr
  3. import argparse, os, sys, glob
  4. import torch
  5. import torch.nn as nn
  6. import numpy as np
  7. from omegaconf import OmegaConf
  8. from PIL import Image
  9. from tqdm import tqdm, trange
  10. from itertools import islice
  11. from einops import rearrange, repeat
  12. from torchvision.utils import make_grid
  13. import time
  14. from pytorch_lightning import seed_everything
  15. from torch import autocast
  16. from contextlib import contextmanager, nullcontext
  17. import accelerate
  18. import mimetypes
  19. import gc
  20. mimetypes.init()
  21. mimetypes.add_type('application/javascript', '.js')
  22.  
  23.  
  24. import k_diffusion as K
  25. from ldm.util import instantiate_from_config
  26. from ldm.models.diffusion.ddim import DDIMSampler
  27. from ldm.models.diffusion.plms import PLMSSampler
  28.  
  29.  
  30. def chunk(it, size):
  31. it = iter(it)
  32. return iter(lambda: tuple(islice(it, size)), ())
  33.  
  34.  
  35. def load_model_from_config(config, ckpt, verbose=False):
  36. print(f"Loading model from {ckpt}")
  37. pl_sd = torch.load(ckpt, map_location="cpu")
  38. if "global_step" in pl_sd:
  39. print(f"Global Step: {pl_sd['global_step']}")
  40. sd = pl_sd["state_dict"]
  41. model = instantiate_from_config(config.model)
  42. m, u = model.load_state_dict(sd, strict=False)
  43. if len(m) > 0 and verbose:
  44. print("missing keys:")
  45. print(m)
  46. if len(u) > 0 and verbose:
  47. print("unexpected keys:")
  48. print(u)
  49.  
  50. model.cuda()
  51. model.eval()
  52. return model
  53.  
  54. def load_img_pil(img_pil):
  55. image = img_pil.convert("RGB")
  56. w, h = image.size
  57. print(f"loaded input image of size ({w}, {h})")
  58. w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
  59. image = image.resize((w, h), resample=PIL.Image.LANCZOS)
  60. print(f"cropped image to size ({w}, {h})")
  61. image = np.array(image).astype(np.float32) / 255.0
  62. image = image[None].transpose(0, 3, 1, 2)
  63. image = torch.from_numpy(image)
  64. return 2.*image - 1.
  65.  
  66. def load_img(path):
  67. return load_img_pil(Image.open(path))
  68.  
  69.  
  70. class CFGDenoiser(nn.Module):
  71. def __init__(self, model):
  72. super().__init__()
  73. self.inner_model = model
  74.  
  75. def forward(self, x, sigma, uncond, cond, cond_scale):
  76. x_in = torch.cat([x] * 2)
  77. sigma_in = torch.cat([sigma] * 2)
  78. cond_in = torch.cat([uncond, cond])
  79. uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
  80. return uncond + (cond - uncond) * cond_scale
  81.  
  82.  
  83. config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
  84. model = load_model_from_config(config, "/gdrive/My Drive/model.ckpt")
  85.  
  86. device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
  87. model = model.half().to(device)
  88.  
  89. def dream(prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scales: str, seed: int, height: int, width: int, same_seed=bool):
  90. torch.cuda.empty_cache()
  91. parser = argparse.ArgumentParser()
  92.  
  93. parser.add_argument(
  94. "--outdir",
  95. type=str,
  96. nargs="?",
  97. help="dir to write results to",
  98. default="outputs/txt2img-samples"
  99. )
  100. parser.add_argument(
  101. "--skip_grid",
  102. action='store_true',
  103. help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
  104. )
  105. parser.add_argument(
  106. "--skip_save",
  107. action='store_true',
  108. help="do not save individual samples. For speed measurements.",
  109. )
  110. parser.add_argument(
  111. "--laion400m",
  112. action='store_true',
  113. help="uses the LAION400M model",
  114. )
  115. parser.add_argument(
  116. "--H",
  117. type=int,
  118. default=height,
  119. help="image height, in pixel space",
  120. )
  121. parser.add_argument(
  122. "--W",
  123. type=int,
  124. default=width,
  125. help="image width, in pixel space",
  126. )
  127. parser.add_argument(
  128. "--C",
  129. type=int,
  130. default=4,
  131. help="latent channels",
  132. )
  133. parser.add_argument(
  134. "--f",
  135. type=int,
  136. default=8,
  137. help="downsampling factor",
  138. )
  139. parser.add_argument(
  140. "--n_rows",
  141. type=int,
  142. default=0,
  143. help="rows in the grid (default: n_samples)",
  144. )
  145. parser.add_argument(
  146. "--from-file",
  147. type=str,
  148. help="if specified, load prompts from this file",
  149. )
  150. parser.add_argument(
  151. "--config",
  152. type=str,
  153. default="configs/stable-diffusion/v1-inference.yaml",
  154. help="path to config which constructs model",
  155. )
  156. parser.add_argument(
  157. "--ckpt",
  158. type=str,
  159. default="/gdrive/My Drive/model.ckpt",
  160. help="path to checkpoint of model",
  161. )
  162. parser.add_argument(
  163. "--precision",
  164. type=str,
  165. help="evaluate at this precision",
  166. choices=["full", "autocast"],
  167. default="autocast"
  168. )
  169. opt = parser.parse_args()
  170.  
  171. if opt.laion400m:
  172. print("Falling back to LAION 400M model...")
  173. opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
  174. opt.ckpt = "models/ldm/text2img-large/model.ckpt"
  175. opt.outdir = "outputs/txt2img-samples-laion400m"
  176.  
  177. accelerator = accelerate.Accelerator()
  178. device = accelerator.device
  179. rng_seed = seed_everything(seed)
  180. seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes])
  181. torch.manual_seed(seeds[accelerator.process_index].item())
  182.  
  183. if plms:
  184. sampler = PLMSSampler(model)
  185. else:
  186. sampler = DDIMSampler(model)
  187.  
  188.  
  189. model_wrap = K.external.CompVisDenoiser(model)
  190. sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
  191.  
  192. os.makedirs(opt.outdir, exist_ok=True)
  193. outpath = opt.outdir
  194.  
  195. batch_size = n_samples
  196. n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
  197. if not opt.from_file:
  198. assert prompt is not None
  199. data = [batch_size * [prompt]]
  200.  
  201. else:
  202. print(f"reading prompts from {opt.from_file}")
  203. with open(opt.from_file, "r") as f:
  204. data = f.read().splitlines()
  205. data = list(chunk(data, batch_size))
  206.  
  207. sample_path = os.path.join(outpath, "samples")
  208. os.makedirs(sample_path, exist_ok=True)
  209. base_count = len(os.listdir(sample_path))
  210. grid_count = len(os.listdir(outpath)) - 1
  211. seedit = 0
  212.  
  213. start_code = None
  214. if fixed_code:
  215. start_code = torch.randn([n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
  216.  
  217. precision_scope = autocast if opt.precision=="autocast" else nullcontext
  218. if len(cfg_scales) > 1: cfg_scales = list(map(float, cfg_scales.split(' ')))
  219. output_images = []
  220. with torch.no_grad():
  221. gc.collect()
  222. torch.cuda.empty_cache()
  223. with precision_scope("cuda"):
  224. with model.ema_scope():
  225. tic = time.time()
  226. all_samples = list()
  227. for n in trange(n_iter, desc="Sampling", disable =not accelerator.is_main_process):
  228. for prompts in tqdm(data, desc="data", disable=not accelerator.is_main_process):
  229. if n_iter > 1: seedit += 1
  230. for cfg in tqdm(cfg_scales, desc="cfg_scales", disable=not accelerator.is_main_process):
  231. cfg_scale = cfg
  232. uc = None
  233. if cfg_scale != 1.0:
  234. uc = model.get_learned_conditioning(batch_size * [""])
  235. if isinstance(prompts, tuple):
  236. prompts = list(prompts)
  237. c = model.get_learned_conditioning(prompts)
  238. shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
  239. # samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
  240. # conditioning=c,
  241. # batch_size=opt.n_samples,
  242. # shape=shape,
  243. # verbose=False,
  244. # unconditional_guidance_scale=opt.scale,
  245. # unconditional_conditioning=uc,
  246. # eta=opt.ddim_eta,
  247. # x_T=start_code)
  248.  
  249. sigmas = model_wrap.get_sigmas(ddim_steps)
  250. torch.manual_seed(rng_seed + seedit) # changes manual seeding procedure
  251. # sigmas = K.sampling.get_sigmas_karras(opt.ddim_steps, sigma_min, sigma_max, device=device)
  252. x = torch.randn([n_samples, *shape], device=device) * sigmas[0] # for GPU draw
  253. # x = torch.randn([opt.n_samples, *shape]).to(device) * sigmas[0] # for CPU draw
  254. model_wrap_cfg = CFGDenoiser(model_wrap)
  255. extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
  256. samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process)
  257. x_samples_ddim = model.decode_first_stage(samples_ddim)
  258. x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  259. x_samples_ddim = accelerator.gather(x_samples_ddim)
  260.  
  261. if accelerator.is_main_process and not opt.skip_save:
  262. for x_sample in x_samples_ddim:
  263. x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
  264. Image.fromarray(x_sample.astype(np.uint8)).save(
  265. os.path.join(sample_path, f"{base_count:05}-{rng_seed + seedit}-{cfg_scale}_{prompt.replace(' ', '_')[:128]}.png"))
  266. output_images.append(Image.fromarray(x_sample.astype(np.uint8)))
  267. base_count += 1
  268. if not same_seed: seedit += 1
  269.  
  270. if accelerator.is_main_process and not opt.skip_grid:
  271. all_samples.append(x_samples_ddim)
  272.  
  273. if accelerator.is_main_process and not opt.skip_grid:
  274. # additionally, save as grid
  275. grid = torch.stack(all_samples, 0)
  276. grid = rearrange(grid, 'n b c h w -> (n b) c h w')
  277. grid = make_grid(grid, nrow=n_rows)
  278.  
  279. # to image
  280. grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
  281. Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
  282. grid_count += 1
  283.  
  284.  
  285. toc = time.time()
  286. gc.collect()
  287. torch.cuda.empty_cache()
  288. del sampler
  289. return output_images, rng_seed
  290.  
  291. def translation(prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scales: str, denoising_strength: float, seed: int, height: int, width: int, same_seed=bool):
  292. torch.cuda.empty_cache()
  293. parser = argparse.ArgumentParser()
  294.  
  295. parser.add_argument(
  296. "--outdir",
  297. type=str,
  298. nargs="?",
  299. help="dir to write results to",
  300. default="outputs/img2img-samples"
  301. )
  302.  
  303. parser.add_argument(
  304. "--skip_grid",
  305. action='store_true',
  306. help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
  307. )
  308.  
  309. parser.add_argument(
  310. "--skip_save",
  311. action='store_true',
  312. help="do not save indiviual samples. For speed measurements.",
  313. )
  314. parser.add_argument(
  315. "--C",
  316. type=int,
  317. default=4,
  318. help="latent channels",
  319. )
  320. parser.add_argument(
  321. "--f",
  322. type=int,
  323. default=8,
  324. help="downsampling factor, most often 8 or 16",
  325. )
  326. parser.add_argument(
  327. "--n_rows",
  328. type=int,
  329. default=0,
  330. help="rows in the grid (default: n_samples)",
  331. )
  332. parser.add_argument(
  333. "--from-file",
  334. type=str,
  335. help="if specified, load prompts from this file",
  336. )
  337. parser.add_argument(
  338. "--config",
  339. type=str,
  340. default="configs/stable-diffusion/v1-inference.yaml",
  341. help="path to config which constructs model",
  342. )
  343. parser.add_argument(
  344. "--ckpt",
  345. type=str,
  346. default="/gdrive/My Drive/model.ckpt",
  347. help="path to checkpoint of model",
  348. )
  349. parser.add_argument(
  350. "--precision",
  351. type=str,
  352. help="evaluate at this precision",
  353. choices=["full", "autocast"],
  354. default="autocast"
  355. )
  356.  
  357. opt = parser.parse_args()
  358.  
  359. accelerator = accelerate.Accelerator()
  360. device = accelerator.device
  361. rng_seed = seed_everything(seed)
  362. seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes])
  363. torch.manual_seed(seeds[accelerator.process_index].item())
  364.  
  365. sampler = DDIMSampler(model)
  366.  
  367. model_wrap = K.external.CompVisDenoiser(model)
  368. sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()
  369.  
  370. os.makedirs(opt.outdir, exist_ok=True)
  371. outpath = opt.outdir
  372.  
  373. batch_size = n_samples
  374. n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
  375. if not opt.from_file:
  376. prompt = prompt
  377. assert prompt is not None
  378. data = [batch_size * [prompt]]
  379. else:
  380. print(f"reading prompts from {opt.from_file}")
  381. with open(opt.from_file, "r") as f:
  382. data = f.read().splitlines()
  383. data = list(chunk(data, batch_size))
  384.  
  385. sample_path = os.path.join(outpath, "samples")
  386. os.makedirs(sample_path, exist_ok=True)
  387. base_count = len(os.listdir(sample_path))
  388. grid_count = len(os.listdir(outpath)) - 1
  389. seedit = 0
  390.  
  391. image = init_img.convert("RGB")
  392. w, h = image.size
  393. print(f"loaded input image of size ({w}, {h})")
  394. w, h = map(lambda x: x - x % 32, (width, height)) # resize to integer multiple of 32
  395. image = image.resize((w, h), resample=PIL.Image.LANCZOS)
  396. print(f"cropped image to size ({w}, {h})")
  397. image = np.array(image).astype(np.float32) / 255.0
  398. image = image[None].transpose(0, 3, 1, 2)
  399. image = torch.from_numpy(image)
  400. if len(cfg_scales) > 1: cfg_scales = list(map(float, cfg_scales.split(' ')))
  401. output_images = []
  402. precision_scope = autocast if opt.precision == "autocast" else nullcontext
  403. with torch.no_grad():
  404. gc.collect()
  405. torch.cuda.empty_cache()
  406. with precision_scope("cuda"):
  407. init_image = 2.*image - 1.
  408. init_image = init_image.to(device)
  409. init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
  410. init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
  411. x0 = init_latent
  412.  
  413. sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)
  414.  
  415. assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
  416. t_enc = int(denoising_strength * ddim_steps)
  417. print(f"target t_enc is {t_enc} steps")
  418. with model.ema_scope():
  419. tic = time.time()
  420. all_samples = list()
  421. for n in trange(n_iter, desc="Sampling", disable=not accelerator.is_main_process):
  422. for prompts in tqdm(data, desc="data", disable=not accelerator.is_main_process):
  423. for cfg in tqdm(cfg_scales, desc="cfg_scales", disable=not accelerator.is_main_process):
  424. cfg_scale = cfg
  425. uc = None
  426. if cfg_scale != 1.0:
  427. uc = model.get_learned_conditioning(batch_size * [""])
  428. if isinstance(prompts, tuple):
  429. prompts = list(prompts)
  430. c = model.get_learned_conditioning(prompts)
  431.  
  432. sigmas = model_wrap.get_sigmas(ddim_steps)
  433. torch.manual_seed(rng_seed + seedit) # changes manual seeding procedure
  434. # sigmas = K.sampling.get_sigmas_karras(ddim_steps, sigma_min, sigma_max, device=device)
  435. noise = torch.randn_like(x0) * sigmas[ddim_steps - t_enc - 1] # for GPU draw
  436. xi = x0 + noise
  437. sigma_sched = sigmas[ddim_steps - t_enc - 1:]
  438. # x = torch.randn([n_samples, *shape]).to(device) * sigmas[0] # for CPU draw
  439. model_wrap_cfg = CFGDenoiser(model_wrap)
  440. extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
  441. samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args=extra_args, disable=not accelerator.is_main_process)
  442. x_samples_ddim = model.decode_first_stage(samples_ddim)
  443. x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
  444. x_samples_ddim = accelerator.gather(x_samples_ddim)
  445.  
  446. if accelerator.is_main_process and not opt.skip_save:
  447. for x_sample in x_samples_ddim:
  448. x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
  449. Image.fromarray(x_sample.astype(np.uint8)).save(
  450. os.path.join(sample_path, f"{base_count:05}-{rng_seed + seedit}-{cfg_scale}_{prompt.replace(' ', '_')[:128]}.png"))
  451. output_images.append(Image.fromarray(x_sample.astype(np.uint8)))
  452. base_count += 1
  453. if not same_seed: seedit += 1
  454.  
  455. if accelerator.is_main_process and not opt.skip_grid:
  456. all_samples.append(x_samples_ddim)
  457.  
  458. if not opt.skip_grid:
  459. # additionally, save as grid
  460. grid = torch.stack(all_samples, 0)
  461. grid = rearrange(grid, 'n b c h w -> (n b) c h w')
  462. grid = make_grid(grid, nrow=n_rows)
  463.  
  464. # to image
  465. grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
  466. Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
  467. Image.fromarray(grid.astype(np.uint8))
  468. grid_count += 1
  469.  
  470. toc = time.time()
  471. gc.collect()
  472. torch.cuda.empty_cache()
  473. del sampler
  474. return output_images, rng_seed
  475.  
  476. dream_interface = gr.Interface(
  477. dream,
  478. inputs=[
  479. gr.Textbox(placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
  480. gr.Slider(minimum=1, maximum=200, step=1, label="Шаги диффузии, идеал - 100.", value=50),
  481. gr.Checkbox(label='Включить PLMS ', value=False),
  482. gr.Checkbox(label='Сэмплинг с одной точки', value=False),
  483. gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
  484. gr.Slider(minimum=1, maximum=50, step=1, label='Сколько раз сгенерировать по запросу (последовательно)', value=1),
  485. gr.Slider(minimum=1, maximum=8, step=1, label='Сколько картинок за раз (одновременно). ЖРЕТ МНОГО ПАМЯТИ', value=1),
  486. gr.Textbox(placeholder="7.0", label='Classifier Free Guidance Scales, через пробел либо только одна. Если больше одной, сэмплит один и тот же запрос с разными cfg. Обязательно число с точкой, типа 7.0 или 15.0', lines=1, value=9.0),
  487. gr.Number(label='Сид', value=-1),
  488. gr.Slider(minimum=64, maximum=2048, step=64, label="Высота", value=512),
  489. gr.Slider(minimum=64, maximum=2048, step=64, label="Ширина", value=512),
  490. gr.Checkbox(label='Один и тот же сид каждый раз. Для того чтобы генерировать одно и тоже с одинаковым запросом.', value=False)
  491. ],
  492. outputs=[
  493. gr.Gallery(),
  494. gr.Number(label='Сид')
  495. ],
  496. title="Stable Diffusion 1.4 текст в картинку",
  497. description="создай картинку из текста, анон, K-LMS используется по умолчанию",
  498. )
  499.  
  500. # prompt, init_img, ddim_steps, plms, ddim_eta, n_iter, n_samples, cfg_scale, denoising_strength, seed
  501.  
  502. img2img_interface = gr.Interface(
  503. translation,
  504. inputs=[
  505. gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
  506. gr.Image(value="https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg", source="upload", interactive=True, type="pil"),
  507. gr.Slider(minimum=1, maximum=200, step=1, label="Шаги диффузии, идеал - 100.", value=50),
  508. gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
  509. gr.Slider(minimum=1, maximum=50, step=1, label='Сколько раз генерировать по запросу (последовательно)', value=1),
  510. gr.Slider(minimum=1, maximum=8, step=1, label='Сколько картинок за раз (одновременно) ЖРЕТ МНОГО ПАМЯТИ', value=1),
  511. gr.Textbox(placeholder="7.0", label='Classifier Free Guidance Scales, через пробел либо только одна. Если больше одной, сэмплит один и тот же запрос с разнами cfg. Обязательно число с точкой, типа 7.0 или 15.0', lines=1, value=7.0),
  512. gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Процент шагов, указанных выше чтобы пройтись по картинке. Моюно считать "силой"', value=0.75),
  513. gr.Number(label='Сид', value=-1),
  514. gr.Slider(minimum=64, maximum=2048, step=64, label="Resize Height", value=512),
  515. gr.Slider(minimum=64, maximum=2048, step=64, label="Resize Width", value=512),
  516. gr.Checkbox(label='Один и тот же сид каждый раз. Для того чтобы генерировать одно и тоже с одинаковым запросом.', value=False)
  517. ],
  518. outputs=[
  519. gr.Gallery(),
  520. gr.Number(label='Seed')
  521. ],
  522. title="Stable Diffusion Image-to-Image",
  523. description="генерация изображения из изображения",
  524. )
  525.  
  526. demo = gr.TabbedInterface(interface_list=[dream_interface, img2img_interface], tab_names=["Dream", "Image Translation"])
  527.  
  528. demo.launch(share=True)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement