Advertisement
Guest User

Untitled

a guest
Aug 20th, 2019
100
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.29 KB | None | 0 0
  1. from pytorch_transformers import BertConfig
  2. from pytorch_transformers import BertModel
  3.  
  4. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  5. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  6.  
  7. class BertForSequenceClassification(nn.Module):
  8. def __init__(self, num_labels=2):
  9. super(BertForSequenceClassification, self).__init__()
  10. self.num_labels = num_labels
  11. self.bert = BertModel.from_pretrained('bert_based_uncased')
  12. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  13. self.classifier = nn.Linear(config.hidden_size, num_labels)
  14.  
  15. nn.init.xavier_normal_(self.classifier.weight)
  16.  
  17. def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
  18. _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
  19. pooled_output = self.dropout(pooled_output)
  20. logits = self.classifier(pooled_output)
  21.  
  22. def freeze_bert_encoder(self):
  23. for param in self.bert.parameters():
  24. param.requires_grad = False
  25.  
  26. def unfreeze_bert_encoder(self):
  27. for param in self.bert.parameters():
  28. param.requires_grad = True
  29.  
  30. num_labels = 11
  31. model = BertForSequenceClassification(num_labels)
  32. model = torch.nn.DataParallel(model)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement