Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # for colab
- !gdown 1DXRgzcH89hgc7cs60cENM6C-fU3FdMUX
- import os
- import numpy as np
- import cv2
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torch.nn.functional as F
- #from google.colab.patches import cv2_imshow
- import torchvision.models.segmentation
- import torchvision
- import torchvision.transforms as T
- from torchvision.transforms.functional import InterpolationMode
- from PIL import Image
- import numpy as np
- import matplotlib.pyplot as plt
- import matplotlib
- from torch.utils.data import Dataset, DataLoader
- !unzip /content/matting.zip -d matting
- width = 400
- height = 400
- transformImg=T.Compose([T.ToPILImage(),T.Resize((height,width)), T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
- transformAnn=T.Compose([T.ToPILImage(),T.Resize((height,width), interpolation=InterpolationMode.NEAREST), T.ToTensor()])
- class CustomDataset(Dataset):
- """matting dataset """
- def __init__(self):
- self.listImages=os.listdir("/content/matting/images/images")
- self.listMasks=os.listdir("/content/matting/segmentation/images-gt")
- self.len = len(self.listImages)
- self.lenm = len(self.listMasks)
- def __len__(self):
- return self.len
- def __getitem__(self, idx):
- img=cv2.imread(os.path.join("/content/matting/images/images", self.listImages[idx]) )
- mask = cv2.imread(os.path.join("/content/matting/segmentation/images-gt", self.listImages[idx].replace("jpg","png")) , 0)
- ann = np.zeros(img.shape[0:2],np.float32)
- if mask is not None: ann[ mask > 0 ] = 1
- img=transformImg(img)
- ann=transformAnn(ann)
- return img, ann
- class BaseConv(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, padding,
- stride):
- super(BaseConv, self).__init__()
- self.act = nn.ReLU()
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding,
- stride)
- self.bn = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size,
- padding, stride)
- def forward(self, x):
- x = self.act(self.bn(self.conv1(x)))
- x = self.act(self.bn(self.conv2(x)))
- return x
- class DownConv(nn.Module):
- def __init__(self, in_channels, out_channels, kernel_size, padding,
- stride):
- super(DownConv, self).__init__()
- self.act = nn.ReLU()
- #self.pool1 = nn.MaxPool2d(kernel_size=2)
- self.conv_downsize = nn.Conv2d(in_channels, in_channels,
- kernel_size=kernel_size,
- padding=padding,
- stride=2)
- self.bn1 = nn.BatchNorm2d(in_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
- self.bn2 = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
- self.conv_block = BaseConv(in_channels, out_channels, kernel_size, padding, stride)
- def forward(self, x):
- x = self.act(self.bn1(self.conv_downsize(x)))
- x = self.act(self.bn2(self.conv_block(x)))
- return x
- class UpConv(nn.Module):
- def __init__(self, in_channels, in_channels_skip, out_channels,
- kernel_size, padding, stride):
- super(UpConv, self).__init__()
- self.act = nn.ReLU()
- self.conv_trans1 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, padding=0, stride=2)
- self.bn1 = nn.BatchNorm2d(in_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
- self.bn2 = nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
- self.conv_block = BaseConv(
- in_channels=in_channels + in_channels_skip,
- out_channels=out_channels,
- kernel_size=kernel_size,
- padding=padding,
- stride=stride)
- def forward(self, x, x_skip):
- x = self.act(self.bn1(self.conv_trans1(x)))
- x = torch.cat ((x, x_skip[:, :, :x.shape[2], :x.shape[3]]), dim = 1)
- x = self.act(self.bn2(self.conv_block(x)))
- return x
- class UNet(nn.Module):
- def __init__(self, in_channels, out_channels, n_class, kernel_size,
- padding, stride):
- super(UNet, self).__init__()
- self.init_conv = BaseConv(in_channels, out_channels, kernel_size,
- padding, stride)
- self.down1 = DownConv(out_channels, 2 * out_channels, kernel_size,
- padding, stride)
- self.down2 = DownConv(2 * out_channels, 4 * out_channels, kernel_size,
- padding, stride)
- self.down3 = DownConv(4 * out_channels, 8 * out_channels, kernel_size,
- padding, stride)
- self.up3 = UpConv(8 * out_channels, 4 * out_channels, 4 * out_channels,
- kernel_size, padding, stride)
- self.up2 = UpConv(4 * out_channels, 2 * out_channels, 2 * out_channels,
- kernel_size, padding, stride)
- self.up1 = UpConv(2 * out_channels, out_channels, out_channels,
- kernel_size, padding, stride)
- self.out = nn.Conv2d(out_channels, n_class, kernel_size, padding, stride)
- def forward(self, x):
- # Encoder
- x = self.init_conv(x)
- x1 = self.down1(x)
- x2 = self.down2(x1)
- x3 = self.down3(x2)
- # Decoder
- x_up = self.up3(x3, x2)
- x_up = self.up2(x_up, x1)
- x_up = self.up1(x_up, x)
- x_out = F.log_softmax(self.out(x_up), 1)
- return x_out
- # make data loader
- dl_train = DataLoader(cds, batch_size=3, shuffle=True, drop_last=False)
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
- unet = UNet(in_channels=3,
- out_channels=64,
- n_class=2,
- kernel_size=3,
- padding=1,
- stride=1)
- unet=unet.to(device)
- unet.train()
- print("ready")
- # training loop
- epochs=1
- unet.train()
- criterion = torch.nn.CrossEntropyLoss()
- optimizer=torch.optim.Adam(params=unet.parameters(),lr=1e-5) # Create adam optimizer
- batch = 0
- for epoch in range(epochs):
- for xb, yb in dl_train:
- batch=batch+1
- yb=torch.squeeze(yb,1)
- xb = torch.autograd.Variable(xb, requires_grad=False).to(device) # Load image
- yb = torch.autograd.Variable(yb, requires_grad=False).to(device) # Load annotation
- y_hat = unet(xb) # one batch
- unet.zero_grad()#set_to_none=True)
- loss = criterion(y_hat, yb.long())
- loss.backward()
- optimizer.step()
- # seg = torch.argmax(y_hat[0], 0).cpu().detach().numpy() # Get prediction classes
- print(batch,") Loss=",loss.data.cpu().numpy())
- if batch % 30 == 0: #Save model weight once every 60k steps permenant file
- print("Saving Model" +str(batch) + ".torch")
- torch.save(unet.state_dict(), str(batch) + ".torch")
- # test
- i,m=cds[101]
- i=i.unsqueeze(0)
- print(i.shape)
- i=i.to('cuda')
- unet.eval()
- pred = unet(i.to('cuda'))
- print(pred.shape)
- result = torch.argmax(pred, dim=1)
- print(result.shape)
- result = result.float()
- a,b = torch.numel(result[result==0]), torch.numel(result[result==1])
- print(a,b, a+b)
- imgx = T.ToPILImage()(result)
- print(imgx.mode)
- plt.imshow(imgx)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement