Advertisement
Guest User

Untitled

a guest
Aug 22nd, 2019
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.15 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4.  
  5. class DenseBlock(nn.Module):
  6.  
  7. def __init__(self, in_channels=3, k= 32):
  8. super(DenseBlock, self).__init__()
  9.  
  10. self.conv_1x1 = nn.Conv2d(in_channels, 2*k, kernel_size=1, stride=1, padding=0, bias=False)
  11. self.conv_3x3_1 = nn.Conv2d(2*k, k//2, kernel_size=3, stride=1, padding=1, bias=False)
  12. self.conv_3x3_2 = nn.Conv2d(k//2, k//2, kernel_size=3, stride=1, padding=1, bias=False)
  13.  
  14. def forward(self, x):
  15. x1 = F.relu(self.conv_1x1(x), inplace=True) # 224x224x2k
  16. x2 = F.relu(self.conv_3x3_1(x1), inplace=True) # 224x224xk/2
  17.  
  18. x3 = F.relu(self.conv_1x1(x), inplace=True) # 224x224x2k
  19. x4 = F.relu(self.conv_3x3_1(x3), inplace=True) # 224x224xk/2
  20. x5 = F.relu(self.conv_3x3_2(x4), inplace=True) # 224x224xk/2
  21.  
  22. x_concat = torch.cat([x2, x5], 1) # 224x224xk
  23.  
  24. return x_concat
  25.  
  26.  
  27.  
  28. if __name__ == '__main__':
  29. # test dense_block
  30. dense_block = DenseBlock()
  31.  
  32. # random input of size: 224x224x3
  33. input_var = torch.randn((1,3,224,224))
  34.  
  35. output = dense_block(input_var)
  36. print(output.shape)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement