Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def warm_start_model(checkpoint_path, model, ignore_layers):
- assert os.path.isfile(checkpoint_path)
- print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
- checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
- model_dict = checkpoint_dict['state_dict']
- model.load_state_dict(model_dict)
- return model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement