From 79e758c61f50a636e1bb7a6134757a66e0b0f869 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 29 Jul 2021 21:10:48 +0800 Subject: [PATCH] add fix op run order pass (#34427) * add fix op run order pass * add ut for fix_op_run_order * fix ci error * improve coverage * improve coverge again and fix cpu test case * follow some comments --- paddle/fluid/framework/details/CMakeLists.txt | 3 +- .../fluid/framework/details/build_strategy.h | 3 + .../details/eager_deletion_op_handle.cc | 11 + .../details/eager_deletion_op_handle.h | 2 + .../fast_threaded_ssa_graph_executor.cc | 9 +- .../fast_threaded_ssa_graph_executor.h | 2 +- .../framework/distributed_strategy.proto | 1 + .../framework/ir/coalesce_grad_tensor_pass.cc | 12 +- paddle/fluid/framework/ir/graph_helper.cc | 26 ++ paddle/fluid/framework/ir/graph_helper.h | 3 + .../multi_devices_graph_pass/CMakeLists.txt | 1 + .../fix_op_run_order_pass.cc | 270 ++++++++++++++++++ paddle/fluid/framework/parallel_executor.cc | 15 + paddle/fluid/platform/dynload/nccl.cc | 4 + paddle/fluid/platform/dynload/nccl.h | 5 + paddle/fluid/pybind/pybind.cc | 59 ++-- paddle/fluid/string/string_helper.h | 22 +- .../fluid/tests/unittests/CMakeLists.txt | 1 + ...test_parallel_executor_fix_op_run_order.py | 92 ++++++ 19 files changed, 517 insertions(+), 24 deletions(-) create mode 100644 paddle/fluid/framework/ir/multi_devices_graph_pass/fix_op_run_order_pass.cc create mode 100644 python/paddle/fluid/tests/unittests/test_parallel_executor_fix_op_run_order.py diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 23a3ee2c58e..1546027b794 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -134,7 +134,8 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass modify_op_lock_and_record_event_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass - sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass) + sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass + fix_op_run_order_pass) if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM)) set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass) endif() diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 81d2d5e6dae..3f8a27f3d5a 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -100,6 +100,9 @@ struct BuildStrategy { // while running. bool cache_runtime_context_{false}; + // Fix the op run order. + bool fix_op_run_order_{false}; + // Operator fusion // TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have // cycle. diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc index ba076173b4a..07f7bbdb97a 100644 --- a/paddle/fluid/framework/details/eager_deletion_op_handle.cc +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc @@ -19,6 +19,7 @@ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" #endif +#include namespace paddle { namespace framework { @@ -177,6 +178,16 @@ void EagerDeletionOpHandle::ClearGarbages( #endif } +std::vector EagerDeletionOpHandle::VarsToDelete() const { + std::vector var_names; + var_names.reserve(var_infos_.size()); + for (auto &info : var_infos_) { + var_names.emplace_back(info->Name()); + } + std::sort(var_names.begin(), var_names.end()); + return var_names; +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.h b/paddle/fluid/framework/details/eager_deletion_op_handle.h index b1b8c21230e..acfc45f1818 100644 --- a/paddle/fluid/framework/details/eager_deletion_op_handle.h +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.h @@ -64,6 +64,8 @@ class EagerDeletionOpHandle : public OpHandleBase { size_t GetScopeIdx() const { return scope_idx_; } + std::vector VarsToDelete() const; + protected: void RunImpl() override; diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 8b41b99ac7a..120bdd2bc9f 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -40,9 +40,14 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( places_(places), graph_(graph), fetch_ctxs_(places), - pool_(strategy.num_threads_), // add one more thread for generate op_deps prepare_pool_(1) { + if (ir::IsTopologySortOperationsUnique(*graph_)) { + VLOG(10) + << "Change thread number to 1 because the toposort order is unique"; + strategy_.num_threads_ = 1; + } + pool_.reset(new ::ThreadPool(strategy.num_threads_)); for (auto &op : ir::FilterByNodeWrapper(*graph_)) { int dep = static_cast(op->NotReadyInputSize()); op_deps_.emplace(op, dep); @@ -223,7 +228,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( OpHandleBase *op, const std::shared_ptr> &complete_q) { ++remaining_; - this->pool_.enqueue([=] { + this->pool_->enqueue([=] { std::deque op_queue; op_queue.push_front(op); diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h index 72f7412602f..4477702900a 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h @@ -60,7 +60,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { atomic_op_deps_; ExceptionHolder exception_; - ::ThreadPool pool_; + std::unique_ptr<::ThreadPool> pool_; ::ThreadPool prepare_pool_; std::vector traced_ops_; diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index dabe2160689..87f4d9af02b 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -111,6 +111,7 @@ message BuildStrategy { optional bool fuse_bn_add_act_ops = 10 [ default = true ]; optional bool enable_auto_fusion = 11 [ default = false ]; optional bool enable_addto = 12 [ default = false ]; + optional bool fix_op_run_order = 13 [ default = false ]; } message ExecutionStrategy { diff --git a/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc b/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc index 41372c09f4e..ffd80f0c90a 100644 --- a/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc +++ b/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h" +#include #include #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h" @@ -254,8 +255,15 @@ class CoalesceGradTensorPass : public ir::Pass { const std::unordered_map> &vars_info, const details::ParamsAndGrads ¶ms_grads, details::GroupParamsAndGrads *group_params_grads) const { - SetGroupAccordingToLayers(vars_info, params_grads, group_params_grads); - SetGroupAccordingToMemorySize(vars_info, group_params_grads); + if (GetFuseParameterMemorySize() == 0) { + group_params_grads->resize(1); + auto &result_param_grads = (*group_params_grads)[0]; + result_param_grads = params_grads; + std::sort(result_param_grads.begin(), result_param_grads.end()); + } else { + SetGroupAccordingToLayers(vars_info, params_grads, group_params_grads); + SetGroupAccordingToMemorySize(vars_info, group_params_grads); + } if (!IsUnifiedDtype(params_grads, vars_info)) { ReGroupByDtype(vars_info, group_params_grads); } diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 50174cfbbba..7b6002da096 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -143,6 +143,32 @@ std::vector TopologySortOperations(const Graph &graph) { return ret; } +bool IsTopologySortOperationsUnique(const Graph &graph) { + auto nodes = TopologySortOperations(graph); + size_t n = nodes.size(); + for (size_t i = 1; i < n; ++i) { + auto *prev_op = nodes[i - 1]; + auto *cur_op = nodes[i]; + + std::unordered_set prev_op_outputs; + for (auto *output : prev_op->outputs) { + prev_op_outputs.insert(output); + } + + bool found = false; + for (auto *input : cur_op->inputs) { + if (prev_op_outputs.count(input) > 0) { + found = true; + break; + } + } + if (!found) { + return false; + } + } + return true; +} + // Build operator inlink edge table. std::map, ir::NodeComp> BuildOperationAdjList(const Graph &graph) { diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index 27a4fe25cd5..3c3ea662502 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -57,6 +57,9 @@ size_t GraphNum(const Graph &graph); // `graph` cannot contain circle. std::vector TopologySortOperations(const Graph &graph); +// Check whether the topological order of graph ops is unique +bool IsTopologySortOperationsUnique(const Graph &graph); + // Topological sort, but try to DFS. std::vector TopologyDfsSortOperations(const Graph &graph); diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt b/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt index 2f79c425e1d..f945ddbd5d6 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/CMakeLists.txt @@ -18,3 +18,4 @@ cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph gr cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS all_reduce_op_handle graph graph_helper pass) cc_library(backward_optimizer_op_deps_pass SRCS backward_optimizer_op_deps_pass.cc DEPS graph graph_helper pass) cc_library(add_reader_dependency_pass SRCS add_reader_dependency_pass.cc DEPS graph graph_helper pass) +cc_library(fix_op_run_order_pass SRCS fix_op_run_order_pass DEPS graph graph_helper multi_devices_helper pass op_handle_base eager_deletion_op_handle) diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/fix_op_run_order_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/fix_op_run_order_pass.cc new file mode 100644 index 00000000000..772b4c1c915 --- /dev/null +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/fix_op_run_order_pass.cc @@ -0,0 +1,270 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/var_handle.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/string/string_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +static std::string kSep(1, static_cast(1)); // NOLINT + +// NOTE: VariableNameMap is sorted! +static std::string VarNameMapToString(const VariableNameMap &var_map) { + std::vector tmp_strs; + tmp_strs.reserve(var_map.size()); + for (auto &pair : var_map) { + auto str = pair.first + kSep + string::join_strings(pair.second, kSep); + tmp_strs.emplace_back(std::move(str)); + } + return string::join_strings(tmp_strs, kSep); +} + +static std::string OpDescToString(const OpDesc &op) { + return "OpDesc" + kSep + op.Type() + kSep + VarNameMapToString(op.Inputs()) + + kSep + VarNameMapToString(op.Outputs()); +} + +static std::string VarHandleListToString( + const std::vector &vars) { + std::vector valid_vars; + valid_vars.reserve(vars.size()); + for (auto *v : vars) { + auto *valid_var = dynamic_cast(v); + if (valid_var != nullptr) { + valid_vars.emplace_back(valid_var->Name()); + } + } + std::sort(valid_vars.begin(), valid_vars.end()); + return string::join_strings(valid_vars, kSep); +} + +static std::string EagerDeletionOpHandleToString( + const details::EagerDeletionOpHandle &op); +static std::string OpHandleToString(const details::OpHandleBase &op); + +static std::string EagerDeletionOpHandleToString( + const details::EagerDeletionOpHandle &op) { + auto vars_to_delete = op.VarsToDelete(); + std::unordered_set prev_ops; + std::vector prev_op_strs; + prev_op_strs.reserve(op.Inputs().size()); + for (auto *var : op.Inputs()) { + auto *prev_op = var->GeneratedOp(); + if (prev_op == nullptr) continue; + prev_op_strs.push_back(OpHandleToString(*prev_op)); + } + std::sort(prev_op_strs.begin(), prev_op_strs.end()); + // NOTE: gc op does not have any valid input/output vars + return "OpHandleBase" + kSep + op.Name() + kSep + + string::join_strings(vars_to_delete, kSep) + kSep + + string::join_strings(prev_op_strs, kSep); +} + +static std::string OpHandleToString(const details::OpHandleBase &op) { + // NOTE: gc op does not have any valid input/output vars + auto gc_op = dynamic_cast(&op); + if (gc_op) { + return EagerDeletionOpHandleToString(*gc_op); + } + return "OpHandleBase" + kSep + op.Name() + kSep + + VarHandleListToString(op.Inputs()) + kSep + + VarHandleListToString(op.Outputs()); +} + +static void AddSequentialDepsForSortedOps( + Graph *graph, const std::vector &sorted_ops) { + size_t n = sorted_ops.size(); + for (size_t i = 1; i < n; ++i) { + auto *prev_op = sorted_ops[i - 1]; + auto *cur_op = sorted_ops[i]; + auto *dep_var = new details::DummyVarHandle(graph->CreateControlDepVar()); + graph->Get(details::kGraphDepVars).emplace(dep_var); + prev_op->AddOutput(dep_var); + cur_op->AddInput(dep_var); + } +} + +class FixOpRunOrderPass : public Pass { + protected: + void ApplyImpl(Graph *graph) const override { + const auto &program = graph->OriginProgram(); + std::unordered_map op_to_idx; + size_t i = 0; + for (auto *op_desc : program.Block(0).AllOps()) { + auto op_desc_str = OpDescToString(*op_desc); + PADDLE_ENFORCE_EQ( + op_to_idx.emplace(op_desc_str, i).second, true, + platform::errors::PermissionDenied( + "FixOpRunOrderPass cannot handle OpDesc with same " + "type, inputs and outputs yet, error string repr: %s", + op_desc_str)); + ++i; + } + + // a map to record: "Node" -> "Node Index" + std::unordered_map node_to_idx; + // a map to record found "Node Index" + std::unordered_set found_node_indices; + // a map to record the new OpDesc created by other Passes. These ops does + // not exist in the origin program + std::map new_op_desc_nodes; + // a map to record the new OpHandle created by other Passes. These ops does + // not have OpDesc and does not exist in the origin program + std::map new_op_handle_nodes; + + // Step 1: handle the unchanged OpDesc, and record new OpDesc/OpHandle + auto op_handles = FilterByNodeWrapper(*graph); + for (auto *op_handle : op_handles) { + auto *node = op_handle->Node(); + if (node->Op() == nullptr) { + auto node_str = OpHandleToString(*op_handle); + PADDLE_ENFORCE_EQ(new_op_handle_nodes.emplace(node_str, node).second, + true, + platform::errors::PermissionDenied( + "FixOpRunOrderPass cannot OpHandle with same " + "inputs and outputs yet, error repr: %s", + node_str)); + continue; + } + + auto node_str = OpDescToString(*(node->Op())); + auto iter = op_to_idx.find(node_str); + if (iter != op_to_idx.end()) { + size_t idx = iter->second; + PADDLE_ENFORCE_EQ( + found_node_indices.count(idx), 0, + platform::errors::PermissionDenied( + "FixOpRunOrderPass cannot handle OpDesc with same " + "type, inputs and outputs yet, error repr: %s", + node_str)); + found_node_indices.insert(idx); + node_to_idx[node] = idx; + } else { + PADDLE_ENFORCE_EQ( + new_op_desc_nodes.emplace(node_str, node).second, true, + platform::errors::PermissionDenied( + "FixOpRunOrderPass cannot handle OpDesc with same " + "type, inputs and outputs yet, error repr: %s", + node_str)); + } + } + + VLOG(10) << "Found unchanged OpDesc " << node_to_idx.size() + << ", new OpDesc " << new_op_desc_nodes.size() << ", new OpHandle " + << new_op_handle_nodes.size(); + + // Step 2: assign node index to new OpDesc + size_t node_id_offset = op_to_idx.size(); + for (auto &pair : new_op_desc_nodes) { + node_to_idx[pair.second] = node_id_offset; + ++node_id_offset; + } + + // Step 3: assign node index to new OpHandle + for (auto &pair : new_op_handle_nodes) { + node_to_idx[pair.second] = node_id_offset; + ++node_id_offset; + } + + // Step 4: sort unchanged OpDesc/new OpDesc/new OpHandle by topological + // order and node index + OpGraphView graph_view(op_handles); + auto comp = [&node_to_idx](details::OpHandleBase *op1, + details::OpHandleBase *op2) { + auto priority1 = static_cast(op1->GetPriority()); + auto priority2 = static_cast(op2->GetPriority()); + if (priority1 != priority2) { + return priority1 < priority2; + } + return node_to_idx.at(op1->Node()) < node_to_idx.at(op2->Node()); + }; + + std::vector sorted_ops; + sorted_ops.reserve(op_handles.size()); + std::queue q; + std::vector tmp_ops; + auto op_deps = graph_view.GetPrecedingDepNum(); + // Get ready ops first + for (auto iter = op_deps.begin(); iter != op_deps.end();) { + if (iter->second != 0) { + ++iter; + continue; + } + tmp_ops.push_back(iter->first); + op_deps.erase(iter++); + } + // Sort ready ops by node index + std::sort(tmp_ops.begin(), tmp_ops.end(), comp); + for (auto *op : tmp_ops) { + q.push(op); + } + while (!q.empty()) { + auto *cur_op = q.front(); + q.pop(); + sorted_ops.push_back(cur_op); + + auto &pending_ops = graph_view.PendingOps(cur_op); + tmp_ops.clear(); + for (auto *pending_op : pending_ops) { + if (--op_deps.at(pending_op) == 0) { + op_deps.erase(pending_op); + tmp_ops.push_back(pending_op); + } + } + // sort next ready ops by node index + std::sort(tmp_ops.begin(), tmp_ops.end(), comp); + for (auto *op : tmp_ops) { + q.push(op); + } + } + + PADDLE_ENFORCE_EQ( + sorted_ops.size(), op_handles.size(), + platform::errors::PermissionDenied("There are unvisited ops")); + if (VLOG_IS_ON(10)) { + // print op order to debug + std::vector sorted_ops_indices; + sorted_ops_indices.reserve(sorted_ops.size()); + for (auto *op : sorted_ops) { + sorted_ops_indices.push_back(node_to_idx.at(op->Node())); + } + VLOG(10) << "Fix op order: " + << string::join_strings(sorted_ops_indices, ','); + } + + // Step 5: add sequential deps for ops to guarantee there is only one + // toposort order + AddSequentialDepsForSortedOps(graph, sorted_ops); + PADDLE_ENFORCE_EQ(IsTopologySortOperationsUnique(*graph), true, + platform::errors::PermissionDenied( + "The topological order must be unique " + "after FixOpRunOrderPass is applied")); + } +}; + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fix_op_run_order_pass, paddle::framework::ir::FixOpRunOrderPass); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index badabce7b34..516a3bc63ca 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -104,6 +104,13 @@ class ParallelExecutorPrivate { inline bool HasGarbageCollectors() const { return !gcs_.empty(); } + void ApplyFixOpRunOrderPass(ir::Graph *graph) { + if (build_strategy_.fix_op_run_order_) { + auto pass = ir::PassRegistry::Instance().Get("fix_op_run_order_pass"); + pass->Apply(graph); + } + } + /** * NOTE(zengjinle): the fed variables of users should not be reused, * because users may feed them into another network. Changing the fed @@ -1462,6 +1469,10 @@ std::vector ParallelExecutor::CreateSSAGraphExecutor( auto possible_inference_graphs = details::TrySeparateToMultipleSingleDeviceGraphs(graph); if (!possible_inference_graphs.empty()) { + for (auto &g : possible_inference_graphs) { + member_->ApplyFixOpRunOrderPass(g.get()); + } + VLOG(5) << "Use ParallelSSAGraphExecutor in inference phase"; auto *pg_exe = new details::ParallelSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, @@ -1474,6 +1485,9 @@ std::vector ParallelExecutor::CreateSSAGraphExecutor( member_->executor_.reset(pg_exe); member_->inference_executor_ = pg_exe; } else { + if (member_->places_.size() == 1) { + member_->ApplyFixOpRunOrderPass(graph); + } LOG_IF(WARNING, details::HasKeepLastReadOp(*graph)) << "drop_last=False for DataLoader is not supported in training " "network. It is automatically turned to drop_last=True."; @@ -1560,3 +1574,4 @@ USE_PASS(eager_deletion_pass); USE_PASS(buffer_shared_inplace_pass); USE_PASS(buffer_shared_cross_op_memory_reuse_pass); USE_PASS(inplace_addto_op_pass); +USE_PASS(fix_op_run_order_pass); diff --git a/paddle/fluid/platform/dynload/nccl.cc b/paddle/fluid/platform/dynload/nccl.cc index cfc98561e87..24a4e5aad04 100644 --- a/paddle/fluid/platform/dynload/nccl.cc +++ b/paddle/fluid/platform/dynload/nccl.cc @@ -29,6 +29,10 @@ NCCL_RAND_ROUTINE_EACH(DEFINE_WRAP); NCCL_RAND_ROUTINE_EACH_AFTER_2212(DEFINE_WRAP) #endif +#if NCCL_VERSION_CODE >= 2304 +NCCL_RAND_ROUTINE_EACH_AFTER_2304(DEFINE_WRAP) +#endif + #if NCCL_VERSION_CODE >= 2703 NCCL_RAND_ROUTINE_EACH_AFTER_2703(DEFINE_WRAP) #endif diff --git a/paddle/fluid/platform/dynload/nccl.h b/paddle/fluid/platform/dynload/nccl.h index 057636cfef8..ea6daf15b91 100644 --- a/paddle/fluid/platform/dynload/nccl.h +++ b/paddle/fluid/platform/dynload/nccl.h @@ -64,6 +64,11 @@ NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) NCCL_RAND_ROUTINE_EACH_AFTER_2212(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) #endif +#if NCCL_VERSION_CODE >= 2304 +#define NCCL_RAND_ROUTINE_EACH_AFTER_2304(__macro) __macro(ncclGetVersion); +NCCL_RAND_ROUTINE_EACH_AFTER_2304(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) +#endif + #if NCCL_VERSION_CODE >= 2703 #define NCCL_RAND_ROUTINE_EACH_AFTER_2703(__macro) \ __macro(ncclSend); \ diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f4976670190..2cda2095917 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -467,6 +467,19 @@ static void AssertStaticGraphAndDygraphGradMakerNoDiff() { string::join_strings(ops, ','))); } +#ifdef PADDLE_WITH_NCCL +static int GetNCCLVersion() { +#if NCCL_VERSION_CODE >= 2304 + int ver; + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetVersion(&ver)); + return ver; +#else + PADDLE_THROW(platform::errors::External( + "Cannot get NCCL version successfully when nccl version < 2.3.4")); +#endif +} +#endif + #ifdef PADDLE_WITH_AVX PYBIND11_MODULE(core_avx, m) { #else @@ -496,6 +509,14 @@ PYBIND11_MODULE(core_noavx, m) { m.def("cudnn_version", &platform::CudnnVersion); #endif +#ifdef PADDLE_WITH_NCCL + m.def("nccl_version", &GetNCCLVersion); +#endif + + m.def("wait_device", [](const platform::Place &place) { + platform::DeviceContextPool::Instance().Get(place)->Wait(); + }); + m.def("from_dlpack", [](py::capsule *dltensor) { DLManagedTensor *dmt = reinterpret_cast( PyCapsule_GetPointer(dltensor->ptr(), "dltensor")); @@ -1796,20 +1817,20 @@ All parameter, weight, gradient are variables in Paddle. .def("__str__", string::to_string); py::class_(m, "Operator") - .def_static( - "create", - [](py::bytes protobin) { - proto::OpDesc desc; - PADDLE_ENFORCE_EQ(desc.ParsePartialFromString(protobin), true, - platform::errors::InvalidArgument( - "Cannot parse user input to OpDesc")); - PADDLE_ENFORCE_EQ( - desc.IsInitialized(), true, - platform::errors::InvalidArgument( - "The provided OpDesc is not initialized, the reason is: %s", - desc.InitializationErrorString())); - return OpRegistry::CreateOp(desc); - }) + .def_static("create", + [](py::bytes protobin) { + proto::OpDesc desc; + PADDLE_ENFORCE_EQ(desc.ParsePartialFromString(protobin), + true, + platform::errors::InvalidArgument( + "Cannot parse user input to OpDesc")); + PADDLE_ENFORCE_EQ(desc.IsInitialized(), true, + platform::errors::InvalidArgument( + "The provided OpDesc is not " + "initialized, the reason is: %s", + desc.InitializationErrorString())); + return OpRegistry::CreateOp(desc); + }) .def("run", [](OperatorBase &self, const Scope &scope, const platform::CPUPlace &place) { self.Run(scope, place); }) @@ -2928,8 +2949,8 @@ All parameter, weight, gradient are variables in Paddle. self.memory_optimize_ = (py_obj == Py_True); } else { PADDLE_THROW(platform::errors::InvalidArgument( - "BuildStrategy.memory_optimize must be set to None, False or " - "True")); + "BuildStrategy.memory_optimize must be set to None, False " + "or True")); } }, R"DOC((bool, optional): memory opitimize aims to save total memory @@ -3003,6 +3024,12 @@ All parameter, weight, gradient are variables in Paddle. const std::unordered_set &mkldnn_enabled_op_types) { self.mkldnn_enabled_op_types_ = mkldnn_enabled_op_types; }) + .def_property( + "fix_op_run_order", + [](const BuildStrategy &self) { return self.fix_op_run_order_; }, + [](BuildStrategy &self, bool fix_op_run_order) { + self.fix_op_run_order_ = fix_op_run_order; + }) .def("_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(true); diff --git a/paddle/fluid/string/string_helper.h b/paddle/fluid/string/string_helper.h index 499539226bd..f7387e877af 100644 --- a/paddle/fluid/string/string_helper.h +++ b/paddle/fluid/string/string_helper.h @@ -38,7 +38,8 @@ void format_string_append(std::string& str, const char* fmt, // NOLINT CHECK_GE(len, 0); size_t oldlen = str.length(); str.resize(oldlen + len + 1); - CHECK(snprintf(&str[oldlen], (size_t)len + 1, fmt, args...) == len); + CHECK(snprintf(&str[oldlen], (size_t)len + 1, fmt, args...) == // NOLINT + len); str.resize(oldlen + len); } @@ -127,7 +128,24 @@ template std::string join_strings(const Container& strs, char delim) { std::string str; - int i = 0; + size_t i = 0; + for (auto& elem : strs) { + if (i > 0) { + str += delim; + } + + str += boost::lexical_cast(elem); + ++i; + } + + return str; +} + +template +std::string join_strings(const Container& strs, const std::string& delim) { + std::string str; + + size_t i = 0; for (auto& elem : strs) { if (i > 0) { str += delim; diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 82874be3230..e7172507696 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -688,6 +688,7 @@ add_subdirectory(ir) if (WITH_TESTING) set_property(TEST test_parallel_executor_mnist PROPERTY ENVIRONMENT GLOG_vmodule=all_reduce_deps_pass=10) + set_property(TEST test_parallel_executor_fix_op_run_order PROPERTY ENVIRONMENT GLOG_vmodule=fix_op_run_order_pass=10) endif() set_tests_properties(test_parallel_executor_test_while_train test_parallel_executor_mnist diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_fix_op_run_order.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_fix_op_run_order.py new file mode 100644 index 00000000000..f48cfbd50eb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_fix_op_run_order.py @@ -0,0 +1,92 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import unittest +import numpy as np +from paddle.vision.models import resnet50 +from paddle.nn import CrossEntropyLoss + + +class TestFixOpRunOrder(unittest.TestCase): + def setUp(self): + paddle.enable_static() + paddle.seed(1) + paddle.framework.random._manual_program_seed(1) + if paddle.is_compiled_with_cuda(): + fluid.set_flags({'FLAGS_cudnn_deterministic': 1}) + + def get_place(self): + return paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( + ) else paddle.CPUPlace() + + def get_feed(self): + batch_size = 32 + image = np.random.random([batch_size, 3, 224, 224]).astype('float32') + label = np.random.randint(0, 1000, [batch_size, 1]).astype('int64') + return {"image": image, "label": label} + + def create_model(self, fix_op_run_order): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + scope = paddle.static.Scope() + with paddle.static.program_guard(main_prog, startup_prog): + image = paddle.static.data( + name="image", shape=[None, 3, 224, 224], dtype="float32") + label = paddle.static.data( + name="label", shape=[None, 1], dtype="int64") + model = resnet50() + pred = model(image) + loss_fn = CrossEntropyLoss() + loss = loss_fn(pred, label) + optimizer = paddle.optimizer.SGD(learning_rate=1e-3) + optimizer.minimize(loss) + + build_strategy = paddle.static.BuildStrategy() + build_strategy.fix_op_run_order = fix_op_run_order + build_strategy.fuse_bn_act_ops = True + build_strategy.fuse_bn_add_act_ops = True + main_prog = paddle.static.CompiledProgram(main_prog).with_data_parallel( + loss_name=loss.name, + build_strategy=build_strategy, + places=[self.get_place()]) + + exe = paddle.static.Executor(self.get_place()) + with paddle.static.scope_guard(scope): + exe.run(startup_prog) + + return main_prog, scope, loss + + def run_and_fetch_loss(self, main_prog, scope, loss, feed): + with paddle.static.scope_guard(scope): + exe = paddle.static.Executor(self.get_place()) + loss_value = exe.run(main_prog, feed=feed, fetch_list=[loss])[0] + return loss_value + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + + main1, scope1, loss1 = self.create_model(True) + main2, scope2, loss2 = self.create_model(False) + for i in range(10): + feed = self.get_feed() + loss_val1 = self.run_and_fetch_loss(main1, scope1, loss1, feed) + loss_val2 = self.run_and_fetch_loss(main2, scope2, loss2, feed) + self.assertEqual(loss_val1, loss_val2) + + +if __name__ == "__main__": + unittest.main() -- GitLab