Advertisement
TheGhastModding

gtest

Aug 26th, 2019
153
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.24 KB | None | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torchvision
  5. import torchvision.transforms as transforms
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8.  
  9. class Generator(nn.Module):
  10.  
  11.     def __init__(self):
  12.         super(Generator, self).__init__()
  13.        
  14.         self.fc = nn.Linear(24, 49)
  15.         self.main = nn.Sequential(
  16.             #7x7x1
  17.             nn.Conv2d(1, 48, 3, padding=1),
  18.             nn.Upsample(scale_factor=2),
  19.             nn.BatchNorm2d(48, 0.8),
  20.             nn.LeakyReLU(0.2, inplace=True),
  21.             #14x14x32
  22.             nn.Conv2d(48, 24, 3, padding=1),
  23.             nn.Upsample(scale_factor=2),
  24.             nn.BatchNorm2d(24, 0.8),
  25.             nn.LeakyReLU(0.2, inplace=True),
  26.             #28x28x16
  27.             nn.Conv2d(24, 1, 3, padding=1, bias=False),
  28.             #nn.LeakyReLU(0.2, inplace=True),
  29.             #28x28x1
  30.         )
  31.  
  32.     def forward(self, x):
  33.         x = self.fc(x)
  34.         x = x.view(-1, 1, 7, 7)
  35.         x = self.main(x)
  36.         return x
  37.  
  38.     def generate(self, device):
  39.         x = torch.randn(1, 24, dtype=torch.float, device=device)
  40.         return self.forward(x)
  41.  
  42. class Discriminator(nn.Module):
  43.  
  44.     def __init__(self):
  45.         super(Discriminator, self).__init__()
  46.  
  47.         self.main = nn.Sequential(
  48.             #28x28x1
  49.             nn.Conv2d(1, 48, 4, padding=1, stride=2),
  50.             nn.BatchNorm2d(48, 0.8),
  51.             nn.LeakyReLU(0.2, inplace=True),
  52.             #14x14x32
  53.             nn.Conv2d(48, 48, 4, padding=1, stride=2),
  54.             nn.BatchNorm2d(48, 0.8),
  55.             nn.LeakyReLU(0.2, inplace=True),
  56.             #7x7x32
  57.             nn.Conv2d(48, 4, 3, padding=1, bias=False),
  58.             #7x7x4
  59.         )
  60.         self.fc = nn.Linear(196, 1)
  61.         self.n = nn.LeakyReLU(0.2, inplace=True)
  62.  
  63.     def forward(self, x):
  64.         x = self.main(x)
  65.         x = x.view(-1, 196)
  66.         x = self.fc(x)
  67.         x = self.n(x)
  68.         return x
  69.  
  70. gen = Generator()
  71. print(gen)
  72.  
  73. print("\r\n\r\n\r\n")
  74.  
  75. disc = Discriminator()
  76. print(disc)
  77.  
  78. device = torch.device("cuda:1")
  79. print(device)
  80. print(torch.cuda.get_device_name(0))
  81. print(torch.cuda.get_device_name(1))
  82.  
  83. gen.to(device)
  84. disc.to(device)
  85.  
  86. def show_samples():
  87.     with torch.no_grad():
  88.         fig = plt.figure(figsize=(8,8))
  89.         for i in range(1, 17):
  90.             img = gen.generate(device).cpu().view(28, 28)
  91.  
  92.             img = img / 2 + 0.5
  93.             npimg = img.numpy()
  94.             fig.add_subplot(4, 4, i)
  95.             plt.imshow(npimg, cmap="gray")
  96.         plt.show()
  97.  
  98. batch_size=32
  99. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
  100.  
  101. trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  102.  
  103. trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
  104.  
  105. learning_rate = 1e-4
  106. optimizer_gen = torch.optim.Adam(list(gen.parameters()), lr=learning_rate, betas=(0.5, 0.999))
  107. optimizer_disc = torch.optim.Adam(list(disc.parameters()), lr=learning_rate * 1.25, betas=(0.5, 0.999))
  108.  
  109. epochs = 8
  110. for epoch in range(epochs):
  111.     print("Epoch " + str(epoch + 1) + "/" + str(epochs))
  112.  
  113.     running_loss_disc = 0.0
  114.     running_loss_gen = 0.0
  115.     for i, data in enumerate(trainloader, 0):
  116.         #print("step" + str(i))
  117.         inputs = data[0].to(device)
  118.        
  119.         discOut_real = disc.forward(inputs)
  120.         fakeOut = gen.forward(torch.randn(batch_size, 24, dtype=torch.float, device=device))
  121.         discOut_fake = disc.forward(fakeOut.detach())
  122.  
  123.         loss = (discOut_real - 0.9).pow(2).sum() + (discOut_fake + 0.1).pow(2).sum()
  124.         running_loss_disc += loss.item()
  125.         loss.backward()
  126.         optimizer_disc.step()
  127.  
  128.         discOut_fake = disc.forward(fakeOut)        
  129.         loss = (discOut_fake - 0.9).pow(2).sum()
  130.         running_loss_gen += loss.item()
  131.         loss.backward()
  132.         disc.zero_grad()
  133.         optimizer_gen.step()
  134.         gen.zero_grad()
  135.  
  136.         if i % 500 == 499:
  137.             print('[%d, %5d] loss: %.3f %.3f' % (epoch + 1, i + 1, running_loss_disc / batch_size / 2000, running_loss_gen / batch_size / 2000))
  138.             running_loss_disc = 0.0
  139.             running_loss_gen = 0.0
  140.  
  141. show_samples()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement