未验证 提交 0adc2006 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] auto STOP msg and auto notify carrier (#37742)

上级 79095918
......@@ -16,6 +16,7 @@
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/scope.h"
......@@ -24,14 +25,14 @@ namespace distributed {
USE_INTERCEPTOR(Compute);
void Carrier::Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
void Carrier::Init(std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope,
framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init."));
interceptor_id_to_node_ = interceptor_id_to_node;
runtime_graph_ = runtime_graph;
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
place_ = place;
......@@ -41,15 +42,34 @@ void Carrier::Init(
is_init_ = true;
}
Carrier::~Carrier() {
void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.
// Sending STOP msg to the source interceptor
MessageBus& msg_bus = MessageBus::Instance();
PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true,
platform::errors::PreconditionNotMet(
"Message bus has not been initialized."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< ".";
InterceptorMessage stop_msg;
// source node STOP is send by carrier, so set src_id=-1
stop_msg.set_src_id(-1);
stop_msg.set_dst_id(id);
stop_msg.set_message_type(STOP);
msg_bus.Send(stop_msg);
}
// TODO(wangxi): Maybe need a better to use thread.
for (auto& interceptor : interceptor_idx_to_interceptor_) {
interceptor.second->Join();
}
}
Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }
bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor
......@@ -139,6 +159,17 @@ void Carrier::SetCreatingFlag(bool flag) {
creating_interceptors_ = flag;
creating_flag_mutex_.unlock();
if (!flag) {
for (auto& pair : interceptor_idx_to_interceptor_) {
// update the source interceptor id
if (std::find(source_interceptor_ids_.begin(),
source_interceptor_ids_.end(),
pair.first) == source_interceptor_ids_.end()) {
auto task = pair.second->GetTaskNode();
if (task != nullptr && task->upstream().empty()) {
source_interceptor_ids_.emplace_back(pair.first);
}
}
}
// finish create interceptors outside, handle tmp messsages
HandleTmpMessages();
}
......@@ -161,9 +192,9 @@ void Carrier::HandleTmpMessages() {
void Carrier::CreateInterceptors() {
// create each Interceptor
if (!interceptor_id_to_node_.empty()) {
if (!(runtime_graph_->intercepter_id_to_node().empty())) {
// no auto init since there is no config
for (const auto& item : interceptor_id_to_node_) {
for (const auto& item : runtime_graph_->intercepter_id_to_node()) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;
......
......@@ -39,6 +39,7 @@ namespace distributed {
class TaskNode;
class InterceptorMessageServiceImpl;
class RuntimeGraph;
// A singleton MessageBus
class Carrier final {
......@@ -48,13 +49,13 @@ class Carrier final {
return carrier;
}
void Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
void Init(std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
~Carrier();
void Release();
// Enqueue a message to corresponding interceptor id
bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message);
......@@ -84,9 +85,6 @@ class Carrier final {
void HandleTmpMessages();
// interceptor logic id to the Nodes info
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_node_;
// interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;
......@@ -105,7 +103,8 @@ class Carrier final {
framework::Scope* root_scope_;
framework::Scope* minibatch_scope_;
paddle::platform::Place place_;
paddle::platform::DeviceContext* dev_ctx_ = nullptr;
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_;
};
} // namespace distributed
......
......@@ -51,6 +51,11 @@ void ComputeInterceptor::PrepareDeps() {
"times, but now max_run_times=%ld",
node_->max_run_times()));
}
// If there is no downstream or every downstream is in different rank,
// then this interceptor is the last one for current rank.
// This can be get during init, can be cached for later use.
is_last_ = downstream.empty();
}
void ComputeInterceptor::IncreaseReady(int64_t up_id) {
......@@ -129,7 +134,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor Send data_is_ready msg to " << down_id;
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id;
Send(down_id, ready_msg);
}
}
......@@ -148,7 +154,8 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
InterceptorMessage reply_msg;
reply_msg.set_message_type(DATE_IS_USELESS);
VLOG(3) << "ComputeInterceptor Reply data_is_useless msg to " << up_id;
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id;
Send(up_id, reply_msg);
}
}
......@@ -159,7 +166,7 @@ void ComputeInterceptor::Run() {
// step_ %= node_->max_run_times();
for (auto op : node_->ops()) {
auto* scope = microbatch_scopes_[step_ % node_->max_slot_nums()];
auto* scope = microbatch_scopes_[step_ % node_->max_run_times()];
op->Run(*scope, place_);
}
++step_;
......@@ -168,6 +175,10 @@ void ComputeInterceptor::Run() {
SendDataReadyToDownStream();
// reply to upstream and decrease ready data
ReplyCompletedToUpStream();
// Try to stop Carrier
if (step_ % node_->max_run_times() == 0 && is_last_) {
StopCarrier();
}
}
// If there is no limit, source interceptor can be executed
......@@ -221,11 +232,6 @@ void ComputeInterceptor::TryStop() {
Send(down_id, stop);
}
stop_ = true;
if (out_buffs_.size() == 0) {
// TODO(fleet executor dev) need a better place to notify
StopCarrier();
}
}
void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
......
......@@ -44,6 +44,7 @@ class ComputeInterceptor : public Interceptor {
private:
bool is_source_{false};
bool is_last_{false};
int64_t step_{0};
// upstream_id-->(max_ready_size, ready_size)
......
......@@ -38,7 +38,7 @@ FleetExecutor::~FleetExecutor() {
void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
framework::Scope* scope,
const platform::Place& place) {
runtime_graph_ = std::make_unique<RuntimeGraph>(program_desc, exe_desc_);
runtime_graph_ = std::make_shared<RuntimeGraph>(program_desc, exe_desc_);
root_scope_ = scope;
place_ = place;
PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(
......@@ -58,8 +58,8 @@ void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
void FleetExecutor::InitCarrier() {
Carrier& carrier_instance = Carrier::Instance();
if (!carrier_instance.IsInit()) {
carrier_instance.Init(runtime_graph_->intercepter_id_to_node(), root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
carrier_instance.Init(runtime_graph_, root_scope_, minibatch_scope_,
microbatch_scopes_, place_);
}
}
......
......@@ -47,7 +47,7 @@ class FleetExecutor final {
void InitCarrier();
void CopyParameters(int microbatch_id, const framework::ProgramDesc& program);
FleetExecutorDesc exe_desc_;
std::unique_ptr<RuntimeGraph> runtime_graph_;
std::shared_ptr<RuntimeGraph> runtime_graph_;
framework::Scope* root_scope_;
framework::Scope* minibatch_scope_;
platform::Place place_;
......
......@@ -46,7 +46,6 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
VLOG(3) << "Interceptor is using default message handler. This handler is "
"only used for test purpose. Check whether you init interceptor "
"in the proper way.";
if (msg.message_type() == DATA_IS_READY) {
if (node_->role() != 2) {
VLOG(3) << "Fake handler is sending DATA_IS_READY message to: "
......@@ -54,14 +53,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) {
InterceptorMessage data_is_ready_msg;
data_is_ready_msg.set_message_type(DATA_IS_READY);
Send(interceptor_id_ + 1, data_is_ready_msg);
} else {
// NOTE: max run time is reach for last interceptor
StopCarrier();
}
VLOG(3) << "Fake handler is sending stop message to it self.";
InterceptorMessage stop_msg;
stop_msg.set_message_type(STOP);
Send(interceptor_id_, stop_msg);
} else if (msg.message_type() == STOP) {
stop_ = true;
StopCarrier();
if (node_->role() != 2) {
VLOG(3) << "Fake handler is sending STOP message to: "
<< interceptor_id_ + 1 << ".";
InterceptorMessage stop_msg;
stop_msg.set_message_type(STOP);
Send(interceptor_id_ + 1, stop_msg);
}
}
}
}
......
......@@ -57,6 +57,10 @@ bool MessageBus::IsInit() const { return is_init_; }
MessageBus::~MessageBus() {
VLOG(3) << "Message bus releases resource.";
// NOTE: fleet_executor inits carrier before message bus,
// therefore the message bus's destructor will be called first
Carrier& carrier = Carrier::Instance();
carrier.Release();
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000);
......
......@@ -61,15 +61,15 @@ TEST(ComputeInterceptor, Compute) {
std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace();
Carrier& carrier = Carrier::Instance();
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
Carrier& carrier = Carrier::Instance();
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a =
new TaskNode(0, ops, 0, 0, 2, 2); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0);
new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0);
// a->b
node_a->AddDownstreamTask(1);
......@@ -90,13 +90,6 @@ TEST(ComputeInterceptor, Compute) {
msg.set_src_id(-1);
msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg);
// stop
InterceptorMessage stop;
stop.set_message_type(STOP);
stop.set_src_id(-1);
stop.set_dst_id(0);
carrier.EnqueueInterceptorMessage(stop);
}
} // namespace distributed
......
......@@ -35,31 +35,25 @@ class StartInterceptor : public Interceptor {
void NOP(const InterceptorMessage& msg) {
if (msg.message_type() == STOP) {
stop_ = true;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(1, stop); // stop 1, compute
return;
}
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
++count_;
if (count_ == 3) {
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(msg.dst_id(), stop); // stop 0, this
Send(msg.src_id(), stop); // stop 1, compute
}
}
int count_{0};
};
TEST(ComputeInterceptor, Compute) {
Carrier& carrier = Carrier::Instance();
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");
Carrier& carrier = Carrier::Instance();
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0);
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
// a->b->c
node_a->AddDownstreamTask(1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册