Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- %matplotlib notebook
- from matplotlib import pyplot
- import torch
- import math
- import numpy
- from torch.nn import Parameter, Module
- def dirichlet_log_pdf(pi, alpha):
- numel = torch.lgamma(alpha.sum(0)) + torch.sum(torch.log(pi) * (alpha - 1.0))
- denom = torch.sum(torch.lgamma(alpha))
- return numel - denom
- def normal_log_pdf(xs, means, cov):
- n_batch, n_dim = xs.size()
- n_component = means.size(0)
- assert isinstance(cov, float)
- xs_ms = xs.unsqueeze(1) - means.unsqueeze(0)
- coeff = - n_dim * math.log(2 * math.pi) - math.log(cov)
- xms = xs_ms.view(n_batch * n_component, n_dim, 1)
- pdfs = coeff + (-0.5 * xms.transpose(1, 2).bmm(xms) / cov)
- return pdfs.view(n_batch, n_component)
- class GMMSampler(Module):
- def __init__(self, n_dim, n_component):
- super().__init__()
- self.n_dim = n_dim
- self.n_component = n_component
- self.means = Parameter(torch.randn(n_component, n_dim))
- self.cov_of_mean = 1.0
- self.mean_of_mean = Parameter(torch.zeros(n_dim))
- self.log_prior = Parameter(torch.log(torch.ones(n_component) / n_component))
- def select_k(self, xs, ids, k):
- mask = ids == k
- mask = mask.expand(self.n_dim, xs.size(0)).transpose(0, 1)
- return torch.masked_select(xs, mask).view(-1, self.n_dim)
- def joint_prob(self, xs, ids):
- px = normal_log_pdf(xs, self.means.data, 1.0).exp()
- px_k = px[torch.arange(0, xs.size(0), out=xs.new().long()), ids]
- log_pxz = torch.sum(px_k.log() + self.log_prior.data.index_select(0, ids), dim=0)
- p_mean = normal_log_pdf(self.means.data,
- self.mean_of_mean.data.unsqueeze(0),
- self.cov_of_mean)
- log_p_mean = torch.sum(p_mean, dim=0)
- return (log_pxz + log_p_mean)[0]
- def sample(self, xs, n_iter=1):
- assert xs.size(1) == self.n_dim
- for i in range(n_iter):
- pdfs = normal_log_pdf(xs, self.means.data, 1.0).exp()
- pdfs /= pdfs.sum(dim=1, keepdim=True)
- component_ids = torch.multinomial(pdfs, 1).squeeze(1)
- for k in range(self.n_component):
- x_k = self.select_k(xs, component_ids, k)
- n_k = 0 if x_k.dim() == 0 else x_k.size(0)
- x_mean_k = torch.mean(x_k, dim=0)
- self.means.data[k] = torch.normal(n_k / (n_k + 1) * x_mean_k, 1.0 / (n_k + 1))
- print(self.joint_prob(xs, component_ids))
- def test_select_k_th():
- n_dim = 2
- gmm = GMMSampler(n_dim, 3)
- n_batch = 10
- xs = torch.randn(n_batch, n_dim)
- ids = torch.zeros(n_batch)
- ids[0] = 1
- ids[2] = 1
- ids[-1] = 1
- ys = gmm.select_k(xs, ids, 1)
- assert torch.equal(ys[0], xs[0])
- assert torch.equal(ys[1], xs[2])
- assert torch.equal(ys[2], xs[-1])
- test_select_k_th()
- use_cuda = False
- n = 10
- x1 = torch.randn(n, 2) + torch.FloatTensor([[0.0, 5.0]])
- x2 = torch.randn(n, 2) + torch.FloatTensor([[5.0, 0.0]])
- x3 = torch.randn(n, 2) + torch.FloatTensor([[0.0, -5.0]])
- for x in [x1, x2, x3]:
- pyplot.scatter(x[:, 0].numpy(), x[:, 1].numpy())
- xs = torch.cat([x1, x2, x3], dim=0)
- gmm = GMMSampler(2, 3)
- if use_cuda:
- gmm.cuda()
- xs = xs.cuda()
- gmm.sample(xs, 10)
- for m in gmm.means:
- pyplot.scatter(m[0], m[1])
Add Comment
Please, Sign In to add comment