Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/python
- import sys
- import time
- import torch
- from transformers import T5Tokenizer, T5ForConditionalGeneration
- def read_prompt():
- prompt = ''
- while True:
- prompt += input("> ") + "\n"
- if prompt.endswith("\n\n"):
- break
- return prompt
- dev = "cuda"
- PARAMS = [(50, 0.4), (100, 0.4), (200, 0.4),
- (50, 0.6), (100, 0.6), (200, 0.6),
- (50, 0.9), (100, 0.9), (200, 0.9),
- ]
- start = time.time()
- tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
- model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto")
- print(time.time() - start)
- while True:
- prompt = read_prompt()
- print("-"*10)
- start = time.time()
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(dev)
- #attention_mask = torch.ones_like(inputs['input_ids'])
- for i in range(len(PARAMS)):
- print("Iter %d (top_k=%d, top_p=%f): -----" % (i, PARAMS[i][0], PARAMS[i][1]))
- for _ in range(3):
- outputs = model.generate(input_ids,
- max_new_tokens=500,
- do_sample=True,
- top_k=PARAMS[i][0],
- top_p=PARAMS[i][1],
- # #attention_mask=attention_mask
- )
- gen_text = tokenizer.decode(outputs[0])
- print(gen_text)
- print("-")
- print("-"*10)
- print(time.time() - start)
Advertisement
Add Comment
Please, Sign In to add comment