Advertisement
Guest User

Untitled

a guest
Feb 17th, 2020
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.80 KB | None | 0 0
  1. import chainer
  2. import chainer.functions as F
  3. import chainer.links as L
  4. class ResidualBlock(nn.Module):
  5.     def __init__(self, include_center, in_channels, out_channels, filter_size):
  6.         super(ResidualBlock, self).__init__()
  7.         self.vertical_conv_t = nn.Conv2d(in_channels, out_channels, kernel_size=[filter_size//2+1, filter_size],padding=[filter_size//2+1, filter_size//2])
  8.         self.vertical_conv_s = nn.Conv2d(in_channels, out_channels, kernel_size=[filter_size//2+1, filter_size],padding=[filter_size//2+1, filter_size//2])
  9.         self.v_to_h_conv_t=nn.Conv2d(out_channels, out_channels, 1)
  10.         self.v_to_h_conv_s=nn.Conv2d(out_channels, out_channels, 1)
  11.         self.horizontal_conv_t=MaskedCNN(include_center, in_channels=in_channels, out_channels=out_channels, kernel_size=[1, filter_size],padding=[0, filter_size // 2])
  12.         self.horizontal_conv_s=MaskedCNN(include_center, in_channels=in_channels, out_channels=out_channels, kernel_size=[1, filter_size],padding=[0, filter_size // 2])
  13.         self.horizontal_output=MaskedCNN(include_center, in_channels=out_channels, out_channels=out_channels, kernel_size=1)
  14.         self.label=L.EmbedID(10, out_channels)
  15.  
  16.     def forward(self, v, h):
  17.         v_t = self.vertical_conv_t(v)
  18.         v_s = self.vertical_conv_s(v)
  19.         to_vertical_t = self.v_to_h_conv_t(v_t)
  20.         to_vertical_s = self.v_to_h_conv_s(v_s)
  21.         # label = F.broadcast_to(F.expand_dims(F.expand_dims(self.label(label), -1), -1), v_t.shape)
  22.         # v_t, v_s = v_t + label, v_s + label
  23.         v = F.tanh(v_t) * F.sigmoid(v_s)
  24.  
  25.         h_t = self.horizontal_conv_t(h)
  26.         h_s = self.horizontal_conv_s(h)
  27.         h_t, h_s = h_t + to_vertical_t, h_s + to_vertical_s
  28.        
  29.         h = self.horizontal_output(F.tanh(h_t) * F.sigmoid(h_s))
  30.  
  31.         return v, h
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement