slush

Untitled

Jan 18th, 2023
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.47 KB | None | 0 0
  1. #!/usr/bin/python
  2. import sys
  3. import time
  4. import torch
  5. from transformers import T5Tokenizer, T5ForConditionalGeneration
  6.  
  7. def read_prompt():
  8. prompt = ''
  9. while True:
  10. prompt += input("> ") + "\n"
  11. if prompt.endswith("\n\n"):
  12. break
  13.  
  14. return prompt
  15.  
  16. dev = "cuda"
  17.  
  18. PARAMS = [(50, 0.4), (100, 0.4), (200, 0.4),
  19. (50, 0.6), (100, 0.6), (200, 0.6),
  20. (50, 0.9), (100, 0.9), (200, 0.9),
  21. ]
  22.  
  23. start = time.time()
  24.  
  25. tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
  26. model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto")
  27.  
  28. print(time.time() - start)
  29.  
  30. while True:
  31. prompt = read_prompt()
  32. print("-"*10)
  33. start = time.time()
  34.  
  35. input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(dev)
  36. #attention_mask = torch.ones_like(inputs['input_ids'])
  37.  
  38. for i in range(len(PARAMS)):
  39. print("Iter %d (top_k=%d, top_p=%f): -----" % (i, PARAMS[i][0], PARAMS[i][1]))
  40.  
  41. for _ in range(3):
  42. outputs = model.generate(input_ids,
  43. max_new_tokens=500,
  44. do_sample=True,
  45. top_k=PARAMS[i][0],
  46. top_p=PARAMS[i][1],
  47. # #attention_mask=attention_mask
  48. )
  49. gen_text = tokenizer.decode(outputs[0])
  50. print(gen_text)
  51. print("-")
  52.  
  53. print("-"*10)
  54.  
  55. print(time.time() - start)
  56.  
Advertisement
Add Comment
Please, Sign In to add comment