未验证 提交 887b7f7f 编写于 作者: S shengjun.li 提交者: GitHub

fix too many copies (#2661)

Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>
上级 ecc7d839
......@@ -40,14 +40,11 @@ GPUIVF::Train(const DatasetPtr& dataset_ptr, const Config& config) {
idx_config.device = gpu_id_;
int32_t nlist = config[IndexParams::nlist];
faiss::MetricType metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
faiss::gpu::GpuIndexIVFFlat device_index(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config);
device_index.train(rows, (float*)p_data);
auto device_index =
new faiss::gpu::GpuIndexIVFFlat(gpu_res->faiss_res.get(), dim, nlist, metric_type, idx_config);
device_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(&device_index));
auto device_index1 = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
index_.reset(device_index1);
index_.reset(device_index);
res_ = gpu_res;
} else {
KNOWHERE_THROW_MSG("Build IVF can't get gpu resource");
......
......@@ -38,11 +38,8 @@ GPUIVFPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
config[IndexParams::m], config[IndexParams::nbits],
GetMetricType(config[Metric::TYPE].get<std::string>())); // IP not support
device_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
auto device_index1 = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
index_.reset(device_index1);
index_.reset(device_index);
res_ = gpu_res;
} else {
KNOWHERE_THROW_MSG("Build IVFPQ can't get gpu resource");
......
......@@ -45,19 +45,14 @@ IVFSQHybrid::Train(const DatasetPtr& dataset_ptr, const Config& config) {
auto device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, build_index);
device_index->train(rows, (float*)p_data);
std::shared_ptr<faiss::Index> host_index = nullptr;
host_index.reset(faiss::gpu::index_gpu_to_cpu(device_index));
delete device_index;
delete build_index;
device_index = faiss::gpu::index_cpu_to_gpu(gpu_res->faiss_res.get(), gpu_id_, host_index.get());
index_.reset(device_index);
res_ = gpu_res;
gpu_mode_ = 2;
} else {
KNOWHERE_THROW_MSG("Build IVFSQHybrid can't get gpu resource");
}
delete build_index;
}
VecIndexPtr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册