Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import print_function, absolute_import
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torchvision.models as models
- class FPN(nn.Module):
- def __init__(self,
- backbone='resnet50',
- feature_size=256,
- kernel_size=3,
- use_bias=True):
- super(FPN, self).__init__()
- self.fs = feature_size
- self.ks = kernel_size
- self.backbone_c5_dim = 2048
- self.backbone_dims = [self.backbone_c5_dim,
- self.backbone_c5_dim / 2,
- self.backbone_c5_dim / 4,
- self.backbone_c5_dim / 8]
- print('Building %s model'%backbone)
- self.backbone = models.__dict__[backbone](pretrained=True)
- self.lateral = nn.ModuleList()
- # Lateral connections of FPN
- for d in self.backbone_dims:
- self.lateral.append(
- nn.Conv2d(d, self.fs, kernel_size=1, stride=1, padding=0)
- )
- # Top connections of FPN
- self.top = nn.ModuleList()
- for _ in self.backbone_dims:
- self.top.append(
- nn.Conv2d(self.fs, self.fs, kernel_size=3, stride=1, padding=1)
- )
- self.relu = nn.ReLU()
- self._initialize(self.top, bias=use_bias)
- self._initialize(self.lateral, bias=use_bias)
- def _initialize(self, modules, bias=True):
- for param in modules:
- if isinstance(param, nn.Conv2d):
- nn.init.xavier_normal(param.weight)
- if bias:
- nn.init.constant(param.bias, 0.0)
- def _upsample(self, x, y):
- _, _, H, W = y.size()
- return F.upsample(x, size=(H, W), mode='bilinear')
- def forward(self, input):
- x = self.backbone.conv1(input)
- x = self.backbone.bn1(x)
- x = self.backbone.relu(x)
- c1 = self.backbone.maxpool(x)
- c2 = self.backbone.layer1(c1)
- c3 = self.backbone.layer2(c2)
- c4 = self.backbone.layer3(c3)
- c5 = self.backbone.layer4(c4)
- c = [c5, c4, c3, c2] # These are the intermediate outputs of backbone => stride 2^n
- # FPN P-layers
- p = []
- p_up = None
- for i in range(4):
- _p = self.lateral[i](c[i])
- _p = self.relu(_p)
- if i > 0:
- _p = p_up + _p
- if i < len(c) - 1:
- p_up = self._upsample(_p, c[i + 1])
- _p = self.top[i](_p)
- p.append(_p)
- return p
- def fpn(weights, **kwargs):
- model = FPN(**kwargs)
- if weights:
- model.load_state_dict(torch.load(weights)['state_dict'])
- return model
- if __name__ == '__main__':
- model = FPN()
- x = torch.autograd.Variable(torch.Tensor(4, 3, 256, 256))
- out = model(x)
Add Comment
Please, Sign In to add comment