未验证 提交 b04b3e94 编写于 作者: S shengjun.li 提交者: GitHub

#1897 add heap_swap_top (#1898)

* add heap_swap_top
Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>

* fix wrong code
Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>
上级 3bc17d8c
......@@ -10,11 +10,10 @@ Please mark all change in change log and use the issue from GitHub
- \#1789 Fix multi-client search cause server crash
- \#1832 Fix crash in tracing module
- \#1873 Fix index file serialize to incorrect path
- \#1881 Fix Annoy index search failure
- \#1881 Fix bad alloc when index files lost
## Feature
- \#261 Integrate ANNOY into Milvus
- \#1603 BinaryFlat add 2 Metric: Substructure and Superstructure
- \#1655 GPU index support delete vectors
- \#1660 IVF PQ CPU support deleted vectors searching
- \#1661 HNSW support deleted vectors searching
......@@ -29,6 +28,7 @@ Please mark all change in change log and use the issue from GitHub
- \#1882 Add index annoy into http module
- \#1885 Optimize knowhere unittest
- \#1886 Refactor log on search and insert request
- \#1897 Heap pop and push can be realized by heap_swap_top
## Task
......
......@@ -420,9 +420,8 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
uint32_t dis = hc.hamming (codes);
if (dis < simi[0]) {
heap_pop<C> (k, simi, idxi);
idx_t id = store_pairs ? (list_no << 32 | j) : ids[j];
heap_push<C> (k, simi, idxi, dis, id);
heap_swap_top<C> (k, simi, idxi, dis, id);
nup++;
}
}
......@@ -470,9 +469,8 @@ struct IVFBinaryScannerJaccard: BinaryInvertedListScanner {
float dis = hc.compute (codes);
if (dis < psimi[0]) {
heap_pop<C> (k, psimi, idxi);
idx_t id = store_pairs ? (list_no << 32 | j) : ids[j];
heap_push<C> (k, psimi, idxi, dis, id);
heap_swap_top<C> (k, psimi, idxi, dis, id);
nup++;
}
}
......
......@@ -159,9 +159,8 @@ struct IVFFlatScanner: InvertedListScanner {
float dis = metric == METRIC_INNER_PRODUCT ?
fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
if (C::cmp (simi[0], dis)) {
heap_pop<C> (k, simi, idxi);
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
heap_push<C> (k, simi, idxi, dis, id);
heap_swap_top<C> (k, simi, idxi, dis, id);
nup++;
}
}
......
......@@ -805,8 +805,7 @@ struct KnnSearchResults {
idx_t id = ids ? ids[j] : (key << 32 | j);
if (bitset != nullptr && bitset->test((faiss::ConcurrentBitset::id_type_t)id))
return;
heap_pop<C> (k, heap_sim, heap_ids);
heap_push<C> (k, heap_sim, heap_ids, dis, id);
heap_swap_top<C> (k, heap_sim, heap_ids, dis, id);
nup++;
}
}
......
......@@ -171,9 +171,8 @@ void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k,
float dis = fvec_L2sqr (residual_1, residual_2, d);
if (dis < heap_sim[0]) {
maxheap_pop (k, heap_sim, heap_ids);
idx_t id_or_pair = store_pairs ? sl : id;
maxheap_push (k, heap_sim, heap_ids, dis, id_or_pair);
maxheap_swap_top (k, heap_sim, heap_ids, dis, id_or_pair);
}
n_refine ++;
}
......
......@@ -270,9 +270,8 @@ struct IVFScanner: InvertedListScanner {
float dis = hc.hamming (codes);
if (dis < simi [0]) {
maxheap_pop (k, simi, idxi);
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
maxheap_push (k, simi, idxi, dis, id);
maxheap_swap_top (k, simi, idxi, dis, id);
nup++;
}
}
......
......@@ -330,8 +330,7 @@ static size_t polysemous_inner_loop (
}
if (dis < heap_dis[0]) {
maxheap_pop (k, heap_dis, heap_ids);
maxheap_push (k, heap_dis, heap_ids, dis, bi);
maxheap_swap_top (k, heap_dis, heap_ids, dis, bi);
}
}
b_code += code_size;
......
......@@ -63,8 +63,7 @@ void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
}
if (C::cmp (heap_dis[0], dis)) {
heap_pop<C> (k, heap_dis, heap_ids);
heap_push<C> (k, heap_dis, heap_ids, dis, j);
heap_swap_top<C> (k, heap_dis, heap_ids, dis, j);
}
}
}
......@@ -89,8 +88,7 @@ void pq_estimators_from_tables_M4 (const CT * codes,
dis += dt[*codes++];
if (C::cmp (heap_dis[0], dis)) {
heap_pop<C> (k, heap_dis, heap_ids);
heap_push<C> (k, heap_dis, heap_ids, dis, j);
heap_swap_top<C> (k, heap_dis, heap_ids, dis, j);
}
}
}
......@@ -132,8 +130,7 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
dt += ksub;
}
if (C::cmp (heap_dis[0], dis)) {
heap_pop<C> (k, heap_dis, heap_ids);
heap_push<C> (k, heap_dis, heap_ids, dis, j);
heap_swap_top<C> (k, heap_dis, heap_ids, dis, j);
}
}
}
......@@ -163,8 +160,7 @@ static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq,
}
if (C::cmp(heap_dis[0], dis)) {
heap_pop<C>(k, heap_dis, heap_ids);
heap_push<C>(k, heap_dis, heap_ids, dis, j);
heap_swap_top<C>(k, heap_dis, heap_ids, dis, j);
}
}
}
......@@ -747,8 +743,7 @@ void ProductQuantizer::search_sdc (const uint8_t * qcodes,
tab += ksub * ksub;
}
if (dis < heap_dis[0]) {
maxheap_pop (k, heap_dis, heap_ids);
maxheap_push (k, heap_dis, heap_ids, dis, j);
maxheap_swap_top (k, heap_dis, heap_ids, dis, j);
}
bcode += code_size;
}
......
......@@ -231,9 +231,8 @@ struct IVFSQScannerIP: InvertedListScanner {
float accu = accu0 + dc.query_to_code (codes);
if (accu > simi [0]) {
minheap_pop (k, simi, idxi);
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
minheap_push (k, simi, idxi, accu, id);
minheap_swap_top (k, simi, idxi, accu, id);
nup++;
}
}
......@@ -319,9 +318,8 @@ struct IVFSQScannerL2: InvertedListScanner {
float dis = dc.query_to_code (codes);
if (dis < simi [0]) {
maxheap_pop (k, simi, idxi);
int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
maxheap_push (k, simi, idxi, dis, id);
maxheap_swap_top (k, simi, idxi, dis, id);
nup++;
}
}
......
......@@ -34,12 +34,12 @@ void binary_distence_knn_hc(
if ((bytes_per_code + k * (sizeof(float) + sizeof(int64_t))) * ha->nh < size_1M) {
int thread_max_num = omp_get_max_threads();
// init hash
size_t thread_hash_size = ha->nh * k;
size_t all_hash_size = thread_hash_size * thread_max_num;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (int i = 0; i < all_hash_size; i++) {
// init heap
size_t thread_heap_size = ha->nh * k;
size_t all_heap_size = thread_heap_size * thread_max_num;
float *value = new float[all_heap_size];
int64_t *labels = new int64_t[all_heap_size];
for (int i = 0; i < all_heap_size; i++) {
value[i] = 1.0 / 0.0;
labels[i] = -1;
}
......@@ -58,35 +58,33 @@ void binary_distence_knn_hc(
for (size_t i = 0; i < ha->nh; i++) {
tadis_t dis = hc[i].compute (bs2_);
float * val_ = value + thread_no * thread_hash_size + i * k;
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
float * val_ = value + thread_no * thread_heap_size + i * k;
int64_t * ids_ = labels + thread_no * thread_heap_size + i * k;
if (dis < val_[0]) {
faiss::maxheap_pop<tadis_t> (k, val_, ids_);
faiss::maxheap_push<tadis_t> (k, val_, ids_, dis, j);
faiss::maxheap_swap_top<tadis_t> (k, val_, ids_, dis, j);
}
}
}
}
for (size_t t = 1; t < thread_max_num; t++) {
// merge hash
// merge heap
for (size_t i = 0; i < ha->nh; i++) {
float * __restrict value_x = value + i * k;
int64_t * __restrict labels_x = labels + i * k;
float *value_x_t = value_x + t * thread_hash_size;
int64_t *labels_x_t = labels_x + t * thread_hash_size;
float *value_x_t = value_x + t * thread_heap_size;
int64_t *labels_x_t = labels_x + t * thread_heap_size;
for (size_t j = 0; j < k; j++) {
if (value_x_t[j] < value_x[0]) {
faiss::maxheap_pop<tadis_t> (k, value_x, labels_x);
faiss::maxheap_push<tadis_t> (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
faiss::maxheap_swap_top<tadis_t> (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
}
}
}
}
// copy result
memcpy(ha->val, value, thread_hash_size * sizeof(float));
memcpy(ha->ids, labels, thread_hash_size * sizeof(int64_t));
memcpy(ha->val, value, thread_heap_size * sizeof(float));
memcpy(ha->ids, labels, thread_heap_size * sizeof(int64_t));
delete[] hc;
delete[] value;
......@@ -111,8 +109,7 @@ void binary_distence_knn_hc(
if(!bitset || !bitset->test(j)){
dis = hc.compute (bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_pop<tadis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<tadis_t> (k, bh_val_, bh_ids_, dis, j);
faiss::maxheap_swap_top<tadis_t> (k, bh_val_, bh_ids_, dis, j);
}
}
}
......
......@@ -46,8 +46,7 @@ void HeapArray<C>::addn (size_t nj, const T *vin, TI j0,
for (size_t j = 0; j < nj; j++) {
T ip = ip_line [j];
if (C::cmp(simi[0], ip)) {
heap_pop<C> (k, simi, idxi);
heap_push<C> (k, simi, idxi, ip, j + j0);
heap_swap_top<C> (k, simi, idxi, ip, j + j0);
}
}
}
......@@ -74,8 +73,7 @@ void HeapArray<C>::addn_with_ids (
for (size_t j = 0; j < nj; j++) {
T ip = ip_line [j];
if (C::cmp(simi[0], ip)) {
heap_pop<C> (k, simi, idxi);
heap_push<C> (k, simi, idxi, ip, id_line [j]);
heap_swap_top<C> (k, simi, idxi, ip, id_line [j]);
}
}
}
......
......@@ -83,6 +83,42 @@ struct CMax {
* Basic heap ops: push and pop
*******************************************************************/
/** Pops the top element from the heap defined by bh_val[0..k-1] and
* bh_ids[0..k-1]. on output the element at k-1 is undefined.
*/
template <class C> inline
void heap_swap_top (size_t k,
typename C::T * bh_val, typename C::TI * bh_ids,
typename C::T val, typename C::TI ids)
{
bh_val--; /* Use 1-based indexing for easier node->child translation */
bh_ids--;
size_t i = 1, i1, i2;
while (1) {
i1 = i << 1;
i2 = i1 + 1;
if (i1 > k)
break;
if (i2 == k + 1 || C::cmp(bh_val[i1], bh_val[i2])) {
if (C::cmp(val, bh_val[i1]))
break;
bh_val[i] = bh_val[i1];
bh_ids[i] = bh_ids[i1];
i = i1;
}
else {
if (C::cmp(val, bh_val[i2]))
break;
bh_val[i] = bh_val[i2];
bh_ids[i] = bh_ids[i2];
i = i2;
}
}
bh_val[i] = val;
bh_ids[i] = ids;
}
/** Pops the top element from the heap defined by bh_val[0..k-1] and
* bh_ids[0..k-1]. on output the element at k-1 is undefined.
*/
......@@ -146,6 +182,13 @@ void heap_push (size_t k,
/* Partial instanciation for heaps with TI = int64_t */
template <typename T> inline
void minheap_swap_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
{
heap_swap_top<CMin<T, int64_t> > (k, bh_val, bh_ids, val, ids);
}
template <typename T> inline
void minheap_pop (size_t k, T * bh_val, int64_t * bh_ids)
{
......@@ -160,6 +203,13 @@ void minheap_push (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
}
template <typename T> inline
void maxheap_swap_top (size_t k, T * bh_val, int64_t * bh_ids, T val, int64_t ids)
{
heap_swap_top<CMax<T, int64_t> > (k, bh_val, bh_ids, val, ids);
}
template <typename T> inline
void maxheap_pop (size_t k, T * bh_val, int64_t * bh_ids)
{
......@@ -251,15 +301,13 @@ void heap_addn (size_t k,
if (ids)
for (i = 0; i < n; i++) {
if (C::cmp (bh_val[0], x[i])) {
heap_pop<C> (k, bh_val, bh_ids);
heap_push<C> (k, bh_val, bh_ids, x[i], ids[i]);
heap_swap_top<C> (k, bh_val, bh_ids, x[i], ids[i]);
}
}
else
for (i = 0; i < n; i++) {
if (C::cmp (bh_val[0], x[i])) {
heap_pop<C> (k, bh_val, bh_ids);
heap_push<C> (k, bh_val, bh_ids, x[i], i);
heap_swap_top<C> (k, bh_val, bh_ids, x[i], i);
}
}
}
......
......@@ -155,13 +155,13 @@ static void knn_inner_product_sse (const float * x,
size_t thread_max_num = omp_get_max_threads();
size_t thread_hash_size = nx * k;
size_t all_hash_size = thread_hash_size * thread_max_num;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
size_t thread_heap_size = nx * k;
size_t all_heap_size = thread_heap_size * thread_max_num;
float *value = new float[all_heap_size];
int64_t *labels = new int64_t[all_heap_size];
// init hash
for (size_t i = 0; i < all_hash_size; i++) {
// init heap
for (size_t i = 0; i < all_heap_size; i++) {
value[i] = -1.0 / 0.0;
labels[i] = -1;
}
......@@ -175,27 +175,25 @@ static void knn_inner_product_sse (const float * x,
const float *x_i = x + i * d;
float ip = fvec_inner_product (x_i, y_j, d);
float * val_ = value + thread_no * thread_hash_size + i * k;
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
float * val_ = value + thread_no * thread_heap_size + i * k;
int64_t * ids_ = labels + thread_no * thread_heap_size + i * k;
if (ip > val_[0]) {
minheap_pop (k, val_, ids_);
minheap_push (k, val_, ids_, ip, j);
minheap_swap_top (k, val_, ids_, ip, j);
}
}
}
}
for (size_t t = 1; t < thread_max_num; t++) {
// merge hash
// merge heap
for (size_t i = 0; i < nx; i++) {
float * __restrict value_x = value + i * k;
int64_t * __restrict labels_x = labels + i * k;
float *value_x_t = value_x + t * thread_hash_size;
int64_t *labels_x_t = labels_x + t * thread_hash_size;
float *value_x_t = value_x + t * thread_heap_size;
int64_t *labels_x_t = labels_x + t * thread_heap_size;
for (size_t j = 0; j < k; j++) {
if (value_x_t[j] > value_x[0]) {
minheap_pop (k, value_x, labels_x);
minheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
minheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
}
}
}
......@@ -208,8 +206,8 @@ static void knn_inner_product_sse (const float * x,
}
// copy result
memcpy(res->val, value, thread_hash_size * sizeof(float));
memcpy(res->ids, labels, thread_hash_size * sizeof(int64_t));
memcpy(res->val, value, thread_heap_size * sizeof(float));
memcpy(res->ids, labels, thread_heap_size * sizeof(int64_t));
delete[] value;
delete[] labels;
......@@ -262,13 +260,13 @@ static void knn_L2sqr_sse (
size_t thread_max_num = omp_get_max_threads();
size_t thread_hash_size = nx * k;
size_t all_hash_size = thread_hash_size * thread_max_num;
float *value = new float[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
size_t thread_heap_size = nx * k;
size_t all_heap_size = thread_heap_size * thread_max_num;
float *value = new float[all_heap_size];
int64_t *labels = new int64_t[all_heap_size];
// init hash
for (size_t i = 0; i < all_hash_size; i++) {
// init heap
for (size_t i = 0; i < all_heap_size; i++) {
value[i] = 1.0 / 0.0;
labels[i] = -1;
}
......@@ -282,27 +280,25 @@ static void knn_L2sqr_sse (
const float *x_i = x + i * d;
float disij = fvec_L2sqr (x_i, y_j, d);
float * val_ = value + thread_no * thread_hash_size + i * k;
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
float * val_ = value + thread_no * thread_heap_size + i * k;
int64_t * ids_ = labels + thread_no * thread_heap_size + i * k;
if (disij < val_[0]) {
maxheap_pop (k, val_, ids_);
maxheap_push (k, val_, ids_, disij, j);
maxheap_swap_top (k, val_, ids_, disij, j);
}
}
}
}
for (size_t t = 1; t < thread_max_num; t++) {
// merge hash
// merge heap
for (size_t i = 0; i < nx; i++) {
float * __restrict value_x = value + i * k;
int64_t * __restrict labels_x = labels + i * k;
float *value_x_t = value_x + t * thread_hash_size;
int64_t *labels_x_t = labels_x + t * thread_hash_size;
float *value_x_t = value_x + t * thread_heap_size;
int64_t *labels_x_t = labels_x + t * thread_heap_size;
for (size_t j = 0; j < k; j++) {
if (value_x_t[j] < value_x[0]) {
maxheap_pop (k, value_x, labels_x);
maxheap_push (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
maxheap_swap_top (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
}
}
}
......@@ -315,8 +311,8 @@ static void knn_L2sqr_sse (
}
// copy result
memcpy(res->val, value, thread_hash_size * sizeof(float));
memcpy(res->ids, labels, thread_hash_size * sizeof(int64_t));
memcpy(res->val, value, thread_heap_size * sizeof(float));
memcpy(res->ids, labels, thread_heap_size * sizeof(int64_t));
delete[] value;
delete[] labels;
......@@ -408,8 +404,7 @@ static void knn_inner_product_blas (
float dis = *ip_line;
if(dis > simi[0]){
minheap_pop(k, simi, idxi);
minheap_push(k, simi, idxi, dis, j);
minheap_swap_top(k, simi, idxi, dis, j);
}
}
ip_line++;
......@@ -486,8 +481,7 @@ static void knn_L2sqr_blas (const float * x,
dis = corr (dis, i, j);
if (dis < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, dis, j);
maxheap_swap_top (k, simi, idxi, dis, j);
}
}
ip_line++;
......@@ -563,8 +557,7 @@ static void knn_jaccard_blas (const float * x,
dis = corr (dis, i, j);
if (dis < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, dis, j);
maxheap_swap_top (k, simi, idxi, dis, j);
}
}
ip_line++;
......@@ -638,20 +631,6 @@ void knn_jaccard (const float * x,
}
}
void knn_jaccard (const float * x,
const float * y,
size_t d, size_t nx, size_t ny,
float_maxheap_array_t * res)
{
if (d % 4 == 0 && nx < distance_compute_blas_threshold) {
// knn_jaccard_sse (x, y, d, nx, ny, res);
printf("sse_not implemented!\n");
} else {
NopDistanceCorrection nop;
knn_jaccard_blas (x, y, d, nx, ny, res, nop);
}
}
struct BaseShiftDistanceCorrection {
const float *base_shift;
float operator()(float dis, size_t /*qno*/, size_t bno) const {
......@@ -773,8 +752,7 @@ void knn_inner_products_by_idx (const float * x,
float ip = fvec_inner_product (x_, y + d * idsi[j], d);
if (ip > simi[0]) {
minheap_pop (k, simi, idxi);
minheap_push (k, simi, idxi, ip, idsi[j]);
minheap_swap_top (k, simi, idxi, ip, idsi[j]);
}
}
minheap_reorder (k, simi, idxi);
......@@ -801,8 +779,7 @@ void knn_L2sqr_by_idx (const float * x,
float disij = fvec_L2sqr (x_, y + d * idsi[j], d);
if (disij < simi[0]) {
maxheap_pop (k, simi, idxi);
maxheap_push (k, simi, idxi, disij, idsi[j]);
maxheap_swap_top (k, simi, idxi, disij, idsi[j]);
}
}
maxheap_reorder (res->k, simi, idxi);
......
......@@ -281,12 +281,12 @@ void hammings_knn_hc (
if ((bytes_per_code + k * (sizeof(hamdis_t) + sizeof(int64_t))) * ha->nh < size_1M) {
int thread_max_num = omp_get_max_threads();
// init hash
size_t thread_hash_size = ha->nh * k;
size_t all_hash_size = thread_hash_size * thread_max_num;
hamdis_t *value = new hamdis_t[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
for (int i = 0; i < all_hash_size; i++) {
// init heap
size_t thread_heap_size = ha->nh * k;
size_t all_heap_size = thread_heap_size * thread_max_num;
hamdis_t *value = new hamdis_t[all_heap_size];
int64_t *labels = new int64_t[all_heap_size];
for (int i = 0; i < all_heap_size; i++) {
value[i] = 0x7fffffff;
labels[i] = -1;
}
......@@ -305,35 +305,33 @@ void hammings_knn_hc (
for (size_t i = 0; i < ha->nh; i++) {
hamdis_t dis = hc[i].hamming (bs2_);
hamdis_t * val_ = value + thread_no * thread_hash_size + i * k;
int64_t * ids_ = labels + thread_no * thread_hash_size + i * k;
hamdis_t * val_ = value + thread_no * thread_heap_size + i * k;
int64_t * ids_ = labels + thread_no * thread_heap_size + i * k;
if (dis < val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, val_, ids_);
faiss::maxheap_push<hamdis_t> (k, val_, ids_, dis, j);
faiss::maxheap_swap_top<hamdis_t> (k, val_, ids_, dis, j);
}
}
}
}
for (size_t t = 1; t < thread_max_num; t++) {
// merge hash
// merge heap
for (size_t i = 0; i < ha->nh; i++) {
hamdis_t * __restrict value_x = value + i * k;
int64_t * __restrict labels_x = labels + i * k;
hamdis_t *value_x_t = value_x + t * thread_hash_size;
int64_t *labels_x_t = labels_x + t * thread_hash_size;
hamdis_t *value_x_t = value_x + t * thread_heap_size;
int64_t *labels_x_t = labels_x + t * thread_heap_size;
for (size_t j = 0; j < k; j++) {
if (value_x_t[j] < value_x[0]) {
faiss::maxheap_pop<hamdis_t> (k, value_x, labels_x);
faiss::maxheap_push<hamdis_t> (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
faiss::maxheap_swap_top<hamdis_t> (k, value_x, labels_x, value_x_t[j], labels_x_t[j]);
}
}
}
}
// copy result
memcpy(ha->val, value, thread_hash_size * sizeof(hamdis_t));
memcpy(ha->ids, labels, thread_hash_size * sizeof(int64_t));
memcpy(ha->val, value, thread_heap_size * sizeof(hamdis_t));
memcpy(ha->ids, labels, thread_heap_size * sizeof(int64_t));
delete[] hc;
delete[] value;
......@@ -357,8 +355,7 @@ void hammings_knn_hc (
if(!bitset || !bitset->test(j)){
dis = hc.hamming (bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
faiss::maxheap_swap_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
}
}
}
......@@ -452,12 +449,12 @@ void hammings_knn_hc_1 (
int thread_max_num = omp_get_max_threads();
if (ha->nh == 1) {
// omp for n2
int all_hash_size = thread_max_num * k;
hamdis_t *value = new hamdis_t[all_hash_size];
int64_t *labels = new int64_t[all_hash_size];
int all_heap_size = thread_max_num * k;
hamdis_t *value = new hamdis_t[all_heap_size];
int64_t *labels = new int64_t[all_heap_size];
// init hash
for (int i = 0; i < all_hash_size; i++) {
// init heap
for (int i = 0; i < all_heap_size; i++) {
value[i] = 0x7fffffff;
}
const uint64_t bs1_ = bs1[0];
......@@ -470,18 +467,16 @@ void hammings_knn_hc_1 (
hamdis_t * __restrict val_ = value + thread_no * k;
int64_t * __restrict ids_ = labels + thread_no * k;
if (dis < val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, val_, ids_);
faiss::maxheap_push<hamdis_t> (k, val_, ids_, dis, j);
faiss::maxheap_swap_top<hamdis_t> (k, val_, ids_, dis, j);
}
}
}
// merge hash
// merge heap
hamdis_t * __restrict bh_val_ = ha->val;
int64_t * __restrict bh_ids_ = ha->ids;
for (int i = 0; i < all_hash_size; i++) {
for (int i = 0; i < all_heap_size; i++) {
if (value[i] < bh_val_[0]) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
faiss::maxheap_swap_top<hamdis_t> (k, bh_val_, bh_ids_, value[i], labels[i]);
}
}
......@@ -502,8 +497,7 @@ void hammings_knn_hc_1 (
if(!bitset || !bitset->test(j)){
dis = popcount64 (bs1_ ^ *bs2_);
if (dis < bh_val_0) {
faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
faiss::maxheap_swap_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
bh_val_0 = bh_val_[0];
}
}
......@@ -849,8 +843,7 @@ static void hamming_dis_inner_loop (
int ndiff = hc.hamming (cb);
cb += code_size;
if (ndiff < bh_val_[0]) {
maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, ndiff, j);
maxheap_swap_top<hamdis_t> (k, bh_val_, bh_ids_, ndiff, j);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册