Advertisement
Guest User

Untitled

a guest
Mar 26th, 2019
66
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.70 KB | None | 0 0
  1. arams = []
  2.  
  3. # callback for storing the params of the model after each epoch
  4. class SaveModelParams(Callback):
  5. def __init__(self, model):
  6. self.model = model
  7.  
  8. def on_epoch_end(self, metrics):
  9. params.append([p.data.cpu().numpy() for p in self.model.parameters()])
  10.  
  11. # basic setup and training of the model
  12. net2 = SimpleNet([32*32*3, 40, 10])
  13. learn2 = ConvLearner.from_model_data(net2, data)
  14. lr = 2e-2
  15. learn2.fit(lr, 3, use_swa=True, callbacks=[SaveModelParams(learn2.model)])
  16.  
  17. # grab the params from the SWA model
  18. swa_model_params = [p.data.cpu().numpy() for p in learn2.swa_model.parameters()]
  19.  
  20. for p_model1, p_model2, p_model3, p_swa_model in zip(*params, swa_model_params):
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement