Advertisement
Guest User

Untitled

a guest
Apr 27th, 2023
126
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.33 KB | None | 0 0
  1. import os
  2. import sys
  3. import shutil
  4. import tarfile
  5. import argparse
  6. import boto3
  7. import torch
  8. from transformers import AutoTokenizer, AutoModelForCausalLM
  9.  
  10.  
  11. SAVE_NAME = "6b"
  12. MODEL_NAME = f"PygmalionAI/pygmalion-{SAVE_NAME}"
  13.  
  14.  
  15. def compress(tar_dir=None, output_file=f"model_{SAVE_NAME}.tar.gz"):
  16.     with tarfile.open(output_file, "w:gz") as tar:
  17.         tar.add(tar_dir, arcname=os.path.sep)
  18.  
  19.  
  20. def upload_file_to_s3(bucket_name=None, file_name="model.tar.gz", key_prefix=""):
  21.     s3 = boto3.resource("s3")
  22.     key_prefix_with_file_name = os.path.join(key_prefix, file_name)
  23.     s3.Bucket(bucket_name).upload_file(file_name, key_prefix_with_file_name)
  24.     return f"s3://{bucket_name}/{key_prefix_with_file_name}"
  25.  
  26.  
  27. def convert(bucket_name="hf-sagemaker-inference"):
  28.     model_save_dir = "./tmp"
  29.     key_prefix = "pyg"
  30.     src_inference_script = os.path.join("code", "inference.py")
  31.     dst_inference_script = os.path.join(model_save_dir, "code")
  32.  
  33.     os.makedirs(model_save_dir, exist_ok=True)
  34.     os.makedirs(dst_inference_script, exist_ok=True)
  35.  
  36.  
  37.     # load fp 16 model
  38.     print(f"Loading model from {MODEL_NAME}")
  39.     model = AutoModelForCausalLM.from_pretrained(
  40.         MODEL_NAME)
  41.     print("saving model with `torch.save`")
  42.     torch.save(model.eval().half(), os.path.join(model_save_dir, f"pygmalion.pt"))
  43.  
  44.     print("saving tokenizer")
  45.     tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
  46.     tokenizer.save_pretrained(model_save_dir)
  47.  
  48.     # copy inference script
  49.     print("copying inference.py script")
  50.     shutil.copy(src_inference_script, dst_inference_script)
  51.  
  52.     # create archive
  53.     # print("creating `model.tar.gz` archive")
  54.     # compress(model_save_dir)
  55.  
  56.  
  57.     model_uri = "DONE"
  58.     # Manually uploading to S3
  59.     # upload to s3
  60.     # print(
  61.     #     f"uploading `model.tar.gz` archive to s3://{bucket_name}/{key_prefix}/model.tar.gz"
  62.     # )
  63.     # model_uri = upload_file_to_s3(bucket_name=bucket_name, key_prefix=key_prefix)
  64.     # print(f"Successfully uploaded to {model_uri}")
  65.    
  66.    
  67.     sys.stdout.write(model_uri)
  68.     return model_uri
  69.  
  70.  
  71. def parse_args():
  72.     parser = argparse.ArgumentParser()
  73.     parser.add_argument("--bucket_name", type=str, default=None)
  74.     return parser.parse_args()
  75.  
  76.  
  77. if __name__ == "__main__":
  78.     # read config file
  79.     convert(BUCKET_NAME)
  80.    
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement