提交 cf3f58cb 编写于 作者: M Megvii Engine Team

fix(mge/autodiff): fix segfault when grad is nullptr

GitOrigin-RevId: 6139212bfdc75ac7af5275436f5557ba487673e7
上级 288c2e08
......@@ -59,6 +59,9 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(2);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) {
ret[i] = reduce_to(grad, shapes[i].get());
......@@ -84,6 +87,9 @@ apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(2);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) {
ret[i] = reshape_to(grad, shapes[i].get());
......@@ -107,10 +113,10 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak
maker.output_size(1).output_captured(0, false);
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
if (inputs[0]) {
if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1);
Tensor* grad = grads[0];
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
args_[0] = zeros.get();
args_[1] = grad;
......@@ -137,10 +143,10 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward:
maker.output_size(1).output_captured(0, false);
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
if (inputs[0]) {
if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1);
Tensor* grad = grads[0];
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
args_[0] = zeros.get();
args_[1] = grad;
......@@ -167,7 +173,7 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
if (shapes[0]) {
if (grad && shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0].get());
}
return ret;
......@@ -180,14 +186,17 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<AddAxis>();
mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
auto&& grad_op = RemoveAxis::make(op.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
ret[0] = python::apply(grad_op_, grad)[0];
if (grad && flag_) {
ret[0] = python::apply(grad_op_, grad)[0];
}
return ret;
});
return apply(ctx);
......@@ -196,14 +205,17 @@ apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<RemoveAxis>();
mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
auto&& grad_op = AddAxis::make(op.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end());
maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1);
ret[0] = python::apply(grad_op_, grad)[0];
if (grad && flag_) {
ret[0] = python::apply(grad_op_, grad)[0];
}
return ret;
});
return apply(ctx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册