Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # %%
- import torch
- import os
- import numpy as np
- import cv2
- import torch.nn.functional as F
- from torch.nn import Linear, LazyConv2d, ReLU
- relu = ReLU()
- # %%
- train_inp_dir = "./cityscapes_data/new_train/input/"
- train_mask_dir = "./cityscapes_data/new_train/masks/"
- test_inp_dir = "./cityscapes_data/new_test/input/"
- test_mask_dir = "./cityscapes_data/new_test/masks/"
- # (idx, epoch, batch, running loss, memory used [CUDA])
- proc_state_snapshot = []
- torch.cuda.memory._record_memory_history()
- # %% [markdown]
- # ## Hyperparams.
- # %%
- patch_size = 16
- inp_H, inp_W = (256, 256)
- num_classes = 8
- in_channels = 3
- num_patches = (inp_H * inp_W)//(patch_size ** 2)
- batch_size = 32
- filters = [16, 32, 64, 128]
- num_layers = 12
- num_skip_conn = len(filters)
- drop_rate = 0.0
- hidden_dim = filters[0]
- num_heads = hidden_dim // 2
- learning_rate = 1e-2
- num_epochs = 15
- smooth = 1e-7
- # %% [markdown]
- # ## Patch generator
- #
- # - `in_dims` : `(batch_size, in_channels, inp_H, inp_W)`
- # - `out_dims` : `(batch_size, patch_size_H * patch_size_W * in_channels, num_patches)`
- # %%
- from torch.nn import Unfold
- class patchGenerator(torch.nn.Module):
- def __init__(self, device) -> None:
- super(patchGenerator, self).__init__()
- self.unfold = Unfold(kernel_size= (patch_size, patch_size), dilation= (1, 1),
- padding= 0, stride= (patch_size, patch_size))
- def forward(self, X):
- X = torch.reshape(X, (batch_size, in_channels, inp_H, inp_W))
- X = self.unfold(X)
- patch_dim_final = X.shape[-1]
- X = torch.reshape(X, (batch_size, -1, patch_dim_final))
- X = torch.reshape(X, (batch_size, X.shape[2], X.shape[1]))
- return X
- # %% [markdown]
- # ## Patch encoder
- # - `in_dims` : `(batch_size, patch_size_H * patch_size_W * in_channels, num_patches)`
- # - `out_dims` : `(batch_size, hidden_dim, num_patches)`
- # %%
- from torch.nn import Embedding
- class patchEncoder(torch.nn.Module):
- def __init__(self, device) -> None:
- super(patchEncoder, self).__init__()
- self.projLayer = Linear(in_features= patch_size * patch_size * in_channels,
- out_features= hidden_dim, device= device)
- self.pos_emb = Embedding(num_embeddings= num_patches, embedding_dim= hidden_dim, device= device)
- self.device = device
- def forward(self, X):
- positions = torch.arange(start= 0, end= num_patches).to(self.device)
- X_enc = self.projLayer(X) + self.pos_emb(positions)
- return X_enc
- # %% [markdown]
- # ## Transformer Encoder
- # %%
- from torch.nn import LayerNorm, MultiheadAttention, Linear
- class transformerEncoder(torch.nn.Module):
- def __init__(self, device) -> None:
- super(transformerEncoder, self).__init__()
- self.mha = MultiheadAttention(num_heads= num_heads, dropout= drop_rate, embed_dim= hidden_dim, device= device)
- self.num_layers = num_layers
- self.lnorm = LayerNorm(hidden_dim, device= device)
- self.dense1 = Linear(hidden_dim, hidden_dim * 2, device= device)
- self.dense2 = Linear(hidden_dim * 2, hidden_dim, device= device)
- def forward(self, z):
- for _ in range(self.num_layers):
- lnorm_op = self.lnorm(z)
- msa_op = self.mha(lnorm_op, lnorm_op, lnorm_op)[0]
- msa_op = torch.add(input= msa_op, other= lnorm_op)
- lnorm_op = self.lnorm(msa_op)
- mlp_op = self.dense1(lnorm_op)
- mlp_op = self.dense2(mlp_op)
- z = torch.add(input= msa_op, other= mlp_op)
- return z
- # %% [markdown]
- # ## Casacaded UpSampler
- # Here, the entire upsampling operation is split into multiple blocks. Each block performs upsampling, concat, convolution (in that order). Achieved by using nested classes where each instance of the class `CUS_block` represents an upsampling block.
- # %%
- from torch.nn import Upsample
- class CUS(torch.nn.Module):
- class CUS_block(torch.nn.Module):
- def __init__(self, out_channels, mode, device) -> None:
- super().__init__()
- self.conv = LazyConv2d(out_channels= out_channels, kernel_size= 3,
- stride= (1, 1), padding= 'same', device= device)
- self.upsamp = Upsample(scale_factor= 2, mode= mode)
- def forward(self, X, conv_op):
- global relu
- X = self.upsamp(X)
- X = torch.concat([X, conv_op], dim= 1)
- X = self.conv(X)
- X = relu(X)
- return X
- def __init__(self, device) -> None:
- super(CUS, self).__init__()
- self.CUS_block_list = []
- self.upsamp = Upsample(scale_factor= 2, mode= 'nearest')
- for i in range(num_skip_conn):
- self.CUS_block_list.append(self.CUS_block(filters[num_skip_conn - i - 1], 'nearest', device))
- def forward(self, X, skip_conn):
- for i in range(num_skip_conn):
- X = self.CUS_block_list[i](X, skip_conn[num_skip_conn - i -1])
- X = self.upsamp(X)
- return X
- # %% [markdown]
- # ## Convolutional encoder
- # %%
- class convolutionalEncoder(torch.nn.Module):
- class conv_block(torch.nn.Module):
- def __init__(self, out_channels, device) -> None:
- super().__init__()
- self.conv = LazyConv2d(out_channels= out_channels, kernel_size= 2,
- stride= (2, 2), padding= 'valid', device= device)
- def forward(self, X):
- global relu
- X = self.conv(X)
- X = relu(X)
- return X
- def __init__(self, device) -> None:
- super(convolutionalEncoder, self).__init__()
- self.conv_block_list = []
- self.skip_conn = []
- for i in range(num_skip_conn):
- self.conv_block_list.append(self.conv_block(filters[i], device))
- def forward(self, X):
- for i in range(num_skip_conn):
- X = self.conv_block_list[i](X)
- self.skip_conn.append(X)
- return self.skip_conn
- # %% [markdown]
- # ## Dataset creation
- # Transforming the mask into a tensor with `num_classes` channels
- # %%
- def boolMaskGen(mask):
- bool_tensor = F.one_hot(mask, num_classes= num_classes)
- return bool_tensor
- v_boolMaskGen = torch.vmap(boolMaskGen, in_dims= 0, out_dims= 0)
- # %%
- from torch.utils.data import Dataset
- class imageDataset(Dataset):
- def __init__(self, inp_dir, mask_dir) -> None:
- self.inp_dir = inp_dir
- self.mask_dir = mask_dir
- self.list_of_images = os.listdir(inp_dir)
- def __len__(self):
- return len(self.list_of_images)
- def __getitem__(self, index):
- image = cv2.imread(self.inp_dir + self.list_of_images[index])
- image = torch.FloatTensor(image)
- mask = cv2.imread(self.mask_dir + self.list_of_images[index])[:,:,0]
- mask = v_boolMaskGen(torch.tensor(mask, dtype= int))
- return image, mask
- # %% [markdown]
- # ## Building the model
- # %%
- from torch.nn import Softmax
- class transUNet(torch.nn.Module):
- def __init__(self, device) -> None:
- super(transUNet, self).__init__()
- self.conv_enc = convolutionalEncoder(device)
- self.patch_gen = patchGenerator(device)
- self.patch_enc = patchEncoder(device)
- self.trans_enc = transformerEncoder(device)
- self.cus = CUS(device)
- self.conv_after_transenc = LazyConv2d(filters[-1], 2, (2,2), padding= 'valid', device= device)
- self.conv_prefinal = LazyConv2d(16, 3, (1, 1), padding= 'same', device= device)
- self.conv_final = LazyConv2d(num_classes, 3, (1, 1), padding= 'same', device= device)
- self.smax = Softmax(dim= 1)
- self.device = device
- def forward(self, X):
- global relu
- X_copy = X.clone().detach()
- skip_conn = self.conv_enc(X)
- X_copy = self.patch_gen(X_copy)
- X_copy = self.patch_enc(X_copy)
- X_copy = self.trans_enc(X_copy)
- X_copy = torch.reshape(X_copy, (batch_size, hidden_dim, inp_H//patch_size, inp_W//patch_size))
- X_copy = self.conv_after_transenc(X_copy)
- X_copy = relu(X_copy)
- y_pred = self.cus(X_copy, skip_conn)
- y_pred = self.conv_prefinal(y_pred)
- y_pred = relu(y_pred)
- y_pred = self.conv_final(y_pred)
- y_pred = self.smax(y_pred)
- return y_pred
- # %% [markdown]
- # ## Dice loss
- # %%
- def dice_loss(output, target):
- target = torch.reshape(v_boolMaskGen(torch.argmax(target, dim= 1)), (batch_size, num_classes, inp_H, inp_W))
- target = torch.flatten(target)
- output = torch.flatten(output)
- intersection = torch.sum(output * target) + smooth
- dice = (2 * intersection)/(torch.sum(output) + torch.sum(target) + smooth)
- return 1 - dice
- # %% [markdown]
- # ## Training the model
- # %%
- import torch.optim as optim
- from torch.utils.data import DataLoader
- def model_train():
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- model = transUNet(device)
- model.to(device)
- optimizer = optim.Adam(model.parameters(), lr= learning_rate)
- loss_fn = dice_loss
- train_dataset = imageDataset(train_inp_dir, train_mask_dir)
- train_dataloader = DataLoader(dataset= train_dataset, batch_size= batch_size, shuffle= True, num_workers= 1)
- # test_dataset = imageDataset(test_inp_dir, test_mask_dir)
- # test_dataloader = DataLoader(dataset= test_dataset, batch_size= batch_size, shuffle= True, num_workers= 4)
- for epoch in range(num_epochs):
- torch.cuda.memory._snapshot(device= device)
- running_loss = 0.
- last_loss = 0.
- for i, (input, mask) in enumerate(train_dataloader):
- optimizer.zero_grad(set_to_none= True)
- mask = torch.reshape(mask, (batch_size, num_classes, inp_H, inp_W)).to(device= device)
- input = torch.reshape(input, (batch_size, in_channels, inp_H, inp_W)).to(device)
- pred_mask = model(input)
- loss = loss_fn(pred_mask, mask)
- loss.backward(retain_graph= True)
- optimizer.step()
- running_loss += loss.item()
- # (epoch, i, running_loss, (torch.cuda.memory_allocated(device) / (1024 ** 3)))
- proc_state_snapshot.append({'epoch': epoch,
- 'batch': i,
- 'running_loss': running_loss/(i+1),
- 'CUDA mem used': torch.cuda.memory_allocated(device) / (1024 ** 3)})
- if i % 50 == 49:
- last_loss = running_loss/50
- print(f"Epoch: {epoch} Step: {i+1} Loss: {last_loss:.5f} Memory used: {(torch.cuda.memory_allocated(device) / (1024 ** 3)): .2f}G")
- running_loss = 0.
- torch.cuda.empty_cache()
- # with torch.no_grad():
- # torch.cuda.empty_cache()
- # %%
- from datetime import datetime
- from csv import DictWriter
- filename = './outputs/' + str(datetime.today())[:-10]
- try:
- model_train()
- finally:
- torch.cuda.memory._dump_snapshot(filename + '.pickle')
- with open(filename + '.csv', 'w') as f:
- fieldnames = ['epoch', 'batch', 'running_loss', 'CUDA mem used']
- csvwriter = DictWriter(f, fieldnames)
- csvwriter.writeheader()
- csvwriter.writerows(proc_state_snapshot)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement