提交 fb20cb36 编写于 作者: M Megvii Engine Team

docs(mge/traced_module): update traced_module api doc

GitOrigin-RevId: 19a95d26c71e672376c5fda00a4e7dc6050e1c6a
上级 c7a8d945
......@@ -130,3 +130,4 @@ import megengine.optimizer
import megengine.quantization
import megengine.random
import megengine.utils
import megengine.traced_module
......@@ -33,15 +33,22 @@ def rstrip(s: str, __chars: str):
class Expr:
"""``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``."""
__total_id = 0
r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
``GetAttr``, ``Input``, ``Constant``) on ``Node``.
"""
inputs = None # type: List[Node]
r"""The input Nodes of this Expr."""
outputs = None # type: List[Node]
r"""The output Nodes of this Expr."""
const_val = None # type: List[Any]
r"""The non-tensor object in the input of the operation."""
arg_def = None # type: TreeDef
r"""The :class:`TreeDef` used to reconstruct the input of the operation."""
out_def = None # type: TreeDef
r"""The :class:`TreeDef` used to reconstruct the output of the operation."""
_top_graph = None # type: weakref.ReferenceType
__total_id = 0
def __init__(self) -> None:
self._id = Expr.__total_id
......@@ -125,6 +132,11 @@ class Expr:
return inputs, {}
def replace_inputs(self, repl_dict: Dict[Node, Node]):
r"""Replace the input Nodes of this Expr.
Args:
repl_dict: the map {old_Node: new_Node} that specifies how to replace the input Nodes.
"""
while repl_dict:
node, repl_node = repl_dict.popitem()
assert type(node) == type(repl_node)
......@@ -147,16 +159,19 @@ class Expr:
@property
def kwargs(self):
r"""Get the the keyword arguments of the operation corresponding to this Expr."""
_, kwargs = self.unflatten_args(self.inputs)
return kwargs
@property
def args(self):
r"""Get the the positional arguments of the operation corresponding to this Expr."""
args, _ = self.unflatten_args(self.inputs)
return args
@property
def top_graph(self):
r"""Get the parent graph of this Expr."""
if self._top_graph:
return self._top_graph()
return None
......@@ -168,17 +183,18 @@ class Expr:
return state
@classmethod
def get_total_id(cls):
def _get_next_id(cls):
return cls.__total_id
@classmethod
def set_total_id(cls, id: int = 0):
def _set_next_id(cls, id: int = 0):
assert isinstance(id, int)
cls.__total_id = id
# expr: None (i.e. fake expression which is used to mark input)
class Input(Expr):
r"""A fake Expr which is used to mark the input of graph."""
name = None
def __init__(self, name=None, type=None, orig_name=None):
......@@ -204,13 +220,15 @@ class Input(Expr):
return expr.outputs[0]
def __repr__(self):
return "%{}:\t{} = Input({})".format(self._id, self.outputs[0], self.name)
return "%{}:\t{} = Input()".format(self._id, self.outputs[0])
# expr: outputs = getattr(inputs[0], self.name)
class GetAttr(Expr):
name = None
r"""``Getattr`` represents the fetch of an attribute from the ``Module`` hierarchy."""
name = None
r"""name: the qualified name of the attribute to be retrieved."""
def __init__(self, module, name, type=None, orig_name=None):
super().__init__()
assert isinstance(module, ModuleNode)
......@@ -251,6 +269,13 @@ class GetAttr(Expr):
# expr: outputs = inputs[0].__call__(*inputs[1:])
class CallMethod(Expr):
r"""``CallMethod`` represents a call to the ``__call__`` method of ``Module`` or a method of ``Tensor``.
Args:
node: the Node to be called.
method: the method name.
Default: "__call__"
"""
def __init__(self, node, method="__call__"):
super().__init__()
if isinstance(node, type):
......@@ -320,8 +345,12 @@ class CallMethod(Expr):
# expr: outputs = apply(self.opdef, *inputs)
class Apply(Expr):
opdef = None
r"""``Apply`` represents a call to :func:`apply`.
Args:
opdef: the applied :class:`OpDef`.
"""
opdef = None
def __init__(self, opdef):
super().__init__()
assert isinstance(opdef, OpDef)
......@@ -388,6 +417,11 @@ class Apply(Expr):
class CallFunction(Expr):
r"""``CallFunction`` represents a call to a built-in function.
Args:
func: a built-in function.
"""
def __init__(self, func):
super().__init__()
assert isinstance(func, Callable)
......@@ -425,7 +459,14 @@ class CallFunction(Expr):
# expr outputs = self.value
class Constant(Expr):
r"""``Constant`` represents a ``Tensor`` or "Module" which is not the attribute of a Module.
Args:
c: a const Tensor or Module.
name: the name of output Node.
"""
value = None
r"""The const Tensor or Module"""
# TODO: constant cache to reduce the size of dumped model
_constant_cache = {}
......
......@@ -15,6 +15,8 @@ from ..quantization.utils import QParams, QuantMode, fake_quant_tensor
class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
r"""A module to do quant and dequant according to :attr:`~.FakeQuantize.qparams`."""
def __init__(
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
):
......@@ -35,9 +37,10 @@ class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
return self.qparams
def set_qparams(self, qparams: QParams):
r"""
r"""Initialize :attr:`~.FakeQuantize.qparams`.
Args:
qparams: used to set initial scale.
qparams: used to set initial ``scale`` and ``zero_point``.
"""
if qparams.scale is None:
raise AssertionError("Can not get an initialized scale")
......
......@@ -11,29 +11,29 @@ from typing import Any, Dict, List, Tuple, Type
import numpy
from .. import get_logger
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..module import Module
from ..tensor import Tensor
logger = get_logger(__name__)
class Node:
r"""``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method.
They are inputs/outputs of Expr(the operations on variables).
Args:
expr: the Expr which produces the node
name: the name of the node
class Node:
r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
They are inputs/outputs of Expr (the operations on variables).
"""
expr = None
__total_id = 0
_id = None
expr = None # type: Expr
r"""The Expr which produces the Node."""
__total_id = 0 # type: int
_id = None # type: int
_top_graph = None # type: weakref.ReferenceType
_name = None
_orig_name = None
_format_spec = ""
_name = None # type: str
_orig_name = None # type: str
_format_spec = "" # type: str
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
def __init__(self, expr: "Expr", name: str, orig_name: str):
self.expr = expr
self.users = [] # List[Expr]
self._id = Node.__total_id
......@@ -73,32 +73,51 @@ class Node:
else:
return name if name else ("%d" % self._id)
@property
def name(self):
r"""Return the name of this Node."""
return self._name
@name.setter
def name(self, new_name: str):
graph = self.top_graph
assert graph is not None, "The parent graph of this Node cannot be None."
assert new_name not in graph._used_names, (
"The name(%s) is already in use. Please try a different one again."
% (new_name)
)
new_name = graph._create_unique_name(new_name)
self._name = new_name
self._orig_name = new_name
@property
def top_graph(self):
r"""Get the parent graph of this Node."""
if self._top_graph:
return self._top_graph()
return None
@classmethod
def set_format_spec(cls, str):
def _set_format_spec(cls, str):
old_format_spec = cls._format_spec
cls._format_spec = str
return old_format_spec
@classmethod
def get_total_id(cls):
def _get_next_id(cls):
return cls.__total_id
@classmethod
def set_total_id(cls, id: int = 0):
def _set_next_id(cls, id: int = 0):
assert isinstance(id, int)
cls.__total_id = id
class ModuleNode(Node):
r"""``ModuleNode`` represents the Module objects."""
module_type = Module # type: Type[Module]
r"""The type of the Module correspending to the ModuleNode."""
_owner = None # type: weakref.ReferenceType
def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
......@@ -116,6 +135,11 @@ class ModuleNode(Node):
@property
def owner(self):
r"""Get the ``Module`` corresponding to this ``ModuleNode``.
Returns:
An :calss:`~.Module`.
"""
if self._owner:
return self._owner()
return None
......@@ -145,6 +169,7 @@ class TensorNode(Node):
@property
def shape(self):
r"""Get the shape of this Node."""
return self._shape
@shape.setter
......@@ -153,6 +178,7 @@ class TensorNode(Node):
@property
def dtype(self):
r"""Get the dtype of this Node."""
return self._dtype
@dtype.setter
......@@ -161,6 +187,7 @@ class TensorNode(Node):
@property
def device(self):
r"""Get the device of this Node pointed Tensor."""
return self._device
@device.setter
......@@ -169,6 +196,7 @@ class TensorNode(Node):
@property
def qparams(self):
r"""Get the :calss:`QParams` of this Node."""
return self._qparams
@qparams.setter
......@@ -177,10 +205,16 @@ class TensorNode(Node):
@property
def value(self):
r"""Get the bound Tensor of this Node."""
return self._value
@value.setter
def value(self, value):
r"""Bind a Tensor to this Node.
Args:
value: A :class:`Tensor`.
"""
if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
setattr(value, "_NodeMixin__node", None)
self._value = value
......
......@@ -150,6 +150,9 @@ def tree_flatten(
is_leaf: Callable = _is_leaf,
is_const_leaf: Callable = _is_const_leaf,
):
r"""Flattens a object into a list of values and a :calss:`TreeDef` that can be used
to reconstruct the object.
"""
if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values), values
node = LeafDef(leaf_type(values))
......@@ -169,6 +172,15 @@ def tree_flatten(
class TreeDef:
r"""A ``TreeDef`` represents the structure of a pytree.
Args:
type: the type of root Node of the pytree.
aux_data: some const data that is useful in unflattening the pytree.
children_defs: ``TreeDef`` for each child of the root Node.
num_leaves: the number of leaves.
"""
def __init__(self, type, aux_data, children_defs):
self.type = type
self.aux_data = aux_data
......@@ -176,6 +188,9 @@ class TreeDef:
self.num_leaves = sum(ch.num_leaves for ch in children_defs)
def unflatten(self, leaves):
r"""Given a list of values and a ``TreeDef``, builds a object.
This is the inverse operation of ``tree_flatten``.
"""
assert len(leaves) == self.num_leaves
start = 0
children = []
......@@ -196,13 +211,10 @@ class TreeDef:
)
)
def __lt__(self, other):
return self.__hash__() < other.__hash__()
def __gt__(self, other):
return self.__hash__() > other.__hash__()
def __ne__(self, other) -> bool:
return not self.__eq__(other)
def __eq__(self, other):
def __eq__(self, other) -> bool:
return (
self.type == other.type
and self.aux_data == other.aux_data
......@@ -227,6 +239,9 @@ class LeafDef(TreeDef):
assert isinstance(leaves[0], self.type), self.type
return leaves[0]
def __ne__(self, other) -> bool:
return not self.__eq__(other)
def __eq__(self, other):
if isinstance(self.const_val, np.ndarray):
return self.type == other.type and (self.const_val == other.const_val).all()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册