Advertisement
Guest User

Multiprocessing Mandelbrot

a guest
Sep 28th, 2011
305
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.21 KB | None | 0 0
  1. import time
  2. import logging
  3.  
  4. import multiprocessing
  5. from Queue import Empty  # awkward
  6.  
  7. import numpy
  8. from numpy import multiply, add  # Frequently used in loop
  9.  
  10. # Set to logging.INFO for debugging
  11. logging.basicConfig(level=logging.ERROR)
  12.  
  13. class Chunk(object):
  14.     def __init__(self, size, base, r_range, i_range):
  15.         self.size = size
  16.         self.base = base
  17.         self.r_range = r_range
  18.         self.i_range = i_range
  19.  
  20.         self.image = None
  21.         self.compute_time = 0.0
  22.         self.iterations = 100
  23.         self.escape = 2.0
  24.  
  25. # I was using numpy to quickly calculate the Mandelbrot set but this version is
  26. # much, much slicker:
  27. # http://thesamovar.wordpress.com/2009/03/22/fast-fractals-with-python-and-numpy/
  28. # ...So I'm using it instead (with a few minor changes)
  29.  
  30. def mandelbrot(size,
  31.                r_start=-2, r_stop=0.5,
  32.                i_start=-1.25, i_stop=1.25,
  33.                iterations=256, escape=2.0):
  34.  
  35.     ix, iy = numpy.mgrid[0:size[0], 0:size[1]]
  36.     x = numpy.linspace(r_start, r_stop, size[0])[ix]
  37.     y = numpy.linspace(i_start, i_stop, size[1])[iy]
  38.     c = x + numpy.complex(0, 1) * y
  39.     del x, y  # Seems a little excessive
  40.  
  41.     results = numpy.zeros(c.shape, numpy.uint8)
  42.  
  43.     # This is the really cool part.  
  44.     ix.shape = size[0] * size[1]
  45.     iy.shape = size[0] * size[1]
  46.     c.shape  = size[0] * size[1]
  47.  
  48.     z = numpy.copy(c)
  49.     for i in xrange(iterations):
  50.         if not len(z): break
  51.         multiply(z, z, z)  # z * z -> z [fast square]
  52.         add(z, c, z)       # z + c -> z [fast add]
  53.  
  54.         # tally escaped points
  55.         rem = abs(z) > escape
  56.         results[ix[rem], iy[rem]] = i + 1
  57.  
  58.         # prune points already escaped
  59.         rem = -rem
  60.         z = z[rem]
  61.         ix, iy = ix[rem], iy[rem]
  62.         c = c[rem]
  63.  
  64.     return results
  65.  
  66. def compute(input_q, output_q):
  67.     name = multiprocessing.current_process().name
  68.     logging.info("Starting compute thread: %r" % name)
  69.  
  70.     while True:
  71.         try:
  72.             logging.info("Worker %r reading input queue." % name)
  73.             chunk = input_q.get(True, 1.0)
  74.         except Empty:
  75.             logging.info("Worker %r queue empty." % name)
  76.             break
  77.         if chunk is None: break
  78.  
  79.         logging.info("Worker %r starting computation." % name)
  80.         start = time.time()
  81.         chunk.image = mandelbrot(chunk.size,
  82.                                  chunk.r_range[0], chunk.r_range[1],
  83.                                  chunk.i_range[0], chunk.i_range[1],
  84.                                  chunk.iterations, chunk.escape)
  85.         chunk.compute_time = time.time() - start
  86.         logging.info("Worker %r finished computation in %.3f seconds." % (name, chunk.compute_time))
  87.         output_q.put(chunk)
  88.         logging.info("Worker %r pushed chunk %s into output queue." % (name, id(chunk)))
  89.     logging.info("Shutting down compute thread: %r" % name)
  90.  
  91. def sticher(image_size, input_q, output_q):
  92.     name = multiprocessing.current_process().name
  93.     logging.info("Starting stiching thread: %r" % name)
  94.  
  95.     image = numpy.zeros(image_size, numpy.uint8)
  96.  
  97.     fragments = 0
  98.     total_time = 0.0
  99.     while True:
  100.         try:
  101.             logging.info("Sticher reading data from queue")
  102.             record = input_q.get(True, 2.0)
  103.         except Empty:
  104.             logging.info("Sticher found no data in queue")
  105.             continue
  106.  
  107.         if record is None:
  108.             logging.info("Sticher got a shutdown command")
  109.             break
  110.         elif record == 'image':
  111.             logging.info("Sticher returning an image")
  112.             output_q.put(image)
  113.         elif isinstance(record, Chunk):
  114.             logging.info("Sticher processing chunk: %s" % id(record))
  115.             x, y = record.image.shape
  116.             image[record.base[0]:record.base[0]+x, record.base[1]:record.base[1]+y] = record.image
  117.             fragments += 1
  118.             total_time += record.compute_time
  119.  
  120.     logging.info("Total chunks processed: %d" % fragments)
  121.     logging.info("Total chunk compute time: %.3f seconds" % total_time)
  122.     logging.info("Shutting down stiching thread: %r" % name)
  123.  
  124. def multi_mandel(image_size, r_range, i_range, chunks=10, threads=8):
  125.     if any([i % chunks for i in image_size]):
  126.         raise ValueError("Both image_size dimensions must be divisable by chunks")
  127.     if threads < 2:
  128.         raise ValueError("Threads must be >= 2")
  129.  
  130.     logging.info("Initializing queues")
  131.     # Work units start out here.  Read by the workers.
  132.     chunk_q = multiprocessing.Queue()
  133.     # Workers write to the work_q and it is read by the sticher.
  134.     work_q = multiprocessing.Queue()
  135.     # Control communication with the sticher
  136.     cmd_q = multiprocessing.Queue()
  137.  
  138.     r = numpy.linspace(r_range[0], r_range[1], chunks+1)
  139.     i = numpy.linspace(i_range[0], i_range[1], chunks+1)
  140.     chunk_size = tuple(i / chunks for i in image_size)
  141.  
  142.     logging.info("Building work units")
  143.     # Push data into the chunk_q
  144.     for y in range(chunks):
  145.         for x in range(chunks):
  146.             base = x * chunk_size[0], y * chunk_size[1]
  147.             r_range = r[x], r[x+1]
  148.             i_range = i[y], i[y+1]
  149.             chunk_q.put(Chunk(chunk_size, base, r_range, i_range))
  150.     # Add a bunch of Nones to chunk_q to force shutdown without
  151.     # waiting for timeouts.
  152.     for i in range(threads-1):
  153.         chunk_q.put(None)
  154.  
  155.     logging.info("Starting processing")
  156.     # This thread will read all of the computed chunks and stich them
  157.     # into a large image.
  158.     s = multiprocessing.Process(target=sticher, args=(image_size, work_q, cmd_q))
  159.     s.start()
  160.  
  161.     # Start all of the computing threads.
  162.     workers = [multiprocessing.Process(target=compute, args=(chunk_q, work_q)) for i in range(threads-1)]
  163.     for w in workers:
  164.         w.start()
  165.  
  166.     # Wait for the workers to finish
  167.     for w in workers:
  168.         w.join()
  169.     logging.info("All computing processes terminated")
  170.  
  171.     # Request an image
  172.     logging.info("Requesting an image")
  173.     work_q.put('image')
  174.     image = cmd_q.get(True)
  175.     logging.info("Got an image from the sticher")
  176.  
  177.     # Shut down the sticher
  178.     logging.info("Terminating the sticher")
  179.     work_q.put(None)
  180.     s.join()
  181.     logging.info("Done")
  182.  
  183.     return image
  184.  
  185. def generate_palette(p):
  186.     no_palettes = 3
  187.     if not (0 <= p < no_palettes):
  188.         raise ValueError("Must specify a palette from 0 to %d" % (palettes - 1))
  189.  
  190.     palette = []
  191.     if p == 0:
  192.         for i in range(256):
  193.             palette.extend((i, i, i))
  194.     elif p == 1:
  195.         for i in range(255, -1, -1):
  196.             palette.extend((i, i, i))
  197.     elif p == 2:
  198.         for i in range(256):
  199.             c = 255 if i else 0
  200.             palette.extend((c, c, c))
  201.  
  202.     return palette
  203.  
  204. if __name__=='__main__':
  205.     import Image
  206.  
  207.     start = time.time()
  208.     I = multi_mandel((7000, 4500), (-2, 0.5), (-1.25, 1.25), chunks=10, threads=8)
  209.     print 'Time taken:', time.time()-start
  210.  
  211.     img = Image.fromarray(I.astype('uint8'), 'L')
  212.     palette = generate_palette(1)
  213.     img.putpalette(palette)
  214.     img.save('mandel.png')
  215.  
  216.     #from pylab import *
  217.     #I[I==0] = 101
  218.     #img = imshow(I.T, origin='lower left')
  219.     #img.write_png('mandel.png', noscale=True)
  220.     #show()
  221.  
  222.  
  223.  
  224.  
  225.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement