Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn.functional as F
- # ComfyUI images are typically float tensors in [0,1] with shape [B,H,W,C]
- def _to_bchw(img_bhwc: torch.Tensor) -> torch.Tensor:
- return img_bhwc.permute(0, 3, 1, 2).contiguous()
- def _to_bhwc(img_bchw: torch.Tensor) -> torch.Tensor:
- return img_bchw.permute(0, 2, 3, 1).contiguous()
- def _gaussian_kernel1d(sigma: float, device, dtype):
- sigma = max(float(sigma), 1e-6)
- radius = int(max(1, round(3.0 * sigma)))
- size = 2 * radius + 1
- x = torch.arange(-radius, radius + 1, device=device, dtype=dtype)
- kernel = torch.exp(-(x * x) / (2.0 * sigma * sigma))
- kernel = kernel / kernel.sum()
- return kernel # [K]
- def gaussian_blur_bchw(img: torch.Tensor, sigma: float) -> torch.Tensor:
- """
- Separable Gaussian blur on BCHW tensor.
- """
- if sigma <= 0.0:
- return img
- b, c, h, w = img.shape
- device, dtype = img.device, img.dtype
- k1d = _gaussian_kernel1d(sigma, device, dtype) # [K]
- kx = k1d.view(1, 1, 1, -1).repeat(c, 1, 1, 1) # [C,1,1,K]
- ky = k1d.view(1, 1, -1, 1).repeat(c, 1, 1, 1) # [C,1,K,1]
- pad = k1d.numel() // 2
- x = F.pad(img, (pad, pad, 0, 0), mode="reflect")
- x = F.conv2d(x, kx, groups=c)
- x = F.pad(x, (0, 0, pad, pad), mode="reflect")
- x = F.conv2d(x, ky, groups=c)
- return x
- def _luma_bchw(img: torch.Tensor) -> torch.Tensor:
- # img: [B,C,H,W] -> [B,1,H,W]
- b, c, h, w = img.shape
- if c >= 3:
- return (0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3])
- return img.mean(dim=1, keepdim=True)
- def estimate_shift_phase_corr(
- ref_bchw: torch.Tensor,
- mov_bchw: torch.Tensor,
- max_shift: int = 16,
- pre_blur_sigma: float = 1.0,
- eps: float = 1e-8,
- ) -> torch.Tensor:
- """
- Estimate integer (dy, dx) such that shifting mov by (dy, dx) aligns it to ref.
- Returns:
- shifts_yx: int64 tensor [B,2] with (dy, dx)
- """
- if max_shift <= 0:
- b = ref_bchw.shape[0]
- return torch.zeros((b, 2), device=ref_bchw.device, dtype=torch.int64)
- ref = _luma_bchw(ref_bchw)
- mov = _luma_bchw(mov_bchw)
- if pre_blur_sigma > 0:
- ref = gaussian_blur_bchw(ref, pre_blur_sigma)
- mov = gaussian_blur_bchw(mov, pre_blur_sigma)
- # remove mean to reduce DC dominance
- ref = ref - ref.mean(dim=(-2, -1), keepdim=True)
- mov = mov - mov.mean(dim=(-2, -1), keepdim=True)
- # FFT-based phase correlation
- R = torch.fft.rfft2(ref, dim=(-2, -1))
- M = torch.fft.rfft2(mov, dim=(-2, -1))
- CPS = R * torch.conj(M)
- CPS = CPS / (torch.abs(CPS) + eps)
- corr = torch.fft.irfft2(CPS, s=ref.shape[-2:], dim=(-2, -1)) # [B,1,H,W]
- corr = corr.squeeze(1) # [B,H,W]
- b, h, w = corr.shape
- # Candidate coordinates around 0 with wrap-around windows.
- # Peak can appear near 0..max_shift or near H-max_shift..H-1 (wrap).
- ys = list(range(0, max_shift + 1)) + list(range(max(0, h - max_shift), h))
- xs = list(range(0, max_shift + 1)) + list(range(max(0, w - max_shift), w))
- ys_t = torch.tensor(ys, device=corr.device, dtype=torch.long)
- xs_t = torch.tensor(xs, device=corr.device, dtype=torch.long)
- shifts = []
- for i in range(b):
- sub = corr[i].index_select(0, ys_t).index_select(1, xs_t) # [len(ys), len(xs)]
- idx = torch.argmax(sub)
- yy = int(idx // sub.shape[1])
- xx = int(idx % sub.shape[1])
- peak_y = ys[yy]
- peak_x = xs[xx]
- # convert peak location to signed shift
- dy = peak_y if peak_y <= h // 2 else peak_y - h
- dx = peak_x if peak_x <= w // 2 else peak_x - w
- dy = max(-max_shift, min(max_shift, dy))
- dx = max(-max_shift, min(max_shift, dx))
- shifts.append((dy, dx))
- return torch.tensor(shifts, device=ref_bchw.device, dtype=torch.int64) # [B,2]
- def apply_integer_shift_bchw(img: torch.Tensor, shifts_yx: torch.Tensor, wrap_edges: bool = True) -> torch.Tensor:
- """
- Apply integer (dy,dx) shift per batch item.
- If wrap_edges=True: uses torch.roll (wrap-around).
- If wrap_edges=False: uses grid_sample (pads with zeros, no wrap).
- """
- b, c, h, w = img.shape
- out = []
- if wrap_edges:
- for i in range(b):
- dy, dx = int(shifts_yx[i, 0].item()), int(shifts_yx[i, 1].item())
- out.append(torch.roll(img[i:i+1], shifts=(dy, dx), dims=(-2, -1)))
- return torch.cat(out, dim=0)
- # no-wrap shift using grid_sample
- # grid_sample wants normalized coords. Positive dx should move content right,
- # which corresponds to sampling from x - dx, so shift the grid by -dx.
- yy, xx = torch.meshgrid(
- torch.linspace(-1.0, 1.0, h, device=img.device, dtype=img.dtype),
- torch.linspace(-1.0, 1.0, w, device=img.device, dtype=img.dtype),
- indexing="ij",
- )
- base_grid = torch.stack([xx, yy], dim=-1)[None, ...].repeat(b, 1, 1, 1) # [B,H,W,2]
- # Convert pixel shift to normalized shift:
- # one pixel in x is 2/(w-1), one pixel in y is 2/(h-1)
- sx = 2.0 / max(1.0, float(w - 1))
- sy = 2.0 / max(1.0, float(h - 1))
- grid = base_grid.clone()
- for i in range(b):
- dy, dx = int(shifts_yx[i, 0].item()), int(shifts_yx[i, 1].item())
- grid[i, :, :, 0] = grid[i, :, :, 0] - dx * sx
- grid[i, :, :, 1] = grid[i, :, :, 1] - dy * sy
- shifted = F.grid_sample(
- img,
- grid,
- mode="bilinear",
- padding_mode="zeros",
- align_corners=True,
- )
- return shifted
- def sobel_edge_mask_bchw(img: torch.Tensor, sigma_pre: float = 0.8, gamma: float = 1.0) -> torch.Tensor:
- """
- Returns a 1-channel edge mask in [0,1] with shape [B,1,H,W].
- Computes Sobel magnitude on a lightly blurred luminance.
- """
- device, dtype = img.device, img.dtype
- b, c, h, w = img.shape
- x = gaussian_blur_bchw(img, sigma_pre)
- # luminance
- if c >= 3:
- y = 0.299 * x[:, 0:1] + 0.587 * x[:, 1:2] + 0.114 * x[:, 2:3]
- else:
- y = x.mean(dim=1, keepdim=True)
- kx = torch.tensor([[-1, 0, 1],
- [-2, 0, 2],
- [-1, 0, 1]], device=device, dtype=dtype).view(1, 1, 3, 3)
- ky = torch.tensor([[-1, -2, -1],
- [ 0, 0, 0],
- [ 1, 2, 1]], device=device, dtype=dtype).view(1, 1, 3, 3)
- ypad = F.pad(y, (1, 1, 1, 1), mode="reflect")
- gx = F.conv2d(ypad, kx)
- gy = F.conv2d(ypad, ky)
- mag = torch.sqrt(gx * gx + gy * gy + 1e-12)
- # robust normalize per image
- mag_flat = mag.view(b, -1)
- k = max(1, int(mag_flat.shape[1] * 0.05))
- topk, _ = torch.topk(mag_flat, k=k, dim=1, largest=True, sorted=False)
- scale = topk.mean(dim=1).view(b, 1, 1, 1).clamp_min(1e-6)
- mask = (mag / scale).clamp(0.0, 1.0)
- if gamma != 1.0:
- mask = mask.pow(float(gamma))
- return mask
- def soft_limit(x: torch.Tensor, limit: float) -> torch.Tensor:
- """
- Smoothly limits values to approx [-limit, limit] using tanh.
- """
- limit = float(limit)
- if limit <= 0.0:
- return x
- return torch.tanh(x / limit) * limit
- class FocusFusionTwoImages:
- @classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "oversharpened": ("IMAGE",),
- "blurry": ("IMAGE",),
- # Alignment
- "auto_align": ("BOOLEAN", {"default": True}),
- "max_shift": ("INT", {"default": 8, "min": 0, "max": 128, "step": 1}),
- "align_blur_sigma": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.1}),
- "wrap_edges": ("BOOLEAN", {"default": True}),
- # Base (low frequencies) from blurry
- "base_sigma": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 20.0, "step": 0.1}),
- # Detail extracted from oversharpened
- "detail_sigma": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 20.0, "step": 0.1}),
- "detail_amount": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 4.0, "step": 0.05}),
- # Halo / ringing control
- "detail_limit": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 0.5, "step": 0.005}),
- # Edge gating
- "edge_sigma": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 10.0, "step": 0.1}),
- "edge_gamma": ("FLOAT", {"default": 1.3, "min": 0.1, "max": 4.0, "step": 0.1}),
- "edge_mix": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05}),
- # Output clamp
- "clamp_output": ("BOOLEAN", {"default": True}),
- }
- }
- RETURN_TYPES = ("IMAGE",)
- FUNCTION = "fuse"
- CATEGORY = "image/focus"
- def fuse(
- self,
- oversharpened,
- blurry,
- auto_align: bool,
- max_shift: int,
- align_blur_sigma: float,
- wrap_edges: bool,
- base_sigma: float,
- detail_sigma: float,
- detail_amount: float,
- detail_limit: float,
- edge_sigma: float,
- edge_gamma: float,
- edge_mix: float,
- clamp_output: bool,
- ):
- if oversharpened.shape != blurry.shape:
- raise ValueError(f"Input images must match shape. Got {oversharpened.shape} vs {blurry.shape}")
- A = _to_bchw(oversharpened) # moving: oversharpened
- B = _to_bchw(blurry) # reference: blurry
- # Auto-align A into B coordinates (translation only)
- if auto_align and int(max_shift) > 0:
- shifts = estimate_shift_phase_corr(
- ref_bchw=B,
- mov_bchw=A,
- max_shift=int(max_shift),
- pre_blur_sigma=float(align_blur_sigma),
- )
- A = apply_integer_shift_bchw(A, shifts, wrap_edges=bool(wrap_edges))
- # Base from blurry (optionally low-passed)
- base = gaussian_blur_bchw(B, float(base_sigma)) if base_sigma > 0 else B
- # Detail from oversharpened (high frequencies)
- A_low = gaussian_blur_bchw(A, float(detail_sigma)) if detail_sigma > 0 else A
- detail = A - A_low
- # Limit detail amplitude to reduce halos/ringing
- detail = soft_limit(detail, float(detail_limit))
- # Edge mask from blurry (more trustworthy edge locations)
- edge = sobel_edge_mask_bchw(B, sigma_pre=float(edge_sigma), gamma=float(edge_gamma)) # [B,1,H,W]
- edge = edge.expand_as(detail)
- # Edge gating mix: 1.0 fully gated, 0.0 no gating
- w = (1.0 - float(edge_mix)) + float(edge_mix) * edge
- fused = base + float(detail_amount) * detail * w
- if clamp_output:
- fused = fused.clamp(0.0, 1.0)
- return (_to_bhwc(fused),)
- NODE_CLASS_MAPPINGS = {
- "FocusFusionTwoImages": FocusFusionTwoImages,
- }
- NODE_DISPLAY_NAME_MAPPINGS = {
- "FocusFusionTwoImages": "Focus Mix",
- }
Advertisement
Add Comment
Please, Sign In to add comment