Advertisement
BlackBB

unet

Aug 9th, 2022
36
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.81 KB | None | 0 0
  1. # for colab
  2. !gdown 1DXRgzcH89hgc7cs60cENM6C-fU3FdMUX
  3.  
  4. import os
  5. import numpy as np
  6. import cv2
  7. import torch
  8. import torch.nn as nn
  9. import torch.optim as optim
  10. import torch.nn.functional as F
  11. #from google.colab.patches import cv2_imshow
  12. import torchvision.models.segmentation
  13. import torchvision
  14. import torchvision.transforms as T
  15. from torchvision.transforms.functional import InterpolationMode
  16. from PIL import Image
  17. import numpy as np
  18. import matplotlib.pyplot as plt
  19. import matplotlib
  20. from torch.utils.data import Dataset, DataLoader
  21.  
  22. !unzip /content/matting.zip -d matting
  23.  
  24. width = 400
  25. height = 400
  26. transformImg=T.Compose([T.ToPILImage(),T.Resize((height,width)), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
  27. transformAnn=T.Compose([T.ToPILImage(),T.Resize((height,width), interpolation=InterpolationMode.NEAREST), T.ToTensor()])
  28.  
  29. class CustomDataset(Dataset):
  30. """matting dataset """
  31.  
  32. def __init__(self):
  33. self.listImages=os.listdir("/content/matting/images/images")
  34. self.listMasks=os.listdir("/content/matting/segmentation/images-gt")
  35. self.len = len(self.listImages)
  36. self.lenm = len(self.listMasks)
  37.  
  38. def __len__(self):
  39. return self.len
  40.  
  41. def __getitem__(self, idx):
  42.  
  43. img=cv2.imread(os.path.join("/content/matting/images/images", self.listImages[idx]) )
  44. mask = cv2.imread(os.path.join("/content/matting/segmentation/images-gt", self.listImages[idx].replace("jpg","png")) , 0)
  45.  
  46. ann = np.zeros(img.shape[0:2],np.float32)
  47. if mask is not None: ann[ mask > 0 ] = 1
  48.  
  49.  
  50. img=transformImg(img)
  51. ann=transformAnn(ann)
  52. return img, ann
  53.  
  54. class BaseConv(nn.Module):
  55. def __init__(self, in_channels, out_channels, kernel_size, padding,
  56. stride):
  57. super(BaseConv, self).__init__()
  58.  
  59. self.act = nn.ReLU()
  60.  
  61. self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding,
  62. stride)
  63. self.bn = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  64.  
  65. self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size,
  66. padding, stride)
  67.  
  68. def forward(self, x):
  69. x = self.act(self.bn(self.conv1(x)))
  70. x = self.act(self.bn(self.conv2(x)))
  71. return x
  72.  
  73.  
  74. class DownConv(nn.Module):
  75. def __init__(self, in_channels, out_channels, kernel_size, padding,
  76. stride):
  77. super(DownConv, self).__init__()
  78.  
  79. self.act = nn.ReLU()
  80.  
  81. #self.pool1 = nn.MaxPool2d(kernel_size=2)
  82. self.conv_downsize = nn.Conv2d(in_channels, in_channels,
  83. kernel_size=kernel_size,
  84. padding=padding,
  85. stride=2)
  86.  
  87. self.bn1 = nn.BatchNorm2d(in_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  88.  
  89. self.bn2 = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  90.  
  91. self.conv_block = BaseConv(in_channels, out_channels, kernel_size, padding, stride)
  92.  
  93. def forward(self, x):
  94. x = self.act(self.bn1(self.conv_downsize(x)))
  95. x = self.act(self.bn2(self.conv_block(x)))
  96. return x
  97.  
  98.  
  99. class UpConv(nn.Module):
  100. def __init__(self, in_channels, in_channels_skip, out_channels,
  101. kernel_size, padding, stride):
  102. super(UpConv, self).__init__()
  103.  
  104. self.act = nn.ReLU()
  105.  
  106. self.conv_trans1 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, padding=0, stride=2)
  107.  
  108. self.bn1 = nn.BatchNorm2d(in_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  109. self.bn2 = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  110.  
  111. self.conv_block = BaseConv(
  112. in_channels=in_channels + in_channels_skip,
  113. out_channels=out_channels,
  114. kernel_size=kernel_size,
  115. padding=padding,
  116. stride=stride)
  117.  
  118. def forward(self, x, x_skip):
  119. x = self.act(self.bn1(self.conv_trans1(x)))
  120.  
  121. x = torch.cat ((x, x_skip[:, :, :x.shape[2], :x.shape[3]]), dim = 1)
  122.  
  123. x = self.act(self.bn2(self.conv_block(x)))
  124. return x
  125.  
  126.  
  127. class UNet(nn.Module):
  128. def __init__(self, in_channels, out_channels, n_class, kernel_size,
  129. padding, stride):
  130. super(UNet, self).__init__()
  131.  
  132. self.init_conv = BaseConv(in_channels, out_channels, kernel_size,
  133. padding, stride)
  134.  
  135. self.down1 = DownConv(out_channels, 2 * out_channels, kernel_size,
  136. padding, stride)
  137.  
  138. self.down2 = DownConv(2 * out_channels, 4 * out_channels, kernel_size,
  139. padding, stride)
  140.  
  141. self.down3 = DownConv(4 * out_channels, 8 * out_channels, kernel_size,
  142. padding, stride)
  143.  
  144. self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels,
  145. kernel_size, padding, stride)
  146.  
  147. self.up2 = UpConv(4 * out_channels, 2 * out_channels, 2 * out_channels,
  148. kernel_size, padding, stride)
  149.  
  150. self.up1 = UpConv(2 * out_channels, out_channels, out_channels,
  151. kernel_size, padding, stride)
  152.  
  153. self.out = nn.Conv2d(out_channels, n_class, kernel_size, padding, stride)
  154.  
  155. def forward(self, x):
  156. # Encoder
  157. x = self.init_conv(x)
  158. x1 = self.down1(x)
  159. x2 = self.down2(x1)
  160. x3 = self.down3(x2)
  161. # Decoder
  162. x_up = self.up3(x3, x2)
  163. x_up = self.up2(x_up, x1)
  164. x_up = self.up1(x_up, x)
  165. x_out = F.log_softmax(self.out(x_up), 1)
  166. return x_out
  167.  
  168. # make data loader
  169. dl_train = DataLoader(cds, batch_size=3, shuffle=True, drop_last=False)
  170. device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  171. unet = UNet(in_channels=3,
  172. out_channels=64,
  173. n_class=2,
  174. kernel_size=3,
  175. padding=1,
  176. stride=1)
  177.  
  178. unet=unet.to(device)
  179. unet.train()
  180. print("ready")
  181.  
  182. # training loop
  183. epochs=1
  184. unet.train()
  185. criterion = torch.nn.CrossEntropyLoss()
  186. optimizer=torch.optim.Adam(params=unet.parameters(),lr=1e-5) # Create adam optimizer
  187. batch = 0
  188. for epoch in range(epochs):
  189.  
  190. for xb, yb in dl_train:
  191. batch=batch+1
  192. yb=torch.squeeze(yb,1)
  193. xb = torch.autograd.Variable(xb, requires_grad=False).to(device) # Load image
  194. yb = torch.autograd.Variable(yb, requires_grad=False).to(device) # Load annotation
  195. y_hat = unet(xb) # one batch
  196.  
  197.  
  198. unet.zero_grad()#set_to_none=True)
  199. loss = criterion(y_hat, yb.long())
  200. loss.backward()
  201. optimizer.step()
  202.  
  203. # seg = torch.argmax(y_hat[0], 0).cpu().detach().numpy() # Get prediction classes
  204. print(batch,") Loss=",loss.data.cpu().numpy())
  205. if batch % 30 == 0: #Save model weight once every 60k steps permenant file
  206. print("Saving Model" +str(batch) + ".torch")
  207. torch.save(unet.state_dict(), str(batch) + ".torch")
  208.  
  209. # test
  210.  
  211. i,m=cds[101]
  212. i=i.unsqueeze(0)
  213. print(i.shape)
  214. i=i.to('cuda')
  215. unet.eval()
  216. pred = unet(i.to('cuda'))
  217. print(pred.shape)
  218. result = torch.argmax(pred, dim=1)
  219. print(result.shape)
  220. result = result.float()
  221. a,b = torch.numel(result[result==0]), torch.numel(result[result==1])
  222. print(a,b, a+b)
  223. imgx = T.ToPILImage()(result)
  224. print(imgx.mode)
  225. plt.imshow(imgx)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement