Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import time
- import logging
- import multiprocessing
- from Queue import Empty # awkward
- import numpy
- from numpy import multiply, add # Frequently used in loop
- # Set to logging.INFO for debugging
- logging.basicConfig(level=logging.ERROR)
- class Chunk(object):
- def __init__(self, size, base, r_range, i_range):
- self.size = size
- self.base = base
- self.r_range = r_range
- self.i_range = i_range
- self.image = None
- self.compute_time = 0.0
- self.iterations = 100
- self.escape = 2.0
- # I was using numpy to quickly calculate the Mandelbrot set but this version is
- # much, much slicker:
- # http://thesamovar.wordpress.com/2009/03/22/fast-fractals-with-python-and-numpy/
- # ...So I'm using it instead (with a few minor changes)
- def mandelbrot(size,
- r_start=-2, r_stop=0.5,
- i_start=-1.25, i_stop=1.25,
- iterations=256, escape=2.0):
- ix, iy = numpy.mgrid[0:size[0], 0:size[1]]
- x = numpy.linspace(r_start, r_stop, size[0])[ix]
- y = numpy.linspace(i_start, i_stop, size[1])[iy]
- c = x + numpy.complex(0, 1) * y
- del x, y # Seems a little excessive
- results = numpy.zeros(c.shape, numpy.uint8)
- # This is the really cool part.
- ix.shape = size[0] * size[1]
- iy.shape = size[0] * size[1]
- c.shape = size[0] * size[1]
- z = numpy.copy(c)
- for i in xrange(iterations):
- if not len(z): break
- multiply(z, z, z) # z * z -> z [fast square]
- add(z, c, z) # z + c -> z [fast add]
- # tally escaped points
- rem = abs(z) > escape
- results[ix[rem], iy[rem]] = i + 1
- # prune points already escaped
- rem = -rem
- z = z[rem]
- ix, iy = ix[rem], iy[rem]
- c = c[rem]
- return results
- def compute(input_q, output_q):
- name = multiprocessing.current_process().name
- logging.info("Starting compute thread: %r" % name)
- while True:
- try:
- logging.info("Worker %r reading input queue." % name)
- chunk = input_q.get(True, 1.0)
- except Empty:
- logging.info("Worker %r queue empty." % name)
- break
- if chunk is None: break
- logging.info("Worker %r starting computation." % name)
- start = time.time()
- chunk.image = mandelbrot(chunk.size,
- chunk.r_range[0], chunk.r_range[1],
- chunk.i_range[0], chunk.i_range[1],
- chunk.iterations, chunk.escape)
- chunk.compute_time = time.time() - start
- logging.info("Worker %r finished computation in %.3f seconds." % (name, chunk.compute_time))
- output_q.put(chunk)
- logging.info("Worker %r pushed chunk %s into output queue." % (name, id(chunk)))
- logging.info("Shutting down compute thread: %r" % name)
- def sticher(image_size, input_q, output_q):
- name = multiprocessing.current_process().name
- logging.info("Starting stiching thread: %r" % name)
- image = numpy.zeros(image_size, numpy.uint8)
- fragments = 0
- total_time = 0.0
- while True:
- try:
- logging.info("Sticher reading data from queue")
- record = input_q.get(True, 2.0)
- except Empty:
- logging.info("Sticher found no data in queue")
- continue
- if record is None:
- logging.info("Sticher got a shutdown command")
- break
- elif record == 'image':
- logging.info("Sticher returning an image")
- output_q.put(image)
- elif isinstance(record, Chunk):
- logging.info("Sticher processing chunk: %s" % id(record))
- x, y = record.image.shape
- image[record.base[0]:record.base[0]+x, record.base[1]:record.base[1]+y] = record.image
- fragments += 1
- total_time += record.compute_time
- logging.info("Total chunks processed: %d" % fragments)
- logging.info("Total chunk compute time: %.3f seconds" % total_time)
- logging.info("Shutting down stiching thread: %r" % name)
- def multi_mandel(image_size, r_range, i_range, chunks=10, threads=8):
- if any([i % chunks for i in image_size]):
- raise ValueError("Both image_size dimensions must be divisable by chunks")
- if threads < 2:
- raise ValueError("Threads must be >= 2")
- logging.info("Initializing queues")
- # Work units start out here. Read by the workers.
- chunk_q = multiprocessing.Queue()
- # Workers write to the work_q and it is read by the sticher.
- work_q = multiprocessing.Queue()
- # Control communication with the sticher
- cmd_q = multiprocessing.Queue()
- r = numpy.linspace(r_range[0], r_range[1], chunks+1)
- i = numpy.linspace(i_range[0], i_range[1], chunks+1)
- chunk_size = tuple(i / chunks for i in image_size)
- logging.info("Building work units")
- # Push data into the chunk_q
- for y in range(chunks):
- for x in range(chunks):
- base = x * chunk_size[0], y * chunk_size[1]
- r_range = r[x], r[x+1]
- i_range = i[y], i[y+1]
- chunk_q.put(Chunk(chunk_size, base, r_range, i_range))
- # Add a bunch of Nones to chunk_q to force shutdown without
- # waiting for timeouts.
- for i in range(threads-1):
- chunk_q.put(None)
- logging.info("Starting processing")
- # This thread will read all of the computed chunks and stich them
- # into a large image.
- s = multiprocessing.Process(target=sticher, args=(image_size, work_q, cmd_q))
- s.start()
- # Start all of the computing threads.
- workers = [multiprocessing.Process(target=compute, args=(chunk_q, work_q)) for i in range(threads-1)]
- for w in workers:
- w.start()
- # Wait for the workers to finish
- for w in workers:
- w.join()
- logging.info("All computing processes terminated")
- # Request an image
- logging.info("Requesting an image")
- work_q.put('image')
- image = cmd_q.get(True)
- logging.info("Got an image from the sticher")
- # Shut down the sticher
- logging.info("Terminating the sticher")
- work_q.put(None)
- s.join()
- logging.info("Done")
- return image
- def generate_palette(p):
- no_palettes = 3
- if not (0 <= p < no_palettes):
- raise ValueError("Must specify a palette from 0 to %d" % (palettes - 1))
- palette = []
- if p == 0:
- for i in range(256):
- palette.extend((i, i, i))
- elif p == 1:
- for i in range(255, -1, -1):
- palette.extend((i, i, i))
- elif p == 2:
- for i in range(256):
- c = 255 if i else 0
- palette.extend((c, c, c))
- return palette
- if __name__=='__main__':
- import Image
- start = time.time()
- I = multi_mandel((7000, 4500), (-2, 0.5), (-1.25, 1.25), chunks=10, threads=8)
- print 'Time taken:', time.time()-start
- img = Image.fromarray(I.astype('uint8'), 'L')
- palette = generate_palette(1)
- img.putpalette(palette)
- img.save('mandel.png')
- #from pylab import *
- #I[I==0] = 101
- #img = imshow(I.T, origin='lower left')
- #img.write_png('mandel.png', noscale=True)
- #show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement