Advertisement
Guest User

model.py

a guest
Oct 15th, 2020
61
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.36 KB | None | 0 0
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import math, copy, time
  6. import torchvision.models as models
  7. from torch.nn import TransformerDecoderLayer, TransformerDecoder
  8. from torch.nn.utils.rnn import pack_padded_sequence
  9. from torch.autograd import Variable
  10.  
  11. class EncoderCNN(nn.Module):
  12.     def __init__(self, embed_size):
  13.         super(EncoderCNN, self).__init__()
  14.         resnet = models.resnet152(pretrained=True)
  15.         self.resnet = nn.Sequential(*list(resnet.children())[:-2])
  16.         self.conv1 = nn.Conv2d(2048, embed_size, 1)
  17.         self.embed_size = embed_size
  18.  
  19.         self.fine_tune()
  20.        
  21.     def forward(self, images):
  22.         features = self.resnet(images)
  23.         batch_size, _,_,_ = features.shape
  24.         features = self.conv1(features)
  25.         features = features.view(batch_size, self.embed_size, -1)
  26.         features = features.permute(2, 0, 1)
  27.  
  28.         return features
  29.  
  30.     def fine_tune(self, fine_tune=True):
  31.         for p in self.resnet.parameters():
  32.             p.requires_grad = False
  33.         for c in list(self.resnet.children())[5:]:
  34.             for p in c.parameters():
  35.                 p.requires_grad = fine_tune
  36.  
  37. class PositionEncoder(nn.Module):
  38.     def __init__(self, d_model, dropout, max_len=5000):
  39.         super(PositionEncoder, self).__init__()
  40.         self.dropout = nn.Dropout(p=dropout)
  41.  
  42.         pe = torch.zeros(max_len, d_model)
  43.         position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
  44.         div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
  45.         pe[:, 0::2] = torch.sin(position * div_term)
  46.         pe[:, 1::2] = torch.cos(position * div_term)
  47.         pe = pe.unsqueeze(0).transpose(0, 1)
  48.         self.register_buffer('pe', pe)
  49.  
  50.     def forward(self, x):
  51.         x = x + self.pe[:x.size(0), :]
  52.         return self.dropout(x)
  53.    
  54. class Embedder(nn.Module):
  55.     def __init__(self, vocab_size, d_model):
  56.         super().__init__()
  57.         self.embed = nn.Embedding(vocab_size, d_model)
  58.     def forward(self, x):
  59.         return self.embed(x)
  60.  
  61.  
  62. class Transformer(nn.Module):
  63.     def __init__(self, vocab_size, d_model, h, num_hidden, N, device, dropout_dec=0.1, dropout_pos=0.1):
  64.         super(Transformer, self).__init__()
  65.         decoder_layers = TransformerDecoderLayer(d_model, h, num_hidden, dropout_dec)
  66.         self.source_mask = None
  67.         self.device = device
  68.         self.d_model = d_model
  69.         self.pos_decoder = PositionalEncoder(d_model, dropout_pos)
  70.         self.decoder = TransformerDecoder(decoder_layers, N)
  71.         self.embed = Embedder(vocab_size, d_model)
  72.         self.linear = nn.Linear(d_model, vocab_size)
  73.  
  74.         self.init_weights()
  75.  
  76.     def forward(self, source, mem):
  77.         source = source.permute(1,0)
  78.         if self.source_mask is None or self.source_mask.size(0) != len(source):
  79.             self.source_mask = nn.Transformer.generate_square_subsequent_mask(self=self, sz=len(source)).to(self.device)
  80.  
  81.         source = self.embed(source)
  82.         source = source*math.sqrt(self.d_model)  
  83.         source = self.pos_decoder(source)
  84.         output = self.decoder(source, mem, self.source_mask)
  85.         output = self.linear(output)
  86.         return output
  87.  
  88.     def init_weights(self):
  89.         initrange = 0.1
  90.         self.linear.bias.data.zero_()
  91.         self.linear.weight.data.uniform_(-initrange, initrange)
  92.  
  93.     def pred(self, memory, pred_len):
  94.         batch_size = memory.size(1)
  95.         src = torch.ones((pred_len, batch_size), dtype=int) * 2
  96.         if self.source_mask is None or self.source_mask.size(0) != len(src):
  97.             self.source_mask = nn.Transformer.generate_square_subsequent_mask(self=self, sz=len(src)).to(self.device)
  98.         output = torch.ones((pred_len, batch_size), dtype=int)
  99.         src, output = src.cuda(), output.cuda()
  100.         for i in range(pred_len):
  101.             src_emb = self.embed(src) # src_len * batch size * embed size
  102.             src_emb = src_emb*math.sqrt(self.d_model)
  103.             src_emb = self.pos_decoder(src_emb)
  104.             out = self.decoder(src_emb, memory, self.source_mask)
  105.             out = out[i]
  106.             out = self.linear(out) # batch_size * vocab_size
  107.             out = out.argmax(dim=1)
  108.             if i < pred_len-1:
  109.                 src[i+1] = out
  110.             output[i] = out
  111.         return output
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement