Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- class TransitionLayer(nn.Module):
- def __init__(self, in_channels=3):
- super(TransitionLayer, self).__init__()
- self.conv_1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, bias=False)
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
- def forward(self, x):
- x = F.relu(self.conv_1x1(x), inplace=True) # 224x224x3
- x = self.pool(x) # 112x112x3
- return x
- if __name__ == '__main__':
- # test transition layer
- transition = TransitionLayer()
- # random input of size: 224x224x3
- input_var = torch.randn((1,3,224,224))
- output = transition(input_var)
- print(output.shape)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement