Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
- from architectures.models.SSDNet import DepthWiseConv
- from architectures.backbones.MobileNet import mobilenet_v2, ConvBNReLU
- # taken from https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Detection/SSD
- class DepthWiseConv_noBN(nn.Module):
- """
- depth wise followed by point wise convolution
- """
- def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, padding=0):
- super().__init__()
- self.ds_conv = ConvBNReLU(in_planes, in_planes, kernel_size=kernel_size,
- stride=stride, groups=in_planes, padding=padding)
- self.pw_conv = nn.Conv2d(in_planes, out_planes, kernel_size=1)
- def forward(self, x):
- return self.pw_conv(self.ds_conv(x))
- class ResNet(nn.Module):
- def __init__(self, backbone='resnet50', backbone_path=None):
- super().__init__()
- if backbone == 'resnet18':
- backbone = resnet18(pretrained=not backbone_path)
- self.out_channels = [256, 512, 512, 256, 256, 128]
- elif backbone == 'resnet34':
- backbone = resnet34(pretrained=not backbone_path)
- self.out_channels = [256, 512, 512, 256, 256, 256]
- elif backbone == 'resnet50':
- backbone = resnet50(pretrained=not backbone_path)
- self.out_channels = [1024, 512, 512, 256, 256, 256]
- elif backbone == 'resnet101':
- backbone = resnet101(pretrained=not backbone_path)
- self.out_channels = [1024, 512, 512, 256, 256, 256]
- else: # backbone == 'resnet152':
- backbone = resnet152(pretrained=not backbone_path)
- self.out_channels = [1024, 512, 512, 256, 256, 256]
- if backbone_path:
- backbone.load_state_dict(torch.load(backbone_path))
- self.feature_extractor = nn.Sequential(*list(backbone.children())[:7])
- conv4_block1 = self.feature_extractor[-1][0]
- conv4_block1.conv1.stride = (1, 1)
- conv4_block1.conv2.stride = (1, 1)
- conv4_block1.downsample[0].stride = (1, 1)
- def forward(self, x):
- x = self.feature_extractor(x)
- return x
- class SSD300(nn.Module):
- def __init__(self, backbone=ResNet('resnet50'), n_classes=81):
- super().__init__()
- # self.feature_extractor = backbone
- # self.out_channels = [1024, 512, 512, 256, 256, 256]
- self.feature_extractor = mobilenet_v2(pretrained=True, width_mult=1)
- self.out_channels = [576, 1280, 512, 256, 256, 128]
- self.label_num = n_classes # number of COCO classes
- self._build_additional_features(self.out_channels[1:-1])
- self.num_defaults = [4, 6, 6, 6, 6, 6]
- self.loc = []
- self.conf = []
- for nd, oc in zip(self.num_defaults, self.out_channels):
- self.loc.append(DepthWiseConv_noBN(oc, nd * 4, kernel_size=3, padding=1))
- self.conf.append(DepthWiseConv_noBN(oc, nd * self.label_num, kernel_size=3, padding=1))
- self.loc = nn.ModuleList(self.loc)
- self.conf = nn.ModuleList(self.conf)
- self._init_weights()
- def _build_additional_features(self, input_size):
- self.additional_blocks = []
- for i, (input_size, output_size, channels) in enumerate(zip(input_size[:-1], input_size[1:], [256, 256, 128, 128, 128])):
- if i < 3:
- layer = nn.Sequential(
- # nn.Conv2d(input_size, channels, kernel_size=1, bias=False),
- # nn.BatchNorm2d(channels),
- # nn.ReLU(inplace=True),
- DepthWiseConv(input_size, output_size, kernel_size=3,
- padding=1, stride=2),
- # nn.BatchNorm2d(output_size),
- # nn.ReLU(inplace=True),
- )
- else:
- layer = nn.Sequential(
- # nn.Conv2d(input_size, channels, kernel_size=1, bias=False),
- # nn.BatchNorm2d(channels),
- # nn.ReLU(inplace=True),
- DepthWiseConv(input_size, output_size, kernel_size=3),
- # nn.BatchNorm2d(output_size),
- # nn.ReLU(inplace=True),
- )
- self.additional_blocks.append(layer)
- self.additional_blocks.append(DepthWiseConv(256, 128, kernel_size=2))
- self.additional_blocks = nn.ModuleList(self.additional_blocks)
- def _init_weights(self):
- layers = [*self.additional_blocks, *self.loc, *self.conf]
- for layer in layers:
- for param in layer.parameters():
- if param.dim() > 1:
- nn.init.xavier_uniform_(param)
- # Shape the classifier to the view of bboxes
- def bbox_view(self, src, loc, conf):
- ret = []
- for s, l, c in zip(src, loc, conf):
- ret.append((l(s).view(s.size(0), 4, -1), c(s).view(s.size(0), self.label_num, -1)))
- locs, confs = list(zip(*ret))
- locs, confs = torch.cat(locs, 2).contiguous(), torch.cat(confs, 2).contiguous()
- return locs, confs
- def forward(self, x):
- if self.out_channels[0] == 576:
- inter_layer, x = self.feature_extractor(x)
- detection_feed = [inter_layer, x]
- else:
- x = self.feature_extractor(x)
- detection_feed = [x]
- for l in self.additional_blocks:
- x = l(x)
- detection_feed.append(x)
- # Feature Map 38x38x4, 19x19x6, 10x10x6, 5x5x6, 3x3x4, 1x1x4
- locs, confs = self.bbox_view(detection_feed, self.loc, self.conf)
- # For SSD 300, shall return nbatch x 8732 x {nlabels, nlocs} results
- return locs, confs
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement