Guest User

Untitled

a guest
Mar 23rd, 2018
85
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.13 KB | None | 0 0
  1. // Init -----------------
  2. cudaStream_t computation_stream, all_reduce_stream;
  3.  
  4. // Map: grad_name -> events
  5. map<string, cudaEvent_t> computation_event;
  6. map<string, cudaEvent_t> all_reduce_event;
  7. for (atuo& argu : param_grad) {
  8. cudaEventCreateWithFlags(&computation_event[argu], cudaEventDisableTiming)
  9. cudaEventCreateWithFlags(&all_reduce_event[argu], cudaEventDisableTiming)
  10. }
  11.  
  12. // ...
  13.  
  14.  
  15. // Run -----------------
  16. for (auto& op : ctx->ops_) {
  17. // sgd should wait for allreduce to be finished
  18. for (auto& argu : op->InputArguments()) {
  19. if (param_grad_set.find(argu)) {
  20. cudaStreamWaitEvent(computation_stream, all_reduce_event[argu])
  21. }
  22. }
  23.  
  24. op->Run(..., computation_stream);
  25.  
  26. // allreduce should wait for fc_grad to be finished.
  27. for (auto& argu : op->OutputArguments()) {
  28. if (param_grad_set.find(argu)) {
  29. cudaEventRecord(&computation_event[argu], computation_stream);
  30. cudaStreamWaitEvent(all_reduce_stream, computation_event[argu])
  31. allreduce(argu, all_reduce_stream)
  32. cudaEventRecord(&all_reduce_event[argu], all_reduce_stream);
  33. }
  34. }
  35. }
Add Comment
Please, Sign In to add comment