Guest User

Untitled

a guest
Nov 22nd, 2022
276
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.33 KB | None | 0 0
  1. # originally from https://gist.github.com/younesbelkada/382016361580b939a87edcddc94c6593
  2. # I made a single change to fix it outputting an extra file and winding up with pytorch_model_00011-of-10.bin
  3. import torch
  4. import os
  5. import json
  6. import argparse
  7.  
  8. parser = argparse.ArgumentParser(description='Sharding Hugging Face models')
  9. parser.add_argument('--sharding_factor', default=4, type=int, help='Sharding factor - aka how many shards to create')
  10. parser.add_argument('--source_model_path', default="t5-v1_1-xl", type=str, help='Relative path to the source model folder')
  11. parser.add_argument('--sharded_model_path', default="t5-v1_1-xl-sharded", type=str, help='Relative path to the target sharded model folder')
  12. args = parser.parse_args()
  13.  
  14. def get_memory_footprint_param(param):
  15.     r"""
  16.        Get the memory footprint of a parameter.
  17.    """
  18.     return sum([param.nelement()*param.element_size()])
  19.  
  20.  
  21. def get_index_json_file():
  22.     r"""
  23.        Get the default index.json dictionary. This
  24.        had to contain the metadata of the model as well
  25.        as the weight_map.
  26.    """
  27.     index_dict = {
  28.         "metadata": {
  29.             "total_size": 0
  30.         },
  31.         "weight_map": {}
  32.     }
  33.     return index_dict
  34.  
  35.  
  36. def save_index_file(path_sharded, index_dict):
  37.     r"""
  38.        Save the index.json file.
  39.    """
  40.     with open(os.path.join(path_sharded, "pytorch_model.bin.index.json"), "w", encoding="utf-8") as f:
  41.         json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
  42.         f.write(json_config)
  43.  
  44. if __name__ == "__main__":
  45.     # Get the args
  46.     ROOT_PATH=os.getcwd()
  47.     source_model = args.source_model_path
  48.     target_model = args.sharded_model_path
  49.  
  50.     path_model = os.path.join(ROOT_PATH, source_model, "pytorch_model.bin")
  51.     path_sharded = os.path.join(ROOT_PATH, target_model)
  52.     sharding_factor = args.sharding_factor
  53.  
  54.     # Initialize the variables
  55.     index_dict = get_index_json_file()
  56.     state_dict = torch.load(path_model)
  57.    
  58.     sharded_state_dict = {}
  59.     total_keys = []
  60.  
  61.     current_file_name = f"pytorch_model_00001-of-{str(sharding_factor).zfill(5)}.bin"
  62.     checking_step = len(state_dict.keys())//(sharding_factor-1)
  63.  
  64.     # Loop over the parms and shard them if necessary
  65.     for i, key in enumerate(state_dict.keys()):
  66.         # Get the current param
  67.         param = state_dict[key]
  68.         index_dict["metadata"]["total_size"] += get_memory_footprint_param(param)
  69.         index_dict["weight_map"][key] = current_file_name
  70.         sharded_state_dict[key] = param
  71.         total_keys.append(key)
  72.         # Check if we need to create a new file
  73.         if (i+1) % checking_step == 0:
  74.             torch.save(sharded_state_dict, os.path.join(path_sharded, current_file_name))
  75.             sharded_state_dict = {}
  76.             new_index = ((i+1)//checking_step) + 1
  77.             current_file_name = f"pytorch_model_{str(new_index).zfill(5)}-of-{str(sharding_factor).zfill(5)}.bin"
  78.  
  79.     # Save the last sharded file if necessary
  80.     if len(sharded_state_dict) > 0:
  81.         torch.save(sharded_state_dict, os.path.join(path_sharded, current_file_name))
  82.    
  83.     # Last sanity check
  84.     if total_keys != list(state_dict.keys()):
  85.         raise ValueError("The keys in the index.json file are not the same as the keys in the model")
  86.  
  87.     save_index_file(path_sharded, index_dict)
  88.  
Advertisement
Add Comment
Please, Sign In to add comment