Advertisement
Guest User

model

a guest
Aug 17th, 2021
44
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.56 KB | None | 0 0
  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. import math
  4. import torch
  5. from blitz.modules import BayesianConv3d
  6. from blitz.utils import variational_estimator
  7.  
  8. def normalization(planes, norm='gn'):
  9.     if norm == 'bn':
  10.         m = nn.BatchNorm3d(planes)
  11.     elif norm == 'gn':
  12.         m = nn.GroupNorm(4, planes)
  13.     elif norm == 'in':
  14.         m = nn.InstanceNorm3d(planes)
  15.     else:
  16.         raise ValueError('normalization type {} is not supported'.format(norm))
  17.     return m
  18.  
  19.  
  20. class ConvD(nn.Module):
  21.     def __init__(self, inplanes, planes, dropout=0.0, norm='gn', first=False):
  22.         super(ConvD, self).__init__()
  23.  
  24.         self.first = first
  25.         self.maxpool = nn.MaxPool3d(2, 2)
  26.  
  27.         self.dropout = dropout
  28.         self.relu = nn.ReLU(inplace=True)
  29.  
  30.         self.conv1 = BayesianConv3d(inplanes, planes, (3,3,3), 1, 1, bias=False)
  31.         self.bn1   = normalization(planes, norm)
  32.  
  33.         self.conv2 = BayesianConv3d(planes, planes, (3,3,3), 1, 1, bias=False)
  34.         self.bn2   = normalization(planes, norm)
  35.  
  36.         self.conv3 = BayesianConv3d(planes, planes, (3,3,3), 1, 1, bias=False)
  37.         self.bn3   = normalization(planes, norm)
  38.  
  39.     def forward(self, x):
  40.         if not self.first:
  41.             x = self.maxpool(x)
  42.         x = self.conv1(x)
  43.         x = self.bn1(x)
  44.         #x = self.bn1(self.conv1(x))
  45.         y = self.relu(self.bn2(self.conv2(x)))
  46.         if self.dropout > 0:
  47.             y = F.dropout3d(y, self.dropout)
  48.         y = self.bn3(self.conv3(x))
  49.         return self.relu(x + y)
  50.  
  51.  
  52. class ConvU(nn.Module):
  53.     def __init__(self, planes, norm='gn', first=False):
  54.         super(ConvU, self).__init__()
  55.  
  56.         self.first = first
  57.  
  58.         if not self.first:
  59.             self.conv1 = BayesianConv3d(2*planes, planes, (3,3,3), 1, 1, bias=False)
  60.             self.bn1   = normalization(planes, norm)
  61.  
  62.         self.conv2 = BayesianConv3d(planes, planes//2, (1,1,1), 1, 0, bias=False)
  63.         self.bn2   = normalization(planes//2, norm)
  64.  
  65.         self.conv3 = BayesianConv3d(planes, planes, (3,3,3), 1, 1, bias=False)
  66.         self.bn3   = normalization(planes, norm)
  67.  
  68.         self.relu = nn.ReLU(inplace=True)
  69.  
  70.     def forward(self, x, prev):
  71.         # final output is the localization layer
  72.         if not self.first:
  73.             x = self.relu(self.bn1(self.conv1(x)))
  74.  
  75.         y = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False)
  76.         y = self.relu(self.bn2(self.conv2(y)))
  77.  
  78.         y = torch.cat([prev, y], 1)
  79.         y = self.relu(self.bn3(self.conv3(y)))
  80.  
  81.         return y
  82.  
  83. @variational_estimator
  84. class Unet(nn.Module):
  85.     def __init__(self, c=1, n=16, dropout=0.5, norm='gn', num_classes=5):
  86.         super(Unet, self).__init__()
  87.         self.upsample = nn.Upsample(scale_factor=2,
  88.                 mode='trilinear', align_corners=False)
  89.  
  90.         self.convd1 = ConvD(c,     n, dropout, norm, first=True)
  91.         self.convd2 = ConvD(n,   2*n, dropout, norm)
  92.         self.convd3 = ConvD(2*n, 4*n, dropout, norm)
  93.         self.convd4 = ConvD(4*n, 8*n, dropout, norm)
  94.         self.convd5 = ConvD(8*n,16*n, dropout, norm)
  95.  
  96.         self.convu4 = ConvU(16*n, norm, True)
  97.         self.convu3 = ConvU(8*n, norm)
  98.         self.convu2 = ConvU(4*n, norm)
  99.         self.convu1 = ConvU(2*n, norm)
  100.  
  101.         self.seg3 = BayesianConv3d(8*n, num_classes, (1,1,1))
  102.         self.seg2 = BayesianConv3d(4*n, num_classes, (1,1,1))
  103.         self.seg1 = BayesianConv3d(2*n, num_classes, (1,1,1))
  104.  
  105.         for m in self.modules():
  106.             if isinstance(m, BayesianConv3d):
  107.                 n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
  108.                 m.weight_mu.data.normal_(0, math.sqrt(2. / n))
  109.                 m.bias_mu.data.zero_()
  110.             elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm):
  111.                 nn.init.constant_(m.weight, 1)
  112.                 nn.init.constant_(m.bias, 0)
  113.  
  114.     def forward(self, x):
  115.         x1 = self.convd1(x)
  116.         x2 = self.convd2(x1)
  117.         x3 = self.convd3(x2)
  118.         x4 = self.convd4(x3)
  119.         x5 = self.convd5(x4)
  120.  
  121.         y4 = self.convu4(x5, x4)
  122.         y3 = self.convu3(y4, x3)
  123.         y2 = self.convu2(y3, x2)
  124.         y1 = self.convu1(y2, x1)
  125.  
  126.         y3 = self.seg3(y3)
  127.         y2 = self.seg2(y2) + self.upsample(y3)
  128.         y1 = self.seg1(y1) + self.upsample(y2)
  129.  
  130.         return y1
  131.  
  132. #import torch
  133. #import os
  134. #os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  135. #cuda0 = torch.device('cuda:0')
  136. #x = torch.rand((2, 4, 32, 32, 32), device=cuda0)
  137. #model = Unet()
  138. #model.cuda()
  139. #y = model(x)
  140. #print(y.shape)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement