未验证 提交 876aa717 编写于 作者: S seemingwang 提交者: GitHub

support distributed graph_split load and query. (#37740)

上级 a710abee
......@@ -514,6 +514,42 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
return fut;
}
std::future<int32_t> GraphBrpcClient::load_graph_split_config(
uint32_t table_id, std::string path) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)
->set_cmd_id(PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)->add_params(path);
GraphPsService_Stub rpc_stub =
getServiceStub(get_cmd_channel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache(
uint32_t table_id, size_t total_size_limit, size_t ttl) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
......
......@@ -93,6 +93,8 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id,
size_t size_limit,
size_t ttl);
virtual std::future<int32_t> load_graph_split_config(uint32_t table_id,
std::string path);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<uint64_t>& node_id_list);
virtual int32_t initialize();
......
......@@ -204,6 +204,8 @@ int32_t GraphBrpcService::initialize() {
&GraphBrpcService::sample_neighbors_across_multi_servers;
_service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] =
&GraphBrpcService::use_neighbors_sample_cache;
_service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
&GraphBrpcService::load_graph_split_config;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
......@@ -658,5 +660,20 @@ int32_t GraphBrpcService::use_neighbors_sample_cache(
((GraphTable *)table)->make_neighbor_sample_cache(size_limit, ttl);
return 0;
}
int32_t GraphBrpcService::load_graph_split_config(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(response, -1,
"load_graph_split_configrequest requires at least 1 "
"argument1[file_path]");
return 0;
}
((GraphTable *)table)->load_graph_split_config(request.params(0));
return 0;
}
} // namespace distributed
} // namespace paddle
......@@ -126,6 +126,10 @@ class GraphBrpcService : public PsBaseService {
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t load_graph_split_config(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
private:
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
......
......@@ -58,6 +58,7 @@ enum PsCmdID {
PS_GRAPH_SET_NODE_FEAT = 37;
PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER = 38;
PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE = 39;
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG = 40;
}
message PsRequestMessage {
......
......@@ -56,7 +56,7 @@ int32_t GraphTable::add_graph_node(std::vector<uint64_t> &id_list,
tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p.first % this->shard_num - this->shard_start;
this->shards[index].add_graph_node(p.first)->build_edges(p.second);
this->shards[index]->add_graph_node(p.first)->build_edges(p.second);
}
return 0;
}));
......@@ -79,7 +79,7 @@ int32_t GraphTable::remove_graph_node(std::vector<uint64_t> &id_list) {
tasks.push_back(_shards_task_pool[i]->enqueue([&batch, i, this]() -> int {
for (auto &p : batch[i]) {
size_t index = p % this->shard_num - this->shard_start;
this->shards[index].delete_node(p);
this->shards[index]->delete_node(p);
}
return 0;
}));
......@@ -97,6 +97,7 @@ void GraphShard::clear() {
}
GraphShard::~GraphShard() { clear(); }
void GraphShard::delete_node(uint64_t id) {
auto iter = node_location.find(id);
if (iter == node_location.end()) return;
......@@ -117,6 +118,14 @@ GraphNode *GraphShard::add_graph_node(uint64_t id) {
return (GraphNode *)bucket[node_location[id]];
}
GraphNode *GraphShard::add_graph_node(Node *node) {
auto id = node->get_id();
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
bucket.push_back(node);
}
return (GraphNode *)bucket[node_location[id]];
}
FeatureNode *GraphShard::add_feature_node(uint64_t id) {
if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size();
......@@ -134,6 +143,33 @@ Node *GraphShard::find_node(uint64_t id) {
return iter == node_location.end() ? nullptr : bucket[iter->second];
}
GraphTable::~GraphTable() {
for (auto p : shards) {
delete p;
}
for (auto p : extra_shards) {
delete p;
}
shards.clear();
extra_shards.clear();
}
int32_t GraphTable::load_graph_split_config(const std::string &path) {
VLOG(4) << "in server side load graph split config\n";
std::ifstream file(path);
std::string line;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
if (values.size() < 2) continue;
size_t index = (size_t)std::stoi(values[0]);
if (index != _shard_idx) continue;
auto dst_id = std::stoull(values[1]);
extra_nodes.insert(dst_id);
}
if (extra_nodes.size() != 0) use_duplicate_nodes = true;
return 0;
}
int32_t GraphTable::load(const std::string &path, const std::string &param) {
bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n');
......@@ -154,7 +190,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
res.clear();
std::vector<std::future<std::vector<uint64_t>>> tasks;
for (size_t i = 0; i < shards.size() && index < (int)ranges.size(); i++) {
end = total_size + shards[i].get_size();
end = total_size + shards[i]->get_size();
start = total_size;
while (start < end && index < ranges.size()) {
if (ranges[index].second <= start)
......@@ -169,11 +205,11 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
second -= total_size;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, first, second, i]() -> std::vector<uint64_t> {
return shards[i].get_ids_by_range(first, second);
return shards[i]->get_ids_by_range(first, second);
}));
}
}
total_size += shards[i].get_size();
total_size += shards[i]->get_size();
}
for (size_t i = 0; i < tasks.size(); i++) {
auto vec = tasks[i].get();
......@@ -217,7 +253,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
size_t index = shard_id - shard_start;
auto node = shards[index].add_feature_node(id);
auto node = shards[index]->add_feature_node(id);
node->set_feature_size(feat_name.size());
......@@ -245,7 +281,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
std::string sample_type = "random";
bool is_weighted = false;
int valid_count = 0;
int extra_alloc_index = 0;
for (auto path : paths) {
std::ifstream file(path);
std::string line;
......@@ -268,8 +304,24 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
size_t src_shard_id = src_id % shard_num;
if (src_shard_id >= shard_end || src_shard_id < shard_start) {
VLOG(4) << "will not load " << src_id << " from " << path
<< ", please check id distribution";
if (use_duplicate_nodes == false ||
extra_nodes.find(src_id) == extra_nodes.end()) {
VLOG(4) << "will not load " << src_id << " from " << path
<< ", please check id distribution";
continue;
}
int index;
if (extra_nodes_to_thread_index.find(src_id) !=
extra_nodes_to_thread_index.end()) {
index = extra_nodes_to_thread_index[src_id];
} else {
index = extra_alloc_index++;
extra_alloc_index %= task_pool_size_;
extra_nodes_to_thread_index[src_id] = index;
}
extra_shards[index]->add_graph_node(src_id)->build_edges(is_weighted);
extra_shards[index]->add_neighbor(src_id, dst_id, weight);
valid_count++;
continue;
}
if (count % 1000000 == 0) {
......@@ -278,36 +330,130 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
}
size_t index = src_shard_id - shard_start;
shards[index].add_graph_node(src_id)->build_edges(is_weighted);
shards[index].add_neighbor(src_id, dst_id, weight);
shards[index]->add_graph_node(src_id)->build_edges(is_weighted);
shards[index]->add_neighbor(src_id, dst_id, weight);
valid_count++;
}
}
VLOG(0) << valid_count << "/" << count << " edges are loaded successfully in "
<< path;
std::vector<int> used(task_pool_size_, 0);
// Build Sampler j
for (auto &shard : shards) {
auto bucket = shard.get_bucket();
auto bucket = shard->get_bucket();
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
used[get_thread_pool_index(bucket[i]->get_id())]++;
}
}
/*-----------------------
relocate the duplicate nodes to make them distributed evenly among threads.
*/
for (auto &shard : extra_shards) {
auto bucket = shard->get_bucket();
for (size_t i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type);
}
}
int size = extra_nodes_to_thread_index.size();
if (size == 0) return 0;
std::vector<int> index;
for (int i = 0; i < used.size(); i++) index.push_back(i);
sort(index.begin(), index.end(),
[&](int &a, int &b) { return used[a] < used[b]; });
std::vector<int> alloc(index.size(), 0), has_alloc(index.size(), 0);
int t = 1, aim = 0, mod = 0;
for (; t < used.size(); t++) {
if ((used[index[t]] - used[index[t - 1]]) * t >= size) {
break;
} else {
size -= (used[index[t]] - used[index[t - 1]]) * t;
}
}
aim = used[index[t - 1]] + size / t;
mod = size % t;
for (int x = t - 1; x >= 0; x--) {
alloc[index[x]] = aim;
if (t - x <= mod) alloc[index[x]]++;
alloc[index[x]] -= used[index[x]];
}
std::vector<uint64_t> vec[index.size()];
for (auto p : extra_nodes_to_thread_index) {
has_alloc[p.second]++;
vec[p.second].push_back(p.first);
}
sort(index.begin(), index.end(), [&](int &a, int &b) {
return has_alloc[a] - alloc[a] < has_alloc[b] - alloc[b];
});
int left = 0, right = index.size() - 1;
while (left < right) {
if (has_alloc[index[right]] - alloc[index[right]] == 0) break;
int x = std::min(alloc[index[left]] - has_alloc[index[left]],
has_alloc[index[right]] - alloc[index[right]]);
has_alloc[index[left]] += x;
has_alloc[index[right]] -= x;
uint64_t id;
while (x--) {
id = vec[index[right]].back();
vec[index[right]].pop_back();
extra_nodes_to_thread_index[id] = index[left];
vec[index[left]].push_back(id);
}
if (has_alloc[index[right]] - alloc[index[right]] == 0) right--;
if (alloc[index[left]] - has_alloc[index[left]] == 0) left++;
}
std::vector<GraphShard *> extra_shards_copy;
for (int i = 0; i < task_pool_size_; ++i) {
extra_shards_copy.push_back(new GraphShard());
}
for (auto &shard : extra_shards) {
auto &bucket = shard->get_bucket();
auto &node_location = shard->get_node_location();
while (bucket.size()) {
Node *temp = bucket.back();
bucket.pop_back();
node_location.erase(temp->get_id());
extra_shards_copy[extra_nodes_to_thread_index[temp->get_id()]]
->add_graph_node(temp);
}
}
for (int i = 0; i < task_pool_size_; ++i) {
delete extra_shards[i];
extra_shards[i] = extra_shards_copy[i];
}
return 0;
}
Node *GraphTable::find_node(uint64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return nullptr;
if (use_duplicate_nodes == false || extra_nodes_to_thread_index.size() == 0)
return nullptr;
auto iter = extra_nodes_to_thread_index.find(id);
if (iter == extra_nodes_to_thread_index.end())
return nullptr;
else {
return extra_shards[iter->second]->find_node(id);
}
}
size_t index = shard_id - shard_start;
Node *node = shards[index].find_node(id);
Node *node = shards[index]->find_node(id);
return node;
}
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
return node_id % shard_num % shard_num_per_server % task_pool_size_;
if (use_duplicate_nodes == false || extra_nodes_to_thread_index.size() == 0)
return node_id % shard_num % shard_num_per_server % task_pool_size_;
size_t src_shard_id = node_id % shard_num;
if (src_shard_id >= shard_end || src_shard_id < shard_start) {
auto iter = extra_nodes_to_thread_index.find(node_id);
if (iter != extra_nodes_to_thread_index.end()) {
return iter->second;
}
}
return src_shard_id % shard_num_per_server % task_pool_size_;
}
uint32_t GraphTable::get_thread_pool_index_by_shard_index(
......@@ -319,11 +465,16 @@ int32_t GraphTable::clear_nodes() {
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < shards.size(); i++) {
tasks.push_back(
_shards_task_pool[get_thread_pool_index_by_shard_index(i)]->enqueue(
[this, i]() -> int {
this->shards[i].clear();
return 0;
}));
_shards_task_pool[i % task_pool_size_]->enqueue([this, i]() -> int {
this->shards[i]->clear();
return 0;
}));
}
for (size_t i = 0; i < extra_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i]->enqueue([this, i]() -> int {
this->extra_shards[i]->clear();
return 0;
}));
}
for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0;
......@@ -334,7 +485,7 @@ int32_t GraphTable::random_sample_nodes(int sample_size,
int &actual_size) {
int total_size = 0;
for (int i = 0; i < shards.size(); i++) {
total_size += shards[i].get_size();
total_size += shards[i]->get_size();
}
if (sample_size > total_size) sample_size = total_size;
int range_num = random_sample_nodes_ranges;
......@@ -401,8 +552,8 @@ int32_t GraphTable::random_sample_neighbors(
size_t node_num = buffers.size();
std::function<void(char *)> char_del = [](char *c) { delete[] c; };
std::vector<std::future<int>> tasks;
std::vector<std::vector<uint32_t>> seq_id(shard_end - shard_start);
std::vector<std::vector<SampleKey>> id_list(shard_end - shard_start);
std::vector<std::vector<uint32_t>> seq_id(task_pool_size_);
std::vector<std::vector<SampleKey>> id_list(task_pool_size_);
size_t index;
for (size_t idx = 0; idx < node_num; ++idx) {
index = get_thread_pool_index(node_ids[idx]);
......@@ -524,7 +675,7 @@ int32_t GraphTable::set_node_feat(
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&, idx, node_id]() -> int {
size_t index = node_id % this->shard_num - this->shard_start;
auto node = shards[index].add_feature_node(node_id);
auto node = shards[index]->add_feature_node(node_id);
node->set_feature_size(this->feat_name.size());
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
const std::string &feature_name = feature_names[feat_idx];
......@@ -581,7 +732,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int size = 0, cur_size;
std::vector<std::future<std::vector<Node *>>> tasks;
for (size_t i = 0; i < shards.size() && total_size > 0; i++) {
cur_size = shards[i].get_size();
cur_size = shards[i]->get_size();
if (size + cur_size <= start) {
size += cur_size;
continue;
......@@ -590,7 +741,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int end = start + (count - 1) * step + 1;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, end, step, size]() -> std::vector<Node *> {
return this->shards[i].get_batch(start - size, end - size, step);
return this->shards[i]->get_batch(start - size, end - size, step);
}));
start += count * step;
total_size -= count;
......@@ -665,7 +816,14 @@ int32_t GraphTable::initialize() {
shard_end = shard_start + shard_num_per_server;
VLOG(0) << "in init graph table shard idx = " << _shard_idx << " shard_start "
<< shard_start << " shard_end " << shard_end;
shards = std::vector<GraphShard>(shard_num_per_server, GraphShard(shard_num));
for (int i = 0; i < shard_num_per_server; i++) {
shards.push_back(new GraphShard());
}
use_duplicate_nodes = false;
for (int i = 0; i < task_pool_size_; i++) {
extra_shards.push_back(new GraphShard());
}
return 0;
}
} // namespace distributed
......
......@@ -47,7 +47,6 @@ class GraphShard {
public:
size_t get_size();
GraphShard() {}
GraphShard(int shard_num) { this->shard_num = shard_num; }
~GraphShard();
std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step);
......@@ -60,18 +59,18 @@ class GraphShard {
}
GraphNode *add_graph_node(uint64_t id);
GraphNode *add_graph_node(Node *node);
FeatureNode *add_feature_node(uint64_t id);
Node *find_node(uint64_t id);
void delete_node(uint64_t id);
void clear();
void add_neighbor(uint64_t id, uint64_t dst_id, float weight);
std::unordered_map<uint64_t, int> get_node_location() {
std::unordered_map<uint64_t, int> &get_node_location() {
return node_location;
}
private:
std::unordered_map<uint64_t, int> node_location;
int shard_num;
std::vector<Node *> bucket;
};
......@@ -355,7 +354,7 @@ class ScaledLRU {
class GraphTable : public SparseTable {
public:
GraphTable() { use_cache = false; }
virtual ~GraphTable() {}
virtual ~GraphTable();
virtual int32_t pull_graph_list(int start, int size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature,
......@@ -374,6 +373,7 @@ class GraphTable : public SparseTable {
virtual int32_t initialize();
int32_t load(const std::string &path, const std::string &param);
int32_t load_graph_split_config(const std::string &path);
int32_t load_edges(const std::string &path, bool reverse);
......@@ -434,7 +434,7 @@ class GraphTable : public SparseTable {
}
protected:
std::vector<GraphShard> shards;
std::vector<GraphShard *> shards, extra_shards;
size_t shard_start, shard_end, server_num, shard_num_per_server, shard_num;
const int task_pool_size_ = 24;
const int random_sample_nodes_ranges = 3;
......@@ -449,7 +449,9 @@ class GraphTable : public SparseTable {
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
bool use_cache;
std::unordered_set<uint64_t> extra_nodes;
std::unordered_map<uint64_t, size_t> extra_nodes_to_thread_index;
bool use_cache, use_duplicate_nodes;
mutable std::mutex mutex_;
};
} // namespace distributed
......
......@@ -65,6 +65,9 @@ void GraphNode::build_edges(bool is_weighted) {
}
}
void GraphNode::build_sampler(std::string sample_type) {
if (sampler != nullptr) {
return;
}
if (sample_type == "random") {
sampler = new RandomSampler();
} else if (sample_type == "weighted") {
......
......@@ -21,6 +21,9 @@ cc_test(brpc_utils_test SRCS brpc_utils_test.cc DEPS brpc_utils scope math_funct
set_source_files_properties(graph_node_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_node_test SRCS graph_node_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(graph_node_split_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(graph_node_split_test SRCS graph_node_split_test.cc DEPS graph_py_service scope server client communicator ps_service boost table ps_framework_proto ${COMMON_DEPS})
set_source_files_properties(feature_value_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(feature_value_test SRCS feature_value_test.cc DEPS ${COMMON_DEPS} boost table)
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <fstream>
#include <iomanip>
#include <string>
#include <thread> // NOLINT
#include <unordered_set>
#include <vector>
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/service/env.h"
#include "paddle/fluid/distributed/service/graph_brpc_client.h"
#include "paddle/fluid/distributed/service/graph_brpc_server.h"
#include "paddle/fluid/distributed/service/graph_py_service.h"
#include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/service/service.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;
namespace distributed = paddle::distributed;
std::vector<std::string> edges = {
std::string("37\t45\t0.34"), std::string("37\t145\t0.31"),
std::string("37\t112\t0.21"), std::string("96\t48\t1.4"),
std::string("96\t247\t0.31"), std::string("96\t111\t1.21"),
std::string("59\t45\t0.34"), std::string("59\t145\t0.31"),
std::string("59\t122\t0.21"), std::string("97\t48\t0.34"),
std::string("97\t247\t0.31"), std::string("97\t111\t0.21")};
char edge_file_name[] = "edges.txt";
std::vector<std::string> nodes = {
std::string("user\t37\ta 0.34\tb 13 14\tc hello\td abc"),
std::string("user\t96\ta 0.31\tb 15 10\tc 96hello\td abcd"),
std::string("user\t59\ta 0.11\tb 11 14"),
std::string("user\t97\ta 0.11\tb 12 11"),
std::string("item\t45\ta 0.21"),
std::string("item\t145\ta 0.21"),
std::string("item\t112\ta 0.21"),
std::string("item\t48\ta 0.21"),
std::string("item\t247\ta 0.21"),
std::string("item\t111\ta 0.21"),
std::string("item\t46\ta 0.21"),
std::string("item\t146\ta 0.21"),
std::string("item\t122\ta 0.21"),
std::string("item\t49\ta 0.21"),
std::string("item\t248\ta 0.21"),
std::string("item\t113\ta 0.21")};
char node_file_name[] = "nodes.txt";
std::vector<std::string> graph_split = {std::string("0\t97")};
char graph_split_file_name[] = "graph_split.txt";
void prepare_file(char file_name[], std::vector<std::string> data) {
std::ofstream ofile;
ofile.open(file_name);
for (auto x : data) {
ofile << x << std::endl;
}
ofile.close();
}
void GetDownpourSparseTableProto(
::paddle::distributed::TableParameter* sparse_table_proto) {
sparse_table_proto->set_table_id(0);
sparse_table_proto->set_table_class("GraphTable");
sparse_table_proto->set_shard_num(127);
sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE);
::paddle::distributed::TableAccessorParameter* accessor_proto =
sparse_table_proto->mutable_accessor();
accessor_proto->set_accessor_class("CommMergeAccessor");
}
::paddle::distributed::PSParameter GetServerProto() {
// Generate server proto desc
::paddle::distributed::PSParameter server_fleet_desc;
::paddle::distributed::ServerParameter* server_proto =
server_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("GraphBrpcService");
server_service_proto->set_server_class("GraphBrpcServer");
server_service_proto->set_client_class("GraphBrpcClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
::paddle::distributed::TableParameter* sparse_table_proto =
downpour_server_proto->add_downpour_table_param();
GetDownpourSparseTableProto(sparse_table_proto);
return server_fleet_desc;
}
::paddle::distributed::PSParameter GetWorkerProto() {
::paddle::distributed::PSParameter worker_fleet_desc;
::paddle::distributed::WorkerParameter* worker_proto =
worker_fleet_desc.mutable_worker_param();
::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto =
worker_proto->mutable_downpour_worker_param();
::paddle::distributed::TableParameter* worker_sparse_table_proto =
downpour_worker_proto->add_downpour_table_param();
GetDownpourSparseTableProto(worker_sparse_table_proto);
::paddle::distributed::ServerParameter* server_proto =
worker_fleet_desc.mutable_server_param();
::paddle::distributed::DownpourServerParameter* downpour_server_proto =
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("GraphBrpcService");
server_service_proto->set_server_class("GraphBrpcServer");
server_service_proto->set_client_class("GraphBrpcClient");
server_service_proto->set_start_server_port(0);
server_service_proto->set_server_thread_num(12);
::paddle::distributed::TableParameter* server_sparse_table_proto =
downpour_server_proto->add_downpour_table_param();
GetDownpourSparseTableProto(server_sparse_table_proto);
return worker_fleet_desc;
}
/*-------------------------------------------------------------------------*/
std::string ip_ = "127.0.0.1", ip2 = "127.0.0.1";
uint32_t port_ = 5209, port2 = 5210;
std::vector<std::string> host_sign_list_;
std::shared_ptr<paddle::distributed::GraphBrpcServer> pserver_ptr_,
pserver_ptr2;
std::shared_ptr<paddle::distributed::GraphBrpcClient> worker_ptr_;
void RunServer() {
LOG(INFO) << "init first server";
::paddle::distributed::PSParameter server_proto = GetServerProto();
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, 2); // test
pserver_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto));
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr_->configure(server_proto, _ps_env, 0, empty_vec);
LOG(INFO) << "first server, run start(ip,port)";
pserver_ptr_->start(ip_, port_);
pserver_ptr_->build_peer2peer_connection(0);
LOG(INFO) << "init first server Done";
}
void RunServer2() {
LOG(INFO) << "init second server";
::paddle::distributed::PSParameter server_proto2 = GetServerProto();
auto _ps_env2 = paddle::distributed::PaddlePSEnvironment();
_ps_env2.set_ps_servers(&host_sign_list_, 2); // test
pserver_ptr2 = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto2));
std::vector<framework::ProgramDesc> empty_vec2;
framework::ProgramDesc empty_prog2;
empty_vec2.push_back(empty_prog2);
pserver_ptr2->configure(server_proto2, _ps_env2, 1, empty_vec2);
pserver_ptr2->start(ip2, port2);
pserver_ptr2->build_peer2peer_connection(1);
}
void RunClient(
std::map<uint64_t, std::vector<paddle::distributed::Region>>& dense_regions,
int index, paddle::distributed::PsBaseService* service) {
::paddle::distributed::PSParameter worker_proto = GetWorkerProto();
paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list_.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list_, servers_);
worker_ptr_ = std::shared_ptr<paddle::distributed::GraphBrpcClient>(
(paddle::distributed::GraphBrpcClient*)
paddle::distributed::PSClientFactory::create(worker_proto));
worker_ptr_->configure(worker_proto, dense_regions, _ps_env, 0);
worker_ptr_->set_shard_num(127);
worker_ptr_->set_local_channel(index);
worker_ptr_->set_local_graph_service(
(paddle::distributed::GraphBrpcService*)service);
}
void RunGraphSplit() {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
prepare_file(edge_file_name, edges);
prepare_file(node_file_name, nodes);
prepare_file(graph_split_file_name, graph_split);
auto ph_host = paddle::distributed::PSHost(ip_, port_, 0);
host_sign_list_.push_back(ph_host.serialize_to_string());
// test-start
auto ph_host2 = paddle::distributed::PSHost(ip2, port2, 1);
host_sign_list_.push_back(ph_host2.serialize_to_string());
// test-end
// Srart Server
std::thread* server_thread = new std::thread(RunServer);
std::thread* server_thread2 = new std::thread(RunServer2);
sleep(2);
std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
dense_regions.insert(
std::pair<uint64_t, std::vector<paddle::distributed::Region>>(0, {}));
auto regions = dense_regions[0];
RunClient(dense_regions, 0, pserver_ptr_->get_service());
/*-----------------------Test Server Init----------------------------------*/
auto pull_status = worker_ptr_->load_graph_split_config(
0, std::string(graph_split_file_name));
pull_status.wait();
pull_status =
worker_ptr_->load(0, std::string(edge_file_name), std::string("e>"));
srand(time(0));
pull_status.wait();
std::vector<std::vector<uint64_t>> _vs;
std::vector<std::vector<float>> vs;
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 10240001024), 4, _vs, vs, true);
pull_status.wait();
ASSERT_EQ(0, _vs[0].size());
_vs.clear();
vs.clear();
pull_status = worker_ptr_->batch_sample_neighbors(
0, std::vector<uint64_t>(1, 97), 4, _vs, vs, true);
pull_status.wait();
ASSERT_EQ(3, _vs[0].size());
std::remove(edge_file_name);
std::remove(node_file_name);
std::remove(graph_split_file_name);
LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server();
LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker();
}
TEST(RunGraphSplit, Run) { RunGraphSplit(); }
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册