Advertisement
TheGhastModding

EBGAN.py

Aug 27th, 2019
227
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.18 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torchvision
  5. import torchvision.transforms as transforms
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8.  
  9. #The Generator
  10. class Generator(nn.Module):
  11.  
  12.     def __init__(self):
  13.         super(Generator, self).__init__()
  14.  
  15.         self.fc = nn.Linear(24, 49)
  16.         self.main = nn.Sequential(
  17.             #7x7x1
  18.             nn.Conv2d(1, 48, 3, padding=1),
  19.             nn.Upsample(scale_factor=2),
  20.             nn.BatchNorm2d(48, 0.4),
  21.             nn.LeakyReLU(0.2),
  22.             #14x14x32
  23.             nn.Conv2d(48, 24, 3, padding=1),
  24.             nn.Upsample(scale_factor=2),
  25.             nn.BatchNorm2d(24, 0.4),
  26.             nn.LeakyReLU(0.2),
  27.             #28x28x16
  28.             nn.Conv2d(24, 1, 3, padding=1, bias=False),
  29.             #nn.LeakyReLU(0.2, inplace=True),
  30.             #28x28x1
  31.         )
  32.  
  33.     def forward(self, x):
  34.         x = self.fc(x)
  35.         x = x.view(-1, 1, 7, 7)
  36.         x = self.main(x)
  37.         return x
  38.  
  39.     def generate(self, gdevice):
  40.         x = torch.randn(1, 24, dtype=torch.float, device=gdevice)
  41.         return self.forward(x)
  42.  
  43. #The Discriminator / Autoencoder
  44. class Discriminator(nn.Module):
  45.  
  46.     def __init__(self):
  47.         super(Discriminator, self).__init__()
  48.  
  49.         self.encoder = nn.Sequential(
  50.             #28x28x1
  51.             nn.Conv2d(1, 48, 4, padding=1, stride=2),
  52.             nn.BatchNorm2d(48, 0.4),
  53.             nn.LeakyReLU(0.2),
  54.             #14x14x32
  55.             nn.Conv2d(48, 48, 4, padding=1, stride=2),
  56.             nn.BatchNorm2d(48, 0.4),
  57.             nn.LeakyReLU(0.2),
  58.             #7x7x32
  59.             nn.Conv2d(48, 4, 3, padding=1, bias=False),
  60.             #7x7x4
  61.         )
  62.         self.fc = nn.Linear(196, 24)
  63.         self.n = nn.LeakyReLU(0.2)
  64.         self.fc2 = nn.Linear(24, 49)
  65.         self.decoder = nn.Sequential(
  66.             #7x7x1
  67.             nn.Conv2d(1, 48, 3, padding=1),
  68.             nn.Upsample(scale_factor=2),
  69.             nn.BatchNorm2d(48, 0.4),
  70.             nn.LeakyReLU(0.2),
  71.             #14x14x32
  72.             nn.Conv2d(48, 24, 3, padding=1),
  73.             nn.Upsample(scale_factor=2),
  74.             nn.BatchNorm2d(24, 0.4),
  75.             nn.LeakyReLU(0.2),
  76.             #28x28x16
  77.             nn.Conv2d(24, 1, 3, padding=1, bias=False),
  78.             #nn.LeakyReLU(0.2, inplace=True),
  79.             #28x28x1
  80.         )
  81.  
  82.     def forward(self, x):
  83.         x = self.encoder(x)
  84.         x = x.view(-1, 196)
  85.         x = self.fc(x)
  86.         x = self.n(x)
  87.         x = self.fc2(x)
  88.         x = x.view(-1, 1, 7, 7)
  89.         x = self.decoder(x)
  90.         return x
  91.  
  92. gen = Generator()
  93. print(gen)
  94.  
  95. print("\r\n\r\n\r\n")
  96.  
  97. disc = Discriminator()
  98. print(disc)
  99.  
  100. devn = "cuda:1"
  101. device = torch.device(devn)
  102. print(device)
  103.  
  104. gen = gen.to(device)
  105. disc = disc.to(device)
  106. if devn != "cpu":
  107.     gen = gen.cuda(device=device)
  108.     disc = disc.cuda(device=device)
  109.  
  110. def show_samples(gdevice):
  111.     with torch.no_grad():
  112.         fig = plt.figure(figsize=(8,8))
  113.         for i in range(1, 17):
  114.             img = gen.generate(gdevice).cpu().view(28, 28)
  115.  
  116.             img = img / 2 + 0.5
  117.             npimg = img.numpy()
  118.             fig.add_subplot(4, 4, i)
  119.             plt.imshow(npimg, cmap="gray")
  120.         plt.show()
  121.  
  122. batch_size = 32
  123. #margin = max(1, batch_size / 64)
  124. margin = 20
  125. learning_rate = 1e-3
  126. epochs = 16
  127.  
  128. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
  129.  
  130. trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  131.  
  132. trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
  133.  
  134. optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate, betas=(0.5, 0.999))
  135. optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate, betas=(0.5, 0.999))
  136.  
  137. pixel_loss = nn.MSELoss()
  138.  
  139. #torch.autograd.set_detect_anomaly(True)
  140. for epoch in range(epochs):
  141.     print("Epoch " + str(epoch + 1) + "/" + str(epochs))
  142.  
  143.     running_loss_disc = 0.0
  144.     running_loss_gen = 0.0
  145.     for i, data in enumerate(trainloader, 0):
  146.         inputs = data[0].to(device)
  147.  
  148.         gen.zero_grad()
  149.         fakeOut = gen.forward(torch.randn(batch_size, 24, dtype=torch.float, device=device))
  150.         fake_energy = disc.forward(fakeOut)
  151.         loss = pixel_loss(fakeOut.detach(), fake_energy)
  152.         running_loss_gen += loss.item()
  153.         loss.backward()
  154.         optimizer_gen.step()
  155.  
  156.         disc.zero_grad()
  157.         fake_energy = disc.forward(fakeOut.detach())
  158.         real_energy = disc.forward(inputs)
  159.         fake_loss = pixel_loss(fake_energy, fakeOut.detach())
  160.         real_loss = pixel_loss(real_energy, inputs)
  161.         if (margin - fake_loss.data).item() > 0:
  162.             loss = (margin - fake_loss) + real_loss
  163.         else:
  164.             loss = real_loss
  165.         running_loss_disc += loss.item()
  166.         loss.backward()
  167.         optimizer_disc.step()
  168.  
  169.         if i % 500 == 499:
  170.             print('[%d, %5d] loss: %.3f %.3f' % (epoch + 1, i + 1, running_loss_disc / batch_size / 500, running_loss_gen / batch_size / 500))
  171.             running_loss_disc = 0.0
  172.             running_loss_gen = 0.0
  173.  
  174. gen.eval()
  175. show_samples(device)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement