提交 763c56f3 编写于 作者: M Megvii Engine Team

feat(imperative): add traced module

GitOrigin-RevId: 28c3503f2eaca979242c19c7d7495358daa0a8c4
上级 9279104b
......@@ -130,3 +130,4 @@ import megengine.optimizer
import megengine.quantization
import megengine.random
import megengine.utils
import megengine.experimental
......@@ -6,4 +6,5 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from . import traced_module
from .weight_scaler import get_scaled_model
......@@ -5,3 +5,15 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core._imperative_rt.core2 import set_cpp_apply_module_trace
from .traced_module import (
TracedModule,
_register_all_builtin_module,
cpp_apply_module_trace,
register_as_builtin,
trace_module,
)
_register_all_builtin_module()
set_cpp_apply_module_trace(cpp_apply_module_trace)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import List
from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.special import Const
from ...tensor import Tensor
from .module_tracer import active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
class Expr:
"""
``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``.
"""
inputs = None # type: List[Node]
outputs = None # type: List[Node]
# expr: None (i.e. fake expression which is used to mark input)
class Input(Expr):
name = None
def __init__(self, name=None, type=None):
self.inputs = []
node_cls = type if type else Node
self.outputs = [
node_cls(self, name=name),
]
self.name = name
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().add_input(expr.outputs[0])
return expr.outputs[0]
def __repr__(self):
return "{} = Input({})".format(self.outputs[0], self.name)
# expr: outputs = getattr(inputs[0], self.name)
class GetAttr(Expr):
name = None
def __init__(self, module, name, type=None):
assert isinstance(module, ModuleNode)
self.inputs = [
module,
]
self.name = name
node_cls = type if type else Node
self.outputs = [
node_cls(self),
]
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
expr.outputs[0]._name = expr.name
return expr.outputs[0]
def interpret(self, *inputs):
return (getattr(inputs[0], self.name),)
def __repr__(self):
return '{} = GetAttr({}, "{}")'.format(
self.outputs[0], self.inputs[0], self.name
)
# expr: outputs = inputs[0].__call__(*inputs[1:])
class Call(Expr):
def __init__(self, module):
assert isinstance(module, ModuleNode)
self.inputs = [
module,
]
def add_input(self, node):
self.inputs.append(node)
def add_outputs(self, references):
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)
for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
return expr
def interpret(self, *inputs):
mod = inputs[0]
args = inputs[1:]
outputs = mod(*args)
if isinstance(outputs, RawTensor):
outputs = (outputs,)
return outputs
def __repr__(self):
return "{} = Call({})({})".format(
", ".join(str(i) for i in self.outputs),
self.inputs[0],
", ".join(str(i) for i in self.inputs[1:]),
)
# expr: outputs = apply(self.opdef, *inputs)
class Apply(Expr):
opdef = None
def __init__(self, opdef):
assert isinstance(opdef, OpDef)
self.opdef = opdef
self.inputs = []
def add_input(self, node):
self.inputs.append(node)
def add_outputs(self, references):
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)
for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
return expr
def interpret(self, *inputs):
return apply(self.opdef, *inputs)
def __repr__(self):
return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs),
self.opdef,
", ".join(str(i) for i in self.inputs),
)
@classmethod
def apply_module_trace_hook(cls, opdef, *inputs):
for i in inputs:
node = NodeMixin.get(i, None)
if node is None: # capture as constant
NodeMixin.wrap_safe(i, Constant.make(i))
apply_node = cls.make(opdef)
for i in inputs:
apply_node.add_input(NodeMixin.get(i))
unset_module_tracing()
outputs = apply(opdef, *inputs)
set_module_tracing()
apply_node.add_outputs(outputs)
for n, v in zip(apply_node.outputs, outputs):
NodeMixin.wrap_safe(v, n)
return list(outputs)
# expr outputs = self.value
class Constant(Expr):
value = None
# TODO: constant cache to reduce the size of dumped model
_constant_cache = {}
def __init__(self, c):
# TODO: type check, since not all types should be captured as constant
self.value = c
self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c)
self.outputs = [
node_cls(self),
]
@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
return expr.outputs[0]
def interpret(self, *inputs):
if isinstance(self.value, RawTensor):
return Const(self.value.numpy())()
return (self.value,)
def __repr__(self):
return "{} = Constant({})".format(self.outputs[0], self.value)
def __getstate__(self):
state = self.__dict__.copy()
if isinstance(self.value, RawTensor):
state["value"] = Tensor(self.value)
return state
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...module import Module
_active_module_tracer = None
def active_module_tracer():
return _active_module_tracer
def set_active_module_tracer(tracer):
global _active_module_tracer
_active_module_tracer = tracer
class module_tracer:
_opaque_types = set()
_active_scopes = None
def __init__(self):
self._active_scopes = []
@classmethod
def register_as_builtin(cls, mod):
assert issubclass(mod, Module)
cls._opaque_types.add(mod)
return mod
@classmethod
def is_builtin(cls, mod):
return type(mod) in cls._opaque_types
def push_scope(self, scope):
self._active_scopes.append(scope)
def pop_scope(self):
self._active_scopes.pop()
def current_scope(self):
if self._active_scopes:
return self._active_scopes[-1]
return None
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Any, Dict, Tuple, Type
import numpy
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...module import Module
from ...tensor import Tensor
class Node:
"""
``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method. They are inputs/outputs of Expr(the operations on variables).
param expr: the Expr which produces the node
param name: the name of the node
"""
expr = None
__total_id = 0
_id = None
_name = None
def __init__(self, expr: "Expr", name: str = None):
self.expr = expr
self._id = Node.__total_id
Node.__total_id += 1
self._name = name
def __repr__(self):
if self._name is None:
return "%{}".format(self._id)
else:
return "%{}".format(self._name)
class ModuleNode(Node):
"""
``ModuleNode`` represents the Module objects.
Attributes:
module_type: type of the Module correspending to the ModuleNode
graph: the InternalGraph which will be interpreted when call Module's forward method
attr_type_map: record the type of Module's attributes
"""
module_type = Module # type: Type[Module]
graph = None
attr_type_map = None # type: Dict[str, Type[Any]]
def __repr__(self):
if self._name is None:
return "%{}({})".format(self._id, self.module_type.__name__)
else:
return "%{}({})".format(self._name, self.module_type.__name__)
class TensorNode(Node):
"""
``TensorNode`` represents the Tensor objects.
"""
shape = None # type: Tuple[int]
dtype = None # type: numpy.dtype
def __repr__(self):
if self._name is None:
return "%{}(Tensor)".format(self._id)
else:
return "%{}(Tensor)".format(self._name)
class NodeMixin:
__node = None
@classmethod
def wrap(cls, value, node):
if isinstance(value, (NodeMixin, RawTensor)):
if isinstance(node, Node):
if isinstance(value, RawTensor):
node.dtype = value.dtype
node.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
setattr(value, "_NodeMixin__node", node)
else:
assert callable(node)
n = node()
if isinstance(value, RawTensor):
n.dtype = value.dtype
n.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
setattr(value, "_NodeMixin__node", n)
@classmethod
def wrap_safe(cls, value, node):
assert isinstance(value, (NodeMixin, RawTensor))
if isinstance(value, RawTensor):
node.dtype = value.dtype
node.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
setattr(value, "_NodeMixin__node", node)
@classmethod
def get(cls, value, *default):
return getattr(value, "_NodeMixin__node", *default)
@classmethod
def get_wrapped_type(cls, value):
if isinstance(value, RawTensor):
return TensorNode
if isinstance(value, (Module, NodeMixin)):
return ModuleNode
return Node
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy
from typing import List, Type
from ... import module as M
from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing
from ...module import Module
from ...tensor import Tensor
from .expr import Apply, Call, Constant, Expr, GetAttr, Input
from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
class InternalGraph:
"""
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.
Attributes:
_exprs: List of Exprs in order of execution
_inputs: Input Nodes of InternalGraph
_outputs: Output Nodes of InternalGraph
"""
_exprs = None # type: List[Expr]
_inputs = None # type: List[Node]
_outputs = None # type: List[Node]
def __init__(self):
self._exprs = []
self._inputs = []
self._outputs = []
def insert(self, expr):
self._exprs.append(expr)
def add_input(self, i):
self._inputs.append(i)
def add_output(self, o):
self._outputs.append(o)
def interpret(self, *inputs):
# TODO: support kwargs ?
# TODO: skip expressions which are independent and have no side effect
node2value = {}
for n, v in zip(self._inputs, inputs):
node2value[n] = v
for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
for n, v in zip(expr.outputs, values):
node2value[n] = v
return list(node2value[i] for i in self._outputs)
def __repr__(self):
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
", ".join(str(i) for i in self._inputs),
"\n\t".join(str(i) for i in self._exprs),
", ".join(str(i) for i in self._outputs),
)
class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module
_body = None # type: InternalGraph
_is_builtin = None # type: bool
__builder_attributes__ = [
"_mod",
"_body",
"_NodeMixin__node",
"_is_builtin",
"_is_traced",
"build",
]
def __init__(self, mod):
super(TracedModuleBuilder, self).__init__()
self._mod = mod
self._body = InternalGraph()
self._is_traced = False
self._is_builtin = module_tracer.is_builtin(mod)
def build(self):
if self._is_builtin:
node = NodeMixin.get(self)
node.module_type = type(self._mod)
return self._mod
else:
node = NodeMixin.get(self)
node.graph = self._body
node.attr_type_map = {}
traced_module = TracedModule(node)
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
if isinstance(v, TracedModuleBuilder):
v = v.build()
setattr(traced_module, k, v)
traced_module.m_node.attr_type_map[k] = type(v)
return traced_module
def __call__(self, *inputs, **kwargs):
assert isinstance(self._mod, Module)
# prepare args and kwargs for inner graph
def mark_constant(x):
node = NodeMixin.get(x, None)
if node is None: # capture as constant
NodeMixin.wrap(x, lambda: Constant.make(x))
for i in inputs:
mark_constant(i)
for k, v in kwargs.items():
mark_constant(v)
callnode = Call.make(NodeMixin.get(self))
def add_input(x):
callnode.add_input(NodeMixin.get(x))
for i in inputs:
add_input(i)
for k, v in kwargs.items():
add_input(v)
if self._is_builtin or self._is_traced:
unset_module_tracing()
outputs = self._mod(*inputs, **kwargs)
set_module_tracing()
if self._is_builtin:
self._body = None
else:
active_module_tracer().push_scope(self._body)
# rebind self to new input node
orig_self = NodeMixin.get(self)
NodeMixin.wrap_safe(
self, Input.make("self", NodeMixin.get_wrapped_type(self))
)
# prepare args and kwargs for inner graph
def wrap(x):
wrapped = copy.copy(x) # FIXME
NodeMixin.wrap(
wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
)
return wrapped
args = []
for i in inputs:
args.append(wrap(i))
for k, v in kwargs.items():
kwargs[k] = wrap(v)
outputs = type(self._mod).forward(self, *args, **kwargs)
for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
):
active_module_tracer().current_scope().add_output(NodeMixin.get(i))
NodeMixin.wrap_safe(self, orig_self)
self._is_traced = True
active_module_tracer().pop_scope()
# rebind output to outer graph
callnode.add_outputs(outputs)
for i, node in zip(
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,),
callnode.outputs,
):
NodeMixin.wrap_safe(i, node)
return outputs
def __getattr__(self, name):
if name not in self._mod.__dict__:
attr = getattr(type(self._mod), name).__get__(self, type(self))
else:
attr = getattr(self._mod, name)
if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)
setattr(self, name, attr)
NodeMixin.wrap(
attr,
lambda: GetAttr.make(
NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
),
)
return attr
def __getattribute__(self, name):
if name in TracedModuleBuilder.__builder_attributes__:
return super().__getattribute__(name)
else:
wrapped = super().__getattribute__(name)
if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None):
assert not self._is_builtin
NodeMixin.wrap(
wrapped,
lambda: GetAttr.make(
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
),
)
return wrapped
class TracedModule(Module):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
"""
m_node = None # type: ModuleNode
def __init__(self, node):
super(TracedModule, self).__init__()
self.m_node = node
def forward(self, *inputs):
rst = self.m_node.graph.interpret(self, *inputs)
if len(rst) == 1:
rst = rst[0]
return rst
def __getstate__(self):
d = self.__dict__
for k in Module.__dict__:
d.pop(k, None)
return d
def cpp_apply_module_trace(opdef, *args):
return Apply.apply_module_trace_hook(opdef, *args)
def register_as_builtin(mod_cls: Type[Module]) -> None:
"""
Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.
param mod_cls: the Module class which will be threated as builtin module in tracing
"""
module_tracer.register_as_builtin(mod_cls)
def _register_all_builtin_module():
from inspect import getmembers, isclass
for sub_mod in [M, M.qat, M.quantized]:
for m in getmembers(sub_mod):
if (
isclass(m[1])
and issubclass(m[1], M.Module)
and m[1] is not M.Sequential
):
module_tracer.register_as_builtin(m[1])
def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule:
"""
Traces module ``mod`` and returns corresponding TracedModule.
param mod: the module will be converted to TracedModule
param input: the positional arguments passed to forward method of ``mod``
param kwargs: the keyword arguments passed to forward method of ``mod``
"""
assert active_module_tracer() is None
try:
set_module_tracing()
set_active_module_tracer(module_tracer())
global_scope = InternalGraph()
active_module_tracer().push_scope(global_scope)
builder = TracedModuleBuilder(mod)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
for _, i in enumerate(inputs):
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
for k, v in kwargs.items():
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))
builder(*inputs, **kwargs)
active_module_tracer().pop_scope()
return builder.build()
finally:
set_active_module_tracer(None)
unset_module_tracing()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册