提交 59a9275c 编写于 作者: M Megvii Engine Team

fix(mge): fix optimize_for_inference during trace.dump

GitOrigin-RevId: e10f7c323a1832a9727211c9ee6cd9242c869c3b
上级 bd7f885a
...@@ -570,7 +570,9 @@ class trace: ...@@ -570,7 +570,9 @@ class trace:
if h not in h2v: if h not in h2v:
assert info.external assert info.external
assert info.bound_data assert info.bound_data
h2v[h] = graph.make_const(info.bound_data._dev_tensor()) h2v[h] = graph.make_const(
info.bound_data.numpy(), dtype=info.dtype, device=info.device
)
ivars.append(h2v[h]) ivars.append(h2v[h])
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
assert len(ovars) == len(ohandles) assert len(ovars) == len(ohandles)
......
...@@ -150,7 +150,7 @@ def test_dump_volatile(): ...@@ -150,7 +150,7 @@ def test_dump_volatile():
(out,) = outputs (out,) = outputs
assert ( assert (
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
== "SharedDeviceTensor" == "ImmutableTensor"
) )
...@@ -235,6 +235,18 @@ def test_optimize_for_inference(): ...@@ -235,6 +235,18 @@ def test_optimize_for_inference():
assert computing_input.dtype == np.float16 assert computing_input.dtype == np.float16
def test_optimize_for_inference_broadcast():
a = tensor(np.ones(1, dtype=np.float32))
@trace(capture_as_const=True, tensor_shape=True)
def f():
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32))
return b
f()
f.dump(io.BytesIO())
def test_trace_cvt_bool(): def test_trace_cvt_bool():
set_tensor_shape(True) set_tensor_shape(True)
x = tensor([0], dtype=np.int32) x = tensor([0], dtype=np.int32)
......
...@@ -561,7 +561,7 @@ void ParamFusePass::apply(OptState &state) const { ...@@ -561,7 +561,7 @@ void ParamFusePass::apply(OptState &state) const {
} }
SymbolVar new_var; SymbolVar new_var;
bool is_default_format = var->layout().format.is_default(); bool is_default_format = var->format().is_default();
if (cg::is_static_var_value(var) && is_default_format) { if (cg::is_static_var_value(var) && is_default_format) {
// use ImmutableTensor for inferable vars // use ImmutableTensor for inferable vars
HostTensorND hv; HostTensorND hv;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册