提交 9fb8444d 编写于 作者: M Megvii Engine Team

fix(imperative): catch python exception in c++

GitOrigin-RevId: 16a2abfdad35c52d50f34783d29c2d503ab29568
上级 e3a3e0cd
......@@ -240,10 +240,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
pyf = cpp_apply_const_with_tracing;
}
auto ret = py::reinterpret_steal<py::object>(
PyObject_Call(pyf, tup.ptr(), nullptr));
auto py_ret = py::reinterpret_borrow<py::list>(ret);
if (auto* t = try_cast(py_ret[0].ptr())) {
auto py_ret = PyObject_Call(pyf, tup.ptr(), nullptr);
if (!py_ret) throw py::error_already_set();
auto py_list = py::reinterpret_steal<py::list>(py_ret);
if (auto* t = try_cast(py_list[0].ptr())) {
m_tensor = t->m_tensor;
}
return;
......@@ -389,6 +389,7 @@ PyObject* TensorWrapper::device() {
PyObject* TensorWrapper::numpy() {
if (m_tensor->m_trace_info.compiled_info != nullptr) {
PyObject* np_val = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
if (!np_val) throw py::error_already_set();
if (np_val == Py_None) {
throw TraceReadError("value of this tensor is not read in trace");
}
......@@ -478,6 +479,7 @@ PyObject* TensorWrapper::detach() {
PyObject* TensorWrapper::_dev_tensor(){
if (m_tensor->m_trace_info.compiled_info != nullptr) {
auto *dev_tensor = PyObject_CallMethod(m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
if (!dev_tensor) throw py::error_already_set();
if (dev_tensor == Py_None) {
throw TraceReadError("raw data of this tensor is not read in trace");
}
......
......@@ -31,9 +31,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
}
py::object ret = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr));
if (!ret) {
throw py::value_error("invalid py object call");
}
if (!ret) throw py::error_already_set();
// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
......@@ -58,8 +56,9 @@ apply_result_t apply_trace(ApplyContext& ctx) {
for (size_t i = 0; i < ctx.nargs; i++) {
args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
}
auto ret = py::reinterpret_steal<py::object>(
PyObject_Call(pyf, args.ptr(), nullptr));
auto pyout = PyObject_Call(pyf, args.ptr(), nullptr);
if (!pyout) throw py::error_already_set();
auto ret = py::reinterpret_steal<py::object>(pyout);
// assumption: python function always returns PyList
auto tup = py::reinterpret_borrow<py::list>(ret);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册