Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from io import StringIO
- import textwrap
- conversation = StringIO(textwrap.dedent("""\
- The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly. The assistant's name is James. He was created by AC Clash.
- AI: I am an AI created by AC Clash. How can I help you today?
- Human: Hi, who are you?
- AI:"""))
- def ac_gpt():
- def start_conversation():
- get_response(conversation.read())
- def get_response(input_text):
- from transformers import AutoModelForCausalLM, AutoTokenizer
- model_dir = 'ac-gpt'
- model = AutoModelForCausalLM.from_pretrained(model_dir)
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
- # Now you can use the 'model' and 'tokenizer' as usual
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
- generated_ids = model.generate(
- input_ids,
- max_length=80, # Original was 50
- temperature=0.7,
- num_beams=5,
- no_repeat_ngram_size=2,
- num_return_sequences=1,
- bos_token_id=tokenizer.bos_token_id,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id
- )
- generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
- print(generated_text)
- def add_message(prompt):
- conversation.write("\nHuman:")
- conversation.write(prompt)
- conversation.write("\nAI:")
- def loop():
- while True:
- prompt = input("Send message: ")
- add_message(prompt)
- get_response(conversation.read())
- start_conversation()
- loop()
- # Press the green button in the gutter to run the script.
- def save_gpt():
- from transformers import GPT2Model, TFGPT2Model, GPT2Tokenizer
- local_model_dir = "ac-gpt"
- pt_model = GPT2Model.from_pretrained('gpt2')
- tf_model = TFGPT2Model.from_pretrained('gpt2', from_pt=True)
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
- # Save the models and tokenizer
- pt_model.save_pretrained(local_model_dir)
- tf_model.save_pretrained(local_model_dir)
- tokenizer.save_pretrained(local_model_dir)
- def train_gpt():
- pass
- if __name__ == '__main__':
- ac_gpt()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement