# Code for "Multimodal Noise and Covering Initializations"

lopezpaz Feb 17th, 2017 (edited) 136 Never
1. # %matplotlib inline
2. from matplotlib import pyplot as plt
3. from IPython import display
5. import torch
6. import numpy as np
7.
8. def sample_real(n=128, k=8, std=0.01):
9.     t = np.linspace(0, 2*np.pi, k)
10.     m = np.vstack((np.sin(t), np.cos(t))).T
11.     i = np.random.randint(m.shape[0], size=n)
12.     return Variable(torch.Tensor(np.random.randn(n, 2)*std+m[i]))
13.
14. def sample_noise(n=128, d=2):
15.     con = torch.rand(n, d/2)
16.     cat = torch.Tensor(np.random.multinomial(1, np.ones(d/2)/(d/2), n))
17.     # return Variable(torch.rand(n,d)) # un-comment for unimodal noise
18.     return Variable(torch.cat((con, cat), 1))
19.
20. def sample_batch(bs, netG, who):
21.     b_fake = netG(sample_noise(bs, netG.input_dim))
22.     if who == 'D':
23.         batch = torch.cat((sample_real(bs), b_fake.detach()))
24.         label = Variable(torch.cat((torch.ones(bs), torch.zeros(bs))))
25.     else:
26.         batch = b_fake
27.         label = Variable(torch.ones(bs))
28.     return batch, label
29.
30. class MyNet(torch.nn.Module):
31.     def __init__(self, n_input, n_hidden, n_output, act, out_act):
32.         super(MyNet, self).__init__()
33.         self.linear_in = torch.nn.Linear(n_input, n_hidden)
34.         self.linear_out = torch.nn.Linear(n_hidden, n_output)
35.         self.act = act
36.         self.out_act = out_act
37.         self.input_dim = n_input
38.
39.     def forward(self, x):
40.         y = self.linear_in(x)
41.         y = self.act(y)
42.         y = self.linear_out(y)
43.         return self.out_act(y)
44.
45. def precondition_gan(net, max_iter=1000, n_real=10000, n_fake=1000):
46.     def cov(x):
47.         n = x.size(0)
48.         m = x.mean(0)
49.         c = x-m.repeat(n, 1)
51.
52.     def stats(x):
54.
56.     sss = stats(sample_real(n_real))
57.     mse = torch.nn.MSELoss()
58.
59.     for i in range(max_iter):
60.         stats_fake = stats(net(sample_noise(n_fake, net.input_dim)))
62.         mse(stats_fake, sss).backward()
63.         opt.step()
64.     return net
65.
66. def plot_gan(netD, netG, n=1000):
67.     xmin, xmax, ymin, ymax = -1.5, 1.5, -1.5, 1.5
68.     pfake = netG(sample_noise(n, netG.input_dim)).data.numpy()
69.     preal = sample_real(n).data.numpy()
70.     step = 20
71.
72.     plt.plot(preal[:, 0], preal[:, 1], '.', label='real', alpha=0.5)
73.     plt.plot(pfake[:, 0], pfake[:, 1], '.', label='fake', alpha=0.5)
74.
75.     grid = torch.zeros((step, step))
76.     elem = torch.Tensor(1, 2)
77.     for ki, vi in enumerate(torch.linspace(xmin, xmax, step)):
78.         for kj, vj in enumerate(torch.linspace(xmin, xmax, step)):
79.             elem[0][0] = vi
80.             elem[0][1] = vj
81.             grid[ki][kj] = netD(Variable(elem)).data[0][0]
82.
83.     plt.imshow(np.flipud(grid.numpy().T), extent=[xmin, xmax, ymin, ymax],
84.                vmin=0, vmax=1, aspect='auto', cmap="gray")
85.     plt.ylim(ymin, ymax)
86.     plt.xlim(xmin, xmax)
87.     plt.axis('off')
88.     plt.show()
89.
90. ### MAIN ##########################################################
91.
92. G_epochs      = 1000 # number of iterations
93. G_batchsize   = 128  # batchsize
94. G_hiddens     = 128  # number of hidden neurons
95. G_d           = 20   # dimensionality of noise
96. G_extra_discr = 10   # number of iterations for discriminator
97.
98. netD = MyNet(2, G_hiddens, 1, torch.nn.ReLU(), torch.nn.Sigmoid())
99. netG = MyNet(G_d, G_hiddens, 2, torch.nn.ReLU(), lambda x: x)ll
100.
101. plot_gan(netD, netG)
102. netG = precondition_gan(netG) # comment to remove pre-conditioning
103. plot_gan(netD, netG)
104.
107.
108. logD = np.zeros(G_epochs)
109. logG = np.zeros(G_epochs)
110.
111. criterion = torch.nn.BCELoss()
112.
113. for epoch in range(G_epochs):
114.     for k in range(G_extra_discr):
115.         batch_d, label_d = sample_batch(G_batchsize, netG, 'D')
117.         errD = criterion(netD(batch_d), label_d)
118.         errD.backward()
119.         optD.step()
120.     logD[epoch] = errD.data[0]
121.
122.     batch_g, label_g = sample_batch(G_batchsize, netG, 'G')