Advertisement
Guest User

Untitled

a guest
Feb 23rd, 2020
164
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.59 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3.  
  4. from architectures.backbones.MobileNet import ConvBNReLU, mobilenet_v2
  5.  
  6.  
  7. class DepthWiseConv_No_ReLu(nn.Module):
  8.     """
  9.    Depth wise convolution used for the final bbox offset and class score predictions
  10.    """
  11.     def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=0):
  12.         super().__init__()
  13.         self.ds_conv = nn.Conv2d(in_planes, in_planes, kernel_size, groups=in_planes, padding=padding)
  14.         self.ds_bn = nn.BatchNorm2d(in_planes)
  15.         self.pw_conv = nn.Conv2d(in_planes, out_planes, kernel_size=1)
  16.  
  17.     def forward(self, x):
  18.         return self.pw_conv(self.ds_bn(self.ds_conv(x)))
  19.  
  20.  
  21. class DepthWiseConv(nn.Module):
  22.     """
  23.    depth wise followed by point wise convolution
  24.    """
  25.  
  26.     def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=0, bias=False):
  27.         super().__init__()
  28.         self.ds_conv = ConvBNReLU(in_planes, in_planes, kernel_size=kernel_size,
  29.                                   stride=stride, groups=in_planes, padding=padding, bias=False)
  30.         self.pw_conv = ConvBNReLU(in_planes, out_planes, kernel_size=1, bias=False)
  31.  
  32.     def forward(self, x):
  33.         return self.pw_conv(self.ds_conv(x))
  34.  
  35.  
  36. class SSD_Head(nn.Module):
  37.     def __init__(self, n_classes=81, k_list=[4, 6, 6, 6, 6, 6]):
  38.         super().__init__()
  39.         self.backbone = mobilenet_v2(pretrained=True, width_mult=1)
  40.         self.out_channels = [576, 1280, 512, 256, 256, 128]
  41.  
  42.         self.label_num = n_classes
  43.         self._build_additional_features(self.out_channels[1:-1])
  44.         self.num_defaults = k_list
  45.         self.loc = []
  46.         self.conf = []
  47.  
  48.         for nd, oc in zip(self.num_defaults, self.out_channels):
  49.             self.loc.append(DepthWiseConv_No_ReLu(oc, nd * 4, kernel_size=3, padding=1))
  50.             self.conf.append(DepthWiseConv_No_ReLu(
  51.                 oc, nd * self.label_num, kernel_size=3, padding=1))
  52.  
  53.         self.loc = nn.ModuleList(self.loc)
  54.         self.conf = nn.ModuleList(self.conf)
  55.         self._init_weights()
  56.  
  57.     def _build_additional_features(self, input_size):
  58.         self.additional_blocks = []
  59.         for i, (input_size, output_size) in enumerate(zip(input_size[:-1], input_size[1:])):
  60.             layer = DepthWiseConv(input_size, output_size, kernel_size=3,
  61.                                   padding=1, stride=2)
  62.             self.additional_blocks.append(layer)
  63.  
  64.         self.additional_blocks.append(DepthWiseConv(256, 128, kernel_size=2))
  65.  
  66.         self.additional_blocks = nn.ModuleList(self.additional_blocks)
  67.  
  68.     def _init_weights(self):
  69.         layers = [*self.additional_blocks, *self.loc, *self.conf]
  70.         for layer in layers:
  71.             for param in layer.parameters():
  72.                 if param.dim() > 1:
  73.                     nn.init.xavier_uniform_(param)
  74.  
  75.     # Shape the classifier to the view of bboxes
  76.     def bbox_view(self, src, loc, conf):
  77.         ret = []
  78.         for s, l, c in zip(src, loc, conf):
  79.             ret.append((l(s).view(s.size(0), 4, -1), c(s).view(s.size(0), self.label_num, -1)))
  80.  
  81.         locs, confs = list(zip(*ret))
  82.         locs, confs = torch.cat(locs, 2).contiguous(), torch.cat(confs, 2).contiguous()
  83.         return locs, confs
  84.  
  85.     def forward(self, x):
  86.         inter_layer, x = self.backbone(x)
  87.         detection_feed = [inter_layer, x]
  88.         for l in self.additional_blocks:
  89.             x = l(x)
  90.             detection_feed.append(x)
  91.  
  92.         # Feature Maps 19x19, 10x10, 5x5, 3x3, 1x1
  93.         locs, confs = self.bbox_view(detection_feed, self.loc, self.conf)
  94.  
  95.         return locs, confs
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement