mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: [2.5][bitset] extend op_find() to be able to search both 0 and 1 (#39237)
Cherry-pick from master pr: #39176 #39229 issue: #39124 `bitset::find_first()` and `bitset::find_next()` now accept one more parameter, which allows to search for `0` bit instead of `1` bit --------- Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com> Signed-off-by: Congqi Xia <congqi.xia@zilliz.com> Co-authored-by: Alexander Guzhva <alexanderguzhva@gmail.com>
This commit is contained in:
parent
e6ac2fe063
commit
b4ddf746c1
@ -22,8 +22,8 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64")
|
||||
detail/platform/x86/instruction_set.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(detail/platform/x86/avx512-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq")
|
||||
set_source_files_properties(detail/platform/x86/avx2-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma")
|
||||
set_source_files_properties(detail/platform/x86/avx512-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq -mavx512cd -mbmi")
|
||||
set_source_files_properties(detail/platform/x86/avx2-inst.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma -mbmi")
|
||||
|
||||
# set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512bw -mavx512vl -mavx512dq")
|
||||
# set_source_files_properties(detail/platform/dynamic.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx -mfma")
|
||||
|
||||
@ -546,23 +546,26 @@ class BitsetBase {
|
||||
return as_derived();
|
||||
}
|
||||
|
||||
// Find the index of the first bit set to true.
|
||||
// Find the index of the first bit set to either true (default), or false.
|
||||
inline std::optional<size_t>
|
||||
find_first() const {
|
||||
find_first(const bool is_set = true) const {
|
||||
return policy_type::op_find(
|
||||
this->data(), this->offset(), this->size(), 0);
|
||||
this->data(), this->offset(), this->size(), 0, is_set);
|
||||
}
|
||||
|
||||
// Find the index of the first bit set to true, starting from a given bit index.
|
||||
// Find the index of the first bit set to either true (default), or false, starting from a given bit index.
|
||||
inline std::optional<size_t>
|
||||
find_next(const size_t starting_bit_idx) const {
|
||||
find_next(const size_t starting_bit_idx, const bool is_set = true) const {
|
||||
const size_t size_v = this->size();
|
||||
if (starting_bit_idx + 1 >= size_v) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return policy_type::op_find(
|
||||
this->data(), this->offset(), this->size(), starting_bit_idx + 1);
|
||||
return policy_type::op_find(this->data(),
|
||||
this->offset(),
|
||||
this->size(),
|
||||
starting_bit_idx + 1,
|
||||
is_set);
|
||||
}
|
||||
|
||||
// Read multiple bits starting from a given bit index.
|
||||
|
||||
@ -315,10 +315,11 @@ struct BitWiseBitsetPolicy {
|
||||
op_find(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx) {
|
||||
const size_t starting_idx,
|
||||
const bool is_set) {
|
||||
for (size_t i = starting_idx; i < size; i++) {
|
||||
const auto proxy = get_proxy(data, start + i);
|
||||
if (proxy) {
|
||||
if (proxy == is_set) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
@ -220,9 +220,10 @@ struct VectorizedElementWiseBitsetPolicy {
|
||||
op_find(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx) {
|
||||
const size_t starting_idx,
|
||||
const bool is_set) {
|
||||
return ElementWiseBitsetPolicy<ElementT>::op_find(
|
||||
data, start, size, starting_idx);
|
||||
data, start, size, starting_idx, is_set);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@ -718,10 +718,10 @@ struct ElementWiseBitsetPolicy {
|
||||
|
||||
//
|
||||
static inline std::optional<size_t>
|
||||
op_find(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx) {
|
||||
op_find_1(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx) {
|
||||
if (size == 0) {
|
||||
return std::nullopt;
|
||||
}
|
||||
@ -791,6 +791,91 @@ struct ElementWiseBitsetPolicy {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static inline std::optional<size_t>
|
||||
op_find_0(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx) {
|
||||
if (size == 0) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
//
|
||||
auto start_element = get_element(start + starting_idx);
|
||||
const auto end_element = get_element(start + size);
|
||||
|
||||
const auto start_shift = get_shift(start + starting_idx);
|
||||
const auto end_shift = get_shift(start + size);
|
||||
|
||||
// same element?
|
||||
if (start_element == end_element) {
|
||||
const data_type existing_v = ~data[start_element];
|
||||
|
||||
const data_type existing_mask = get_shift_mask_end(start_shift) &
|
||||
get_shift_mask_begin(end_shift);
|
||||
|
||||
const data_type value = existing_v & existing_mask;
|
||||
if (value != 0) {
|
||||
const auto ctz = CtzHelper<data_type>::ctz(value);
|
||||
return size_t(ctz) + start_element * data_bits - start;
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
// process the first element
|
||||
if (start_shift != 0) {
|
||||
const data_type existing_v = ~data[start_element];
|
||||
const data_type existing_mask = get_shift_mask_end(start_shift);
|
||||
|
||||
const data_type value = existing_v & existing_mask;
|
||||
if (value != 0) {
|
||||
const auto ctz = CtzHelper<data_type>::ctz(value) +
|
||||
start_element * data_bits - start;
|
||||
return size_t(ctz);
|
||||
}
|
||||
|
||||
start_element += 1;
|
||||
}
|
||||
|
||||
// process the middle
|
||||
for (size_t i = start_element; i < end_element; i++) {
|
||||
const data_type value = ~data[i];
|
||||
if (value != 0) {
|
||||
const auto ctz = CtzHelper<data_type>::ctz(value);
|
||||
return size_t(ctz) + i * data_bits - start;
|
||||
}
|
||||
}
|
||||
|
||||
// process the last element
|
||||
if (end_shift != 0) {
|
||||
const data_type existing_v = ~data[end_element];
|
||||
const data_type existing_mask = get_shift_mask_begin(end_shift);
|
||||
|
||||
const data_type value = existing_v & existing_mask;
|
||||
if (value != 0) {
|
||||
const auto ctz = CtzHelper<data_type>::ctz(value);
|
||||
return size_t(ctz) + end_element * data_bits - start;
|
||||
}
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
//
|
||||
static inline std::optional<size_t>
|
||||
op_find(const data_type* const data,
|
||||
const size_t start,
|
||||
const size_t size,
|
||||
const size_t starting_idx,
|
||||
const bool is_set) {
|
||||
if (is_set) {
|
||||
return op_find_1(data, start, size, starting_idx);
|
||||
} else {
|
||||
return op_find_0(data, start, size, starting_idx);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
template <typename T, typename U, CompareOpType Op>
|
||||
static inline void
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace milvus {
|
||||
namespace bitset {
|
||||
namespace detail {
|
||||
|
||||
@ -1255,12 +1255,8 @@ ChunkedSegmentSealedImpl::find_first(int64_t limit,
|
||||
std::vector<int64_t> seg_offsets;
|
||||
seg_offsets.reserve(limit);
|
||||
|
||||
// flip bitset since `find_next` is used to find true.
|
||||
auto flipped = bitset.clone();
|
||||
flipped.flip();
|
||||
|
||||
int64_t offset = 0;
|
||||
std::optional<size_t> result = flipped.find_first();
|
||||
std::optional<size_t> result = bitset.find_first(false);
|
||||
while (result.has_value() && hit_num < limit) {
|
||||
hit_num++;
|
||||
seg_offsets.push_back(result.value());
|
||||
@ -1269,7 +1265,7 @@ ChunkedSegmentSealedImpl::find_first(int64_t limit,
|
||||
// In fact, this case won't happen on sealed segments.
|
||||
continue;
|
||||
}
|
||||
result = flipped.find_next(offset);
|
||||
result = bitset.find_next(offset, false);
|
||||
}
|
||||
|
||||
return {seg_offsets, more_hit_than_limit && result.has_value()};
|
||||
|
||||
@ -1723,13 +1723,8 @@ SegmentSealedImpl::find_first(int64_t limit, const BitsetType& bitset) const {
|
||||
std::vector<int64_t> seg_offsets;
|
||||
seg_offsets.reserve(limit);
|
||||
|
||||
// flip bitset since `find_first` & `find_next` is used to find true.
|
||||
// could be optimized by support find false in bitset.
|
||||
auto flipped = bitset.clone();
|
||||
flipped.flip();
|
||||
|
||||
int64_t offset = 0;
|
||||
std::optional<size_t> result = flipped.find_first();
|
||||
std::optional<size_t> result = bitset.find_first(false);
|
||||
while (result.has_value() && hit_num < limit) {
|
||||
hit_num++;
|
||||
seg_offsets.push_back(result.value());
|
||||
@ -1738,7 +1733,7 @@ SegmentSealedImpl::find_first(int64_t limit, const BitsetType& bitset) const {
|
||||
// In fact, this case won't happen on sealed segments.
|
||||
continue;
|
||||
}
|
||||
result = flipped.find_next(offset);
|
||||
result = bitset.find_next(offset, false);
|
||||
}
|
||||
|
||||
return {seg_offsets, more_hit_than_limit && result.has_value()};
|
||||
|
||||
@ -346,7 +346,7 @@ from_i32(const int32_t i) {
|
||||
//
|
||||
template <typename BitsetT>
|
||||
void
|
||||
TestFindImpl(BitsetT& bitset, const size_t max_v) {
|
||||
TestFindImpl(BitsetT& bitset, const size_t max_v, const bool is_set) {
|
||||
const size_t n = bitset.size();
|
||||
|
||||
std::default_random_engine rng(123);
|
||||
@ -361,9 +361,13 @@ TestFindImpl(BitsetT& bitset, const size_t max_v) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_set) {
|
||||
bitset.flip();
|
||||
}
|
||||
|
||||
StopWatch sw;
|
||||
|
||||
auto bit_idx = bitset.find_first();
|
||||
auto bit_idx = bitset.find_first(is_set);
|
||||
if (!bit_idx.has_value()) {
|
||||
ASSERT_EQ(one_pos.size(), 0);
|
||||
return;
|
||||
@ -372,7 +376,7 @@ TestFindImpl(BitsetT& bitset, const size_t max_v) {
|
||||
for (size_t i = 0; i < one_pos.size(); i++) {
|
||||
ASSERT_TRUE(bit_idx.has_value()) << n << ", " << max_v;
|
||||
ASSERT_EQ(bit_idx.value(), one_pos[i]) << n << ", " << max_v;
|
||||
bit_idx = bitset.find_next(bit_idx.value());
|
||||
bit_idx = bitset.find_next(bit_idx.value(), is_set);
|
||||
}
|
||||
|
||||
ASSERT_FALSE(bit_idx.has_value())
|
||||
@ -387,32 +391,40 @@ template <typename BitsetT>
|
||||
void
|
||||
TestFindImpl() {
|
||||
for (const size_t n : typical_sizes) {
|
||||
for (const size_t pr : {1, 100}) {
|
||||
BitsetT bitset(n);
|
||||
bitset.reset();
|
||||
|
||||
if (print_log) {
|
||||
printf("Testing bitset, n=%zd, pr=%zd\n", n, pr);
|
||||
}
|
||||
|
||||
TestFindImpl(bitset, pr);
|
||||
|
||||
for (const size_t offset : typical_offsets) {
|
||||
if (offset >= n) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const bool is_set : {true, false}) {
|
||||
for (const size_t pr : {1, 100}) {
|
||||
BitsetT bitset(n);
|
||||
bitset.reset();
|
||||
auto view = bitset.view(offset);
|
||||
|
||||
if (print_log) {
|
||||
printf("Testing bitset view, n=%zd, offset=%zd, pr=%zd\n",
|
||||
printf("Testing bitset, n=%zd, is_set=%d, pr=%zd\n",
|
||||
n,
|
||||
offset,
|
||||
(is_set) ? 1 : 0,
|
||||
pr);
|
||||
}
|
||||
|
||||
TestFindImpl(view, pr);
|
||||
TestFindImpl(bitset, pr, is_set);
|
||||
|
||||
for (const size_t offset : typical_offsets) {
|
||||
if (offset >= n) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bitset.reset();
|
||||
auto view = bitset.view(offset);
|
||||
|
||||
if (print_log) {
|
||||
printf(
|
||||
"Testing bitset view, n=%zd, offset=%zd, "
|
||||
"is_set=%d, pr=%zd\n",
|
||||
n,
|
||||
offset,
|
||||
(is_set) ? 1 : 0,
|
||||
pr);
|
||||
}
|
||||
|
||||
TestFindImpl(view, pr, is_set);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user