未验证 提交 80def0d0 编写于 作者: C cqy123456 提交者: GitHub

Allow more choices for the parameter ‘m’ of CPU IVF_PQ #3254 (#3508)

* ivfpq fixed m problem
Signed-off-by: Ncqy <yaya645@126.com>

* ivfpq fixed m problem
Signed-off-by: Ncqy <yaya645@126.com>
上级 524f1024
...@@ -69,15 +69,8 @@ ExecutionEngineImpl::ExecutionEngineImpl(const std::string& dir_root, const Segm ...@@ -69,15 +69,8 @@ ExecutionEngineImpl::ExecutionEngineImpl(const std::string& dir_root, const Segm
} }
knowhere::VecIndexPtr knowhere::VecIndexPtr
ExecutionEngineImpl::CreateVecIndex(const std::string& index_name) { ExecutionEngineImpl::CreateVecIndex(const std::string& index_name, knowhere::IndexMode mode) {
knowhere::VecIndexFactory& vec_index_factory = knowhere::VecIndexFactory::GetInstance(); knowhere::VecIndexFactory& vec_index_factory = knowhere::VecIndexFactory::GetInstance();
knowhere::IndexMode mode = knowhere::IndexMode::MODE_CPU;
#ifdef MILVUS_GPU_VERSION
if (gpu_enable_) {
mode = knowhere::IndexMode::MODE_GPU;
}
#endif
knowhere::VecIndexPtr index = vec_index_factory.CreateVecIndex(index_name, mode); knowhere::VecIndexPtr index = vec_index_factory.CreateVecIndex(index_name, mode);
if (index == nullptr) { if (index == nullptr) {
std::string err_msg = std::string err_msg =
...@@ -216,9 +209,13 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id) { ...@@ -216,9 +209,13 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id) {
for (auto& pair : indice) { for (auto& pair : indice) {
if (pair.second != nullptr) { if (pair.second != nullptr) {
auto gpu_index = knowhere::cloner::CopyCpuToGpu(pair.second, device_id, knowhere::Config()); auto gpu_index = knowhere::cloner::CopyCpuToGpu(pair.second, device_id, knowhere::Config());
if (gpu_index == nullptr) {
new_map.insert(pair);
} else {
new_map.insert(std::make_pair(pair.first, gpu_index)); new_map.insert(std::make_pair(pair.first, gpu_index));
} }
} }
}
indice.swap(new_map); indice.swap(new_map);
gpu_num_ = device_id; gpu_num_ = device_id;
...@@ -771,12 +768,6 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col ...@@ -771,12 +768,6 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col
throw Exception(DB_ERROR, "ExecutionEngineImpl: from_index is not IDMAP"); throw Exception(DB_ERROR, "ExecutionEngineImpl: from_index is not IDMAP");
} }
// build index by knowhere
new_index = CreateVecIndex(index_info.index_type_);
if (!new_index) {
throw Exception(DB_ERROR, "Unsupported index type");
}
auto segment_visitor = segment_reader_->GetSegmentVisitor(); auto segment_visitor = segment_reader_->GetSegmentVisitor();
auto& snapshot = segment_visitor->GetSnapshot(); auto& snapshot = segment_visitor->GetSnapshot();
auto& segment = segment_visitor->GetSegment(); auto& segment = segment_visitor->GetSegment();
...@@ -795,10 +786,26 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col ...@@ -795,10 +786,26 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col
conf[knowhere::meta::DEVICEID] = gpu_num_; conf[knowhere::meta::DEVICEID] = gpu_num_;
conf[knowhere::Metric::TYPE] = index_info.metric_name_; conf[knowhere::Metric::TYPE] = index_info.metric_name_;
LOG_ENGINE_DEBUG_ << "Index params: " << conf.dump(); LOG_ENGINE_DEBUG_ << "Index params: " << conf.dump();
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(new_index->index_type());
if (!adapter->CheckTrain(conf, new_index->index_mode())) { knowhere::IndexMode mode = knowhere::IndexMode::MODE_CPU;
#ifdef MILVUS_GPU_VERSION
if (gpu_enable_) {
mode = knowhere::IndexMode::MODE_GPU;
}
if (index_info.index_type_ == milvus::knowhere::IndexEnum::INDEX_FAISS_IVFPQ) {
auto m = conf[knowhere::IndexParams::m].get<int64_t>();
knowhere::IVFPQConfAdapter::GetValidM(dimension, m, mode);
}
#endif
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(index_info.index_type_);
if (!adapter->CheckTrain(conf, mode)) {
throw Exception(DB_ERROR, "Illegal index params"); throw Exception(DB_ERROR, "Illegal index params");
} }
// build index by knowhere
new_index = CreateVecIndex(index_info.index_type_, mode);
if (!new_index) {
throw Exception(DB_ERROR, "Unsupported index type");
}
LOG_ENGINE_DEBUG_ << "Index config: " << conf.dump(); LOG_ENGINE_DEBUG_ << "Index config: " << conf.dump();
std::vector<idx_t> uids; std::vector<idx_t> uids;
......
...@@ -46,7 +46,7 @@ class ExecutionEngineImpl : public ExecutionEngine { ...@@ -46,7 +46,7 @@ class ExecutionEngineImpl : public ExecutionEngine {
knowhere::VecIndexPtr& vec_index, bool hybrid = false); knowhere::VecIndexPtr& vec_index, bool hybrid = false);
knowhere::VecIndexPtr knowhere::VecIndexPtr
CreateVecIndex(const std::string& index_name); CreateVecIndex(const std::string& index_name, knowhere::IndexMode mode);
Status Status
CreateStructuredIndex(const engine::DataType field_type, engine::BinaryDataPtr& raw_data, CreateStructuredIndex(const engine::DataType field_type, engine::BinaryDataPtr& raw_data,
......
...@@ -146,24 +146,34 @@ IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { ...@@ -146,24 +146,34 @@ IVFPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
// auto tune params // auto tune params
oricfg[knowhere::IndexParams::nlist] = oricfg[knowhere::IndexParams::nlist] =
MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), oricfg[knowhere::IndexParams::nlist].get<int64_t>()); MatchNlist(oricfg[knowhere::meta::ROWS].get<int64_t>(), oricfg[knowhere::IndexParams::nlist].get<int64_t>());
auto m = oricfg[knowhere::IndexParams::m].get<int64_t>();
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
// Best Practice // Best Practice
// static int64_t MIN_POINTS_PER_CENTROID = 40; // static int64_t MIN_POINTS_PER_CENTROID = 40;
// static int64_t MAX_POINTS_PER_CENTROID = 256; // static int64_t MAX_POINTS_PER_CENTROID = 256;
// CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist); // CheckIntByRange(knowhere::meta::ROWS, MIN_POINTS_PER_CENTROID * nlist, MAX_POINTS_PER_CENTROID * nlist);
std::vector<int64_t> resset; /*std::vector<int64_t> resset;
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>(); IVFPQConfAdapter::GetValidCPUM(dimension, resset);*/
IVFPQConfAdapter::GetValidMList(dimension, resset); IndexMode ivfpq_mode = mode;
return GetValidM(dimension, m, ivfpq_mode);
CheckIntByValues(knowhere::IndexParams::m, resset); }
bool
IVFPQConfAdapter::GetValidM(int64_t dimension, int64_t m, IndexMode& mode) {
#ifdef MILVUS_GPU_VERSION
if (mode == knowhere::IndexMode::MODE_GPU && !IVFPQConfAdapter::GetValidGPUM(dimension, m)) {
mode = knowhere::IndexMode::MODE_CPU;
}
#endif
if (mode == knowhere::IndexMode::MODE_CPU && !IVFPQConfAdapter::GetValidCPUM(dimension, m)) {
return false;
}
return true; return true;
} }
void bool
IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset) { IVFPQConfAdapter::GetValidGPUM(int64_t dimension, int64_t m) {
resset.clear();
/* /*
* Faiss 1.6 * Faiss 1.6
* Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with * Only 1, 2, 3, 4, 6, 8, 10, 12, 16, 20, 24, 28, 32 dims per sub-quantizer are currently supported with
...@@ -172,6 +182,13 @@ IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset) ...@@ -172,6 +182,13 @@ IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset)
static const std::vector<int64_t> support_dim_per_subquantizer{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1}; static const std::vector<int64_t> support_dim_per_subquantizer{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1};
static const std::vector<int64_t> support_subquantizer{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1}; static const std::vector<int64_t> support_subquantizer{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1};
int64_t sub_dim = dimension / m;
return (std::find(std::begin(support_subquantizer), std::end(support_subquantizer), m) !=
support_subquantizer.end()) &&
(std::find(std::begin(support_dim_per_subquantizer), std::end(support_dim_per_subquantizer), sub_dim) !=
support_dim_per_subquantizer.end());
/*resset.clear();
for (const auto& dimperquantizer : support_dim_per_subquantizer) { for (const auto& dimperquantizer : support_dim_per_subquantizer) {
if (!(dimension % dimperquantizer)) { if (!(dimension % dimperquantizer)) {
auto subquantzier_num = dimension / dimperquantizer; auto subquantzier_num = dimension / dimperquantizer;
...@@ -180,7 +197,12 @@ IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset) ...@@ -180,7 +197,12 @@ IVFPQConfAdapter::GetValidMList(int64_t dimension, std::vector<int64_t>& resset)
resset.push_back(subquantzier_num); resset.push_back(subquantzier_num);
} }
} }
} }*/
}
bool
IVFPQConfAdapter::GetValidCPUM(int64_t dimension, int64_t m) {
return (dimension % m == 0);
} }
bool bool
...@@ -277,11 +299,10 @@ RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) { ...@@ -277,11 +299,10 @@ RHNSWPQConfAdapter::CheckTrain(Config& oricfg, const IndexMode mode) {
CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION); CheckIntByRange(knowhere::IndexParams::efConstruction, MIN_EFCONSTRUCTION, MAX_EFCONSTRUCTION);
CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M); CheckIntByRange(knowhere::IndexParams::M, MIN_M, MAX_M);
std::vector<int64_t> resset;
auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>(); auto dimension = oricfg[knowhere::meta::DIM].get<int64_t>();
IVFPQConfAdapter::GetValidMList(dimension, resset);
CheckIntByValues(knowhere::IndexParams::PQM, resset); IVFPQConfAdapter::GetValidCPUM(dimension, oricfg[knowhere::IndexParams::PQM].get<int64_t>());
return ConfAdapter::CheckTrain(oricfg, mode); return ConfAdapter::CheckTrain(oricfg, mode);
} }
......
...@@ -51,8 +51,14 @@ class IVFPQConfAdapter : public IVFConfAdapter { ...@@ -51,8 +51,14 @@ class IVFPQConfAdapter : public IVFConfAdapter {
bool bool
CheckTrain(Config& oricfg, const IndexMode mode) override; CheckTrain(Config& oricfg, const IndexMode mode) override;
static void static bool
GetValidMList(int64_t dimension, std::vector<int64_t>& resset); GetValidM(int64_t dimension, int64_t m, IndexMode& mode);
static bool
GetValidGPUM(int64_t dimension, int64_t m);
static bool
GetValidCPUM(int64_t dimension, int64_t m);
}; };
class NSGConfAdapter : public IVFConfAdapter { class NSGConfAdapter : public IVFConfAdapter {
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "knowhere/index/vector_index/adapter/VectorAdapter.h" #include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h" #include "knowhere/index/vector_index/helpers/IndexParameter.h"
#ifdef MILVUS_GPU_VERSION #ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/ConfAdapter.h"
#include "knowhere/index/vector_index/gpu/IndexGPUIVF.h" #include "knowhere/index/vector_index/gpu/IndexGPUIVF.h"
#include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h" #include "knowhere/index/vector_index/gpu/IndexGPUIVFPQ.h"
#endif #endif
...@@ -47,6 +48,12 @@ IVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) { ...@@ -47,6 +48,12 @@ IVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
VecIndexPtr VecIndexPtr
IVFPQ::CopyCpuToGpu(const int64_t device_id, const Config& config) { IVFPQ::CopyCpuToGpu(const int64_t device_id, const Config& config) {
#ifdef MILVUS_GPU_VERSION #ifdef MILVUS_GPU_VERSION
auto ivfpq_index = dynamic_cast<faiss::IndexIVFPQ*>(index_.get());
int64_t dim = ivfpq_index->d;
int64_t m = ivfpq_index->pq.M;
if (!IVFPQConfAdapter::GetValidGPUM(dim, m)) {
return nullptr;
}
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) { if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(device_id)) {
ResScope rs(res, device_id, false); ResScope rs(res, device_id, false);
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get()); auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), device_id, index_.get());
......
...@@ -65,8 +65,9 @@ CopyCpuToGpu(const VecIndexPtr& index, const int64_t device_id, const Config& co ...@@ -65,8 +65,9 @@ CopyCpuToGpu(const VecIndexPtr& index, const int64_t device_id, const Config& co
} else { } else {
KNOWHERE_THROW_MSG("this index type not support transfer to gpu"); KNOWHERE_THROW_MSG("this index type not support transfer to gpu");
} }
if (result != nullptr) {
CopyIndexData(result, index); CopyIndexData(result, index);
}
return result; return result;
} }
......
...@@ -67,6 +67,7 @@ set(faiss_srcs ...@@ -67,6 +67,7 @@ set(faiss_srcs
) )
if (MILVUS_GPU_VERSION) if (MILVUS_GPU_VERSION)
set(faiss_srcs ${faiss_srcs} set(faiss_srcs ${faiss_srcs}
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/ConfAdapter.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/helpers/Cloner.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIDMAP.cpp
${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp ${INDEX_SOURCE_DIR}/knowhere/knowhere/index/vector_index/gpu/IndexGPUIVF.cpp
......
...@@ -248,7 +248,13 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s ...@@ -248,7 +248,13 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
} }
// special check for 'm' parameter // special check for 'm' parameter
std::vector<int64_t> resset; int64_t m_value = index_params[knowhere::IndexParams::m];
if (!milvus::knowhere::IVFPQConfAdapter::GetValidCPUM(dimension, m_value)) {
std::string msg = "Invalid m, dimension can't not be divided by m ";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
}
/*std::vector<int64_t> resset;
milvus::knowhere::IVFPQConfAdapter::GetValidMList(dimension, resset); milvus::knowhere::IVFPQConfAdapter::GetValidMList(dimension, resset);
int64_t m_value = index_params[knowhere::IndexParams::m]; int64_t m_value = index_params[knowhere::IndexParams::m];
if (resset.empty()) { if (resset.empty()) {
...@@ -270,7 +276,7 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s ...@@ -270,7 +276,7 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
LOG_SERVER_ERROR_ << msg; LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg); return Status(SERVER_INVALID_ARGUMENT, msg);
} }*/
} else if (index_type == knowhere::IndexEnum::INDEX_NSG) { } else if (index_type == knowhere::IndexEnum::INDEX_NSG) {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::search_length, 10, 300); auto status = CheckParameterRange(index_params, knowhere::IndexParams::search_length, 10, 300);
if (!status.ok()) { if (!status.ok()) {
...@@ -307,9 +313,13 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s ...@@ -307,9 +313,13 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
} }
// special check for 'PQM' parameter // special check for 'PQM' parameter
std::vector<int64_t> resset;
milvus::knowhere::IVFPQConfAdapter::GetValidMList(dimension, resset);
int64_t pqm_value = index_params[knowhere::IndexParams::PQM]; int64_t pqm_value = index_params[knowhere::IndexParams::PQM];
if (!milvus::knowhere::IVFPQConfAdapter::GetValidCPUM(dimension, pqm_value)) {
std::string msg = "Invalid m, dimension can't not be divided by m ";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg);
}
/*int64_t pqm_value = index_params[knowhere::IndexParams::PQM];
if (resset.empty()) { if (resset.empty()) {
std::string msg = "Invalid collection dimension, unable to get reasonable values for 'PQM'"; std::string msg = "Invalid collection dimension, unable to get reasonable values for 'PQM'";
LOG_SERVER_ERROR_ << msg; LOG_SERVER_ERROR_ << msg;
...@@ -329,7 +339,7 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s ...@@ -329,7 +339,7 @@ ValidateIndexParams(const milvus::json& index_params, int64_t dimension, const s
LOG_SERVER_ERROR_ << msg; LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_ARGUMENT, msg); return Status(SERVER_INVALID_ARGUMENT, msg);
} }*/
} }
} else if (index_type == knowhere::IndexEnum::INDEX_ANNOY) { } else if (index_type == knowhere::IndexEnum::INDEX_ANNOY) {
auto status = CheckParameterRange(index_params, knowhere::IndexParams::n_trees, 1, 1024); auto status = CheckParameterRange(index_params, knowhere::IndexParams::n_trees, 1, 1024);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册