Advertisement
NTahmid

training

Apr 9th, 2024
687
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.21 KB | None | 0 0
  1. def train(model, train_loader, test_loader, optimizer, loss_fn):
  2.   losses = []
  3.   start_step = 0
  4.   max_ssim = 0
  5.   max_psnr = 0
  6.   ssims = []
  7.   psnrs = []
  8.  
  9.   if resume:
  10.     losses = saved_losses
  11.     start_step = saved_step
  12.     max_ssim = saved_max_ssim
  13.     max_psnr = saved_max_psnr
  14.     ssims = saved_ssims
  15.     psnrs = saved_psnrs
  16.     optimizer.load_state_dict(saved_optimizer)
  17.     model.load_state_dict(saved_model)
  18.     print(f'model loaded from {last_data_path}')
  19.  
  20.   for step in range(start_step+1, steps+1):
  21.     model.train()
  22.     lr = init_lr
  23.     lr = lr_schedule_cosdecay(step, steps,init_lr)
  24.     for param_group in optimizer.param_groups:
  25.       param_group["lr"] = lr
  26.  
  27.     x, y = next(iter(train_loader))
  28.     x = x.to(device)
  29.     y = y.to(device)
  30.  
  31.     out = model(x)
  32.  
  33.     # print(f"Output : {out.shape}" )
  34.     # print(f"Target: {y.shape}")
  35.  
  36.     loss = 0.0
  37.     l1_loss = loss_fn[0](out, y)
  38.     loss = l1_loss
  39.  
  40.     loss.backward()
  41.  
  42.     optimizer.step()
  43.     optimizer.zero_grad()
  44.     losses.append(loss.item())
  45.  
  46.     print(f'loss: {loss.item():.5f}, L1_loss: {l1_loss:.5f} | step :{step}/{steps}|lr :{lr :.7f} |time_used :{(time.time() - start_time) / 60 :.1f}',end='', flush=True)
  47.  
  48.     with SummaryWriter(logdir=f'./data/{ablation}-logs', comment=f'./data/{ablation}-logs') as writer:
  49.       writer.add_scalar('runs-loss' + ablation, loss, step)
  50.       writer.add_scalar('runs-loss_l1' + ablation, l1_loss, step)
  51.  
  52.     if step % config['eval_step'] == 0:
  53.       epoch = step // config['eval_step']
  54.  
  55.       save_model_dir = f'./data/trained_models_{ablation}/{epoch}.ok'
  56.       best_model_dir = f'./data/trained_models_{ablation}/trained_model.best'
  57.  
  58.  
  59.  
  60.       with torch.no_grad():
  61.         ssim_eval, psnr_eval = test(model, test_loader)
  62.       log = f'\nstep :{step} | epoch: {epoch} | ssim:{ssim_eval:.4f}| psnr:{psnr_eval:.4f}'
  63.       model_name = config['model_name']
  64.       print(log)
  65.       with open(f'./data/{ablation}-logs/{ablation +"_"+ model_name}.txt', 'a') as f:
  66.         f.write(log + '\n')
  67.  
  68.       ssims.append(ssim_eval)
  69.       psnrs.append(psnr_eval)
  70.  
  71.       if psnr_eval > max_psnr:
  72.         max_ssim = max(max_ssim, ssim_eval)
  73.         max_psnr = max(max_psnr, psnr_eval)
  74.         print(f'\n model saved at step :{step}| epoch: {epoch} | max_psnr:{max_psnr:.4f}| max_ssim:{max_ssim:.4f}')
  75.         torch.save({
  76.         'epoch': epoch,
  77.         'step': step,
  78.         'max_psnr': max_psnr,
  79.         'max_ssim': max_ssim,
  80.         'ssims': ssims,
  81.         'psnrs': psnrs,
  82.         'losses': losses,
  83.         'model': model.state_dict(),
  84.         'optimizer': optimizer.state_dict()
  85.       }, best_model_dir)
  86.  
  87.  
  88.       torch.save({
  89.         'epoch': epoch,
  90.         'step': step,
  91.         'max_psnr': max_psnr,
  92.         'max_ssim': max_ssim,
  93.         'ssims': ssims,
  94.         'psnrs': psnrs,
  95.         'losses': losses,
  96.         'model': model.state_dict(),
  97.         'optimizer': optimizer.state_dict()
  98.       }, save_model_dir)
  99.  
  100.   np.save(f'./data/{ablation}_numpy_files/{ablation + model_name}_{steps}_losses.npy', losses)
  101.   np.save(f'./data/{ablation}_numpy_files/{ablation + model_name}_{steps}_ssims.npy', ssims)
  102.   np.save(f'./data/{ablation}_numpy_files/{ablation + model_name}_{steps}_psnrs.npy', psnrs)
  103.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement