Advertisement
Guest User

Untitled

a guest
May 3rd, 2021
68
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.64 KB | None | 0 0
  1. import torch
  2. from matplotlib import pyplot as plt
  3. import numpy as np
  4.  
  5.  
  6. NSAMPLE = 1000
  7. x_data = np.float32(np.random.uniform(-10.5, 10.5, (1, NSAMPLE))).T
  8. r_data = np.float32(np.random.normal(size=(NSAMPLE,1)))
  9. y_data = np.float32(np.sin(0.75*x_data)*7.0+x_data*0.5+r_data*1.0)
  10.  
  11. # plt.figure(figsize=(8, 8))
  12. # plt.title("Original")
  13. # plot_out = plt.plot(x_data,y_data,'ro',alpha=0.3)
  14.  
  15.  
  16. temp_data = x_data
  17. x_data = y_data
  18. y_data = temp_data
  19.  
  20. from torch import nn
  21.  
  22. # the model must output
  23. num_mixtures = 128
  24.  
  25. hidden_dim = 64
  26. output_dim = num_mixtures * 3 # categorical_logit, gaussian_mean, gaussian_stddev
  27. model = torch.nn.Sequential(nn.Linear(1, hidden_dim), nn.ReLU(), \
  28. nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim))
  29.  
  30.  
  31. import torch.distributions as D
  32. from torch.distributions.mixture_same_family import MixtureSameFamily
  33.  
  34. optimizer = torch.optim.Adam(params=model.parameters())
  35. torch.distributions.Distribution.set_default_validate_args(True)
  36.  
  37. def process_categorical_logits(cl):
  38. # get_mixture_coef: https://blog.otoro.net/2015/11/24/mixture-density-networks-with-tensorflow/
  39. # make logsumexp numerically stable by normalizing the logits by the max value
  40. out_pi = torch.exp(cl - cl.max(dim=-1)[0].unsqueeze(-1))
  41. out_pi = (out_pi / out_pi.max(dim=-1)[0].unsqueeze(-1))
  42. return out_pi
  43.  
  44.  
  45. for i in range(1000):
  46. categorical_logits, gaussian_means, gaussian_logstddevs = model(torch.as_tensor(x_data)).chunk(3, dim=-1)
  47.  
  48. print(process_categorical_logits(categorical_logits).max())
  49. print(process_categorical_logits(categorical_logits).min())
  50.  
  51. mix = D.Categorical(probs=process_categorical_logits(categorical_logits))
  52. comp = D.Normal(loc=gaussian_means, scale=torch.exp(gaussian_logstddevs))
  53. gmm = MixtureSameFamily(mix, comp)
  54.  
  55. loss = -torch.mean(gmm.log_prob(torch.as_tensor(y_data.squeeze())))
  56.  
  57. optimizer.zero_grad()
  58. loss.backward()
  59. optimizer.step()
  60. print(f"Loss {loss}")
  61.  
  62.  
  63. x_test = np.float32(np.arange(-15,15,0.01))
  64. NTEST = x_test.size
  65. x_test = x_test.reshape(NTEST,1) # needs to be a matrix, not a vector
  66.  
  67. def generate_y_data(x_data, model):
  68. categorical_logits, gaussian_means, gaussian_logstddevs = model(torch.as_tensor(x_data)).chunk(3, dim=-1)
  69.  
  70.  
  71. mix = D.Categorical(logits=categorical_logits)
  72. comp = D.Normal(loc=gaussian_means, scale=torch.exp(gaussian_logstddevs))
  73. gmm = MixtureSameFamily(mix, comp)
  74. return gmm.sample()
  75.  
  76.  
  77. plt.figure(figsize=(8, 8))
  78. plt.title("Inverted")
  79. plot_out = plt.plot(x_data,y_data,'ro', x_test, generate_y_data(x_test, model),'bo',alpha=0.3)
  80. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement