Advertisement
Guest User

Untitled

a guest
Feb 22nd, 2019
100
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.33 KB | None | 0 0
  1. import sys
  2. import os.path
  3. import glob
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import architecture as arch
  8.  
  9. model_path = sys.argv[1]  # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
  10. input_path = sys.argv [2]
  11. #output_path = sys.argv [3]
  12.  
  13. test_img_folder = input_path+'/*'
  14.  
  15. print (test_img_folder)
  16. print (input_path)
  17.  
  18. model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
  19.                         mode='CNA', res_scale=1, upsample_mode='upconv')
  20. model.load_state_dict(torch.load(model_path), strict=True)
  21. model.eval()
  22. for k, v in model.named_parameters():
  23.     v.requires_grad = False
  24. model = model.cuda()
  25.  
  26. print('Model path {:s}. \nTesting...'.format(model_path))
  27.  
  28. idx = 0
  29. for path in glob.glob(test_img_folder):
  30.     idx += 1
  31.     base = os.path.splitext(os.path.basename(path))[0]
  32.     print(idx, base)
  33.     # read image
  34.     img = cv2.imread(path, cv2.IMREAD_COLOR)
  35.     img = img * 1.0 / 255
  36.     img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
  37.     img_LR = img.unsqueeze(0)
  38.     img_LR = img_LR.cuda()
  39.  
  40.     output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
  41.     output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
  42.     output = (output * 255.0).round()
  43.     cv2.imwrite('results/{:s}srgan4.png'.format(base), output)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement