Advertisement
Guest User

Batch captioner based on MiaoshouAI/Florence-2-base-PromptGen-v1.5

a guest
Sep 5th, 2024
124
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.14 KB | None | 0 0
  1. #!/usr/bin/python3
  2. import argparse, json, os, os.path, sys, time, torch
  3. from transformers import AutoProcessor, AutoModelForCausalLM
  4. from PIL import Image
  5.  
  6. parser = argparse.ArgumentParser(
  7.                     description="Batch image captioner based on MiaoshouAI/Florence-2-base-PromptGen-v1.5")
  8. parser.add_argument("directory", help="Directory of images to be captioned")
  9. parser.add_argument("--model", help="Captioning model to use", default="MiaoshouAI/Florence-2-base-PromptGen-v1.5")
  10. parser.add_argument("--batchsize", help="Batch size to use", type=int, default=4)
  11. parser.add_argument("--prompt", help="Prompt to use (see PromptGen docs for options)", default="<MORE_DETAILED_CAPTION>")
  12. parser.add_argument("--outfile", help="jsonl file to write captions to (default is individual .caption files in the image directory")
  13. args = parser.parse_args()
  14.  
  15. imgdir = args.directory
  16.  
  17.  
  18. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  19. torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
  20.  
  21. model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch_dtype, trust_remote_code=True).to(device)
  22. processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
  23.  
  24. batch = []
  25.  
  26. def _do_batch(batch, prompt, outf):
  27.     #assert len(batch) <= BATCH_SIZE
  28.  
  29.     start = time.perf_counter()
  30.  
  31.     inputs = processor(text=[prompt]*len(batch), images=[s[0] for s in batch], return_tensors="pt").to(device, torch_dtype)
  32.  
  33.     generated_ids = model.generate(
  34.         **inputs,
  35.         max_new_tokens=1024,
  36.         do_sample=False,
  37.         num_beams=3
  38.     )
  39.  
  40.     generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False)
  41.  
  42.     generated_texts = [text.replace('</s>', '').replace('<s>', '').replace('<pad>', '') for text in generated_texts]
  43.  
  44.     for desc, (descfile, filename) in zip(generated_texts, (s[1:3] for s in batch)):
  45.         if outf is None:
  46.             assert not os.path.exists(descfile)
  47.             print(f"{descfile}: {desc}")
  48.             with open(descfile, 'w', encoding='utf-8') as f:
  49.                 f.write(desc)
  50.         else:
  51.             print(f"{filename}: {desc}")
  52.             data = { "file_name": filename, "text": desc }
  53.             outf.write(json.dumps(data, ensure_ascii=False)+"\n")
  54.  
  55.     elapsed = time.perf_counter() - start
  56.     print(f"Processed batch in {elapsed} secs")
  57.  
  58.  
  59. if args.outfile is not None:
  60.     outf = open(args.outfile, "w", encoding="utf-8")
  61. else:
  62.     outf = None
  63.  
  64. for filename in os.listdir(imgdir):
  65.     (basename, ext) = os.path.splitext(filename)
  66.     if outf is None:
  67.         descfile = f"{imgdir}/{basename}.caption"
  68.     else:
  69.         descfile = None
  70.  
  71.     if ext.lower() not in {".jpg", ".jpeg", ".gif", ".png", ".webp"}:
  72.         continue
  73.  
  74.     if (outf is None) and os.path.exists(descfile):
  75.         continue
  76.  
  77.     image = Image.open(imgdir + "/" + filename).convert("RGB")
  78.     batch.append((image, descfile, filename))
  79.  
  80.     if len(batch) >= args.batchsize:
  81.         _do_batch(batch, args.prompt, outf)
  82.         batch = []
  83.  
  84. if len(batch) > 0:
  85.     _do_batch(batch, args.prompt, outf)
  86.  
  87. if outf is not None:
  88.     outf.close()
  89.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement