Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from pytorch_transformers import BertConfig
- from pytorch_transformers import BertModel
- config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
- num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
- class BertForSequenceClassification(nn.Module):
- def __init__(self, num_labels=2):
- super(BertForSequenceClassification, self).__init__()
- self.num_labels = num_labels
- self.bert = BertModel.from_pretrained('bert_based_uncased')
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, num_labels)
- nn.init.xavier_normal_(self.classifier.weight)
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
- _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- def freeze_bert_encoder(self):
- for param in self.bert.parameters():
- param.requires_grad = False
- def unfreeze_bert_encoder(self):
- for param in self.bert.parameters():
- param.requires_grad = True
- num_labels = 11
- model = BertForSequenceClassification(num_labels)
- model = torch.nn.DataParallel(model)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement