Advertisement
Guest User

Untitled

a guest
Jul 19th, 2019
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.46 KB | None | 0 0
  1. def set_parameter_requires_grad(model, featureExtracting):
  2. if featureExtracting:
  3. for param in model.parameters():
  4. param.requires_grad = False
  5.  
  6. def VGG16_pretrained_model(numClasses, featureExtract=True, usePretrained=True):
  7. model = models.vgg16(pretrained=True)
  8. set_parameter_requires_grad(model, featureExtract)
  9. numFtrs = model.classifier[6].in_features
  10. model.classifier[6] = nn.Linear(numFtrs, numClasses)
  11. return model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement