Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import re
- from transformers import DonutProcessor, VisionEncoderDecoderModel
- from PIL import PngImagePlugin, TiffImagePlugin, JpegImagePlugin
- import imghdr
- import os
- import torch
- from tqdm import tqdm
- processor = DonutProcessor.from_pretrained(
- "naver-clova-ix/donut-base-finetuned-rvlcdip"
- )
- model = VisionEncoderDecoderModel.from_pretrained(
- "naver-clova-ix/donut-base-finetuned-rvlcdip"
- )
- device = "cuda" if torch.cuda.is_available() else "cpu"
- print(f"Running on {device}")
- model.to(device)
- def load_image(filepath):
- imgtype = imghdr.what(filepath)
- if imgtype == 'jpeg':
- img = JpegImagePlugin.JpegImageFile(filepath)
- elif imgtype == 'png':
- img = PngImagePlugin.PngImageFile(filepath)
- elif imgtype == 'tiff':
- img = TiffImagePlugin.TiffImageFile(filepath)
- else:
- print(filepath)
- return img.convert('RGB'), imgtype, filepath
- def scale_image_size(sizetuple, scalar):
- return (int(sizetuple[0] * scalar), int(sizetuple[1] * scalar))
- def absolute_scale(sizetuple, output_max_dim_size, output_min_dim_size):
- """
- Scales the image so that the larger dimension is output_max_dim_size,
- and smaller dimension is no smaller than output_min_dim_size,
- with the latter taking precedence if image shape isn't compatible with both.
- """
- scalar = output_max_dim_size/max(sizetuple)
- new_shape = scale_image_size(sizetuple, scalar)
- if min(new_shape) < output_min_dim_size:
- scalar = output_min_dim_size/min(sizetuple)
- return scale_image_size(sizetuple, scalar)
- return new_shape
- def get_folder_and_filename_from_filepath(filepath):
- splitfilepath = filepath.split('/')
- return splitfilepath[-2], splitfilepath[-1]
- def classify_and_save_thumbnail(image, sorted_folder, new_file_name):
- pred = run_prediction(image)
- thumbnail = image.resize(scale_image_size(image.size, 0.2))
- #display(thumbnail)
- # make a new folder called sorted_into_classes
- class_folder = sorted_folder + '/' + pred['class']
- os.makedirs(class_folder, exist_ok = True)
- thumbnail.save(class_folder + '/' + new_file_name)
- img_folder = "/home/frans/Vaults/sovag-docs-bucket-sample/"
- sorted_folder = "/home/frans/Vaults/sorted_sovag_docs/"
- max_n_images = 10
- n_images = 0
- classified = []
- errored = []
- for subdir, dirs, files in tqdm(os.walk(img_folder, topdown=False)):
- for file in tqdm(files, leave=False):
- assert files, "Open the vault folder"
- #print os.path.join(subdir, file)
- filepath = subdir + os.sep + file
- if n_images > max_n_images:
- break
- elif "Trash" in filepath or filepath.endswith("trashinfo"):
- continue
- else:
- #images.append(load_image(filepath))
- #print(filepath)
- image, imgtype, _ = load_image(filepath)
- #print(image.size)
- #print(imgtype)
- try:
- new_file_name = '_'.join(get_folder_and_filename_from_filepath(filepath)) + '.png'
- #classify_and_save_thumbnail(image, sorted_folder, new_file_name)
- sample = image
- # prepare decoder inputs
- task_prompt = "<s_rvlcdip>"
- #task_prompt = "<s_rvlcdip><s_class>"
- processor.feature_extractor.size = list(absolute_scale(sample.size, 1052, 357))
- decoder_input_ids = processor.tokenizer(
- task_prompt, add_special_tokens=False, return_tensors="pt"
- ).input_ids
- pixel_values = processor(sample, return_tensors="pt").pixel_values
- # transpose last two dims, if image is in landscape mode
- is_landscape = pixel_values.shape[-1] >= pixel_values.shape[-2]
- if is_landscape:
- pixel_values = torch.transpose(pixel_values, -1, -2)
- # force beam search to use one of the class labels at least once
- # force it to include end class tokens
- # by setting min_length = max_length = 4, we should get the output to always be on the form
- # <s_rvlcdip><s_class><class_label/></s_class>
- force_class_start = ["<s_class>"]
- force_class_end = ["</s_class>"]
- force_one_of = ["<invoice/>", "<budget/>", "<news_article/>", "<specification/>", "<scientific_report/>", "<scientific_publication/>", "<questionnaire/>", "<letter/>", "<advertisement/>", "<form/>", "<handwritten/>", "<file_folder/>", "<email/>", "<memo/>", "<resume/>", "<presentation/>"]
- force_words_ids = [
- *processor.tokenizer(force_class_start, add_special_tokens=False).input_ids,
- *processor.tokenizer(force_class_end, add_special_tokens=False).input_ids,
- #[token[0] for token in processor.tokenizer(force_one_of, add_special_tokens=False).input_ids]
- processor.tokenizer(force_one_of, add_special_tokens=False).input_ids
- ]
- outputs = model.generate(
- pixel_values.to(device),
- decoder_input_ids=decoder_input_ids.to(device),
- force_words_ids=force_words_ids,
- max_length=4,
- min_length=4,
- early_stopping=True,
- pad_token_id=processor.tokenizer.pad_token_id,
- eos_token_id=processor.tokenizer.eos_token_id,
- use_cache=True,
- num_beams=16,
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
- num_return_sequences = 16,
- return_dict_in_generate=True,
- output_scores=True,
- no_repeat_ngram_size = 1, # prevents repetition of tokens like <s_class><s_class>
- )
- sequence = processor.batch_decode(outputs.sequences)
- classified.append(filepath)
- except RuntimeError as e:
- errored.append(filepath)
- print(e)
- n_images += 1
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement