Guest User

PixArt Load T5 8Bit

a guest
Oct 22nd, 2023
229
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.52 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import re
  4. import html
  5. import urllib.parse as ul
  6.  
  7. import ftfy
  8. import torch
  9. from bs4 import BeautifulSoup
  10. from transformers import T5EncoderModel, AutoTokenizer
  11. from huggingface_hub import hf_hub_download
  12.  
  13. class T5Embedder:
  14.  
  15. available_models = ['t5-v1_1-xxl']
  16. bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
  17.  
  18. def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
  19. t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None):
  20. self.device = torch.device(device)
  21. self.torch_dtype = torch_dtype or torch.bfloat16
  22. if t5_model_kwargs is None:
  23. t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype, 'load_in_8bit':True, 'device_map':'auto'}
  24. self.use_text_preprocessing = use_text_preprocessing
  25. self.hf_token = hf_token
  26. self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
  27. self.dir_or_name = dir_or_name
  28. tokenizer_path, path = dir_or_name, dir_or_name
  29. if local_cache:
  30. cache_dir = os.path.join(self.cache_dir, dir_or_name)
  31. tokenizer_path, path = cache_dir, cache_dir
  32. elif dir_or_name in self.available_models:
  33. cache_dir = os.path.join(self.cache_dir, dir_or_name)
  34. for filename in [
  35. 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
  36. 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
  37. ]:
  38. hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
  39. force_filename=filename, token=self.hf_token)
  40. tokenizer_path, path = cache_dir, cache_dir
  41. else:
  42. cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
  43. for filename in [
  44. 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
  45. ]:
  46. hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
  47. force_filename=filename, token=self.hf_token)
  48. tokenizer_path = cache_dir
  49.  
  50. print(tokenizer_path)
  51. self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
  52. self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
  53.  
  54. def get_text_embeddings(self, texts):
  55. texts = [self.text_preprocessing(text) for text in texts]
  56.  
  57. text_tokens_and_mask = self.tokenizer(
  58. texts,
  59. max_length=120,
  60. padding='max_length',
  61. truncation=True,
  62. return_attention_mask=True,
  63. add_special_tokens=True,
  64. return_tensors='pt'
  65. )
  66.  
  67. text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
  68. text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
  69.  
  70. with torch.no_grad():
  71. text_encoder_embs = self.model(
  72. input_ids=text_tokens_and_mask['input_ids'].to(self.device),
  73. attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
  74. )['last_hidden_state'].detach()
  75. return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device)
  76.  
  77. def text_preprocessing(self, text):
  78. if self.use_text_preprocessing:
  79. # The exact text cleaning as was in the training stage:
  80. text = self.clean_caption(text)
  81. text = self.clean_caption(text)
  82. return text
  83. else:
  84. return text.lower().strip()
  85.  
  86. @staticmethod
  87. def basic_clean(text):
  88. text = ftfy.fix_text(text)
  89. text = html.unescape(html.unescape(text))
  90. return text.strip()
  91.  
  92. def clean_caption(self, caption):
  93. caption = str(caption)
  94. caption = ul.unquote_plus(caption)
  95. caption = caption.strip().lower()
  96. caption = re.sub('<person>', 'person', caption)
  97. # urls:
  98. caption = re.sub(
  99. r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
  100. '', caption) # regex for urls
  101. caption = re.sub(
  102. r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
  103. '', caption) # regex for urls
  104. # html:
  105. caption = BeautifulSoup(caption, features='html.parser').text
  106.  
  107. # @<nickname>
  108. caption = re.sub(r'@[\w\d]+\b', '', caption)
  109.  
  110. # 31C0—31EF CJK Strokes
  111. # 31F0—31FF Katakana Phonetic Extensions
  112. # 3200—32FF Enclosed CJK Letters and Months
  113. # 3300—33FF CJK Compatibility
  114. # 3400—4DBF CJK Unified Ideographs Extension A
  115. # 4DC0—4DFF Yijing Hexagram Symbols
  116. # 4E00—9FFF CJK Unified Ideographs
  117. caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
  118. caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
  119. caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
  120. caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
  121. caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
  122. caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
  123. caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
  124. #######################################################
  125.  
  126. # все виды тире / all types of dash --> "-"
  127. caption = re.sub(
  128. r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
  129. '-', caption)
  130.  
  131. # кавычки к одному стандарту
  132. caption = re.sub(r'[`´«»“”¨]', '"', caption)
  133. caption = re.sub(r'[‘’]', "'", caption)
  134.  
  135. # &quot;
  136. caption = re.sub(r'&quot;?', '', caption)
  137. # &amp
  138. caption = re.sub(r'&amp', '', caption)
  139.  
  140. # ip adresses:
  141. caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
  142.  
  143. # article ids:
  144. caption = re.sub(r'\d:\d\d\s+$', '', caption)
  145.  
  146. # \n
  147. caption = re.sub(r'\\n', ' ', caption)
  148.  
  149. # "#123"
  150. caption = re.sub(r'#\d{1,3}\b', '', caption)
  151. # "#12345.."
  152. caption = re.sub(r'#\d{5,}\b', '', caption)
  153. # "123456.."
  154. caption = re.sub(r'\b\d{6,}\b', '', caption)
  155. # filenames:
  156. caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
  157.  
  158. #
  159. caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
  160. caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
  161.  
  162. caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
  163. caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
  164.  
  165. # this-is-my-cute-cat / this_is_my_cute_cat
  166. regex2 = re.compile(r'(?:\-|\_)')
  167. if len(re.findall(regex2, caption)) > 3:
  168. caption = re.sub(regex2, ' ', caption)
  169.  
  170. caption = self.basic_clean(caption)
  171.  
  172. caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
  173. caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
  174. caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
  175.  
  176. caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
  177. caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
  178. caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
  179. caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
  180. caption = re.sub(r'\bpage\s+\d+\b', '', caption)
  181.  
  182. caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
  183.  
  184. caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
  185.  
  186. caption = re.sub(r'\b\s+\:\s+', r': ', caption)
  187. caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
  188. caption = re.sub(r'\s+', ' ', caption)
  189.  
  190. caption.strip()
  191.  
  192. caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
  193. caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
  194. caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
  195. caption = re.sub(r'^\.\S+$', '', caption)
  196.  
  197. return caption.strip()
  198.  
Advertisement
Add Comment
Please, Sign In to add comment