fix: fix incorrect bitset for the division comparison when the right is < 0 (#43179)

issue: https://github.com/milvus-io/milvus/issues/42900
@sunby Unfortunately, it is not that easy to fix as it was thought in
#43177

Upd: also handles `Inf` and `NaN` values, and the division by zero case
for `fp32` and `fp64`

Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
This commit is contained in:
Alexander Guzhva 2025-07-11 11:04:49 +00:00 committed by GitHub
parent 15a6631147
commit a848c4a8c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 926 additions and 253 deletions

View File

@ -157,5 +157,19 @@ struct ArithCompareOperator {
}
};
// This is related for a special handling of A/B vs C comparison.
// A multiplication operation is used instead of a division,
// and it is needed to invert signs and change comparison operators
// in case if the denominator is negative.
template <CompareOpType CmpOp>
struct CompareOpDivFlip {
static constexpr CompareOpType op =
(CmpOp == CompareOpType::LE) ? CompareOpType::GE
: (CmpOp == CompareOpType::LT) ? CompareOpType::GT
: (CmpOp == CompareOpType::GE) ? CompareOpType::LE
: (CmpOp == CompareOpType::GT) ? CompareOpType::LT
: CmpOp;
};
} // namespace bitset
} // namespace milvus

View File

@ -21,6 +21,7 @@
#include <arm_neon.h>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <type_traits>
@ -1459,14 +1460,25 @@ struct ArithHelperF32<ArithOpType::Mul, CmpOp> {
template <CompareOpType CmpOp>
struct ArithHelperF32<ArithOpType::Div, CmpOp> {
static inline uint32x4x2_t
op(const float32x4x2_t left,
const float32x4x2_t right,
const float32x4x2_t value) {
op_special(const float32x4x2_t left,
const float32x4x2_t right,
const float32x4x2_t value) {
// this is valid for the positive denominator, == and != cases.
// left == right * value
const float32x4x2_t rv = {vmulq_f32(right.val[0], value.val[0]),
vmulq_f32(right.val[1], value.val[1])};
return CmpHelper<CmpOp>::compare(left, rv);
}
static inline uint32x4x2_t
op(const float32x4x2_t left,
const float32x4x2_t right,
const float32x4x2_t value) {
// left / right == value
const float32x4x2_t rv = {vdivq_f32(left.val[0], right.val[0]),
vdivq_f32(left.val[1], right.val[1])};
return CmpHelper<CmpOp>::compare(rv, value);
}
};
//
@ -1521,9 +1533,9 @@ struct ArithHelperF64<ArithOpType::Mul, CmpOp> {
template <CompareOpType CmpOp>
struct ArithHelperF64<ArithOpType::Div, CmpOp> {
static inline uint64x2x4_t
op(const float64x2x4_t left,
const float64x2x4_t right,
const float64x2x4_t value) {
op_special(const float64x2x4_t left,
const float64x2x4_t right,
const float64x2x4_t value) {
// left == right * value
const float64x2x4_t rv = {vmulq_f64(right.val[0], value.val[0]),
vmulq_f64(right.val[1], value.val[1]),
@ -1531,6 +1543,18 @@ struct ArithHelperF64<ArithOpType::Div, CmpOp> {
vmulq_f64(right.val[3], value.val[3])};
return CmpHelper<CmpOp>::compare(left, rv);
}
static inline uint64x2x4_t
op(const float64x2x4_t left,
const float64x2x4_t right,
const float64x2x4_t value) {
// left / right == value
const float64x2x4_t rv = {vdivq_f64(left.val[0], right.val[0]),
vdivq_f64(left.val[1], right.val[1]),
vdivq_f64(left.val[2], right.val[2]),
vdivq_f64(left.val[3], right.val[3])};
return CmpHelper<CmpOp>::compare(rv, value);
}
};
} // namespace
@ -1743,28 +1767,74 @@ OpArithCompareImpl<float, AOp, CmpOp>::op_arith_compare(
if constexpr (AOp == ArithOpType::Mod) {
return false;
} else {
// the restriction of the API
assert((size % 8) == 0);
if constexpr (AOp == ArithOpType::Div) {
if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand > 0) {
// a special case that allows faster processing by using the multiplication
// operation instead of the division one.
//
const float32x4x2_t right_v = {vdupq_n_f32(right_operand),
vdupq_n_f32(right_operand)};
const float32x4x2_t value_v = {vdupq_n_f32(value), vdupq_n_f32(value)};
// the restriction of the API
assert((size % 8) == 0);
// todo: aligned reads & writes
//
const float32x4x2_t right_v = {vdupq_n_f32(right_operand),
vdupq_n_f32(right_operand)};
const float32x4x2_t value_v = {vdupq_n_f32(value),
vdupq_n_f32(value)};
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const float32x4x2_t v0v = {vld1q_f32(src + i),
vld1q_f32(src + i + 4)};
const uint32x4x2_t cmp =
ArithHelperF32<AOp, CmpOp>::op(v0v, right_v, value_v);
// todo: aligned reads & writes
const uint8_t mmask = movemask(cmp);
res_u8[i / 8] = mmask;
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const float32x4x2_t v0v = {vld1q_f32(src + i),
vld1q_f32(src + i + 4)};
const uint32x4x2_t cmp =
ArithHelperF32<AOp, CmpOp>::op_special(
v0v, right_v, value_v);
const uint8_t mmask = movemask(cmp);
res_u8[i / 8] = mmask;
}
return true;
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand < 0) {
// flip signs and go for the multiplication case
return OpArithCompareImpl<float,
AOp,
CompareOpDivFlip<CmpOp>::op>::
op_arith_compare(res_u8, src, -right_operand, -value, size);
}
// go with the default case
}
return true;
// a default case
{
// the restriction of the API
assert((size % 8) == 0);
//
const float32x4x2_t right_v = {vdupq_n_f32(right_operand),
vdupq_n_f32(right_operand)};
const float32x4x2_t value_v = {vdupq_n_f32(value),
vdupq_n_f32(value)};
// todo: aligned reads & writes
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const float32x4x2_t v0v = {vld1q_f32(src + i),
vld1q_f32(src + i + 4)};
const uint32x4x2_t cmp =
ArithHelperF32<AOp, CmpOp>::op(v0v, right_v, value_v);
const uint8_t mmask = movemask(cmp);
res_u8[i / 8] = mmask;
}
return true;
}
}
}
@ -1779,35 +1849,86 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
if constexpr (AOp == ArithOpType::Mod) {
return false;
} else {
// the restriction of the API
assert((size % 8) == 0);
if constexpr (AOp == ArithOpType::Div) {
if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand > 0) {
// a special case that allows faster processing by using the multiplication
// operation instead of the division one.
//
const float64x2x4_t right_v = {vdupq_n_f64(right_operand),
vdupq_n_f64(right_operand),
vdupq_n_f64(right_operand),
vdupq_n_f64(right_operand)};
const float64x2x4_t value_v = {vdupq_n_f64(value),
vdupq_n_f64(value),
vdupq_n_f64(value),
vdupq_n_f64(value)};
// the restriction of the API
assert((size % 8) == 0);
// todo: aligned reads & writes
//
const float64x2x4_t right_v = {vdupq_n_f64(right_operand),
vdupq_n_f64(right_operand),
vdupq_n_f64(right_operand),
vdupq_n_f64(right_operand)};
const float64x2x4_t value_v = {vdupq_n_f64(value),
vdupq_n_f64(value),
vdupq_n_f64(value),
vdupq_n_f64(value)};
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const float64x2x4_t v0v = {vld1q_f64(src + i),
vld1q_f64(src + i + 2),
vld1q_f64(src + i + 4),
vld1q_f64(src + i + 6)};
const uint64x2x4_t cmp =
ArithHelperF64<AOp, CmpOp>::op(v0v, right_v, value_v);
// todo: aligned reads & writes
const uint8_t mmask = movemask(cmp);
res_u8[i / 8] = mmask;
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const float64x2x4_t v0v = {vld1q_f64(src + i),
vld1q_f64(src + i + 2),
vld1q_f64(src + i + 4),
vld1q_f64(src + i + 6)};
const uint64x2x4_t cmp =
ArithHelperF64<AOp, CmpOp>::op_special(
v0v, right_v, value_v);
const uint8_t mmask = movemask(cmp);
res_u8[i / 8] = mmask;
}
return true;
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand < 0) {
// flip signs and go for the multiplication case
return OpArithCompareImpl<double,
AOp,
CompareOpDivFlip<CmpOp>::op>::
op_arith_compare(res_u8, src, -right_operand, -value, size);
}
// go with the default case
}
return true;
// a default case
{
// the restriction of the API
assert((size % 8) == 0);
//
const float64x2x4_t right_v = {vdupq_n_f64(right_operand),
vdupq_n_f64(right_operand),
vdupq_n_f64(right_operand),
vdupq_n_f64(right_operand)};
const float64x2x4_t value_v = {vdupq_n_f64(value),
vdupq_n_f64(value),
vdupq_n_f64(value),
vdupq_n_f64(value)};
// todo: aligned reads & writes
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const float64x2x4_t v0v = {vld1q_f64(src + i),
vld1q_f64(src + i + 2),
vld1q_f64(src + i + 4),
vld1q_f64(src + i + 6)};
const uint64x2x4_t cmp =
ArithHelperF64<AOp, CmpOp>::op(v0v, right_v, value_v);
const uint8_t mmask = movemask(cmp);
res_u8[i / 8] = mmask;
}
return true;
}
}
}

View File

@ -21,6 +21,7 @@
#include <arm_sve.h>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <type_traits>
@ -1314,15 +1315,26 @@ struct ArithHelperI64<ArithOpType::Mul, CmpOp> {
};
template <CompareOpType CmpOp>
struct ArithHelperI64<ArithOpType::Div, CmpOp> {
struct ArithHelperF32<ArithOpType::Div, CmpOp> {
static inline svbool_t
op_special(const svbool_t pred,
const svfloat32_t left,
const svfloat32_t right,
const svfloat32_t value) {
// this is valid for the positive denominator, == and != cases.
// left == right * value
return CmpHelper<CmpOp>::compare(
pred, left, svmul_f32_z(pred, right, value));
}
static inline svbool_t
op(const svbool_t pred,
const svint64_t left,
const svint64_t right,
const svint64_t value) {
const svfloat32_t left,
const svfloat32_t right,
const svfloat32_t value) {
// left / right == value
return CmpHelper<CmpOp>::compare(
pred, svdiv_s64_z(pred, left, right), value);
pred, svdiv_f32_z(pred, left, right), value);
}
};
@ -1371,14 +1383,25 @@ struct ArithHelperF32<ArithOpType::Mul, CmpOp> {
template <CompareOpType CmpOp>
struct ArithHelperF32<ArithOpType::Div, CmpOp> {
static inline svbool_t
op_special(const svbool_t pred,
const svfloat32_t left,
const svfloat32_t right,
const svfloat32_t value) {
// this is valid for the positive denominator, == and != cases.
// left == right * value
return CmpHelper<CmpOp>::compare(
pred, left, svmul_f32_z(pred, right, value));
}
static inline svbool_t
op(const svbool_t pred,
const svfloat32_t left,
const svfloat32_t right,
const svfloat32_t value) {
// left == right * value
// left / right == value
return CmpHelper<CmpOp>::compare(
pred, left, svmul_f32_z(pred, right, value));
pred, svdiv_f32_z(pred, left, right), value);
}
};
@ -1427,14 +1450,25 @@ struct ArithHelperF64<ArithOpType::Mul, CmpOp> {
template <CompareOpType CmpOp>
struct ArithHelperF64<ArithOpType::Div, CmpOp> {
static inline svbool_t
op_special(const svbool_t pred,
const svfloat64_t left,
const svfloat64_t right,
const svfloat64_t value) {
// this is valid for the positive denominator, == and != cases.
// left == right * value
return CmpHelper<CmpOp>::compare(
pred, left, svmul_f64_z(pred, right, value));
}
static inline svbool_t
op(const svbool_t pred,
const svfloat64_t left,
const svfloat64_t right,
const svfloat64_t value) {
// left == right * value
// left / right == value
return CmpHelper<CmpOp>::compare(
pred, left, svmul_f64_z(pred, right, value));
pred, svdiv_f64_z(pred, left, right), value);
}
};
@ -1573,22 +1607,60 @@ OpArithCompareImpl<float, AOp, CmpOp>::op_arith_compare(
if constexpr (AOp == ArithOpType::Mod) {
return false;
} else {
using T = float;
if constexpr (AOp == ArithOpType::Div) {
if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand > 0) {
// a special case that allows faster processing by using the multiplication
// operation instead of the division one.
auto handler = [src, right_operand, value](const svbool_t pred,
const size_t idx) {
using sve_t = SVEVector<T>;
using T = float;
const auto right_v = svdup_n_f32(right_operand);
const auto value_v = svdup_n_f32(value);
const svfloat32_t src_v = svld1_f32(pred, src + idx);
auto handler = [src, right_operand, value](const svbool_t pred,
const size_t idx) {
using sve_t = SVEVector<T>;
const svbool_t cmp =
ArithHelperF32<AOp, CmpOp>::op(pred, src_v, right_v, value_v);
return cmp;
};
const auto right_v = svdup_n_f32(right_operand);
const auto value_v = svdup_n_f32(value);
const svfloat32_t src_v = svld1_f32(pred, src + idx);
return op_mask_helper<T, decltype(handler)>(res_u8, size, handler);
const svbool_t cmp = ArithHelperF32<AOp, CmpOp>::op_special(
pred, src_v, right_v, value_v);
return cmp;
};
return op_mask_helper<T, decltype(handler)>(
res_u8, size, handler);
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand < 0) {
// flip signs and go for the multiplication case
return OpArithCompareImpl<float,
AOp,
CompareOpDivFlip<CmpOp>::op>::
op_arith_compare(res_u8, src, -right_operand, -value, size);
}
// go with the default case
}
// a default case
{
using T = float;
auto handler = [src, right_operand, value](const svbool_t pred,
const size_t idx) {
using sve_t = SVEVector<T>;
const auto right_v = svdup_n_f32(right_operand);
const auto value_v = svdup_n_f32(value);
const svfloat32_t src_v = svld1_f32(pred, src + idx);
const svbool_t cmp = ArithHelperF32<AOp, CmpOp>::op(
pred, src_v, right_v, value_v);
return cmp;
};
return op_mask_helper<T, decltype(handler)>(res_u8, size, handler);
}
}
}
@ -1603,22 +1675,60 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
if constexpr (AOp == ArithOpType::Mod) {
return false;
} else {
using T = double;
if constexpr (AOp == ArithOpType::Div) {
if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand > 0) {
// a special case that allows faster processing by using the multiplication
// operation instead of the division one.
auto handler = [src, right_operand, value](const svbool_t pred,
const size_t idx) {
using sve_t = SVEVector<T>;
using T = double;
const auto right_v = svdup_n_f64(right_operand);
const auto value_v = svdup_n_f64(value);
const svfloat64_t src_v = svld1_f64(pred, src + idx);
auto handler = [src, right_operand, value](const svbool_t pred,
const size_t idx) {
using sve_t = SVEVector<T>;
const svbool_t cmp =
ArithHelperF64<AOp, CmpOp>::op(pred, src_v, right_v, value_v);
return cmp;
};
const auto right_v = svdup_n_f64(right_operand);
const auto value_v = svdup_n_f64(value);
const svfloat64_t src_v = svld1_f64(pred, src + idx);
return op_mask_helper<T, decltype(handler)>(res_u8, size, handler);
const svbool_t cmp = ArithHelperF64<AOp, CmpOp>::op(
pred, src_v, right_v, value_v);
return cmp;
};
return op_mask_helper<T, decltype(handler)>(
res_u8, size, handler);
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand < 0) {
// flip signs and go for the multiplication case
return OpArithCompareImpl<double,
AOp,
CompareOpDivFlip<CmpOp>::op>::
op_arith_compare(res_u8, src, -right_operand, -value, size);
}
// go with the default case
}
// a default case
{
using T = double;
auto handler = [src, right_operand, value](const svbool_t pred,
const size_t idx) {
using sve_t = SVEVector<T>;
const auto right_v = svdup_n_f64(right_operand);
const auto value_v = svdup_n_f64(value);
const svfloat64_t src_v = svld1_f64(pred, src + idx);
const svbool_t cmp = ArithHelperF64<AOp, CmpOp>::op(
pred, src_v, right_v, value_v);
return cmp;
};
return op_mask_helper<T, decltype(handler)>(res_u8, size, handler);
}
}
}

View File

@ -21,6 +21,7 @@
#include <immintrin.h>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <type_traits>
@ -1347,11 +1348,19 @@ struct ArithHelperF32<ArithOpType::Mul, CmpOp> {
template <CompareOpType CmpOp>
struct ArithHelperF32<ArithOpType::Div, CmpOp> {
static inline __m256
op(const __m256 left, const __m256 right, const __m256 value) {
op_special(const __m256 left, const __m256 right, const __m256 value) {
// this is valid for the positive denominator, == and != cases.
// left == right * value
constexpr auto pred = ComparePredicate<float, CmpOp>::value;
return _mm256_cmp_ps(left, _mm256_mul_ps(right, value), pred);
}
static inline __m256
op(const __m256 left, const __m256 right, const __m256 value) {
// left / right == value
constexpr auto pred = ComparePredicate<float, CmpOp>::value;
return _mm256_cmp_ps(_mm256_div_ps(left, right), value, pred);
}
};
// todo: Mod
@ -1393,11 +1402,19 @@ struct ArithHelperF64<ArithOpType::Mul, CmpOp> {
template <CompareOpType CmpOp>
struct ArithHelperF64<ArithOpType::Div, CmpOp> {
static inline __m256d
op(const __m256d left, const __m256d right, const __m256d value) {
op_special(const __m256d left, const __m256d right, const __m256d value) {
// this is valid for the positive denominator, == and != cases.
// left == right * value
constexpr auto pred = ComparePredicate<double, CmpOp>::value;
return _mm256_cmp_pd(left, _mm256_mul_pd(right, value), pred);
}
static inline __m256d
op(const __m256d left, const __m256d right, const __m256d value) {
// left / right == value
constexpr auto pred = ComparePredicate<double, CmpOp>::value;
return _mm256_cmp_pd(_mm256_div_pd(left, right), value, pred);
}
};
} // namespace
@ -1589,26 +1606,67 @@ OpArithCompareImpl<float, AOp, CmpOp>::op_arith_compare(
if constexpr (AOp == ArithOpType::Mod) {
return false;
} else {
// the restriction of the API
assert((size % 8) == 0);
if constexpr (AOp == ArithOpType::Div) {
if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand > 0) {
// a special case that allows faster processing by using the multiplication
// operation instead of the division one.
//
const __m256 right_v = _mm256_set1_ps(right_operand);
const __m256 value_v = _mm256_set1_ps(value);
// the restriction of the API
assert((size % 8) == 0);
// todo: aligned reads & writes
//
const __m256 right_v = _mm256_set1_ps(right_operand);
const __m256 value_v = _mm256_set1_ps(value);
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const __m256 v0s = _mm256_loadu_ps(src + i);
const __m256 cmp =
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
const uint8_t mmask = _mm256_movemask_ps(cmp);
// todo: aligned reads & writes
res_u8[i / 8] = mmask;
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const __m256 v0s = _mm256_loadu_ps(src + i);
const __m256 cmp = ArithHelperF32<AOp, CmpOp>::op_special(
v0s, right_v, value_v);
const uint8_t mmask = _mm256_movemask_ps(cmp);
res_u8[i / 8] = mmask;
}
return true;
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand < 0) {
// flip signs and go for the multiplication case
return OpArithCompareImpl<float,
AOp,
CompareOpDivFlip<CmpOp>::op>::
op_arith_compare(res_u8, src, -right_operand, -value, size);
}
// go with the default case
}
return true;
// a default case
{
// the restriction of the API
assert((size % 8) == 0);
//
const __m256 right_v = _mm256_set1_ps(right_operand);
const __m256 value_v = _mm256_set1_ps(value);
// todo: aligned reads & writes
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const __m256 v0s = _mm256_loadu_ps(src + i);
const __m256 cmp =
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
const uint8_t mmask = _mm256_movemask_ps(cmp);
res_u8[i / 8] = mmask;
}
return true;
}
}
}
@ -1623,30 +1681,75 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
if constexpr (AOp == ArithOpType::Mod) {
return false;
} else {
// the restriction of the API
assert((size % 8) == 0);
if constexpr (AOp == ArithOpType::Div) {
if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand > 0) {
// a special case that allows faster processing by using the multiplication
// operation instead of the division one.
//
const __m256d right_v = _mm256_set1_pd(right_operand);
const __m256d value_v = _mm256_set1_pd(value);
// the restriction of the API
assert((size % 8) == 0);
// todo: aligned reads & writes
//
const __m256d right_v = _mm256_set1_pd(right_operand);
const __m256d value_v = _mm256_set1_pd(value);
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const __m256d v0s = _mm256_loadu_pd(src + i);
const __m256d v1s = _mm256_loadu_pd(src + i + 4);
const __m256d cmp0 =
ArithHelperF64<AOp, CmpOp>::op(v0s, right_v, value_v);
const __m256d cmp1 =
ArithHelperF64<AOp, CmpOp>::op(v1s, right_v, value_v);
const uint8_t mmask0 = _mm256_movemask_pd(cmp0);
const uint8_t mmask1 = _mm256_movemask_pd(cmp1);
// todo: aligned reads & writes
res_u8[i / 8] = mmask0 + mmask1 * 16;
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const __m256d v0s = _mm256_loadu_pd(src + i);
const __m256d v1s = _mm256_loadu_pd(src + i + 4);
const __m256d cmp0 = ArithHelperF64<AOp, CmpOp>::op_special(
v0s, right_v, value_v);
const __m256d cmp1 = ArithHelperF64<AOp, CmpOp>::op_special(
v1s, right_v, value_v);
const uint8_t mmask0 = _mm256_movemask_pd(cmp0);
const uint8_t mmask1 = _mm256_movemask_pd(cmp1);
res_u8[i / 8] = mmask0 + mmask1 * 16;
}
return true;
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand < 0) {
// flip signs and go for the multiplication case
return OpArithCompareImpl<double,
AOp,
CompareOpDivFlip<CmpOp>::op>::
op_arith_compare(res_u8, src, -right_operand, -value, size);
}
// go with the default case
}
return true;
// a default case
{
// the restriction of the API
assert((size % 8) == 0);
//
const __m256d right_v = _mm256_set1_pd(right_operand);
const __m256d value_v = _mm256_set1_pd(value);
// todo: aligned reads & writes
const size_t size8 = (size / 8) * 8;
for (size_t i = 0; i < size8; i += 8) {
const __m256d v0s = _mm256_loadu_pd(src + i);
const __m256d v1s = _mm256_loadu_pd(src + i + 4);
const __m256d cmp0 =
ArithHelperF64<AOp, CmpOp>::op(v0s, right_v, value_v);
const __m256d cmp1 =
ArithHelperF64<AOp, CmpOp>::op(v1s, right_v, value_v);
const uint8_t mmask0 = _mm256_movemask_pd(cmp0);
const uint8_t mmask1 = _mm256_movemask_pd(cmp1);
res_u8[i / 8] = mmask0 + mmask1 * 16;
}
return true;
}
}
}

View File

@ -21,6 +21,7 @@
#include <immintrin.h>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <type_traits>
@ -1391,11 +1392,19 @@ struct ArithHelperF32<ArithOpType::Mul, CmpOp> {
template <CompareOpType CmpOp>
struct ArithHelperF32<ArithOpType::Div, CmpOp> {
static inline __mmask16
op(const __m512 left, const __m512 right, const __m512 value) {
op_special(const __m512 left, const __m512 right, const __m512 value) {
// this is valid for the positive denominator, == and != cases.
// left == right * value
constexpr auto pred = ComparePredicate<float, CmpOp>::value;
return _mm512_cmp_ps_mask(left, _mm512_mul_ps(right, value), pred);
}
static inline __mmask16
op(const __m512 left, const __m512 right, const __m512 value) {
// left / right == value
constexpr auto pred = ComparePredicate<float, CmpOp>::value;
return _mm512_cmp_ps_mask(_mm512_div_ps(left, right), value, pred);
}
};
//
@ -1435,11 +1444,19 @@ struct ArithHelperF64<ArithOpType::Mul, CmpOp> {
template <CompareOpType CmpOp>
struct ArithHelperF64<ArithOpType::Div, CmpOp> {
static inline __mmask8
op(const __m512d left, const __m512d right, const __m512d value) {
op_special(const __m512d left, const __m512d right, const __m512d value) {
// this is valid for the positive denominator, == and != cases.
// left == right * value
constexpr auto pred = ComparePredicate<double, CmpOp>::value;
return _mm512_cmp_pd_mask(left, _mm512_mul_pd(right, value), pred);
}
static inline __mmask8
op(const __m512d left, const __m512d right, const __m512d value) {
// left / right == value
constexpr auto pred = ComparePredicate<double, CmpOp>::value;
return _mm512_cmp_pd_mask(_mm512_div_pd(left, right), value, pred);
}
};
} // namespace
@ -1762,58 +1779,137 @@ OpArithCompareImpl<float, AOp, CmpOp>::op_arith_compare(
if constexpr (AOp == ArithOpType::Mod) {
return false;
} else {
// the restriction of the API
assert((size % 8) == 0);
if constexpr (AOp == ArithOpType::Div) {
if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand > 0) {
// a special case that allows faster processing by using the multiplication
// operation instead of the division one.
//
const __m512 right_v = _mm512_set1_ps(right_operand);
const __m512 value_v = _mm512_set1_ps(value);
uint16_t* const __restrict res_u16 =
reinterpret_cast<uint16_t*>(res_u8);
// the restriction of the API
assert((size % 8) == 0);
// todo: aligned reads & writes
//
const __m512 right_v = _mm512_set1_ps(right_operand);
const __m512 value_v = _mm512_set1_ps(value);
uint16_t* const __restrict res_u16 =
reinterpret_cast<uint16_t*>(res_u8);
// interleaved pages
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(float));
const size_t size_8p =
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
for (size_t p = 0; p < BLOCK_COUNT; p += 16) {
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
const __m512 v0s =
_mm512_loadu_ps(src + i + p + ip * BLOCK_COUNT);
// todo: aligned reads & writes
// interleaved pages
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(float));
const size_t size_8p =
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
for (size_t p = 0; p < BLOCK_COUNT; p += 16) {
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
const __m512 v0s =
_mm512_loadu_ps(src + i + p + ip * BLOCK_COUNT);
const __mmask16 cmp_mask =
ArithHelperF32<AOp, CmpOp>::op_special(
v0s, right_v, value_v);
res_u16[(i + p + ip * BLOCK_COUNT) / 16] = cmp_mask;
_mm_prefetch(
(const char*)(src + i + p + ip * BLOCK_COUNT) +
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
_MM_HINT_T0);
}
}
}
// process big blocks
const size_t size16 = (size / 16) * 16;
for (size_t i = size_8p; i < size16; i += 16) {
const __m512 v0s = _mm512_loadu_ps(src + i);
const __mmask16 cmp_mask =
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
ArithHelperF32<AOp, CmpOp>::op_special(
v0s, right_v, value_v);
res_u16[i / 16] = cmp_mask;
}
res_u16[(i + p + ip * BLOCK_COUNT) / 16] = cmp_mask;
// process leftovers
if (size16 != size) {
// process 8 elements
const __m256 vs = _mm256_loadu_ps(src + size16);
const __m512 v0s = _mm512_castps256_ps512(vs);
const __mmask16 cmp_mask =
ArithHelperF32<AOp, CmpOp>::op_special(
v0s, right_v, value_v);
res_u8[size16 / 8] = uint8_t(cmp_mask);
}
_mm_prefetch((const char*)(src + i + p + ip * BLOCK_COUNT) +
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
_MM_HINT_T0);
return true;
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand < 0) {
// flip signs and go for the multiplication case
return OpArithCompareImpl<float,
AOp,
CompareOpDivFlip<CmpOp>::op>::
op_arith_compare(res_u8, src, -right_operand, -value, size);
}
// go with the default case
}
// a default case
{
// the restriction of the API
assert((size % 8) == 0);
//
const __m512 right_v = _mm512_set1_ps(right_operand);
const __m512 value_v = _mm512_set1_ps(value);
uint16_t* const __restrict res_u16 =
reinterpret_cast<uint16_t*>(res_u8);
// todo: aligned reads & writes
// interleaved pages
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(float));
const size_t size_8p =
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
for (size_t p = 0; p < BLOCK_COUNT; p += 16) {
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
const __m512 v0s =
_mm512_loadu_ps(src + i + p + ip * BLOCK_COUNT);
const __mmask16 cmp_mask =
ArithHelperF32<AOp, CmpOp>::op(
v0s, right_v, value_v);
res_u16[(i + p + ip * BLOCK_COUNT) / 16] = cmp_mask;
_mm_prefetch(
(const char*)(src + i + p + ip * BLOCK_COUNT) +
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
_MM_HINT_T0);
}
}
}
}
// process big blocks
const size_t size16 = (size / 16) * 16;
for (size_t i = size_8p; i < size16; i += 16) {
const __m512 v0s = _mm512_loadu_ps(src + i);
const __mmask16 cmp_mask =
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
res_u16[i / 16] = cmp_mask;
}
// process big blocks
const size_t size16 = (size / 16) * 16;
for (size_t i = size_8p; i < size16; i += 16) {
const __m512 v0s = _mm512_loadu_ps(src + i);
const __mmask16 cmp_mask =
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
res_u16[i / 16] = cmp_mask;
}
// process leftovers
if (size16 != size) {
// process 8 elements
const __m256 vs = _mm256_loadu_ps(src + size16);
const __m512 v0s = _mm512_castps256_ps512(vs);
const __mmask16 cmp_mask =
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
res_u8[size16 / 8] = uint8_t(cmp_mask);
}
// process leftovers
if (size16 != size) {
// process 8 elements
const __m256 vs = _mm256_loadu_ps(src + size16);
const __m512 v0s = _mm512_castps256_ps512(vs);
const __mmask16 cmp_mask =
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
res_u8[size16 / 8] = uint8_t(cmp_mask);
}
return true;
return true;
}
}
}
@ -1828,47 +1924,114 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
if constexpr (AOp == ArithOpType::Mod) {
return false;
} else {
// the restriction of the API
assert((size % 8) == 0);
if constexpr (AOp == ArithOpType::Div) {
if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand > 0) {
// a special case that allows faster processing by using the multiplication
// operation instead of the division one.
//
const __m512d right_v = _mm512_set1_pd(right_operand);
const __m512d value_v = _mm512_set1_pd(value);
// the restriction of the API
assert((size % 8) == 0);
// todo: aligned reads & writes
//
const __m512d right_v = _mm512_set1_pd(right_operand);
const __m512d value_v = _mm512_set1_pd(value);
// interleaved pages
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(int64_t));
const size_t size_8p =
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
for (size_t p = 0; p < BLOCK_COUNT; p += 8) {
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
const __m512d v0s =
_mm512_loadu_pd(src + i + p + ip * BLOCK_COUNT);
// todo: aligned reads & writes
// interleaved pages
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(int64_t));
const size_t size_8p =
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
for (size_t p = 0; p < BLOCK_COUNT; p += 8) {
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
const __m512d v0s =
_mm512_loadu_pd(src + i + p + ip * BLOCK_COUNT);
const __mmask8 cmp_mask =
ArithHelperF64<AOp, CmpOp>::op_special(
v0s, right_v, value_v);
res_u8[(i + p + ip * BLOCK_COUNT) / 8] = cmp_mask;
_mm_prefetch(
(const char*)(src + i + p + ip * BLOCK_COUNT) +
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
_MM_HINT_T0);
}
}
}
// process big blocks
const size_t size8 = (size / 8) * 8;
for (size_t i = size_8p; i < size8; i += 8) {
const __m512d v0s = _mm512_loadu_pd(src + i);
const __mmask8 cmp_mask =
ArithHelperF64<AOp, CmpOp>::op(v0s, right_v, value_v);
ArithHelperF64<AOp, CmpOp>::op_special(
v0s, right_v, value_v);
res_u8[(i + p + ip * BLOCK_COUNT) / 8] = cmp_mask;
res_u8[i / 8] = cmp_mask;
}
_mm_prefetch((const char*)(src + i + p + ip * BLOCK_COUNT) +
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
_MM_HINT_T0);
return true;
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
right_operand < 0) {
// flip signs and go for the multiplication case
return OpArithCompareImpl<double,
AOp,
CompareOpDivFlip<CmpOp>::op>::
op_arith_compare(res_u8, src, -right_operand, -value, size);
}
// go with the default case
}
// a default case
{
// the restriction of the API
assert((size % 8) == 0);
//
const __m512d right_v = _mm512_set1_pd(right_operand);
const __m512d value_v = _mm512_set1_pd(value);
// todo: aligned reads & writes
// interleaved pages
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(int64_t));
const size_t size_8p =
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
for (size_t p = 0; p < BLOCK_COUNT; p += 8) {
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
const __m512d v0s =
_mm512_loadu_pd(src + i + p + ip * BLOCK_COUNT);
const __mmask8 cmp_mask =
ArithHelperF64<AOp, CmpOp>::op(
v0s, right_v, value_v);
res_u8[(i + p + ip * BLOCK_COUNT) / 8] = cmp_mask;
_mm_prefetch(
(const char*)(src + i + p + ip * BLOCK_COUNT) +
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
_MM_HINT_T0);
}
}
}
// process big blocks
const size_t size8 = (size / 8) * 8;
for (size_t i = size_8p; i < size8; i += 8) {
const __m512d v0s = _mm512_loadu_pd(src + i);
const __mmask8 cmp_mask =
ArithHelperF64<AOp, CmpOp>::op(v0s, right_v, value_v);
res_u8[i / 8] = cmp_mask;
}
return true;
}
// process big blocks
const size_t size8 = (size / 8) * 8;
for (size_t i = size_8p; i < size8; i += 8) {
const __m512d v0s = _mm512_loadu_pd(src + i);
const __mmask8 cmp_mask =
ArithHelperF64<AOp, CmpOp>::op(v0s, right_v, value_v);
res_u8[i / 8] = cmp_mask;
}
return true;
}
}

View File

@ -64,7 +64,7 @@ struct ComparePredicate<T, CompareOpType::GE> {
template <typename T>
struct ComparePredicate<T, CompareOpType::NE> {
static inline constexpr int value =
std::is_floating_point_v<T> ? _CMP_NEQ_OQ : _MM_CMPINT_NE;
std::is_floating_point_v<T> ? _CMP_NEQ_UQ : _MM_CMPINT_NE;
};
} // namespace x86

View File

@ -16,6 +16,7 @@
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <optional>
#include <random>
#include <string>
@ -308,6 +309,18 @@ FillRandom(std::vector<T>& t,
}
}
template <typename T>
void
FillRandomRange(std::vector<T>& t,
std::default_random_engine& rng,
const int32_t min_v,
const int32_t max_v) {
std::uniform_int_distribution<int32_t> tt(0, max_v);
for (size_t i = 0; i < t.size(); i++) {
t[i] = static_cast<T>(tt(rng));
}
}
template <>
void
FillRandom<std::string>(std::vector<std::string>& t,
@ -1296,18 +1309,24 @@ INSTANTIATE_TYPED_TEST_SUITE_P(InplaceWithinRangeValTest,
template <typename BitsetT, typename T>
struct TestInplaceArithCompareImplS {
static void
process(BitsetT& bitset, ArithOpType a_op, CompareOpType cmp_op) {
process(BitsetT& bitset,
ArithOpType a_op,
CompareOpType cmp_op,
const int32_t right_operand_in,
const int32_t value_in) {
using HT = ArithHighPrecisionType<T>;
const size_t n = bitset.size();
constexpr size_t max_v = 10;
constexpr int32_t max_v = 10;
std::vector<T> left(n, 0);
const HT right_operand = from_i32<HT>(2);
const HT value = from_i32<HT>(5);
const HT right_operand = from_i32<HT>(right_operand_in);
const HT value = from_i32<HT>(value_in);
std::default_random_engine rng(123);
FillRandom(left, rng, max_v);
// Generating values in (-x, x) range.
// This is fine, because we're operating with signed integers.
FillRandomRange(left, rng, -max_v, max_v);
StopWatch sw;
bitset.inplace_arith_compare(
@ -1321,110 +1340,140 @@ struct TestInplaceArithCompareImplS {
if (a_op == ArithOpType::Add) {
if (cmp_op == CompareOpType::EQ) {
ASSERT_EQ((left[i] + right_operand) == value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GE) {
ASSERT_EQ((left[i] + right_operand) >= value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GT) {
ASSERT_EQ((left[i] + right_operand) > value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LE) {
ASSERT_EQ((left[i] + right_operand) <= value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LT) {
ASSERT_EQ((left[i] + right_operand) < value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::NE) {
ASSERT_EQ((left[i] + right_operand) != value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else {
ASSERT_TRUE(false) << "Not implemented";
}
} else if (a_op == ArithOpType::Sub) {
if (cmp_op == CompareOpType::EQ) {
ASSERT_EQ((left[i] - right_operand) == value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GE) {
ASSERT_EQ((left[i] - right_operand) >= value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GT) {
ASSERT_EQ((left[i] - right_operand) > value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LE) {
ASSERT_EQ((left[i] - right_operand) <= value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LT) {
ASSERT_EQ((left[i] - right_operand) < value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::NE) {
ASSERT_EQ((left[i] - right_operand) != value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else {
ASSERT_TRUE(false) << "Not implemented";
}
} else if (a_op == ArithOpType::Mul) {
if (cmp_op == CompareOpType::EQ) {
ASSERT_EQ((left[i] * right_operand) == value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GE) {
ASSERT_EQ((left[i] * right_operand) >= value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GT) {
ASSERT_EQ((left[i] * right_operand) > value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LE) {
ASSERT_EQ((left[i] * right_operand) <= value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LT) {
ASSERT_EQ((left[i] * right_operand) < value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::NE) {
ASSERT_EQ((left[i] * right_operand) != value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else {
ASSERT_TRUE(false) << "Not implemented";
}
} else if (a_op == ArithOpType::Div) {
if (cmp_op == CompareOpType::EQ) {
ASSERT_EQ((left[i] / right_operand) == value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GE) {
ASSERT_EQ((left[i] / right_operand) >= value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GT) {
ASSERT_EQ((left[i] / right_operand) > value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LE) {
ASSERT_EQ((left[i] / right_operand) <= value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LT) {
ASSERT_EQ((left[i] / right_operand) < value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::NE) {
ASSERT_EQ((left[i] / right_operand) != value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else {
ASSERT_TRUE(false) << "Not implemented";
}
} else if (a_op == ArithOpType::Mod) {
if (cmp_op == CompareOpType::EQ) {
ASSERT_EQ(fmod(left[i], right_operand) == value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GE) {
ASSERT_EQ(fmod(left[i], right_operand) >= value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GT) {
ASSERT_EQ(fmod(left[i], right_operand) > value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LE) {
ASSERT_EQ(fmod(left[i], right_operand) <= value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LT) {
ASSERT_EQ(fmod(left[i], right_operand) < value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::NE) {
ASSERT_EQ(fmod(left[i], right_operand) != value, bitset[i])
<< i;
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else {
ASSERT_TRUE(false) << "Not implemented";
}
@ -1433,12 +1482,63 @@ struct TestInplaceArithCompareImplS {
}
}
}
static void
process_div_special(BitsetT& bitset,
CompareOpType cmp_op,
const T left_v,
const T right_v,
const T value_v) {
// test a single special point for the division
using HT = ArithHighPrecisionType<T>;
const size_t n = bitset.size();
std::vector<T> left(n, left_v);
const HT right_operand = right_v;
const HT value = value_v;
bitset.inplace_arith_compare(
left.data(), right_operand, value, n, ArithOpType::Div, cmp_op);
for (size_t i = 0; i < n; i++) {
if (cmp_op == CompareOpType::EQ) {
ASSERT_EQ((left[i] / right_operand) == value, bitset[i])
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GE) {
ASSERT_EQ((left[i] / right_operand) >= value, bitset[i])
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::GT) {
ASSERT_EQ((left[i] / right_operand) > value, bitset[i])
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LE) {
ASSERT_EQ((left[i] / right_operand) <= value, bitset[i])
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::LT) {
ASSERT_EQ((left[i] / right_operand) < value, bitset[i])
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else if (cmp_op == CompareOpType::NE) {
ASSERT_EQ((left[i] / right_operand) != value, bitset[i])
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
<< right_operand << " " << value;
} else {
ASSERT_TRUE(false) << "Not implemented";
}
}
}
};
template <typename BitsetT>
struct TestInplaceArithCompareImplS<BitsetT, std::string> {
static void
process(BitsetT&, ArithOpType, CompareOpType) {
process(
BitsetT&, ArithOpType, CompareOpType, const int32_t, const int32_t) {
// does nothing
}
};
@ -1446,40 +1546,102 @@ struct TestInplaceArithCompareImplS<BitsetT, std::string> {
template <typename BitsetT, typename T>
void
TestInplaceArithCompareImpl() {
for (const size_t n : typical_sizes) {
for (const auto a_op : typical_arith_ops) {
for (const auto cmp_op : typical_compare_ops) {
BitsetT bitset(n);
bitset.reset();
if constexpr (std::is_floating_point_v<T>)
for (const size_t n : typical_sizes) {
for (const auto a_op : typical_arith_ops) {
for (const auto cmp_op : typical_compare_ops) {
// test both positive, zero and negative
for (const int32_t right_operand : {2, 0, -2}) {
if ((!std::is_floating_point_v<T> ||
a_op == milvus::bitset::ArithOpType::Mod) &&
right_operand == 0) {
continue;
}
if (print_log) {
printf(
"Testing bitset, n=%zd, a_op=%zd\n", n, (size_t)a_op);
// test both positive, zero and negative
for (const int32_t value : {2, 0, -2}) {
BitsetT bitset(n);
bitset.reset();
if (print_log) {
printf(
"Testing bitset, n=%zd, a_op=%zd, "
"cmp_op=%zd, right_operand=%d\n",
n,
(size_t)a_op,
(size_t)cmp_op,
right_operand);
}
TestInplaceArithCompareImplS<BitsetT, T>::process(
bitset, a_op, cmp_op, right_operand, value);
for (const size_t offset : typical_offsets) {
if (offset >= n) {
continue;
}
bitset.reset();
auto view = bitset.view(offset);
if (print_log) {
printf(
"Testing bitset view, n=%zd, "
"offset=%zd, a_op=%zd, cmp_op=%zd, "
"right_operand=%d\n",
n,
offset,
(size_t)a_op,
(size_t)cmp_op,
right_operand);
}
TestInplaceArithCompareImplS<
decltype(view),
T>::process(view,
a_op,
cmp_op,
right_operand,
value);
}
}
}
}
}
}
TestInplaceArithCompareImplS<BitsetT, T>::process(
bitset, a_op, cmp_op);
if constexpr (std::is_floating_point_v<T>) {
// test various special use cases for IEEE-754 for the division operation.
std::vector<T> variety = {0,
1,
-1,
std::numeric_limits<T>::quiet_NaN(),
-std::numeric_limits<T>::quiet_NaN(),
std::numeric_limits<T>::infinity(),
-std::numeric_limits<T>::infinity()};
for (const size_t offset : typical_offsets) {
if (offset >= n) {
continue;
for (const auto cmp_op : typical_compare_ops) {
for (const T left_v : variety) {
for (const T right_v : variety) {
for (const T value_v : variety) {
// 40 should be sufficient to test avx512
BitsetT bitset(40);
bitset.reset();
if (print_log) {
printf(
"Testing bitset div special case, cmp_op=%zd, "
"left_v=%f, right_v=%f, value_v=%f\n",
(size_t)cmp_op,
left_v,
right_v,
value_v);
}
TestInplaceArithCompareImplS<BitsetT, T>::
process_div_special(
bitset, cmp_op, left_v, right_v, value_v);
}
bitset.reset();
auto view = bitset.view(offset);
if (print_log) {
printf(
"Testing bitset view, n=%zd, offset=%zd, a_op=%zd, "
"cmp_op=%zd\n",
n,
offset,
(size_t)a_op,
(size_t)cmp_op);
}
TestInplaceArithCompareImplS<decltype(view), T>::process(
view, a_op, cmp_op);
}
}
}