Advertisement
Guest User

loop

a guest
Aug 17th, 2021
35
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.48 KB | None | 0 0
  1. train_ds = Dataset(data=train_dicts, transform=train_transforms)
  2. train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4)
  3.  
  4. val_ds = Dataset(data=valid_dicts, transform=val_transforms)
  5. val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
  6. post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
  7.  
  8.  
  9. device = torch.device("cuda")
  10.  
  11. model = Unet().to(device)
  12. loss_function = monai.losses.DiceCELoss(to_onehot_y=True, softmax=True)
  13. optimizer = torch.optim.Adam(model.parameters(), 1e-4)
  14.  
  15. val_interval = 2
  16. best_metric = -1
  17. best_metric_epoch = -1
  18. epoch_loss_values = list()
  19. metric_values = list()
  20. post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
  21. post_label = AsDiscrete(to_onehot=True, n_classes=2)
  22. for epoch in range(100):
  23.     print("-" * 10)
  24.     print(f"epoch {epoch + 1}/{100}")
  25.     model.train()
  26.     epoch_loss = 0
  27.     step = 0
  28.     for batch_data in train_loader:
  29.         step += 1
  30.         inputs, labels = (batch_data["image"].to(device), batch_data["label"].to(device),)
  31.         #print("inputs: ", inputs.shape)
  32.         #print("labels: ", labels.shape)
  33.         optimizer.zero_grad()
  34.         outputs = model(inputs)
  35.         #loss = loss_function(outputs, labels)
  36.         loss = model.sample_elbo(inputs=inputs,
  37.                            labels=labels,
  38.                            criterion=loss_function,
  39.                            sample_nbr=3,
  40.                            complexity_cost_weight=1/50000)
  41.         loss.backward()
  42.         optimizer.step()
  43.         epoch_loss += loss.item()
  44.         epoch_len = len(train_ds) // train_loader.batch_size
  45.         #print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
  46.     epoch_loss /= step
  47.     epoch_loss_values.append(epoch_loss)
  48.     print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
  49.  
  50.     if (epoch + 1) % val_interval == 0:
  51.         model.eval()
  52.         with torch.no_grad():
  53.             metric_sum = 0.0
  54.             metric_count = 0
  55.             val_images = None
  56.             val_labels = None
  57.             val_outputs = None
  58.             for val_data in val_loader:
  59.                 val_images, val_labels = (val_data["image"].to(device), val_data["label"].to(device),)
  60.                 roi_size = (128, 128, 32)
  61.                 sw_batch_size = 4
  62.                 val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
  63.                 #print(val_outputs)
  64.                 #print(val_outputs.shape)
  65.                 val_outputs = post_pred(val_outputs)
  66.                 val_labels = post_label(val_labels)
  67.                 value = compute_meandice(
  68.                     y_pred=val_outputs,
  69.                     y=val_labels,
  70.                     include_background=False,
  71.                 )
  72.                 metric_count += len(value)
  73.                 metric_sum += value.sum().item()
  74.             metric = metric_sum / metric_count
  75.             metric_values.append(metric)
  76.             if metric > best_metric:
  77.                 best_metric = metric
  78.                 best_metric_epoch = epoch + 1
  79.                 #torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
  80.                 print("saved new best metric model")
  81.             print(
  82.                 f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
  83.                 f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
  84.             )
  85.  
  86. print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement