Advertisement
CookiePPP

Untitled

Dec 22nd, 2019
175
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.35 KB | None | 0 0
  1. def warm_start_model(checkpoint_path, model, ignore_layers):
  2. assert os.path.isfile(checkpoint_path)
  3. print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
  4. checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
  5. model_dict = checkpoint_dict['state_dict']
  6. model.load_state_dict(model_dict)
  7. return model
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement