未验证 提交 585c8ea0 编写于 作者: C cqy123456 提交者: GitHub

fix rnsg ip (#2827)

Signed-off-by: Ncqy <yaya645@126.com>
上级 d666dc8c
......@@ -21,6 +21,7 @@ Please mark all change in change log and use the issue from GitHub
- \#2767 Fix a bug of getting wrong nprobe limitation in knowhere on GPU version
- \#2768 After building the index,the number of vectors increases
- \#2776 Fix too many data copies during creating IVF index
- \#2813 To implemente RNSG IP
## Feature
......
......@@ -140,7 +140,18 @@ NSG::Train(const DatasetPtr& dataset_ptr, const Config& config) {
b_params.search_length = config[IndexParams::search_length];
GETTENSORWITHIDS(dataset_ptr)
index_ = std::make_shared<impl::NsgIndex>(dim, rows, config[Metric::TYPE].get<std::string>());
impl::NsgIndex::Metric_Type metric;
auto metric_str = config[Metric::TYPE].get<std::string>();
if (metric_str == knowhere::Metric::IP) {
metric = impl::NsgIndex::Metric_Type::Metric_Type_IP;
} else if (metric_str == knowhere::Metric::L2) {
metric = impl::NsgIndex::Metric_Type::Metric_Type_L2;
} else {
KNOWHERE_THROW_MSG("Metric is not supported");
}
index_ = std::make_shared<impl::NsgIndex>(dim, rows, metric);
index_->SetKnnGraph(knng);
index_->Build_with_ids(rows, (float*)p_data, (int64_t*)p_ids, b_params);
}
......
......@@ -237,7 +237,7 @@ DistanceL2::Compare(const float* a, const float* b, unsigned size) const {
float
DistanceIP::Compare(const float* a, const float* b, unsigned size) const {
return faiss::fvec_inner_product(a, b, (size_t)size);
return -(faiss::fvec_inner_product(a, b, (size_t)size));
}
#endif
......
......@@ -31,11 +31,11 @@ namespace impl {
unsigned int seed = 100;
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, std::string metric)
NsgIndex::NsgIndex(const size_t& dimension, const size_t& n, Metric_Type metric)
: dimension(dimension), ntotal(n), metric_type(metric) {
if (metric == knowhere::Metric::L2) {
if (metric == Metric_Type::Metric_Type_L2) {
distance_ = new DistanceL2;
} else if (metric == knowhere::Metric::IP) {
} else if (metric == Metric_Type::Metric_Type_IP) {
distance_ = new DistanceIP;
}
}
......@@ -406,7 +406,6 @@ NsgIndex::GetNeighbors(const float* query, std::vector<Neighbor>& resset, Graph&
// std::cout << "pos: " << pos << ", nn: " << nn.id << ":" << nn.distance << ", nup: " <<
// nearest_updated_pos << std::endl;
/////
// trick: avoid search query search_length < init_ids.size() ...
if (buffer_size + 1 < resset.size())
++buffer_size;
......@@ -846,6 +845,8 @@ NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, co
}
}
rc.RecordSection("search");
bool is_ip = (metric_type == Metric_Type::Metric_Type_IP);
for (unsigned int i = 0; i < nq; ++i) {
unsigned int pos = 0;
for (unsigned int j = 0; j < resset[i].size(); ++j) {
......@@ -853,7 +854,7 @@ NsgIndex::Search(const float* query, const unsigned& nq, const unsigned& dim, co
break; // already top k
if (!bitset || !bitset->test((faiss::ConcurrentBitset::id_type_t)resset[i][j].id)) {
ids[i * k + pos] = ids_[resset[i][j].id];
dist[i * k + pos] = resset[i][j].distance;
dist[i * k + pos] = is_ip ? -resset[i][j].distance : resset[i][j].distance;
++pos;
}
}
......
......@@ -43,9 +43,14 @@ using Graph = std::vector<std::vector<node_t>>;
class NsgIndex {
public:
enum Metric_Type {
Metric_Type_L2,
Metric_Type_IP,
};
size_t dimension;
size_t ntotal; // totabl nb of indexed vectors
std::string metric_type; // L2 | IP
size_t ntotal; // totabl nb of indexed vectors
int32_t metric_type; // enum Metric_Type
Distance* distance_;
float* ori_data_;
......@@ -65,7 +70,7 @@ class NsgIndex {
size_t out_degree;
public:
explicit NsgIndex(const size_t& dimension, const size_t& n, std::string metric = knowhere::Metric::L2);
explicit NsgIndex(const size_t& dimension, const size_t& n, Metric_Type metric);
NsgIndex() = default;
......
......@@ -19,6 +19,7 @@ namespace impl {
void
write_index(NsgIndex* index, MemoryIOWriter& writer) {
writer(&index->metric_type, sizeof(int32_t), 1);
writer(&index->ntotal, sizeof(index->ntotal), 1);
writer(&index->dimension, sizeof(index->dimension), 1);
writer(&index->navigation_point, sizeof(index->navigation_point), 1);
......@@ -36,9 +37,11 @@ NsgIndex*
read_index(MemoryIOReader& reader) {
size_t ntotal;
size_t dimension;
int32_t metric;
reader(&metric, sizeof(int32_t), 1);
reader(&ntotal, sizeof(size_t), 1);
reader(&dimension, sizeof(size_t), 1);
auto index = new NsgIndex(dimension, ntotal);
auto index = new NsgIndex(dimension, ntotal, (impl::NsgIndex::Metric_Type)metric);
reader(&index->navigation_point, sizeof(index->navigation_point), 1);
index->ori_data_ = new float[index->ntotal * index->dimension];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册