Advertisement
Guest User

Untitled

a guest
Sep 15th, 2024
63
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.42 KB | None | 0 0
  1. from PIL import Image
  2. from transformers import AutoModelForCausalLM
  3. from transformers import AutoProcessor
  4. import torch
  5.  
  6. model_path = "my_model_path"
  7.  
  8. # Note: set _attn_implementation='eager' if you don't have flash_attn installed
  9. model = AutoModelForCausalLM.from_pretrained(
  10. model_path,
  11. device_map="auto",
  12. trust_remote_code=True,
  13. torch_dtype="auto",
  14. _attn_implementation='flash_attention_2'
  15. )
  16.  
  17. # for best performance, use num_crops=4 for multi-frame, num_crops=16 for single-frame.
  18. processor = AutoProcessor.from_pretrained(model_path,
  19. trust_remote_code=True,
  20. num_crops=16
  21. )
  22.  
  23. image = Image.open("image_path")
  24.  
  25. messages = [
  26. {"role": "user", "content": f"<|image_1|>\n Extract all the text you see. The language is Roamanian."},
  27. ]
  28.  
  29. prompt = processor.tokenizer.apply_chat_template(
  30. messages,
  31. tokenize=False,
  32. add_generation_prompt=True
  33. )
  34.  
  35. inputs = processor(prompt, image, return_tensors="pt")
  36.  
  37. generation_args = {
  38. "max_new_tokens": 1000,
  39. "temperature": 0.0,
  40. "do_sample": False,
  41. }
  42. with torch.no_grad():
  43. generate_ids = model.generate(**inputs,
  44. eos_token_id=processor.tokenizer.eos_token_id,
  45. **generation_args
  46. )
  47.  
  48. # remove input tokens
  49. generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
  50.  
  51. response = processor.batch_decode(generate_ids,
  52. skip_special_tokens=True,
  53. clean_up_tokenization_spaces=False)[0]
  54.  
  55. print(response)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement