Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- '''
- test images with an alpha channel (png)
- directly reuse the pretrained network - e.g., RRDB_ESRGAN_x4
- '''
- import sys
- import os.path
- import glob
- import cv2
- import numpy as np
- import torch
- import architecture as arch
- model_path = 'models/interp_500.pth' # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
- device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
- # device = torch.device('cpu')
- test_img_folder = 'LR/*'
- model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
- mode='CNA', res_scale=1, upsample_mode='upconv')
- model.load_state_dict(torch.load(model_path), strict=True)
- model.eval()
- model = model.to(device)
- print('Model path {:s}. \nTesting...'.format(model_path))
- idx = 0
- for path in glob.glob(test_img_folder):
- idx += 1
- base = os.path.splitext(os.path.basename(path))[0]
- print(idx, base)
- # read image
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
- H, W, C = img.shape
- if C != 4:
- raise ValueError('You are not processing images with an alpha channel. Try use test.py')
- img = img * 1.0 / 255
- img_rgb = img[:, :, 0:3]
- img_alpha = img[:, :, 3]
- #########################
- # process RGB channels
- #########################
- img_rgb = torch.from_numpy(np.transpose(img_rgb[:, :, [2, 1, 0]], (2, 0, 1))).float()
- img_LR = img_rgb.unsqueeze(0)
- img_LR = img_LR.to(device)
- with torch.no_grad():
- output_rgb = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
- output_rgb = np.transpose(output_rgb[[2, 1, 0], :, :], (1, 2, 0))
- output_rgb = (output_rgb * 255.0).round()
- #########################
- # process alpha channel
- #########################
- # directly upsampling the alpha channel - it is too blur
- img_alpha_cubic = cv2.resize(img_alpha, (W * 4, H * 4), interpolation=cv2.INTER_CUBIC)
- img_alpha_cubic = np.clip(img_alpha_cubic, 0, 1)
- img_alpha_cubic = (img_alpha_cubic * 255.0).round()
- # use the ESRGAN model to enlarge the alpha channel - with artifacts!
- img_alpha = torch.from_numpy(img_alpha).float()
- H, W = img_alpha.size()
- # treat as colorful image
- img_alpha = torch.stack([img_alpha] * 3, dim=0).unsqueeze(0)
- with torch.no_grad():
- output_alpha = model(img_alpha.to(device)).data.squeeze().float().cpu().clamp_(0, 1).numpy()
- # turn into a gray image
- output_alpha = 0.298936 * output_alpha[0, :, :] + 0.587043 * output_alpha[
- 1, :, :] + 0.114021 * output_alpha[2, :, :]
- output_alpha *= 255
- # use erosion and dilation to obtain a sharp alpha channel without many artifacts
- kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
- kernel_2 = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
- erosion = cv2.erode(img_alpha_cubic, kernel, iterations=1)
- dilation = cv2.dilate(img_alpha_cubic, kernel, iterations=1)
- # cv2.imwrite('results/{:s}_rlt_alpha_ori.png'.format(base), output_alpha)
- output_alpha = np.minimum(np.maximum(output_alpha, erosion), dilation)
- output_alpha = cv2.erode(output_alpha, kernel_2, iterations=1)
- # cv2.imwrite('results/{:s}_rlt_alpha_bicubic.png'.format(base), img_alpha_cubic)
- # cv2.imwrite('results/{:s}_rlt_alpha_erosion.png'.format(base), erosion)
- # cv2.imwrite('results/{:s}_rlt_alpha_dilation.png'.format(base), dilation)
- # cv2.imwrite('results/{:s}_rlt_alpha_ESRGAN_900.png'.format(base), output_alpha)
- #########################
- # merge
- #########################
- output = np.stack((output_rgb[:, :, 0], output_rgb[:, :, 1], output_rgb[:, :, 2], output_alpha),
- axis=2)
- cv2.imwrite('results/{:s}_rlt_ESRGAN_500.png'.format(base), output)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement