// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include "common.h" #include "detail/maybe_vector.h" namespace milvus { namespace bitset { namespace { // A supporting facility for checking out of range. // It is needed to add a capability to verify that we won't go out of // range even for the Release build. template struct RangeChecker {}; // disabled. template <> struct RangeChecker { // Check if a < max template static inline void lt(const SizeT a, const SizeT max) { } // Check if a <= max template static inline void le(const SizeT a, const SizeT max) { } // Check if a == b template static inline void eq(const SizeT a, const SizeT b) { } }; // enabled. template <> struct RangeChecker { // Check if a < max template static inline void lt(const SizeT a, const SizeT max) { // todo: replace assert(a < max); } // Check if a <= max template static inline void le(const SizeT a, const SizeT max) { // todo: replace assert(a <= max); } // Check if a == b template static inline void eq(const SizeT a, const SizeT b) { // todo: replace assert(a == b); } }; } // namespace // CRTP // Bitset view, which does not own the data. template class BitsetView; // Bitset, which owns the data. template class Bitset; // This is the base CRTP class. template class BitsetBase { template friend class BitsetView; template friend class Bitset; public: using policy_type = PolicyT; using data_type = typename policy_type::data_type; using proxy_type = typename policy_type::proxy_type; using const_proxy_type = typename policy_type::const_proxy_type; using range_checker = RangeChecker; // inline data_type* data() { return as_derived().data_impl(); } // inline const data_type* data() const { return as_derived().data_impl(); } // Return the number of bits we're working with. inline size_t size() const { return as_derived().size_impl(); } // Return the number of bytes which is needed to // contain all our bits. inline size_t size_in_bytes() const { return policy_type::get_required_size_in_bytes(this->size()); } // Return the number of elements which is needed to // contain all our bits. inline size_t size_in_elements() const { return policy_type::get_required_size_in_elements(this->size()); } // inline bool empty() const { return (this->size() == 0); } // inline proxy_type operator[](const size_t bit_idx) { range_checker::lt(bit_idx, this->size()); const size_t idx_v = bit_idx + this->offset(); return policy_type::get_proxy(this->data(), idx_v); } // inline bool operator[](const size_t bit_idx) const { range_checker::lt(bit_idx, this->size()); const size_t idx_v = bit_idx + this->offset(); const auto proxy = policy_type::get_proxy(this->data(), idx_v); return proxy.operator bool(); } // Set all bits to true. inline void set() { policy_type::op_set(this->data(), this->offset(), this->size()); } // Set a given bit to a given value. inline void set(const size_t bit_idx, const bool value = true) { this->operator[](bit_idx) = value; } // Set a given range of [a, b) bits to a given value. inline void set(const size_t bit_idx_start, const size_t size, const bool value = true) { range_checker::le(bit_idx_start + size, this->size()); policy_type::op_fill( this->data(), this->offset() + bit_idx_start, size, value); } // Set all bits to false. inline void reset() { policy_type::op_reset(this->data(), this->offset(), this->size()); } // Set a given bit to false. inline void reset(const size_t bit_idx) { this->operator[](bit_idx) = false; } // Set a given range of [a, b) bits to false. inline void reset(const size_t bit_idx_start, const size_t size) { this->set(bit_idx_start, size, false); } // Return whether all bits are set to true. inline bool all() const { return policy_type::op_all(this->data(), this->offset(), this->size()); } // Return whether any of the bits is set to true. inline bool any() const { return (!this->none()); } // Return whether all bits are set to false. inline bool none() const { return policy_type::op_none(this->data(), this->offset(), this->size()); } // Inplace and. template inline void inplace_and(const BitsetBase& other, const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); policy_type::op_and( this->data(), other.data(), this->offset(), other.offset(), size); } template inline void inplace_and(const BitsetView* const others, const size_t n_others, const size_t size) { range_checker::le(size, this->size()); for (size_t i = 0; i < n_others; i++) { range_checker::le(size, others[i].size()); } // pick buffers detail::MaybeVector tmp_data(n_others); detail::MaybeVector tmp_offset(n_others); for (size_t i = 0; i < n_others; i++) { tmp_data[i] = others[i].data(); tmp_offset[i] = others[i].offset(); } policy_type::op_and_multiple(this->data(), tmp_data.data(), this->offset(), tmp_offset.data(), n_others, size); } template inline void inplace_and(const BitsetView* const others, const size_t n_others) { this->inplace_and(others, n_others, this->size()); } template inline void inplace_and(const Bitset* const others, const size_t n_others, const size_t size) { range_checker::le(size, this->size()); for (size_t i = 0; i < n_others; i++) { range_checker::le(size, others[i].size()); } // pick buffers detail::MaybeVector tmp_data(n_others); detail::MaybeVector tmp_offset(n_others); for (size_t i = 0; i < n_others; i++) { tmp_data[i] = others[i].data(); tmp_offset[i] = others[i].offset(); } policy_type::op_and_multiple(this->data(), tmp_data.data(), this->offset(), tmp_offset.data(), n_others, size); } template inline void inplace_and(const Bitset* const others, const size_t n_others) { this->inplace_and(others, n_others, this->size()); } // Inplace and. A given bitset / bitset view is expected to have the same size. template inline ImplT& operator&=(const BitsetBase& other) { range_checker::eq(other.size(), this->size()); this->inplace_and(other, this->size()); return as_derived(); } // Inplace or. template inline void inplace_or(const BitsetBase& other, const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); policy_type::op_or( this->data(), other.data(), this->offset(), other.offset(), size); } template inline void inplace_or(const BitsetView* const others, const size_t n_others, const size_t size) { range_checker::le(size, this->size()); for (size_t i = 0; i < n_others; i++) { range_checker::le(size, others[i].size()); } // pick buffers detail::MaybeVector tmp_data(n_others); detail::MaybeVector tmp_offset(n_others); for (size_t i = 0; i < n_others; i++) { tmp_data[i] = others[i].data(); tmp_offset[i] = others[i].offset(); } policy_type::op_or_multiple(this->data(), tmp_data.data(), this->offset(), tmp_offset.data(), n_others, size); } template inline void inplace_or(const BitsetView* const others, const size_t n_others) { this->inplace_or(others, n_others, this->size()); } template inline void inplace_or(const Bitset* const others, const size_t n_others, const size_t size) { range_checker::le(size, this->size()); for (size_t i = 0; i < n_others; i++) { range_checker::le(size, others[i].size()); } // pick buffers detail::MaybeVector tmp_data(n_others); detail::MaybeVector tmp_offset(n_others); for (size_t i = 0; i < n_others; i++) { tmp_data[i] = others[i].data(); tmp_offset[i] = others[i].offset(); } policy_type::op_or_multiple(this->data(), tmp_data.data(), this->offset(), tmp_offset.data(), n_others, size); } template inline void inplace_or(const Bitset* const others, const size_t n_others) { this->inplace_or(others, n_others, this->size()); } // Inplace or. A given bitset / bitset view is expected to have the same size. template inline ImplT& operator|=(const BitsetBase& other) { range_checker::eq(other.size(), this->size()); this->inplace_or(other, this->size()); return as_derived(); } // Revert all bits. inline void flip() { policy_type::op_flip(this->data(), this->offset(), this->size()); } // inline BitsetView operator+(const size_t offset) { return this->view(offset); } // Create a view of a given size from the given position. inline BitsetView view(const size_t offset, const size_t size) { range_checker::le(offset, this->size()); range_checker::le(offset + size, this->size()); return BitsetView( this->data(), this->offset() + offset, size); } // Create a const view of a given size from the given position. inline BitsetView view(const size_t offset, const size_t size) const { range_checker::le(offset, this->size()); range_checker::le(offset + size, this->size()); return BitsetView( const_cast(this->data()), this->offset() + offset, size); } // Create a view from the given position, which uses all available size. inline BitsetView view(const size_t offset) { range_checker::le(offset, this->size()); return BitsetView( this->data(), this->offset() + offset, this->size() - offset); } // Create a const view from the given position, which uses all available size. inline const BitsetView view(const size_t offset) const { range_checker::le(offset, this->size()); return BitsetView( const_cast(this->data()), this->offset() + offset, this->size() - offset); } // Create a view. inline BitsetView view() { return this->view(0); } // Create a const view. inline const BitsetView view() const { return this->view(0); } // Return the number of bits which are set to true. inline size_t count() const { return policy_type::op_count( this->data(), this->offset(), this->size()); } // Compare the current bitset with another bitset / bitset view. template inline bool operator==(const BitsetBase& other) { if (this->size() != other.size()) { return false; } return policy_type::op_eq(this->data(), other.data(), this->offset(), other.offset(), this->size()); } // Compare the current bitset with another bitset / bitset view. template inline bool operator!=(const BitsetBase& other) { return (!(*this == other)); } // Inplace xor. template inline void inplace_xor(const BitsetBase& other, const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); policy_type::op_xor( this->data(), other.data(), this->offset(), other.offset(), size); } // Inplace xor. A given bitset / bitset view is expected to have the same size. template inline ImplT& operator^=(const BitsetBase& other) { range_checker::eq(other.size(), this->size()); this->inplace_xor(other, this->size()); return as_derived(); } // Inplace sub. template inline void inplace_sub(const BitsetBase& other, const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); policy_type::op_sub( this->data(), other.data(), this->offset(), other.offset(), size); } // Inplace sub. A given bitset / bitset view is expected to have the same size. template inline ImplT& operator-=(const BitsetBase& other) { range_checker::eq(other.size(), this->size()); this->inplace_sub(other, this->size()); return as_derived(); } // Find the index of the first bit set to either true (default), or false. inline std::optional find_first(const bool is_set = true) const { return policy_type::op_find( this->data(), this->offset(), this->size(), 0, is_set); } // Find the index of the first bit set to either true (default), or false, starting from a given bit index. inline std::optional find_next(const size_t starting_bit_idx, const bool is_set = true) const { const size_t size_v = this->size(); if (starting_bit_idx + 1 >= size_v) { return std::nullopt; } return policy_type::op_find(this->data(), this->offset(), this->size(), starting_bit_idx + 1, is_set); } // Read multiple bits starting from a given bit index. inline data_type read(const size_t starting_bit_idx, const size_t nbits) { range_checker::le(nbits, sizeof(data_type)); return policy_type::op_read( this->data(), this->offset() + starting_bit_idx, nbits); } // Write multiple bits starting from a given bit index. inline void write(const size_t starting_bit_idx, const data_type value, const size_t nbits) { range_checker::le(nbits, sizeof(data_type)); policy_type::op_write( this->data(), this->offset() + starting_bit_idx, nbits, value); } // Compare two arrays element-wise template void inplace_compare_column(const T* const __restrict t, const U* const __restrict u, const size_t size, CompareOpType op) { if (op == CompareOpType::EQ) { this->inplace_compare_column(t, u, size); } else if (op == CompareOpType::GE) { this->inplace_compare_column(t, u, size); } else if (op == CompareOpType::GT) { this->inplace_compare_column(t, u, size); } else if (op == CompareOpType::LE) { this->inplace_compare_column(t, u, size); } else if (op == CompareOpType::LT) { this->inplace_compare_column(t, u, size); } else if (op == CompareOpType::NE) { this->inplace_compare_column(t, u, size); } else { // unimplemented } } template void inplace_compare_column(const T* const __restrict t, const U* const __restrict u, const size_t size) { range_checker::le(size, this->size()); policy_type::template op_compare_column( this->data(), this->offset(), t, u, size); } // Compare elements of an given array with a given value template void inplace_compare_val(const T* const __restrict t, const size_t size, const T& value, CompareOpType op) { if (op == CompareOpType::EQ) { this->inplace_compare_val(t, size, value); } else if (op == CompareOpType::GE) { this->inplace_compare_val(t, size, value); } else if (op == CompareOpType::GT) { this->inplace_compare_val(t, size, value); } else if (op == CompareOpType::LE) { this->inplace_compare_val(t, size, value); } else if (op == CompareOpType::LT) { this->inplace_compare_val(t, size, value); } else if (op == CompareOpType::NE) { this->inplace_compare_val(t, size, value); } else { // unimplemented } } template void inplace_compare_val(const T* const __restrict t, const size_t size, const T& value) { range_checker::le(size, this->size()); policy_type::template op_compare_val( this->data(), this->offset(), t, size, value); } // template void inplace_within_range_column(const T* const __restrict lower, const T* const __restrict upper, const T* const __restrict values, const size_t size, const RangeType op) { if (op == RangeType::IncInc) { this->inplace_within_range_column( lower, upper, values, size); } else if (op == RangeType::IncExc) { this->inplace_within_range_column( lower, upper, values, size); } else if (op == RangeType::ExcInc) { this->inplace_within_range_column( lower, upper, values, size); } else if (op == RangeType::ExcExc) { this->inplace_within_range_column( lower, upper, values, size); } else { // unimplemented } } template void inplace_within_range_column(const T* const __restrict lower, const T* const __restrict upper, const T* const __restrict values, const size_t size) { range_checker::le(size, this->size()); policy_type::template op_within_range_column( this->data(), this->offset(), lower, upper, values, size); } // template void inplace_within_range_val(const T& lower, const T& upper, const T* const __restrict values, const size_t size, const RangeType op) { if (op == RangeType::IncInc) { this->inplace_within_range_val( lower, upper, values, size); } else if (op == RangeType::IncExc) { this->inplace_within_range_val( lower, upper, values, size); } else if (op == RangeType::ExcInc) { this->inplace_within_range_val( lower, upper, values, size); } else if (op == RangeType::ExcExc) { this->inplace_within_range_val( lower, upper, values, size); } else { // unimplemented } } template void inplace_within_range_val(const T& lower, const T& upper, const T* const __restrict values, const size_t size) { range_checker::le(size, this->size()); policy_type::template op_within_range_val( this->data(), this->offset(), lower, upper, values, size); } // template void inplace_arith_compare(const T* const __restrict src, const ArithHighPrecisionType& right_operand, const ArithHighPrecisionType& value, const size_t size, const ArithOpType a_op, const CompareOpType cmp_op) { if (a_op == ArithOpType::Add) { if (cmp_op == CompareOpType::EQ) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::GE) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::GT) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::LE) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::LT) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::NE) { this->inplace_arith_compare( src, right_operand, value, size); } else { // unimplemented } } else if (a_op == ArithOpType::Sub) { if (cmp_op == CompareOpType::EQ) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::GE) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::GT) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::LE) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::LT) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::NE) { this->inplace_arith_compare( src, right_operand, value, size); } else { // unimplemented } } else if (a_op == ArithOpType::Mul) { if (cmp_op == CompareOpType::EQ) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::GE) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::GT) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::LE) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::LT) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::NE) { this->inplace_arith_compare( src, right_operand, value, size); } else { // unimplemented } } else if (a_op == ArithOpType::Div) { if (cmp_op == CompareOpType::EQ) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::GE) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::GT) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::LE) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::LT) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::NE) { this->inplace_arith_compare( src, right_operand, value, size); } else { // unimplemented } } else if (a_op == ArithOpType::Mod) { if (cmp_op == CompareOpType::EQ) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::GE) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::GT) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::LE) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::LT) { this->inplace_arith_compare( src, right_operand, value, size); } else if (cmp_op == CompareOpType::NE) { this->inplace_arith_compare( src, right_operand, value, size); } else { // unimplemented } } else { // unimplemented } } template void inplace_arith_compare(const T* const __restrict src, const ArithHighPrecisionType& right_operand, const ArithHighPrecisionType& value, const size_t size) { range_checker::le(size, this->size()); policy_type::template op_arith_compare( this->data(), this->offset(), src, right_operand, value, size); } // // Inplace and. Also, counts the number of active bits. template inline size_t inplace_and_with_count(const BitsetBase& other, const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); return policy_type::op_and_with_count( this->data(), other.data(), this->offset(), other.offset(), size); } // Inplace or. Also, counts the number of inactive bits. template inline size_t inplace_or_with_count(const BitsetBase& other, const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); return policy_type::op_or_with_count( this->data(), other.data(), this->offset(), other.offset(), size); } // Return the starting bit offset in our container. inline size_t offset() const { return as_derived().offset_impl(); } private: // CRTP inline ImplT& as_derived() { return static_cast(*this); } // CRTP inline const ImplT& as_derived() const { return static_cast(*this); } }; // Bitset view template class BitsetView : public BitsetBase, IsRangeCheckEnabled> { friend class BitsetBase, IsRangeCheckEnabled>; public: using policy_type = PolicyT; using data_type = typename policy_type::data_type; using proxy_type = typename policy_type::proxy_type; using const_proxy_type = typename policy_type::const_proxy_type; using range_checker = RangeChecker; BitsetView() = default; BitsetView(const BitsetView&) = default; BitsetView(BitsetView&&) = default; BitsetView& operator=(const BitsetView&) = default; BitsetView& operator=(BitsetView&&) = default; template explicit BitsetView(BitsetBase& bitset) : Data{bitset.data()}, Size{bitset.size()}, Offset{bitset.offset()} { } BitsetView(void* data, const size_t size) : Data{reinterpret_cast(data)}, Size{size} { } BitsetView(void* data, const size_t offset, const size_t size) : Data{reinterpret_cast(data)}, Size{size}, Offset{offset} { } private: // the referenced bits are [Offset, Offset + Size) data_type* Data = nullptr; // measured in bits size_t Size = 0; // measured in bits size_t Offset = 0; inline data_type* data_impl() { return Data; } inline const data_type* data_impl() const { return Data; } inline size_t size_impl() const { return Size; } inline size_t offset_impl() const { return Offset; } }; // Bitset template class Bitset : public BitsetBase, IsRangeCheckEnabled> { friend class BitsetBase, IsRangeCheckEnabled>; public: using policy_type = PolicyT; using data_type = typename policy_type::data_type; using proxy_type = typename policy_type::proxy_type; using const_proxy_type = typename policy_type::const_proxy_type; using view_type = BitsetView; // This is the container type. using container_type = ContainerT; // This is how the data is stored. For example, we may operate using // uint64_t values, but store the data in std::vector container. // This is useful if we need to convert a bitset into a container // using move operator. using container_data_type = typename container_type::value_type; using range_checker = RangeChecker; // Allocate an empty one. Bitset() = default; // Allocate the given number of bits. explicit Bitset(const size_t size) : Data(get_required_size_in_container_elements(size)), Size{size} { } // Allocate the given number of bits, initialize with a given value. Bitset(const size_t size, const bool init) : Data(get_required_size_in_container_elements(size), init ? data_type(-1) : 0), Size{size} { } // Do not allow implicit copies (Rust style). Bitset(const Bitset&) = delete; // Allow default move. Bitset(Bitset&&) = default; // Do not allow implicit copies (Rust style). Bitset& operator=(const Bitset&) = delete; // Allow default move. Bitset& operator=(Bitset&&) = default; template explicit Bitset(const BitsetBase& other) { Data = container_type( get_required_size_in_container_elements(other.size())); Size = other.size(); policy_type::op_copy(other.data(), other.offset(), this->data(), this->offset(), other.size()); } // Clone a current bitset (Rust style). Bitset clone() const { Bitset cloned; cloned.Data = Data; cloned.Size = Size; return cloned; } // Rust style. inline container_type into() && { return std::move(this->Data); } // Resize. void resize(const size_t new_size) { const size_t new_size_in_container_elements = get_required_size_in_container_elements(new_size); Data.resize(new_size_in_container_elements); Size = new_size; } // Resize and initialize new bits with a given value if grown. void resize(const size_t new_size, const bool init) { const size_t old_size = this->size(); this->resize(new_size); if (new_size > old_size) { policy_type::op_fill( this->data(), old_size, new_size - old_size, init); } } // Append data from another bitset / bitset view in // [starting_bit_idx, starting_bit_idx + count) range // to the end of this bitset. template void append(const BitsetBase& other, const size_t starting_bit_idx, const size_t count) { range_checker::le(starting_bit_idx, other.size()); const size_t old_size = this->size(); this->resize(this->size() + count); policy_type::op_copy(other.data(), other.offset() + starting_bit_idx, this->data(), this->offset() + old_size, count); } // Append data from another bitset / bitset view // to the end of this bitset. template void append(const BitsetBase& other) { this->append(other, 0, other.size()); } // Make bitset empty. inline void clear() { Data.clear(); Size = 0; } // Reserve inline void reserve(const size_t capacity) { const size_t capacity_in_container_elements = get_required_size_in_container_elements(capacity); Data.reserve(capacity_in_container_elements); } // Return a new bitset, equal to a | b template friend Bitset operator|(const BitsetBase& a, const BitsetBase& b) { Bitset clone(a); return std::move(clone |= b); } // Return a new bitset, equal to a - b template friend Bitset operator-(const BitsetBase& a, const BitsetBase& b) { Bitset clone(a); return std::move(clone -= b); } protected: // the container container_type Data; // the actual number of bits size_t Size = 0; inline data_type* data_impl() { return reinterpret_cast(Data.data()); } inline const data_type* data_impl() const { return reinterpret_cast(Data.data()); } inline size_t size_impl() const { return Size; } inline size_t offset_impl() const { return 0; } // static inline size_t get_required_size_in_container_elements(const size_t size) { const size_t size_in_bytes = policy_type::get_required_size_in_bytes(size); return (size_in_bytes + sizeof(container_data_type) - 1) / sizeof(container_data_type); } }; } // namespace bitset } // namespace milvus