提交 06ba043a 编写于 作者: M Megvii Engine Team

fix(imperative): fix subtensor some error

GitOrigin-RevId: bcc0307d67c66b4b7c9237775cdbe3b4360fdef5
上级 a60ad267
......@@ -435,7 +435,8 @@ def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
"""
op = builtin.FillLike(value=value)
(rst,) = apply(op, inp)
rst.format = inp.format
# rst.format = inp.format
# see jira:MGE-4505
return rst
......
......@@ -1208,6 +1208,11 @@ py::object _fastpath_getitem_cpp(py::handle inp_hdl, py::tuple tuple_val) {
ax += 1;
} else if (PyBool_Check(t.ptr())) {
expand_items.push_back(ax);
if (t.ptr() == Py_False) {
cpp_items.push_back({ax, true, true, true, false});
slice_items.push_back({0, 0, 1, INT_MAX});
}
ax += 1;
} else if (t.ptr() == Py_None) {
expand_items.push_back(ax);
ax += 1;
......
......@@ -342,6 +342,46 @@ def test_subtensor():
np.array([[0, 0, 0], [1, 1, 0], [0, 0, 0]], dtype=np.float32), x.grad.numpy()
)
x_np = np.random.rand(3, 2).astype("float32")
x = mge.Tensor(x_np)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}
def f(x):
x = x * 1
y = x[True, 0:1]
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[1, 1], [0, 0], [0, 0]], dtype=np.float32), x.grad.numpy()
)
with Grad() as grad:
grad.wrt(x, callback=save_to(x))
refs = {}
def f(x):
x = x * 1
y = x[False, 0:1]
refs["x"] = TensorWeakRef(x)
return y
y = f(x)
for _, r in refs.items():
assert r() is None
grad(y, F.ones_like(y))
np.testing.assert_equal(
np.array([[0, 0], [0, 0], [0, 0]], dtype=np.float32), x.grad.numpy()
)
def test_IndexingMultiAxisVec():
x_np = np.random.rand(3, 3).astype("float32")
......
......@@ -84,8 +84,8 @@ TensorLayout deduce_layout(
return 0;
return v < 0 ? v + size_ax : v;
};
#define CHECK(cond) \
mgb_assert(cond, "index out of bound: layout=%s", src.to_string().c_str())
auto tostr = [](int v) -> std::string { return std::to_string(v); };
for (int i = items.size() - 1; i >= 0; i--) {
auto&& [axis, b_flag, e_flag, s_flag, idx_flag] = items[i];
......@@ -99,16 +99,28 @@ TensorLayout deduce_layout(
slice_stop = mod_size(slice_stop, shape_axis);
slice_stop = std::min(slice_stop, shape_axis);
slice_start = std::min(slice_start, slice_stop);
CHECK(slice_start >= 0 && slice_stop >= slice_start &&
slice_stop <= shape_axis);
mgb_assert(
(slice_start >= 0 && slice_stop >= slice_start &&
slice_stop <= shape_axis),
"index out of bound: layout=%s; request begin=%s end=%s step=%s "
"axis=%s",
src.to_string().c_str(), tostr(slice_start).c_str(),
tostr(slice_stop).c_str(), tostr(slice_step).c_str(),
tostr(axis).c_str());
} else {
slice_start = s_val == INT_MIN ? shape_axis - 1 : b_val;
slice_start = mod_size(slice_start, shape_axis);
slice_stop = e_val == INT_MAX ? -1 : mod_size(e_val, shape_axis);
slice_start = std::min(slice_start, std::max(shape_axis - 1, 0));
slice_stop = std::min(slice_stop, slice_start);
CHECK(slice_step < 0 && slice_start >= 0 && slice_stop <= slice_start &&
slice_start < shape_axis && slice_stop >= -1);
mgb_assert(
(slice_step < 0 && slice_start >= 0 && slice_stop <= slice_start &&
slice_start < shape_axis && slice_stop >= -1),
"index out of bound: layout=%s; request begin=%s end=%s step=%s "
"axis=%s",
src.to_string().c_str(), tostr(slice_start).c_str(),
tostr(slice_stop).c_str(), tostr(slice_step).c_str(),
tostr(axis).c_str());
}
int abs_step = std::abs(slice_step);
if (axis < 0) {
......@@ -205,7 +217,7 @@ SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
return layout_checker;
}
OP_TRAIT_REG(Subtensor, Subtensor, opr::Subtensor)
OP_TRAIT_REG(Subtensor, Subtensor)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
......
......@@ -369,6 +369,27 @@ ValueRefList layer_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
return imperative::apply(op, inputs);
}
ValueRefList group_norm_rule(const OpDef& op, Span<ValueRef> inputs) {
if (DTypePromoteCfg::amp_dtype_autocast_enabled) {
SmallVector<DType> dtypes = get_value_dtypes(inputs);
ValueRefList converted(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
mgb::DType target_dtype = DTypePromoteCfg::amp_high_prec_dtype;
if (dtypes[i] != target_dtype) {
converted[i] = imperative::apply(
ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0];
} else {
converted[i] = inputs[i];
}
}
return imperative::apply(op, converted);
}
return imperative::apply(op, inputs);
}
ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) {
SmallVector<DType> dtypes = get_value_dtypes(inputs);
mgb::DType target_dtype = get_promoted_dtype(dtypes);
......@@ -402,6 +423,7 @@ struct DTypePromoteRuleRegistry {
register_dtype_promote_rule<Convolution3D>(naive_promote_rule);
register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule);
register_dtype_promote_rule<LayerNorm>(layer_norm_rule);
register_dtype_promote_rule<GroupNorm>(group_norm_rule);
}
} register_helper;
......
......@@ -549,6 +549,7 @@ ValueRefList adaptive_pooling_rule(
cb(FastpathCopy) \
cb(TypeCvt) \
cb(Dropout) \
cb(FillLike) \
cb(Identity)
#define FOREACH_FORMAT_OP(cb) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册