Advertisement
Guest User

Untitled

a guest
Oct 19th, 2017
58
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 21.79 KB | None | 0 0
  1. def _step_left_endpoint(obj_func, y, L, J, w, m):
  2. while y < obj_func(L):
  3. L -= w
  4. if m:
  5. J -= 1
  6. if J <= 0:
  7. break
  8. return L
  9. def _step_right_endpoint(obj_func, y, R, K, w, m):
  10. while y < obj_func(R):
  11. R += w
  12. if m:
  13. K -= 1
  14. if K <= 0:
  15. break
  16. return R
  17.  
  18. class slice_sampler(object):
  19. """ Generic slice-sampler.
  20. Base class for both univariate and multivariate methods
  21.  
  22. See http://www.cs.toronto.edu/~radford/ftp/slc-samp.pdf for theory
  23. and details
  24. """
  25. def __init__(self, log_f):
  26. """ Instantiates a slice sampler
  27.  
  28. Parameters
  29. ----------
  30. log_f : callable
  31. A function that returns the log of the function you want to
  32. sample and accepts a numpy array as an argument (the x)
  33. """
  34. if not hasattr(log_f, '__call__'):
  35. raise TypeError("log_f is not callable")
  36. ## The function to sample
  37. self._g = log_f
  38. def set_function(self, log_f):
  39. """ Sets the function being sampled
  40.  
  41. Parameters
  42. ----------
  43. log_f : callable
  44. A function that returns the log of the function you want to
  45. sample and accepts a numpy array (or scalar if univariate function)
  46. as an argument (the x)
  47. """
  48. self._g = log_f
  49.  
  50. class univariate_slice_sampler(slice_sampler):
  51. """ Does slice sampling on univariate distributions
  52. """
  53. def __init__(self, log_f, adapt_w = False, start_w = 0.1):
  54. """ Instantiates a slice sampler
  55.  
  56. Parameters
  57. ----------
  58. log_f : callable
  59. A function that returns the log of the function you want to
  60. sample and accepts a scalar as an argument (the x)
  61. adapt_w : boolean
  62. Whether to adapt w during sampling. Will work in between samplings
  63. start_w : float
  64. Starting value for w, necessary if adapting w during sampling
  65. """
  66. super(univariate_slice_sampler, self).__init__(log_f)
  67. self._w = start_w
  68. self._adapt_w = adapt_w
  69. ## Parameter for number of data points going into w
  70. self._w_n = 1.
  71. def accept_doubling(self, x0, x1, y, w, interval):
  72. """ Test for whether a new point, x1, is an acceptable next state,
  73. when the interval was found by the doubling procedure
  74. (See Radford paper at http://www.cs.toronto.edu/~radford/ftp/slc-samp.pdf)
  75.  
  76. Parameters
  77. ----------
  78. x0 : float
  79. The current point
  80. x1 : float
  81. The possible next point
  82. y : float
  83. The vertical level defining the slice
  84. w : float
  85. Estimte of typical slice size
  86. interval : Ranger.Range
  87. The interval found be the doubling procedure
  88.  
  89. Returns
  90. -------
  91. Whether or not x1 is an acceptable next state
  92. """
  93. L_hat = interval.lowerEndpoint()
  94. R_hat = interval.upperEndpoint()
  95. D = False
  96. while (R_hat - L_hat) > 1.1*w:
  97. M = (L_hat+R_hat)/2.
  98. if (x0 < M and x1 >= M) or (x0 >= M and x1 < M):
  99. D = True
  100. if x1 < M:
  101. R_hat = M
  102. else:
  103. L_hat = M
  104. if D and (y > self._g(L_hat) and y >= self._g(R_hat)):
  105. return False
  106. return True
  107. def find_interval_doubling(self, x0, y, w, p=None):
  108. """ Finds an interval around a current point, x0, using a doubling procedure
  109. (See Radford paper at http://www.cs.toronto.edu/~radford/ftp/slc-samp.pdf)
  110.  
  111. Parameters
  112. ----------
  113. xo : float
  114. The current point
  115. y : float
  116. The vertical level defining the slice
  117. w : float
  118. Estimate of typical slice size
  119. p : int, optional
  120. Integer limiting the size of a slice to (2^p)w. If None,
  121. then interval can grow without bound
  122.  
  123. Returns
  124. -------
  125. Range containing the interval found
  126. """
  127. # Sample initial interval around x0
  128. U = np.random.random()
  129. L = x0-w*U
  130. R = L + w
  131. # Set limiter
  132. K = None
  133. if p:
  134. K = p
  135. # Refine interval
  136. while (y < self._g(L) or y < self._g(R)):
  137. V = np.random.random()
  138. if V < 0.5:
  139. L -= (R-L)
  140. else:
  141. R = R + (R-L)
  142. if p:
  143. K -= 1
  144. if K <= 0:
  145. break
  146. # Return the inteval
  147. return Range.closed(L,R)
  148. def find_interval_step_out(self, x0, y, w, m=None):
  149. """ Finds an interval around a current point, x0, using the stepping-out
  150. procedure (See Radford paper at http://www.cs.toronto.edu/~radford/ftp/slc-samp.pdf)
  151.  
  152. Parameters
  153. ----------
  154. xo : float
  155. The current point
  156. y : float
  157. The vertical level defining the slice
  158. w : float
  159. Estimate of typical slice size
  160. m : int, optional
  161. Integer, where maximum size of slice should be mw. If None,
  162. then interval can grow without bound
  163.  
  164. Returns
  165. -------
  166. Range containing the interval found
  167. """
  168. # Sample initial interval around x0
  169. U = np.random.random()
  170. L = x0 - w*U
  171. R = L + w
  172. # Place limitations on interval if necessary
  173. V = None
  174. J = None
  175. K = None
  176. if m:
  177. V = np.random.random()
  178. J = np.floor(m*V)
  179. K = (m-1)-J
  180. # Get left endpoint
  181. L = _step_left_endpoint(self._g, y, L, J, w, m)
  182. # Get right endpoint
  183. R = _step_right_endpoint(self._g, y, R, K, w, m)
  184. # Return interval
  185. return Range.closed(L,R)
  186. def run_sampler(self, x0_start = 0., n_samp = 10000, interval_method='doubling',
  187. w=0.1, m=None, p=None):
  188. """ Runs the slice sampler
  189.  
  190. Parameters
  191. ----------
  192. x0_start : float
  193. An initial value for x
  194. n_samp : int
  195. The number of samples to take
  196. interval_method : str
  197. The method for determining the interval at each stage of sampling. Possible values
  198. are 'doubling', 'stepping'.
  199. w : float
  200. Estimate of typical slice size. If adapt_w is true, then this is overriden
  201. m : int, optional (Only relevant for stepping interval procedure)
  202. Integer, where maximum size of slice should be mw. If None,
  203. then interval can grow without bound.
  204. p : int, optional (Only relevant for doubling interval procedure)
  205. Integer limiting the size of a slice to (2^p)w. If None,
  206. then interval can grow without bound
  207.  
  208. Returns
  209. -------
  210. Generator of samples from the distribution
  211.  
  212. Examples
  213. --------
  214. >>> from quantgen.stats_utils.slice_sampler import univariate_slice_sampler
  215. >>> from scipy.stats import norm
  216. >>> sampler = univariate_slice_sampler(lambda x: norm.logpdf(x, loc=0, scale=5.))
  217. >>> samples = [x for x in sampler.run_sampler(n_samp=1000, w=2.5)]
  218. """
  219. x0 = x0_start
  220. interval = None
  221. doubling_used = True
  222. if interval_method != 'doubling':
  223. doubling_used = False
  224. if self._adapt_w:
  225. w = self._w
  226. for i in xrange(n_samp):
  227. # Draw vertical value, y, that defines the horizontal slice
  228. y = self._g(x0) - np.random.exponential()
  229. # Find interval around x0 that contains at least a big part of the slice
  230. if interval_method == 'doubling':
  231. interval = self.find_interval_doubling(x0,y,w=w,p=p)
  232. elif interval_method == 'stepping':
  233. interval = self.find_interval_step_out(x0,y,w=w,m=m)
  234. else:
  235. raise ValueError("%s is not an interval method" % interval_method)
  236. # Draw new point
  237. x0 = self.sample_by_shrinkage(x0, y, w, interval, doubling_used=doubling_used)
  238. # Update w if necessary
  239. if self._adapt_w:
  240. interval_length = (interval.upperEndpoint()-interval.lowerEndpoint())
  241. self._w = np.power(self._w,self._w_n/(self._w_n+1.))*
  242. np.power(interval_length/2.,1./(self._w_n+1.))
  243. self._w_n += 1.
  244. w = self._w
  245. yield float(x0)
  246. def sample_by_shrinkage(self, x0, y, w, interval, doubling_used = False):
  247. """ Samples a point from an interval using the shrinkage procedure
  248. (See Radford paper at http://www.cs.toronto.edu/~radford/ftp/slc-samp.pdf)
  249.  
  250. Parameters
  251. ----------
  252. x0 : float
  253. The current point
  254. y : float
  255. The vertical level defining the slice
  256. w : float
  257. Estimate of typical slice size
  258. interval : Ranger.Range
  259. The interval found be the doubling procedure
  260. doubling_used : boolean
  261. Whether the doubling procedure was used when defining the interval
  262.  
  263. Returns
  264. -------
  265. The new point
  266. """
  267. # Set up the accept function, based on whether interval was found
  268. # using doubling
  269. accept_func = None
  270. if doubling_used:
  271. accept_func = lambda x0, x1, y, w, interval:
  272. self.accept_doubling(x0, x1, y, w, interval)
  273. else:
  274. accept_func = lambda x0, x1, y, w, interval: True
  275. # Set initial interval
  276. L_bar = interval.lowerEndpoint()
  277. R_bar = interval.upperEndpoint()
  278. x1 = L_bar + 0.5*(R_bar-L_bar)
  279. # Run through shrinkage
  280. while 1:
  281. U = np.random.random()
  282. x1 = L_bar + U*(R_bar-L_bar)
  283. if y < self._g(x1) and accept_func(x0, x1, y, w, interval):
  284. break
  285. if x1 < x0:
  286. L_bar = x1
  287. else:
  288. R_bar = x1
  289. return x1
  290.  
  291. #cython: wraparound=False
  292. #cython: boundscheck=False
  293. #cython: cdivision=True
  294. #cython: nonecheck=False
  295. from cpython cimport array
  296. import cython
  297. import numpy as np
  298. import ctypes
  299. cimport numpy as np
  300. cimport cython
  301. cimport python_unicode
  302. from libc.stdlib cimport malloc, free
  303. from libcpp.vector cimport vector
  304. cdef extern from "<math.h>":
  305. cdef double floor(double)
  306. cdef double log(double)
  307. cdef double pow(double, double)
  308. cdef double tgamma(double)
  309.  
  310. cdef extern from "stdint.h":
  311. ctypedef unsigned long long uint64_t
  312. cdef extern from "gsl/gsl_rng.h":#nogil:
  313. ctypedef struct gsl_rng_type:
  314. pass
  315. ctypedef struct gsl_rng:
  316. pass
  317. gsl_rng_type *gsl_rng_mt19937
  318. gsl_rng *gsl_rng_alloc(gsl_rng_type * T)
  319.  
  320. cdef gsl_rng *r = gsl_rng_alloc(gsl_rng_mt19937)
  321.  
  322. cdef extern from "gsl/gsl_randist.h" nogil:
  323. double unif "gsl_rng_uniform"(gsl_rng * r)
  324. double unif_interval "gsl_ran_flat"(gsl_rng * r,double,double) ## syntax; (seed, lower, upper)
  325. double exponential "gsl_ran_exponential"(gsl_rng * r,double) ## syntax; (seed, mean) ... mean is 1/rate
  326.  
  327. ctypedef double (*func_t)(double)
  328.  
  329. cdef class wrapper:
  330. cdef func_t wrapped
  331. def __call__(self, value):
  332. return self.wrapped(value)
  333. def __unsafe_set(self, ptr):
  334. self.wrapped = <func_t><void *><size_t>ptr
  335.  
  336. cdef double* stepping_out(double x0, double y, double w, int m, func_t f):
  337. """
  338. Function for finding an interval around the current point
  339. using the "stepping out" procedure (Figure 3, Neal (2003))
  340. Parameters of stepping_out subroutine:
  341. Input:
  342. x0 ------------ the current point
  343. y ------------ logarithm of the vertical level defining the slice
  344. w ------------ estimate of the typical size of a slice
  345. m ------------ integer limiting the size of a slice to "m*w"
  346. (*func_t) ------------ routine to compute g(x) = log(f(x))
  347. Output:
  348. interv[2] ------------ the left and right sides of found interval
  349. """
  350. cdef double *interv = <double *>malloc(2 * cython.sizeof(double))
  351. if interv is NULL:
  352. raise MemoryError()
  353. cdef double u
  354. cdef int J, K
  355. cdef double g_interv[2]
  356. #Initial guess for the interval
  357. u = unif_interval(r,0,1)
  358. interv[0] = x0 - w*u
  359. interv[1] = interv[0] + w
  360.  
  361. #Get numbers of steps tried to left and to right
  362. if m>0:
  363. u = unif_interval(r,0,1)
  364. J = <uint64_t>floor(m*u)
  365. K = (m-1)-J
  366.  
  367. #Initial evaluation of g in the left and right limits of the interval
  368. g_interv[0]=f(interv[0])
  369. g_interv[1]=f(interv[1])
  370.  
  371. #Step to left until leaving the slice
  372. while (g_interv[0] >= y):
  373. interv[0] -= w
  374. g_interv[0]=f(interv[0])
  375. if m>0:
  376. J-=1
  377. if (J<= 0):
  378. break
  379.  
  380.  
  381. #Step to right until leaving the slice */
  382. while (g_interv[1] > y):
  383. interv[1] += w
  384. g_interv[1]=f(interv[1])
  385. if m>0:
  386. K-=1
  387. if (K<= 0):
  388. break
  389. #http://cython.readthedocs.io/en/latest/src/tutorial/memory_allocation.html
  390. try:
  391. return interv
  392. finally:
  393. # return the previously allocated memory to the system
  394. free(interv)
  395.  
  396. cdef double* doubling(double x0, double y, double w, int p, func_t f):
  397. """
  398. Function for finding an interval around the current point
  399. using the "doubling" procedure (Figure 4, Neal (2003))
  400. Input:
  401. x0 ------------ the current point
  402. y ------------ logarithm of the vertical level defining the slice
  403. w ------------ estimate of the typical size of a slice
  404. p ------------ integer limiting the size of a slice to "2^p*w"
  405. (*func_t) ------------ routine to compute g(x) = log(f(x))
  406. Output:
  407. interv[2] ------------ the left and right sides of found interval
  408. """
  409. cdef double* interv = <double *>malloc(2 * cython.sizeof(double))
  410. if interv is NULL:
  411. raise MemoryError()
  412. cdef double u
  413. cdef int K
  414. cdef bint now_left
  415. cdef double g_interv[2]
  416. #Initial guess for the interval
  417. u = unif_interval(r,0,1)
  418. interv[0] = x0 - w*u
  419. interv[1] = interv[0] + w
  420. if p>0:
  421. K = p
  422.  
  423. # Initial evaluation of g in the left and right limits of the interval
  424. g_interv[0]= f(interv[0])
  425. g_interv[1]= f(interv[1])
  426.  
  427. # Perform doubling until both ends are outside the slice
  428. while ((g_interv[0] > y) or (g_interv[1] > y)):
  429. u = unif_interval(r,0,1)
  430. now_left = (u < 0.5)
  431. if (now_left):
  432. interv[0] -= (interv[1] - interv[0])
  433. g_interv[0]=f(interv[0])
  434. else:
  435. interv[1] += (interv[1] - interv[0])
  436. g_interv[1]=f(interv[1])
  437. if p>0:
  438. K-=1
  439. if (K<=0):
  440. break
  441. try:
  442. return interv
  443. finally:
  444. # return the previously allocated memory to the system
  445. free(interv)
  446.  
  447. cdef bint accept_doubling(double x0, double x1, double y, double w, np.ndarray[ndim=1, dtype=np.float64_t] interv, func_t f):
  448. """
  449. Acceptance test of newly sampled point when the "doubling" procedure has been used to find an
  450. interval to sample from (Figure 6, Neal (2003))
  451. Parameters
  452. Input:
  453. x0 ------------ the current point
  454. x1 ------------- the possible next candidate point
  455. y ------------ logarithm of the vertical level defining the slice
  456. w ------------ estimate of the typical size of a slice
  457. interv[2] ------------ the left and right sides of found interval
  458. (*func_t) ------------ routine to compute g(x) = log(f(x))
  459. Output:
  460. accept ------------ True/False indicating whether the point is acceptable or not
  461. """
  462. cdef double interv1[2]
  463. cdef double g_interv1[2]
  464. cdef bint D
  465. cdef double w11, mid
  466. w11 = 1.1*w
  467. interv1[0] = interv[0]
  468. interv1[1] = interv[1]
  469. D = False
  470. while ( (interv1[1] - interv1[0]) > w11):
  471. mid = 0.5*(interv1[0] + interv1[1])
  472. if ((x0 < mid) and (x1 >= mid)) or ((x0 >= mid) and (x1 < mid)):
  473. D = True
  474. if (x1 < mid):
  475. interv1[1] = mid
  476. g_interv1[1] = f(interv1[1])
  477. else:
  478. interv1[0] = mid
  479. g_interv1[0] = f(interv1[0])
  480. if (D and (g_interv1[0] < y) and (g_interv1[1] <= y)):
  481. return False
  482. return True
  483.  
  484.  
  485. cdef double shrinkage(double x0, double y, double w, np.ndarray[ndim=1, dtype=np.float64_t] interv, bint doubling, func_t f):
  486. """
  487. Function to sample a point from the interval while skrinking the interval when the sampled point is
  488. not acceptable (Figure 5, Neal (2003))
  489. Input:
  490. x0 ------------ the current point
  491. y ------------ logarithm of the vertical level defining the slice
  492. w ------------ estimate of the typical size of a slice
  493. interv[2] ------------ the left and right sides of found interval
  494. (*func_t) ------------ routine to compute g(x) = log(f(x))
  495. doubling ------------ 0/1 indicating whether doubling was used to find an interval
  496. Output:
  497. x1 ------------- newly sampled point
  498. """
  499. cdef double u, gx1, x1
  500. cdef bint accept
  501. cdef double L_bar, R_bar
  502. L_bar=interv[0]
  503. R_bar=interv[1]
  504. x1 = L_bar + 0.5*(R_bar - L_bar)
  505. while True:
  506. u = unif_interval(r,0,1)
  507. x1 = L_bar + u*(R_bar - L_bar)
  508. gx1=f(x1)
  509. if (doubling):
  510. accept=accept_doubling(x0, x1, y, w, interv, f )
  511. if ((gx1 > y) and accept):
  512. break
  513. if (x1 < x0):
  514. L_bar = x1
  515. else:
  516. R_bar = x1
  517. else:
  518. if (gx1 > y):
  519. break
  520. if (x1 < x0):
  521. L_bar = x1
  522. else:
  523. R_bar = x1
  524. return x1
  525. cdef double log_beta(double x):
  526. #Log of beta distribution with second argument b=1
  527. cdef double a=5.
  528. return log(a)+(a-1.)*log(x)
  529.  
  530. cdef wrapper make_wrapper(func_t f):
  531. cdef wrapper W=wrapper()
  532. W.wrapped=f
  533. return W
  534.  
  535. def slice_sampler(int n_sample,
  536. wrapper f,
  537. int m = 0,
  538. int p = 0,
  539. double x0_start=0.0,
  540. bint adapt_w=False,
  541. double w_start=0.1,
  542. char* interval_method ='doubling'):
  543. """
  544. Inputs:
  545. n_sample ------------ Number of sample points from the given distribution
  546. f ------------ A log of the function you want to sample and accepts a scalar as an argument (the x)
  547. x0_start ------------ An initial value for x
  548. adapt_w ------------ Whether to adapt w during sampling. Will work in between samplings
  549. w_start ------------ Starting value for w, necessary if adapting w during sampling
  550. p ------------ Integer limiting the size of a slice to (2^p)w. If None, then interval can grow without bound
  551. m ------------ Integer, where maximum size of slice should be mw. If None, then interval can grow without bound
  552. interval_method ------------ The method for determining the interval at each stage of sampling. Possible values are 'doubling', 'stepping'.
  553. """
  554. cdef unicode s= interval_method.decode('UTF-8', 'strict')
  555. cdef double x0 = x0_start
  556. cdef double vertical, w, expon
  557. cdef double interval_length
  558. cdef bint doubling_used=True
  559. if (s!=u'doubling'):
  560. doubling_used=False
  561. cdef double w_n=1.
  562. cdef vector[double] samples #http://cython.readthedocs.io/en/latest/src/userguide/wrapping_CPlusPlus.html
  563. w=w_start
  564. cdef Py_ssize_t i
  565. cdef np.ndarray[ndim=1, dtype=np.float64_t] interv
  566. cdef np.float64_t[:] view
  567. for 0<= i <n_sample:
  568. expon = exponential(r, 1)
  569. vertical = f.wrapped(x0) - expon
  570. print x0,f.wrapped(x0),vertical
  571. if (s=='doubling'):
  572. view=<np.float64_t[:2]>doubling(x0, vertical, w, p, f.wrapped)
  573. interv= np.asarray(view)
  574. print "after:", interv
  575. elif (s==u'stepping'):
  576. view=<np.float64_t[:2]>stepping_out(x0, vertical, w, m, f.wrapped)
  577. interv= np.asarray(view)
  578. else:
  579. raise ValueError("%s is not an acceptable interval method for slice sampler"%s )
  580. print "finish expanding the range.."
  581. x0=shrinkage(x0, vertical, w, interv, doubling_used, f.wrapped)
  582. samples.push_back(x0)
  583.  
  584. if adapt_w:
  585. interval_length=interv[1]-interv[0]
  586. w=pow(w,w_n/(w_n+1))*pow(interval_length/2.,1./(w_n+1.))
  587. w_n+=1.
  588. print x0
  589. return samples
  590.  
  591. def run(int n_sample,
  592. double x0_start=0.01,
  593. double w_start=2.5):
  594. wrap_f=make_wrapper(log_beta)
  595. return slice_sampler(n_sample, wrap_f,x0_start=x0_start, w_start=w_start)
  596.  
  597. from distutils.core import setup
  598. from distutils.extension import Extension
  599.  
  600. import numpy
  601. from Cython.Distutils import build_ext
  602. extra_compile_args = ['-std=c++11']
  603. extra_link_args = ['-Wall']
  604. setup(
  605. cmdclass = {'build_ext': build_ext},
  606. ext_modules=[
  607. Extension("SliceSampler",
  608. sources=["SliceSampler.pyx"],
  609. language="c++",
  610. libraries=["stdc++","gsl", "gslcblas"],
  611. include_dirs=[numpy.get_include()],
  612. extra_compile_args=extra_compile_args,
  613. extra_link_args=extra_link_args)
  614. ],
  615. gdb_debug=True)
  616.  
  617. import numpy as np
  618. import pylab as plt
  619. from SliceSampler import *
  620. x=run(10000)
  621. n, bins, patches = plt.hist(x, 50, normed=1, histtype='step', lw=2, color='r', label="Beta")
  622. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement