SHARE
TWEET

Untitled

a guest May 20th, 2019 68 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torch.nn as nn
  3.  
  4. device = 'cuda' is torch.cuda.is_available() else 'cpu'
  5.  
  6. class NgramModule(nn.Module):
  7.    
  8.     def __init__(self, seq_len, kernel, channels):
  9.         super().__init__()
  10.         self.conv = nn.Conv2d(1, channels, kernel)
  11.         self.pool = nn.MaxPool1d(seq_len - kernel[0] + 1)
  12.         self.relu = nn.ReLU(inplace=True)
  13.    
  14.     def forward(self, x):
  15.         x = self.relu(self.conv(x)).squeeze(-1)
  16.         return self.pool(x).squeeze(-1)
  17.  
  18. class CNN(nn.Module):
  19.  
  20.     def __init__(self, ksizes, n_channels, embedding_dim, seq_len,
  21.                  dropout, n_classes, pretrained):
  22.         super().__init__()
  23.         self.embedding = nn.Embedding.from_pretrained(pretrained, freeze=False)
  24.         self.blocks = [NgramModule(seq_len, (ksize, embedding_dim), \
  25.                        n_channels).to(device) for ksize in ksizes]
  26.         self.dropout = nn.Dropout(dropout)
  27.         self.fc = nn.Linear(n_channels*len(ksizes), n_classes)
  28.  
  29.     def forward(self, x):
  30.         x = self.embedding(x)[:,None,:,:] # [B,1,T,E]
  31.         x = torch.cat([block(x) for block in self.blocks], dim=-1).to(device)
  32.         return self.fc(self.dropout(x))
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top