Guest User

DCGAN on MNIST

a guest
Feb 26th, 2020
205
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import torch.nn as nn
  3. import torchvision
  4. from torchvision import transforms, datasets
  5. import torch.nn.functional as F
  6. from torch import optim as optim
  7. from torch.utils.tensorboard import SummaryWriter
  8.  
  9. import numpy as np
  10.  
  11. import os
  12. import time
  13.  
  14.  
  15. class Discriminator(torch.nn.Module):
  16.     def __init__(self, ndf=16, dropout_value=0.5):  # ndf feature map discriminator
  17.         super().__init__()
  18.         self.ndf = ndf
  19.         self.droupout_value = dropout_value
  20.  
  21.         self.condi = nn.Sequential(
  22.             nn.Linear(in_features=10, out_features=64 * 64)
  23.         )
  24.  
  25.         self.hidden0 = nn.Sequential(
  26.             nn.Conv2d(in_channels=2, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
  27.             nn.LeakyReLU(0.2),
  28.         )
  29.         self.hidden1 = nn.Sequential(
  30.             nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
  31.             nn.BatchNorm2d(self.ndf * 2),
  32.             nn.LeakyReLU(0.2),
  33.             nn.Dropout(self.droupout_value)
  34.         )
  35.         self.hidden2 = nn.Sequential(
  36.             nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
  37.             #nn.BatchNorm2d(self.ndf * 4),
  38.             nn.LeakyReLU(0.2),
  39.             nn.Dropout(self.droupout_value)
  40.         )
  41.         self.hidden3 = nn.Sequential(
  42.             nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
  43.             nn.BatchNorm2d(self.ndf * 8),
  44.             nn.LeakyReLU(0.2),
  45.             nn.Dropout(self.droupout_value)
  46.         )
  47.         self.out = nn.Sequential(
  48.             nn.Conv2d(in_channels=self.ndf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
  49.             torch.nn.Sigmoid()
  50.         )
  51.  
  52.     def forward(self, x, y):
  53.         y = self.condi(y.view(-1, 10))
  54.         y = y.view(-1, 1, 64, 64)
  55.  
  56.         x = torch.cat((x, y), dim=1)
  57.  
  58.         x = self.hidden0(x)
  59.         x = self.hidden1(x)
  60.         x = self.hidden2(x)
  61.         x = self.hidden3(x)
  62.         x = self.out(x)
  63.  
  64.         return x
  65.  
  66.  
  67. class Generator(torch.nn.Module):
  68.     def __init__(self, n_features=100, ngf=16, c_channels=1, dropout_value=0.5):  # ngf feature map of generator
  69.         super().__init__()
  70.         self.ngf = ngf
  71.         self.n_features = n_features
  72.         self.c_channels = c_channels
  73.         self.droupout_value = dropout_value
  74.  
  75.         self.hidden0 = nn.Sequential(
  76.             nn.ConvTranspose2d(in_channels=self.n_features + 10, out_channels=self.ngf * 8,
  77.                                kernel_size=4, stride=1, padding=0, bias=False),
  78.             nn.BatchNorm2d(self.ngf * 8),
  79.             nn.LeakyReLU(0.2)
  80.         )
  81.  
  82.         self.hidden1 = nn.Sequential(
  83.             nn.ConvTranspose2d(in_channels=self.ngf * 8, out_channels=self.ngf * 4,
  84.                                kernel_size=4, stride=2, padding=1, bias=False),
  85.             #nn.BatchNorm2d(self.ngf * 4),
  86.             nn.LeakyReLU(0.2),
  87.             nn.Dropout(self.droupout_value)
  88.         )
  89.  
  90.         self.hidden2 = nn.Sequential(
  91.             nn.ConvTranspose2d(in_channels=self.ngf * 4, out_channels=self.ngf * 2,
  92.                                kernel_size=4, stride=2, padding=1, bias=False),
  93.             nn.BatchNorm2d(self.ngf * 2),
  94.             nn.LeakyReLU(0.2),
  95.             nn.Dropout(self.droupout_value)
  96.         )
  97.  
  98.         self.hidden3 = nn.Sequential(
  99.             nn.ConvTranspose2d(in_channels=self.ngf * 2, out_channels=self.ngf,
  100.                                kernel_size=4, stride=2, padding=1, bias=False),
  101.             nn.BatchNorm2d(self.ngf),
  102.             nn.LeakyReLU(0.2),
  103.             nn.Dropout(self.droupout_value)
  104.         )
  105.  
  106.         self.out = nn.Sequential(
  107.             # "out_channels=1" because gray scale
  108.             nn.ConvTranspose2d(in_channels=self.ngf, out_channels=1, kernel_size=4,
  109.                                stride=2, padding=1, bias=False),
  110.             nn.Tanh()
  111.         )
  112.  
  113.     def forward(self, x, y):
  114.         x_cond = torch.cat((x, y), dim=1)  # Combine flatten image with conditional input (class labels)
  115.  
  116.         x = self.hidden0(x_cond)           # Image goes into a "ConvTranspose2d" layer
  117.         x = self.hidden1(x)
  118.         x = self.hidden2(x)
  119.         x = self.hidden3(x)
  120.         x = self.out(x)
  121.  
  122.         return x
  123.  
  124.  
  125. class Logger:
  126.     def __init__(self, model_name, model1, model2, m1_optimizer, m2_optimizer, model_parameter, train_loader):
  127.         self.out_dir = "data"
  128.         self.model_name = model_name
  129.         self.train_loader = train_loader
  130.         self.model1 = model1
  131.         self.model2 = model2
  132.         self.model_parameter = model_parameter
  133.         self.m1_optimizer = m1_optimizer
  134.         self.m2_optimizer = m2_optimizer
  135.  
  136.         # Exclude Epochs of the model name. This make sense e.g. when we stop a training progress and continue later on.
  137.         self.experiment_name = '_'.join("{!s}={!r}".format(k, v) for (k, v) in model_parameter.items())\
  138.             .replace("Epochs" + "=" + str(model_parameter["Epochs"]), "")
  139.  
  140.         self.d_error = 0
  141.         self.g_error = 0
  142.  
  143.         self.tb = SummaryWriter(log_dir=str(self.out_dir + "/log/" + self.model_name + "/runs/" + self.experiment_name))
  144.  
  145.         self.path_image = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/images/{self.experiment_name}')
  146.         self.path_model = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/model/{self.experiment_name}')
  147.  
  148.         try:
  149.             os.makedirs(self.path_image)
  150.         except Exception as e:
  151.             print("WARNING: ", str(e))
  152.  
  153.         try:
  154.             os.makedirs(self.path_model)
  155.         except Exception as e:
  156.             print("WARNING: ", str(e))
  157.  
  158.     def log_graph(self, model1_input, model2_input, model1_label, model2_label):
  159.         self.tb.add_graph(self.model1, input_to_model=(model1_input, model1_label))
  160.         self.tb.add_graph(self.model2, input_to_model=(model2_input, model2_label))
  161.  
  162.     def log(self, num_epoch, d_error, g_error):
  163.         self.d_error = d_error
  164.         self.g_error = g_error
  165.  
  166.         self.tb.add_scalar("Discriminator Train Error", self.d_error, num_epoch)
  167.         self.tb.add_scalar("Generator Train Error", self.g_error, num_epoch)
  168.  
  169.     def log_image(self, images, epoch, batch_num):
  170.         grid = torchvision.utils.make_grid(images)
  171.         torchvision.utils.save_image(grid, f'{self.path_image}\\Epoch_{epoch}_batch_{batch_num}.png')
  172.  
  173.         self.tb.add_image("Generator Image", grid)
  174.  
  175.     def log_histogramm(self):
  176.         for name, param in self.model2.named_parameters():
  177.             self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
  178.             self.tb.add_histogram(f'gen_{name}.grad', param.grad, self.model_parameter["Epochs"])
  179.  
  180.         for name, param in self.model1.named_parameters():
  181.             self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
  182.             self.tb.add_histogram(f'dis_{name}.grad', param.grad, self.model_parameter["Epochs"])
  183.  
  184.     def log_model(self, num_epoch):
  185.         torch.save({
  186.             "epoch": num_epoch,
  187.             "model_generator_state_dict": self.model1.state_dict(),
  188.             "model_discriminator_state_dict": self.model2.state_dict(),
  189.             "optimizer_generator_state_dict":  self.m1_optimizer.state_dict(),
  190.             "optimizer_discriminator_state_dict":  self.m2_optimizer.state_dict(),
  191.         }, str(self.path_model + f'\\{time.time()}_epoch{num_epoch}.pth'))
  192.  
  193.     def close(self, logger, images, num_epoch,  d_error, g_error):
  194.         logger.log_model(num_epoch)
  195.         logger.log_histogramm()
  196.         logger.log(num_epoch, d_error, g_error)
  197.         self.tb.close()
  198.  
  199.     def display_stats(self, epoch, batch_num, dis_error, gen_error):
  200.         print(f'Epoch: [{epoch}/{self.model_parameter["Epochs"]}] '
  201.               f'Batch: [{batch_num}/{len(self.train_loader)}] '
  202.               f'Loss_D: {dis_error.data.cpu()}, '
  203.               f'Loss_G: {gen_error.data.cpu()}')
  204.  
  205.  
  206. def get_MNIST_dataset(num_workers_loader, model_parameter, out_dir="data"):
  207.     compose = transforms.Compose([
  208.         transforms.Resize((64, 64)),
  209.         transforms.CenterCrop((64, 64)),
  210.         transforms.ToTensor(),
  211.         torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
  212.     ])
  213.  
  214.     dataset = datasets.MNIST(
  215.         root=out_dir,
  216.         train=True,
  217.         download=True,
  218.         transform=compose
  219.     )
  220.  
  221.     train_loader = torch.utils.data.DataLoader(dataset,
  222.                                                batch_size=model_parameter["batch_size"],
  223.                                                num_workers=num_workers_loader,
  224.                                                shuffle=model_parameter["shuffle"])
  225.  
  226.     return dataset, train_loader
  227.  
  228.  
  229. def train_discriminator(p_optimizer, p_noise, p_images, p_fake_target, p_real_target, p_images_labels, p_fake_labels, device):
  230.     p_optimizer.zero_grad()
  231.  
  232.     # 1.1 Train on real data
  233.     pred_dis_real = discriminator(p_images, p_images_labels)
  234.     error_real = loss(pred_dis_real, p_real_target)
  235.  
  236.     error_real.backward()
  237.  
  238.     # 1.2 Train on fake data
  239.     fake_data = generator(p_noise, p_fake_labels).detach()
  240.     fake_data = add_noise_to_image(fake_data, device)
  241.     pred_dis_fake = discriminator(fake_data, p_fake_labels)
  242.     error_fake = loss(pred_dis_fake, p_fake_target)
  243.  
  244.     error_fake.backward()
  245.  
  246.     p_optimizer.step()
  247.  
  248.     return error_fake + error_real
  249.  
  250.  
  251. def train_generator(p_optimizer, p_noise, p_real_target, p_fake_labels, device):
  252.     p_optimizer.zero_grad()
  253.  
  254.     fake_images = generator(p_noise, p_fake_labels)
  255.     fake_images = add_noise_to_image(fake_images, device)
  256.     pred_dis_fake = discriminator(fake_images, p_fake_labels)
  257.     error_fake = loss(pred_dis_fake, p_real_target)  # because
  258.     """
  259.    We use "p_real_target" instead of "p_fake_target" because we want to
  260.    maximize that the discriminator is wrong.
  261.    """
  262.  
  263.     error_fake.backward()
  264.  
  265.     p_optimizer.step()
  266.  
  267.     return fake_images, pred_dis_fake, error_fake
  268.  
  269.  
  270. # TODO change to a Truncated normal distribution
  271. def get_noise(batch_size, n_features=100):
  272.     return torch.FloatTensor(batch_size, n_features, 1, 1).uniform_(-1, 1)
  273.  
  274.  
  275. # We flip label of real and fate data. Better gradient flow I have told
  276. def get_real_data_target(batch_size):
  277.     return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.0, 0.2)
  278.  
  279.  
  280. def get_fake_data_target(batch_size):
  281.     return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.8, 1.1)
  282.  
  283.  
  284. def image_to_vector(images):
  285.     return torch.flatten(images, start_dim=1, end_dim=-1)
  286.  
  287.  
  288. def vector_to_image(images):
  289.     return images.view(images.size(0), 1, 28, 28)
  290.  
  291.  
  292. def get_rand_labels(batch_size):
  293.     return torch.randint(low=0, high=9, size=(batch_size,))
  294.  
  295.  
  296. def load_model(model_load_path):
  297.     if model_load_path:
  298.         checkpoint = torch.load(model_load_path)
  299.  
  300.         discriminator.load_state_dict(checkpoint["model_discriminator_state_dict"])
  301.         generator.load_state_dict(checkpoint["model_generator_state_dict"])
  302.  
  303.         dis_opti.load_state_dict(checkpoint["optimizer_discriminator_state_dict"])
  304.         gen_opti.load_state_dict(checkpoint["optimizer_generator_state_dict"])
  305.  
  306.         return checkpoint["epoch"]
  307.  
  308.     else:
  309.         return 0
  310.  
  311.  
  312. def init_model_optimizer(model_parameter, device):
  313.     # Initialize the Models
  314.     discriminator = Discriminator(ndf=model_parameter["ndf"], dropout_value=model_parameter["dropout"]).to(device)
  315.     generator = Generator(ngf=model_parameter["ngf"], dropout_value=model_parameter["dropout"]).to(device)
  316.  
  317.     # train
  318.     dis_opti = optim.Adam(discriminator.parameters(), lr=model_parameter["learning_rate_dis"], betas=(0.5, 0.999))
  319.     gen_opti = optim.Adam(generator.parameters(), lr=model_parameter["learning_rate_gen"], betas=(0.5, 0.999))
  320.  
  321.     return discriminator, generator, dis_opti, gen_opti
  322.  
  323.  
  324. def get_hot_vector_encode(labels, device):
  325.     return torch.eye(10)[labels].view(-1, 10, 1, 1).to(device)
  326.  
  327.  
  328. def add_noise_to_image(images, device, level_of_noise=0.1):
  329.     return images[0].to(device) + (level_of_noise) * torch.randn(images.shape).to(device)
  330.  
  331.  
  332. if __name__ == "__main__":
  333.     # Hyperparameter
  334.     model_parameter = {
  335.         "batch_size": 500,
  336.         "learning_rate_dis": 0.0002,
  337.         "learning_rate_gen": 0.0002,
  338.         "shuffle": False,
  339.         "Epochs": 10,
  340.         "ndf": 64,
  341.         "ngf": 64,
  342.         "dropout": 0.5
  343.     }
  344.  
  345.     # Parameter
  346.     r_frequent = 10        # How many samples we save for replay per batch (batch_size / r_frequent).
  347.     model_name = "CDCGAN"   # The name of you model e.g. "Gan"
  348.     num_workers_loader = 1  # How many workers should load the data
  349.     sample_save_size = 16   # How many numbers your saved imaged should show
  350.     device = "cuda"         # Which device should be used to train the neural network
  351.     model_load_path = ""    # If set load model instead of training from new
  352.     num_epoch_log = 1       # How frequent you want to log/
  353.     torch.manual_seed(43)   # Sets a seed for torch for reproducibility
  354.  
  355.     dataset_train, train_loader = get_MNIST_dataset(num_workers_loader, model_parameter)  # Get dataset
  356.  
  357.     # Initialize the Models and optimizer
  358.     discriminator, generator, dis_opti, gen_opti = init_model_optimizer(model_parameter, device)  # Init model/Optimizer
  359.  
  360.     start_epoch = load_model(model_load_path)  # when we want to load a model
  361.  
  362.     # Init Logger
  363.     logger = Logger(model_name, generator, discriminator, gen_opti, dis_opti, model_parameter, train_loader)
  364.  
  365.     loss = nn.BCELoss()
  366.  
  367.     images, labels = next(iter(train_loader))  # For logging
  368.  
  369.     # For testing
  370.     # pred = generator(get_noise(model_parameter["batch_size"]).to(device), get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device))
  371.     # dis = discriminator(images.to(device), get_hot_vector_encode(labels, device))
  372.  
  373.     logger.log_graph(get_noise(model_parameter["batch_size"]).to(device), images.to(device),
  374.                      get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device),
  375.                      get_hot_vector_encode(labels, device))
  376.  
  377.  
  378.     # Array to store
  379.     exp_replay = torch.tensor([]).to(device)
  380.  
  381.     for num_epoch in range(start_epoch, model_parameter["Epochs"]):
  382.         for batch_num, data_loader in enumerate(train_loader):
  383.             images, labels = data_loader
  384.             images = add_noise_to_image(images, device)  # Add noise to the images
  385.  
  386.             # 1. Train Discriminator
  387.             dis_error = train_discriminator(
  388.                                             dis_opti,
  389.                                             get_noise(model_parameter["batch_size"]).to(device),
  390.                                             images.to(device),
  391.                                             get_fake_data_target(model_parameter["batch_size"]).to(device),
  392.                                             get_real_data_target(model_parameter["batch_size"]).to(device),
  393.                                             get_hot_vector_encode(labels, device),
  394.                                             get_hot_vector_encode(
  395.                                                 get_rand_labels(model_parameter["batch_size"]), device),
  396.                                             device
  397.                                             )
  398.  
  399.             # 2. Train Generator
  400.             fake_image, pred_dis_fake, gen_error = train_generator(
  401.                                                                   gen_opti,
  402.                                                                   get_noise(model_parameter["batch_size"]).to(device),
  403.                                                                   get_real_data_target(model_parameter["batch_size"]).to(device),
  404.                                                                   get_hot_vector_encode(
  405.                                                                       get_rand_labels(model_parameter["batch_size"]),
  406.                                                                       device),
  407.                                                                   device
  408.                                                                   )
  409.  
  410.  
  411.             # Store a random point for experience replay
  412.             perm = torch.randperm(fake_image.size(0))
  413.             r_idx = perm[:max(1, int(model_parameter["batch_size"] / r_frequent))]
  414.             r_samples = add_noise_to_image(fake_image[r_idx], device)
  415.             exp_replay = torch.cat((exp_replay, r_samples), 0).detach()
  416.  
  417.             if exp_replay.size(0) >= model_parameter["batch_size"]:
  418.                 # Train on experienced data
  419.                 dis_opti.zero_grad()
  420.  
  421.                 r_label = get_hot_vector_encode(torch.zeros(exp_replay.size(0)).numpy(), device)
  422.                 pred_dis_real = discriminator(exp_replay, r_label)
  423.                 error_real = loss(pred_dis_real,  get_fake_data_target(exp_replay.size(0)).to(device))
  424.  
  425.                 error_real.backward()
  426.  
  427.                 dis_opti.step()
  428.  
  429.                 print(f'Epoch: [{num_epoch}/{model_parameter["Epochs"]}] '
  430.                       f'Batch: Replay/Experience batch '
  431.                       f'Loss_D: {error_real.data.cpu()}, '
  432.                       )
  433.  
  434.                 exp_replay = torch.tensor([]).to(device)
  435.  
  436.             logger.display_stats(epoch=num_epoch, batch_num=batch_num, dis_error=dis_error, gen_error=gen_error)
  437.  
  438.             if batch_num % 100 == 0:
  439.                 logger.log_image(fake_image[:sample_save_size], num_epoch, batch_num)
  440.  
  441.         logger.log(num_epoch, dis_error, gen_error)
  442.         if num_epoch % num_epoch_log == 0:
  443.             logger.log_model(num_epoch)
  444.             logger.log_histogramm()
  445.     logger.close(logger, fake_image[:sample_save_size], num_epoch, dis_error, gen_error)
RAW Paste Data

Adblocker detected! Please consider disabling it...

We've detected AdBlock Plus or some other adblocking software preventing Pastebin.com from fully loading.

We don't have any obnoxious sound, or popup ads, we actively block these annoying types of ads!

Please add Pastebin.com to your ad blocker whitelist or disable your adblocking software.

×