Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- class UNet_down_block(torch.nn.Module):
- def __init__(self, input_channel, output_channel, down_size):
- super(UNet_down_block, self).__init__()
- self.conv1 = torch.nn.Conv2d(input_channel, output_channel, 3, padding=1)
- self.bn1 = torch.nn.BatchNorm2d(output_channel)
- self.conv2 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
- self.bn2 = torch.nn.BatchNorm2d(output_channel)
- self.conv3 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
- self.bn3 = torch.nn.BatchNorm2d(output_channel)
- self.max_pool = torch.nn.MaxPool2d(2, 2)
- self.relu = torch.nn.ReLU()
- self.down_size = down_size
- def forward(self, x):
- if self.down_size:
- x = self.max_pool(x)
- x = self.relu(self.bn1(self.conv1(x)))
- x = self.relu(self.bn2(self.conv2(x)))
- x = self.relu(self.bn3(self.conv3(x)))
- return x
- class UNet_up_block(torch.nn.Module):
- def __init__(self, prev_channel, input_channel, output_channel):
- super(UNet_up_block, self).__init__()
- self.up_sampling = torch.nn.Upsample(scale_factor=2, mode='bilinear')
- self.conv1 = torch.nn.Conv2d(input_channel + input_channel, output_channel, 3, padding=1)
- self.bn1 = torch.nn.BatchNorm2d(output_channel)
- self.conv2 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
- self.bn2 = torch.nn.BatchNorm2d(output_channel)
- self.conv3 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
- self.bn3 = torch.nn.BatchNorm2d(output_channel)
- self.relu = torch.nn.ReLU()
- # self.up1=torch.nn.ConvTranspose2d(12,25,3,stride=2,padding=1)
- def forward(self, prev_feature_map, x,k):
- # print('before up',x.size())
- if k!=0:
- x = self.up_sampling(x)
- x = torch.cat((x, prev_feature_map), dim=1)
- x = self.relu(self.bn1(self.conv1(x)))
- x = self.relu(self.bn2(self.conv2(x)))
- x = self.relu(self.bn3(self.conv3(x)))
- return x
- class UNet(torch.nn.Module):
- def __init__(self):
- super(UNet, self).__init__()
- self.down_block1 = UNet_down_block(3, 16, False)
- self.down_block2 = UNet_down_block(16, 32, True)
- self.down_block3 = UNet_down_block(32, 64, True)
- self.mid_conv1 = torch.nn.Conv2d(64, 64, 3, padding=1)
- self.bn1 = torch.nn.BatchNorm2d(64)
- self.mid_conv2 = torch.nn.Conv2d(64, 64, 3, padding=1)
- self.bn2 = torch.nn.BatchNorm2d(64)
- self.mid_conv3 = torch.nn.Conv2d(64, 64, 3, padding=1)
- self.bn3 = torch.nn.BatchNorm2d(64)
- self.up_block5 = UNet_up_block(32, 64, 32)
- self.up_block6 = UNet_up_block(16, 32, 16)
- self.up_block7 = UNet_up_block(3, 16, 16)
- self.last_conv1 = torch.nn.Conv2d(16, 3, 3, padding=1)
- self.last_bn = torch.nn.BatchNorm2d(3)
- self.last_conv2 = torch.nn.Conv2d(3, 1, 1, padding=0)
- self.relu = torch.nn.ReLU()
- self.max_pool = torch.nn.MaxPool2d(2, 2)
- def forward(self, x):
- # ins=x.clone()
- self.x1 = self.down_block1(x)
- # print('self.x1',self.x1.size())
- self.x2 = self.down_block2(self.x1)
- # print('self.x2',self.x2.size())
- self.x3 = self.down_block3(self.x2)
- # print('self.x3',self.x3.size())
- # self.mid=self.max_pool(self.x3)
- self.x7 = self.relu(self.bn1(self.mid_conv1(self.x3)))
- self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))
- self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))
- # print('prev,x',self.x7.size(),self.x3.size())
- x = self.up_block5(self.x3, self.x7,k=0)
- x = self.up_block6(self.x2, x,k=1)
- x=self.up_block7(self.x1,x,k=1)
- x = self.relu(self.last_bn(self.last_conv1(x)))
- x = self.last_conv2(x)
- return x
- def dice(input, taget):
- smooth=.001
- input=input.view(-1)
- target=taget.view(-1)
- return(1-2*(input*target).sum()/(input.sum()+taget.sum()+smooth))
- net = UNet()
- x = torch.randn(1, 3, 100, 100)
- target = torch.randint(0, 2, (1, 1, 100, 100), dtype=torch.float32)
- optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
- criterion = nn.BCEWithLogitsLoss()
- for epoch in range(20):
- optimizer.zero_grad()
- output = net(x)
- bce_loss = criterion(output, target)
- dice_loss = dice(output, target)
- loss = bce_loss + dice_loss
- loss.backward()
- optimizer.step()
- print('Epoch {}, loss {}, bce {}, dice {}'.format(
- epoch, loss.item(), bce_loss.item(), dice_loss.item()))
Add Comment
Please, Sign In to add comment