Advertisement
Guest User

Untitled

a guest
Apr 18th, 2024
16
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.66 KB | None | 0 0
  1. logger = TensorBoardLogger('./tb_logs/segmentation', name='segnet small v2 (--) 53k with graphs')
  2.  
  3. def bce_loss(y_pred, y_true):
  4. return torch.where(y_pred < 0, 0, y_pred) - y_pred * y_true + torch.log(1 + torch.exp(-torch.abs(y_pred)))
  5.  
  6. class SegNet(L.LightningModule):
  7. def __init__(self, *args, **kwargs):
  8. super().__init__()
  9. self.model = SegNet_small()
  10. self.loss_fn = bce_loss
  11. self.IoU = torchmetrics.classification.MulticlassJaccardIndex(num_classes=2)
  12.  
  13. def training_step(self, batch, batch_ind):
  14. X, y = batch
  15. outputs = self.model(X)
  16. loss = self.loss_fn(outputs, y)
  17.  
  18. y_pred = torch.where(outputs < 0, 0, 1)
  19. self.log_dict({'train_loss': loss.mean(), 'train_IoU': self.IoU(y_pred, y)}, prog_bar=True, on_step=False, on_epoch=True)
  20.  
  21. return loss.mean()
  22.  
  23. def validation_step(self, batch, batch_ind):
  24. X, y = batch
  25. outputs = self.model(X)
  26. loss = self.loss_fn(outputs, y)
  27.  
  28. y_pred = torch.where(outputs < 0, 0, 1)
  29. self.log_dict({'val_loss': loss.mean(), 'val_IoU': self.IoU(y_pred, y)}, prog_bar=True, on_step=False, on_epoch=True)
  30.  
  31. if batch_ind % 100 == 0:
  32. self.visualise_metrics(X, y, y_pred)
  33.  
  34. return loss.mean()
  35.  
  36. def configure_optimizers(self):
  37. return torch.optim.Adam(self.parameters(), lr=1e-3)
  38.  
  39. def visualise_metrics(self, X, y, y_pred):
  40. # Visualize tools
  41. # clear_output(wait=True)
  42. X, y, y_pred = X.cpu(), y.cpu(), y_pred.cpu()
  43. for k in range(6):
  44. plt.subplot(3, 6, k+1)
  45. plt.imshow(np.rollaxis(X[k].numpy(), 0, 3), cmap='gray')
  46. plt.title('Real')
  47. plt.axis('off')
  48.  
  49. plt.subplot(3, 6, k+7)
  50. plt.imshow(y_pred[k, 0], cmap='gray')
  51. plt.title('Output')
  52. plt.axis('off')
  53.  
  54. plt.subplot(3, 6, k+13)
  55. plt.imshow(y[k, 0], cmap='gray')
  56. plt.title('Output')
  57. plt.axis('off')
  58.  
  59. import io
  60. from torchvision.transforms import ToTensor
  61.  
  62. buffer = io.BytesIO()
  63. plt.savefig(buffer, format='png')
  64. buffer.seek(0)
  65. from PIL import Image
  66. image = Image.open(buffer)
  67. image_tensor = ToTensor()(image)
  68.  
  69. self.logger.experiment.add_image("Predictions at val_loader", image_tensor, self.global_step)
  70.  
  71. buffer.close()
  72.  
  73. trainer = L.Trainer(max_epochs=100, logger=logger, log_every_n_steps=1)
  74. trainer.fit(model=SegNet(), train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement