Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torchvision
- from torchvision import transforms, datasets
- import torch.nn.functional as F
- from torch import optim as optim
- from torch.utils.tensorboard import SummaryWriter
- import numpy as np
- import os
- import time
- class Discriminator(torch.nn.Module):
- def __init__(self, ndf=16, dropout_value=0.5): # ndf feature map discriminator
- super().__init__()
- self.ndf = ndf
- self.droupout_value = dropout_value
- self.condi = nn.Sequential(
- nn.Linear(in_features=10, out_features=64 * 64)
- )
- self.hidden0 = nn.Sequential(
- nn.Conv2d(in_channels=2, out_channels=self.ndf, kernel_size=4, stride=2, padding=1, bias=False),
- nn.LeakyReLU(0.2),
- )
- self.hidden1 = nn.Sequential(
- nn.Conv2d(in_channels=self.ndf, out_channels=self.ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(self.ndf * 2),
- nn.LeakyReLU(0.2),
- nn.Dropout(self.droupout_value)
- )
- self.hidden2 = nn.Sequential(
- nn.Conv2d(in_channels=self.ndf * 2, out_channels=self.ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
- #nn.BatchNorm2d(self.ndf * 4),
- nn.LeakyReLU(0.2),
- nn.Dropout(self.droupout_value)
- )
- self.hidden3 = nn.Sequential(
- nn.Conv2d(in_channels=self.ndf * 4, out_channels=self.ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(self.ndf * 8),
- nn.LeakyReLU(0.2),
- nn.Dropout(self.droupout_value)
- )
- self.out = nn.Sequential(
- nn.Conv2d(in_channels=self.ndf * 8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False),
- torch.nn.Sigmoid()
- )
- def forward(self, x, y):
- y = self.condi(y.view(-1, 10))
- y = y.view(-1, 1, 64, 64)
- x = torch.cat((x, y), dim=1)
- x = self.hidden0(x)
- x = self.hidden1(x)
- x = self.hidden2(x)
- x = self.hidden3(x)
- x = self.out(x)
- return x
- class Generator(torch.nn.Module):
- def __init__(self, n_features=100, ngf=16, c_channels=1, dropout_value=0.5): # ngf feature map of generator
- super().__init__()
- self.ngf = ngf
- self.n_features = n_features
- self.c_channels = c_channels
- self.droupout_value = dropout_value
- self.hidden0 = nn.Sequential(
- nn.ConvTranspose2d(in_channels=self.n_features + 10, out_channels=self.ngf * 8,
- kernel_size=4, stride=1, padding=0, bias=False),
- nn.BatchNorm2d(self.ngf * 8),
- nn.LeakyReLU(0.2)
- )
- self.hidden1 = nn.Sequential(
- nn.ConvTranspose2d(in_channels=self.ngf * 8, out_channels=self.ngf * 4,
- kernel_size=4, stride=2, padding=1, bias=False),
- #nn.BatchNorm2d(self.ngf * 4),
- nn.LeakyReLU(0.2),
- nn.Dropout(self.droupout_value)
- )
- self.hidden2 = nn.Sequential(
- nn.ConvTranspose2d(in_channels=self.ngf * 4, out_channels=self.ngf * 2,
- kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(self.ngf * 2),
- nn.LeakyReLU(0.2),
- nn.Dropout(self.droupout_value)
- )
- self.hidden3 = nn.Sequential(
- nn.ConvTranspose2d(in_channels=self.ngf * 2, out_channels=self.ngf,
- kernel_size=4, stride=2, padding=1, bias=False),
- nn.BatchNorm2d(self.ngf),
- nn.LeakyReLU(0.2),
- nn.Dropout(self.droupout_value)
- )
- self.out = nn.Sequential(
- # "out_channels=1" because gray scale
- nn.ConvTranspose2d(in_channels=self.ngf, out_channels=1, kernel_size=4,
- stride=2, padding=1, bias=False),
- nn.Tanh()
- )
- def forward(self, x, y):
- x_cond = torch.cat((x, y), dim=1) # Combine flatten image with conditional input (class labels)
- x = self.hidden0(x_cond) # Image goes into a "ConvTranspose2d" layer
- x = self.hidden1(x)
- x = self.hidden2(x)
- x = self.hidden3(x)
- x = self.out(x)
- return x
- class Logger:
- def __init__(self, model_name, model1, model2, m1_optimizer, m2_optimizer, model_parameter, train_loader):
- self.out_dir = "data"
- self.model_name = model_name
- self.train_loader = train_loader
- self.model1 = model1
- self.model2 = model2
- self.model_parameter = model_parameter
- self.m1_optimizer = m1_optimizer
- self.m2_optimizer = m2_optimizer
- # Exclude Epochs of the model name. This make sense e.g. when we stop a training progress and continue later on.
- self.experiment_name = '_'.join("{!s}={!r}".format(k, v) for (k, v) in model_parameter.items())\
- .replace("Epochs" + "=" + str(model_parameter["Epochs"]), "")
- self.d_error = 0
- self.g_error = 0
- self.tb = SummaryWriter(log_dir=str(self.out_dir + "/log/" + self.model_name + "/runs/" + self.experiment_name))
- self.path_image = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/images/{self.experiment_name}')
- self.path_model = os.path.join(os.getcwd(), f'{self.out_dir}/log/{self.model_name}/model/{self.experiment_name}')
- try:
- os.makedirs(self.path_image)
- except Exception as e:
- print("WARNING: ", str(e))
- try:
- os.makedirs(self.path_model)
- except Exception as e:
- print("WARNING: ", str(e))
- def log_graph(self, model1_input, model2_input, model1_label, model2_label):
- self.tb.add_graph(self.model1, input_to_model=(model1_input, model1_label))
- self.tb.add_graph(self.model2, input_to_model=(model2_input, model2_label))
- def log(self, num_epoch, d_error, g_error):
- self.d_error = d_error
- self.g_error = g_error
- self.tb.add_scalar("Discriminator Train Error", self.d_error, num_epoch)
- self.tb.add_scalar("Generator Train Error", self.g_error, num_epoch)
- def log_image(self, images, epoch, batch_num):
- grid = torchvision.utils.make_grid(images)
- torchvision.utils.save_image(grid, f'{self.path_image}\\Epoch_{epoch}_batch_{batch_num}.png')
- self.tb.add_image("Generator Image", grid)
- def log_histogramm(self):
- for name, param in self.model2.named_parameters():
- self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
- self.tb.add_histogram(f'gen_{name}.grad', param.grad, self.model_parameter["Epochs"])
- for name, param in self.model1.named_parameters():
- self.tb.add_histogram(name, param, self.model_parameter["Epochs"])
- self.tb.add_histogram(f'dis_{name}.grad', param.grad, self.model_parameter["Epochs"])
- def log_model(self, num_epoch):
- torch.save({
- "epoch": num_epoch,
- "model_generator_state_dict": self.model1.state_dict(),
- "model_discriminator_state_dict": self.model2.state_dict(),
- "optimizer_generator_state_dict": self.m1_optimizer.state_dict(),
- "optimizer_discriminator_state_dict": self.m2_optimizer.state_dict(),
- }, str(self.path_model + f'\\{time.time()}_epoch{num_epoch}.pth'))
- def close(self, logger, images, num_epoch, d_error, g_error):
- logger.log_model(num_epoch)
- logger.log_histogramm()
- logger.log(num_epoch, d_error, g_error)
- self.tb.close()
- def display_stats(self, epoch, batch_num, dis_error, gen_error):
- print(f'Epoch: [{epoch}/{self.model_parameter["Epochs"]}] '
- f'Batch: [{batch_num}/{len(self.train_loader)}] '
- f'Loss_D: {dis_error.data.cpu()}, '
- f'Loss_G: {gen_error.data.cpu()}')
- def get_MNIST_dataset(num_workers_loader, model_parameter, out_dir="data"):
- compose = transforms.Compose([
- transforms.Resize((64, 64)),
- transforms.CenterCrop((64, 64)),
- transforms.ToTensor(),
- torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
- ])
- dataset = datasets.MNIST(
- root=out_dir,
- train=True,
- download=True,
- transform=compose
- )
- train_loader = torch.utils.data.DataLoader(dataset,
- batch_size=model_parameter["batch_size"],
- num_workers=num_workers_loader,
- shuffle=model_parameter["shuffle"])
- return dataset, train_loader
- def train_discriminator(p_optimizer, p_noise, p_images, p_fake_target, p_real_target, p_images_labels, p_fake_labels, device):
- p_optimizer.zero_grad()
- # 1.1 Train on real data
- pred_dis_real = discriminator(p_images, p_images_labels)
- error_real = loss(pred_dis_real, p_real_target)
- error_real.backward()
- # 1.2 Train on fake data
- fake_data = generator(p_noise, p_fake_labels).detach()
- fake_data = add_noise_to_image(fake_data, device)
- pred_dis_fake = discriminator(fake_data, p_fake_labels)
- error_fake = loss(pred_dis_fake, p_fake_target)
- error_fake.backward()
- p_optimizer.step()
- return error_fake + error_real
- def train_generator(p_optimizer, p_noise, p_real_target, p_fake_labels, device):
- p_optimizer.zero_grad()
- fake_images = generator(p_noise, p_fake_labels)
- fake_images = add_noise_to_image(fake_images, device)
- pred_dis_fake = discriminator(fake_images, p_fake_labels)
- error_fake = loss(pred_dis_fake, p_real_target) # because
- """
- We use "p_real_target" instead of "p_fake_target" because we want to
- maximize that the discriminator is wrong.
- """
- error_fake.backward()
- p_optimizer.step()
- return fake_images, pred_dis_fake, error_fake
- # TODO change to a Truncated normal distribution
- def get_noise(batch_size, n_features=100):
- return torch.FloatTensor(batch_size, n_features, 1, 1).uniform_(-1, 1)
- # We flip label of real and fate data. Better gradient flow I have told
- def get_real_data_target(batch_size):
- return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.0, 0.2)
- def get_fake_data_target(batch_size):
- return torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0.8, 1.1)
- def image_to_vector(images):
- return torch.flatten(images, start_dim=1, end_dim=-1)
- def vector_to_image(images):
- return images.view(images.size(0), 1, 28, 28)
- def get_rand_labels(batch_size):
- return torch.randint(low=0, high=9, size=(batch_size,))
- def load_model(model_load_path):
- if model_load_path:
- checkpoint = torch.load(model_load_path)
- discriminator.load_state_dict(checkpoint["model_discriminator_state_dict"])
- generator.load_state_dict(checkpoint["model_generator_state_dict"])
- dis_opti.load_state_dict(checkpoint["optimizer_discriminator_state_dict"])
- gen_opti.load_state_dict(checkpoint["optimizer_generator_state_dict"])
- return checkpoint["epoch"]
- else:
- return 0
- def init_model_optimizer(model_parameter, device):
- # Initialize the Models
- discriminator = Discriminator(ndf=model_parameter["ndf"], dropout_value=model_parameter["dropout"]).to(device)
- generator = Generator(ngf=model_parameter["ngf"], dropout_value=model_parameter["dropout"]).to(device)
- # train
- dis_opti = optim.Adam(discriminator.parameters(), lr=model_parameter["learning_rate_dis"], betas=(0.5, 0.999))
- gen_opti = optim.Adam(generator.parameters(), lr=model_parameter["learning_rate_gen"], betas=(0.5, 0.999))
- return discriminator, generator, dis_opti, gen_opti
- def get_hot_vector_encode(labels, device):
- return torch.eye(10)[labels].view(-1, 10, 1, 1).to(device)
- def add_noise_to_image(images, device, level_of_noise=0.1):
- return images[0].to(device) + (level_of_noise) * torch.randn(images.shape).to(device)
- if __name__ == "__main__":
- # Hyperparameter
- model_parameter = {
- "batch_size": 500,
- "learning_rate_dis": 0.0002,
- "learning_rate_gen": 0.0002,
- "shuffle": False,
- "Epochs": 10,
- "ndf": 64,
- "ngf": 64,
- "dropout": 0.5
- }
- # Parameter
- r_frequent = 10 # How many samples we save for replay per batch (batch_size / r_frequent).
- model_name = "CDCGAN" # The name of you model e.g. "Gan"
- num_workers_loader = 1 # How many workers should load the data
- sample_save_size = 16 # How many numbers your saved imaged should show
- device = "cuda" # Which device should be used to train the neural network
- model_load_path = "" # If set load model instead of training from new
- num_epoch_log = 1 # How frequent you want to log/
- torch.manual_seed(43) # Sets a seed for torch for reproducibility
- dataset_train, train_loader = get_MNIST_dataset(num_workers_loader, model_parameter) # Get dataset
- # Initialize the Models and optimizer
- discriminator, generator, dis_opti, gen_opti = init_model_optimizer(model_parameter, device) # Init model/Optimizer
- start_epoch = load_model(model_load_path) # when we want to load a model
- # Init Logger
- logger = Logger(model_name, generator, discriminator, gen_opti, dis_opti, model_parameter, train_loader)
- loss = nn.BCELoss()
- images, labels = next(iter(train_loader)) # For logging
- # For testing
- # pred = generator(get_noise(model_parameter["batch_size"]).to(device), get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device))
- # dis = discriminator(images.to(device), get_hot_vector_encode(labels, device))
- logger.log_graph(get_noise(model_parameter["batch_size"]).to(device), images.to(device),
- get_hot_vector_encode(get_rand_labels(model_parameter["batch_size"]), device),
- get_hot_vector_encode(labels, device))
- # Array to store
- exp_replay = torch.tensor([]).to(device)
- for num_epoch in range(start_epoch, model_parameter["Epochs"]):
- for batch_num, data_loader in enumerate(train_loader):
- images, labels = data_loader
- images = add_noise_to_image(images, device) # Add noise to the images
- # 1. Train Discriminator
- dis_error = train_discriminator(
- dis_opti,
- get_noise(model_parameter["batch_size"]).to(device),
- images.to(device),
- get_fake_data_target(model_parameter["batch_size"]).to(device),
- get_real_data_target(model_parameter["batch_size"]).to(device),
- get_hot_vector_encode(labels, device),
- get_hot_vector_encode(
- get_rand_labels(model_parameter["batch_size"]), device),
- device
- )
- # 2. Train Generator
- fake_image, pred_dis_fake, gen_error = train_generator(
- gen_opti,
- get_noise(model_parameter["batch_size"]).to(device),
- get_real_data_target(model_parameter["batch_size"]).to(device),
- get_hot_vector_encode(
- get_rand_labels(model_parameter["batch_size"]),
- device),
- device
- )
- # Store a random point for experience replay
- perm = torch.randperm(fake_image.size(0))
- r_idx = perm[:max(1, int(model_parameter["batch_size"] / r_frequent))]
- r_samples = add_noise_to_image(fake_image[r_idx], device)
- exp_replay = torch.cat((exp_replay, r_samples), 0).detach()
- if exp_replay.size(0) >= model_parameter["batch_size"]:
- # Train on experienced data
- dis_opti.zero_grad()
- r_label = get_hot_vector_encode(torch.zeros(exp_replay.size(0)).numpy(), device)
- pred_dis_real = discriminator(exp_replay, r_label)
- error_real = loss(pred_dis_real, get_fake_data_target(exp_replay.size(0)).to(device))
- error_real.backward()
- dis_opti.step()
- print(f'Epoch: [{num_epoch}/{model_parameter["Epochs"]}] '
- f'Batch: Replay/Experience batch '
- f'Loss_D: {error_real.data.cpu()}, '
- )
- exp_replay = torch.tensor([]).to(device)
- logger.display_stats(epoch=num_epoch, batch_num=batch_num, dis_error=dis_error, gen_error=gen_error)
- if batch_num % 100 == 0:
- logger.log_image(fake_image[:sample_save_size], num_epoch, batch_num)
- logger.log(num_epoch, dis_error, gen_error)
- if num_epoch % num_epoch_log == 0:
- logger.log_model(num_epoch)
- logger.log_histogramm()
- logger.close(logger, fake_image[:sample_save_size], num_epoch, dis_error, gen_error)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement