Advertisement
Bunich

EBM - Langevin dynamics (2nd order)

Apr 19th, 2021
876
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.45 KB | None | 0 0
  1. ## Standard libraries
  2. import os
  3. import numpy as np
  4. from tqdm.notebook import tqdm
  5. from IPython.display import clear_output
  6. import shutil
  7.  
  8. ## Imports for plotting
  9. import matplotlib.pyplot as plt
  10. from matplotlib import cm
  11.  
  12. ## PyTorch
  13. import torch
  14. import torch.nn as nn
  15. import torch.optim as optim
  16. import torch.autograd as autograd
  17. import torch.utils.data as data
  18. from torch.utils.tensorboard import SummaryWriter
  19.  
  20. # Torchvision
  21. import torchvision
  22. from torchvision.utils import make_grid
  23.  
  24. # Custom modules
  25. from ebm.config import *
  26. from ebm.models import CNNModel
  27.  
  28.  
  29. class DeepEnergyModel:
  30.     """
  31.    model_name (str) - Any name to visually recognize the model, like the #run.
  32.    model_description (str) - Will be logged by tensorboard as "text"
  33.    model_family (str) - When running multiple experiments, it may be useful to divide
  34.        the models and their logged results in families (subdirs of checkpoint path).
  35.        This param can have the form of a path/to/subfolder.
  36.    overwrite (bool) - If the logs folder already exists, if True "overwrite" it (namely,
  37.        add also the new logs, without removing the onld ones).
  38.    """
  39.     def __init__(self,
  40.                  img_shape,
  41.                  batch_size,
  42.                  alpha=1,
  43.                  lr=1e-4,
  44.                  weight_decay=1e-4,
  45.                  mcmc_step_size=1e-5,
  46.                  mcmc_steps=250,
  47.                  model_name="unnamed",
  48.                  model_description="",
  49.                  model_family="Langevin_vanilla",
  50.                  device="cuda:1",
  51.                  overwrite=False,
  52.                  **CNN_args):
  53.         super().__init__()
  54.  
  55.         # Model
  56.         self.img_shape = img_shape
  57.         self.device = torch.device(device) if torch.cuda.is_available() else torch.device("cpu")
  58.         print("Running on device:", self.device)
  59.         # Use CNNModel by default
  60.         self.cnn = CNNModel(**CNN_args).to(self.device)
  61.  
  62.         # Optimizers
  63.         self.lr = lr
  64.         self.weight_decay = weight_decay
  65.  
  66.         # Reg loss weigth
  67.         self.alpha = alpha
  68.  
  69.         # Dataset
  70.         self.batch_size = batch_size
  71.  
  72.         # MCMC
  73.         self.mcmc_step_size = mcmc_step_size
  74.         self.mcmc_steps = mcmc_steps
  75.  
  76.  
  77.     ######################################################
  78.     ################ Training section ####################
  79.     ######################################################
  80.  
  81.     def training_step(self, batch):
  82.  
  83.         # Train mode
  84.         self.cnn.train()
  85.  
  86.         # We add minimal noise to the original images to prevent the model from focusing on purely "clean" inputs
  87.         real_imgs, _ = batch
  88.         real_imgs = real_imgs.to(self.device)
  89.  
  90.         #small_noise = torch.randn_like(real_imgs) * 0.005
  91.         #real_imgs.add_(small_noise).clamp_(min=-1.0, max=1.0)
  92.  
  93.         # Obtain samples
  94.         fake_imgs = self.generate_samples()
  95.  
  96.         # Predict energy score for all images
  97.         inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
  98.         real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)
  99.  
  100.         # Calculate losses
  101.         cdiv_loss = real_out.mean() - fake_out.mean()
  102.         if self.alpha > 0:
  103.             reg_loss = self.alpha * (real_out**2 + fake_out**2).mean()
  104.             loss = reg_loss + cdiv_loss
  105.         else:
  106.             reg_loss = torch.tensor(0)
  107.             loss = cdiv_loss
  108.  
  109.         # Optimize
  110.         self.optimizer.zero_grad()
  111.         loss.backward()
  112.         self.optimizer.step()
  113.  
  114.  
  115.     def fit(self, n_epochs=None):
  116.        
  117.         assert self.is_setup, "Model is not properly setup. Call .setup() before running!"
  118.  
  119.         if self.train_loader is None:
  120.             print("Train data not loaded")
  121.             return
  122.  
  123.         # Epochs
  124.         self.tot_batches = len(self.train_loader)
  125.         for self.epoch_n in range(n_epochs):
  126.             clear_output()
  127.             print("Epoch #" + str(self.epoch_n + 1))
  128.  
  129.             # Iterations
  130.             self.log_active = True
  131.             for self.iter_n, batch in tqdm(enumerate(self.train_loader),
  132.                                            total=self.tot_batches,
  133.                                            position=0,
  134.                                            leave=True):
  135.  
  136.                 self.training_step(batch)
  137.  
  138.     ######################################################
  139.     ############ 1st order Langevin dynamics #############
  140.     ######################################################
  141.    
  142.  
  143.     def generate_samples(self,
  144.                          evaluation=False,
  145.                          batch_size=None,
  146.                          mcmc_steps=None):
  147.      
  148.         is_training = self.cnn.training
  149.         self.cnn.eval()
  150.  
  151.         # Init images with RND normal noise: x_i ~ N(0,1)
  152.         x = torch.randn((batch_size, ) + self.img_shape, device=self.device)
  153.         x.requires_grad = True
  154.        
  155.         noise_scale = np.sqrt(self.mcmc_step_size * 2)
  156.        
  157.         # Pre-allocate additive noise (for Langevin step)
  158.         noise = torch.randn_like(x, device=self.device)
  159.  
  160.         for _ in range(mcmc_steps):
  161.             # Re-init noise tensor
  162.             noise.normal_(mean=0.0, std=noise_scale)
  163.             out = self.cnn(x)
  164.             grad = autograd.grad(out.sum(), x, only_inputs=True)[0]
  165.  
  166.             dynamics = self.mcmc_step_size * grad + noise
  167.             x = x - dynamics
  168.  
  169.         self.cnn.train(is_training)
  170.  
  171.         return x.detach()
  172.    
  173.  
  174.  
  175.  
  176. ###########################################################################################################################
  177.  
  178. class EBMLangVanilla(DeepEnergyModel):
  179.     """"Vanilla Langevin Dynamics"""
  180.     def __init__(self, **kwargs):
  181.         super().__init__(**kwargs)
  182.  
  183.  
  184.  
  185. ###########################################################################################################################
  186. class EBMLang2Ord(DeepEnergyModel):
  187.     """Second order Langevin Dynamics, with leapfrog"""
  188.     def __init__(self, C=2, mass=1, **kwargs):
  189.         super().__init__(**kwargs)
  190.         self.C = C
  191.         self.hparams_dict['C'] = C
  192.         self.mass = mass
  193.         self.hparams_dict['mass'] = mass
  194.    
  195.     def generate_samples(self,
  196.                          evaluation=False,
  197.                          batch_size=None,
  198.                          mcmc_steps=None):
  199.      
  200.        
  201.         is_training = self.cnn.training
  202.         self.cnn.eval()
  203.  
  204.         # Init images with RND normal noise: x_i ~ N(0,1)
  205.         x = torch.randn((batch_size, ) + self.img_shape, device=self.device)
  206.         original_x = x.clone().detach()
  207.         x.requires_grad = True
  208.        
  209.         # Init momentum
  210.         #momentum = torch.randn((batch_size, ) + self.img_shape, device=self.device)
  211.         momentum = torch.zeros_like(x, device=self.device)
  212.         noise_scale = np.sqrt(self.mcmc_step_size * 2 * self.C)
  213.        
  214.         # Pre-allocate additive noise (for Langevin step)
  215.         noise = torch.randn_like(x, device=self.device)
  216.  
  217.        
  218.         for _ in range(mcmc_steps):
  219.  
  220.             # Re-init noise tensor
  221.             noise.normal_(mean=0.0, std=noise_scale)
  222.             out = self.cnn(x)
  223.             grad = autograd.grad(out.sum(), x, only_inputs=True)[0]
  224.             momentum = momentum - self.mass * momentum * self.mcmc_step_size * self.C - self.mcmc_step_size * grad + noise
  225.             x = x + self.mcmc_step_size * self.mass * momentum
  226.  
  227.         self.cnn.train(is_training)
  228.  
  229.         return x.detach()
  230.    
  231.    
  232.    
  233.        
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement