Advertisement
Guest User

DCGAN on MNIST

a guest
Feb 26th, 2020
600
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 17.72 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torchvision
  4. from torchvision import transforms, datasets
  5. import torch.nn.functional as F
  6. from torch import optim as optim
  7. from torch.utils.tensorboard import SummaryWriter
  8.  
  9. import numpy as np
  10.  
  11. import os
  12. import time
  13.  
  14.  
  15. class Discriminator(torch.nn.Module):
  16.     def __init__(self, ndf=16, dropout_value=0.5):  # ndf feature map discriminator
  17.         super().__init__()
  18.         self.ndf = ndf
  19.         self.droupout_value = dropout_value
  20.  
  21.         self.condi = nn.Sequential(
  22.             nn.Linear(in_features=10, out_features=64 * 64)
  23.         )
  24.  
  25.         self.hidden0 = nn.Sequential(
  26.             nn.Conv2d(in_channels=2, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
  27.             nn.LeakyReLU(0.2),
  28.         )
  29.         self.hidden1 = nn.Sequential(
  30.             nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
  31.             nn.BatchNorm2d(self.ndf * 2),
  32.             nn.LeakyReLU(0.2),
  33.             nn.Dropout(self.droupout_value)
  34.         )
  35.         self.hidden2 = nn.Sequential(
  36.             nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
  37.             #nn.BatchNorm2d(self.ndf * 4),
  38.             nn.LeakyReLU(0.2),
  39.             nn.Dropout(self.droupout_value)
  40.         )
  41.         self.hidden3 = nn.Sequential(
  42.             nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
  43.             nn.BatchNorm2d(self.ndf * 8),
  44.             nn.LeakyReLU(0.2),
  45.             nn.Dropout(self.droupout_value)
  46.         )
  47.         self.out = nn.Sequential(
  48.             nn.Conv2d(in_channels=self.ndf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
  49.             torch.nn.Sigmoid()
  50.         )
  51.  
  52.     def forward(self, x, y):
  53.         y = self.condi(y.view(-1, 10))
  54.         y = y.view(-1, 1, 64, 64)
  55.  
  56.         x = torch.cat((x, y), dim=1)
  57.  
  58.         x = self.hidden0(x)
  59.         x = self.hidden1(x)
  60.         x = self.hidden2(x)
  61.         x = self.hidden3(x)
  62.         x = self.out(x)
  63.  
  64.         return x
  65.  
  66.  
  67. class Generator(torch.nn.Module):
  68.     def __init__(self, n_features=100, ngf=16, c_channels=1, dropout_value=0.5):  # ngf feature map of generator
  69.         super().__init__()
  70.         self.ngf = ngf
  71.         self.n_features = n_features
  72.         self.c_channels = c_channels
  73.         self.droupout_value = dropout_value
  74.  
  75.         self.hidden0 = nn.Sequential(
  76.             nn.ConvTranspose2d(in_channels=self.n_features + 10, out_channels=self.ngf * 8,
  77.                                kernel_size=4, stride=1, padding=0, bias=False),
  78.             nn.BatchNorm2d(self.ngf * 8),
  79.             nn.LeakyReLU(0.2)
  80.         )
  81.  
  82.         self.hidden1 = nn.Sequential(
  83.             nn.ConvTranspose2d(in_channels=self.ngf * 8, out_channels=self.ngf * 4,
  84.                                kernel_size=4, stride=2, padding=1, bias=False),
  85.             #nn.BatchNorm2d(self.ngf * 4),
  86.             nn.LeakyReLU(0.2),
  87.             nn.Dropout(self.droupout_value)
  88.         )
  89.  
  90.         self.hidden2 = nn.Sequential(
  91.             nn.ConvTranspose2d(in_channels=self.ngf * 4, out_channels=self.ngf * 2,
  92.                                kernel_size=4, stride=2, padding=1, bias=False),
  93.             nn.BatchNorm2d(self.ngf * 2),
  94.             nn.LeakyReLU(0.2),
  95.             nn.Dropout(self.droupout_value)
  96.         )
  97.  
  98.         self.hidden3 = nn.Sequential(
  99.             nn.ConvTranspose2d(in_channels=self.ngf * 2, out_channels=self.ngf,
  100.                                kernel_size=4, stride=2, padding=1, bias=False),
  101.             nn.BatchNorm2d(self.ngf),
  102.             nn.LeakyReLU(0.2),
  103.             nn.Dropout(self.droupout_value)
  104.         )
  105.  
  106.         self.out = nn.Sequential(
  107.             # "out_channels=1" because gray scale
  108.             nn.ConvTranspose2d(in_channels=self.ngf, out_channels=1, kernel_size=4,
  109.                                stride=2, padding=1, bias=False),
  110.             nn.Tanh()
  111.         )
  112.  
  113.     def forward(self, x, y):
  114.         x_cond = torch.cat((x, y), dim=1)  # Combine flatten image with conditional input (class labels)
  115.  
  116.         x = self.hidden0(x_cond)           # Image goes into a "ConvTranspose2d" layer
  117.         x = self.hidden1(x)
  118.         x = self.hidden2(x)
  119.         x = self.hidden3(x)
  120.         x = self.out(x)
  121.  
  122.         return x
  123.  
  124.  
  125. class Logger:
  126.     def __init__(self, model_name, model1, model2, m1_optimizer, m2_optimizer, model_parameter, train_loader):
  127.         self.out_dir = "data"
  128.         self.model_name = model_name
  129.         self.train_loader = train_loader
  130.         self.model1 = model1
  131.         self.model2 = model2
  132.         self.model_parameter = model_parameter
  133.         self.m1_optimizer = m1_optimizer
  134.         self.m2_optimizer = m2_optimizer
  135.  
  136.         # Exclude Epochs of the model name. This make sense e.g. when we stop a training progress and continue later on.
  137.         self.experiment_name = '_'.join("{!s}={!r}".format(k, v) for (k, v) in model_parameter.items())\
  138.             .replace("Epochs" + "=" + str(model_parameter["Epochs"]), "")
  139.  
  140.         self.d_error = 0
  141.         self.g_error = 0
  142.  
  143.         self.tb = SummaryWriter(log_dir=str(self.out_dir + "/log/" + self.model_name + "/runs/" + self.experiment_name))
  144.  
  145.         self.path_image = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/images/{self.experiment_name}')
  146.         self.path_model = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/model/{self.experiment_name}')
  147.  
  148.         try:
  149.             os.makedirs(self.path_image)
  150.         except Exception as e:
  151.             print("WARNING: ", str(e))
  152.  
  153.         try:
  154.             os.makedirs(self.path_model)
  155.         except Exception as e:
  156.             print("WARNING: ", str(e))
  157.  
  158.     def log_graph(self, model1_input, model2_input, model1_label, model2_label):
  159.         self.tb.add_graph(self.model1, input_to_model=(model1_input, model1_label))
  160.         self.tb.add_graph(self.model2, input_to_model=(model2_input, model2_label))
  161.  
  162.     def log(self, num_epoch, d_error, g_error):
  163.         self.d_error = d_error
  164.         self.g_error = g_error
  165.  
  166.         self.tb.add_scalar("Discriminator Train Error", self.d_error, num_epoch)
  167.         self.tb.add_scalar("Generator Train Error", self.g_error, num_epoch)
  168.  
  169.     def log_image(self, images, epoch, batch_num):
  170.         grid = torchvision.utils.make_grid(images)
  171.         torchvision.utils.save_image(grid, f'{self.path_image}\\Epoch_{epoch}_batch_{batch_num}.png')
  172.  
  173.         self.tb.add_image("Generator Image", grid)
  174.  
  175.     def log_histogramm(self):
  176.         for name, param in self.model2.named_parameters():
  177.             self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
  178.             self.tb.add_histogram(f'gen_{name}.grad', param.grad, self.model_parameter["Epochs"])
  179.  
  180.         for name, param in self.model1.named_parameters():
  181.             self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
  182.             self.tb.add_histogram(f'dis_{name}.grad', param.grad, self.model_parameter["Epochs"])
  183.  
  184.     def log_model(self, num_epoch):
  185.         torch.save({
  186.             "epoch": num_epoch,
  187.             "model_generator_state_dict": self.model1.state_dict(),
  188.             "model_discriminator_state_dict": self.model2.state_dict(),
  189.             "optimizer_generator_state_dict":  self.m1_optimizer.state_dict(),
  190.             "optimizer_discriminator_state_dict":  self.m2_optimizer.state_dict(),
  191.         }, str(self.path_model + f'\\{time.time()}_epoch{num_epoch}.pth'))
  192.  
  193.     def close(self, logger, images, num_epoch,  d_error, g_error):
  194.         logger.log_model(num_epoch)
  195.         logger.log_histogramm()
  196.         logger.log(num_epoch, d_error, g_error)
  197.         self.tb.close()
  198.  
  199.     def display_stats(self, epoch, batch_num, dis_error, gen_error):
  200.         print(f'Epoch: [{epoch}/{self.model_parameter["Epochs"]}] '
  201.               f'Batch: [{batch_num}/{len(self.train_loader)}] '
  202.               f'Loss_D: {dis_error.data.cpu()}, '
  203.               f'Loss_G: {gen_error.data.cpu()}')
  204.  
  205.  
  206. def get_MNIST_dataset(num_workers_loader, model_parameter, out_dir="data"):
  207.     compose = transforms.Compose([
  208.         transforms.Resize((64, 64)),
  209.         transforms.CenterCrop((64, 64)),
  210.         transforms.ToTensor(),
  211.         torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
  212.     ])
  213.  
  214.     dataset = datasets.MNIST(
  215.         root=out_dir,
  216.         train=True,
  217.         download=True,
  218.         transform=compose
  219.     )
  220.  
  221.     train_loader = torch.utils.data.DataLoader(dataset,
  222.                                                batch_size=model_parameter["batch_size"],
  223.                                                num_workers=num_workers_loader,
  224.                                                shuffle=model_parameter["shuffle"])
  225.  
  226.     return dataset, train_loader
  227.  
  228.  
  229. def train_discriminator(p_optimizer, p_noise, p_images, p_fake_target, p_real_target, p_images_labels, p_fake_labels, device):
  230.     p_optimizer.zero_grad()
  231.  
  232.     # 1.1 Train on real data
  233.     pred_dis_real = discriminator(p_images, p_images_labels)
  234.     error_real = loss(pred_dis_real, p_real_target)
  235.  
  236.     error_real.backward()
  237.  
  238.     # 1.2 Train on fake data
  239.     fake_data = generator(p_noise, p_fake_labels).detach()
  240.     fake_data = add_noise_to_image(fake_data, device)
  241.     pred_dis_fake = discriminator(fake_data, p_fake_labels)
  242.     error_fake = loss(pred_dis_fake, p_fake_target)
  243.  
  244.     error_fake.backward()
  245.  
  246.     p_optimizer.step()
  247.  
  248.     return error_fake + error_real
  249.  
  250.  
  251. def train_generator(p_optimizer, p_noise, p_real_target, p_fake_labels, device):
  252.     p_optimizer.zero_grad()
  253.  
  254.     fake_images = generator(p_noise, p_fake_labels)
  255.     fake_images = add_noise_to_image(fake_images, device)
  256.     pred_dis_fake = discriminator(fake_images, p_fake_labels)
  257.     error_fake = loss(pred_dis_fake, p_real_target)  # because
  258.     """
  259.    We use "p_real_target" instead of "p_fake_target" because we want to
  260.    maximize that the discriminator is wrong.
  261.    """
  262.  
  263.     error_fake.backward()
  264.  
  265.     p_optimizer.step()
  266.  
  267.     return fake_images, pred_dis_fake, error_fake
  268.  
  269.  
  270. # TODO change to a Truncated normal distribution
  271. def get_noise(batch_size, n_features=100):
  272.     return torch.FloatTensor(batch_size, n_features, 1, 1).uniform_(-1, 1)
  273.  
  274.  
  275. # We flip label of real and fate data. Better gradient flow I have told
  276. def get_real_data_target(batch_size):
  277.     return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.0, 0.2)
  278.  
  279.  
  280. def get_fake_data_target(batch_size):
  281.     return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.8, 1.1)
  282.  
  283.  
  284. def image_to_vector(images):
  285.     return torch.flatten(images, start_dim=1, end_dim=-1)
  286.  
  287.  
  288. def vector_to_image(images):
  289.     return images.view(images.size(0), 1, 28, 28)
  290.  
  291.  
  292. def get_rand_labels(batch_size):
  293.     return torch.randint(low=0, high=9, size=(batch_size,))
  294.  
  295.  
  296. def load_model(model_load_path):
  297.     if model_load_path:
  298.         checkpoint = torch.load(model_load_path)
  299.  
  300.         discriminator.load_state_dict(checkpoint["model_discriminator_state_dict"])
  301.         generator.load_state_dict(checkpoint["model_generator_state_dict"])
  302.  
  303.         dis_opti.load_state_dict(checkpoint["optimizer_discriminator_state_dict"])
  304.         gen_opti.load_state_dict(checkpoint["optimizer_generator_state_dict"])
  305.  
  306.         return checkpoint["epoch"]
  307.  
  308.     else:
  309.         return 0
  310.  
  311.  
  312. def init_model_optimizer(model_parameter, device):
  313.     # Initialize the Models
  314.     discriminator = Discriminator(ndf=model_parameter["ndf"], dropout_value=model_parameter["dropout"]).to(device)
  315.     generator = Generator(ngf=model_parameter["ngf"], dropout_value=model_parameter["dropout"]).to(device)
  316.  
  317.     # train
  318.     dis_opti = optim.Adam(discriminator.parameters(), lr=model_parameter["learning_rate_dis"], betas=(0.5, 0.999))
  319.     gen_opti = optim.Adam(generator.parameters(), lr=model_parameter["learning_rate_gen"], betas=(0.5, 0.999))
  320.  
  321.     return discriminator, generator, dis_opti, gen_opti
  322.  
  323.  
  324. def get_hot_vector_encode(labels, device):
  325.     return torch.eye(10)[labels].view(-1, 10, 1, 1).to(device)
  326.  
  327.  
  328. def add_noise_to_image(images, device, level_of_noise=0.1):
  329.     return images[0].to(device) + (level_of_noise) * torch.randn(images.shape).to(device)
  330.  
  331.  
  332. if __name__ == "__main__":
  333.     # Hyperparameter
  334.     model_parameter = {
  335.         "batch_size": 500,
  336.         "learning_rate_dis": 0.0002,
  337.         "learning_rate_gen": 0.0002,
  338.         "shuffle": False,
  339.         "Epochs": 10,
  340.         "ndf": 64,
  341.         "ngf": 64,
  342.         "dropout": 0.5
  343.     }
  344.  
  345.     # Parameter
  346.     r_frequent = 10        # How many samples we save for replay per batch (batch_size / r_frequent).
  347.     model_name = "CDCGAN"   # The name of you model e.g. "Gan"
  348.     num_workers_loader = 1  # How many workers should load the data
  349.     sample_save_size = 16   # How many numbers your saved imaged should show
  350.     device = "cuda"         # Which device should be used to train the neural network
  351.     model_load_path = ""    # If set load model instead of training from new
  352.     num_epoch_log = 1       # How frequent you want to log/
  353.     torch.manual_seed(43)   # Sets a seed for torch for reproducibility
  354.  
  355.     dataset_train, train_loader = get_MNIST_dataset(num_workers_loader, model_parameter)  # Get dataset
  356.  
  357.     # Initialize the Models and optimizer
  358.     discriminator, generator, dis_opti, gen_opti = init_model_optimizer(model_parameter, device)  # Init model/Optimizer
  359.  
  360.     start_epoch = load_model(model_load_path)  # when we want to load a model
  361.  
  362.     # Init Logger
  363.     logger = Logger(model_name, generator, discriminator, gen_opti, dis_opti, model_parameter, train_loader)
  364.  
  365.     loss = nn.BCELoss()
  366.  
  367.     images, labels = next(iter(train_loader))  # For logging
  368.  
  369.     # For testing
  370.     # pred = generator(get_noise(model_parameter["batch_size"]).to(device), get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device))
  371.     # dis = discriminator(images.to(device), get_hot_vector_encode(labels, device))
  372.  
  373.     logger.log_graph(get_noise(model_parameter["batch_size"]).to(device), images.to(device),
  374.                      get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device),
  375.                      get_hot_vector_encode(labels, device))
  376.  
  377.  
  378.     # Array to store
  379.     exp_replay = torch.tensor([]).to(device)
  380.  
  381.     for num_epoch in range(start_epoch, model_parameter["Epochs"]):
  382.         for batch_num, data_loader in enumerate(train_loader):
  383.             images, labels = data_loader
  384.             images = add_noise_to_image(images, device)  # Add noise to the images
  385.  
  386.             # 1. Train Discriminator
  387.             dis_error = train_discriminator(
  388.                                             dis_opti,
  389.                                             get_noise(model_parameter["batch_size"]).to(device),
  390.                                             images.to(device),
  391.                                             get_fake_data_target(model_parameter["batch_size"]).to(device),
  392.                                             get_real_data_target(model_parameter["batch_size"]).to(device),
  393.                                             get_hot_vector_encode(labels, device),
  394.                                             get_hot_vector_encode(
  395.                                                 get_rand_labels(model_parameter["batch_size"]), device),
  396.                                             device
  397.                                             )
  398.  
  399.             # 2. Train Generator
  400.             fake_image, pred_dis_fake, gen_error = train_generator(
  401.                                                                   gen_opti,
  402.                                                                   get_noise(model_parameter["batch_size"]).to(device),
  403.                                                                   get_real_data_target(model_parameter["batch_size"]).to(device),
  404.                                                                   get_hot_vector_encode(
  405.                                                                       get_rand_labels(model_parameter["batch_size"]),
  406.                                                                       device),
  407.                                                                   device
  408.                                                                   )
  409.  
  410.  
  411.             # Store a random point for experience replay
  412.             perm = torch.randperm(fake_image.size(0))
  413.             r_idx = perm[:max(1, int(model_parameter["batch_size"] / r_frequent))]
  414.             r_samples = add_noise_to_image(fake_image[r_idx], device)
  415.             exp_replay = torch.cat((exp_replay, r_samples), 0).detach()
  416.  
  417.             if exp_replay.size(0) >= model_parameter["batch_size"]:
  418.                 # Train on experienced data
  419.                 dis_opti.zero_grad()
  420.  
  421.                 r_label = get_hot_vector_encode(torch.zeros(exp_replay.size(0)).numpy(), device)
  422.                 pred_dis_real = discriminator(exp_replay, r_label)
  423.                 error_real = loss(pred_dis_real,  get_fake_data_target(exp_replay.size(0)).to(device))
  424.  
  425.                 error_real.backward()
  426.  
  427.                 dis_opti.step()
  428.  
  429.                 print(f'Epoch: [{num_epoch}/{model_parameter["Epochs"]}] '
  430.                       f'Batch: Replay/Experience batch '
  431.                       f'Loss_D: {error_real.data.cpu()}, '
  432.                       )
  433.  
  434.                 exp_replay = torch.tensor([]).to(device)
  435.  
  436.             logger.display_stats(epoch=num_epoch, batch_num=batch_num, dis_error=dis_error, gen_error=gen_error)
  437.  
  438.             if batch_num % 100 == 0:
  439.                 logger.log_image(fake_image[:sample_save_size], num_epoch, batch_num)
  440.  
  441.         logger.log(num_epoch, dis_error, gen_error)
  442.         if num_epoch % num_epoch_log == 0:
  443.             logger.log_model(num_epoch)
  444.             logger.log_histogramm()
  445.     logger.close(logger, fake_image[:sample_save_size], num_epoch, dis_error, gen_error)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement