Advertisement
Guest User

Untitled

a guest
Mar 19th, 2019
127
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.60 KB | None | 0 0
  1. import numpy as np
  2. import torch
  3. from pytorch_pretrained_bert import BertTokenizer, BertModel
  4.  
  5.  
  6. class BertEncoder(object):
  7. _BERT_AVAILABLE_MODELS_NAMES = ['bert-base-uncased']
  8. _CLS_TOKEN = '[CLS]'
  9. _SEP_TOKEN = '[SEP]'
  10. _MAX_SEQ_LENGTH = 32
  11. _BATCH_SIZE = 32
  12.  
  13. def __init__(self, bert_model_name='bert-base-uncased', device_type='cuda:0'):
  14. if bert_model_name not in self._BERT_AVAILABLE_MODELS_NAMES:
  15. raise ValueError('{} model name is not supported. Use one of the available models: {}'.format(
  16. bert_model_name, self._BERT_AVAILABLE_MODELS_NAMES))
  17.  
  18. self._tokenizer = BertTokenizer.from_pretrained(bert_model_name, do_lower_case=True)
  19.  
  20. self._bert_model = BertModel.from_pretrained(bert_model_name)
  21. self._device = torch.device(device_type)
  22. self._bert_model.to(self._device)
  23.  
  24. def _get_indexed_texts(self, texts, max_seq_length):
  25. texts_ids = []
  26. texts_masks = []
  27. for text in texts:
  28. tokenized_text = self._tokenizer.tokenize(text)
  29.  
  30. # Account for [CLS] and [SEP]
  31. if len(tokenized_text) > max_seq_length - 2:
  32. tokenized_text = tokenized_text[0:(max_seq_length - 2)]
  33.  
  34. tokenized_text = [self._CLS_TOKEN] + tokenized_text + [self._SEP_TOKEN]
  35. tokens_ids = self._tokenizer.convert_tokens_to_ids(tokenized_text)
  36.  
  37. # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
  38. input_mask = [1] * len(tokens_ids)
  39.  
  40. # Zero-pad up to the sequence length.
  41. while len(tokens_ids) < max_seq_length:
  42. tokens_ids.append(0)
  43. input_mask.append(0)
  44.  
  45. texts_ids.append(tokens_ids)
  46. texts_masks.append(input_mask)
  47.  
  48. return texts_ids, texts_masks
  49.  
  50. def get_batches_iter(iterable, batch_size, return_incomplete_batch=True):
  51. batch = []
  52. for sample in iterable:
  53. batch.append(sample)
  54.  
  55. if len(batch) == batch_size:
  56. yield batch
  57. batch = []
  58.  
  59. if batch and return_incomplete_batch:
  60. yield batch
  61.  
  62. def encode(self, texts, max_seq_length=_MAX_SEQ_LENGTH, batch_size=_BATCH_SIZE):
  63. """
  64. :param texts: list of strings for encoding
  65. :param max_seq_length: The maximum total input sequence length after WordPiece tokenization. Sequences longer
  66. than this will be truncated, and sequences shorter than this will be padded."
  67. :return:
  68. """
  69. encodings = []
  70. for text_batch in get_batches_iter(texts, batch_size):
  71. texts_ids, texts_masks = self._get_indexed_texts(text_batch, max_seq_length)
  72. texts_tensor = torch.LongTensor(texts_ids).to(self._device)
  73. masks_tensor = torch.LongTensor(texts_masks).to(self._device)
  74.  
  75. all_encoder_layers, _ = self._bert_model(
  76. texts_tensor, token_type_ids=None, attention_mask=masks_tensor, output_all_encoded_layers=True)
  77.  
  78. text_batch_encodings = None
  79. for layer in [-1, -2, -3, -4]:
  80. encoder_layer = all_encoder_layers[layer].cpu().detach().numpy()
  81. encoder_layer = encoder_layer[:, 0, :]
  82. if text_batch_encodings is None:
  83. text_batch_encodings = encoder_layer
  84. else:
  85. #text_batch_encodings = np.sum([text_batch_encodings, encoder_layer], axis=0)
  86. text_batch_encodings = np.append(text_batch_encodings, encoder_layer, axis=1)
  87.  
  88. encodings.extend(text_batch_encodings)
  89.  
  90. return np.array(encodings)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement