Advertisement
NikitaYak

Untitled

May 8th, 2024
558
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 16.57 KB | None | 0 0
  1. from pathlib import Path
  2.  
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torch.optim import lr_scheduler
  7. from torch.utils.data import DataLoader, TensorDataset
  8. from torchvision import models, transforms
  9. from torchvision.transforms import v2, GaussianBlur
  10. from tqdm import tqdm
  11. import numpy as np
  12. import time
  13. from efficientnet_pytorch import EfficientNet
  14. from nets_parts.datasets.psi_dataset_regions_cls import (
  15.     PSIRegionClsDataset,
  16.     PSIRegionClsDatasetParams,
  17. )
  18. from nets_parts.datasets.psi_dataset_random import (
  19.     PSIRandomDataset,
  20.     PSIRandomDatasetParams,
  21. )
  22. from nets_parts.nets_train_part import (
  23.     run_valid, run_train,
  24.     save_confusion_matrix,
  25.     get_data_iterator, get_data_iterator_pl
  26. )
  27. from torchvision.transforms import RandAugment, ToPILImage
  28. from torchvision.transforms.functional import pil_to_tensor
  29. import gc
  30. from nets_parts.datasets.psi_torch_dataset import TorchPSIDataset
  31. from nets_parts.datasets.utils import low_entropy_filter
  32. from torch.utils.tensorboard import SummaryWriter
  33. import sys, json, os
  34. from os.path import isdir, join
  35. from nets_parts.datasets.utils import low_entropy_filter_single_input_with_prob
  36. sys.path.append(os.path.dirname(os.path.realpath(__file__)))
  37. print(os.path.dirname(os.path.realpath(__file__)))
  38. from nets_parts.RandAugment import RandAugment as MyRandAugment
  39. THRESHOLD = 0.95
  40.  
  41. def train_model(
  42.     model,
  43.     criterion,
  44.     optimizer,
  45.     scheduler,
  46.     iterations_per_epoch,
  47.     n_epochs,
  48.     valid_loader,
  49.     patch_size,
  50.     layer,
  51.     batch_size_train,
  52.     valid_image_num,
  53.     checkpoint_path,
  54.     cur_statistic_path,
  55.     rand_aug_strange,
  56.     rand_aug_max_strange,
  57.     n_ops,
  58.     epoch_for_start_pl,
  59.     use_all_data,
  60.     writer,
  61.     use_my_rand_aug,
  62.     images_to_use
  63. ):
  64.     global_train_loss = []
  65.     global_train_acc = []
  66.     global_valid_loss = []
  67.     global_valid_acc = []
  68.     best_train_acc = None
  69.     best_valid_acc = None
  70.     train_data_pseudo_label = None
  71.     train_ds = None
  72.    
  73.     data_pl_iterator, train_data_pseudo_label = get_data_iterator_pl(
  74.         path_to_data="/home/n.yakovlev/datasets/symblink/WSS2/train_valid",
  75.         train_ds=train_data_pseudo_label,
  76.         layer=layer,
  77.         patch_size=patch_size,
  78.         batch_size=64
  79.     )
  80.     generated_images: int = 0
  81.     if use_my_rand_aug:
  82.         RandAugmentator = MyRandAugment(ops_num=n_ops, cur_value=rand_aug_strange, max_value=rand_aug_max_strange)
  83.     else:
  84.         RandAugmentator = RandAugment(num_ops=n_ops, magnitude=rand_aug_strange, num_magnitude_bins=rand_aug_max_strange, fill=255)
  85.     using_labels_in_cur_epoch = []
  86.     for epoch in range(n_epochs):
  87.         model.train()
  88.         if epoch % images_to_use == 0:
  89.             data_iterator, train_ds = get_data_iterator(
  90.                 path_to_data=f"/home/n.yakovlev/datasets/symblink/WSS2/train_valid_{images_to_use}",
  91.                 train_ds=train_ds,
  92.                 layer=layer,
  93.                 patch_size=patch_size,
  94.                 batch_size=64
  95.             )
  96.         print(f"Epoch {epoch}/{n_epochs - 1}")
  97.         print("-" * 10)
  98.  
  99.         # Each epoch has a training and validation phase
  100.         model.train()  # Set model to training mode
  101.         torch.set_grad_enabled(True)
  102.  
  103.         running_loss = 0.0
  104.         running_corrects = 0
  105.         using_labels_in_cur_epoch.append({k: int(0) for k in range(5)})
  106.  
  107.         # Iterate over data.
  108.         for _ in tqdm(
  109.             range(iterations_per_epoch), f"running epoch {epoch + 1}"
  110.         ):
  111.             model.eval()
  112.             torch.set_grad_enabled(False)
  113.             inputs_pl = None
  114.             labels_pl = None
  115.             cycle_try_to_get_pl = 0
  116.             while epoch >= epoch_for_start_pl and (inputs_pl is None or inputs_pl.size(dim=0) < batch_size_train * 7):
  117.                 cycle_try_to_get_pl += 1
  118.                 gc.collect()
  119.                 torch.cuda.empty_cache()
  120.                 if inputs_pl is None and cycle_try_to_get_pl > 10:
  121.                     break
  122.                 if generated_images > train_data_pseudo_label.__len__() - 10 * batch_size_train:
  123.                     generated_images = 0
  124.                     data_pl_iterator, train_data_pseudo_label = get_data_iterator_pl(
  125.                         path_to_data="/home/n.yakovlev/datasets/symblink/WSS2/train_valid",
  126.                         train_ds=train_data_pseudo_label,
  127.                         layer=layer,
  128.                         patch_size=patch_size,
  129.                         batch_size=64
  130.                     )
  131.                
  132.                 images_cur = next(data_pl_iterator).to(device)
  133.                 generated_images += images_cur.size(dim=0)
  134.                 labels_cur = model(images_cur)
  135.                 if (labels_cur.max(dim=1).values > THRESHOLD).sum().item() < 10:
  136.                     continue
  137.                 # get pseudo labels with threshold
  138.                 images_cur = images_cur[labels_cur.max(dim=1).values > THRESHOLD]
  139.                 labels_cur = labels_cur[labels_cur.max(dim=1).values > THRESHOLD]
  140.                 labels_cur[labels_cur > THRESHOLD] = 1.0
  141.                 labels_cur[labels_cur < 1 - 1e-5] = 0
  142.                 if inputs_pl is None:
  143.                     inputs_pl = images_cur
  144.                     labels_pl = labels_cur
  145.                 else:
  146.                     inputs_pl = torch.cat((inputs_pl, images_cur), dim=0)
  147.                     labels_pl = torch.cat((labels_pl, labels_cur), dim=0)
  148.             model.train()  # Set model to training mode
  149.             torch.set_grad_enabled(True)
  150.             if not (inputs_pl is None or labels_pl is None):
  151.                 inputs_pl, labels_pl = inputs_pl[:7 * batch_size_train], labels_pl[:7 * batch_size_train]
  152.             if inputs_pl is None or labels_pl is None or inputs_pl.size(0) == 0:
  153.                 inputs, labels = next(data_iterator)
  154.                 inputs = inputs.to(device)
  155.                 labels = labels.to(device)
  156.  
  157.                 # zero the parameter gradients
  158.                 optimizer.zero_grad()
  159.  
  160.                 outputs = model(inputs)
  161.                 loss = criterion(outputs, labels)
  162.  
  163.                 loss.backward()
  164.                 optimizer.step()
  165.  
  166.                 # statistics
  167.                 preds = torch.argmax(outputs, dim=1)
  168.                 gts = torch.argmax(labels, dim=1)
  169.                 running_loss += loss.item()
  170.                 running_corrects += torch.sum(preds == gts)
  171.             else:
  172.                 uniq_values, uniq_counts = np.unique(labels_pl.argmax(1).cpu().numpy(), return_counts=True)
  173.                 for ind, i in zip(uniq_values, uniq_counts):
  174.                     using_labels_in_cur_epoch[epoch][ind] += int(i)
  175.                 for image in inputs_pl:
  176.                     if use_my_rand_aug:
  177.                         temp_img = (image * 255).to(torch.uint8)
  178.                         temp_img = ToPILImage()(temp_img)
  179.                         temp_img = RandAugmentator(temp_img)
  180.                         image = pil_to_tensor(temp_img) / 255
  181.                     else:
  182.                         image = (RandAugmentator((image * 255).to(torch.uint8))).to(torch.float32)
  183.                         if image.max() > 1.1:
  184.                             image /= 255
  185.                 # Each epoch has a training and validation phase
  186.                 inputs, labels = next(data_iterator)
  187.                 labels_pl = labels_pl.to(labels.dtype)
  188.                 inputs = inputs.to(device)
  189.                 labels = labels.to(device)
  190.                 inputs_pl = inputs_pl.to(device)
  191.                 labels_pl = labels_pl.to(device)
  192.  
  193.                 # zero the parameter gradients
  194.                 optimizer.zero_grad()
  195.  
  196.                 outputs = model(inputs)
  197.                 loss_l = criterion(outputs, labels)
  198.  
  199.                 outputs_pl = model(inputs_pl)
  200.                 loss_u = criterion(outputs_pl, labels_pl)
  201.  
  202.                 loss = loss_l + 1. / 7 * loss_u
  203.                 loss.backward()
  204.                 optimizer.step()
  205.  
  206.                 # statistics
  207.                 preds = torch.argmax(outputs, dim=1)
  208.                 gts = torch.argmax(labels, dim=1)
  209.                 running_loss += loss.item()
  210.                 running_corrects += torch.sum(preds == gts)
  211.             gc.collect()
  212.             torch.cuda.empty_cache()
  213.  
  214.         if epoch >= 10:
  215.             scheduler.step()
  216.  
  217.         epoch_loss = running_loss / iterations_per_epoch
  218.         epoch_acc = running_corrects.float() / (
  219.             iterations_per_epoch * inputs.size(0)
  220.         )
  221.         global_train_acc.append(float(epoch_acc))
  222.         global_train_loss.append(float(epoch_loss))
  223.         if best_train_acc is None or best_train_acc < epoch_acc:
  224.             torch.save(model.state_dict(), f"{checkpoint_path}/best_train_acc_{round(float(epoch_acc) * 100)}.pth")
  225.             best_train_acc = epoch_acc
  226.  
  227.         print(f"Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
  228.         writer.add_scalar("Loss", epoch_loss, epoch)
  229.         writer.add_scalar("Acc", epoch_acc, epoch)
  230.  
  231.         matrix_conf, epoch_loss, epoch_acc = run_valid(
  232.             model=model,
  233.             valid_loader=valid_loader,
  234.             valid_images_num=valid_image_num,
  235.             criterion=criterion,
  236.             device=device
  237.         )
  238.         global_valid_acc.append(float(epoch_acc))
  239.         global_valid_loss.append(float(epoch_loss))
  240.         if best_valid_acc is None or best_valid_acc < epoch_acc:
  241.             torch.save(model.state_dict(), f"{checkpoint_path}/best_valid_acc_{round(float(epoch_acc) * 100)}.pth")
  242.             best_valid_acc = epoch_acc
  243.         print(f"Val Loss: {epoch_loss:.4f} Val Acc: {epoch_acc:.4f}")
  244.         writer.add_scalar("Loss valid", epoch_loss, epoch)
  245.         writer.add_scalar("Accuracy valid", epoch_acc, epoch)
  246.         if epoch % 1 == 0:
  247.             with open(f"{cur_statistic_path}/statistic.json", "w") as f:
  248.                 json.dump(
  249.                     {
  250.                         "train_loss": global_train_loss,
  251.                         "train_acc": global_train_acc,
  252.                         "valid_loss": global_valid_loss,
  253.                         "valid_acc": global_valid_acc,
  254.                     },
  255.                     f
  256.                 )
  257.             with open(f"{cur_statistic_path}/using_pl.json", "w") as f:
  258.                 json.dump(
  259.                     using_labels_in_cur_epoch,
  260.                     f
  261.                 )
  262.             path_to_save_conf_matrix = f"{cur_statistic_path}/confusion_matrix_best_epoch.png"
  263.             save_confusion_matrix(matrix_conf, path_to_save_conf_matrix)
  264.     if not (train_data_pseudo_label is None):
  265.         train_data_pseudo_label.close()
  266.     if not (train_ds is None):
  267.         train_ds.close()
  268.     train_data_pseudo_label = None
  269.     train_ds = None
  270.  
  271.  
  272. if __name__ == "__main__":
  273.     """
  274.    {
  275.        "batch_size_train": 192,
  276.        "batch_size_valid": 512,
  277.        "lr": 0.01,
  278.        "exprement_path": "/home/n.yakovlev/my_best_program/diplom_8sem/experiments/classifier",
  279.  
  280.        "patch_size": 224,
  281.        "layer": 2,
  282.        "scheduler_param": 0.99,
  283.        "pretrain": 1,
  284.  
  285.        "device": 1,
  286.  
  287.        "is_efficientnet": 1
  288.        "nn_name": "efficientnet-b4"
  289.        "iter_per_epoch": 26
  290.    }
  291. """
  292.     path_to_params = sys.argv[1]
  293.     with open(path_to_params) as f:
  294.         parsed_file = f.read()
  295.         parsed_json = json.loads(parsed_file)
  296.         BATCH_SIZE_TRAIN = int(parsed_json["batch_size_train"])
  297.         BATCH_SIZE_VALID = int(parsed_json["batch_size_valid"])
  298.         LEARING_RATE = float(parsed_json["lr"])
  299.         LAYER = int(parsed_json["layer"])
  300.  
  301.         PRETRAIN = bool(parsed_json["pretrain"])
  302.         EXP_PATH = str(parsed_json["exprement_path"])
  303.         os.makedirs(EXP_PATH, exist_ok=True)
  304.         PATCH_SIZE = int(parsed_json["patch_size"])
  305.         SCHEDULER_PARAM = float(parsed_json["scheduler_param"])
  306.  
  307.         IS_EFFICIENTNET = bool(parsed_json["is_efficientnet"])
  308.         NN_NAME = str(parsed_json["nn_name"])
  309.         ITER_PER_EPOCH = int(parsed_json["iter_per_epoch"])
  310.         rand_aug_strange = int(parsed_json["rand_aug_strange"])
  311.         rand_aug_max_strange = int(parsed_json["rand_aug_max_strange"])
  312.         n_ops = int(parsed_json["n_ops"])
  313.         device = f"cuda:{parsed_json["device"]}"
  314.         epoch_for_start_pl = int(parsed_json["epoch_for_start_pl"])
  315.         use_all_data = bool(parsed_json["use_all_data"])
  316.         use_my_rand_aug = bool(parsed_json["use_my_rand_aug"])
  317.         images_to_use = int(parsed_json["images_to_use"])
  318.    
  319.     data = np.load("/home/n.yakovlev/datasets/test_files_WSS2.npz")
  320.     images = data["images"]
  321.     valid_image_num = images.shape[0]
  322.     labels = data["labels"]
  323.  
  324.     del data
  325.     gc.collect()
  326.     valid_dataset = TensorDataset(torch.tensor(images), torch.tensor(labels))
  327.     del images, labels
  328.     valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE_VALID, num_workers=8)
  329.  
  330.     for use_my_rand_aug in [False]:
  331.         for rand_aug_strange in [4, 5, 6, 7, 8, 9, 10, 11, 12, 13]:
  332.             rand_aug_max_strange = 31
  333.             n_ops = int(parsed_json["n_ops"])
  334.             path_to_params = sys.argv[1]
  335.             with open(path_to_params) as f:
  336.                 parsed_file = f.read()
  337.                 parsed_json = json.loads(parsed_file)
  338.                 parsed_json["rand_aug_strange"] = rand_aug_strange
  339.                 parsed_json["use_my_rand_aug"] = use_my_rand_aug
  340.  
  341.                 EXP_PATH = str(parsed_json["exprement_path"])
  342.                 os.makedirs(EXP_PATH, exist_ok=True)
  343.                 onlydirs = [f for f in os.listdir(EXP_PATH) if isdir(join(EXP_PATH, f))]
  344.                 cur_exp_path = f"{EXP_PATH}/{len(onlydirs)}"
  345.                 os.makedirs(cur_exp_path)
  346.                 cur_statistic_path = f"{cur_exp_path}/statistic"
  347.                 os.makedirs(cur_statistic_path)
  348.                 checkpoint_path = cur_exp_path + "/checkpoints"
  349.                 os.makedirs(checkpoint_path)
  350.                 writer = SummaryWriter(EXP_PATH)
  351.                 with open(f"{cur_exp_path}/params.json", 'w') as f:
  352.                     json.dump(parsed_json, f)
  353.    
  354.             if IS_EFFICIENTNET:
  355.                 if PRETRAIN:
  356.                     model_ft = EfficientNet.from_pretrained(NN_NAME, num_classes=5)
  357.                 else:
  358.                     model_ft = EfficientNet.from_name(NN_NAME, num_classes=5)
  359.             else:
  360.                 model_ft = models.resnet50(weights="IMAGENET1K_V1")
  361.                 num_ftrs = model_ft.fc.in_features
  362.                 model_ft.fc = nn.Linear(num_ftrs, 5)
  363.  
  364.             try:
  365.                 pretrained_dict = torch.load("/home/n.yakovlev/conv_autoencoder_256neuron_best_valid.pth")
  366.                 model_dict = model_ft.state_dict()
  367.  
  368.                 processed_dict = {}
  369.  
  370.                 for k in model_dict.keys():
  371.                     decomposed_key = k.split(".")
  372.                     if ("model" in decomposed_key):
  373.                         pretrained_key = ".".join(decomposed_key[1:])
  374.                         processed_dict[k] = pretrained_dict[pretrained_key]
  375.  
  376.                 model_ft.load_state_dict(processed_dict, strict=False)
  377.  
  378.             except Exception:
  379.                 print("Model weights doesn't loaded!")
  380.            
  381.  
  382.             model_ft = model_ft.to(device)
  383.  
  384.             criterion = nn.CrossEntropyLoss(weight=torch.tensor([1, 1, 1.2, 1, 1.2], device=device))
  385.  
  386.             # Observe that all parameters are being optimized
  387.             optimizer_ft = optim.Adam(model_ft.parameters(), lr=LEARING_RATE)
  388.  
  389.             exp_lr_scheduler = lr_scheduler.StepLR(
  390.                 optimizer_ft, step_size=1, gamma=SCHEDULER_PARAM
  391.             )
  392.  
  393.             train_model(
  394.                 model_ft,
  395.                 criterion,
  396.                 optimizer_ft,
  397.                 exp_lr_scheduler,
  398.                 iterations_per_epoch=ITER_PER_EPOCH,  # len(train_ds) // batch_size,
  399.                 n_epochs=50,
  400.                 valid_loader=valid_loader,
  401.                 patch_size=PATCH_SIZE,
  402.                 layer=LAYER,
  403.                 batch_size_train=BATCH_SIZE_TRAIN,
  404.                 valid_image_num=valid_image_num,
  405.                 checkpoint_path=checkpoint_path,
  406.                 cur_statistic_path=cur_statistic_path,
  407.                 rand_aug_strange=rand_aug_strange,
  408.                 rand_aug_max_strange=rand_aug_max_strange,
  409.                 n_ops=n_ops,
  410.                 epoch_for_start_pl=epoch_for_start_pl,
  411.                 use_all_data=use_all_data,
  412.                 writer=writer,
  413.                 use_my_rand_aug=use_my_rand_aug,
  414.                 images_to_use=images_to_use
  415.             )
  416.             del model_ft, optimizer_ft
  417.             gc.collect()
  418.             torch.cuda.empty_cache()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement