Guest User

Untitled

a guest
Jun 19th, 2018
58
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.79 KB | None | 0 0
  1. from __future__ import print_function, absolute_import
  2.  
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. import torchvision.models as models
  7.  
  8. class FPN(nn.Module):
  9. def __init__(self,
  10. backbone='resnet50',
  11. feature_size=256,
  12. kernel_size=3,
  13. use_bias=True):
  14.  
  15. super(FPN, self).__init__()
  16. self.fs = feature_size
  17. self.ks = kernel_size
  18. self.backbone_c5_dim = 2048
  19. self.backbone_dims = [self.backbone_c5_dim,
  20. self.backbone_c5_dim / 2,
  21. self.backbone_c5_dim / 4,
  22. self.backbone_c5_dim / 8]
  23.  
  24. print('Building %s model'%backbone)
  25.  
  26. self.backbone = models.__dict__[backbone](pretrained=True)
  27.  
  28. self.lateral = nn.ModuleList()
  29. # Lateral connections of FPN
  30. for d in self.backbone_dims:
  31. self.lateral.append(
  32. nn.Conv2d(d, self.fs, kernel_size=1, stride=1, padding=0)
  33. )
  34.  
  35. # Top connections of FPN
  36. self.top = nn.ModuleList()
  37. for _ in self.backbone_dims:
  38. self.top.append(
  39. nn.Conv2d(self.fs, self.fs, kernel_size=3, stride=1, padding=1)
  40. )
  41. self.relu = nn.ReLU()
  42. self._initialize(self.top, bias=use_bias)
  43. self._initialize(self.lateral, bias=use_bias)
  44.  
  45. def _initialize(self, modules, bias=True):
  46. for param in modules:
  47. if isinstance(param, nn.Conv2d):
  48. nn.init.xavier_normal(param.weight)
  49. if bias:
  50. nn.init.constant(param.bias, 0.0)
  51.  
  52. def _upsample(self, x, y):
  53. _, _, H, W = y.size()
  54. return F.upsample(x, size=(H, W), mode='bilinear')
  55.  
  56.  
  57. def forward(self, input):
  58.  
  59. x = self.backbone.conv1(input)
  60. x = self.backbone.bn1(x)
  61. x = self.backbone.relu(x)
  62. c1 = self.backbone.maxpool(x)
  63.  
  64. c2 = self.backbone.layer1(c1)
  65. c3 = self.backbone.layer2(c2)
  66. c4 = self.backbone.layer3(c3)
  67. c5 = self.backbone.layer4(c4)
  68.  
  69. c = [c5, c4, c3, c2] # These are the intermediate outputs of backbone => stride 2^n
  70.  
  71. # FPN P-layers
  72. p = []
  73. p_up = None
  74. for i in range(4):
  75. _p = self.lateral[i](c[i])
  76. _p = self.relu(_p)
  77. if i > 0:
  78. _p = p_up + _p
  79. if i < len(c) - 1:
  80. p_up = self._upsample(_p, c[i + 1])
  81. _p = self.top[i](_p)
  82. p.append(_p)
  83.  
  84. return p
  85.  
  86.  
  87. def fpn(weights, **kwargs):
  88. model = FPN(**kwargs)
  89. if weights:
  90. model.load_state_dict(torch.load(weights)['state_dict'])
  91. return model
  92.  
  93.  
  94. if __name__ == '__main__':
  95. model = FPN()
  96. x = torch.autograd.Variable(torch.Tensor(4, 3, 256, 256))
  97. out = model(x)
Add Comment
Please, Sign In to add comment