提交 f88bd3ae 编写于 作者: M Megvii Engine Team

refactor(traced_module): let TracedModule own argdef_graph_map

GitOrigin-RevId: 80d685b9a395c7ee2d84742aad0b0f507efec7dd
上级 b1c46ba4
......@@ -9,6 +9,7 @@
import builtins
import collections
import copy
import inspect
from typing import Callable, List
......@@ -46,7 +47,7 @@ class Expr:
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))
def add_outputs(self, outputs, check_inplace=True):
def add_outputs(self, outputs):
self.outputs = []
if outputs is not None:
if not isinstance(outputs, collections.Sequence):
......@@ -54,10 +55,7 @@ class Expr:
for i in outputs:
assert isinstance(i, RawTensor)
node = NodeMixin.get(i, None) if check_inplace else None
self.outputs.append(
node if node else NodeMixin.get_wrapped_type(i)(self)
)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)
......@@ -165,9 +163,12 @@ class CallMethod(Expr):
def graph(self):
if isinstance(self.inputs[0], ModuleNode):
m_node = self.inputs[0]
if m_node.argdef_graph_map:
assert self.arg_def in m_node.argdef_graph_map
return m_node.argdef_graph_map[self.arg_def]
if (
hasattr(m_node.owner, "argdef_graph_map")
and m_node.owner.argdef_graph_map
):
assert self.arg_def in m_node.owner.argdef_graph_map
return m_node.owner.argdef_graph_map[self.arg_def]
return None
def interpret(self, *inputs):
......
......@@ -184,6 +184,9 @@ class Patcher:
if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn)
for m in module_tracer._opaque_types:
self.auto_patch(getattr(getattr(m, "forward", m), "__globals__", {}))
def patch_function(self, frame_dict, fn, wrap_fn):
patched_fn = PatchedFn(frame_dict, fn)
self.patched_fn_ids.add(id(patched_fn.origin_fn))
......
......@@ -6,6 +6,8 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import weakref
from typing import Any, Dict, List, Tuple, Type
import numpy
......@@ -58,15 +60,10 @@ class ModuleNode(Node):
"""
module_type = Module # type: Type[Module]
attr_type_map = None # type: Dict[str, Type[Any]]
argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
argdef_outdef_map = None # type: Dict[Treedef, Treedef]
_owner = None # type: weakref.ReferenceType
def __init__(self, expr: "Expr", name: str = None):
super().__init__(expr, name)
self.attr_type_map = {}
self.argdef_graph_map = {}
self.argdef_outdef_map = {}
def __repr__(self):
if self._name is None:
......@@ -74,6 +71,15 @@ class ModuleNode(Node):
else:
return "%{}({})".format(self._name, self.module_type.__name__)
def __getstate__(self):
d = self.__dict__
d.pop("_owner", None)
return d
@property
def owner(self):
return self._owner()
class TensorNode(Node):
"""
......@@ -90,9 +96,14 @@ class TensorNode(Node):
return "%{}(Tensor)".format(self._name)
class NodeMixin:
class NodeMixin(abc.ABC):
__node = None
@abc.abstractmethod
def _record_wrapped_nodes(self, node):
# record the nodes which had been bound to this NodeMixin
pass
@classmethod
def wrap(cls, value, node):
if isinstance(value, (NodeMixin, RawTensor)):
......@@ -102,15 +113,20 @@ class NodeMixin:
node.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
setattr(value, "_NodeMixin__node", node)
else:
assert callable(node)
n = node()
assert isinstance(n, Node)
if isinstance(value, RawTensor):
n.dtype = value.dtype
n.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(n)
setattr(value, "_NodeMixin__node", n)
@classmethod
......@@ -122,6 +138,8 @@ class NodeMixin:
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
setattr(value, "_NodeMixin__node", node)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
@classmethod
def get(cls, value, *default):
......
......@@ -9,6 +9,7 @@
import collections
import copy
import functools
import weakref
from inspect import getmembers, isclass, ismethod
from typing import Callable, Dict, Iterable, List, Sequence, Type
......@@ -51,7 +52,9 @@ def _leaf_type(node):
def _is_leaf(node):
assert isinstance(node, RawTensor), type(node)
assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
type(node)
)
return isinstance(node, RawTensor)
......@@ -107,6 +110,32 @@ class InternalGraph:
def add_output(self, o):
self._outputs.append(o)
def _replace_inputs_outputs(self, repl_dict):
for node, repl_node in repl_dict.items():
assert node in self._inputs or node in self._outputs
for i in node.users:
if i not in repl_node.users:
repl_node.users.append(i)
for idx, i in enumerate(self._inputs):
if i in repl_dict:
self._inputs[idx] = repl_dict[i]
for idx, o in enumerate(self._outputs):
if o in repl_dict:
self._outputs[idx] = repl_dict[o]
self._outputs[idx].expr = node.expr
for expr in self._exprs:
for idx, i in enumerate(expr.inputs):
if i in repl_dict:
expr.inputs[idx] = repl_dict[i]
for idx, o in enumerate(expr.outputs):
if o in repl_dict:
expr.outputs[idx] = repl_dict[o]
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence):
nodes = (nodes,)
......@@ -117,6 +146,7 @@ class InternalGraph:
expr = node.expr
if expr not in ret:
ret.append(expr)
for i in expr.inputs:
if i not in queue:
queue.append(i)
......@@ -287,10 +317,7 @@ def _wrapped_function(orig_func):
call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs)
if meth_name == "__new__":
call_node.add_outputs(outputs, False)
else:
call_node.add_outputs(outputs)
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
return orig_func(*args, **kwargs)
......@@ -303,12 +330,19 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module
_body = None # type: InternalGraph
_is_builtin = None # type: bool
_argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
_argdef_outdef_map = None # type: Dict[Treedef, Treedef]
nodes = None
__builder_attributes__ = [
"_mod",
"_body",
"_NodeMixin__node",
"_is_builtin",
"build",
"_argdef_graph_map",
"_argdef_outdef_map",
"nodes",
]
def __init__(self, mod, is_top_module=False):
......@@ -316,23 +350,36 @@ class TracedModuleBuilder(NodeMixin):
self._mod = mod
self._body = None
self._is_builtin = module_tracer.is_builtin(mod)
self._argdef_graph_map = {}
self._argdef_outdef_map = {}
self.nodes = set()
def build(self):
if self._is_builtin:
node = NodeMixin.get(self)
node.module_type = type(self._mod)
for node in self.nodes:
node.module_type = type(self._mod)
# node._owner = weakref.ref(self._mod)
return self._mod
else:
node = NodeMixin.get(self)
traced_module = TracedModule(node)
traced_module = TracedModule(
self._argdef_graph_map, self._argdef_outdef_map
)
for _, g in self._argdef_graph_map.items():
g.compile()
# for node in self.nodes:
# node._owner = weakref.ref(traced_module)
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
if isinstance(v, TracedModuleBuilder):
v = v.build()
setattr(traced_module, k, v)
traced_module.m_node.attr_type_map[k] = type(v)
return traced_module
def _record_wrapped_nodes(self, node):
self.nodes.add(node)
def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module)
# prepare args and kwargs for inner graph
......@@ -360,19 +407,30 @@ class TracedModuleBuilder(NodeMixin):
if self._is_builtin:
self._body = None
else:
self_node = None
if self._body:
self_node = self._body.inputs[0]
self._body = InternalGraph()
active_module_tracer().push_scope(self._body)
# rebind self to new input node
orig_self = NodeMixin.get(self)
NodeMixin.wrap_safe(
self, Input.make("self", NodeMixin.get_wrapped_type(self))
)
if self_node:
NodeMixin.wrap_safe(self, self_node)
active_module_tracer().current_scope().add_input(self_node)
else:
NodeMixin.wrap_safe(
self,
self_node
if self_node
else Input.make("self", NodeMixin.get_wrapped_type(self)),
)
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph
def wrap(x):
NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
if isinstance(x, (RawTensor, NodeMixin)):
NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
return x
args = [self]
......@@ -397,9 +455,8 @@ class TracedModuleBuilder(NodeMixin):
# rebind output to outer graph
callnode.add_outputs(outputs)
self_node = NodeMixin.get(self)
self_node.argdef_graph_map[callnode.arg_def] = self._body
self_node.argdef_outdef_map[callnode.arg_def] = out_def
self._argdef_graph_map[callnode.arg_def] = self._body
self._argdef_outdef_map[callnode.arg_def] = out_def
return rst
def __getattr__(self, name):
......@@ -424,8 +481,8 @@ class TracedModuleBuilder(NodeMixin):
else:
wrapped = super().__getattribute__(name)
if name in self._mod.__dict__:
if not NodeMixin.get(wrapped, None):
assert not self._is_builtin
assert not self._is_builtin
if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap(
wrapped,
lambda: GetAttr.make(
......@@ -434,14 +491,15 @@ class TracedModuleBuilder(NodeMixin):
type=NodeMixin.get_wrapped_type(wrapped),
),
)
"""
else:
node = NodeMixin.get(wrapped)
expr = GetAttr.make(
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
).expr
expr.outputs[0] = node
expr = node.expr
assert isinstance(expr, GetAttr)
if expr not in active_module_tracer().current_scope()._exprs:
active_module_tracer().current_scope().insert(expr)
"""
return wrapped
......@@ -514,33 +572,51 @@ class ExprFilterCallMethod(ExprFilter):
class TracedModule(Module):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
interpreted by CallMethod Expr.
`TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.
"""
m_node = None # type: ModuleNode
# m_node = None # type: ModuleNode
argdef_graph_map = None
argdef_outdef_map = None
def __init__(self, node):
def __init__(self, argdef_graph_map, argdef_outdef_map):
super(TracedModule, self).__init__()
self.m_node = node
self.argdef_graph_map = argdef_graph_map
self.argdef_outdef_map = argdef_outdef_map
def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
)
assert treedef in self.m_node.argdef_graph_map
assert treedef in self.argdef_graph_map
inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace.
outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs)
out_def = self.m_node.argdef_outdef_map[treedef]
outputs = self.argdef_graph_map[treedef].interpret(*inputs)
out_def = self.argdef_outdef_map[treedef]
outputs = out_def.unflatten(outputs)
return outputs
@property
def graph(self):
assert len(self.m_node.argdef_graph_map) == 1
return list(self.m_node.argdef_graph_map.values())[0]
self._update_modulenode_ref()
assert len(self.argdef_graph_map) == 1
return list(self.argdef_graph_map.values())[0]
def _update_modulenode_ref(self):
for _, graph in self.argdef_graph_map.items():
graph._inputs[0]._owner = weakref.ref(self)
node2obj = {}
node2obj[graph._inputs[0]] = self
for expr in graph._exprs:
if isinstance(expr, GetAttr) and isinstance(
expr.outputs[0], ModuleNode
):
obj = getattr(node2obj[expr.inputs[0]], expr.name)
expr.outputs[0]._owner = weakref.ref(obj)
node2obj[expr.outputs[0]] = obj
if isinstance(obj, TracedModule):
obj._update_modulenode_ref()
@property
def exprs(self):
......@@ -561,39 +637,49 @@ class TracedModule(Module):
const.outputs[0] = call.inputs[0]
const.outputs[0].expr = const
return [const, call]
if call is not None:
graph = copy.deepcopy(graph)
exprs = []
node2obj = {}
node2obj[graph._inputs[0]] = module
if call:
node2obj[call.inputs[0]] = module
for expr in graph._exprs:
# replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs:
inp_idx = graph._inputs.index(inp)
expr.inputs[idx] = call.inputs[inp_idx]
call.inputs[inp_idx].users.append(expr)
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
oup_idx = graph._outputs.index(outp)
expr.outputs[idx] = call.outputs[oup_idx]
call.outputs[oup_idx].expr = expr
# replace inputs for submodule's exprx
if call:
repl_dict = dict(
zip(graph._inputs + graph._outputs, call.inputs + call.outputs)
)
graph._replace_inputs_outputs(repl_dict)
if isinstance(expr, GetAttr):
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module, expr.name))
const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
const.outputs = expr.outputs
const.outputs[0].expr = const
exprs.append(const)
elif isinstance(expr.outputs[0], ModuleNode):
node2obj[expr.outputs[0]] = getattr(
node2obj[expr.inputs[0]], expr.name
)
elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
pre_expr = expr.inputs[0].expr
if isinstance(pre_expr, GetAttr):
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
(obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
expr_graph = (
obj.argdef_graph_map[expr.arg_def]
if hasattr(obj, "argdef_graph_map")
else None
)
exprs.extend(_flatten_subgraph(expr_graph, obj, expr))
else:
# module has been replaced.
assert isinstance(pre_expr, Constant)
exprs.append(expr)
else:
exprs.append(expr)
else:
......
......@@ -9,6 +9,7 @@
import itertools
import numpy as np
import pytest
import megengine as mge
import megengine.autodiff as ad
......
......@@ -9,6 +9,7 @@
import itertools
import numpy as np
import pytest
import megengine as mge
import megengine.autodiff as ad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册