Guest User

Untitled

a guest
May 20th, 2019
78
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torch.nn as nn
  3.  
  4. device = 'cuda' is torch.cuda.is_available() else 'cpu'
  5.  
  6. class NgramModule(nn.Module):
  7.  
  8. def __init__(self, seq_len, kernel, channels):
  9. super().__init__()
  10. self.conv = nn.Conv2d(1, channels, kernel)
  11. self.pool = nn.MaxPool1d(seq_len - kernel[0] + 1)
  12. self.relu = nn.ReLU(inplace=True)
  13.  
  14. def forward(self, x):
  15. x = self.relu(self.conv(x)).squeeze(-1)
  16. return self.pool(x).squeeze(-1)
  17.  
  18. class CNN(nn.Module):
  19.  
  20. def __init__(self, ksizes, n_channels, embedding_dim, seq_len,
  21. dropout, n_classes, pretrained):
  22. super().__init__()
  23. self.embedding = nn.Embedding.from_pretrained(pretrained, freeze=False)
  24. self.blocks = [NgramModule(seq_len, (ksize, embedding_dim), \
  25. n_channels).to(device) for ksize in ksizes]
  26. self.dropout = nn.Dropout(dropout)
  27. self.fc = nn.Linear(n_channels*len(ksizes), n_classes)
  28.  
  29. def forward(self, x):
  30. x = self.embedding(x)[:,None,:,:] # [B,1,T,E]
  31. x = torch.cat([block(x) for block in self.blocks], dim=-1).to(device)
  32. return self.fc(self.dropout(x))
RAW Paste Data