提交 813be72c 编写于 作者: M Megvii Engine Team

feat(mge/utils): refactor GraphInference and add more options

GitOrigin-RevId: 44b96dbf3dbad8abad7900b2b4e0e82bf1c8314d
上级 d970b85d
......@@ -9,10 +9,9 @@
import collections
import json
import os
import threading
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Dict, List, Union
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple, Union
import numpy as np
......@@ -22,7 +21,7 @@ from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._imperative_rt.ops import BackwardGraph
from .._wrap import device as as_device
from ..ops.builtin import OpDef
from .core import OpBase, TensorBase
from .core import TensorBase
def set_priority_to_id(dest_vars):
......@@ -284,9 +283,9 @@ def optimize_for_inference(dest_vars, **kwargs):
if kwargs:
raise ValueError("unknown options: %s" % list(kwargs))
dest_vars = [var._node for var in dest_vars]
dest_vars = _unwrap(dest_vars)
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options)
return [VarNode(i) for i in res_vars]
return _wrap(res_vars)
CompGraphDumpResult = collections.namedtuple(
......@@ -312,7 +311,7 @@ def dump_graph(
keep_opr_priority: bool = False,
strip_info_file=None,
append_json=False
):
) -> Tuple[bytes, CompGraphDumpResult]:
"""
serialize the computing graph of `output_vars` and get byte result.
......@@ -347,22 +346,20 @@ def dump_graph(
* ``params`` list of names of dumped params
* ``outputs`` names of output vars
"""
ov = []
if isinstance(output_vars, dict):
used_vars = set()
for name, var in output_vars.items():
assert isinstance(var, VarNode), "bad output var: {!r}".format(var)
assert var.id not in used_vars, (
"var name is associated with a var object, so we can not have "
"two names given to the same var: {}".format(var)
)
used_vars.add(var.id)
var.name = name
ov.append(var._node)
output_vars = list(output_vars.values())
else:
for var in output_vars:
assert isinstance(var, VarNode), "bad output var: {!r}".format(var)
ov.append(var._node)
output_vars = list(output_vars)
ov = _unwrap(output_vars)
stat = []
inputs = []
......@@ -413,7 +410,7 @@ CompGraphLoadResult = collections.namedtuple(
)
def load_graph(fpath):
def load_graph(fpath) -> CompGraphLoadResult:
"""
Load a serialized computing graph from file.
......@@ -471,8 +468,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode):
graph._make_const_for_backward,
args,
)
outputs = [o._node if hasattr(o, "_node") else o for o in outputs]
return outputs
return _unwrap(outputs)
set_cpp_apply_backward_varnode(apply_backward_varnode)
......
......@@ -7,12 +7,14 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from collections import OrderedDict
from typing import Dict, List, Optional
from typing import Dict, List, Tuple, Union
import numpy
import numpy as np
from ..core import _imperative_rt
from ..core._imperative_rt import OperatorNode, VarNode
from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt import OperatorNode as _OpNode
from ..core._imperative_rt import VarNode as _VarNode
from ..core.tensor import megbrain_graph as G
from ..core.tensor.megbrain_graph import set_priority_to_id
from ..tensor import Tensor
......@@ -31,7 +33,9 @@ __all__ = [
]
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
def get_dep_vars(
var: Union[_VarNode, List[_VarNode]], var_type: Union[str, List[str]] = None
) -> List[_VarNode]:
"""
Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
depands on. If ``var_type`` is None, returns all types.
......@@ -39,7 +43,7 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
outputs = []
memo = set()
if isinstance(var, VarNode):
if isinstance(var, _VarNode):
var = [var]
if isinstance(var_type, str):
......@@ -61,14 +65,14 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
return outputs
def get_owner_opr_inputs(var: VarNode) -> List[VarNode]:
def get_owner_opr_inputs(var: _VarNode) -> List[_VarNode]:
"""
Gets the inputs of owner opr of a variable.
"""
return var.owner.inputs
def get_owner_opr_type(var: VarNode) -> str:
def get_owner_opr_type(var: _VarNode) -> str:
"""
Gets the type of owner opr of a variable.
......@@ -76,15 +80,15 @@ def get_owner_opr_type(var: VarNode) -> str:
return var.owner.type
def get_opr_type(opr: OperatorNode) -> str:
def get_opr_type(opr: _OpNode) -> str:
"""
Gets the type of an opr.
"""
assert isinstance(opr, OperatorNode)
assert isinstance(opr, _OpNode)
return opr.type
def graph_traversal(outputs: VarNode):
def graph_traversal(outputs: _VarNode):
"""
Helper function to traverse the computing graph and return enough useful information.
......@@ -142,8 +146,8 @@ def graph_traversal(outputs: VarNode):
def get_oprs_seq(
outputs: List[VarNode], prune_reshape=False, prune_immtensor=True
) -> List[OperatorNode]:
outputs: List[_VarNode], prune_reshape=False, prune_immtensor=True
) -> List[_OpNode]:
"""
Gets oprs in some topological order for a dumped model.
......@@ -218,7 +222,9 @@ def get_oprs_seq(
return oprs_seq
def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]:
def replace_vars(
dst: List[_VarNode], varmap: Dict[_VarNode, _VarNode]
) -> List[_VarNode]:
"""
Replaces vars in the graph.
......@@ -232,21 +238,19 @@ def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]:
repl_src_vec = []
repl_dst_vec = []
for i in dst:
assert isinstance(i, VarNode)
assert isinstance(i, _VarNode)
dst_vec.append(i)
for i, j in getattr(varmap, "items", lambda: varmap)():
assert isinstance(i, VarNode)
assert isinstance(j, VarNode)
assert isinstance(i, _VarNode)
assert isinstance(j, _VarNode)
repl_src_vec.append(i)
repl_dst_vec.append(j)
return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
def replace_oprs(
dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode]
) -> List[VarNode]:
def replace_oprs(dst: List[_VarNode], oprmap: Dict[_OpNode, _OpNode]) -> List[_VarNode]:
"""
Replaces operators in the graph.
......@@ -260,65 +264,154 @@ def replace_oprs(
repl_src_vec = []
repl_dst_vec = []
for i in dst:
assert isinstance(i, VarNode)
assert isinstance(i, _VarNode)
dst_vec.append(i)
for i, j in getattr(oprmap, "items", lambda: oprmap)():
assert isinstance(i, OperatorNode)
assert isinstance(j, OperatorNode)
assert isinstance(i, _OpNode)
assert isinstance(j, _OpNode)
repl_src_vec.append(i)
repl_dst_vec.append(j)
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
def find_vars_by_name(dst: List[_VarNode], names: List[str]) -> List[_VarNode]:
"""
Gets VarNode list by names in the graph.
:param dst: target vars representing the graph.
:param names: name list for target VarNode.
:return: results found by names.
"""
output_names = names.copy()
all_vars = get_dep_vars(dst) + dst
# use dict to keep outputs order the same as names.
output_dict = {}
for i in all_vars:
if i.name in output_names:
output_dict[i.name] = i
output_names.remove(i.name)
assert len(output_names) == 0, "Can not find varnode {} in this model".format(
output_names
)
return [output_dict[i] for i in names]
def convert_inputs(
dst: List[_VarNode], inputs: List[_VarNode] = None
) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
"""
Replaces ``Host2DeviceCopy`` with :class:`~.InputNode` in the graph
to :meth:`~.InputNode.set_value` and run.
:param dst: target vars representing the graph.
:param inputs: indicates which inputs to be replaced. All
inputs(``Host2DeiceCopy``) will be replaced if not specified.
:return: new vars that correspond to ``dst`` with all inputs
replaced, and new inputs dict.
"""
if inputs is None:
inputs = get_dep_vars(dst, "Host2DeviceCopy")
input_dict = OrderedDict()
replace_dict = {}
for inp in inputs:
inp_node = G.InputNode(
device=inp.comp_node, dtype=inp.dtype, shape=inp.shape, graph=inp.graph,
)
inp_node.name = inp.name
input_dict[inp.name] = inp_node
replace_dict[inp] = inp_node.outputs[0]
new_output_nodes = replace_vars(dst, replace_dict)
for old, new in zip(dst, new_output_nodes):
new.name = old.name
return new_output_nodes, input_dict
def convert_outputs(dst: List[_VarNode]) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
"""
Wraps ``dst`` with :class:`~.OutputNode` in the graph to get outputs
with :meth:`~.OutputNode.get_value`.
:param dst: target vars representing the graph.
:return: new vars that correspond to ``dst`` with all inputs
replaced, and outputs dict.
"""
output_dict = OrderedDict([(i.name, G.OutputNode(i)) for i in dst])
new_output_nodes = [i.outputs[0] for i in output_dict.values()]
return new_output_nodes, output_dict
def embed_inputs(
dst: List[_VarNode], data: List[np.ndarray], inputs: List[_VarNode] = None
) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
"""
Embeds ``data`` to the graph's inputs of ``dst``.
:param dst: target vars representing the graph.
:param data: data to be embeded.
:param inputs: indicates which inputs to be replaced. All
inputs(``Host2DeiceCopy``) will be replaced if not specified.
:return: new vars that correspond to ``dst`` with all inputs
replaced, and new inputs dict.
"""
if inputs is None:
inputs = get_dep_vars(dst, "Host2DeviceCopy")
assert len(data) == len(inputs)
input_dict = OrderedDict()
replace_dict = {}
for inp, d in zip(inputs, data):
new_inp = _imperative_rt.make_shared(inp.graph, Tensor(d)._dev_tensor())
new_inp.name = inp.name
input_dict[inp.name] = new_inp
replace_dict[inp] = new_inp
new_output_nodes = replace_vars(dst, replace_dict)
for old, new in zip(dst, new_output_nodes):
new.name = old.name
return new_output_nodes, input_dict
class GraphInference:
"""
Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph.
The `GraphInference.run()` accepts a list `inp_args` or a dict `inp_dict` {input_name: input_value} as input and returns a dict {output_name: output_value}.
Loads a serialized computing graph as a GraphInference object which can be used
to execute the computing graph.
:param file: could be file object or filename.
:param outputs: only compile the subgraph with outputs as its endpoints.
"""
def __init__(self, file, outputs: Optional[List[str]] = None):
*_, output_nodes = G.load_graph(file)
def __init__(
self,
file,
outputs: List[str] = None,
profiling: bool = False,
optimize_for_inference: bool = False,
**kwargs
):
self._graph, _, output_nodes = G.load_graph(file)
if outputs is not None:
output_name = outputs.copy()
all_vars = get_dep_vars(output_nodes) + output_nodes
new_outputs = {}
for i in all_vars:
if i.name in output_name:
new_outputs[i.name] = i
output_name.remove(i.name)
assert (
len(output_name) == 0
), "Can not find varnode {} in this model".format(output_name)
output_nodes = [new_outputs[i] for i in outputs]
inputs = get_dep_vars(output_nodes, "Host2DeviceCopy")
self._inp_dict = OrderedDict()
replace_dict = {}
for idx, i in enumerate(inputs):
inp_node = G.InputNode(
device="xpux", dtype=inputs[idx].dtype, graph=inputs[0].graph
)
self._inp_dict[i.name] = inp_node
replace_dict[i] = inp_node.outputs[0]
new_output_nodes = replace_vars(output_nodes, replace_dict)
for old, new in zip(output_nodes, new_output_nodes):
new.name = old.name
self._out_dict = OrderedDict(
[(i.name, G.OutputNode(i)) for i in new_output_nodes]
)
new_out_list = [i.outputs[0] for i in self._out_dict.values()]
cg = new_out_list[0].graph
self._func = cg.compile(new_out_list)
output_nodes = find_vars_by_name(output_nodes, outputs)
self._origin_outputs = output_nodes
# replace inputs with `InputNode`
output_nodes, self._inp_dict = convert_inputs(output_nodes)
# replace outputs with `OutputNode`
output_nodes, self._oup_dict = convert_outputs(output_nodes)
self._func = self._graph.compile(output_nodes)
def run(
self,
*inp_args: numpy.ndarray,
inp_dict: Optional[Dict[str, numpy.ndarray]] = None
):
self, *inp_args: np.ndarray, inp_dict: Dict[str, np.ndarray] = None
) -> Dict[str, np.ndarray]:
"""
:param inp_args: list of input datas.
:param inp_dict: dict of named input datas.
:return: a dict {output_name: output_value}.
"""
assert len(inp_args) <= len(
self._inp_dict
), "This model expects {} inputs".format(len(self._inp_dict))
......@@ -335,8 +428,11 @@ class GraphInference:
)
for key in self._inp_dict:
self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor())
self._func.execute()
self._func.wait()
result = OrderedDict()
for key in self._out_dict:
result[key] = self._out_dict[key].get_value().numpy()
for key in self._oup_dict:
result[key] = self._oup_dict[key].get_value().numpy()
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册