Advertisement
Guest User

Untitled

a guest
Feb 10th, 2016
55
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.72 KB | None | 0 0
  1.  
  2. import numbers
  3. from numpy import *
  4.  
  5. class Sampler:
  6.  
  7. def __init__(self, max_entries, max_value=100, min_value=1):
  8. self.nentries = 0
  9. self.max_entries = max_entries
  10. self.max_value = max_value
  11. self.min_value = min_value
  12. self.top_level = int(ceil(log2(max_value)))
  13. self.bottom_level = int(ceil(log2(min_value)))
  14. self.nlevels = 1 + self.top_level - self.bottom_level
  15.  
  16. self.total_weight = 0
  17. self.weights = zeros(max_entries, dtype='d')
  18.  
  19.  
  20. self.level_weights = zeros(self.nlevels, dtype='d')
  21. self.level_buckets = [[] for i in range(self.nlevels)]
  22. self.level_max = [pow(2, self.top_level-i) for i in range(self.nlevels)]
  23.  
  24. def add(self, idx, weight):
  25. if weight > self.max_value or weight < self.min_value:
  26. raise Exception("Weight out of range: %1.2e" % weight)
  27.  
  28. if idx < 0 or idx >= self.max_entries or not isinstance(idx, numbers.Integral):
  29. raise Exception("Bad index: %s", idx)
  30.  
  31. self.nentries += 1
  32. self.total_weight += weight
  33.  
  34. self.weights[idx] = weight
  35.  
  36. raw_level = int(ceil(log2(weight)))
  37. level = self.top_level - raw_level
  38.  
  39. self.level_weights[level] += weight
  40. self.level_buckets[level].append(idx)
  41.  
  42. def _sample(self):
  43.  
  44. u = random.uniform(high=self.total_weight)
  45.  
  46. # Sample a level using the CDF method
  47. cumulative_weight = 0
  48. for i in range(self.nlevels):
  49. cumulative_weight += self.level_weights[i]
  50. level = i
  51. if u < cumulative_weight:
  52. break
  53.  
  54. # Now sample within the level using rejection sampling
  55. level_size = len(self.level_buckets[level])
  56. level_max = self.level_max[level]
  57. reject = True
  58. while reject:
  59. idx_in_level = random.randint(0, level_size)
  60. idx = self.level_buckets[level][idx_in_level]
  61. idx_weight = self.weights[idx]
  62. u_lvl = random.uniform(high=level_max)
  63. if u_lvl <= idx_weight:
  64. reject = False
  65.  
  66. return (idx, level, idx_in_level, idx_weight)
  67.  
  68. def sample(self):
  69. return self._sample()[0]
  70.  
  71. def sampleAndRemove(self):
  72. (idx, level, idx_in_level, weight) = self._sample()
  73.  
  74. # Remove it
  75. self.weights[idx] = 0.0
  76. self.total_weight -= weight
  77. self.level_weights[level] -= weight
  78. # Swap with last element for efficent delete
  79. swap_idx = self.level_buckets[level].pop()
  80. self.level_buckets[level][idx_in_level] = swap_idx
  81. self.nentries -= 1
  82.  
  83. return (idx, weight)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement