Advertisement
Guest User

seti

a guest
Aug 8th, 2021
114
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 40.09 KB | None | 0 0
  1.  
  2.  
  3.  
  4. !pip install -q git+https://github.com/rwightman/pytorch-image-models.git
  5. !pip install -q torchsummary
  6. !pip install -q -U git+https://github.com/albu/albumentations --no-cache-dir
  7. !pip install -q neptune-client
  8. from IPython.display import clear_output
  9. clear_output()
  10.  
  11.  
  12. #import torch.nn as nn
  13. import torch.nn.init as init
  14.  
  15. import sys
  16. import numpy as np
  17. import torch
  18. from torch.nn.parameter import Parameter
  19.  
  20.  
  21.  
  22. import math
  23. import os
  24. from torchsummary import summary
  25. import warnings
  26. import random
  27. from matplotlib import pyplot as plt
  28. import seaborn as sns
  29. from typing import *
  30. import albumentations
  31. from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold
  32. import cv2
  33. import neptune.new as neptune
  34. import numpy as np
  35. import pandas as pd
  36. import timm
  37. import torch
  38. import torch.nn.functional as F
  39. from albumentations.pytorch.transforms import ToTensorV2
  40. from sklearn.preprocessing import LabelEncoder
  41. from torch import nn
  42. from torch.autograd import Variable
  43. from torch.optim.lr_scheduler import _LRScheduler
  44. from torch.optim.optimizer import Optimizer
  45. from torchsummary import summary
  46. from torchvision import models
  47. from tqdm.notebook import tqdm
  48. import pandas as pd
  49. warnings.filterwarnings("ignore")
  50. from torch.utils.data import DataLoader, Dataset
  51. from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
  52. from IPython.display import clear_output
  53. clear_output()
  54.  
  55. torch.cuda.empty_cache()
  56.  
  57. import albumentations as A
  58. size = 512
  59. bs = 96
  60. #efficientnet_b0 nfnet_l0
  61. CONFIG = {
  62.     "COMPETITION_NAME": "SETI",
  63.     "MODEL": {"MODEL_FACTORY": "timm", "MODEL_NAME": "efficientnet_b0"},
  64.     "WORKSPACE": "home",
  65.     "DATA": {
  66.         "TARGET_COL_NAME": "target",
  67.         "IMAGE_COL_NAME": "id",
  68.         "NUM_CLASSES": 1,
  69.         "CLASS_LIST": [0, 1],
  70.         "IMAGE_SIZE": size,
  71.         "CHANNEL_MODE": "spatial_3ch",
  72.         "USE_MIXUP": True
  73.         #"USE_CUTMIX": False
  74.     },
  75.     "CROSS_VALIDATION": {"SCHEMA" : 'StratifiedKFold', "NUM_FOLDS": 5},
  76.     "TRAIN": {
  77.         "DATALOADER": {
  78.             "batch_size": bs,
  79.             "shuffle": True, #using random sampler
  80.             "num_workers": 4,
  81.             "drop_last": False,
  82.         },
  83.         "SETTINGS": {
  84.             "IMAGE_SIZE": size,
  85.             "NUM_EPOCHS": 60,
  86.             "USE_AMP": True,
  87.             "USE_GRAD_ACCUM": False,
  88.             "ACCUMULATION_STEP": 1,
  89.             "DEBUG": False,
  90.             "VERBOSE": True,
  91.             "VERBOSE_STEP": 10,
  92.         },
  93.     },
  94.     "VALIDATION": {
  95.         "DATALOADER": {
  96.             "batch_size": 16,
  97.             "shuffle": False,
  98.             "num_workers": 4,
  99.             "drop_last": False,
  100.         }
  101.     },
  102.     "TEST": {
  103.         "DATALOADER": {
  104.             "batch_size": 16,
  105.             "shuffle": False,
  106.             "num_workers": 4,
  107.             "drop_last": False,
  108.         }
  109.     },
  110.     "OPTIMIZER": {
  111.         "NAME": "AdamW",
  112.         "OPTIMIZER_PARAMS": {"lr": 1e-4, "eps": 1.0e-8, "weight_decay": 1.0e-3},
  113.     },
  114.     "SCHEDULER": {
  115.         "NAME": "CosineAnnealingWarmRestarts",
  116.         "SCHEDULER_PARAMS": {
  117.             "T_0": 4,
  118.             "T_mult": 1,
  119.             "eta_min": 1.0e-7,
  120.             "last_epoch": -1,
  121.             "verbose": True,
  122.         },
  123.         "CUSTOM": "GradualWarmupSchedulerV2",
  124.         "CUSTOM_PARAMS": {"multiplier": 7, "total_epoch": 1},
  125.         "VAL_STEP": False,
  126.     },
  127.     "CRITERION_TRAIN": {
  128.         "NAME": "BCEWithLogitsLoss",
  129.         "LOSS_PARAMS": {
  130.             "weight": None,
  131.             "size_average": None,
  132.             "reduce": None,
  133.             "reduction": "mean",
  134.             "pos_weight": None
  135.         },
  136.     },
  137.     "CRITERION_VALIDATION": {
  138.         "NAME": "BCEWithLogitsLoss",
  139.         "LOSS_PARAMS": {
  140.             "weight": None,
  141.             "size_average": None,
  142.             "reduce": None,
  143.             "reduction": "mean",
  144.             "pos_weight": None
  145.         },
  146.     },
  147.     "TRAIN_TRANSFORMS": {        
  148.         "VerticalFlip": {"p": 0.5},
  149.         "HorizontalFlip": {"p": 0.5},
  150.         "Resize": {"height": size, "width": size, "p": 1},
  151.     },
  152.     "VALID_TRANSFORMS": {
  153.         "Resize": {"height": size, "width": size, "p": 1},
  154.     },
  155.     "TEST_TRANSFORMS": {
  156.         "Resize": {"height": size, "width": size, "p": 1},
  157.     },
  158.     "PATH": {
  159.         "DATA_DIR": "/home/apsisdev/data/seti/",
  160.         "TRAIN_CSV": "/home/apsisdev/data/seti/seti-breakthrough-listen/train_labels.csv",
  161.         "TRAIN_PATH": "/home/apsisdev/data/seti/seti-breakthrough-listen/train",
  162.        
  163.         "TEST_CSV": "/home/apsisdev/data/seti/seti-breakthrough-listen/sample_submission.csv",
  164.         "TEST_PATH": "/home/apsisdev/data/seti/seti-breakthrough-listen/test",
  165.         "SAVE_WEIGHT_PATH": "./",
  166.         "OOF_PATH": "/home/apsisdev/data/seti/seti-breakthrough-listen/",
  167.         "LOG_PATH": "/home/apsisdev/data/seti/seti-breakthrough-listen/log.txt"
  168.     },
  169.     "SEED": 63,
  170.     "DEVICE": "cuda",
  171.     "GPU": "rtx3090",
  172. }
  173.  
  174. config = CONFIG
  175.  
  176. def seed_all(seed: int = 63):
  177.     """Seed all random number generators."""
  178.     print("Using Seed Number {}".format(seed))
  179.  
  180.     os.environ["PYTHONHASHSEED"] = str(
  181.         seed
  182.     )  # set PYTHONHASHSEED env var at fixed value
  183.     torch.manual_seed(seed)
  184.     torch.cuda.manual_seed_all(seed)
  185.     torch.cuda.manual_seed(seed)  # pytorch (both CPU and CUDA)
  186.     np.random.seed(seed)  # for numpy pseudo-random generator
  187.     random.seed(seed)  # set fixed value for python built-in pseudo-random generator
  188.     torch.backends.cudnn.deterministic = True
  189.     torch.backends.cudnn.benchmark = False
  190.     torch.backends.cudnn.enabled = False
  191.  
  192.  
  193. def seed_worker(_worker_id):
  194.     """Seed a worker with the given ID."""
  195.     worker_seed = torch.initial_seed() % 2 ** 32
  196.     np.random.seed(worker_seed)
  197.     random.seed(worker_seed)
  198.  
  199. seed_all(config['SEED'])
  200.  
  201. train = pd.read_csv(CONFIG['PATH']['TRAIN_CSV'])
  202.  
  203. def get_train_file_path(image_id):
  204.     if config['WORKSPACE'] == 'home':
  205.  
  206.         return "/home/apsisdev/data/seti/seti-breakthrough-listen/train/{}/{}.npy".format(image_id[0], image_id)
  207.     elif config['WORKSPACE'] == 'Colab':
  208.         return "/home/apsisdev/data/seti/seti-breakthrough-listen/{}/{}.npy".format(image_id[0], image_id)
  209.  
  210.  
  211. train['file_path'] = train['id'].apply(get_train_file_path)
  212.  
  213.  
  214. display(train.head())
  215.  
  216. def make_folds(train_csv: pd.DataFrame, config) -> pd.DataFrame:
  217.     """Split the given dataframe into training folds."""
  218.     # TODO: add options for cv_scheme as it is cumbersome here.
  219.     if config['CROSS_VALIDATION']['SCHEMA'] == "StratifiedKFold":
  220.         df_folds = train_csv.copy()
  221.         skf = StratifiedKFold(
  222.             n_splits=config['CROSS_VALIDATION']['NUM_FOLDS'], shuffle=True, random_state=config['SEED']
  223.         )
  224.  
  225.         for fold, (train_idx, val_idx) in enumerate(
  226.             skf.split(
  227.                 X=df_folds[config['DATA']['IMAGE_COL_NAME']], y=df_folds[config['DATA']['TARGET_COL_NAME']]
  228.             )
  229.         ):
  230.             df_folds.loc[val_idx, "fold"] = int(fold + 1)
  231.         df_folds["fold"] = df_folds["fold"].astype(int)
  232.         print(df_folds.groupby(["fold", config['DATA']['TARGET_COL_NAME']]).size())
  233.  
  234.     elif config.cv_schema == "GroupKfold":
  235.         df_folds = train_csv.copy()
  236.         gkf = GroupKFold(n_splits=config.num_folds)
  237.         groups = df_folds[config.group_kfold_split].values
  238.         for fold, (train_index, val_index) in enumerate(
  239.             gkf.split(X=df_folds, y=df_folds[config.class_col_name], groups=groups)
  240.         ):
  241.             df_folds.loc[val_index, "fold"] = int(fold + 1)
  242.         df_folds["fold"] = df_folds["fold"].astype(int)
  243.         try:
  244.             print(df_folds.groupby(["fold", config.class_col_name]).size())
  245.         except:
  246.             display(df_folds)
  247.  
  248.     else:  # No CV Schema used in this file, but custom one
  249.         df_folds = train_csv.copy()
  250.         try:
  251.             print(df_folds.groupby(["fold", config.class_col_name]).size())
  252.         except:
  253.             display(df_folds)
  254.  
  255.     return df_folds
  256.  
  257. df_folds =  make_folds(train, config)
  258.  
  259.  
  260. def gem(x, p=3, eps=1e-6):
  261.     return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
  262. class GeM(nn.Module):
  263.     def __init__(self, p=3, eps=1e-6):
  264.         super(GeM,self).__init__()
  265.         self.p = Parameter(torch.ones(1)*p)
  266.         self.eps = eps
  267.     def forward(self, x):
  268.         return gem(x, p=self.p, eps=self.eps)      
  269.     def __repr__(self):
  270.         return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
  271.    
  272. class Transform:
  273.  
  274.     def __init__(self, aug_kwargs: Dict):
  275.         albu_augs = [getattr(A, name)(**kwargs)
  276.                      for name, kwargs in aug_kwargs.items()]
  277.         albu_augs.append(ToTensorV2(p=1))
  278.  
  279.         self.transform = A.Compose(albu_augs)
  280.  
  281.     def __call__(self, image):
  282.         image = self.transform(image=image)["image"]
  283.         return image
  284.  
  285. def mixup_data(x, y, alpha=1.0, use_cuda=True):
  286.    
  287.     if alpha > 0:
  288.         lam = np.random.beta(alpha, alpha)
  289.         #lam = max(lam, 1-lam)
  290.     else:
  291.         lam = 1
  292.  
  293.     batch_size = x.size()[0]
  294.     if use_cuda:
  295.         index = torch.randperm(batch_size).cuda()
  296.     else:
  297.         index = torch.randperm(batch_size)
  298.  
  299.     mixed_x = lam * x + (1 - lam) * x[index, :]
  300.     y_a, y_b = y, y[index]
  301.     return mixed_x, y_a, y_b, lam
  302.  
  303.  
  304. def mixup_criterion(criterion, pred, y_a, y_b, lam):
  305.     return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
  306.  
  307. class AlienTrainDataset(Dataset):
  308.     def __init__(self, df, config, transform=None, mode = 'train'):
  309.         self.df = df
  310.         self.config = config
  311.         self.file_names = df['file_path'].values
  312.         self.labels = df[config['DATA']['TARGET_COL_NAME']].values
  313.         self.transform = transform
  314.         self.mode = mode
  315.        
  316.     def __len__(self):
  317.         return len(self.df)
  318.  
  319.     def __getitem__(self, idx):
  320.         image = np.load(self.file_names[idx])
  321.         # print(image.shape) -> (6, 273, 256)
  322.         if self.config['DATA']['CHANNEL_MODE'] == 'spatial_6ch':
  323.             image = image.astype(np.float32)
  324.             image = np.vstack(image) # no transpose here (1638, 256)
  325.             # image = np.vstack(image).transpose((1, 0))
  326.             # print(image.shape) -> (256, 1638)
  327.  
  328.         elif self.config['DATA']['CHANNEL_MODE'] == 'spatial_3ch':
  329.             image = image[::2].astype(np.float32)
  330.             image = np.vstack(image).transpose((1, 0))
  331.         elif self.config['DATA']['CHANNEL_MODE'] == '6_channel':
  332.             image = image.astype(np.float32)
  333.             image = np.transpose(image, (1,2,0))
  334.         elif self.config['DATA']['CHANNEL_MODE'] == '3_channel':
  335.             image = image[::2].astype(np.float32)
  336.             image = np.transpose(image, (1,2,0))
  337.        
  338.         if self.transform:
  339.             image = self.transform(image)
  340.  
  341.         else:
  342.             image = torch.from_numpy(image).float()
  343.  
  344.         if self.mode == 'test':
  345.             return image    
  346.         else:
  347.             label = torch.tensor(self.labels[idx]).float()
  348.             return image, label
  349.            
  350. train_dataset = AlienTrainDataset(train, config, transform=Transform(config["TRAIN_TRANSFORMS"]))
  351.  
  352. class AverageLossMeter:
  353.     """
  354.    Computes and stores the average and current loss
  355.    """
  356.  
  357.     def __init__(self):
  358.         self.reset()
  359.  
  360.     def reset(self):
  361.         self.curr_batch_avg_loss = 0
  362.         self.avg = 0
  363.         self.running_total_loss = 0
  364.         self.count = 0
  365.  
  366.     def update(self, curr_batch_avg_loss: float, batch_size: str):
  367.         self.curr_batch_avg_loss = curr_batch_avg_loss
  368.         self.running_total_loss += curr_batch_avg_loss * batch_size
  369.         self.count += batch_size
  370.         self.avg = self.running_total_loss / self.count
  371.  
  372. import warnings
  373.  
  374. warnings.filterwarnings("ignore")
  375.  
  376. from torch.optim.lr_scheduler import _LRScheduler
  377. from torch.optim.lr_scheduler import ReduceLROnPlateau
  378.  
  379. ### Original Implementation ###
  380. class GradualWarmupScheduler(_LRScheduler):
  381.     """Gradually warm-up(increasing) learning rate in optimizer.
  382.    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
  383.    Args:
  384.        optimizer (Optimizer): Wrapped optimizer.
  385.        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
  386.        total_epoch: target learning rate is reached at total_epoch, gradually
  387.        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
  388.    """
  389.  
  390.     def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
  391.         self.multiplier = multiplier
  392.         if self.multiplier < 1.0:
  393.             raise ValueError("multiplier should be greater thant or equal to 1.")
  394.         self.total_epoch = total_epoch
  395.         self.after_scheduler = after_scheduler
  396.         self.finished = False
  397.         super(GradualWarmupScheduler, self).__init__(optimizer)
  398.  
  399.     def get_lr(self):
  400.         if self.last_epoch > self.total_epoch:
  401.             if self.after_scheduler:
  402.                 if not self.finished:
  403.                     self.after_scheduler.base_lrs = [
  404.                         base_lr * self.multiplier for base_lr in self.base_lrs
  405.                     ]
  406.                     self.finished = True
  407.                 return self.after_scheduler.get_last_lr()
  408.             return [base_lr * self.multiplier for base_lr in self.base_lrs]
  409.  
  410.         if self.multiplier == 1.0:
  411.             return [
  412.                 base_lr * (float(self.last_epoch) / self.total_epoch)
  413.                 for base_lr in self.base_lrs
  414.             ]
  415.         else:
  416.             return [
  417.                 base_lr
  418.                 * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
  419.                 for base_lr in self.base_lrs
  420.             ]
  421.  
  422.     def step_ReduceLROnPlateau(self, metrics, epoch=None):
  423.         if epoch is None:
  424.             epoch = self.last_epoch + 1
  425.         self.last_epoch = (
  426.             epoch if epoch != 0 else 1
  427.         )  # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
  428.         if self.last_epoch <= self.total_epoch:
  429.             warmup_lr = [
  430.                 base_lr
  431.                 * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
  432.                 for base_lr in self.base_lrs
  433.             ]
  434.             for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
  435.                 param_group["lr"] = lr
  436.         else:
  437.             if epoch is None:
  438.                 self.after_scheduler.step(metrics, None)
  439.             else:
  440.                 self.after_scheduler.step(metrics, epoch - self.total_epoch)
  441.  
  442.     def step(self, epoch=None, metrics=None):
  443.         if type(self.after_scheduler) != ReduceLROnPlateau:
  444.             if self.finished and self.after_scheduler:
  445.                 if epoch is None:
  446.                     self.after_scheduler.step(None)
  447.                 else:
  448.                     self.after_scheduler.step(epoch - self.total_epoch)
  449.                 self._last_lr = self.after_scheduler.get_last_lr()
  450.             else:
  451.                 return super(GradualWarmupScheduler, self).step(epoch)
  452.         else:
  453.             self.step_ReduceLROnPlateau(metrics, epoch)
  454.  
  455.  
  456. ### Fix Warmup Bug here, a modified version of above.
  457.  
  458.  
  459. class GradualWarmupSchedulerV2(GradualWarmupScheduler):
  460.     def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
  461.         super(GradualWarmupSchedulerV2, self).__init__(
  462.             optimizer, multiplier, total_epoch, after_scheduler
  463.         )
  464.  
  465.     def get_lr(self):
  466.         if self.last_epoch > self.total_epoch:
  467.             if self.after_scheduler:
  468.                 if not self.finished:
  469.                     self.after_scheduler.base_lrs = [
  470.                         base_lr * self.multiplier for base_lr in self.base_lrs
  471.                     ]
  472.                     self.finished = True
  473.                 return self.after_scheduler.get_lr()
  474.             return [base_lr * self.multiplier for base_lr in self.base_lrs]
  475.         if self.multiplier == 1.0:
  476.             return [
  477.                 base_lr * (float(self.last_epoch) / self.total_epoch)
  478.                 for base_lr in self.base_lrs
  479.             ]
  480.         else:
  481.             return [
  482.                 base_lr
  483.                 * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
  484.                 for base_lr in self.base_lrs
  485.             ]
  486.  
  487. sigmoid = torch.nn.Sigmoid()
  488.  
  489.  
  490. class Swish(torch.autograd.Function):
  491.     @staticmethod
  492.     def forward(ctx, i):
  493.         result = i * sigmoid(i)
  494.         ctx.save_for_backward(i)
  495.         return result
  496.  
  497.     @staticmethod
  498.     def backward(ctx, grad_output):
  499.         i = ctx.saved_variables[0]
  500.         sigmoid_i = sigmoid(i)
  501.         return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
  502.  
  503.  
  504. class Swish_Module(torch.nn.Module):
  505.     def forward(self, x):
  506.         return Swish.apply(x)
  507. class AlienSingleHead(torch.nn.Module):
  508.     """A custom model."""
  509.  
  510.     def __init__(
  511.         self,
  512.         config: type,
  513.         pretrained: bool = True,
  514.     ):
  515.         """Construct a custom model."""
  516.         super().__init__()
  517.         self.config = config
  518.         self.pretrained = pretrained
  519.         print("Pretrained is {}".format(self.pretrained))
  520.         # self.activation = Swish_Module()
  521.         self.activation = Swish_Module()
  522.         self.architecture = {
  523.             "backbone": None,
  524.             "bottleneck": None,
  525.             "classifier_head": None,
  526.         }
  527.  
  528.         def __setattr__(self, name, value):
  529.             self.model.__setattr__(self, name, value)
  530.  
  531.         _model_factory = (
  532.             timm.create_model
  533.             if self.config["MODEL"]["MODEL_FACTORY"] == "timm"
  534.             else geffnet.create_model
  535.         )
  536.         if config['DATA']['CHANNEL_MODE'] == 'spatial_6ch' or config['DATA']['CHANNEL_MODE'] == 'spatial_3ch':
  537.  
  538.             self.model = _model_factory(
  539.                 model_name=self.config["MODEL"]["MODEL_NAME"],
  540.                 pretrained=self.pretrained, in_chans=1) # set channel = 1 since we using spatial
  541.  
  542.         else:
  543.             self.model = _model_factory(
  544.                             model_name=self.config["MODEL"]["MODEL_NAME"],
  545.                             pretrained=self.pretrained, in_chans=3) # set channel = 1 since we using spatial
  546.  
  547.         # reset head
  548.         self.model.reset_classifier(num_classes=0, global_pool="avg")
  549.         # after resetting, there is no longer any classifier head, therefore it is the backbone now.
  550.         self.architecture["backbone"] = self.model
  551.         # get out features of the last cnn layer from backbone, which is also the in features of the next layer
  552.  
  553.         self.in_features = self.architecture["backbone"].num_features
  554.  
  555.         self.single_head_fc = torch.nn.Sequential(
  556.             torch.nn.Linear(self.in_features, self.in_features),
  557.             self.activation,
  558.             #torch.nn.Dropout(p=0.05),
  559.             torch.nn.Linear(self.in_features, self.config["DATA"]["NUM_CLASSES"]),
  560.         )
  561.         self.architecture["classifier_head"] = self.single_head_fc
  562.  
  563.  
  564.     # feature map after cnn layer
  565.     def extract_features(self, x):
  566.         feature_logits = self.architecture["backbone"](x)
  567.         # TODO: caution, if you use forward_features, then you need reshape. See test.py
  568.         return feature_logits
  569.  
  570.     def forward(self, x):
  571.         feature_logits = self.extract_features(x)
  572.         classifier_logits = self.architecture["classifier_head"](feature_logits)
  573.         return classifier_logits
  574.  
  575. model = AlienSingleHead(config,pretrained=False)
  576. train_dataset = AlienTrainDataset(train, config, transform=Transform(config["TRAIN_TRANSFORMS"]))
  577. train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True,
  578.                           num_workers=4, pin_memory=True, drop_last=True)
  579.  
  580.  
  581. def torchsummary_wrapper(model, image_size: Tuple):
  582.     model_summary = summary(model, image_size)
  583.     return model_summary
  584.  
  585. """Model training."""
  586.  
  587. import datetime
  588. import os
  589. import random
  590. import time
  591.  
  592. import numpy as np
  593. import pandas as pd
  594. import pytz
  595. import sklearn
  596. import torch
  597. import torch.nn as nn
  598. from sklearn.model_selection import GroupKFold
  599. from torch.utils.data import DataLoader
  600.  
  601. from tqdm import tqdm
  602. from sklearn.metrics import roc_auc_score
  603.  
  604.  
  605. class Trainer:
  606.  
  607.     """A class to perform model training."""
  608.  
  609.     def __init__(self, model, config, early_stopping=None, neptune=None):
  610.         """Construct a Trainer instance."""
  611.         self.model = model
  612.         self.patience = 15
  613.         self.config = config
  614.         self.neptune = neptune
  615.         self.early_stopping = early_stopping
  616.         self.epoch = 0
  617.         self.best_auc = 0
  618.         self.log_path = self.config["PATH"]["LOG_PATH"]
  619.         self.best_loss = np.inf
  620.         self.num_epochs = self.config["TRAIN"]["SETTINGS"]["NUM_EPOCHS"]
  621.         self.save_path = self.config["PATH"]["SAVE_WEIGHT_PATH"]
  622.         if not os.path.exists(self.save_path):
  623.             os.makedirs(self.save_path)
  624.         self.device = self.config["DEVICE"]
  625.         """scaler is only used when use_amp is True, use_amp is inside config."""
  626.         if self.config["TRAIN"]["SETTINGS"]["USE_AMP"]:
  627.             self.scaler = torch.cuda.amp.GradScaler()
  628.         self.date = datetime.datetime.now(pytz.timezone("Asia/Dhaka")).strftime(
  629.             "%Y-%m-%d"
  630.         )
  631.         self.log(f"Fitter prepared. Device is {self.device}")
  632.  
  633.         self.criterion_train = getattr(
  634.             torch.nn, self.config["CRITERION_TRAIN"]["NAME"]
  635.         )(**self.config["CRITERION_TRAIN"]["LOSS_PARAMS"]).to(self.device)
  636.        
  637.         self.criterion_val = getattr(
  638.             torch.nn, self.config["CRITERION_VALIDATION"]["NAME"]
  639.         )(**self.config["CRITERION_VALIDATION"]["LOSS_PARAMS"])
  640.  
  641.         self.optimizer = getattr(torch.optim, self.config["OPTIMIZER"]["NAME"])(
  642.             self.model.parameters(), **self.config["OPTIMIZER"]["OPTIMIZER_PARAMS"]
  643.         )
  644.         self.scheduler = getattr(
  645.             torch.optim.lr_scheduler, self.config["SCHEDULER"]["NAME"]
  646.         )(optimizer=self.optimizer, **self.config["SCHEDULER"]["SCHEDULER_PARAMS"])
  647.  
  648.         self.scheduler_warmup = GradualWarmupSchedulerV2(
  649.             self.optimizer,
  650.             **self.config["SCHEDULER"]["CUSTOM_PARAMS"],
  651.             after_scheduler=self.scheduler,
  652.         )  # total epoch = warmup epoch
  653.         self.val_predictions = None
  654.         self.date = datetime.datetime.now(pytz.timezone("Asia/Dhaka")).strftime(
  655.             "%Y-%m-%d"
  656.         )
  657.  
  658.         self.log(
  659.             "Trainer prepared. We are using {} device.".format(self.config["DEVICE"])
  660.         )
  661.  
  662.     def fit(self, train_loader, val_loader, fold: int):
  663.         """Fit the model on the given fold."""
  664.         self.log(
  665.             "Training on Fold {} and using {}".format(
  666.                 fold, self.config["MODEL"]["MODEL_NAME"]
  667.             )
  668.         )
  669.  
  670.         for _epoch in range(self.num_epochs):
  671.             # Getting the learning rate after each epoch!
  672.             current_lr = self.optimizer.param_groups[0]["lr"]
  673.  
  674.             timestamp = datetime.datetime.now(pytz.timezone("Asia/Dhaka")).strftime(
  675.                 "%Y-%m-%d %H-%M-%S"
  676.             )
  677.             # printing the lr and the timestamp after each epoch.
  678.             self.log("\n{}\nLR: {}".format(timestamp, current_lr))
  679.  
  680.             # start time of training on the training set
  681.             train_start_time = time.time()
  682.             '''
  683.            if(_epoch<6):
  684.                print('light aug....')
  685.                self.config['DATA']['USE_MIXUP']=False
  686.                self.config['DATA']['USE_CUTMIX']=False
  687.            elif(_epoch>5 and epoch<20):
  688.                print('mixup without hesitation....')
  689.                self.config['DATA']['USE_MIXUP']=True
  690.                self.config['DATA']['USE_CUTMIX']=False
  691.            else:
  692.                print('cutmix....')
  693.                self.config['DATA']['USE_MIXUP']= False
  694.                self.config['DATA']['USE_CUTMIX']=True
  695.                
  696.            '''
  697.             # train one epoch on the training set
  698.             avg_train_loss = self.train_one_epoch(train_loader)
  699.             # end time of training on the training set
  700.             train_end_time = time.time()
  701.  
  702.             # formatting time to make it nicer
  703.             train_elapsed_time = time.strftime(
  704.                 "%H:%M:%S", time.gmtime(train_end_time - train_start_time)
  705.             )
  706.             self.log(
  707.                 "[RESULT]: Train. Epoch {} | Avg Train Summary Loss: {:.3f} | "
  708.                 "Time Elapsed: {}".format(
  709.                     self.epoch + 1,
  710.                     avg_train_loss,
  711.                     train_elapsed_time,
  712.                 )
  713.             )
  714.  
  715.             val_start_time = time.time()
  716.            
  717.             (
  718.                 avg_val_loss,
  719.                 avg_val_roc,
  720.                 val_predictions,
  721.             ) = self.valid_one_epoch(val_loader)
  722.             # here we get oof preds
  723.             self.val_predictions = val_predictions
  724.             val_end_time = time.time()
  725.             val_elapsed_time = time.strftime(
  726.                 "%H:%M:%S", time.gmtime(val_end_time - val_start_time)
  727.             )
  728.             # self.neptune["Metrics/AUC"].log(avg_val_roc)
  729.             self.log(
  730.                 "[RESULT]: Validation. Epoch: {} | "
  731.                 "Avg Validation Summary Loss: {:.3f} | "
  732.                 "Validation ROC: {:.3f} | Time Elapsed: {}".format(
  733.                     self.epoch + 1,
  734.                     avg_val_loss,
  735.                     avg_val_roc,
  736.                     val_elapsed_time,
  737.                 )
  738.             )
  739.  
  740.             # added this flag right before early stopping to let user
  741.             # know which metric im monitoring.
  742.             self.monitored_metrics = avg_val_roc
  743.  
  744.             if self.early_stopping is not None:
  745.  
  746.                 best_score, early_stop = self.early_stopping.should_stop(
  747.                     curr_epoch_score=self.monitored_metrics
  748.                 )
  749.                 self.best_loss = best_score
  750.                 self.save(
  751.                     "{}_best_loss_fold_{}.pt".format(
  752.                         self.config["MODEL"]["MODEL_NAME"], fold
  753.                     )
  754.                 )
  755.                 if early_stop:
  756.                     break
  757.  
  758.             else:
  759.  
  760.                 if avg_val_loss < self.best_loss:
  761.                     self.best_loss = avg_val_loss
  762.  
  763.             if self.best_auc < avg_val_roc:
  764.                 self.best_auc = avg_val_roc
  765.                 self.save(
  766.                     os.path.join(
  767.                         self.save_path,
  768.                         "{}_{}_best_auc_fold_{}.pt".format(
  769.                             self.date, self.config["MODEL"]["MODEL_NAME"], fold
  770.                         ),
  771.                     )
  772.                 )
  773.                 self.patience = 15
  774.             else:
  775.                 self.patience -= 1
  776.                 if self.patience == 0:
  777.                     print("Early Stopping")
  778.                     break
  779.  
  780.             '''
  781.            CosineAnnealingWarmRestart
  782.            '''
  783.             self.scheduler_warmup.step()
  784.             if _epoch==2: self.scheduler_warmup.step() # bug workaround  
  785.  
  786.             if self.config["SCHEDULER"]["VAL_STEP"]:
  787.                 if isinstance(
  788.                     self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
  789.                 ):
  790.                     self.scheduler.step(self.monitored_metrics)
  791.                 else:
  792.                     self.scheduler.step()
  793.  
  794.             # end of training, epoch + 1 so that self.epoch can be updated.
  795.             self.epoch += 1
  796.  
  797.         curr_fold_best_checkpoint = self.load(
  798.             os.path.join(
  799.                 self.save_path,
  800.                 "{}_{}_best_auc_fold_{}.pt".format(
  801.                     self.date, self.config["MODEL"]["MODEL_NAME"], fold
  802.                 ),
  803.             )
  804.         )
  805.         return curr_fold_best_checkpoint
  806.  
  807.     def train_one_epoch(self, train_loader):
  808.         """Train one epoch of the model."""
  809.         # set to train mode
  810.         #self.model.avg_pool = GeM()
  811.         self.model.train()
  812.         #self.config = myconfig
  813.  
  814.         # log metrics
  815.         train_summary_loss = AverageLossMeter()
  816.         # TODO: use Alex's ROC METER?
  817.  
  818.         # timer
  819.         start_time = time.time()
  820.         train_bar = train_loader
  821.         # looping through train loader for one epoch, steps is the
  822.         # number of times to go through each epoch
  823.         for step, (images, labels) in enumerate(train_bar):
  824.             if self.config['DATA']['USE_MIXUP']:
  825.  
  826.                 images, labels = (
  827.                     images.float(),
  828.                     labels,
  829.                 )
  830.                 images, targets_a, targets_b, lam = mixup_data(images, labels.view(-1, 1), use_cuda=True)
  831.                 images, targets_a, targets_b = images.to(self.device), targets_a.to(self.device), targets_b.to(self.device)
  832.          
  833.             else:
  834.                 images, labels = (
  835.                     images.to(self.device).float(),
  836.                     labels.to(self.device),
  837.                 )
  838.  
  839.  
  840.             batch_size = labels.shape[0]
  841.  
  842.             if (
  843.                 self.config["TRAIN"]["SETTINGS"]["USE_AMP"] is True
  844.                 and self.config["TRAIN"]["SETTINGS"]["USE_GRAD_ACCUM"] is False
  845.             ):
  846.  
  847.                 """I would think clearing gradients here is the correct way, as opposed to calling it last."""
  848.                 self.optimizer.zero_grad()
  849.                 with torch.cuda.amp.autocast():
  850.                     logits = self.model(images)
  851.                     if self.config["DATA"]["USE_MIXUP"]:
  852.                         train_loss = mixup_criterion(self.criterion_train, logits, targets_a, targets_b, lam)
  853.                         #train_loss = mixup_criterion(self.criterion_train,targets_a, targets_b, lam)
  854.                  
  855.                     else:
  856.                         train_loss = self.criterion_train(input=logits.view(-1), target=labels) # use view here for BCELogitLoss
  857.                
  858.                 loss_value = train_loss.item()
  859.                 self.scaler.scale(train_loss).backward()
  860.                 self.scaler.step(self.optimizer)
  861.                 self.scaler.update()
  862.  
  863.             elif (
  864.                 self.config["TRAIN"]["SETTINGS"]["USE_AMP"] is True
  865.                 and self.config["TRAIN"]["SETTINGS"]["USE_GRAD_ACCUM"] is True
  866.             ):
  867.  
  868.                 with torch.cuda.amp.autocast():
  869.                     logits = self.model(images)
  870.                     train_loss = self.criterion_train(input=logits, target=labels)
  871.                     train_loss = (
  872.                         train_loss
  873.                         / self.config["TRAIN"]["SETTINGS"]["ACCUMULATION_STEP"]
  874.                     )
  875.                 loss_value = train_loss.item()
  876.                 self.scaler.scale(train_loss).backward()
  877.                 if (step + 1) % self.config["TRAIN"]["SETTINGS"][
  878.                     "ACCUMULATION_STEP"
  879.                 ] == 0:
  880.                     self.scaler.step(self.optimizer)
  881.                     self.scaler.update()
  882.                     self.optimizer.zero_grad()
  883.             else:
  884.                 logits = self.model(images)
  885.                 train_loss = self.criterion_train(input=logits, target=labels)
  886.                 loss_value = train_loss.item()
  887.                 self.optimizer.zero_grad()
  888.                 train_loss.backward()
  889.                 self.optimizer.step()
  890.             train_summary_loss.update(train_loss.item(), batch_size)
  891.             # here onwards, we have already completed the necessary forward pass and backprop, so we can come out of the if else loop.
  892.  
  893.             y_true = labels.cpu().numpy()
  894.  
  895.             softmax_preds = torch.nn.Softmax(dim=1)(input=logits).cpu().detach().numpy()
  896.             y_preds = np.argmax(a=softmax_preds, axis=1)
  897.  
  898.             # measure elapsed time
  899.             end_time = time.time()
  900.             #train_bar.set_description(f"loss: {train_summary_loss.avg:.3f}")
  901.  
  902.             if self.config["TRAIN"]["SETTINGS"]["VERBOSE"]:
  903.                 if (step % self.config["TRAIN"]["SETTINGS"]["VERBOSE_STEP"]) == 0:
  904.                     print(
  905.                         f"Train Steps {step}/{len(train_loader)}, "
  906.                         f"summary_loss: {train_summary_loss.avg:.3f}, "
  907.                         f"time: {(end_time - start_time):.3f}",
  908.                         end="\r",
  909.                     )
  910.  
  911.         return train_summary_loss.avg
  912.  
  913.     # @torch.no_grad
  914.     def valid_one_epoch(self, val_loader):
  915.         """Validate one training epoch."""
  916.         # set to eval mode
  917.         self.model.eval()
  918.         print(self.device)
  919.         # log metrics
  920.         valid_summary_loss = AverageLossMeter()
  921.  
  922.         # timer
  923.         start_time = time.time()
  924.  
  925.         LOGITS = []
  926.         Y_TRUE = []
  927.         Y_PROBS = []
  928.         POSITIVE_CLASS_PROBS = []
  929.  
  930.         with torch.no_grad():
  931.             for step, (images, labels) in enumerate(val_loader):
  932.  
  933.                 images, labels = (
  934.                     images.to(self.device).float(),
  935.                     labels.to(self.device),
  936.                 )
  937.  
  938.                 batch_size = labels.shape[0]
  939.                 print(images)
  940.                 logits = self.model(images)
  941.                 print(logits)
  942.                 val_loss = self.criterion_val(input=logits.view(-1), target=labels) # use view here for BCELogitLoss
  943.                 loss_value = val_loss.item()
  944.                 valid_summary_loss.update(loss_value, batch_size)
  945.                 sigmoid_preds = torch.sigmoid(logits)
  946.                 y_preds = np.argmax(a=sigmoid_preds.detach().cpu(), axis=1)
  947.                
  948.                 LOGITS.append(logits.detach().cpu())
  949.                 Y_TRUE.append(labels.detach().cpu())
  950.                 Y_PROBS.append(sigmoid_preds.detach().cpu())
  951.  
  952.                 end_time = time.time()
  953.  
  954.                 if self.config["TRAIN"]["SETTINGS"]["VERBOSE"]:
  955.                     if (step % self.config["TRAIN"]["SETTINGS"]["VERBOSE_STEP"]) == 0:
  956.                         print(
  957.                             f"Validation Steps {step}/{len(val_loader)}, "
  958.                             + f"summary_loss: {valid_summary_loss.avg:.3f},"
  959.                             + f"time: {(end_time - start_time):.3f}",
  960.                             end="\r",
  961.                         )
  962.            
  963.             LOGITS = torch.cat(LOGITS).numpy()
  964.             Y_TRUE = torch.cat(Y_TRUE).numpy()
  965.             Y_PROBS = torch.cat(Y_PROBS).numpy()
  966.            
  967.  
  968.             if self.config["DATA"]["NUM_CLASSES"] > 2:
  969.                 val_roc_auc_score = sklearn.metrics.roc_auc_score(
  970.                     y_true=Y_TRUE, y_score=Y_PROBS, multi_class="ovr"
  971.                 )
  972.             else:
  973.                 val_roc_auc_score = sklearn.metrics.roc_auc_score(
  974.                     y_true=Y_TRUE, y_score=Y_PROBS
  975.                 )
  976.  
  977.         return (valid_summary_loss.avg, val_roc_auc_score, Y_PROBS)
  978.  
  979.     def save_model(self, path):
  980.         """Save the trained model."""
  981.         self.model.eval()
  982.         torch.save(self.model.state_dict(), path)
  983.  
  984.     # will save the weight for the best val loss and corresponding oof preds
  985.     def save(self, path):
  986.         """Save the weight for the best evaluation loss."""
  987.         self.model.eval()
  988.         torch.save(
  989.             {
  990.                 "model_state_dict": self.model.state_dict(),
  991.                 "optimizer_state_dict": self.optimizer.state_dict(),
  992.                 "scheduler_state_dict": self.scheduler.state_dict(),
  993.                 "best_auc": self.best_auc,
  994.                 "best_loss": self.best_loss,
  995.                 "epoch": self.epoch,
  996.                 "oof_preds": self.val_predictions,
  997.             },
  998.             path,
  999.         )
  1000.  
  1001.     def load(self, path):
  1002.         """Load a model checkpoint from the given path."""
  1003.         checkpoint = torch.load(path)
  1004.         return checkpoint
  1005.  
  1006.     def log(self, message):
  1007.         """Log a message."""
  1008.         if self.config["TRAIN"]["SETTINGS"]["VERBOSE"]:
  1009.             print(message)
  1010.         with open(self.config["PATH"]["LOG_PATH"], "a+") as logger:
  1011.             logger.write(f"{message}\n")
  1012. def train_on_fold(model, df_folds: pd.DataFrame, config, fold: int, neptune=None):
  1013.     """Train the model on the given fold."""
  1014.    
  1015.     model.to(config["DEVICE"])
  1016.    
  1017.     try:
  1018.         model_summary = torchsummary_wrapper(
  1019.             model, (1, config["DATA"]["IMAGE_SIZE"], config["DATA"]["IMAGE_SIZE"])
  1020.         )
  1021.     except RuntimeError:
  1022.         print("Check the channel number.")
  1023.  
  1024.     print("Model Summary: \n{}".format(model_summary))
  1025.  
  1026.     if config["TRAIN"]["SETTINGS"]["DEBUG"]:
  1027.         # args.n_epochs = 5
  1028.         df_train = df_folds[df_folds["fold"] != fold].sample(
  1029.             config["TRAIN"]["DATALOADER"]["batch_size"] * 128
  1030.         )
  1031.         df_valid = df_folds[df_folds["fold"] == fold].sample(
  1032.             config["TRAIN"]["DATALOADER"]["batch_size"] * 128
  1033.         )
  1034.     else:
  1035.         df_train = df_folds[df_folds["fold"] != fold].reset_index(drop=True)
  1036.         df_valid = df_folds[df_folds["fold"] == fold].reset_index(drop=True)
  1037.  
  1038.     dataset_train = AlienTrainDataset(
  1039.         config=config,
  1040.         df=df_train,
  1041.         mode="train",
  1042.         transform=Transform(config["TRAIN_TRANSFORMS"]),
  1043.     )
  1044.     dataset_valid = AlienTrainDataset(
  1045.         config=config,
  1046.         df=df_valid,
  1047.         mode="valid",
  1048.         transform=Transform(config["VALID_TRANSFORMS"]),
  1049.     )
  1050.  
  1051.     train_loader = torch.utils.data.DataLoader(
  1052.         dataset_train,
  1053.         # sampler=RandomSampler(dataset_train),
  1054.         **config["TRAIN"]["DATALOADER"],
  1055.     )
  1056.     valid_loader = torch.utils.data.DataLoader(
  1057.         dataset_valid, **config["VALIDATION"]["DATALOADER"]
  1058.     )
  1059.  
  1060.     hongnan_classifier = Trainer(model=model, config=config, neptune=neptune)
  1061.  
  1062.     curr_fold_best_checkpoint = hongnan_classifier.fit(train_loader, valid_loader, fold)
  1063.     # print(len(curr_fold_best_checkpoint["oof_preds"]))
  1064.     df_valid[
  1065.         [str(c) for c in range(config["DATA"]["NUM_CLASSES"])]
  1066.     ] = curr_fold_best_checkpoint["oof_preds"]
  1067.     # val_df["preds"] = curr_fold_best_checkpoint["oof_preds"].argmax(1)
  1068.  
  1069.     return df_valid
  1070.  
  1071.  
  1072. def train_loop(
  1073.     model,
  1074.     df_folds: pd.DataFrame,
  1075.     config,
  1076.     fold_num: int = None,
  1077.     train_one_fold=False,
  1078.     neptune=None,
  1079. ):
  1080.     """Perform the training loop on all folds. Here The CV score is the average of the validation fold metric.
  1081.    While the OOF score is the aggregation of all validation folds."""
  1082.  
  1083.     cv_score_list = []
  1084.     oof_df = pd.DataFrame()
  1085.     if train_one_fold:
  1086.         _oof_df = train_on_fold(
  1087.             model, df_folds=df_folds, config=config, fold=fold_num, neptune=neptune
  1088.         )
  1089.         _oof_df.to_csv(os.path.join(config["PATH"]["OOF_PATH"], "_oof.csv"))
  1090.         # curr_fold_best_score = get_oof_roc(config, _oof_df)
  1091.         # print("Fold {} OOF Score is {}".format(fold_num, curr_fold_best_score))
  1092.     else:
  1093.         """The below for loop code guarantees fold starts from 1 and not 0. https://stackoverflow.com/questions/33282444/pythonic-way-to-iterate-through-a-range-starting-at-1"""
  1094.         for fold in (
  1095.             number + 1 for number in range(config["CROSS_VALIDATION"]["NUM_FOLDS"])
  1096.         ):
  1097.             _oof_df = train_on_fold(
  1098.                 model, df_folds=df_folds, config=config, fold=fold, neptune=neptune
  1099.             )
  1100.             oof_df = pd.concat([oof_df, _oof_df])
  1101.             curr_fold_best_score_dict, curr_fold_best_score = get_oof_roc(
  1102.                 config, _oof_df
  1103.             )
  1104.             cv_score_list.append(curr_fold_best_score)
  1105.             print(
  1106.                 "\n\n\nOOF Score for Fold {}: {}\n\n\n".format(
  1107.                     fold, curr_fold_best_score
  1108.                 )
  1109.             )
  1110.  
  1111.         print("CV score", np.mean(cv_score_list))
  1112.         print("Variance", np.var(cv_score_list))
  1113.         print("Five Folds OOF", get_oof_roc(config, oof_df))
  1114.         oof_df.to_csv(os.path.join(config["PATH"]["OOF_PATH"], "oof.csv"))
  1115.  
  1116. model_pretrained = AlienSingleHead(config=config, pretrained=True)
  1117. train_loop(
  1118.     model_pretrained, df_folds, config, fold_num=0, train_one_fold=True, neptune=None
  1119. )
  1120.  
  1121.  
  1122.  
  1123.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement