From 802d62eb24a70b5fc209ec59136777c4eb369e52 Mon Sep 17 00:00:00 2001 From: groot Date: Mon, 16 Mar 2020 16:26:12 +0800 Subject: [PATCH] #1651 Check validity of dimension when collection metric type is binary one (#1666) * #1648 The cache cannot be used all when the type is binary Signed-off-by: groot * #1646 The cache cannot be used all when the type is binary Signed-off-by: groot * #1646 The cache cannot be used all when the type is binary Signed-off-by: groot * #1651 Check validity of dimension when collection metric type is binary one Signed-off-by: groot * typo Signed-off-by: yhmo --- CHANGELOG.md | 3 +- core/src/db/engine/ExecutionEngineImpl.cpp | 60 ------------------- core/src/db/engine/ExecutionEngineImpl.h | 6 -- .../delivery/request/CreateTableRequest.cpp | 2 +- core/src/utils/ValidationUtil.cpp | 16 ++++- core/src/utils/ValidationUtil.h | 2 +- core/unittest/server/test_util.cpp | 57 +++++++++++++++--- sdk/examples/binary_vector/src/ClientTest.cpp | 21 ++++--- 8 files changed, 78 insertions(+), 89 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d534d67..61e425d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,8 @@ Please mark all change in change log and use the issue from GitHub - \#1301 Data in WAL may be accidentally inserted into a new table with the same name. - \#1634 Fix search demo bug in HTTP doc - \#1635 Vectors can be returned by searching after vectors deleted if `cache_insert_data` set true -- \#1648 The cache cannot be used all when the type is binary +- \#1648 The cache cannot be used all when the vector type is binary +- \#1651 Check validity of dimension when collection metric type is binary one ## Feature - \#1603 BinaryFlat add 2 Metric: Substructure and Superstructure diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index fc0cdda3..921c9206 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -622,66 +622,6 @@ ExecutionEngineImpl::CopyToCpu() { return Status::OK(); } -// ExecutionEnginePtr -// ExecutionEngineImpl::Clone() { -// if (index_ == nullptr) { -// ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to clone"; -// return nullptr; -// } -// -// auto ret = std::make_shared(dim_, location_, index_type_, metric_type_, nlist_); -// ret->Init(); -// ret->index_ = index_->Clone(); -// return ret; -//} - -/* -Status -ExecutionEngineImpl::Merge(const std::string& location) { - if (location == location_) { - return Status(DB_ERROR, "Cannot Merge Self"); - } - ENGINE_LOG_DEBUG << "Merge index file: " << location << " to: " << location_; - - auto to_merge = cache::CpuCacheMgr::GetInstance()->GetIndex(location); - if (!to_merge) { - try { - double physical_size = server::CommonUtil::GetFileSize(location); - server::CollectExecutionEngineMetrics metrics(physical_size); - to_merge = read_index(location); - } catch (std::exception& e) { - ENGINE_LOG_ERROR << e.what(); - return Status(DB_ERROR, e.what()); - } - } - - if (index_ == nullptr) { - ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to merge"; - return Status(DB_ERROR, "index is null"); - } - - if (auto file_index = std::dynamic_pointer_cast(to_merge)) { - auto status = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds()); - if (!status.ok()) { - ENGINE_LOG_ERROR << "Failed to merge: " << location << " to: " << location_; - } else { - ENGINE_LOG_DEBUG << "Finish merge index file: " << location; - } - return status; - } else if (auto bin_index = std::dynamic_pointer_cast(to_merge)) { - auto status = index_->Add(bin_index->Count(), bin_index->GetRawVectors(), bin_index->GetRawIds()); - if (!status.ok()) { - ENGINE_LOG_ERROR << "Failed to merge: " << location << " to: " << location_; - } else { - ENGINE_LOG_DEBUG << "Finish merge index file: " << location; - } - return status; - } else { - return Status(DB_ERROR, "file index type is not idmap"); - } -} -*/ - ExecutionEnginePtr ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_type) { ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_; diff --git a/core/src/db/engine/ExecutionEngineImpl.h b/core/src/db/engine/ExecutionEngineImpl.h index 2ab781f4..10a3968c 100644 --- a/core/src/db/engine/ExecutionEngineImpl.h +++ b/core/src/db/engine/ExecutionEngineImpl.h @@ -62,12 +62,6 @@ class ExecutionEngineImpl : public ExecutionEngine { Status CopyToCpu() override; - // ExecutionEnginePtr - // Clone() override; - - // Status - // Merge(const std::string& location) override; - Status GetVectorByID(const int64_t& id, float* vector, bool hybrid) override; diff --git a/core/src/server/delivery/request/CreateTableRequest.cpp b/core/src/server/delivery/request/CreateTableRequest.cpp index d92db5a9..9f4a3878 100644 --- a/core/src/server/delivery/request/CreateTableRequest.cpp +++ b/core/src/server/delivery/request/CreateTableRequest.cpp @@ -52,7 +52,7 @@ CreateTableRequest::OnExecute() { return status; } - status = ValidationUtil::ValidateTableDimension(dimension_); + status = ValidationUtil::ValidateTableDimension(dimension_, metric_type_); if (!status.ok()) { return status; } diff --git a/core/src/utils/ValidationUtil.cpp b/core/src/utils/ValidationUtil.cpp index 50a6536f..be71ef0a 100644 --- a/core/src/utils/ValidationUtil.cpp +++ b/core/src/utils/ValidationUtil.cpp @@ -11,6 +11,7 @@ #include "utils/ValidationUtil.h" #include "Log.h" +#include "db/Utils.h" #include "db/engine/ExecutionEngine.h" #include "index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h" #include "utils/StringHelpFunctions.h" @@ -128,16 +129,25 @@ ValidationUtil::ValidateTableName(const std::string& table_name) { } Status -ValidationUtil::ValidateTableDimension(int64_t dimension) { +ValidationUtil::ValidateTableDimension(int64_t dimension, int64_t metric_type) { if (dimension <= 0 || dimension > TABLE_DIMENSION_LIMIT) { std::string msg = "Invalid table dimension: " + std::to_string(dimension) + ". " + "The table dimension must be within the range of 1 ~ " + std::to_string(TABLE_DIMENSION_LIMIT) + "."; SERVER_LOG_ERROR << msg; return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); - } else { - return Status::OK(); } + + if (milvus::engine::utils::IsBinaryMetricType(metric_type)) { + if ((dimension % 8) != 0) { + std::string msg = "Invalid table dimension: " + std::to_string(dimension) + ". " + + "The table dimension must be a multiple of 8"; + SERVER_LOG_ERROR << msg; + return Status(SERVER_INVALID_VECTOR_DIMENSION, msg); + } + } + + return Status::OK(); } Status diff --git a/core/src/utils/ValidationUtil.h b/core/src/utils/ValidationUtil.h index 481cb31f..787adc12 100644 --- a/core/src/utils/ValidationUtil.h +++ b/core/src/utils/ValidationUtil.h @@ -30,7 +30,7 @@ class ValidationUtil { ValidateTableName(const std::string& table_name); static Status - ValidateTableDimension(int64_t dimension); + ValidateTableDimension(int64_t dimension, int64_t metric_type); static Status ValidateTableIndexType(int32_t index_type); diff --git a/core/unittest/server/test_util.cpp b/core/unittest/server/test_util.cpp index 9fab0815..7b9f48d1 100644 --- a/core/unittest/server/test_util.cpp +++ b/core/unittest/server/test_util.cpp @@ -358,14 +358,55 @@ TEST(ValidationUtilTest, VALIDATE_TABLENAME_TEST) { } TEST(ValidationUtilTest, VALIDATE_DIMENSION_TEST) { - ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(-1).code(), - milvus::SERVER_INVALID_VECTOR_DIMENSION); - ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(0).code(), - milvus::SERVER_INVALID_VECTOR_DIMENSION); - ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(32769).code(), - milvus::SERVER_INVALID_VECTOR_DIMENSION); - ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(32768).code(), milvus::SERVER_SUCCESS); - ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(1).code(), milvus::SERVER_SUCCESS); + std::vector + float_metric_types = {(int64_t)milvus::engine::MetricType::L2, (int64_t)milvus::engine::MetricType::IP}; + + std::vector + binary_metric_types = { + (int64_t)milvus::engine::MetricType::JACCARD, + (int64_t)milvus::engine::MetricType::TANIMOTO, + (int64_t)milvus::engine::MetricType::HAMMING, + (int64_t)milvus::engine::MetricType::SUBSTRUCTURE, + (int64_t)milvus::engine::MetricType::SUPERSTRUCTURE + }; + + std::vector valid_float_dimensions = {1, 512, 32768}; + std::vector invalid_float_dimensions = {-1, 0, 32769}; + + std::vector valid_binary_dimensions = {8, 1024, 32768}; + std::vector invalid_binary_dimensions = {-1, 0, 32769, 1, 15, 999}; + + // valid float dimensions + for (auto dim : valid_float_dimensions) { + for (auto metric : float_metric_types) { + ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(dim, metric).code(), + milvus::SERVER_SUCCESS); + } + } + + // invalid float dimensions + for (auto dim : invalid_float_dimensions) { + for (auto metric : float_metric_types) { + ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(dim, metric).code(), + milvus::SERVER_INVALID_VECTOR_DIMENSION); + } + } + + // valid binary dimensions + for (auto dim : valid_binary_dimensions) { + for (auto metric : binary_metric_types) { + ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(dim, metric).code(), + milvus::SERVER_SUCCESS); + } + } + + // invalid binary dimensions + for (auto dim : invalid_binary_dimensions) { + for (auto metric : binary_metric_types) { + ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(dim, metric).code(), + milvus::SERVER_INVALID_VECTOR_DIMENSION); + } + } } TEST(ValidationUtilTest, VALIDATE_INDEX_TEST) { diff --git a/sdk/examples/binary_vector/src/ClientTest.cpp b/sdk/examples/binary_vector/src/ClientTest.cpp index e96acb5a..6fce2ab0 100644 --- a/sdk/examples/binary_vector/src/ClientTest.cpp +++ b/sdk/examples/binary_vector/src/ClientTest.cpp @@ -39,7 +39,10 @@ BuildBinaryVectors(int64_t from, int64_t to, std::vector& entity entity_array.clear(); entity_ids.clear(); - int64_t dim_byte = dimension / 8; + int64_t dim_byte = ceil(dimension / 8); + if ((dimension % 8) > 0) { + dim_byte++; + } for (int64_t k = from; k < to; k++) { milvus::Entity entity; entity.binary_data.resize(dim_byte); @@ -151,8 +154,8 @@ ClientTest::Test(const std::string& address, const std::string& port) { { milvus::CollectionParam collection_param = { "collection_1", - 512, - 256, + 512, // dimension + 256, // index file size milvus::MetricType::TANIMOTO }; @@ -169,12 +172,12 @@ ClientTest::Test(const std::string& address, const std::string& port) { { milvus::CollectionParam collection_param = { "collection_2", - 256, - 512, + 512, // dimension + 512, // index file size milvus::MetricType::SUBSTRUCTURE }; - JSON json_params = {{"nlist", 2048}}; + JSON json_params = {}; milvus::IndexParam index_param = { collection_param.collection_name, milvus::IndexType::FLAT, @@ -187,12 +190,12 @@ ClientTest::Test(const std::string& address, const std::string& port) { { milvus::CollectionParam collection_param = { "collection_3", - 128, - 1024, + 128, // dimension + 1024, // index file size milvus::MetricType::SUPERSTRUCTURE }; - JSON json_params = {{"nlist", 4092}}; + JSON json_params = {}; milvus::IndexParam index_param = { collection_param.collection_name, milvus::IndexType::FLAT, -- GitLab