Advertisement
Guest User

Untitled

a guest
Sep 30th, 2022
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 6.13 KB | None | 0 0
  1. import re
  2. from transformers import DonutProcessor, VisionEncoderDecoderModel
  3. from PIL import PngImagePlugin, TiffImagePlugin, JpegImagePlugin
  4. import imghdr
  5. import os
  6. import torch
  7. from tqdm import tqdm
  8.  
  9. processor = DonutProcessor.from_pretrained(
  10.     "naver-clova-ix/donut-base-finetuned-rvlcdip"
  11. )
  12. model = VisionEncoderDecoderModel.from_pretrained(
  13.     "naver-clova-ix/donut-base-finetuned-rvlcdip"
  14. )
  15. device = "cuda" if torch.cuda.is_available() else "cpu"
  16. print(f"Running on {device}")
  17. model.to(device)
  18.  
  19. def load_image(filepath):
  20.     imgtype = imghdr.what(filepath)
  21.     if imgtype == 'jpeg':
  22.         img = JpegImagePlugin.JpegImageFile(filepath)
  23.     elif imgtype == 'png':
  24.         img = PngImagePlugin.PngImageFile(filepath)
  25.     elif imgtype == 'tiff':
  26.         img = TiffImagePlugin.TiffImageFile(filepath)
  27.     else:
  28.         print(filepath)
  29.     return img.convert('RGB'), imgtype, filepath
  30.    
  31. def scale_image_size(sizetuple, scalar):
  32.     return (int(sizetuple[0] * scalar), int(sizetuple[1] * scalar))
  33.  
  34. def absolute_scale(sizetuple, output_max_dim_size, output_min_dim_size):
  35.     """
  36.    Scales the image so that the larger dimension is output_max_dim_size,
  37.    and smaller dimension is no smaller than output_min_dim_size,
  38.    with the latter taking precedence if image shape isn't compatible with both.
  39.    """
  40.     scalar = output_max_dim_size/max(sizetuple)
  41.     new_shape = scale_image_size(sizetuple, scalar)
  42.     if min(new_shape) < output_min_dim_size:
  43.         scalar = output_min_dim_size/min(sizetuple)
  44.         return scale_image_size(sizetuple, scalar)
  45.     return new_shape
  46.    
  47.  
  48. def get_folder_and_filename_from_filepath(filepath):
  49.     splitfilepath = filepath.split('/')
  50.     return splitfilepath[-2], splitfilepath[-1]
  51.  
  52. def classify_and_save_thumbnail(image, sorted_folder, new_file_name):
  53.  
  54.     pred = run_prediction(image)
  55.  
  56.     thumbnail = image.resize(scale_image_size(image.size, 0.2))
  57.     #display(thumbnail)
  58.  
  59.     # make a new folder called sorted_into_classes
  60.     class_folder = sorted_folder + '/' + pred['class']
  61.     os.makedirs(class_folder, exist_ok = True)
  62.     thumbnail.save(class_folder + '/' + new_file_name)
  63.  
  64. img_folder = "/home/frans/Vaults/sovag-docs-bucket-sample/"
  65. sorted_folder = "/home/frans/Vaults/sorted_sovag_docs/"
  66.  
  67. max_n_images = 10
  68. n_images = 0
  69. classified = []
  70. errored = []
  71. for subdir, dirs, files in tqdm(os.walk(img_folder, topdown=False)):
  72.     for file in tqdm(files, leave=False):
  73.         assert files, "Open the vault folder"
  74.         #print os.path.join(subdir, file)
  75.         filepath = subdir + os.sep + file
  76.         if n_images > max_n_images:
  77.             break
  78.         elif "Trash" in filepath or filepath.endswith("trashinfo"):
  79.             continue
  80.         else:
  81.             #images.append(load_image(filepath))
  82.             #print(filepath)
  83.             image, imgtype, _ = load_image(filepath)
  84.             #print(image.size)
  85.             #print(imgtype)
  86.             try:
  87.                 new_file_name = '_'.join(get_folder_and_filename_from_filepath(filepath)) + '.png'
  88.                 #classify_and_save_thumbnail(image, sorted_folder, new_file_name)
  89.  
  90.                 sample = image
  91.                 # prepare decoder inputs
  92.                 task_prompt = "<s_rvlcdip>"
  93.                 #task_prompt = "<s_rvlcdip><s_class>"
  94.                 processor.feature_extractor.size = list(absolute_scale(sample.size, 1052, 357))
  95.                 decoder_input_ids = processor.tokenizer(
  96.                     task_prompt, add_special_tokens=False, return_tensors="pt"
  97.                 ).input_ids
  98.                 pixel_values = processor(sample, return_tensors="pt").pixel_values
  99.  
  100.                 # transpose last two dims, if image is in landscape mode
  101.                 is_landscape = pixel_values.shape[-1] >= pixel_values.shape[-2]
  102.                 if is_landscape:
  103.                     pixel_values = torch.transpose(pixel_values, -1, -2)
  104.  
  105.                 # force beam search to use one of the class labels at least once
  106.                 # force it to include end class tokens
  107.                 # by setting min_length = max_length = 4, we should get the output to always be on the form
  108.                 # <s_rvlcdip><s_class><class_label/></s_class>
  109.                 force_class_start = ["<s_class>"]
  110.                 force_class_end = ["</s_class>"]
  111.                 force_one_of = ["<invoice/>", "<budget/>", "<news_article/>", "<specification/>", "<scientific_report/>", "<scientific_publication/>", "<questionnaire/>", "<letter/>", "<advertisement/>", "<form/>", "<handwritten/>", "<file_folder/>", "<email/>", "<memo/>", "<resume/>", "<presentation/>"]
  112.  
  113.                 force_words_ids = [
  114.                     *processor.tokenizer(force_class_start, add_special_tokens=False).input_ids,
  115.                     *processor.tokenizer(force_class_end, add_special_tokens=False).input_ids,
  116.                     #[token[0] for token in processor.tokenizer(force_one_of, add_special_tokens=False).input_ids]
  117.                     processor.tokenizer(force_one_of, add_special_tokens=False).input_ids
  118.                     ]
  119.  
  120.                 outputs = model.generate(
  121.                     pixel_values.to(device),
  122.                     decoder_input_ids=decoder_input_ids.to(device),
  123.                     force_words_ids=force_words_ids,
  124.                     max_length=4,
  125.                     min_length=4,
  126.                     early_stopping=True,
  127.                     pad_token_id=processor.tokenizer.pad_token_id,
  128.                     eos_token_id=processor.tokenizer.eos_token_id,
  129.                     use_cache=True,
  130.                     num_beams=16,
  131.                     bad_words_ids=[[processor.tokenizer.unk_token_id]],
  132.                     num_return_sequences = 16,
  133.                     return_dict_in_generate=True,
  134.                     output_scores=True,
  135.                     no_repeat_ngram_size = 1, # prevents repetition of tokens like <s_class><s_class>
  136.                 )
  137.  
  138.                 sequence = processor.batch_decode(outputs.sequences)
  139.  
  140.                 classified.append(filepath)
  141.             except RuntimeError as e:
  142.                 errored.append(filepath)
  143.                 print(e)
  144.             n_images += 1
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement