提交 8dd69e87 编写于 作者: M Megvii Engine Team

perf(imperative): speed up the hot code of reshape

GitOrigin-RevId: 97cc37198240b284b106f9614280801351853ce2
上级 022dbea8
......@@ -1747,6 +1747,54 @@ PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
mgb_assert(nargs == 2, "reshape should have 2 args but give %zu", nargs);
PyObject* pyinp = args[0];
PyObject* pyshp = args[1];
if ((PyList_CheckExact(pyshp) || PyTuple_CheckExact(pyshp) ||
PyLong_Check(pyshp)) &&
enable_fastpath(pyinp)) {
using OptionalAxisV1 = ::megdnn::param::OptionalAxisV1;
int32_t unspec_axis = OptionalAxisV1::INVALID_AXIS;
std::vector<int32_t> shape;
bool is_int_arrays = true;
// judge whether the target shape is int arrays like (int,), [int], int
if (PyLong_Check(pyshp)) {
shape.push_back(static_cast<int32_t>(PyLong_AsLong(pyshp)));
if (shape[0] == -1) {
unspec_axis = 0;
}
} else {
size_t len = PySequence_Fast_GET_SIZE(pyshp);
shape.resize(len);
for (size_t i = 0; i < len; ++i) {
auto obj = PySequence_Fast_GET_ITEM(pyshp, i);
if (!PyLong_Check(obj)) {
is_int_arrays = false;
break;
}
shape[i] = PyLong_AsLong(obj);
if (shape[i] == -1) {
mgb_assert(unspec_axis == OptionalAxisV1::INVALID_AXIS);
unspec_axis = i;
}
}
}
// if the target shape is the int arrays, we dispatch directly
if (is_int_arrays) {
std::shared_ptr<OpDef> op = Reshape::make(unspec_axis, shape);
std::vector<PyObject*> packed_param(2);
py::object Op = py::cast(op);
packed_param[0] = Op.ptr();
packed_param[1] = pyinp;
py::tuple ret = py::reinterpret_steal<py::object>(
py_apply(NULL, packed_param.data(), packed_param.size()));
return py::object(ret[0]).release().ptr();
}
}
// fallback
return _reshape_cpp(args[0], args[1]).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册