daily pastebin goal
10%
SHARE
TWEET

Untitled

a guest Feb 13th, 2018 56 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torchvision
  3. import torchvision.transforms as transforms
  4. import torch.nn as nn
  5. from torch.autograd import Variable
  6. from torch.utils.data.sampler import SubsetRandomSampler
  7. import numpy as np
  8. import matplotlib.pyplot as plt
  9. %matplotlib inline
  10.  
  11. from torchvision.models import vgg16_bn # bn = batch normalization
  12. import copy
  13.  
  14. vgg_base = vgg16_bn(pretrained=True)
  15. # freeze old layers
  16. for param in vgg_base.parameters():
  17.     param.requires_grad = False
  18.  
  19. #print(vgg_base)
  20. for m in vgg_base.modules():
  21.     if isinstance(m, nn.Conv2d):
  22.         first_weights = m.weight.data
  23.         #print(first_weights)
  24.         break
  25.  
  26. fig = plt.figure()
  27. plt.figure(figsize=(10,10))
  28. for index, filter in enumerate(first_weights):
  29.     print(type(filter))
  30.     #print(type(np.abs(filter)))
  31.     plt.subplot(8, 8, index+1)
  32.     plt.imshow(filter[:,:,:].numpy())
  33.     plt.axis('off')
  34.  
  35. fig.show()
RAW Paste Data
Top