Advertisement
Guest User

Untitled

a guest
Mar 26th, 2021
151
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 15.77 KB | None | 0 0
  1. import numpy as np
  2.  
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6.  
  7. def conv3x3(in_planes, out_planes, stride=1):
  8.     """3x3 convolution with padding"""
  9.     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  10.                      padding=1, bias=False)
  11.  
  12. class BasicBlock(nn.Module):
  13.     def __init__(self, inplanes, planes, stride=1, downsample=False):
  14.         super(BasicBlock, self).__init__()
  15.         self.expansion = 1
  16.         self.conv1 = conv3x3(inplanes, planes, stride)
  17.         self.bn1 = nn.BatchNorm2d(planes, momentum=0.01)
  18.         self.relu = nn.ReLU(inplace=True)
  19.         self.conv2 = conv3x3(planes, planes)
  20.         self.bn2 = nn.BatchNorm2d(planes, momentum=0.01)
  21.  
  22.         if downsample and stride != 1 or inplanes != planes * self.expansion:
  23.             self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
  24.                                        nn.BatchNorm2d(planes, momentum=0.01))
  25.         else:
  26.             self.downsample = nn.Identity()
  27.         self.stride = stride
  28.  
  29.     def forward(self, x):
  30.         out = self.conv1(x)
  31.         out = self.bn1(out)
  32.         out = self.relu(out)
  33.         out = self.conv2(out)
  34.         out = self.bn2(out)
  35.  
  36.         residual = self.downsample(x)
  37.  
  38.         out += residual
  39.         out = self.relu(out)
  40.  
  41.         return out
  42.  
  43.  
  44. class Bottleneck(nn.Module):
  45.     def __init__(self, inplanes, planes, stride=1, downsample=False):
  46.         super(Bottleneck, self).__init__()
  47.         self.expansion = 4
  48.  
  49.         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  50.         self.bn1 = nn.BatchNorm2d(planes, momentum=0.01)
  51.         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
  52.                                padding=1, bias=False)
  53.         self.bn2 = nn.BatchNorm2d(planes, momentum=0.01)
  54.         self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
  55.                                bias=False)
  56.         self.bn3 = nn.BatchNorm2d(planes * self.expansion,
  57.                                   momentum=0.01)
  58.         self.relu = nn.ReLU(inplace=True)
  59.  
  60.         if downsample and stride != 1 or inplanes != planes * self.expansion:
  61.             self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes*self.expansion, kernel_size=1, stride=stride,
  62.                                                       bias=False),
  63.                                             nn.BatchNorm2d(planes*self.expansion, momentum=0.01))
  64.         else:
  65.             self.downsample = nn.Identity()
  66.         self.stride = stride
  67.  
  68.     def forward(self, x):
  69.         out = self.conv1(x)
  70.         out = self.bn1(out)
  71.         out = self.relu(out)
  72.         out = self.conv2(out)
  73.         out = self.bn2(out)
  74.         out = self.relu(out)
  75.         out = self.conv3(out)
  76.         out = self.bn3(out)
  77.         residual = self.downsample(x)
  78.  
  79.         out += residual
  80.         out = self.relu(out)
  81.  
  82.         return out
  83.  
  84.  
  85. def expansion(block):
  86.     if block.__name__ == 'Bottleneck':
  87.         return 4
  88.     elif block.__name__ == 'BasicBlock':
  89.         return 1
  90.  
  91.  
  92. class HighResolutionModule(nn.Module):
  93.     def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
  94.                  num_channels, fuse_method, multi_scale_output=True):
  95.         super(HighResolutionModule, self).__init__()
  96.         self.num_inchannels = num_inchannels
  97.         self.fuse_method = fuse_method
  98.         self.num_branches = num_branches
  99.  
  100.         self.multi_scale_output = multi_scale_output
  101.  
  102.         self.branches = self._make_branches(
  103.             num_branches, blocks, num_blocks, num_channels)
  104.         self.fuse_layers = self._make_fuse_layers()
  105.         self.relu = nn.ReLU(inplace=True)
  106.  
  107.     def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
  108.                          stride=1):
  109.         layers = []
  110.         layers.append(block(self.num_inchannels[branch_index],
  111.                             num_channels[branch_index], stride, downsample=True))
  112.         self.num_inchannels[branch_index] = num_channels[branch_index] * expansion(block)
  113.         for i in range(1, num_blocks[branch_index]):
  114.             layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
  115.  
  116.         return nn.Sequential(*layers)
  117.  
  118.     def _make_branches(self, num_branches, block, num_blocks, num_channels):
  119.         branches = []
  120.  
  121.         for i in range(num_branches):
  122.             branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
  123.  
  124.         return nn.ModuleList(branches)
  125.  
  126.     def _make_fuse_layers(self):
  127.         if self.num_branches == 1:
  128.             return None
  129.  
  130.         num_branches = self.num_branches
  131.         num_inchannels = self.num_inchannels
  132.         fuse_layers = []
  133.         for i in range(num_branches if self.multi_scale_output else 1):
  134.             fuse_layer = []
  135.             for j in range(num_branches):
  136.                 if j > i:
  137.                     fuse_layer.append(nn.Sequential(
  138.                         nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
  139.                         nn.BatchNorm2d(num_inchannels[i], momentum=0.01)))
  140.                 elif j == i:
  141.                     fuse_layer.append(None)
  142.                 else:
  143.                     conv3x3s = []
  144.                     for k in range(i - j):
  145.                         if k == i - j - 1:
  146.                             num_outchannels_conv3x3 = num_inchannels[i]
  147.                             conv3x3s.append(nn.Sequential(
  148.                                 nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
  149.                                 nn.BatchNorm2d(num_outchannels_conv3x3, momentum=0.01)))
  150.                         else:
  151.                             num_outchannels_conv3x3 = num_inchannels[j]
  152.                             conv3x3s.append(nn.Sequential(
  153.                                 nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
  154.                                 nn.BatchNorm2d(num_outchannels_conv3x3, momentum=0.01),
  155.                                 nn.ReLU(inplace=True)))
  156.                     fuse_layer.append(nn.Sequential(*conv3x3s))
  157.             fuse_layers.append(nn.ModuleList(fuse_layer))
  158.  
  159.         return nn.ModuleList(fuse_layers)
  160.  
  161.     def get_num_inchannels(self):
  162.         return self.num_inchannels
  163.  
  164.     def forward(self, x):
  165.         if self.num_branches == 1:
  166.             return [self.branches[0](x[0])]
  167.  
  168.         for i in range(self.num_branches):
  169.             x[i] = self.branches[i](x[i])
  170.  
  171.         x_fuse = []
  172.         for i in range(len(self.fuse_layers)):
  173.             y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
  174.             for j in range(1, self.num_branches):
  175.                 if i == j:
  176.                     y = y + x[j]
  177.                 elif j > i:
  178.                     width_output = x[i].shape[-1]
  179.                     height_output = x[i].shape[-2]
  180.                     y = y + F.interpolate(
  181.                         self.fuse_layers[i][j](x[j]),
  182.                         size=[height_output, width_output],
  183.                         mode='bilinear')
  184.                 else:
  185.                     y = y + self.fuse_layers[i][j](x[j])
  186.             x_fuse.append(self.relu(y))
  187.  
  188.         return x_fuse
  189.  
  190. def get_block(block_type):
  191.     if block_type == 'BASIC':
  192.         return BasicBlock
  193.     elif block_type == 'BOTTLENECK':
  194.         return Bottleneck
  195.  
  196. class HighResolutionNet(nn.Module):
  197.     def __init__(self, config, num_classes=20, **kwargs):
  198.         super(HighResolutionNet, self).__init__()
  199.  
  200.  
  201.         # stem net
  202.         self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
  203.                                bias=False)
  204.         self.bn1 = nn.BatchNorm2d(64, momentum=0.01)
  205.         self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
  206.                                bias=False)
  207.         self.bn2 = nn.BatchNorm2d(64, momentum=0.01)
  208.         self.relu = nn.ReLU(inplace=True)
  209.  
  210.         self.stage1_cfg = config['STAGE1']
  211.         num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
  212.         block = get_block(self.stage1_cfg['BLOCK'])
  213.         num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
  214.  
  215.         self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
  216.  
  217.         stage1_out_channel = expansion(block) * num_channels
  218.  
  219.         self.stage2_cfg = config['STAGE2']
  220.         num_channels = self.stage2_cfg['NUM_CHANNELS']
  221.         block = get_block(self.stage2_cfg['BLOCK'])
  222.         num_channels = [num_channels[i] * expansion(block) for i in range(len(num_channels))]
  223.         self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels)
  224.         self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
  225.  
  226.         self.stage3_cfg = config['STAGE3']
  227.         num_channels = self.stage3_cfg['NUM_CHANNELS']
  228.         block = get_block(self.stage3_cfg['BLOCK'])
  229.         num_channels = [num_channels[i] * expansion(block) for i in range(len(num_channels))]
  230.         self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
  231.         self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
  232.  
  233.         self.stage4_cfg = config['STAGE4']
  234.         num_channels = self.stage4_cfg['NUM_CHANNELS']
  235.         block = get_block(self.stage4_cfg['BLOCK'])
  236.         num_channels = [num_channels[i] * expansion(block) for i in range(len(num_channels))]
  237.         self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
  238.         self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True)
  239.  
  240.         last_inp_channels = np.int(np.sum(pre_stage_channels))
  241.  
  242.         self.last_layer = nn.Sequential(
  243.             nn.Conv2d(in_channels=last_inp_channels, out_channels=last_inp_channels, kernel_size=1,stride=1,padding=0),
  244.             nn.BatchNorm2d(last_inp_channels, momentum=0.01),
  245.             nn.ReLU(inplace=True),
  246.             nn.Conv2d(in_channels=last_inp_channels,out_channels=num_classes,kernel_size=config['FINAL_CONV_KERNEL'],
  247.                 stride=1, padding=1 if config['FINAL_CONV_KERNEL'] == 3 else 0)
  248.         )
  249.  
  250.     def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
  251.         num_branches_cur = len(num_channels_cur_layer)
  252.         num_branches_pre = len(num_channels_pre_layer)
  253.         transition_layers = []
  254.         for i in range(num_branches_cur):
  255.             if i < num_branches_pre:
  256.                 if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
  257.                     transition_layers.append(nn.Sequential(
  258.                         nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False),
  259.                         nn.BatchNorm2d(num_channels_cur_layer[i], momentum=0.01), nn.ReLU(inplace=True)))
  260.                 else:
  261.                     transition_layers.append(None)
  262.             else:
  263.                 conv3x3s = []
  264.                 for j in range(i + 1 - num_branches_pre):
  265.                     inchannels = num_channels_pre_layer[-1]
  266.                     outchannels = num_channels_cur_layer[i] \
  267.                         if j == i - num_branches_pre else inchannels
  268.                     conv3x3s.append(nn.Sequential(
  269.                         nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
  270.                         nn.BatchNorm2d(outchannels, momentum=0.01),
  271.                         nn.ReLU(inplace=True)))
  272.                 transition_layers.append(nn.Sequential(*conv3x3s))
  273.  
  274.         return nn.ModuleList(transition_layers)
  275.  
  276.     def _make_layer(self, block, inplanes, planes, blocks, stride=1):
  277.         layers = []
  278.         layers.append(block(inplanes, planes, stride, downsample=True))
  279.         inplanes = planes * expansion(block)
  280.         for i in range(1, blocks):
  281.             layers.append(block(inplanes, planes))
  282.  
  283.         return nn.Sequential(*layers)
  284.  
  285.     def _make_stage(self, layer_config, num_inchannels,
  286.                     multi_scale_output=True):
  287.  
  288.         num_modules = layer_config['NUM_MODULES']
  289.         num_branches = layer_config['NUM_BRANCHES']
  290.         num_blocks = layer_config['NUM_BLOCKS']
  291.         num_channels = layer_config['NUM_CHANNELS']
  292.         fuse_method = layer_config['FUSE_METHOD']
  293.  
  294.         modules = []
  295.         for i in range(num_modules):
  296.             # multi_scale_output is only used last module
  297.             if not multi_scale_output and i == num_modules - 1:
  298.                 reset_multi_scale_output = False
  299.             else:
  300.                 reset_multi_scale_output = True
  301.             modules.append(HighResolutionModule(num_branches, get_block(layer_config['BLOCK']), num_blocks,
  302.                                                 num_inchannels, num_channels, fuse_method, reset_multi_scale_output))
  303.             num_inchannels = modules[-1].get_num_inchannels()
  304.  
  305.         return nn.Sequential(*modules), num_inchannels
  306.  
  307.     def forward(self, x):
  308.         x = self.conv1(x)
  309.         x = self.bn1(x)
  310.         x = self.relu(x)
  311.         x = self.conv2(x)
  312.         x = self.bn2(x)
  313.         x = self.relu(x)
  314.         x = self.layer1(x)
  315.  
  316.         x_list = []
  317.         for i in range(self.stage2_cfg['NUM_BRANCHES']):
  318.             if self.transition1[i] is not None:
  319.                 x_list.append(self.transition1[i](x))
  320.             else:
  321.                 x_list.append(x)
  322.         y_list = self.stage2(x_list)
  323.  
  324.         x_list = []
  325.         for i in range(self.stage3_cfg['NUM_BRANCHES']):
  326.             if self.transition2[i] is not None:
  327.                 x_list.append(self.transition2[i](y_list[-1]))
  328.             else:
  329.                 x_list.append(y_list[i])
  330.         y_list = self.stage3(x_list)
  331.  
  332.         x_list = []
  333.         for i in range(self.stage4_cfg['NUM_BRANCHES']):
  334.             if self.transition3[i] is not None:
  335.                 x_list.append(self.transition3[i](y_list[-1]))
  336.             else:
  337.                 x_list.append(y_list[i])
  338.         x = self.stage4(x_list)
  339.  
  340.         # Upsampling
  341.         x0_h, x0_w = x[0].size(2), x[0].size(3)
  342.         x1 = F.upsample(x[1], size=(x0_h, x0_w), mode='bilinear')
  343.         x2 = F.upsample(x[2], size=(x0_h, x0_w), mode='bilinear')
  344.         x3 = F.upsample(x[3], size=(x0_h, x0_w), mode='bilinear')
  345.  
  346.         x = torch.cat([x[0], x1, x2, x3], 1)
  347.  
  348.         x = self.last_layer(x)
  349.         return x
  350.  
  351.  
  352. class HRN(nn.Module):
  353.     def __init__(self, num_classes):
  354.         super(HRN, self).__init__()
  355.         config = {'FINAL_CONV_KERNEL': 1,
  356.                   'STAGE1': {'NUM_MODULES': 1,
  357.                              'NUM_RANCHES': 1,
  358.                              'BLOCK': 'BOTTLENECK',
  359.                              'NUM_BLOCKS': [4],
  360.                              'NUM_CHANNELS': [64],
  361.                              'FUSE_METHOD': 'SUM'},
  362.                   'STAGE2': {'NUM_MODULES': 1,
  363.                              'NUM_BRANCHES': 2,
  364.                              'BLOCK': 'BASIC',
  365.                              'NUM_BLOCKS': [4, 4],
  366.                              'NUM_CHANNELS': [48, 96],
  367.                              'FUSE_METHOD': 'SUM'},
  368.                   'STAGE3': {'NUM_MODULES': 4,
  369.                              'NUM_BRANCHES': 3,
  370.                              'BLOCK': 'BASIC',
  371.                              'NUM_BLOCKS': [4, 4, 4],
  372.                              'NUM_CHANNELS': [48, 96, 192],
  373.                              'FUSE_METHOD': 'SUM'},
  374.                   'STAGE4': {'NUM_MODULES': 3,
  375.                              'NUM_BRANCHES': 4,
  376.                              'BLOCK': 'BASIC',
  377.                              'NUM_BLOCKS': [4, 4, 4, 4],
  378.                              'NUM_CHANNELS': [48, 96, 192, 384],
  379.                              'FUSE_METHOD': 'SUM'}}
  380.  
  381.         self.model = HighResolutionNet(config, num_classes=num_classes)
  382.  
  383.     def forward(self, x):
  384.         b, c, h, w = x.shape
  385.         out = self.model(x)
  386.         return F.interpolate(out, size=(h, w))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement