Advertisement
Guest User

Untitled

a guest
Feb 7th, 2019
499
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 15.05 KB | None | 0 0
  1. import torch.optim as optim
  2. import torch.nn.functional as F
  3. import torch.nn as nn
  4. import torch
  5. import torchvision
  6. import torchvision.transforms as transforms
  7. from torchvision.models import resnet18
  8. import pywt
  9.  
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. import random
  13. import scipy.misc
  14.  
  15. from PIL import Image
  16.  
  17. # fixed seed
  18. torch.manual_seed(0)
  19. torch.cuda.manual_seed(0)
  20. np.random.seed(0)
  21. random.seed(0)
  22. torch.backends.cudnn.deterministic = True
  23.  
  24. transform_data = transforms.Compose([
  25.     transforms.ToTensor(),
  26. ])
  27.  
  28. transform_data_G = transforms.Compose([
  29.     transforms.ToTensor(),
  30.     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  31. ])
  32. # for DWT
  33. def transform_RGB2YCBCR(img):        
  34.     out_y = []
  35.     out_cb = []
  36.     out_cr = []
  37.     for i in range(0, img.size()[0]):
  38.         img_t = transforms.ToPILImage()(img[i,:,:,:])
  39.         out = img_t.convert('YCbCr')
  40.         out_yt, out_cbt, out_crt = out.split()        
  41.         out_y.append(transforms.ToTensor()(out_yt))
  42.         out_cb.append(transforms.ToTensor()(out_cbt))
  43.         out_cr.append(transforms.ToTensor()(out_crt))    
  44.     return torch.stack(out_y), torch.stack(out_cb), torch.stack(out_cr)
  45.  
  46. # for IDWT
  47. def transform_YCBCR2RGB(y, cb, cr):
  48.     out = []
  49.     for i in range(0, y.size()[0]):
  50.         img_yt = transforms.ToPILImage()(y[i,:,:,:])
  51.         img_cbt = transforms.ToPILImage()(cb[i,:,:,:])
  52.         img_crt = transforms.ToPILImage()(cr[i,:,:,:])
  53.         out_t = Image.merge('YCbCr',[img_yt, img_cbt, img_crt]).convert('RGB')
  54.         out.append(transforms.ToTensor()(out_t))
  55.     return torch.stack(out)
  56.  
  57. # generate low-resolutional images for cifar10
  58. def transform_LR(img):          
  59.     out = img.resize((8, 8))      
  60.     out = out.resize((32,32))
  61.     img = img.resize((32,32))    
  62.     return transform_data_G(out), transform_data_G(img)
  63.  
  64. # resize image tensor for pre-trained resnet input
  65. def transform_224(img):        
  66.     npimg = np.uint8(img.cpu().detach().numpy())
  67.     out = []
  68.     for i in range(0, img.size()[0]):
  69.         img_t = transforms.ToPILImage()(npimg[i,:,:,:])      
  70.         out_t = transforms.functional.resize(img_t, size = (224, 224))      
  71.         out.append(transforms.ToTensor()(out_t))
  72.    
  73.     return torch.stack(out)
  74.  
  75. # 2D discrete wavelet transform
  76. def transform_DWT(img):
  77.     npimg = img.numpy()
  78.     coeffs2 = pywt.dwt2(npimg, 'haar')
  79.     LL, (LH, HL, HH) = coeffs2
  80.     out = np.concatenate((LL, LH, HL, HH), axis=1)
  81.     return torch.from_numpy(out).float()
  82.  
  83. # IDWT for input frequential images
  84. def transform_IDWT(img):
  85.     npimg = img.numpy()
  86.     npimg = (npimg[:, 0, :, :], (npimg[:, 1, :, :],
  87.                                  npimg[:, 2, :, :], npimg[:, 3, :, :]))
  88.     out = pywt.idwt2(npimg, 'haar')
  89.     return torch.from_numpy(out).float()
  90.  
  91. # designed for calculating the MSE loss for reconstructed image by IDWT (has 3 channels : HH, HL, LH)
  92. # input : (size of minibatch) * channel * width * height
  93. # target : (size of minibatch) * channel * width * height
  94. class myMSELoss(torch.nn.Module):
  95.     def __init__(self):
  96.         super(myMSELoss,self).__init__()      
  97.     def forward(self, inp, tar):
  98.         loss = torch.sum((inp-tar) ** 2).data / (inp.size()[0] * inp.size()[1] * inp.size()[2] * inp.size()[3])
  99.         return loss
  100.     def backward(self, grad_output):
  101.         return grad_input, None
  102.  
  103. # save image to file
  104. def saveImg(img, outFileName):
  105.     img = img / 2 + 0.5
  106.     npimg = img.numpy()
  107.     scipy.misc.imsave(outFileName, np.transpose(npimg, (1, 2, 0)))
  108.  
  109. # weight initialization
  110. def weights_init_normal(m):
  111.     classname = m.__class__.__name__
  112.     if classname.find('Conv') != -1:
  113.         torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
  114.     elif classname.find('BatchNorm2d') != -1:
  115.         torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
  116.         torch.nn.init.constant_(m.bias.data, 0.0)
  117.  
  118. # device check
  119. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  120. print(device)
  121.  
  122. #load dataset, similar to original cifar10 dataset
  123. #train/airplane, ... train/truck
  124. #test/airplane, ... test/truck
  125.  
  126. trainset = torchvision.datasets.ImageFolder(
  127.     root='cifar10_aug_32/train', transform=transform_LR)
  128. trainloader = torch.utils.data.DataLoader(
  129.     trainset, batch_size=5, shuffle=True, num_workers=0, worker_init_fn=np.random.seed(0))
  130.  
  131. testset = torchvision.datasets.ImageFolder(
  132.     root='cifar10_aug_32/test', transform=transform_LR)
  133. testloader = torch.utils.data.DataLoader(
  134.     testset, batch_size=100, shuffle=True, num_workers=0, worker_init_fn=np.random.seed(0))
  135.  
  136. # A network for image enhancing
  137. # input : RGB image
  138. # output : RGB image, same size as input
  139. class Generator(nn.Module):
  140.     def __init__(self):
  141.         super(Generator, self).__init__()
  142.         self.L1 = nn.Sequential(
  143.             nn.BatchNorm2d(3),
  144.             nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
  145.             nn.LeakyReLU(0.2, inplace=True),
  146.             nn.Dropout2d(0.5)
  147.         )
  148.         self.L2 = nn.Sequential(
  149.             nn.BatchNorm2d(32),
  150.             nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
  151.             nn.LeakyReLU(0.2, inplace=True),
  152.             nn.Dropout2d(0.5)
  153.         )
  154.         self.L3 = nn.Sequential(
  155.             nn.BatchNorm2d(16),
  156.             nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
  157.             nn.LeakyReLU(0.2, inplace=True),
  158.             nn.Dropout2d(0.5)
  159.         )
  160.         self.L4 = nn.Sequential(
  161.             nn.BatchNorm2d(8),
  162.             nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
  163.             nn.LeakyReLU(0.2, inplace=True),
  164.             nn.Dropout2d(0.5)
  165.         )        
  166.         self.L5 = nn.Sequential(
  167.             nn.BatchNorm2d(16),
  168.             nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
  169.             nn.LeakyReLU(0.2, inplace=True),
  170.             nn.Dropout2d(0.5)
  171.         )
  172.         self.L3d = nn.Sequential(
  173.             nn.BatchNorm2d(8),
  174.             nn.Conv2d(8, 3, kernel_size=3, stride=1, padding=1),
  175.             nn.Tanh()
  176.         )
  177.         self.L4d = nn.Sequential(
  178.             nn.BatchNorm2d(16),
  179.             nn.Conv2d(16, 3, kernel_size=3, stride=1, padding=1),
  180.             nn.Tanh()
  181.         )
  182.         self.L5d = nn.Sequential(
  183.             nn.BatchNorm2d(32),
  184.             nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
  185.             nn.Tanh()
  186.         )
  187.         self.L1_32 = nn.Sequential(
  188.             nn.BatchNorm2d(3),
  189.             nn.Conv2d(3, 32, kernel_size=32, stride=32, padding=0),
  190.             nn.LeakyReLU(0.2, inplace=True),
  191.             nn.Dropout2d(0.5),
  192.             nn.BatchNorm2d(32),
  193.             nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
  194.             nn.Tanh()
  195.         )
  196.         self.L1_16 = nn.Sequential(
  197.             nn.BatchNorm2d(3),
  198.             nn.Conv2d(3, 32, kernel_size=16, stride=16, padding=0),
  199.             nn.LeakyReLU(0.2, inplace=True),
  200.             nn.Dropout2d(0.5),
  201.             nn.BatchNorm2d(32),
  202.             nn.Conv2d(32, 3, kernel_size=2, stride=2, padding=0),
  203.             nn.Tanh()
  204.         )
  205.         self.L1_8 = nn.Sequential(
  206.             nn.BatchNorm2d(3),
  207.             nn.Conv2d(3, 32, kernel_size=8, stride=8, padding=0),
  208.             nn.LeakyReLU(0.2, inplace=True),
  209.             nn.Dropout2d(0.5),
  210.             nn.BatchNorm2d(32),        
  211.             nn.Conv2d(32, 3, kernel_size=4, stride=4, padding=0),
  212.             nn.Tanh()
  213.         )      
  214.        
  215.     def forward(self, img):
  216.         out = self.L1(img)
  217.         out = self.L2(out)
  218.         out_L3 = self.L3(out)
  219.         out_L4 = self.L4(out_L3)
  220.         out_L5 = self.L5(out_L4)
  221.        
  222.         out_L3d = self.L3d(out_L3)
  223.         out_L4d = self.L4d(out_L4)
  224.         out_L5d = self.L5d(out_L5)
  225.  
  226.         out_L1_32 = self.L1_32(img)
  227.         out_L1_16 = self.L1_16(img)
  228.         out_L1_8 = self.L1_8(img)
  229.        
  230.         out_L6 = (out_L3d + out_L4d + out_L5d) / 3
  231.         out_context = (out_L1_32 + out_L1_16 + out_L1_8) / 3
  232.         return img + out_L6 + out_context
  233.  
  234. # Discriminator & Classifier
  235. # using pre-trained resnet18
  236. # delete last layer and get 512-dim fc output
  237. # add a fc-256 layer
  238. # append fc-10 layer to classifier and fc-1 layer to validate image (like GAN) to fc-256 layer
  239. # resnet18(fc-512) - (fc-256) - (fc-10) : classifier (0~9 class label, one-hot vector)
  240. #                             - (fc-1)  : discriminator (real/fake, 1-dim output)
  241. class Discriminator(nn.Module):
  242.     def __init__(self):
  243.         super(Discriminator, self).__init__()
  244.  
  245.         self.model_ft = resnet18(pretrained=True)
  246.         num_ftrs = self.model_ft.fc.in_features
  247.         self.modules = list(self.model_ft.children())[:-1] # delete last layer
  248.         self.model_ft = nn.Sequential(*self.modules)
  249.         self.fc_feature = nn.Sequential(nn.Linear(num_ftrs, 256), nn.ReLU(), nn.Dropout()) # add fc-256 layer
  250.         self.fc_layer = nn.Sequential(nn.Linear(256, 10), nn.Softmax()) # add fc-10
  251.         self.val_layer = nn.Sequential(nn.Linear(256, 1), nn.Sigmoid()) # add fc-1
  252.  
  253.     def forward(self, img):
  254.         out = self.model_ft(img)      
  255.         out = out.view(out.shape[0],-1)
  256.         out = self.fc_feature(out)
  257.         label = self.fc_layer(out)
  258.         validity = self.val_layer(out)
  259.         return validity, label # two outputs
  260.  
  261. # GPU env.
  262. generator = Generator().cuda()
  263. generator.to(device)
  264. discriminator = Discriminator().cuda()
  265. discriminator.to(device)
  266.  
  267. # criterions for generator and discriminator
  268. criterion_g = nn.MSELoss().cuda()
  269. criterion_d = nn.CrossEntropyLoss().cuda()
  270. criterion_val = nn.BCELoss().cuda()
  271.  
  272. # weight initialize
  273. generator.apply(weights_init_normal)
  274. discriminator.apply(weights_init_normal)
  275.  
  276. # optimizers
  277. optimizer_g = optim.Adam(generator.parameters())
  278. optimizer_d = optim.Adam(discriminator.parameters())
  279.  
  280. # log file
  281. f = open('cifar10_DG_RGB_J_ResNet18_ALL_OnlyGLoss.log', 'w')
  282.  
  283. for epoch in range(100):  # loop over the dataset multiple times
  284.     running_loss_d = 0.0
  285.     running_loss_g = 0.0
  286.     for i, data in enumerate(trainloader, 0): # take mini-batches
  287.         imgs, labels = data
  288.         imgLR, imgHR = imgs # generate Low-Resolutional images for each mini-batches
  289.         labels = labels.cuda()
  290.         #variables for discriminator
  291.         valid = torch.autograd.Variable(torch.cuda.FloatTensor(labels.size()[0]).fill_(1.0), requires_grad=False)
  292.         fake = torch.autograd.Variable(torch.cuda.FloatTensor(labels.size()[0]).fill_(0.0), requires_grad=False)
  293.         # target data : High frequency factors of Y channel in High-resolutional images (HH, HL, LH)
  294.         imgHR_Y, _, _ = transform_RGB2YCBCR(imgHR)
  295.         imgHR_W = transform_DWT(imgHR_Y)[:,1:,:,:]
  296.         # zero_grad for generator
  297.         optimizer_g.zero_grad()
  298.         # get enhanced image
  299.         imgSR = generator(imgLR.cuda())        
  300.         imgSR_Y, _, _ = transform_RGB2YCBCR(imgSR.cpu())
  301.         imgSR_W = transform_DWT(imgSR_Y)[:,1:,:,:]
  302.         # resize for pre-trained resnet18 (32x32x3 --> 224x224x3)
  303.         imgSRd = transform_224(imgSR)  
  304.         # get discriminator output
  305.         validity, pred_label = discriminator(imgSRd.cuda())        
  306.         # generator loss
  307.         # mse_loss between generated images(imgSR) and High-resolutional images(imgHR)
  308.         # + mse_loss between high frequency factors in generated images(imgSR_W) and ones in High-resolutional images(imgHR_W)
  309.         # + adversarial loss to fool discriminator
  310.         loss_g = criterion_g(imgHR, imgSR.cpu()) + myMSELoss(imgSR_W, imgHR_W) + criterion_val(validity.cpu(), valid.cpu())              
  311.         #check if weights are updated
  312.         a = list(generator.parameters())[0].clone()
  313.         loss_g.backward()
  314.         optimizer_g.step()
  315.         b = list(generator.parameters())[0].clone()
  316.         print(list(generator.parameters())[0].grad)
  317.         optimizer_d.zero_grad()        
  318.        
  319.         imgHRd = transform_224(imgHR)  
  320.         imgLRd = transform_224(imgLR)  
  321.         # get discriminator output for Low-resolutional images, High-resolutional images, and generated images
  322.         val_LR, aux_LR = discriminator(imgLRd.cuda())
  323.         val_HR, aux_HR = discriminator(imgHRd.cuda())
  324.         val_SR, aux_SR = discriminator(imgSRd.cuda())
  325.         # discriminator loss
  326.         # classification loss for Low-Resolutional images + Real/Fake Loss for Low-Resolutional images
  327.         # + classification loss for High-Resolutional images + Real/Fake Loss for High-Resolutional images
  328.         # + classification loss for generated images + Real/Fake Loss for generated images
  329.         loss_d_LR = (criterion_d(aux_LR, labels) + criterion_val(val_LR, fake)) / 2
  330.         loss_d_HR = (criterion_d(aux_HR, labels) + criterion_val(val_HR, valid)) / 2
  331.         loss_d_SR = (criterion_d(aux_SR, labels) + criterion_val(val_SR, fake)) / 2
  332.         loss_d = (loss_d_HR + loss_d_LR + loss_d_SR) / 3
  333.         loss_d.backward()
  334.         optimizer_d.step()
  335.        
  336.         running_loss_d += loss_d.item()
  337.         running_loss_g += loss_g.item()
  338.         if i % 10 == 9:    # print every 10 mini-batches
  339.             disp_str = '[%d, %5d] d_loss: %.3f g_loss: %.3f' % (epoch + 1, i + 1, running_loss_d / 10, running_loss_g / 10)
  340.             print(disp_str)
  341.             f.write(disp_str)
  342.             f.write('\n')
  343.             running_loss_d = 0.0
  344.             running_loss_g = 0.0
  345.  
  346.         if i % 10 == 9: # save sample every 10 mini-batches in test images
  347.             for it, data in enumerate(testloader, 0):
  348.                 imgs, labels = data
  349.                 imgLR, imgHR = imgs
  350.                 imgSR = generator(imgLR.cuda())
  351.                
  352.                 outImg = torch.cat(
  353.                     (imgLR, imgHR, imgSR.cpu().detach()), dim=0)
  354.                 outImg = torchvision.utils.make_grid(outImg, normalize=True, nrow=20)
  355.  
  356.                 baseName = 'results/cifar10_DG_RGB_J_ResNet18_ALL_OnlyGLoss_'
  357.                 outFileName = baseName + 'epoch_' + \
  358.                     str(epoch) + '_' + str(i+1) + '.jpg'
  359.                 saveImg(outImg, outFileName)
  360.  
  361.                 break
  362.     #save weights
  363.     baseName = 'results/checkpoint_cifar10_DG_RGB_J_ResNet18_ALL_OnlyGLoss_'
  364.     outFileName = baseName + 'epoch_' + str(epoch)
  365.                
  366.     torch.save({
  367.             'epoch': epoch,
  368.             'generator_state_dict': generator.state_dict(),
  369.             'discriminator_state_dict': discriminator.state_dict(),
  370.             'optimizer_d_state_dict': optimizer_d.state_dict(),
  371.             'optimizer_g_state_dict': optimizer_g.state_dict(),            
  372.             'loss_d': loss_d,
  373.             'loss_d': loss_g,
  374.     }, outFileName)
  375.  
  376.     running_corrects_LR = 0
  377.     running_corrects_HR = 0
  378.     running_corrects_SR = 0
  379.  
  380.     #test accuracy for each epoches
  381.     for it, data in enumerate(testloader, 0):
  382.         print(it)
  383.         imgs, labels = data
  384.         labels = labels.cuda()
  385.         imgLR, imgHR = imgs
  386.         imgSR = generator(imgLR.cuda())
  387.  
  388.         imgLRd = transform_224(imgLR)
  389.         imgHRd = transform_224(imgHR)
  390.         imgSRd = transform_224(imgSR)
  391.  
  392.         val_LR, aux_LR = discriminator(imgLRd.cuda())
  393.         val_HR, aux_HR = discriminator(imgHRd.cuda())
  394.         val_SR, aux_SR = discriminator(imgSRd.cuda())
  395.         _, pred_LR = torch.max(aux_LR.data, 1)
  396.         _, pred_HR = torch.max(aux_HR.data, 1)
  397.         _, pred_SR = torch.max(aux_SR.data, 1)
  398.  
  399.         running_corrects_LR += torch.sum(pred_LR == labels.data)
  400.         running_corrects_HR += torch.sum(pred_HR == labels.data)
  401.         running_corrects_SR += torch.sum(pred_SR == labels.data)
  402.  
  403.     disp_str = '[%d] LR : %d, HR : %d, SR : %d' % (epoch, running_corrects_LR.data, running_corrects_HR.data, running_corrects_SR.data)
  404.  
  405.     print(disp_str)
  406.     f.write(disp_str)
  407.     f.write('\n')
  408.  
  409. print('Finished Training')
  410. f.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement