Advertisement
Guest User

Untitled

a guest
Mar 22nd, 2017
117
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 6.99 KB | None | 0 0
  1. from __future__ import print_function
  2. import argparse
  3. import os
  4. from math import log10
  5.  
  6. import torch
  7. import torch.nn as nn
  8. import torch.optim as optim
  9. from torch.autograd import Variable
  10. from torch.utils.data import DataLoader
  11. from models import G, D, weights_init
  12. from data import get_training_set, get_test_set
  13. import torch.backends.cudnn as cudnn
  14. import torchvision.utils as vutils
  15.  
  16. # Training settings
  17. parser = argparse.ArgumentParser(description='pix2pix-PyTorch-implementation')
  18. parser.add_argument('--dataset', required=True, help='facades')
  19. parser.add_argument('--batchSize', type=int, default=16, help='training batch size')
  20. parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size')
  21. parser.add_argument('--nEpochs', type=int, default=200, help='number of epochs to train for')
  22. parser.add_argument('--input_nc', type=int, default=3, help='input image channels')
  23. parser.add_argument('--output_nc', type=int, default=3, help='output image channels')
  24. parser.add_argument('--ngf', type=int, default=64, help='generator filters in first conv layer')
  25. parser.add_argument('--ndf', type=int, default=64, help='discriminator filters in first conv layer')
  26. parser.add_argument('--lr', type=float, default=0.0002, help='Learning Rate. Default=0.002')
  27. parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
  28. parser.add_argument('--cuda', action='store_true', help='use cuda?')
  29. parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
  30. parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
  31. parser.add_argument('--lamb', type=int, default=100, help='weight on L1 term in objective')
  32. parser.add_argument('--netG', default='', help="path to netG (to continue training)")
  33. parser.add_argument('--netD', default='', help="path to netD (to continue training)")
  34. opt = parser.parse_args()
  35.  
  36. print(opt)
  37.  
  38. if opt.cuda and not torch.cuda.is_available():
  39. raise Exception("No GPU found, please run without --cuda")
  40.  
  41. cudnn.benchmark = True
  42.  
  43. torch.manual_seed(opt.seed)
  44. if opt.cuda:
  45. torch.cuda.manual_seed(opt.seed)
  46.  
  47. print('===> Loading datasets')
  48. root_path = "dataset/"
  49. train_set = get_training_set(root_path + opt.dataset)
  50. test_set = get_test_set(root_path + opt.dataset)
  51. training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
  52. testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)
  53.  
  54. print('===> Building model')
  55. if opt.netG:
  56. netG = torch.load(opt.netG)
  57. print('==> Loaded model for G.')
  58. else:
  59. netG = G(opt.input_nc, opt.output_nc, opt.ngf)
  60. netG.apply(weights_init)
  61.  
  62. if opt.netD:
  63. netG = torch.load(opt.netD)
  64. print('==> Loaded model for D.')
  65. else:
  66. netD = D(opt.input_nc, opt.output_nc, opt.ndf)
  67. netD.apply(weights_init)
  68.  
  69. criterion = nn.BCELoss()
  70. criterion_l1 = nn.L1Loss()
  71. criterion_mse = nn.MSELoss()
  72.  
  73. real_A = torch.FloatTensor(opt.batchSize, opt.input_nc, 256, 256)
  74. real_B = torch.FloatTensor(opt.batchSize, opt.output_nc, 256, 256)
  75. label = torch.FloatTensor(opt.batchSize)
  76. real_label = 1
  77. fake_label = 0
  78.  
  79. if opt.cuda:
  80. netD = netD.cuda()
  81. netG = netG.cuda()
  82. criterion = criterion.cuda()
  83. criterion_l1 = criterion_l1.cuda()
  84. criterion_mse = criterion_mse.cuda()
  85. real_A = real_A.cuda()
  86. real_B = real_B.cuda()
  87. label = label.cuda()
  88.  
  89.  
  90. real_A = Variable(real_A)
  91. real_B = Variable(real_B)
  92. label = Variable(label)
  93.  
  94. # setup optimizer
  95. optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
  96. optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
  97.  
  98.  
  99. def train(epoch):
  100. for iteration, batch in enumerate(training_data_loader, 1):
  101. ############################
  102. # (1) Update D network: maximize log(D(x,y)) + log(1 - D(x,G(x)))
  103. ###########################
  104. for p in netD.parameters(): # reset requires_grad
  105. p.requires_grad = True # they are set to False below in netG update
  106.  
  107. # train with real
  108. netD.zero_grad()
  109. real_a_cpu, real_b_cpu = batch[0], batch[1]
  110. real_A.data.resize_(real_a_cpu.size()).copy_(real_a_cpu)
  111. real_B.data.resize_(real_b_cpu.size()).copy_(real_b_cpu)
  112.  
  113. output = netD(torch.cat((real_A, real_B), 1))
  114. label.data.resize_(output.size()).fill_(real_label)
  115. err_d_real = criterion(output, label)
  116. err_d_real.backward()
  117. d_x_y = output.data.mean()
  118.  
  119. # train with fake
  120. fake_b = netG(real_A)
  121. output = netD(torch.cat((real_A, fake_b.detach()), 1))
  122. label.data.resize_(output.size()).fill_(fake_label)
  123. err_d_fake = criterion(output, label)
  124. err_d_fake.backward()
  125. d_x_gx = output.data.mean()
  126.  
  127. err_d = (err_d_real + err_d_fake) / 2.0
  128. optimizerD.step()
  129.  
  130. ############################
  131. # (2) Update G network: maximize log(D(x,G(x))) + L1(y,G(x))
  132. ###########################
  133. for p in netD.parameters():
  134. p.requires_grad = False # to avoid computation
  135. netG.zero_grad()
  136. output = netD(torch.cat((real_A, fake_b), 1))
  137. label.data.resize_(output.size()).fill_(real_label)
  138. err_g = criterion(output, label) + opt.lamb * criterion_l1(fake_b, real_B)
  139. err_g.backward()
  140. d_x_gx_2 = output.data.mean()
  141. optimizerG.step()
  142.  
  143. print("===> Epoch[{}]({}/{}): Loss_D: {:.4f} Loss_G: {:.4f} D(x): {:.4f} D(G(z)): {:.4f}/{:.4f}".format(
  144. epoch, iteration, len(training_data_loader), err_d.data[0], err_g.data[0], d_x_y, d_x_gx, d_x_gx_2))
  145.  
  146. if iteration % 200 == 1:
  147. vutils.save_image(real_a_cpu,
  148. '%s/input_samples_epoch_%03d_%03d.png' % ("out", epoch,iteration))
  149. # fake = netG(fixed_noise)
  150. vutils.save_image(fake_b.data,
  151. '%s/output_samples_epoch_%03d_%03d.png' % ("out", epoch, iteration))
  152.  
  153.  
  154. def test():
  155. avg_psnr = 0
  156. for batch in testing_data_loader:
  157. input, target = Variable(batch[0]), Variable(batch[1])
  158. if opt.cuda:
  159. input = input.cuda()
  160. target = target.cuda()
  161.  
  162. prediction = netG(input)
  163. mse = criterion_mse(prediction, target)
  164. psnr = 10 * log10(1 / mse.data[0])
  165. avg_psnr += psnr
  166. print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(testing_data_loader)))
  167.  
  168.  
  169. def checkpoint(epoch):
  170. if not os.path.exists("checkpoint"):
  171. os.mkdir("checkpoint")
  172. if not os.path.exists(os.path.join("checkpoint", opt.dataset)):
  173. os.mkdir(os.path.join("checkpoint", opt.dataset))
  174. net_g_model_out_path = "checkpoint/{}/netG_model_epoch_{}.pth".format(opt.dataset, epoch)
  175. net_d_model_out_path = "checkpoint/{}/netD_model_epoch_{}.pth".format(opt.dataset, epoch)
  176. torch.save(netG, net_g_model_out_path)
  177. torch.save(netD, net_d_model_out_path)
  178. print("Checkpoint saved to {}".format("checkpoint" + opt.dataset))
  179.  
  180. for epoch in range(1, opt.nEpochs + 1):
  181. train(epoch)
  182. #test()
  183. if epoch % 5 == 0:
  184. checkpoint(epoch)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement