Advertisement
creamygoat

Dartboard vs Resampling Wheel

Jun 13th, 2012
557
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 12.32 KB | None | 0 0
  1. #!/usr/bin/python
  2. #-------------------------------------------------------------------------------
  3. # dbvsrsw.py
  4. # http://pastebin.com/F4JZtDTB
  5. #-------------------------------------------------------------------------------
  6.  
  7.  
  8. '''DBvsRSW, a Python script for testing the fairness of a resampling wheel.
  9.  
  10. Description:
  11.  
  12.  A resampling wheel is a much more efficient method for taking weighted
  13.  samples from a population according to their weights than a dartboard.
  14.  The resampling wheel is not completely unbiased, though it's likely
  15.  to be adequate for genetic algorithms like particle filtering where
  16.  the problem of over-filtering is a more pressing concern.
  17.  
  18.  For a list of weights in increasing or decreasing order, at least,
  19.  the resampling wheel often performs very poorly if the starting index
  20.  is not randomised and the angle step factor is small.
  21.  
  22.  Included in this testing suite is the brilliant O(N) resampling wheel
  23.  presented by Erik Colban on the Udacity CS373 forum.
  24.  
  25. Author:
  26.  Daniel Neville (Blancmange), creamygoat@gmail.com
  27.  (Resampling Wheel adapted from Prof. Sebastian Thrun's CS373 lecture notes)
  28.  (SmartWheel adapted from Erik Colban's O(N) weighted resampler)
  29.  
  30. Copyright:
  31.  None
  32.  
  33. Licence:
  34.  Public domain
  35.  
  36.  
  37. INDEX
  38.  
  39.  
  40. Imports
  41.  
  42. Output formatting functions:
  43.  
  44.  GFListStr(L)
  45.  
  46. Dartboard functions:
  47.  
  48.  DartboardFromWeights(Weights)
  49.  ThrowDart(Dartboard)
  50.  
  51. Weighted sampler test functions:
  52.  
  53.  TestDartboard(Weights, NumRounds, [Variations])
  54.  TestRSWheel(Weights, NumRounds, [Variations])
  55.  TestSmartWheel(Weights, NumRounds, [Variations])
  56.  DisplayWeightsHistogram(Weights, Histogram)
  57.  
  58. Main:
  59.  
  60.  Main()
  61.  
  62. Command line trigger
  63.  
  64. '''
  65.  
  66.  
  67. #-------------------------------------------------------------------------------
  68. # Imports
  69. #-------------------------------------------------------------------------------
  70.  
  71. import math
  72. from math import (log, sqrt, frexp)
  73. import random
  74.  
  75.  
  76. #-------------------------------------------------------------------------------
  77. # Output formatting functions
  78. #-------------------------------------------------------------------------------
  79.  
  80.  
  81. def GFListStr(L):
  82.   '''Return as a string, a list (or tuple) in general precision format.'''
  83.   return '[' + (', '.join('%g' % (x) for x in L)) + ']'
  84.  
  85.  
  86. #-------------------------------------------------------------------------------
  87. # Dartboard functions
  88. #-------------------------------------------------------------------------------
  89.  
  90.  
  91. def DartboardFromWeights(Weights):
  92.  
  93.   '''Returns a dartboard with its n+1 fences spaced by the n weights.'''
  94.  
  95.   Result = []
  96.   ZoneLowerLimit = 0.0
  97.   Result.append(ZoneLowerLimit)
  98.  
  99.   for w in Weights:
  100.     ZoneLowerLimit += w
  101.     Result.append(ZoneLowerLimit)
  102.  
  103.   return Result
  104.  
  105.  
  106. #-------------------------------------------------------------------------------
  107.  
  108.  
  109. def ThrowDart(Dartboard):
  110.  
  111.   '''Returns the index of a random dart thrown at a dartboard.
  112.  
  113.  The board is defined by lower (inclusive) bounds of each zone
  114.  followed by the (exclusive) upper bound.
  115.  
  116.  '''
  117.  
  118.   LowIx = 0
  119.   Result = -1
  120.   HighIx = len(Dartboard) - 2
  121.  
  122.   if HighIx >= 0:
  123.  
  124.     InclusiveLowerLimit = Dartboard[0]
  125.     ExclusiveUpperLimit = Dartboard[HighIx + 1]
  126.     r = random.random()
  127.     Dart = (1.0 - r) * InclusiveLowerLimit + r * ExclusiveUpperLimit
  128.  
  129.     # Binary search, choosing last among equals.
  130.     while LowIx < HighIx:
  131.       MidIx = HighIx - ((HighIx - LowIx) / 2)
  132.       ZoneLL = Dartboard[MidIx]
  133.       if Dart < ZoneLL:
  134.         HighIx = MidIx - 1
  135.       else:
  136.         LowIx = MidIx
  137.     # Whatever happens, LowIx = HighIx
  138.     Result = LowIx
  139.  
  140.   return Result
  141.  
  142.  
  143. #-------------------------------------------------------------------------------
  144. # Weighted sampler test functions
  145. #-------------------------------------------------------------------------------
  146.  
  147.  
  148. def TestDartboard(Weights, NumRounds, Variations=None):
  149.  
  150.   '''Test the Dartboard method for selecting items according to their weights.
  151.  
  152.  This method is not especially efficient but is straightforward and robust.
  153.  It serves as a standard by which the other resamplers may be compared.
  154.  
  155.  No variations are offered.
  156.  
  157.  Returned is a histogram with one entry for each item in Weights.
  158.  
  159.  '''
  160.  
  161.   Dartboard = DartboardFromWeights(Weights)
  162.   Histogram = [0] * len(Weights)
  163.   NumDartsToThrow = len(Weights) * NumRounds
  164.  
  165.   print "Throwing %d darts..." % (NumDartsToThrow)
  166.  
  167.   while NumDartsToThrow > 0:
  168.     Ix = ThrowDart(Dartboard)
  169.     Histogram[Ix] += 1
  170.     NumDartsToThrow -= 1
  171.  
  172.   return Histogram
  173.  
  174.  
  175. #-------------------------------------------------------------------------------
  176.  
  177.  
  178. def TestRSWheel(Weights, NumRounds, Variations=None):
  179.  
  180.   '''Test the Resampling Wheel method for weighted resampling.
  181.  
  182.  The resampling wheel, presented by Pref. Sebastian Thrun in the Udacity
  183.  course CS373: Programming a Robotic Car, is quick and dirty and prone to
  184.  bias but adequate for applications such as particle filtering.
  185.  
  186.  Defaults for Variations:
  187.    MaxStepFactor: 2.0
  188.    RandomStartIndex: True
  189.  
  190.  Returned is a histogram with one entry for each item in Weights.
  191.  
  192.  '''
  193.  
  194.   #-----------------------------------------------------------------------------
  195.  
  196.   def Variation(Key, Default):
  197.     Result = Default
  198.     if Variations is not None:
  199.       if Key in Variations:
  200.         Result = Variations[Key]
  201.     return Result
  202.  
  203.   #-----------------------------------------------------------------------------
  204.  
  205.   MaxStepFactor = Variation('MaxStepFactor', 2.0)
  206.   DoRandomiseStartIx = Variation('RandomStartIndex', True)
  207.  
  208.   N = len(Weights)
  209.   Histogram = [0] * N
  210.  
  211.   print 'Resampling %d items from %d successive wheels...' % (N, NumRounds)
  212.   print 'Max step factor:', MaxStepFactor
  213.   print 'Start index:', (['0', 'Random'][DoRandomiseStartIx])
  214.  
  215.   NumRoundsToGo = NumRounds
  216.  
  217.   while NumRoundsToGo > 0:
  218.  
  219.     WIx = 0
  220.     Phase = 0.0
  221.     MaxStep = MaxStepFactor * max(Weights)
  222.  
  223.     if DoRandomiseStartIx:
  224.       WIx = random.randint(0, N - 1)
  225.  
  226.     NumSamplesToGo = N
  227.  
  228.     while NumSamplesToGo > 0:
  229.       Step = MaxStep * random.random()
  230.       Phase += Step
  231.       while Weights[WIx] <= Phase:
  232.         Phase -= Weights[WIx]
  233.         WIx = (WIx + 1) % N
  234.       Histogram[WIx] += 1
  235.       NumSamplesToGo -= 1
  236.  
  237.     NumRoundsToGo -= 1
  238.  
  239.   return Histogram
  240.  
  241.  
  242. #-------------------------------------------------------------------------------
  243.  
  244.  
  245. def TestSmartWheel(Weights, NumRounds, Variations=None):
  246.  
  247.   '''Test Erik Colban's weighted resampler.
  248.  
  249.  This resampler is like a resampling wheel except that it only goes around
  250.  once and has a complexity of O(N).
  251.  
  252.  Defaults for Variations:
  253.    UseApproxLog2: False
  254.  
  255.  Returned is a histogram with one entry for each item in Weights.
  256.  
  257.  '''
  258.  
  259.   # From Erik's notes:
  260.   #
  261.   # The algorithm is based on the fact that, after sorting N uniformly
  262.   # distributed samples, the distances between two consecutive samples
  263.   # is exponentially distributed. The algorithm is similar to the resampling
  264.   # wheel algorithm, except that it makes exactly one revolution around the
  265.   # resampling wheel. This resampling algorithm is O(N)
  266.   #
  267.   # http://forums.udacity.com/cs373-april2012/questions/1328/an-on-unbiased-resampler
  268.  
  269.   #-----------------------------------------------------------------------------
  270.  
  271.   def Variation(Key, Default):
  272.     Result = Default
  273.     if Variations is not None:
  274.       if Key in Variations:
  275.         Result = Variations[Key]
  276.     return Result
  277.  
  278.   #-----------------------------------------------------------------------------
  279.  
  280.   UseApproxLog2 = Variation('ApproxLog2', False)
  281.  
  282.   N = len(Weights)
  283.   Histogram = [0] * N
  284.  
  285.   print 'Resampling %d items from %d successive wheels...' % (N, NumRounds)
  286.   print 'Logarithm:', (['Natural', 'Approx. base 2'][UseApproxLog2])
  287.  
  288.   NumRoundsToGo = NumRounds
  289.  
  290.   while NumRoundsToGo > 0:
  291.  
  292.     # Select N + 1 numbers exponentially distributed with parameter lambda = 1.
  293.     Diffs = [0] * (N + 1)
  294.  
  295.     if UseApproxLog2:
  296.       # Use an approximate base-2 logarithim to avoid repeatedly
  297.       # calling a transcendental function.
  298.  
  299.       # The trick relies on the way IEEE 754 floats are stored.
  300.       # The average error is 0.001276 and the maximum error is 0.001915.
  301.       a = 1.0 / sqrt(2.0)
  302.       b = 1.0 / (1.0 + a)
  303.       c = 1.0 / (1.0 / (0.5 + a) - b)
  304.       for i in range(N + 1):
  305.         x = 1.0 - random.random()
  306.         m, e = frexp(x)
  307.         al2 = e - c * (1.0 / (m + a) - b)
  308.         Diffs[i] = -al2
  309.  
  310.     else:
  311.       # Use a real logarithm function. The natural logarithm is fine.
  312.  
  313.       for i in range(N + 1):
  314.         Diffs[i] = -log(1.0 - random.random())
  315.  
  316.     # Stretch to fit the circumference of the resampling wheel.
  317.     Scale = sum(Weights) / sum(Diffs)
  318.     Diffs = [Scale * x for x in Diffs]
  319.  
  320.     WIx = 0
  321.     Phase = 0
  322.  
  323.     try:
  324.       # Go around the resampling wheel exactly once.
  325.       for i in range(N):
  326.         Phase += Diffs[i]
  327.         # The number of step-seek iterations is random for each sampling
  328.         # but the total number of such iterations when all N samplings
  329.         # are performed is N - 1. The wheel sample index never exceeds
  330.         # N - 1 except in an extremely unlucky case of rounding.
  331.         while Phase > Weights[WIx]:
  332.           Phase -= Weights[WIx]
  333.           WIx += 1
  334.         Histogram[WIx] += 1
  335.     except (IndexError):
  336.       # This can only happen in the extremely unlucky case
  337.       # of accumulated rounding errors.
  338.       pass
  339.  
  340.     NumRoundsToGo -= 1
  341.  
  342.   return Histogram
  343.  
  344.  
  345. #-------------------------------------------------------------------------------
  346.  
  347.  
  348. def DisplayWeightsHistogram(Weights, Histogram):
  349.  
  350.   '''Display a histogram using ASCII art, each bar labelled with weights.
  351.  
  352.  The width of the histogram is automatically scaled to the largest bar.
  353.  
  354.  '''
  355.  
  356.   #-----------------------------------------------------------------------------
  357.  
  358.   MAX_BAR_WIDTH = 40
  359.  
  360.   #-----------------------------------------------------------------------------
  361.  
  362.   IxWidth = len(str(len(Histogram)))
  363.   MaxH = max(1, max(Histogram))
  364.   TotalH = sum(Histogram)
  365.   SafeHPCMult = 100.0 / max(1, TotalH)
  366.  
  367.   for Ix, NumHits in enumerate(Histogram):
  368.     BarLength = int(round(MAX_BAR_WIDTH * NumHits / float(MaxH)))
  369.     ItemStr = "W[%*d]" % (IxWidth, Ix)
  370.     if Weights is not None:
  371.       ValueStr = " = %6.3f " % (Weights[Ix])
  372.     BarStr = "#" * BarLength
  373.     PercentStr = "(%.2f%%)" % (SafeHPCMult * Histogram[Ix])
  374.     print "%-*s%s: %s %s" % (
  375.         IxWidth + 2, ItemStr, ValueStr, BarStr, PercentStr)
  376.  
  377.  
  378. #-------------------------------------------------------------------------------
  379. # Main
  380. #-------------------------------------------------------------------------------
  381.  
  382.  
  383. def Main():
  384.  
  385.   '''Run the test suite of weighted resamplers.'''
  386.  
  387.   Weights = [15 + x for x in range(1, 21)]
  388.   NumRounds = 1000
  389.  
  390.   WeightedSamplers = [
  391.     (TestDartboard, 'Dartboard', None, None),
  392.     (TestRSWheel, 'Resampling Wheel', '(Random start index as standard)', None),
  393.     (TestRSWheel, 'Resampling Wheel', '(Starts at first index)',
  394.         {'RandomStartIndex': False}),
  395.     (TestRSWheel, 'Resampling Wheel', '(SFI, Step factor of 5)',
  396.         {'RandomStartIndex': False, 'MaxStepFactor': 5.0}),
  397.     (TestRSWheel, 'Resampling Wheel', '(RSI, step factor of 5)',
  398.         {'RandomStartIndex': True, 'MaxStepFactor': 5.0}),
  399.     (TestSmartWheel, 'Erik Colban\'s Smart Wheel', None, None),
  400.     (TestSmartWheel, 'Erik Colban\'s Smart Wheel', '(Approx. log2)',
  401.         {'ApproxLog2': True})
  402.   ]
  403.  
  404.   print 'Test of Weighted Resampling Functions\n'
  405.  
  406.   print "Weights:"
  407.   print GFListStr(Weights)
  408.   print
  409.  
  410.   for TIx, TestRec in enumerate(WeightedSamplers):
  411.     TestFn, TestName, VariantName, Variations = TestRec
  412.     print 'Test %d/%d: %s' % (TIx + 1, len(WeightedSamplers) , TestName)
  413.     if VariantName is not None:
  414.       print VariantName
  415.     print
  416.     Histogram = TestFn(Weights, NumRounds, Variations)
  417.     DisplayWeightsHistogram(Weights, Histogram)
  418.     print
  419.  
  420.  
  421. #-------------------------------------------------------------------------------
  422. # Command line trigger
  423. #-------------------------------------------------------------------------------
  424.  
  425.  
  426. if __name__ == '__main__':
  427.   Main()
  428.  
  429.  
  430. #-------------------------------------------------------------------------------
  431. # End
  432. #-------------------------------------------------------------------------------
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement