mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-02-02 01:06:41 +08:00
Dirty blacklist is left in the cache (#5272)
Multi threads accessing will leave dirty blacklist in the cache. To fix it, we let knowhere::index not hold the blacklist. And each query will regenerate it by DeleteDoc. Later, we will add blacklist to the cache to improve performance. Resolves: #4897 Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
This commit is contained in:
parent
dce238e939
commit
53cacf60b2
@ -4,11 +4,15 @@ Please mark all change in change log and use the issue from GitHub
|
||||
|
||||
# Milvus 1.1.1 (TBD)
|
||||
## Bug
|
||||
- \#4897 Query results contain some deleted ids
|
||||
|
||||
## Feature
|
||||
|
||||
## Improvement
|
||||
- \#5161 Enable Gpu cache
|
||||
- \#5161 Enable Gpu cache
|
||||
- \#5204 Improve IVF query on GPU when no entity deleted
|
||||
|
||||
## Task
|
||||
|
||||
# Milvus 1.1.0 (2021-05-07)
|
||||
## Bug
|
||||
|
||||
@ -574,7 +574,7 @@ DBImpl::ReLoadSegmentsDeletedDocs(const std::string& collection_id, const std::v
|
||||
if (!initialized_.load(std::memory_order_acquire)) {
|
||||
return SHUTDOWN_ERROR;
|
||||
}
|
||||
|
||||
#if 0 // todo
|
||||
meta::FilesHolder files_holder;
|
||||
std::vector<size_t> file_ids;
|
||||
for (auto& id : segment_ids) {
|
||||
@ -624,7 +624,7 @@ DBImpl::ReLoadSegmentsDeletedDocs(const std::string& collection_id, const std::v
|
||||
blacklist->set(i);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1546,8 +1546,8 @@ DBImpl::GetVectorsByIdHelper(const IDNumbers& id_array, std::vector<engine::Vect
|
||||
if (deleted == deleted_docs.end()) {
|
||||
// Load raw vector
|
||||
std::vector<uint8_t> raw_vector;
|
||||
status =
|
||||
segment_reader.LoadVectors(offset * single_vector_bytes, single_vector_bytes, raw_vector);
|
||||
status = segment_reader.LoadsSingleVector(offset * single_vector_bytes, single_vector_bytes,
|
||||
raw_vector);
|
||||
if (!status.ok()) {
|
||||
LOG_ENGINE_ERROR_ << status.message();
|
||||
return status;
|
||||
|
||||
@ -382,12 +382,17 @@ ExecutionEngineImpl::Serialize() {
|
||||
|
||||
Status
|
||||
ExecutionEngineImpl::Load(bool to_cache) {
|
||||
std::string segment_dir;
|
||||
segment::SegmentReaderPtr segment_reader_ptr = nullptr;
|
||||
auto get_segment_reader = [&]() {
|
||||
utils::GetParentPath(location_, segment_dir);
|
||||
segment_reader_ptr = std::make_shared<segment::SegmentReader>(segment_dir);
|
||||
};
|
||||
|
||||
index_ = std::static_pointer_cast<knowhere::VecIndex>(cache::CpuCacheMgr::GetInstance()->GetItem(location_));
|
||||
if (!index_) {
|
||||
// not in the cache
|
||||
std::string segment_dir;
|
||||
utils::GetParentPath(location_, segment_dir);
|
||||
auto segment_reader_ptr = std::make_shared<segment::SegmentReader>(segment_dir);
|
||||
get_segment_reader();
|
||||
knowhere::VecIndexFactory& vec_index_factory = knowhere::VecIndexFactory::GetInstance();
|
||||
|
||||
if (utils::IsRawIndexType((int32_t)index_type_)) {
|
||||
@ -405,18 +410,14 @@ ExecutionEngineImpl::Load(bool to_cache) {
|
||||
throw Exception(DB_ERROR, "Illegal index params");
|
||||
}
|
||||
|
||||
auto status = segment_reader_ptr->Load();
|
||||
segment::VectorsPtr vectors = nullptr;
|
||||
auto status = segment_reader_ptr->LoadsVectors(vectors);
|
||||
if (!status.ok()) {
|
||||
std::string msg = "Failed to load segment from " + location_;
|
||||
std::string msg = "Failed to load vectors from " + location_;
|
||||
LOG_ENGINE_ERROR_ << msg;
|
||||
return Status(DB_ERROR, msg);
|
||||
}
|
||||
|
||||
segment::SegmentPtr segment_ptr;
|
||||
segment_reader_ptr->GetSegment(segment_ptr);
|
||||
auto& vectors = segment_ptr->vectors_ptr_;
|
||||
auto& deleted_docs = segment_ptr->deleted_docs_ptr_->GetDeletedDocs();
|
||||
|
||||
auto& vectors_uids = vectors->GetMutableUids();
|
||||
std::shared_ptr<std::vector<int64_t>> vector_uids_ptr = std::make_shared<std::vector<int64_t>>();
|
||||
vector_uids_ptr->swap(vectors_uids);
|
||||
@ -424,28 +425,16 @@ ExecutionEngineImpl::Load(bool to_cache) {
|
||||
LOG_ENGINE_DEBUG_ << "set uids " << vector_uids_ptr->size() << " for index " << location_;
|
||||
|
||||
auto& vectors_data = vectors->GetData();
|
||||
|
||||
auto count = vector_uids_ptr->size();
|
||||
|
||||
faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = nullptr;
|
||||
if (!deleted_docs.empty()) {
|
||||
concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(count);
|
||||
for (auto& offset : deleted_docs) {
|
||||
concurrent_bitset_ptr->set(offset);
|
||||
}
|
||||
}
|
||||
|
||||
auto dataset = knowhere::GenDataset(count, this->dim_, vectors_data.data());
|
||||
if (index_type_ == EngineType::FAISS_IDMAP) {
|
||||
auto bf_index = std::static_pointer_cast<knowhere::IDMAP>(index_);
|
||||
bf_index->Train(knowhere::DatasetPtr(), conf);
|
||||
bf_index->AddWithoutIds(dataset, conf);
|
||||
bf_index->SetBlacklist(concurrent_bitset_ptr);
|
||||
} else if (index_type_ == EngineType::FAISS_BIN_IDMAP) {
|
||||
auto bin_bf_index = std::static_pointer_cast<knowhere::BinaryIDMAP>(index_);
|
||||
bin_bf_index->Train(knowhere::DatasetPtr(), conf);
|
||||
bin_bf_index->AddWithoutIds(dataset, conf);
|
||||
bin_bf_index->SetBlacklist(concurrent_bitset_ptr);
|
||||
}
|
||||
|
||||
LOG_ENGINE_DEBUG_ << "Finished loading raw data from segment " << segment_dir;
|
||||
@ -455,38 +444,18 @@ ExecutionEngineImpl::Load(bool to_cache) {
|
||||
segment_reader_ptr->GetSegment(segment_ptr);
|
||||
auto status = segment_reader_ptr->LoadVectorIndex(location_, segment_ptr->vector_index_ptr_);
|
||||
index_ = segment_ptr->vector_index_ptr_->GetVectorIndex();
|
||||
|
||||
if (index_ == nullptr) {
|
||||
std::string msg = "Failed to load index from " + location_;
|
||||
LOG_ENGINE_ERROR_ << msg;
|
||||
return Status(DB_ERROR, msg);
|
||||
} else {
|
||||
segment::DeletedDocsPtr deleted_docs_ptr;
|
||||
auto status = segment_reader_ptr->LoadDeletedDocs(deleted_docs_ptr);
|
||||
if (!status.ok()) {
|
||||
std::string msg = "Failed to load deleted docs from " + location_;
|
||||
LOG_ENGINE_ERROR_ << msg;
|
||||
return Status(DB_ERROR, msg);
|
||||
}
|
||||
auto& deleted_docs = deleted_docs_ptr->GetDeletedDocs();
|
||||
|
||||
faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = nullptr;
|
||||
if (!deleted_docs.empty()) {
|
||||
concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(index_->Count());
|
||||
for (auto& offset : deleted_docs) {
|
||||
if (!concurrent_bitset_ptr->test(offset)) {
|
||||
concurrent_bitset_ptr->set(offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
index_->SetBlacklist(concurrent_bitset_ptr);
|
||||
segment::UidsPtr uids_ptr = nullptr;
|
||||
segment_reader_ptr->LoadUids(uids_ptr);
|
||||
index_->SetUids(uids_ptr);
|
||||
LOG_ENGINE_DEBUG_ << "set uids " << index_->GetUids()->size() << " for index " << location_;
|
||||
|
||||
LOG_ENGINE_DEBUG_ << "Finished loading index file from segment " << segment_dir;
|
||||
}
|
||||
|
||||
segment::UidsPtr uids_ptr = nullptr;
|
||||
segment_reader_ptr->LoadUids(uids_ptr);
|
||||
index_->SetUids(uids_ptr);
|
||||
LOG_ENGINE_DEBUG_ << "set uids " << index_->GetUids()->size() << " for index " << location_;
|
||||
|
||||
LOG_ENGINE_DEBUG_ << "Finished loading index file from segment " << segment_dir;
|
||||
} catch (std::exception& e) {
|
||||
LOG_ENGINE_ERROR_ << e.what();
|
||||
return Status(DB_ERROR, e.what());
|
||||
@ -498,8 +467,32 @@ ExecutionEngineImpl::Load(bool to_cache) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!blacklist_) {
|
||||
if (!segment_reader_ptr) {
|
||||
get_segment_reader();
|
||||
}
|
||||
|
||||
segment::DeletedDocsPtr deleted_docs_ptr;
|
||||
auto status = segment_reader_ptr->LoadDeletedDocs(deleted_docs_ptr);
|
||||
if (!status.ok()) {
|
||||
std::string msg = "Failed to load deleted docs from " + location_;
|
||||
LOG_ENGINE_ERROR_ << msg;
|
||||
return Status(DB_ERROR, msg);
|
||||
}
|
||||
auto& deleted_docs = deleted_docs_ptr->GetDeletedDocs();
|
||||
|
||||
blacklist_ = std::make_shared<knowhere::Blacklist>();
|
||||
if (!deleted_docs.empty()) {
|
||||
auto concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(index_->Count());
|
||||
for (auto& offset : deleted_docs) {
|
||||
concurrent_bitset_ptr->set(offset);
|
||||
}
|
||||
blacklist_->bitset_ = concurrent_bitset_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
} // namespace engine
|
||||
}
|
||||
|
||||
Status
|
||||
ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
|
||||
@ -657,12 +650,10 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
|
||||
auto dataset = knowhere::GenDataset(Count(), Dimension(), from_index->GetRawVectors());
|
||||
to_index->BuildAll(dataset, conf);
|
||||
uids = from_index->GetUids();
|
||||
blacklist = from_index->GetBlacklist();
|
||||
} else if (bin_from_index) {
|
||||
auto dataset = knowhere::GenDataset(Count(), Dimension(), bin_from_index->GetRawVectors());
|
||||
to_index->BuildAll(dataset, conf);
|
||||
uids = bin_from_index->GetUids();
|
||||
blacklist = bin_from_index->GetBlacklist();
|
||||
}
|
||||
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
@ -675,10 +666,7 @@ ExecutionEngineImpl::BuildIndex(const std::string& location, EngineType engine_t
|
||||
|
||||
to_index->SetUids(uids);
|
||||
LOG_ENGINE_DEBUG_ << "Set " << to_index->UidsSize() << "uids for " << location;
|
||||
if (blacklist != nullptr) {
|
||||
to_index->SetBlacklist(blacklist);
|
||||
LOG_ENGINE_DEBUG_ << "Set blacklist for index " << location;
|
||||
}
|
||||
|
||||
LOG_ENGINE_DEBUG_ << "Finish build index: " << location;
|
||||
return std::make_shared<ExecutionEngineImpl>(to_index, location, engine_type, metric_type_, index_params_);
|
||||
}
|
||||
@ -721,7 +709,7 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, const milvu
|
||||
|
||||
rc.RecordSection("query prepare");
|
||||
auto dataset = knowhere::GenDataset(n, index_->Dim(), data);
|
||||
auto result = index_->Query(dataset, conf);
|
||||
auto result = index_->Query(dataset, conf, (blacklist_ ? blacklist_->bitset_ : nullptr));
|
||||
rc.RecordSection("query done");
|
||||
|
||||
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld] get %ld uids from index %s", "search", 0, index_->GetUids()->size(),
|
||||
@ -762,7 +750,7 @@ ExecutionEngineImpl::Search(int64_t n, const uint8_t* data, int64_t k, const mil
|
||||
|
||||
rc.RecordSection("query prepare");
|
||||
auto dataset = knowhere::GenDataset(n, index_->Dim(), data);
|
||||
auto result = index_->Query(dataset, conf);
|
||||
auto result = index_->Query(dataset, conf, (blacklist_ ? blacklist_->bitset_ : nullptr));
|
||||
rc.RecordSection("query done");
|
||||
|
||||
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld] get %ld uids from index %s", "search", 0, index_->GetUids()->size(),
|
||||
|
||||
@ -125,6 +125,7 @@ class ExecutionEngineImpl : public ExecutionEngine {
|
||||
HybridUnset() const;
|
||||
|
||||
protected:
|
||||
knowhere::BlacklistPtr blacklist_ = nullptr;
|
||||
knowhere::VecIndexPtr index_ = nullptr;
|
||||
#ifdef MILVUS_GPU_VERSION
|
||||
knowhere::VecIndexPtr index_reserve_ = nullptr; // reserve the cpu index before copying it to gpu
|
||||
|
||||
@ -272,30 +272,7 @@ MemTable::ApplyDeletes() {
|
||||
|
||||
segment::UidsPtr uids_ptr = nullptr;
|
||||
|
||||
// Get all index that contains blacklist in cache
|
||||
std::vector<knowhere::VecIndexPtr> indexes;
|
||||
std::vector<faiss::ConcurrentBitsetPtr> blacklists;
|
||||
milvus::engine::meta::SegmentsSchema& segment_files = segment_holder.HoldFiles();
|
||||
for (auto& segment_file : segment_files) {
|
||||
auto data_obj_ptr = cache::CpuCacheMgr::GetInstance()->GetItem(segment_file.location_);
|
||||
auto index = std::static_pointer_cast<knowhere::VecIndex>(data_obj_ptr);
|
||||
if (index != nullptr) {
|
||||
faiss::ConcurrentBitsetPtr blacklist = index->GetBlacklist();
|
||||
if (blacklist == nullptr) {
|
||||
// to update and set the blacklist
|
||||
blacklist = std::make_shared<faiss::ConcurrentBitset>(index->Count());
|
||||
indexes.emplace_back(index);
|
||||
blacklists.emplace_back(blacklist);
|
||||
} else {
|
||||
// just to update the blacklist
|
||||
indexes.emplace_back(nullptr);
|
||||
blacklists.emplace_back(blacklist);
|
||||
}
|
||||
|
||||
// load uids from cache
|
||||
uids_ptr = index->GetUids();
|
||||
}
|
||||
}
|
||||
|
||||
std::string segment_dir;
|
||||
utils::GetParentPath(file.location_, segment_dir);
|
||||
@ -333,9 +310,6 @@ MemTable::ApplyDeletes() {
|
||||
deleted_docs->AddDeletedDoc(i);
|
||||
id_bloom_filter_ptr->Remove((*uids_ptr)[i]);
|
||||
|
||||
for (auto& blacklist : blacklists) {
|
||||
blacklist->set(i);
|
||||
}
|
||||
auto set_end = std::chrono::high_resolution_clock::now();
|
||||
set_diff += (set_end - set_start);
|
||||
}
|
||||
@ -352,12 +326,6 @@ MemTable::ApplyDeletes() {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < indexes.size(); ++i) {
|
||||
if (indexes[i]) {
|
||||
indexes[i]->SetBlacklist(blacklists[i]);
|
||||
}
|
||||
}
|
||||
|
||||
segment::Segment tmp_segment;
|
||||
segment::SegmentWriter segment_writer(segment_dir);
|
||||
status = segment_writer.WriteDeletedDocs(deleted_docs);
|
||||
|
||||
@ -31,15 +31,20 @@ class Index : public milvus::cache::DataObj {
|
||||
|
||||
using IndexPtr = std::shared_ptr<Index>;
|
||||
|
||||
// todo: remove from knowhere
|
||||
class ToIndexData : public milvus::cache::DataObj {
|
||||
class Blacklist : public milvus::cache::DataObj {
|
||||
public:
|
||||
explicit ToIndexData(int64_t size) : size_(size) {
|
||||
Blacklist() {
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t size_ = 0;
|
||||
int64_t
|
||||
Size() override {
|
||||
return (bitset_ != nullptr) ? 0 : bitset_->size();
|
||||
}
|
||||
|
||||
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
|
||||
};
|
||||
|
||||
using BlacklistPtr = std::shared_ptr<Blacklist>;
|
||||
|
||||
} // namespace knowhere
|
||||
} // namespace milvus
|
||||
|
||||
@ -175,7 +175,7 @@ IVFPQConfAdapter::CheckTrain(Config& oricfg, IndexMode& mode) {
|
||||
return true;
|
||||
}
|
||||
// else try CPU Mode
|
||||
mode == IndexMode::MODE_CPU;
|
||||
mode = IndexMode::MODE_CPU;
|
||||
}
|
||||
#endif
|
||||
return IsValidForCPU(dimension, m);
|
||||
|
||||
@ -108,7 +108,7 @@ IndexAnnoy::BuildAll(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
@ -119,7 +119,6 @@ IndexAnnoy::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
auto all_num = rows * k;
|
||||
auto p_id = (int64_t*)malloc(all_num * sizeof(int64_t));
|
||||
auto p_dist = (float*)malloc(all_num * sizeof(float));
|
||||
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (unsigned int i = 0; i < rows; ++i) {
|
||||
|
||||
@ -48,7 +48,7 @@ class IndexAnnoy : public VecIndex {
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
@ -37,7 +37,7 @@ BinaryIDMAP::Load(const BinarySet& index_binary) {
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
@ -50,7 +50,7 @@ BinaryIDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config);
|
||||
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config, blacklist);
|
||||
MapOffsetToUid(p_id, static_cast<size_t>(elems));
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
@ -107,13 +107,13 @@ BinaryIDMAP::GetRawVectors() {
|
||||
|
||||
void
|
||||
BinaryIDMAP::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& config) {
|
||||
const Config& config, faiss::ConcurrentBitsetPtr blacklist) {
|
||||
auto default_type = index_->metric_type;
|
||||
if (config.contains(Metric::TYPE))
|
||||
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
|
||||
int32_t* i_distances = reinterpret_cast<int32_t*>(distances);
|
||||
index_->search(n, (uint8_t*)data, k, i_distances, labels, GetBlacklist());
|
||||
index_->search(n, (uint8_t*)data, k, i_distances, labels, blacklist);
|
||||
|
||||
// if hamming, it need transform int32 to float
|
||||
if (index_->metric_type == faiss::METRIC_Hamming) {
|
||||
|
||||
@ -44,7 +44,7 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
@ -62,7 +62,8 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex {
|
||||
|
||||
protected:
|
||||
virtual void
|
||||
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
|
||||
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist);
|
||||
};
|
||||
|
||||
using BinaryIDMAPPtr = std::shared_ptr<BinaryIDMAP>;
|
||||
|
||||
@ -41,7 +41,7 @@ BinaryIVF::Load(const BinarySet& index_binary) {
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
@ -57,7 +57,7 @@ BinaryIVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config);
|
||||
QueryImpl(rows, (uint8_t*)p_data, k, p_dist, p_id, config, blacklist);
|
||||
MapOffsetToUid(p_id, static_cast<size_t>(elems));
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
@ -163,15 +163,15 @@ BinaryIVF::GenParams(const Config& config) {
|
||||
}
|
||||
|
||||
void
|
||||
BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& config) {
|
||||
BinaryIVF::QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexBinaryIVF*>(index_.get());
|
||||
ivf_index->nprobe = params->nprobe;
|
||||
|
||||
stdclock::time_point before = stdclock::now();
|
||||
int32_t* i_distances = reinterpret_cast<int32_t*>(distances);
|
||||
index_->search(n, (uint8_t*)data, k, i_distances, labels, GetBlacklist());
|
||||
index_->search(n, (uint8_t*)data, k, i_distances, labels, blacklist);
|
||||
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
|
||||
@ -47,7 +47,7 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
@ -63,7 +63,8 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex {
|
||||
GenParams(const Config& config);
|
||||
|
||||
virtual void
|
||||
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config);
|
||||
QueryImpl(int64_t n, const uint8_t* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist);
|
||||
};
|
||||
|
||||
using BinaryIVFIndexPtr = std::shared_ptr<BinaryIVF>;
|
||||
|
||||
@ -110,7 +110,7 @@ IndexHNSW::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
@ -124,7 +124,6 @@ IndexHNSW::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
|
||||
index_->setEf(config[IndexParams::ef]);
|
||||
|
||||
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
|
||||
bool transform = (index_->metric_type_ == 1); // InnerProduct: 1
|
||||
|
||||
#pragma omp parallel for
|
||||
|
||||
@ -40,7 +40,7 @@ class IndexHNSW : public VecIndex {
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
@ -68,7 +68,7 @@ IDMAP::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) {
|
||||
if (!index_) {
|
||||
KNOWHERE_THROW_MSG("index not initialize");
|
||||
}
|
||||
@ -81,7 +81,7 @@ IDMAP::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
QueryImpl(rows, (float*)p_data, k, p_dist, p_id, config);
|
||||
QueryImpl(rows, (float*)p_data, k, p_dist, p_id, config, blacklist);
|
||||
MapOffsetToUid(p_id, static_cast<size_t>(elems));
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
@ -135,11 +135,12 @@ IDMAP::GetRawVectors() {
|
||||
}
|
||||
|
||||
void
|
||||
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
IDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist) {
|
||||
auto default_type = index_->metric_type;
|
||||
if (config.contains(Metric::TYPE))
|
||||
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
index_->search(n, (float*)data, k, distances, labels, GetBlacklist());
|
||||
index_->search(n, (float*)data, k, distances, labels, blacklist);
|
||||
index_->metric_type = default_type;
|
||||
}
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
@ -64,7 +64,8 @@ class IDMAP : public VecIndex, public FaissBaseIndex {
|
||||
|
||||
protected:
|
||||
virtual void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
|
||||
QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist);
|
||||
};
|
||||
|
||||
using IDMAPPtr = std::shared_ptr<IDMAP>;
|
||||
|
||||
@ -85,7 +85,7 @@ IVF::AddWithoutIds(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
IVF::Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
@ -103,7 +103,7 @@ IVF::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
QueryImpl(rows, (float*)p_data, k, p_dist, p_id, config);
|
||||
QueryImpl(rows, (float*)p_data, k, p_dist, p_id, config, blacklist);
|
||||
MapOffsetToUid(p_id, static_cast<size_t>(elems));
|
||||
|
||||
auto ret_ds = std::make_shared<Dataset>();
|
||||
@ -272,7 +272,7 @@ IVF::GenGraph(const float* data, const int64_t k, GraphType& graph, const Config
|
||||
res.resize(K * b_size);
|
||||
|
||||
auto xq = data + batch_size * dim * i;
|
||||
QueryImpl(b_size, (float*)xq, K, res_dis.data(), res.data(), config);
|
||||
QueryImpl(b_size, (float*)xq, K, res_dis.data(), res.data(), config, nullptr);
|
||||
|
||||
for (int j = 0; j < b_size; ++j) {
|
||||
auto& node = graph[batch_size * i + j];
|
||||
@ -294,7 +294,8 @@ IVF::GenParams(const Config& config) {
|
||||
}
|
||||
|
||||
void
|
||||
IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist) {
|
||||
auto params = GenParams(config);
|
||||
auto ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
|
||||
ivf_index->nprobe = std::min(params->nprobe, ivf_index->invlists->nlist);
|
||||
@ -304,7 +305,7 @@ IVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_
|
||||
} else {
|
||||
ivf_index->parallel_mode = 0;
|
||||
}
|
||||
ivf_index->search(n, (float*)data, k, distances, labels, GetBlacklist());
|
||||
ivf_index->search(n, (float*)data, k, distances, labels, blacklist);
|
||||
stdclock::time_point after = stdclock::now();
|
||||
double search_cost = (std::chrono::duration<double, std::micro>(after - before)).count();
|
||||
LOG_KNOWHERE_DEBUG_ << "IVF search cost: " << search_cost
|
||||
|
||||
@ -47,12 +47,7 @@ class IVF : public VecIndex, public FaissBaseIndex {
|
||||
AddWithoutIds(const DatasetPtr&, const Config&) override;
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
|
||||
#if 0
|
||||
DatasetPtr
|
||||
QueryById(const DatasetPtr& dataset, const Config& config) override;
|
||||
#endif
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
@ -82,7 +77,8 @@ class IVF : public VecIndex, public FaissBaseIndex {
|
||||
GenParams(const Config&);
|
||||
|
||||
virtual void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&);
|
||||
QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist);
|
||||
|
||||
void
|
||||
SealImpl() override;
|
||||
|
||||
@ -71,7 +71,7 @@ NSG::Load(const BinarySet& index_binary) {
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
NSG::Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) {
|
||||
if (!index_ || !index_->is_trained) {
|
||||
KNOWHERE_THROW_MSG("index not initialize or trained");
|
||||
}
|
||||
@ -85,8 +85,6 @@ NSG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
auto p_id = (int64_t*)malloc(p_id_size);
|
||||
auto p_dist = (float*)malloc(p_dist_size);
|
||||
|
||||
faiss::ConcurrentBitsetPtr blacklist = GetBlacklist();
|
||||
|
||||
impl::SearchParams s_params;
|
||||
s_params.search_length = config[IndexParams::search_length];
|
||||
s_params.k = config[meta::TOPK];
|
||||
|
||||
@ -54,7 +54,7 @@ class NSG : public VecIndex {
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr&, const Config&) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
@ -176,7 +176,7 @@ CPUSPTAGRNG::SetParameters(const Config& config) {
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config) {
|
||||
CPUSPTAGRNG::Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) {
|
||||
SetParameters(config);
|
||||
|
||||
float* p_data = (float*)dataset_ptr->Get<const void*>(meta::TENSOR);
|
||||
|
||||
@ -47,7 +47,7 @@ class CPUSPTAGRNG : public VecIndex {
|
||||
}
|
||||
|
||||
DatasetPtr
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config) override;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) override;
|
||||
|
||||
int64_t
|
||||
Count() override;
|
||||
|
||||
@ -41,7 +41,7 @@ class VecIndex : public Index {
|
||||
AddWithoutIds(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
|
||||
virtual DatasetPtr
|
||||
Query(const DatasetPtr& dataset, const Config& config) = 0;
|
||||
Query(const DatasetPtr& dataset_ptr, const Config& config, faiss::ConcurrentBitsetPtr blacklist) = 0;
|
||||
|
||||
virtual int64_t
|
||||
Dim() = 0;
|
||||
@ -59,18 +59,6 @@ class VecIndex : public Index {
|
||||
return index_mode_;
|
||||
}
|
||||
|
||||
faiss::ConcurrentBitsetPtr
|
||||
GetBlacklist() {
|
||||
std::unique_lock<std::mutex> lck(bitset_mutex_);
|
||||
return bitset_;
|
||||
}
|
||||
|
||||
void
|
||||
SetBlacklist(faiss::ConcurrentBitsetPtr bitset_ptr) {
|
||||
std::unique_lock<std::mutex> lck(bitset_mutex_);
|
||||
bitset_ = std::move(bitset_ptr);
|
||||
}
|
||||
|
||||
std::shared_ptr<std::vector<IDType>>
|
||||
GetUids() const {
|
||||
return uids_;
|
||||
@ -92,12 +80,6 @@ class VecIndex : public Index {
|
||||
}
|
||||
}
|
||||
|
||||
size_t
|
||||
BlacklistSize() {
|
||||
std::unique_lock<std::mutex> lck(bitset_mutex_);
|
||||
return bitset_ ? bitset_->size() : 0;
|
||||
}
|
||||
|
||||
size_t
|
||||
UidsSize() {
|
||||
return (uids_ == nullptr) ? 0 : (uids_->size() * sizeof(IDType));
|
||||
@ -122,7 +104,7 @@ class VecIndex : public Index {
|
||||
|
||||
int64_t
|
||||
Size() override {
|
||||
return BlacklistSize() + UidsSize() + IndexSize();
|
||||
return UidsSize() + IndexSize();
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -130,11 +112,6 @@ class VecIndex : public Index {
|
||||
IndexMode index_mode_ = IndexMode::MODE_CPU;
|
||||
std::shared_ptr<std::vector<IDType>> uids_ = nullptr;
|
||||
int64_t index_size_ = -1;
|
||||
|
||||
private:
|
||||
// multi thread may access bitset_
|
||||
std::mutex bitset_mutex_;
|
||||
faiss::ConcurrentBitsetPtr bitset_ = nullptr;
|
||||
};
|
||||
|
||||
using VecIndexPtr = std::shared_ptr<VecIndex>;
|
||||
|
||||
@ -97,13 +97,14 @@ GPUIDMAP::GetRawVectors() {
|
||||
}
|
||||
|
||||
void
|
||||
GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
GPUIDMAP::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist) {
|
||||
ResScope rs(res_, gpu_id_);
|
||||
|
||||
auto default_type = index_->metric_type;
|
||||
if (config.contains(Metric::TYPE))
|
||||
index_->metric_type = GetMetricType(config[Metric::TYPE].get<std::string>());
|
||||
index_->search(n, (float*)data, k, distances, labels, GetBlacklist());
|
||||
index_->search(n, (float*)data, k, distances, labels, blacklist);
|
||||
index_->metric_type = default_type;
|
||||
}
|
||||
|
||||
@ -128,7 +129,7 @@ GPUIDMAP::GenGraph(const float* data, const int64_t k, GraphType& graph, const C
|
||||
res.resize(K * b_size);
|
||||
|
||||
auto xq = data + batch_size * dim * i;
|
||||
QueryImpl(b_size, (float*)xq, K, res_dis.data(), res.data(), config);
|
||||
QueryImpl(b_size, (float*)xq, K, res_dis.data(), res.data(), config, nullptr);
|
||||
|
||||
for (int j = 0; j < b_size; ++j) {
|
||||
auto& node = graph[batch_size * i + j];
|
||||
|
||||
@ -50,7 +50,8 @@ class GPUIDMAP : public IDMAP, public GPUIndex {
|
||||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
|
||||
QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist);
|
||||
};
|
||||
|
||||
using GPUIDMAPPtr = std::shared_ptr<GPUIDMAP>;
|
||||
|
||||
@ -133,7 +133,8 @@ GPUIVF::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
|
||||
}
|
||||
|
||||
void
|
||||
GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config) {
|
||||
GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist) {
|
||||
auto device_index = std::dynamic_pointer_cast<faiss::gpu::GpuIndexIVF>(index_);
|
||||
fiu_do_on("GPUIVF.search_impl.invald_index", device_index = nullptr);
|
||||
if (device_index) {
|
||||
@ -145,8 +146,7 @@ GPUIVF::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int
|
||||
int64_t dim = device_index->d;
|
||||
for (int64_t i = 0; i < n; i += block_size) {
|
||||
int64_t search_size = (n - i > block_size) ? block_size : (n - i);
|
||||
device_index->search(search_size, (float*)data + i * dim, k, distances + i * k, labels + i * k,
|
||||
GetBlacklist());
|
||||
device_index->search(search_size, (float*)data + i * dim, k, distances + i * k, labels + i * k, blacklist);
|
||||
}
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Not a GpuIndexIVF type.");
|
||||
|
||||
@ -51,7 +51,8 @@ class GPUIVF : public IVF, public GPUIndex {
|
||||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
|
||||
QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist);
|
||||
};
|
||||
|
||||
using GPUIVFPtr = std::shared_ptr<GPUIVF>;
|
||||
|
||||
@ -243,21 +243,21 @@ IVFSQHybrid::LoadImpl(const BinarySet& binary_set, const IndexType& type) {
|
||||
}
|
||||
|
||||
void
|
||||
IVFSQHybrid::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels,
|
||||
const Config& config) {
|
||||
IVFSQHybrid::QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist) {
|
||||
if (gpu_mode_ == 2) {
|
||||
GPUIVF::QueryImpl(n, data, k, distances, labels, config);
|
||||
GPUIVF::QueryImpl(n, data, k, distances, labels, config, blacklist);
|
||||
// index_->search(n, (float*)data, k, distances, labels);
|
||||
} else if (gpu_mode_ == 1) { // hybrid
|
||||
auto gpu_id = quantizer_->gpu_id;
|
||||
if (auto res = FaissGpuResourceMgr::GetInstance().GetRes(gpu_id)) {
|
||||
ResScope rs(res, gpu_id, true);
|
||||
IVF::QueryImpl(n, data, k, distances, labels, config);
|
||||
IVF::QueryImpl(n, data, k, distances, labels, config, blacklist);
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("Hybrid Search Error, can't get gpu: " + std::to_string(gpu_id) + "resource");
|
||||
}
|
||||
} else if (gpu_mode_ == 0) {
|
||||
IVF::QueryImpl(n, data, k, distances, labels, config);
|
||||
IVF::QueryImpl(n, data, k, distances, labels, config, blacklist);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -90,7 +90,8 @@ class IVFSQHybrid : public GPUIVFSQ {
|
||||
LoadImpl(const BinarySet&, const IndexType&) override;
|
||||
|
||||
void
|
||||
QueryImpl(int64_t, const float*, int64_t, float*, int64_t*, const Config&) override;
|
||||
QueryImpl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, const Config& config,
|
||||
faiss::ConcurrentBitsetPtr blacklist);
|
||||
|
||||
protected:
|
||||
int64_t gpu_mode_ = 0; // 0: CPU, 1: Hybrid, 2: GPU
|
||||
|
||||
@ -27,8 +27,6 @@ namespace cloner {
|
||||
void
|
||||
CopyIndexData(const VecIndexPtr& dst_index, const VecIndexPtr& src_index) {
|
||||
dst_index->SetUids(src_index->GetUids());
|
||||
|
||||
dst_index->SetBlacklist(src_index->GetBlacklist());
|
||||
dst_index->SetIndexSize(src_index->IndexSize());
|
||||
}
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ TEST_P(AnnoyTest, annoy_basic) {
|
||||
// null faiss index
|
||||
{
|
||||
ASSERT_ANY_THROW(index_->Train(base_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr));
|
||||
ASSERT_ANY_THROW(index_->Serialize(conf));
|
||||
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Count());
|
||||
@ -64,7 +64,7 @@ TEST_P(AnnoyTest, annoy_basic) {
|
||||
ASSERT_EQ(index_->Count(), nb);
|
||||
ASSERT_EQ(index_->Dim(), dim);
|
||||
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
|
||||
@ -73,7 +73,7 @@ TEST_P(AnnoyTest, annoy_basic) {
|
||||
base_dataset->Set(milvus::knowhere::meta::ROWS, rows);
|
||||
index_ = std::make_shared<milvus::knowhere::IndexAnnoy>();
|
||||
index_->BuildAll(base_dataset, conf);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
auto result2 = index_->Query(query_dataset, conf, nullptr);
|
||||
auto res_ids = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
for (int64_t i = 0; i < nq; i++) {
|
||||
for (int64_t j = rows; j < k; j++) {
|
||||
@ -95,12 +95,11 @@ TEST_P(AnnoyTest, annoy_delete) {
|
||||
bitset->set(i);
|
||||
}
|
||||
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
auto result1 = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result1, nq, k);
|
||||
ReleaseQueryResult(result1);
|
||||
|
||||
index_->SetBlacklist(bitset);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
auto result2 = index_->Query(query_dataset, conf, bitset);
|
||||
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
ReleaseQueryResult(result2);
|
||||
|
||||
@ -193,7 +192,7 @@ TEST_P(AnnoyTest, annoy_serialize) {
|
||||
index_->Load(binaryset);
|
||||
ASSERT_EQ(index_->Count(), nb);
|
||||
ASSERT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
ReleaseQueryResult(result);
|
||||
}
|
||||
|
||||
@ -52,7 +52,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
||||
// null faiss index
|
||||
{
|
||||
ASSERT_ANY_THROW(index_->Serialize());
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr));
|
||||
ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf));
|
||||
}
|
||||
|
||||
@ -61,7 +61,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
ASSERT_TRUE(index_->GetRawVectors() != nullptr);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -69,7 +69,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
||||
auto binaryset = index_->Serialize();
|
||||
auto new_index = std::make_shared<milvus::knowhere::BinaryIDMAP>();
|
||||
new_index->Load(binaryset);
|
||||
auto result2 = new_index->Query(query_dataset, conf);
|
||||
auto result2 = new_index->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result2, nq, k);
|
||||
// PrintResult(re_result, nq, k);
|
||||
ReleaseQueryResult(result2);
|
||||
@ -78,9 +78,8 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) {
|
||||
for (int64_t i = 0; i < nq; ++i) {
|
||||
concurrent_bitset_ptr->set(i);
|
||||
}
|
||||
index_->SetBlacklist(concurrent_bitset_ptr);
|
||||
|
||||
auto result_bs_1 = index_->Query(query_dataset, conf);
|
||||
auto result_bs_1 = index_->Query(query_dataset, conf, concurrent_bitset_ptr);
|
||||
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
ReleaseQueryResult(result_bs_1);
|
||||
|
||||
@ -108,7 +107,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
|
||||
// serialize index
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->AddWithoutIds(base_dataset, milvus::knowhere::Config());
|
||||
auto re_result = index_->Query(query_dataset, conf);
|
||||
auto re_result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(re_result, nq, k);
|
||||
// PrintResult(re_result, nq, k);
|
||||
ReleaseQueryResult(re_result);
|
||||
@ -128,7 +127,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_serialize) {
|
||||
index_->Load(binaryset);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
|
||||
@ -63,7 +63,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
|
||||
// null faiss index
|
||||
{
|
||||
ASSERT_ANY_THROW(index_->Serialize());
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr));
|
||||
ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf));
|
||||
}
|
||||
|
||||
@ -71,7 +71,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -80,9 +80,8 @@ TEST_P(BinaryIVFTest, binaryivf_basic) {
|
||||
for (int64_t i = 0; i < nq; ++i) {
|
||||
concurrent_bitset_ptr->set(i);
|
||||
}
|
||||
index_->SetBlacklist(concurrent_bitset_ptr);
|
||||
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
auto result2 = index_->Query(query_dataset, conf, concurrent_bitset_ptr);
|
||||
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
ReleaseQueryResult(result2);
|
||||
|
||||
@ -146,7 +145,7 @@ TEST_P(BinaryIVFTest, binaryivf_serialize) {
|
||||
index_->Load(binaryset);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
|
||||
@ -67,7 +67,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
|
||||
{
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
auto gpu_idx = cpu_idx->CopyCpuToGpu(DEVICEID, conf);
|
||||
auto result = gpu_idx->Query(query_dataset, conf);
|
||||
auto result = gpu_idx->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -84,7 +84,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
|
||||
auto pair = cpu_idx->CopyCpuToGpuWithQuantizer(DEVICEID, conf);
|
||||
auto gpu_idx = pair.first;
|
||||
|
||||
auto result = gpu_idx->Query(query_dataset, conf);
|
||||
auto result = gpu_idx->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -95,7 +95,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
|
||||
hybrid_idx->Load(binaryset);
|
||||
auto quantization = hybrid_idx->LoadQuantizer(quantizer_conf);
|
||||
auto new_idx = hybrid_idx->LoadData(quantization, quantizer_conf);
|
||||
auto result = new_idx->Query(query_dataset, conf);
|
||||
auto result = new_idx->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -115,7 +115,7 @@ TEST_F(SingleIndexTest, IVFSQHybrid) {
|
||||
hybrid_idx->Load(binaryset);
|
||||
|
||||
hybrid_idx->SetQuantizer(quantization);
|
||||
auto result = hybrid_idx->Query(query_dataset, conf);
|
||||
auto result = hybrid_idx->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
hybrid_idx->UnsetQuantizer();
|
||||
|
||||
@ -74,7 +74,7 @@ TEST_F(GPURESTEST, copyandsearch) {
|
||||
auto conf = ParamGenerator::GetInstance().Gen(index_type_);
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->AddWithoutIds(base_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
|
||||
@ -89,7 +89,7 @@ TEST_F(GPURESTEST, copyandsearch) {
|
||||
auto search_func = [&] {
|
||||
// TimeRecorder tc("search&load");
|
||||
for (int i = 0; i < search_count; ++i) {
|
||||
auto result = search_idx->Query(query_dataset, conf);
|
||||
auto result = search_idx->Query(query_dataset, conf, nullptr);
|
||||
ReleaseQueryResult(result);
|
||||
// if (i > search_count - 6 || i == 0)
|
||||
// tc.RecordSection("search once");
|
||||
@ -109,7 +109,7 @@ TEST_F(GPURESTEST, copyandsearch) {
|
||||
milvus::knowhere::TimeRecorder tc("Basic");
|
||||
milvus::knowhere::cloner::CopyCpuToGpu(cpu_idx, DEVICEID, milvus::knowhere::Config());
|
||||
tc.RecordSection("Copy to gpu once");
|
||||
auto result2 = search_idx->Query(query_dataset, conf);
|
||||
auto result2 = search_idx->Query(query_dataset, conf, nullptr);
|
||||
ReleaseQueryResult(result2);
|
||||
tc.RecordSection("Search once");
|
||||
search_func();
|
||||
@ -148,7 +148,7 @@ TEST_F(GPURESTEST, trainandsearch) {
|
||||
};
|
||||
auto search_stage = [&](milvus::knowhere::VecIndexPtr& search_idx) {
|
||||
for (int i = 0; i < search_count; ++i) {
|
||||
auto result = search_idx->Query(query_dataset, conf);
|
||||
auto result = search_idx->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
}
|
||||
|
||||
@ -50,7 +50,7 @@ TEST_P(HNSWTest, HNSW_basic) {
|
||||
// null faiss index
|
||||
{
|
||||
ASSERT_ANY_THROW(index_->Serialize());
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr));
|
||||
ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf));
|
||||
ASSERT_ANY_THROW(index_->Count());
|
||||
ASSERT_ANY_THROW(index_->Dim());
|
||||
@ -61,7 +61,7 @@ TEST_P(HNSWTest, HNSW_basic) {
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
|
||||
@ -70,7 +70,7 @@ TEST_P(HNSWTest, HNSW_basic) {
|
||||
base_dataset->Set(milvus::knowhere::meta::ROWS, rows);
|
||||
index_->Train(base_dataset, conf);
|
||||
index_->AddWithoutIds(base_dataset, conf);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
auto result2 = index_->Query(query_dataset, conf, nullptr);
|
||||
auto res_ids = result2->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
for (int64_t i = 0; i < nq; i++) {
|
||||
for (int64_t j = rows; j < k; j++) {
|
||||
@ -92,12 +92,11 @@ TEST_P(HNSWTest, HNSW_delete) {
|
||||
for (auto i = 0; i < nq; ++i) {
|
||||
bitset->set(i);
|
||||
}
|
||||
auto result1 = index_->Query(query_dataset, conf);
|
||||
auto result1 = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result1, nq, k);
|
||||
ReleaseQueryResult(result1);
|
||||
|
||||
index_->SetBlacklist(bitset);
|
||||
auto result2 = index_->Query(query_dataset, conf);
|
||||
auto result2 = index_->Query(query_dataset, conf, bitset);
|
||||
AssertAnns(result2, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
ReleaseQueryResult(result2);
|
||||
|
||||
@ -151,7 +150,7 @@ TEST_P(HNSWTest, HNSW_serialize) {
|
||||
index_->Load(binaryset);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
ReleaseQueryResult(result);
|
||||
}
|
||||
|
||||
@ -73,7 +73,7 @@ TEST_P(IDMAPTest, idmap_basic) {
|
||||
// null faiss index
|
||||
{
|
||||
ASSERT_ANY_THROW(index_->Serialize());
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf, nullptr));
|
||||
ASSERT_ANY_THROW(index_->AddWithoutIds(nullptr, conf));
|
||||
}
|
||||
|
||||
@ -82,7 +82,7 @@ TEST_P(IDMAPTest, idmap_basic) {
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
ASSERT_TRUE(index_->GetRawVectors() != nullptr);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -97,7 +97,7 @@ TEST_P(IDMAPTest, idmap_basic) {
|
||||
auto binaryset = index_->Serialize();
|
||||
auto new_index = std::make_shared<milvus::knowhere::IDMAP>();
|
||||
new_index->Load(binaryset);
|
||||
auto result2 = new_index->Query(query_dataset, conf);
|
||||
auto result2 = new_index->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result2, nq, k);
|
||||
// PrintResult(re_result, nq, k);
|
||||
ReleaseQueryResult(result2);
|
||||
@ -114,9 +114,8 @@ TEST_P(IDMAPTest, idmap_basic) {
|
||||
for (int64_t i = 0; i < nq; ++i) {
|
||||
concurrent_bitset_ptr->set(i);
|
||||
}
|
||||
index_->SetBlacklist(concurrent_bitset_ptr);
|
||||
|
||||
auto result_bs_1 = index_->Query(query_dataset, conf);
|
||||
auto result_bs_1 = index_->Query(query_dataset, conf, concurrent_bitset_ptr);
|
||||
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
ReleaseQueryResult(result_bs_1);
|
||||
|
||||
@ -154,7 +153,7 @@ TEST_P(IDMAPTest, idmap_serialize) {
|
||||
#endif
|
||||
}
|
||||
|
||||
auto re_result = index_->Query(query_dataset, conf);
|
||||
auto re_result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(re_result, nq, k);
|
||||
// PrintResult(re_result, nq, k);
|
||||
ReleaseQueryResult(re_result);
|
||||
@ -174,7 +173,7 @@ TEST_P(IDMAPTest, idmap_serialize) {
|
||||
index_->Load(binaryset);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -194,7 +193,7 @@ TEST_P(IDMAPTest, idmap_copy) {
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
ASSERT_TRUE(index_->GetRawVectors() != nullptr);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -210,7 +209,7 @@ TEST_P(IDMAPTest, idmap_copy) {
|
||||
// cpu to gpu
|
||||
ASSERT_ANY_THROW(milvus::knowhere::cloner::CopyCpuToGpu(index_, -1, conf));
|
||||
auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf);
|
||||
auto clone_result = clone_index->Query(query_dataset, conf);
|
||||
auto clone_result = clone_index->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(clone_result, nq, k);
|
||||
ReleaseQueryResult(clone_result);
|
||||
ASSERT_THROW({ std::static_pointer_cast<milvus::knowhere::GPUIDMAP>(clone_index)->GetRawVectors(); },
|
||||
@ -223,7 +222,7 @@ TEST_P(IDMAPTest, idmap_copy) {
|
||||
|
||||
auto binary = clone_index->Serialize();
|
||||
clone_index->Load(binary);
|
||||
auto new_result = clone_index->Query(query_dataset, conf);
|
||||
auto new_result = clone_index->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(new_result, nq, k);
|
||||
ReleaseQueryResult(new_result);
|
||||
|
||||
@ -233,7 +232,7 @@ TEST_P(IDMAPTest, idmap_copy) {
|
||||
|
||||
// gpu to cpu
|
||||
auto host_index = milvus::knowhere::cloner::CopyGpuToCpu(clone_index, conf);
|
||||
auto host_result = host_index->Query(query_dataset, conf);
|
||||
auto host_result = host_index->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(host_result, nq, k);
|
||||
ReleaseQueryResult(host_result);
|
||||
ASSERT_TRUE(std::static_pointer_cast<milvus::knowhere::IDMAP>(host_index)->GetRawVectors() != nullptr);
|
||||
@ -242,7 +241,7 @@ TEST_P(IDMAPTest, idmap_copy) {
|
||||
auto device_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, conf);
|
||||
auto new_device_index =
|
||||
std::static_pointer_cast<milvus::knowhere::GPUIDMAP>(device_index)->CopyGpuToGpu(DEVICEID, conf);
|
||||
auto device_result = new_device_index->Query(query_dataset, conf);
|
||||
auto device_result = new_device_index->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(device_result, nq, k);
|
||||
ReleaseQueryResult(device_result);
|
||||
}
|
||||
|
||||
@ -105,7 +105,7 @@ TEST_P(IVFTest, ivf_basic_cpu) {
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
auto result = index_->Query(query_dataset, conf_);
|
||||
auto result = index_->Query(query_dataset, conf_, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -129,9 +129,8 @@ TEST_P(IVFTest, ivf_basic_cpu) {
|
||||
for (int64_t i = 0; i < nq; ++i) {
|
||||
concurrent_bitset_ptr->set(i);
|
||||
}
|
||||
index_->SetBlacklist(concurrent_bitset_ptr);
|
||||
|
||||
auto result_bs_1 = index_->Query(query_dataset, conf_);
|
||||
auto result_bs_1 = index_->Query(query_dataset, conf_, concurrent_bitset_ptr);
|
||||
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result_bs_1);
|
||||
@ -165,7 +164,7 @@ TEST_P(IVFTest, ivf_basic_gpu) {
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
|
||||
auto result = index_->Query(query_dataset, conf_);
|
||||
auto result = index_->Query(query_dataset, conf_, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -174,9 +173,8 @@ TEST_P(IVFTest, ivf_basic_gpu) {
|
||||
for (int64_t i = 0; i < nq; ++i) {
|
||||
concurrent_bitset_ptr->set(i);
|
||||
}
|
||||
index_->SetBlacklist(concurrent_bitset_ptr);
|
||||
|
||||
auto result_bs_1 = index_->Query(query_dataset, conf_);
|
||||
auto result_bs_1 = index_->Query(query_dataset, conf_, concurrent_bitset_ptr);
|
||||
AssertAnns(result_bs_1, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
// PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result_bs_1);
|
||||
@ -214,7 +212,7 @@ TEST_P(IVFTest, ivf_serialize) {
|
||||
index_->Load(binaryset);
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->Dim(), dim);
|
||||
auto result = index_->Query(query_dataset, conf_);
|
||||
auto result = index_->Query(query_dataset, conf_, nullptr);
|
||||
AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]);
|
||||
ReleaseQueryResult(result);
|
||||
}
|
||||
@ -233,7 +231,7 @@ TEST_P(IVFTest, clone_test) {
|
||||
/* set peseodo index size, avoid throw exception */
|
||||
index_->SetIndexSize(nq * dim * sizeof(float));
|
||||
|
||||
auto result = index_->Query(query_dataset, conf_);
|
||||
auto result = index_->Query(query_dataset, conf_, nullptr);
|
||||
AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]);
|
||||
// PrintResult(result, nq, k);
|
||||
|
||||
@ -273,7 +271,7 @@ TEST_P(IVFTest, clone_test) {
|
||||
if (index_mode_ == milvus::knowhere::IndexMode::MODE_GPU) {
|
||||
EXPECT_NO_THROW({
|
||||
auto clone_index = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config());
|
||||
auto clone_result = clone_index->Query(query_dataset, conf_);
|
||||
auto clone_result = clone_index->Query(query_dataset, conf_, nullptr);
|
||||
AssertEqual(result, clone_result);
|
||||
ReleaseQueryResult(clone_result);
|
||||
std::cout << "clone G <=> C [" << index_type_ << "] success" << std::endl;
|
||||
@ -293,7 +291,7 @@ TEST_P(IVFTest, clone_test) {
|
||||
if (index_type_ != milvus::knowhere::IndexEnum::INDEX_FAISS_IVFSQ8H) {
|
||||
EXPECT_NO_THROW({
|
||||
auto clone_index = milvus::knowhere::cloner::CopyCpuToGpu(index_, DEVICEID, milvus::knowhere::Config());
|
||||
auto clone_result = clone_index->Query(query_dataset, conf_);
|
||||
auto clone_result = clone_index->Query(query_dataset, conf_, nullptr);
|
||||
AssertEqual(result, clone_result);
|
||||
ReleaseQueryResult(clone_result);
|
||||
std::cout << "clone C <=> G [" << index_type_ << "] success" << std::endl;
|
||||
@ -313,7 +311,7 @@ TEST_P(IVFTest, gpu_seal_test) {
|
||||
}
|
||||
assert(!xb.empty());
|
||||
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_, nullptr));
|
||||
ASSERT_ANY_THROW(index_->Seal());
|
||||
|
||||
index_->Train(base_dataset, conf_);
|
||||
@ -324,16 +322,16 @@ TEST_P(IVFTest, gpu_seal_test) {
|
||||
/* set peseodo index size, avoid throw exception */
|
||||
index_->SetIndexSize(nq * dim * sizeof(float));
|
||||
|
||||
auto result = index_->Query(query_dataset, conf_);
|
||||
auto result = index_->Query(query_dataset, conf_, nullptr);
|
||||
AssertAnns(result, nq, conf_[milvus::knowhere::meta::TOPK]);
|
||||
ReleaseQueryResult(result);
|
||||
|
||||
fiu_init(0);
|
||||
fiu_enable("IVF.Search.throw_std_exception", 1, nullptr, 0);
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_, nullptr));
|
||||
fiu_disable("IVF.Search.throw_std_exception");
|
||||
fiu_enable("IVF.Search.throw_faiss_exception", 1, nullptr, 0);
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, conf_, nullptr));
|
||||
fiu_disable("IVF.Search.throw_faiss_exception");
|
||||
|
||||
auto cpu_idx = milvus::knowhere::cloner::CopyGpuToCpu(index_, milvus::knowhere::Config());
|
||||
@ -374,7 +372,7 @@ TEST_P(IVFTest, invalid_gpu_source) {
|
||||
fiu_disable("GPUIVF.SerializeImpl.throw_exception");
|
||||
|
||||
fiu_enable("GPUIVF.search_impl.invald_index", 1, nullptr, 0);
|
||||
ASSERT_ANY_THROW(index_->Query(base_dataset, invalid_conf));
|
||||
ASSERT_ANY_THROW(index_->Query(base_dataset, invalid_conf, nullptr));
|
||||
fiu_disable("GPUIVF.search_impl.invald_index");
|
||||
|
||||
auto ivf_index = std::dynamic_pointer_cast<milvus::knowhere::GPUIVF>(index_);
|
||||
|
||||
@ -81,13 +81,13 @@ TEST_F(NSGInterfaceTest, basic_test) {
|
||||
// untrained index
|
||||
{
|
||||
ASSERT_ANY_THROW(index_->Serialize());
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, search_conf));
|
||||
ASSERT_ANY_THROW(index_->Query(query_dataset, search_conf, nullptr));
|
||||
ASSERT_ANY_THROW(index_->AddWithoutIds(base_dataset, search_conf));
|
||||
}
|
||||
|
||||
train_conf[milvus::knowhere::meta::DEVICEID] = -1;
|
||||
index_->BuildAll(base_dataset, train_conf);
|
||||
auto result = index_->Query(query_dataset, search_conf);
|
||||
auto result = index_->Query(query_dataset, search_conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
|
||||
@ -102,7 +102,7 @@ TEST_F(NSGInterfaceTest, basic_test) {
|
||||
auto new_index_1 = std::make_shared<milvus::knowhere::NSG>(DEVICE_GPU0);
|
||||
train_conf[milvus::knowhere::meta::DEVICEID] = DEVICE_GPU0;
|
||||
new_index_1->BuildAll(base_dataset, train_conf);
|
||||
auto new_result_1 = new_index_1->Query(query_dataset, search_conf);
|
||||
auto new_result_1 = new_index_1->Query(query_dataset, search_conf, nullptr);
|
||||
AssertAnns(new_result_1, nq, k);
|
||||
ReleaseQueryResult(new_result_1);
|
||||
|
||||
@ -115,7 +115,7 @@ TEST_F(NSGInterfaceTest, basic_test) {
|
||||
fiu_disable("NSG.Load.throw_exception");
|
||||
}
|
||||
|
||||
auto new_result_2 = new_index_2->Query(query_dataset, search_conf);
|
||||
auto new_result_2 = new_index_2->Query(query_dataset, search_conf, nullptr);
|
||||
AssertAnns(new_result_2, nq, k);
|
||||
ReleaseQueryResult(new_result_2);
|
||||
|
||||
@ -144,7 +144,7 @@ TEST_F(NSGInterfaceTest, delete_test) {
|
||||
train_conf[milvus::knowhere::meta::DEVICEID] = DEVICE_GPU0;
|
||||
index_->BuildAll(base_dataset, train_conf);
|
||||
|
||||
auto result = index_->Query(query_dataset, search_conf);
|
||||
auto result = index_->Query(query_dataset, search_conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
auto I_before = result->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
|
||||
@ -156,8 +156,7 @@ TEST_F(NSGInterfaceTest, delete_test) {
|
||||
for (int i = 0; i < nq; i++) {
|
||||
bitset->set(i);
|
||||
}
|
||||
index_->SetBlacklist(bitset);
|
||||
auto result_after = index_->Query(query_dataset, search_conf);
|
||||
auto result_after = index_->Query(query_dataset, search_conf, bitset);
|
||||
AssertAnns(result_after, nq, k, CheckMode::CHECK_NOT_EQUAL);
|
||||
auto I_after = result_after->Get<int64_t*>(milvus::knowhere::meta::IDS);
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ TEST_P(SPTAGTest, sptag_basic) {
|
||||
|
||||
index_->BuildAll(base_dataset, conf);
|
||||
// index_->Add(base_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
|
||||
@ -98,7 +98,7 @@ TEST_P(SPTAGTest, sptag_serialize) {
|
||||
auto binaryset = index_->Serialize();
|
||||
auto new_index = std::make_shared<milvus::knowhere::CPUSPTAGRNG>(IndexType);
|
||||
new_index->Load(binaryset);
|
||||
auto result = new_index->Query(query_dataset, conf);
|
||||
auto result = new_index->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -135,7 +135,7 @@ TEST_P(SPTAGTest, sptag_serialize) {
|
||||
|
||||
auto new_index = std::make_shared<milvus::knowhere::CPUSPTAGRNG>(IndexType);
|
||||
new_index->Load(load_data_list);
|
||||
auto result = new_index->Query(query_dataset, conf);
|
||||
auto result = new_index->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, k);
|
||||
PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
|
||||
@ -82,7 +82,7 @@ TEST_P(VecIndexTest, basic) {
|
||||
EXPECT_EQ(index_->index_type(), index_type_);
|
||||
EXPECT_EQ(index_->index_mode(), index_mode_);
|
||||
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
PrintResult(result, nq, k);
|
||||
ReleaseQueryResult(result);
|
||||
@ -94,7 +94,7 @@ TEST_P(VecIndexTest, serialize) {
|
||||
EXPECT_EQ(index_->Count(), nb);
|
||||
EXPECT_EQ(index_->index_type(), index_type_);
|
||||
EXPECT_EQ(index_->index_mode(), index_mode_);
|
||||
auto result = index_->Query(query_dataset, conf);
|
||||
auto result = index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
ReleaseQueryResult(result);
|
||||
|
||||
@ -105,7 +105,7 @@ TEST_P(VecIndexTest, serialize) {
|
||||
EXPECT_EQ(index_->Count(), new_index->Count());
|
||||
EXPECT_EQ(index_->index_type(), new_index->index_type());
|
||||
EXPECT_EQ(index_->index_mode(), new_index->index_mode());
|
||||
auto new_result = new_index_->Query(query_dataset, conf);
|
||||
auto new_result = new_index_->Query(query_dataset, conf, nullptr);
|
||||
AssertAnns(new_result, nq, conf[milvus::knowhere::meta::TOPK]);
|
||||
ReleaseQueryResult(new_result);
|
||||
}
|
||||
|
||||
@ -55,13 +55,28 @@ SegmentReader::Load() {
|
||||
}
|
||||
|
||||
Status
|
||||
SegmentReader::LoadVectors(off_t offset, size_t num_bytes, std::vector<uint8_t>& raw_vectors) {
|
||||
SegmentReader::LoadsVectors(VectorsPtr& vectors_ptr) {
|
||||
codec::DefaultCodec default_codec;
|
||||
try {
|
||||
fs_ptr_->operation_ptr_->CreateDirectory();
|
||||
vectors_ptr = std::make_shared<Vectors>();
|
||||
default_codec.GetVectorsFormat()->read(fs_ptr_, vectors_ptr);
|
||||
} catch (std::exception& e) {
|
||||
std::string err_msg = "Failed to load raw vectors: " + std::string(e.what());
|
||||
LOG_ENGINE_ERROR_ << err_msg;
|
||||
return Status(DB_ERROR, e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
SegmentReader::LoadsSingleVector(off_t offset, size_t num_bytes, std::vector<uint8_t>& raw_vectors) {
|
||||
codec::DefaultCodec default_codec;
|
||||
try {
|
||||
fs_ptr_->operation_ptr_->CreateDirectory();
|
||||
default_codec.GetVectorsFormat()->read_vectors(fs_ptr_, offset, num_bytes, raw_vectors);
|
||||
} catch (std::exception& e) {
|
||||
std::string err_msg = "Failed to load raw vectors: " + std::string(e.what());
|
||||
std::string err_msg = "Failed to load single vector: " + std::string(e.what());
|
||||
LOG_ENGINE_ERROR_ << err_msg;
|
||||
return Status(DB_ERROR, err_msg);
|
||||
}
|
||||
|
||||
@ -40,7 +40,10 @@ class SegmentReader {
|
||||
Load();
|
||||
|
||||
Status
|
||||
LoadVectors(off_t offset, size_t num_bytes, std::vector<uint8_t>& raw_vectors);
|
||||
LoadsVectors(VectorsPtr& vectors_ptr);
|
||||
|
||||
Status
|
||||
LoadsSingleVector(off_t offset, size_t num_bytes, std::vector<uint8_t>& raw_vectors);
|
||||
|
||||
Status
|
||||
LoadUids(UidsPtr& uids);
|
||||
|
||||
@ -59,7 +59,8 @@ class MockVecIndex : public milvus::knowhere::VecIndex {
|
||||
|
||||
virtual milvus::knowhere::DatasetPtr
|
||||
Query(const milvus::knowhere::DatasetPtr& dataset,
|
||||
const milvus::knowhere::Config& cfg = milvus::knowhere::Config()) {
|
||||
const milvus::knowhere::Config& cfg = milvus::knowhere::Config(),
|
||||
faiss::ConcurrentBitsetPtr blacklist = nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user