Advertisement
Guest User

Untitled

a guest
Jun 15th, 2019
128
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.66 KB | None | 0 0
  1. import argparse
  2. import os
  3. import threading
  4. import time
  5. import torch
  6. import torch.distributed as dist
  7.  
  8. NUM_TRIALS = 20
  9.  
  10. def receive_tensor_helper(tensors, src_rank, groups, num_iterations,
  11. intra_server_broadcast):
  12. for tag, tensor in enumerate(tensors):
  13. for i in range(num_iterations):
  14. if intra_server_broadcast:
  15. dist.broadcast(tensor=tensor, group=groups[tag], src=src_rank)
  16. else:
  17. dist.recv(tensor=tensor, src=src_rank, tag=tag)
  18. print("Done with tensor size %s" % tensor.size())
  19.  
  20. def send_tensor_helper(tensors, dst_rank, groups, num_iterations,
  21. intra_server_broadcast):
  22. for tag, tensor in enumerate(tensors):
  23. for i in range(num_iterations):
  24. if intra_server_broadcast:
  25. dist.broadcast(tensor=tensor, group=groups[tag], src=1-dst_rank)
  26. else:
  27. dist.send(tensor=tensor, dst=dst_rank, tag=tag)
  28. print("Done with tensor size %s" % tensor.size())
  29.  
  30. def start_helper_thread(func, args):
  31. helper_thread = threading.Thread(target=func,
  32. args=tuple(args))
  33. helper_thread.start()
  34. return helper_thread
  35.  
  36.  
  37. if __name__ == '__main__':
  38. parser = argparse.ArgumentParser(
  39. description='Test lightweight communication library')
  40. parser.add_argument("--backend", type=str, default='gloo',
  41. help="Backend")
  42. parser.add_argument('-s', "--send", action='store_true',
  43. help="Send tensor (if not specified, will receive tensor)")
  44. parser.add_argument("--master_addr", required=True, type=str,
  45. help="IP address of master")
  46. parser.add_argument("--use_helper_threads", action='store_true',
  47. help="Use multiple threads")
  48. parser.add_argument("--rank", required=True, type=int,
  49. help="Rank of current worker")
  50. parser.add_argument('-p', "--master_port", default=12345,
  51. help="Port used to communicate tensors")
  52. parser.add_argument("--intra_server_broadcast", action='store_true',
  53. help="Broadcast within a server")
  54.  
  55. args = parser.parse_args()
  56.  
  57. num_ranks_in_server = 1
  58. if args.intra_server_broadcast:
  59. num_ranks_in_server = 2
  60. local_rank = args.rank % num_ranks_in_server
  61. torch.cuda.set_device(local_rank)
  62.  
  63. os.environ['MASTER_ADDR'] = args.master_addr
  64. os.environ['MASTER_PORT'] = str(args.master_port)
  65. world_size = 2
  66. dist.init_process_group(args.backend, rank=args.rank, world_size=world_size)
  67.  
  68. tensor_sizes = [10, 100, 1000, 10000, 100000, 1000000, 10000000]
  69.  
  70. tensors = []
  71. groups = []
  72. for tag in range(len(tensor_sizes)):
  73. tensor_size = tensor_sizes[tag]
  74. if args.intra_server_broadcast:
  75. group = dist.new_group([0, 1])
  76. groups.append(group)
  77. if args.send:
  78. if args.intra_server_broadcast:
  79. tensor = torch.tensor(range(tensor_size), dtype=torch.float32).cuda(
  80. local_rank)
  81. else:
  82. tensor = torch.tensor(range(tensor_size), dtype=torch.float32).cpu()
  83. else:
  84. if args.intra_server_broadcast:
  85. tensor = torch.zeros((tensor_size,), dtype=torch.float32).cuda(
  86. local_rank)
  87. else:
  88. tensor = torch.zeros((tensor_size,), dtype=torch.float32).cpu()
  89.  
  90. tensors.append(tensor)
  91.  
  92. # Should be all zeros.
  93. # if not args.send:
  94. # print(tensors)
  95.  
  96. helper_threads = []
  97. if args.send:
  98. if args.use_helper_threads:
  99. helper_thread = start_helper_thread(send_tensor_helper,
  100. [tensors, 1-args.rank,
  101. groups, NUM_TRIALS,
  102. args.intra_server_broadcast])
  103. else:
  104. send_tensor_helper(tensors, 1-args.rank, groups,
  105. NUM_TRIALS, args.intra_server_broadcast)
  106. else:
  107. if args.use_helper_threads:
  108. helper_thread = start_helper_thread(receive_tensor_helper,
  109. [tensors, 1-args.rank,
  110. groups, NUM_TRIALS,
  111. args.intra_server_broadcast])
  112. else:
  113. receive_tensor_helper(tensors, 1-args.rank, groups,
  114. NUM_TRIALS, args.intra_server_broadcast)
  115. helper_thread.join()
  116. # Should be range(tensor_size).
  117. # if not args.send:
  118. # print(tensors)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement