let bitset nullptr it no deletion and no filter (#3872)

* fix bitset

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>

* check bitset

Signed-off-by: shengjun.li <shengjun.li@zilliz.com>
This commit is contained in:
shengjun.li 2020-09-25 16:42:33 +08:00 committed by GitHub
parent 2d20795839
commit 2eec5607f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 94 additions and 94 deletions

View File

@ -300,9 +300,9 @@ Status
ExecutionEngineImpl::Search(ExecutionEngineContext& context) {
TimeRecorder rc(LogOut("[%s][%ld] ExecutionEngineImpl::Search", "search", 0));
try {
faiss::ConcurrentBitsetPtr bitset;
faiss::ConcurrentBitsetPtr bitset = nullptr;
std::string vector_placeholder;
faiss::ConcurrentBitsetPtr list;
faiss::ConcurrentBitsetPtr list = nullptr;
SegmentPtr segment_ptr;
segment_reader_->GetSegment(segment_ptr);
@ -325,19 +325,27 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) {
}
}
list = vec_index->GetBlacklist();
entity_count_ = list->capacity();
entity_count_ = vec_index->Count();
// Parse general query
auto status = ExecBinaryQuery(context.query_ptr_->root, bitset, attr_type, vector_placeholder);
if (!status.ok()) {
return status;
}
if (bitset != nullptr) {
bitset->negate();
}
rc.RecordSection("Scalar field filtering");
// Do And
for (int64_t i = 0; i < entity_count_; i++) {
if (!list->test(i) && !bitset->test(i)) {
list->set(i);
// combine filter and deletion
list = vec_index->GetBlacklist();
if (list != nullptr) {
if (bitset != nullptr) {
list = (*list) | (*bitset);
}
} else {
if (bitset != nullptr) {
list = bitset;
}
}
@ -380,27 +388,25 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener
}
}
if (left_bitset == nullptr || right_bitset == nullptr) {
bitset = left_bitset != nullptr ? left_bitset : right_bitset;
if (left_bitset == nullptr) {
bitset = right_bitset;
} else if (right_bitset == nullptr) {
bitset = left_bitset;
} else {
switch (general_query->bin->relation) {
case milvus::query::QueryRelation::AND:
case milvus::query::QueryRelation::R1: {
bitset = (*left_bitset) & right_bitset;
bitset = (*left_bitset) & (*right_bitset);
break;
}
case milvus::query::QueryRelation::OR:
case milvus::query::QueryRelation::R2:
case milvus::query::QueryRelation::R3: {
bitset = (*left_bitset) | right_bitset;
bitset = (*left_bitset) | (*right_bitset);
break;
}
case milvus::query::QueryRelation::R4: {
for (uint64_t i = 0; i < entity_count_; ++i) {
if (left_bitset->test(i) && !right_bitset->test(i)) {
bitset->set(i);
}
}
bitset = (*left_bitset) & (right_bitset->negate());
break;
}
default: {
@ -410,13 +416,7 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener
}
// TODO(yukun): optimize
if (general_query->bin->is_not) {
for (uint64_t i = 0; i < entity_count_; ++i) {
if (bitset->test(i)) {
bitset->clear(i);
} else {
bitset->set(i);
}
}
bitset->negate();
}
}
return status;
@ -431,7 +431,6 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener
}
if (!general_query->leaf->vector_placeholder.empty()) {
// skip vector query
bitset = std::make_shared<faiss::ConcurrentBitset>(entity_count_, 255);
vector_placeholder = general_query->leaf->vector_placeholder;
}
}
@ -540,10 +539,10 @@ ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr&
const std::string& comp_op = range_value_it.key();
T value = range_value_it.value();
if (not flag) {
bitset = (*bitset) | T_index->Range(value, knowhere::s_map_operator_type.at(comp_op));
bitset = (*bitset) | (*T_index->Range(value, knowhere::s_map_operator_type.at(comp_op)));
flag = true;
} else {
bitset = (*bitset) & T_index->Range(value, knowhere::s_map_operator_type.at(comp_op));
bitset = (*bitset) & (*T_index->Range(value, knowhere::s_map_operator_type.at(comp_op)));
}
}
} catch (std::exception& exception) {
@ -852,9 +851,7 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col
#endif
new_index->SetUids(uids);
if (blacklist != nullptr) {
new_index->SetBlacklist(blacklist);
}
new_index->SetBlacklist(blacklist);
return Status::OK();
}

View File

@ -26,21 +26,12 @@ ConcurrentBitset::ConcurrentBitset(id_type_t capacity, uint8_t init_value) : cap
}
}
std::vector<std::atomic<uint8_t>>&
ConcurrentBitset::bitset() {
return bitset_;
}
ConcurrentBitset&
ConcurrentBitset::operator&=(ConcurrentBitset& bitset) {
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
// bitset_[i].fetch_and(bitset.bitset()[i].load());
// }
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset.data());
ConcurrentBitset::operator&=(const ConcurrentBitset& bitset) {
auto u8_1 = mutable_data();
auto u8_2 = bitset.data();
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
auto u64_2 = reinterpret_cast<const uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
@ -60,16 +51,16 @@ ConcurrentBitset::operator&=(ConcurrentBitset& bitset) {
}
std::shared_ptr<ConcurrentBitset>
ConcurrentBitset::operator&(const std::shared_ptr<ConcurrentBitset>& bitset) {
auto result_bitset = std::make_shared<ConcurrentBitset>(bitset->capacity());
ConcurrentBitset::operator&(const ConcurrentBitset& bitset) const {
auto result_bitset = std::make_shared<ConcurrentBitset>(bitset.capacity());
auto result_8 = const_cast<uint8_t*>(result_bitset->data());
auto result_8 = result_bitset->mutable_data();
auto result_64 = reinterpret_cast<uint64_t*>(result_8);
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset->data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
auto u8_1 = data();
auto u8_2 = bitset.data();
auto u64_1 = reinterpret_cast<const uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<const uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
@ -91,15 +82,11 @@ ConcurrentBitset::operator&(const std::shared_ptr<ConcurrentBitset>& bitset) {
}
ConcurrentBitset&
ConcurrentBitset::operator|=(ConcurrentBitset& bitset) {
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
// bitset_[i].fetch_or(bitset.bitset()[i].load());
// }
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset.data());
ConcurrentBitset::operator|=(const ConcurrentBitset& bitset) {
auto u8_1 = mutable_data();
auto u8_2 = bitset.data();
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
auto u64_2 = reinterpret_cast<const uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
@ -119,16 +106,16 @@ ConcurrentBitset::operator|=(ConcurrentBitset& bitset) {
}
std::shared_ptr<ConcurrentBitset>
ConcurrentBitset::operator|(const std::shared_ptr<ConcurrentBitset>& bitset) {
auto result_bitset = std::make_shared<ConcurrentBitset>(bitset->capacity());
ConcurrentBitset::operator|(const ConcurrentBitset& bitset) const {
auto result_bitset = std::make_shared<ConcurrentBitset>(bitset.capacity());
auto result_8 = const_cast<uint8_t*>(result_bitset->data());
auto result_8 = result_bitset->mutable_data();
auto result_64 = reinterpret_cast<uint64_t*>(result_8);
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset->data());
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
auto u8_1 = data();
auto u8_2 = bitset.data();
auto u64_1 = reinterpret_cast<const uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<const uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
@ -149,15 +136,11 @@ ConcurrentBitset::operator|(const std::shared_ptr<ConcurrentBitset>& bitset) {
}
ConcurrentBitset&
ConcurrentBitset::operator^=(ConcurrentBitset& bitset) {
// for (id_type_t i = 0; i < ((capacity_ + 8 -1) >> 3); ++i) {
// bitset_[i].fetch_xor(bitset.bitset()[i].load());
// }
auto u8_1 = const_cast<uint8_t*>(data());
auto u8_2 = const_cast<uint8_t*>(bitset.data());
ConcurrentBitset::operator^=(const ConcurrentBitset& bitset) {
auto u8_1 = mutable_data();
auto u8_2 = bitset.data();
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
auto u64_2 = reinterpret_cast<uint64_t*>(u8_2);
auto u64_2 = reinterpret_cast<const uint64_t*>(u8_2);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
@ -176,6 +159,27 @@ ConcurrentBitset::operator^=(ConcurrentBitset& bitset) {
return *this;
}
ConcurrentBitset&
ConcurrentBitset::negate() {
auto u8_1 = mutable_data();
auto u64_1 = reinterpret_cast<uint64_t*>(u8_1);
size_t n8 = bitset_.size();
size_t n64 = n8 / 8;
for (size_t i = 0; i < n64; i++) {
u64_1[i] = ~u64_1[i];
}
size_t remain = n8 % 8;
u8_1 += n64 * 8;
for (size_t i = 0; i < remain; i++) {
u8_1[i] = ~u8_1[i];
}
return *this;
}
bool
ConcurrentBitset::test(id_type_t id) {
return bitset_[id >> 3].load() & (0x1 << (id & 0x7));
@ -192,17 +196,17 @@ ConcurrentBitset::clear(id_type_t id) {
}
size_t
ConcurrentBitset::capacity() {
ConcurrentBitset::capacity() const {
return capacity_;
}
size_t
ConcurrentBitset::size() {
ConcurrentBitset::size() const {
return ((capacity_ + 8 - 1) >> 3);
}
const uint8_t*
ConcurrentBitset::data() {
ConcurrentBitset::data() const {
return reinterpret_cast<const uint8_t*>(bitset_.data());
}

View File

@ -29,27 +29,23 @@ class ConcurrentBitset {
explicit ConcurrentBitset(id_type_t size, uint8_t init_value = 0);
// ConcurrentBitset(const ConcurrentBitset&) = delete;
// ConcurrentBitset&
// operator=(const ConcurrentBitset&) = delete;
std::vector<std::atomic<uint8_t>>&
bitset();
ConcurrentBitset&
operator&=(ConcurrentBitset& bitset);
operator&=(const ConcurrentBitset& bitset);
std::shared_ptr<ConcurrentBitset>
operator&(const std::shared_ptr<ConcurrentBitset>& bitset);
operator&(const ConcurrentBitset& bitset) const;
ConcurrentBitset&
operator|=(ConcurrentBitset& bitset);
operator|=(const ConcurrentBitset& bitset);
std::shared_ptr<ConcurrentBitset>
operator|(const std::shared_ptr<ConcurrentBitset>& bitset);
operator|(const ConcurrentBitset& bitset) const;
ConcurrentBitset&
operator^=(ConcurrentBitset& bitset);
operator^=(const ConcurrentBitset& bitset);
ConcurrentBitset&
negate();
bool
test(id_type_t id);
@ -61,13 +57,13 @@ class ConcurrentBitset {
clear(id_type_t id);
size_t
capacity();
capacity() const;
size_t
size();
size() const;
const uint8_t*
data();
data() const;
uint8_t*
mutable_data();

View File

@ -273,13 +273,16 @@ SegmentReader::LoadVectorIndex(const std::string& field_name, knowhere::VecIndex
// load deleted doc
int64_t row_count = GetRowCount();
faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(row_count);
faiss::ConcurrentBitsetPtr concurrent_bitset_ptr = nullptr;
segment::DeletedDocsPtr deleted_docs_ptr;
LoadDeletedDocs(deleted_docs_ptr);
if (deleted_docs_ptr != nullptr) {
auto& deleted_docs = deleted_docs_ptr->GetDeletedDocs();
for (auto& offset : deleted_docs) {
concurrent_bitset_ptr->set(offset);
if (!deleted_docs.empty()) {
concurrent_bitset_ptr = std::make_shared<faiss::ConcurrentBitset>(row_count);
for (auto& offset : deleted_docs) {
concurrent_bitset_ptr->set(offset);
}
}
}
recorder.RecordSection("prepare");