Guest User

vllm captioning with InterVL-8B

a guest
Aug 26th, 2024
16
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.93 KB | None | 0 0
  1. """
  2. This example shows how to use vLLM for running offline inference
  3. with the correct prompt format on vision language models.
  4.  
  5. For most models, the prompt format should follow corresponding examples
  6. on HuggingFace model repository.
  7. """
  8. import json
  9. import random
  10. from pathlib import Path
  11.  
  12. import numpy as np
  13. from PIL import Image
  14. from tqdm import tqdm
  15. from transformers import AutoTokenizer
  16. from vllm import LLM, SamplingParams
  17.  
  18. model_name = "OpenGVLab/InternVL2-8B"
  19. folder_path = r"/workspace/datasets/Belle Delphine"
  20. config_extension = '.tagging-info.internvl8.json'
  21.  
  22.  
  23. # Create a struct like for image properties holding the path and the pillow image
  24. class ImageObject:
  25.     def __init__(self, path: Path, image: Image):
  26.         self.path_str = str(path)
  27.         self.path = path
  28.         self.image = image
  29.         self.prompt = None
  30.  
  31.     def __str__(self):
  32.         return f"ImageObject(path={self.path_str}, image={self.image})"
  33.  
  34.  
  35. print("Loading the LLM")
  36.  
  37. llm = LLM(
  38.     model=model_name,
  39.     trust_remote_code=True,
  40.     max_num_seqs=8,
  41.     max_model_len=8192,
  42.     quantization="awq",
  43.     enforce_eager=True,
  44.     dtype="half",
  45. )
  46.  
  47. tokenizer = AutoTokenizer.from_pretrained(model_name,
  48.                                           trust_remote_code=True)
  49.  
  50. # Stop tokens for InternVL
  51. # models variants may have different stop tokens
  52. # please refer to the model card for the correct "stop words":
  53. # https://huggingface.co/OpenGVLab/InternVL2-2B#service
  54. stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
  55. stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
  56.  
  57. print("Looking for images in the folder: " + folder_path)
  58.  
  59. # Recursively load all images in the folder and subfolders. - Ignore files where a caption file already exists.
  60. image_paths = []
  61. for path in Path(folder_path).rglob('*'):
  62.     if path.is_file() and path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
  63.         if "-masklabel" in str(path):
  64.             continue
  65.         tags_path = path.with_suffix(config_extension)
  66.         if not tags_path.is_file():
  67.             image_paths.append(path)
  68.  
  69. print(f"Found {len(image_paths)} images.")
  70.  
  71. # Randomize the image paths (should not be the same order every time)
  72. np.random.seed(random.randint(0, 1000000))
  73.  
  74. np.random.shuffle(image_paths)
  75.  
  76.  
  77. def construct_prompt(image_object: ImageObject):
  78.     prompt = "<|im_start|>system\n你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。 You are describing images of Belle Delphine. If she is visible, directly use her name. For example: Belle Delphine is taking a selfie ....<|im_end|>\n"
  79.  
  80.     prompt += f"<|im_start|>user\n<image>\nDescribe the image in detail. Talk a bit about the scene. The main focus should be the person. (What is the person doing / what is their pose, how do they look, their gender, and especially what is their outfit?). If there is text in the image, describe its position and spell it out.<|im_end|>\n"
  81.  
  82.     if "snapchat" in image_object.path_str.lower():
  83.         prompt += "<|im_start|>assistant\nSnapchat post which shows "
  84.     else:
  85.         prompt += "<|im_start|>assistant\nThe image shows "
  86.  
  87.     image_object.prompt = prompt
  88.  
  89.  
  90. # Create an array containing groups of 8 image paths
  91. grouped_paths = [image_paths[i:i + 8] for i in range(0, len(image_paths), 8)]
  92.  
  93. sampling_params = SamplingParams(temperature=0.0,
  94.                                  max_tokens=256,
  95.                                  top_k=50,
  96.                                  stop_token_ids=stop_token_ids)
  97.  
  98. for group in tqdm(grouped_paths):
  99.     try:
  100.         image_objects = []
  101.         inputs = []
  102.         for path in group:
  103.             image = Image.open(path).convert("RGB")
  104.             image_object = ImageObject(path, image)
  105.             image_objects.append(image_object)
  106.             construct_prompt(image_object)
  107.  
  108.             inputs.append({
  109.                 "prompt": image_object.prompt,
  110.                 "multi_modal_data": {
  111.                     "image": image_object.image}})
  112.  
  113.         outputs = llm.generate(inputs, sampling_params=sampling_params)
  114.  
  115.         for o, i in zip(outputs, image_objects):
  116.             generated_text = o.outputs[0].text
  117.  
  118.             # Replace the folowing things with just Belle Delphine
  119.             generated_text = generated_text.replace("“Belle Delphine”", "Belle Delphine").replace("“Belle Delphine,”",
  120.                                                                                                   "Belle Delphine,")
  121.  
  122.             print(i.path_str + ":")
  123.             print(generated_text)
  124.  
  125.             # Save the data to a file
  126.             data = {"file_path": i.path_str, "generated_text": generated_text}
  127.  
  128.             with open(i.path.with_suffix(config_extension), 'w') as f:
  129.                 json.dump(data, f)
  130.     except Exception as e:
  131.         print(f"Error processing group {group}: {e}")
  132.  
Advertisement
Add Comment
Please, Sign In to add comment