kowshik0808

Untitled

Jun 10th, 2018
120
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.94 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pylab as pl
  3. import ot
  4. # necessary for 3d plot even if not used
  5. from mpl_toolkits.mplot3d import Axes3D  # noqa
  6. from matplotlib.collections import PolyCollection
  7.  
  8. import os
  9.  
  10. ##############################################################################
  11. # Generate data
  12. # -------------
  13.  
  14. #%% parameters
  15.  
  16. n = 200  # nb bins
  17.  
  18. # bin positions
  19. x = np.arange(n, dtype=np.float64)
  20.  
  21. # Gaussian distributions
  22. a1 = ot.datasets.get_1D_gauss(n, m=10, s=5)  # m= mean, s= std
  23.  
  24.  
  25. a2 = ot.datasets.get_1D_gauss(n, m=60, s=5)
  26.  
  27. # creating matrix A - a bimodal distriution
  28. A = a1+0.5*a2
  29.  
  30. b1 = ot.datasets.get_1D_gauss(n, m=100, s=8)  # m= mean, s= std
  31.  
  32.  
  33. b2 = ot.datasets.get_1D_gauss(n, m=150, s=8)
  34.  
  35. # creating matrix B another bimodal
  36. B=0.5*b1+b2
  37. distributions=np.vstack((A,B)).T
  38.  
  39. n_distributions = distributions.shape[1]
  40.  
  41. # loss matrix + normalization
  42. M =ot.dist(x.reshape(n,1),x.reshape(n,1),metric= 'sqeuclidean')
  43. M /= M.max()
  44.  
  45. ##############################################################################
  46. # Plot data
  47. # ---------
  48.  
  49. #%% plot the distributions
  50.  
  51. pl.figure(1, figsize=(6.4, 3))
  52. for i in range(n_distributions):
  53.     pl.plot(x, distributions[:,i])
  54. pl.title('Distributions')
  55. pl.tight_layout()
  56.  
  57.  
  58.  
  59.  
  60. ##############################################################################
  61. # Barycenter computation
  62. # ----------------------
  63.  
  64. #%% barycenter computation
  65.  
  66. alpha = 0.5  # 0<=alpha<=1
  67. weights = np.array([1 - alpha, alpha])
  68.  
  69. # l2bary
  70. bary_l2 = distributions.dot(weights)
  71.  
  72. # wasserstein
  73. reg = 0.001
  74. bary_wass = ot.bregman.barycenter(distributions, M, reg, weights)
  75.  
  76. pl.figure(2)
  77. pl.clf()
  78. pl.subplot(2, 1, 1)
  79. for i in range(n_distributions):
  80.     pl.plot(x, distributions[:, i])
  81. pl.title('Distributions')
  82.  
  83. pl.subplot(2, 1, 2)
  84. pl.plot(x, bary_l2, 'r', label='l2')
  85. pl.plot(x, bary_wass, 'g', label='Wasserstein')
  86. pl.legend()
  87. pl.title('Barycenters')
  88. pl.tight_layout()
Add Comment
Please, Sign In to add comment