提交 9a79342d 编写于 作者: M Megvii Engine Team

refactor(mge/virtualdep): remove virtualdep opdef

GitOrigin-RevId: bbe3ae3fa3d635e7adc1a7a5080e060699ea611f
上级 88b402ef
......@@ -569,3 +569,9 @@ class AttrOutputNode(OpNode):
def reset(self):
self._rendezvous.reset()
class VirtualDepNode(OpNode):
def __init__(self, vars, device=""):
out = _imperative_rt.virtual_dep(_unwrap(vars), device)
super().__init__(out)
......@@ -25,7 +25,6 @@ from ..core._imperative_rt.ops import (
RemoteRecv,
RemoteSend,
UniformRNG,
VirtualDep,
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
......@@ -548,9 +547,10 @@ class trace:
need_reset_nodes.append(opnode)
info.varnode, *in_out_links = opnode.outputs
if require_links and i == 0 and len(io_links) > 0:
info.varnode = apply(
VirtualDep(str(io_links[0].device)), info.varnode, *io_links
)[0]
opnode = G.VirtualDepNode(
[info.varnode, *io_links], str(io_links[0].device)
)
info.varnode = opnode.outputs[0]
io_links = (info.varnode,)
ivars.append(info.varnode)
......@@ -1112,11 +1112,11 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
if require_links and active_trace._lazy_eval_links:
assert len(ivars) > 0, "op should has at least one input"
ivars[0] = apply(
VirtualDep(str(active_trace._lazy_eval_links[0].device)),
ivars[0],
*active_trace._lazy_eval_links,
)[0]
opnode = G.VirtualDepNode(
[ivars[0], *active_trace._lazy_eval_links],
str(active_trace._lazy_eval_links[0].device),
)
ivars[0] = opnode.outputs[0]
active_trace._lazy_eval_links = (ivars[0],)
ovars = apply(op, *ivars)
......
......@@ -15,6 +15,7 @@
#include "megbrain/serialization/serializer.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative.h"
#include "./helper.h"
......@@ -562,4 +563,16 @@ void init_graph_rt(py::module m) {
};
return output_callback(std::move(f), std::move(inputs), p, true);
});
m.def("virtual_dep", [](std::vector<cg::VarNode*> inputs, std::string device) {
auto&& graph = inputs[0]->owner_graph();
VarNodeArray inps(inputs.begin(), inputs.end());
cg::OperatorNodeConfig config;
if (device.length() > 0) {
config.comp_node(CompNode::load(device));
}
cg::OperatorNodeBase* opr = graph->insert_opr(
std::make_unique<mgb::opr::VirtualDep>(inps, config));
return opr;
});
}
......@@ -10,12 +10,10 @@
*/
#include "./ops.h"
#include <string>
#include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/autogen.h"
namespace py = pybind11;
......@@ -45,9 +43,5 @@ void init_ops(py::module m) {
return self.graph().interpret<py::object>(f, c, inputs);
});
py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep")
.def(py::init<>())
.def(py::init<std::string>());
#include "opdef.py.inl"
}
/**
* \file imperative/src/impl/ops/utility.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/utility.h"
#include <string>
#include "megbrain/comp_node.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/utility.h"
#include "../op_trait.h"
namespace mgb::imperative {
namespace {
cg::OperatorNodeBase* virtual_dep_apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) {
auto&& graph = inputs[0]->owner_graph();
auto&& op = def.cast_final_safe<VirtualDep>();
VarNodeArray inps(inputs.begin(), inputs.end());
cg::OperatorNodeConfig config;
if (op.device.length() > 0) {
config.comp_node(CompNode::load(op.device));
}
cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::VirtualDep>(
inps, config));
return opr;
}
OP_TRAIT_REG(VirtualDep, VirtualDep, mgb::opr::VirtualDep)
.apply_on_var_node(virtual_dep_apply_on_var_node)
.fallback();
} // namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL(VirtualDep);
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/ops/utility.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <string>
#include "megbrain/graph/operator_node.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h"
namespace mgb::imperative {
class VirtualDep : public OpDefImplBase<VirtualDep> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
VirtualDep() = default;
VirtualDep(std::string dev) : device(dev) {}
std::string device;
size_t hash() const override {
return reinterpret_cast<size_t>(dyn_typeinfo());
}
bool is_same_st(const Hashable& rhs) const override {
return true;
}
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册