Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torchvision
- import torchvision.transforms as transforms
- import matplotlib.pyplot as plt
- import numpy as np
- #The Generator
- class Generator(nn.Module):
- def __init__(self):
- super(Generator, self).__init__()
- self.fc = nn.Linear(24, 49)
- self.main = nn.Sequential(
- #7x7x1
- nn.Conv2d(1, 48, 3, padding=1),
- nn.Upsample(scale_factor=2),
- nn.BatchNorm2d(48, 0.4),
- nn.LeakyReLU(0.2),
- #14x14x32
- nn.Conv2d(48, 24, 3, padding=1),
- nn.Upsample(scale_factor=2),
- nn.BatchNorm2d(24, 0.4),
- nn.LeakyReLU(0.2),
- #28x28x16
- nn.Conv2d(24, 1, 3, padding=1, bias=False),
- #nn.LeakyReLU(0.2, inplace=True),
- #28x28x1
- )
- def forward(self, x):
- x = self.fc(x)
- x = x.view(-1, 1, 7, 7)
- x = self.main(x)
- return x
- def generate(self, gdevice):
- x = torch.randn(1, 24, dtype=torch.float, device=gdevice)
- return self.forward(x)
- #The Discriminator / Autoencoder
- class Discriminator(nn.Module):
- def __init__(self):
- super(Discriminator, self).__init__()
- self.encoder = nn.Sequential(
- #28x28x1
- nn.Conv2d(1, 48, 4, padding=1, stride=2),
- nn.BatchNorm2d(48, 0.4),
- nn.LeakyReLU(0.2),
- #14x14x32
- nn.Conv2d(48, 48, 4, padding=1, stride=2),
- nn.BatchNorm2d(48, 0.4),
- nn.LeakyReLU(0.2),
- #7x7x32
- nn.Conv2d(48, 4, 3, padding=1, bias=False),
- #7x7x4
- )
- self.fc = nn.Linear(196, 24)
- self.n = nn.LeakyReLU(0.2)
- self.fc2 = nn.Linear(24, 49)
- self.decoder = nn.Sequential(
- #7x7x1
- nn.Conv2d(1, 48, 3, padding=1),
- nn.Upsample(scale_factor=2),
- nn.BatchNorm2d(48, 0.4),
- nn.LeakyReLU(0.2),
- #14x14x32
- nn.Conv2d(48, 24, 3, padding=1),
- nn.Upsample(scale_factor=2),
- nn.BatchNorm2d(24, 0.4),
- nn.LeakyReLU(0.2),
- #28x28x16
- nn.Conv2d(24, 1, 3, padding=1, bias=False),
- #nn.LeakyReLU(0.2, inplace=True),
- #28x28x1
- )
- def forward(self, x):
- x = self.encoder(x)
- x = x.view(-1, 196)
- x = self.fc(x)
- x = self.n(x)
- x = self.fc2(x)
- x = x.view(-1, 1, 7, 7)
- x = self.decoder(x)
- return x
- gen = Generator()
- print(gen)
- print("\r\n\r\n\r\n")
- disc = Discriminator()
- print(disc)
- devn = "cuda:1"
- device = torch.device(devn)
- print(device)
- gen = gen.to(device)
- disc = disc.to(device)
- if devn != "cpu":
- gen = gen.cuda(device=device)
- disc = disc.cuda(device=device)
- def show_samples(gdevice):
- with torch.no_grad():
- fig = plt.figure(figsize=(8,8))
- for i in range(1, 17):
- img = gen.generate(gdevice).cpu().view(28, 28)
- img = img / 2 + 0.5
- npimg = img.numpy()
- fig.add_subplot(4, 4, i)
- plt.imshow(npimg, cmap="gray")
- plt.show()
- batch_size = 32
- #margin = max(1, batch_size / 64)
- margin = 20
- learning_rate = 1e-3
- epochs = 16
- transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
- trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
- optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate, betas=(0.5, 0.999))
- optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate, betas=(0.5, 0.999))
- pixel_loss = nn.MSELoss()
- #torch.autograd.set_detect_anomaly(True)
- for epoch in range(epochs):
- print("Epoch " + str(epoch + 1) + "/" + str(epochs))
- running_loss_disc = 0.0
- running_loss_gen = 0.0
- for i, data in enumerate(trainloader, 0):
- inputs = data[0].to(device)
- gen.zero_grad()
- fakeOut = gen.forward(torch.randn(batch_size, 24, dtype=torch.float, device=device))
- fake_energy = disc.forward(fakeOut)
- loss = pixel_loss(fakeOut.detach(), fake_energy)
- running_loss_gen += loss.item()
- loss.backward()
- optimizer_gen.step()
- disc.zero_grad()
- fake_energy = disc.forward(fakeOut.detach())
- real_energy = disc.forward(inputs)
- fake_loss = pixel_loss(fake_energy, fakeOut.detach())
- real_loss = pixel_loss(real_energy, inputs)
- if (margin - fake_loss.data).item() > 0:
- loss = (margin - fake_loss) + real_loss
- else:
- loss = real_loss
- running_loss_disc += loss.item()
- loss.backward()
- optimizer_disc.step()
- if i % 500 == 499:
- print('[%d, %5d] loss: %.3f %.3f' % (epoch + 1, i + 1, running_loss_disc / batch_size / 500, running_loss_gen / batch_size / 500))
- running_loss_disc = 0.0
- running_loss_gen = 0.0
- gen.eval()
- show_samples(device)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement