未验证 提交 0074a3c9 编写于 作者: L LiYuRio 提交者: GitHub

[Fleet Executor] Refine runtime graph (#37703)

上级 bfb85779
...@@ -12,7 +12,7 @@ endif() ...@@ -12,7 +12,7 @@ endif()
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
${BRPC_DEPS}) ${BRPC_DEPS})
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
......
...@@ -31,9 +31,7 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) { ...@@ -31,9 +31,7 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
"Error occurs while parsing string to proto")); "Error occurs while parsing string to proto"));
} }
FleetExecutor::~FleetExecutor() { FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); }
// Destroy Executor
}
void FleetExecutor::Init(const framework::ProgramDesc& program_desc, void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
framework::Scope* scope, framework::Scope* scope,
...@@ -113,8 +111,6 @@ void FleetExecutor::Run() { ...@@ -113,8 +111,6 @@ void FleetExecutor::Run() {
carrier_instance.Start(); carrier_instance.Start();
} }
void FleetExecutor::Release() { root_scope_->DropKids(); }
void FleetExecutor::CopyParameters(int microbatch_id, void FleetExecutor::CopyParameters(int microbatch_id,
const framework::ProgramDesc& program) { const framework::ProgramDesc& program) {
auto& global_block = program.Block(0); auto& global_block = program.Block(0);
......
...@@ -39,7 +39,6 @@ class FleetExecutor final { ...@@ -39,7 +39,6 @@ class FleetExecutor final {
void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope, void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place); const platform::Place& place);
void Run(); void Run();
void Release();
private: private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor); DISABLE_COPY_AND_ASSIGN(FleetExecutor);
......
...@@ -21,7 +21,7 @@ message RankInfo { ...@@ -21,7 +21,7 @@ message RankInfo {
} }
message FleetExecutorDesc { message FleetExecutorDesc {
optional string grain = 1 [ default = "coarse" ]; optional string strategy = 1 [ default = "Origin" ];
optional int64 cur_rank = 2 [ default = 0 ]; // Rank id of current processor optional int64 cur_rank = 2 [ default = 0 ]; // Rank id of current processor
repeated RankInfo cluster_info = 3; repeated RankInfo cluster_info = 3;
optional int32 dp_degree = 4 [ default = 1 ]; optional int32 dp_degree = 4 [ default = 1 ];
......
...@@ -40,34 +40,9 @@ void Interceptor::Join() { ...@@ -40,34 +40,9 @@ void Interceptor::Join() {
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; } void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
void Interceptor::Handle(const InterceptorMessage& msg) { void Interceptor::Handle(const InterceptorMessage& msg) {
if (handle_) { PADDLE_ENFORCE_NOT_NULL(handle_, platform::errors::PreconditionNotMet(
handle_(msg); "Message handle is not registered."));
} else { handle_(msg);
VLOG(3) << "Interceptor is using default message handler. This handler is "
"only used for test purpose. Check whether you init interceptor "
"in the proper way.";
if (msg.message_type() == DATA_IS_READY) {
if (node_->role() != 2) {
VLOG(3) << "Fake handler is sending DATA_IS_READY message to: "
<< interceptor_id_ + 1 << ".";
InterceptorMessage data_is_ready_msg;
data_is_ready_msg.set_message_type(DATA_IS_READY);
Send(interceptor_id_ + 1, data_is_ready_msg);
} else {
// NOTE: max run time is reach for last interceptor
StopCarrier();
}
} else if (msg.message_type() == STOP) {
stop_ = true;
if (node_->role() != 2) {
VLOG(3) << "Fake handler is sending STOP message to: "
<< interceptor_id_ + 1 << ".";
InterceptorMessage stop_msg;
stop_msg.set_message_type(STOP);
Send(interceptor_id_ + 1, stop_msg);
}
}
}
} }
void Interceptor::StopCarrier() { void Interceptor::StopCarrier() {
......
...@@ -100,11 +100,25 @@ std::vector<OpRole> RuntimeGraph::functionality_order = { ...@@ -100,11 +100,25 @@ std::vector<OpRole> RuntimeGraph::functionality_order = {
RuntimeGraph::RuntimeGraph(const ProgramDesc& program, RuntimeGraph::RuntimeGraph(const ProgramDesc& program,
const FleetExecutorDesc& exe_desc) const FleetExecutorDesc& exe_desc)
: exe_desc_(exe_desc) { : exe_desc_(exe_desc) {
if (exe_desc.grain() == "coarse") { if (exe_desc.strategy() == "1F1B") {
SplitProgramBasedFunctionality(program); SplitProgramBasedFunctionality(program);
AssignTaskToIntercepter(); AssignTaskToIntercepter();
FakeDependence(); FakeDependence();
FakeRuntimeInfo(); FakeRuntimeInfo();
} else if (exe_desc.strategy() == "Origin") {
int64_t cur_rank = exe_desc_.cur_rank();
int64_t max_run_times = exe_desc_.num_micro_batches();
int64_t max_slot_nums = exe_desc_.num_slots();
auto task_node = std::make_unique<TaskNode>(program, cur_rank,
max_run_times, max_slot_nums);
task_node->SetType("Compute");
task_nodes_.emplace_back(std::move(task_node));
int64_t task_id = task_nodes_[0]->task_id();
intercepter_id_to_rank_.insert({task_id, cur_rank});
intercepter_id_to_node_.insert({task_id, task_nodes_[0].get()});
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Strategy %s is None of 1F1B or Origin.", exe_desc.strategy()));
} }
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
...@@ -30,6 +31,12 @@ TaskNode::TaskNode(const framework::ProgramDesc& program, int64_t rank, ...@@ -30,6 +31,12 @@ TaskNode::TaskNode(const framework::ProgramDesc& program, int64_t rank,
// Should be serially invoked, not thread-safe // Should be serially invoked, not thread-safe
static int64_t task_node_cnt = 0; static int64_t task_node_cnt = 0;
task_id_ = task_node_cnt++; task_id_ = task_node_cnt++;
for (const auto& op_desc : program.Block(0).AllOps()) {
ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*op_desc));
}
for (const auto& op : ops_vec_) {
ops_.emplace_back(op.get());
}
} }
TaskNode::TaskNode(int32_t role, const std::vector<OperatorBase*>& ops, TaskNode::TaskNode(int32_t role, const std::vector<OperatorBase*>& ops,
......
...@@ -76,10 +76,12 @@ class TaskNode final { ...@@ -76,10 +76,12 @@ class TaskNode final {
private: private:
DISABLE_COPY_AND_ASSIGN(TaskNode); DISABLE_COPY_AND_ASSIGN(TaskNode);
TaskNode() = default; TaskNode() = default;
// ops_ will be removed in the future
std::vector<OperatorBase*> ops_; std::vector<OperatorBase*> ops_;
std::unordered_set<int64_t> upstream_; std::unordered_set<int64_t> upstream_;
std::unordered_set<int64_t> downstream_; std::unordered_set<int64_t> downstream_;
framework::ProgramDesc program_; framework::ProgramDesc program_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
int32_t role_; int32_t role_;
int64_t rank_; int64_t rank_;
int64_t task_id_; int64_t task_id_;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -32,8 +33,7 @@ void BindFleetExecutor(py::module* m) { ...@@ -32,8 +33,7 @@ void BindFleetExecutor(py::module* m) {
py::class_<FleetExecutor>(*m, "FleetExecutor") py::class_<FleetExecutor>(*m, "FleetExecutor")
.def(py::init<const std::string&>()) .def(py::init<const std::string&>())
.def("init", &FleetExecutor::Init) .def("init", &FleetExecutor::Init)
.def("run", &FleetExecutor::Run) .def("run", &FleetExecutor::Run);
.def("release", &FleetExecutor::Release);
py::class_<TaskNode>(*m, "TaskNode") py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<const framework::ProgramDesc&, int64_t, int64_t, int64_t>()) .def(py::init<const framework::ProgramDesc&, int64_t, int64_t, int64_t>())
......
...@@ -682,6 +682,8 @@ class Executor(object): ...@@ -682,6 +682,8 @@ class Executor(object):
self._enable_interpreter_core = _is_enable_standalone_executor() self._enable_interpreter_core = _is_enable_standalone_executor()
self._executor_cache = _ExecutorCache(self.place) self._executor_cache = _ExecutorCache(self.place)
self._fleet_executor_cache = None
def _get_scope_cache(self, program_cache_key): def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None) return self.scope_caches.get(program_cache_key, None)
...@@ -1960,49 +1962,52 @@ class Executor(object): ...@@ -1960,49 +1962,52 @@ class Executor(object):
print_period=100, print_period=100,
fetch_handler=None, fetch_handler=None,
use_program_cache=False): use_program_cache=False):
scope, real_fetch_list, trainer_instance = \ if self._fleet_executor_cache is None:
self._prepare_pipeline_ctx(program, dataset, scope, thread, from ..distributed.fleet.proto import fleet_executor_desc_pb2
is_infer, debug, fetch_list, fetch_info, from google.protobuf import text_format
print_period, fetch_handler, cur_rank = os.getenv("PADDLE_TRAINER_ID")
use_program_cache) trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS")
from ..distributed.fleet.proto import fleet_executor_desc_pb2 fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc()
from google.protobuf import text_format nrank = 1
cur_rank = os.getenv("PADDLE_TRAINER_ID") if cur_rank and trainer_endpoints_str:
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS") fleet_exe_desc.cur_rank = int(cur_rank)
fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc() trainer_endpoints = trainer_endpoints_str.split(',')
nrank = 1 for rank, endpoint in enumerate(trainer_endpoints):
if cur_rank and trainer_endpoints_str: rank_info = fleet_executor_desc_pb2.RankInfo()
fleet_exe_desc.cur_rank = int(cur_rank) rank_info.rank = rank
trainer_endpoints = trainer_endpoints_str.split(',') rank_info.ip_port = endpoint
for rank, endpoint in enumerate(trainer_endpoints): fleet_exe_desc.cluster_info.append(rank_info)
nrank = len(trainer_endpoints)
else:
fleet_exe_desc.cur_rank = 0
rank_info = fleet_executor_desc_pb2.RankInfo() rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = rank rank_info.rank = 0
rank_info.ip_port = endpoint rank_info.ip_port = ''
fleet_exe_desc.cluster_info.append(rank_info) fleet_exe_desc.cluster_info.append(rank_info)
nrank = len(trainer_endpoints) logging.warning(
else: "Fleet Executor will run on single device only.")
fleet_exe_desc.cur_rank = 0 fleet_opt = program._pipeline_opt["fleet_opt"]
rank_info = fleet_executor_desc_pb2.RankInfo() if "dist_strategy" in fleet_opt:
rank_info.rank = 0 fleet_exe_desc.dp_degree = fleet_opt["dist_strategy"][
rank_info.ip_port = '' "dp_degree"]
fleet_exe_desc.cluster_info.append(rank_info) fleet_exe_desc.mp_degree = fleet_opt["dist_strategy"][
logging.warning("Fleet Executor will run on single device only.") "mp_degree"]
fleet_opt = program._pipeline_opt["fleet_opt"] fleet_exe_desc.pp_degree = fleet_opt["dist_strategy"][
if "dist_strategy" in fleet_opt: "pp_degree"]
fleet_exe_desc.dp_degree = fleet_opt["dist_strategy"]["dp_degree"] if "num_micro_batches" in fleet_opt:
fleet_exe_desc.mp_degree = fleet_opt["dist_strategy"]["mp_degree"] fleet_exe_desc.num_micro_batches = fleet_opt[
fleet_exe_desc.pp_degree = fleet_opt["dist_strategy"]["pp_degree"] "num_micro_batches"]
if "num_micro_batches" in fleet_opt: num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree
fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"] assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu."
num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu." place = core.Place()
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) place.set_place(self.place)
place = core.Place() if scope is None:
place.set_place(self.place) scope = global_scope()
fleet_exe.init(program._pipeline_opt["section_program"].desc, scope, fleet_exe.init(program._pipeline_opt["section_program"].desc, scope,
place) place)
fleet_exe.run() self._fleet_executor_cache = fleet_exe
fleet_exe.release() self._fleet_executor_cache.run()
return None return None
def _run_pipeline(self, def _run_pipeline(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册