Advertisement
Guest User

pytorch implementation of a model

a guest
Dec 11th, 2021
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.10 KB | None | 0 0
  1. #segmentation model on keras
  2. puts = Input(shape=(256, 256, 1))
  3.  
  4. # encoding
  5. net = Conv2D(32, kernel_size=3, activation='relu', padding='same')(inputs)
  6. net = MaxPooling2D(pool_size=2, padding='same')(net)
  7. net = Conv2D(64, kernel_size=3, activation='relu', padding='same')(net)
  8. net = MaxPooling2D(pool_size=2, padding='same')(net)
  9. net = Conv2D(128, kernel_size=3, activation='relu', padding='same')(net)
  10. net = MaxPooling2D(pool_size=2, padding='same')(net)
  11.  
  12. net = Dense(128, activation='relu')(net)
  13.  
  14. # decoding
  15. net = UpSampling2D(size=2)(net)
  16. net = Conv2D(128, kernel_size=3, activation='sigmoid', padding='same')(net)
  17. net = UpSampling2D(size=2)(net)
  18. net = Conv2D(64, kernel_size=3, activation='sigmoid', padding='same')(net)
  19. net = UpSampling2D(size=2)(net)
  20.  
  21. # output with 1 channel for gray scale segmenation
  22. outputs = Conv2D(1, kernel_size=3, activation='sigmoid', padding='same')(net)
  23.  
  24. model = Model(inputs=inputs, outputs=outputs)
  25.  
  26. # use binary cross entropy with sigmoid function
  27. model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc', 'mse'])
  28.  
  29. model.summary()
  30.  
  31.  
  32.  
  33. #############################################################################################################################################################################################################################################################################################################################################################################################################
  34.  
  35. #model implementation on pytorch
  36.  
  37. import os
  38. import torch
  39. import torchvision
  40. import tarfile
  41. from torchvision.datasets.utils import download_url
  42. from torch.utils.data import random_split
  43. from torchsummary import summary
  44.  
  45.  
  46.  
  47. # Implementation of CNN/ConvNet Model
  48.  
  49. class build_unet(torch.nn.Module):
  50.  
  51.     def __init__(self):
  52.         super(build_unet, self).__init__()
  53.         # L1 ImgIn shape=(?, 28, 28, 1)
  54.         # Conv -> (?, 28, 28, 32)
  55.         # Pool -> (?, 14, 14, 32)
  56.         keep_prob = 0.5
  57.         self.layer1 = torch.nn.Sequential(
  58.             torch.nn.Conv2d(3, 32, kernel_size=3),
  59.             torch.nn.ReLU(),
  60.             torch.nn.MaxPool2d(kernel_size=2, padding=1))
  61.         # L2 ImgIn shape=(?, 14, 14, 32)
  62.         # Conv      ->(?, 14, 14, 64)
  63.         # Pool      ->(?, 7, 7, 64)
  64.         self.layer2 = torch.nn.Sequential(
  65.             torch.nn.Conv2d(32, 64, kernel_size=3),
  66.             torch.nn.ReLU(),
  67.             torch.nn.MaxPool2d(kernel_size=2, padding=1))
  68.         # L3 ImgIn shape=(?, 7, 7, 64)
  69.         # Conv ->(?, 7, 7, 128)
  70.         # Pool ->(?, 4, 4, 128)
  71.         self.layer3 = torch.nn.Sequential(
  72.             torch.nn.Conv2d(64, 128, kernel_size=3),
  73.             torch.nn.ReLU(),
  74.             torch.nn.MaxPool2d(kernel_size=2, padding=1))
  75.  
  76.  
  77.  
  78.         self.dense = torch.nn.Linear(128, 128, bias=True)
  79.         torch.nn.init.xavier_uniform_(self.dense.weight)
  80.         self.layer4 = torch.nn.Sequential(
  81.             self.dense,
  82.             torch.nn.ReLU(),
  83.             torch.nn.Upsample()
  84.         )
  85.        
  86.         self.layer5 = torch.nn.Sequential(
  87.             torch.nn.Conv2d(128, 128, kernel_size=3),
  88.             torch.nn.Sigmoid(),
  89.             torch.nn.Upsample()
  90.         )
  91.        
  92.         self.layer6 = torch.nn.Sequential(
  93.             torch.nn.Conv2d(128, 64, kernel_size=3),
  94.             torch.nn.Sigmoid(),
  95.             torch.nn.Upsample()
  96.         )
  97.        
  98.         self.layer7 = torch.nn.Sequential(
  99.             torch.nn.Conv2d(64, 1, kernel_size=3),
  100.             torch.nn.Sigmoid()
  101.         )
  102.        
  103.        
  104.        
  105.     def forward(self, x):
  106.         out = self.layer1(x)
  107.         out = self.layer2(out)
  108.         out = self.layer3(out)
  109.         out = self.layer4(out)
  110.         out = self.layer5(out)
  111.         out = self.layer6(out)
  112.         out = self.layer7(out)
  113.        
  114.        
  115.         #         out = out.view(out.size(0), -1)   # Flatten them for FC
  116. #         out = self.fc1(out)
  117. #         out = self.fc2(out)
  118.        
  119.         return out
  120.  
  121. if __name__ == "__main__":
  122.     x = torch.randn((2, 3, 512, 512))
  123.     f = build_unet()
  124.     y = f(x)
  125.     print(y.shape)
  126.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement