未验证 提交 1432e3d2 编写于 作者: R ronnywang 提交者: GitHub

Fix bug (#37868)

上级 6a1e4de2
......@@ -27,8 +27,8 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
......@@ -145,7 +145,7 @@ void NCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) {
void *src_ptr = src_tensor->data<void>();
auto nccl_dtype = platform::ToNCCLDataType(src_tensor->type());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast(
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
src_ptr, src_tensor->numel(), nccl_dtype, 0, comm->comm(), stream));
}
......
if(WIN32)
cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS device_context)
else()
if (WITH_NCCL OR WITH_RCCL)
if (WITH_GLOO AND (WITH_NCCL OR WITH_RCCL))
cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context)
cc_test(heter_ccl_context_test SRCS heter_ccl_context_test.cc DEPS heter_ccl_context nccl_context imperative_gloo_context gloo_context gloo_wrapper gloo fs shell)
#set_tests_properties(heter_ccl_context_test PROPERTIES LABELS "RUN_TYPE=DIST")
......
......@@ -79,7 +79,7 @@ void AllReduceByStream(int local_rank, int device_id) {
}
TEST(AllReduceByStream, Run) {
if (platform::GetCUDADeviceCount() >= 2) {
if (platform::GetGPUDeviceCount() >= 2) {
std::thread t0(AllReduceByStream, 0, 0);
std::thread t1(AllReduceByStream, 1, 1);
t0.join();
......
......@@ -111,7 +111,7 @@ void Broadcast(int local_rank, int device_id) {
}
TEST(Broadcast, Run) {
if (platform::GetCUDADeviceCount() >= 2) {
if (platform::GetGPUDeviceCount() >= 2) {
std::thread t0(Broadcast, 0, 0);
std::thread t1(Broadcast, 1, 1);
t0.join();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册