Advertisement
warrior98

Untitled

Feb 22nd, 2020
563
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.80 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
  4.  
  5. from architectures.models.SSDNet import DepthWiseConv
  6. from architectures.backbones.MobileNet import mobilenet_v2, ConvBNReLU
  7.  
  8.  
  9. # taken from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Detection/SSD
  10.  
  11.  
  12. class DepthWiseConv_noBN(nn.Module):
  13.     """
  14.    depth wise followed by point wise convolution
  15.    """
  16.  
  17.     def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=0):
  18.         super().__init__()
  19.         self.ds_conv = ConvBNReLU(in_planes, in_planes, kernel_size=kernel_size,
  20.                                   stride=stride, groups=in_planes, padding=padding)
  21.         self.pw_conv = nn.Conv2d(in_planes, out_planes, kernel_size=1)
  22.  
  23.     def forward(self, x):
  24.         return self.pw_conv(self.ds_conv(x))
  25.  
  26.  
  27. class ResNet(nn.Module):
  28.     def __init__(self, backbone='resnet50', backbone_path=None):
  29.         super().__init__()
  30.         if backbone == 'resnet18':
  31.             backbone = resnet18(pretrained=not backbone_path)
  32.             self.out_channels = [256, 512, 512, 256, 256, 128]
  33.         elif backbone == 'resnet34':
  34.             backbone = resnet34(pretrained=not backbone_path)
  35.             self.out_channels = [256, 512, 512, 256, 256, 256]
  36.         elif backbone == 'resnet50':
  37.             backbone = resnet50(pretrained=not backbone_path)
  38.             self.out_channels = [1024, 512, 512, 256, 256, 256]
  39.         elif backbone == 'resnet101':
  40.             backbone = resnet101(pretrained=not backbone_path)
  41.             self.out_channels = [1024, 512, 512, 256, 256, 256]
  42.         else:  # backbone == 'resnet152':
  43.             backbone = resnet152(pretrained=not backbone_path)
  44.             self.out_channels = [1024, 512, 512, 256, 256, 256]
  45.         if backbone_path:
  46.             backbone.load_state_dict(torch.load(backbone_path))
  47.  
  48.         self.feature_extractor = nn.Sequential(*list(backbone.children())[:7])
  49.  
  50.         conv4_block1 = self.feature_extractor[-1][0]
  51.  
  52.         conv4_block1.conv1.stride = (1, 1)
  53.         conv4_block1.conv2.stride = (1, 1)
  54.         conv4_block1.downsample[0].stride = (1, 1)
  55.  
  56.     def forward(self, x):
  57.         x = self.feature_extractor(x)
  58.         return x
  59.  
  60.  
  61. class SSD300(nn.Module):
  62.     def __init__(self, backbone=ResNet('resnet50'), n_classes=81):
  63.         super().__init__()
  64.  
  65.         # self.feature_extractor = backbone
  66.         # self.out_channels = [1024, 512, 512, 256, 256, 256]
  67.  
  68.         self.feature_extractor = mobilenet_v2(pretrained=True, width_mult=1)
  69.         self.out_channels = [576, 1280, 512, 256, 256, 128]
  70.  
  71.         self.label_num = n_classes  # number of COCO classes
  72.         self._build_additional_features(self.out_channels[1:-1])
  73.         self.num_defaults = [4, 6, 6, 6, 6, 6]
  74.         self.loc = []
  75.         self.conf = []
  76.  
  77.         for nd, oc in zip(self.num_defaults, self.out_channels):
  78.             self.loc.append(DepthWiseConv_noBN(oc, nd * 4, kernel_size=3, padding=1))
  79.             self.conf.append(DepthWiseConv_noBN(oc, nd * self.label_num, kernel_size=3, padding=1))
  80.  
  81.         self.loc = nn.ModuleList(self.loc)
  82.         self.conf = nn.ModuleList(self.conf)
  83.         self._init_weights()
  84.  
  85.     def _build_additional_features(self, input_size):
  86.         self.additional_blocks = []
  87.         for i, (input_size, output_size, channels) in enumerate(zip(input_size[:-1], input_size[1:], [256, 256, 128, 128, 128])):
  88.             if i < 3:
  89.                 layer = nn.Sequential(
  90.                     # nn.Conv2d(input_size, channels, kernel_size=1, bias=False),
  91.                     # nn.BatchNorm2d(channels),
  92.                     # nn.ReLU(inplace=True),
  93.                     DepthWiseConv(input_size, output_size, kernel_size=3,
  94.                                   padding=1, stride=2),
  95.                     # nn.BatchNorm2d(output_size),
  96.                     # nn.ReLU(inplace=True),
  97.                 )
  98.             else:
  99.                 layer = nn.Sequential(
  100.                     # nn.Conv2d(input_size, channels, kernel_size=1, bias=False),
  101.                     # nn.BatchNorm2d(channels),
  102.                     # nn.ReLU(inplace=True),
  103.                     DepthWiseConv(input_size, output_size, kernel_size=3),
  104.                     # nn.BatchNorm2d(output_size),
  105.                     # nn.ReLU(inplace=True),
  106.                 )
  107.  
  108.             self.additional_blocks.append(layer)
  109.  
  110.         self.additional_blocks.append(DepthWiseConv(256, 128, kernel_size=2))
  111.  
  112.         self.additional_blocks = nn.ModuleList(self.additional_blocks)
  113.  
  114.     def _init_weights(self):
  115.         layers = [*self.additional_blocks, *self.loc, *self.conf]
  116.         for layer in layers:
  117.             for param in layer.parameters():
  118.                 if param.dim() > 1:
  119.                     nn.init.xavier_uniform_(param)
  120.  
  121.     # Shape the classifier to the view of bboxes
  122.     def bbox_view(self, src, loc, conf):
  123.         ret = []
  124.         for s, l, c in zip(src, loc, conf):
  125.             ret.append((l(s).view(s.size(0), 4, -1), c(s).view(s.size(0), self.label_num, -1)))
  126.  
  127.         locs, confs = list(zip(*ret))
  128.         locs, confs = torch.cat(locs, 2).contiguous(), torch.cat(confs, 2).contiguous()
  129.         return locs, confs
  130.  
  131.     def forward(self, x):
  132.         if self.out_channels[0] == 576:
  133.             inter_layer, x = self.feature_extractor(x)
  134.  
  135.             detection_feed = [inter_layer, x]
  136.         else:
  137.             x = self.feature_extractor(x)
  138.             detection_feed = [x]
  139.  
  140.         for l in self.additional_blocks:
  141.             x = l(x)
  142.             detection_feed.append(x)
  143.  
  144.         # Feature Map 38x38x4, 19x19x6, 10x10x6, 5x5x6, 3x3x4, 1x1x4
  145.         locs, confs = self.bbox_view(detection_feed, self.loc, self.conf)
  146.  
  147.         # For SSD 300, shall return nbatch x 8732 x {nlabels, nlocs} results
  148.         return locs, confs
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement