Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- diff --git a/torch/csrc/autograd/functions/convolution.cpp b/torch/csrc/autograd/functions/convolution.cpp
- index 465283c..991648c 100644
- --- a/torch/csrc/autograd/functions/convolution.cpp
- +++ b/torch/csrc/autograd/functions/convolution.cpp
- @@ -578,6 +578,16 @@ auto ConvBackwardBackward::apply(const variable_list& grad_grad_inputs) -> varia
- gI = Transpose(0, 1).apply({gIt})[0];
- }
- + auto zeros_like = [](const Variable& var) -> std::shared_ptr<Variable> {
- + auto data = var.data->newTensor();
- + data->resizeAs(*var.data).zero();
- + return std::make_shared<Variable>(std::move(data), false, false);
- + };
- +
- + if (should_compute_output(0) && !ggO) ggO = zeros_like(*gO);
- + if (should_compute_output(1) && !gI) gI = zeros_like(*input);
- + if (should_compute_output(2) && !gW) gW = zeros_like(*weight);
- +
- return {ggO, gI, gW};
- }
Add Comment
Please, Sign In to add comment