Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/python3
- import argparse, json, os, os.path, sys, time, torch
- from transformers import AutoProcessor, AutoModelForCausalLM
- from PIL import Image
- parser = argparse.ArgumentParser(
- description="Batch image captioner based on MiaoshouAI/Florence-2-base-PromptGen-v1.5")
- parser.add_argument("directory", help="Directory of images to be captioned")
- parser.add_argument("--model", help="Captioning model to use", default="MiaoshouAI/Florence-2-base-PromptGen-v1.5")
- parser.add_argument("--batchsize", help="Batch size to use", type=int, default=4)
- parser.add_argument("--prompt", help="Prompt to use (see PromptGen docs for options)", default="<MORE_DETAILED_CAPTION>")
- parser.add_argument("--outfile", help="jsonl file to write captions to (default is individual .caption files in the image directory")
- args = parser.parse_args()
- imgdir = args.directory
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
- model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch_dtype, trust_remote_code=True).to(device)
- processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True)
- batch = []
- def _do_batch(batch, prompt, outf):
- #assert len(batch) <= BATCH_SIZE
- start = time.perf_counter()
- inputs = processor(text=[prompt]*len(batch), images=[s[0] for s in batch], return_tensors="pt").to(device, torch_dtype)
- generated_ids = model.generate(
- **inputs,
- max_new_tokens=1024,
- do_sample=False,
- num_beams=3
- )
- generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False)
- generated_texts = [text.replace('</s>', '').replace('<s>', '').replace('<pad>', '') for text in generated_texts]
- for desc, (descfile, filename) in zip(generated_texts, (s[1:3] for s in batch)):
- if outf is None:
- assert not os.path.exists(descfile)
- print(f"{descfile}: {desc}")
- with open(descfile, 'w', encoding='utf-8') as f:
- f.write(desc)
- else:
- print(f"{filename}: {desc}")
- data = { "file_name": filename, "text": desc }
- outf.write(json.dumps(data, ensure_ascii=False)+"\n")
- elapsed = time.perf_counter() - start
- print(f"Processed batch in {elapsed} secs")
- if args.outfile is not None:
- outf = open(args.outfile, "w", encoding="utf-8")
- else:
- outf = None
- for filename in os.listdir(imgdir):
- (basename, ext) = os.path.splitext(filename)
- if outf is None:
- descfile = f"{imgdir}/{basename}.caption"
- else:
- descfile = None
- if ext.lower() not in {".jpg", ".jpeg", ".gif", ".png", ".webp"}:
- continue
- if (outf is None) and os.path.exists(descfile):
- continue
- image = Image.open(imgdir + "/" + filename).convert("RGB")
- batch.append((image, descfile, filename))
- if len(batch) >= args.batchsize:
- _do_batch(batch, args.prompt, outf)
- batch = []
- if len(batch) > 0:
- _do_batch(batch, args.prompt, outf)
- if outf is not None:
- outf.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement