Advertisement
ZeroCool22

models.py

Feb 4th, 2018
15,238
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.58 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import math
  5.  
  6.  
  7. def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
  8.     "3x3 convolution with padding"
  9.     return nn.Conv2d(in_planes, out_planes, kernel_size=3,
  10.                      stride=strd, padding=padding, bias=bias)
  11.  
  12.  
  13. class ConvBlock(nn.Module):
  14.     def __init__(self, in_planes, out_planes):
  15.         super(ConvBlock, self).__init__()
  16.         self.bn1 = nn.BatchNorm2d(in_planes)
  17.         self.conv1 = conv3x3(in_planes, int(out_planes / 2))
  18.         self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
  19.         self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
  20.         self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
  21.         self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
  22.  
  23.         if in_planes != out_planes:
  24.             self.downsample = nn.Sequential(
  25.                 nn.BatchNorm2d(in_planes),
  26.                 nn.ReLU(True),
  27.                 nn.Conv2d(in_planes, out_planes,
  28.                           kernel_size=1, stride=1, bias=False),
  29.             )
  30.         else:
  31.             self.downsample = None
  32.  
  33.     def forward(self, x):
  34.         residual = x
  35.  
  36.         out1 = self.bn1(x)
  37.         out1 = F.relu(out1, True)
  38.         out1 = self.conv1(out1)
  39.  
  40.         out2 = self.bn2(out1)
  41.         out2 = F.relu(out2, True)
  42.         out2 = self.conv2(out2)
  43.  
  44.         out3 = self.bn3(out2)
  45.         out3 = F.relu(out3, True)
  46.         out3 = self.conv3(out3)
  47.  
  48.         out3 = torch.cat((out1, out2, out3), 1)
  49.  
  50.         if self.downsample is not None:
  51.             residual = self.downsample(residual)
  52.  
  53.         out3 += residual
  54.  
  55.         return out3
  56.  
  57.  
  58. class Bottleneck(nn.Module):
  59.  
  60.     expansion = 4
  61.  
  62.     def __init__(self, inplanes, planes, stride=1, downsample=None):
  63.         super(Bottleneck, self).__init__()
  64.         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  65.         self.bn1 = nn.BatchNorm2d(planes)
  66.         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
  67.                                padding=1, bias=False)
  68.         self.bn2 = nn.BatchNorm2d(planes)
  69.         self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
  70.         self.bn3 = nn.BatchNorm2d(planes * 4)
  71.         self.relu = nn.ReLU(inplace=True)
  72.         self.downsample = downsample
  73.         self.stride = stride
  74.  
  75.     def forward(self, x):
  76.         residual = x
  77.  
  78.         out = self.conv1(x)
  79.         out = self.bn1(out)
  80.         out = self.relu(out)
  81.  
  82.         out = self.conv2(out)
  83.         out = self.bn2(out)
  84.         out = self.relu(out)
  85.  
  86.         out = self.conv3(out)
  87.         out = self.bn3(out)
  88.  
  89.         if self.downsample is not None:
  90.             residual = self.downsample(x)
  91.  
  92.         out += residual
  93.         out = self.relu(out)
  94.  
  95.         return out
  96.  
  97.  
  98. class HourGlass(nn.Module):
  99.     def __init__(self, num_modules, depth, num_features):
  100.         super(HourGlass, self).__init__()
  101.         self.num_modules = num_modules
  102.         self.depth = depth
  103.         self.features = num_features
  104.  
  105.         self._generate_network(self.depth)
  106.  
  107.     def _generate_network(self, level):
  108.         self.add_module('b1_' + str(level), ConvBlock(256, 256))
  109.  
  110.         self.add_module('b2_' + str(level), ConvBlock(256, 256))
  111.  
  112.         if level > 1:
  113.             self._generate_network(level - 1)
  114.         else:
  115.             self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))
  116.  
  117.         self.add_module('b3_' + str(level), ConvBlock(256, 256))
  118.  
  119.     def _forward(self, level, inp):
  120.         # Upper branch
  121.         up1 = inp
  122.         up1 = self._modules['b1_' + str(level)](up1)
  123.  
  124.         # Lower branch
  125.         low1 = F.max_pool2d(inp, 2, stride=2)
  126.         low1 = self._modules['b2_' + str(level)](low1)
  127.  
  128.         if level > 1:
  129.             low2 = self._forward(level - 1, low1)
  130.         else:
  131.             low2 = low1
  132.             low2 = self._modules['b2_plus_' + str(level)](low2)
  133.  
  134.         low3 = low2
  135.         low3 = self._modules['b3_' + str(level)](low3)
  136.  
  137.         up2 = F.upsample(low3, scale_factor=2, mode='nearest')
  138.  
  139.         return up1 + up2
  140.  
  141.     def forward(self, x):
  142.         return self._forward(self.depth, x)
  143.  
  144.  
  145. class FAN(nn.Module):
  146.  
  147.     def __init__(self, num_modules=1):
  148.         super(FAN, self).__init__()
  149.         self.num_modules = num_modules
  150.  
  151.         # Base part
  152.         self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
  153.         self.bn1 = nn.BatchNorm2d(64)
  154.         self.conv2 = ConvBlock(64, 128)
  155.         self.conv3 = ConvBlock(128, 128)
  156.         self.conv4 = ConvBlock(128, 256)
  157.  
  158.         # Stacking part
  159.         for hg_module in range(self.num_modules):
  160.             self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
  161.             self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
  162.             self.add_module('conv_last' + str(hg_module),
  163.                             nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
  164.             self.add_module('l' + str(hg_module), nn.Conv2d(256,
  165.                                                             68, kernel_size=1, stride=1, padding=0))
  166.             self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
  167.  
  168.             if hg_module < self.num_modules - 1:
  169.                 self.add_module(
  170.                     'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
  171.                 self.add_module('al' + str(hg_module), nn.Conv2d(68,
  172.                                                                  256, kernel_size=1, stride=1, padding=0))
  173.  
  174.     def forward(self, x):
  175.         x = F.relu(self.bn1(self.conv1(x)), True)
  176.         x = F.max_pool2d(self.conv2(x), 2)
  177.         x = self.conv3(x)
  178.         x = self.conv4(x)
  179.  
  180.         previous = x
  181.  
  182.         outputs = []
  183.         for i in range(self.num_modules):
  184.             hg = self._modules['m' + str(i)](previous)
  185.  
  186.             ll = hg
  187.             ll = self._modules['top_m_' + str(i)](ll)
  188.  
  189.             ll = F.relu(self._modules['bn_end' + str(i)]
  190.                         (self._modules['conv_last' + str(i)](ll)), True)
  191.  
  192.             # Predict heatmaps
  193.             tmp_out = self._modules['l' + str(i)](ll)
  194.             outputs.append(tmp_out)
  195.  
  196.             if i < self.num_modules - 1:
  197.                 ll = self._modules['bl' + str(i)](ll)
  198.                 tmp_out_ = self._modules['al' + str(i)](tmp_out)
  199.                 previous = previous + ll + tmp_out_
  200.  
  201.         return outputs
  202.  
  203.  
  204. class ResNetDepth(nn.Module):
  205.  
  206.     def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
  207.         self.inplanes = 64
  208.         super(ResNetDepth, self).__init__()
  209.         self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
  210.                                bias=False)
  211.         self.bn1 = nn.BatchNorm2d(64)
  212.         self.relu = nn.ReLU(inplace=True)
  213.         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  214.         self.layer1 = self._make_layer(block, 64, layers[0])
  215.         self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  216.         self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  217.         self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  218.         self.avgpool = nn.AvgPool2d(7)
  219.         self.fc = nn.Linear(512 * block.expansion, num_classes)
  220.  
  221.         for m in self.modules():
  222.             if isinstance(m, nn.Conv2d):
  223.                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  224.                 m.weight.data.normal_(0, math.sqrt(2. / n))
  225.             elif isinstance(m, nn.BatchNorm2d):
  226.                 m.weight.data.fill_(1)
  227.                 m.bias.data.zero_()
  228.  
  229.     def _make_layer(self, block, planes, blocks, stride=1):
  230.         downsample = None
  231.         if stride != 1 or self.inplanes != planes * block.expansion:
  232.             downsample = nn.Sequential(
  233.                 nn.Conv2d(self.inplanes, planes * block.expansion,
  234.                           kernel_size=1, stride=stride, bias=False),
  235.                 nn.BatchNorm2d(planes * block.expansion),
  236.             )
  237.  
  238.         layers = []
  239.         layers.append(block(self.inplanes, planes, stride, downsample))
  240.         self.inplanes = planes * block.expansion
  241.         for i in range(1, blocks):
  242.             layers.append(block(self.inplanes, planes))
  243.  
  244.         return nn.Sequential(*layers)
  245.  
  246.     def forward(self, x):
  247.         x = self.conv1(x)
  248.         x = self.bn1(x)
  249.         x = self.relu(x)
  250.         x = self.maxpool(x)
  251.  
  252.         x = self.layer1(x)
  253.         x = self.layer2(x)
  254.         x = self.layer3(x)
  255.         x = self.layer4(x)
  256.  
  257.         x = self.avgpool(x)
  258.         x = x.view(x.size(0), -1)
  259.         x = self.fc(x)
  260.  
  261.         return x
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement