Guest User

Untitled

a guest
Dec 13th, 2017
91
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.22 KB | None | 0 0
  1. %matplotlib notebook
  2. from matplotlib import pyplot
  3. import torch
  4. import math
  5. import numpy
  6.  
  7. from torch.nn import Parameter, Module
  8.  
  9. def dirichlet_log_pdf(pi, alpha):
  10. numel = torch.lgamma(alpha.sum(0)) + torch.sum(torch.log(pi) * (alpha - 1.0))
  11. denom = torch.sum(torch.lgamma(alpha))
  12. return numel - denom
  13.  
  14.  
  15. def normal_log_pdf(xs, means, cov):
  16. n_batch, n_dim = xs.size()
  17. n_component = means.size(0)
  18. assert isinstance(cov, float)
  19. xs_ms = xs.unsqueeze(1) - means.unsqueeze(0)
  20. coeff = - n_dim * math.log(2 * math.pi) - math.log(cov)
  21. xms = xs_ms.view(n_batch * n_component, n_dim, 1)
  22. pdfs = coeff + (-0.5 * xms.transpose(1, 2).bmm(xms) / cov)
  23. return pdfs.view(n_batch, n_component)
  24.  
  25.  
  26. class GMMSampler(Module):
  27. def __init__(self, n_dim, n_component):
  28. super().__init__()
  29. self.n_dim = n_dim
  30. self.n_component = n_component
  31. self.means = Parameter(torch.randn(n_component, n_dim))
  32. self.cov_of_mean = 1.0
  33. self.mean_of_mean = Parameter(torch.zeros(n_dim))
  34. self.log_prior = Parameter(torch.log(torch.ones(n_component) / n_component))
  35.  
  36. def select_k(self, xs, ids, k):
  37. mask = ids == k
  38. mask = mask.expand(self.n_dim, xs.size(0)).transpose(0, 1)
  39. return torch.masked_select(xs, mask).view(-1, self.n_dim)
  40.  
  41. def joint_prob(self, xs, ids):
  42. px = normal_log_pdf(xs, self.means.data, 1.0).exp()
  43. px_k = px[torch.arange(0, xs.size(0), out=xs.new().long()), ids]
  44. log_pxz = torch.sum(px_k.log() + self.log_prior.data.index_select(0, ids), dim=0)
  45. p_mean = normal_log_pdf(self.means.data,
  46. self.mean_of_mean.data.unsqueeze(0),
  47. self.cov_of_mean)
  48. log_p_mean = torch.sum(p_mean, dim=0)
  49. return (log_pxz + log_p_mean)[0]
  50.  
  51. def sample(self, xs, n_iter=1):
  52. assert xs.size(1) == self.n_dim
  53. for i in range(n_iter):
  54. pdfs = normal_log_pdf(xs, self.means.data, 1.0).exp()
  55. pdfs /= pdfs.sum(dim=1, keepdim=True)
  56. component_ids = torch.multinomial(pdfs, 1).squeeze(1)
  57. for k in range(self.n_component):
  58. x_k = self.select_k(xs, component_ids, k)
  59. n_k = 0 if x_k.dim() == 0 else x_k.size(0)
  60. x_mean_k = torch.mean(x_k, dim=0)
  61. self.means.data[k] = torch.normal(n_k / (n_k + 1) * x_mean_k, 1.0 / (n_k + 1))
  62. print(self.joint_prob(xs, component_ids))
  63.  
  64. def test_select_k_th():
  65. n_dim = 2
  66. gmm = GMMSampler(n_dim, 3)
  67. n_batch = 10
  68. xs = torch.randn(n_batch, n_dim)
  69. ids = torch.zeros(n_batch)
  70. ids[0] = 1
  71. ids[2] = 1
  72. ids[-1] = 1
  73. ys = gmm.select_k(xs, ids, 1)
  74. assert torch.equal(ys[0], xs[0])
  75. assert torch.equal(ys[1], xs[2])
  76. assert torch.equal(ys[2], xs[-1])
  77.  
  78. test_select_k_th()
  79.  
  80. use_cuda = False
  81. n = 10
  82. x1 = torch.randn(n, 2) + torch.FloatTensor([[0.0, 5.0]])
  83. x2 = torch.randn(n, 2) + torch.FloatTensor([[5.0, 0.0]])
  84. x3 = torch.randn(n, 2) + torch.FloatTensor([[0.0, -5.0]])
  85. for x in [x1, x2, x3]:
  86. pyplot.scatter(x[:, 0].numpy(), x[:, 1].numpy())
  87.  
  88. xs = torch.cat([x1, x2, x3], dim=0)
  89.  
  90. gmm = GMMSampler(2, 3)
  91. if use_cuda:
  92. gmm.cuda()
  93. xs = xs.cuda()
  94. gmm.sample(xs, 10)
  95. for m in gmm.means:
  96. pyplot.scatter(m[0], m[1])
Add Comment
Please, Sign In to add comment