Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import simulation
- import helper
- import copy
- import time
- from collections import defaultdict
- from torch.utils.data import Dataset, DataLoader
- from torchvision import transforms, datasets, models
- import torch.nn.functional as F
- import torch.nn as nn
- import torch
- import torch.optim as optim
- from torch.optim import lr_scheduler
- import numpy as np
- from convcrf import convcrf
- import ipdb
- import matplotlib.pyplot as plt
- input_images, target_masks = simulation.generate_random_data(320, 320, count=3)
- class SimDataset(Dataset):
- def __init__(self, count, transform=None):
- self.input_images, self.target_masks = simulation.generate_random_data(320, 320, count=count)
- self.transform = transform
- def __len__(self):
- return len(self.input_images)
- def __getitem__(self, idx):
- image = self.input_images[idx]
- mask = self.target_masks[idx]
- if self.transform:
- image = self.transform(image)
- return [image, mask]
- trans = transforms.Compose([
- transforms.ToTensor(),
- ])
- train_set = SimDataset(2000, transform = trans)
- val_set = SimDataset(200, transform = trans)
- # train_set = SimDataset(4, transform=trans)
- # val_set = SimDataset(2, transform=trans)
- image_datasets = {
- 'train': train_set, 'val': val_set
- }
- batch_size = 25
- batch_size = 1
- dataloaders = {
- 'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
- 'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
- }
- dataset_sizes = {
- x: len(image_datasets[x]) for x in image_datasets.keys()
- }
- # Generate some random images
- input_images, target_masks = simulation.generate_random_data(320, 320, count=3)
- # target_masks = target_masks[:, :2, :, :]
- for x in [input_images, target_masks]:
- print(x.shape)
- print(x.min(), x.max())
- # Change channel-order and make 3 channels for matplot
- input_images_rgb = [x.astype(np.uint8) for x in input_images]
- # Map each channel (i.e. class) to each color
- target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]
- # Left: Input image, Right: Target mask (Ground-truth)
- helper.plot_side_by_side([input_images_rgb, target_masks_rgb])
- def double_conv(in_channels, out_channels):
- return nn.Sequential(
- nn.Conv2d(in_channels, out_channels, 3, padding=1),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, 3, padding=1),
- nn.ReLU(inplace=True)
- )
- class UNet(nn.Module):
- def __init__(self, n_class):
- super().__init__()
- self.dconv_down1 = double_conv(3, 64)
- self.dconv_down2 = double_conv(64, 128)
- self.dconv_down3 = double_conv(128, 256)
- self.dconv_down4 = double_conv(256, 512)
- self.maxpool = nn.MaxPool2d(2)
- self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
- self.dconv_up3 = double_conv(256 + 512, 256)
- self.dconv_up2 = double_conv(128 + 256, 128)
- self.dconv_up1 = double_conv(128 + 64, 64)
- self.conv_last = nn.Conv2d(64, n_class, 1)
- shape = (320, 320)
- config = convcrf.default_conf
- config['pyinn'] = False
- config['trainable'] = True
- config['trainable_bias'] = True
- self.convcrf = convcrf.GaussCRF(conf=config, shape=shape, nclasses=n_class)
- self.postprocessing = False
- def forward(self, x):
- x_origin = x
- conv1 = self.dconv_down1(x)
- x = self.maxpool(conv1)
- conv2 = self.dconv_down2(x)
- x = self.maxpool(conv2)
- conv3 = self.dconv_down3(x)
- x = self.maxpool(conv3)
- x = self.dconv_down4(x)
- x = self.upsample(x)
- x = torch.cat([x, conv3], dim=1)
- x = self.dconv_up3(x)
- x = self.upsample(x)
- x = torch.cat([x, conv2], dim=1)
- x = self.dconv_up2(x)
- x = self.upsample(x)
- x = torch.cat([x, conv1], dim=1)
- x = self.dconv_up1(x)
- out_x = self.conv_last(x)
- if self.postprocessing:
- # out_x = torch.clamp(out_x, 0, 1)
- # ipdb.set_trace()
- out_x = self.convcrf(out_x, x_origin)
- return out_x
- def postprocessing_state(self, is_enamble=False):
- self.postprocessing = is_enamble
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- model = UNet(2)
- model = model.to(device)
- model.eval()
- for inputs, labels in dataloaders['train']:
- inputs = inputs.to(device)
- labels = labels.to(device)
- with torch.no_grad():
- res = model(inputs)
- break
- def dice_loss(pred, target, smooth=1.):
- pred = pred.contiguous()
- target = target.contiguous()
- intersection = (pred * target).sum(dim=2).sum(dim=2)
- loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
- return loss.mean()
- def calc_loss(pred, target, metrics, bce_weight=0.5):
- bce = F.binary_cross_entropy_with_logits(pred, target)
- pred = torch.sigmoid(pred)
- dice = dice_loss(pred, target)
- loss = bce * bce_weight + dice * (1 - bce_weight)
- metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
- metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
- metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
- return loss
- def print_metrics(metrics, epoch_samples, phase):
- outputs = []
- for k in metrics.keys():
- outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
- print("{}: {}".format(phase, ", ".join(outputs)))
- def train_model(model, optimizer, scheduler, num_epochs=25):
- best_model_wts = copy.deepcopy(model.state_dict())
- best_loss = 1e10
- for epoch in range(num_epochs):
- print('Epoch {}/{}'.format(epoch, num_epochs - 1))
- print('-' * 10)
- since = time.time()
- # Each epoch has a training and validation phase
- for phase in ['train', 'val']:
- if phase == 'train':
- scheduler.step()
- for param_group in optimizer.param_groups:
- print("LR", param_group['lr'])
- model.train() # Set model to training mode
- else:
- model.eval() # Set model to evaluate mode
- metrics = defaultdict(float)
- epoch_samples = 0
- for inputs, labels in dataloaders[phase]:
- inputs = inputs.to(device)
- labels = labels.to(device)
- # zero the parameter gradients
- optimizer.zero_grad()
- # forward
- # track history if only in train
- with torch.set_grad_enabled(phase == 'train'):
- outputs = model(inputs)
- loss = calc_loss(outputs, labels, metrics)
- # backward + optimize only if in training phase
- if phase == 'train':
- loss.backward()
- optimizer.step()
- # statistics
- epoch_samples += inputs.size(0)
- print_metrics(metrics, epoch_samples, phase)
- epoch_loss = metrics['loss'] / epoch_samples
- # deep copy the model
- if phase == 'val' and epoch_loss < best_loss:
- print("saving best model")
- best_loss = epoch_loss
- best_model_wts = copy.deepcopy(model.state_dict())
- time_elapsed = time.time() - since
- print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
- print('Best val loss: {:4f}'.format(best_loss))
- # load best model weights
- model.load_state_dict(best_model_wts)
- return model
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- print(device)
- num_class = 2
- model = UNet(num_class).to(device)
- # Observe that all parameters are being optimized
- optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)
- exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)
- model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=2)
- model.eval()
- for inputs, labels in dataloaders['train']:
- inputs = inputs.to(device)
- labels = labels.to(device)
- with torch.no_grad():
- res = model(inputs)
- break
- print('TRAININ WITH CONVCRF')
- model.postprocessing_state(True)
- model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=2)
- torch.save(model.state_dict(), 'model_artif.torch')
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- print(device)
- num_class = 2
- model = UNet(num_class).to(device)
- model.load_state_dict(torch.load('model_artif.torch'))
- model.eval()
- def reverse_transform(inp):
- inp = inp.numpy().transpose((1, 2, 0))
- inp = np.clip(inp, 0, 1)
- inp = (inp * 255).astype(np.uint8)
- return inp
- model.eval() # Set model to evaluate mode
- test_dataset = SimDataset(3, transform=trans)
- test_loader = DataLoader(test_dataset, batch_size=3, shuffle=False, num_workers=0)
- inputs, labels = next(iter(test_loader))
- inputs = inputs.to(device)
- labels = labels.to(device)
- pred = model(inputs)
- pred = pred.data.cpu().numpy()
- print(pred.shape)
- input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]
- target_masks_rgb = [helper.masks_to_colorimg(x) for x in labels.cpu().numpy()]
- pred_rgb = [helper.masks_to_colorimg(x) for x in pred]
- helper.plot_side_by_side([input_images_rgb, target_masks_rgb, pred_rgb])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement