未验证 提交 fdd51400 编写于 作者: O op-hunter 提交者: GitHub

#1660 support IVFPQ CPU delete (#1695)

* support IVFPQ CPU delete
Signed-off-by: Nlichengming <chengming.li@zilliz.com>

* update CHANGELOG.md
Signed-off-by: Nlichengming <chengming.li@zilliz.com>
Co-authored-by: Nlichengming <chengming.li@zilliz.com>
上级 504a9e30
......@@ -14,6 +14,7 @@ Please mark all change in change log and use the issue from GitHub
## Feature
- \#1603 BinaryFlat add 2 Metric: Substructure and Superstructure
- \#1660 IVF PQ CPU support deleted vectors searching
## Improvement
- \#1537 Optimize raw vector and uids read/write
......
......@@ -30,7 +30,8 @@ void Index::train(idx_t /*n*/, const float* /*x*/) {
void Index::range_search (idx_t , const float *, float,
RangeSearchResult *) const
RangeSearchResult *,
ConcurrentBitsetPtr) const
{
FAISS_THROW_MSG ("range search not implemented");
}
......
......@@ -176,7 +176,8 @@ struct Index {
* @param result result table
*/
virtual void range_search (idx_t n, const float *x, float radius,
RangeSearchResult *result) const;
RangeSearchResult *result,
ConcurrentBitsetPtr bitset = nullptr) const;
/** return the indexes of the k vectors closest to the query x.
*
......
......@@ -21,7 +21,8 @@ void IndexBinary::train(idx_t, const uint8_t *) {
}
void IndexBinary::range_search(idx_t, const uint8_t *, int,
RangeSearchResult *) const {
RangeSearchResult *,
ConcurrentBitsetPtr) const {
FAISS_THROW_MSG("range search not implemented");
}
......
......@@ -134,7 +134,8 @@ struct IndexBinary {
* @param result result table
*/
virtual void range_search(idx_t n, const uint8_t *x, int radius,
RangeSearchResult *result) const;
RangeSearchResult *result,
ConcurrentBitsetPtr bitset = nullptr) const;
/** Return the indexes of the k vectors closest to the query x.
*
......
......@@ -65,7 +65,8 @@ void IndexFlat::search(idx_t n, const float* x, idx_t k, float* distances, idx_t
}
void IndexFlat::range_search (idx_t n, const float *x, float radius,
RangeSearchResult *result) const
RangeSearchResult *result,
ConcurrentBitsetPtr bitset) const
{
switch (metric_type) {
case METRIC_INNER_PRODUCT:
......
......@@ -40,7 +40,8 @@ struct IndexFlat: Index {
idx_t n,
const float* x,
float radius,
RangeSearchResult* result) const override;
RangeSearchResult* result,
ConcurrentBitsetPtr bitset = nullptr) const override;
void reconstruct(idx_t key, float* recons) const override;
......
......@@ -543,7 +543,8 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
void IndexIVF::range_search (idx_t nx, const float *x, float radius,
RangeSearchResult *result) const
RangeSearchResult *result,
ConcurrentBitsetPtr bitset) const
{
std::unique_ptr<idx_t[]> keys (new idx_t[nx * nprobe]);
std::unique_ptr<float []> coarse_dis (new float[nx * nprobe]);
......@@ -556,7 +557,7 @@ void IndexIVF::range_search (idx_t nx, const float *x, float radius,
invlists->prefetch_lists (keys.get(), nx * nprobe);
range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (),
result);
result, bitset);
indexIVF_stats.search_time += getmillisecs() - t0;
}
......@@ -564,7 +565,8 @@ void IndexIVF::range_search (idx_t nx, const float *x, float radius,
void IndexIVF::range_search_preassigned (
idx_t nx, const float *x, float radius,
const idx_t *keys, const float *coarse_dis,
RangeSearchResult *result) const
RangeSearchResult *result,
ConcurrentBitsetPtr bitset) const
{
size_t nlistv = 0, ndis = 0;
......@@ -601,7 +603,7 @@ void IndexIVF::range_search_preassigned (
nlistv++;
ndis += list_size;
scanner->scan_codes_range (list_size, scodes.get(),
ids.get(), radius, qres);
ids.get(), radius, qres, bitset);
};
if (parallel_mode == 0) {
......@@ -983,7 +985,8 @@ void InvertedListScanner::scan_codes_range (size_t ,
const uint8_t *,
const idx_t *,
float ,
RangeQueryResult &) const
RangeQueryResult &,
ConcurrentBitsetPtr) const
{
FAISS_THROW_MSG ("scan_codes_range not implemented");
}
......
......@@ -18,6 +18,7 @@
#include <faiss/InvertedLists.h>
#include <faiss/Clustering.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/ConcurrentBitset.h>
namespace faiss {
......@@ -193,11 +194,13 @@ struct IndexIVF: Index, Level1Quantizer {
ConcurrentBitsetPtr bitset = nullptr) override;
void range_search (idx_t n, const float* x, float radius,
RangeSearchResult* result) const override;
RangeSearchResult* result,
ConcurrentBitsetPtr bitset = nullptr) const override;
void range_search_preassigned(idx_t nx, const float *x, float radius,
const idx_t *keys, const float *coarse_dis,
RangeSearchResult *result) const;
RangeSearchResult *result,
ConcurrentBitsetPtr bitset = nullptr) const;
/// get a scanner for this index (store_pairs means ignore labels)
virtual InvertedListScanner *get_InvertedListScanner (
......@@ -342,7 +345,8 @@ struct InvertedListScanner {
const uint8_t *codes,
const idx_t *ids,
float radius,
RangeQueryResult &result) const;
RangeQueryResult &result,
ConcurrentBitsetPtr bitset = nullptr) const;
virtual ~InvertedListScanner () {}
......
......@@ -173,7 +173,8 @@ struct IVFFlatScanner: InvertedListScanner {
const uint8_t *codes,
const idx_t *ids,
float radius,
RangeQueryResult & res) const override
RangeQueryResult & res,
ConcurrentBitsetPtr bitset = nullptr) const override
{
const float *list_vecs = (const float*)codes;
for (size_t j = 0; j < list_size; j++) {
......@@ -483,7 +484,8 @@ void IndexIVFFlatDedup::range_search(
idx_t ,
const float* ,
float ,
RangeSearchResult* ) const
RangeSearchResult* ,
ConcurrentBitsetPtr) const
{
FAISS_THROW_MSG ("not implemented");
}
......
......@@ -97,7 +97,8 @@ struct IndexIVFFlatDedup: IndexIVFFlat {
idx_t n,
const float* x,
float radius,
RangeSearchResult* result) const override;
RangeSearchResult* result,
ConcurrentBitsetPtr bitset = nullptr) const override;
/// not implemented
void update_vectors (int nv, idx_t *idx, const float *v) override;
......
......@@ -800,10 +800,12 @@ struct KnnSearchResults {
size_t nup;
inline void add (idx_t j, float dis) {
inline void add (idx_t j, float dis, faiss::ConcurrentBitsetPtr bitset = nullptr) {
if (C::cmp (heap_sim[0], dis)) {
heap_pop<C> (k, heap_sim, heap_ids);
idx_t id = ids ? ids[j] : (key << 32 | j);
if (bitset != nullptr && bitset->test((faiss::ConcurrentBitset::id_type_t)id))
return;
heap_pop<C> (k, heap_sim, heap_ids);
heap_push<C> (k, heap_sim, heap_ids, dis, id);
nup++;
}
......@@ -820,7 +822,7 @@ struct RangeSearchResults {
float radius;
RangeQueryResult & rres;
inline void add (idx_t j, float dis) {
inline void add (idx_t j, float dis, faiss::ConcurrentBitsetPtr bitset = nullptr) {
if (C::cmp (radius, dis)) {
idx_t id = ids ? ids[j] : (key << 32 | j);
rres.add (dis, id);
......@@ -870,7 +872,8 @@ struct IVFPQScannerT: QueryTables {
/// version of the scan where we use precomputed tables
template<class SearchResultType>
void scan_list_with_table (size_t ncode, const uint8_t *codes,
SearchResultType & res) const
SearchResultType & res,
faiss::ConcurrentBitsetPtr bitset = nullptr) const
{
for (size_t j = 0; j < ncode; j++) {
......@@ -882,7 +885,7 @@ struct IVFPQScannerT: QueryTables {
tab += pq.ksub;
}
res.add(j, dis);
res.add(j, dis, bitset);
}
}
......@@ -891,7 +894,8 @@ struct IVFPQScannerT: QueryTables {
/// relevant X_c|x_r tables
template<class SearchResultType>
void scan_list_with_pointer (size_t ncode, const uint8_t *codes,
SearchResultType & res) const
SearchResultType & res,
faiss::ConcurrentBitsetPtr bitset = nullptr) const
{
for (size_t j = 0; j < ncode; j++) {
......@@ -903,7 +907,7 @@ struct IVFPQScannerT: QueryTables {
dis += sim_table_ptrs [m][ci] - 2 * tab [ci];
tab += pq.ksub;
}
res.add (j, dis);
res.add (j, dis, bitset);
}
}
......@@ -911,7 +915,8 @@ struct IVFPQScannerT: QueryTables {
/// nothing is precomputed: access residuals on-the-fly
template<class SearchResultType>
void scan_on_the_fly_dist (size_t ncode, const uint8_t *codes,
SearchResultType &res) const
SearchResultType &res,
faiss::ConcurrentBitsetPtr bitset = nullptr) const
{
const float *dvec;
float dis0 = 0;
......@@ -939,7 +944,7 @@ struct IVFPQScannerT: QueryTables {
} else {
dis = fvec_L2sqr (decoded_vec, dvec, d);
}
res.add (j, dis);
res.add (j, dis, bitset);
}
}
......@@ -950,7 +955,8 @@ struct IVFPQScannerT: QueryTables {
template <class HammingComputer, class SearchResultType>
void scan_list_polysemous_hc (
size_t ncode, const uint8_t *codes,
SearchResultType & res) const
SearchResultType & res,
faiss::ConcurrentBitsetPtr bitset = nullptr) const
{
int ht = ivfpq.polysemous_ht;
size_t n_hamming_pass = 0, nup = 0;
......@@ -973,7 +979,7 @@ struct IVFPQScannerT: QueryTables {
tab += pq.ksub;
}
res.add (j, dis);
res.add (j, dis, bitset);
}
codes += code_size;
}
......@@ -986,14 +992,15 @@ struct IVFPQScannerT: QueryTables {
template<class SearchResultType>
void scan_list_polysemous (
size_t ncode, const uint8_t *codes,
SearchResultType &res) const
SearchResultType &res,
faiss::ConcurrentBitsetPtr bitset = nullptr) const
{
switch (pq.code_size) {
#define HANDLE_CODE_SIZE(cs) \
case cs: \
scan_list_polysemous_hc \
<HammingComputer ## cs, SearchResultType> \
(ncode, codes, res); \
(ncode, codes, res, bitset); \
break
HANDLE_CODE_SIZE(4);
HANDLE_CODE_SIZE(8);
......@@ -1006,11 +1013,11 @@ struct IVFPQScannerT: QueryTables {
if (pq.code_size % 8 == 0)
scan_list_polysemous_hc
<HammingComputerM8, SearchResultType>
(ncode, codes, res);
(ncode, codes, res, bitset);
else
scan_list_polysemous_hc
<HammingComputerM4, SearchResultType>
(ncode, codes, res);
(ncode, codes, res, bitset);
break;
}
}
......@@ -1062,7 +1069,7 @@ struct IVFPQScanner:
const idx_t *ids,
float *heap_sim, idx_t *heap_ids,
size_t k,
ConcurrentBitsetPtr bitset) const override
faiss::ConcurrentBitsetPtr bitset) const override
{
KnnSearchResults<C> res = {
/* key */ this->key,
......@@ -1075,13 +1082,13 @@ struct IVFPQScanner:
if (this->polysemous_ht > 0) {
assert(precompute_mode == 2);
this->scan_list_polysemous (ncode, codes, res);
this->scan_list_polysemous (ncode, codes, res, bitset);
} else if (precompute_mode == 2) {
this->scan_list_with_table (ncode, codes, res);
this->scan_list_with_table (ncode, codes, res, bitset);
} else if (precompute_mode == 1) {
this->scan_list_with_pointer (ncode, codes, res);
this->scan_list_with_pointer (ncode, codes, res, bitset);
} else if (precompute_mode == 0) {
this->scan_on_the_fly_dist (ncode, codes, res);
this->scan_on_the_fly_dist (ncode, codes, res, bitset);
} else {
FAISS_THROW_MSG("bad precomp mode");
}
......@@ -1092,7 +1099,8 @@ struct IVFPQScanner:
const uint8_t *codes,
const idx_t *ids,
float radius,
RangeQueryResult & rres) const override
RangeQueryResult & rres,
faiss::ConcurrentBitsetPtr bitset = nullptr) const override
{
RangeSearchResults<C> res = {
/* key */ this->key,
......@@ -1103,13 +1111,13 @@ struct IVFPQScanner:
if (this->polysemous_ht > 0) {
assert(precompute_mode == 2);
this->scan_list_polysemous (ncode, codes, res);
this->scan_list_polysemous (ncode, codes, res, bitset);
} else if (precompute_mode == 2) {
this->scan_list_with_table (ncode, codes, res);
this->scan_list_with_table (ncode, codes, res, bitset);
} else if (precompute_mode == 1) {
this->scan_list_with_pointer (ncode, codes, res);
this->scan_list_with_pointer (ncode, codes, res, bitset);
} else if (precompute_mode == 0) {
this->scan_on_the_fly_dist (ncode, codes, res);
this->scan_on_the_fly_dist (ncode, codes, res, bitset);
} else {
FAISS_THROW_MSG("bad precomp mode");
}
......
......@@ -285,7 +285,8 @@ struct IVFScanner: InvertedListScanner {
const uint8_t *codes,
const idx_t *ids,
float radius,
RangeQueryResult & res) const override
RangeQueryResult & res,
ConcurrentBitsetPtr bitset = nullptr) const override
{
for (size_t j = 0; j < list_size; j++) {
float dis = hc.hamming (codes);
......
......@@ -190,7 +190,8 @@ void IndexPreTransform::search (idx_t n, const float *x, idx_t k,
}
void IndexPreTransform::range_search (idx_t n, const float* x, float radius,
RangeSearchResult* result) const
RangeSearchResult* result,
ConcurrentBitsetPtr bitset) const
{
FAISS_THROW_IF_NOT (is_trained);
const float *xt = apply_chain (n, x);
......
......@@ -57,7 +57,8 @@ struct IndexPreTransform: Index {
/* range search, no attempt is done to change the radius */
void range_search (idx_t n, const float* x, float radius,
RangeSearchResult* result) const override;
RangeSearchResult* result,
ConcurrentBitsetPtr bitset = nullptr) const override;
void reconstruct (idx_t key, float * recons) const override;
......
......@@ -121,9 +121,9 @@ void IndexIDMapTemplate<IndexT>::search_by_id (idx_t n, const idx_t *xid, idx_t
template <typename IndexT>
void IndexIDMapTemplate<IndexT>::range_search
(typename IndexT::idx_t n, const typename IndexT::component_t *x,
typename IndexT::distance_t radius, RangeSearchResult *result) const
typename IndexT::distance_t radius, RangeSearchResult *result, ConcurrentBitsetPtr bitset) const
{
index->range_search(n, x, radius, result);
index->range_search(n, x, radius, result, bitset);
#pragma omp parallel for
for (idx_t i = 0; i < result->lims[result->nq]; i++) {
result->labels[i] = result->labels[i] < 0 ?
......
......@@ -53,7 +53,8 @@ struct IndexIDMapTemplate : IndexT {
size_t remove_ids(const IDSelector& sel) override;
void range_search (idx_t n, const component_t *x, distance_t radius,
RangeSearchResult *result) const override;
RangeSearchResult *result,
ConcurrentBitsetPtr bitset = nullptr) const override;
~IndexIDMapTemplate () override;
IndexIDMapTemplate () {own_fields=false; index=nullptr; }
......
......@@ -246,7 +246,8 @@ struct IVFSQScannerIP: InvertedListScanner {
const uint8_t *codes,
const idx_t *ids,
float radius,
RangeQueryResult & res) const override
RangeQueryResult & res,
ConcurrentBitsetPtr bitset = nullptr) const override
{
for (size_t j = 0; j < list_size; j++) {
float accu = accu0 + dc.query_to_code (codes);
......@@ -333,7 +334,8 @@ struct IVFSQScannerL2: InvertedListScanner {
const uint8_t *codes,
const idx_t *ids,
float radius,
RangeQueryResult & res) const override
RangeQueryResult & res,
ConcurrentBitsetPtr bitset = nullptr) const override
{
for (size_t j = 0; j < list_size; j++) {
float dis = dc.query_to_code (codes);
......
......@@ -10,12 +10,14 @@
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>
#include "../../utils/ConcurrentBitset.h"
int main() {
int d = 64; // dimension
int nb = 100000; // database size
int nq = 10000; // nb of queries
int nq = 10;//10000; // nb of queries
faiss::ConcurrentBitsetPtr bitset = std::make_shared<faiss::ConcurrentBitset>(nb);
float *xb = new float[d * nb];
float *xq = new float[d * nq];
......@@ -26,11 +28,17 @@ int main() {
xb[d * i] += i / 1000.;
}
srand((unsigned)time(NULL));
printf("delete ids: \n");
for(int i = 0; i < nq; i++) {
auto tmp = rand()%nb;
bitset->set(tmp);
printf("%d ", tmp);
for(int j = 0; j < d; j++)
xq[d * i + j] = drand48();
xq[d * i] += i / 1000.;
xq[d * i + j] = xb[d * tmp + j];
// xq[d * i] += i / 1000.;
}
printf("\n");
int nlist = 100;
......@@ -42,6 +50,7 @@ int main() {
index.train(nb, xb);
index.add(nb, xb);
printf("------------sanity check----------------\n");
{ // sanity check
long *I = new long[k * 5];
float *D = new float[k * 5];
......@@ -66,6 +75,7 @@ int main() {
delete [] D;
}
printf("---------------search xq-------------\n");
{ // search xq
long *I = new long[k * nq];
float *D = new float[k * nq];
......@@ -74,7 +84,26 @@ int main() {
index.search(nq, xq, k, D, I);
printf("I=\n");
for(int i = nq - 5; i < nq; i++) {
for(int i = 0; i < nq; i++) {
for(int j = 0; j < k; j++)
printf("%5ld ", I[i * k + j]);
printf("\n");
}
delete [] I;
delete [] D;
}
printf("----------------search xq with delete------------\n");
{ // search xq with delete
long *I = new long[k * nq];
float *D = new float[k * nq];
index.nprobe = 10;
index.search(nq, xq, k, D, I, bitset);
printf("I=\n");
for(int i = 0; i < nq; i++) {
for(int j = 0; j < k; j++)
printf("%5ld ", I[i * k + j]);
printf("\n");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册