Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import chainer
- import chainer.functions as F
- import chainer.links as L
- class ResidualBlock(nn.Module):
- def __init__(self, include_center, in_channels, out_channels, filter_size):
- super(ResidualBlock, self).__init__()
- self.vertical_conv_t = nn.Conv2d(in_channels, out_channels, kernel_size=[filter_size//2+1, filter_size],padding=[filter_size//2+1, filter_size//2])
- self.vertical_conv_s = nn.Conv2d(in_channels, out_channels, kernel_size=[filter_size//2+1, filter_size],padding=[filter_size//2+1, filter_size//2])
- self.v_to_h_conv_t=nn.Conv2d(out_channels, out_channels, 1)
- self.v_to_h_conv_s=nn.Conv2d(out_channels, out_channels, 1)
- self.horizontal_conv_t=MaskedCNN(include_center, in_channels=in_channels, out_channels=out_channels, kernel_size=[1, filter_size],padding=[0, filter_size // 2])
- self.horizontal_conv_s=MaskedCNN(include_center, in_channels=in_channels, out_channels=out_channels, kernel_size=[1, filter_size],padding=[0, filter_size // 2])
- self.horizontal_output=MaskedCNN(include_center, in_channels=out_channels, out_channels=out_channels, kernel_size=1)
- self.label=L.EmbedID(10, out_channels)
- def forward(self, v, h):
- v_t = self.vertical_conv_t(v)
- v_s = self.vertical_conv_s(v)
- to_vertical_t = self.v_to_h_conv_t(v_t)
- to_vertical_s = self.v_to_h_conv_s(v_s)
- # label = F.broadcast_to(F.expand_dims(F.expand_dims(self.label(label), -1), -1), v_t.shape)
- # v_t, v_s = v_t + label, v_s + label
- v = F.tanh(v_t) * F.sigmoid(v_s)
- h_t = self.horizontal_conv_t(h)
- h_s = self.horizontal_conv_s(h)
- h_t, h_s = h_t + to_vertical_t, h_s + to_vertical_s
- h = self.horizontal_output(F.tanh(h_t) * F.sigmoid(h_s))
- return v, h
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement