Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torchvision.transforms as transforms
- import torchvision.utils as vutils
- from PIL import Image
- import matplotlib.pyplot as plt
- def windowed_tensor(tensor, size, stride):
- """assumes a square 3D tensor"""
- tensor = tensor.unfold(1, size, stride)
- tensor = tensor.unfold(2, size, stride)
- tensor = torch.transpose(tensor, 0, 1)
- tensor = torch.transpose(tensor, 1, 2)
- tensor = tensor.contiguous()
- tensor = tensor.view(-1, 3, size, size)
- return tensor
- if __name__ == '__main__':
- image_size = 512
- loader = transforms.Compose([transforms.Scale(image_size), transforms.ToTensor()])
- to_image = transforms.ToPILImage()
- image = Image.open('some image path')
- image = loader(image)
- tsize = 64
- x = torch.range(1, tsize)
- y = x.view(8, 8)
- image = windowed_tensor(image, 64, 32)
- print(image.size())
- grid = vutils.make_grid(image, 15, 0, normalize=True)
- plt.imshow(to_image(grid))
- plt.axis('off')
- plt.imsave('glitch.png', to_image(grid))
- plt.show()
Add Comment
Please, Sign In to add comment