Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- train_ds = Dataset(data=train_dicts, transform=train_transforms)
- train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4)
- val_ds = Dataset(data=valid_dicts, transform=val_transforms)
- val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
- post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
- device = torch.device("cuda")
- model = Unet().to(device)
- loss_function = monai.losses.DiceCELoss(to_onehot_y=True, softmax=True)
- optimizer = torch.optim.Adam(model.parameters(), 1e-4)
- val_interval = 2
- best_metric = -1
- best_metric_epoch = -1
- epoch_loss_values = list()
- metric_values = list()
- post_pred = AsDiscrete(argmax=True, to_onehot=True, n_classes=2)
- post_label = AsDiscrete(to_onehot=True, n_classes=2)
- for epoch in range(100):
- print("-" * 10)
- print(f"epoch {epoch + 1}/{100}")
- model.train()
- epoch_loss = 0
- step = 0
- for batch_data in train_loader:
- step += 1
- inputs, labels = (batch_data["image"].to(device), batch_data["label"].to(device),)
- #print("inputs: ", inputs.shape)
- #print("labels: ", labels.shape)
- optimizer.zero_grad()
- outputs = model(inputs)
- #loss = loss_function(outputs, labels)
- loss = model.sample_elbo(inputs=inputs,
- labels=labels,
- criterion=loss_function,
- sample_nbr=3,
- complexity_cost_weight=1/50000)
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- epoch_len = len(train_ds) // train_loader.batch_size
- #print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
- epoch_loss /= step
- epoch_loss_values.append(epoch_loss)
- print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
- if (epoch + 1) % val_interval == 0:
- model.eval()
- with torch.no_grad():
- metric_sum = 0.0
- metric_count = 0
- val_images = None
- val_labels = None
- val_outputs = None
- for val_data in val_loader:
- val_images, val_labels = (val_data["image"].to(device), val_data["label"].to(device),)
- roi_size = (128, 128, 32)
- sw_batch_size = 4
- val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
- #print(val_outputs)
- #print(val_outputs.shape)
- val_outputs = post_pred(val_outputs)
- val_labels = post_label(val_labels)
- value = compute_meandice(
- y_pred=val_outputs,
- y=val_labels,
- include_background=False,
- )
- metric_count += len(value)
- metric_sum += value.sum().item()
- metric = metric_sum / metric_count
- metric_values.append(metric)
- if metric > best_metric:
- best_metric = metric
- best_metric_epoch = epoch + 1
- #torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model.pth"))
- print("saved new best metric model")
- print(
- f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
- f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
- )
- print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement