Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # %matplotlib inline
- from matplotlib import pyplot as plt
- from IPython import display
- from torch.autograd import Variable
- import torch
- import numpy as np
- def sample_real(n=128, k=8, std=0.01):
- t = np.linspace(0, 2*np.pi, k)
- m = np.vstack((np.sin(t), np.cos(t))).T
- i = np.random.randint(m.shape[0], size=n)
- return Variable(torch.Tensor(np.random.randn(n, 2)*std+m[i]))
- def sample_noise(n=128, d=2):
- con = torch.rand(n, d/2)
- cat = torch.Tensor(np.random.multinomial(1, np.ones(d/2)/(d/2), n))
- # return Variable(torch.rand(n,d)) # un-comment for unimodal noise
- return Variable(torch.cat((con, cat), 1))
- def sample_batch(bs, netG, who):
- b_fake = netG(sample_noise(bs, netG.input_dim))
- if who == 'D':
- batch = torch.cat((sample_real(bs), b_fake.detach()))
- label = Variable(torch.cat((torch.ones(bs), torch.zeros(bs))))
- else:
- batch = b_fake
- label = Variable(torch.ones(bs))
- return batch, label
- class MyNet(torch.nn.Module):
- def __init__(self, n_input, n_hidden, n_output, act, out_act):
- super(MyNet, self).__init__()
- self.linear_in = torch.nn.Linear(n_input, n_hidden)
- self.linear_out = torch.nn.Linear(n_hidden, n_output)
- self.act = act
- self.out_act = out_act
- self.input_dim = n_input
- def forward(self, x):
- y = self.linear_in(x)
- y = self.act(y)
- y = self.linear_out(y)
- return self.out_act(y)
- def precondition_gan(net, max_iter=1000, n_real=10000, n_fake=1000):
- def cov(x):
- n = x.size(0)
- m = x.mean(0)
- c = x-m.repeat(n, 1)
- return torch.mm(c.t(), c)/(n-1)
- def stats(x):
- return torch.cat((x.mean(0), cov(x).view(1, -1)), 1)
- opt = torch.optim.Adam(net.parameters())
- sss = stats(sample_real(n_real))
- mse = torch.nn.MSELoss()
- for i in range(max_iter):
- stats_fake = stats(net(sample_noise(n_fake, net.input_dim)))
- net.zero_grad()
- mse(stats_fake, sss).backward()
- opt.step()
- return net
- def plot_gan(netD, netG, n=1000):
- xmin, xmax, ymin, ymax = -1.5, 1.5, -1.5, 1.5
- pfake = netG(sample_noise(n, netG.input_dim)).data.numpy()
- preal = sample_real(n).data.numpy()
- step = 20
- plt.plot(preal[:, 0], preal[:, 1], '.', label='real', alpha=0.5)
- plt.plot(pfake[:, 0], pfake[:, 1], '.', label='fake', alpha=0.5)
- grid = torch.zeros((step, step))
- elem = torch.Tensor(1, 2)
- for ki, vi in enumerate(torch.linspace(xmin, xmax, step)):
- for kj, vj in enumerate(torch.linspace(xmin, xmax, step)):
- elem[0][0] = vi
- elem[0][1] = vj
- grid[ki][kj] = netD(Variable(elem)).data[0][0]
- plt.imshow(np.flipud(grid.numpy().T), extent=[xmin, xmax, ymin, ymax],
- vmin=0, vmax=1, aspect='auto', cmap="gray")
- plt.ylim(ymin, ymax)
- plt.xlim(xmin, xmax)
- plt.axis('off')
- plt.show()
- ### MAIN ##########################################################
- G_epochs = 1000 # number of iterations
- G_batchsize = 128 # batchsize
- G_hiddens = 128 # number of hidden neurons
- G_d = 20 # dimensionality of noise
- G_extra_discr = 10 # number of iterations for discriminator
- netD = MyNet(2, G_hiddens, 1, torch.nn.ReLU(), torch.nn.Sigmoid())
- netG = MyNet(G_d, G_hiddens, 2, torch.nn.ReLU(), lambda x: x)ll
- plot_gan(netD, netG)
- netG = precondition_gan(netG) # comment to remove pre-conditioning
- plot_gan(netD, netG)
- optD = torch.optim.Adam(netD.parameters())
- optG = torch.optim.Adam(netG.parameters())
- logD = np.zeros(G_epochs)
- logG = np.zeros(G_epochs)
- criterion = torch.nn.BCELoss()
- for epoch in range(G_epochs):
- for k in range(G_extra_discr):
- batch_d, label_d = sample_batch(G_batchsize, netG, 'D')
- netD.zero_grad()
- errD = criterion(netD(batch_d), label_d)
- errD.backward()
- optD.step()
- logD[epoch] = errD.data[0]
- batch_g, label_g = sample_batch(G_batchsize, netG, 'G')
- netG.zero_grad()
- errG = criterion(netD(batch_g), label_g)
- errG.backward()
- optG.step()
- logG[epoch] = errG.data[0]
- plot_gan(netD, netG)
Add Comment
Please, Sign In to add comment