提交 556e0222 编写于 作者: M Megvii Engine Team

refactor(mge/trace): remove apply on varnode

GitOrigin-RevId: 0185244c1f957a90b4fd7b81d0e0a0a73ea552c2
上级 243a05b4
......@@ -437,7 +437,7 @@ def _unwrap(x):
return x
def apply_normal_op(op: OpDef, *args: VarNode):
def apply_normal_varnode(op: OpDef, *args: VarNode):
outputs = _imperative_rt.invoke_op(op, _unwrap(args))
return _wrap(outputs)
......@@ -447,7 +447,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode):
graph = args[0].graph
outputs = op.interpret(
op,
lambda op, args: apply_normal_op(op, *args),
lambda op, args: apply_normal_varnode(op, *args),
graph._make_const_for_backward,
args,
)
......
......@@ -41,7 +41,7 @@ from ..core._imperative_rt.ops import (
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core.ops.builtin import OpDef
from ..core.ops.builtin import BackwardGraph, OpDef
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from .sublinear_memory_config import SublinearMemoryConfig
......@@ -372,6 +372,7 @@ class trace:
lazy_eval_graph()
for r, x in zip(readers, lazy_eval_tensors):
x()._handle = RawTensor(r.op.get_value())._handle
x()._reset_varnode()
@contextlib.contextmanager
def _setup(self):
......@@ -580,9 +581,11 @@ class trace:
ivars.append(info.varnode)
ivars = [RawTensor(ivar) for ivar in ivars]
ovars = apply(op, *ivars)
ovars = [x._varnode for x in ovars]
if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
ovars = G.apply_normal_varnode(op, *ivars)
if require_links and len(ovars) > 0:
io_links = (ovars[0],)
assert len(ovars) == len(ohandles)
......@@ -768,11 +771,10 @@ class trace:
info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
)
ivars.append(h2v[h])
ivars = [RawTensor(ivar) for ivar in ivars]
ovars = apply(op, *ivars)
ovars = [x._varnode for x in ovars]
ovars = G.apply_normal_varnode(op, *ivars)
assert len(ovars) == len(ohandles)
h2v.update(zip(ohandles, ovars))
unset_tracing()
dest_vars = []
for i, h in enumerate(self._output_bindings):
......@@ -781,7 +783,6 @@ class trace:
v.name = output_names[i]
dest_vars.append(v)
dest_vars = [G.VarNode(var) for var in dest_vars]
if optimize_for_inference:
dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
......@@ -1007,7 +1008,6 @@ def assign_raw_tensor(lhs, rhs):
lhs.__init__(rhs)
# this hook turns RawTensor into LazyEvalTensor(varnode)
def apply_symbolic_mode(op: OpDef, *args: RawTensor):
graph = active_trace._lazy_eval_graph
ivars = []
......@@ -1038,13 +1038,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
ivars[0] = opnode.outputs[0]
active_trace._lazy_eval_links = (ivars[0],)
ivars = [
RawTensor(ivar._node) if hasattr(ivar, "_node") else RawTensor(ivar)
for ivar in ivars
]
unset_symbolic()
outputs = apply(op, *ivars)
set_symbolic()
if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
ovars = G.apply_normal_varnode(op, *ivars)
outputs = [RawTensor(o) for o in ovars]
if require_links:
active_trace._lazy_eval_links = (outputs[0]._varnode,)
......
......@@ -392,6 +392,10 @@ void TensorWrapper::reset(PyObject* tensor) {
m_tensor = t->m_tensor;
}
void TensorWrapper::reset_varnode() {
m_tensor->m_var = nullptr;
}
PyObject* TensorWrapper::detach() {
PyObject* self = wrap_t::pycast(this);
PyTypeObject* pytype = self->ob_type;
......@@ -687,6 +691,7 @@ void init_tensor(py::module m) {
.def<&TensorWrapper::_swap_out>("_swap_out")
.def<&TensorWrapper::_swap_in>("_swap_in")
.def<&TensorWrapper::_drop>("_drop")
.def<&TensorWrapper::reset_varnode>("_reset_varnode")
.def_getset<&TensorWrapper::varnode>("_varnode")
.def_getset<&TensorWrapper::data_read, &TensorWrapper::set_data_read>("data_read")
.def_getset<&TensorWrapper::value_read, &TensorWrapper::set_value_read>("value_read")
......
......@@ -155,6 +155,7 @@ struct TensorWrapper {
void _swap_out();
void _drop();
PyObject* varnode();
void reset_varnode();
PyObject* handle();
void set_handle(PyObject *);
......
......@@ -17,30 +17,9 @@ namespace py = pybind11;
namespace mgb::imperative::python {
apply_result_t apply_tensor_on_var_node(ApplyContext& ctx) {
apply_result_t outputs;
cg::VarNodeArray vinputs(ctx.nargs);
for (size_t i = 0; i < ctx.nargs; i++) {
vinputs[i] = ctx.args[i]->m_var;
}
auto ovars = OpDef::apply_on_var_node(*ctx.op, vinputs);
for (size_t i = 0; i < ovars.size(); i++) {
outputs.emplace_back(std::make_shared<Tensor>(ovars[i]));
}
return outputs;
}
apply_result_t apply_trace(ApplyContext& ctx) {
apply_result_t outputs;
bool run_apply_on_var_node = false;
for (size_t i = 0; i < ctx.nargs; i++) {
run_apply_on_var_node |= ((ctx.args[i]->m_handle.get() == nullptr) & (ctx.args[i]->m_var != nullptr));
}
if (ctx.backward) {
// reach here when symbolic=True or compiled=True
// call megbrain_graph.py apply(BackwardGraph, *args)
......@@ -63,10 +42,6 @@ apply_result_t apply_trace(ApplyContext& ctx) {
return outputs;
}
if (run_apply_on_var_node && !is_symbolic) {
return apply_tensor_on_var_node(ctx);
}
py::object pyf;
if (is_compiled) {
// run apply in compiled mode, step 2, 3, etc
......
......@@ -112,7 +112,7 @@ def test_quint8_typecvt():
data = np.random.random(shape).astype(np.float32) * 5 - 1
def typecvt(x, dt=None):
(y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x)
(y,) = G.apply_normal_varnode(ops.TypeCvt(dtype=dt), x)
return y
# convert to quint8
......@@ -193,7 +193,7 @@ def test_quint4_typecvt():
data = np.random.random(shape).astype(np.float32) * 5 - 1
def typecvt(x, dt=None):
(y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x)
(y,) = G.apply_normal_varnode(ops.TypeCvt(dtype=dt), x)
return y
# convert to quint4
......
......@@ -72,7 +72,7 @@ def test_op():
lambda: x, device=x.comp_node, dtype=x.dtype, graph=g
)
neg = Elemwise(Elemwise.Mode.NEGATE)
v = mgb_graph.apply_normal_op(neg, v)[0]
v = mgb_graph.apply_normal_varnode(neg, v)[0]
y = Future()
v = mgb_graph.output_callback(y.set_result, v)
f = g.compile(v)
......@@ -90,7 +90,7 @@ def test_exception():
g = mgb_graph.Graph()
x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g)
neg = Elemwise(Elemwise.Mode.NEGATE)
y = mgb_graph.OutputNode(mgb_graph.apply_normal_op(neg, x)[0])
y = mgb_graph.OutputNode(mgb_graph.apply_normal_varnode(neg, x)[0])
f = g.compile(y.outputs[0])
try:
f.execute()
......
......@@ -16,7 +16,7 @@ import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor import megbrain_graph as mgb_graph
from megengine.core.tensor.megbrain_graph import apply_normal_op
from megengine.core.tensor.megbrain_graph import apply_normal_varnode
from megengine.core.tensor.utils import astensor1d
from megengine.jit import trace
......@@ -34,9 +34,9 @@ def test_replace_vars():
const = g.make_const(1.234, device=device)
add_op = Elemwise(Elemwise.Mode.ADD)
mul_op = Elemwise(Elemwise.Mode.MUL)
a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0]
a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0]
rst = apply_normal_op(add_op, a_plus_a_mul_const, a.outputs[0])[0]
a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0]
a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0]
rst = apply_normal_varnode(add_op, a_plus_a_mul_const, a.outputs[0])[0]
(new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node})
out = mgb_graph.OutputNode(mgb_graph.VarNode(new))
func = g.compile(out.outputs[0])
......@@ -56,10 +56,10 @@ def test_replace_oprs():
const = g.make_const(1.25, device=device)
add_op = Elemwise(Elemwise.Mode.ADD)
mul_op = Elemwise(Elemwise.Mode.MUL)
a_plus_a = apply_normal_op(add_op, a.outputs[0], a.outputs[0])[0]
a_plus_a = apply_normal_varnode(add_op, a.outputs[0], a.outputs[0])[0]
old_opr = a_plus_a.op
a_plus_a_mul_const = apply_normal_op(mul_op, a_plus_a, const)[0]
a_mul_a = apply_normal_op(mul_op, a.outputs[0], a.outputs[0])[0]
a_plus_a_mul_const = apply_normal_varnode(mul_op, a_plus_a, const)[0]
a_mul_a = apply_normal_varnode(mul_op, a.outputs[0], a.outputs[0])[0]
new_opr = a_mul_a.op
(new,) = cgtools.replace_oprs(
[a_plus_a_mul_const._node], {old_opr._node: new_opr._node}
......
......@@ -163,6 +163,7 @@ def test_trace_profiler():
assert out.get("profiler")
@pytest.mark.skip(reason="force opt_level=0 when building graph")
def test_goptions():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x):
......@@ -181,6 +182,7 @@ def test_goptions():
np.testing.assert_equal(g(d).numpy().item(), 1.0)
@pytest.mark.skip(reason="force opt_level=0 when building graph")
def test_goptions_log_sum_exp():
@trace(symbolic=True, opt_level=0, capture_as_const=True)
def f(x, y):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册