Advertisement
Guest User

cDCGAN

a guest
Dec 14th, 2020
637
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.81 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. from torchsummary import summary
  4. import torchvision.transforms as transforms
  5. import torchvision.datasets as dset
  6. from tqdm.autonotebook import tqdm
  7. from torchvision.utils import save_image
  8. from copy import deepcopy
  9. from matplotlib import pyplot as plt
  10. import numpy as np
  11. from torch.autograd import Variable
  12. from torch.nn.functional import adaptive_avg_pool2d
  13. import os
  14. from scipy import linalg
  15. from torch.nn.functional import adaptive_avg_pool2d
  16. from pytorch_fid.inception import InceptionV3
  17.  
  18.  
  19.  
  20. num_epochs = 1000
  21. betas = (0.5, 0.999)
  22. lr = 0.0002# 1e-5
  23.  
  24. batch_size = 100
  25. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  26. z_dim = 100        # latent Space
  27. c_dim = 1          # Image Channel
  28. label_dim = 10     # label
  29. image_size = 32
  30. beta1 = 0.5
  31. PATH = "./generate/"
  32.  
  33. # MNIST dataset
  34. transform = transforms.Compose([
  35.     transforms.Resize((image_size, image_size)),
  36.     transforms.ToTensor(),
  37.     #transforms.Normalize((0.5,),(0.5,)),
  38.     ])
  39.  
  40. train_set = dset.MNIST(root='./mnist_data/',
  41.                        train=True,
  42.                        transform=transform,
  43.                        download=True)
  44.  
  45.  
  46. train_loader = torch.utils.data.DataLoader(
  47.     dataset = train_set,
  48.     batch_size = batch_size,
  49.     shuffle=True,
  50.     drop_last=True
  51. )
  52.  
  53.  
  54.  
  55. # Generator model
  56. class Generator(nn.Module):
  57.     def __init__(self, z_dim, label_dim):
  58.         super(Generator, self).__init__()
  59.         self.input_x = nn.Sequential(
  60.             # input is Z, going into a convolution
  61.             nn.ConvTranspose2d(z_dim, 64*4, 4, 1, 0, bias=False),
  62.             nn.BatchNorm2d(64*4),
  63.             nn.ReLU(True),
  64.         )
  65.         self.input_y = nn.Sequential(
  66.             # input is Z, going into a convolution
  67.             nn.ConvTranspose2d( label_dim, 64*4, 4, 1, 0, bias=False),
  68.             nn.BatchNorm2d(64*4),
  69.             nn.ReLU(True),
  70.         )
  71.         self.concat = nn.Sequential(
  72.  
  73.             nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),    
  74.             nn.BatchNorm2d(64*4),
  75.             nn.ReLU(True),
  76.  
  77.             nn.ConvTranspose2d( 64*4, 64*2, 4, 2, 1, bias=False),
  78.             nn.BatchNorm2d(64*2),
  79.             nn.ReLU(True),
  80.  
  81.            
  82.             nn.ConvTranspose2d( 64*2, 1, 4, 2, 1, bias=False),
  83.             nn.Tanh()
  84.  
  85.         )
  86.  
  87.     def forward(self, x, y):
  88.         x = self.input_x(x)
  89.         y = self.input_y(y)
  90.         out = torch.cat([x, y] , dim=1)
  91.         out = self.concat(out)
  92.         return out
  93.  
  94.  
  95. # Discriminator model
  96. class Discriminator(nn.Module):
  97.    
  98.     def __init__(self, nc=1, label_dim=10):
  99.        
  100.         super(Discriminator, self).__init__()
  101.        
  102.         self.input_x = nn.Sequential(
  103.            
  104.             nn.Conv2d(nc, 64, 4, 2, 1, bias=False),
  105.             nn.LeakyReLU(0.2, inplace=True),
  106.            
  107.         )
  108.         self.input_y = nn.Sequential(
  109.            
  110.             nn.Conv2d(label_dim, 64, 4, 2, 1, bias=False),
  111.             nn.LeakyReLU(0.2, inplace=True),
  112.         )
  113.        
  114.         self.concate = nn.Sequential(
  115.            
  116.             nn.Conv2d(64*2 , 64*4, 4, 2, 1, bias=False),
  117.             nn.BatchNorm2d(64 * 4),
  118.             nn.LeakyReLU(0.2, inplace=True),
  119.            
  120.  
  121.             nn.Conv2d(64*4, 64*8, 4, 2, 1, bias=False),
  122.             nn.BatchNorm2d(64 * 8),
  123.             nn.LeakyReLU(0.2, inplace=True),
  124.  
  125.             nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
  126.             nn.Sigmoid()
  127.  
  128.         )
  129.        
  130.     def forward(self, x, y):
  131.        
  132.         x = self.input_x(x)
  133.         y = self.input_y(y)
  134.         out = torch.cat([x, y] , dim=1)
  135.  
  136.         out = self.concate(out)
  137.  
  138.         return out
  139.  
  140. def weights_init(m):
  141.     classname = m.__class__.__name__
  142.     if classname.find('Conv') != -1:
  143.         nn.init.normal_(m.weight.data, 0.0, 0.02)
  144.     elif classname.find('BatchNorm') != -1:
  145.         nn.init.normal_(m.weight.data, 1.0, 0.02)
  146.         nn.init.constant_(m.bias.data, 0)
  147.  
  148.  
  149. def train_GAN(G, D, G_opt, D_opt, dataset):
  150.     for i,(data,label) in tqdm(enumerate(dataset)):
  151.  
  152.         ## Train with all-real batch        
  153.         D_opt.zero_grad()
  154.          
  155.         x_real = data.to(device)
  156.         y_real = torch.ones(batch_size, ).to(device)
  157.         c_real = fill[label].to(device)
  158.    
  159.  
  160.         y_real_predict = D(x_real, c_real).squeeze()        # (-1, 1, 1, 1) -> (-1, )
  161.         d_real_loss = criterion(y_real_predict, y_real)
  162.         d_real_loss.backward()
  163.  
  164.         ## Train with all-fake batch
  165.         noise = torch.randn(batch_size, z_dim, 1, 1, device = device)
  166.         noise_label = (torch.rand(batch_size, 1) * label_dim).type(torch.LongTensor).squeeze()
  167.  
  168.         noise_label_onehot = onehot[noise_label].to(device)
  169.  
  170.         x_fake = G(noise, noise_label_onehot)
  171.         y_fake = torch.zeros(batch_size, ).to(device)
  172.         c_fake = fill[noise_label].to(device)
  173.  
  174.         y_fake_predict = D(x_fake, c_fake).squeeze()
  175.         d_fake_loss = criterion(y_fake_predict, y_fake)
  176.         d_fake_loss.backward()
  177.         D_opt.step()
  178.          
  179.         # (2) Update G network: maximize log(D(G(z)))        
  180.         G_opt.zero_grad()
  181.          
  182.         noise = torch.randn(batch_size, z_dim, 1, 1, device = device)
  183.         noise_label = (torch.rand(batch_size, 1) * label_dim).type(torch.LongTensor).squeeze()
  184.         noise_label_onehot = onehot[noise_label].to(device)
  185.          
  186.         x_fake = G(noise, noise_label_onehot)
  187.         #y_fake = torch.ones(batch_size, ).to(device)
  188.         c_fake = fill[noise_label].to(device)
  189.          
  190.         y_fake_predict = D(x_fake, c_fake).squeeze()
  191.         g_loss = criterion(y_fake_predict, y_real)
  192.         g_loss.backward()
  193.         G_opt.step()
  194.  
  195.         err_D = d_fake_loss.item() + d_real_loss.item()
  196.         err_G = g_loss.item()
  197.  
  198.  
  199.     return err_D, err_G
  200.  
  201.        
  202.  
  203. # Models
  204. D = Discriminator(c_dim, label_dim).to(device)
  205. D.apply(weights_init)
  206.  
  207. G = Generator(z_dim, label_dim).to(device)
  208. G.apply(weights_init)
  209.  
  210. D_opt = torch.optim.Adam(D.parameters(), lr= lr/2, betas=(beta1, 0.999))#, betas=(beta1, 0.999))
  211. G_opt = torch.optim.Adam(G.parameters(), lr= lr, betas=(beta1, 0.999))#, betas=(beta1, 0.999))
  212.  
  213. # Loss function
  214. criterion = torch.nn.BCELoss()
  215.  
  216. ##########
  217. fixed_noise = torch.randn(100,100)
  218. fixed_noise = fixed_noise.reshape(100,100,1,1)
  219.  
  220. fixed_noise2 = torch.randn(100,100)
  221. fixed_noise2 = fixed_noise2.reshape(100,100,1,1)
  222.  
  223. labels = torch.LongTensor([i for i in range(10) for _ in range(10)]).cuda() #00000000001111111111222222222233333333334444444444555555555566666666667777777777788888888889999999999
  224. fixed_c = labels.reshape(100,1).float()
  225.  
  226. labels = labels.reshape(100,1)
  227.  
  228. one_hot = nn.functional.one_hot(labels, num_classes=10)#fixed_c codificato in one_hot
  229. fixed_label = one_hot.reshape(100,10,1,1).float()
  230.  
  231.  
  232. onehot_before_cod = torch.LongTensor([i for i in range(10)]).cuda() #0123456789
  233. onehot = nn.functional.one_hot(onehot_before_cod, num_classes=10)
  234.  
  235. onehot = onehot.reshape(10,10,1,1).float()
  236. fill = onehot.repeat(1,1,32,32)
  237.  
  238.  
  239. D_loss = []
  240. G_loss = []
  241.  
  242. for epoch in tqdm(range(num_epochs)):
  243.     D_losses = []
  244.     G_losses = []
  245.     if epoch == 5 or epoch == 10:
  246.         G_opt.param_groups[0]['lr'] /= 2
  247.         D_opt.param_groups[0]['lr'] /= 2
  248.        
  249.     # training
  250.     err_D, err_G, fretchet_dist = train_GAN(G, D, G_opt, D_opt, train_loader)
  251.  
  252.  
  253.     D_loss.append(err_D)
  254.     G_loss.append(err_G)
  255.    
  256.     # test
  257.     if epoch % 1 == 0 or epoch +1 == num_epochs:
  258.         with torch.no_grad():
  259.             out_imgs = G(fixed_noise.to(device), fixed_label.to(device))
  260.             out_imgs2 = G(fixed_noise2.to(device), fixed_label.to(device))
  261.  
  262.         save_image(out_imgs,f"{PATH}{epoch}.png")
  263.  
  264.  
  265. D.eval()
  266. G.eval()
  267. torch.save(D.state_dict(),f'{PATH}discriminator_cDCGAN_with_fid.pth')
  268. torch.save(G.state_dict(), f'{PATH}generator_cDCGAN_with_fid.pth')
  269.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement