Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- from torchsummary import summary
- import torchvision.transforms as transforms
- import torchvision.datasets as dset
- from tqdm.autonotebook import tqdm
- from torchvision.utils import save_image
- from copy import deepcopy
- from matplotlib import pyplot as plt
- import numpy as np
- from torch.autograd import Variable
- from torch.nn.functional import adaptive_avg_pool2d
- import os
- from scipy import linalg
- from torch.nn.functional import adaptive_avg_pool2d
- from pytorch_fid.inception import InceptionV3
- num_epochs = 1000
- betas = (0.5, 0.999)
- lr = 0.0002# 1e-5
- batch_size = 100
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
- z_dim = 100 # latent Space
- c_dim = 1 # Image Channel
- label_dim = 10 # label
- image_size = 32
- beta1 = 0.5
- PATH = "./generate/"
- # MNIST dataset
- transform = transforms.Compose([
- transforms.Resize((image_size, image_size)),
- transforms.ToTensor(),
- #transforms.Normalize((0.5,),(0.5,)),
- ])
- train_set = dset.MNIST(root='./mnist_data/',
- train=True,
- transform=transform,
- download=True)
- train_loader = torch.utils.data.DataLoader(
- dataset = train_set,
- batch_size = batch_size,
- shuffle=True,
- drop_last=True
- )
- # Generator model
- class Generator(nn.Module):
- def __init__(self, z_dim, label_dim):
- super(Generator, self).__init__()
- self.input_x = nn.Sequential(
- # input is Z, going into a convolution
- nn.ConvTranspose2d(z_dim, 64*4, 4, 1, 0, bias=False),
- nn.BatchNorm2d(64*4),
- nn.ReLU(True),
- )
- self.input_y = nn.Sequential(
- # input is Z, going into a convolution
- nn.ConvTranspose2d( label_dim, 64*4, 4, 1, 0, bias=False),
- nn.BatchNorm2d(64*4),
- nn.ReLU(True),
- )
- self.concat = nn.Sequential(
- nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
- nn.BatchNorm2d(64*4),
- nn.ReLU(True),
- nn.ConvTranspose2d( 64*4, 64*2, 4, 2, 1, bias=False),
- nn.BatchNorm2d(64*2),
- nn.ReLU(True),
- nn.ConvTranspose2d( 64*2, 1, 4, 2, 1, bias=False),
- nn.Tanh()
- )
- def forward(self, x, y):
- x = self.input_x(x)
- y = self.input_y(y)
- out = torch.cat([x, y] , dim=1)
- out = self.concat(out)
- return out
- # Discriminator model
- class Discriminator(nn.Module):
- def __init__(self, nc=1, label_dim=10):
- super(Discriminator, self).__init__()
- self.input_x = nn.Sequential(
- nn.Conv2d(nc, 64, 4, 2, 1, bias=False),
- nn.LeakyReLU(0.2, inplace=True),
- )
- self.input_y = nn.Sequential(
- nn.Conv2d(label_dim, 64, 4, 2, 1, bias=False),
- nn.LeakyReLU(0.2, inplace=True),
- )
- self.concate = nn.Sequential(
- nn.Conv2d(64*2 , 64*4, 4, 2, 1, bias=False),
- nn.BatchNorm2d(64 * 4),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(64*4, 64*8, 4, 2, 1, bias=False),
- nn.BatchNorm2d(64 * 8),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
- nn.Sigmoid()
- )
- def forward(self, x, y):
- x = self.input_x(x)
- y = self.input_y(y)
- out = torch.cat([x, y] , dim=1)
- out = self.concate(out)
- return out
- def weights_init(m):
- classname = m.__class__.__name__
- if classname.find('Conv') != -1:
- nn.init.normal_(m.weight.data, 0.0, 0.02)
- elif classname.find('BatchNorm') != -1:
- nn.init.normal_(m.weight.data, 1.0, 0.02)
- nn.init.constant_(m.bias.data, 0)
- def train_GAN(G, D, G_opt, D_opt, dataset):
- for i,(data,label) in tqdm(enumerate(dataset)):
- ## Train with all-real batch
- D_opt.zero_grad()
- x_real = data.to(device)
- y_real = torch.ones(batch_size, ).to(device)
- c_real = fill[label].to(device)
- y_real_predict = D(x_real, c_real).squeeze() # (-1, 1, 1, 1) -> (-1, )
- d_real_loss = criterion(y_real_predict, y_real)
- d_real_loss.backward()
- ## Train with all-fake batch
- noise = torch.randn(batch_size, z_dim, 1, 1, device = device)
- noise_label = (torch.rand(batch_size, 1) * label_dim).type(torch.LongTensor).squeeze()
- noise_label_onehot = onehot[noise_label].to(device)
- x_fake = G(noise, noise_label_onehot)
- y_fake = torch.zeros(batch_size, ).to(device)
- c_fake = fill[noise_label].to(device)
- y_fake_predict = D(x_fake, c_fake).squeeze()
- d_fake_loss = criterion(y_fake_predict, y_fake)
- d_fake_loss.backward()
- D_opt.step()
- # (2) Update G network: maximize log(D(G(z)))
- G_opt.zero_grad()
- noise = torch.randn(batch_size, z_dim, 1, 1, device = device)
- noise_label = (torch.rand(batch_size, 1) * label_dim).type(torch.LongTensor).squeeze()
- noise_label_onehot = onehot[noise_label].to(device)
- x_fake = G(noise, noise_label_onehot)
- #y_fake = torch.ones(batch_size, ).to(device)
- c_fake = fill[noise_label].to(device)
- y_fake_predict = D(x_fake, c_fake).squeeze()
- g_loss = criterion(y_fake_predict, y_real)
- g_loss.backward()
- G_opt.step()
- err_D = d_fake_loss.item() + d_real_loss.item()
- err_G = g_loss.item()
- return err_D, err_G
- # Models
- D = Discriminator(c_dim, label_dim).to(device)
- D.apply(weights_init)
- G = Generator(z_dim, label_dim).to(device)
- G.apply(weights_init)
- D_opt = torch.optim.Adam(D.parameters(), lr= lr/2, betas=(beta1, 0.999))#, betas=(beta1, 0.999))
- G_opt = torch.optim.Adam(G.parameters(), lr= lr, betas=(beta1, 0.999))#, betas=(beta1, 0.999))
- # Loss function
- criterion = torch.nn.BCELoss()
- ##########
- fixed_noise = torch.randn(100,100)
- fixed_noise = fixed_noise.reshape(100,100,1,1)
- fixed_noise2 = torch.randn(100,100)
- fixed_noise2 = fixed_noise2.reshape(100,100,1,1)
- labels = torch.LongTensor([i for i in range(10) for _ in range(10)]).cuda() #00000000001111111111222222222233333333334444444444555555555566666666667777777777788888888889999999999
- fixed_c = labels.reshape(100,1).float()
- labels = labels.reshape(100,1)
- one_hot = nn.functional.one_hot(labels, num_classes=10)#fixed_c codificato in one_hot
- fixed_label = one_hot.reshape(100,10,1,1).float()
- onehot_before_cod = torch.LongTensor([i for i in range(10)]).cuda() #0123456789
- onehot = nn.functional.one_hot(onehot_before_cod, num_classes=10)
- onehot = onehot.reshape(10,10,1,1).float()
- fill = onehot.repeat(1,1,32,32)
- D_loss = []
- G_loss = []
- for epoch in tqdm(range(num_epochs)):
- D_losses = []
- G_losses = []
- if epoch == 5 or epoch == 10:
- G_opt.param_groups[0]['lr'] /= 2
- D_opt.param_groups[0]['lr'] /= 2
- # training
- err_D, err_G, fretchet_dist = train_GAN(G, D, G_opt, D_opt, train_loader)
- D_loss.append(err_D)
- G_loss.append(err_G)
- # test
- if epoch % 1 == 0 or epoch +1 == num_epochs:
- with torch.no_grad():
- out_imgs = G(fixed_noise.to(device), fixed_label.to(device))
- out_imgs2 = G(fixed_noise2.to(device), fixed_label.to(device))
- save_image(out_imgs,f"{PATH}{epoch}.png")
- D.eval()
- G.eval()
- torch.save(D.state_dict(),f'{PATH}discriminator_cDCGAN_with_fid.pth')
- torch.save(G.state_dict(), f'{PATH}generator_cDCGAN_with_fid.pth')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement