Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- logger = TensorBoardLogger('./tb_logs/segmentation', name='segnet small v2 (--) 53k with graphs')
- def bce_loss(y_pred, y_true):
- return torch.where(y_pred < 0, 0, y_pred) - y_pred * y_true + torch.log(1 + torch.exp(-torch.abs(y_pred)))
- class SegNet(L.LightningModule):
- def __init__(self, *args, **kwargs):
- super().__init__()
- self.model = SegNet_small()
- self.loss_fn = bce_loss
- self.IoU = torchmetrics.classification.MulticlassJaccardIndex(num_classes=2)
- def training_step(self, batch, batch_ind):
- X, y = batch
- outputs = self.model(X)
- loss = self.loss_fn(outputs, y)
- y_pred = torch.where(outputs < 0, 0, 1)
- self.log_dict({'train_loss': loss.mean(), 'train_IoU': self.IoU(y_pred, y)}, prog_bar=True, on_step=False, on_epoch=True)
- return loss.mean()
- def validation_step(self, batch, batch_ind):
- X, y = batch
- outputs = self.model(X)
- loss = self.loss_fn(outputs, y)
- y_pred = torch.where(outputs < 0, 0, 1)
- self.log_dict({'val_loss': loss.mean(), 'val_IoU': self.IoU(y_pred, y)}, prog_bar=True, on_step=False, on_epoch=True)
- if batch_ind % 100 == 0:
- self.visualise_metrics(X, y, y_pred)
- return loss.mean()
- def configure_optimizers(self):
- return torch.optim.Adam(self.parameters(), lr=1e-3)
- def visualise_metrics(self, X, y, y_pred):
- # Visualize tools
- # clear_output(wait=True)
- X, y, y_pred = X.cpu(), y.cpu(), y_pred.cpu()
- for k in range(6):
- plt.subplot(3, 6, k+1)
- plt.imshow(np.rollaxis(X[k].numpy(), 0, 3), cmap='gray')
- plt.title('Real')
- plt.axis('off')
- plt.subplot(3, 6, k+7)
- plt.imshow(y_pred[k, 0], cmap='gray')
- plt.title('Output')
- plt.axis('off')
- plt.subplot(3, 6, k+13)
- plt.imshow(y[k, 0], cmap='gray')
- plt.title('Output')
- plt.axis('off')
- import io
- from torchvision.transforms import ToTensor
- buffer = io.BytesIO()
- plt.savefig(buffer, format='png')
- buffer.seek(0)
- from PIL import Image
- image = Image.open(buffer)
- image_tensor = ToTensor()(image)
- self.logger.experiment.add_image("Predictions at val_loader", image_tensor, self.global_step)
- buffer.close()
- trainer = L.Trainer(max_epochs=100, logger=logger, log_every_n_steps=1)
- trainer.fit(model=SegNet(), train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement