Advertisement
VssA

bert

Jul 19th, 2023 (edited)
128
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.59 KB | None | 0 0
  1. from sentence_transformers import SentenceTransformer, models
  2. from transformers import BertTokenizer
  3. import torch
  4.  
  5. if torch.cuda.is_available():
  6.     device = torch.device("cuda")
  7. else:
  8.     device = torch.device("cpu")
  9.  
  10. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  11.  
  12.  
  13. class BertForSTS(torch.nn.Module):
  14.  
  15.     def __init__(self):
  16.         super(BertForSTS, self).__init__()
  17.         self.bert = models.Transformer('bert-base-uncased', max_seq_length=128)
  18.         self.pooling_layer = models.Pooling(self.bert.get_word_embedding_dimension())
  19.         self.sts_bert = SentenceTransformer(modules=[self.bert, self.pooling_layer])
  20.  
  21.     def forward(self, input_data):
  22.         output = self.sts_bert(input_data)['sentence_embedding']
  23.         return output
  24.  
  25.  
  26. def predict_similarity(sentence_pair):
  27.     test_input = tokenizer(sentence_pair, padding='max_length', max_length=128, truncation=True,
  28.                            return_tensors="pt").to(device)
  29.     test_input['input_ids'] = test_input['input_ids']
  30.     test_input['attention_mask'] = test_input['attention_mask']
  31.     del test_input['token_type_ids']
  32.     output = model(test_input)
  33.     sim = torch.nn.functional.cosine_similarity(output[0], output[1], dim=0).item()
  34.     return sim
  35.  
  36.  
  37. if __name__ == '__main__':
  38.     PATH = 'bert-sts.pt'
  39.     model = BertForSTS()
  40.     model.load_state_dict(torch.load(f"{PATH}", map_location=torch.device(device=device)))
  41.     model.eval()
  42.     first_text = ['хочу сказать этому обэме', 'хочу сказать этому байдену']
  43.     print(predict_similarity(first_text))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement