未验证 提交 539d7185 编写于 作者: D danleifeng 提交者: GitHub

【HETERPS】edit cuda remote_streams (#34276)

* psgpu:edit cuda remote_streams; test=develop
上级 ca174025
......@@ -23,8 +23,9 @@ limitations under the License. */
#ifdef PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/table/depends/large_scale_kv.h"
#endif
#include "paddle/fluid/framework/rw_lock.h"
#include "thrust/pair.h"
//#include "cudf/concurrent_unordered_map.cuh.h"
// #include "cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/platform/type_defs.h"
......@@ -63,6 +64,8 @@ class HashTable {
int size() { return container_->size(); }
std::unique_ptr<RWLock> rwlock_{nullptr};
private:
TableContainer<KeyType, ValType>* container_;
int BLOCK_SIZE_{256};
......
......@@ -73,6 +73,7 @@ __global__ void update_kernel(Table* table,
template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
container_ = new TableContainer<KeyType, ValType>(capacity);
rwlock_.reset(new RWLock);
}
template <typename KeyType, typename ValType>
......
......@@ -525,12 +525,15 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
auto& node = path_[num][i].nodes_.back();
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
tables_[i]->rwlock_->RDLock();
tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<ValType*>(node.val_storage),
h_right[i] - h_left[i] + 1, resource_->remote_stream(i));
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, num));
}
for (int i = 0; i < total_gpu; ++i) {
cudaStreamSynchronize(resource_->remote_stream(i));
cudaStreamSynchronize(resource_->remote_stream(i, num));
tables_[i]->rwlock_->UNLock();
}
walk_to_src(num, total_gpu, h_left, h_right, d_shard_vals_ptr);
......@@ -621,13 +624,15 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int gpu_num,
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
tables_[i]->rwlock_->WRLock();
tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
reinterpret_cast<GradType*>(node.val_storage),
h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i));
resource_->remote_stream(i, gpu_num));
}
for (int i = 0; i < total_gpu; ++i) {
cudaStreamSynchronize(resource_->remote_stream(i));
cudaStreamSynchronize(resource_->remote_stream(i, gpu_num));
tables_[i]->rwlock_->UNLock();
}
}
......@@ -641,9 +646,11 @@ void HeterComm<KeyType, ValType, GradType>::update_one_table(
int dev_id = resource_->dev_id(gpu_num);
platform::CUDADeviceGuard guard(dev_id);
tables_[gpu_num]->rwlock_->WRLock();
tables_[gpu_num]->update(d_keys, d_grads, len, sgd,
resource_->remote_stream(gpu_num));
cudaStreamSynchronize(resource_->remote_stream(gpu_num));
resource_->remote_stream(gpu_num, gpu_num));
tables_[gpu_num]->rwlock_->UNLock();
cudaStreamSynchronize(resource_->remote_stream(gpu_num, gpu_num));
}
template <typename KeyType, typename ValType, typename GradType>
......
......@@ -27,16 +27,16 @@ GPUResource::GPUResource(std::vector<int>& dev_ids, int index) {
platform::CUDADeviceGuard guard(dev_id_);
local_streams_.resize(dev_ids_.size());
comm_streams_.resize(dev_ids_.size());
remote_streams_.resize(dev_ids_.size());
for (size_t i = 0; i < dev_ids_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&local_streams_[i], cudaStreamNonBlocking));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&comm_streams_[i], cudaStreamNonBlocking));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&remote_streams_[i], cudaStreamNonBlocking));
}
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamCreateWithFlags(&remote_stream_, cudaStreamNonBlocking));
}
GPUResource::~GPUResource() {
......@@ -47,7 +47,9 @@ GPUResource::~GPUResource() {
for (size_t i = 0; i < comm_streams_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(comm_streams_[i]));
}
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(remote_stream_));
for (size_t i = 0; i < remote_streams_.size(); ++i) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(remote_streams_[i]));
}
}
void HeterPsResource::enable_p2p() {
......@@ -90,8 +92,8 @@ cudaStream_t HeterPsResource::local_stream(int gpu_num, int stream_num) {
return resources_[gpu_num]->local_stream(stream_num);
}
cudaStream_t HeterPsResource::remote_stream(int gpu_num) {
return resources_[gpu_num]->remote_stream();
cudaStream_t HeterPsResource::remote_stream(int gpu_num, int stream_num) {
return resources_[gpu_num]->remote_stream(stream_num);
}
int HeterPsResource::dev_id(int num) { return dev_ids_[num]; }
......
......@@ -35,13 +35,13 @@ class GPUResource {
int dev_id() const { return dev_id_; }
int index() const { return index_; }
gpuStream_t local_stream(int num) { return local_streams_[num]; }
gpuStream_t remote_stream() { return remote_stream_; }
gpuStream_t remote_stream(int num) { return remote_streams_[num]; }
gpuStream_t comm_stream(int num) { return comm_streams_[num]; }
int dev_id_;
int index_;
std::vector<int> dev_ids_;
gpuStream_t remote_stream_;
std::vector<gpuStream_t> remote_streams_;
std::vector<gpuStream_t> local_streams_;
std::vector<gpuStream_t> comm_streams_;
};
......@@ -57,7 +57,7 @@ class HeterPsResource {
int get_index_by_devid(int devid);
int dev_id(int num);
gpuStream_t local_stream(int gpu_num, int stream_num);
gpuStream_t remote_stream(int gpu_num);
gpuStream_t remote_stream(int gpu_num, int stream_num);
gpuStream_t comm_stream(int gpu_num, int stream_num);
std::vector<std::shared_ptr<GPUResource>> resources_;
......
......@@ -121,7 +121,7 @@ void PSGPUWrapper::CopyForPull(const paddle::platform::Place& place,
cudaMemcpy(gpu_values, values.data(), values.size() * sizeof(float*),
cudaMemcpyHostToDevice);
PullCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>(
PullCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
gpu_values, total_values_gpu, gpu_len, hidden_size, slot_num,
total_length, gpu_keys);
cudaStreamSynchronize(stream);
......@@ -135,7 +135,7 @@ void PSGPUWrapper::CopyKeys(const paddle::platform::Place& place,
platform::DeviceContextPool::Instance().Get(
BOOST_GET_CONST(platform::CUDAPlace, place)))
->stream();
CopyKeysKernel<<<(total_len + 512 - 1) / 512, 512, 0, stream>>>(
CopyKeysKernel<<<(total_len + 1024 - 1) / 1024, 1024, 0, stream>>>(
origin_keys, total_keys, gpu_len, slot_num, total_len);
cudaStreamSynchronize(stream);
}
......@@ -173,7 +173,7 @@ void PSGPUWrapper::CopyForPush(const paddle::platform::Place& place,
cudaMemcpy(d_slot_vector, slot_vector_.data(),
slot_lengths_lod.size() * sizeof(int), cudaMemcpyHostToDevice);
PushCopy<<<(total_length + 512 - 1) / 512, 512, 0, stream>>>(
PushCopy<<<(total_length + 1024 - 1) / 1024, 1024, 0, stream>>>(
total_grad_values_gpu, gpu_values, gpu_len, hidden_size,
slot_lengths.size(), total_length, batch_size, d_slot_vector);
cudaStreamSynchronize(stream);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册