未验证 提交 964e20e0 编写于 作者: W WangXi 提交者: GitHub

[fleet_executor] Add compute interceptor (#37376)

上级 9d3e1896
......@@ -11,12 +11,13 @@ else()
endif()
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc interceptor_message_service.cc message_bus.cc
interceptor.cc compute_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto ${BRPC_DEPS})
if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
......@@ -21,6 +21,8 @@
namespace paddle {
namespace distributed {
USE_INTERCEPTOR(Compute);
void Carrier::Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
PrepareDeps();
RegisterMsgHandle([this](const InterceptorMessage& msg) { Compute(msg); });
}
void ComputeInterceptor::PrepareDeps() {
auto& upstream = GetTaskNode()->upstream();
upstream_deps_.insert(upstream.begin(), upstream.end());
}
void ComputeInterceptor::SendDataReadyToDownStream() {
auto& downstream = GetTaskNode()->downstream();
for (auto dst_id : downstream) {
InterceptorMessage dst_msg;
dst_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor Send msg to " << dst_id;
Send(dst_id, dst_msg);
}
}
void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
auto src_id = msg.src_id();
upstream_deps_.erase(src_id);
// all input is ready
if (upstream_deps_.empty()) {
// TODO(wangxi): op run
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
SendDataReadyToDownStream();
PrepareDeps();
}
}
}
REGISTER_INTERCEPTOR(Compute, ComputeInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
class ComputeInterceptor : public Interceptor {
public:
ComputeInterceptor(int64_t interceptor_id, TaskNode* node);
void PrepareDeps();
void SendDataReadyToDownStream();
void Compute(const InterceptorMessage& msg);
private:
std::unordered_set<int64_t> upstream_deps_;
};
} // namespace distributed
} // namespace paddle
......@@ -76,7 +76,7 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {
void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message
while (true) {
for (;;) {
if (local_mailbox_.empty()) {
// local mailbox is empty, fetch the remote mailbox
VLOG(3) << interceptor_id_ << "'s local mailbox is empty. "
......
......@@ -62,6 +62,9 @@ class Interceptor {
DISABLE_COPY_AND_ASSIGN(Interceptor);
protected:
TaskNode* GetTaskNode() const { return node_; }
private:
// pool the local mailbox, parse the Message
void PoolTheMailbox();
......@@ -114,19 +117,30 @@ class InterceptorFactory {
int64_t id, TaskNode* node);
};
template <typename InterceptorClass>
std::unique_ptr<Interceptor> CreatorInterceptor(int64_t id, TaskNode* node) {
return std::make_unique<InterceptorClass>(id, node);
}
#define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \
std::unique_ptr<Interceptor> CreatorInterceptor_##interceptor_type( \
int64_t id, TaskNode* node) { \
return std::make_unique<interceptor_class>(id, node); \
} \
class __RegisterInterceptor_##interceptor_type { \
public: \
__RegisterInterceptor_##interceptor_type() { \
InterceptorFactory::Register(#interceptor_type, \
CreatorInterceptor_##interceptor_type); \
CreatorInterceptor<interceptor_class>); \
} \
void Touch() {} \
}; \
__RegisterInterceptor_##interceptor_type g_register_##interceptor_type;
__RegisterInterceptor_##interceptor_type g_register_##interceptor_type; \
int TouchRegisterInterceptor_##interceptor_type() { \
g_register_##interceptor_type.Touch(); \
return 0; \
}
#define USE_INTERCEPTOR(interceptor_type) \
extern int TouchRegisterInterceptor_##interceptor_type(); \
UNUSED static int use_interceptor_##interceptor_type = \
TouchRegisterInterceptor_##interceptor_type();
} // namespace distributed
} // namespace paddle
......@@ -15,8 +15,10 @@
#pragma once
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/platform/macros.h"
namespace paddle {
......@@ -33,6 +35,7 @@ class TaskNode final {
TaskNode(int32_t role, const std::vector<OperatorBase*>& ops, int64_t rank,
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums);
~TaskNode() = default;
int64_t rank() const { return rank_; }
int64_t task_id() const { return task_id_; }
int32_t role() const { return role_; }
......@@ -40,9 +43,12 @@ class TaskNode final {
int64_t max_slot_nums() const { return max_slot_nums_; }
const std::unordered_set<int64_t>& upstream() const { return upstream_; }
const std::unordered_set<int64_t>& downstream() const { return downstream_; }
const std::string& type() const { return type_; }
void AddUpstreamTask(int64_t task_id);
void AddDownstreamTask(int64_t task_id);
std::string DebugString() const;
static std::unique_ptr<TaskNode> CreateEmptyTaskNode(int32_t role,
int64_t rank,
int64_t task_id,
......@@ -63,6 +69,8 @@ class TaskNode final {
int64_t task_id_;
int64_t max_run_times_;
int64_t max_slot_nums_;
std::string type_;
};
} // namespace distributed
......
set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS})
cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS})
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
namespace paddle {
namespace distributed {
class StopInterceptor : public Interceptor {
public:
StopInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { Stop(msg); });
}
void Stop(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
Send(2, stop);
}
};
TEST(ComputeInterceptor, Compute) {
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
Carrier& carrier = Carrier::Instance();
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0);
// a->b->c
node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0);
node_b->AddDownstreamTask(2);
Interceptor* a = carrier.SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier.SetInterceptor(2, std::make_unique<StopInterceptor>(2, node_c));
carrier.SetCreatingFlag(false);
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
a->Send(1, msg);
}
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册