Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import argparse
- import os
- import threading
- import time
- import torch
- import torch.distributed as dist
- NUM_TRIALS = 20
- def receive_tensor_helper(tensors, src_rank, groups, num_iterations,
- intra_server_broadcast):
- for tag, tensor in enumerate(tensors):
- for i in range(num_iterations):
- if intra_server_broadcast:
- dist.broadcast(tensor=tensor, group=groups[tag], src=src_rank)
- else:
- dist.recv(tensor=tensor, src=src_rank, tag=tag)
- print("Done with tensor size %s" % tensor.size())
- def send_tensor_helper(tensors, dst_rank, groups, num_iterations,
- intra_server_broadcast):
- for tag, tensor in enumerate(tensors):
- for i in range(num_iterations):
- if intra_server_broadcast:
- dist.broadcast(tensor=tensor, group=groups[tag], src=1-dst_rank)
- else:
- dist.send(tensor=tensor, dst=dst_rank, tag=tag)
- print("Done with tensor size %s" % tensor.size())
- def start_helper_thread(func, args):
- helper_thread = threading.Thread(target=func,
- args=tuple(args))
- helper_thread.start()
- return helper_thread
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(
- description='Test lightweight communication library')
- parser.add_argument("--backend", type=str, default='gloo',
- help="Backend")
- parser.add_argument('-s', "--send", action='store_true',
- help="Send tensor (if not specified, will receive tensor)")
- parser.add_argument("--master_addr", required=True, type=str,
- help="IP address of master")
- parser.add_argument("--use_helper_threads", action='store_true',
- help="Use multiple threads")
- parser.add_argument("--rank", required=True, type=int,
- help="Rank of current worker")
- parser.add_argument('-p', "--master_port", default=12345,
- help="Port used to communicate tensors")
- parser.add_argument("--intra_server_broadcast", action='store_true',
- help="Broadcast within a server")
- args = parser.parse_args()
- num_ranks_in_server = 1
- if args.intra_server_broadcast:
- num_ranks_in_server = 2
- local_rank = args.rank % num_ranks_in_server
- torch.cuda.set_device(local_rank)
- os.environ['MASTER_ADDR'] = args.master_addr
- os.environ['MASTER_PORT'] = str(args.master_port)
- world_size = 2
- dist.init_process_group(args.backend, rank=args.rank, world_size=world_size)
- tensor_sizes = [10, 100, 1000, 10000, 100000, 1000000, 10000000]
- tensors = []
- groups = []
- for tag in range(len(tensor_sizes)):
- tensor_size = tensor_sizes[tag]
- if args.intra_server_broadcast:
- group = dist.new_group([0, 1])
- groups.append(group)
- if args.send:
- if args.intra_server_broadcast:
- tensor = torch.tensor(range(tensor_size), dtype=torch.float32).cuda(
- local_rank)
- else:
- tensor = torch.tensor(range(tensor_size), dtype=torch.float32).cpu()
- else:
- if args.intra_server_broadcast:
- tensor = torch.zeros((tensor_size,), dtype=torch.float32).cuda(
- local_rank)
- else:
- tensor = torch.zeros((tensor_size,), dtype=torch.float32).cpu()
- tensors.append(tensor)
- # Should be all zeros.
- # if not args.send:
- # print(tensors)
- helper_threads = []
- if args.send:
- if args.use_helper_threads:
- helper_thread = start_helper_thread(send_tensor_helper,
- [tensors, 1-args.rank,
- groups, NUM_TRIALS,
- args.intra_server_broadcast])
- else:
- send_tensor_helper(tensors, 1-args.rank, groups,
- NUM_TRIALS, args.intra_server_broadcast)
- else:
- if args.use_helper_threads:
- helper_thread = start_helper_thread(receive_tensor_helper,
- [tensors, 1-args.rank,
- groups, NUM_TRIALS,
- args.intra_server_broadcast])
- else:
- receive_tensor_helper(tensors, 1-args.rank, groups,
- NUM_TRIALS, args.intra_server_broadcast)
- helper_thread.join()
- # Should be range(tensor_size).
- # if not args.send:
- # print(tensors)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement