Advertisement
Guest User

test_with_alpha_channel

a guest
Jan 21st, 2019
1,438
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.75 KB | None | 0 0
  1. '''
  2. test images with an alpha channel (png)
  3. directly reuse the pretrained network - e.g., RRDB_ESRGAN_x4
  4. '''
  5.  
  6. import sys
  7. import os.path
  8. import glob
  9. import cv2
  10. import numpy as np
  11. import torch
  12. import architecture as arch
  13.  
  14. model_path = 'models/interp_500.pth'  # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
  15. device = torch.device('cuda')  # if you want to run on CPU, change 'cuda' -> cpu
  16. # device = torch.device('cpu')
  17.  
  18. test_img_folder = 'LR/*'
  19.  
  20. model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
  21.                         mode='CNA', res_scale=1, upsample_mode='upconv')
  22. model.load_state_dict(torch.load(model_path), strict=True)
  23. model.eval()
  24. model = model.to(device)
  25.  
  26. print('Model path {:s}. \nTesting...'.format(model_path))
  27.  
  28. idx = 0
  29. for path in glob.glob(test_img_folder):
  30.     idx += 1
  31.     base = os.path.splitext(os.path.basename(path))[0]
  32.     print(idx, base)
  33.     # read image
  34.     img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
  35.     H, W, C = img.shape
  36.     if C != 4:
  37.         raise ValueError('You are not processing images with an alpha channel. Try use test.py')
  38.     img = img * 1.0 / 255
  39.  
  40.     img_rgb = img[:, :, 0:3]
  41.     img_alpha = img[:, :, 3]
  42.  
  43.     #########################
  44.     # process RGB channels
  45.     #########################
  46.     img_rgb = torch.from_numpy(np.transpose(img_rgb[:, :, [2, 1, 0]], (2, 0, 1))).float()
  47.     img_LR = img_rgb.unsqueeze(0)
  48.     img_LR = img_LR.to(device)
  49.     with torch.no_grad():
  50.         output_rgb = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
  51.     output_rgb = np.transpose(output_rgb[[2, 1, 0], :, :], (1, 2, 0))
  52.     output_rgb = (output_rgb * 255.0).round()
  53.  
  54.     #########################
  55.     # process alpha channel
  56.     #########################
  57.     # directly upsampling the alpha channel - it is too blur
  58.     img_alpha_cubic = cv2.resize(img_alpha, (W * 4, H * 4), interpolation=cv2.INTER_CUBIC)
  59.     img_alpha_cubic = np.clip(img_alpha_cubic, 0, 1)
  60.     img_alpha_cubic = (img_alpha_cubic * 255.0).round()
  61.  
  62.     # use the ESRGAN model to enlarge the alpha channel - with artifacts!
  63.     img_alpha = torch.from_numpy(img_alpha).float()
  64.     H, W = img_alpha.size()
  65.     # treat as colorful image
  66.     img_alpha = torch.stack([img_alpha] * 3, dim=0).unsqueeze(0)
  67.     with torch.no_grad():
  68.         output_alpha = model(img_alpha.to(device)).data.squeeze().float().cpu().clamp_(0, 1).numpy()
  69.     # turn into a gray image
  70.     output_alpha = 0.298936 * output_alpha[0, :, :] + 0.587043 * output_alpha[
  71.         1, :, :] + 0.114021 * output_alpha[2, :, :]
  72.     output_alpha *= 255
  73.  
  74.     # use erosion and dilation to obtain a sharp alpha channel without many artifacts
  75.     kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
  76.     kernel_2 = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
  77.     erosion = cv2.erode(img_alpha_cubic, kernel, iterations=1)
  78.     dilation = cv2.dilate(img_alpha_cubic, kernel, iterations=1)
  79.     # cv2.imwrite('results/{:s}_rlt_alpha_ori.png'.format(base), output_alpha)
  80.     output_alpha = np.minimum(np.maximum(output_alpha, erosion), dilation)
  81.     output_alpha = cv2.erode(output_alpha, kernel_2, iterations=1)
  82.     # cv2.imwrite('results/{:s}_rlt_alpha_bicubic.png'.format(base), img_alpha_cubic)
  83.     # cv2.imwrite('results/{:s}_rlt_alpha_erosion.png'.format(base), erosion)
  84.     # cv2.imwrite('results/{:s}_rlt_alpha_dilation.png'.format(base), dilation)
  85.     # cv2.imwrite('results/{:s}_rlt_alpha_ESRGAN_900.png'.format(base), output_alpha)
  86.  
  87.     #########################
  88.     # merge
  89.     #########################
  90.     output = np.stack((output_rgb[:, :, 0], output_rgb[:, :, 1], output_rgb[:, :, 2], output_alpha),
  91.                       axis=2)
  92.     cv2.imwrite('results/{:s}_rlt_ESRGAN_500.png'.format(base), output)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement