Advertisement
cryisis

Untitled

May 25th, 2020
1,757
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.33 KB | None | 0 0
  1. class UNet(nn.Module):
  2.     def __init__(self):
  3.         super().__init__()
  4.  
  5.         # 256 -> 128
  6.         self.enc_conv0 = nn.Sequential(
  7.             nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
  8.             nn.BatchNorm2d(64),
  9.             nn.ReLU(),
  10.             nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
  11.             nn.BatchNorm2d(64),
  12.             nn.ReLU()
  13.         )
  14.         self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
  15.        
  16.         # 128 -> 64
  17.         self.enc_conv1 = nn.Sequential(
  18.             nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
  19.             nn.BatchNorm2d(128),
  20.             nn.ReLU(),
  21.             nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
  22.             nn.BatchNorm2d(128),
  23.             nn.ReLU()
  24.         )
  25.         self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
  26.        
  27.         # 64 -> 32
  28.         self.enc_conv2 = nn.Sequential(
  29.             nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
  30.             nn.BatchNorm2d(256),
  31.             nn.ReLU(),
  32.             nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
  33.             nn.BatchNorm2d(256),
  34.             nn.ReLU()
  35.         )
  36.         self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
  37.        
  38.         # 32 -> 16
  39.         self.enc_conv3 = nn.Sequential(
  40.             nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
  41.             nn.BatchNorm2d(512),
  42.             nn.ReLU(),
  43.             nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
  44.             nn.BatchNorm2d(512),
  45.             nn.ReLU()
  46.         )
  47.         self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
  48.  
  49.        
  50.         # bottleneck
  51.         self.bottleneck_conv = nn.Sequential(
  52.             nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
  53.             nn.ReLU()
  54.         )
  55.        
  56.         # decoder (upsampling)
  57.         # 16 -> 32
  58.         self.upsample0 = nn.MaxUnpool2d(kernel_size=2, stride=2)
  59.         self.dec_conv0 = nn.Sequential(
  60.             nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
  61.             nn.BatchNorm2d(512),
  62.             nn.ReLU(),
  63.             nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
  64.             nn.BatchNorm2d(256),
  65.             nn.ReLU()
  66.         )
  67.        
  68.         # 32 -> 64
  69.         self.upsample1 = nn.MaxUnpool2d(kernel_size=2, stride=2)
  70.         self.dec_conv1 = nn.Sequential(
  71.             nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
  72.             nn.BatchNorm2d(256),
  73.             nn.ReLU(),
  74.             nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
  75.             nn.BatchNorm2d(128),
  76.             nn.ReLU()
  77.         )
  78.        
  79.         # 64 -> 128
  80.         self.upsample2 = nn.MaxUnpool2d(kernel_size=2, stride=2)
  81.         self.dec_conv2 = nn.Sequential(
  82.             nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
  83.             nn.BatchNorm2d(128),
  84.             nn.ReLU(),
  85.             nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
  86.             nn.BatchNorm2d(64),
  87.             nn.ReLU()
  88.         )
  89.        
  90.         # 128 -> 256
  91.         self.upsample3 = nn.MaxUnpool2d(kernel_size=2, stride=2)
  92.         self.dec_conv3 = nn.Sequential(
  93.             nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
  94.             nn.BatchNorm2d(64),
  95.             nn.ReLU(),
  96.             nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, padding=1),
  97.         )
  98.  
  99.     def forward(self, x):
  100.         # encoder
  101.         e0, ei0 = self.pool0(self.enc_conv0(x))
  102.         e1, ei1 = self.pool1(self.enc_conv1(e0))
  103.         e2, ei2 = self.pool2(self.enc_conv2(e1))
  104.         e3, ei3 = self.pool3(self.enc_conv3(e2))
  105.  
  106.         # bottleneck
  107.         b = self.bottleneck_conv(e3)
  108.  
  109.         # decoder
  110.         d0 = self.dec_conv0(torch.cat((self.upsample0(b, ei3), e3), dim=1))
  111.         d1 = self.dec_conv1(torch.cat((self.upsample1(d0, ei2), e2), dim=1))
  112.         d2 = self.dec_conv2(torch.cat((self.upsample2(d1, ei1), e1), dim=1))
  113.         d3 = self.dec_conv3(torch.cat((self.upsample3(d2, ei0), e0), dim=1)) # no activation
  114.         return d3
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement