未验证 提交 5d0ce171 编写于 作者: Y Yuang Liu 提交者: GitHub

add time wait for message bus (#37809)

上级 075a02d2
...@@ -46,6 +46,8 @@ void ComputeInterceptor::PrepareDeps() { ...@@ -46,6 +46,8 @@ void ComputeInterceptor::PrepareDeps() {
"Source ComputeInterceptor must run at least one " "Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld", "times, but now max_run_times=%ld",
node_->max_run_times())); node_->max_run_times()));
in_readys_.emplace(-1,
std::make_pair(std::numeric_limits<int64_t>::max(), 0));
} }
// If there is no downstream or every downstream is in different rank, // If there is no downstream or every downstream is in different rank,
...@@ -55,14 +57,17 @@ void ComputeInterceptor::PrepareDeps() { ...@@ -55,14 +57,17 @@ void ComputeInterceptor::PrepareDeps() {
} }
void ComputeInterceptor::IncreaseReady(int64_t up_id) { void ComputeInterceptor::IncreaseReady(int64_t up_id) {
// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) return;
auto it = in_readys_.find(up_id); auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, in_readys_.end(), PADDLE_ENFORCE_NE(it, in_readys_.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Cannot find upstream=%lld in in_readys.", up_id)); "Cannot find upstream=%lld in in_readys.", up_id));
// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) {
it->second.second = GetTaskNode()->max_run_times();
return;
}
auto max_ready_size = it->second.first; auto max_ready_size = it->second.first;
auto ready_size = it->second.second; auto ready_size = it->second.second;
ready_size += 1; ready_size += 1;
...@@ -93,7 +98,11 @@ bool ComputeInterceptor::IsInputReady() { ...@@ -93,7 +98,11 @@ bool ComputeInterceptor::IsInputReady() {
for (auto& ins : in_readys_) { for (auto& ins : in_readys_) {
auto ready_size = ins.second.second; auto ready_size = ins.second.second;
// not ready, return false // not ready, return false
if (ready_size == 0) return false; if (ready_size == 0) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< "'s upstreams aren't all ready.";
return false;
}
} }
return true; return true;
} }
...@@ -103,14 +112,23 @@ bool ComputeInterceptor::CanWriteOutput() { ...@@ -103,14 +112,23 @@ bool ComputeInterceptor::CanWriteOutput() {
auto max_buffer_size = outs.second.first; auto max_buffer_size = outs.second.first;
auto used_size = outs.second.second; auto used_size = outs.second.second;
// full, return false // full, return false
if (used_size == max_buffer_size) return false; if (used_size == max_buffer_size) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< "'s out buffer is full.";
return false;
}
} }
return true; return true;
} }
// only source node need reset // only source node need reset
bool ComputeInterceptor::ShouldReset() { bool ComputeInterceptor::ShouldReset() {
return is_source_ && (step_ == node_->max_run_times()); if (is_source_ && step_ == node_->max_run_times()) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " should reset for step: " << step_ << ".";
return true;
}
return false;
} }
void ComputeInterceptor::SendDataReadyToDownStream() { void ComputeInterceptor::SendDataReadyToDownStream() {
...@@ -130,7 +148,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -130,7 +148,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY); ready_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor " << interceptor_id_ VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id; << " Send data_is_ready msg to " << down_id
<< " for step: " << step_;
Send(down_id, ready_msg); Send(down_id, ready_msg);
} }
} }
...@@ -147,23 +166,43 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -147,23 +166,43 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
ready_size)); ready_size));
ins.second.second = ready_size; ins.second.second = ready_size;
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id
<< " for step: " << step_;
if (up_id == -1) return;
InterceptorMessage reply_msg; InterceptorMessage reply_msg;
reply_msg.set_message_type(DATE_IS_USELESS); reply_msg.set_message_type(DATE_IS_USELESS);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id;
Send(up_id, reply_msg); Send(up_id, reply_msg);
} }
} }
void ComputeInterceptor::RunOps() { void ComputeInterceptor::RunOps() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the " VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ << " time."; << step_ + 1 << " time.";
for (auto op : node_->ops()) { for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_); op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
} }
} }
void ComputeInterceptor::Run() { void ComputeInterceptor::Run() {
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run max_run_times.
if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " out buffer for downstream: " << out_buff.first
<< "'s counter is: " << out_buff.second.second
<< ". Cannot be reset.";
return;
}
}
step_ = 0; // reset
}
while (IsInputReady() && CanWriteOutput() && !ShouldReset()) { while (IsInputReady() && CanWriteOutput() && !ShouldReset()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
...@@ -181,18 +220,6 @@ void ComputeInterceptor::Run() { ...@@ -181,18 +220,6 @@ void ComputeInterceptor::Run() {
StopCarrier(); StopCarrier();
} }
} }
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run max_run_times.
if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) return;
}
step_ = 0; // reset
return;
}
} }
void ComputeInterceptor::ReceivedStop(int64_t up_id) { void ComputeInterceptor::ReceivedStop(int64_t up_id) {
......
...@@ -109,6 +109,15 @@ void FleetExecutor::Run() { ...@@ -109,6 +109,15 @@ void FleetExecutor::Run() {
message_bus_instance.IsInit(), true, message_bus_instance.IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet.")); platform::errors::Unavailable("MessageBus has not been init yet."));
carrier_instance.Start(); carrier_instance.Start();
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// But when while_op also create a local executor to run it's sub block,
// the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
micro_scop->DropKids();
}
} }
void FleetExecutor::CopyParameters(int microbatch_id, void FleetExecutor::CopyParameters(int microbatch_id,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <chrono> #include <chrono>
#include <memory> #include <memory>
#include <set>
#include <thread> #include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/carrier.h"
...@@ -56,11 +57,11 @@ void MessageBus::Init( ...@@ -56,11 +57,11 @@ void MessageBus::Init(
bool MessageBus::IsInit() const { return is_init_; } bool MessageBus::IsInit() const { return is_init_; }
MessageBus::~MessageBus() { MessageBus::~MessageBus() {
VLOG(3) << "Message bus releases resource.";
// NOTE: fleet_executor inits carrier before message bus, // NOTE: fleet_executor inits carrier before message bus,
// therefore the message bus's destructor will be called first // therefore the message bus's destructor will be called first
Carrier& carrier = Carrier::Instance(); Carrier& carrier = Carrier::Instance();
carrier.Release(); carrier.Release();
VLOG(3) << "Message bus releases resource.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL) !defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000); server_.Stop(1000);
...@@ -90,6 +91,8 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) { ...@@ -90,6 +91,8 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
<< retry_time << " times retries."; << retry_time << " times retries.";
return true; return true;
} }
VLOG(3) << "Message bus sends failed, retry after 1 seconds.";
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
} }
VLOG(3) << "Message bus sends inter rank fail after 10 times retries."; VLOG(3) << "Message bus sends inter rank fail after 10 times retries.";
return false; return false;
...@@ -121,16 +124,40 @@ void MessageBus::ListenPort() { ...@@ -121,16 +124,40 @@ void MessageBus::ListenPort() {
brpc::ServerOptions options; brpc::ServerOptions options;
options.idle_timeout_sec = -1; options.idle_timeout_sec = -1;
int retry_times = 0; int retry_times = 0;
int interval = 1000; int interval = 100;
while (server_.Start(ip_for_brpc, &options) != 0) { while (server_.Start(ip_for_brpc, &options) != 0) {
++retry_times; ++retry_times;
LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times
<< " times. And will retry after " << interval / 1000 << " times. And will retry after " << interval / 1000
<< " seconds."; << " seconds.";
std::this_thread::sleep_for(std::chrono::milliseconds(interval)); std::this_thread::sleep_for(std::chrono::milliseconds(interval));
interval += 2000; interval += 500;
} }
LOG(INFO) << "Message bus's listen port thread starts successful."; LOG(INFO) << "Message bus's listen port thread starts successful.";
std::set<int64_t> visit;
InterceptorMessage tmp_msg;
tmp_msg.set_ctrl_message(true);
for (auto pair : interceptor_id_to_rank_) {
if (rank_to_addr_.at(pair.second) == addr_) {
tmp_msg.set_src_id(pair.first);
}
}
for (auto pair : interceptor_id_to_rank_) {
int64_t rank = pair.second;
if (rank_to_addr_.at(rank) == addr_) {
continue;
}
tmp_msg.set_dst_id(pair.first);
if (visit.find(rank) == visit.end()) {
VLOG(3) << "Message bus is testing connection for rank: " << rank << ".";
visit.insert(rank);
while (!Send(tmp_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus has connected to rank: " << rank << ".";
}
}
#else #else
LOG(WARNING) LOG(WARNING)
<< "Fleet executor's ListenPort() is a fake function when Paddle is " << "Fleet executor's ListenPort() is a fake function when Paddle is "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册