Guest User

Untitled

a guest
Nov 17th, 2017
106
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.04 KB | None | 0 0
  1. import torch
  2. import torchvision.transforms as transforms
  3. import torchvision.utils as vutils
  4. from PIL import Image
  5. import matplotlib.pyplot as plt
  6.  
  7.  
  8. def windowed_tensor(tensor, size, stride):
  9. """assumes a square 3D tensor"""
  10. tensor = tensor.unfold(1, size, stride)
  11. tensor = tensor.unfold(2, size, stride)
  12. tensor = torch.transpose(tensor, 0, 1)
  13. tensor = torch.transpose(tensor, 1, 2)
  14. tensor = tensor.contiguous()
  15. tensor = tensor.view(-1, 3, size, size)
  16. return tensor
  17.  
  18.  
  19.  
  20.  
  21. if __name__ == '__main__':
  22.  
  23. image_size = 512
  24. loader = transforms.Compose([transforms.Scale(image_size), transforms.ToTensor()])
  25. to_image = transforms.ToPILImage()
  26.  
  27. image = Image.open('some image path')
  28. image = loader(image)
  29.  
  30. tsize = 64
  31.  
  32. x = torch.range(1, tsize)
  33. y = x.view(8, 8)
  34.  
  35. image = windowed_tensor(image, 64, 32)
  36.  
  37. print(image.size())
  38.  
  39. grid = vutils.make_grid(image, 15, 0, normalize=True)
  40. plt.imshow(to_image(grid))
  41. plt.axis('off')
  42. plt.imsave('glitch.png', to_image(grid))
  43. plt.show()
Add Comment
Please, Sign In to add comment