未验证 提交 802d62eb 编写于 作者: G groot 提交者: GitHub

#1651 Check validity of dimension when collection metric type is binary one (#1666)

* #1648 The cache cannot be used all when the type is binary
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* #1646 The cache cannot be used all when the type is binary
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* #1646 The cache cannot be used all when the type is binary
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* #1651 Check validity of dimension when collection metric type is binary one
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* typo
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>
上级 6c826c13
......@@ -8,7 +8,8 @@ Please mark all change in change log and use the issue from GitHub
- \#1301 Data in WAL may be accidentally inserted into a new table with the same name.
- \#1634 Fix search demo bug in HTTP doc
- \#1635 Vectors can be returned by searching after vectors deleted if `cache_insert_data` set true
- \#1648 The cache cannot be used all when the type is binary
- \#1648 The cache cannot be used all when the vector type is binary
- \#1651 Check validity of dimension when collection metric type is binary one
## Feature
- \#1603 BinaryFlat add 2 Metric: Substructure and Superstructure
......
......@@ -622,66 +622,6 @@ ExecutionEngineImpl::CopyToCpu() {
return Status::OK();
}
// ExecutionEnginePtr
// ExecutionEngineImpl::Clone() {
// if (index_ == nullptr) {
// ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to clone";
// return nullptr;
// }
//
// auto ret = std::make_shared<ExecutionEngineImpl>(dim_, location_, index_type_, metric_type_, nlist_);
// ret->Init();
// ret->index_ = index_->Clone();
// return ret;
//}
/*
Status
ExecutionEngineImpl::Merge(const std::string& location) {
if (location == location_) {
return Status(DB_ERROR, "Cannot Merge Self");
}
ENGINE_LOG_DEBUG << "Merge index file: " << location << " to: " << location_;
auto to_merge = cache::CpuCacheMgr::GetInstance()->GetIndex(location);
if (!to_merge) {
try {
double physical_size = server::CommonUtil::GetFileSize(location);
server::CollectExecutionEngineMetrics metrics(physical_size);
to_merge = read_index(location);
} catch (std::exception& e) {
ENGINE_LOG_ERROR << e.what();
return Status(DB_ERROR, e.what());
}
}
if (index_ == nullptr) {
ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to merge";
return Status(DB_ERROR, "index is null");
}
if (auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge)) {
auto status = index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to merge: " << location << " to: " << location_;
} else {
ENGINE_LOG_DEBUG << "Finish merge index file: " << location;
}
return status;
} else if (auto bin_index = std::dynamic_pointer_cast<BinBFIndex>(to_merge)) {
auto status = index_->Add(bin_index->Count(), bin_index->GetRawVectors(), bin_index->GetRawIds());
if (!status.ok()) {
ENGINE_LOG_ERROR << "Failed to merge: " << location << " to: " << location_;
} else {
ENGINE_LOG_DEBUG << "Finish merge index file: " << location;
}
return status;
} else {
return Status(DB_ERROR, "file index type is not idmap");
}
}
*/
ExecutionEnginePtr
ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_type) {
ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_;
......
......@@ -62,12 +62,6 @@ class ExecutionEngineImpl : public ExecutionEngine {
Status
CopyToCpu() override;
// ExecutionEnginePtr
// Clone() override;
// Status
// Merge(const std::string& location) override;
Status
GetVectorByID(const int64_t& id, float* vector, bool hybrid) override;
......
......@@ -52,7 +52,7 @@ CreateTableRequest::OnExecute() {
return status;
}
status = ValidationUtil::ValidateTableDimension(dimension_);
status = ValidationUtil::ValidateTableDimension(dimension_, metric_type_);
if (!status.ok()) {
return status;
}
......
......@@ -11,6 +11,7 @@
#include "utils/ValidationUtil.h"
#include "Log.h"
#include "db/Utils.h"
#include "db/engine/ExecutionEngine.h"
#include "index/knowhere/knowhere/index/vector_index/helpers/IndexParameter.h"
#include "utils/StringHelpFunctions.h"
......@@ -128,16 +129,25 @@ ValidationUtil::ValidateTableName(const std::string& table_name) {
}
Status
ValidationUtil::ValidateTableDimension(int64_t dimension) {
ValidationUtil::ValidateTableDimension(int64_t dimension, int64_t metric_type) {
if (dimension <= 0 || dimension > TABLE_DIMENSION_LIMIT) {
std::string msg = "Invalid table dimension: " + std::to_string(dimension) + ". " +
"The table dimension must be within the range of 1 ~ " +
std::to_string(TABLE_DIMENSION_LIMIT) + ".";
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
} else {
return Status::OK();
}
if (milvus::engine::utils::IsBinaryMetricType(metric_type)) {
if ((dimension % 8) != 0) {
std::string msg = "Invalid table dimension: " + std::to_string(dimension) + ". " +
"The table dimension must be a multiple of 8";
SERVER_LOG_ERROR << msg;
return Status(SERVER_INVALID_VECTOR_DIMENSION, msg);
}
}
return Status::OK();
}
Status
......
......@@ -30,7 +30,7 @@ class ValidationUtil {
ValidateTableName(const std::string& table_name);
static Status
ValidateTableDimension(int64_t dimension);
ValidateTableDimension(int64_t dimension, int64_t metric_type);
static Status
ValidateTableIndexType(int32_t index_type);
......
......@@ -358,14 +358,55 @@ TEST(ValidationUtilTest, VALIDATE_TABLENAME_TEST) {
}
TEST(ValidationUtilTest, VALIDATE_DIMENSION_TEST) {
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(-1).code(),
milvus::SERVER_INVALID_VECTOR_DIMENSION);
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(0).code(),
milvus::SERVER_INVALID_VECTOR_DIMENSION);
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(32769).code(),
milvus::SERVER_INVALID_VECTOR_DIMENSION);
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(32768).code(), milvus::SERVER_SUCCESS);
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(1).code(), milvus::SERVER_SUCCESS);
std::vector<int64_t>
float_metric_types = {(int64_t)milvus::engine::MetricType::L2, (int64_t)milvus::engine::MetricType::IP};
std::vector<int64_t>
binary_metric_types = {
(int64_t)milvus::engine::MetricType::JACCARD,
(int64_t)milvus::engine::MetricType::TANIMOTO,
(int64_t)milvus::engine::MetricType::HAMMING,
(int64_t)milvus::engine::MetricType::SUBSTRUCTURE,
(int64_t)milvus::engine::MetricType::SUPERSTRUCTURE
};
std::vector<int64_t> valid_float_dimensions = {1, 512, 32768};
std::vector<int64_t> invalid_float_dimensions = {-1, 0, 32769};
std::vector<int64_t> valid_binary_dimensions = {8, 1024, 32768};
std::vector<int64_t> invalid_binary_dimensions = {-1, 0, 32769, 1, 15, 999};
// valid float dimensions
for (auto dim : valid_float_dimensions) {
for (auto metric : float_metric_types) {
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(dim, metric).code(),
milvus::SERVER_SUCCESS);
}
}
// invalid float dimensions
for (auto dim : invalid_float_dimensions) {
for (auto metric : float_metric_types) {
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(dim, metric).code(),
milvus::SERVER_INVALID_VECTOR_DIMENSION);
}
}
// valid binary dimensions
for (auto dim : valid_binary_dimensions) {
for (auto metric : binary_metric_types) {
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(dim, metric).code(),
milvus::SERVER_SUCCESS);
}
}
// invalid binary dimensions
for (auto dim : invalid_binary_dimensions) {
for (auto metric : binary_metric_types) {
ASSERT_EQ(milvus::server::ValidationUtil::ValidateTableDimension(dim, metric).code(),
milvus::SERVER_INVALID_VECTOR_DIMENSION);
}
}
}
TEST(ValidationUtilTest, VALIDATE_INDEX_TEST) {
......
......@@ -39,7 +39,10 @@ BuildBinaryVectors(int64_t from, int64_t to, std::vector<milvus::Entity>& entity
entity_array.clear();
entity_ids.clear();
int64_t dim_byte = dimension / 8;
int64_t dim_byte = ceil(dimension / 8);
if ((dimension % 8) > 0) {
dim_byte++;
}
for (int64_t k = from; k < to; k++) {
milvus::Entity entity;
entity.binary_data.resize(dim_byte);
......@@ -151,8 +154,8 @@ ClientTest::Test(const std::string& address, const std::string& port) {
{
milvus::CollectionParam collection_param = {
"collection_1",
512,
256,
512, // dimension
256, // index file size
milvus::MetricType::TANIMOTO
};
......@@ -169,12 +172,12 @@ ClientTest::Test(const std::string& address, const std::string& port) {
{
milvus::CollectionParam collection_param = {
"collection_2",
256,
512,
512, // dimension
512, // index file size
milvus::MetricType::SUBSTRUCTURE
};
JSON json_params = {{"nlist", 2048}};
JSON json_params = {};
milvus::IndexParam index_param = {
collection_param.collection_name,
milvus::IndexType::FLAT,
......@@ -187,12 +190,12 @@ ClientTest::Test(const std::string& address, const std::string& port) {
{
milvus::CollectionParam collection_param = {
"collection_3",
128,
1024,
128, // dimension
1024, // index file size
milvus::MetricType::SUPERSTRUCTURE
};
JSON json_params = {{"nlist", 4092}};
JSON json_params = {};
milvus::IndexParam index_param = {
collection_param.collection_name,
milvus::IndexType::FLAT,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册