Advertisement
Guest User

Untitled

a guest
Mar 16th, 2023
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 4.65 KB | Source Code | 0 0
  1. /*
  2. export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$PWD/nccl_2.17.1-1+cuda11.0_x86_64/lib/:$LD_LIBRARY_PATH
  3. export LIBRARY_PATH=/usr/local/cuda/lib64:$PWD/nccl_2.17.1-1+cuda11.0_x86_64/lib/:$LIBRARY_PATH
  4. export C_INCLUDE_PATH=/usr/local/cuda/include/:$PWD/nccl_2.17.1-1+cuda11.0_x86_64/include/:$C_INCLUDE_PATH
  5. export CPLUS_INCLUDE_PATH=/usr/local/cuda/include/:$PWD/nccl_2.17.1-1+cuda11.0_x86_64/include/:$CPLUS_INCLUDE_PATH
  6. g++ send_recv.cc -lpthread -lcudart -lnccl
  7. */
  8.  
  9. #include <unistd.h>
  10.  
  11. #include <cassert>
  12. #include <chrono>
  13. #include <functional>
  14. #include <iostream>
  15. #include <memory>
  16. #include <mutex>
  17. #include <string>
  18. #include <thread>
  19. #include <vector>
  20.  
  21. #include "cuda_runtime.h"
  22. #include "nccl.h"
  23.  
  24. ncclUniqueId ncclId;
  25.  
  26. #define CUDACHECK(cmd)                                              \
  27.   do {                                                              \
  28.     cudaError_t e = cmd;                                            \
  29.     if (e != cudaSuccess) {                                         \
  30.       printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
  31.              cudaGetErrorString(e));                                \
  32.       assert(false);                                                \
  33.     }                                                               \
  34.   } while (0)
  35. #define NCCLCHECK(cmd)                                              \
  36.   do {                                                              \
  37.     ncclResult_t r = cmd;                                           \
  38.     if (r != ncclSuccess) {                                         \
  39.       printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
  40.              ncclGetErrorString(r));                                \
  41.       assert(false);                                                \
  42.     }                                                               \
  43.   } while (0)
  44.  
  45. #define ASYNC 1
  46.  
  47. const int device_count = 4;
  48.  
  49. void custom_recv(int dev_id, int device_count) {
  50.     cudaSetDevice(dev_id);
  51.     ncclComm_t comm;
  52.     #if NCCL_VERSION >= 21700
  53.     ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
  54.     NCCLCHECK(ncclCommInitRankConfig(&comm, device_count, ncclId, dev_id, &config));
  55.     #else
  56.     NCCLCHECK(ncclCommInitRank(&comm, device_count, ncclId, dev_id));
  57.     #endif
  58.     cudaStream_t stream;
  59.     CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
  60.     void* buffer;
  61.     CUDACHECK(cudaMalloc(&buffer, sizeof(float)));
  62.  
  63.     ncclGroupStart();
  64.     NCCLCHECK(ncclRecv(buffer, 1, ncclFloat32, 0, comm, stream));
  65.     ncclGroupEnd();
  66.  
  67.     float host_buf;
  68.     #if ASYNC
  69.     CUDACHECK(cudaMemcpyAsync(&host_buf, buffer, sizeof(float), cudaMemcpyDeviceToHost, stream));
  70.     // void* buffer2;
  71.     // CUDACHECK(cudaMallocAsync(&buffer2, sizeof(float)*1024, stream));
  72.     // CUDACHECK(cudaStreamSynchronize(stream));
  73.     #else
  74.     cudaStream_t copy_stream;
  75.     CUDACHECK(cudaStreamCreateWithFlags(&copy_stream, cudaStreamNonBlocking));
  76.     CUDACHECK(cudaMemcpyAsync(&host_buf, buffer, sizeof(float), cudaMemcpyDeviceToHost, copy_stream));
  77.     CUDACHECK(cudaStreamSynchronize(stream));
  78.     #endif
  79.     assert(host_buf == 1);
  80. end:
  81.     std::cout << "device " << dev_id << " recv done" << std::endl;
  82. }
  83.  
  84. int main(int argc, char** argv) {
  85.   {
  86.     NCCLCHECK(ncclGetUniqueId(&ncclId));
  87.  
  88.     // send thread, 0 -> 1 2 3
  89.     std::thread thr0([]() {
  90.       cudaSetDevice(0);
  91.       ncclComm_t comm;
  92.       #if NCCL_VERSION >= 21700
  93.       ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
  94.       NCCLCHECK(ncclCommInitRankConfig(&comm, device_count, ncclId, 0, &config));
  95.       #else
  96.       NCCLCHECK(ncclCommInitRank(&comm, device_count, ncclId, 0));
  97.       #endif
  98.       cudaStream_t stream;
  99.       CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
  100.       void* buffer[device_count];
  101.       for (int i = 1; i < device_count; i++) {
  102.         CUDACHECK(cudaMalloc(&buffer[i], sizeof(float)));
  103.       }
  104.       float r = 1;
  105.       for (int i = 1; i < device_count; i++) {
  106.         CUDACHECK(cudaMemcpy(buffer[i], &r, sizeof(float), cudaMemcpyHostToDevice));
  107.       }
  108.  
  109.       NCCLCHECK(ncclGroupStart());
  110.       for (int i = 1; i < device_count; i++) {
  111.         NCCLCHECK(ncclSend(buffer[i], 1, ncclFloat32, i, comm, stream));
  112.       }
  113.       NCCLCHECK(ncclGroupEnd());
  114.  
  115.       CUDACHECK(cudaStreamSynchronize(stream));
  116.       std::cout << "device 0 send done" << std::endl;
  117.     });
  118.  
  119.     // recv thread: i <- 0
  120.     std::vector<std::thread> threads;
  121.     for (int i = 1; i < device_count; i++) {
  122.       threads.push_back(std::thread(custom_recv, i, device_count));
  123.     }
  124.  
  125.     thr0.join();
  126.     for (auto& thr : threads) {
  127.       thr.join();
  128.     }
  129.   }
  130. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement