From 71c0f64a167287aff94d8fe4ffacacab06fc6cb5 Mon Sep 17 00:00:00 2001 From: Alexander Guzhva Date: Wed, 16 Jul 2025 07:36:52 +0000 Subject: [PATCH] fix: [2.5] fix incorrect bitset for the division comparison when the right is < 0 (#43180) issue: #42900 pr: #43179 Upd: also handles Inf and NaN values, and the division by zero case for fp32 and fp64 Signed-off-by: Alexandr Guzhva --- internal/core/src/bitset/common.h | 14 + .../bitset/detail/platform/arm/neon-impl.h | 211 +++++++++--- .../src/bitset/detail/platform/arm/sve-impl.h | 176 ++++++++-- .../bitset/detail/platform/x86/avx2-impl.h | 171 ++++++++-- .../bitset/detail/platform/x86/avx512-impl.h | 313 +++++++++++++----- .../src/bitset/detail/platform/x86/common.h | 2 +- internal/core/unittest/test_bitset.cpp | 292 ++++++++++++---- 7 files changed, 926 insertions(+), 253 deletions(-) diff --git a/internal/core/src/bitset/common.h b/internal/core/src/bitset/common.h index f747f52246..5a0e6e7741 100644 --- a/internal/core/src/bitset/common.h +++ b/internal/core/src/bitset/common.h @@ -157,5 +157,19 @@ struct ArithCompareOperator { } }; +// This is related for a special handling of A/B vs C comparison. +// A multiplication operation is used instead of a division, +// and it is needed to invert signs and change comparison operators +// in case if the denominator is negative. +template +struct CompareOpDivFlip { + static constexpr CompareOpType op = + (CmpOp == CompareOpType::LE) ? CompareOpType::GE + : (CmpOp == CompareOpType::LT) ? CompareOpType::GT + : (CmpOp == CompareOpType::GE) ? CompareOpType::LE + : (CmpOp == CompareOpType::GT) ? CompareOpType::LT + : CmpOp; +}; + } // namespace bitset } // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/arm/neon-impl.h b/internal/core/src/bitset/detail/platform/arm/neon-impl.h index b8423272dc..a4756b7d2d 100644 --- a/internal/core/src/bitset/detail/platform/arm/neon-impl.h +++ b/internal/core/src/bitset/detail/platform/arm/neon-impl.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -1459,14 +1460,25 @@ struct ArithHelperF32 { template struct ArithHelperF32 { static inline uint32x4x2_t - op(const float32x4x2_t left, - const float32x4x2_t right, - const float32x4x2_t value) { + op_special(const float32x4x2_t left, + const float32x4x2_t right, + const float32x4x2_t value) { + // this is valid for the positive denominator, == and != cases. // left == right * value const float32x4x2_t rv = {vmulq_f32(right.val[0], value.val[0]), vmulq_f32(right.val[1], value.val[1])}; return CmpHelper::compare(left, rv); } + + static inline uint32x4x2_t + op(const float32x4x2_t left, + const float32x4x2_t right, + const float32x4x2_t value) { + // left / right == value + const float32x4x2_t rv = {vdivq_f32(left.val[0], right.val[0]), + vdivq_f32(left.val[1], right.val[1])}; + return CmpHelper::compare(rv, value); + } }; // @@ -1521,9 +1533,9 @@ struct ArithHelperF64 { template struct ArithHelperF64 { static inline uint64x2x4_t - op(const float64x2x4_t left, - const float64x2x4_t right, - const float64x2x4_t value) { + op_special(const float64x2x4_t left, + const float64x2x4_t right, + const float64x2x4_t value) { // left == right * value const float64x2x4_t rv = {vmulq_f64(right.val[0], value.val[0]), vmulq_f64(right.val[1], value.val[1]), @@ -1531,6 +1543,18 @@ struct ArithHelperF64 { vmulq_f64(right.val[3], value.val[3])}; return CmpHelper::compare(left, rv); } + + static inline uint64x2x4_t + op(const float64x2x4_t left, + const float64x2x4_t right, + const float64x2x4_t value) { + // left / right == value + const float64x2x4_t rv = {vdivq_f64(left.val[0], right.val[0]), + vdivq_f64(left.val[1], right.val[1]), + vdivq_f64(left.val[2], right.val[2]), + vdivq_f64(left.val[3], right.val[3])}; + return CmpHelper::compare(rv, value); + } }; } // namespace @@ -1743,28 +1767,74 @@ OpArithCompareImpl::op_arith_compare( if constexpr (AOp == ArithOpType::Mod) { return false; } else { - // the restriction of the API - assert((size % 8) == 0); + if constexpr (AOp == ArithOpType::Div) { + if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand > 0) { + // a special case that allows faster processing by using the multiplication + // operation instead of the division one. - // - const float32x4x2_t right_v = {vdupq_n_f32(right_operand), - vdupq_n_f32(right_operand)}; - const float32x4x2_t value_v = {vdupq_n_f32(value), vdupq_n_f32(value)}; + // the restriction of the API + assert((size % 8) == 0); - // todo: aligned reads & writes + // + const float32x4x2_t right_v = {vdupq_n_f32(right_operand), + vdupq_n_f32(right_operand)}; + const float32x4x2_t value_v = {vdupq_n_f32(value), + vdupq_n_f32(value)}; - const size_t size8 = (size / 8) * 8; - for (size_t i = 0; i < size8; i += 8) { - const float32x4x2_t v0v = {vld1q_f32(src + i), - vld1q_f32(src + i + 4)}; - const uint32x4x2_t cmp = - ArithHelperF32::op(v0v, right_v, value_v); + // todo: aligned reads & writes - const uint8_t mmask = movemask(cmp); - res_u8[i / 8] = mmask; + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float32x4x2_t v0v = {vld1q_f32(src + i), + vld1q_f32(src + i + 4)}; + const uint32x4x2_t cmp = + ArithHelperF32::op_special( + v0v, right_v, value_v); + + const uint8_t mmask = movemask(cmp); + res_u8[i / 8] = mmask; + } + + return true; + } else if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand < 0) { + // flip signs and go for the multiplication case + return OpArithCompareImpl::op>:: + op_arith_compare(res_u8, src, -right_operand, -value, size); + } + + // go with the default case } - return true; + // a default case + { + // the restriction of the API + assert((size % 8) == 0); + + // + const float32x4x2_t right_v = {vdupq_n_f32(right_operand), + vdupq_n_f32(right_operand)}; + const float32x4x2_t value_v = {vdupq_n_f32(value), + vdupq_n_f32(value)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float32x4x2_t v0v = {vld1q_f32(src + i), + vld1q_f32(src + i + 4)}; + const uint32x4x2_t cmp = + ArithHelperF32::op(v0v, right_v, value_v); + + const uint8_t mmask = movemask(cmp); + res_u8[i / 8] = mmask; + } + + return true; + } } } @@ -1779,35 +1849,86 @@ OpArithCompareImpl::op_arith_compare( if constexpr (AOp == ArithOpType::Mod) { return false; } else { - // the restriction of the API - assert((size % 8) == 0); + if constexpr (AOp == ArithOpType::Div) { + if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand > 0) { + // a special case that allows faster processing by using the multiplication + // operation instead of the division one. - // - const float64x2x4_t right_v = {vdupq_n_f64(right_operand), - vdupq_n_f64(right_operand), - vdupq_n_f64(right_operand), - vdupq_n_f64(right_operand)}; - const float64x2x4_t value_v = {vdupq_n_f64(value), - vdupq_n_f64(value), - vdupq_n_f64(value), - vdupq_n_f64(value)}; + // the restriction of the API + assert((size % 8) == 0); - // todo: aligned reads & writes + // + const float64x2x4_t right_v = {vdupq_n_f64(right_operand), + vdupq_n_f64(right_operand), + vdupq_n_f64(right_operand), + vdupq_n_f64(right_operand)}; + const float64x2x4_t value_v = {vdupq_n_f64(value), + vdupq_n_f64(value), + vdupq_n_f64(value), + vdupq_n_f64(value)}; - const size_t size8 = (size / 8) * 8; - for (size_t i = 0; i < size8; i += 8) { - const float64x2x4_t v0v = {vld1q_f64(src + i), - vld1q_f64(src + i + 2), - vld1q_f64(src + i + 4), - vld1q_f64(src + i + 6)}; - const uint64x2x4_t cmp = - ArithHelperF64::op(v0v, right_v, value_v); + // todo: aligned reads & writes - const uint8_t mmask = movemask(cmp); - res_u8[i / 8] = mmask; + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float64x2x4_t v0v = {vld1q_f64(src + i), + vld1q_f64(src + i + 2), + vld1q_f64(src + i + 4), + vld1q_f64(src + i + 6)}; + const uint64x2x4_t cmp = + ArithHelperF64::op_special( + v0v, right_v, value_v); + + const uint8_t mmask = movemask(cmp); + res_u8[i / 8] = mmask; + } + + return true; + } else if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand < 0) { + // flip signs and go for the multiplication case + return OpArithCompareImpl::op>:: + op_arith_compare(res_u8, src, -right_operand, -value, size); + } + + // go with the default case } - return true; + // a default case + { + // the restriction of the API + assert((size % 8) == 0); + + // + const float64x2x4_t right_v = {vdupq_n_f64(right_operand), + vdupq_n_f64(right_operand), + vdupq_n_f64(right_operand), + vdupq_n_f64(right_operand)}; + const float64x2x4_t value_v = {vdupq_n_f64(value), + vdupq_n_f64(value), + vdupq_n_f64(value), + vdupq_n_f64(value)}; + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const float64x2x4_t v0v = {vld1q_f64(src + i), + vld1q_f64(src + i + 2), + vld1q_f64(src + i + 4), + vld1q_f64(src + i + 6)}; + const uint64x2x4_t cmp = + ArithHelperF64::op(v0v, right_v, value_v); + + const uint8_t mmask = movemask(cmp); + res_u8[i / 8] = mmask; + } + + return true; + } } } diff --git a/internal/core/src/bitset/detail/platform/arm/sve-impl.h b/internal/core/src/bitset/detail/platform/arm/sve-impl.h index c5cd456659..70cfcb0307 100644 --- a/internal/core/src/bitset/detail/platform/arm/sve-impl.h +++ b/internal/core/src/bitset/detail/platform/arm/sve-impl.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -1314,15 +1315,26 @@ struct ArithHelperI64 { }; template -struct ArithHelperI64 { +struct ArithHelperF32 { + static inline svbool_t + op_special(const svbool_t pred, + const svfloat32_t left, + const svfloat32_t right, + const svfloat32_t value) { + // this is valid for the positive denominator, == and != cases. + // left == right * value + return CmpHelper::compare( + pred, left, svmul_f32_z(pred, right, value)); + } + static inline svbool_t op(const svbool_t pred, - const svint64_t left, - const svint64_t right, - const svint64_t value) { + const svfloat32_t left, + const svfloat32_t right, + const svfloat32_t value) { // left / right == value return CmpHelper::compare( - pred, svdiv_s64_z(pred, left, right), value); + pred, svdiv_f32_z(pred, left, right), value); } }; @@ -1371,14 +1383,25 @@ struct ArithHelperF32 { template struct ArithHelperF32 { + static inline svbool_t + op_special(const svbool_t pred, + const svfloat32_t left, + const svfloat32_t right, + const svfloat32_t value) { + // this is valid for the positive denominator, == and != cases. + // left == right * value + return CmpHelper::compare( + pred, left, svmul_f32_z(pred, right, value)); + } + static inline svbool_t op(const svbool_t pred, const svfloat32_t left, const svfloat32_t right, const svfloat32_t value) { - // left == right * value + // left / right == value return CmpHelper::compare( - pred, left, svmul_f32_z(pred, right, value)); + pred, svdiv_f32_z(pred, left, right), value); } }; @@ -1427,14 +1450,25 @@ struct ArithHelperF64 { template struct ArithHelperF64 { + static inline svbool_t + op_special(const svbool_t pred, + const svfloat64_t left, + const svfloat64_t right, + const svfloat64_t value) { + // this is valid for the positive denominator, == and != cases. + // left == right * value + return CmpHelper::compare( + pred, left, svmul_f64_z(pred, right, value)); + } + static inline svbool_t op(const svbool_t pred, const svfloat64_t left, const svfloat64_t right, const svfloat64_t value) { - // left == right * value + // left / right == value return CmpHelper::compare( - pred, left, svmul_f64_z(pred, right, value)); + pred, svdiv_f64_z(pred, left, right), value); } }; @@ -1573,22 +1607,60 @@ OpArithCompareImpl::op_arith_compare( if constexpr (AOp == ArithOpType::Mod) { return false; } else { - using T = float; + if constexpr (AOp == ArithOpType::Div) { + if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand > 0) { + // a special case that allows faster processing by using the multiplication + // operation instead of the division one. - auto handler = [src, right_operand, value](const svbool_t pred, - const size_t idx) { - using sve_t = SVEVector; + using T = float; - const auto right_v = svdup_n_f32(right_operand); - const auto value_v = svdup_n_f32(value); - const svfloat32_t src_v = svld1_f32(pred, src + idx); + auto handler = [src, right_operand, value](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; - const svbool_t cmp = - ArithHelperF32::op(pred, src_v, right_v, value_v); - return cmp; - }; + const auto right_v = svdup_n_f32(right_operand); + const auto value_v = svdup_n_f32(value); + const svfloat32_t src_v = svld1_f32(pred, src + idx); - return op_mask_helper(res_u8, size, handler); + const svbool_t cmp = ArithHelperF32::op_special( + pred, src_v, right_v, value_v); + return cmp; + }; + + return op_mask_helper( + res_u8, size, handler); + } else if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand < 0) { + // flip signs and go for the multiplication case + return OpArithCompareImpl::op>:: + op_arith_compare(res_u8, src, -right_operand, -value, size); + } + + // go with the default case + } + + // a default case + { + using T = float; + + auto handler = [src, right_operand, value](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; + + const auto right_v = svdup_n_f32(right_operand); + const auto value_v = svdup_n_f32(value); + const svfloat32_t src_v = svld1_f32(pred, src + idx); + + const svbool_t cmp = ArithHelperF32::op( + pred, src_v, right_v, value_v); + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); + } } } @@ -1603,22 +1675,60 @@ OpArithCompareImpl::op_arith_compare( if constexpr (AOp == ArithOpType::Mod) { return false; } else { - using T = double; + if constexpr (AOp == ArithOpType::Div) { + if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand > 0) { + // a special case that allows faster processing by using the multiplication + // operation instead of the division one. - auto handler = [src, right_operand, value](const svbool_t pred, - const size_t idx) { - using sve_t = SVEVector; + using T = double; - const auto right_v = svdup_n_f64(right_operand); - const auto value_v = svdup_n_f64(value); - const svfloat64_t src_v = svld1_f64(pred, src + idx); + auto handler = [src, right_operand, value](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; - const svbool_t cmp = - ArithHelperF64::op(pred, src_v, right_v, value_v); - return cmp; - }; + const auto right_v = svdup_n_f64(right_operand); + const auto value_v = svdup_n_f64(value); + const svfloat64_t src_v = svld1_f64(pred, src + idx); - return op_mask_helper(res_u8, size, handler); + const svbool_t cmp = ArithHelperF64::op( + pred, src_v, right_v, value_v); + return cmp; + }; + + return op_mask_helper( + res_u8, size, handler); + } else if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand < 0) { + // flip signs and go for the multiplication case + return OpArithCompareImpl::op>:: + op_arith_compare(res_u8, src, -right_operand, -value, size); + } + + // go with the default case + } + + // a default case + { + using T = double; + + auto handler = [src, right_operand, value](const svbool_t pred, + const size_t idx) { + using sve_t = SVEVector; + + const auto right_v = svdup_n_f64(right_operand); + const auto value_v = svdup_n_f64(value); + const svfloat64_t src_v = svld1_f64(pred, src + idx); + + const svbool_t cmp = ArithHelperF64::op( + pred, src_v, right_v, value_v); + return cmp; + }; + + return op_mask_helper(res_u8, size, handler); + } } } diff --git a/internal/core/src/bitset/detail/platform/x86/avx2-impl.h b/internal/core/src/bitset/detail/platform/x86/avx2-impl.h index 51af01047a..f1b783727a 100644 --- a/internal/core/src/bitset/detail/platform/x86/avx2-impl.h +++ b/internal/core/src/bitset/detail/platform/x86/avx2-impl.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -1347,11 +1348,19 @@ struct ArithHelperF32 { template struct ArithHelperF32 { static inline __m256 - op(const __m256 left, const __m256 right, const __m256 value) { + op_special(const __m256 left, const __m256 right, const __m256 value) { + // this is valid for the positive denominator, == and != cases. // left == right * value constexpr auto pred = ComparePredicate::value; return _mm256_cmp_ps(left, _mm256_mul_ps(right, value), pred); } + + static inline __m256 + op(const __m256 left, const __m256 right, const __m256 value) { + // left / right == value + constexpr auto pred = ComparePredicate::value; + return _mm256_cmp_ps(_mm256_div_ps(left, right), value, pred); + } }; // todo: Mod @@ -1393,11 +1402,19 @@ struct ArithHelperF64 { template struct ArithHelperF64 { static inline __m256d - op(const __m256d left, const __m256d right, const __m256d value) { + op_special(const __m256d left, const __m256d right, const __m256d value) { + // this is valid for the positive denominator, == and != cases. // left == right * value constexpr auto pred = ComparePredicate::value; return _mm256_cmp_pd(left, _mm256_mul_pd(right, value), pred); } + + static inline __m256d + op(const __m256d left, const __m256d right, const __m256d value) { + // left / right == value + constexpr auto pred = ComparePredicate::value; + return _mm256_cmp_pd(_mm256_div_pd(left, right), value, pred); + } }; } // namespace @@ -1589,26 +1606,67 @@ OpArithCompareImpl::op_arith_compare( if constexpr (AOp == ArithOpType::Mod) { return false; } else { - // the restriction of the API - assert((size % 8) == 0); + if constexpr (AOp == ArithOpType::Div) { + if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand > 0) { + // a special case that allows faster processing by using the multiplication + // operation instead of the division one. - // - const __m256 right_v = _mm256_set1_ps(right_operand); - const __m256 value_v = _mm256_set1_ps(value); + // the restriction of the API + assert((size % 8) == 0); - // todo: aligned reads & writes + // + const __m256 right_v = _mm256_set1_ps(right_operand); + const __m256 value_v = _mm256_set1_ps(value); - const size_t size8 = (size / 8) * 8; - for (size_t i = 0; i < size8; i += 8) { - const __m256 v0s = _mm256_loadu_ps(src + i); - const __m256 cmp = - ArithHelperF32::op(v0s, right_v, value_v); - const uint8_t mmask = _mm256_movemask_ps(cmp); + // todo: aligned reads & writes - res_u8[i / 8] = mmask; + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256 v0s = _mm256_loadu_ps(src + i); + const __m256 cmp = ArithHelperF32::op_special( + v0s, right_v, value_v); + const uint8_t mmask = _mm256_movemask_ps(cmp); + + res_u8[i / 8] = mmask; + } + + return true; + } else if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand < 0) { + // flip signs and go for the multiplication case + return OpArithCompareImpl::op>:: + op_arith_compare(res_u8, src, -right_operand, -value, size); + } + + // go with the default case } - return true; + // a default case + { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m256 right_v = _mm256_set1_ps(right_operand); + const __m256 value_v = _mm256_set1_ps(value); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256 v0s = _mm256_loadu_ps(src + i); + const __m256 cmp = + ArithHelperF32::op(v0s, right_v, value_v); + const uint8_t mmask = _mm256_movemask_ps(cmp); + + res_u8[i / 8] = mmask; + } + + return true; + } } } @@ -1623,30 +1681,75 @@ OpArithCompareImpl::op_arith_compare( if constexpr (AOp == ArithOpType::Mod) { return false; } else { - // the restriction of the API - assert((size % 8) == 0); + if constexpr (AOp == ArithOpType::Div) { + if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand > 0) { + // a special case that allows faster processing by using the multiplication + // operation instead of the division one. - // - const __m256d right_v = _mm256_set1_pd(right_operand); - const __m256d value_v = _mm256_set1_pd(value); + // the restriction of the API + assert((size % 8) == 0); - // todo: aligned reads & writes + // + const __m256d right_v = _mm256_set1_pd(right_operand); + const __m256d value_v = _mm256_set1_pd(value); - const size_t size8 = (size / 8) * 8; - for (size_t i = 0; i < size8; i += 8) { - const __m256d v0s = _mm256_loadu_pd(src + i); - const __m256d v1s = _mm256_loadu_pd(src + i + 4); - const __m256d cmp0 = - ArithHelperF64::op(v0s, right_v, value_v); - const __m256d cmp1 = - ArithHelperF64::op(v1s, right_v, value_v); - const uint8_t mmask0 = _mm256_movemask_pd(cmp0); - const uint8_t mmask1 = _mm256_movemask_pd(cmp1); + // todo: aligned reads & writes - res_u8[i / 8] = mmask0 + mmask1 * 16; + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256d v0s = _mm256_loadu_pd(src + i); + const __m256d v1s = _mm256_loadu_pd(src + i + 4); + const __m256d cmp0 = ArithHelperF64::op_special( + v0s, right_v, value_v); + const __m256d cmp1 = ArithHelperF64::op_special( + v1s, right_v, value_v); + const uint8_t mmask0 = _mm256_movemask_pd(cmp0); + const uint8_t mmask1 = _mm256_movemask_pd(cmp1); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; + } else if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand < 0) { + // flip signs and go for the multiplication case + return OpArithCompareImpl::op>:: + op_arith_compare(res_u8, src, -right_operand, -value, size); + } + + // go with the default case } - return true; + // a default case + { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m256d right_v = _mm256_set1_pd(right_operand); + const __m256d value_v = _mm256_set1_pd(value); + + // todo: aligned reads & writes + + const size_t size8 = (size / 8) * 8; + for (size_t i = 0; i < size8; i += 8) { + const __m256d v0s = _mm256_loadu_pd(src + i); + const __m256d v1s = _mm256_loadu_pd(src + i + 4); + const __m256d cmp0 = + ArithHelperF64::op(v0s, right_v, value_v); + const __m256d cmp1 = + ArithHelperF64::op(v1s, right_v, value_v); + const uint8_t mmask0 = _mm256_movemask_pd(cmp0); + const uint8_t mmask1 = _mm256_movemask_pd(cmp1); + + res_u8[i / 8] = mmask0 + mmask1 * 16; + } + + return true; + } } } diff --git a/internal/core/src/bitset/detail/platform/x86/avx512-impl.h b/internal/core/src/bitset/detail/platform/x86/avx512-impl.h index 3ffbf209d2..5ea7ac9644 100644 --- a/internal/core/src/bitset/detail/platform/x86/avx512-impl.h +++ b/internal/core/src/bitset/detail/platform/x86/avx512-impl.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -1391,11 +1392,19 @@ struct ArithHelperF32 { template struct ArithHelperF32 { static inline __mmask16 - op(const __m512 left, const __m512 right, const __m512 value) { + op_special(const __m512 left, const __m512 right, const __m512 value) { + // this is valid for the positive denominator, == and != cases. // left == right * value constexpr auto pred = ComparePredicate::value; return _mm512_cmp_ps_mask(left, _mm512_mul_ps(right, value), pred); } + + static inline __mmask16 + op(const __m512 left, const __m512 right, const __m512 value) { + // left / right == value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_ps_mask(_mm512_div_ps(left, right), value, pred); + } }; // @@ -1435,11 +1444,19 @@ struct ArithHelperF64 { template struct ArithHelperF64 { static inline __mmask8 - op(const __m512d left, const __m512d right, const __m512d value) { + op_special(const __m512d left, const __m512d right, const __m512d value) { + // this is valid for the positive denominator, == and != cases. // left == right * value constexpr auto pred = ComparePredicate::value; return _mm512_cmp_pd_mask(left, _mm512_mul_pd(right, value), pred); } + + static inline __mmask8 + op(const __m512d left, const __m512d right, const __m512d value) { + // left / right == value + constexpr auto pred = ComparePredicate::value; + return _mm512_cmp_pd_mask(_mm512_div_pd(left, right), value, pred); + } }; } // namespace @@ -1762,58 +1779,137 @@ OpArithCompareImpl::op_arith_compare( if constexpr (AOp == ArithOpType::Mod) { return false; } else { - // the restriction of the API - assert((size % 8) == 0); + if constexpr (AOp == ArithOpType::Div) { + if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand > 0) { + // a special case that allows faster processing by using the multiplication + // operation instead of the division one. - // - const __m512 right_v = _mm512_set1_ps(right_operand); - const __m512 value_v = _mm512_set1_ps(value); - uint16_t* const __restrict res_u16 = - reinterpret_cast(res_u8); + // the restriction of the API + assert((size % 8) == 0); - // todo: aligned reads & writes + // + const __m512 right_v = _mm512_set1_ps(right_operand); + const __m512 value_v = _mm512_set1_ps(value); + uint16_t* const __restrict res_u16 = + reinterpret_cast(res_u8); - // interleaved pages - constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(float)); - const size_t size_8p = - (size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT; - for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) { - for (size_t p = 0; p < BLOCK_COUNT; p += 16) { - for (size_t ip = 0; ip < N_BLOCKS; ip++) { - const __m512 v0s = - _mm512_loadu_ps(src + i + p + ip * BLOCK_COUNT); + // todo: aligned reads & writes + + // interleaved pages + constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(float)); + const size_t size_8p = + (size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT; + for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) { + for (size_t p = 0; p < BLOCK_COUNT; p += 16) { + for (size_t ip = 0; ip < N_BLOCKS; ip++) { + const __m512 v0s = + _mm512_loadu_ps(src + i + p + ip * BLOCK_COUNT); + const __mmask16 cmp_mask = + ArithHelperF32::op_special( + v0s, right_v, value_v); + + res_u16[(i + p + ip * BLOCK_COUNT) / 16] = cmp_mask; + + _mm_prefetch( + (const char*)(src + i + p + ip * BLOCK_COUNT) + + BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH, + _MM_HINT_T0); + } + } + } + + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = size_8p; i < size16; i += 16) { + const __m512 v0s = _mm512_loadu_ps(src + i); const __mmask16 cmp_mask = - ArithHelperF32::op(v0s, right_v, value_v); + ArithHelperF32::op_special( + v0s, right_v, value_v); + res_u16[i / 16] = cmp_mask; + } - res_u16[(i + p + ip * BLOCK_COUNT) / 16] = cmp_mask; + // process leftovers + if (size16 != size) { + // process 8 elements + const __m256 vs = _mm256_loadu_ps(src + size16); + const __m512 v0s = _mm512_castps256_ps512(vs); + const __mmask16 cmp_mask = + ArithHelperF32::op_special( + v0s, right_v, value_v); + res_u8[size16 / 8] = uint8_t(cmp_mask); + } - _mm_prefetch((const char*)(src + i + p + ip * BLOCK_COUNT) + - BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH, - _MM_HINT_T0); + return true; + } else if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand < 0) { + // flip signs and go for the multiplication case + return OpArithCompareImpl::op>:: + op_arith_compare(res_u8, src, -right_operand, -value, size); + } + + // go with the default case + } + + // a default case + { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512 right_v = _mm512_set1_ps(right_operand); + const __m512 value_v = _mm512_set1_ps(value); + uint16_t* const __restrict res_u16 = + reinterpret_cast(res_u8); + + // todo: aligned reads & writes + + // interleaved pages + constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(float)); + const size_t size_8p = + (size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT; + for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) { + for (size_t p = 0; p < BLOCK_COUNT; p += 16) { + for (size_t ip = 0; ip < N_BLOCKS; ip++) { + const __m512 v0s = + _mm512_loadu_ps(src + i + p + ip * BLOCK_COUNT); + const __mmask16 cmp_mask = + ArithHelperF32::op( + v0s, right_v, value_v); + + res_u16[(i + p + ip * BLOCK_COUNT) / 16] = cmp_mask; + + _mm_prefetch( + (const char*)(src + i + p + ip * BLOCK_COUNT) + + BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH, + _MM_HINT_T0); + } } } - } - // process big blocks - const size_t size16 = (size / 16) * 16; - for (size_t i = size_8p; i < size16; i += 16) { - const __m512 v0s = _mm512_loadu_ps(src + i); - const __mmask16 cmp_mask = - ArithHelperF32::op(v0s, right_v, value_v); - res_u16[i / 16] = cmp_mask; - } + // process big blocks + const size_t size16 = (size / 16) * 16; + for (size_t i = size_8p; i < size16; i += 16) { + const __m512 v0s = _mm512_loadu_ps(src + i); + const __mmask16 cmp_mask = + ArithHelperF32::op(v0s, right_v, value_v); + res_u16[i / 16] = cmp_mask; + } - // process leftovers - if (size16 != size) { - // process 8 elements - const __m256 vs = _mm256_loadu_ps(src + size16); - const __m512 v0s = _mm512_castps256_ps512(vs); - const __mmask16 cmp_mask = - ArithHelperF32::op(v0s, right_v, value_v); - res_u8[size16 / 8] = uint8_t(cmp_mask); - } + // process leftovers + if (size16 != size) { + // process 8 elements + const __m256 vs = _mm256_loadu_ps(src + size16); + const __m512 v0s = _mm512_castps256_ps512(vs); + const __mmask16 cmp_mask = + ArithHelperF32::op(v0s, right_v, value_v); + res_u8[size16 / 8] = uint8_t(cmp_mask); + } - return true; + return true; + } } } @@ -1828,47 +1924,114 @@ OpArithCompareImpl::op_arith_compare( if constexpr (AOp == ArithOpType::Mod) { return false; } else { - // the restriction of the API - assert((size % 8) == 0); + if constexpr (AOp == ArithOpType::Div) { + if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand > 0) { + // a special case that allows faster processing by using the multiplication + // operation instead of the division one. - // - const __m512d right_v = _mm512_set1_pd(right_operand); - const __m512d value_v = _mm512_set1_pd(value); + // the restriction of the API + assert((size % 8) == 0); - // todo: aligned reads & writes + // + const __m512d right_v = _mm512_set1_pd(right_operand); + const __m512d value_v = _mm512_set1_pd(value); - // interleaved pages - constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(int64_t)); - const size_t size_8p = - (size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT; - for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) { - for (size_t p = 0; p < BLOCK_COUNT; p += 8) { - for (size_t ip = 0; ip < N_BLOCKS; ip++) { - const __m512d v0s = - _mm512_loadu_pd(src + i + p + ip * BLOCK_COUNT); + // todo: aligned reads & writes + + // interleaved pages + constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(int64_t)); + const size_t size_8p = + (size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT; + for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) { + for (size_t p = 0; p < BLOCK_COUNT; p += 8) { + for (size_t ip = 0; ip < N_BLOCKS; ip++) { + const __m512d v0s = + _mm512_loadu_pd(src + i + p + ip * BLOCK_COUNT); + const __mmask8 cmp_mask = + ArithHelperF64::op_special( + v0s, right_v, value_v); + + res_u8[(i + p + ip * BLOCK_COUNT) / 8] = cmp_mask; + + _mm_prefetch( + (const char*)(src + i + p + ip * BLOCK_COUNT) + + BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH, + _MM_HINT_T0); + } + } + } + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = size_8p; i < size8; i += 8) { + const __m512d v0s = _mm512_loadu_pd(src + i); const __mmask8 cmp_mask = - ArithHelperF64::op(v0s, right_v, value_v); + ArithHelperF64::op_special( + v0s, right_v, value_v); - res_u8[(i + p + ip * BLOCK_COUNT) / 8] = cmp_mask; + res_u8[i / 8] = cmp_mask; + } - _mm_prefetch((const char*)(src + i + p + ip * BLOCK_COUNT) + - BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH, - _MM_HINT_T0); + return true; + } else if (std::isfinite(value) && std::isfinite(right_operand) && + right_operand < 0) { + // flip signs and go for the multiplication case + return OpArithCompareImpl::op>:: + op_arith_compare(res_u8, src, -right_operand, -value, size); + } + + // go with the default case + } + + // a default case + { + // the restriction of the API + assert((size % 8) == 0); + + // + const __m512d right_v = _mm512_set1_pd(right_operand); + const __m512d value_v = _mm512_set1_pd(value); + + // todo: aligned reads & writes + + // interleaved pages + constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(int64_t)); + const size_t size_8p = + (size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT; + for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) { + for (size_t p = 0; p < BLOCK_COUNT; p += 8) { + for (size_t ip = 0; ip < N_BLOCKS; ip++) { + const __m512d v0s = + _mm512_loadu_pd(src + i + p + ip * BLOCK_COUNT); + const __mmask8 cmp_mask = + ArithHelperF64::op( + v0s, right_v, value_v); + + res_u8[(i + p + ip * BLOCK_COUNT) / 8] = cmp_mask; + + _mm_prefetch( + (const char*)(src + i + p + ip * BLOCK_COUNT) + + BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH, + _MM_HINT_T0); + } } } + + // process big blocks + const size_t size8 = (size / 8) * 8; + for (size_t i = size_8p; i < size8; i += 8) { + const __m512d v0s = _mm512_loadu_pd(src + i); + const __mmask8 cmp_mask = + ArithHelperF64::op(v0s, right_v, value_v); + + res_u8[i / 8] = cmp_mask; + } + + return true; } - - // process big blocks - const size_t size8 = (size / 8) * 8; - for (size_t i = size_8p; i < size8; i += 8) { - const __m512d v0s = _mm512_loadu_pd(src + i); - const __mmask8 cmp_mask = - ArithHelperF64::op(v0s, right_v, value_v); - - res_u8[i / 8] = cmp_mask; - } - - return true; } } diff --git a/internal/core/src/bitset/detail/platform/x86/common.h b/internal/core/src/bitset/detail/platform/x86/common.h index 9bedb78c32..af1daf8784 100644 --- a/internal/core/src/bitset/detail/platform/x86/common.h +++ b/internal/core/src/bitset/detail/platform/x86/common.h @@ -64,7 +64,7 @@ struct ComparePredicate { template struct ComparePredicate { static inline constexpr int value = - std::is_floating_point_v ? _CMP_NEQ_OQ : _MM_CMPINT_NE; + std::is_floating_point_v ? _CMP_NEQ_UQ : _MM_CMPINT_NE; }; } // namespace x86 diff --git a/internal/core/unittest/test_bitset.cpp b/internal/core/unittest/test_bitset.cpp index e4decc751c..6d30c44eee 100644 --- a/internal/core/unittest/test_bitset.cpp +++ b/internal/core/unittest/test_bitset.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -308,6 +309,18 @@ FillRandom(std::vector& t, } } +template +void +FillRandomRange(std::vector& t, + std::default_random_engine& rng, + const int32_t min_v, + const int32_t max_v) { + std::uniform_int_distribution tt(0, max_v); + for (size_t i = 0; i < t.size(); i++) { + t[i] = static_cast(tt(rng)); + } +} + template <> void FillRandom(std::vector& t, @@ -1296,18 +1309,24 @@ INSTANTIATE_TYPED_TEST_SUITE_P(InplaceWithinRangeValTest, template struct TestInplaceArithCompareImplS { static void - process(BitsetT& bitset, ArithOpType a_op, CompareOpType cmp_op) { + process(BitsetT& bitset, + ArithOpType a_op, + CompareOpType cmp_op, + const int32_t right_operand_in, + const int32_t value_in) { using HT = ArithHighPrecisionType; const size_t n = bitset.size(); - constexpr size_t max_v = 10; + constexpr int32_t max_v = 10; std::vector left(n, 0); - const HT right_operand = from_i32(2); - const HT value = from_i32(5); + const HT right_operand = from_i32(right_operand_in); + const HT value = from_i32(value_in); std::default_random_engine rng(123); - FillRandom(left, rng, max_v); + // Generating values in (-x, x) range. + // This is fine, because we're operating with signed integers. + FillRandomRange(left, rng, -max_v, max_v); StopWatch sw; bitset.inplace_arith_compare( @@ -1321,110 +1340,140 @@ struct TestInplaceArithCompareImplS { if (a_op == ArithOpType::Add) { if (cmp_op == CompareOpType::EQ) { ASSERT_EQ((left[i] + right_operand) == value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::GE) { ASSERT_EQ((left[i] + right_operand) >= value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::GT) { ASSERT_EQ((left[i] + right_operand) > value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::LE) { ASSERT_EQ((left[i] + right_operand) <= value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::LT) { ASSERT_EQ((left[i] + right_operand) < value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::NE) { ASSERT_EQ((left[i] + right_operand) != value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else { ASSERT_TRUE(false) << "Not implemented"; } } else if (a_op == ArithOpType::Sub) { if (cmp_op == CompareOpType::EQ) { ASSERT_EQ((left[i] - right_operand) == value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::GE) { ASSERT_EQ((left[i] - right_operand) >= value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::GT) { ASSERT_EQ((left[i] - right_operand) > value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::LE) { ASSERT_EQ((left[i] - right_operand) <= value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::LT) { ASSERT_EQ((left[i] - right_operand) < value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::NE) { ASSERT_EQ((left[i] - right_operand) != value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else { ASSERT_TRUE(false) << "Not implemented"; } } else if (a_op == ArithOpType::Mul) { if (cmp_op == CompareOpType::EQ) { ASSERT_EQ((left[i] * right_operand) == value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::GE) { ASSERT_EQ((left[i] * right_operand) >= value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::GT) { ASSERT_EQ((left[i] * right_operand) > value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::LE) { ASSERT_EQ((left[i] * right_operand) <= value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::LT) { ASSERT_EQ((left[i] * right_operand) < value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::NE) { ASSERT_EQ((left[i] * right_operand) != value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else { ASSERT_TRUE(false) << "Not implemented"; } } else if (a_op == ArithOpType::Div) { if (cmp_op == CompareOpType::EQ) { ASSERT_EQ((left[i] / right_operand) == value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::GE) { ASSERT_EQ((left[i] / right_operand) >= value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::GT) { ASSERT_EQ((left[i] / right_operand) > value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::LE) { ASSERT_EQ((left[i] / right_operand) <= value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::LT) { ASSERT_EQ((left[i] / right_operand) < value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::NE) { ASSERT_EQ((left[i] / right_operand) != value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else { ASSERT_TRUE(false) << "Not implemented"; } } else if (a_op == ArithOpType::Mod) { if (cmp_op == CompareOpType::EQ) { ASSERT_EQ(fmod(left[i], right_operand) == value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::GE) { ASSERT_EQ(fmod(left[i], right_operand) >= value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::GT) { ASSERT_EQ(fmod(left[i], right_operand) > value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::LE) { ASSERT_EQ(fmod(left[i], right_operand) <= value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::LT) { ASSERT_EQ(fmod(left[i], right_operand) < value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else if (cmp_op == CompareOpType::NE) { ASSERT_EQ(fmod(left[i], right_operand) != value, bitset[i]) - << i; + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; } else { ASSERT_TRUE(false) << "Not implemented"; } @@ -1433,12 +1482,63 @@ struct TestInplaceArithCompareImplS { } } } + + static void + process_div_special(BitsetT& bitset, + CompareOpType cmp_op, + const T left_v, + const T right_v, + const T value_v) { + // test a single special point for the division + + using HT = ArithHighPrecisionType; + + const size_t n = bitset.size(); + + std::vector left(n, left_v); + const HT right_operand = right_v; + const HT value = value_v; + + bitset.inplace_arith_compare( + left.data(), right_operand, value, n, ArithOpType::Div, cmp_op); + + for (size_t i = 0; i < n; i++) { + if (cmp_op == CompareOpType::EQ) { + ASSERT_EQ((left[i] / right_operand) == value, bitset[i]) + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; + } else if (cmp_op == CompareOpType::GE) { + ASSERT_EQ((left[i] / right_operand) >= value, bitset[i]) + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; + } else if (cmp_op == CompareOpType::GT) { + ASSERT_EQ((left[i] / right_operand) > value, bitset[i]) + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; + } else if (cmp_op == CompareOpType::LE) { + ASSERT_EQ((left[i] / right_operand) <= value, bitset[i]) + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; + } else if (cmp_op == CompareOpType::LT) { + ASSERT_EQ((left[i] / right_operand) < value, bitset[i]) + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; + } else if (cmp_op == CompareOpType::NE) { + ASSERT_EQ((left[i] / right_operand) != value, bitset[i]) + << i << " " << size_t(cmp_op) << " " << left[i] << " " + << right_operand << " " << value; + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } + } }; template struct TestInplaceArithCompareImplS { static void - process(BitsetT&, ArithOpType, CompareOpType) { + process( + BitsetT&, ArithOpType, CompareOpType, const int32_t, const int32_t) { // does nothing } }; @@ -1446,40 +1546,102 @@ struct TestInplaceArithCompareImplS { template void TestInplaceArithCompareImpl() { - for (const size_t n : typical_sizes) { - for (const auto a_op : typical_arith_ops) { - for (const auto cmp_op : typical_compare_ops) { - BitsetT bitset(n); - bitset.reset(); + if constexpr (std::is_floating_point_v) + for (const size_t n : typical_sizes) { + for (const auto a_op : typical_arith_ops) { + for (const auto cmp_op : typical_compare_ops) { + // test both positive, zero and negative + for (const int32_t right_operand : {2, 0, -2}) { + if ((!std::is_floating_point_v || + a_op == milvus::bitset::ArithOpType::Mod) && + right_operand == 0) { + continue; + } - if (print_log) { - printf( - "Testing bitset, n=%zd, a_op=%zd\n", n, (size_t)a_op); + // test both positive, zero and negative + for (const int32_t value : {2, 0, -2}) { + BitsetT bitset(n); + bitset.reset(); + + if (print_log) { + printf( + "Testing bitset, n=%zd, a_op=%zd, " + "cmp_op=%zd, right_operand=%d\n", + n, + (size_t)a_op, + (size_t)cmp_op, + right_operand); + } + + TestInplaceArithCompareImplS::process( + bitset, a_op, cmp_op, right_operand, value); + + for (const size_t offset : typical_offsets) { + if (offset >= n) { + continue; + } + + bitset.reset(); + auto view = bitset.view(offset); + + if (print_log) { + printf( + "Testing bitset view, n=%zd, " + "offset=%zd, a_op=%zd, cmp_op=%zd, " + "right_operand=%d\n", + n, + offset, + (size_t)a_op, + (size_t)cmp_op, + right_operand); + } + + TestInplaceArithCompareImplS< + decltype(view), + T>::process(view, + a_op, + cmp_op, + right_operand, + value); + } + } + } } + } + } - TestInplaceArithCompareImplS::process( - bitset, a_op, cmp_op); + if constexpr (std::is_floating_point_v) { + // test various special use cases for IEEE-754 for the division operation. + std::vector variety = {0, + 1, + -1, + std::numeric_limits::quiet_NaN(), + -std::numeric_limits::quiet_NaN(), + std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}; - for (const size_t offset : typical_offsets) { - if (offset >= n) { - continue; + for (const auto cmp_op : typical_compare_ops) { + for (const T left_v : variety) { + for (const T right_v : variety) { + for (const T value_v : variety) { + // 40 should be sufficient to test avx512 + BitsetT bitset(40); + bitset.reset(); + + if (print_log) { + printf( + "Testing bitset div special case, cmp_op=%zd, " + "left_v=%f, right_v=%f, value_v=%f\n", + (size_t)cmp_op, + left_v, + right_v, + value_v); + } + + TestInplaceArithCompareImplS:: + process_div_special( + bitset, cmp_op, left_v, right_v, value_v); } - - bitset.reset(); - auto view = bitset.view(offset); - - if (print_log) { - printf( - "Testing bitset view, n=%zd, offset=%zd, a_op=%zd, " - "cmp_op=%zd\n", - n, - offset, - (size_t)a_op, - (size_t)cmp_op); - } - - TestInplaceArithCompareImplS::process( - view, a_op, cmp_op); } } }