未验证 提交 5e59056a 编写于 作者: Y yukun 提交者: GitHub

Fix DBImpl and scheduler for search (#3062)

* Add search implementation in ExecutionEngine
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Remove attr_type from query_ptr
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add CreateStructuredIndex in ExecutionEngine
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add index fields in QueryPtr
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix SearchReq bugs
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* code format
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add Search in scheduler
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix SearchJob and SearchTask
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Add MergeTopkResultSet
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Remove nq in VectorQuery
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Change segment_size to segment_row_count in C++ sdk
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>

* Fix row_count in Search
Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
Co-authored-by: NWang XiangYu <xy.wang@zilliz.com>
上级 7cc4862d
......@@ -573,7 +573,52 @@ DBImpl::Query(const server::ContextPtr& context, const query::QueryPtr& query_pt
TimeRecorder rc("DBImpl::Query");
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(nullptr, options_, query_ptr);
snapshot::ScopedSnapshotT ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, query_ptr->collection_id));
auto ss_id = ss->GetID();
/* collect all valid segment */
std::vector<SegmentVisitor::Ptr> segment_visitors;
auto exec = [&](const snapshot::Segment::Ptr& segment, snapshot::SegmentIterator* handler) -> Status {
auto p_id = segment->GetPartitionId();
auto p_ptr = ss->GetResource<snapshot::Partition>(p_id);
auto& p_name = p_ptr->GetName();
/* check partition match pattern */
bool match = false;
if (query_ptr->partitions.empty()) {
match = true;
} else {
for (auto& pattern : query_ptr->partitions) {
if (StringHelpFunctions::IsRegexMatch(p_name, pattern)) {
match = true;
break;
}
}
}
if (match) {
auto visitor = SegmentVisitor::Build(ss, segment->GetID());
if (!visitor) {
return Status(milvus::SS_ERROR, "Cannot build segment visitor");
}
segment_visitors.push_back(visitor);
}
return Status::OK();
};
auto segment_iter = std::make_shared<snapshot::SegmentIterator>(ss, exec);
segment_iter->Iterate();
STATUS_CHECK(segment_iter->GetStatus());
LOG_ENGINE_DEBUG_ << LogOut("Engine query begin, segment count: %ld", segment_visitors.size());
engine::snapshot::IDS_TYPE segment_ids;
for (auto& sv : segment_visitors) {
segment_ids.emplace_back(sv->GetSegment()->GetID());
}
scheduler::SearchJobPtr job = std::make_shared<scheduler::SearchJob>(nullptr, ss, options_, query_ptr, segment_ids);
/* put search job to scheduler and wait job finish */
scheduler::JobMgrInst::GetInstance()->Put(job);
......@@ -583,61 +628,8 @@ DBImpl::Query(const server::ContextPtr& context, const query::QueryPtr& query_pt
return job->status();
}
// snapshot::ScopedSnapshotT ss;
// STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name));
//
// /* collect all valid segment */
// std::vector<SegmentVisitor::Ptr> segment_visitors;
// auto exec = [&] (const snapshot::Segment::Ptr& segment, snapshot::SegmentIterator* handler) -> Status {
// auto p_id = segment->GetPartitionId();
// auto p_ptr = ss->GetResource<snapshot::Partition>(p_id);
// auto& p_name = p_ptr->GetName();
//
// /* check partition match pattern */
// bool match = false;
// if (partition_patterns.empty()) {
// match = true;
// } else {
// for (auto &pattern : partition_patterns) {
// if (StringHelpFunctions::IsRegexMatch(p_name, pattern)) {
// match = true;
// break;
// }
// }
// }
//
// if (match) {
// auto visitor = SegmentVisitor::Build(ss, segment->GetID());
// if (!visitor) {
// return Status(milvus::SS_ERROR, "Cannot build segment visitor");
// }
// segment_visitors.push_back(visitor);
// }
// return Status::OK();
// };
//
// auto segment_iter = std::make_shared<snapshot::SegmentIterator>(ss, exec);
// segment_iter->Iterate();
// STATUS_CHECK(segment_iter->GetStatus());
//
// LOG_ENGINE_DEBUG_ << LogOut("Engine query begin, segment count: %ld", segment_visitors.size());
//
// VectorsData vectors;
// scheduler::SearchJobPtr job =
// std::make_shared<scheduler::SSSearchJob>(tracer.Context(), general_query, query_ptr, attr_type, vectors);
// for (auto& sv : segment_visitors) {
// job->AddSegmentVisitor(sv);
// }
//
// // step 2: put search job to scheduler and wait result
// scheduler::JobMgrInst::GetInstance()->Put(job);
// job->WaitResult();
//
// if (!job->GetStatus().ok()) {
// return job->GetStatus();
// }
//
// // step 3: construct results
result = job->query_result();
// step 3: construct results
// result.row_num_ = job->vector_count();
// result.result_ids_ = job->GetResultIds();
// result.result_distances_ = job->GetResultDistances();
......
......@@ -191,6 +191,84 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id) {
return Status::OK();
}
void
MapAndCopyResult(const knowhere::DatasetPtr& dataset, const std::vector<milvus::segment::doc_id_t>& uids, int64_t nq,
int64_t k, float* distances, int64_t* labels) {
int64_t* res_ids = dataset->Get<int64_t*>(knowhere::meta::IDS);
float* res_dist = dataset->Get<float*>(knowhere::meta::DISTANCE);
memcpy(distances, res_dist, sizeof(float) * nq * k);
/* map offsets to ids */
int64_t num = nq * k;
for (int64_t i = 0; i < num; ++i) {
int64_t offset = res_ids[i];
if (offset != -1) {
labels[i] = uids[offset];
} else {
labels[i] = -1;
}
}
free(res_ids);
free(res_dist);
}
Status
ExecutionEngineImpl::VecSearch(milvus::engine::ExecutionEngineContext& context,
const query::VectorQueryPtr& vector_param, knowhere::VecIndexPtr& vec_index,
bool hybrid) {
TimeRecorder rc(LogOut("[%s][%ld] ExecutionEngineImpl::Search", "search", 0));
if (vec_index == nullptr) {
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] ExecutionEngineImpl: index is null, failed to search", "search", 0);
return Status(DB_ERROR, "index is null");
}
uint64_t nq = 0;
auto query_vector = vector_param->query_vector;
if (!query_vector.float_data.empty()) {
nq = vector_param->query_vector.float_data.size() / vec_index->Dim();
} else if (!query_vector.binary_data.empty()) {
nq = vector_param->query_vector.binary_data.size() * 8 / vec_index->Dim();
}
uint64_t topk = vector_param->topk;
context.query_result_ = std::make_shared<QueryResult>();
context.query_result_->result_ids_.resize(topk * nq);
context.query_result_->result_distances_.resize(topk * nq);
milvus::json conf = vector_param->extra_params;
conf[knowhere::meta::TOPK] = topk;
auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(vec_index->index_type());
if (!adapter->CheckSearch(conf, vec_index->index_type(), vec_index->index_mode())) {
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] Illegal search params", "search", 0);
throw Exception(DB_ERROR, "Illegal search params");
}
if (hybrid) {
// HybridLoad();
}
rc.RecordSection("query prepare");
knowhere::DatasetPtr dataset;
if (!query_vector.float_data.empty()) {
dataset = knowhere::GenDataset(nq, vec_index->Dim(), query_vector.float_data.data());
} else {
dataset = knowhere::GenDataset(nq, vec_index->Dim(), query_vector.binary_data.data());
}
auto result = vec_index->Query(dataset, conf);
MapAndCopyResult(result, vec_index->GetUids(), nq, topk, context.query_result_->result_distances_.data(),
context.query_result_->result_ids_.data());
if (hybrid) {
// HybridUnset();
}
return Status::OK();
}
Status
ExecutionEngineImpl::Search(ExecutionEngineContext& context) {
try {
......@@ -212,7 +290,6 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) {
if (field->GetFtype() == (int)engine::meta::DataType::VECTOR_FLOAT ||
field->GetFtype() == (int)engine::meta::DataType::VECTOR_BINARY) {
segment_ptr->GetVectorIndex(field->GetName(), vec_index);
break;
} else if (type == (int)engine::meta::DataType::UID) {
continue;
} else {
......@@ -236,8 +313,14 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) {
}
vec_index->SetBlacklist(list);
auto vector_query = context.query_ptr_->vectors.at(vector_placeholder);
auto& vector_param = context.query_ptr_->vectors.at(vector_placeholder);
if (!vector_param->query_vector.float_data.empty()) {
vector_param->nq = vector_param->query_vector.float_data.size() / vec_index->Dim();
} else if (!vector_param->query_vector.binary_data.empty()) {
vector_param->nq = vector_param->query_vector.binary_data.size() * 8 / vec_index->Dim();
}
status = VecSearch(context, context.query_ptr_->vectors.at(vector_placeholder), vec_index);
if (!status.ok()) {
return status;
}
......
......@@ -41,6 +41,10 @@ class ExecutionEngineImpl : public ExecutionEngine {
BuildIndex() override;
private:
Status
VecSearch(ExecutionEngineContext& context, const query::VectorQueryPtr& vector_param,
knowhere::VecIndexPtr& vec_index, bool hybrid = false);
knowhere::VecIndexPtr
CreateVecIndex(const std::string& index_name);
......
......@@ -77,6 +77,7 @@ struct VectorQuery {
std::string field_name;
milvus::json extra_params = {};
int64_t topk;
int64_t nq;
float boost;
VectorRecord query_vector;
};
......
......@@ -16,15 +16,21 @@
namespace milvus {
namespace scheduler {
SearchJob::SearchJob(const server::ContextPtr& context, engine::DBOptions options, const query::QueryPtr& query_ptr)
: Job(JobType::SEARCH), context_(context), options_(options), query_ptr_(query_ptr) {
GetSegmentsFromQuery(query_ptr, segment_ids_);
SearchJob::SearchJob(const server::ContextPtr& context, const engine::snapshot::ScopedSnapshotT& snapshot,
engine::DBOptions options, const query::QueryPtr& query_ptr,
const engine::snapshot::IDS_TYPE& segment_ids)
: Job(JobType::SEARCH),
context_(context),
snapshot_(snapshot),
options_(options),
query_ptr_(query_ptr),
segment_ids_(segment_ids) {
}
void
SearchJob::OnCreateTasks(JobTasks& tasks) {
for (auto& id : segment_ids_) {
auto task = std::make_shared<SearchTask>(context_, options_, query_ptr_, id, nullptr);
auto task = std::make_shared<SearchTask>(context_, snapshot_, options_, query_ptr_, id, nullptr);
task->job_ = this;
tasks.emplace_back(task);
}
......@@ -40,10 +46,5 @@ SearchJob::Dump() const {
return ret;
}
void
SearchJob::GetSegmentsFromQuery(const query::QueryPtr& query_ptr, engine::snapshot::IDS_TYPE& segment_ids) {
// TODO
}
} // namespace scheduler
} // namespace milvus
......@@ -39,7 +39,9 @@ namespace scheduler {
class SearchJob : public Job {
public:
SearchJob(const server::ContextPtr& context, engine::DBOptions options, const query::QueryPtr& query_ptr);
SearchJob(const server::ContextPtr& context, const engine::snapshot::ScopedSnapshotT& snapshot,
engine::DBOptions options, const query::QueryPtr& query_ptr,
const engine::snapshot::IDS_TYPE& segment_ids);
public:
json
......@@ -74,13 +76,9 @@ class SearchJob : public Job {
void
OnCreateTasks(JobTasks& tasks) override;
private:
void
GetSegmentsFromQuery(const query::QueryPtr& query_ptr, engine::snapshot::IDS_TYPE& segment_ids);
private:
const server::ContextPtr context_;
engine::snapshot::ScopedSnapshotT snapshot_;
engine::DBOptions options_;
query::QueryPtr query_ptr_;
......
......@@ -29,10 +29,12 @@
namespace milvus {
namespace scheduler {
SearchTask::SearchTask(const server::ContextPtr& context, const engine::DBOptions& options,
const query::QueryPtr& query_ptr, engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label)
SearchTask::SearchTask(const server::ContextPtr& context, engine::snapshot::ScopedSnapshotT snapshot,
const engine::DBOptions& options, const query::QueryPtr& query_ptr,
engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label)
: Task(TaskType::SearchTask, std::move(label)),
context_(context),
snapshot_(snapshot),
options_(options),
query_ptr_(query_ptr),
segment_id_(segment_id) {
......@@ -42,9 +44,7 @@ SearchTask::SearchTask(const server::ContextPtr& context, const engine::DBOption
void
SearchTask::CreateExecEngine() {
if (execution_engine_ == nullptr && query_ptr_ != nullptr) {
engine::snapshot::ScopedSnapshotT latest_ss;
engine::snapshot::Snapshots::GetInstance().GetSnapshot(latest_ss, query_ptr_->collection_id);
execution_engine_ = engine::EngineFactory::Build(latest_ss, options_.meta_.path_, segment_id_);
execution_engine_ = engine::EngineFactory::Build(snapshot_, options_.meta_.path_, segment_id_);
}
}
......@@ -106,29 +106,37 @@ SearchTask::OnExecute() {
return Status(DB_ERROR, "execution engine is null");
}
// auto search_job = std::static_pointer_cast<scheduler::SearchJob>(std::shared_ptr<scheduler::Job>(job_));
auto search_job = static_cast<scheduler::SearchJob*>(job_);
try {
/* step 2: search */
engine::ExecutionEngineContext context;
context.query_ptr_ = query_ptr_;
context.query_result_ = std::make_shared<engine::QueryResult>();
auto status = execution_engine_->Search(context);
if (!status.ok()) {
return status;
}
STATUS_CHECK(execution_engine_->Search(context));
rc.RecordSection("search done");
/* step 3: pick up topk result */
// auto spec_k = file_->row_count_ < topk ? file_->row_count_ : topk;
// if (spec_k == 0) {
// LOG_ENGINE_WARNING_ << LogOut("[%s][%ld] Searching in an empty file. file location = %s",
// "search", 0,
// file_->location_.c_str());
// } else {
// std::unique_lock<std::mutex> lock(search_job->mutex());
// XSearchTask::MergeTopkToResultSet(result, spec_k, nq, topk, ascending_, search_job->GetQueryResult());
// }
// TODO(yukun): Remove hardcode here
auto vector_param = context.query_ptr_->vectors.begin()->second;
auto topk = vector_param->topk;
auto segment_ptr = snapshot_->GetSegmentCommitBySegmentId(segment_id_);
auto spec_k = segment_ptr->GetRowCount() < topk ? segment_ptr->GetRowCount() : topk;
int64_t nq = vector_param->nq;
if (spec_k == 0) {
LOG_ENGINE_WARNING_ << LogOut("[%s][%ld] Searching in an empty segment. segment id = %d", "search", 0,
segment_ptr->GetID());
} else {
// std::unique_lock<std::mutex> lock(search_job->mutex());
if (!search_job->query_result()) {
search_job->query_result() = std::make_shared<engine::QueryResult>();
search_job->query_result()->row_num_ = nq;
}
SearchTask::MergeTopkToResultSet(context.query_result_->result_ids_,
context.query_result_->result_distances_, spec_k, nq, topk,
ascending_reduce_, search_job->query_result());
}
rc.RecordSection("reduce topk done");
} catch (std::exception& ex) {
......@@ -140,6 +148,72 @@ SearchTask::OnExecute() {
return Status::OK();
}
void
SearchTask::MergeTopkToResultSet(const engine::ResultIds& src_ids, const engine::ResultDistances& src_distances,
size_t src_k, size_t nq, size_t topk, bool ascending, engine::QueryResultPtr& result) {
if (src_ids.empty()) {
LOG_ENGINE_DEBUG_ << LogOut("[%s][%d] Search result is empty.", "search", 0);
return;
}
size_t tar_k = result->result_ids_.size() / nq;
size_t buf_k = std::min(topk, src_k + tar_k);
engine::ResultIds buf_ids(nq * buf_k, -1);
engine::ResultDistances buf_distances(nq * buf_k, 0.0);
for (uint64_t i = 0; i < nq; i++) {
size_t buf_k_j = 0, src_k_j = 0, tar_k_j = 0;
size_t buf_idx, src_idx, tar_idx;
size_t buf_k_multi_i = buf_k * i;
size_t src_k_multi_i = topk * i;
size_t tar_k_multi_i = tar_k * i;
while (buf_k_j < buf_k && src_k_j < src_k && tar_k_j < tar_k) {
src_idx = src_k_multi_i + src_k_j;
tar_idx = tar_k_multi_i + tar_k_j;
buf_idx = buf_k_multi_i + buf_k_j;
if ((result->result_ids_[tar_idx] == -1) || // initialized value
(ascending && src_distances[src_idx] < result->result_distances_[tar_idx]) ||
(!ascending && src_distances[src_idx] > result->result_distances_[tar_idx])) {
buf_ids[buf_idx] = src_ids[src_idx];
buf_distances[buf_idx] = src_distances[src_idx];
src_k_j++;
} else {
buf_ids[buf_idx] = result->result_ids_[tar_idx];
buf_distances[buf_idx] = result->result_distances_[tar_idx];
tar_k_j++;
}
buf_k_j++;
}
if (buf_k_j < buf_k) {
if (src_k_j < src_k) {
while (buf_k_j < buf_k && src_k_j < src_k) {
buf_idx = buf_k_multi_i + buf_k_j;
src_idx = src_k_multi_i + src_k_j;
buf_ids[buf_idx] = src_ids[src_idx];
buf_distances[buf_idx] = src_distances[src_idx];
src_k_j++;
buf_k_j++;
}
} else {
while (buf_k_j < buf_k && tar_k_j < tar_k) {
buf_idx = buf_k_multi_i + buf_k_j;
tar_idx = tar_k_multi_i + tar_k_j;
buf_ids[buf_idx] = result->result_ids_[tar_idx];
buf_distances[buf_idx] = result->result_distances_[tar_idx];
tar_k_j++;
buf_k_j++;
}
}
}
}
result->result_ids_.swap(buf_ids);
result->result_distances_.swap(buf_distances);
}
int64_t
SearchTask::nq() {
return 0;
......
......@@ -26,8 +26,9 @@ namespace scheduler {
class SearchTask : public Task {
public:
explicit SearchTask(const server::ContextPtr& context, const engine::DBOptions& options,
const query::QueryPtr& query_ptr, engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label);
explicit SearchTask(const server::ContextPtr& context, engine::snapshot::ScopedSnapshotT snapshot,
const engine::DBOptions& options, const query::QueryPtr& query_ptr,
engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label);
inline json
Dump() const override {
......@@ -44,6 +45,10 @@ class SearchTask : public Task {
Status
OnExecute() override;
static void
MergeTopkToResultSet(const engine::ResultIds& src_ids, const engine::ResultDistances& src_distances, size_t src_k,
size_t nq, size_t topk, bool ascending, engine::QueryResultPtr& result);
int64_t
nq();
......@@ -53,12 +58,17 @@ class SearchTask : public Task {
public:
const std::shared_ptr<server::Context> context_;
engine::snapshot::ScopedSnapshotT snapshot_;
const engine::DBOptions& options_;
query::QueryPtr query_ptr_;
engine::snapshot::ID_TYPE segment_id_;
engine::ExecutionEnginePtr execution_engine_;
// distance -- value 0 means two vectors equal, ascending reduce, L2/HAMMING/JACCARD/TONIMOTO ...
// similarity -- infinity value means two vectors equal, descending reduce, IP
bool ascending_reduce_ = true;
};
} // namespace scheduler
......
......@@ -643,7 +643,9 @@ GrpcRequestHandler::CreateCollection(::grpc::ServerContext* context, const ::mil
// Currently only one extra_param
if (field.extra_params_size() != 0) {
field_schema.field_params_ = json::parse(field.extra_params(0).value());
if (!field.extra_params(0).value().empty()) {
field_schema.field_params_ = json::parse(field.extra_params(0).value());
}
}
for (int j = 0; j < field.index_params_size(); j++) {
......
......@@ -114,7 +114,7 @@ ClientTest::CreateCollection(const std::string& collection_name) {
field_ptr4->extra_params = extra_params_4.dump();
JSON extra_params;
extra_params["segment_size"] = 1024;
extra_params["segment_row_count"] = 1024;
milvus::Mapping mapping = {collection_name, {field_ptr1, field_ptr2, field_ptr3, field_ptr4}};
milvus::Status stat = conn_->CreateCollection(mapping, extra_params.dump());
......@@ -352,5 +352,5 @@ ClientTest::Test() {
// entities
//
// DropIndex(collection_name, "field_vec", "index_3");
DropCollection(collection_name);
// DropCollection(collection_name);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册