Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python2.7
- # -*- coding: utf-8 -*-
- import multiprocessing
- import multiprocessing.reduction
- multiprocessing.allow_connection_pickling()
- import sys
- from social_graph import *
- import time
- PATH_PATTERN = 'shard.%s'
- def writer((queue, path)):
- try:
- sys.stdout.write('Writer: path %s start\n' % (path))
- with open(path, 'w') as f:
- while True:
- bulk = queue.get()
- if bulk is None:
- queue.task_done()
- break
- for line in bulk:
- f.write(line)
- queue.task_done()
- sys.stdout.write('Writer: path %s done\n' % path)
- return True
- except KeyboardInterrupt:
- sys.stderr.write('Writer: path %s interrupted\n' % path)
- queue.close()
- return False
- except BaseException as error:
- sys.stdout.write('Writer: path %s problem %s\n' % (path, error))
- return False
- def spliter((hash_fun, queues, path)):
- try:
- sys.stdout.write('Spliter: path %s start\n' % path)
- shard_count = len(queues)
- bulk = dict((index, []) for index in range(shard_count))
- with open(path, 'r') as file_in:
- for line in file_in:
- try:
- index = line.index('\t')
- user_id = long(line[:index])
- except ValueError:
- continue
- shard = hash_fun(user_id) % shard_count
- bulk[shard].append(line)
- if len(bulk[shard]) >= 1024:
- queues[shard].put(bulk[shard])
- bulk[shard] = []
- for shard in sorted(bulk):
- if bulk[shard]:
- queues[shard].put(bulk[shard])
- sys.stdout.write('Spliter: path %s done\n' % path)
- except KeyboardInterrupt:
- sys.stderr.write('Spliter: path %s interrupted\n' % path)
- def inputs():
- return iter(line.replace('\n', '') for line in sys.stdin)
- def shards(shard_count):
- return iter(PATH_PATTERN % index for index in xrange(shard_count))
- def writer_tasks(queues):
- return iter((queues[index], path) for index, path in enumerate(shards(len(queues))))
- def reader_tasks(settings, queues):
- return iter((settings.hash, queues, path) for path in inputs())
- def split(settings):
- manager = multiprocessing.Manager()
- queues = [manager.JoinableQueue() for index in xrange(settings.shard)]
- pool_writer = multiprocessing.Pool(settings.shard)
- pool_reader = multiprocessing.Pool(settings.parallel)
- try:
- pool_writer.map_async(writer, writer_tasks(queues))
- pool_reader.map_async(spliter, reader_tasks(settings, queues))
- pool_reader.close()
- pool_reader.join()
- for q in queues:
- q.put(None)
- for q in queues:
- q.join()
- pool_writer.close()
- pool_writer.join()
- return True
- except KeyboardInterrupt:
- sys.stderr.write('Split Manager: interrupted\n')
- for q in queues:
- q.close()
- pool_reader.terminate()
- pool_writer.terminate()
- return False
- def sorter(path):
- try:
- path_in = path
- path_out = '%s.sorted' % path_in
- sys.stdout.write('Sorter: path %s start\n' % path)
- with open(path_in, 'r') as file_in:
- with open(path_out, 'w') as file_out:
- subprocess.Popen(['sort', '-n'], stdin=file_in, stdout=file_out).wait()
- sys.stdout.write('Sorter: path %s done\n' % path)
- subprocess.Popen(['mv', path_out, path_in]).wait()
- except KeyboardInterrupt:
- sys.stdout.write('Sorter: path %s interrupted\n' % path)
- def sort(settings):
- pool = multiprocessing.Pool(settings.parallel)
- try:
- pool.map_async(sorter, shards(settings.shard))
- pool.close()
- pool.join()
- except KeyboardInterrupt:
- sys.stderr.write('Sort Manager: interrupted\n')
- pool.terminate()
- return False
- def parse_args():
- import argparse
- parser = argparse.ArgumentParser(description='social graph repart')
- parser = argparse_group_shard(parser)
- parser = argparse_options_parallel(parser)
- parser = argparse_options_hash(parser)
- settings = parser.parse_args()
- settings = argparse_analyze_hash(settings)
- return settings
- def main():
- try:
- settings = parse_args()
- if not split(settings):
- sys.exit(1)
- if not sort(settings):
- sys.exit(1)
- except KeyboardInterrupt:
- pass
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement