未验证 提交 fb276f23 编写于 作者: J jameszhang 提交者: GitHub

[kunlun] prevent overflow in collective softmax_with_ce (#52356)

* [kunlun] prevent numerical overflow in collective softmax_with_ce

* add fix in another branch
上级 4c6ad5c0
......@@ -131,6 +131,13 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::XPUContext, T> {
};
phi::XPUElementwise<T, XPUType>(
dev_ctx, logits_2d, logits_max, axis, &softmax_2d, f);
ret = xpu::clip<XPUType>(dev_ctx.x_context(),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
N * D,
-64.,
0.);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "clip");
}
// step 3, obtain predict target
......@@ -322,6 +329,13 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::XPUContext, T> {
};
phi::XPUElementwise<T, XPUType>(
dev_ctx, logits_2d, logits_max, axis, &softmax_2d, f);
ret = xpu::clip<XPUType>(dev_ctx.x_context(),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
reinterpret_cast<XPUType*>(softmax_2d.data<T>()),
N * D,
-64.,
0.);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "clip");
}
// step 3, obtain predict target
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册