Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- """
- This example shows how to use vLLM for running offline inference
- with the correct prompt format on vision language models.
- For most models, the prompt format should follow corresponding examples
- on HuggingFace model repository.
- """
- import json
- import random
- from pathlib import Path
- import numpy as np
- from PIL import Image
- from tqdm import tqdm
- from transformers import AutoTokenizer
- from vllm import LLM, SamplingParams
- model_name = "OpenGVLab/InternVL2-8B"
- folder_path = r"/workspace/datasets/Belle Delphine"
- config_extension = '.tagging-info.internvl8.json'
- # Create a struct like for image properties holding the path and the pillow image
- class ImageObject:
- def __init__(self, path: Path, image: Image):
- self.path_str = str(path)
- self.path = path
- self.image = image
- self.prompt = None
- def __str__(self):
- return f"ImageObject(path={self.path_str}, image={self.image})"
- print("Loading the LLM")
- llm = LLM(
- model=model_name,
- trust_remote_code=True,
- max_num_seqs=8,
- max_model_len=8192,
- quantization="awq",
- enforce_eager=True,
- dtype="half",
- )
- tokenizer = AutoTokenizer.from_pretrained(model_name,
- trust_remote_code=True)
- # Stop tokens for InternVL
- # models variants may have different stop tokens
- # please refer to the model card for the correct "stop words":
- # https://huggingface.co/OpenGVLab/InternVL2-2B#service
- stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
- stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
- print("Looking for images in the folder: " + folder_path)
- # Recursively load all images in the folder and subfolders. - Ignore files where a caption file already exists.
- image_paths = []
- for path in Path(folder_path).rglob('*'):
- if path.is_file() and path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
- if "-masklabel" in str(path):
- continue
- tags_path = path.with_suffix(config_extension)
- if not tags_path.is_file():
- image_paths.append(path)
- print(f"Found {len(image_paths)} images.")
- # Randomize the image paths (should not be the same order every time)
- np.random.seed(random.randint(0, 1000000))
- np.random.shuffle(image_paths)
- def construct_prompt(image_object: ImageObject):
- prompt = "<|im_start|>system\n你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。 You are describing images of Belle Delphine. If she is visible, directly use her name. For example: Belle Delphine is taking a selfie ....<|im_end|>\n"
- prompt += f"<|im_start|>user\n<image>\nDescribe the image in detail. Talk a bit about the scene. The main focus should be the person. (What is the person doing / what is their pose, how do they look, their gender, and especially what is their outfit?). If there is text in the image, describe its position and spell it out.<|im_end|>\n"
- if "snapchat" in image_object.path_str.lower():
- prompt += "<|im_start|>assistant\nSnapchat post which shows "
- else:
- prompt += "<|im_start|>assistant\nThe image shows "
- image_object.prompt = prompt
- # Create an array containing groups of 8 image paths
- grouped_paths = [image_paths[i:i + 8] for i in range(0, len(image_paths), 8)]
- sampling_params = SamplingParams(temperature=0.0,
- max_tokens=256,
- top_k=50,
- stop_token_ids=stop_token_ids)
- for group in tqdm(grouped_paths):
- try:
- image_objects = []
- inputs = []
- for path in group:
- image = Image.open(path).convert("RGB")
- image_object = ImageObject(path, image)
- image_objects.append(image_object)
- construct_prompt(image_object)
- inputs.append({
- "prompt": image_object.prompt,
- "multi_modal_data": {
- "image": image_object.image}})
- outputs = llm.generate(inputs, sampling_params=sampling_params)
- for o, i in zip(outputs, image_objects):
- generated_text = o.outputs[0].text
- # Replace the folowing things with just Belle Delphine
- generated_text = generated_text.replace("“Belle Delphine”", "Belle Delphine").replace("“Belle Delphine,”",
- "Belle Delphine,")
- print(i.path_str + ":")
- print(generated_text)
- # Save the data to a file
- data = {"file_path": i.path_str, "generated_text": generated_text}
- with open(i.path.with_suffix(config_extension), 'w') as f:
- json.dump(data, f)
- except Exception as e:
- print(f"Error processing group {group}: {e}")
Advertisement
Add Comment
Please, Sign In to add comment