Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class DenseBlock(nn.Module):
- def __init__(self, in_channels=3, k= 32):
- super(DenseBlock, self).__init__()
- self.conv_1x1 = nn.Conv2d(in_channels, 2*k, kernel_size=1, stride=1, padding=0, bias=False)
- self.conv_3x3_1 = nn.Conv2d(2*k, k//2, kernel_size=3, stride=1, padding=1, bias=False)
- self.conv_3x3_2 = nn.Conv2d(k//2, k//2, kernel_size=3, stride=1, padding=1, bias=False)
- def forward(self, x):
- x1 = F.relu(self.conv_1x1(x), inplace=True) # 224x224x2k
- x2 = F.relu(self.conv_3x3_1(x1), inplace=True) # 224x224xk/2
- x3 = F.relu(self.conv_1x1(x), inplace=True) # 224x224x2k
- x4 = F.relu(self.conv_3x3_1(x3), inplace=True) # 224x224xk/2
- x5 = F.relu(self.conv_3x3_2(x4), inplace=True) # 224x224xk/2
- x_concat = torch.cat([x2, x5], 1) # 224x224xk
- return x_concat
- if __name__ == '__main__':
- # test dense_block
- dense_block = DenseBlock()
- # random input of size: 224x224x3
- input_var = torch.randn((1,3,224,224))
- output = dense_block(input_var)
- print(output.shape)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement