SHARE
TWEET

Untitled

a guest Aug 20th, 2019 69 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from pytorch_transformers import BertConfig
  2. from pytorch_transformers import BertModel
  3.  
  4. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  5.         num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  6.        
  7. class BertForSequenceClassification(nn.Module):
  8.     def __init__(self, num_labels=2):
  9.         super(BertForSequenceClassification, self).__init__()
  10.         self.num_labels = num_labels
  11.         self.bert = BertModel.from_pretrained('bert_based_uncased')
  12.         self.dropout = nn.Dropout(config.hidden_dropout_prob)
  13.         self.classifier = nn.Linear(config.hidden_size, num_labels)
  14.  
  15.         nn.init.xavier_normal_(self.classifier.weight)
  16.  
  17.     def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
  18.         _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
  19.         pooled_output = self.dropout(pooled_output)
  20.         logits = self.classifier(pooled_output)
  21.    
  22.     def freeze_bert_encoder(self):
  23.         for param in self.bert.parameters():
  24.             param.requires_grad = False
  25.    
  26.     def unfreeze_bert_encoder(self):
  27.         for param in self.bert.parameters():
  28.             param.requires_grad = True
  29.        
  30. num_labels = 11
  31. model = BertForSequenceClassification(num_labels)
  32. model = torch.nn.DataParallel(model)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top