未验证 提交 1676f7cf 编写于 作者: S shengjun.li 提交者: GitHub

fix GPU search (#2455)

Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>
上级 8878951b
......@@ -9,6 +9,7 @@ Please mark all change in change log and use the issue from GitHub
- \#2395 Fix large nq cudaMalloc error
- \#2399 The nlist set by the user may not take effect
- \#2403 MySQL max_idle_time is 10 by default
- \#2450 The deleted vectors may be found on GPU
## Feature
......
......@@ -75,10 +75,9 @@ pass1SelectLists(void** listIndices,
topQueryToCentroid,
opt);
if (bitsetEmpty || (!(bitset[index >> 3] & (0x1 << (index & 0x7))))) {
heap.add(distanceStart[i], start + i);
} else {
heap.add((1.0 / 0.0), start + i);
heap.addThreadQ(distanceStart[i], start + i);
}
heap.checkThreadQ();
}
// Handle warp divergence separately
......@@ -91,8 +90,6 @@ pass1SelectLists(void** listIndices,
opt);
if (bitsetEmpty || (!(bitset[index >> 3] & (0x1 << (index & 0x7))))) {
heap.addThreadQ(distanceStart[i], start + i);
} else {
heap.addThreadQ((1.0 / 0.0), start + i);
}
}
......
......@@ -156,22 +156,17 @@ __global__ void l2SelectMinK(Tensor<T, 2, true> productDistances,
if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
v = Math<T>::add(centroidDistances[i],
productDistances[row][i]);
} else {
v = (T)(1.0 / 0.0);
heap.addThreadQ(v, i);
}
heap.add(v, i);
heap.checkThreadQ();
}
if (i < productDistances.getSize(1)) {
if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
v = Math<T>::add(centroidDistances[i],
productDistances[row][i]);
} else {
v = (T)(1.0 / 0.0);
heap.addThreadQ(v, i);
}
heap.addThreadQ(v, i);
}
heap.reduce();
......
......@@ -146,11 +146,10 @@ __global__ void blockSelect(Tensor<K, 2, true> in,
for (; i < limit; i += ThreadsPerBlock) {
if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
heap.add(*inStart, (IndexType) i);
} else {
heap.add(-1.0, (IndexType) i);
heap.addThreadQ(*inStart, (IndexType) i);
}
heap.checkThreadQ();
inStart += ThreadsPerBlock;
}
......@@ -158,8 +157,6 @@ __global__ void blockSelect(Tensor<K, 2, true> in,
if (i < in.getSize(1)) {
if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
heap.addThreadQ(*inStart, (IndexType) i);
} else {
heap.addThreadQ(-1.0, (IndexType) i);
}
}
......@@ -208,10 +205,9 @@ __global__ void blockSelectPair(Tensor<K, 2, true> inK,
for (; i < limit; i += ThreadsPerBlock) {
if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
heap.add(*inKStart, *inVStart);
} else {
heap.add(-1.0, *inVStart);
heap.addThreadQ(*inKStart, *inVStart);
}
heap.checkThreadQ();
inKStart += ThreadsPerBlock;
inVStart += ThreadsPerBlock;
......@@ -221,8 +217,6 @@ __global__ void blockSelectPair(Tensor<K, 2, true> inK,
if (i < inK.getSize(1)) {
if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) {
heap.addThreadQ(*inKStart, *inVStart);
} else {
heap.addThreadQ(-1.0, *inVStart);
}
}
......
......@@ -283,15 +283,6 @@ XSearchTask::Execute() {
{
std::unique_lock<std::mutex> lock(search_job->mutex());
if (search_job->GetResultIds().size() > spec_k) {
if (search_job->GetResultIds().front() == -1) {
// initialized results set
search_job->GetResultIds().resize(spec_k * nq);
search_job->GetResultDistances().resize(spec_k * nq);
}
}
search_job->vector_count() = nq;
XSearchTask::MergeTopkToResultSet(output_ids, output_distance, spec_k, nq, topk, ascending_reduce,
search_job->GetResultIds(), search_job->GetResultDistances());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册