SHARE
TWEET

Untitled

a guest Feb 22nd, 2019 74 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import sys
  2. import os.path
  3. import glob
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import architecture as arch
  8.  
  9. model_path = sys.argv[1]  # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
  10. input_path = sys.argv [2]
  11. #output_path = sys.argv [3]
  12.  
  13. test_img_folder = input_path+'/*'
  14.  
  15. print (test_img_folder)
  16. print (input_path)
  17.  
  18. model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
  19.                         mode='CNA', res_scale=1, upsample_mode='upconv')
  20. model.load_state_dict(torch.load(model_path), strict=True)
  21. model.eval()
  22. for k, v in model.named_parameters():
  23.     v.requires_grad = False
  24. model = model.cuda()
  25.  
  26. print('Model path {:s}. \nTesting...'.format(model_path))
  27.  
  28. idx = 0
  29. for path in glob.glob(test_img_folder):
  30.     idx += 1
  31.     base = os.path.splitext(os.path.basename(path))[0]
  32.     print(idx, base)
  33.     # read image
  34.     img = cv2.imread(path, cv2.IMREAD_COLOR)
  35.     img = img * 1.0 / 255
  36.     img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
  37.     img_LR = img.unsqueeze(0)
  38.     img_LR = img_LR.cuda()
  39.  
  40.     output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
  41.     output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
  42.     output = (output * 255.0).round()
  43.     cv2.imwrite('results/{:s}srgan4.png'.format(base), output)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top