Norod78

p2s2p_torchscript_inference.py

Nov 10th, 2021 (edited)
809
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.14 KB | None | 0 0
  1. import torch
  2. import PIL
  3. from PIL import Image, ImageOps
  4. import numpy as np
  5.  
  6. def tensor2im(var):
  7.     var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
  8.     var = ((var + 1) / 2)
  9.     var[var < 0] = 0
  10.     var[var > 1] = 1
  11.     var = var * 255
  12.     return Image.fromarray(var.astype('uint8'))
  13.  
  14. def load_image_as_array(path):
  15.     image_in = path
  16.     im = Image.open(image_in)
  17.     try:
  18.         im = ImageOps.exif_transpose(im)
  19.     except:
  20.         print("exif problem, not rotating")
  21.         im = im.convert("RGB")
  22.  
  23.     im = im.resize((256, 256))
  24.     im_array = np.array(im, np.float32)
  25.     im_array = (im_array/255)*2 - 1
  26.     im_array = np.transpose(im_array, (2, 0, 1))
  27.     im_array = np.expand_dims(im_array, 0)
  28.  
  29.     return im_array
  30.  
  31. im_array = load_image_as_array('test_data/face-ok.jpg')
  32. tensor_in = torch.Tensor(im_array)
  33.  
  34. test_image = tensor2im(tensor_in[0])
  35. test_image.show()
  36.  
  37. net = torch.jit.load('p2s2p_torchscript.pt')
  38. net.eval()
  39. result = net(tensor_in)
  40.  
  41. #traced_model = torch.jit.trace(net, tensor_in)
  42. #result = traced_model(tensor_in)
  43.  
  44. output_image = tensor2im(result[0])
  45. output_image.save('face-toon.jpg')
  46. output_image.show()
Add Comment
Please, Sign In to add comment