Guest User

Untitled

a guest
Oct 21st, 2018
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.12 KB | None | 0 0
  1. """Sample random numbers according to a uniformly sampled 2d probability
  2. density function.
  3.  
  4. The method used here probably is similar to rejection sampling.
  5. """
  6.  
  7. import numpy as np
  8. import sys
  9. from scipy.interpolate import interp2d, RectBivariateSpline, SmoothBivariateSpline
  10.  
  11. class sampler(object):
  12.  
  13. def __init__(self, x, y, z, m=0.95, cond=None):
  14. """Create a sampler object from data.
  15.  
  16. x,y : arrays
  17. 1d arrays for x and y data.
  18. z : array
  19. Samples PDF of shape [len(x), len(y)]. Does not need to be
  20. normalized correctly.
  21. m : float, optional
  22. Number in [0; 1). Used in renormalization of PDF. Random samples
  23. (x,y) will be accepted if PDF_renormalized(x, y) >= Random[0; 1).
  24. Low m values will create more values regions of low PDF.
  25. cond : function, optional
  26. A boolean function of x and y. True if the value in the x,y plane
  27. is of interest.
  28.  
  29. Notes
  30. -----
  31. To restrict x and y to the unit circle, use
  32. cond=lambda x,y: x**2 + y**2 <= 1.
  33.  
  34. For more information on the format of x, y, z see the docstring of
  35. scipy.interpolate.interp2d().
  36.  
  37. Note that interpolation can be very, very slow for larger z matrices
  38. (say > 100x100).
  39. """
  40.  
  41. # check validity of input:
  42. if(np.any(z < 0.0)):
  43. print >> sys.stderr("z has negative values and thus is not a density!")
  44. return
  45.  
  46. if(not 0.0 < m < 1.0):
  47. print >> sys.stderr("m has to be in (0; 1)!")
  48. return
  49.  
  50. maxVal = np.max(z)
  51. z *= m/maxVal # normalize maximum value in z to m
  52.  
  53. print("Preparing interpolating function")
  54. self._interp = RectBivariateSpline(x, y, z.transpose()) # TODO FIXME: why .transpose()?
  55. print("Interpolation done")
  56. self._xRange = (x[0], x[-1]) # set x and y ranges
  57. self._yRange = (y[0], y[-1])
  58.  
  59. self._cond = cond
  60.  
  61. def sample(self, size=1):
  62. """Sample a given number of random numbers with following given PDF.
  63.  
  64. Parameters
  65. ----------
  66. size : int
  67. Create this many random variates.
  68.  
  69. Returns
  70. -------
  71. vals : list
  72. List of tuples (x_i, y_i) of samples.
  73. """
  74.  
  75. vals = []
  76.  
  77. while(len(vals) < size):
  78.  
  79. # first create x and y samples in the allowed ranges (shift from [0, 1)
  80. # to [min, max))
  81. while(True):
  82. x, y = np.random.rand(2)
  83. x = (self._xRange[1]-self._xRange[0])*x + self._xRange[0]
  84. y = (self._yRange[1]-self._yRange[0])*y + self._yRange[0]
  85.  
  86. # additional condition true? --> use these values
  87. if(self._cond is not None):
  88. if(self._cond(x, y)):
  89. break
  90. else:
  91. continue
  92. else: # no condition -> use values immediately
  93. break
  94.  
  95. # to decide if the values are to be kept, sample the PDF there and
  96. # decide about rejection
  97. chance = np.random.ranf()
  98. PDFsample = self._interp(x, y)
  99.  
  100. # keep or reject sample? if at (x,y) the renormalized PDF is >= than
  101. # the random number generated, keep the sample
  102. if(PDFsample >= chance):
  103. vals.append((x, y))
  104.  
  105. return vals
  106.  
  107. if(__name__ == '__main__'): # test with an illustrative plot
  108.  
  109. # create a sin^2*Gaussian PDF on the unit square and create random variates
  110. # inside the upper half of a disk centered on the middle of the square
  111.  
  112. import matplotlib.pyplot as plt
  113. gridSamples = 1024
  114. x = np.linspace(0, 1., gridSamples)
  115. y = np.linspace(0, 1., gridSamples)
  116. XX, YY = np.meshgrid(x, y)
  117. # sample a sin^2*cos^2*Gaussian PDF (not normalized)
  118. z = np.exp(-(XX-0.5)**2/(2*0.2**2) -(YY-0.5)**2/(2*0.1**2) )*np.sin(2*np.pi*XX)**2*np.cos(4*np.pi*(YY+XX))**2
  119.  
  120. s = sampler(x, y, z, cond=lambda x,y: (x-0.5)**2 + (y-0.5)**2 <= 0.4**2 and y > 0.5)
  121.  
  122. vals = s.sample(5000); xVals = []; yVals = [];
  123.  
  124. # plot sampled random variates over PDF
  125. plt.imshow(z, cmap=plt.cm.Blues, origin="lower",
  126. extent=(s._xRange[0], s._xRange[1], s._yRange[0], s._yRange[1]),
  127. aspect="equal")
  128. for item in vals: # plot point by point
  129. xVals.append(item[0])
  130. yVals.append(item[1])
  131. plt.scatter(item[0], item[1], marker="x", c="red")
  132. plt.show()
  133.  
  134. # create a histogram/density plot for random variates and plot over PDF
  135. hist, bla, blubb = np.histogram2d(xVals, yVals, bins=100, normed=True, range=((s._xRange[0], s._xRange[1]), (s._yRange[0], s._yRange[1])))
  136. plt.imshow(z, cmap=plt.cm.Blues, extent=(s._xRange[0], s._xRange[1], s._yRange[0], s._yRange[1]), aspect="equal", origin="lower")
  137. plt.title("'PDF'")
  138. plt.show()
  139. # have to plot transpose of hist because of NumPy's convention for histogram2d
  140. plt.imshow(hist.transpose(), cmap=plt.cm.Reds, extent=(s._xRange[0], s._xRange[1], s._yRange[0], s._yRange[1]), aspect="equal", origin="lower")
  141. plt.title("Density of random variates")
  142. plt.show()
Add Comment
Please, Sign In to add comment