提交 7252825c 编写于 作者: M Megvii Engine Team

fix(functional): broadcast_to supports mutable target shape

GitOrigin-RevId: ff79456d5d2d669d20112d57fdeb255ae837e868
上级 2484cd27
......@@ -924,78 +924,67 @@ bool enable_fastpath(py::handle inp) {
return true;
}
py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) {
py::object shape_hdl = _expand_args(args);
bool auto_infer = false;
py::list lis;
py::list new_shape;
if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) {
lis = py::reinterpret_steal<py::list>(PySequence_List(shape_hdl.ptr()));
for (size_t i = 0; i < lis.size(); ++i) {
if (lis[i].is_none()) {
auto_infer = true;
size_t right = lis.size() - i;
py::object tshp = getattr(inp_hdl, "_tuple_shape");
if (tshp.is_none()) {
throw py::index_error("does not support `None` with unknown shape");
py::object _broadcast_cpp(py::handle input, py::handle args) {
py::object shape = _expand_args(args);
py::list dims;
bool all_imm;
if (PyList_Check(shape.ptr()) || PyTuple_Check(shape.ptr())) {
dims = py::reinterpret_steal<py::list>(PySequence_List(shape.ptr()));
mgb_assert(!dims.is_none());
all_imm = true;
py::object inp_shape = py::none();
size_t inp_ndim;
for (size_t i = 0; i < dims.size(); ++i) {
py::object dim = dims[i];
if (dim.is_none()) {
ptrdiff_t right = (ptrdiff_t)i - dims.size();
if (inp_shape.is_none()) {
inp_shape = input.attr("shape");
mgb_assert(!inp_shape.is_none());
inp_ndim = py::len(inp_shape);
}
py::tuple inp_shape = py::reinterpret_borrow<py::tuple>(tshp);
if (inp_shape.size() >= right) {
if (enable_fastpath(inp_hdl)) {
lis[i] = inp_shape[inp_shape.size() - right];
}
new_shape.append(inp_shape[inp_shape.size() - right]);
} else {
throw py::value_error("invalid broadcast shape");
if ((ptrdiff_t)inp_ndim + right < 0) {
throw py::value_error("size connot be `None` for new axis");
}
} else {
new_shape.append(lis[i]);
if (PyLong_Check(lis[i].ptr())) {
int32_t s = lis[i].cast<int32_t>();
if (s < 0) {
throw py::value_error(
"expect shape[" + std::to_string(i) +
"] >= 0 or use `None` to auto infer, got " +
std::to_string(s));
}
dim = inp_shape.attr("__getitem__")(right);
dims[i] = dim;
}
if (py::int_::check_(dim)) {
if (dim.cast<long>() < 0) {
throw py::value_error(ssprintf(
"expect shape[%zu] >= 0 or use `None` to auto infer, got "
"%s",
i, py::repr(dims[i]).cast<std::string>().c_str()));
}
} else {
all_imm = false;
}
}
shape = dims;
} else {
all_imm = false;
}
if (auto_infer) {
if (enable_fastpath(inp_hdl)) {
shape_hdl = py::reinterpret_borrow<py::tuple>(lis);
} else {
shape_hdl = _astensor1d_cpp(
new_shape, py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device"), inp_hdl);
}
bool fastpath = all_imm && enable_fastpath(input);
if ((!fastpath) && (!is_tensor(shape))) {
shape = _astensor1d_cpp(
shape, py::cast((mgb::DType)dtype::Int32()), input.attr("device"),
input);
}
py::object shape_tuple;
try {
shape_tuple = _make_shape_tuple(shape_hdl);
} catch (py::error_already_set& err) {
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
}
auto [shape, fastpath] = tuple2vector(shape_tuple);
fastpath &= enable_fastpath(inp_hdl);
std::shared_ptr<OpDef> op;
std::vector<PyObject*> p;
py::object shape_tensor;
SmallVector<PyObject*> p(2);
if (fastpath) {
op = Broadcast::make(shape);
p.resize(2);
std::vector<int32_t> shape_vec;
for (auto&& dim : dims) {
shape_vec.push_back(dim.cast<long>());
}
op = Broadcast::make(shape_vec);
} else {
op = Broadcast::make();
shape_tensor = _astensor1d_cpp(
shape_hdl, py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device"), inp_hdl);
p.resize(3);
p[2] = shape_tensor.ptr();
p.push_back(shape.ptr());
}
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::object py_op = py::cast(op);
p[0] = py_op.ptr();
p[1] = input.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
......@@ -1675,4 +1664,4 @@ PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
} // namespace mgb::imperative::python
\ No newline at end of file
} // namespace mgb::imperative::python
......@@ -753,6 +753,40 @@ def test_broadcast_on_empty_tensor(is_trace):
test(func, inp, comp, target_shp)
@pytest.mark.parametrize(
"input_shape, target_shapes",
[
((3,), [(2, 1, 3), (1, 2, 3), (2, 2, 3)]),
((1, 3, 1), [(2, None, 3), (3, None, 3), (1, None, 1)]),
],
)
@pytest.mark.parametrize("is_symbolic", [True, False])
def test_broadcast_on_trace(is_symbolic, input_shape, target_shapes):
x = F.ones(input_shape)
@trace(symbolic=is_symbolic)
def broadcast(inp, shape):
return F.broadcast_to(inp, shape)
for target_shape in target_shapes:
if None in target_shape:
symbolic_target_shape = tuple(
map(lambda x: None if x is None else Tensor(x), target_shape)
)
output = broadcast(x, symbolic_target_shape)
for i in range(len(target_shape)):
if target_shape[i] is not None:
assert output._tuple_shape[i] == target_shape[i]
else:
assert (
output._tuple_shape[i] == x._tuple_shape[i - len(target_shape)]
)
else:
symbolic_target_shape = Tensor(target_shape)
output = broadcast(x, symbolic_target_shape)
assert output._tuple_shape == target_shape
@pytest.mark.parametrize("is_varnode", [True, False])
def test_utils_astensor1d(is_varnode):
if is_varnode:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册