Advertisement
Guest User

hack of stable diffusion

a guest
Sep 22nd, 2022
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 31.52 KB | None | 0 0
  1. # Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
  2.  
  3. # Derived from source code carrying the following copyrights
  4. # Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
  5. # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
  6. # hacked by h4krm
  7.  
  8. import torch
  9. import numpy as np
  10. import random
  11. import os
  12. import time
  13. import re
  14. import sys
  15. import traceback
  16. import transformers
  17. import copy
  18. import numpy
  19. from omegaconf import OmegaConf
  20. from PIL import Image, ImageOps
  21. from torch import nn
  22. from pytorch_lightning import seed_everything
  23.  
  24. from ldm.util                      import instantiate_from_config
  25. from ldm.models.diffusion.ddim     import DDIMSampler
  26. from ldm.models.diffusion.plms     import PLMSSampler
  27. from ldm.models.diffusion.ksampler import KSampler
  28. from ldm.dream.pngwriter           import PngWriter
  29. from ldm.dream.image_util          import InitImageResizer
  30. from ldm.dream.devices             import choose_torch_device
  31. from ldm.dream.conditioning        import get_uc_and_c
  32. g_cpu = torch.Generator()
  33. g_cpu.manual_seed(2)
  34.  
  35. """Simplified text to image API for stable diffusion/latent diffusion
  36.  
  37. Example Usage:
  38.  
  39. from ldm.generate import Generate
  40.  
  41. # Create an object with default values
  42. gr = Generate()
  43.  
  44. # do the slow model initialization
  45. gr.load_model()
  46.  
  47. # Do the fast inference & image generation. Any options passed here
  48. # override the default values assigned during class initialization
  49. # Will call load_model() if the model was not previously loaded and so
  50. # may be slow at first.
  51. # The method returns a list of images. Each row of the list is a sub-list of [filename,seed]
  52. results = gr.prompt2png(prompt     = "an astronaut riding a horse",
  53.                         outdir     = "./outputs/samples",
  54.                         iterations = 3)
  55.  
  56. for row in results:
  57.    print(f'filename={row[0]}')
  58.    print(f'seed    ={row[1]}')
  59.  
  60. # Same thing, but using an initial image.
  61. results = gr.prompt2png(prompt   = "an astronaut riding a horse",
  62.                         outdir   = "./outputs/,
  63.                         iterations = 3,
  64.                         init_img = "./sketches/horse+rider.png")
  65.  
  66. for row in results:
  67.    print(f'filename={row[0]}')
  68.    print(f'seed    ={row[1]}')
  69.  
  70. # Same thing, but we return a series of Image objects, which lets you manipulate them,
  71. # combine them, and save them under arbitrary names
  72.  
  73. results = gr.prompt2image(prompt   = "an astronaut riding a horse"
  74.                           outdir   = "./outputs/")
  75. for row in results:
  76.    im   = row[0]
  77.    seed = row[1]
  78.    im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png')
  79.    im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg')
  80.  
  81. Note that the old txt2img() and img2img() calls are deprecated but will
  82. still work.
  83.  
  84. The full list of arguments to Generate() are:
  85. gr = Generate(
  86.          weights     = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
  87.          config     = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
  88.          iterations  = <integer>     // how many times to run the sampling (1)
  89.          steps       = <integer>     // 50
  90.          seed        = <integer>     // current system time
  91.          sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms']  // k_lms
  92.          grid        = <boolean>     // false
  93.          width       = <integer>     // image width, multiple of 64 (512)
  94.          height      = <integer>     // image height, multiple of 64 (512)
  95.          cfg_scale   = <float>       // condition-free guidance scale (7.5)
  96.          )
  97.  
  98. """
  99.  
  100.  
  101. class Generate:
  102.     """Generate class
  103.    Stores default values for multiple configuration items
  104.    """
  105.  
  106.     def __init__(
  107.             self,
  108.             iterations            = 1,
  109.             steps                 = 50,
  110.             cfg_scale             = 7.5,
  111.             weights               = 'models/ldm/stable-diffusion-v1/model.ckpt',
  112.             config                = 'configs/stable-diffusion/v1-inference.yaml',
  113.             grid                  = False,
  114.             width                 = 512,
  115.             height                = 512,
  116.             sampler_name          = 'k_lms',
  117.             ddim_eta              = 0.0,  # deterministic
  118.             precision             = 'autocast',
  119.             full_precision        = True,
  120.             strength              = 0.75,  # default in scripts/img2img.py
  121.             seamless              = False,
  122.             embedding_path        = None,
  123.             device_type           = 'cuda',
  124.             ignore_ctrl_c         = False,
  125.     ):
  126.         self.iterations               = iterations
  127.         self.width                    = width
  128.         self.height                   = height
  129.         self.steps                    = steps
  130.         self.cfg_scale                = cfg_scale
  131.         self.weights                  = weights
  132.         self.config                   = config
  133.         self.sampler_name             = sampler_name
  134.         self.grid                     = grid
  135.         self.ddim_eta                 = ddim_eta
  136.         self.precision                = precision
  137.         self.full_precision           = True if choose_torch_device() == 'mps' else full_precision
  138.         self.strength                 = strength
  139.         self.seamless                 = seamless
  140.         self.embedding_path           = embedding_path
  141.         self.device_type              = device_type
  142.         self.ignore_ctrl_c            = ignore_ctrl_c    # note, this logic probably doesn't belong here...
  143.         self.model                    = None     # empty for now
  144.         self.sampler                  = None
  145.         self.device                   = None
  146.         self.generators               = {}
  147.         self.base_generator           = None
  148.         self.seed                     = None
  149.  
  150.         if device_type == 'cuda' and not torch.cuda.is_available():
  151.             device_type = choose_torch_device()
  152.             print(">> cuda not available, using device", device_type)
  153.         self.device = torch.device(device_type)
  154.  
  155.         # for VRAM usage statistics
  156.         device_type          = choose_torch_device()
  157.         self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
  158.         transformers.logging.set_verbosity_error()
  159.  
  160.     def prompt2png(self, prompt, outdir, **kwargs):
  161.         """
  162.        Takes a prompt and an output directory, writes out the requested number
  163.        of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
  164.        Optional named arguments are the same as those passed to Generate and prompt2image()
  165.        """
  166.         results = self.prompt2image(prompt, **kwargs)
  167.         pngwriter = PngWriter(outdir)
  168.         prefix = pngwriter.unique_prefix()
  169.         outputs = []
  170.         for image, seed in results:
  171.             name = f'{prefix}.{seed}.png'
  172.             path = pngwriter.save_image_and_prompt_to_png(
  173.                 image, f'{prompt} -S{seed}', name)
  174.             outputs.append([path, seed])
  175.         return outputs
  176.  
  177.     def txt2img(self, prompt, **kwargs):
  178.         outdir = kwargs.pop('outdir', 'outputs/img-samples')
  179.         return self.prompt2png(prompt, outdir, **kwargs)
  180.  
  181.     def img2img(self, prompt, **kwargs):
  182.         outdir = kwargs.pop('outdir', 'outputs/img-samples')
  183.         assert (
  184.             'init_img' in kwargs
  185.         ), 'call to img2img() must include the init_img argument'
  186.         return self.prompt2png(prompt, outdir, **kwargs)
  187.  
  188.     def prompt2image(
  189.             self,
  190.             # these are common
  191.             prompt,
  192.             iterations     =    None,
  193.             steps          =    None,
  194.             seed           =    None,
  195.             cfg_scale      =    None,
  196.             ddim_eta       =    None,
  197.             skip_normalize =    False,
  198.             image_callback =    None,
  199.             step_callback  =    None,
  200.             width          =    None,
  201.             height         =    None,
  202.             sampler_name   =    None,
  203.             seamless       =    False,
  204.             log_tokenization=  False,
  205.             with_variations =   None,
  206.             variation_amount =  0.0,
  207.             # these are specific to img2img and inpaint
  208.             init_img       =    None,
  209.             init_mask      =    None,
  210.             fit            =    False,
  211.             strength       =    None,
  212.             # these are specific to GFPGAN/ESRGAN
  213.             gfpgan_strength=    0,
  214.             save_original  =    False,
  215.             upscale        =    None,
  216.             mutation = 0,
  217.             mutations = [],
  218.  
  219.             **args,
  220.     ):   # eat up additional cruft
  221.         """
  222.        ldm.generate.prompt2image() is the common entry point for txt2img() and img2img()
  223.        It takes the following arguments:
  224.           prompt                          // prompt string (no default)
  225.           iterations                      // iterations (1); image count=iterations
  226.           steps                           // refinement steps per iteration
  227.           seed                            // seed for random number generator
  228.           width                           // width of image, in multiples of 64 (512)
  229.           height                          // height of image, in multiples of 64 (512)
  230.           cfg_scale                       // how strongly the prompt influences the image (7.5) (must be >1)
  231.           seamless                        // whether the generated image should tile
  232.           init_img                        // path to an initial image
  233.           strength                        // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
  234.           gfpgan_strength                 // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
  235.           ddim_eta                        // image randomness (eta=0.0 means the same seed always produces the same image)
  236.           step_callback                   // a function or method that will be called each step
  237.           image_callback                  // a function or method that will be called each time an image is generated
  238.           with_variations                 // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation
  239.           variation_amount                // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image)
  240.  
  241.        To use the step callback, define a function that receives two arguments:
  242.        - Image GPU data
  243.        - The step number
  244.  
  245.        To use the image callback, define a function of method that receives two arguments, an Image object
  246.        and the seed. You can then do whatever you like with the image, including converting it to
  247.        different formats and manipulating it. For example:
  248.  
  249.            def process_image(image,seed):
  250.                image.save(f{'images/seed.png'})
  251.  
  252.        The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code
  253.        to create the requested output directory, select a unique informative name for each image, and
  254.        write the prompt into the PNG metadata.
  255.        """
  256.         # TODO: convert this into a getattr() loop
  257.         steps                 = steps      or self.steps
  258.         width                 = width      or self.width
  259.         height                = height     or self.height
  260.         seamless              = seamless   or self.seamless
  261.         cfg_scale             = cfg_scale  or self.cfg_scale
  262.         ddim_eta              = ddim_eta   or self.ddim_eta
  263.         iterations            = iterations or self.iterations
  264.         strength              = strength   or self.strength
  265.         self.seed             = seed
  266.         self.log_tokenization = log_tokenization
  267.         with_variations = [] if with_variations is None else with_variations
  268.  
  269.         model = (
  270.             self.load_model()
  271.         )  # will instantiate the model or return it from cache
  272.  
  273.         for m in model.modules():
  274.             if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
  275.                 m.padding_mode = 'circular' if seamless else m._orig_padding_mode
  276.        
  277.         assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
  278.         assert (
  279.             0.0 < strength < 1.0
  280.         ), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
  281.         assert (
  282.                 0.0 <= variation_amount <= 1.0
  283.         ), '-v --variation_amount must be in [0.0, 1.0]'
  284.         modifications = []
  285.         eparams = []
  286.         eparam_names = []
  287.         #list(model.parameters())
  288.         for name, param in model.named_parameters():
  289.             eparam_names.append(name)
  290.             eparams.append(param)
  291.        
  292.         print(len(eparams))
  293.         #print(eparams)
  294.  
  295.         backups = []
  296.        
  297.         for mm in mutations:
  298.             #x = mm #random.randrange(len(eparams))
  299.             #pdb.set_trace()
  300.             vect = eparams[mm].data.cpu().detach().numpy()
  301.             eparams[mm].data += (mutation * torch.randn_like(eparams[mm]))
  302.             mutation = mutation + 0.001
  303.             #idx = torch.randperm(t.shape[0] #,generator=g_cpu)
  304.             #t = t[idx]
  305.             #eparams[x] = t
  306.             #t[idx].view(t.size())            
  307.             #w_matrix = weights_dict[key].cpu().detach().numpy()
  308.             #layer_weights_shape = w_matrix.shape
  309.             #layer_weights_size = w_matrix.size
  310.             #net_weights = copy.deepcopy(vect)
  311.             #net_weights = numpy.array(net_weights) * numpy.random.uniform(low=-1.0*mutation, high=mutation)
  312.             #newdata = numpy.random.permutation(vect)
  313.             #eparams[x] = torch.from_numpy(newdata)
  314.            
  315.             #backups.append([x, net_weights])
  316.             backups.append([mm, vect])
  317.             #print(f"modified {x} at {mutation}")
  318.             modifications.append([mm,
  319.                                   mutation,
  320.                                   eparam_names[mm],
  321.                                   str(eparams[mm].size()),
  322.                                   #eparams[x].cpu().numpy().tolist()
  323.                                   len(eparams)])
  324.            
  325.         # check this logic - doesn't look right
  326.         if len(with_variations) > 0 or variation_amount > 1.0:
  327.             assert seed is not None,\
  328.                 'seed must be specified when using with_variations'
  329.             if variation_amount == 0.0:
  330.                 assert iterations == 1,\
  331.                     'when using --with_variations, multiple iterations are only possible when using --variation_amount'
  332.             assert all(0 <= weight <= 1 for _, weight in with_variations),\
  333.                 f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}'
  334.  
  335.         width, height, _ = self._resolution_check(width, height, log=True)
  336.  
  337.         if sampler_name and (sampler_name != self.sampler_name):
  338.             self.sampler_name = sampler_name
  339.             self._set_sampler()
  340.  
  341.         tic = time.time()
  342.         if torch.cuda.is_available():
  343.             torch.cuda.reset_peak_memory_stats()
  344.  
  345.         results          = list()
  346.         init_image       = None
  347.         mask_image       = None
  348.  
  349.         try:
  350.             uc, c = get_uc_and_c(
  351.                 prompt, model=self.model,
  352.                 skip_normalize=skip_normalize,
  353.                 log_tokens=self.log_tokenization
  354.             )
  355.  
  356.             (init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
  357.            
  358.             if (init_image is not None) and (mask_image is not None):
  359.                 generator = self._make_inpaint()
  360.             elif init_image is not None:
  361.                 generator = self._make_img2img()
  362.             else:
  363.                 generator = self._make_txt2img()
  364.  
  365.             #generator.set_variation(self.seed, variation_amount, with_variations)
  366.             results = generator.generate(
  367.                 prompt,
  368.                 iterations     = iterations,
  369.                 seed           = self.seed,
  370.                 sampler        = self.sampler,
  371.                 steps          = steps,
  372.                 cfg_scale      = cfg_scale,
  373.                 conditioning   = (uc,c),
  374.                 ddim_eta       = ddim_eta,
  375.                 image_callback = image_callback,  # called after the final image is generated
  376.                 step_callback  = step_callback,   # called after each intermediate image is generated
  377.                 width          = width,
  378.                 height         = height,
  379.                 init_image     = init_image,      # notice that init_image is different from init_img
  380.                 mask_image     = mask_image,
  381.                 strength       = strength,
  382.                 mutations      = mutations,
  383.                 mods      = modifications,
  384.             )
  385.  
  386.             if upscale is not None or gfpgan_strength > 0:
  387.                 self.upscale_and_reconstruct(results,
  388.                                              upscale        = upscale,
  389.                                              strength       = gfpgan_strength,
  390.                                              save_original  = save_original,
  391.                                              image_callback = image_callback)
  392.  
  393.         except KeyboardInterrupt:
  394.             print('*interrupted*')
  395.             if not self.ignore_ctrl_c:
  396.                 raise KeyboardInterrupt
  397.             print(
  398.                 '>> Partial results will be returned; if --grid was requested, nothing will be returned.'
  399.             )
  400.         except RuntimeError as e:
  401.             print(traceback.format_exc(), file=sys.stderr)
  402.             print('>> Could not generate image.')
  403.  
  404.         toc = time.time()
  405.         print('>> Usage stats:')
  406.         print(
  407.             f'>>   {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
  408.         )
  409.         if torch.cuda.is_available() and self.device.type == 'cuda':
  410.             print(
  411.                 f'>>   Max VRAM used for this generation:',
  412.                 '%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),
  413.                 'Current VRAM utilization:'
  414.                 '%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
  415.             )
  416.  
  417.             self.session_peakmem = max(
  418.                 self.session_peakmem, torch.cuda.max_memory_allocated()
  419.             )
  420.             print(
  421.                 f'>>   Max VRAM used since script start: ',
  422.                 '%4.2fG' % (self.session_peakmem / 1e9),
  423.             )
  424.         for p in backups:
  425.             eparams[p[0]].data = torch.from_numpy(p[1]).cuda()
  426.             #eparams[x].data[0][0] = p[1]
  427.  
  428.         return results
  429.  
  430.     def _make_images(self, img_path, mask_path, width, height, fit=False):
  431.         init_image      = None
  432.         init_mask       = None
  433.         if not img_path:
  434.             return None,None
  435.  
  436.         image        = self._load_img(img_path, width, height, fit=fit) # this returns an Image
  437.         init_image   = self._create_init_image(image)                   # this returns a torch tensor
  438.  
  439.         if self._has_transparency(image) and not mask_path:      # if image has a transparent area and no mask was provided, then try to generate mask
  440.             print('>> Initial image has transparent areas. Will inpaint in these regions.')
  441.             if self._check_for_erasure(image):
  442.                 print(
  443.                     '>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
  444.                     '>>          Inpainting will be suboptimal. Please preserve the colors when making\n',
  445.                     '>>          a transparency mask, or provide mask explicitly using --init_mask (-M).'
  446.                 )
  447.             init_mask = self._create_init_mask(image)                   # this returns a torch tensor
  448.  
  449.         if mask_path:
  450.             mask_image  = self._load_img(mask_path, width, height, fit=fit) # this returns an Image
  451.             init_mask   = self._create_init_mask(mask_image)
  452.  
  453.         return init_image,init_mask
  454.  
  455.     def _make_img2img(self):
  456.         if not self.generators.get('img2img'):
  457.             from ldm.dream.generator.img2img import Img2Img
  458.             self.generators['img2img'] = Img2Img(self.model)
  459.         return self.generators['img2img']
  460.  
  461.     def _make_txt2img(self):
  462.         if not self.generators.get('txt2img'):
  463.             from ldm.dream.generator.txt2img import Txt2Img
  464.             self.generators['txt2img'] = Txt2Img(self.model)
  465.         return self.generators['txt2img']
  466.  
  467.     def _make_inpaint(self):
  468.         if not self.generators.get('inpaint'):
  469.             from ldm.dream.generator.inpaint import Inpaint
  470.             self.generators['inpaint'] = Inpaint(self.model)
  471.         return self.generators['inpaint']
  472.  
  473.     def load_model(self):
  474.         """Load and initialize the model from configuration variables passed at object creation time"""
  475.         if self.model is None:
  476.             #seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
  477.             seed_everything(1)
  478.             try:
  479.                 config = OmegaConf.load(self.config)
  480.                 model = self._load_model_from_config(config, self.weights)
  481.                 if self.embedding_path is not None:
  482.                     model.embedding_manager.load(
  483.                         self.embedding_path, self.full_precision
  484.                     )
  485.                 self.model = model.to(self.device)
  486.                 # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
  487.                 self.model.cond_stage_model.device = self.device
  488.             except AttributeError as e:
  489.                 print(f'>> Error loading model. {str(e)}', file=sys.stderr)
  490.                 print(traceback.format_exc(), file=sys.stderr)
  491.                 raise SystemExit from e
  492.  
  493.             self._set_sampler()
  494.  
  495.             for m in self.model.modules():
  496.                 if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
  497.                     m._orig_padding_mode = m.padding_mode
  498.  
  499.         return self.model
  500.  
  501.     def upscale_and_reconstruct(self,
  502.                                 image_list,
  503.                                 upscale       = None,
  504.                                 strength      =  0.0,
  505.                                 save_original = False,
  506.                                 image_callback = None):
  507.         try:
  508.             if upscale is not None:
  509.                 from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
  510.             if strength > 0:
  511.                 from ldm.gfpgan.gfpgan_tools import run_gfpgan
  512.         except (ModuleNotFoundError, ImportError):
  513.             print(traceback.format_exc(), file=sys.stderr)
  514.             print('>> You may need to install the ESRGAN and/or GFPGAN modules')
  515.             return
  516.            
  517.         for r in image_list:
  518.             image, seed1 = r
  519.             try:
  520.                 if upscale is not None:
  521.                     if len(upscale) < 2:
  522.                         upscale.append(0.75)
  523.                     image = real_esrgan_upscale(
  524.                         image,
  525.                         upscale[1],
  526.                         int(upscale[0]),
  527.                         seed1,
  528.                     )
  529.                 if strength > 0:
  530.                     image = run_gfpgan(
  531.                         image, strength, seed1, 1
  532.                     )
  533.             except Exception as e:
  534.                 print(
  535.                     f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
  536.                 )
  537.  
  538.             if image_callback is not None:
  539.                 image_callback(image, seed1, upscaled=True)
  540.             else:
  541.                 r[0] = image
  542.  
  543.     # to help WebGUI - front end to generator util function
  544.     def sample_to_image(self,samples):
  545.         return self._sample_to_image(samples)
  546.  
  547.     def _sample_to_image(self,samples):
  548.         if not self.base_generator:
  549.             from ldm.dream.generator import Generator
  550.             self.base_generator = Generator(self.model)
  551.         return self.base_generator.sample_to_image(samples)
  552.  
  553.     def _set_sampler(self):
  554.         msg = f'>> Setting Sampler to {self.sampler_name}'
  555.         if self.sampler_name == 'plms':
  556.             self.sampler = PLMSSampler(self.model, device=self.device)
  557.         elif self.sampler_name == 'ddim':
  558.             self.sampler = DDIMSampler(self.model, device=self.device)
  559.         elif self.sampler_name == 'k_dpm_2_a':
  560.             self.sampler = KSampler(
  561.                 self.model, 'dpm_2_ancestral', device=self.device
  562.             )
  563.         elif self.sampler_name == 'k_dpm_2':
  564.             self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
  565.         elif self.sampler_name == 'k_euler_a':
  566.             self.sampler = KSampler(
  567.                 self.model, 'euler_ancestral', device=self.device
  568.             )
  569.         elif self.sampler_name == 'k_euler':
  570.             self.sampler = KSampler(self.model, 'euler', device=self.device)
  571.         elif self.sampler_name == 'k_heun':
  572.             self.sampler = KSampler(self.model, 'heun', device=self.device)
  573.         elif self.sampler_name == 'k_lms':
  574.             self.sampler = KSampler(self.model, 'lms', device=self.device)
  575.         else:
  576.             msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms'
  577.             self.sampler = PLMSSampler(self.model, device=self.device)
  578.  
  579.         print(msg)
  580.  
  581.     def _load_model_from_config(self, config, ckpt):
  582.         print(f'>> Loading model from {ckpt}')
  583.  
  584.         # for usage statistics
  585.         device_type = choose_torch_device()
  586.         if device_type == 'cuda':
  587.             torch.cuda.reset_peak_memory_stats()
  588.         tic = time.time()
  589.  
  590.         # this does the work
  591.         pl_sd = torch.load(ckpt, map_location='cpu')
  592.         sd = pl_sd['state_dict']
  593.         model = instantiate_from_config(config.model)
  594.         m, u = model.load_state_dict(sd, strict=False)
  595.        
  596.         if self.full_precision:
  597.             print(
  598.                 '>> Using slower but more accurate full-precision math (--full_precision)'
  599.             )
  600.         else:
  601.             print(
  602.                 '>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.'
  603.             )
  604.             model.half()
  605.         model.to(self.device)
  606.         model.eval()
  607.  
  608.         # usage statistics
  609.         toc = time.time()
  610.         print(
  611.             f'>> Model loaded in', '%4.2fs' % (toc - tic)
  612.         )
  613.         if device_type == 'cuda':
  614.             print(
  615.                 '>> Max VRAM used to load the model:',
  616.                 '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
  617.                 '\n>> Current VRAM usage:'
  618.                 '%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
  619.             )
  620.  
  621.         return model
  622.  
  623.     def _load_img(self, path, width, height, fit=False):
  624.         assert os.path.exists(path), f'>> {path}: File not found'
  625.  
  626.         #        with Image.open(path) as img:
  627.         #            image = img.convert('RGBA')
  628.         image = Image.open(path)
  629.         print(
  630.             f'>> loaded input image of size {image.width}x{image.height} from {path}'
  631.         )
  632.         if fit:
  633.             image = self._fit_image(image,(width,height))
  634.         else:
  635.             image = self._squeeze_image(image)
  636.         return image
  637.  
  638.     def _create_init_image(self,image):
  639.         image = image.convert('RGB')
  640.         # print(
  641.         #     f'>> DEBUG: writing the image to img.png'
  642.         # )
  643.         # image.save('img.png')
  644.         image = np.array(image).astype(np.float32) / 255.0
  645.         image = image[None].transpose(0, 3, 1, 2)
  646.         image = torch.from_numpy(image)
  647.         image = 2.0 * image - 1.0
  648.         return image.to(self.device)
  649.  
  650.     def _create_init_mask(self, image):
  651.         # convert into a black/white mask
  652.         image = self._image_to_mask(image)
  653.         image = image.convert('RGB')
  654.         # BUG: We need to use the model's downsample factor rather than hardcoding "8"
  655.         from ldm.dream.generator.base import downsampling
  656.         image = image.resize((image.width//downsampling, image.height//downsampling), resample=Image.Resampling.LANCZOS)
  657.         # print(
  658.         #     f'>> DEBUG: writing the mask to mask.png'
  659.         #     )
  660.         # image.save('mask.png')
  661.         image = np.array(image)
  662.         image = image.astype(np.float32) / 255.0
  663.         image = image[None].transpose(0, 3, 1, 2)
  664.         image = torch.from_numpy(image)
  665.         return image.to(self.device)
  666.  
  667.     # The mask is expected to have the region to be inpainted
  668.     # with alpha transparency. It converts it into a black/white
  669.     # image with the transparent part black.
  670.     def _image_to_mask(self, mask_image, invert=False) -> Image:
  671.         # Obtain the mask from the transparency channel
  672.         mask = Image.new(mode="L", size=mask_image.size, color=255)
  673.         mask.putdata(mask_image.getdata(band=3))
  674.         if invert:
  675.             mask = ImageOps.invert(mask)
  676.         return mask
  677.  
  678.     def _has_transparency(self,image):
  679.         if image.info.get("transparency", None) is not None:
  680.             return True
  681.         if image.mode == "P":
  682.             transparent = image.info.get("transparency", -1)
  683.             for _, index in image.getcolors():
  684.                 if index == transparent:
  685.                     return True
  686.         elif image.mode == "RGBA":
  687.             extrema = image.getextrema()
  688.             if extrema[3][0] < 255:
  689.                 return True
  690.         return False
  691.  
  692.    
  693.     def _check_for_erasure(self,image):
  694.         width, height = image.size
  695.         pixdata       = image.load()
  696.         colored       = 0
  697.         for y in range(height):
  698.             for x in range(width):
  699.                 if pixdata[x, y][3] == 0:
  700.                     r, g, b, _ = pixdata[x, y]
  701.                     if (r, g, b) != (0, 0, 0) and \
  702.                        (r, g, b) != (255, 255, 255):
  703.                         colored += 1
  704.         return colored == 0
  705.  
  706.     def _squeeze_image(self,image):
  707.         x,y,resize_needed = self._resolution_check(image.width,image.height)
  708.         if resize_needed:
  709.             return InitImageResizer(image).resize(x,y)
  710.         return image
  711.  
  712.  
  713.     def _fit_image(self,image,max_dimensions):
  714.         w,h = max_dimensions
  715.         print(
  716.             f'>> image will be resized to fit inside a box {w}x{h} in size.'
  717.         )
  718.         if image.width > image.height:
  719.             h   = None   # by setting h to none, we tell InitImageResizer to fit into the width and calculate height
  720.         elif image.height > image.width:
  721.             w   = None   # ditto for w
  722.         else:
  723.             pass
  724.         image = InitImageResizer(image).resize(w,h)   # note that InitImageResizer does the multiple of 64 truncation internally
  725.         print(
  726.             f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}'
  727.             )
  728.         return image
  729.  
  730.     def _resolution_check(self, width, height, log=False):
  731.         resize_needed = False
  732.         w, h = map(
  733.             lambda x: x - x % 64, (width, height)
  734.         )  # resize to integer multiple of 64
  735.         if h != height or w != width:
  736.             if log:
  737.                 print(
  738.                     f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}'
  739.                 )
  740.             height = h
  741.             width  = w
  742.             resize_needed = True
  743.  
  744.         if (width * height) > (self.width * self.height):
  745.             print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
  746.  
  747.         return width, height, resize_needed
  748.  
  749.  
  750.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement