Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- arams = []
- # callback for storing the params of the model after each epoch
- class SaveModelParams(Callback):
- def __init__(self, model):
- self.model = model
- def on_epoch_end(self, metrics):
- params.append([p.data.cpu().numpy() for p in self.model.parameters()])
- # basic setup and training of the model
- net2 = SimpleNet([32*32*3, 40, 10])
- learn2 = ConvLearner.from_model_data(net2, data)
- lr = 2e-2
- learn2.fit(lr, 3, use_swa=True, callbacks=[SaveModelParams(learn2.model)])
- # grab the params from the SWA model
- swa_model_params = [p.data.cpu().numpy() for p in learn2.swa_model.parameters()]
- for p_model1, p_model2, p_model3, p_swa_model in zip(*params, swa_model_params):
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement