From b4ddf746c18728e125ccc3d9cb221e4c9232aafa Mon Sep 17 00:00:00 2001 From: congqixia Date: Tue, 14 Jan 2025 17:59:00 +0800 Subject: [PATCH] enhance: [2.5][bitset] extend op_find() to be able to search both 0 and 1 (#39237) Cherry-pick from master pr: #39176 #39229 issue: #39124 `bitset::find_first()` and `bitset::find_next()` now accept one more parameter, which allows to search for `0` bit instead of `1` bit --------- Signed-off-by: Alexandr Guzhva Signed-off-by: Congqi Xia Co-authored-by: Alexander Guzhva --- internal/core/src/bitset/CMakeLists.txt | 4 +- internal/core/src/bitset/bitset.h | 17 ++-- internal/core/src/bitset/detail/bit_wise.h | 5 +- .../src/bitset/detail/element_vectorized.h | 5 +- .../core/src/bitset/detail/element_wise.h | 93 ++++++++++++++++++- internal/core/src/bitset/detail/proxy.h | 2 + .../src/segcore/ChunkedSegmentSealedImpl.cpp | 8 +- .../core/src/segcore/SegmentSealedImpl.cpp | 9 +- internal/core/unittest/test_bitset.cpp | 56 ++++++----- 9 files changed, 147 insertions(+), 52 deletions(-) diff --git a/internal/core/src/bitset/CMakeLists.txt b/internal/core/src/bitset/CMakeLists.txt index 3f7c6ae24d..7bec771da9 100644 --- a/internal/core/src/bitset/CMakeLists.txt +++ b/internal/core/src/bitset/CMakeLists.txt @@ -22,8 +22,8 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") detail/platform/x86/instruction_set.cpp ) - set_source_files_properties(detail/platform/x86/avx512-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq") - set_source_files_properties(detail/platform/x86/avx2-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma") + set_source_files_properties(detail/platform/x86/avx512-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq -mavx512cd -mbmi") + set_source_files_properties(detail/platform/x86/avx2-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma -mbmi") # set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq") # set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma") diff --git a/internal/core/src/bitset/bitset.h b/internal/core/src/bitset/bitset.h index e5a82ac263..87dbe5b65c 100644 --- a/internal/core/src/bitset/bitset.h +++ b/internal/core/src/bitset/bitset.h @@ -546,23 +546,26 @@ class BitsetBase { return as_derived(); } - // Find the index of the first bit set to true. + // Find the index of the first bit set to either true (default), or false. inline std::optional - find_first() const { + find_first(const bool is_set = true) const { return policy_type::op_find( - this->data(), this->offset(), this->size(), 0); + this->data(), this->offset(), this->size(), 0, is_set); } - // Find the index of the first bit set to true, starting from a given bit index. + // Find the index of the first bit set to either true (default), or false, starting from a given bit index. inline std::optional - find_next(const size_t starting_bit_idx) const { + find_next(const size_t starting_bit_idx, const bool is_set = true) const { const size_t size_v = this->size(); if (starting_bit_idx + 1 >= size_v) { return std::nullopt; } - return policy_type::op_find( - this->data(), this->offset(), this->size(), starting_bit_idx + 1); + return policy_type::op_find(this->data(), + this->offset(), + this->size(), + starting_bit_idx + 1, + is_set); } // Read multiple bits starting from a given bit index. diff --git a/internal/core/src/bitset/detail/bit_wise.h b/internal/core/src/bitset/detail/bit_wise.h index f3d08dc5be..0103a58209 100644 --- a/internal/core/src/bitset/detail/bit_wise.h +++ b/internal/core/src/bitset/detail/bit_wise.h @@ -315,10 +315,11 @@ struct BitWiseBitsetPolicy { op_find(const data_type* const data, const size_t start, const size_t size, - const size_t starting_idx) { + const size_t starting_idx, + const bool is_set) { for (size_t i = starting_idx; i < size; i++) { const auto proxy = get_proxy(data, start + i); - if (proxy) { + if (proxy == is_set) { return i; } } diff --git a/internal/core/src/bitset/detail/element_vectorized.h b/internal/core/src/bitset/detail/element_vectorized.h index 93668904ab..f490741ac0 100644 --- a/internal/core/src/bitset/detail/element_vectorized.h +++ b/internal/core/src/bitset/detail/element_vectorized.h @@ -220,9 +220,10 @@ struct VectorizedElementWiseBitsetPolicy { op_find(const data_type* const data, const size_t start, const size_t size, - const size_t starting_idx) { + const size_t starting_idx, + const bool is_set) { return ElementWiseBitsetPolicy::op_find( - data, start, size, starting_idx); + data, start, size, starting_idx, is_set); } // diff --git a/internal/core/src/bitset/detail/element_wise.h b/internal/core/src/bitset/detail/element_wise.h index 3baf9c45af..771ca4178b 100644 --- a/internal/core/src/bitset/detail/element_wise.h +++ b/internal/core/src/bitset/detail/element_wise.h @@ -718,10 +718,10 @@ struct ElementWiseBitsetPolicy { // static inline std::optional - op_find(const data_type* const data, - const size_t start, - const size_t size, - const size_t starting_idx) { + op_find_1(const data_type* const data, + const size_t start, + const size_t size, + const size_t starting_idx) { if (size == 0) { return std::nullopt; } @@ -791,6 +791,91 @@ struct ElementWiseBitsetPolicy { return std::nullopt; } + static inline std::optional + op_find_0(const data_type* const data, + const size_t start, + const size_t size, + const size_t starting_idx) { + if (size == 0) { + return std::nullopt; + } + + // + auto start_element = get_element(start + starting_idx); + const auto end_element = get_element(start + size); + + const auto start_shift = get_shift(start + starting_idx); + const auto end_shift = get_shift(start + size); + + // same element? + if (start_element == end_element) { + const data_type existing_v = ~data[start_element]; + + const data_type existing_mask = get_shift_mask_end(start_shift) & + get_shift_mask_begin(end_shift); + + const data_type value = existing_v & existing_mask; + if (value != 0) { + const auto ctz = CtzHelper::ctz(value); + return size_t(ctz) + start_element * data_bits - start; + } else { + return std::nullopt; + } + } + + // process the first element + if (start_shift != 0) { + const data_type existing_v = ~data[start_element]; + const data_type existing_mask = get_shift_mask_end(start_shift); + + const data_type value = existing_v & existing_mask; + if (value != 0) { + const auto ctz = CtzHelper::ctz(value) + + start_element * data_bits - start; + return size_t(ctz); + } + + start_element += 1; + } + + // process the middle + for (size_t i = start_element; i < end_element; i++) { + const data_type value = ~data[i]; + if (value != 0) { + const auto ctz = CtzHelper::ctz(value); + return size_t(ctz) + i * data_bits - start; + } + } + + // process the last element + if (end_shift != 0) { + const data_type existing_v = ~data[end_element]; + const data_type existing_mask = get_shift_mask_begin(end_shift); + + const data_type value = existing_v & existing_mask; + if (value != 0) { + const auto ctz = CtzHelper::ctz(value); + return size_t(ctz) + end_element * data_bits - start; + } + } + + return std::nullopt; + } + + // + static inline std::optional + op_find(const data_type* const data, + const size_t start, + const size_t size, + const size_t starting_idx, + const bool is_set) { + if (is_set) { + return op_find_1(data, start, size, starting_idx); + } else { + return op_find_0(data, start, size, starting_idx); + } + } + // template static inline void diff --git a/internal/core/src/bitset/detail/proxy.h b/internal/core/src/bitset/detail/proxy.h index b29eaec7bb..2b2d6613f7 100644 --- a/internal/core/src/bitset/detail/proxy.h +++ b/internal/core/src/bitset/detail/proxy.h @@ -16,6 +16,8 @@ #pragma once +#include + namespace milvus { namespace bitset { namespace detail { diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp index 1a79fc8cff..0c7c0c7a91 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp @@ -1255,12 +1255,8 @@ ChunkedSegmentSealedImpl::find_first(int64_t limit, std::vector seg_offsets; seg_offsets.reserve(limit); - // flip bitset since `find_next` is used to find true. - auto flipped = bitset.clone(); - flipped.flip(); - int64_t offset = 0; - std::optional result = flipped.find_first(); + std::optional result = bitset.find_first(false); while (result.has_value() && hit_num < limit) { hit_num++; seg_offsets.push_back(result.value()); @@ -1269,7 +1265,7 @@ ChunkedSegmentSealedImpl::find_first(int64_t limit, // In fact, this case won't happen on sealed segments. continue; } - result = flipped.find_next(offset); + result = bitset.find_next(offset, false); } return {seg_offsets, more_hit_than_limit && result.has_value()}; diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 1b7b9ccffa..68fc4e9646 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -1723,13 +1723,8 @@ SegmentSealedImpl::find_first(int64_t limit, const BitsetType& bitset) const { std::vector seg_offsets; seg_offsets.reserve(limit); - // flip bitset since `find_first` & `find_next` is used to find true. - // could be optimized by support find false in bitset. - auto flipped = bitset.clone(); - flipped.flip(); - int64_t offset = 0; - std::optional result = flipped.find_first(); + std::optional result = bitset.find_first(false); while (result.has_value() && hit_num < limit) { hit_num++; seg_offsets.push_back(result.value()); @@ -1738,7 +1733,7 @@ SegmentSealedImpl::find_first(int64_t limit, const BitsetType& bitset) const { // In fact, this case won't happen on sealed segments. continue; } - result = flipped.find_next(offset); + result = bitset.find_next(offset, false); } return {seg_offsets, more_hit_than_limit && result.has_value()}; diff --git a/internal/core/unittest/test_bitset.cpp b/internal/core/unittest/test_bitset.cpp index c307fc33df..e4decc751c 100644 --- a/internal/core/unittest/test_bitset.cpp +++ b/internal/core/unittest/test_bitset.cpp @@ -346,7 +346,7 @@ from_i32(const int32_t i) { // template void -TestFindImpl(BitsetT& bitset, const size_t max_v) { +TestFindImpl(BitsetT& bitset, const size_t max_v, const bool is_set) { const size_t n = bitset.size(); std::default_random_engine rng(123); @@ -361,9 +361,13 @@ TestFindImpl(BitsetT& bitset, const size_t max_v) { } } + if (!is_set) { + bitset.flip(); + } + StopWatch sw; - auto bit_idx = bitset.find_first(); + auto bit_idx = bitset.find_first(is_set); if (!bit_idx.has_value()) { ASSERT_EQ(one_pos.size(), 0); return; @@ -372,7 +376,7 @@ TestFindImpl(BitsetT& bitset, const size_t max_v) { for (size_t i = 0; i < one_pos.size(); i++) { ASSERT_TRUE(bit_idx.has_value()) << n << ", " << max_v; ASSERT_EQ(bit_idx.value(), one_pos[i]) << n << ", " << max_v; - bit_idx = bitset.find_next(bit_idx.value()); + bit_idx = bitset.find_next(bit_idx.value(), is_set); } ASSERT_FALSE(bit_idx.has_value()) @@ -387,32 +391,40 @@ template void TestFindImpl() { for (const size_t n : typical_sizes) { - for (const size_t pr : {1, 100}) { - BitsetT bitset(n); - bitset.reset(); - - if (print_log) { - printf("Testing bitset, n=%zd, pr=%zd\n", n, pr); - } - - TestFindImpl(bitset, pr); - - for (const size_t offset : typical_offsets) { - if (offset >= n) { - continue; - } - + for (const bool is_set : {true, false}) { + for (const size_t pr : {1, 100}) { + BitsetT bitset(n); bitset.reset(); - auto view = bitset.view(offset); if (print_log) { - printf("Testing bitset view, n=%zd, offset=%zd, pr=%zd\n", + printf("Testing bitset, n=%zd, is_set=%d, pr=%zd\n", n, - offset, + (is_set) ? 1 : 0, pr); } - TestFindImpl(view, pr); + TestFindImpl(bitset, pr, is_set); + + for (const size_t offset : typical_offsets) { + if (offset >= n) { + continue; + } + + bitset.reset(); + auto view = bitset.view(offset); + + if (print_log) { + printf( + "Testing bitset view, n=%zd, offset=%zd, " + "is_set=%d, pr=%zd\n", + n, + offset, + (is_set) ? 1 : 0, + pr); + } + + TestFindImpl(view, pr, is_set); + } } } }