未验证 提交 edb9aff5 编写于 作者: L liu zhengxi 提交者: GitHub

update gather tree error msg (#34322)

上级 04288091
......@@ -50,6 +50,14 @@ class GatherTreeOpCUDAKernel : public framework::OpKernel<T> {
const auto *parents_data = parents->data<T>();
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(
ids_data, platform::errors::InvalidArgument(
"Input(Ids) of gather_tree should not be null."));
PADDLE_ENFORCE_NOT_NULL(
parents_data, platform::errors::InvalidArgument(
"Input(Parents) of gather_tree should not be null."));
auto &ids_dims = ids->dims();
int64_t max_length = ids_dims[0];
int64_t batch_size = ids_dims[1];
......
......@@ -38,6 +38,14 @@ class GatherTreeOpKernel : public framework::OpKernel<T> {
auto batch_size = ids_dims[1];
auto beam_size = ids_dims[2];
PADDLE_ENFORCE_NOT_NULL(
ids_data, platform::errors::InvalidArgument(
"Input(Ids) of gather_tree should not be null."));
PADDLE_ENFORCE_NOT_NULL(
parents_data, platform::errors::InvalidArgument(
"Input(Parents) of gather_tree should not be null."));
for (int batch = 0; batch < batch_size; batch++) {
for (int beam = 0; beam < beam_size; beam++) {
auto idx = (max_length - 1) * batch_size * beam_size +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册