Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- // Init -----------------
- cudaStream_t computation_stream, all_reduce_stream;
- // Map: grad_name -> events
- map<string, cudaEvent_t> computation_event;
- map<string, cudaEvent_t> all_reduce_event;
- for (atuo& argu : param_grad) {
- cudaEventCreateWithFlags(&computation_event[argu], cudaEventDisableTiming)
- cudaEventCreateWithFlags(&all_reduce_event[argu], cudaEventDisableTiming)
- }
- // ...
- // Run -----------------
- for (auto& op : ctx->ops_) {
- // sgd should wait for allreduce to be finished
- for (auto& argu : op->InputArguments()) {
- if (param_grad_set.find(argu)) {
- cudaStreamWaitEvent(computation_stream, all_reduce_event[argu])
- }
- }
- op->Run(..., computation_stream);
- // allreduce should wait for fc_grad to be finished.
- for (auto& argu : op->OutputArguments()) {
- if (param_grad_set.find(argu)) {
- cudaEventRecord(&computation_event[argu], computation_stream);
- cudaStreamWaitEvent(all_reduce_stream, computation_event[argu])
- allreduce(argu, all_reduce_stream)
- cudaEventRecord(&all_reduce_event[argu], all_reduce_stream);
- }
- }
- }
Add Comment
Please, Sign In to add comment