Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import sys
- import shutil
- import tarfile
- import argparse
- import boto3
- import torch
- from transformers import AutoTokenizer, AutoModelForCausalLM
- SAVE_NAME = "6b"
- MODEL_NAME = f"PygmalionAI/pygmalion-{SAVE_NAME}"
- def compress(tar_dir=None, output_file=f"model_{SAVE_NAME}.tar.gz"):
- with tarfile.open(output_file, "w:gz") as tar:
- tar.add(tar_dir, arcname=os.path.sep)
- def upload_file_to_s3(bucket_name=None, file_name="model.tar.gz", key_prefix=""):
- s3 = boto3.resource("s3")
- key_prefix_with_file_name = os.path.join(key_prefix, file_name)
- s3.Bucket(bucket_name).upload_file(file_name, key_prefix_with_file_name)
- return f"s3://{bucket_name}/{key_prefix_with_file_name}"
- def convert(bucket_name="hf-sagemaker-inference"):
- model_save_dir = "./tmp"
- key_prefix = "pyg"
- src_inference_script = os.path.join("code", "inference.py")
- dst_inference_script = os.path.join(model_save_dir, "code")
- os.makedirs(model_save_dir, exist_ok=True)
- os.makedirs(dst_inference_script, exist_ok=True)
- # load fp 16 model
- print(f"Loading model from {MODEL_NAME}")
- model = AutoModelForCausalLM.from_pretrained(
- MODEL_NAME)
- print("saving model with `torch.save`")
- torch.save(model.eval().half(), os.path.join(model_save_dir, f"pygmalion.pt"))
- print("saving tokenizer")
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
- tokenizer.save_pretrained(model_save_dir)
- # copy inference script
- print("copying inference.py script")
- shutil.copy(src_inference_script, dst_inference_script)
- # create archive
- # print("creating `model.tar.gz` archive")
- # compress(model_save_dir)
- model_uri = "DONE"
- # Manually uploading to S3
- # upload to s3
- # print(
- # f"uploading `model.tar.gz` archive to s3://{bucket_name}/{key_prefix}/model.tar.gz"
- # )
- # model_uri = upload_file_to_s3(bucket_name=bucket_name, key_prefix=key_prefix)
- # print(f"Successfully uploaded to {model_uri}")
- sys.stdout.write(model_uri)
- return model_uri
- def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument("--bucket_name", type=str, default=None)
- return parser.parse_args()
- if __name__ == "__main__":
- # read config file
- convert(BUCKET_NAME)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement