未验证 提交 02cc3c5e 编写于 作者: G gongweibao 提交者: GitHub

Fix allreduce_sum potential bugs on NPU. (#34462)

上级 b56dbe08
......@@ -164,6 +164,7 @@ void SectionWorker::Run1F1B(std::unique_ptr<GarbageCollector> &gc) {
while (fw_step < startup_steps) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
VLOG(2) << "micro steps fw_step:" << fw_step;
}
// 1f1b phase
......@@ -180,6 +181,7 @@ void SectionWorker::Run1F1B(std::unique_ptr<GarbageCollector> &gc) {
fw_step += 1;
bw_step += 1;
VLOG(2) << "micro steps fw_step:" << fw_step << ", bw_step:" << bw_step;
}
int reserve_bw_send_step = bw_step - 2;
......@@ -187,8 +189,10 @@ void SectionWorker::Run1F1B(std::unique_ptr<GarbageCollector> &gc) {
while (bw_step < num_microbatches_) {
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
VLOG(2) << "micro steps bw_step:" << bw_step;
}
VLOG(2) << "run update";
RunUpdate(gc, unused_vars_);
if (gc) {
......@@ -203,6 +207,7 @@ void SectionWorker::Run1F1B(std::unique_ptr<GarbageCollector> &gc) {
void SectionWorker::TrainFiles() {
VLOG(5) << "begin section_worker TrainFiles";
VLOG(2) << "mini batch steps:" << batch_id_;
int64_t max_memory_size = GetEagerDeletionThreshold();
std::unique_ptr<GarbageCollector> gc;
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_XPU_BKCL)
......@@ -119,13 +120,45 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
}
};
#if defined(PADDLE_WITH_ASCEND_CL)
// return true if found_inf_or_nan or return false;
template <typename T>
bool CheckNumerics(const framework::ExecutionContext& exe_ctx,
aclrtStream stream, const paddle::framework::Tensor* in) {
auto& dev_ctx =
exe_ctx.template device_context<paddle::platform::NPUDeviceContext>();
using Tensor = paddle::framework::Tensor;
Tensor out(in->type());
out.Resize(in->dims());
out.mutable_data<T>(dev_ctx.GetPlace());
bool found_inf_data = false;
try {
const auto& runner =
NpuOpRunner("CheckNumerics", {*in}, {out},
{{"message", std::string("check_numberics")}});
runner.Run(stream);
dev_ctx.Wait();
} catch (platform::EnforceNotMet& exception) {
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
found_inf_data = true;
} catch (...) {
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
found_inf_data = true;
}
return found_inf_data;
}
#endif
template <ReduceType red_type, typename T>
class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
auto in = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
auto place = ctx.GetPlace();
HcclDataType dtype = platform::ToHCCLDataType(in->type());
int64_t numel = in->numel();
......@@ -141,9 +174,10 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);
aclrtStream stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto dev_ctx = static_cast<platform::NPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
if (ctx.Attr<bool>("use_calc_stream")) {
stream = static_cast<platform::NPUDeviceContext*>(dev_ctx)->stream();
stream = dev_ctx->stream();
} else {
stream = comm->stream();
}
......@@ -171,9 +205,46 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
"Invalid reduce type: %d", red_type));
}
VLOG(3) << "begin hccl allreduce, parameter is: "
VLOG(3) << "hccl allreduce, parameter is: "
<< "input num: " << in->dims() << "dtype: " << dtype
<< "hccl_red_type: " << hccl_red_type << ", group is: " << group
<< ", sendbuff:" << sendbuff << ", recvbuff:" << recvbuff
<< ", out_size:" << out->memory_size()
<< ", use_calc_stream:" << ctx.Attr<bool>("use_calc_stream")
<< ", stream:" << stream;
framework::Tensor tmp;
tmp.mutable_data<float>({8}, ctx.GetPlace());
bool check_numerics = false;
auto d_type = in->type();
switch (d_type) {
case framework::proto::VarType::FP16:
case framework::proto::VarType::FP32: {
VLOG(4) << "prepare to FoundNanInf";
check_numerics = CheckNumerics<T>(ctx, dev_ctx->stream(), in);
VLOG(4) << "check_numerics:" << check_numerics;
break;
}
default:
break;
}
if (check_numerics) {
T inf = static_cast<T>(std::numeric_limits<float>::infinity());
VLOG(4) << "fill input data constant inf";
auto dims = in->dims();
auto mutable_in = const_cast<framework::Tensor*>(in);
FillNpuTensorWithConstant<T>(mutable_in, inf);
mutable_in->Resize(dims);
}
VLOG(3) << "hccl allreduce, parameter is: "
<< "input num: " << numel << "dtype: " << dtype
<< "hccl_red_type: " << hccl_red_type << ", group is: " << group;
<< "hccl_red_type: " << hccl_red_type << ", group is: " << group
<< ", sendbuff:" << sendbuff << ", recvbuff:" << recvbuff
<< ", out_size:" << out->memory_size();
PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce(
sendbuff, recvbuff, numel, dtype, hccl_red_type, comm->comm(),
......@@ -198,7 +269,7 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
BKCLDataType dtype = platform::ToBKCLDataType(in->type());
int64_t numel = in->numel();
const void* sendbuff = in->data<void>();
const void* sendbuff = in->data<T>();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);
......@@ -260,7 +331,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel();
const void* sendbuff = in->data<void>();
const void* sendbuff = in->data<T>();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册