未验证 提交 5d0ce171 编写于 作者: Y Yuang Liu 提交者: GitHub

add time wait for message bus (#37809)

上级 075a02d2
......@@ -46,6 +46,8 @@ void ComputeInterceptor::PrepareDeps() {
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld",
node_->max_run_times()));
in_readys_.emplace(-1,
std::make_pair(std::numeric_limits<int64_t>::max(), 0));
}
// If there is no downstream or every downstream is in different rank,
......@@ -55,14 +57,17 @@ void ComputeInterceptor::PrepareDeps() {
}
void ComputeInterceptor::IncreaseReady(int64_t up_id) {
// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) return;
auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, in_readys_.end(),
platform::errors::NotFound(
"Cannot find upstream=%lld in in_readys.", up_id));
// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) {
it->second.second = GetTaskNode()->max_run_times();
return;
}
auto max_ready_size = it->second.first;
auto ready_size = it->second.second;
ready_size += 1;
......@@ -93,7 +98,11 @@ bool ComputeInterceptor::IsInputReady() {
for (auto& ins : in_readys_) {
auto ready_size = ins.second.second;
// not ready, return false
if (ready_size == 0) return false;
if (ready_size == 0) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< "'s upstreams aren't all ready.";
return false;
}
}
return true;
}
......@@ -103,14 +112,23 @@ bool ComputeInterceptor::CanWriteOutput() {
auto max_buffer_size = outs.second.first;
auto used_size = outs.second.second;
// full, return false
if (used_size == max_buffer_size) return false;
if (used_size == max_buffer_size) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< "'s out buffer is full.";
return false;
}
}
return true;
}
// only source node need reset
bool ComputeInterceptor::ShouldReset() {
return is_source_ && (step_ == node_->max_run_times());
if (is_source_ && step_ == node_->max_run_times()) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " should reset for step: " << step_ << ".";
return true;
}
return false;
}
void ComputeInterceptor::SendDataReadyToDownStream() {
......@@ -130,7 +148,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id;
<< " Send data_is_ready msg to " << down_id
<< " for step: " << step_;
Send(down_id, ready_msg);
}
}
......@@ -147,23 +166,43 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
ready_size));
ins.second.second = ready_size;
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id
<< " for step: " << step_;
if (up_id == -1) return;
InterceptorMessage reply_msg;
reply_msg.set_message_type(DATE_IS_USELESS);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Reply data_is_useless msg to " << up_id;
Send(up_id, reply_msg);
}
}
void ComputeInterceptor::RunOps() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the "
<< step_ << " time.";
<< step_ + 1 << " time.";
for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
}
}
void ComputeInterceptor::Run() {
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run max_run_times.
if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) {
VLOG(3) << "Interceptor " << GetInterceptorId()
<< " out buffer for downstream: " << out_buff.first
<< "'s counter is: " << out_buff.second.second
<< ". Cannot be reset.";
return;
}
}
step_ = 0; // reset
}
while (IsInputReady() && CanWriteOutput() && !ShouldReset()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
......@@ -181,18 +220,6 @@ void ComputeInterceptor::Run() {
StopCarrier();
}
}
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run max_run_times.
if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) return;
}
step_ = 0; // reset
return;
}
}
void ComputeInterceptor::ReceivedStop(int64_t up_id) {
......
......@@ -109,6 +109,15 @@ void FleetExecutor::Run() {
message_bus_instance.IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet."));
carrier_instance.Start();
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
// But when while_op also create a local executor to run it's sub block,
// the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
micro_scop->DropKids();
}
}
void FleetExecutor::CopyParameters(int microbatch_id,
......
......@@ -14,6 +14,7 @@
#include <chrono>
#include <memory>
#include <set>
#include <thread>
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
......@@ -56,11 +57,11 @@ void MessageBus::Init(
bool MessageBus::IsInit() const { return is_init_; }
MessageBus::~MessageBus() {
VLOG(3) << "Message bus releases resource.";
// NOTE: fleet_executor inits carrier before message bus,
// therefore the message bus's destructor will be called first
Carrier& carrier = Carrier::Instance();
carrier.Release();
VLOG(3) << "Message bus releases resource.";
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \
!defined(PADDLE_WITH_ASCEND_CL)
server_.Stop(1000);
......@@ -90,6 +91,8 @@ bool MessageBus::Send(const InterceptorMessage& interceptor_message) {
<< retry_time << " times retries.";
return true;
}
VLOG(3) << "Message bus sends failed, retry after 1 seconds.";
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus sends inter rank fail after 10 times retries.";
return false;
......@@ -121,16 +124,40 @@ void MessageBus::ListenPort() {
brpc::ServerOptions options;
options.idle_timeout_sec = -1;
int retry_times = 0;
int interval = 1000;
int interval = 100;
while (server_.Start(ip_for_brpc, &options) != 0) {
++retry_times;
LOG(INFO) << "Message bus is retring for starting brpc for " << retry_times
<< " times. And will retry after " << interval / 1000
<< " seconds.";
std::this_thread::sleep_for(std::chrono::milliseconds(interval));
interval += 2000;
interval += 500;
}
LOG(INFO) << "Message bus's listen port thread starts successful.";
std::set<int64_t> visit;
InterceptorMessage tmp_msg;
tmp_msg.set_ctrl_message(true);
for (auto pair : interceptor_id_to_rank_) {
if (rank_to_addr_.at(pair.second) == addr_) {
tmp_msg.set_src_id(pair.first);
}
}
for (auto pair : interceptor_id_to_rank_) {
int64_t rank = pair.second;
if (rank_to_addr_.at(rank) == addr_) {
continue;
}
tmp_msg.set_dst_id(pair.first);
if (visit.find(rank) == visit.end()) {
VLOG(3) << "Message bus is testing connection for rank: " << rank << ".";
visit.insert(rank);
while (!Send(tmp_msg)) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
VLOG(3) << "Message bus has connected to rank: " << rank << ".";
}
}
#else
LOG(WARNING)
<< "Fleet executor's ListenPort() is a fake function when Paddle is "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册