未验证 提交 1432e3d2 编写于 作者: R ronnywang 提交者: GitHub

Fix bug (#37868)

上级 6a1e4de2
...@@ -27,8 +27,8 @@ ...@@ -27,8 +27,8 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
...@@ -145,7 +145,7 @@ void NCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) { ...@@ -145,7 +145,7 @@ void NCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) {
void *src_ptr = src_tensor->data<void>(); void *src_ptr = src_tensor->data<void>();
auto nccl_dtype = platform::ToNCCLDataType(src_tensor->type()); auto nccl_dtype = platform::ToNCCLDataType(src_tensor->type());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBcast( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclBcast(
src_ptr, src_tensor->numel(), nccl_dtype, 0, comm->comm(), stream)); src_ptr, src_tensor->numel(), nccl_dtype, 0, comm->comm(), stream));
} }
......
if(WIN32) if(WIN32)
cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS device_context) cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS device_context)
else() else()
if (WITH_NCCL OR WITH_RCCL) if (WITH_GLOO AND (WITH_NCCL OR WITH_RCCL))
cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context) cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context)
cc_test(heter_ccl_context_test SRCS heter_ccl_context_test.cc DEPS heter_ccl_context nccl_context imperative_gloo_context gloo_context gloo_wrapper gloo fs shell) cc_test(heter_ccl_context_test SRCS heter_ccl_context_test.cc DEPS heter_ccl_context nccl_context imperative_gloo_context gloo_context gloo_wrapper gloo fs shell)
#set_tests_properties(heter_ccl_context_test PROPERTIES LABELS "RUN_TYPE=DIST") #set_tests_properties(heter_ccl_context_test PROPERTIES LABELS "RUN_TYPE=DIST")
......
...@@ -79,7 +79,7 @@ void AllReduceByStream(int local_rank, int device_id) { ...@@ -79,7 +79,7 @@ void AllReduceByStream(int local_rank, int device_id) {
} }
TEST(AllReduceByStream, Run) { TEST(AllReduceByStream, Run) {
if (platform::GetCUDADeviceCount() >= 2) { if (platform::GetGPUDeviceCount() >= 2) {
std::thread t0(AllReduceByStream, 0, 0); std::thread t0(AllReduceByStream, 0, 0);
std::thread t1(AllReduceByStream, 1, 1); std::thread t1(AllReduceByStream, 1, 1);
t0.join(); t0.join();
......
...@@ -111,7 +111,7 @@ void Broadcast(int local_rank, int device_id) { ...@@ -111,7 +111,7 @@ void Broadcast(int local_rank, int device_id) {
} }
TEST(Broadcast, Run) { TEST(Broadcast, Run) {
if (platform::GetCUDADeviceCount() >= 2) { if (platform::GetGPUDeviceCount() >= 2) {
std::thread t0(Broadcast, 0, 0); std::thread t0(Broadcast, 0, 0);
std::thread t1(Broadcast, 1, 1); std::thread t1(Broadcast, 1, 1);
t0.join(); t0.join();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册