未验证 提交 750c6f42 编写于 作者: Y yaoxuefeng 提交者: GitHub

multi-loss optimization by adding a DownpourOpt worker (#22025) (#22638)

* update

* update test=develop

* update compile set test=develop

* update compile set test=develop

* update test=develop

* update test=develop

* update test=develop

* update compile setting test=develop

* update compile setting test=develop

* update run demo test=develop

* update test=develop

* update test=develop

* fix test=develop

* update test=develop

* update test=develop

* update test=develop

* update test=develop

* update test=develop

* update test=develop

* update test=develop

* update test=develop

* update test=develop

* update format test=develop

* update format test=develop

* update style test=develop

* update style test=develop

* change style test=develop

* change style test=develop

* change style test=develop

* add dataset unittest test=develop

* update test=develop

* update for record test=develop

* udpate style for record test=develop

* update for record test=develop

* update for record test=develop

* update for record test=develop

* fix format test=develop

* update test=develop

* update test=develop

* update test=develop

* update test=develop

* update test=develop
上级 c35413bf
......@@ -156,6 +156,11 @@ copy(inference_lib_dist
SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib)
set(dst_dir "${FLUID_INFERENCE_INSTALL_DIR}/third_party/threadpool")
copy(inference_lib_dist
SRCS ${THREADPOOL_INCLUDE_DIR}/ThreadPool.h
DSTS ${dst_dir})
copy(inference_lib_dist
SRCS ${CMAKE_CURRENT_BINARY_DIR}/CMakeCache.txt
DSTS ${FLUID_INFERENCE_INSTALL_DIR})
......
......@@ -189,7 +189,7 @@ cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc o
if(WITH_DISTRIBUTE)
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
......@@ -199,7 +199,7 @@ set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_CO
else()
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
......
......@@ -123,6 +123,12 @@ void DatasetImpl<T>::SetMergeByInsId(int merge_size) {
merge_size_ = merge_size;
}
template <typename T>
void DatasetImpl<T>::SetGenerateUniqueFeasign(bool gen_uni_feasigns) {
gen_uni_feasigns_ = gen_uni_feasigns;
VLOG(3) << "Set generate unique feasigns: " << gen_uni_feasigns;
}
template <typename T>
void DatasetImpl<T>::SetFeaEval(bool fea_eval, int record_candidate_size) {
slots_shuffle_fea_eval_ = fea_eval;
......@@ -640,6 +646,85 @@ int DatasetImpl<T>::ReceiveFromClient(int msg_type, int client_id,
// explicit instantiation
template class DatasetImpl<Record>;
void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num,
int consume_thread_num,
int shard_num) {
VLOG(3) << "MultiSlotDataset::GenerateUniqueFeasign begin";
if (!gen_uni_feasigns_) {
VLOG(3) << "generate_unique_feasign_=false, will not GenerateUniqueFeasign";
return;
}
CHECK(multi_output_channel_.size() != 0); // NOLINT
auto fleet_ptr_ = FleetWrapper::GetInstance();
std::vector<std::unordered_map<uint64_t, std::vector<float>>>&
local_map_tables = fleet_ptr_->GetLocalTable();
local_map_tables.resize(shard_num);
// read thread
int channel_num = multi_output_channel_.size();
if (read_thread_num < channel_num) {
read_thread_num = channel_num;
}
std::vector<std::thread> threads(read_thread_num);
consume_task_pool_.resize(consume_thread_num);
for (size_t i = 0; i < consume_task_pool_.size(); i++) {
consume_task_pool_[i].reset(new ::ThreadPool(1));
}
auto consume_func = [&local_map_tables](int shard_id, int feadim,
std::vector<uint64_t>& keys) {
for (auto k : keys) {
if (local_map_tables[shard_id].find(k) ==
local_map_tables[shard_id].end()) {
local_map_tables[shard_id][k] = std::vector<float>(feadim, 0);
}
}
};
auto gen_func = [this, &shard_num, &feadim, &local_map_tables,
&consume_func](int i) {
std::vector<Record> vec_data;
std::vector<std::vector<uint64_t>> task_keys(shard_num);
std::vector<std::future<void>> task_futures;
this->multi_output_channel_[i]->Close();
this->multi_output_channel_[i]->ReadAll(vec_data);
for (size_t j = 0; j < vec_data.size(); j++) {
for (auto& feature : vec_data[j].uint64_feasigns_) {
int shard = feature.sign().uint64_feasign_ % shard_num;
task_keys[shard].push_back(feature.sign().uint64_feasign_);
}
}
for (int shard_id = 0; shard_id < shard_num; shard_id++) {
task_futures.emplace_back(consume_task_pool_[shard_id]->enqueue(
consume_func, shard_id, feadim, task_keys[shard_id]));
}
multi_output_channel_[i]->Open();
multi_output_channel_[i]->Write(std::move(vec_data));
vec_data.clear();
vec_data.shrink_to_fit();
for (auto& tk : task_keys) {
tk.clear();
std::vector<uint64_t>().swap(tk);
}
task_keys.clear();
std::vector<std::vector<uint64_t>>().swap(task_keys);
for (auto& tf : task_futures) {
tf.wait();
}
};
for (size_t i = 0; i < threads.size(); i++) {
threads[i] = std::thread(gen_func, i);
}
for (std::thread& t : threads) {
t.join();
}
for (size_t i = 0; i < consume_task_pool_.size(); i++) {
consume_task_pool_[i].reset();
}
consume_task_pool_.clear();
fleet_ptr_->PullSparseToLocal(table_id, feadim);
}
void MultiSlotDataset::MergeByInsId() {
VLOG(3) << "MultiSlotDataset::MergeByInsId begin";
if (!merge_by_insid_) {
......
......@@ -14,12 +14,14 @@
#pragma once
#include <ThreadPool.h>
#include <fstream>
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <string>
#include <thread> // NOLINT
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -63,6 +65,7 @@ class Dataset {
virtual void SetParseContent(bool parse_content) = 0;
// set merge by ins id
virtual void SetMergeByInsId(int merge_size) = 0;
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
// set fea eval mode
virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
// get file list
......@@ -112,6 +115,11 @@ class Dataset {
virtual int64_t GetShuffleDataSize() = 0;
// merge by ins id
virtual void MergeByInsId() = 0;
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num,
int consume_thread_num,
int shard_num) = 0;
virtual void ClearLocalTables() = 0;
// create preload readers
virtual void CreatePreLoadReaders() = 0;
// destroy preload readers after prelaod done
......@@ -148,7 +156,7 @@ class DatasetImpl : public Dataset {
virtual void SetParseInsId(bool parse_ins_id);
virtual void SetParseContent(bool parse_content);
virtual void SetMergeByInsId(int merge_size);
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
virtual int GetThreadNum() { return thread_num_; }
......@@ -179,6 +187,11 @@ class DatasetImpl : public Dataset {
virtual int64_t GetMemoryDataSize();
virtual int64_t GetShuffleDataSize();
virtual void MergeByInsId() {}
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num,
int consume_thread_num,
int shard_num) {}
virtual void ClearLocalTables() {}
virtual void CreatePreLoadReaders();
virtual void DestroyPreLoadReaders();
virtual void SetPreLoadThreadNum(int thread_num);
......@@ -195,6 +208,7 @@ class DatasetImpl : public Dataset {
int channel_num_;
std::vector<paddle::framework::Channel<T>> multi_output_channel_;
std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
std::vector<std::unordered_set<uint64_t>> local_tables_;
// when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in output_channel, else consume_channel
......@@ -202,6 +216,7 @@ class DatasetImpl : public Dataset {
std::vector<T> slots_shuffle_original_data_;
RecordCandidateList slots_shuffle_rclist_;
int thread_num_;
int pull_sparse_to_local_thread_num_;
paddle::framework::DataFeedDesc data_feed_desc_;
int trainer_num_;
std::vector<std::string> filelist_;
......@@ -217,9 +232,11 @@ class DatasetImpl : public Dataset {
bool parse_content_;
size_t merge_size_;
bool slots_shuffle_fea_eval_ = false;
bool gen_uni_feasigns_ = false;
int preload_thread_num_;
std::mutex global_index_mutex_;
int64_t global_index_ = 0;
std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
};
// use std::vector<MultiSlotType> or Record as data type
......@@ -227,6 +244,16 @@ class MultiSlotDataset : public DatasetImpl<Record> {
public:
MultiSlotDataset() {}
virtual void MergeByInsId();
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
int read_thread_num,
int consume_thread_num, int shard_num);
virtual void ClearLocalTables() {
for (auto& t : local_tables_) {
t.clear();
std::unordered_set<uint64_t>().swap(t);
}
std::vector<std::unordered_set<uint64_t>>().swap(local_tables_);
}
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
std::vector<Record>* result);
......
......@@ -207,54 +207,80 @@ class DownpourWorker : public HogwildWorker {
void CopySparseTable();
void CopyDenseTable();
void CopyDenseVars();
private:
bool need_dump_param_;
std::vector<std::string> dump_param_;
bool need_to_push_dense_;
bool need_dump_field_;
bool dump_slot_;
bool need_to_push_sparse_;
std::vector<std::string> dump_fields_;
ChannelWriter<std::string> writer_;
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end);
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index);
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size);
DownpourWorkerParameter param_;
float scale_datanorm_;
// just save the value in param_ for easy access
std::map<uint64_t, std::string> label_var_name_;
std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
// copy table
CopyTableConfig copy_table_config_;
std::vector<std::pair<uint64_t, uint64_t>> copy_sparse_tables_;
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
// actually pushed feasign of each table
std::map<uint64_t, std::vector<uint64_t>> sparse_push_keys_;
std::map<uint64_t, std::vector<std::string>> sparse_key_names_;
// feasign
std::map<uint64_t, std::vector<uint64_t>> features_;
// feasign stats
std::map<uint64_t, std::vector<float>> feature_labels_;
// feasign embedding
std::map<uint64_t, std::vector<std::vector<float>>> feature_values_;
std::map<uint64_t, std::vector<std::string>> sparse_value_names_;
// adjust ins weight
AdjustInsWeightConfig adjust_ins_weight_config_;
// check nan and inf during training
std::vector<std::string> check_nan_var_names_;
bool need_to_push_sparse_;
// feasign stats
std::map<uint64_t, std::vector<float>> feature_labels_;
std::map<uint64_t, std::vector<std::string>> sparse_grad_names_;
// feasign embedding gradient
std::map<uint64_t, std::vector<std::vector<float>>> feature_grads_;
std::vector<::std::future<int32_t>> push_sparse_status_;
bool dump_slot_;
bool need_to_push_dense_;
bool need_dump_field_;
bool need_dump_param_;
std::map<uint64_t, std::vector<std::string>> dense_grad_names_;
float scale_datanorm_;
std::vector<::std::future<int32_t>> push_dense_status_;
std::vector<std::string> dump_fields_;
ChannelWriter<std::string> writer_;
// skipped ops
std::vector<std::string> skip_ops_;
std::vector<std::string> dump_param_;
// just save the value in param_ for easy access
std::map<uint64_t, std::string> label_var_name_;
std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::map<uint64_t, uint64_t> table_dependency_;
std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
private:
// std::vector<std::string> dump_param_;
// just save the value in param_ for easy access
// std::map<uint64_t, std::string> label_var_name_;
// std::map<uint64_t, std::vector<std::string>> dense_value_names_;
std::shared_ptr<PullDenseWorker> _pull_dense_worker;
std::vector<::std::future<int32_t>> push_sparse_status_;
std::vector<::std::future<int32_t>> push_dense_status_;
// adjust ins weight
AdjustInsWeightConfig adjust_ins_weight_config_;
std::vector<float> nid_show_;
// check nan and inf during training
std::vector<std::string> check_nan_var_names_;
// copy table
CopyTableConfig copy_table_config_;
std::map<uint64_t, uint64_t> table_dependency_;
std::vector<std::pair<uint64_t, uint64_t>> copy_sparse_tables_;
std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
// std::map<uint64_t, uint64_t> table_dependency_;
// std::vector<std::pair<uint64_t, uint64_t>> copy_dense_tables_;
};
class DownpourWorkerOpt : public DownpourWorker {
public:
DownpourWorkerOpt() {}
virtual ~DownpourWorkerOpt() {}
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
protected:
void CreateThreadOperatorsWithRerank(const ProgramDesc& program);
std::vector<std::vector<OperatorBase*>> loss_ops_;
std::vector<std::vector<std::string>> loss_op_names_;
std::vector<std::string> loss_names_;
std::string async_wait_name_;
int async_index_ = -1;
uint64_t async_tid_ = 0;
};
#if defined(PADDLE_WITH_NCCL)
......
......@@ -61,6 +61,7 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
REGISTER_DEVICE_WORKER_CLASS(HogwildWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorker);
REGISTER_DEVICE_WORKER_CLASS(DownpourWorkerOpt);
#if defined(PADDLE_WITH_NCCL)
REGISTER_DEVICE_WORKER_CLASS(SectionWorker);
#endif
......
......@@ -157,7 +157,8 @@ std::string PrintLodTensorIntType(LoDTensor* tensor, int64_t start,
return os.str();
}
std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) {
std::string DownpourWorker::PrintLodTensor(LoDTensor* tensor, int64_t start,
int64_t end) {
std::string out_val;
if (tensor->type() == proto::VarType::FP32) {
out_val = PrintLodTensorType<float>(tensor, start, end);
......@@ -171,7 +172,8 @@ std::string PrintLodTensor(LoDTensor* tensor, int64_t start, int64_t end) {
return out_val;
}
std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index) {
std::pair<int64_t, int64_t> DownpourWorker::GetTensorBound(LoDTensor* tensor,
int index) {
auto& dims = tensor->dims();
if (tensor->lod().size() != 0) {
auto& lod = tensor->lod()[0];
......@@ -181,7 +183,7 @@ std::pair<int64_t, int64_t> GetTensorBound(LoDTensor* tensor, int index) {
}
}
bool CheckValidOutput(LoDTensor* tensor, size_t batch_size) {
bool DownpourWorker::CheckValidOutput(LoDTensor* tensor, size_t batch_size) {
auto& dims = tensor->dims();
if (dims.size() != 2) return false;
if (tensor->lod().size() != 0) {
......
此差异已折叠。
......@@ -29,9 +29,12 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <algorithm>
#include <utility>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
......@@ -151,6 +154,151 @@ void FleetWrapper::CreateClient2ClientConnection() {
#endif
}
void FleetWrapper::PullSparseToLocal(const uint64_t table_id,
int fea_value_dim) {
#ifdef PADDLE_WITH_PSLIB
size_t fea_keys_size = local_tables_.size();
if (fea_keys_size == 0) {
return;
}
local_table_shard_num_ = fea_keys_size;
platform::Timer timeline;
std::vector<std::thread> threads(fea_keys_size);
auto ptl_func = [this, &table_id](int i) {
size_t key_size = this->local_tables_[i].size();
std::vector<uint64_t> keys;
keys.reserve(key_size);
std::vector<float*> pull_result_ptr;
pull_result_ptr.reserve(key_size);
for (auto& kv : this->local_tables_[i]) {
keys.emplace_back(kv.first);
pull_result_ptr.emplace_back(kv.second.data());
}
auto tt = pslib_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, keys.data(), key_size);
tt.wait();
auto status = tt.get();
if (status != 0) {
LOG(ERROR) << "fleet pull sparse failed, status[" << status << "]";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
} else {
VLOG(3) << "FleetWrapper Pull sparse to local done with table size: "
<< pull_result_ptr.size();
}
};
timeline.Start();
for (size_t i = 0; i < threads.size(); i++) {
threads[i] = std::thread(ptl_func, i);
}
for (std::thread& t : threads) {
t.join();
}
local_pull_pool_.reset(new ::ThreadPool(pull_local_thread_num_));
timeline.Pause();
#endif
}
void FleetWrapper::PullSparseVarsFromLocal(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_value_dim) {
#ifdef PADDLE_WITH_PSLIB
fea_keys->clear();
fea_keys->resize(0);
fea_keys->reserve(MAX_FEASIGN_NUM);
for (auto name : var_names) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel();
for (auto i = 0u; i < len; ++i) {
if (ids[i] == 0u) {
continue;
}
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
}
}
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
size_t key_length = fea_keys->size();
int local_step = key_length / pull_local_thread_num_;
std::vector<std::future<void>> task_futures;
task_futures.reserve(key_length / local_step + 1);
for (size_t i = 0; i < key_length; i += local_step) {
size_t end = i + local_step < key_length ? i + local_step : key_length;
auto pull_local_task = [this, i, end, &fea_values, &fea_keys,
&fea_value_dim] {
for (size_t j = i; j < end; j++) {
std::memcpy((*fea_values)[j].data(),
local_tables_[(*fea_keys)[j] % local_table_shard_num_]
[(*fea_keys)[j]]
.data(),
fea_value_dim * sizeof(float));
}
};
task_futures.emplace_back(
local_pull_pool_->enqueue(std::move(pull_local_task)));
}
for (auto& tf : task_futures) {
tf.wait();
}
#endif
}
void FleetWrapper::ClearLocalTable() {
#ifdef PADDLE_WITH_PSLIB
for (auto& t : local_tables_) {
t.clear();
}
#endif
}
std::future<int32_t> FleetWrapper::PullSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_value_dim) {
#ifdef PADDLE_WITH_PSLIB
fea_keys->clear();
fea_keys->resize(0);
fea_keys->reserve(MAX_FEASIGN_NUM);
for (auto name : var_names) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel();
for (auto i = 0u; i < len; ++i) {
if (ids[i] == 0u) {
continue;
}
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
}
}
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
return pslib_ptr_->_worker_ptr->pull_sparse(
pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
#endif
return std::future<int32_t>();
}
void FleetWrapper::PullSparseVarsSync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
......
......@@ -19,12 +19,15 @@ limitations under the License. */
#include <archive.h>
#include <pslib.h>
#endif
#include <ThreadPool.h>
#include <atomic>
#include <ctime>
#include <map>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
......@@ -65,12 +68,16 @@ class FleetWrapper {
client2client_connect_timeout_ms_ = 10000;
// pslib request max retry
client2client_max_retry_ = 3;
pull_local_thread_num_ = 25;
}
// set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry);
void SetPullLocalThreadNum(int thread_num) {
pull_local_thread_num_ = thread_num;
}
// Pull sparse variables from server in sync mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim
// Param<out>: fea_values
......@@ -80,7 +87,11 @@ class FleetWrapper {
std::vector<std::vector<float>>* fea_values,
int fea_dim,
const std::vector<std::string>& var_emb_names);
std::future<int32_t> PullSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_dim);
// pull dense variables from server in sync mod
void PullDenseVarsSync(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names);
......@@ -111,6 +122,18 @@ class FleetWrapper {
const std::vector<std::string>& var_names);
// Push sparse variables with labels to server in async mode
std::vector<std::unordered_map<uint64_t, std::vector<float>>> local_tables_;
void PullSparseToLocal(const uint64_t table_id, int fea_value_dim);
void PullSparseVarsFromLocal(const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values,
int fea_value_dim);
void ClearLocalTable();
std::vector<std::unordered_map<uint64_t, std::vector<float>>>&
GetLocalTable() {
return local_tables_;
}
// This is specially designed for click/show stats in server
// Param<in>: scope, table_id, fea_keys, fea_labels, sparse_key_names,
// sparse_grad_names, batch_size, use_cvm, dump_slot
......@@ -237,6 +260,10 @@ class FleetWrapper {
int client2client_request_timeout_ms_;
int client2client_connect_timeout_ms_;
int client2client_max_retry_;
std::unique_ptr<::ThreadPool> local_pull_pool_{nullptr};
int pull_local_thread_num_;
std::unique_ptr<::ThreadPool> pull_to_local_pool_{nullptr};
int local_table_shard_num_;
DISABLE_COPY_AND_ASSIGN(FleetWrapper);
};
......
......@@ -48,6 +48,7 @@ message TrainerDesc {
optional AdjustInsWeightConfig adjust_ins_weight_config = 20;
optional bool no_cvm = 21 [ default = false ];
optional bool thread_barrier = 22;
repeated string loss_names = 23;
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
......@@ -164,4 +165,9 @@ message TableParameter {
optional int32 emb_dim = 10;
optional int32 fea_dim = 11;
optional string label_var_name = 12;
// if table will pull sparse to local first
optional bool is_local = 13 [ default = false ];
// if table will pull sparse asynchronously in worker
optional bool is_async = 14 [ default = false ];
optional string async_wait_op_name = 15;
}
......@@ -247,6 +247,12 @@ void BindDataset(py::module *m) {
py::call_guard<py::gil_scoped_release>())
.def("merge_by_lineid", &framework::Dataset::MergeByInsId,
py::call_guard<py::gil_scoped_release>())
.def("set_generate_unique_feasigns",
&framework::Dataset::SetGenerateUniqueFeasign,
py::call_guard<py::gil_scoped_release>())
.def("generate_local_tables_unlock",
&framework::Dataset::GenerateLocalTablesUnlock,
py::call_guard<py::gil_scoped_release>())
.def("slots_shuffle", &framework::Dataset::SlotsShuffle,
py::call_guard<py::gil_scoped_release>())
.def("set_fea_eval", &framework::Dataset::SetFeaEval,
......
......@@ -75,6 +75,8 @@ void BindFleetWrapper(py::module* m) {
.def("load_model_one_table", &framework::FleetWrapper::LoadModelOneTable)
.def("set_client2client_config",
&framework::FleetWrapper::SetClient2ClientConfig)
.def("set_pull_local_thread_num",
&framework::FleetWrapper::SetPullLocalThreadNum)
.def("copy_table", &framework::FleetWrapper::CopyTable)
.def("copy_table_by_feasign",
&framework::FleetWrapper::CopyTableByFeasign);
......
......@@ -20,6 +20,7 @@ include_directories("${PADDLE_LIB}/third_party/install/zlib/include")
include_directories("${PADDLE_LIB}/third_party/boost")
include_directories("${PADDLE_LIB}/third_party/eigen3")
include_directories("${PADDLE_LIB}/third_party/threadpool")
include_directories("${PADDLE_LIB}/third_party/dlpack")
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
......
......@@ -20,6 +20,7 @@ include_directories("${PADDLE_LIB}/third_party/install/zlib/include")
include_directories("${PADDLE_LIB}/third_party/boost")
include_directories("${PADDLE_LIB}/third_party/eigen3")
include_directories("${PADDLE_LIB}/third_party/threadpool")
include_directories("${PADDLE_LIB}/third_party/dlpack")
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
......
......@@ -428,6 +428,16 @@ class InMemoryDataset(DatasetBase):
self.merge_by_lineid = True
self.parse_ins_id = True
def set_generate_unique_feasigns(self, generate_uni_feasigns, shard_num):
self.dataset.set_generate_unique_feasigns(generate_uni_feasigns)
self.gen_uni_feasigns = generate_uni_feasigns
self.local_shard_num = shard_num
def generate_local_tables_unlock(self, table_id, fea_dim, read_thread_num,
consume_thread_num, shard_num):
self.dataset.generate_local_tables_unlock(
table_id, fea_dim, read_thread_num, consume_thread_num, shard_num)
def load_into_memory(self):
"""
Load data into memory
......
......@@ -13,7 +13,9 @@
# limitations under the License.
"""Defination of device workers."""
__all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD', 'Section']
__all__ = [
'DeviceWorker', 'Hogwild', 'DownpourSGD', 'Section', 'DownpourSGDOPT'
]
class DeviceWorker(object):
......@@ -190,6 +192,112 @@ class DownpourSGD(DeviceWorker):
downpour.push_sparse = False
class DownpourSGDOPT(DeviceWorker):
"""
DownpourSGDOPT is a kind of distributed SGD algorithm.
"""
def __init__(self):
"""
Init.
initialize downpourSGDOPT device worker
"""
super(DownpourSGDOPT, self).__init__()
def _gen_worker_desc(self, trainer_desc):
"""
Generator worker desc, which device worker is DownpourWorker.
Args:
trainer_desc(TrainerDesc): a TrainerDesc object
"""
dense_table_set = set()
program_id = str(id(self._program))
if self._program == None:
print("program of current device worker is not configured")
exit(-1)
opt_info = self._program._fleet_opt
program_configs = opt_info["program_configs"]
downpour = trainer_desc.downpour_param
for pid in program_configs:
if pid == program_id:
pc = downpour.program_config.add()
pc.program_id = program_id
for i in program_configs[program_id]["push_sparse"]:
pc.push_sparse_table_id.extend([i])
for i in program_configs[program_id]["push_dense"]:
pc.push_dense_table_id.extend([i])
dense_table_set.add(i)
for i in program_configs[program_id]["pull_sparse"]:
pc.pull_sparse_table_id.extend([i])
for i in program_configs[program_id]["pull_dense"]:
pc.pull_dense_table_id.extend([i])
dense_table_set.add(i)
break
trainer_desc.device_worker_name = "DownpourWorkerOpt"
pull_thread = trainer_desc.pull_dense_param
pull_thread.device_num = trainer_desc.thread_num
if opt_info.get("program_id_to_worker") is None:
raise ValueError("opt_info must have program_id_to_worker")
prog_id_to_worker = opt_info["program_id_to_worker"]
if prog_id_to_worker.get(program_id) is None:
raise ValueError("%s not found in program_id_to_worker" %
program_id)
worker = opt_info["program_id_to_worker"][program_id]
for i in worker.get_desc().dense_table:
if i.table_id in dense_table_set:
dense_table = pull_thread.dense_table.add()
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.table_id = \
i.table_id
sparse_len = len(worker.get_desc().sparse_table)
for i in range(sparse_len):
sparse_table = downpour.sparse_table.add()
sparse_table.table_id = worker.get_desc().sparse_table[i].table_id
sparse_table.sparse_key_name.extend(worker.get_desc().sparse_table[
i].slot_key)
sparse_table.sparse_value_name.extend(worker.get_desc()
.sparse_table[i].slot_value)
sparse_table.sparse_grad_name.extend(worker.get_desc().sparse_table[
i].slot_gradient)
if opt_info["use_cvm"] or "no_cvm" in opt_info and opt_info[
"no_cvm"] == True:
sparse_table.emb_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
i].accessor.fea_dim
sparse_table.fea_dim = sparse_table.emb_dim
else:
sparse_table.emb_dim = \
self._fleet_desc.server_param.downpour_server_param.downpour_table_param[
i].accessor.fea_dim - 2
sparse_table.fea_dim = sparse_table.emb_dim + 2
# TODO(guru4elephant): hard code here, need to improve
sparse_table.label_var_name = "click"
if "local_tables" in opt_info and sparse_table.table_id in opt_info[
"local_tables"]:
sparse_table.is_local = True
if "async_tables" in opt_info and sparse_table.table_id in opt_info[
"async_tables"]:
sparse_table.is_async = True
if opt_info["stat_var_names"]:
for i in opt_info["stat_var_names"]:
downpour.stat_var_names.extend([i])
for i in worker.get_desc().dense_table:
if i.table_id in dense_table_set:
dense_table = downpour.dense_table.add()
dense_table.table_id = i.table_id
dense_table.dense_value_name.extend(i.dense_variable_name)
dense_table.dense_grad_name.extend(
i.dense_gradient_variable_name)
downpour.skip_ops.extend(worker.get_desc().skip_op)
if self._infer:
downpour.push_dense = False
downpour.push_sparse = False
class Section(DeviceWorker):
"""SectionWorker."""
......
......@@ -51,6 +51,9 @@ class PSLib(Fleet):
self._client2client_connect_timeout_ms = connect_timeout_ms
self._client2client_max_retry = max_retry
def set_pull_local_thread_num(self, thread_num):
self._fleet_ptr.set_pull_local_thread_num(thread_num)
def init_worker(self):
"""
init_worker(): will be called by user. When a user knows current process is_server(), he/she
......
......@@ -182,41 +182,49 @@ class DistributedAdam(DistributedOptimizerImplBase):
prog_id_to_param_grads = OrderedDict()
# sparse_grads of each program
prog_id_to_sparse_grads = OrderedDict()
# unique program set
program_id_set = set()
sparse_table_to_index = OrderedDict()
sparse_table_index = 0
for loss in losses:
sparse_table = self._find_multi_distributed_lookup_table([loss])
prog_id = str(id(loss.block.program))
prog_id_to_sparse_table[prog_id] = sparse_table
# get sparse_table_to_index
for tn in sparse_table:
if sparse_table_to_index.get(tn) is None:
sparse_table_to_index[tn] = sparse_table_index
sparse_table_index += 1
# get inputs_dict
inputs_dict = self._find_distributed_lookup_table_inputs(
loss.block.program, sparse_table)
prog_id_to_inputs_dict[prog_id] = inputs_dict
# get outputs_dict
outputs_dict = self._find_distributed_lookup_table_outputs(
loss.block.program, sparse_table)
prog_id_to_outputs_dict[prog_id] = outputs_dict
prog_id_to_worker[prog_id] = DownpourWorker(self._window)
if prog_id not in program_id_set:
program_id_set.add(prog_id)
sparse_table = self._find_multi_distributed_lookup_table([loss])
prog_id_to_sparse_table[prog_id] = sparse_table
# get sparse_table_to_index
for tn in sparse_table:
if sparse_table_to_index.get(tn) is None:
sparse_table_to_index[tn] = sparse_table_index
sparse_table_index += 1
# get inputs_dict
inputs_dict = self._find_distributed_lookup_table_inputs(
loss.block.program, sparse_table)
prog_id_to_inputs_dict[prog_id] = inputs_dict
# get outputs_dict
outputs_dict = self._find_distributed_lookup_table_outputs(
loss.block.program, sparse_table)
prog_id_to_outputs_dict[prog_id] = outputs_dict
prog_id_to_worker[prog_id] = DownpourWorker(self._window)
grads_dict = self._find_distributed_lookup_table_grads(
loss.block.program, sparse_table)
prog_id_to_sparse_grads[prog_id] = grads_dict
# param_grads of program
params_grads = sorted(
fluid.backward.append_backward(loss, parameter_list,
no_grad_set),
key=lambda x: x[0].name)
prog_id_to_param_grads[prog_id] = params_grads
if prog_id not in prog_id_to_param_grads:
prog_id_to_param_grads[prog_id] = []
prog_id_to_param_grads[prog_id].append(params_grads)
grads_dict = self._find_distributed_lookup_table_grads(
loss.block.program, sparse_table)
prog_id_to_sparse_grads[prog_id] = grads_dict
#if strategy.get("parallel_compute")
# if user specify a fleet_desc.prototxt file, then load the file
# instead of creating default fleet_desc.prototxt.
......@@ -251,90 +259,109 @@ class DistributedAdam(DistributedOptimizerImplBase):
server.add_sparse_table(sparse_table_index, None)
# each DownpourTrainerParameter add its own sparse tables
program_id_set.clear()
for loss in losses:
prog_id = str(id(loss.block.program))
worker = prog_id_to_worker[prog_id]
inputs_dict = prog_id_to_inputs_dict[prog_id]
outputs_dict = prog_id_to_outputs_dict[prog_id]
for tn in prog_id_to_sparse_table[prog_id]:
sparse_table_index = sparse_table_to_index[tn]
grads_dict = prog_id_to_sparse_grads[prog_id]
worker.add_sparse_table(sparse_table_index, inputs_dict[tn],
outputs_dict[tn], grads_dict[tn])
if prog_id not in program_id_set:
program_id_set.add(prog_id)
worker = prog_id_to_worker[prog_id]
inputs_dict = prog_id_to_inputs_dict[prog_id]
outputs_dict = prog_id_to_outputs_dict[prog_id]
for tn in prog_id_to_sparse_table[prog_id]:
sparse_table_index = sparse_table_to_index[tn]
grads_dict = prog_id_to_sparse_grads[prog_id]
worker.add_sparse_table(sparse_table_index, inputs_dict[tn],
outputs_dict[tn], grads_dict[tn])
dense_start_table_id = len(sparse_table_to_index)
dense_table_index = len(sparse_table_to_index)
program_configs = {}
# ServerParameter add all dense tables
# each DownpourTrainerParameter add its own dense tables
program_id_set.clear()
for loss_index in range(len(losses)):
program_id = str(id(losses[loss_index].block.program))
worker = prog_id_to_worker[program_id]
sparse_table_names = prog_id_to_sparse_table[program_id]
sparse_table_index = \
[sparse_table_to_index[i] for i in sparse_table_names]
program_configs[program_id] = {
"pull_sparse": [t_index for t_index in sparse_table_index],
"push_sparse": [t_index for t_index in sparse_table_index]
}
params_grads = prog_id_to_param_grads[program_id]
params = []
grads = []
data_norm_params = []
data_norm_grads = []
for i in params_grads:
is_data_norm_data = False
for data_norm_name in self.data_norm_name:
if i[0].name.endswith(data_norm_name):
is_data_norm_data = True
data_norm_params.append(i[0])
if not is_data_norm_data:
params.append(i[0])
for i in params_grads:
is_data_norm_data = False
for data_norm_grad in self.data_norm_name:
if i[0].name.endswith(data_norm_grad):
is_data_norm_data = True
data_norm_grads.append(i[1])
if not is_data_norm_data:
grads.append(i[1])
if strategy.get('dense_table') is not None:
server.add_dense_table(dense_table_index, params, grads,
strategy['dense_table'],
sparse_table_names)
else:
server.add_dense_table(dense_table_index, params, grads, None,
sparse_table_names)
worker.add_dense_table(dense_table_index, self._learning_rate,
params, grads, dense_start_table_id,
sparse_table_names)
program_configs[program_id]["pull_dense"] = [dense_table_index]
program_configs[program_id]["push_dense"] = [dense_table_index]
if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
dense_table_index += 1
if strategy.get('datanorm_table') is not None:
server.add_data_norm_table(
dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads,
strategy['datanorm_table'], sparse_table_names)
else:
server.add_data_norm_table(
dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads, None,
sparse_table_names)
worker.add_dense_table(dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads,
dense_start_table_id, sparse_table_names)
program_configs[program_id]["pull_dense"].extend(
[dense_table_index])
program_configs[program_id]["push_dense"].extend(
[dense_table_index])
dense_table_index += 1
if program_id not in program_id_set:
program_id_set.add(program_id)
worker = prog_id_to_worker[program_id]
sparse_table_names = prog_id_to_sparse_table[program_id]
sparse_table_index = \
[sparse_table_to_index[i] for i in sparse_table_names]
program_configs[program_id] = {
"pull_sparse": [t_index for t_index in sparse_table_index],
"push_sparse": [t_index for t_index in sparse_table_index]
}
params_grads = prog_id_to_param_grads[program_id]
for pg in params_grads:
params = []
grads = []
data_norm_params = []
data_norm_grads = []
for i in pg:
is_data_norm_data = False
for data_norm_name in self.data_norm_name:
if i[0].name.endswith(data_norm_name):
is_data_norm_data = True
data_norm_params.append(i[0])
if not is_data_norm_data:
params.append(i[0])
for i in pg:
is_data_norm_data = False
for data_norm_grad in self.data_norm_name:
if i[0].name.endswith(data_norm_grad):
is_data_norm_data = True
data_norm_grads.append(i[1])
if not is_data_norm_data:
grads.append(i[1])
if strategy.get('dense_table') is not None:
server.add_dense_table(dense_table_index, params, grads,
strategy['dense_table'],
sparse_table_names)
else:
server.add_dense_table(dense_table_index, params, grads,
None, sparse_table_names)
worker.add_dense_table(
dense_table_index, self._learning_rate, params, grads,
dense_start_table_id, sparse_table_names)
if "pull_dense" in program_configs[
program_id] and "push_dense" in program_configs[
program_id] and len(program_configs[program_id][
"pull_dense"]) > 0:
program_configs[program_id]["pull_dense"].extend(
[dense_table_index])
program_configs[program_id]["push_dense"].extend(
[dense_table_index])
else:
program_configs[program_id][
"pull_dense"] = [dense_table_index]
program_configs[program_id][
"push_dense"] = [dense_table_index]
if len(data_norm_params) != 0 and len(data_norm_grads) != 0:
dense_table_index += 1
if strategy.get('datanorm_table') is not None:
server.add_data_norm_table(
dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads,
strategy['datanorm_table'], sparse_table_names)
else:
server.add_data_norm_table(
dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads, None,
sparse_table_names)
worker.add_dense_table(
dense_table_index, self._learning_rate,
data_norm_params, data_norm_grads,
dense_start_table_id, sparse_table_names)
program_configs[program_id]["pull_dense"].extend(
[dense_table_index])
program_configs[program_id]["push_dense"].extend(
[dense_table_index])
dense_table_index += 1
# Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table
......@@ -370,13 +397,16 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["program_id_to_worker"] = prog_id_to_worker
opt_info["program_configs"] = program_configs
opt_info["trainer"] = "DistMultiTrainer"
opt_info["device_worker"] = "DownpourSGD"
opt_info["device_worker"] = strategy.get("device_worker", "DownpourSGD")
opt_info["optimizer"] = "DownpourSGD"
opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops
opt_info["use_cvm"] = strategy.get("use_cvm", False)
opt_info["no_cvm"] = strategy.get("no_cvm", False)
opt_info["stat_var_names"] = strategy.get("stat_var_names", [])
opt_info["local_tables"] = strategy.get("local_tables", [])
opt_info["async_tables"] = strategy.get("async_tables", [])
opt_info["async_tables"] = strategy.get("async_tables", [])
opt_info["scale_datanorm"] = strategy.get("scale_datanorm", -1)
opt_info["check_nan_var_names"] = strategy.get("check_nan_var_names",
[])
......@@ -391,6 +421,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_slot"] = True
opt_info["adjust_ins_weight"] = strategy.get("adjust_ins_weight", {})
opt_info["copy_table"] = strategy.get("copy_table", {})
opt_info["loss_names"] = strategy.get("loss_names", [])
for loss in losses:
loss.block.program._fleet_opt = opt_info
......
......@@ -177,7 +177,8 @@ class TestDataset(unittest.TestCase):
dataset.set_fea_eval(10000, True)
dataset.slots_shuffle(["slot1"])
dataset.local_shuffle()
dataset.set_generate_unique_feasigns(True, 15)
dataset.generate_local_tables_unlock(0, 11, 1, 25, 15)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
if self.use_data_loader:
......
......@@ -25,7 +25,7 @@ import unittest
import sys
from op_test import OpTest
from paddle.fluid.trainer_desc import DistMultiTrainer
from paddle.fluid.device_worker import DownpourSGD
from paddle.fluid.device_worker import DownpourSGD, DownpourSGDOPT
from paddle.fluid.incubate.fleet.parameter_server.pslib.node import DownpourWorker
from google.protobuf import text_format
import paddle.fluid.incubate.fleet.parameter_server.pslib.ps_pb2 as pslib
......@@ -157,6 +157,66 @@ class TestListenAndServOp(unittest.TestCase):
cmd = "rm fleet_desc.prototxt*"
os.system(cmd)
def test_downpour_opt_work(self):
"""test devicve worker."""
if sys.platform == 'win32' or sys.platform == 'sys.platform':
pass
else:
print(sys.platform)
cmd = "wget --no-check-certificate https://pslib.bj.bcebos.com/fleet_desc.prototxt"
os.system(cmd)
x = fluid.layers.data(name='x', shape=[1], dtype='int64')
x_emb = fluid.layers.embedding(
input=x, size=[1, 2], is_distributed=True)
y_predict = fluid.layers.fc(input=x_emb, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
ps_param = pslib.PSParameter()
with open("fleet_desc.prototxt") as f:
text_format.Merge(f.read(), ps_param)
fleet_desc = ps_param
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
opt_info = {}
main_program = fluid.default_main_program()
program_id = str(id(avg_cost.block.program))
program_configs = {}
program_configs[program_id] = {
"pull_sparse": [0],
"push_sparse": [0]
}
program_configs[program_id]["pull_dense"] = [1]
program_configs[program_id]["push_dense"] = [1]
worker_skipped_ops = ["lookup_table", "lookup_table_grad"]
opt_info["program_configs"] = program_configs
opt_info["trainer"] = "DistMultiTrainer"
opt_info["device_worker"] = "DownpourSGDOPT"
opt_info["optimizer"] = "DownpourSGD"
opt_info["fleet_desc"] = ps_param
opt_info["worker_skipped_ops"] = worker_skipped_ops
opt_info["use_cvm"] = False
opt_info["scale_datanorm"] = -1
opt_info["dump_slot"] = False
opt_info["stat_var_names"] = []
worker = DownpourWorker(None)
worker.get_desc().CopyFrom(ps_param.trainer_param[0])
opt_info["program_id_to_worker"] = {program_id: worker}
main_program._fleet_opt = opt_info
trainer = DistMultiTrainer()
trainer._set_program(main_program)
device_worker = DownpourSGDOPT()
device_worker._set_fleet_desc(fleet_desc)
trainer._set_device_worker(device_worker)
trainer._set_fleet_desc(fleet_desc)
trainer._gen_trainer_desc()
cmd = "rm fleet_desc.prototxt*"
os.system(cmd)
if __name__ == "__main__":
unittest.main()
......@@ -115,6 +115,10 @@ class TrainerDesc(object):
for var in check_nan_var_names:
self.proto_desc.check_nan_var_names.append(var)
def _set_loss_names(self, loss_names):
for loss in loss_names:
self.proto_desc.loss_names.append(loss)
def _set_adjust_ins_weight(self, config_dict):
self.proto_desc.adjust_ins_weight_config.need_adjust = \
config_dict.get("need_adjust", False)
......
......@@ -23,7 +23,7 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
local_logger = logging.getLogger(__name__)
from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer
from .device_worker import Hogwild, DownpourSGD, Section
from .device_worker import Hogwild, DownpourSGD, Section, DownpourSGDOPT
from .framework import Variable
from multiprocessing import Process, Manager
......@@ -86,6 +86,8 @@ class TrainerFactory(object):
"check_nan_var_names"])
if opt_info.get("dump_param") is not None:
trainer._set_dump_param(opt_info["dump_param"])
if opt_info.get("loss_names") is not None:
trainer._set_loss_names(opt_info["loss_names"])
trainer._set_device_worker(device_worker)
return trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册