Advertisement
Kafke

txt2mask-refined.py

Dec 23rd, 2022
1,608
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.51 KB | None | 0 0
  1. # Author: Therefore Games
  2. # https://github.com/ThereforeGames/txt2img2img
  3.  
  4. import modules.scripts as scripts
  5. import gradio as gr
  6.  
  7. from modules import processing, images, shared, sd_samplers
  8. from modules.processing import process_images, Processed
  9. from modules.shared import opts, cmd_opts, state, Options
  10.  
  11. import torch
  12. import cv2
  13. import requests
  14. import os.path
  15.  
  16. from repositories.clipseg.models.clipseg import CLIPDensePredT
  17. from PIL import ImageChops, Image, ImageOps
  18. from torchvision import transforms
  19. from matplotlib import pyplot as plt
  20. import numpy
  21.  
  22. debug = False
  23.  
  24. class Script(scripts.Script):
  25.     def title(self):
  26.         return "txt2mask v0.1.1 - Refined"
  27.  
  28.     def show(self, is_img2img):
  29.         return is_img2img
  30.  
  31.     def ui(self, is_img2img):
  32.         if not is_img2img:
  33.             return None
  34.  
  35.         mask_prompt = gr.Textbox(label="Mask prompt", lines=1)
  36.         negative_mask_prompt = gr.Textbox(label="Negative mask prompt", lines=1)
  37.         mask_precision = gr.Slider(label="Mask precision", minimum=0.0, maximum=255.0, step=1.0, value=100.0)
  38.         mask_padding = gr.Slider(label="Mask padding", minimum=0.0, maximum=500.0, step=1.0, value=0.0)
  39.         smoothing = gr.Slider(label="Smoothing", minimum=0.0, maximum=100.0, step=1.0, value=20.0) #added smoothing
  40.         brush_mask_mode = gr.Radio(label="Brush mask mode", choices=['discard','add','subtract'], value='discard', type="index", visible=False)
  41.         smoothing_enabled = gr.Checkbox(label="Enable Smoothing?",value=True)
  42.         mask_output = gr.Checkbox(label="Show mask in output?",value=True)
  43.  
  44.         plug = gr.HTML(label="plug",value='<div class="gr-block gr-box relative w-full overflow-hidden border-solid border border-gray-200 gr-panel"><p>If you like my work, please consider showing your support on <strong><a href="https://patreon.com/thereforegames" target="_blank">Patreon</a></strong>. Thank you! &#10084;</p></div>')
  45.  
  46.         return [mask_prompt,negative_mask_prompt, mask_precision, mask_padding, smoothing, brush_mask_mode, smoothing_enabled, mask_output, plug]
  47.  
  48.     def run(self, p, mask_prompt, negative_mask_prompt, mask_precision, mask_padding, smoothing, brush_mask_mode, smoothing_enabled, mask_output, plug):
  49.         def download_file(filename, url):
  50.             with open(filename, 'wb') as fout:
  51.                 response = requests.get(url, stream=True)
  52.                 response.raise_for_status()
  53.                 # Write response data to file
  54.                 for block in response.iter_content(4096):
  55.                     fout.write(block)
  56.         def pil_to_cv2(img):
  57.             return (cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR))
  58.         def gray_to_pil(img):
  59.             return (Image.fromarray(cv2.cvtColor(img,cv2.COLOR_GRAY2RGBA)))
  60.        
  61.         def center_crop(img,new_width,new_height):
  62.             width, height = img.size   # Get dimensions
  63.  
  64.             left = (width - new_width)/2
  65.             top = (height - new_height)/2
  66.             right = (width + new_width)/2
  67.             bottom = (height + new_height)/2
  68.  
  69.             # Crop the center of the image
  70.             return(img.crop((left, top, right, bottom)))
  71.  
  72.         def overlay_mask_part(img_a,img_b,mode):
  73.             if (mode == 0):
  74.                 img_a = ImageChops.darker(img_a, img_b)
  75.             else: img_a = ImageChops.lighter(img_a, img_b)
  76.             return(img_a)
  77.  
  78.         def process_mask_parts(these_preds,these_prompt_parts,mode,final_img = None):
  79.             for i in range(these_prompt_parts):
  80.                 filename = f"mask_{mode}_{i}.png"
  81.                 plt.imsave(filename,torch.sigmoid(these_preds[i][0]))
  82.  
  83.                 # TODO: Figure out how to convert the plot above to numpy instead of re-loading image
  84.                 img = cv2.imread(filename)
  85.  
  86.                 #New smoothing function
  87.                 if smoothing_enabled:
  88.                     radius = int(smoothing)
  89.                     smoothing_kernel = numpy.ones((radius,radius),numpy.float32)/(radius*radius)
  90.                     img = cv2.filter2D(img,-1,smoothing_kernel)
  91.                 #------------------------
  92.  
  93.                 gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  94.                 (thresh, bw_image) = cv2.threshold(gray_image, mask_precision, 255, cv2.THRESH_BINARY)
  95.  
  96.                 if (mode == 0): bw_image = numpy.invert(bw_image)
  97.  
  98.                 if (debug):
  99.                     print(f"bw_image: {bw_image}")
  100.                     print(f"final_img: {final_img}")
  101.  
  102.                 # overlay mask parts
  103.                 bw_image = gray_to_pil(bw_image)
  104.                 if (i > 0 or final_img is not None):
  105.                     bw_image = overlay_mask_part(bw_image,final_img,mode)
  106.  
  107.                 # For debugging only:
  108.                 if (debug): bw_image.save(f"processed_{filename}")
  109.  
  110.                 final_img = bw_image
  111.  
  112.             return(final_img)
  113.  
  114.         def get_mask():
  115.             # load model
  116.             model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
  117.             model.eval();
  118.             model_dir = "./repositories/clipseg/weights"
  119.             os.makedirs(model_dir, exist_ok=True)
  120.             d64_file = f"{model_dir}/rd64-uni-refined.pth"
  121.             d16_file = f"{model_dir}/rd16-uni.pth"
  122.             delimiter_string = "|"
  123.            
  124.             # Download model weights if we don't have them yet
  125.             if not os.path.exists(d64_file):
  126.                 print("Downloading clipseg model weights...")
  127.                 download_file(d64_file,"https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download?path=%2F&files=rd64-uni-refined.pth")
  128.                 download_file(d16_file,"https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download?path=%2F&files=rd16-uni.pth")
  129.                 # Mirror:
  130.                 # https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth
  131.                 # https://github.com/timojl/clipseg/raw/master/weights/rd16-uni.pth
  132.            
  133.             # non-strict, because we only stored decoder weights (not CLIP weights)
  134.             model.load_state_dict(torch.load(d64_file, map_location=torch.device('cuda')), strict=False);          
  135.  
  136.             transform = transforms.Compose([
  137.                 transforms.ToTensor(),
  138.                 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  139.                 transforms.Resize((512, 512)),
  140.             ])
  141.             img = transform(p.init_images[0]).unsqueeze(0)
  142.  
  143.             prompts = mask_prompt.split(delimiter_string)
  144.             prompt_parts = len(prompts)
  145.             negative_prompts = negative_mask_prompt.split(delimiter_string)
  146.             negative_prompt_parts = len(negative_prompts)
  147.  
  148.             # predict
  149.             with torch.no_grad():
  150.                 preds = model(img.repeat(prompt_parts,1,1,1), prompts)[0]
  151.                 negative_preds = model(img.repeat(negative_prompt_parts,1,1,1), negative_prompts)[0]
  152.  
  153.             #tests
  154.             if (debug):
  155.                 print("Check initial mask vars before processing...")
  156.                 print(f"p.image_mask: {p.image_mask}")
  157.                 print(f"p.latent_mask: {p.latent_mask}")
  158.                 print(f"p.mask_for_overlay: {p.mask_for_overlay}")
  159.  
  160.             if (brush_mask_mode == 1 and p.image_mask is not None):
  161.                 final_img = p.image_mask.convert("RGBA")
  162.             else: final_img = None
  163.  
  164.             # process masking
  165.             final_img = process_mask_parts(preds,prompt_parts,1,final_img)
  166.  
  167.             # process negative masking
  168.             if (brush_mask_mode == 2 and p.image_mask is not None):
  169.                 p.image_mask = ImageOps.invert(p.image_mask)
  170.                 p.image_mask = p.image_mask.convert("RGBA")
  171.                 final_img = overlay_mask_part(final_img,p.image_mask,0)
  172.             if (negative_mask_prompt): final_img = process_mask_parts(negative_preds,negative_prompt_parts,0,final_img)
  173.  
  174.             # Increase mask size with padding
  175.             if (mask_padding > 0):
  176.                 aspect_ratio = p.init_images[0].width / p.init_images[0].height
  177.                 new_width = p.init_images[0].width+mask_padding*2
  178.                 new_height = round(new_width / aspect_ratio)
  179.                 final_img = final_img.resize((new_width,new_height))
  180.                 final_img = center_crop(final_img,p.init_images[0].width,p.init_images[0].height)
  181.        
  182.             return (final_img)
  183.                        
  184.  
  185.         # Set up processor parameters correctly
  186.         p.mode = 1
  187.         p.mask_mode = 1
  188.         p.image_mask =  get_mask().resize((p.init_images[0].width,p.init_images[0].height))
  189.         p.mask_for_overlay = p.image_mask
  190.         p.latent_mask = None # fixes inpainting full resolution
  191.  
  192.  
  193.         processed = processing.process_images(p)
  194.  
  195.         if (mask_output):
  196.             processed.images.append(p.image_mask)
  197.  
  198.         return processed
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement