lopezpaz

Code for "Multimodal Noise and Covering Initializations"

Feb 17th, 2017
235
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.17 KB | None | 0 0
  1. # %matplotlib inline
  2. from matplotlib import pyplot as plt
  3. from IPython import display
  4. from torch.autograd import Variable
  5. import torch
  6. import numpy as np
  7.  
  8. def sample_real(n=128, k=8, std=0.01):
  9.     t = np.linspace(0, 2*np.pi, k)
  10.     m = np.vstack((np.sin(t), np.cos(t))).T
  11.     i = np.random.randint(m.shape[0], size=n)
  12.     return Variable(torch.Tensor(np.random.randn(n, 2)*std+m[i]))
  13.  
  14. def sample_noise(n=128, d=2):
  15.     con = torch.rand(n, d/2)
  16.     cat = torch.Tensor(np.random.multinomial(1, np.ones(d/2)/(d/2), n))
  17.     # return Variable(torch.rand(n,d)) # un-comment for unimodal noise
  18.     return Variable(torch.cat((con, cat), 1))
  19.  
  20. def sample_batch(bs, netG, who):
  21.     b_fake = netG(sample_noise(bs, netG.input_dim))
  22.     if who == 'D':
  23.         batch = torch.cat((sample_real(bs), b_fake.detach()))
  24.         label = Variable(torch.cat((torch.ones(bs), torch.zeros(bs))))
  25.     else:
  26.         batch = b_fake
  27.         label = Variable(torch.ones(bs))
  28.     return batch, label
  29.  
  30. class MyNet(torch.nn.Module):
  31.     def __init__(self, n_input, n_hidden, n_output, act, out_act):
  32.         super(MyNet, self).__init__()
  33.         self.linear_in = torch.nn.Linear(n_input, n_hidden)
  34.         self.linear_out = torch.nn.Linear(n_hidden, n_output)
  35.         self.act = act
  36.         self.out_act = out_act
  37.         self.input_dim = n_input
  38.  
  39.     def forward(self, x):
  40.         y = self.linear_in(x)
  41.         y = self.act(y)
  42.         y = self.linear_out(y)
  43.         return self.out_act(y)
  44.  
  45. def precondition_gan(net, max_iter=1000, n_real=10000, n_fake=1000):
  46.     def cov(x):
  47.         n = x.size(0)
  48.         m = x.mean(0)
  49.         c = x-m.repeat(n, 1)
  50.         return torch.mm(c.t(), c)/(n-1)
  51.  
  52.     def stats(x):
  53.         return torch.cat((x.mean(0), cov(x).view(1, -1)), 1)
  54.  
  55.     opt = torch.optim.Adam(net.parameters())
  56.     sss = stats(sample_real(n_real))
  57.     mse = torch.nn.MSELoss()
  58.  
  59.     for i in range(max_iter):
  60.         stats_fake = stats(net(sample_noise(n_fake, net.input_dim)))
  61.         net.zero_grad()
  62.         mse(stats_fake, sss).backward()
  63.         opt.step()
  64.     return net
  65.  
  66. def plot_gan(netD, netG, n=1000):
  67.     xmin, xmax, ymin, ymax = -1.5, 1.5, -1.5, 1.5
  68.     pfake = netG(sample_noise(n, netG.input_dim)).data.numpy()
  69.     preal = sample_real(n).data.numpy()
  70.     step = 20
  71.    
  72.     plt.plot(preal[:, 0], preal[:, 1], '.', label='real', alpha=0.5)
  73.     plt.plot(pfake[:, 0], pfake[:, 1], '.', label='fake', alpha=0.5)
  74.    
  75.     grid = torch.zeros((step, step))
  76.     elem = torch.Tensor(1, 2)
  77.     for ki, vi in enumerate(torch.linspace(xmin, xmax, step)):
  78.         for kj, vj in enumerate(torch.linspace(xmin, xmax, step)):
  79.             elem[0][0] = vi
  80.             elem[0][1] = vj
  81.             grid[ki][kj] = netD(Variable(elem)).data[0][0]
  82.  
  83.     plt.imshow(np.flipud(grid.numpy().T), extent=[xmin, xmax, ymin, ymax],
  84.                vmin=0, vmax=1, aspect='auto', cmap="gray")
  85.     plt.ylim(ymin, ymax)
  86.     plt.xlim(xmin, xmax)
  87.     plt.axis('off')
  88.     plt.show()
  89.  
  90. ### MAIN ##########################################################
  91.  
  92. G_epochs      = 1000 # number of iterations
  93. G_batchsize   = 128  # batchsize
  94. G_hiddens     = 128  # number of hidden neurons
  95. G_d           = 20   # dimensionality of noise
  96. G_extra_discr = 10   # number of iterations for discriminator
  97.  
  98. netD = MyNet(2, G_hiddens, 1, torch.nn.ReLU(), torch.nn.Sigmoid())
  99. netG = MyNet(G_d, G_hiddens, 2, torch.nn.ReLU(), lambda x: x)ll
  100.  
  101. plot_gan(netD, netG)
  102. netG = precondition_gan(netG) # comment to remove pre-conditioning
  103. plot_gan(netD, netG)
  104.  
  105. optD = torch.optim.Adam(netD.parameters())
  106. optG = torch.optim.Adam(netG.parameters())
  107.  
  108. logD = np.zeros(G_epochs)
  109. logG = np.zeros(G_epochs)
  110.  
  111. criterion = torch.nn.BCELoss()
  112.  
  113. for epoch in range(G_epochs):
  114.     for k in range(G_extra_discr):
  115.         batch_d, label_d = sample_batch(G_batchsize, netG, 'D')
  116.         netD.zero_grad()
  117.         errD = criterion(netD(batch_d), label_d)
  118.         errD.backward()
  119.         optD.step()
  120.     logD[epoch] = errD.data[0]
  121.  
  122.     batch_g, label_g = sample_batch(G_batchsize, netG, 'G')
  123.     netG.zero_grad()
  124.     errG = criterion(netD(batch_g), label_g)
  125.     errG.backward()
  126.     optG.step()
  127.     logG[epoch] = errG.data[0]
  128.  
  129. plot_gan(netD, netG)
Add Comment
Please, Sign In to add comment