Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- from pytorch_transformers import RobertaModel
- class CustomRobertatModel(nn.Module):
- def __init__(self,num_labels=2):
- super(CustomRobertatModel,self).__init__()
- self.num_labels = num_labels
- self.roberta = RobertaModel.from_pretrained("roberta-base")
- self.dropout = nn.Dropout(.05)
- self.classifier = nn.Linear(768, num_labels)
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
- _ , pooled_output = self.roberta(input_ids, token_type_ids, attention_mask)
- logits = self.classifier(pooled_output)
- return logits
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement