Guest User

Untitled

a guest
Nov 17th, 2018
110
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.58 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3.  
  4. class UNet_down_block(torch.nn.Module):
  5. def __init__(self, input_channel, output_channel, down_size):
  6. super(UNet_down_block, self).__init__()
  7. self.conv1 = torch.nn.Conv2d(input_channel, output_channel, 3, padding=1)
  8. self.bn1 = torch.nn.BatchNorm2d(output_channel)
  9. self.conv2 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
  10. self.bn2 = torch.nn.BatchNorm2d(output_channel)
  11. self.conv3 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
  12. self.bn3 = torch.nn.BatchNorm2d(output_channel)
  13. self.max_pool = torch.nn.MaxPool2d(2, 2)
  14. self.relu = torch.nn.ReLU()
  15. self.down_size = down_size
  16.  
  17. def forward(self, x):
  18. if self.down_size:
  19. x = self.max_pool(x)
  20. x = self.relu(self.bn1(self.conv1(x)))
  21. x = self.relu(self.bn2(self.conv2(x)))
  22. x = self.relu(self.bn3(self.conv3(x)))
  23. return x
  24.  
  25. class UNet_up_block(torch.nn.Module):
  26. def __init__(self, prev_channel, input_channel, output_channel):
  27. super(UNet_up_block, self).__init__()
  28. self.up_sampling = torch.nn.Upsample(scale_factor=2, mode='bilinear')
  29. self.conv1 = torch.nn.Conv2d(input_channel + input_channel, output_channel, 3, padding=1)
  30. self.bn1 = torch.nn.BatchNorm2d(output_channel)
  31. self.conv2 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
  32. self.bn2 = torch.nn.BatchNorm2d(output_channel)
  33. self.conv3 = torch.nn.Conv2d(output_channel, output_channel, 3, padding=1)
  34. self.bn3 = torch.nn.BatchNorm2d(output_channel)
  35. self.relu = torch.nn.ReLU()
  36.  
  37. # self.up1=torch.nn.ConvTranspose2d(12,25,3,stride=2,padding=1)
  38.  
  39. def forward(self, prev_feature_map, x,k):
  40. # print('before up',x.size())
  41. if k!=0:
  42. x = self.up_sampling(x)
  43. x = torch.cat((x, prev_feature_map), dim=1)
  44. x = self.relu(self.bn1(self.conv1(x)))
  45. x = self.relu(self.bn2(self.conv2(x)))
  46. x = self.relu(self.bn3(self.conv3(x)))
  47. return x
  48.  
  49.  
  50. class UNet(torch.nn.Module):
  51. def __init__(self):
  52. super(UNet, self).__init__()
  53.  
  54. self.down_block1 = UNet_down_block(3, 16, False)
  55. self.down_block2 = UNet_down_block(16, 32, True)
  56. self.down_block3 = UNet_down_block(32, 64, True)
  57.  
  58. self.mid_conv1 = torch.nn.Conv2d(64, 64, 3, padding=1)
  59. self.bn1 = torch.nn.BatchNorm2d(64)
  60. self.mid_conv2 = torch.nn.Conv2d(64, 64, 3, padding=1)
  61. self.bn2 = torch.nn.BatchNorm2d(64)
  62. self.mid_conv3 = torch.nn.Conv2d(64, 64, 3, padding=1)
  63. self.bn3 = torch.nn.BatchNorm2d(64)
  64.  
  65. self.up_block5 = UNet_up_block(32, 64, 32)
  66. self.up_block6 = UNet_up_block(16, 32, 16)
  67. self.up_block7 = UNet_up_block(3, 16, 16)
  68.  
  69. self.last_conv1 = torch.nn.Conv2d(16, 3, 3, padding=1)
  70. self.last_bn = torch.nn.BatchNorm2d(3)
  71. self.last_conv2 = torch.nn.Conv2d(3, 1, 1, padding=0)
  72. self.relu = torch.nn.ReLU()
  73.  
  74. self.max_pool = torch.nn.MaxPool2d(2, 2)
  75.  
  76. def forward(self, x):
  77. # ins=x.clone()
  78. self.x1 = self.down_block1(x)
  79. # print('self.x1',self.x1.size())
  80. self.x2 = self.down_block2(self.x1)
  81. # print('self.x2',self.x2.size())
  82. self.x3 = self.down_block3(self.x2)
  83. # print('self.x3',self.x3.size())
  84.  
  85. # self.mid=self.max_pool(self.x3)
  86.  
  87.  
  88. self.x7 = self.relu(self.bn1(self.mid_conv1(self.x3)))
  89. self.x7 = self.relu(self.bn2(self.mid_conv2(self.x7)))
  90. self.x7 = self.relu(self.bn3(self.mid_conv3(self.x7)))
  91.  
  92. # print('prev,x',self.x7.size(),self.x3.size())
  93.  
  94. x = self.up_block5(self.x3, self.x7,k=0)
  95. x = self.up_block6(self.x2, x,k=1)
  96. x=self.up_block7(self.x1,x,k=1)
  97. x = self.relu(self.last_bn(self.last_conv1(x)))
  98. x = self.last_conv2(x)
  99. return x
  100.  
  101.  
  102. def dice(input, taget):
  103. smooth=.001
  104. input=input.view(-1)
  105. target=taget.view(-1)
  106.  
  107. return(1-2*(input*target).sum()/(input.sum()+taget.sum()+smooth))
  108.  
  109. net = UNet()
  110.  
  111. x = torch.randn(1, 3, 100, 100)
  112. target = torch.randint(0, 2, (1, 1, 100, 100), dtype=torch.float32)
  113. optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
  114. criterion = nn.BCEWithLogitsLoss()
  115.  
  116. for epoch in range(20):
  117. optimizer.zero_grad()
  118. output = net(x)
  119. bce_loss = criterion(output, target)
  120. dice_loss = dice(output, target)
  121. loss = bce_loss + dice_loss
  122. loss.backward()
  123. optimizer.step()
  124. print('Epoch {}, loss {}, bce {}, dice {}'.format(
  125. epoch, loss.item(), bce_loss.item(), dice_loss.item()))
Add Comment
Please, Sign In to add comment