Advertisement
Guest User

Untitled

a guest
Sep 20th, 2019
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.65 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from pytorch_transformers import RobertaModel
  4.  
  5. class CustomRobertatModel(nn.Module):
  6. def __init__(self,num_labels=2):
  7. super(CustomRobertatModel,self).__init__()
  8. self.num_labels = num_labels
  9. self.roberta = RobertaModel.from_pretrained("roberta-base")
  10. self.dropout = nn.Dropout(.05)
  11. self.classifier = nn.Linear(768, num_labels)
  12.  
  13. def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
  14. _ , pooled_output = self.roberta(input_ids, token_type_ids, attention_mask)
  15. logits = self.classifier(pooled_output)
  16. return logits
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement