Advertisement
Guest User

Untitled

a guest
Jun 16th, 2019
872
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 9.22 KB | None | 0 0
  1. import simulation
  2. import helper
  3. import copy
  4. import time
  5. from collections import defaultdict
  6.  
  7. from torch.utils.data import Dataset, DataLoader
  8. from torchvision import transforms, datasets, models
  9. import torch.nn.functional as F
  10. import torch.nn as nn
  11. import torch
  12. import torch.optim as optim
  13. from torch.optim import lr_scheduler
  14. import numpy as np
  15.  
  16. from convcrf import convcrf
  17.  
  18. import ipdb
  19. import matplotlib.pyplot as plt
  20.  
  21. input_images, target_masks = simulation.generate_random_data(320, 320, count=3)
  22.  
  23.  
  24. class SimDataset(Dataset):
  25. def __init__(self, count, transform=None):
  26. self.input_images, self.target_masks = simulation.generate_random_data(320, 320, count=count)
  27. self.transform = transform
  28.  
  29. def __len__(self):
  30. return len(self.input_images)
  31.  
  32. def __getitem__(self, idx):
  33. image = self.input_images[idx]
  34. mask = self.target_masks[idx]
  35. if self.transform:
  36. image = self.transform(image)
  37.  
  38. return [image, mask]
  39.  
  40.  
  41. trans = transforms.Compose([
  42. transforms.ToTensor(),
  43. ])
  44.  
  45. train_set = SimDataset(2000, transform = trans)
  46. val_set = SimDataset(200, transform = trans)
  47.  
  48. # train_set = SimDataset(4, transform=trans)
  49. # val_set = SimDataset(2, transform=trans)
  50.  
  51. image_datasets = {
  52. 'train': train_set, 'val': val_set
  53. }
  54.  
  55. batch_size = 25
  56. batch_size = 1
  57.  
  58. dataloaders = {
  59. 'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
  60. 'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
  61. }
  62.  
  63. dataset_sizes = {
  64. x: len(image_datasets[x]) for x in image_datasets.keys()
  65. }
  66.  
  67. # Generate some random images
  68. input_images, target_masks = simulation.generate_random_data(320, 320, count=3)
  69. # target_masks = target_masks[:, :2, :, :]
  70.  
  71. for x in [input_images, target_masks]:
  72. print(x.shape)
  73. print(x.min(), x.max())
  74.  
  75. # Change channel-order and make 3 channels for matplot
  76. input_images_rgb = [x.astype(np.uint8) for x in input_images]
  77.  
  78. # Map each channel (i.e. class) to each color
  79. target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]
  80.  
  81. # Left: Input image, Right: Target mask (Ground-truth)
  82. helper.plot_side_by_side([input_images_rgb, target_masks_rgb])
  83.  
  84.  
  85. def double_conv(in_channels, out_channels):
  86. return nn.Sequential(
  87. nn.Conv2d(in_channels, out_channels, 3, padding=1),
  88. nn.ReLU(inplace=True),
  89. nn.Conv2d(out_channels, out_channels, 3, padding=1),
  90. nn.ReLU(inplace=True)
  91. )
  92.  
  93.  
  94. class UNet(nn.Module):
  95.  
  96. def __init__(self, n_class):
  97. super().__init__()
  98.  
  99. self.dconv_down1 = double_conv(3, 64)
  100. self.dconv_down2 = double_conv(64, 128)
  101. self.dconv_down3 = double_conv(128, 256)
  102. self.dconv_down4 = double_conv(256, 512)
  103.  
  104. self.maxpool = nn.MaxPool2d(2)
  105. self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
  106.  
  107. self.dconv_up3 = double_conv(256 + 512, 256)
  108. self.dconv_up2 = double_conv(128 + 256, 128)
  109. self.dconv_up1 = double_conv(128 + 64, 64)
  110.  
  111. self.conv_last = nn.Conv2d(64, n_class, 1)
  112.  
  113. shape = (320, 320)
  114. config = convcrf.default_conf
  115. config['pyinn'] = False
  116. config['trainable'] = True
  117. config['trainable_bias'] = True
  118. self.convcrf = convcrf.GaussCRF(conf=config, shape=shape, nclasses=n_class)
  119.  
  120. self.postprocessing = False
  121.  
  122. def forward(self, x):
  123. x_origin = x
  124. conv1 = self.dconv_down1(x)
  125. x = self.maxpool(conv1)
  126.  
  127. conv2 = self.dconv_down2(x)
  128. x = self.maxpool(conv2)
  129.  
  130. conv3 = self.dconv_down3(x)
  131. x = self.maxpool(conv3)
  132.  
  133. x = self.dconv_down4(x)
  134.  
  135. x = self.upsample(x)
  136. x = torch.cat([x, conv3], dim=1)
  137.  
  138. x = self.dconv_up3(x)
  139. x = self.upsample(x)
  140. x = torch.cat([x, conv2], dim=1)
  141.  
  142. x = self.dconv_up2(x)
  143. x = self.upsample(x)
  144. x = torch.cat([x, conv1], dim=1)
  145.  
  146. x = self.dconv_up1(x)
  147.  
  148. out_x = self.conv_last(x)
  149.  
  150. if self.postprocessing:
  151. # out_x = torch.clamp(out_x, 0, 1)
  152. # ipdb.set_trace()
  153. out_x = self.convcrf(out_x, x_origin)
  154.  
  155. return out_x
  156.  
  157. def postprocessing_state(self, is_enamble=False):
  158. self.postprocessing = is_enamble
  159.  
  160.  
  161. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  162.  
  163. model = UNet(2)
  164. model = model.to(device)
  165.  
  166. model.eval()
  167.  
  168. for inputs, labels in dataloaders['train']:
  169. inputs = inputs.to(device)
  170. labels = labels.to(device)
  171.  
  172. with torch.no_grad():
  173. res = model(inputs)
  174. break
  175.  
  176.  
  177. def dice_loss(pred, target, smooth=1.):
  178. pred = pred.contiguous()
  179. target = target.contiguous()
  180.  
  181. intersection = (pred * target).sum(dim=2).sum(dim=2)
  182.  
  183. loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
  184.  
  185. return loss.mean()
  186.  
  187.  
  188. def calc_loss(pred, target, metrics, bce_weight=0.5):
  189. bce = F.binary_cross_entropy_with_logits(pred, target)
  190.  
  191. pred = torch.sigmoid(pred)
  192. dice = dice_loss(pred, target)
  193.  
  194. loss = bce * bce_weight + dice * (1 - bce_weight)
  195.  
  196. metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
  197. metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
  198. metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
  199.  
  200. return loss
  201.  
  202.  
  203. def print_metrics(metrics, epoch_samples, phase):
  204. outputs = []
  205. for k in metrics.keys():
  206. outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
  207.  
  208. print("{}: {}".format(phase, ", ".join(outputs)))
  209.  
  210.  
  211. def train_model(model, optimizer, scheduler, num_epochs=25):
  212. best_model_wts = copy.deepcopy(model.state_dict())
  213. best_loss = 1e10
  214.  
  215. for epoch in range(num_epochs):
  216. print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  217. print('-' * 10)
  218.  
  219. since = time.time()
  220.  
  221. # Each epoch has a training and validation phase
  222. for phase in ['train', 'val']:
  223. if phase == 'train':
  224. scheduler.step()
  225. for param_group in optimizer.param_groups:
  226. print("LR", param_group['lr'])
  227.  
  228. model.train() # Set model to training mode
  229. else:
  230. model.eval() # Set model to evaluate mode
  231.  
  232. metrics = defaultdict(float)
  233. epoch_samples = 0
  234.  
  235. for inputs, labels in dataloaders[phase]:
  236. inputs = inputs.to(device)
  237. labels = labels.to(device)
  238.  
  239. # zero the parameter gradients
  240. optimizer.zero_grad()
  241.  
  242. # forward
  243. # track history if only in train
  244. with torch.set_grad_enabled(phase == 'train'):
  245. outputs = model(inputs)
  246. loss = calc_loss(outputs, labels, metrics)
  247.  
  248. # backward + optimize only if in training phase
  249. if phase == 'train':
  250. loss.backward()
  251. optimizer.step()
  252.  
  253. # statistics
  254. epoch_samples += inputs.size(0)
  255.  
  256. print_metrics(metrics, epoch_samples, phase)
  257. epoch_loss = metrics['loss'] / epoch_samples
  258.  
  259. # deep copy the model
  260. if phase == 'val' and epoch_loss < best_loss:
  261. print("saving best model")
  262. best_loss = epoch_loss
  263. best_model_wts = copy.deepcopy(model.state_dict())
  264.  
  265. time_elapsed = time.time() - since
  266. print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
  267. print('Best val loss: {:4f}'.format(best_loss))
  268.  
  269. # load best model weights
  270. model.load_state_dict(best_model_wts)
  271. return model
  272.  
  273.  
  274. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  275. print(device)
  276.  
  277. num_class = 2
  278.  
  279. model = UNet(num_class).to(device)
  280.  
  281. # Observe that all parameters are being optimized
  282. optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)
  283.  
  284. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)
  285.  
  286. model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=2)
  287.  
  288. model.eval()
  289.  
  290. for inputs, labels in dataloaders['train']:
  291. inputs = inputs.to(device)
  292. labels = labels.to(device)
  293.  
  294. with torch.no_grad():
  295. res = model(inputs)
  296. break
  297.  
  298. print('TRAININ WITH CONVCRF')
  299.  
  300. model.postprocessing_state(True)
  301.  
  302. model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=2)
  303.  
  304. torch.save(model.state_dict(), 'model_artif.torch')
  305.  
  306. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  307. print(device)
  308.  
  309. num_class = 2
  310. model = UNet(num_class).to(device)
  311. model.load_state_dict(torch.load('model_artif.torch'))
  312. model.eval()
  313.  
  314.  
  315. def reverse_transform(inp):
  316. inp = inp.numpy().transpose((1, 2, 0))
  317. inp = np.clip(inp, 0, 1)
  318. inp = (inp * 255).astype(np.uint8)
  319.  
  320. return inp
  321.  
  322.  
  323. model.eval() # Set model to evaluate mode
  324.  
  325. test_dataset = SimDataset(3, transform=trans)
  326. test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)
  327.  
  328. inputs, labels = next(iter(test_loader))
  329. inputs = inputs.to(device)
  330. labels = labels.to(device)
  331.  
  332. pred = model(inputs)
  333.  
  334. pred = pred.data.cpu().numpy()
  335. print(pred.shape)
  336.  
  337. input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]
  338.  
  339. target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]
  340. pred_rgb = [helper.masks_to_colorimg(x) for x in pred]
  341.  
  342. helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement