Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- device = 'cuda' is torch.cuda.is_available() else 'cpu'
- class NgramModule(nn.Module):
- def __init__(self, seq_len, kernel, channels):
- super().__init__()
- self.conv = nn.Conv2d(1, channels, kernel)
- self.pool = nn.MaxPool1d(seq_len - kernel[0] + 1)
- self.relu = nn.ReLU(inplace=True)
- def forward(self, x):
- x = self.relu(self.conv(x)).squeeze(-1)
- return self.pool(x).squeeze(-1)
- class CNN(nn.Module):
- def __init__(self, ksizes, n_channels, embedding_dim, seq_len,
- dropout, n_classes, pretrained):
- super().__init__()
- self.embedding = nn.Embedding.from_pretrained(pretrained, freeze=False)
- self.blocks = [NgramModule(seq_len, (ksize, embedding_dim), \
- n_channels).to(device) for ksize in ksizes]
- self.dropout = nn.Dropout(dropout)
- self.fc = nn.Linear(n_channels*len(ksizes), n_classes)
- def forward(self, x):
- x = self.embedding(x)[:,None,:,:] # [B,1,T,E]
- x = torch.cat([block(x) for block in self.blocks], dim=-1).to(device)
- return self.fc(self.dropout(x))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement