Guest User

replace

a guest
Jun 29th, 2021
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.62 KB | None | 0 0
  1. import argparse
  2. import cv2
  3. import glob
  4. import numpy as np
  5. import os
  6. import torch
  7. from facexlib.utils.face_restoration_helper import FaceRestoreHelper
  8. from torchvision.transforms.functional import normalize
  9.  
  10. from archs.gfpganv1_arch import GFPGANv1
  11. from basicsr.utils import img2tensor, imwrite, tensor2img
  12.  
  13.  
  14. def restoration(gfpgan,
  15. face_helper,
  16. img_path,
  17. save_root,
  18. has_aligned=False,
  19. only_center_face=True,
  20. suffix=None,
  21. paste_back=False):
  22. # read image
  23. img_name = os.path.basename(img_path)
  24. print(f'Processing {img_name} ...')
  25. basename, _ = os.path.splitext(img_name)
  26. input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
  27. face_helper.clean_all()
  28.  
  29. if has_aligned:
  30. input_img = cv2.resize(input_img, (512, 512))
  31. face_helper.cropped_faces = [input_img]
  32. else:
  33. input_img = cv2.resize(input_img, (512, 512))
  34. face_helper.cropped_faces = [input_img]
  35.  
  36. # face restoration
  37. for idx, cropped_face in enumerate(face_helper.cropped_faces):
  38. # prepare data
  39. cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
  40. normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
  41. cropped_face_t = cropped_face_t.unsqueeze(0).to('cuda')
  42.  
  43. try:
  44. with torch.no_grad():
  45. output = gfpgan(cropped_face_t, return_rgb=False)[0]
  46. # convert to image
  47. restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
  48. except RuntimeError as error:
  49. print(f'\tFailed inference for GFPGAN: {error}.')
  50. restored_face = cropped_face
  51.  
  52. restored_face = restored_face.astype('uint8')
  53. face_helper.add_restored_face(restored_face)
  54.  
  55. if suffix is not None:
  56. save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
  57. else:
  58. save_face_name = f'{basename}_{idx:02d}.png'
  59. save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
  60. imwrite(restored_face, save_restore_path)
  61.  
  62. # save cmp image
  63. cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
  64. imwrite(cmp_img, os.path.join(save_root, 'cmp', f'{basename}_{idx:02d}.png'))
  65.  
  66. if not has_aligned and paste_back:
  67. face_helper.get_inverse_affine(None)
  68. save_restore_path = os.path.join(save_root, 'restored_imgs', img_name)
  69. # paste each restored face to the input image
  70. face_helper.paste_faces_to_input_image(save_restore_path)
  71.  
  72.  
  73. if __name__ == '__main__':
  74. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  75. parser = argparse.ArgumentParser()
  76.  
  77. parser.add_argument('--upscale_factor', type=int, default=1)
  78. parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANv1.pth')
  79. parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
  80. parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
  81. parser.add_argument('--only_center_face', action='store_true')
  82. parser.add_argument('--aligned', action='store_true')
  83. parser.add_argument('--paste_back', action='store_true')
  84.  
  85. args = parser.parse_args()
  86. if args.test_path.endswith('/'):
  87. args.test_path = args.test_path[:-1]
  88. save_root = 'results/'
  89. os.makedirs(save_root, exist_ok=True)
  90.  
  91. # initialize the GFP-GAN
  92. gfpgan = GFPGANv1(
  93. out_size=512,
  94. num_style_feat=512,
  95. channel_multiplier=1,
  96. decoder_load_path=None,
  97. fix_decoder=True,
  98. # for stylegan decoder
  99. num_mlp=8,
  100. input_is_latent=True,
  101. different_w=True,
  102. narrow=1,
  103. sft_half=True)
  104.  
  105. gfpgan.to(device)
  106. checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
  107. gfpgan.load_state_dict(checkpoint['params_ema'])
  108. gfpgan.eval()
  109.  
  110. # initialize face helper
  111. face_helper = FaceRestoreHelper(
  112. upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png')
  113.  
  114. img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
  115. for img_path in img_list:
  116. restoration(
  117. gfpgan,
  118. face_helper,
  119. img_path,
  120. save_root,
  121. has_aligned=args.aligned,
  122. only_center_face=args.only_center_face,
  123. suffix=args.suffix,
  124. paste_back=args.paste_back)
  125.  
  126. print('Results are in the <results> folder.')
  127.  
Advertisement
Add Comment
Please, Sign In to add comment