Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch.optim as optim
- import torch.nn.functional as F
- import torch.nn as nn
- import torch
- import torchvision
- import torchvision.transforms as transforms
- from torchvision.models import resnet18
- import pywt
- import matplotlib.pyplot as plt
- import numpy as np
- import random
- import scipy.misc
- from PIL import Image
- # fixed seed
- torch.manual_seed(0)
- torch.cuda.manual_seed(0)
- np.random.seed(0)
- random.seed(0)
- torch.backends.cudnn.deterministic = True
- transform_data = transforms.Compose([
- transforms.ToTensor(),
- ])
- transform_data_G = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
- ])
- # for DWT
- def transform_RGB2YCBCR(img):
- out_y = []
- out_cb = []
- out_cr = []
- for i in range(0, img.size()[0]):
- img_t = transforms.ToPILImage()(img[i,:,:,:])
- out = img_t.convert('YCbCr')
- out_yt, out_cbt, out_crt = out.split()
- out_y.append(transforms.ToTensor()(out_yt))
- out_cb.append(transforms.ToTensor()(out_cbt))
- out_cr.append(transforms.ToTensor()(out_crt))
- return torch.stack(out_y), torch.stack(out_cb), torch.stack(out_cr)
- # for IDWT
- def transform_YCBCR2RGB(y, cb, cr):
- out = []
- for i in range(0, y.size()[0]):
- img_yt = transforms.ToPILImage()(y[i,:,:,:])
- img_cbt = transforms.ToPILImage()(cb[i,:,:,:])
- img_crt = transforms.ToPILImage()(cr[i,:,:,:])
- out_t = Image.merge('YCbCr',[img_yt, img_cbt, img_crt]).convert('RGB')
- out.append(transforms.ToTensor()(out_t))
- return torch.stack(out)
- # generate low-resolutional images for cifar10
- def transform_LR(img):
- out = img.resize((8, 8))
- out = out.resize((32,32))
- img = img.resize((32,32))
- return transform_data_G(out), transform_data_G(img)
- # resize image tensor for pre-trained resnet input
- def transform_224(img):
- npimg = np.uint8(img.cpu().detach().numpy())
- out = []
- for i in range(0, img.size()[0]):
- img_t = transforms.ToPILImage()(npimg[i,:,:,:])
- out_t = transforms.functional.resize(img_t, size = (224, 224))
- out.append(transforms.ToTensor()(out_t))
- return torch.stack(out)
- # 2D discrete wavelet transform
- def transform_DWT(img):
- npimg = img.numpy()
- coeffs2 = pywt.dwt2(npimg, 'haar')
- LL, (LH, HL, HH) = coeffs2
- out = np.concatenate((LL, LH, HL, HH), axis=1)
- return torch.from_numpy(out).float()
- # IDWT for input frequential images
- def transform_IDWT(img):
- npimg = img.numpy()
- npimg = (npimg[:, 0, :, :], (npimg[:, 1, :, :],
- npimg[:, 2, :, :], npimg[:, 3, :, :]))
- out = pywt.idwt2(npimg, 'haar')
- return torch.from_numpy(out).float()
- # designed for calculating the MSE loss for reconstructed image by IDWT (has 3 channels : HH, HL, LH)
- # input : (size of minibatch) * channel * width * height
- # target : (size of minibatch) * channel * width * height
- class myMSELoss(torch.nn.Module):
- def __init__(self):
- super(myMSELoss,self).__init__()
- def forward(self, inp, tar):
- loss = torch.sum((inp-tar) ** 2).data / (inp.size()[0] * inp.size()[1] * inp.size()[2] * inp.size()[3])
- return loss
- def backward(self, grad_output):
- return grad_input, None
- # save image to file
- def saveImg(img, outFileName):
- img = img / 2 + 0.5
- npimg = img.numpy()
- scipy.misc.imsave(outFileName, np.transpose(npimg, (1, 2, 0)))
- # weight initialization
- def weights_init_normal(m):
- classname = m.__class__.__name__
- if classname.find('Conv') != -1:
- torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
- elif classname.find('BatchNorm2d') != -1:
- torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
- torch.nn.init.constant_(m.bias.data, 0.0)
- # device check
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- print(device)
- #load dataset, similar to original cifar10 dataset
- #train/airplane, ... train/truck
- #test/airplane, ... test/truck
- trainset = torchvision.datasets.ImageFolder(
- root='cifar10_aug_32/train', transform=transform_LR)
- trainloader = torch.utils.data.DataLoader(
- trainset, batch_size=5, shuffle=True, num_workers=0, worker_init_fn=np.random.seed(0))
- testset = torchvision.datasets.ImageFolder(
- root='cifar10_aug_32/test', transform=transform_LR)
- testloader = torch.utils.data.DataLoader(
- testset, batch_size=100, shuffle=True, num_workers=0, worker_init_fn=np.random.seed(0))
- # A network for image enhancing
- # input : RGB image
- # output : RGB image, same size as input
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
- self.L1 = nn.Sequential(
- nn.BatchNorm2d(3),
- nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Dropout2d(0.5)
- )
- self.L2 = nn.Sequential(
- nn.BatchNorm2d(32),
- nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Dropout2d(0.5)
- )
- self.L3 = nn.Sequential(
- nn.BatchNorm2d(16),
- nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Dropout2d(0.5)
- )
- self.L4 = nn.Sequential(
- nn.BatchNorm2d(8),
- nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Dropout2d(0.5)
- )
- self.L5 = nn.Sequential(
- nn.BatchNorm2d(16),
- nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Dropout2d(0.5)
- )
- self.L3d = nn.Sequential(
- nn.BatchNorm2d(8),
- nn.Conv2d(8, 3, kernel_size=3, stride=1, padding=1),
- nn.Tanh()
- )
- self.L4d = nn.Sequential(
- nn.BatchNorm2d(16),
- nn.Conv2d(16, 3, kernel_size=3, stride=1, padding=1),
- nn.Tanh()
- )
- self.L5d = nn.Sequential(
- nn.BatchNorm2d(32),
- nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
- nn.Tanh()
- )
- self.L1_32 = nn.Sequential(
- nn.BatchNorm2d(3),
- nn.Conv2d(3, 32, kernel_size=32, stride=32, padding=0),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Dropout2d(0.5),
- nn.BatchNorm2d(32),
- nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
- nn.Tanh()
- )
- self.L1_16 = nn.Sequential(
- nn.BatchNorm2d(3),
- nn.Conv2d(3, 32, kernel_size=16, stride=16, padding=0),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Dropout2d(0.5),
- nn.BatchNorm2d(32),
- nn.Conv2d(32, 3, kernel_size=2, stride=2, padding=0),
- nn.Tanh()
- )
- self.L1_8 = nn.Sequential(
- nn.BatchNorm2d(3),
- nn.Conv2d(3, 32, kernel_size=8, stride=8, padding=0),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Dropout2d(0.5),
- nn.BatchNorm2d(32),
- nn.Conv2d(32, 3, kernel_size=4, stride=4, padding=0),
- nn.Tanh()
- )
- def forward(self, img):
- out = self.L1(img)
- out = self.L2(out)
- out_L3 = self.L3(out)
- out_L4 = self.L4(out_L3)
- out_L5 = self.L5(out_L4)
- out_L3d = self.L3d(out_L3)
- out_L4d = self.L4d(out_L4)
- out_L5d = self.L5d(out_L5)
- out_L1_32 = self.L1_32(img)
- out_L1_16 = self.L1_16(img)
- out_L1_8 = self.L1_8(img)
- out_L6 = (out_L3d + out_L4d + out_L5d) / 3
- out_context = (out_L1_32 + out_L1_16 + out_L1_8) / 3
- return img + out_L6 + out_context
- # Discriminator & Classifier
- # using pre-trained resnet18
- # delete last layer and get 512-dim fc output
- # add a fc-256 layer
- # append fc-10 layer to classifier and fc-1 layer to validate image (like GAN) to fc-256 layer
- # resnet18(fc-512) - (fc-256) - (fc-10) : classifier (0~9 class label, one-hot vector)
- # - (fc-1) : discriminator (real/fake, 1-dim output)
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
- self.model_ft = resnet18(pretrained=True)
- num_ftrs = self.model_ft.fc.in_features
- self.modules = list(self.model_ft.children())[:-1] # delete last layer
- self.model_ft = nn.Sequential(*self.modules)
- self.fc_feature = nn.Sequential(nn.Linear(num_ftrs, 256), nn.ReLU(), nn.Dropout()) # add fc-256 layer
- self.fc_layer = nn.Sequential(nn.Linear(256, 10), nn.Softmax()) # add fc-10
- self.val_layer = nn.Sequential(nn.Linear(256, 1), nn.Sigmoid()) # add fc-1
- def forward(self, img):
- out = self.model_ft(img)
- out = out.view(out.shape[0],-1)
- out = self.fc_feature(out)
- label = self.fc_layer(out)
- validity = self.val_layer(out)
- return validity, label # two outputs
- # GPU env.
- generator = Generator().cuda()
- generator.to(device)
- discriminator = Discriminator().cuda()
- discriminator.to(device)
- # criterions for generator and discriminator
- criterion_g = nn.MSELoss().cuda()
- criterion_d = nn.CrossEntropyLoss().cuda()
- criterion_val = nn.BCELoss().cuda()
- # weight initialize
- generator.apply(weights_init_normal)
- discriminator.apply(weights_init_normal)
- # optimizers
- optimizer_g = optim.Adam(generator.parameters())
- optimizer_d = optim.Adam(discriminator.parameters())
- # log file
- f = open('cifar10_DG_RGB_J_ResNet18_ALL_OnlyGLoss.log', 'w')
- for epoch in range(100): # loop over the dataset multiple times
- running_loss_d = 0.0
- running_loss_g = 0.0
- for i, data in enumerate(trainloader, 0): # take mini-batches
- imgs, labels = data
- imgLR, imgHR = imgs # generate Low-Resolutional images for each mini-batches
- labels = labels.cuda()
- #variables for discriminator
- valid = torch.autograd.Variable(torch.cuda.FloatTensor(labels.size()[0]).fill_(1.0), requires_grad=False)
- fake = torch.autograd.Variable(torch.cuda.FloatTensor(labels.size()[0]).fill_(0.0), requires_grad=False)
- # target data : High frequency factors of Y channel in High-resolutional images (HH, HL, LH)
- imgHR_Y, _, _ = transform_RGB2YCBCR(imgHR)
- imgHR_W = transform_DWT(imgHR_Y)[:,1:,:,:]
- # zero_grad for generator
- optimizer_g.zero_grad()
- # get enhanced image
- imgSR = generator(imgLR.cuda())
- imgSR_Y, _, _ = transform_RGB2YCBCR(imgSR.cpu())
- imgSR_W = transform_DWT(imgSR_Y)[:,1:,:,:]
- # resize for pre-trained resnet18 (32x32x3 --> 224x224x3)
- imgSRd = transform_224(imgSR)
- # get discriminator output
- validity, pred_label = discriminator(imgSRd.cuda())
- # generator loss
- # mse_loss between generated images(imgSR) and High-resolutional images(imgHR)
- # + mse_loss between high frequency factors in generated images(imgSR_W) and ones in High-resolutional images(imgHR_W)
- # + adversarial loss to fool discriminator
- loss_g = criterion_g(imgHR, imgSR.cpu()) + myMSELoss(imgSR_W, imgHR_W) + criterion_val(validity.cpu(), valid.cpu())
- #check if weights are updated
- a = list(generator.parameters())[0].clone()
- loss_g.backward()
- optimizer_g.step()
- b = list(generator.parameters())[0].clone()
- print(list(generator.parameters())[0].grad)
- optimizer_d.zero_grad()
- imgHRd = transform_224(imgHR)
- imgLRd = transform_224(imgLR)
- # get discriminator output for Low-resolutional images, High-resolutional images, and generated images
- val_LR, aux_LR = discriminator(imgLRd.cuda())
- val_HR, aux_HR = discriminator(imgHRd.cuda())
- val_SR, aux_SR = discriminator(imgSRd.cuda())
- # discriminator loss
- # classification loss for Low-Resolutional images + Real/Fake Loss for Low-Resolutional images
- # + classification loss for High-Resolutional images + Real/Fake Loss for High-Resolutional images
- # + classification loss for generated images + Real/Fake Loss for generated images
- loss_d_LR = (criterion_d(aux_LR, labels) + criterion_val(val_LR, fake)) / 2
- loss_d_HR = (criterion_d(aux_HR, labels) + criterion_val(val_HR, valid)) / 2
- loss_d_SR = (criterion_d(aux_SR, labels) + criterion_val(val_SR, fake)) / 2
- loss_d = (loss_d_HR + loss_d_LR + loss_d_SR) / 3
- loss_d.backward()
- optimizer_d.step()
- running_loss_d += loss_d.item()
- running_loss_g += loss_g.item()
- if i % 10 == 9: # print every 10 mini-batches
- disp_str = '[%d, %5d] d_loss: %.3f g_loss: %.3f' % (epoch + 1, i + 1, running_loss_d / 10, running_loss_g / 10)
- print(disp_str)
- f.write(disp_str)
- f.write('\n')
- running_loss_d = 0.0
- running_loss_g = 0.0
- if i % 10 == 9: # save sample every 10 mini-batches in test images
- for it, data in enumerate(testloader, 0):
- imgs, labels = data
- imgLR, imgHR = imgs
- imgSR = generator(imgLR.cuda())
- outImg = torch.cat(
- (imgLR, imgHR, imgSR.cpu().detach()), dim=0)
- outImg = torchvision.utils.make_grid(outImg, normalize=True, nrow=20)
- baseName = 'results/cifar10_DG_RGB_J_ResNet18_ALL_OnlyGLoss_'
- outFileName = baseName + 'epoch_' + \
- str(epoch) + '_' + str(i+1) + '.jpg'
- saveImg(outImg, outFileName)
- break
- #save weights
- baseName = 'results/checkpoint_cifar10_DG_RGB_J_ResNet18_ALL_OnlyGLoss_'
- outFileName = baseName + 'epoch_' + str(epoch)
- torch.save({
- 'epoch': epoch,
- 'generator_state_dict': generator.state_dict(),
- 'discriminator_state_dict': discriminator.state_dict(),
- 'optimizer_d_state_dict': optimizer_d.state_dict(),
- 'optimizer_g_state_dict': optimizer_g.state_dict(),
- 'loss_d': loss_d,
- 'loss_d': loss_g,
- }, outFileName)
- running_corrects_LR = 0
- running_corrects_HR = 0
- running_corrects_SR = 0
- #test accuracy for each epoches
- for it, data in enumerate(testloader, 0):
- print(it)
- imgs, labels = data
- labels = labels.cuda()
- imgLR, imgHR = imgs
- imgSR = generator(imgLR.cuda())
- imgLRd = transform_224(imgLR)
- imgHRd = transform_224(imgHR)
- imgSRd = transform_224(imgSR)
- val_LR, aux_LR = discriminator(imgLRd.cuda())
- val_HR, aux_HR = discriminator(imgHRd.cuda())
- val_SR, aux_SR = discriminator(imgSRd.cuda())
- _, pred_LR = torch.max(aux_LR.data, 1)
- _, pred_HR = torch.max(aux_HR.data, 1)
- _, pred_SR = torch.max(aux_SR.data, 1)
- running_corrects_LR += torch.sum(pred_LR == labels.data)
- running_corrects_HR += torch.sum(pred_HR == labels.data)
- running_corrects_SR += torch.sum(pred_SR == labels.data)
- disp_str = '[%d] LR : %d, HR : %d, SR : %d' % (epoch, running_corrects_LR.data, running_corrects_HR.data, running_corrects_SR.data)
- print(disp_str)
- f.write(disp_str)
- f.write('\n')
- print('Finished Training')
- f.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement