Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- from pathlib import Path
- import torch
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument('--input', '-I', type=str, help='Input file to prune', required = True)
- args = parser.parse_args()
- model_path = args.input
- sd = torch.load(model_path, map_location="cpu")
- if "state_dict" not in sd:
- pruned_sd = {
- "state_dict": dict(),
- }
- else:
- pruned_sd = dict()
- for k in sd.keys():
- if k != "optimizer_states":
- if "state_dict" not in sd:
- pruned_sd["state_dict"][k] = sd[k]
- else:
- pruned_sd[k] = sd[k]
- torch.save(pruned_sd, "model-pre-pruned.ckpt")
Advertisement
Add Comment
Please, Sign In to add comment