Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import numpy as np
- import torch
- from pytorch_pretrained_bert import BertTokenizer, BertModel
- class BertEncoder(object):
- _BERT_AVAILABLE_MODELS_NAMES = ['bert-base-uncased']
- _CLS_TOKEN = '[CLS]'
- _SEP_TOKEN = '[SEP]'
- _MAX_SEQ_LENGTH = 32
- _BATCH_SIZE = 32
- def __init__(self, bert_model_name='bert-base-uncased', device_type='cuda:0'):
- if bert_model_name not in self._BERT_AVAILABLE_MODELS_NAMES:
- raise ValueError('{} model name is not supported. Use one of the available models: {}'.format(
- bert_model_name, self._BERT_AVAILABLE_MODELS_NAMES))
- self._tokenizer = BertTokenizer.from_pretrained(bert_model_name, do_lower_case=True)
- self._bert_model = BertModel.from_pretrained(bert_model_name)
- self._device = torch.device(device_type)
- self._bert_model.to(self._device)
- def _get_indexed_texts(self, texts, max_seq_length):
- texts_ids = []
- texts_masks = []
- for text in texts:
- tokenized_text = self._tokenizer.tokenize(text)
- # Account for [CLS] and [SEP]
- if len(tokenized_text) > max_seq_length - 2:
- tokenized_text = tokenized_text[0:(max_seq_length - 2)]
- tokenized_text = [self._CLS_TOKEN] + tokenized_text + [self._SEP_TOKEN]
- tokens_ids = self._tokenizer.convert_tokens_to_ids(tokenized_text)
- # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
- input_mask = [1] * len(tokens_ids)
- # Zero-pad up to the sequence length.
- while len(tokens_ids) < max_seq_length:
- tokens_ids.append(0)
- input_mask.append(0)
- texts_ids.append(tokens_ids)
- texts_masks.append(input_mask)
- return texts_ids, texts_masks
- def get_batches_iter(iterable, batch_size, return_incomplete_batch=True):
- batch = []
- for sample in iterable:
- batch.append(sample)
- if len(batch) == batch_size:
- yield batch
- batch = []
- if batch and return_incomplete_batch:
- yield batch
- def encode(self, texts, max_seq_length=_MAX_SEQ_LENGTH, batch_size=_BATCH_SIZE):
- """
- :param texts: list of strings for encoding
- :param max_seq_length: The maximum total input sequence length after WordPiece tokenization. Sequences longer
- than this will be truncated, and sequences shorter than this will be padded."
- :return:
- """
- encodings = []
- for text_batch in get_batches_iter(texts, batch_size):
- texts_ids, texts_masks = self._get_indexed_texts(text_batch, max_seq_length)
- texts_tensor = torch.LongTensor(texts_ids).to(self._device)
- masks_tensor = torch.LongTensor(texts_masks).to(self._device)
- all_encoder_layers, _ = self._bert_model(
- texts_tensor, token_type_ids=None, attention_mask=masks_tensor, output_all_encoded_layers=True)
- text_batch_encodings = None
- for layer in [-1, -2, -3, -4]:
- encoder_layer = all_encoder_layers[layer].cpu().detach().numpy()
- encoder_layer = encoder_layer[:, 0, :]
- if text_batch_encodings is None:
- text_batch_encodings = encoder_layer
- else:
- #text_batch_encodings = np.sum([text_batch_encodings, encoder_layer], axis=0)
- text_batch_encodings = np.append(text_batch_encodings, encoder_layer, axis=1)
- encodings.extend(text_batch_encodings)
- return np.array(encodings)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement