未验证 提交 a73064f2 编写于 作者: C chentianyu03 提交者: GitHub

pylayer support tuple/list type args (#37727)

上级 6ff19d66
...@@ -101,6 +101,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, ...@@ -101,6 +101,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
"`%s` type argument can not be cast into `Tensor`.", "`%s` type argument can not be cast into `Tensor`.",
ptr->ptr()->ob_type->tp_name)); ptr->ptr()->ob_type->tp_name));
} }
} else if (py::isinstance<py::tuple>(*ptr) ||
py::isinstance<py::list>(*ptr)) {
try {
auto tuple_arg = ptr->cast<py::tuple>();
for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) {
try {
auto t = iter->cast<std::shared_ptr<VarBase>>();
input_vars.push_back(t);
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, "
"the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->ptr()->ob_type->tp_name));
}
}
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->ptr()->ob_type->tp_name));
}
} }
} }
} }
...@@ -119,6 +141,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls, ...@@ -119,6 +141,28 @@ py::object PyLayerApply(const platform::Place& place, const py::handle& cls,
"`%s` type argument can not be cast into `Tensor`.", "`%s` type argument can not be cast into `Tensor`.",
ptr->second.ptr()->ob_type->tp_name)); ptr->second.ptr()->ob_type->tp_name));
} }
} else if (py::isinstance<py::tuple>(*ptr->second) ||
py::isinstance<py::list>(*ptr->second)) {
try {
auto tuple_arg = ptr->second.cast<py::tuple>();
for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) {
try {
auto t = iter->cast<std::shared_ptr<VarBase>>();
input_vars.push_back(t);
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, "
"the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->second.ptr()->ob_type->tp_name));
}
}
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->second.ptr()->ob_type->tp_name));
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册