提交 ac11c38a 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add graph load and cgtools for imperative

GitOrigin-RevId: ba251f452ae8c6cc9c3dae99d1be92711cbeff5e
上级 76f36796
......@@ -76,6 +76,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save
from .tensor import Parameter, Tensor, tensor
from .version import __version__
from .core import cgtools
_set_fork_exec_path_for_timed_func(
sys.executable,
......
......@@ -10,3 +10,5 @@ import os
import sys
from .tensor import Tensor
from .tensor.megbrain_graph import Graph
from .utils import comp_graph_tools as cgtools
......@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import json
import threading
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
......@@ -162,14 +163,42 @@ def optimize_for_inference(dest_vars, **kwargs):
return [VarNode(i) for i in res_vars]
def dump(*args):
def dump_graph(*args):
return _imperative_rt.dump_graph([i._node for i in args])
CompGraphLoadResult = collections.namedtuple(
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"]
)
def load_graph(fpath):
"""Load a serialized computing graph from file.
:parma fpath: Path or Handle for the output file
:return: An instance of namedtuple :class:`CompGraphLoadResult`,
whose fields are:
* ``graph`` loaded CompGraph
* ``output_vars_dict`` A Python dict, mapping name to output SymbolVar
* ``output_vars_list`` A Python list, containing output vars in the
order passed to serialize_comp_graph_to_file
"""
output_vars_map = []
output_vars_list = []
if isinstance(fpath, str):
buf = open(fpath, "rb").read()
else:
buf = fpath.read()
cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list)
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list)
class VarNode(TensorBase):
def __init__(self, node: _imperative_rt.VarNode):
self._node = node
self.graph._var_cache[node] = self
if hasattr(self.graph, "_var_cache"):
self.graph._var_cache[node] = self
@property
def graph(self) -> Graph:
......@@ -177,12 +206,19 @@ class VarNode(TensorBase):
@property
def op(self):
return self.graph._wrap(self._node.owner)
if hasattr(self.graph, "_wrap"):
return self.graph._wrap(self._node.owner)
else:
return self._node.owner
@property
def name(self):
return self._node.name
@property
def id(self):
return self._node.id
@name.setter
def name(self, name):
self._node.name = name
......@@ -207,7 +243,8 @@ class VarNode(TensorBase):
class OpNode:
def __init__(self, node: _imperative_rt.OperatorNode):
self._node = node
self.graph._op_cache[node] = self
if hasattr(self.graph, "_op_cache"):
self.graph._op_cache[node] = self
@property
def graph(self) -> Graph:
......@@ -217,29 +254,53 @@ class OpNode:
def name(self):
return self._node.name
@property
def id(self):
return self._node.id
@name.setter
def name(self, name):
self._node.name = name
@property
def inputs(self):
return tuple(map(self.graph._wrap, self._node.inputs))
if hasattr(self.graph, "_wrap"):
return tuple(map(self.graph._wrap, self._node.inputs))
else:
return self._node.inputs
@property
def outputs(self):
return tuple(map(self.graph._wrap, self._node.outputs))
if hasattr(self.graph, "_wrap"):
return tuple(map(self.graph._wrap, self._node.outputs))
else:
return self._node.outputs
@property
def params(self):
return json.loads(self._node.params)
@property
def type(self):
return self._node.type
def _wrap(x):
if isinstance(x, collections.abc.Sequence):
return type(x)(map(_wrap, x))
return x.graph._wrap(x)
if hasattr(x.graph, "_wrap"):
return x.graph._wrap(x)
else:
return x
def _unwrap(x):
if isinstance(x, collections.abc.Sequence):
return type(x)(map(_unwrap, x))
return x._node
if isinstance(x, VarNode):
return x._node
else:
return x
@apply.register()
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Dict, List
from .. import _imperative_rt
from .._imperative_rt import OperatorNode, VarNode
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
"""return :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
depands on. If ``var_type`` is None, return all types.
"""
outputs = []
memo = set()
if isinstance(var, VarNode):
var = [var]
if isinstance(var_type, str):
var_type = [var_type]
q = list(var)
while q:
v = q.pop()
if v in memo:
continue
memo.add(v)
q.extend(get_owner_opr_inputs(v))
if var_type is not None:
if get_owner_opr_type(v) in var_type:
outputs.append(v)
else:
outputs.append(v)
return outputs
def get_owner_opr_inputs(var: VarNode) -> List[VarNode]:
"""get the inputs of owner opr of a variable
"""
assert isinstance(var, VarNode)
return var.owner.inputs
def get_owner_opr_type(var: VarNode) -> str:
"""get the type of owner opr of a variable
"""
assert isinstance(var, VarNode)
return var.owner.type
def get_opr_type(opr: OperatorNode) -> str:
"""get the type of a opr
"""
assert isinstance(opr, OperatorNode)
return opr.type
def graph_traversal(outputs: VarNode):
"""helper function to traverse the computing graph and return enough useful information
:param outputs: model outputs
:return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)
WHERE
map_oprs is dict from opr_id to actual opr
map_vars is dict from var_id to actual var
var2oprs is dict from var to dest oprs along with index
opr2receivers is dict from current opr to next opr
indegree2opr is dict from in_degree to opr in computing graph
opr2indegree is dict from opr in computing graph to in_degree
(indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
"""
# meta information for comp graph
map_oprs = collections.defaultdict(set)
map_vars = collections.defaultdict(set)
var2oprs = collections.defaultdict(list)
opr2receivers = collections.defaultdict(list)
queue = list(map(lambda x: x.owner, outputs))
visited = set(map(lambda x: x.id, queue))
# iterate through whole comp_graph, fill in meta information
indegree2opr = collections.defaultdict(set)
opr2indegree = {}
idx = 0
while idx < len(queue):
cur_opr = queue[idx]
map_oprs[cur_opr.id] = cur_opr
idx += 1
indegree = 0
for var_idx, var in enumerate(cur_opr.inputs):
map_vars[var.id] = var
var2oprs[var.id].append((cur_opr.id, var_idx))
pre_opr = var.owner
if pre_opr.id not in visited:
visited.add(pre_opr.id)
queue.append(pre_opr)
indegree += 1
opr2receivers[pre_opr.id].append(cur_opr.id)
indegree2opr[indegree].add(cur_opr.id)
opr2indegree[cur_opr.id] = indegree
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree
def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]:
"""get oprs in some topological order for a dumped model
:param outputs: model outputs
:param prune_reshape: whether to prune the operators useless during inference
:return: opr list with some correct execution order
"""
def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
# generate an execution order with topological sort algorithm
oprs_seq = []
nr_remain = len(map_oprs)
while indegree2opr[0]:
opr_id = indegree2opr[0].pop()
opr = map_oprs[opr_id]
nr_remain -= 1
# skip const value generation operator
if get_opr_type(opr) != "ImmutableTensor":
oprs_seq.append(opr)
for post_id in opr2receivers[opr_id]:
indegree = opr2indegree[post_id]
indegree2opr[indegree].remove(post_id)
indegree -= 1
indegree2opr[indegree].add(post_id)
opr2indegree[post_id] = indegree
assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
nr_remain
)
return oprs_seq
# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
def iterative_pruning(cur_opr, post_opr, marked_opr_ids):
useless = True
for oup in cur_opr.outputs:
if "workspace" not in oup.name:
var_idx = post_opr.inputs.index(oup)
var2oprs[oup.id].remove((post_opr.id, var_idx))
useless = useless and (len(var2oprs[oup.id]) == 0)
if useless:
marked_opr_ids.append(cur_opr.id)
for inp in cur_opr.inputs:
iterative_pruning(inp.owner, cur_opr, marked_opr_ids)
reshape_vars = get_dep_vars(outputs, "Reshape")
reshape_oprs = [var.owner for var in reshape_vars]
marked_opr_ids = []
for reshape_opr in reshape_oprs:
iterative_pruning(reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids)
# filter out all marked oprs
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))
map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
outputs
)
oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
if prune_reshape is True:
oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
return oprs_seq
def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]:
"""replace vars in the graph
:param dst: target vars representing the graph
:param varmap: the map that specifies how to replace the vars
:return: new vars that correspond to ``dst`` with all the dependencies
replaced
"""
dst_vec = []
repl_src_vec = []
repl_dst_vec = []
for i in dst:
assert isinstance(i, VarNode)
dst_vec.append(i)
for i, j in getattr(varmap, "items", lambda: varmap)():
assert isinstance(i, VarNode)
assert isinstance(j, VarNode)
repl_src_vec.append(i)
repl_dst_vec.append(j)
return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)
def replace_oprs(
dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode]
) -> List[VarNode]:
"""Replace operators in the graph.
:param dst: target vars representing the graph
:param oprmap: the map that specifies how to replace the operators
:return: new vars that correspond to ``dst`` with all the dependencies
replaced
"""
dst_vec = []
repl_src_vec = []
repl_dst_vec = []
for i in dst:
assert isinstance(i, VarNode)
dst_vec.append(i)
for i, j in getattr(oprmap, "items", lambda: oprmap)():
assert isinstance(i, OperatorNode)
assert isinstance(j, OperatorNode)
repl_src_vec.append(i)
repl_dst_vec.append(j)
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)
def set_priority_to_id(dest_vars):
"""For all oprs in the subgraph constructed by dest_vars
set its priority to id if its original priority is zero
:param dest_vars: target vars representing the graph
"""
dest_vec = []
for i in dest_vars:
assert isinstance(i, VarNode)
dest_vec.append(i)
_imperative_rt.graph._set_priority_to_id(dest_vec)
......@@ -569,7 +569,7 @@ class trace:
if isinstance(file, str):
permission = "wb" if append == False else "ab"
file = open(file, permission)
file.write(G.dump(*dest_vars))
file.write(G.dump_graph(*dest_vars))
def _process_inputs(self, *args, **kwargs):
if self._untraced:
......
......@@ -64,7 +64,60 @@ auto def_rendezvous(py::object m, const char* name) {
using TensorAttr = LogicalTensorDesc;
using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>;
std::vector<mgb::cg::VarNode*> _replace_vars(const std::vector<mgb::cg::VarNode*>& repl_src,
const std::vector<mgb::cg::VarNode*>& repl_dst,
const std::vector<mgb::cg::VarNode*>& vars) {
mgb::ThinHashMap<SymbolVar, SymbolVar> varmap;
for (size_t i = 0; i < repl_src.size(); ++i) {
varmap[SymbolVar(repl_src[i])] = SymbolVar(repl_dst[i]);
}
SymbolVarArray symvars(vars.begin(), vars.end());
auto sym_result = mgb::cg::replace_vars(symvars, varmap);
std::vector<mgb::cg::VarNode*> result;
for (auto symvar : sym_result){
result.push_back(symvar.node());
}
return result;
}
typedef std::vector<mgb::cg::OperatorNodeBase*> OperatorArray;
std::vector<mgb::cg::VarNode*> _replace_oprs(const OperatorArray& repl_src,
const OperatorArray& repl_dst,
const std::vector<mgb::cg::VarNode*>& vars) {
mgb::ThinHashMap<mgb::cg::OperatorNodeBase*, mgb::cg::OperatorNodeBase*>
oprmap;
for (size_t i = 0; i < repl_src.size(); ++i) {
oprmap[repl_src[i]] = repl_dst[i];
}
const SymbolVarArray symvars(vars.begin(), vars.end());
auto sym_result = mgb::cg::replace_oprs(symvars, oprmap);
std::vector<mgb::cg::VarNode*> result;
for (auto symvar : sym_result){
result.push_back(symvar.node());
}
return result;
}
void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) {
auto on_opr = [](mgb::cg::OperatorNodeBase* opr) {
if (opr->node_prop().attribute().priority == 0) {
opr->node_prop().attribute().priority = opr->id();
}
};
mgb::cg::DepOprIter dep_iter{on_opr};
for (const auto& var : dest_vars) {
dep_iter.add(SymbolVar(var));
}
}
void init_graph_rt(py::module m) {
static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{std::make_unique<mgb::OprFootprint>()};
def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous");
def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous");
......@@ -99,7 +152,10 @@ void init_graph_rt(py::module m) {
return py::none();
}
return py::cast(*val).attr("numpy")();
});
})
.def_property_readonly("id",[](cg::VarNode* v){
return (v->id());
});
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode")
.def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();})
......@@ -110,7 +166,17 @@ void init_graph_rt(py::module m) {
})
.def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) {
return to_tuple(opr->usable_output());
});
})
.def_property_readonly("id",[](cg::OperatorNodeBase* opr){
return opr->id();
})
.def_property_readonly("params",[](cg::OperatorNodeBase* opr){
return _imperative_sm_opr_footprint_ptr->calc_footprint(opr).param->to_string();
})
.def_property_readonly("type",[](cg::OperatorNodeBase* opr){
return opr->dyn_typeinfo()->name;
});
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
.def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>())
......@@ -174,6 +240,44 @@ void init_graph_rt(py::module m) {
});
m.def("load_graph", [](std::string& buf, py::list& _output_var_map, py::list& _output_var_list) {
using namespace mgb::serialization;
auto file = InputFile::make_mem_proxy(buf.c_str(), buf.length());
auto format = GraphLoader::identify_graph_dump_format(*file);
auto loader = GraphLoader::make(std::move(file), format.val());
GraphLoader::LoadConfig config;
auto rst = loader->load(config);
std::vector<std::pair<std::string, SymbolVar>> output_var_map;
SymbolVarArray output_var_list;
output_var_map = {rst.output_var_map.begin(), rst.output_var_map.end()};
output_var_list = std::move(rst.output_var_list);
for (auto i : output_var_list){
_output_var_list.append(i.node());
}
for (auto i : output_var_map){
_output_var_map.append(py::make_tuple(i.first,i.second.node()));
}
std::unordered_map<HostTensorND*, const std::string*> tensor2name;
for (const auto& pair : rst.tensor_map) {
tensor2name[pair.second.get()] = &pair.first;
}
auto cb = [&tensor2name, graph=rst.graph](cg::OperatorNodeBase* opr) {
if (!opr->same_type<opr::Host2DeviceCopy>())
return;
auto& h2d = opr->cast_final_safe<opr::Host2DeviceCopy>();
auto it = tensor2name.find(h2d.host_data().get());
mgb_throw_if(it == tensor2name.end(), GraphError,
"unbound Host2DeviceCopy in loaded graph");
h2d.output(0)->name(*it->second);
};
cg::DepOprIter iter{cb};
for (const auto& var : output_var_list) {
iter.add(var.node()->owner_opr());
}
return rst.graph;
});
#define CURRENT_CLASS cg::ComputingGraph::Options
auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options")
......@@ -287,6 +391,10 @@ void init_graph_rt(py::module m) {
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, shape, dtype), config).node();
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none());
m.def("_replace_vars", &_replace_vars,py::arg(),py::arg(),py::arg());
m.def("_replace_oprs", &_replace_oprs,py::arg(),py::arg(),py::arg());
m.def("_set_priority_to_id",&_set_priority_to_id,py::arg());
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback,
const CompNode& comp_node,
const DType& dtype,
......
......@@ -16,7 +16,7 @@
#include <memory>
#include <mutex>
#include <future>
#include "megbrain/plugin/opr_footprint.h"
#include "megbrain/graph.h"
template<typename T>
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io
import numpy as np
import megengine
import megengine.functional as F
import megengine.module as M
from megengine import cgtools
from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.jit import trace
def make_dev_tensor(value, dtype=None, device=None):
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
def test_replace_vars():
g = mgb_graph.Graph()
g.options.async_exec_level = 0b100
device = "xpux"
dtype = np.float32
a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
const = g.make_const(1.234)
a_plus_a = F.add(a.outputs[0], a.outputs[0])
a_plus_a_mul_const = F.mul(a_plus_a, const)
rst = F.add(a_plus_a_mul_const, a.outputs[0])
(new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node})
out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
func = g.compile(out.outputs[0])
func.execute()
x = make_dev_tensor(5.0, device=device)
a.set_value(x)
res = out.get_value().numpy()
np.testing.assert_equal(res, np.array([105.0]))
def test_replace_oprs():
g = mgb_graph.Graph()
g.options.async_exec_level = 0b100
device = "xpux"
dtype = np.float32
a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g)
const = g.make_const(1.25)
a_plus_a = F.add(a.outputs[0], a.outputs[0])
old_opr = a_plus_a.op
a_plus_a_mul_const = F.mul(a_plus_a, const)
a_mul_a = F.mul(a.outputs[0], a.outputs[0])
new_opr = a_mul_a.op
(new,) = cgtools.replace_oprs(
[a_plus_a_mul_const._node], {old_opr._node: new_opr._node}
)
out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
func = g.compile(out.outputs[0])
func.execute()
x = make_dev_tensor(5.0, device=device)
a.set_value(x)
res = out.get_value().numpy()
np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25]))
def test_graph_traversal():
net = M.Conv2d(3, 32, 3)
@trace(symbolic=True, capture_as_const=True)
def fun(data):
x = net(data)
return x
data = np.random.random([1, 3, 224, 224]).astype(np.float32)
for i in range(3):
fun(megengine.tensor(data))
file = io.BytesIO()
fun.dump(file)
file.seek(0)
cg, _, outputs = mgb_graph.load_graph(file)
_, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs)
input_var = map_vars[1]
_, var_idx = var2oprs[input_var.id][0]
assert var_idx == 0
......@@ -13,6 +13,10 @@ import numpy as np
import pytest
from megengine import tensor
import megengine
import megengine.core.tensor.megbrain_graph as mgb_graph
import megengine.module as M
from megengine import cgtools
from megengine.core.ops import builtin as ops
from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.core import apply
......@@ -21,6 +25,29 @@ from megengine.functional import exp, log
from megengine.jit import exclude_from_trace, trace
def load_and_inference(file, inp_data):
cg, _, out_list = mgb_graph.load_graph(file)
inputs = cgtools.get_dep_vars(out_list, "Host2DeviceCopy")
replace_dict = {}
inp_node_list = []
for i in inputs:
inp_node = mgb_graph.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
replace_dict[i] = inp_node.outputs[0]
inp_node_list.append(inp_node)
new_out = cgtools.replace_vars(out_list, replace_dict)
out_node_list = [mgb_graph.OutputNode(i) for i in new_out]
new_out_list = [i.outputs[0] for i in out_node_list]
new_cg = new_out_list[0].graph
func = new_cg.compile(new_out_list)
for node, value in zip(inp_node_list, inp_data):
node.set_value(as_raw_tensor(value)._dev_tensor())
func.execute()
out_data_list = [o.get_value().numpy() for o in out_node_list]
return out_data_list
def test_trace():
for symbolic in [False, True]:
......@@ -81,13 +108,58 @@ def test_print_in_trace():
def test_dump():
@trace(symbolic=True, capture_as_const=True)
def f(a, b):
op = ops.Elemwise(mode="add")
(y,) = apply(op, a, b)
return y
a = as_raw_tensor([2]).numpy()
b = as_raw_tensor([4]).numpy()
y = f.__wrapped__(as_raw_tensor(a), as_raw_tensor(b)).numpy()
for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(a), as_raw_tensor(b)).numpy(), y)
file = io.BytesIO()
f.dump(file)
file.seek(0)
result = load_and_inference(file, [a, b])
np.testing.assert_equal(result[0], y)
def test_capture_dump():
a = as_raw_tensor([2])
@trace(symbolic=True, capture_as_const=True)
def f(x):
op = ops.Elemwise(mode="mul")
(y,) = apply(op, x, a)
return y
x = as_raw_tensor([3]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
file = io.BytesIO()
f.dump(file)
file.seek(0)
result = load_and_inference(file, [x])
np.testing.assert_equal(result[0], y)
def test_dump_volatile():
p = as_raw_tensor([2])
@trace(symbolic=True, capture_as_const=True)
def f(x):
op = ops.Elemwise(mode="negate")
(y,) = apply(op, x)
op = ops.Elemwise(mode="mul")
(y,) = apply(op, x, p)
return y
x = as_raw_tensor([1]).numpy()
x = as_raw_tensor([3]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
for i in range(3):
......@@ -95,6 +167,13 @@ def test_dump():
file = io.BytesIO()
f.dump(file)
file.seek(0)
cg, _, outputs = mgb_graph.load_graph(file)
(out,) = outputs
assert (
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1])
== "SharedDeviceTensor"
)
def test_trace_profiler():
......
......@@ -471,11 +471,9 @@ def main():
assert not testcase, 'extra inputs provided in testcase: {}'.format(
testcase.keys()
)
mgb.serialize_comp_graph_to_file(
args.output,
output_mgbvars,
append=True,
output_strip_info=args.output_strip_info)
with open(args.output, "ab") as fout:
fout.write(G.dump_graph(*output_mgbvars))
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册