提交 9833a00e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5450 Fixbugfix for server shard range computation

Merge pull request !5450 from ZPaC/r0.7-fix-local-shard-error
...@@ -149,10 +149,12 @@ void SparseOptimInfo::ComputeMean(const std::shared_ptr<std::vector<std::shared_ ...@@ -149,10 +149,12 @@ void SparseOptimInfo::ComputeMean(const std::shared_ptr<std::vector<std::shared_
size_t original_row_count = input_shapes->front(); size_t original_row_count = input_shapes->front();
if (original_row_count > 0) { if (original_row_count > 0) {
size_t offset = 0; size_t offset = 0;
if ((original_row_count % server_num) == 0) { std::map<int, int> rank_dims = Util::AllRankLocalShard(original_row_count, rank_id, server_num);
offset = original_row_count / server_num * rank_id; for (size_t i = 0; i < rank_id; i++) {
} else { if (rank_dims.count(i) == 0) {
offset = std::round((static_cast<float>(original_row_count)) / server_num) * rank_id; MS_LOG(EXCEPTION) << "No local shard number for rank " << i;
}
offset += rank_dims[i];
} }
for (size_t i = 0; i < indices_size; i++) { for (size_t i = 0; i < indices_size; i++) {
indices_data[i] -= offset; indices_data[i] -= offset;
......
...@@ -134,13 +134,33 @@ std::string Util::optimizer_node_name(int id) { ...@@ -134,13 +134,33 @@ std::string Util::optimizer_node_name(int id) {
bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; } bool Util::is_optimizer(std::string name) { return optimizer_to_ids.count(name) > 0; }
int Util::LocalShard(int first_dim, int rank_id, int server_num) { int Util::LocalShard(int first_dim, int rank_id, int server_num) {
int shard_size = std::round((static_cast<float>(first_dim)) / server_num); std::map<int, int> shard_dims = AllRankLocalShard(first_dim, rank_id, server_num);
int remain_size = first_dim % server_num; if (shard_dims.count(rank_id) == 0) {
if (remain_size == 0 || rank_id < server_num - 1) { MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id;
return shard_size; }
} else { return shard_dims[rank_id];
return first_dim - (shard_size * (server_num - 1)); }
std::map<int, int> Util::AllRankLocalShard(int first_dim, int rank_id, int server_num) {
if (rank_id >= server_num) {
MS_LOG(EXCEPTION) << "The rank ID " << rank_id << " should be less than the number of servers " << server_num;
}
std::map<int, int> shard_dims;
for (int i = 0; i < server_num; i++) {
shard_dims[i] = 0;
}
if (server_num != static_cast<int>(shard_dims.size())) {
MS_LOG(EXCEPTION) << "Inconsistent server num " << server_num << " shard dims counter size " << shard_dims.size();
}
int server_index = -1;
for (int i = 0; i < first_dim; i++) {
server_index = (server_index + 1) % server_num;
shard_dims[server_index] = shard_dims[server_index] + 1;
}
if (shard_dims.count(rank_id) == 0) {
MS_LOG(EXCEPTION) << "Invalid rank id " << rank_id << ", total server num " << server_num;
} }
return shard_dims;
} }
void Util::SetRankId(int rank_id) { rank_id_ = rank_id; } void Util::SetRankId(int rank_id) { rank_id_ = rank_id; }
......
...@@ -38,6 +38,7 @@ class Util { ...@@ -38,6 +38,7 @@ class Util {
static std::string optimizer_node_name(int id); static std::string optimizer_node_name(int id);
static bool is_optimizer(std::string name); static bool is_optimizer(std::string name);
static int LocalShard(int first_dim, int rank_id, int server_num); static int LocalShard(int first_dim, int rank_id, int server_num);
static std::map<int, int> AllRankLocalShard(int first_dim, int rank_id, int server_num);
static void SetRankId(int rank_id); static void SetRankId(int rank_id);
static int GetRankId(); static int GetRankId();
static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size, static void ReduceSparseGradient(float *gradients, int *indices, const size_t indices_size, size_t segment_size,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册