Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # @Time : 2019/4/15 9:56
- # @FileName: train_gan.py`enter code here`
- import os
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.autograd import Variable
- from torch.utils.data import DataLoader
- import matplotlib.pyplot as plt
- import numpy as np
- import time
- import torch.optim as optim
- from PIL import Image
- from models import U_Net, AttU_Net, discriminator
- from DataProcess import myDataset, my_Testdata
- from metric import scores
- from metric import TVLoss
- from tensorboardX import SummaryWriter
- from utils import make_trainable, calc_gradient_penalty
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0]]
- def label2image(pred):
- colormap = np.array(VOC_COLORMAP, dtype='uint8')
- X = pred.astype('int32')
- return colormap[X, :]
- def create_dataloader(data_txt, batch=1, mode=0):
- """
- [description]
- Arguments:
- data_txt {[txt path]} -- [description] The train / valid / test txt
- Keyword Arguments:
- batch {number} -- [description] (default: {1})
- mode {number} -- [description] (default: {0 : train, 1 : valid})
- Returns:
- [object] -- [description]
- """
- file_txt = data_txt
- data = myDataset(txt_path=file_txt, transformer=True, one_hot=True)
- if mode == 0:
- loader = DataLoader(data, num_workers=8, batch_size=batch, shuffle=True)
- else:
- loader = DataLoader(data, num_workers=8, batch_size=batch, shuffle=False)
- return loader
- def test_loader(test_path, batch=1):
- data = my_Testdata(txt_path=test_path)
- loader = DataLoader(data, num_workers=8, batch_size=batch, shuffle=False)
- return loader
- if __name__ == '__main__':
- batch_size = 4
- maxepochs = 100
- gan_loss_percent = 0.03
- epoch_lapse = 10
- # threshold = 0.5
- learning_rate = 0.001
- # mutil_gpu = True
- # device_ids = [0, 2]
- train_file_path = "./data/dataset/train.txt"
- valid_file_path = "./data/dataset/valid.txt"
- test_file_path = "./data/test/"
- # 构建MyDataset实例,此处得到的数据为tensor形式,然后结合batch返回train和vaild的数据加载器
- train_data = create_dataloader(train_file_path, batch_size)
- valid_data = create_dataloader(valid_file_path, batch_size, mode=1)
- # net.load_state_dict(torch.load('{0}{1}_{2}.pkl'.format("./save_model/", "unet", 99)))
- # criterion = nn.CrossEntropyLoss().cuda()
- # loss function
- adversarial_loss = torch.nn.MSELoss().cuda()
- generator_loss = nn.CrossEntropyLoss().cuda()
- # Initialize generator and discriminator
- generator = AttU_Net().cuda()
- discriminator = discriminator().cuda()
- # Adversarial ground truths
- one = torch.ones(1).cuda()
- neg_one = one * -1
- gan_one = one * -1 * gan_loss_percent
- # valid = torch.ones(batch_size, 1).cuda()
- # fake = torch.zeros(batch_size, 1).cuda()
- # dataloader is above
- # optimizers
- optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9))
- optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9))
- # Training
- for epoch in range(maxepochs):
- discriminator.train()
- generator.train()
- # train discriminator
- make_trainable(discriminator, val=True)
- make_trainable(generator, val=False)
- Wasserstein_D, g_loss = 0, 0
- for batch_idx, (data, target) in enumerate(train_data):
- # configure input
- target = torch.argmax(target, dim=1).unsqueeze(1).float()
- # target = target.float()
- data, target = data.cuda(), target.cuda()
- # print(target.type())
- # 判别器网络与优化器梯度清零
- discriminator.zero_grad()
- optimizer_D.zero_grad()
- # real pair
- real_pair = torch.cat((data, target), dim=1)
- d_real = discriminator(real_pair)
- d_real = d_real.mean()
- d_real.backward(neg_one)
- # fake pair
- fake_label = torch.argmax(F.softmax(generator(data), dim=1), dim=1).unsqueeze(1).float()
- # print(fake_label.shape)
- fake_pair = torch.cat((data, fake_label), dim=1)
- # print(fake_pair.shape)
- d_fake = discriminator(fake_pair)
- d_fake = d_fake.mean()
- d_fake.backward(one)
- gradient_penalty = calc_gradient_penalty(discriminator, real_pair, fake_pair)
- gradient_penalty.backward()
- Wasserstein_D = d_real - d_fake
- optimizer_D.step()
- break
- # print("Batch is {0},D loss:{1}".format(batch_idx + 1, Wasserstein_D.item()))
- make_trainable(discriminator, False)
- make_trainable(generator, True)
- for batch_idx, (data, target) in enumerate(train_data):
- generator.zero_grad()
- optimizer_G.zero_grad()
- # configure input
- target = torch.argmax(target, dim=1).unsqueeze(1)
- data, target = data.cuda(), target.cuda()
- # pred_labels = generator(data)
- pred_labels = generator(data)
- # pred_labels = torch.argmax(F.softmax(generator(data), dim=1), dim=1).float()
- # print(target.squeeze(1).shape)
- # g_loss.backward(retain_graph=True)
- g_loss = generator_loss(pred_labels, target.squeeze(1))
- # g_loss = generator_loss(pred_labels, torch.argmax(target, dim=1))
- # print(target.shape)
- print(g_loss.grad)
- g_loss.backward(retain_graph=True)
- pred_lab = torch.argmax(F.softmax(generator(data), dim=1), dim=1).unsqueeze(1).float()
- fa_pair = torch.cat((data, pred_lab), dim=1)
- # print(data.shape, pred_labels.shape,pred_lab.shape)
- gd_fake = discriminator(fa_pair)
- gd_fake.mean()
- # gan_loss = gan_loss*gan_loss_percent + g_loss
- gd_fake.backward(gan_one)
- # gan_loss.backward()
- optimizer_G.step()
- print("Batch is {0},G loss:{0}".format(batch_idx + 1, g_loss.item()))
- # print(g_loss.item())
- print("{0}/{1},Wasserstein_D_loss:{2},generator_loss:{3}".format(epoch, maxepochs, Wasserstein_D, g_loss))
- Traceback (most recent call last):
- File "/home/wangzhao/project/master/train_gan.py", line 327, in <module>
- gd_fake.backward(gan_one)
- File "/home/zking/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/tensor.py", line 102, in backward
- torch.autograd.backward(self, gradient, retain_graph, create_graph)
- File "/home/zking/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/autograd/__init__.py", line 90, in backward
- allow_unreachable=True) # allow_unreachable flag
- RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
- Process finished with exit code 1
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement