提交 787f187e 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(imperative/src): fix dot backward error

GitOrigin-RevId: 02ba44a0e6d8cd2ca863ae0058542b260e0b755d
上级 f35687ca
......@@ -442,3 +442,18 @@ def test_removeAxis():
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy())
def test_dot():
x = np.random.rand(2, 2).astype("float32")
x = mge.Tensor(x)
u = F.ones((2,))
v = F.ones((2,))
grad = Grad().wrt(x, callback=save_to(x))
def f(x):
return F.dot(u, F.matmul(x, v))
y = f(x)
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy())
......@@ -33,7 +33,7 @@ DispatchMode decide_dispatch_mode(
const SmallVector<LogicalTensorDesc>& inputs) {
bool host_computable = true;
for (auto&& inp : inputs) {
// FIXME(czh): remove value chech after proxy graph's
// FIXME(czh): remove value check after proxy graph's
// apply_on_device_tensornd is supported and output Tensor
// is made before add_task.
// then if layout is valid, ptr->layout must be ready
......@@ -50,9 +50,18 @@ void apply_on_device_tensornd(
const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) {
auto&& op_def = def.cast_final_safe<GetVarShape>();
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& inp = inputs[0];
auto&& shp = inp.layout();
TensorShape shp;
if (inputs.size() == 1) {
shp = inputs[0].layout();
} else {
TensorShapeArray src(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
src[i] = inputs[i].layout();
}
megdnn::Elemwise::deduce_shape(src, shp);
}
mgb_assert(shp.ndim != 0, "input shape invalid");
mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(),
"GetVarShape's apply_on_device_tensornd should receive default_cpu outputs.");
......@@ -99,27 +108,36 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<GetVarShape>();
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& desc = inputs[0];
if (!desc.layout.ndim) {
TensorShape shp;
if (inputs.size() == 1) {
shp = desc.layout;
} else {
TensorShapeArray src(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
src[i] = inputs[i].layout;
}
megdnn::Elemwise::deduce_shape(src, shp);
}
if (!shp.ndim) {
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
}
DeviceTensorND value;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32());
value = DeviceTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
for (size_t i = 0; i < desc.layout.ndim; ++i) {
ptr[i] = desc.layout[i];
for (size_t i = 0; i < shp.ndim; ++i) {
ptr[i] = shp[i];
}
}else{
int32_t axis = op_def.axis;
if (axis < 0) {
axis += desc.layout.ndim;
axis += shp.ndim;
}
mgb_assert(axis >= 0 && axis < (int32_t)desc.layout.ndim);
mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>();
ptr[0] = desc.layout[axis];
ptr[0] = shp[axis];
}
return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册