io_remote.cpp 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/**
 * \file src/opr-mm/impl/io_remote.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/io_remote.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/megray_helper.h"
#include "megbrain/serialization/sereg.h"

using namespace mgb;
using namespace opr;

21 22 23 24
cudaStream_t get_stream(VarNode* var) {
    return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
}

25 26 27 28
/* ===================== RemoteSend ===================== */

MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);

29
RemoteSend::RemoteSend(const std::string& key, VarNode* var,
30
                       std::shared_ptr<GroupClient> group_client,
31 32 33 34
                       bool is_grad, const OperatorNodeConfig& config) :
        Super(var->owner_graph(), config, "remote_send", {var}),
        m_is_grad(is_grad) {
    m_key = key;
35 36 37 38
    m_group_client = group_client;

    add_input({var});
    auto ovar = add_output(None);
39
    if (!m_is_grad) {
40 41 42 43 44 45
        ovar->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
                .add_flag(VarNode::Flag::VOLATILE_CONTENT);
    }
    add_equivalence_component<ScalarHash<void*>>(this);
}

46
SymbolVar RemoteSend::make(const std::string& key, SymbolVar var,
47
                           std::shared_ptr<GroupClient> group_client,
48 49 50
                           bool is_grad, const OperatorNodeConfig& config) {
    return var.insert_single_output_opr<RemoteSend>(key, var.node(), group_client,
                                                    is_grad, config);
51 52 53 54
}

void RemoteSend::scn_do_execute() {
    if (!m_init) {
55
        auto&& comp_node = output(0)->comp_node();
56 57

        // rank 0 for RemoteSend
58
        auto reg_info = m_group_client->opr_register(m_key, 2, 0, false,
59
                comp_node.get_uid());
60

61
        m_megray_comm = MegRayCommBuilder::get_megray_comm(
62
                reg_info.hash, m_key, 2, 0, MegRay::MEGRAY_NCCL, m_group_client);
63 64 65

        m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));

66 67 68 69 70 71 72 73 74 75 76 77 78 79
        m_init = true;
    }

    mgb_assert(m_init);
    size_t data_size = 1;
    auto&& tensor = input(0)->dev_tensor();
    auto&& ishp = tensor.shape();
    for (size_t i = 0; i < ishp.ndim; i++) {
        data_size *= ishp[i];
    }
    data_size *= tensor.dtype().size();
    auto status = m_megray_comm->send(tensor.raw_ptr(), data_size, 1, m_megray_ctx);
    mgb_assert(status == MegRay::MEGRAY_OK, "MegRay send failed");

80
    if (m_is_grad) {
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        auto&& dest = output(0)->dev_tensor();
        if (m_output_val.empty()) {
            m_output_val.comp_node(dest.comp_node())
                    .dtype(dest.dtype())
                    .resize({1});
            memset(m_output_val.raw_ptr(), 0, m_output_val.dtype().size());
        }
        dest.copy_from_fixlayout(m_output_val);
    }
}

void RemoteSend::init_output_static_infer_desc() {
    using namespace cg::static_infer;
    auto&& mgr = owner_graph()->static_infer_manager();
    auto do_infer = [this](TensorShape& dest, const InpVal&) {
96
        if (m_is_grad) {
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
            dest = {1};
        } else {
            dest = {0};
        }
        return true;
    };
    mgr.register_shape_infer(output(0), {SourceType::CONSTANT, {}, do_infer});
}

cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const {
    auto prop = RemoteIOBase::do_make_node_prop();
    prop->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
    return prop;
}

MGB_IMPL_OPR_GRAD(RemoteSend) {
113 114
    mgb_assert(opr.is_grad());
    return RemoteRecv::make(opr.key() + ":grad",
115 116 117 118 119 120 121 122 123 124 125
                            *opr.owner_graph(), opr.group_client(),
                            OperatorNodeConfig{opr.comp_node()}.name(
                                    opr.name() + ":grad_recv"),
                            opr.input(0)->shape(), opr.input(0)->dtype())
            .node();
}

/* ===================== RemoteRecv ===================== */

MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv);

126
RemoteRecv::RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
127 128 129 130 131
                       std::shared_ptr<GroupClient> group_client,
                       const OperatorNodeConfig& config,
                       const TensorShape& shape, DType dtype) :
        Super(&graph, config, "remote_recv", {}),
        m_shape(shape), m_dtype(dtype) {
132
    m_key = key;
133 134 135 136 137 138 139 140 141
    m_group_client = group_client;

    add_output(None)
            ->dtype(dtype)
            .add_flag(VarNode::Flag::NO_MEM_RECLAIM)
            .add_flag(VarNode::Flag::DISALLOW_RT_FORCE_DYNAMIC_MEM_ALLOC);
    add_equivalence_component<ScalarHash<void*>>(this);
}

142
SymbolVar RemoteRecv::make(const std::string& key, cg::ComputingGraph& graph,
143 144 145 146
                           std::shared_ptr<GroupClient> group_client,
                           const OperatorNodeConfig& config,
                           const TensorShape& shape, DType dtype) {
    auto opr = graph.insert_opr(std::make_unique<RemoteRecv>(
147
            key, graph, group_client, config, shape, dtype));
148 149 150 151 152
    return opr->output(0);
}

void RemoteRecv::scn_do_execute() {
    if (!m_init) {
153
        auto&& comp_node = output(0)->comp_node();
154 155

        // rank 1 for RemoteRecv
156
        auto reg_info = m_group_client->opr_register(
157
                m_key, 2, false, 1,
158
                comp_node.get_uid());
159

160
        m_megray_comm = MegRayCommBuilder::get_megray_comm(
161
                reg_info.hash, m_key, 2, 1, MegRay::MEGRAY_NCCL, m_group_client);
162 163 164

        m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));

165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
        m_init = true;
    }

    mgb_assert(m_init);
    size_t data_size = 1;
    auto&& tensor = output(0)->dev_tensor();
    auto&& ishp = tensor.shape();
    for (size_t i = 0; i < ishp.ndim; i++) {
        data_size *= ishp[i];
    }
    data_size *= tensor.dtype().size();
    auto status = m_megray_comm->recv(tensor.raw_ptr(), data_size, 0, m_megray_ctx);
    mgb_assert(status == MegRay::MEGRAY_OK, "MegRay recv failed");
}

void RemoteRecv::init_output_static_infer_desc() {
    using namespace cg::static_infer;
    auto&& mgr = owner_graph()->static_infer_manager();
    auto do_infer = [this](TensorShape& dest, const InpVal&) {
        dest = m_shape;
        return true;
    };
    mgr.register_shape_infer(output(0), {SourceType::CONSTANT, {}, do_infer});
}

cg::OperatorNodeBase::NodeProp* RemoteRecv::do_make_node_prop() const {
    auto prop = RemoteIOBase::do_make_node_prop();
    prop->add_flag(NodeProp::Flag::IMPURE_FUNC);
    if (input().size() == 1)
        prop->reset_dep_type(input(), {NodeProp::DepType::DEV_COMP_ORDER});
    return prop;
}

/* ===================== shallow copy ===================== */

namespace mgb {
namespace opr {

cg::OperatorNodeBase* opr_shallow_copy_remote_send(
        const serialization::OprShallowCopyContext& ctx,
        const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
        const OperatorNodeConfig& config) {
    mgb_assert(inputs.size() == 1);
    auto&& opr = opr_.cast_final_safe<RemoteSend>();
209 210
    return RemoteSend::make(opr.key(), inputs[0], opr.group_client(),
                            opr.is_grad(), config)
211 212 213 214 215 216 217 218 219 220
            .node()
            ->owner_opr();
}
MGB_REG_OPR_SHALLOW_COPY(RemoteSend, opr_shallow_copy_remote_send);

cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
        const serialization::OprShallowCopyContext& ctx,
        const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
        const OperatorNodeConfig& config) {
    auto&& opr = opr_.cast_final_safe<RemoteRecv>();
221
    return RemoteRecv::make(opr.key(), *opr.owner_graph(),
222 223 224 225 226 227 228 229 230 231 232
                            opr.group_client(), config, inputs[0]->shape(),
                            inputs[0]->dtype())
            .node()
            ->owner_opr();
}
MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv);

}  // namespace opr
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}