From 2eec5607f71efe0ca18c96281b6279a10cf10af4 Mon Sep 17 00:00:00 2001 From: "shengjun.li" Date: Fri, 25 Sep 2020 16:42:33 +0800 Subject: [PATCH] let bitset nullptr it no deletion and no filter (#3872) * fix bitset Signed-off-by: shengjun.li * check bitset Signed-off-by: shengjun.li --- core/src/db/engine/ExecutionEngineImpl.cpp | 57 ++++++----- .../faiss/utils/ConcurrentBitset.cpp | 96 ++++++++++--------- .../thirdparty/faiss/utils/ConcurrentBitset.h | 26 +++-- core/src/segment/SegmentReader.cpp | 9 +- 4 files changed, 94 insertions(+), 94 deletions(-) diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index 9d246af1a9..5fe1d5749f 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -300,9 +300,9 @@ Status ExecutionEngineImpl::Search(ExecutionEngineContext& context) { TimeRecorder rc(LogOut("[%s][%ld] ExecutionEngineImpl::Search", "search", 0)); try { - faiss::ConcurrentBitsetPtr bitset; + faiss::ConcurrentBitsetPtr bitset = nullptr; std::string vector_placeholder; - faiss::ConcurrentBitsetPtr list; + faiss::ConcurrentBitsetPtr list = nullptr; SegmentPtr segment_ptr; segment_reader_->GetSegment(segment_ptr); @@ -325,19 +325,27 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) { } } - list = vec_index->GetBlacklist(); - entity_count_ = list->capacity(); + entity_count_ = vec_index->Count(); + // Parse general query auto status = ExecBinaryQuery(context.query_ptr_->root, bitset, attr_type, vector_placeholder); if (!status.ok()) { return status; } + if (bitset != nullptr) { + bitset->negate(); + } rc.RecordSection("Scalar field filtering"); - // Do And - for (int64_t i = 0; i < entity_count_; i++) { - if (!list->test(i) && !bitset->test(i)) { - list->set(i); + // combine filter and deletion + list = vec_index->GetBlacklist(); + if (list != nullptr) { + if (bitset != nullptr) { + list = (*list) | (*bitset); + } + } else { + if (bitset != nullptr) { + list = bitset; } } @@ -380,27 +388,25 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener } } - if (left_bitset == nullptr || right_bitset == nullptr) { - bitset = left_bitset != nullptr ? left_bitset : right_bitset; + if (left_bitset == nullptr) { + bitset = right_bitset; + } else if (right_bitset == nullptr) { + bitset = left_bitset; } else { switch (general_query->bin->relation) { case milvus::query::QueryRelation::AND: case milvus::query::QueryRelation::R1: { - bitset = (*left_bitset) & right_bitset; + bitset = (*left_bitset) & (*right_bitset); break; } case milvus::query::QueryRelation::OR: case milvus::query::QueryRelation::R2: case milvus::query::QueryRelation::R3: { - bitset = (*left_bitset) | right_bitset; + bitset = (*left_bitset) | (*right_bitset); break; } case milvus::query::QueryRelation::R4: { - for (uint64_t i = 0; i < entity_count_; ++i) { - if (left_bitset->test(i) && !right_bitset->test(i)) { - bitset->set(i); - } - } + bitset = (*left_bitset) & (right_bitset->negate()); break; } default: { @@ -410,13 +416,7 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener } // TODO(yukun): optimize if (general_query->bin->is_not) { - for (uint64_t i = 0; i < entity_count_; ++i) { - if (bitset->test(i)) { - bitset->clear(i); - } else { - bitset->set(i); - } - } + bitset->negate(); } } return status; @@ -431,7 +431,6 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener } if (!general_query->leaf->vector_placeholder.empty()) { // skip vector query - bitset = std::make_shared(entity_count_, 255); vector_placeholder = general_query->leaf->vector_placeholder; } } @@ -540,10 +539,10 @@ ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& const std::string& comp_op = range_value_it.key(); T value = range_value_it.value(); if (not flag) { - bitset = (*bitset) | T_index->Range(value, knowhere::s_map_operator_type.at(comp_op)); + bitset = (*bitset) | (*T_index->Range(value, knowhere::s_map_operator_type.at(comp_op))); flag = true; } else { - bitset = (*bitset) & T_index->Range(value, knowhere::s_map_operator_type.at(comp_op)); + bitset = (*bitset) & (*T_index->Range(value, knowhere::s_map_operator_type.at(comp_op))); } } } catch (std::exception& exception) { @@ -852,9 +851,7 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col #endif new_index->SetUids(uids); - if (blacklist != nullptr) { - new_index->SetBlacklist(blacklist); - } + new_index->SetBlacklist(blacklist); return Status::OK(); } diff --git a/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.cpp b/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.cpp index 1ba9a1e406..2f935fba93 100644 --- a/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.cpp +++ b/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.cpp @@ -26,21 +26,12 @@ ConcurrentBitset::ConcurrentBitset(id_type_t capacity, uint8_t init_value) : cap } } -std::vector>& -ConcurrentBitset::bitset() { - return bitset_; -} - ConcurrentBitset& -ConcurrentBitset::operator&=(ConcurrentBitset& bitset) { - // for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) { - // bitset_[i].fetch_and(bitset.bitset()[i].load()); - // } - - auto u8_1 = const_cast(data()); - auto u8_2 = const_cast(bitset.data()); +ConcurrentBitset::operator&=(const ConcurrentBitset& bitset) { + auto u8_1 = mutable_data(); + auto u8_2 = bitset.data(); auto u64_1 = reinterpret_cast(u8_1); - auto u64_2 = reinterpret_cast(u8_2); + auto u64_2 = reinterpret_cast(u8_2); size_t n8 = bitset_.size(); size_t n64 = n8 / 8; @@ -60,16 +51,16 @@ ConcurrentBitset::operator&=(ConcurrentBitset& bitset) { } std::shared_ptr -ConcurrentBitset::operator&(const std::shared_ptr& bitset) { - auto result_bitset = std::make_shared(bitset->capacity()); +ConcurrentBitset::operator&(const ConcurrentBitset& bitset) const { + auto result_bitset = std::make_shared(bitset.capacity()); - auto result_8 = const_cast(result_bitset->data()); + auto result_8 = result_bitset->mutable_data(); auto result_64 = reinterpret_cast(result_8); - auto u8_1 = const_cast(data()); - auto u8_2 = const_cast(bitset->data()); - auto u64_1 = reinterpret_cast(u8_1); - auto u64_2 = reinterpret_cast(u8_2); + auto u8_1 = data(); + auto u8_2 = bitset.data(); + auto u64_1 = reinterpret_cast(u8_1); + auto u64_2 = reinterpret_cast(u8_2); size_t n8 = bitset_.size(); size_t n64 = n8 / 8; @@ -91,15 +82,11 @@ ConcurrentBitset::operator&(const std::shared_ptr& bitset) { } ConcurrentBitset& -ConcurrentBitset::operator|=(ConcurrentBitset& bitset) { - // for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) { - // bitset_[i].fetch_or(bitset.bitset()[i].load()); - // } - - auto u8_1 = const_cast(data()); - auto u8_2 = const_cast(bitset.data()); +ConcurrentBitset::operator|=(const ConcurrentBitset& bitset) { + auto u8_1 = mutable_data(); + auto u8_2 = bitset.data(); auto u64_1 = reinterpret_cast(u8_1); - auto u64_2 = reinterpret_cast(u8_2); + auto u64_2 = reinterpret_cast(u8_2); size_t n8 = bitset_.size(); size_t n64 = n8 / 8; @@ -119,16 +106,16 @@ ConcurrentBitset::operator|=(ConcurrentBitset& bitset) { } std::shared_ptr -ConcurrentBitset::operator|(const std::shared_ptr& bitset) { - auto result_bitset = std::make_shared(bitset->capacity()); +ConcurrentBitset::operator|(const ConcurrentBitset& bitset) const { + auto result_bitset = std::make_shared(bitset.capacity()); - auto result_8 = const_cast(result_bitset->data()); + auto result_8 = result_bitset->mutable_data(); auto result_64 = reinterpret_cast(result_8); - auto u8_1 = const_cast(data()); - auto u8_2 = const_cast(bitset->data()); - auto u64_1 = reinterpret_cast(u8_1); - auto u64_2 = reinterpret_cast(u8_2); + auto u8_1 = data(); + auto u8_2 = bitset.data(); + auto u64_1 = reinterpret_cast(u8_1); + auto u64_2 = reinterpret_cast(u8_2); size_t n8 = bitset_.size(); size_t n64 = n8 / 8; @@ -149,15 +136,11 @@ ConcurrentBitset::operator|(const std::shared_ptr& bitset) { } ConcurrentBitset& -ConcurrentBitset::operator^=(ConcurrentBitset& bitset) { - // for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) { - // bitset_[i].fetch_xor(bitset.bitset()[i].load()); - // } - - auto u8_1 = const_cast(data()); - auto u8_2 = const_cast(bitset.data()); +ConcurrentBitset::operator^=(const ConcurrentBitset& bitset) { + auto u8_1 = mutable_data(); + auto u8_2 = bitset.data(); auto u64_1 = reinterpret_cast(u8_1); - auto u64_2 = reinterpret_cast(u8_2); + auto u64_2 = reinterpret_cast(u8_2); size_t n8 = bitset_.size(); size_t n64 = n8 / 8; @@ -176,6 +159,27 @@ ConcurrentBitset::operator^=(ConcurrentBitset& bitset) { return *this; } +ConcurrentBitset& +ConcurrentBitset::negate() { + auto u8_1 = mutable_data(); + auto u64_1 = reinterpret_cast(u8_1); + + size_t n8 = bitset_.size(); + size_t n64 = n8 / 8; + + for (size_t i = 0; i < n64; i++) { + u64_1[i] = ~u64_1[i]; + } + + size_t remain = n8 % 8; + u8_1 += n64 * 8; + for (size_t i = 0; i < remain; i++) { + u8_1[i] = ~u8_1[i]; + } + + return *this; +} + bool ConcurrentBitset::test(id_type_t id) { return bitset_[id >> 3].load() & (0x1 << (id & 0x7)); @@ -192,17 +196,17 @@ ConcurrentBitset::clear(id_type_t id) { } size_t -ConcurrentBitset::capacity() { +ConcurrentBitset::capacity() const { return capacity_; } size_t -ConcurrentBitset::size() { +ConcurrentBitset::size() const { return ((capacity_ + 8 - 1) >> 3); } const uint8_t* -ConcurrentBitset::data() { +ConcurrentBitset::data() const { return reinterpret_cast(bitset_.data()); } diff --git a/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h b/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h index 121e212b03..156ecc5963 100644 --- a/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h +++ b/core/src/index/thirdparty/faiss/utils/ConcurrentBitset.h @@ -29,27 +29,23 @@ class ConcurrentBitset { explicit ConcurrentBitset(id_type_t size, uint8_t init_value = 0); - // ConcurrentBitset(const ConcurrentBitset&) = delete; - // ConcurrentBitset& - // operator=(const ConcurrentBitset&) = delete; - - std::vector>& - bitset(); - ConcurrentBitset& - operator&=(ConcurrentBitset& bitset); + operator&=(const ConcurrentBitset& bitset); std::shared_ptr - operator&(const std::shared_ptr& bitset); + operator&(const ConcurrentBitset& bitset) const; ConcurrentBitset& - operator|=(ConcurrentBitset& bitset); + operator|=(const ConcurrentBitset& bitset); std::shared_ptr - operator|(const std::shared_ptr& bitset); + operator|(const ConcurrentBitset& bitset) const; ConcurrentBitset& - operator^=(ConcurrentBitset& bitset); + operator^=(const ConcurrentBitset& bitset); + + ConcurrentBitset& + negate(); bool test(id_type_t id); @@ -61,13 +57,13 @@ class ConcurrentBitset { clear(id_type_t id); size_t - capacity(); + capacity() const; size_t - size(); + size() const; const uint8_t* - data(); + data() const; uint8_t* mutable_data(); diff --git a/core/src/segment/SegmentReader.cpp b/core/src/segment/SegmentReader.cpp index 3fc856f173..a361be98d3 100644 --- a/core/src/segment/SegmentReader.cpp +++ b/core/src/segment/SegmentReader.cpp @@ -273,13 +273,16 @@ SegmentReader::LoadVectorIndex(const std::string& field_name, knowhere::VecIndex // load deleted doc int64_t row_count = GetRowCount(); - faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared(row_count); + faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = nullptr; segment::DeletedDocsPtr deleted_docs_ptr; LoadDeletedDocs(deleted_docs_ptr); if (deleted_docs_ptr != nullptr) { auto& deleted_docs = deleted_docs_ptr->GetDeletedDocs(); - for (auto& offset : deleted_docs) { - concurrent_bitset_ptr->set(offset); + if (!deleted_docs.empty()) { + concurrent_bitset_ptr = std::make_shared(row_count); + for (auto& offset : deleted_docs) { + concurrent_bitset_ptr->set(offset); + } } } recorder.RecordSection("prepare");