未验证 提交 0074a3c9 编写于 作者: L LiYuRio 提交者: GitHub

[Fleet Executor] Refine runtime graph (#37703)

上级 bfb85779
......@@ -12,7 +12,7 @@ endif()
cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
${BRPC_DEPS})
if(WITH_DISTRIBUTE)
......
......@@ -31,9 +31,7 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
"Error occurs while parsing string to proto"));
}
FleetExecutor::~FleetExecutor() {
// Destroy Executor
}
FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); }
void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
framework::Scope* scope,
......@@ -113,8 +111,6 @@ void FleetExecutor::Run() {
carrier_instance.Start();
}
void FleetExecutor::Release() { root_scope_->DropKids(); }
void FleetExecutor::CopyParameters(int microbatch_id,
const framework::ProgramDesc& program) {
auto& global_block = program.Block(0);
......
......@@ -39,7 +39,6 @@ class FleetExecutor final {
void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope,
const platform::Place& place);
void Run();
void Release();
private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
......
......@@ -21,7 +21,7 @@ message RankInfo {
}
message FleetExecutorDesc {
optional string grain = 1 [ default = "coarse" ];
optional string strategy = 1 [ default = "Origin" ];
optional int64 cur_rank = 2 [ default = 0 ]; // Rank id of current processor
repeated RankInfo cluster_info = 3;
optional int32 dp_degree = 4 [ default = 1 ];
......
......@@ -40,34 +40,9 @@ void Interceptor::Join() {
void Interceptor::RegisterMsgHandle(MsgHandle handle) { handle_ = handle; }
void Interceptor::Handle(const InterceptorMessage& msg) {
if (handle_) {
handle_(msg);
} else {
VLOG(3) << "Interceptor is using default message handler. This handler is "
"only used for test purpose. Check whether you init interceptor "
"in the proper way.";
if (msg.message_type() == DATA_IS_READY) {
if (node_->role() != 2) {
VLOG(3) << "Fake handler is sending DATA_IS_READY message to: "
<< interceptor_id_ + 1 << ".";
InterceptorMessage data_is_ready_msg;
data_is_ready_msg.set_message_type(DATA_IS_READY);
Send(interceptor_id_ + 1, data_is_ready_msg);
} else {
// NOTE: max run time is reach for last interceptor
StopCarrier();
}
} else if (msg.message_type() == STOP) {
stop_ = true;
if (node_->role() != 2) {
VLOG(3) << "Fake handler is sending STOP message to: "
<< interceptor_id_ + 1 << ".";
InterceptorMessage stop_msg;
stop_msg.set_message_type(STOP);
Send(interceptor_id_ + 1, stop_msg);
}
}
}
PADDLE_ENFORCE_NOT_NULL(handle_, platform::errors::PreconditionNotMet(
"Message handle is not registered."));
handle_(msg);
}
void Interceptor::StopCarrier() {
......
......@@ -100,11 +100,25 @@ std::vector<OpRole> RuntimeGraph::functionality_order = {
RuntimeGraph::RuntimeGraph(const ProgramDesc& program,
const FleetExecutorDesc& exe_desc)
: exe_desc_(exe_desc) {
if (exe_desc.grain() == "coarse") {
if (exe_desc.strategy() == "1F1B") {
SplitProgramBasedFunctionality(program);
AssignTaskToIntercepter();
FakeDependence();
FakeRuntimeInfo();
} else if (exe_desc.strategy() == "Origin") {
int64_t cur_rank = exe_desc_.cur_rank();
int64_t max_run_times = exe_desc_.num_micro_batches();
int64_t max_slot_nums = exe_desc_.num_slots();
auto task_node = std::make_unique<TaskNode>(program, cur_rank,
max_run_times, max_slot_nums);
task_node->SetType("Compute");
task_nodes_.emplace_back(std::move(task_node));
int64_t task_id = task_nodes_[0]->task_id();
intercepter_id_to_rank_.insert({task_id, cur_rank});
intercepter_id_to_node_.insert({task_id, task_nodes_[0].get()});
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Strategy %s is None of 1F1B or Origin.", exe_desc.strategy()));
}
}
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
......@@ -30,6 +31,12 @@ TaskNode::TaskNode(const framework::ProgramDesc& program, int64_t rank,
// Should be serially invoked, not thread-safe
static int64_t task_node_cnt = 0;
task_id_ = task_node_cnt++;
for (const auto& op_desc : program.Block(0).AllOps()) {
ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*op_desc));
}
for (const auto& op : ops_vec_) {
ops_.emplace_back(op.get());
}
}
TaskNode::TaskNode(int32_t role, const std::vector<OperatorBase*>& ops,
......
......@@ -76,10 +76,12 @@ class TaskNode final {
private:
DISABLE_COPY_AND_ASSIGN(TaskNode);
TaskNode() = default;
// ops_ will be removed in the future
std::vector<OperatorBase*> ops_;
std::unordered_set<int64_t> upstream_;
std::unordered_set<int64_t> downstream_;
framework::ProgramDesc program_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
int32_t role_;
int64_t rank_;
int64_t task_id_;
......
......@@ -16,6 +16,7 @@
#include <pybind11/stl.h>
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
......@@ -32,8 +33,7 @@ void BindFleetExecutor(py::module* m) {
py::class_<FleetExecutor>(*m, "FleetExecutor")
.def(py::init<const std::string&>())
.def("init", &FleetExecutor::Init)
.def("run", &FleetExecutor::Run)
.def("release", &FleetExecutor::Release);
.def("run", &FleetExecutor::Run);
py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<const framework::ProgramDesc&, int64_t, int64_t, int64_t>())
......
......@@ -682,6 +682,8 @@ class Executor(object):
self._enable_interpreter_core = _is_enable_standalone_executor()
self._executor_cache = _ExecutorCache(self.place)
self._fleet_executor_cache = None
def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None)
......@@ -1960,49 +1962,52 @@ class Executor(object):
print_period=100,
fetch_handler=None,
use_program_cache=False):
scope, real_fetch_list, trainer_instance = \
self._prepare_pipeline_ctx(program, dataset, scope, thread,
is_infer, debug, fetch_list, fetch_info,
print_period, fetch_handler,
use_program_cache)
from ..distributed.fleet.proto import fleet_executor_desc_pb2
from google.protobuf import text_format
cur_rank = os.getenv("PADDLE_TRAINER_ID")
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS")
fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc()
nrank = 1
if cur_rank and trainer_endpoints_str:
fleet_exe_desc.cur_rank = int(cur_rank)
trainer_endpoints = trainer_endpoints_str.split(',')
for rank, endpoint in enumerate(trainer_endpoints):
if self._fleet_executor_cache is None:
from ..distributed.fleet.proto import fleet_executor_desc_pb2
from google.protobuf import text_format
cur_rank = os.getenv("PADDLE_TRAINER_ID")
trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS")
fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc()
nrank = 1
if cur_rank and trainer_endpoints_str:
fleet_exe_desc.cur_rank = int(cur_rank)
trainer_endpoints = trainer_endpoints_str.split(',')
for rank, endpoint in enumerate(trainer_endpoints):
rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = rank
rank_info.ip_port = endpoint
fleet_exe_desc.cluster_info.append(rank_info)
nrank = len(trainer_endpoints)
else:
fleet_exe_desc.cur_rank = 0
rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = rank
rank_info.ip_port = endpoint
rank_info.rank = 0
rank_info.ip_port = ''
fleet_exe_desc.cluster_info.append(rank_info)
nrank = len(trainer_endpoints)
else:
fleet_exe_desc.cur_rank = 0
rank_info = fleet_executor_desc_pb2.RankInfo()
rank_info.rank = 0
rank_info.ip_port = ''
fleet_exe_desc.cluster_info.append(rank_info)
logging.warning("Fleet Executor will run on single device only.")
fleet_opt = program._pipeline_opt["fleet_opt"]
if "dist_strategy" in fleet_opt:
fleet_exe_desc.dp_degree = fleet_opt["dist_strategy"]["dp_degree"]
fleet_exe_desc.mp_degree = fleet_opt["dist_strategy"]["mp_degree"]
fleet_exe_desc.pp_degree = fleet_opt["dist_strategy"]["pp_degree"]
if "num_micro_batches" in fleet_opt:
fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"]
num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree
assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu."
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
place = core.Place()
place.set_place(self.place)
fleet_exe.init(program._pipeline_opt["section_program"].desc, scope,
place)
fleet_exe.run()
fleet_exe.release()
logging.warning(
"Fleet Executor will run on single device only.")
fleet_opt = program._pipeline_opt["fleet_opt"]
if "dist_strategy" in fleet_opt:
fleet_exe_desc.dp_degree = fleet_opt["dist_strategy"][
"dp_degree"]
fleet_exe_desc.mp_degree = fleet_opt["dist_strategy"][
"mp_degree"]
fleet_exe_desc.pp_degree = fleet_opt["dist_strategy"][
"pp_degree"]
if "num_micro_batches" in fleet_opt:
fleet_exe_desc.num_micro_batches = fleet_opt[
"num_micro_batches"]
num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree
assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu."
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
place = core.Place()
place.set_place(self.place)
if scope is None:
scope = global_scope()
fleet_exe.init(program._pipeline_opt["section_program"].desc, scope,
place)
self._fleet_executor_cache = fleet_exe
self._fleet_executor_cache.run()
return None
def _run_pipeline(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册