diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 8a4f10473e3d279757fe429bb122aef1064c287a..73f22592dc3a7580c5ff139d3a52502c327680a9 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message_service.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" +#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/framework/scope.h" @@ -24,14 +25,14 @@ namespace distributed { USE_INTERCEPTOR(Compute); -void Carrier::Init( - const std::unordered_map& interceptor_id_to_node, - framework::Scope* root_scope, framework::Scope* minibatch_scope, - const std::vector& microbatch_scopes, - const platform::Place& place) { +void Carrier::Init(std::shared_ptr runtime_graph, + framework::Scope* root_scope, + framework::Scope* minibatch_scope, + const std::vector& microbatch_scopes, + const platform::Place& place) { PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists( "Carrier is already init.")); - interceptor_id_to_node_ = interceptor_id_to_node; + runtime_graph_ = runtime_graph; minibatch_scope_ = minibatch_scope; microbatch_scopes_ = microbatch_scopes; place_ = place; @@ -41,15 +42,34 @@ void Carrier::Init( is_init_ = true; } -Carrier::~Carrier() { +void Carrier::Release() { // NOTE(wangxi): must join before `Derived Interceptor` destruct, // otherwise Derived object will be destructed before thread complete. + + // Sending STOP msg to the source interceptor + MessageBus& msg_bus = MessageBus::Instance(); + PADDLE_ENFORCE_EQ(msg_bus.IsInit(), true, + platform::errors::PreconditionNotMet( + "Message bus has not been initialized.")); + for (int64_t id : source_interceptor_ids_) { + VLOG(3) << "Carrier Release is sending stop to source interceptor " << id + << "."; + InterceptorMessage stop_msg; + // source node STOP is send by carrier, so set src_id=-1 + stop_msg.set_src_id(-1); + stop_msg.set_dst_id(id); + stop_msg.set_message_type(STOP); + msg_bus.Send(stop_msg); + } + // TODO(wangxi): Maybe need a better to use thread. for (auto& interceptor : interceptor_idx_to_interceptor_) { interceptor.second->Join(); } } +Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; } + bool Carrier::EnqueueInterceptorMessage( const InterceptorMessage& interceptor_message) { // enqueue message to interceptor @@ -139,6 +159,17 @@ void Carrier::SetCreatingFlag(bool flag) { creating_interceptors_ = flag; creating_flag_mutex_.unlock(); if (!flag) { + for (auto& pair : interceptor_idx_to_interceptor_) { + // update the source interceptor id + if (std::find(source_interceptor_ids_.begin(), + source_interceptor_ids_.end(), + pair.first) == source_interceptor_ids_.end()) { + auto task = pair.second->GetTaskNode(); + if (task != nullptr && task->upstream().empty()) { + source_interceptor_ids_.emplace_back(pair.first); + } + } + } // finish create interceptors outside, handle tmp messsages HandleTmpMessages(); } @@ -161,9 +192,9 @@ void Carrier::HandleTmpMessages() { void Carrier::CreateInterceptors() { // create each Interceptor - if (!interceptor_id_to_node_.empty()) { + if (!(runtime_graph_->intercepter_id_to_node().empty())) { // no auto init since there is no config - for (const auto& item : interceptor_id_to_node_) { + for (const auto& item : runtime_graph_->intercepter_id_to_node()) { int64_t interceptor_id = item.first; TaskNode* task_node = item.second; diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index b5976b297f91394f8317293c58907548ba47b08f..0c54201c94034f4ceaca1ba720dce22a81fe417d 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -39,6 +39,7 @@ namespace distributed { class TaskNode; class InterceptorMessageServiceImpl; +class RuntimeGraph; // A singleton MessageBus class Carrier final { @@ -48,13 +49,13 @@ class Carrier final { return carrier; } - void Init( - const std::unordered_map& interceptor_id_to_node, - framework::Scope* root_scope, framework::Scope* minibatch_scope, - const std::vector& microbatch_scopes, - const platform::Place& place); + void Init(std::shared_ptr runtime_graph, + framework::Scope* root_scope, framework::Scope* minibatch_scope, + const std::vector& microbatch_scopes, + const platform::Place& place); ~Carrier(); + void Release(); // Enqueue a message to corresponding interceptor id bool EnqueueInterceptorMessage(const InterceptorMessage& interceptor_message); @@ -84,9 +85,6 @@ class Carrier final { void HandleTmpMessages(); - // interceptor logic id to the Nodes info - std::unordered_map interceptor_id_to_node_; - // interceptor logic id to actually interceptor std::unordered_map> interceptor_idx_to_interceptor_; @@ -105,7 +103,8 @@ class Carrier final { framework::Scope* root_scope_; framework::Scope* minibatch_scope_; paddle::platform::Place place_; - paddle::platform::DeviceContext* dev_ctx_ = nullptr; + paddle::platform::DeviceContext* dev_ctx_{nullptr}; + std::shared_ptr runtime_graph_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index fd55aa2aa1c4656a5b14a3eb936f18e77076076e..3d4078c932f7021ef4e5dfd69a6fb834ecbe9001 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -51,6 +51,11 @@ void ComputeInterceptor::PrepareDeps() { "times, but now max_run_times=%ld", node_->max_run_times())); } + + // If there is no downstream or every downstream is in different rank, + // then this interceptor is the last one for current rank. + // This can be get during init, can be cached for later use. + is_last_ = downstream.empty(); } void ComputeInterceptor::IncreaseReady(int64_t up_id) { @@ -129,7 +134,8 @@ void ComputeInterceptor::SendDataReadyToDownStream() { InterceptorMessage ready_msg; ready_msg.set_message_type(DATA_IS_READY); - VLOG(3) << "ComputeInterceptor Send data_is_ready msg to " << down_id; + VLOG(3) << "ComputeInterceptor " << interceptor_id_ + << " Send data_is_ready msg to " << down_id; Send(down_id, ready_msg); } } @@ -148,7 +154,8 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { InterceptorMessage reply_msg; reply_msg.set_message_type(DATE_IS_USELESS); - VLOG(3) << "ComputeInterceptor Reply data_is_useless msg to " << up_id; + VLOG(3) << "ComputeInterceptor " << interceptor_id_ + << " Reply data_is_useless msg to " << up_id; Send(up_id, reply_msg); } } @@ -159,7 +166,7 @@ void ComputeInterceptor::Run() { // step_ %= node_->max_run_times(); for (auto op : node_->ops()) { - auto* scope = microbatch_scopes_[step_ % node_->max_slot_nums()]; + auto* scope = microbatch_scopes_[step_ % node_->max_run_times()]; op->Run(*scope, place_); } ++step_; @@ -168,6 +175,10 @@ void ComputeInterceptor::Run() { SendDataReadyToDownStream(); // reply to upstream and decrease ready data ReplyCompletedToUpStream(); + // Try to stop Carrier + if (step_ % node_->max_run_times() == 0 && is_last_) { + StopCarrier(); + } } // If there is no limit, source interceptor can be executed @@ -221,11 +232,6 @@ void ComputeInterceptor::TryStop() { Send(down_id, stop); } stop_ = true; - - if (out_buffs_.size() == 0) { - // TODO(fleet executor dev) need a better place to notify - StopCarrier(); - } } void ComputeInterceptor::Compute(const InterceptorMessage& msg) { diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index 97e6da2f00eaead14a075de86ce552b87c30633f..8ed443ca971fb1355798869c79bdb5923aba6201 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -44,6 +44,7 @@ class ComputeInterceptor : public Interceptor { private: bool is_source_{false}; + bool is_last_{false}; int64_t step_{0}; // upstream_id-->(max_ready_size, ready_size) diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index ec60ec5fd5901a9eff154aadca487ab449494b30..e84e37a58eb5cbc7657c61326a473ca17f1292f7 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -38,7 +38,7 @@ FleetExecutor::~FleetExecutor() { void FleetExecutor::Init(const framework::ProgramDesc& program_desc, framework::Scope* scope, const platform::Place& place) { - runtime_graph_ = std::make_unique(program_desc, exe_desc_); + runtime_graph_ = std::make_shared(program_desc, exe_desc_); root_scope_ = scope; place_ = place; PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument( @@ -58,8 +58,8 @@ void FleetExecutor::Init(const framework::ProgramDesc& program_desc, void FleetExecutor::InitCarrier() { Carrier& carrier_instance = Carrier::Instance(); if (!carrier_instance.IsInit()) { - carrier_instance.Init(runtime_graph_->intercepter_id_to_node(), root_scope_, - minibatch_scope_, microbatch_scopes_, place_); + carrier_instance.Init(runtime_graph_, root_scope_, minibatch_scope_, + microbatch_scopes_, place_); } } diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index 7be18772e9ec9f85f334609529730f9b7867baab..cee739506b7e62564561d3d16ff0d90b8b3d953b 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -47,7 +47,7 @@ class FleetExecutor final { void InitCarrier(); void CopyParameters(int microbatch_id, const framework::ProgramDesc& program); FleetExecutorDesc exe_desc_; - std::unique_ptr runtime_graph_; + std::shared_ptr runtime_graph_; framework::Scope* root_scope_; framework::Scope* minibatch_scope_; platform::Place place_; diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 63c2bb3fc6eecb3f79896e9116806fb5dc494028..26927f34c6879b1db65a4f242cf8aa1aae6e41d2 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -46,7 +46,6 @@ void Interceptor::Handle(const InterceptorMessage& msg) { VLOG(3) << "Interceptor is using default message handler. This handler is " "only used for test purpose. Check whether you init interceptor " "in the proper way."; - if (msg.message_type() == DATA_IS_READY) { if (node_->role() != 2) { VLOG(3) << "Fake handler is sending DATA_IS_READY message to: " @@ -54,14 +53,19 @@ void Interceptor::Handle(const InterceptorMessage& msg) { InterceptorMessage data_is_ready_msg; data_is_ready_msg.set_message_type(DATA_IS_READY); Send(interceptor_id_ + 1, data_is_ready_msg); + } else { + // NOTE: max run time is reach for last interceptor + StopCarrier(); } - VLOG(3) << "Fake handler is sending stop message to it self."; - InterceptorMessage stop_msg; - stop_msg.set_message_type(STOP); - Send(interceptor_id_, stop_msg); } else if (msg.message_type() == STOP) { stop_ = true; - StopCarrier(); + if (node_->role() != 2) { + VLOG(3) << "Fake handler is sending STOP message to: " + << interceptor_id_ + 1 << "."; + InterceptorMessage stop_msg; + stop_msg.set_message_type(STOP); + Send(interceptor_id_ + 1, stop_msg); + } } } } diff --git a/paddle/fluid/distributed/fleet_executor/message_bus.cc b/paddle/fluid/distributed/fleet_executor/message_bus.cc index de2171e68e19e20f0661856916a67189dabb5630..688a6f3a3882183c324d2b05062bf97b0b76602b 100644 --- a/paddle/fluid/distributed/fleet_executor/message_bus.cc +++ b/paddle/fluid/distributed/fleet_executor/message_bus.cc @@ -57,6 +57,10 @@ bool MessageBus::IsInit() const { return is_init_; } MessageBus::~MessageBus() { VLOG(3) << "Message bus releases resource."; + // NOTE: fleet_executor inits carrier before message bus, + // therefore the message bus's destructor will be called first + Carrier& carrier = Carrier::Instance(); + carrier.Release(); #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) && \ !defined(PADDLE_WITH_ASCEND_CL) server_.Stop(1000); diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc index 2d9776738f83184bfaadade01dd231712b3b6241..c5348db83e0298db1c25c7424fa0e37c5724c24b 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc @@ -61,15 +61,15 @@ TEST(ComputeInterceptor, Compute) { std::vector scopes = {scope, scope}; platform::Place place = platform::CPUPlace(); + Carrier& carrier = Carrier::Instance(); + MessageBus& msg_bus = MessageBus::Instance(); msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); - Carrier& carrier = Carrier::Instance(); - // FIXME: don't delete, otherwise interceptor will use undefined node TaskNode* node_a = - new TaskNode(0, ops, 0, 0, 2, 2); // role, ops, rank, task_id - TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); + new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id + TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0); // a->b node_a->AddDownstreamTask(1); @@ -90,13 +90,6 @@ TEST(ComputeInterceptor, Compute) { msg.set_src_id(-1); msg.set_dst_id(0); carrier.EnqueueInterceptorMessage(msg); - - // stop - InterceptorMessage stop; - stop.set_message_type(STOP); - stop.set_src_id(-1); - stop.set_dst_id(0); - carrier.EnqueueInterceptorMessage(stop); } } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 3cfd3073c8cb9c1b8537ee9b3c2dc00acab0b192..8f44b2035aea02f36205f6c9d1af0d490979370e 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -35,31 +35,25 @@ class StartInterceptor : public Interceptor { void NOP(const InterceptorMessage& msg) { if (msg.message_type() == STOP) { stop_ = true; + InterceptorMessage stop; + stop.set_message_type(STOP); + Send(1, stop); // stop 1, compute return; } std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() << std::endl; - ++count_; - if (count_ == 3) { - InterceptorMessage stop; - stop.set_message_type(STOP); - Send(msg.dst_id(), stop); // stop 0, this - Send(msg.src_id(), stop); // stop 1, compute - } } - int count_{0}; }; TEST(ComputeInterceptor, Compute) { + Carrier& carrier = Carrier::Instance(); MessageBus& msg_bus = MessageBus::Instance(); msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); - Carrier& carrier = Carrier::Instance(); - // NOTE: don't delete, otherwise interceptor will use undefined node - TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id - TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); - TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0); + TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id + TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); + TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); // a->b->c node_a->AddDownstreamTask(1);