SHARE
TWEET

Untitled

a guest Jun 25th, 2019 55 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # @Time    :  2019/4/15 9:56
  2. # @FileName:  train_gan.py`enter code here`
  3.  
  4. import os
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from torch.autograd import Variable
  9. from torch.utils.data import DataLoader
  10. import matplotlib.pyplot as plt
  11. import numpy as np
  12. import time
  13. import torch.optim as optim
  14. from PIL import Image
  15. from models import U_Net, AttU_Net, discriminator
  16. from DataProcess import myDataset, my_Testdata
  17. from metric import scores
  18. from metric import TVLoss
  19. from tensorboardX import SummaryWriter
  20. from utils import make_trainable, calc_gradient_penalty
  21.  
  22.  
  23. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  24. DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  25. VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0]]
  26.  
  27.  
  28. def label2image(pred):
  29.     colormap = np.array(VOC_COLORMAP, dtype='uint8')
  30.     X = pred.astype('int32')
  31.     return colormap[X, :]
  32.  
  33.  
  34. def create_dataloader(data_txt, batch=1, mode=0):
  35.     """
  36.     [description]
  37.     Arguments:
  38.         data_txt {[txt path]} -- [description] The train / valid / test txt
  39.     Keyword Arguments:
  40.         batch {number} -- [description] (default: {1})
  41.         mode {number} -- [description] (default: {0 : train, 1 : valid})
  42.     Returns:
  43.         [object] -- [description]
  44.     """
  45.     file_txt = data_txt
  46.     data = myDataset(txt_path=file_txt, transformer=True, one_hot=True)
  47.     if mode == 0:
  48.         loader = DataLoader(data, num_workers=8, batch_size=batch, shuffle=True)
  49.     else:
  50.         loader = DataLoader(data, num_workers=8, batch_size=batch, shuffle=False)
  51.     return loader
  52.  
  53.  
  54. def test_loader(test_path, batch=1):
  55.     data = my_Testdata(txt_path=test_path)
  56.     loader = DataLoader(data, num_workers=8, batch_size=batch, shuffle=False)
  57.     return loader
  58.  
  59.  
  60. if __name__ == '__main__':
  61.     batch_size = 4
  62.     maxepochs = 100
  63.     gan_loss_percent = 0.03
  64.     epoch_lapse = 10
  65.     # threshold = 0.5
  66.     learning_rate = 0.001
  67.     # mutil_gpu = True
  68.     # device_ids = [0, 2]
  69.  
  70.     train_file_path = "./data/dataset/train.txt"
  71.     valid_file_path = "./data/dataset/valid.txt"
  72.     test_file_path = "./data/test/"
  73.  
  74.     # 构建MyDataset实例,此处得到的数据为tensor形式,然后结合batch返回train和vaild的数据加载器
  75.     train_data = create_dataloader(train_file_path, batch_size)
  76.     valid_data = create_dataloader(valid_file_path, batch_size, mode=1)
  77.  
  78.     # net.load_state_dict(torch.load('{0}{1}_{2}.pkl'.format("./save_model/", "unet", 99)))
  79.     # criterion = nn.CrossEntropyLoss().cuda()
  80.  
  81.     # loss function
  82.     adversarial_loss = torch.nn.MSELoss().cuda()
  83.     generator_loss = nn.CrossEntropyLoss().cuda()
  84.  
  85.     # Initialize generator and discriminator
  86.     generator = AttU_Net().cuda()
  87.     discriminator = discriminator().cuda()
  88.  
  89.     # Adversarial ground truths
  90.     one = torch.ones(1).cuda()
  91.     neg_one = one * -1
  92.     gan_one = one * -1 * gan_loss_percent
  93.     # valid = torch.ones(batch_size, 1).cuda()
  94.     # fake = torch.zeros(batch_size, 1).cuda()
  95.     # dataloader is above
  96.  
  97.     # optimizers
  98.     optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9))
  99.     optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9))
  100.  
  101.     # Training
  102.     for epoch in range(maxepochs):
  103.         discriminator.train()
  104.         generator.train()
  105.         # train discriminator
  106.         make_trainable(discriminator, val=True)
  107.         make_trainable(generator, val=False)
  108.         Wasserstein_D, g_loss = 0, 0
  109.         for batch_idx, (data, target) in enumerate(train_data):
  110.             # configure input
  111.             target = torch.argmax(target, dim=1).unsqueeze(1).float()
  112.             # target = target.float()
  113.  
  114.             data, target = data.cuda(), target.cuda()
  115.             # print(target.type())
  116.             # 判别器网络与优化器梯度清零
  117.             discriminator.zero_grad()
  118.             optimizer_D.zero_grad()
  119.  
  120.             # real pair
  121.             real_pair = torch.cat((data, target), dim=1)
  122.             d_real = discriminator(real_pair)
  123.             d_real = d_real.mean()
  124.             d_real.backward(neg_one)
  125.  
  126.             # fake pair
  127.             fake_label = torch.argmax(F.softmax(generator(data), dim=1), dim=1).unsqueeze(1).float()
  128.             # print(fake_label.shape)
  129.             fake_pair = torch.cat((data, fake_label), dim=1)
  130.             # print(fake_pair.shape)
  131.             d_fake = discriminator(fake_pair)
  132.             d_fake = d_fake.mean()
  133.             d_fake.backward(one)
  134.  
  135.             gradient_penalty = calc_gradient_penalty(discriminator, real_pair, fake_pair)
  136.             gradient_penalty.backward()
  137.  
  138.             Wasserstein_D = d_real - d_fake
  139.             optimizer_D.step()
  140.             break
  141.             # print("Batch is {0},D loss:{1}".format(batch_idx + 1, Wasserstein_D.item()))
  142.  
  143.         make_trainable(discriminator, False)
  144.         make_trainable(generator, True)
  145.         for batch_idx, (data, target) in enumerate(train_data):
  146.             generator.zero_grad()
  147.             optimizer_G.zero_grad()
  148.  
  149.             # configure input
  150.             target = torch.argmax(target, dim=1).unsqueeze(1)
  151.             data, target = data.cuda(), target.cuda()
  152.  
  153.             # pred_labels = generator(data)
  154.             pred_labels = generator(data)
  155.             # pred_labels = torch.argmax(F.softmax(generator(data), dim=1), dim=1).float()
  156.  
  157.             # print(target.squeeze(1).shape)
  158.             # g_loss.backward(retain_graph=True)
  159.             g_loss = generator_loss(pred_labels, target.squeeze(1))
  160.             # g_loss = generator_loss(pred_labels, torch.argmax(target, dim=1))
  161.             # print(target.shape)
  162.             print(g_loss.grad)
  163.             g_loss.backward(retain_graph=True)
  164.             pred_lab = torch.argmax(F.softmax(generator(data), dim=1), dim=1).unsqueeze(1).float()
  165.             fa_pair = torch.cat((data, pred_lab), dim=1)
  166.             # print(data.shape, pred_labels.shape,pred_lab.shape)
  167.  
  168.             gd_fake = discriminator(fa_pair)
  169.  
  170.             gd_fake.mean()
  171.             # gan_loss = gan_loss*gan_loss_percent + g_loss
  172.             gd_fake.backward(gan_one)
  173.             # gan_loss.backward()
  174.             optimizer_G.step()
  175.  
  176.             print("Batch is {0},G loss:{0}".format(batch_idx + 1, g_loss.item()))
  177.             # print(g_loss.item())
  178.         print("{0}/{1},Wasserstein_D_loss:{2},generator_loss:{3}".format(epoch, maxepochs, Wasserstein_D, g_loss))
  179.  
  180. Traceback (most recent call last):
  181.   File "/home/wangzhao/project/master/train_gan.py", line 327, in <module>
  182.     gd_fake.backward(gan_one)
  183.   File "/home/zking/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/tensor.py", line 102, in backward
  184.     torch.autograd.backward(self, gradient, retain_graph, create_graph)
  185.   File "/home/zking/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/autograd/__init__.py", line 90, in backward
  186.     allow_unreachable=True)  # allow_unreachable flag
  187. RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
  188.  
  189. Process finished with exit code 1
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top