提交 cf3a55ce 编写于 作者: M Megvii Engine Team

fix(mgb/opr-mm): remove PeerDesc from RemoteSend and RemoteRecv

GitOrigin-RevId: b7a7bbd0dad4ab27d9c51c59c8011e518e79e097
上级 d53dab2f
......@@ -72,10 +72,10 @@ SymbolVar _Opr::remote_send(
const std::string& key, SymbolVar var,
const bool is_grad,
const OperatorNodeConfig& config) {
return RemoteSend::make({key, RemoteIOBase::Type::SEND, is_grad}, var,
return RemoteSend::make(key, var,
std::make_shared<GroupClientProxy>(ssprintf(
"%s:%d", server_addr.c_str(), port)),
config);
is_grad, config);
}
SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
......@@ -85,8 +85,7 @@ SymbolVar _Opr::remote_recv(const std::string& server_addr, const int port,
const TensorShape ishape = npy::vec2shape(shape);
const DType idtype = npy::dtype_np2mgb(dtype);
return RemoteRecv::make({key, RemoteIOBase::Type::RECV, false},
graph.get(),
return RemoteRecv::make(key, graph.get(),
std::make_shared<GroupClientProxy>(
ssprintf("%s:%d", server_addr.c_str(), port)),
config, ishape, idtype);
......
......@@ -26,27 +26,28 @@ cudaStream_t get_stream(VarNode* var) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
RemoteSend::RemoteSend(const PeerDesc& peer, VarNode* var,
RemoteSend::RemoteSend(const std::string& key, VarNode* var,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config) :
Super(var->owner_graph(), config, "remote_send", {var}) {
m_peer = peer;
bool is_grad, const OperatorNodeConfig& config) :
Super(var->owner_graph(), config, "remote_send", {var}),
m_is_grad(is_grad) {
m_key = key;
m_group_client = group_client;
add_input({var});
auto ovar = add_output(None);
if (!peer.is_grad) {
if (!m_is_grad) {
ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT);
}
add_equivalence_component<ScalarHash<void*>>(this);
}
SymbolVar RemoteSend::make(const PeerDesc& peer, SymbolVar var,
SymbolVar RemoteSend::make(const std::string& key, SymbolVar var,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config) {
return var.insert_single_output_opr<RemoteSend>(peer, var.node(),
group_client, config);
bool is_grad, const OperatorNodeConfig& config) {
return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client,
is_grad, config);
}
void RemoteSend::scn_do_execute() {
......@@ -54,11 +55,11 @@ void RemoteSend::scn_do_execute() {
auto&& comp_node = output(0)->comp_node();
// rank 0 for RemoteSend
auto reg_info = m_group_client->opr_register(m_peer.key, 2, 0, false,
auto reg_info = m_group_client->opr_register(m_key, 2, 0, false,
comp_node.get_uid());
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_peer.key, 2, 0, MegRay::MEGRAY_UCX, m_group_client);
reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_UCX, m_group_client);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
......@@ -76,7 +77,7 @@ void RemoteSend::scn_do_execute() {
auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx);
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed");
if (m_peer.is_grad) {
if (m_is_grad) {
auto&& dest = output(0)->dev_tensor();
if (m_output_val.empty()) {
m_output_val.comp_node(dest.comp_node())
......@@ -92,7 +93,7 @@ void RemoteSend::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
auto do_infer = [this](TensorShape& dest, const InpVal&) {
if (peer_desc().is_grad) {
if (m_is_grad) {
dest = {1};
} else {
dest = {0};
......@@ -109,9 +110,8 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const {
}
MGB_IMPL_OPR_GRAD(RemoteSend) {
mgb_assert(opr.peer_desc().is_grad);
return RemoteRecv::make({opr.peer_desc().key + ":grad",
RemoteIOBase::Type::RECV, false},
mgb_assert(opr.is_grad());
return RemoteRecv::make(opr.key() + ":grad",
*opr.owner_graph(), opr.group_client(),
OperatorNodeConfig{opr.comp_node()}.name(
opr.name() + ":grad_recv"),
......@@ -123,13 +123,13 @@ MGB_IMPL_OPR_GRAD(RemoteSend) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv);
RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph,
RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) :
Super(&graph, config, "remote_recv", {}),
m_shape(shape), m_dtype(dtype) {
m_peer = peer;
m_key = key;
m_group_client = group_client;
add_output(None)
......@@ -139,12 +139,12 @@ RemoteRecv::RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph,
add_equivalence_component<ScalarHash<void*>>(this);
}
SymbolVar RemoteRecv::make(const PeerDesc& peer, cg::ComputingGraph& graph,
SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config,
const TensorShape& shape, DType dtype) {
auto opr = graph.insert_opr(std::make_unique<RemoteRecv>(
peer, graph, group_client, config, shape, dtype));
key, graph, group_client, config, shape, dtype));
return opr->output(0);
}
......@@ -154,11 +154,11 @@ void RemoteRecv::scn_do_execute() {
// rank 1 for RemoteRecv
auto reg_info = m_group_client->opr_register(
m_peer.key, 2, false, 1,
m_key, 2, false, 1,
comp_node.get_uid());
m_megray_comm = MegRayCommBuilder::get_megray_comm(
reg_info.hash, m_peer.key, 2, 1, MegRay::MEGRAY_UCX, m_group_client);
reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_UCX, m_group_client);
m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));
......@@ -206,8 +206,8 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_send(
const OperatorNodeConfig& config) {
mgb_assert(inputs.size() == 1);
auto&& opr = opr_.cast_final_safe<RemoteSend>();
return RemoteSend::make(opr.peer_desc(), inputs[0], opr.group_client(),
config)
return RemoteSend::make(opr.key(), inputs[0], opr.group_client(),
opr.is_grad(), config)
.node()
->owner_opr();
}
......@@ -218,7 +218,7 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<RemoteRecv>();
return RemoteRecv::make(opr.peer_desc(), *opr.owner_graph(),
return RemoteRecv::make(opr.key(), *opr.owner_graph(),
opr.group_client(), config, inputs[0]->shape(),
inputs[0]->dtype())
.node()
......
......@@ -25,25 +25,14 @@ namespace opr {
*/
MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // {
public:
enum Type {
SEND,
RECV
};
struct PeerDesc {
std::string key;
Type type;
bool is_grad;
};
const PeerDesc& peer_desc() const { return m_peer; }
const std::string& key() const { return m_key; }
std::shared_ptr<GroupClient> group_client() const {
return m_group_client;
}
protected:
PeerDesc m_peer;
std::string m_key;
std::shared_ptr<GroupClient> m_group_client;
std::shared_ptr<MegRay::Communicator> m_megray_comm;
std::shared_ptr<MegRay::Context> m_megray_ctx;
......@@ -53,21 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // {
/*!
* \brief send a variable to remote address; a virtual output is produced
* for expressing dependency
* for expressing dependency
*/
MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // {
public:
RemoteSend(const PeerDesc& peer, VarNode* var,
RemoteSend(const std::string& key, VarNode* var,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config);
bool is_grad, const OperatorNodeConfig& config);
static SymbolVar make(
const PeerDesc& peer, SymbolVar var,
const std::string& key, SymbolVar var,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config = {});
bool is_grad, const OperatorNodeConfig& config = {});
bool is_grad() const { return m_is_grad; }
private:
HostTensorND m_output_val;
bool m_is_grad;
void scn_do_execute() override;
void init_output_static_infer_desc() override;
......@@ -75,19 +67,18 @@ MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // {
};
/*!
* \brief receive from multiple remote addresses and write to a var
*
* Target computing node of the var must be specified in config
* \brief receive a variable from remote address; target computing node
* of the var must be specified in config
*/
MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
public:
RemoteRecv(const PeerDesc& peer, cg::ComputingGraph& graph,
RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);
static SymbolVar make(
const PeerDesc& peer, cg::ComputingGraph& graph,
const std::string& key, cg::ComputingGraph& graph,
std::shared_ptr<GroupClient> group_client,
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);
......
......@@ -20,9 +20,6 @@
using namespace mgb;
const auto send_tag = opr::RemoteIOBase::Type::SEND;
const auto recv_tag = opr::RemoteIOBase::Type::RECV;
TEST(TestOprIORemote, Identity) {
REQUIRE_GPU(2);
auto cn0 = CompNode::load("gpu0");
......@@ -36,8 +33,8 @@ TEST(TestOprIORemote, Identity) {
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x, cn0);
auto xr = opr::RemoteSend::make({"x", send_tag, false}, x, client);
auto y = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(),
auto xr = opr::RemoteSend::make("x", x, client, false);
auto y = opr::RemoteRecv::make("x", *graph.get(),
client, {cn1}, host_x->shape(),
host_x->dtype());
......@@ -59,7 +56,7 @@ TEST(TestOprIORemote, IdentityMultiThread) {
auto graph = ComputingGraph::make();
sys::set_thread_name("sender");
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client);
xr = opr::RemoteSend::make("x", x, client, false);
auto func = graph->compile({{xr, {}}});
func->execute();
};
......@@ -67,7 +64,7 @@ TEST(TestOprIORemote, IdentityMultiThread) {
auto receiver = [&]() {
sys::set_thread_name("receiver");
auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(),
auto x = opr::RemoteRecv::make("x", *graph.get(),
client, {cns[0]}, host_x->shape(),
host_x->dtype());
auto func = graph->compile({make_callback_copy(x, host_x_get)});
......@@ -92,7 +89,7 @@ TEST(TestOprIORemote, IdentityWithGopt) {
sys::set_thread_name("sender");
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x) * 2 + 1,
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client);
xr = opr::RemoteSend::make("x", x, client, false);
auto func = graph->compile({{xr, {}}});
func->execute();
};
......@@ -100,7 +97,7 @@ TEST(TestOprIORemote, IdentityWithGopt) {
auto receiver = [&]() {
sys::set_thread_name("receiver");
auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(),
auto x = opr::RemoteRecv::make("x", *graph.get(),
client, {cns[0]}, host_x->shape(),
host_x->dtype());
auto func =
......@@ -124,14 +121,14 @@ TEST(TestOprIORemote, APlusB) {
auto sender = [&]() {
auto graph = ComputingGraph::make();
auto z = opr::RemoteRecv::make({"z", recv_tag, false}, *graph.get(),
auto z = opr::RemoteRecv::make("z", *graph.get(),
client, {cns[0]}, host_x->shape(),
host_x->dtype());
auto x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x"),
y = opr::Host2DeviceCopy::make(*graph, host_y).rename("y"),
xr = opr::RemoteSend::make({"x", send_tag, false}, x, client)
xr = opr::RemoteSend::make("x", x, client, false)
.rename("xr"),
yr = opr::RemoteSend::make({"y", send_tag, false}, y, client)
yr = opr::RemoteSend::make("y", y, client, false)
.rename("yr");
auto func = graph->compile(
{{xr, {}}, {yr, {}}, make_callback_copy(z, host_z)});
......@@ -142,14 +139,14 @@ TEST(TestOprIORemote, APlusB) {
auto receiver = [&]() {
auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make({"x", recv_tag, false}, *graph.get(),
auto x = opr::RemoteRecv::make("x", *graph.get(),
client, {cns[1]}, host_x->shape(),
host_x->dtype()),
y = opr::RemoteRecv::make({"y", recv_tag, false}, *graph.get(),
y = opr::RemoteRecv::make("y", *graph.get(),
client, {cns[1]}, host_y->shape(),
host_y->dtype()),
z = x + y,
zr = opr::RemoteSend::make({"z", send_tag, false}, z, client);
zr = opr::RemoteSend::make("z", z, client, false);
auto func = graph->compile({{zr, {}}});
func->execute();
};
......@@ -177,10 +174,10 @@ TEST(TestOprIORemote, SendGrad) {
sys::set_thread_name("sender");
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
loss = opr::RemoteSend::make({"loss", send_tag, false}, x, client);
loss = opr::RemoteSend::make("loss", x, client, false);
ASSERT_TRUE(!loss.shape().ndim &&
loss.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
loss = opr::RemoteSend::make({"loss", send_tag, true}, x, client);
loss = opr::RemoteSend::make("loss", x, client, true);
auto gx = cg::grad(loss, x);
set_priority(loss, 0);
set_priority(gx, 1);
......@@ -197,10 +194,10 @@ TEST(TestOprIORemote, SendGrad) {
auto receiver = [&]() {
sys::set_thread_name("receiver");
auto graph = ComputingGraph::make();
auto x = opr::RemoteRecv::make({"loss", recv_tag, false}, *graph.get(),
auto x = opr::RemoteRecv::make("loss", *graph.get(),
client, {cns[1]}, host_x->shape(),
host_x->dtype());
auto y = opr::RemoteSend::make({"loss:grad", send_tag, false}, x + 1, client);
auto y = opr::RemoteSend::make("loss:grad", x + 1, client, false);
auto func = graph->compile({{y, {}}});
func->execute();
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册