Advertisement
Guest User

TransUNet

a guest
Feb 8th, 2024
232
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 11.62 KB | Source Code | 0 0
  1. # %%
  2. import torch
  3. import os
  4. import numpy as np
  5. import cv2
  6. import torch.nn.functional as F
  7. from torch.nn import Linear, LazyConv2d, ReLU
  8.  
  9. relu = ReLU()
  10.  
  11. # %%
  12. train_inp_dir = "./cityscapes_data/new_train/input/"
  13. train_mask_dir = "./cityscapes_data/new_train/masks/"
  14.  
  15. test_inp_dir = "./cityscapes_data/new_test/input/"
  16. test_mask_dir = "./cityscapes_data/new_test/masks/"
  17.  
  18. # (idx, epoch, batch, running loss, memory used [CUDA])
  19. proc_state_snapshot = []
  20. torch.cuda.memory._record_memory_history()
  21.  
  22. # %% [markdown]
  23. # ## Hyperparams.
  24.  
  25. # %%
  26. patch_size = 16
  27. inp_H, inp_W = (256, 256)
  28. num_classes = 8
  29. in_channels = 3
  30.  
  31. num_patches = (inp_H * inp_W)//(patch_size ** 2)  
  32. batch_size = 32
  33.  
  34. filters = [16, 32, 64, 128]
  35. num_layers = 12
  36. num_skip_conn = len(filters)
  37. drop_rate = 0.0
  38. hidden_dim = filters[0]
  39. num_heads = hidden_dim // 2
  40.  
  41. learning_rate = 1e-2
  42. num_epochs = 15
  43.  
  44. smooth = 1e-7
  45.  
  46. # %% [markdown]
  47. # ## Patch generator
  48. #
  49. # - `in_dims` : `(batch_size, in_channels, inp_H, inp_W)`
  50. # - `out_dims` : `(batch_size, patch_size_H * patch_size_W * in_channels, num_patches)`
  51.  
  52. # %%
  53. from torch.nn import Unfold
  54.  
  55. class patchGenerator(torch.nn.Module):
  56.    
  57.     def __init__(self, device) -> None:
  58.         super(patchGenerator, self).__init__()
  59.         self.unfold = Unfold(kernel_size= (patch_size, patch_size), dilation= (1, 1),
  60.                              padding= 0, stride= (patch_size, patch_size))
  61.  
  62.     def forward(self, X):
  63.         X = torch.reshape(X, (batch_size, in_channels, inp_H, inp_W))
  64.         X = self.unfold(X)
  65.         patch_dim_final = X.shape[-1]
  66.         X = torch.reshape(X, (batch_size, -1, patch_dim_final))
  67.         X = torch.reshape(X, (batch_size, X.shape[2], X.shape[1]))
  68.  
  69.         return X  
  70.  
  71. # %% [markdown]
  72. # ## Patch encoder
  73. # - `in_dims` : `(batch_size, patch_size_H * patch_size_W * in_channels, num_patches)`
  74. # - `out_dims` : `(batch_size, hidden_dim, num_patches)`
  75.  
  76. # %%
  77. from torch.nn import Embedding
  78.  
  79. class patchEncoder(torch.nn.Module):
  80.    
  81.     def __init__(self, device) -> None:
  82.         super(patchEncoder, self).__init__()
  83.         self.projLayer = Linear(in_features= patch_size * patch_size * in_channels,
  84.                                 out_features= hidden_dim, device= device)
  85.         self.pos_emb = Embedding(num_embeddings= num_patches, embedding_dim= hidden_dim, device= device)
  86.         self.device = device
  87.  
  88.     def forward(self, X):
  89.         positions = torch.arange(start= 0, end= num_patches).to(self.device)
  90.         X_enc = self.projLayer(X) + self.pos_emb(positions)
  91.  
  92.         return X_enc
  93.  
  94. # %% [markdown]
  95. # ## Transformer Encoder
  96.  
  97. # %%
  98. from torch.nn import LayerNorm, MultiheadAttention, Linear
  99.  
  100. class transformerEncoder(torch.nn.Module):
  101.  
  102.     def __init__(self, device) -> None:
  103.         super(transformerEncoder, self).__init__()
  104.        
  105.         self.mha = MultiheadAttention(num_heads= num_heads, dropout= drop_rate, embed_dim= hidden_dim, device= device)
  106.         self.num_layers = num_layers
  107.         self.lnorm = LayerNorm(hidden_dim, device= device)
  108.         self.dense1 = Linear(hidden_dim, hidden_dim * 2, device= device)
  109.         self.dense2 = Linear(hidden_dim * 2, hidden_dim, device= device)
  110.  
  111.     def forward(self, z):
  112.  
  113.         for _ in range(self.num_layers):
  114.            
  115.             lnorm_op = self.lnorm(z)
  116.             msa_op = self.mha(lnorm_op, lnorm_op, lnorm_op)[0]
  117.             msa_op = torch.add(input= msa_op, other= lnorm_op)
  118.  
  119.             lnorm_op = self.lnorm(msa_op)
  120.             mlp_op = self.dense1(lnorm_op)
  121.             mlp_op = self.dense2(mlp_op)
  122.  
  123.             z = torch.add(input= msa_op, other= mlp_op)
  124.        
  125.         return z
  126.  
  127.  
  128. # %% [markdown]
  129. # ## Casacaded UpSampler
  130. # Here, the entire upsampling operation is split into multiple blocks. Each block performs upsampling, concat, convolution (in that order). Achieved by using nested classes where each instance of the class `CUS_block` represents an upsampling block.
  131.  
  132. # %%
  133. from torch.nn import Upsample
  134.  
  135. class CUS(torch.nn.Module):
  136.  
  137.     class CUS_block(torch.nn.Module):
  138.         def __init__(self, out_channels, mode, device) -> None:
  139.             super().__init__()
  140.            
  141.             self.conv = LazyConv2d(out_channels= out_channels, kernel_size= 3,
  142.                                    stride= (1, 1), padding= 'same', device= device)
  143.             self.upsamp = Upsample(scale_factor= 2, mode= mode)
  144.            
  145.  
  146.         def forward(self, X, conv_op):
  147.             global relu
  148.  
  149.             X = self.upsamp(X)
  150.             X = torch.concat([X, conv_op], dim= 1)
  151.             X = self.conv(X)
  152.             X = relu(X)
  153.  
  154.             return X
  155.  
  156.     def __init__(self, device) -> None:
  157.         super(CUS, self).__init__()
  158.        
  159.         self.CUS_block_list = []
  160.         self.upsamp = Upsample(scale_factor= 2, mode= 'nearest')
  161.        
  162.         for i in range(num_skip_conn):
  163.             self.CUS_block_list.append(self.CUS_block(filters[num_skip_conn - i - 1], 'nearest', device))
  164.  
  165.     def forward(self, X, skip_conn):
  166.        
  167.         for i in range(num_skip_conn):
  168.             X = self.CUS_block_list[i](X, skip_conn[num_skip_conn - i -1])
  169.  
  170.         X = self.upsamp(X)
  171.         return X
  172.  
  173. # %% [markdown]
  174. # ## Convolutional encoder
  175.  
  176. # %%
  177. class convolutionalEncoder(torch.nn.Module):
  178.    
  179.     class conv_block(torch.nn.Module):
  180.  
  181.         def __init__(self, out_channels, device) -> None:
  182.             super().__init__()
  183.             self.conv = LazyConv2d(out_channels= out_channels, kernel_size= 2,
  184.                                    stride= (2, 2), padding= 'valid', device= device)
  185.            
  186.  
  187.         def forward(self, X):
  188.             global relu
  189.  
  190.             X = self.conv(X)
  191.             X = relu(X)
  192.  
  193.             return X
  194.  
  195.     def __init__(self, device) -> None:
  196.         super(convolutionalEncoder, self).__init__()
  197.         self.conv_block_list = []
  198.         self.skip_conn = []
  199.         for i in range(num_skip_conn):
  200.             self.conv_block_list.append(self.conv_block(filters[i], device))
  201.  
  202.     def forward(self, X):
  203.  
  204.         for i in range(num_skip_conn):
  205.             X = self.conv_block_list[i](X)
  206.             self.skip_conn.append(X)
  207.  
  208.         return self.skip_conn
  209.  
  210. # %% [markdown]
  211. # ## Dataset creation
  212. # Transforming the mask into a tensor with `num_classes` channels
  213.  
  214. # %%
  215. def boolMaskGen(mask):
  216.    
  217.     bool_tensor = F.one_hot(mask, num_classes= num_classes)
  218.     return bool_tensor
  219.  
  220. v_boolMaskGen = torch.vmap(boolMaskGen, in_dims= 0, out_dims= 0)
  221.  
  222. # %%
  223. from torch.utils.data import Dataset
  224.  
  225. class imageDataset(Dataset):
  226.  
  227.     def __init__(self, inp_dir, mask_dir) -> None:
  228.         self.inp_dir = inp_dir
  229.         self.mask_dir = mask_dir
  230.         self.list_of_images = os.listdir(inp_dir)
  231.  
  232.     def __len__(self):
  233.         return len(self.list_of_images)
  234.    
  235.     def __getitem__(self, index):
  236.         image = cv2.imread(self.inp_dir + self.list_of_images[index])
  237.         image = torch.FloatTensor(image)
  238.  
  239.         mask = cv2.imread(self.mask_dir + self.list_of_images[index])[:,:,0]
  240.         mask = v_boolMaskGen(torch.tensor(mask, dtype= int))
  241.  
  242.         return image, mask
  243.  
  244. # %% [markdown]
  245. # ## Building the model
  246.  
  247. # %%
  248. from torch.nn import Softmax
  249.  
  250. class transUNet(torch.nn.Module):
  251.  
  252.     def __init__(self, device) -> None:
  253.         super(transUNet, self).__init__()
  254.        
  255.         self.conv_enc = convolutionalEncoder(device)
  256.         self.patch_gen = patchGenerator(device)
  257.         self.patch_enc = patchEncoder(device)
  258.         self.trans_enc = transformerEncoder(device)
  259.         self.cus = CUS(device)
  260.        
  261.         self.conv_after_transenc = LazyConv2d(filters[-1], 2, (2,2), padding= 'valid', device= device)
  262.         self.conv_prefinal = LazyConv2d(16, 3, (1, 1), padding= 'same', device= device)
  263.         self.conv_final = LazyConv2d(num_classes, 3, (1, 1), padding= 'same', device= device)
  264.        
  265.         self.smax = Softmax(dim= 1)
  266.  
  267.         self.device = device
  268.  
  269.     def forward(self, X):
  270.         global relu
  271.  
  272.         X_copy = X.clone().detach()
  273.         skip_conn = self.conv_enc(X)
  274.                
  275.         X_copy = self.patch_gen(X_copy)
  276.         X_copy = self.patch_enc(X_copy)
  277.         X_copy = self.trans_enc(X_copy)
  278.  
  279.         X_copy = torch.reshape(X_copy, (batch_size, hidden_dim, inp_H//patch_size, inp_W//patch_size))
  280.         X_copy = self.conv_after_transenc(X_copy)
  281.         X_copy = relu(X_copy)
  282.  
  283.         y_pred = self.cus(X_copy, skip_conn)
  284.        
  285.         y_pred = self.conv_prefinal(y_pred)
  286.         y_pred = relu(y_pred)
  287.  
  288.         y_pred = self.conv_final(y_pred)
  289.         y_pred = self.smax(y_pred)
  290.  
  291.         return y_pred
  292.    
  293.  
  294. # %% [markdown]
  295. # ## Dice loss  
  296.  
  297. # %%
  298. def dice_loss(output, target):
  299.     target = torch.reshape(v_boolMaskGen(torch.argmax(target, dim= 1)), (batch_size, num_classes, inp_H, inp_W))
  300.  
  301.     target = torch.flatten(target)
  302.     output = torch.flatten(output)
  303.  
  304.     intersection = torch.sum(output * target) + smooth
  305.     dice = (2 * intersection)/(torch.sum(output) + torch.sum(target) + smooth)
  306.  
  307.     return 1 - dice
  308.  
  309. # %% [markdown]
  310. # ## Training the model
  311.  
  312. # %%
  313. import torch.optim as optim
  314. from torch.utils.data import DataLoader
  315.  
  316. def model_train():
  317.     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  318.     model = transUNet(device)
  319.     model.to(device)
  320.  
  321.     optimizer = optim.Adam(model.parameters(), lr= learning_rate)
  322.     loss_fn = dice_loss
  323.  
  324.     train_dataset = imageDataset(train_inp_dir, train_mask_dir)
  325.     train_dataloader = DataLoader(dataset= train_dataset, batch_size= batch_size, shuffle= True, num_workers= 1)
  326.  
  327.     # test_dataset = imageDataset(test_inp_dir, test_mask_dir)
  328.     # test_dataloader = DataLoader(dataset= test_dataset, batch_size= batch_size, shuffle= True, num_workers= 4)
  329.  
  330.     for epoch in range(num_epochs):
  331.        
  332.         torch.cuda.memory._snapshot(device= device)
  333.  
  334.         running_loss = 0.
  335.         last_loss = 0.
  336.        
  337.         for i, (input, mask) in enumerate(train_dataloader):
  338.  
  339.             optimizer.zero_grad(set_to_none= True)
  340.             mask = torch.reshape(mask, (batch_size, num_classes, inp_H, inp_W)).to(device= device)
  341.             input = torch.reshape(input, (batch_size, in_channels, inp_H, inp_W)).to(device)
  342.        
  343.             pred_mask = model(input)
  344.        
  345.             loss = loss_fn(pred_mask, mask)
  346.             loss.backward(retain_graph= True)
  347.  
  348.             optimizer.step()
  349.  
  350.             running_loss += loss.item()
  351.             # (epoch, i, running_loss, (torch.cuda.memory_allocated(device) / (1024 ** 3)))
  352.             proc_state_snapshot.append({'epoch': epoch,
  353.                                         'batch': i,
  354.                                         'running_loss': running_loss/(i+1),
  355.                                         'CUDA mem used': torch.cuda.memory_allocated(device) / (1024 ** 3)})
  356.            
  357.             if i % 50 == 49:
  358.                 last_loss = running_loss/50
  359.                 print(f"Epoch: {epoch} Step: {i+1} Loss: {last_loss:.5f} Memory used: {(torch.cuda.memory_allocated(device) / (1024 ** 3)): .2f}G")
  360.                 running_loss = 0.
  361.  
  362.             torch.cuda.empty_cache()
  363.             # with torch.no_grad():
  364.             #     torch.cuda.empty_cache()    
  365.  
  366. # %%
  367. from datetime import datetime
  368. from csv import DictWriter
  369.  
  370. filename = './outputs/' + str(datetime.today())[:-10]
  371.  
  372. try:
  373.     model_train()
  374. finally:
  375.     torch.cuda.memory._dump_snapshot(filename + '.pickle')
  376.     with open(filename + '.csv', 'w') as f:
  377.         fieldnames = ['epoch', 'batch', 'running_loss', 'CUDA mem used']
  378.         csvwriter = DictWriter(f, fieldnames)
  379.        
  380.         csvwriter.writeheader()
  381.         csvwriter.writerows(proc_state_snapshot)
  382.  
  383.  
  384.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement