提交 e027dcbf 编写于 作者: M Megvii Engine Team

chore(mge): improve symbolic tracing value/shape inference

GitOrigin-RevId: d1a6baac741726604c799752b19d2ed90e399639
上级 e6e29748
...@@ -186,6 +186,9 @@ class trace: ...@@ -186,6 +186,9 @@ class trace:
self._seq.append((op, tuple(ihandles), tuple(ohandles))) self._seq.append((op, tuple(ihandles), tuple(ohandles)))
self._active_tensors.update(outputs) self._active_tensors.update(outputs)
def _record_const(self, op, outputs):
pass
@contextlib.contextmanager @contextlib.contextmanager
def _setup(self): def _setup(self):
global active_trace global active_trace
...@@ -195,8 +198,10 @@ class trace: ...@@ -195,8 +198,10 @@ class trace:
if self._untraced: if self._untraced:
apply.enable(apply_with_tracing) apply.enable(apply_with_tracing)
apply.enable(apply_const_with_tracing)
if self._symbolic: if self._symbolic:
apply.enable(apply_symbolic_mode) apply.enable(apply_symbolic_mode)
apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph() self._lazy_eval_graph = G.Graph()
else: else:
apply.enable(apply_compiled_mode) apply.enable(apply_compiled_mode)
...@@ -239,7 +244,9 @@ class trace: ...@@ -239,7 +244,9 @@ class trace:
self._pc = 0 self._pc = 0
apply.disable(apply_with_tracing) apply.disable(apply_with_tracing)
apply.disable(apply_const_with_tracing)
apply.disable(apply_symbolic_mode) apply.disable(apply_symbolic_mode)
apply.disable(apply_const_symbolic_mode)
apply.disable(apply_compiled_mode) apply.disable(apply_compiled_mode)
active_trace = None active_trace = None
...@@ -477,6 +484,16 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): ...@@ -477,6 +484,16 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
apply.disable(apply_symbolic_mode) apply.disable(apply_symbolic_mode)
@apply.register()
def apply_const_symbolic_mode(op: Const, *args: RawTensor):
graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
return (ret,)
apply.disable(apply_const_symbolic_mode)
@apply.register() @apply.register()
def apply_compiled_mode(op: OpDef, *args: RawTensor): def apply_compiled_mode(op: OpDef, *args: RawTensor):
if skip_tracing: if skip_tracing:
...@@ -502,9 +519,14 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): ...@@ -502,9 +519,14 @@ def apply_with_tracing(op: OpDef, *args: RawTensor):
apply.disable(apply_with_tracing) apply.disable(apply_with_tracing)
# @apply.register() @apply.register()
# def _(op: Const, *args: RawTensor): def apply_const_with_tracing(op: Const, *args: RawTensor):
# return active_trace._apply_const(op, args) outputs = apply.super(op, *args)
active_trace._record_const(op, outputs)
return outputs
apply.disable(apply_const_with_tracing)
class BrokenRawTensor(RawTensor): class BrokenRawTensor(RawTensor):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册