IndexRHNSWPQ.cpp 3.3 KB
Newer Older
O
op-hunter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

#include "knowhere/index/vector_index/IndexRHNSWPQ.h"

#include <algorithm>
#include <cassert>
#include <iterator>
#include <utility>
#include <vector>

#include "faiss/BuilderSuspend.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"
#include "knowhere/index/vector_index/adapter/VectorAdapter.h"
#include "knowhere/index/vector_index/helpers/FaissIO.h"

namespace milvus {
namespace knowhere {

IndexRHNSWPQ::IndexRHNSWPQ(int d, int pq_m, int M) {
    index_ = std::shared_ptr<faiss::Index>(new faiss::IndexRHNSWPQ(d, pq_m, M));
}

BinarySet
IndexRHNSWPQ::Serialize(const Config& config) {
    if (!index_) {
        KNOWHERE_THROW_MSG("index not initialize or trained");
    }

    try {
        auto res_set = IndexRHNSW::Serialize(config);
        MemoryIOWriter writer;
42
        writer.name = QUANTIZATION_DATA;
O
op-hunter 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
        auto real_idx = dynamic_cast<faiss::IndexRHNSWPQ*>(index_.get());
        if (real_idx == nullptr) {
            KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWPQ*>(index_) failed during Serialize!");
        }
        faiss::write_index(real_idx->storage, &writer);
        std::shared_ptr<uint8_t[]> data(writer.data_);

        res_set.Append(writer.name, data, writer.rp);
        return res_set;
    } catch (std::exception& e) {
        KNOWHERE_THROW_MSG(e.what());
    }
}

void
IndexRHNSWPQ::Load(const BinarySet& index_binary) {
    try {
        IndexRHNSW::Load(index_binary);
        MemoryIOReader reader;
62
        reader.name = QUANTIZATION_DATA;
O
op-hunter 已提交
63 64
        auto binary = index_binary.GetByName(reader.name);

C
cqy123456 已提交
65
        reader.total = static_cast<size_t>(binary->size);
O
op-hunter 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
        reader.data_ = binary->data.get();

        auto real_idx = dynamic_cast<faiss::IndexRHNSWPQ*>(index_.get());
        if (real_idx == nullptr) {
            KNOWHERE_THROW_MSG("dynamic_cast<faiss::IndexRHNSWPQ*>(index_) failed during Load!");
        }
        real_idx->storage = faiss::read_index(&reader);
        real_idx->init_hnsw();
    } catch (std::exception& e) {
        KNOWHERE_THROW_MSG(e.what());
    }
}

void
IndexRHNSWPQ::Train(const DatasetPtr& dataset_ptr, const Config& config) {
    try {
        GET_TENSOR_DATA_DIM(dataset_ptr)

        auto idx = new faiss::IndexRHNSWPQ(int(dim), config[IndexParams::PQM], config[IndexParams::M]);
        idx->hnsw.efConstruction = config[IndexParams::efConstruction];
        index_ = std::shared_ptr<faiss::Index>(idx);
C
cqy123456 已提交
87
        index_->train(rows, reinterpret_cast<const float*>(p_data));
O
op-hunter 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    } catch (std::exception& e) {
        KNOWHERE_THROW_MSG(e.what());
    }
}

void
IndexRHNSWPQ::UpdateIndexSize() {
    if (!index_) {
        KNOWHERE_THROW_MSG("index not initialize");
    }
    index_size_ = dynamic_cast<faiss::IndexRHNSWPQ*>(index_.get())->cal_size();
}

}  // namespace knowhere
}  // namespace milvus