提交 cf3f58cb 编写于 作者: M Megvii Engine Team

fix(mge/autodiff): fix segfault when grad is nullptr

GitOrigin-RevId: 6139212bfdc75ac7af5275436f5557ba487673e7
上级 288c2e08
...@@ -59,6 +59,9 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make ...@@ -59,6 +59,9 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0]; Tensor* grad = grads[0];
apply_result_t ret(2); apply_result_t ret(2);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) { if (shapes[i]) {
ret[i] = reduce_to(grad, shapes[i].get()); ret[i] = reduce_to(grad, shapes[i].get());
...@@ -84,6 +87,9 @@ apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker ...@@ -84,6 +87,9 @@ apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0]; Tensor* grad = grads[0];
apply_result_t ret(2); apply_result_t ret(2);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) { for (size_t i = 0; i < 2; ++i) {
if (shapes[i]) { if (shapes[i]) {
ret[i] = reshape_to(grad, shapes[i].get()); ret[i] = reshape_to(grad, shapes[i].get());
...@@ -107,10 +113,10 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak ...@@ -107,10 +113,10 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1); apply_result_t ret(1);
if (inputs[0]) { if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1); SmallVector<Tensor*> args_(inputs.size()+1);
Tensor* grad = grads[0];
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
args_[0] = zeros.get(); args_[0] = zeros.get();
args_[1] = grad; args_[1] = grad;
...@@ -137,10 +143,10 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward: ...@@ -137,10 +143,10 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward:
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { maker.backward([inputs=std::move(inputs), grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0];
apply_result_t ret(1); apply_result_t ret(1);
if (inputs[0]) { if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1); SmallVector<Tensor*> args_(inputs.size()+1);
Tensor* grad = grads[0];
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
args_[0] = zeros.get(); args_[0] = zeros.get();
args_[1] = grad; args_[1] = grad;
...@@ -167,7 +173,7 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) ...@@ -167,7 +173,7 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0]; Tensor* grad = grads[0];
apply_result_t ret(1); apply_result_t ret(1);
if (shapes[0]) { if (grad && shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0].get()); ret[0] = broadcast_to(grad, shapes[0].get());
} }
return ret; return ret;
...@@ -180,14 +186,17 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) ...@@ -180,14 +186,17 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<AddAxis>(); auto&& op = ctx.op->cast_final_safe<AddAxis>();
mgb_assert(ctx.nargs == 1); mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
auto&& grad_op = RemoveAxis::make(op.axis); auto&& grad_op = RemoveAxis::make(op.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>()); std::sort(grad_op->axis.begin(), grad_op->axis.end(), std::greater<int32_t>());
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0]; Tensor* grad = grads[0];
apply_result_t ret(1); apply_result_t ret(1);
ret[0] = python::apply(grad_op_, grad)[0]; if (grad && flag_) {
ret[0] = python::apply(grad_op_, grad)[0];
}
return ret; return ret;
}); });
return apply(ctx); return apply(ctx);
...@@ -196,14 +205,17 @@ apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker ...@@ -196,14 +205,17 @@ apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { apply_result_t removeAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto&& op = ctx.op->cast_final_safe<RemoveAxis>(); auto&& op = ctx.op->cast_final_safe<RemoveAxis>();
mgb_assert(ctx.nargs == 1); mgb_assert(ctx.nargs == 1);
bool flag = input_requires_grad(ctx, 0);
auto&& grad_op = AddAxis::make(op.axis); auto&& grad_op = AddAxis::make(op.axis);
std::sort(grad_op->axis.begin(), grad_op->axis.end()); std::sort(grad_op->axis.begin(), grad_op->axis.end());
maker.output_size(1).output_captured(0, false); maker.output_size(1).output_captured(0, false);
maker.backward([grad_op_=std::move(grad_op)](BackwardContext&, Tensor*const* grads, size_t ngrads) { maker.backward([grad_op_=std::move(grad_op), flag_=flag](BackwardContext&, Tensor*const* grads, size_t ngrads) {
mgb_assert(ngrads == 1); mgb_assert(ngrads == 1);
Tensor* grad = grads[0]; Tensor* grad = grads[0];
apply_result_t ret(1); apply_result_t ret(1);
ret[0] = python::apply(grad_op_, grad)[0]; if (grad && flag_) {
ret[0] = python::apply(grad_op_, grad)[0];
}
return ret; return ret;
}); });
return apply(ctx); return apply(ctx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册