Advertisement
Guest User

Untitled

a guest
Aug 22nd, 2019
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.73 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4.  
  5. class TransitionLayer(nn.Module):
  6.  
  7. def __init__(self, in_channels=3):
  8. super(TransitionLayer, self).__init__()
  9.  
  10. self.conv_1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, bias=False)
  11. self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
  12.  
  13. def forward(self, x):
  14. x = F.relu(self.conv_1x1(x), inplace=True) # 224x224x3
  15. x = self.pool(x) # 112x112x3
  16.  
  17. return x
  18.  
  19.  
  20.  
  21. if __name__ == '__main__':
  22. # test transition layer
  23. transition = TransitionLayer()
  24.  
  25. # random input of size: 224x224x3
  26. input_var = torch.randn((1,3,224,224))
  27.  
  28. output = transition(input_var)
  29. print(output.shape)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement