Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torchvision
- import torchvision.transforms as transforms
- import matplotlib.pyplot as plt
- import numpy as np
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
- self.fc = nn.Linear(16, 49)
- self.main = nn.Sequential(
- #7x7x1
- nn.Conv2d(1, 32, 3, padding=1),
- nn.Upsample(scale_factor=2),
- nn.BatchNorm2d(32, 0.8),
- nn.LeakyReLU(0.2, inplace=True),
- #14x14x32
- nn.Conv2d(32, 16, 3, padding=1),
- nn.Upsample(scale_factor=2),
- nn.BatchNorm2d(16, 0.8),
- nn.LeakyReLU(0.2, inplace=True),
- #28x28x16
- nn.Conv2d(16, 1, 3, padding=1),
- #nn.LeakyReLU(0.2, inplace=True),
- #28x28x1
- )
- def forward(self, x):
- x = self.fc(x)
- x = x.view(-1, 1, 7, 7)
- x = self.main(x)
- return x
- def generate(self, device):
- x = torch.randn(1, 16, dtype=torch.float, device=device)
- return self.forward(x)
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
- self.main = nn.Sequential(
- #28x28x1
- nn.Conv2d(1, 32, 4, padding=1, stride=2),
- nn.BatchNorm2d(32, 0.8),
- nn.LeakyReLU(0.2, inplace=True),
- #14x14x32
- nn.Conv2d(32, 32, 4, padding=1, stride=2),
- nn.BatchNorm2d(32, 0.8),
- nn.LeakyReLU(0.2, inplace=True),
- #7x7x32
- nn.Conv2d(32, 4, 3, padding=1),
- #7x7x4
- )
- self.fc = nn.Linear(196, 1)
- self.n = nn.LeakyReLU(0.2, inplace=True)
- def forward(self, x):
- x = self.main(x)
- x = x.view(-1, 196)
- x = self.fc(x)
- x = self.n(x)
- return x
- gen = Generator()
- print(gen)
- print("\r\n\r\n\r\n")
- disc = Discriminator()
- print(disc)
- device = torch.device("cpu")
- print(device)
- print(torch.cuda.get_device_name(0))
- print(torch.cuda.get_device_name(1))
- gen.to(device)
- disc.to(device)
- def show_samples():
- with torch.no_grad():
- fig = plt.figure(figsize=(8,8))
- for i in range(1, 17):
- img = gen.generate(device).view(28, 28)
- img = img / 2 + 0.5
- npimg = img.numpy()
- fig.add_subplot(4, 4, i)
- plt.imshow(npimg, cmap="gray")
- plt.show()
- batch_size=4
- transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
- trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
- learning_rate = 0.001
- optimizer_gen = torch.optim.Adam(list(gen.parameters()), lr=learning_rate)
- optimizer_disc = torch.optim.Adam(list(disc.parameters()), lr=learning_rate)
- for epoch in range(8):
- print("Epoch " + str(epoch + 1) + "/3")
- running_loss_disc = 0.0
- running_loss_gen = 0.0
- for i, data in enumerate(trainloader, 0):
- #print("step" + str(i))
- inputs = data[0].to(device)
- optimizer_gen.zero_grad()
- optimizer_disc.zero_grad()
- discOut_real = disc.forward(inputs)
- fakeOut = gen.forward(torch.randn(batch_size, 16, dtype=torch.float, device=device))
- discOut_fake = disc.forward(fakeOut.clone().detach().to(device).requires_grad_(False))
- loss = (discOut_real - 0.9).pow(2).sum() + (discOut_fake - 0.0).pow(2).sum()
- running_loss_disc += loss.item()
- loss.backward()
- optimizer_disc.step()
- discOut_fake = disc.forward(fakeOut)
- loss = (discOut_fake - 0.9).pow(2).sum()
- running_loss_gen += loss.item()
- loss.backward()
- disc.zero_grad()
- optimizer_gen.step()
- gen.zero_grad()
- if i % 2000 == 1999:
- print('[%d, %5d] loss: %.3f %.3f' % (epoch + 1, i + 1, running_loss_disc / 2000, running_loss_gen / 2000))
- running_loss_disc = 0.0
- running_loss_gen = 0.0
- show_samples()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement