Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class UNet(nn.Module):
- def __init__(self):
- super().__init__()
- # 256 -> 128
- self.enc_conv0 = nn.Sequential(
- nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1),
- nn.BatchNorm2d(64),
- nn.ReLU(),
- nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
- nn.BatchNorm2d(64),
- nn.ReLU()
- )
- self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
- # 128 -> 64
- self.enc_conv1 = nn.Sequential(
- nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
- nn.BatchNorm2d(128),
- nn.ReLU(),
- nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
- nn.BatchNorm2d(128),
- nn.ReLU()
- )
- self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
- # 64 -> 32
- self.enc_conv2 = nn.Sequential(
- nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
- nn.BatchNorm2d(256),
- nn.ReLU(),
- nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
- nn.BatchNorm2d(256),
- nn.ReLU()
- )
- self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
- # 32 -> 16
- self.enc_conv3 = nn.Sequential(
- nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
- nn.BatchNorm2d(512),
- nn.ReLU(),
- nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
- nn.BatchNorm2d(512),
- nn.ReLU()
- )
- self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
- # bottleneck
- self.bottleneck_conv = nn.Sequential(
- nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
- nn.ReLU()
- )
- # decoder (upsampling)
- # 16 -> 32
- self.upsample0 = nn.MaxUnpool2d(kernel_size=2, stride=2)
- self.dec_conv0 = nn.Sequential(
- nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
- nn.BatchNorm2d(512),
- nn.ReLU(),
- nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
- nn.BatchNorm2d(256),
- nn.ReLU()
- )
- # 32 -> 64
- self.upsample1 = nn.MaxUnpool2d(kernel_size=2, stride=2)
- self.dec_conv1 = nn.Sequential(
- nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
- nn.BatchNorm2d(256),
- nn.ReLU(),
- nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
- nn.BatchNorm2d(128),
- nn.ReLU()
- )
- # 64 -> 128
- self.upsample2 = nn.MaxUnpool2d(kernel_size=2, stride=2)
- self.dec_conv2 = nn.Sequential(
- nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
- nn.BatchNorm2d(128),
- nn.ReLU(),
- nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
- nn.BatchNorm2d(64),
- nn.ReLU()
- )
- # 128 -> 256
- self.upsample3 = nn.MaxUnpool2d(kernel_size=2, stride=2)
- self.dec_conv3 = nn.Sequential(
- nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
- nn.BatchNorm2d(64),
- nn.ReLU(),
- nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, padding=1),
- )
- def forward(self, x):
- # encoder
- e0, ei0 = self.pool0(self.enc_conv0(x))
- e1, ei1 = self.pool1(self.enc_conv1(e0))
- e2, ei2 = self.pool2(self.enc_conv2(e1))
- e3, ei3 = self.pool3(self.enc_conv3(e2))
- # bottleneck
- b = self.bottleneck_conv(e3)
- # decoder
- d0 = self.dec_conv0(torch.cat((self.upsample0(b, ei3), e3), dim=1))
- d1 = self.dec_conv1(torch.cat((self.upsample1(d0, ei2), e2), dim=1))
- d2 = self.dec_conv2(torch.cat((self.upsample2(d1, ei1), e1), dim=1))
- d3 = self.dec_conv3(torch.cat((self.upsample3(d2, ei0), e0), dim=1)) # no activation
- return d3
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement