Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- diff --git a/main.py b/main.py
- index 84607cb..91c9066 100644
- --- a/main.py
- +++ b/main.py
- @@ -486,6 +486,13 @@ if __name__ == '__main__':
- # Forward pass W
- decoder.train()
- normalize_kernels(decoder, args.filters_norm) # Project each kernel in module to the sphere
- + if batch_id == 0:
- + import torch.nn.functional as F
- + import torchvision
- +
- + w = decoder.backbone[0][0].weight.permute(1, 0, 2, 3)
- + w = F.interpolate(w, scale_factor=8)
- + torchvision.utils.save_image(w, f'foo/w_{epoch:02d}.png', nrow=4, normalize=True)
- rec = decoder(Zs)
- rec_loss = MSE(X, rec)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement