daily pastebin goal
46%
SHARE
TWEET

Untitled

a guest Mar 19th, 2019 61 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. import torch
  3. from pytorch_pretrained_bert import BertTokenizer, BertModel
  4.  
  5.  
  6. class BertEncoder(object):
  7.     _BERT_AVAILABLE_MODELS_NAMES = ['bert-base-uncased']
  8.     _CLS_TOKEN = '[CLS]'
  9.     _SEP_TOKEN = '[SEP]'
  10.     _MAX_SEQ_LENGTH = 32
  11.     _BATCH_SIZE = 32
  12.  
  13.     def __init__(self, bert_model_name='bert-base-uncased', device_type='cuda:0'):
  14.         if bert_model_name not in self._BERT_AVAILABLE_MODELS_NAMES:
  15.             raise ValueError('{} model name is not supported. Use one of the available models: {}'.format(
  16.                 bert_model_name, self._BERT_AVAILABLE_MODELS_NAMES))
  17.  
  18.         self._tokenizer = BertTokenizer.from_pretrained(bert_model_name, do_lower_case=True)
  19.  
  20.         self._bert_model = BertModel.from_pretrained(bert_model_name)
  21.         self._device = torch.device(device_type)
  22.         self._bert_model.to(self._device)
  23.  
  24.     def _get_indexed_texts(self, texts, max_seq_length):
  25.         texts_ids = []
  26.         texts_masks = []
  27.         for text in texts:
  28.             tokenized_text = self._tokenizer.tokenize(text)
  29.  
  30.             # Account for [CLS] and [SEP]
  31.             if len(tokenized_text) > max_seq_length - 2:
  32.                 tokenized_text = tokenized_text[0:(max_seq_length - 2)]
  33.  
  34.             tokenized_text = [self._CLS_TOKEN] + tokenized_text + [self._SEP_TOKEN]
  35.             tokens_ids = self._tokenizer.convert_tokens_to_ids(tokenized_text)
  36.  
  37.             # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
  38.             input_mask = [1] * len(tokens_ids)
  39.  
  40.             # Zero-pad up to the sequence length.
  41.             while len(tokens_ids) < max_seq_length:
  42.                 tokens_ids.append(0)
  43.                 input_mask.append(0)
  44.  
  45.             texts_ids.append(tokens_ids)
  46.             texts_masks.append(input_mask)
  47.  
  48.         return texts_ids, texts_masks
  49.      
  50.     def get_batches_iter(iterable, batch_size, return_incomplete_batch=True):
  51.         batch = []
  52.         for sample in iterable:
  53.             batch.append(sample)
  54.  
  55.             if len(batch) == batch_size:
  56.                 yield batch
  57.                 batch = []
  58.  
  59.         if batch and return_incomplete_batch:
  60.             yield batch
  61.  
  62.     def encode(self, texts, max_seq_length=_MAX_SEQ_LENGTH, batch_size=_BATCH_SIZE):
  63.         """
  64.         :param texts: list of strings for encoding
  65.         :param max_seq_length: The maximum total input sequence length after WordPiece tokenization. Sequences longer
  66.             than this will be truncated, and sequences shorter than this will be padded."
  67.         :return:
  68.         """
  69.         encodings = []
  70.         for text_batch in get_batches_iter(texts, batch_size):
  71.             texts_ids, texts_masks = self._get_indexed_texts(text_batch, max_seq_length)
  72.             texts_tensor = torch.LongTensor(texts_ids).to(self._device)
  73.             masks_tensor = torch.LongTensor(texts_masks).to(self._device)
  74.  
  75.             all_encoder_layers, _ = self._bert_model(
  76.                 texts_tensor, token_type_ids=None, attention_mask=masks_tensor, output_all_encoded_layers=True)
  77.  
  78.             text_batch_encodings = None
  79.             for layer in [-1, -2, -3, -4]:
  80.                 encoder_layer = all_encoder_layers[layer].cpu().detach().numpy()
  81.                 encoder_layer = encoder_layer[:, 0, :]
  82.                 if text_batch_encodings is None:
  83.                     text_batch_encodings = encoder_layer
  84.                 else:
  85.                     #text_batch_encodings = np.sum([text_batch_encodings, encoder_layer], axis=0)
  86.                     text_batch_encodings = np.append(text_batch_encodings, encoder_layer, axis=1)
  87.  
  88.             encodings.extend(text_batch_encodings)
  89.  
  90.         return np.array(encodings)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top