Guest User

__init__.py

a guest
Dec 14th, 2025
56
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 10.92 KB | Source Code | 0 0
  1. import torch
  2. import torch.nn.functional as F
  3.  
  4. # ComfyUI images are typically float tensors in [0,1] with shape [B,H,W,C]
  5.  
  6.  
  7. def _to_bchw(img_bhwc: torch.Tensor) -> torch.Tensor:
  8.     return img_bhwc.permute(0, 3, 1, 2).contiguous()
  9.  
  10.  
  11. def _to_bhwc(img_bchw: torch.Tensor) -> torch.Tensor:
  12.     return img_bchw.permute(0, 2, 3, 1).contiguous()
  13.  
  14.  
  15. def _gaussian_kernel1d(sigma: float, device, dtype):
  16.     sigma = max(float(sigma), 1e-6)
  17.     radius = int(max(1, round(3.0 * sigma)))
  18.     size = 2 * radius + 1
  19.     x = torch.arange(-radius, radius + 1, device=device, dtype=dtype)
  20.     kernel = torch.exp(-(x * x) / (2.0 * sigma * sigma))
  21.     kernel = kernel / kernel.sum()
  22.     return kernel  # [K]
  23.  
  24.  
  25. def gaussian_blur_bchw(img: torch.Tensor, sigma: float) -> torch.Tensor:
  26.     """
  27.    Separable Gaussian blur on BCHW tensor.
  28.    """
  29.     if sigma <= 0.0:
  30.         return img
  31.  
  32.     b, c, h, w = img.shape
  33.     device, dtype = img.device, img.dtype
  34.  
  35.     k1d = _gaussian_kernel1d(sigma, device, dtype)  # [K]
  36.     kx = k1d.view(1, 1, 1, -1).repeat(c, 1, 1, 1)   # [C,1,1,K]
  37.     ky = k1d.view(1, 1, -1, 1).repeat(c, 1, 1, 1)   # [C,1,K,1]
  38.  
  39.     pad = k1d.numel() // 2
  40.  
  41.     x = F.pad(img, (pad, pad, 0, 0), mode="reflect")
  42.     x = F.conv2d(x, kx, groups=c)
  43.  
  44.     x = F.pad(x, (0, 0, pad, pad), mode="reflect")
  45.     x = F.conv2d(x, ky, groups=c)
  46.  
  47.     return x
  48.  
  49.  
  50. def _luma_bchw(img: torch.Tensor) -> torch.Tensor:
  51.     # img: [B,C,H,W] -> [B,1,H,W]
  52.     b, c, h, w = img.shape
  53.     if c >= 3:
  54.         return (0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3])
  55.     return img.mean(dim=1, keepdim=True)
  56.  
  57.  
  58. def estimate_shift_phase_corr(
  59.     ref_bchw: torch.Tensor,
  60.     mov_bchw: torch.Tensor,
  61.     max_shift: int = 16,
  62.     pre_blur_sigma: float = 1.0,
  63.     eps: float = 1e-8,
  64. ) -> torch.Tensor:
  65.     """
  66.    Estimate integer (dy, dx) such that shifting mov by (dy, dx) aligns it to ref.
  67.  
  68.    Returns:
  69.      shifts_yx: int64 tensor [B,2] with (dy, dx)
  70.    """
  71.     if max_shift <= 0:
  72.         b = ref_bchw.shape[0]
  73.         return torch.zeros((b, 2), device=ref_bchw.device, dtype=torch.int64)
  74.  
  75.     ref = _luma_bchw(ref_bchw)
  76.     mov = _luma_bchw(mov_bchw)
  77.  
  78.     if pre_blur_sigma > 0:
  79.         ref = gaussian_blur_bchw(ref, pre_blur_sigma)
  80.         mov = gaussian_blur_bchw(mov, pre_blur_sigma)
  81.  
  82.     # remove mean to reduce DC dominance
  83.     ref = ref - ref.mean(dim=(-2, -1), keepdim=True)
  84.     mov = mov - mov.mean(dim=(-2, -1), keepdim=True)
  85.  
  86.     # FFT-based phase correlation
  87.     R = torch.fft.rfft2(ref, dim=(-2, -1))
  88.     M = torch.fft.rfft2(mov, dim=(-2, -1))
  89.  
  90.     CPS = R * torch.conj(M)
  91.     CPS = CPS / (torch.abs(CPS) + eps)
  92.  
  93.     corr = torch.fft.irfft2(CPS, s=ref.shape[-2:], dim=(-2, -1))  # [B,1,H,W]
  94.     corr = corr.squeeze(1)  # [B,H,W]
  95.  
  96.     b, h, w = corr.shape
  97.  
  98.     # Candidate coordinates around 0 with wrap-around windows.
  99.     # Peak can appear near 0..max_shift or near H-max_shift..H-1 (wrap).
  100.     ys = list(range(0, max_shift + 1)) + list(range(max(0, h - max_shift), h))
  101.     xs = list(range(0, max_shift + 1)) + list(range(max(0, w - max_shift), w))
  102.  
  103.     ys_t = torch.tensor(ys, device=corr.device, dtype=torch.long)
  104.     xs_t = torch.tensor(xs, device=corr.device, dtype=torch.long)
  105.  
  106.     shifts = []
  107.     for i in range(b):
  108.         sub = corr[i].index_select(0, ys_t).index_select(1, xs_t)  # [len(ys), len(xs)]
  109.         idx = torch.argmax(sub)
  110.         yy = int(idx // sub.shape[1])
  111.         xx = int(idx % sub.shape[1])
  112.  
  113.         peak_y = ys[yy]
  114.         peak_x = xs[xx]
  115.  
  116.         # convert peak location to signed shift
  117.         dy = peak_y if peak_y <= h // 2 else peak_y - h
  118.         dx = peak_x if peak_x <= w // 2 else peak_x - w
  119.  
  120.         dy = max(-max_shift, min(max_shift, dy))
  121.         dx = max(-max_shift, min(max_shift, dx))
  122.         shifts.append((dy, dx))
  123.  
  124.     return torch.tensor(shifts, device=ref_bchw.device, dtype=torch.int64)  # [B,2]
  125.  
  126.  
  127. def apply_integer_shift_bchw(img: torch.Tensor, shifts_yx: torch.Tensor, wrap_edges: bool = True) -> torch.Tensor:
  128.     """
  129.    Apply integer (dy,dx) shift per batch item.
  130.  
  131.    If wrap_edges=True: uses torch.roll (wrap-around).
  132.    If wrap_edges=False: uses grid_sample (pads with zeros, no wrap).
  133.    """
  134.     b, c, h, w = img.shape
  135.     out = []
  136.  
  137.     if wrap_edges:
  138.         for i in range(b):
  139.             dy, dx = int(shifts_yx[i, 0].item()), int(shifts_yx[i, 1].item())
  140.             out.append(torch.roll(img[i:i+1], shifts=(dy, dx), dims=(-2, -1)))
  141.         return torch.cat(out, dim=0)
  142.  
  143.     # no-wrap shift using grid_sample
  144.     # grid_sample wants normalized coords. Positive dx should move content right,
  145.     # which corresponds to sampling from x - dx, so shift the grid by -dx.
  146.     yy, xx = torch.meshgrid(
  147.         torch.linspace(-1.0, 1.0, h, device=img.device, dtype=img.dtype),
  148.         torch.linspace(-1.0, 1.0, w, device=img.device, dtype=img.dtype),
  149.         indexing="ij",
  150.     )
  151.     base_grid = torch.stack([xx, yy], dim=-1)[None, ...].repeat(b, 1, 1, 1)  # [B,H,W,2]
  152.  
  153.     # Convert pixel shift to normalized shift:
  154.     # one pixel in x is 2/(w-1), one pixel in y is 2/(h-1)
  155.     sx = 2.0 / max(1.0, float(w - 1))
  156.     sy = 2.0 / max(1.0, float(h - 1))
  157.  
  158.     grid = base_grid.clone()
  159.     for i in range(b):
  160.         dy, dx = int(shifts_yx[i, 0].item()), int(shifts_yx[i, 1].item())
  161.         grid[i, :, :, 0] = grid[i, :, :, 0] - dx * sx
  162.         grid[i, :, :, 1] = grid[i, :, :, 1] - dy * sy
  163.  
  164.     shifted = F.grid_sample(
  165.         img,
  166.         grid,
  167.         mode="bilinear",
  168.         padding_mode="zeros",
  169.         align_corners=True,
  170.     )
  171.     return shifted
  172.  
  173.  
  174. def sobel_edge_mask_bchw(img: torch.Tensor, sigma_pre: float = 0.8, gamma: float = 1.0) -> torch.Tensor:
  175.     """
  176.    Returns a 1-channel edge mask in [0,1] with shape [B,1,H,W].
  177.    Computes Sobel magnitude on a lightly blurred luminance.
  178.    """
  179.     device, dtype = img.device, img.dtype
  180.     b, c, h, w = img.shape
  181.  
  182.     x = gaussian_blur_bchw(img, sigma_pre)
  183.  
  184.     # luminance
  185.     if c >= 3:
  186.         y = 0.299 * x[:, 0:1] + 0.587 * x[:, 1:2] + 0.114 * x[:, 2:3]
  187.     else:
  188.         y = x.mean(dim=1, keepdim=True)
  189.  
  190.     kx = torch.tensor([[-1, 0, 1],
  191.                        [-2, 0, 2],
  192.                        [-1, 0, 1]], device=device, dtype=dtype).view(1, 1, 3, 3)
  193.     ky = torch.tensor([[-1, -2, -1],
  194.                        [ 0,  0,  0],
  195.                        [ 1,  2,  1]], device=device, dtype=dtype).view(1, 1, 3, 3)
  196.  
  197.     ypad = F.pad(y, (1, 1, 1, 1), mode="reflect")
  198.     gx = F.conv2d(ypad, kx)
  199.     gy = F.conv2d(ypad, ky)
  200.  
  201.     mag = torch.sqrt(gx * gx + gy * gy + 1e-12)
  202.  
  203.     # robust normalize per image
  204.     mag_flat = mag.view(b, -1)
  205.     k = max(1, int(mag_flat.shape[1] * 0.05))
  206.     topk, _ = torch.topk(mag_flat, k=k, dim=1, largest=True, sorted=False)
  207.     scale = topk.mean(dim=1).view(b, 1, 1, 1).clamp_min(1e-6)
  208.     mask = (mag / scale).clamp(0.0, 1.0)
  209.  
  210.     if gamma != 1.0:
  211.         mask = mask.pow(float(gamma))
  212.  
  213.     return mask
  214.  
  215.  
  216. def soft_limit(x: torch.Tensor, limit: float) -> torch.Tensor:
  217.     """
  218.    Smoothly limits values to approx [-limit, limit] using tanh.
  219.    """
  220.     limit = float(limit)
  221.     if limit <= 0.0:
  222.         return x
  223.     return torch.tanh(x / limit) * limit
  224.  
  225.  
  226. class FocusFusionTwoImages:
  227.     @classmethod
  228.     def INPUT_TYPES(cls):
  229.         return {
  230.             "required": {
  231.                 "oversharpened": ("IMAGE",),
  232.                 "blurry": ("IMAGE",),
  233.  
  234.                 # Alignment
  235.                 "auto_align": ("BOOLEAN", {"default": True}),
  236.                 "max_shift": ("INT", {"default": 8, "min": 0, "max": 128, "step": 1}),
  237.                 "align_blur_sigma": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1}),
  238.                 "wrap_edges": ("BOOLEAN", {"default": True}),
  239.  
  240.                 # Base (low frequencies) from blurry
  241.                 "base_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 20.0, "step": 0.1}),
  242.  
  243.                 # Detail extracted from oversharpened
  244.                 "detail_sigma": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 20.0, "step": 0.1}),
  245.                 "detail_amount": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 4.0, "step": 0.05}),
  246.  
  247.                 # Halo / ringing control
  248.                 "detail_limit": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 0.5, "step": 0.005}),
  249.  
  250.                 # Edge gating
  251.                 "edge_sigma": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 10.0, "step": 0.1}),
  252.                 "edge_gamma": ("FLOAT", {"default": 1.3, "min": 0.1, "max": 4.0, "step": 0.1}),
  253.                 "edge_mix": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05}),
  254.  
  255.                 # Output clamp
  256.                 "clamp_output": ("BOOLEAN", {"default": True}),
  257.             }
  258.         }
  259.  
  260.     RETURN_TYPES = ("IMAGE",)
  261.     FUNCTION = "fuse"
  262.     CATEGORY = "image/focus"
  263.  
  264.     def fuse(
  265.         self,
  266.         oversharpened,
  267.         blurry,
  268.         auto_align: bool,
  269.         max_shift: int,
  270.         align_blur_sigma: float,
  271.         wrap_edges: bool,
  272.         base_sigma: float,
  273.         detail_sigma: float,
  274.         detail_amount: float,
  275.         detail_limit: float,
  276.         edge_sigma: float,
  277.         edge_gamma: float,
  278.         edge_mix: float,
  279.         clamp_output: bool,
  280.     ):
  281.         if oversharpened.shape != blurry.shape:
  282.             raise ValueError(f"Input images must match shape. Got {oversharpened.shape} vs {blurry.shape}")
  283.  
  284.         A = _to_bchw(oversharpened)  # moving: oversharpened
  285.         B = _to_bchw(blurry)         # reference: blurry
  286.  
  287.         # Auto-align A into B coordinates (translation only)
  288.         if auto_align and int(max_shift) > 0:
  289.             shifts = estimate_shift_phase_corr(
  290.                 ref_bchw=B,
  291.                 mov_bchw=A,
  292.                 max_shift=int(max_shift),
  293.                 pre_blur_sigma=float(align_blur_sigma),
  294.             )
  295.             A = apply_integer_shift_bchw(A, shifts, wrap_edges=bool(wrap_edges))
  296.  
  297.         # Base from blurry (optionally low-passed)
  298.         base = gaussian_blur_bchw(B, float(base_sigma)) if base_sigma > 0 else B
  299.  
  300.         # Detail from oversharpened (high frequencies)
  301.         A_low = gaussian_blur_bchw(A, float(detail_sigma)) if detail_sigma > 0 else A
  302.         detail = A - A_low
  303.  
  304.         # Limit detail amplitude to reduce halos/ringing
  305.         detail = soft_limit(detail, float(detail_limit))
  306.  
  307.         # Edge mask from blurry (more trustworthy edge locations)
  308.         edge = sobel_edge_mask_bchw(B, sigma_pre=float(edge_sigma), gamma=float(edge_gamma))  # [B,1,H,W]
  309.         edge = edge.expand_as(detail)
  310.  
  311.         # Edge gating mix: 1.0 fully gated, 0.0 no gating
  312.         w = (1.0 - float(edge_mix)) + float(edge_mix) * edge
  313.  
  314.         fused = base + float(detail_amount) * detail * w
  315.  
  316.         if clamp_output:
  317.             fused = fused.clamp(0.0, 1.0)
  318.  
  319.         return (_to_bhwc(fused),)
  320.  
  321.  
  322. NODE_CLASS_MAPPINGS = {
  323.     "FocusFusionTwoImages": FocusFusionTwoImages,
  324. }
  325.  
  326. NODE_DISPLAY_NAME_MAPPINGS = {
  327.     "FocusFusionTwoImages": "Focus Mix",
  328. }
Advertisement
Add Comment
Please, Sign In to add comment