提交 aed681d3 编写于 作者: M Megvii Engine Team

feat(imperative/utils): optimize the naming rules

GitOrigin-RevId: 329bac640aa6e2e3c981aa294a361684b982892e
上级 c6bbc478
......@@ -40,7 +40,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
from ..utils.naming import auto_naming
from ..utils.naming import AutoNaming
from .sublinear_memory_config import SublinearMemoryConfig
......@@ -297,9 +297,7 @@ class trace:
h = getattr(x, "_mixin_handle", -1)
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle()
name = (
auto_naming.get_scope() + "." + (x.c_name if x.c_name else x._name)
)
name = AutoNaming.gen_name(x)
info.name = name
info.external = True
info.device = x.device
......@@ -845,17 +843,17 @@ class trace:
ivars.append(h2v[h])
ovars = G.apply_normal_varnode(op, *ivars)
auto_naming.record_opnode(ovars[0].op)
AutoNaming.record_opnode(ovars[0].op)
assert len(ovars) == len(ohandles)
h2v.update(zip(ohandles, ovars))
for i in ohandles:
name = auto_naming.get_var_name(i)
name = AutoNaming.get_var_name(i)
if name is not None:
h2v[i].name = name
auto_naming.remove_duplicate_names()
AutoNaming.remove_duplicate_names()
dest_vars = []
for i, h in enumerate(self._output_bindings):
......@@ -1173,7 +1171,7 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
def apply_with_tracing(op: OpDef, *args: RawTensor):
if hasattr(op, "scope"):
op.scope = auto_naming.get_scope()
op.scope = AutoNaming.get_scope()
if active_trace._symbolic:
outputs = apply_symbolic_mode(op, *args)
else:
......
......@@ -16,7 +16,7 @@ from ..logger import get_logger
from ..tensor import Parameter, Tensor
from ..utils.deprecation import deprecated
from ..utils.hook import HookHandler
from ..utils.naming import auto_naming
from ..utils.naming import AutoNaming
logger = get_logger(__name__)
......@@ -111,7 +111,7 @@ class Module(metaclass=ABCMeta):
self._forward_hooks = OrderedDict()
# used for profiler and automatic naming
self._name = "{anonymous}"
self._name = None
@abstractmethod
def forward(self, inputs):
......@@ -137,7 +137,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_hooks, hook)
def __call__(self, *inputs, **kwargs):
auto_naming.push_scope(self.name if self.name is not None else self._name)
AutoNaming.push_scope(self.name if self.name is not None else self._name)
for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs)
if modified_inputs is not None:
......@@ -151,7 +151,7 @@ class Module(metaclass=ABCMeta):
modified_outputs = hook(self, inputs, outputs)
if modified_outputs is not None:
outputs = modified_outputs
auto_naming.pop_scope()
AutoNaming.pop_scope()
return outputs
def _flatten(
......
......@@ -20,7 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device
from .logger import get_logger
from .utils.deprecation import deprecated
from .utils.naming import auto_naming
from .utils.naming import AutoNaming
logger = get_logger(__name__)
......@@ -168,7 +168,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
@name.setter
def name(self, name):
self.c_name = name
auto_naming.record_var_name(self._mixin_handle, name)
AutoNaming.record_var_name(self._mixin_handle, name)
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value):
......
......@@ -15,40 +15,57 @@ class AutoNaming:
renamed by the user.
"""
def __init__(self):
self.scopes = []
self.c_ops = []
self.name2ops = {}
self.handle2names = {}
scopes = []
c_ops = []
name2ops = {}
handle2names = {}
__cls_attributes__ = {"scopes", "c_ops", "name2ops", "handle2names"}
def clear(self):
for var in vars(self).values():
var.clear()
@classmethod
def clear(cls):
for attr in cls.__cls_attributes__:
getattr(cls, attr).clear()
def push_scope(self, scope):
push_scope(scope)
self.scopes.append(scope)
@classmethod
def push_scope(cls, scope):
if scope is not None:
push_scope(scope)
cls.scopes.append(scope)
def pop_scope(self):
scope = self.scopes.pop()
pop_scope(scope)
@classmethod
def pop_scope(cls):
scope = cls.scopes.pop()
if scope is not None:
pop_scope(scope)
def get_scope(self):
return ".".join(self.scopes)
@classmethod
def get_scope(cls):
return ".".join(s for s in cls.scopes if s is not None)
def record_var_name(self, handle, name):
self.handle2names[handle] = name
@classmethod
def gen_name(cls, x) -> str:
scope = cls.get_scope()
name = x.c_name if x.c_name else x._name
return scope + "." + name if len(scope) else name
def get_var_name(self, handle):
return self.handle2names.pop(handle, None)
@classmethod
def record_var_name(cls, handle, name):
cls.handle2names[handle] = name
def record_opnode(self, op):
ops = self.name2ops.get(op.name, [])
ops.append(op)
self.name2ops[op.name] = ops
@classmethod
def get_var_name(cls, handle):
return cls.handle2names.pop(handle, None)
def remove_duplicate_names(self):
for key, ops in self.name2ops.items():
@classmethod
def record_opnode(cls, op):
ops = cls.name2ops.get(op.name, [])
if op not in ops:
ops.append(op)
cls.name2ops[op.name] = ops
@classmethod
def remove_duplicate_names(cls):
for key, ops in cls.name2ops.items():
if len(ops) == 1:
continue
for i, op in enumerate(ops):
......@@ -57,7 +74,4 @@ class AutoNaming:
continue
for var in op.outputs:
var.name = var.name.replace(key, op.name)
self.name2ops.clear()
auto_naming = AutoNaming()
cls.name2ops.clear()
......@@ -28,7 +28,7 @@ from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace
from megengine.module import Module
from megengine.random import normal, uniform
from megengine.utils.naming import auto_naming
from megengine.utils.naming import AutoNaming
@pytest.mark.parametrize("trace_mode", [False, True])
......@@ -141,7 +141,7 @@ def test_dump():
return a + b
# prevent from remaining scope from exception test
auto_naming.clear()
AutoNaming.clear()
a = tensor([2])
b = tensor([4])
y = f(a, b).numpy()
......
......@@ -18,11 +18,11 @@ from megengine import Parameter, Tensor
from megengine.core.tensor import megbrain_graph as G
from megengine.jit.tracing import trace
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.utils.naming import auto_naming
from megengine.utils.naming import AutoNaming
def _dump_and_load(func, symbolic, keep_opr_name=True):
auto_naming.clear()
AutoNaming.clear()
func = trace(func, symbolic=symbolic, capture_as_const=True)
x = Tensor(np.ones(shape=(2, 3)))
func(x).numpy()
......@@ -103,6 +103,18 @@ def test_without_module(symbolic):
assert op.name == "MUL"
@pytest.mark.parametrize("symbolic", [False, True])
def test_ignore_top_module(symbolic):
class Simple(M.Module):
def forward(self, x):
return x + x
m = Simple()
op = _dump_and_load(m, symbolic)[-1]
assert op.name == "ADD"
assert op.outputs[0].name == "ADD"
@pytest.mark.parametrize("symbolic", [False, True])
def test_with_submodule(symbolic):
class Simple(M.Module):
......@@ -196,7 +208,7 @@ def test_not_keep_opr_name():
return 2 * x
op = _dump_and_load(f, True, False)[-1]
assert op.name == "MUL(x,2[2])[4]"
assert op.name == "MUL(x,const<2>[2])[4]"
@pytest.mark.parametrize("symbolic", [False, True])
......
......@@ -419,7 +419,7 @@ void ImmutableTensor::Value::setup(CompNode cn, const HostTensorND &val) {
if (one_elem(val.shape())) {
float v;
static_cast_dtype(&v, val.dtype(), val.raw_ptr());
m_summary = ssprintf("%.3g", v);
m_summary = ssprintf("const<%.3g>", v);
if (val.shape().ndim != 1) {
m_summary += val.shape().to_string();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册