Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from PIL import Image
- import torchvision.transforms as T
- import torch
- img = Image.open("test.png").convert("RGB").resize((128,128))
- img = T.ToTensor()(img).unsqueeze(0).to("cuda") * 2 - 1
- img = img.permute(0,2,3,1).reshape(1,-1,3)
- steps = 5
- step_size = 0.95
- do_attention = True
- with torch.no_grad():
- for i in range(steps):
- q = k = torch.nn.functional.normalize(img, dim=-1)
- output = torch.nn.functional.scaled_dot_product_attention(q, k, img, dropout_p=0.0, is_causal=False, scale=100.0)
- img = img * (1-step_size) + output * step_size
- img = img.clamp(-1, 1)
- img = img.reshape(1,128,128,3).permute(0,3,1,2)
- img = (img + 1) / 2
- img = T.ToPILImage()(img[0])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement