Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- ## Standard libraries
- import os
- import numpy as np
- from tqdm.notebook import tqdm
- from IPython.display import clear_output
- import shutil
- ## Imports for plotting
- import matplotlib.pyplot as plt
- from matplotlib import cm
- ## PyTorch
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torch.autograd as autograd
- import torch.utils.data as data
- from torch.utils.tensorboard import SummaryWriter
- # Torchvision
- import torchvision
- from torchvision.utils import make_grid
- # Custom modules
- from ebm.config import *
- from ebm.models import CNNModel
- class DeepEnergyModel:
- """
- model_name (str) - Any name to visually recognize the model, like the #run.
- model_description (str) - Will be logged by tensorboard as "text"
- model_family (str) - When running multiple experiments, it may be useful to divide
- the models and their logged results in families (subdirs of checkpoint path).
- This param can have the form of a path/to/subfolder.
- overwrite (bool) - If the logs folder already exists, if True "overwrite" it (namely,
- add also the new logs, without removing the onld ones).
- """
- def __init__(self,
- img_shape,
- batch_size,
- alpha=1,
- lr=1e-4,
- weight_decay=1e-4,
- mcmc_step_size=1e-5,
- mcmc_steps=250,
- model_name="unnamed",
- model_description="",
- model_family="Langevin_vanilla",
- device="cuda:1",
- overwrite=False,
- **CNN_args):
- super().__init__()
- # Model
- self.img_shape = img_shape
- self.device = torch.device(device) if torch.cuda.is_available() else torch.device("cpu")
- print("Running on device:", self.device)
- # Use CNNModel by default
- self.cnn = CNNModel(**CNN_args).to(self.device)
- # Optimizers
- self.lr = lr
- self.weight_decay = weight_decay
- # Reg loss weigth
- self.alpha = alpha
- # Dataset
- self.batch_size = batch_size
- # MCMC
- self.mcmc_step_size = mcmc_step_size
- self.mcmc_steps = mcmc_steps
- ######################################################
- ################ Training section ####################
- ######################################################
- def training_step(self, batch):
- # Train mode
- self.cnn.train()
- # We add minimal noise to the original images to prevent the model from focusing on purely "clean" inputs
- real_imgs, _ = batch
- real_imgs = real_imgs.to(self.device)
- #small_noise = torch.randn_like(real_imgs) * 0.005
- #real_imgs.add_(small_noise).clamp_(min=-1.0, max=1.0)
- # Obtain samples
- fake_imgs = self.generate_samples()
- # Predict energy score for all images
- inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)
- real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)
- # Calculate losses
- cdiv_loss = real_out.mean() - fake_out.mean()
- if self.alpha > 0:
- reg_loss = self.alpha * (real_out**2 + fake_out**2).mean()
- loss = reg_loss + cdiv_loss
- else:
- reg_loss = torch.tensor(0)
- loss = cdiv_loss
- # Optimize
- self.optimizer.zero_grad()
- loss.backward()
- self.optimizer.step()
- def fit(self, n_epochs=None):
- assert self.is_setup, "Model is not properly setup. Call .setup() before running!"
- if self.train_loader is None:
- print("Train data not loaded")
- return
- # Epochs
- self.tot_batches = len(self.train_loader)
- for self.epoch_n in range(n_epochs):
- clear_output()
- print("Epoch #" + str(self.epoch_n + 1))
- # Iterations
- self.log_active = True
- for self.iter_n, batch in tqdm(enumerate(self.train_loader),
- total=self.tot_batches,
- position=0,
- leave=True):
- self.training_step(batch)
- ######################################################
- ############ 1st order Langevin dynamics #############
- ######################################################
- def generate_samples(self,
- evaluation=False,
- batch_size=None,
- mcmc_steps=None):
- is_training = self.cnn.training
- self.cnn.eval()
- # Init images with RND normal noise: x_i ~ N(0,1)
- x = torch.randn((batch_size, ) + self.img_shape, device=self.device)
- x.requires_grad = True
- noise_scale = np.sqrt(self.mcmc_step_size * 2)
- # Pre-allocate additive noise (for Langevin step)
- noise = torch.randn_like(x, device=self.device)
- for _ in range(mcmc_steps):
- # Re-init noise tensor
- noise.normal_(mean=0.0, std=noise_scale)
- out = self.cnn(x)
- grad = autograd.grad(out.sum(), x, only_inputs=True)[0]
- dynamics = self.mcmc_step_size * grad + noise
- x = x - dynamics
- self.cnn.train(is_training)
- return x.detach()
- ###########################################################################################################################
- class EBMLangVanilla(DeepEnergyModel):
- """"Vanilla Langevin Dynamics"""
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- ###########################################################################################################################
- class EBMLang2Ord(DeepEnergyModel):
- """Second order Langevin Dynamics, with leapfrog"""
- def __init__(self, C=2, mass=1, **kwargs):
- super().__init__(**kwargs)
- self.C = C
- self.hparams_dict['C'] = C
- self.mass = mass
- self.hparams_dict['mass'] = mass
- def generate_samples(self,
- evaluation=False,
- batch_size=None,
- mcmc_steps=None):
- is_training = self.cnn.training
- self.cnn.eval()
- # Init images with RND normal noise: x_i ~ N(0,1)
- x = torch.randn((batch_size, ) + self.img_shape, device=self.device)
- original_x = x.clone().detach()
- x.requires_grad = True
- # Init momentum
- #momentum = torch.randn((batch_size, ) + self.img_shape, device=self.device)
- momentum = torch.zeros_like(x, device=self.device)
- noise_scale = np.sqrt(self.mcmc_step_size * 2 * self.C)
- # Pre-allocate additive noise (for Langevin step)
- noise = torch.randn_like(x, device=self.device)
- for _ in range(mcmc_steps):
- # Re-init noise tensor
- noise.normal_(mean=0.0, std=noise_scale)
- out = self.cnn(x)
- grad = autograd.grad(out.sum(), x, only_inputs=True)[0]
- momentum = momentum - self.mass * momentum * self.mcmc_step_size * self.C - self.mcmc_step_size * grad + noise
- x = x + self.mcmc_step_size * self.mass * momentum
- self.cnn.train(is_training)
- return x.detach()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement