mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 09:38:39 +08:00
fix: fix incorrect bitset for the division comparison when the right is < 0 (#43179)
issue: https://github.com/milvus-io/milvus/issues/42900 @sunby Unfortunately, it is not that easy to fix as it was thought in #43177 Upd: also handles `Inf` and `NaN` values, and the division by zero case for `fp32` and `fp64` Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
This commit is contained in:
parent
15a6631147
commit
a848c4a8c5
@ -157,5 +157,19 @@ struct ArithCompareOperator {
|
||||
}
|
||||
};
|
||||
|
||||
// This is related for a special handling of A/B vs C comparison.
|
||||
// A multiplication operation is used instead of a division,
|
||||
// and it is needed to invert signs and change comparison operators
|
||||
// in case if the denominator is negative.
|
||||
template <CompareOpType CmpOp>
|
||||
struct CompareOpDivFlip {
|
||||
static constexpr CompareOpType op =
|
||||
(CmpOp == CompareOpType::LE) ? CompareOpType::GE
|
||||
: (CmpOp == CompareOpType::LT) ? CompareOpType::GT
|
||||
: (CmpOp == CompareOpType::GE) ? CompareOpType::LE
|
||||
: (CmpOp == CompareOpType::GT) ? CompareOpType::LT
|
||||
: CmpOp;
|
||||
};
|
||||
|
||||
} // namespace bitset
|
||||
} // namespace milvus
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include <arm_neon.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
@ -1459,14 +1460,25 @@ struct ArithHelperF32<ArithOpType::Mul, CmpOp> {
|
||||
template <CompareOpType CmpOp>
|
||||
struct ArithHelperF32<ArithOpType::Div, CmpOp> {
|
||||
static inline uint32x4x2_t
|
||||
op(const float32x4x2_t left,
|
||||
const float32x4x2_t right,
|
||||
const float32x4x2_t value) {
|
||||
op_special(const float32x4x2_t left,
|
||||
const float32x4x2_t right,
|
||||
const float32x4x2_t value) {
|
||||
// this is valid for the positive denominator, == and != cases.
|
||||
// left == right * value
|
||||
const float32x4x2_t rv = {vmulq_f32(right.val[0], value.val[0]),
|
||||
vmulq_f32(right.val[1], value.val[1])};
|
||||
return CmpHelper<CmpOp>::compare(left, rv);
|
||||
}
|
||||
|
||||
static inline uint32x4x2_t
|
||||
op(const float32x4x2_t left,
|
||||
const float32x4x2_t right,
|
||||
const float32x4x2_t value) {
|
||||
// left / right == value
|
||||
const float32x4x2_t rv = {vdivq_f32(left.val[0], right.val[0]),
|
||||
vdivq_f32(left.val[1], right.val[1])};
|
||||
return CmpHelper<CmpOp>::compare(rv, value);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
@ -1521,9 +1533,9 @@ struct ArithHelperF64<ArithOpType::Mul, CmpOp> {
|
||||
template <CompareOpType CmpOp>
|
||||
struct ArithHelperF64<ArithOpType::Div, CmpOp> {
|
||||
static inline uint64x2x4_t
|
||||
op(const float64x2x4_t left,
|
||||
const float64x2x4_t right,
|
||||
const float64x2x4_t value) {
|
||||
op_special(const float64x2x4_t left,
|
||||
const float64x2x4_t right,
|
||||
const float64x2x4_t value) {
|
||||
// left == right * value
|
||||
const float64x2x4_t rv = {vmulq_f64(right.val[0], value.val[0]),
|
||||
vmulq_f64(right.val[1], value.val[1]),
|
||||
@ -1531,6 +1543,18 @@ struct ArithHelperF64<ArithOpType::Div, CmpOp> {
|
||||
vmulq_f64(right.val[3], value.val[3])};
|
||||
return CmpHelper<CmpOp>::compare(left, rv);
|
||||
}
|
||||
|
||||
static inline uint64x2x4_t
|
||||
op(const float64x2x4_t left,
|
||||
const float64x2x4_t right,
|
||||
const float64x2x4_t value) {
|
||||
// left / right == value
|
||||
const float64x2x4_t rv = {vdivq_f64(left.val[0], right.val[0]),
|
||||
vdivq_f64(left.val[1], right.val[1]),
|
||||
vdivq_f64(left.val[2], right.val[2]),
|
||||
vdivq_f64(left.val[3], right.val[3])};
|
||||
return CmpHelper<CmpOp>::compare(rv, value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -1743,28 +1767,74 @@ OpArithCompareImpl<float, AOp, CmpOp>::op_arith_compare(
|
||||
if constexpr (AOp == ArithOpType::Mod) {
|
||||
return false;
|
||||
} else {
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
if constexpr (AOp == ArithOpType::Div) {
|
||||
if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand > 0) {
|
||||
// a special case that allows faster processing by using the multiplication
|
||||
// operation instead of the division one.
|
||||
|
||||
//
|
||||
const float32x4x2_t right_v = {vdupq_n_f32(right_operand),
|
||||
vdupq_n_f32(right_operand)};
|
||||
const float32x4x2_t value_v = {vdupq_n_f32(value), vdupq_n_f32(value)};
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
// todo: aligned reads & writes
|
||||
//
|
||||
const float32x4x2_t right_v = {vdupq_n_f32(right_operand),
|
||||
vdupq_n_f32(right_operand)};
|
||||
const float32x4x2_t value_v = {vdupq_n_f32(value),
|
||||
vdupq_n_f32(value)};
|
||||
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const float32x4x2_t v0v = {vld1q_f32(src + i),
|
||||
vld1q_f32(src + i + 4)};
|
||||
const uint32x4x2_t cmp =
|
||||
ArithHelperF32<AOp, CmpOp>::op(v0v, right_v, value_v);
|
||||
// todo: aligned reads & writes
|
||||
|
||||
const uint8_t mmask = movemask(cmp);
|
||||
res_u8[i / 8] = mmask;
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const float32x4x2_t v0v = {vld1q_f32(src + i),
|
||||
vld1q_f32(src + i + 4)};
|
||||
const uint32x4x2_t cmp =
|
||||
ArithHelperF32<AOp, CmpOp>::op_special(
|
||||
v0v, right_v, value_v);
|
||||
|
||||
const uint8_t mmask = movemask(cmp);
|
||||
res_u8[i / 8] = mmask;
|
||||
}
|
||||
|
||||
return true;
|
||||
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand < 0) {
|
||||
// flip signs and go for the multiplication case
|
||||
return OpArithCompareImpl<float,
|
||||
AOp,
|
||||
CompareOpDivFlip<CmpOp>::op>::
|
||||
op_arith_compare(res_u8, src, -right_operand, -value, size);
|
||||
}
|
||||
|
||||
// go with the default case
|
||||
}
|
||||
|
||||
return true;
|
||||
// a default case
|
||||
{
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
//
|
||||
const float32x4x2_t right_v = {vdupq_n_f32(right_operand),
|
||||
vdupq_n_f32(right_operand)};
|
||||
const float32x4x2_t value_v = {vdupq_n_f32(value),
|
||||
vdupq_n_f32(value)};
|
||||
|
||||
// todo: aligned reads & writes
|
||||
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const float32x4x2_t v0v = {vld1q_f32(src + i),
|
||||
vld1q_f32(src + i + 4)};
|
||||
const uint32x4x2_t cmp =
|
||||
ArithHelperF32<AOp, CmpOp>::op(v0v, right_v, value_v);
|
||||
|
||||
const uint8_t mmask = movemask(cmp);
|
||||
res_u8[i / 8] = mmask;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1779,35 +1849,86 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
|
||||
if constexpr (AOp == ArithOpType::Mod) {
|
||||
return false;
|
||||
} else {
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
if constexpr (AOp == ArithOpType::Div) {
|
||||
if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand > 0) {
|
||||
// a special case that allows faster processing by using the multiplication
|
||||
// operation instead of the division one.
|
||||
|
||||
//
|
||||
const float64x2x4_t right_v = {vdupq_n_f64(right_operand),
|
||||
vdupq_n_f64(right_operand),
|
||||
vdupq_n_f64(right_operand),
|
||||
vdupq_n_f64(right_operand)};
|
||||
const float64x2x4_t value_v = {vdupq_n_f64(value),
|
||||
vdupq_n_f64(value),
|
||||
vdupq_n_f64(value),
|
||||
vdupq_n_f64(value)};
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
// todo: aligned reads & writes
|
||||
//
|
||||
const float64x2x4_t right_v = {vdupq_n_f64(right_operand),
|
||||
vdupq_n_f64(right_operand),
|
||||
vdupq_n_f64(right_operand),
|
||||
vdupq_n_f64(right_operand)};
|
||||
const float64x2x4_t value_v = {vdupq_n_f64(value),
|
||||
vdupq_n_f64(value),
|
||||
vdupq_n_f64(value),
|
||||
vdupq_n_f64(value)};
|
||||
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const float64x2x4_t v0v = {vld1q_f64(src + i),
|
||||
vld1q_f64(src + i + 2),
|
||||
vld1q_f64(src + i + 4),
|
||||
vld1q_f64(src + i + 6)};
|
||||
const uint64x2x4_t cmp =
|
||||
ArithHelperF64<AOp, CmpOp>::op(v0v, right_v, value_v);
|
||||
// todo: aligned reads & writes
|
||||
|
||||
const uint8_t mmask = movemask(cmp);
|
||||
res_u8[i / 8] = mmask;
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const float64x2x4_t v0v = {vld1q_f64(src + i),
|
||||
vld1q_f64(src + i + 2),
|
||||
vld1q_f64(src + i + 4),
|
||||
vld1q_f64(src + i + 6)};
|
||||
const uint64x2x4_t cmp =
|
||||
ArithHelperF64<AOp, CmpOp>::op_special(
|
||||
v0v, right_v, value_v);
|
||||
|
||||
const uint8_t mmask = movemask(cmp);
|
||||
res_u8[i / 8] = mmask;
|
||||
}
|
||||
|
||||
return true;
|
||||
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand < 0) {
|
||||
// flip signs and go for the multiplication case
|
||||
return OpArithCompareImpl<double,
|
||||
AOp,
|
||||
CompareOpDivFlip<CmpOp>::op>::
|
||||
op_arith_compare(res_u8, src, -right_operand, -value, size);
|
||||
}
|
||||
|
||||
// go with the default case
|
||||
}
|
||||
|
||||
return true;
|
||||
// a default case
|
||||
{
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
//
|
||||
const float64x2x4_t right_v = {vdupq_n_f64(right_operand),
|
||||
vdupq_n_f64(right_operand),
|
||||
vdupq_n_f64(right_operand),
|
||||
vdupq_n_f64(right_operand)};
|
||||
const float64x2x4_t value_v = {vdupq_n_f64(value),
|
||||
vdupq_n_f64(value),
|
||||
vdupq_n_f64(value),
|
||||
vdupq_n_f64(value)};
|
||||
|
||||
// todo: aligned reads & writes
|
||||
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const float64x2x4_t v0v = {vld1q_f64(src + i),
|
||||
vld1q_f64(src + i + 2),
|
||||
vld1q_f64(src + i + 4),
|
||||
vld1q_f64(src + i + 6)};
|
||||
const uint64x2x4_t cmp =
|
||||
ArithHelperF64<AOp, CmpOp>::op(v0v, right_v, value_v);
|
||||
|
||||
const uint8_t mmask = movemask(cmp);
|
||||
res_u8[i / 8] = mmask;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include <arm_sve.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
@ -1314,15 +1315,26 @@ struct ArithHelperI64<ArithOpType::Mul, CmpOp> {
|
||||
};
|
||||
|
||||
template <CompareOpType CmpOp>
|
||||
struct ArithHelperI64<ArithOpType::Div, CmpOp> {
|
||||
struct ArithHelperF32<ArithOpType::Div, CmpOp> {
|
||||
static inline svbool_t
|
||||
op_special(const svbool_t pred,
|
||||
const svfloat32_t left,
|
||||
const svfloat32_t right,
|
||||
const svfloat32_t value) {
|
||||
// this is valid for the positive denominator, == and != cases.
|
||||
// left == right * value
|
||||
return CmpHelper<CmpOp>::compare(
|
||||
pred, left, svmul_f32_z(pred, right, value));
|
||||
}
|
||||
|
||||
static inline svbool_t
|
||||
op(const svbool_t pred,
|
||||
const svint64_t left,
|
||||
const svint64_t right,
|
||||
const svint64_t value) {
|
||||
const svfloat32_t left,
|
||||
const svfloat32_t right,
|
||||
const svfloat32_t value) {
|
||||
// left / right == value
|
||||
return CmpHelper<CmpOp>::compare(
|
||||
pred, svdiv_s64_z(pred, left, right), value);
|
||||
pred, svdiv_f32_z(pred, left, right), value);
|
||||
}
|
||||
};
|
||||
|
||||
@ -1371,14 +1383,25 @@ struct ArithHelperF32<ArithOpType::Mul, CmpOp> {
|
||||
|
||||
template <CompareOpType CmpOp>
|
||||
struct ArithHelperF32<ArithOpType::Div, CmpOp> {
|
||||
static inline svbool_t
|
||||
op_special(const svbool_t pred,
|
||||
const svfloat32_t left,
|
||||
const svfloat32_t right,
|
||||
const svfloat32_t value) {
|
||||
// this is valid for the positive denominator, == and != cases.
|
||||
// left == right * value
|
||||
return CmpHelper<CmpOp>::compare(
|
||||
pred, left, svmul_f32_z(pred, right, value));
|
||||
}
|
||||
|
||||
static inline svbool_t
|
||||
op(const svbool_t pred,
|
||||
const svfloat32_t left,
|
||||
const svfloat32_t right,
|
||||
const svfloat32_t value) {
|
||||
// left == right * value
|
||||
// left / right == value
|
||||
return CmpHelper<CmpOp>::compare(
|
||||
pred, left, svmul_f32_z(pred, right, value));
|
||||
pred, svdiv_f32_z(pred, left, right), value);
|
||||
}
|
||||
};
|
||||
|
||||
@ -1427,14 +1450,25 @@ struct ArithHelperF64<ArithOpType::Mul, CmpOp> {
|
||||
|
||||
template <CompareOpType CmpOp>
|
||||
struct ArithHelperF64<ArithOpType::Div, CmpOp> {
|
||||
static inline svbool_t
|
||||
op_special(const svbool_t pred,
|
||||
const svfloat64_t left,
|
||||
const svfloat64_t right,
|
||||
const svfloat64_t value) {
|
||||
// this is valid for the positive denominator, == and != cases.
|
||||
// left == right * value
|
||||
return CmpHelper<CmpOp>::compare(
|
||||
pred, left, svmul_f64_z(pred, right, value));
|
||||
}
|
||||
|
||||
static inline svbool_t
|
||||
op(const svbool_t pred,
|
||||
const svfloat64_t left,
|
||||
const svfloat64_t right,
|
||||
const svfloat64_t value) {
|
||||
// left == right * value
|
||||
// left / right == value
|
||||
return CmpHelper<CmpOp>::compare(
|
||||
pred, left, svmul_f64_z(pred, right, value));
|
||||
pred, svdiv_f64_z(pred, left, right), value);
|
||||
}
|
||||
};
|
||||
|
||||
@ -1573,22 +1607,60 @@ OpArithCompareImpl<float, AOp, CmpOp>::op_arith_compare(
|
||||
if constexpr (AOp == ArithOpType::Mod) {
|
||||
return false;
|
||||
} else {
|
||||
using T = float;
|
||||
if constexpr (AOp == ArithOpType::Div) {
|
||||
if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand > 0) {
|
||||
// a special case that allows faster processing by using the multiplication
|
||||
// operation instead of the division one.
|
||||
|
||||
auto handler = [src, right_operand, value](const svbool_t pred,
|
||||
const size_t idx) {
|
||||
using sve_t = SVEVector<T>;
|
||||
using T = float;
|
||||
|
||||
const auto right_v = svdup_n_f32(right_operand);
|
||||
const auto value_v = svdup_n_f32(value);
|
||||
const svfloat32_t src_v = svld1_f32(pred, src + idx);
|
||||
auto handler = [src, right_operand, value](const svbool_t pred,
|
||||
const size_t idx) {
|
||||
using sve_t = SVEVector<T>;
|
||||
|
||||
const svbool_t cmp =
|
||||
ArithHelperF32<AOp, CmpOp>::op(pred, src_v, right_v, value_v);
|
||||
return cmp;
|
||||
};
|
||||
const auto right_v = svdup_n_f32(right_operand);
|
||||
const auto value_v = svdup_n_f32(value);
|
||||
const svfloat32_t src_v = svld1_f32(pred, src + idx);
|
||||
|
||||
return op_mask_helper<T, decltype(handler)>(res_u8, size, handler);
|
||||
const svbool_t cmp = ArithHelperF32<AOp, CmpOp>::op_special(
|
||||
pred, src_v, right_v, value_v);
|
||||
return cmp;
|
||||
};
|
||||
|
||||
return op_mask_helper<T, decltype(handler)>(
|
||||
res_u8, size, handler);
|
||||
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand < 0) {
|
||||
// flip signs and go for the multiplication case
|
||||
return OpArithCompareImpl<float,
|
||||
AOp,
|
||||
CompareOpDivFlip<CmpOp>::op>::
|
||||
op_arith_compare(res_u8, src, -right_operand, -value, size);
|
||||
}
|
||||
|
||||
// go with the default case
|
||||
}
|
||||
|
||||
// a default case
|
||||
{
|
||||
using T = float;
|
||||
|
||||
auto handler = [src, right_operand, value](const svbool_t pred,
|
||||
const size_t idx) {
|
||||
using sve_t = SVEVector<T>;
|
||||
|
||||
const auto right_v = svdup_n_f32(right_operand);
|
||||
const auto value_v = svdup_n_f32(value);
|
||||
const svfloat32_t src_v = svld1_f32(pred, src + idx);
|
||||
|
||||
const svbool_t cmp = ArithHelperF32<AOp, CmpOp>::op(
|
||||
pred, src_v, right_v, value_v);
|
||||
return cmp;
|
||||
};
|
||||
|
||||
return op_mask_helper<T, decltype(handler)>(res_u8, size, handler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1603,22 +1675,60 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
|
||||
if constexpr (AOp == ArithOpType::Mod) {
|
||||
return false;
|
||||
} else {
|
||||
using T = double;
|
||||
if constexpr (AOp == ArithOpType::Div) {
|
||||
if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand > 0) {
|
||||
// a special case that allows faster processing by using the multiplication
|
||||
// operation instead of the division one.
|
||||
|
||||
auto handler = [src, right_operand, value](const svbool_t pred,
|
||||
const size_t idx) {
|
||||
using sve_t = SVEVector<T>;
|
||||
using T = double;
|
||||
|
||||
const auto right_v = svdup_n_f64(right_operand);
|
||||
const auto value_v = svdup_n_f64(value);
|
||||
const svfloat64_t src_v = svld1_f64(pred, src + idx);
|
||||
auto handler = [src, right_operand, value](const svbool_t pred,
|
||||
const size_t idx) {
|
||||
using sve_t = SVEVector<T>;
|
||||
|
||||
const svbool_t cmp =
|
||||
ArithHelperF64<AOp, CmpOp>::op(pred, src_v, right_v, value_v);
|
||||
return cmp;
|
||||
};
|
||||
const auto right_v = svdup_n_f64(right_operand);
|
||||
const auto value_v = svdup_n_f64(value);
|
||||
const svfloat64_t src_v = svld1_f64(pred, src + idx);
|
||||
|
||||
return op_mask_helper<T, decltype(handler)>(res_u8, size, handler);
|
||||
const svbool_t cmp = ArithHelperF64<AOp, CmpOp>::op(
|
||||
pred, src_v, right_v, value_v);
|
||||
return cmp;
|
||||
};
|
||||
|
||||
return op_mask_helper<T, decltype(handler)>(
|
||||
res_u8, size, handler);
|
||||
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand < 0) {
|
||||
// flip signs and go for the multiplication case
|
||||
return OpArithCompareImpl<double,
|
||||
AOp,
|
||||
CompareOpDivFlip<CmpOp>::op>::
|
||||
op_arith_compare(res_u8, src, -right_operand, -value, size);
|
||||
}
|
||||
|
||||
// go with the default case
|
||||
}
|
||||
|
||||
// a default case
|
||||
{
|
||||
using T = double;
|
||||
|
||||
auto handler = [src, right_operand, value](const svbool_t pred,
|
||||
const size_t idx) {
|
||||
using sve_t = SVEVector<T>;
|
||||
|
||||
const auto right_v = svdup_n_f64(right_operand);
|
||||
const auto value_v = svdup_n_f64(value);
|
||||
const svfloat64_t src_v = svld1_f64(pred, src + idx);
|
||||
|
||||
const svbool_t cmp = ArithHelperF64<AOp, CmpOp>::op(
|
||||
pred, src_v, right_v, value_v);
|
||||
return cmp;
|
||||
};
|
||||
|
||||
return op_mask_helper<T, decltype(handler)>(res_u8, size, handler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include <immintrin.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
@ -1347,11 +1348,19 @@ struct ArithHelperF32<ArithOpType::Mul, CmpOp> {
|
||||
template <CompareOpType CmpOp>
|
||||
struct ArithHelperF32<ArithOpType::Div, CmpOp> {
|
||||
static inline __m256
|
||||
op(const __m256 left, const __m256 right, const __m256 value) {
|
||||
op_special(const __m256 left, const __m256 right, const __m256 value) {
|
||||
// this is valid for the positive denominator, == and != cases.
|
||||
// left == right * value
|
||||
constexpr auto pred = ComparePredicate<float, CmpOp>::value;
|
||||
return _mm256_cmp_ps(left, _mm256_mul_ps(right, value), pred);
|
||||
}
|
||||
|
||||
static inline __m256
|
||||
op(const __m256 left, const __m256 right, const __m256 value) {
|
||||
// left / right == value
|
||||
constexpr auto pred = ComparePredicate<float, CmpOp>::value;
|
||||
return _mm256_cmp_ps(_mm256_div_ps(left, right), value, pred);
|
||||
}
|
||||
};
|
||||
|
||||
// todo: Mod
|
||||
@ -1393,11 +1402,19 @@ struct ArithHelperF64<ArithOpType::Mul, CmpOp> {
|
||||
template <CompareOpType CmpOp>
|
||||
struct ArithHelperF64<ArithOpType::Div, CmpOp> {
|
||||
static inline __m256d
|
||||
op(const __m256d left, const __m256d right, const __m256d value) {
|
||||
op_special(const __m256d left, const __m256d right, const __m256d value) {
|
||||
// this is valid for the positive denominator, == and != cases.
|
||||
// left == right * value
|
||||
constexpr auto pred = ComparePredicate<double, CmpOp>::value;
|
||||
return _mm256_cmp_pd(left, _mm256_mul_pd(right, value), pred);
|
||||
}
|
||||
|
||||
static inline __m256d
|
||||
op(const __m256d left, const __m256d right, const __m256d value) {
|
||||
// left / right == value
|
||||
constexpr auto pred = ComparePredicate<double, CmpOp>::value;
|
||||
return _mm256_cmp_pd(_mm256_div_pd(left, right), value, pred);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -1589,26 +1606,67 @@ OpArithCompareImpl<float, AOp, CmpOp>::op_arith_compare(
|
||||
if constexpr (AOp == ArithOpType::Mod) {
|
||||
return false;
|
||||
} else {
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
if constexpr (AOp == ArithOpType::Div) {
|
||||
if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand > 0) {
|
||||
// a special case that allows faster processing by using the multiplication
|
||||
// operation instead of the division one.
|
||||
|
||||
//
|
||||
const __m256 right_v = _mm256_set1_ps(right_operand);
|
||||
const __m256 value_v = _mm256_set1_ps(value);
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
// todo: aligned reads & writes
|
||||
//
|
||||
const __m256 right_v = _mm256_set1_ps(right_operand);
|
||||
const __m256 value_v = _mm256_set1_ps(value);
|
||||
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const __m256 v0s = _mm256_loadu_ps(src + i);
|
||||
const __m256 cmp =
|
||||
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
const uint8_t mmask = _mm256_movemask_ps(cmp);
|
||||
// todo: aligned reads & writes
|
||||
|
||||
res_u8[i / 8] = mmask;
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const __m256 v0s = _mm256_loadu_ps(src + i);
|
||||
const __m256 cmp = ArithHelperF32<AOp, CmpOp>::op_special(
|
||||
v0s, right_v, value_v);
|
||||
const uint8_t mmask = _mm256_movemask_ps(cmp);
|
||||
|
||||
res_u8[i / 8] = mmask;
|
||||
}
|
||||
|
||||
return true;
|
||||
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand < 0) {
|
||||
// flip signs and go for the multiplication case
|
||||
return OpArithCompareImpl<float,
|
||||
AOp,
|
||||
CompareOpDivFlip<CmpOp>::op>::
|
||||
op_arith_compare(res_u8, src, -right_operand, -value, size);
|
||||
}
|
||||
|
||||
// go with the default case
|
||||
}
|
||||
|
||||
return true;
|
||||
// a default case
|
||||
{
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
//
|
||||
const __m256 right_v = _mm256_set1_ps(right_operand);
|
||||
const __m256 value_v = _mm256_set1_ps(value);
|
||||
|
||||
// todo: aligned reads & writes
|
||||
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const __m256 v0s = _mm256_loadu_ps(src + i);
|
||||
const __m256 cmp =
|
||||
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
const uint8_t mmask = _mm256_movemask_ps(cmp);
|
||||
|
||||
res_u8[i / 8] = mmask;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1623,30 +1681,75 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
|
||||
if constexpr (AOp == ArithOpType::Mod) {
|
||||
return false;
|
||||
} else {
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
if constexpr (AOp == ArithOpType::Div) {
|
||||
if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand > 0) {
|
||||
// a special case that allows faster processing by using the multiplication
|
||||
// operation instead of the division one.
|
||||
|
||||
//
|
||||
const __m256d right_v = _mm256_set1_pd(right_operand);
|
||||
const __m256d value_v = _mm256_set1_pd(value);
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
// todo: aligned reads & writes
|
||||
//
|
||||
const __m256d right_v = _mm256_set1_pd(right_operand);
|
||||
const __m256d value_v = _mm256_set1_pd(value);
|
||||
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const __m256d v0s = _mm256_loadu_pd(src + i);
|
||||
const __m256d v1s = _mm256_loadu_pd(src + i + 4);
|
||||
const __m256d cmp0 =
|
||||
ArithHelperF64<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
const __m256d cmp1 =
|
||||
ArithHelperF64<AOp, CmpOp>::op(v1s, right_v, value_v);
|
||||
const uint8_t mmask0 = _mm256_movemask_pd(cmp0);
|
||||
const uint8_t mmask1 = _mm256_movemask_pd(cmp1);
|
||||
// todo: aligned reads & writes
|
||||
|
||||
res_u8[i / 8] = mmask0 + mmask1 * 16;
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const __m256d v0s = _mm256_loadu_pd(src + i);
|
||||
const __m256d v1s = _mm256_loadu_pd(src + i + 4);
|
||||
const __m256d cmp0 = ArithHelperF64<AOp, CmpOp>::op_special(
|
||||
v0s, right_v, value_v);
|
||||
const __m256d cmp1 = ArithHelperF64<AOp, CmpOp>::op_special(
|
||||
v1s, right_v, value_v);
|
||||
const uint8_t mmask0 = _mm256_movemask_pd(cmp0);
|
||||
const uint8_t mmask1 = _mm256_movemask_pd(cmp1);
|
||||
|
||||
res_u8[i / 8] = mmask0 + mmask1 * 16;
|
||||
}
|
||||
|
||||
return true;
|
||||
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand < 0) {
|
||||
// flip signs and go for the multiplication case
|
||||
return OpArithCompareImpl<double,
|
||||
AOp,
|
||||
CompareOpDivFlip<CmpOp>::op>::
|
||||
op_arith_compare(res_u8, src, -right_operand, -value, size);
|
||||
}
|
||||
|
||||
// go with the default case
|
||||
}
|
||||
|
||||
return true;
|
||||
// a default case
|
||||
{
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
//
|
||||
const __m256d right_v = _mm256_set1_pd(right_operand);
|
||||
const __m256d value_v = _mm256_set1_pd(value);
|
||||
|
||||
// todo: aligned reads & writes
|
||||
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = 0; i < size8; i += 8) {
|
||||
const __m256d v0s = _mm256_loadu_pd(src + i);
|
||||
const __m256d v1s = _mm256_loadu_pd(src + i + 4);
|
||||
const __m256d cmp0 =
|
||||
ArithHelperF64<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
const __m256d cmp1 =
|
||||
ArithHelperF64<AOp, CmpOp>::op(v1s, right_v, value_v);
|
||||
const uint8_t mmask0 = _mm256_movemask_pd(cmp0);
|
||||
const uint8_t mmask1 = _mm256_movemask_pd(cmp1);
|
||||
|
||||
res_u8[i / 8] = mmask0 + mmask1 * 16;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include <immintrin.h>
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
@ -1391,11 +1392,19 @@ struct ArithHelperF32<ArithOpType::Mul, CmpOp> {
|
||||
template <CompareOpType CmpOp>
|
||||
struct ArithHelperF32<ArithOpType::Div, CmpOp> {
|
||||
static inline __mmask16
|
||||
op(const __m512 left, const __m512 right, const __m512 value) {
|
||||
op_special(const __m512 left, const __m512 right, const __m512 value) {
|
||||
// this is valid for the positive denominator, == and != cases.
|
||||
// left == right * value
|
||||
constexpr auto pred = ComparePredicate<float, CmpOp>::value;
|
||||
return _mm512_cmp_ps_mask(left, _mm512_mul_ps(right, value), pred);
|
||||
}
|
||||
|
||||
static inline __mmask16
|
||||
op(const __m512 left, const __m512 right, const __m512 value) {
|
||||
// left / right == value
|
||||
constexpr auto pred = ComparePredicate<float, CmpOp>::value;
|
||||
return _mm512_cmp_ps_mask(_mm512_div_ps(left, right), value, pred);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
@ -1435,11 +1444,19 @@ struct ArithHelperF64<ArithOpType::Mul, CmpOp> {
|
||||
template <CompareOpType CmpOp>
|
||||
struct ArithHelperF64<ArithOpType::Div, CmpOp> {
|
||||
static inline __mmask8
|
||||
op(const __m512d left, const __m512d right, const __m512d value) {
|
||||
op_special(const __m512d left, const __m512d right, const __m512d value) {
|
||||
// this is valid for the positive denominator, == and != cases.
|
||||
// left == right * value
|
||||
constexpr auto pred = ComparePredicate<double, CmpOp>::value;
|
||||
return _mm512_cmp_pd_mask(left, _mm512_mul_pd(right, value), pred);
|
||||
}
|
||||
|
||||
static inline __mmask8
|
||||
op(const __m512d left, const __m512d right, const __m512d value) {
|
||||
// left / right == value
|
||||
constexpr auto pred = ComparePredicate<double, CmpOp>::value;
|
||||
return _mm512_cmp_pd_mask(_mm512_div_pd(left, right), value, pred);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -1762,58 +1779,137 @@ OpArithCompareImpl<float, AOp, CmpOp>::op_arith_compare(
|
||||
if constexpr (AOp == ArithOpType::Mod) {
|
||||
return false;
|
||||
} else {
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
if constexpr (AOp == ArithOpType::Div) {
|
||||
if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand > 0) {
|
||||
// a special case that allows faster processing by using the multiplication
|
||||
// operation instead of the division one.
|
||||
|
||||
//
|
||||
const __m512 right_v = _mm512_set1_ps(right_operand);
|
||||
const __m512 value_v = _mm512_set1_ps(value);
|
||||
uint16_t* const __restrict res_u16 =
|
||||
reinterpret_cast<uint16_t*>(res_u8);
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
// todo: aligned reads & writes
|
||||
//
|
||||
const __m512 right_v = _mm512_set1_ps(right_operand);
|
||||
const __m512 value_v = _mm512_set1_ps(value);
|
||||
uint16_t* const __restrict res_u16 =
|
||||
reinterpret_cast<uint16_t*>(res_u8);
|
||||
|
||||
// interleaved pages
|
||||
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(float));
|
||||
const size_t size_8p =
|
||||
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
|
||||
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
|
||||
for (size_t p = 0; p < BLOCK_COUNT; p += 16) {
|
||||
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
|
||||
const __m512 v0s =
|
||||
_mm512_loadu_ps(src + i + p + ip * BLOCK_COUNT);
|
||||
// todo: aligned reads & writes
|
||||
|
||||
// interleaved pages
|
||||
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(float));
|
||||
const size_t size_8p =
|
||||
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
|
||||
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
|
||||
for (size_t p = 0; p < BLOCK_COUNT; p += 16) {
|
||||
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
|
||||
const __m512 v0s =
|
||||
_mm512_loadu_ps(src + i + p + ip * BLOCK_COUNT);
|
||||
const __mmask16 cmp_mask =
|
||||
ArithHelperF32<AOp, CmpOp>::op_special(
|
||||
v0s, right_v, value_v);
|
||||
|
||||
res_u16[(i + p + ip * BLOCK_COUNT) / 16] = cmp_mask;
|
||||
|
||||
_mm_prefetch(
|
||||
(const char*)(src + i + p + ip * BLOCK_COUNT) +
|
||||
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
|
||||
_MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process big blocks
|
||||
const size_t size16 = (size / 16) * 16;
|
||||
for (size_t i = size_8p; i < size16; i += 16) {
|
||||
const __m512 v0s = _mm512_loadu_ps(src + i);
|
||||
const __mmask16 cmp_mask =
|
||||
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
ArithHelperF32<AOp, CmpOp>::op_special(
|
||||
v0s, right_v, value_v);
|
||||
res_u16[i / 16] = cmp_mask;
|
||||
}
|
||||
|
||||
res_u16[(i + p + ip * BLOCK_COUNT) / 16] = cmp_mask;
|
||||
// process leftovers
|
||||
if (size16 != size) {
|
||||
// process 8 elements
|
||||
const __m256 vs = _mm256_loadu_ps(src + size16);
|
||||
const __m512 v0s = _mm512_castps256_ps512(vs);
|
||||
const __mmask16 cmp_mask =
|
||||
ArithHelperF32<AOp, CmpOp>::op_special(
|
||||
v0s, right_v, value_v);
|
||||
res_u8[size16 / 8] = uint8_t(cmp_mask);
|
||||
}
|
||||
|
||||
_mm_prefetch((const char*)(src + i + p + ip * BLOCK_COUNT) +
|
||||
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
|
||||
_MM_HINT_T0);
|
||||
return true;
|
||||
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand < 0) {
|
||||
// flip signs and go for the multiplication case
|
||||
return OpArithCompareImpl<float,
|
||||
AOp,
|
||||
CompareOpDivFlip<CmpOp>::op>::
|
||||
op_arith_compare(res_u8, src, -right_operand, -value, size);
|
||||
}
|
||||
|
||||
// go with the default case
|
||||
}
|
||||
|
||||
// a default case
|
||||
{
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
//
|
||||
const __m512 right_v = _mm512_set1_ps(right_operand);
|
||||
const __m512 value_v = _mm512_set1_ps(value);
|
||||
uint16_t* const __restrict res_u16 =
|
||||
reinterpret_cast<uint16_t*>(res_u8);
|
||||
|
||||
// todo: aligned reads & writes
|
||||
|
||||
// interleaved pages
|
||||
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(float));
|
||||
const size_t size_8p =
|
||||
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
|
||||
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
|
||||
for (size_t p = 0; p < BLOCK_COUNT; p += 16) {
|
||||
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
|
||||
const __m512 v0s =
|
||||
_mm512_loadu_ps(src + i + p + ip * BLOCK_COUNT);
|
||||
const __mmask16 cmp_mask =
|
||||
ArithHelperF32<AOp, CmpOp>::op(
|
||||
v0s, right_v, value_v);
|
||||
|
||||
res_u16[(i + p + ip * BLOCK_COUNT) / 16] = cmp_mask;
|
||||
|
||||
_mm_prefetch(
|
||||
(const char*)(src + i + p + ip * BLOCK_COUNT) +
|
||||
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
|
||||
_MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process big blocks
|
||||
const size_t size16 = (size / 16) * 16;
|
||||
for (size_t i = size_8p; i < size16; i += 16) {
|
||||
const __m512 v0s = _mm512_loadu_ps(src + i);
|
||||
const __mmask16 cmp_mask =
|
||||
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
res_u16[i / 16] = cmp_mask;
|
||||
}
|
||||
// process big blocks
|
||||
const size_t size16 = (size / 16) * 16;
|
||||
for (size_t i = size_8p; i < size16; i += 16) {
|
||||
const __m512 v0s = _mm512_loadu_ps(src + i);
|
||||
const __mmask16 cmp_mask =
|
||||
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
res_u16[i / 16] = cmp_mask;
|
||||
}
|
||||
|
||||
// process leftovers
|
||||
if (size16 != size) {
|
||||
// process 8 elements
|
||||
const __m256 vs = _mm256_loadu_ps(src + size16);
|
||||
const __m512 v0s = _mm512_castps256_ps512(vs);
|
||||
const __mmask16 cmp_mask =
|
||||
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
res_u8[size16 / 8] = uint8_t(cmp_mask);
|
||||
}
|
||||
// process leftovers
|
||||
if (size16 != size) {
|
||||
// process 8 elements
|
||||
const __m256 vs = _mm256_loadu_ps(src + size16);
|
||||
const __m512 v0s = _mm512_castps256_ps512(vs);
|
||||
const __mmask16 cmp_mask =
|
||||
ArithHelperF32<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
res_u8[size16 / 8] = uint8_t(cmp_mask);
|
||||
}
|
||||
|
||||
return true;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1828,47 +1924,114 @@ OpArithCompareImpl<double, AOp, CmpOp>::op_arith_compare(
|
||||
if constexpr (AOp == ArithOpType::Mod) {
|
||||
return false;
|
||||
} else {
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
if constexpr (AOp == ArithOpType::Div) {
|
||||
if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand > 0) {
|
||||
// a special case that allows faster processing by using the multiplication
|
||||
// operation instead of the division one.
|
||||
|
||||
//
|
||||
const __m512d right_v = _mm512_set1_pd(right_operand);
|
||||
const __m512d value_v = _mm512_set1_pd(value);
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
// todo: aligned reads & writes
|
||||
//
|
||||
const __m512d right_v = _mm512_set1_pd(right_operand);
|
||||
const __m512d value_v = _mm512_set1_pd(value);
|
||||
|
||||
// interleaved pages
|
||||
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(int64_t));
|
||||
const size_t size_8p =
|
||||
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
|
||||
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
|
||||
for (size_t p = 0; p < BLOCK_COUNT; p += 8) {
|
||||
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
|
||||
const __m512d v0s =
|
||||
_mm512_loadu_pd(src + i + p + ip * BLOCK_COUNT);
|
||||
// todo: aligned reads & writes
|
||||
|
||||
// interleaved pages
|
||||
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(int64_t));
|
||||
const size_t size_8p =
|
||||
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
|
||||
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
|
||||
for (size_t p = 0; p < BLOCK_COUNT; p += 8) {
|
||||
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
|
||||
const __m512d v0s =
|
||||
_mm512_loadu_pd(src + i + p + ip * BLOCK_COUNT);
|
||||
const __mmask8 cmp_mask =
|
||||
ArithHelperF64<AOp, CmpOp>::op_special(
|
||||
v0s, right_v, value_v);
|
||||
|
||||
res_u8[(i + p + ip * BLOCK_COUNT) / 8] = cmp_mask;
|
||||
|
||||
_mm_prefetch(
|
||||
(const char*)(src + i + p + ip * BLOCK_COUNT) +
|
||||
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
|
||||
_MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process big blocks
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = size_8p; i < size8; i += 8) {
|
||||
const __m512d v0s = _mm512_loadu_pd(src + i);
|
||||
const __mmask8 cmp_mask =
|
||||
ArithHelperF64<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
ArithHelperF64<AOp, CmpOp>::op_special(
|
||||
v0s, right_v, value_v);
|
||||
|
||||
res_u8[(i + p + ip * BLOCK_COUNT) / 8] = cmp_mask;
|
||||
res_u8[i / 8] = cmp_mask;
|
||||
}
|
||||
|
||||
_mm_prefetch((const char*)(src + i + p + ip * BLOCK_COUNT) +
|
||||
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
|
||||
_MM_HINT_T0);
|
||||
return true;
|
||||
} else if (std::isfinite(value) && std::isfinite(right_operand) &&
|
||||
right_operand < 0) {
|
||||
// flip signs and go for the multiplication case
|
||||
return OpArithCompareImpl<double,
|
||||
AOp,
|
||||
CompareOpDivFlip<CmpOp>::op>::
|
||||
op_arith_compare(res_u8, src, -right_operand, -value, size);
|
||||
}
|
||||
|
||||
// go with the default case
|
||||
}
|
||||
|
||||
// a default case
|
||||
{
|
||||
// the restriction of the API
|
||||
assert((size % 8) == 0);
|
||||
|
||||
//
|
||||
const __m512d right_v = _mm512_set1_pd(right_operand);
|
||||
const __m512d value_v = _mm512_set1_pd(value);
|
||||
|
||||
// todo: aligned reads & writes
|
||||
|
||||
// interleaved pages
|
||||
constexpr size_t BLOCK_COUNT = PAGE_SIZE / (sizeof(int64_t));
|
||||
const size_t size_8p =
|
||||
(size / (N_BLOCKS * BLOCK_COUNT)) * N_BLOCKS * BLOCK_COUNT;
|
||||
for (size_t i = 0; i < size_8p; i += N_BLOCKS * BLOCK_COUNT) {
|
||||
for (size_t p = 0; p < BLOCK_COUNT; p += 8) {
|
||||
for (size_t ip = 0; ip < N_BLOCKS; ip++) {
|
||||
const __m512d v0s =
|
||||
_mm512_loadu_pd(src + i + p + ip * BLOCK_COUNT);
|
||||
const __mmask8 cmp_mask =
|
||||
ArithHelperF64<AOp, CmpOp>::op(
|
||||
v0s, right_v, value_v);
|
||||
|
||||
res_u8[(i + p + ip * BLOCK_COUNT) / 8] = cmp_mask;
|
||||
|
||||
_mm_prefetch(
|
||||
(const char*)(src + i + p + ip * BLOCK_COUNT) +
|
||||
BLOCKS_PREFETCH_AHEAD * CACHELINE_WIDTH,
|
||||
_MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// process big blocks
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = size_8p; i < size8; i += 8) {
|
||||
const __m512d v0s = _mm512_loadu_pd(src + i);
|
||||
const __mmask8 cmp_mask =
|
||||
ArithHelperF64<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
|
||||
res_u8[i / 8] = cmp_mask;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// process big blocks
|
||||
const size_t size8 = (size / 8) * 8;
|
||||
for (size_t i = size_8p; i < size8; i += 8) {
|
||||
const __m512d v0s = _mm512_loadu_pd(src + i);
|
||||
const __mmask8 cmp_mask =
|
||||
ArithHelperF64<AOp, CmpOp>::op(v0s, right_v, value_v);
|
||||
|
||||
res_u8[i / 8] = cmp_mask;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ struct ComparePredicate<T, CompareOpType::GE> {
|
||||
template <typename T>
|
||||
struct ComparePredicate<T, CompareOpType::NE> {
|
||||
static inline constexpr int value =
|
||||
std::is_floating_point_v<T> ? _CMP_NEQ_OQ : _MM_CMPINT_NE;
|
||||
std::is_floating_point_v<T> ? _CMP_NEQ_UQ : _MM_CMPINT_NE;
|
||||
};
|
||||
|
||||
} // namespace x86
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <string>
|
||||
@ -308,6 +309,18 @@ FillRandom(std::vector<T>& t,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
FillRandomRange(std::vector<T>& t,
|
||||
std::default_random_engine& rng,
|
||||
const int32_t min_v,
|
||||
const int32_t max_v) {
|
||||
std::uniform_int_distribution<int32_t> tt(0, max_v);
|
||||
for (size_t i = 0; i < t.size(); i++) {
|
||||
t[i] = static_cast<T>(tt(rng));
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void
|
||||
FillRandom<std::string>(std::vector<std::string>& t,
|
||||
@ -1296,18 +1309,24 @@ INSTANTIATE_TYPED_TEST_SUITE_P(InplaceWithinRangeValTest,
|
||||
template <typename BitsetT, typename T>
|
||||
struct TestInplaceArithCompareImplS {
|
||||
static void
|
||||
process(BitsetT& bitset, ArithOpType a_op, CompareOpType cmp_op) {
|
||||
process(BitsetT& bitset,
|
||||
ArithOpType a_op,
|
||||
CompareOpType cmp_op,
|
||||
const int32_t right_operand_in,
|
||||
const int32_t value_in) {
|
||||
using HT = ArithHighPrecisionType<T>;
|
||||
|
||||
const size_t n = bitset.size();
|
||||
constexpr size_t max_v = 10;
|
||||
constexpr int32_t max_v = 10;
|
||||
|
||||
std::vector<T> left(n, 0);
|
||||
const HT right_operand = from_i32<HT>(2);
|
||||
const HT value = from_i32<HT>(5);
|
||||
const HT right_operand = from_i32<HT>(right_operand_in);
|
||||
const HT value = from_i32<HT>(value_in);
|
||||
|
||||
std::default_random_engine rng(123);
|
||||
FillRandom(left, rng, max_v);
|
||||
// Generating values in (-x, x) range.
|
||||
// This is fine, because we're operating with signed integers.
|
||||
FillRandomRange(left, rng, -max_v, max_v);
|
||||
|
||||
StopWatch sw;
|
||||
bitset.inplace_arith_compare(
|
||||
@ -1321,110 +1340,140 @@ struct TestInplaceArithCompareImplS {
|
||||
if (a_op == ArithOpType::Add) {
|
||||
if (cmp_op == CompareOpType::EQ) {
|
||||
ASSERT_EQ((left[i] + right_operand) == value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GE) {
|
||||
ASSERT_EQ((left[i] + right_operand) >= value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GT) {
|
||||
ASSERT_EQ((left[i] + right_operand) > value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LE) {
|
||||
ASSERT_EQ((left[i] + right_operand) <= value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LT) {
|
||||
ASSERT_EQ((left[i] + right_operand) < value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::NE) {
|
||||
ASSERT_EQ((left[i] + right_operand) != value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else {
|
||||
ASSERT_TRUE(false) << "Not implemented";
|
||||
}
|
||||
} else if (a_op == ArithOpType::Sub) {
|
||||
if (cmp_op == CompareOpType::EQ) {
|
||||
ASSERT_EQ((left[i] - right_operand) == value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GE) {
|
||||
ASSERT_EQ((left[i] - right_operand) >= value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GT) {
|
||||
ASSERT_EQ((left[i] - right_operand) > value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LE) {
|
||||
ASSERT_EQ((left[i] - right_operand) <= value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LT) {
|
||||
ASSERT_EQ((left[i] - right_operand) < value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::NE) {
|
||||
ASSERT_EQ((left[i] - right_operand) != value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else {
|
||||
ASSERT_TRUE(false) << "Not implemented";
|
||||
}
|
||||
} else if (a_op == ArithOpType::Mul) {
|
||||
if (cmp_op == CompareOpType::EQ) {
|
||||
ASSERT_EQ((left[i] * right_operand) == value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GE) {
|
||||
ASSERT_EQ((left[i] * right_operand) >= value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GT) {
|
||||
ASSERT_EQ((left[i] * right_operand) > value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LE) {
|
||||
ASSERT_EQ((left[i] * right_operand) <= value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LT) {
|
||||
ASSERT_EQ((left[i] * right_operand) < value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::NE) {
|
||||
ASSERT_EQ((left[i] * right_operand) != value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else {
|
||||
ASSERT_TRUE(false) << "Not implemented";
|
||||
}
|
||||
} else if (a_op == ArithOpType::Div) {
|
||||
if (cmp_op == CompareOpType::EQ) {
|
||||
ASSERT_EQ((left[i] / right_operand) == value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GE) {
|
||||
ASSERT_EQ((left[i] / right_operand) >= value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GT) {
|
||||
ASSERT_EQ((left[i] / right_operand) > value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LE) {
|
||||
ASSERT_EQ((left[i] / right_operand) <= value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LT) {
|
||||
ASSERT_EQ((left[i] / right_operand) < value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::NE) {
|
||||
ASSERT_EQ((left[i] / right_operand) != value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else {
|
||||
ASSERT_TRUE(false) << "Not implemented";
|
||||
}
|
||||
} else if (a_op == ArithOpType::Mod) {
|
||||
if (cmp_op == CompareOpType::EQ) {
|
||||
ASSERT_EQ(fmod(left[i], right_operand) == value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GE) {
|
||||
ASSERT_EQ(fmod(left[i], right_operand) >= value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GT) {
|
||||
ASSERT_EQ(fmod(left[i], right_operand) > value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LE) {
|
||||
ASSERT_EQ(fmod(left[i], right_operand) <= value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LT) {
|
||||
ASSERT_EQ(fmod(left[i], right_operand) < value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::NE) {
|
||||
ASSERT_EQ(fmod(left[i], right_operand) != value, bitset[i])
|
||||
<< i;
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else {
|
||||
ASSERT_TRUE(false) << "Not implemented";
|
||||
}
|
||||
@ -1433,12 +1482,63 @@ struct TestInplaceArithCompareImplS {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
process_div_special(BitsetT& bitset,
|
||||
CompareOpType cmp_op,
|
||||
const T left_v,
|
||||
const T right_v,
|
||||
const T value_v) {
|
||||
// test a single special point for the division
|
||||
|
||||
using HT = ArithHighPrecisionType<T>;
|
||||
|
||||
const size_t n = bitset.size();
|
||||
|
||||
std::vector<T> left(n, left_v);
|
||||
const HT right_operand = right_v;
|
||||
const HT value = value_v;
|
||||
|
||||
bitset.inplace_arith_compare(
|
||||
left.data(), right_operand, value, n, ArithOpType::Div, cmp_op);
|
||||
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
if (cmp_op == CompareOpType::EQ) {
|
||||
ASSERT_EQ((left[i] / right_operand) == value, bitset[i])
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GE) {
|
||||
ASSERT_EQ((left[i] / right_operand) >= value, bitset[i])
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::GT) {
|
||||
ASSERT_EQ((left[i] / right_operand) > value, bitset[i])
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LE) {
|
||||
ASSERT_EQ((left[i] / right_operand) <= value, bitset[i])
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::LT) {
|
||||
ASSERT_EQ((left[i] / right_operand) < value, bitset[i])
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else if (cmp_op == CompareOpType::NE) {
|
||||
ASSERT_EQ((left[i] / right_operand) != value, bitset[i])
|
||||
<< i << " " << size_t(cmp_op) << " " << left[i] << " "
|
||||
<< right_operand << " " << value;
|
||||
} else {
|
||||
ASSERT_TRUE(false) << "Not implemented";
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BitsetT>
|
||||
struct TestInplaceArithCompareImplS<BitsetT, std::string> {
|
||||
static void
|
||||
process(BitsetT&, ArithOpType, CompareOpType) {
|
||||
process(
|
||||
BitsetT&, ArithOpType, CompareOpType, const int32_t, const int32_t) {
|
||||
// does nothing
|
||||
}
|
||||
};
|
||||
@ -1446,40 +1546,102 @@ struct TestInplaceArithCompareImplS<BitsetT, std::string> {
|
||||
template <typename BitsetT, typename T>
|
||||
void
|
||||
TestInplaceArithCompareImpl() {
|
||||
for (const size_t n : typical_sizes) {
|
||||
for (const auto a_op : typical_arith_ops) {
|
||||
for (const auto cmp_op : typical_compare_ops) {
|
||||
BitsetT bitset(n);
|
||||
bitset.reset();
|
||||
if constexpr (std::is_floating_point_v<T>)
|
||||
for (const size_t n : typical_sizes) {
|
||||
for (const auto a_op : typical_arith_ops) {
|
||||
for (const auto cmp_op : typical_compare_ops) {
|
||||
// test both positive, zero and negative
|
||||
for (const int32_t right_operand : {2, 0, -2}) {
|
||||
if ((!std::is_floating_point_v<T> ||
|
||||
a_op == milvus::bitset::ArithOpType::Mod) &&
|
||||
right_operand == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (print_log) {
|
||||
printf(
|
||||
"Testing bitset, n=%zd, a_op=%zd\n", n, (size_t)a_op);
|
||||
// test both positive, zero and negative
|
||||
for (const int32_t value : {2, 0, -2}) {
|
||||
BitsetT bitset(n);
|
||||
bitset.reset();
|
||||
|
||||
if (print_log) {
|
||||
printf(
|
||||
"Testing bitset, n=%zd, a_op=%zd, "
|
||||
"cmp_op=%zd, right_operand=%d\n",
|
||||
n,
|
||||
(size_t)a_op,
|
||||
(size_t)cmp_op,
|
||||
right_operand);
|
||||
}
|
||||
|
||||
TestInplaceArithCompareImplS<BitsetT, T>::process(
|
||||
bitset, a_op, cmp_op, right_operand, value);
|
||||
|
||||
for (const size_t offset : typical_offsets) {
|
||||
if (offset >= n) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bitset.reset();
|
||||
auto view = bitset.view(offset);
|
||||
|
||||
if (print_log) {
|
||||
printf(
|
||||
"Testing bitset view, n=%zd, "
|
||||
"offset=%zd, a_op=%zd, cmp_op=%zd, "
|
||||
"right_operand=%d\n",
|
||||
n,
|
||||
offset,
|
||||
(size_t)a_op,
|
||||
(size_t)cmp_op,
|
||||
right_operand);
|
||||
}
|
||||
|
||||
TestInplaceArithCompareImplS<
|
||||
decltype(view),
|
||||
T>::process(view,
|
||||
a_op,
|
||||
cmp_op,
|
||||
right_operand,
|
||||
value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TestInplaceArithCompareImplS<BitsetT, T>::process(
|
||||
bitset, a_op, cmp_op);
|
||||
if constexpr (std::is_floating_point_v<T>) {
|
||||
// test various special use cases for IEEE-754 for the division operation.
|
||||
std::vector<T> variety = {0,
|
||||
1,
|
||||
-1,
|
||||
std::numeric_limits<T>::quiet_NaN(),
|
||||
-std::numeric_limits<T>::quiet_NaN(),
|
||||
std::numeric_limits<T>::infinity(),
|
||||
-std::numeric_limits<T>::infinity()};
|
||||
|
||||
for (const size_t offset : typical_offsets) {
|
||||
if (offset >= n) {
|
||||
continue;
|
||||
for (const auto cmp_op : typical_compare_ops) {
|
||||
for (const T left_v : variety) {
|
||||
for (const T right_v : variety) {
|
||||
for (const T value_v : variety) {
|
||||
// 40 should be sufficient to test avx512
|
||||
BitsetT bitset(40);
|
||||
bitset.reset();
|
||||
|
||||
if (print_log) {
|
||||
printf(
|
||||
"Testing bitset div special case, cmp_op=%zd, "
|
||||
"left_v=%f, right_v=%f, value_v=%f\n",
|
||||
(size_t)cmp_op,
|
||||
left_v,
|
||||
right_v,
|
||||
value_v);
|
||||
}
|
||||
|
||||
TestInplaceArithCompareImplS<BitsetT, T>::
|
||||
process_div_special(
|
||||
bitset, cmp_op, left_v, right_v, value_v);
|
||||
}
|
||||
|
||||
bitset.reset();
|
||||
auto view = bitset.view(offset);
|
||||
|
||||
if (print_log) {
|
||||
printf(
|
||||
"Testing bitset view, n=%zd, offset=%zd, a_op=%zd, "
|
||||
"cmp_op=%zd\n",
|
||||
n,
|
||||
offset,
|
||||
(size_t)a_op,
|
||||
(size_t)cmp_op);
|
||||
}
|
||||
|
||||
TestInplaceArithCompareImplS<decltype(view), T>::process(
|
||||
view, a_op, cmp_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user