mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: add compare simd function (#29432)
#26137 Signed-off-by: luzhang <luzhang@zilliz.com> Co-authored-by: luzhang <luzhang@zilliz.com>
This commit is contained in:
parent
a3bae80b59
commit
d07197ab1a
@ -24,6 +24,7 @@
|
||||
#include "common/Vector.h"
|
||||
#include "exec/expression/Expr.h"
|
||||
#include "segcore/SegmentInterface.h"
|
||||
#include "simd/interface.h"
|
||||
|
||||
namespace milvus {
|
||||
namespace exec {
|
||||
@ -41,7 +42,7 @@ using ChunkDataAccessor = std::function<const number(int)>;
|
||||
template <typename T, typename U, proto::plan::OpType op>
|
||||
struct CompareElementFunc {
|
||||
void
|
||||
operator()(const T* left, const U* right, size_t size, bool* res) {
|
||||
operator_base(const T* left, const U* right, size_t size, bool* res) {
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if constexpr (op == proto::plan::OpType::Equal) {
|
||||
res[i] = left[i] == right[i];
|
||||
@ -63,6 +64,24 @@ struct CompareElementFunc {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
operator()(const T* left, const U* right, size_t size, bool* res) {
|
||||
#if defined(USE_DYNAMIC_SIMD)
|
||||
if constexpr (std::is_same_v<T, U>) {
|
||||
milvus::simd::compare_col_func<T>(
|
||||
static_cast<milvus::simd::CompareType>(op),
|
||||
left,
|
||||
right,
|
||||
size,
|
||||
res);
|
||||
} else {
|
||||
operator_base(left, right, size, res);
|
||||
}
|
||||
#else
|
||||
operator_base(left, right, size, res);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
class PhyCompareFilterExpr : public Expr {
|
||||
|
||||
@ -2632,30 +2632,9 @@ ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType {
|
||||
return index->In(n, terms.data());
|
||||
};
|
||||
|
||||
#if defined(USE_DYNAMIC_SIMD)
|
||||
std::function<bool(MayConstRef<T> x)> elem_func;
|
||||
if (n <= milvus::simd::TERM_EXPR_IN_SIZE_THREAD) {
|
||||
elem_func = [&terms, &term_set, n](MayConstRef<T> x) {
|
||||
if constexpr (std::is_integral<T>::value ||
|
||||
std::is_floating_point<T>::value) {
|
||||
return milvus::simd::find_term_func<T>(terms.data(), n, x);
|
||||
} else {
|
||||
// For string type, simd performance not better than set mode
|
||||
static_assert(std::is_same<T, std::string>::value ||
|
||||
std::is_same<T, std::string_view>::value);
|
||||
return term_set.find(x) != term_set.end();
|
||||
}
|
||||
};
|
||||
} else {
|
||||
elem_func = [&term_set, n](MayConstRef<T> x) {
|
||||
return term_set.find(x) != term_set.end();
|
||||
};
|
||||
}
|
||||
#else
|
||||
auto elem_func = [&term_set](MayConstRef<T> x) {
|
||||
return term_set.find(x) != term_set.end();
|
||||
};
|
||||
#endif
|
||||
|
||||
auto default_skip_index_func = [&](const SkipIndex& skipIndex,
|
||||
FieldId fieldId,
|
||||
|
||||
@ -25,7 +25,8 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64")
|
||||
)
|
||||
set_source_files_properties(sse4.cpp PROPERTIES COMPILE_FLAGS "-msse4.2")
|
||||
set_source_files_properties(avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2")
|
||||
set_source_files_properties(avx512.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx512bw")
|
||||
set_source_files_properties(avx512.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512vl -mavx512dq -mavx512bw")
|
||||
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm*")
|
||||
# TODO: add arm cpu simd
|
||||
message ("simd using arm mode")
|
||||
@ -37,4 +38,4 @@ endif()
|
||||
add_library(milvus_simd ${MILVUS_SIMD_SRCS})
|
||||
|
||||
# Link the milvus_simd library with other libraries as needed
|
||||
target_link_libraries(milvus_simd milvus_log)
|
||||
target_link_libraries(milvus_simd milvus_log)
|
||||
|
||||
@ -29,11 +29,12 @@ GetBitsetBlockAVX2(const bool* src) {
|
||||
// BitsetBlockType has 64 bits
|
||||
__m256i highbit = _mm256_set1_epi8(0x7F);
|
||||
uint32_t tmp[8];
|
||||
for (size_t i = 0; i < 2; i += 1) {
|
||||
__m256i boolvec = _mm256_loadu_si256((__m256i*)&src[i * 32]);
|
||||
__m256i highbits = _mm256_add_epi8(boolvec, highbit);
|
||||
tmp[i] = _mm256_movemask_epi8(highbits);
|
||||
}
|
||||
__m256i boolvec = _mm256_loadu_si256((__m256i*)(src));
|
||||
__m256i highbits = _mm256_add_epi8(boolvec, highbit);
|
||||
tmp[0] = _mm256_movemask_epi8(highbits);
|
||||
boolvec = _mm256_loadu_si256((__m256i*)(src + 32));
|
||||
highbits = _mm256_add_epi8(boolvec, highbit);
|
||||
tmp[1] = _mm256_movemask_epi8(highbits);
|
||||
|
||||
__m256i tmpvec = _mm256_loadu_si256((__m256i*)tmp);
|
||||
BitsetBlockType res[4];
|
||||
@ -65,9 +66,9 @@ FindTermAVX2(const bool* src, size_t vec_size, bool val) {
|
||||
__m256i ymm_data;
|
||||
size_t num_chunks = vec_size / 32;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
for (size_t i = 0; i < 32 * num_chunks; i += 32) {
|
||||
ymm_data =
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 32 * i));
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
|
||||
__m256i ymm_match = _mm256_cmpeq_epi8(ymm_data, ymm_target);
|
||||
int mask = _mm256_movemask_epi8(ymm_match);
|
||||
if (mask != 0) {
|
||||
@ -90,9 +91,9 @@ FindTermAVX2(const int8_t* src, size_t vec_size, int8_t val) {
|
||||
__m256i ymm_data;
|
||||
size_t num_chunks = vec_size / 32;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
for (size_t i = 0; i < 32 * num_chunks; i += 32) {
|
||||
ymm_data =
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 32 * i));
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
|
||||
__m256i ymm_match = _mm256_cmpeq_epi8(ymm_data, ymm_target);
|
||||
int mask = _mm256_movemask_epi8(ymm_match);
|
||||
if (mask != 0) {
|
||||
@ -114,10 +115,9 @@ FindTermAVX2(const int16_t* src, size_t vec_size, int16_t val) {
|
||||
__m256i ymm_target = _mm256_set1_epi16(val);
|
||||
__m256i ymm_data;
|
||||
size_t num_chunks = vec_size / 16;
|
||||
size_t remaining_size = vec_size % 16;
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
for (size_t i = 0; i < 16 * num_chunks; i += 16) {
|
||||
ymm_data =
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 16 * i));
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
|
||||
__m256i ymm_match = _mm256_cmpeq_epi16(ymm_data, ymm_target);
|
||||
int mask = _mm256_movemask_epi8(ymm_match);
|
||||
if (mask != 0) {
|
||||
@ -141,9 +141,9 @@ FindTermAVX2(const int32_t* src, size_t vec_size, int32_t val) {
|
||||
size_t num_chunks = vec_size / 8;
|
||||
size_t remaining_size = vec_size % 8;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
for (size_t i = 0; i < 8 * num_chunks; i += 8) {
|
||||
ymm_data =
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 8 * i));
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
|
||||
__m256i ymm_match = _mm256_cmpeq_epi32(ymm_data, ymm_target);
|
||||
int mask = _mm256_movemask_epi8(ymm_match);
|
||||
if (mask != 0) {
|
||||
@ -163,11 +163,10 @@ FindTermAVX2(const int64_t* src, size_t vec_size, int64_t val) {
|
||||
__m256i ymm_target = _mm256_set1_epi64x(val);
|
||||
__m256i ymm_data;
|
||||
size_t num_chunks = vec_size / 4;
|
||||
size_t remaining_size = vec_size % 4;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
for (size_t i = 0; i < 4 * num_chunks; i += 4) {
|
||||
ymm_data =
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + 4 * i));
|
||||
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
|
||||
__m256i ymm_match = _mm256_cmpeq_epi64(ymm_data, ymm_target);
|
||||
int mask = _mm256_movemask_epi8(ymm_match);
|
||||
if (mask != 0) {
|
||||
@ -190,8 +189,8 @@ FindTermAVX2(const float* src, size_t vec_size, float val) {
|
||||
__m256 ymm_data;
|
||||
size_t num_chunks = vec_size / 8;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
ymm_data = _mm256_loadu_ps(src + 8 * i);
|
||||
for (size_t i = 0; i < 8 * num_chunks; i += 8) {
|
||||
ymm_data = _mm256_loadu_ps(src + i);
|
||||
__m256 ymm_match = _mm256_cmp_ps(ymm_data, ymm_target, _CMP_EQ_OQ);
|
||||
int mask = _mm256_movemask_ps(ymm_match);
|
||||
if (mask != 0) {
|
||||
@ -214,8 +213,8 @@ FindTermAVX2(const double* src, size_t vec_size, double val) {
|
||||
__m256d ymm_data;
|
||||
size_t num_chunks = vec_size / 4;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
ymm_data = _mm256_loadu_pd(src + 8 * i);
|
||||
for (size_t i = 0; i < 4 * num_chunks; i += 4) {
|
||||
ymm_data = _mm256_loadu_pd(src + i);
|
||||
__m256d ymm_match = _mm256_cmp_pd(ymm_data, ymm_target, _CMP_EQ_OQ);
|
||||
int mask = _mm256_movemask_pd(ymm_match);
|
||||
if (mask != 0) {
|
||||
|
||||
@ -25,9 +25,9 @@ FindTermAVX512(const bool* src, size_t vec_size, bool val) {
|
||||
__m512i zmm_data;
|
||||
size_t num_chunks = vec_size / 64;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
for (size_t i = 0; i < 64 * num_chunks; i += 64) {
|
||||
zmm_data =
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 64 * i));
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i));
|
||||
__mmask64 mask = _mm512_cmpeq_epi8_mask(zmm_data, zmm_target);
|
||||
if (mask != 0) {
|
||||
return true;
|
||||
@ -49,9 +49,9 @@ FindTermAVX512(const int8_t* src, size_t vec_size, int8_t val) {
|
||||
__m512i zmm_data;
|
||||
size_t num_chunks = vec_size / 64;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
for (size_t i = 0; i < 64 * num_chunks; i += 64) {
|
||||
zmm_data =
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 64 * i));
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i));
|
||||
__mmask64 mask = _mm512_cmpeq_epi8_mask(zmm_data, zmm_target);
|
||||
if (mask != 0) {
|
||||
return true;
|
||||
@ -73,9 +73,9 @@ FindTermAVX512(const int16_t* src, size_t vec_size, int16_t val) {
|
||||
__m512i zmm_data;
|
||||
size_t num_chunks = vec_size / 32;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
for (size_t i = 0; i < 32 * num_chunks; i += 32) {
|
||||
zmm_data =
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 32 * i));
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i));
|
||||
__mmask32 mask = _mm512_cmpeq_epi16_mask(zmm_data, zmm_target);
|
||||
if (mask != 0) {
|
||||
return true;
|
||||
@ -97,9 +97,9 @@ FindTermAVX512(const int32_t* src, size_t vec_size, int32_t val) {
|
||||
__m512i zmm_data;
|
||||
size_t num_chunks = vec_size / 16;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
for (size_t i = 0; i < 16 * num_chunks; i += 16) {
|
||||
zmm_data =
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 16 * i));
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i));
|
||||
__mmask16 mask = _mm512_cmpeq_epi32_mask(zmm_data, zmm_target);
|
||||
if (mask != 0) {
|
||||
return true;
|
||||
@ -121,9 +121,9 @@ FindTermAVX512(const int64_t* src, size_t vec_size, int64_t val) {
|
||||
__m512i zmm_data;
|
||||
size_t num_chunks = vec_size / 8;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
for (size_t i = 0; i < 8 * num_chunks; i += 8) {
|
||||
zmm_data =
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + 8 * i));
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i));
|
||||
__mmask8 mask = _mm512_cmpeq_epi64_mask(zmm_data, zmm_target);
|
||||
if (mask != 0) {
|
||||
return true;
|
||||
@ -145,8 +145,8 @@ FindTermAVX512(const float* src, size_t vec_size, float val) {
|
||||
__m512 zmm_data;
|
||||
size_t num_chunks = vec_size / 16;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
zmm_data = _mm512_loadu_ps(src + 16 * i);
|
||||
for (size_t i = 0; i < 16 * num_chunks; i += 16) {
|
||||
zmm_data = _mm512_loadu_ps(src + i);
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(zmm_data, zmm_target, _CMP_EQ_OQ);
|
||||
if (mask != 0) {
|
||||
return true;
|
||||
@ -168,8 +168,8 @@ FindTermAVX512(const double* src, size_t vec_size, double val) {
|
||||
__m512d zmm_data;
|
||||
size_t num_chunks = vec_size / 8;
|
||||
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
zmm_data = _mm512_loadu_pd(src + 8 * i);
|
||||
for (size_t i = 0; i < 8 * num_chunks; i += 8) {
|
||||
zmm_data = _mm512_loadu_pd(src + i);
|
||||
__mmask8 mask = _mm512_cmp_pd_mask(zmm_data, zmm_target, _CMP_EQ_OQ);
|
||||
if (mask != 0) {
|
||||
return true;
|
||||
@ -216,6 +216,703 @@ OrBoolAVX512(bool* left, bool* right, int64_t size) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, CompareType type>
|
||||
struct CompareOperator;
|
||||
|
||||
template <typename T>
|
||||
struct CompareOperator<T, CompareType::EQ> {
|
||||
static constexpr int ComparePredicate =
|
||||
std::is_floating_point_v<T> ? _CMP_EQ_OQ : _MM_CMPINT_EQ;
|
||||
static constexpr bool
|
||||
Op(T a, T b) {
|
||||
return a == b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CompareOperator<T, CompareType::NEQ> {
|
||||
static constexpr int ComparePredicate =
|
||||
std::is_floating_point_v<T> ? _CMP_NEQ_OQ : _MM_CMPINT_NE;
|
||||
static constexpr bool
|
||||
Op(T a, T b) {
|
||||
return a != b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CompareOperator<T, CompareType::LT> {
|
||||
static constexpr int ComparePredicate =
|
||||
std::is_floating_point_v<T> ? _CMP_LT_OQ : _MM_CMPINT_LT;
|
||||
static constexpr bool
|
||||
Op(T a, T b) {
|
||||
return a < b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CompareOperator<T, CompareType::LE> {
|
||||
static constexpr int ComparePredicate =
|
||||
std::is_floating_point_v<T> ? _CMP_LE_OQ : _MM_CMPINT_LE;
|
||||
static constexpr bool
|
||||
Op(T a, T b) {
|
||||
return a <= b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CompareOperator<T, CompareType::GT> {
|
||||
static constexpr int ComparePredicate =
|
||||
std::is_floating_point_v<T> ? _CMP_GT_OQ : _MM_CMPINT_NLE;
|
||||
static constexpr bool
|
||||
Op(T a, T b) {
|
||||
return a > b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct CompareOperator<T, CompareType::GE> {
|
||||
static constexpr int ComparePredicate =
|
||||
std::is_floating_point_v<T> ? _CMP_GE_OQ : _MM_CMPINT_NLT;
|
||||
static constexpr bool
|
||||
Op(T a, T b) {
|
||||
return a >= b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, CompareType type>
|
||||
struct CompareValAVX512Impl {
|
||||
static void
|
||||
Compare(const T* src, size_t size, T val, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
}
|
||||
};
|
||||
|
||||
template <CompareType type>
|
||||
struct CompareValAVX512Impl<int8_t, type> {
|
||||
static void
|
||||
Compare(const int8_t* src, size_t size, int8_t val, bool* res) {
|
||||
__m512i target = _mm512_set1_epi8(val);
|
||||
|
||||
int middle = size / 64 * 64;
|
||||
|
||||
for (size_t i = 0; i < middle; i += 64) {
|
||||
__m512i data =
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i));
|
||||
|
||||
__mmask64 cmp_res_mask = _mm512_cmp_epi8_mask(
|
||||
data,
|
||||
target,
|
||||
(CompareOperator<int8_t, type>::ComparePredicate));
|
||||
__m512i cmp_res = _mm512_maskz_set1_epi8(cmp_res_mask, 0x01);
|
||||
_mm512_storeu_si512(res + i, cmp_res);
|
||||
}
|
||||
|
||||
for (size_t i = middle; i < size; ++i) {
|
||||
res[i] = CompareOperator<int8_t, type>::Op(src[i], val);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <CompareType type>
|
||||
struct CompareValAVX512Impl<int16_t, type> {
|
||||
static void
|
||||
Compare(const int16_t* src, size_t size, int16_t val, bool* res) {
|
||||
__m512i target = _mm512_set1_epi16(val);
|
||||
|
||||
int middle = size / 32 * 32;
|
||||
|
||||
for (size_t i = 0; i < middle; i += 32) {
|
||||
__m512i data =
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i));
|
||||
|
||||
__mmask32 cmp_res_mask = _mm512_cmp_epi16_mask(
|
||||
data,
|
||||
target,
|
||||
(CompareOperator<int16_t, type>::ComparePredicate));
|
||||
__m256i cmp_res = _mm256_maskz_set1_epi8(cmp_res_mask, 0x01);
|
||||
_mm256_storeu_si256((__m256i*)(res + i), cmp_res);
|
||||
}
|
||||
|
||||
for (size_t i = middle; i < size; ++i) {
|
||||
res[i] = CompareOperator<int16_t, type>::Op(src[i], val);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <CompareType type>
|
||||
struct CompareValAVX512Impl<int32_t, type> {
|
||||
static void
|
||||
Compare(const int32_t* src, size_t size, int32_t val, bool* res) {
|
||||
__m512i target = _mm512_set1_epi32(val);
|
||||
|
||||
int middle = size / 16 * 16;
|
||||
|
||||
for (size_t i = 0; i < middle; i += 16) {
|
||||
__m512i data =
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i));
|
||||
|
||||
__mmask16 cmp_res_mask = _mm512_cmp_epi32_mask(
|
||||
data,
|
||||
target,
|
||||
(CompareOperator<int32_t, type>::ComparePredicate));
|
||||
__m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01);
|
||||
_mm_storeu_si128((__m128i*)(res + i), cmp_res);
|
||||
}
|
||||
|
||||
for (size_t i = middle; i < size; ++i) {
|
||||
res[i] = CompareOperator<int32_t, type>::Op(src[i], val);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <CompareType type>
|
||||
struct CompareValAVX512Impl<int64_t, type> {
|
||||
static void
|
||||
Compare(const int64_t* src, size_t size, int64_t val, bool* res) {
|
||||
__m512i target = _mm512_set1_epi64(val);
|
||||
int middle = size / 8 * 8;
|
||||
int index = 0;
|
||||
for (size_t i = 0; i < middle; i += 8) {
|
||||
__m512i data =
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(src + i));
|
||||
__mmask8 mask = _mm512_cmp_epi64_mask(
|
||||
data,
|
||||
target,
|
||||
(CompareOperator<int64_t, type>::ComparePredicate));
|
||||
__m128i cmp_res = _mm_maskz_set1_epi8(mask, 0x01);
|
||||
_mm_storeu_si64((__m128i*)(res + i), cmp_res);
|
||||
}
|
||||
|
||||
for (size_t i = middle; i < size; ++i) {
|
||||
res[i] = CompareOperator<int64_t, type>::Op(src[i], val);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <CompareType type>
|
||||
struct CompareValAVX512Impl<float, type> {
|
||||
static void
|
||||
Compare(const float* src, size_t size, float val, bool* res) {
|
||||
__m512 target = _mm512_set1_ps(val);
|
||||
|
||||
int middle = size / 16 * 16;
|
||||
|
||||
for (size_t i = 0; i < middle; i += 16) {
|
||||
__m512 data = _mm512_loadu_ps(src + i);
|
||||
|
||||
__mmask16 cmp_res_mask = _mm512_cmp_ps_mask(
|
||||
data, target, (CompareOperator<float, type>::ComparePredicate));
|
||||
__m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01);
|
||||
_mm_storeu_si128((__m128i*)(res + i), cmp_res);
|
||||
}
|
||||
|
||||
for (size_t i = middle; i < size; ++i) {
|
||||
res[i] = CompareOperator<float, type>::Op(src[i], val);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <CompareType type>
|
||||
struct CompareValAVX512Impl<double, type> {
|
||||
static void
|
||||
Compare(const double* src, size_t size, double val, bool* res) {
|
||||
__m512d target = _mm512_set1_pd(val);
|
||||
|
||||
int middle = size / 8 * 8;
|
||||
|
||||
for (size_t i = 0; i < middle; i += 8) {
|
||||
__m512d data = _mm512_loadu_pd(src + i);
|
||||
|
||||
__mmask8 cmp_res_mask = _mm512_cmp_pd_mask(
|
||||
data,
|
||||
target,
|
||||
(CompareOperator<double, type>::ComparePredicate));
|
||||
__m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01);
|
||||
_mm_storeu_si64((res + i), cmp_res);
|
||||
}
|
||||
|
||||
for (size_t i = middle; i < size; ++i) {
|
||||
res[i] = CompareOperator<double, type>::Op(src[i], val);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
EqualValAVX512(const T* src, size_t size, T val, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareValAVX512Impl<T, CompareType::EQ>::Compare(src, size, val, res);
|
||||
};
|
||||
template void
|
||||
EqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res);
|
||||
template void
|
||||
EqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res);
|
||||
template void
|
||||
EqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res);
|
||||
template void
|
||||
EqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res);
|
||||
template void
|
||||
EqualValAVX512(const float* src, size_t size, float val, bool* res);
|
||||
template void
|
||||
EqualValAVX512(const double* src, size_t size, double val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessValAVX512(const T* src, size_t size, T val, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareValAVX512Impl<T, CompareType::LT>::Compare(src, size, val, res);
|
||||
};
|
||||
template void
|
||||
LessValAVX512(const int8_t* src, size_t size, int8_t val, bool* res);
|
||||
template void
|
||||
LessValAVX512(const int16_t* src, size_t size, int16_t val, bool* res);
|
||||
template void
|
||||
LessValAVX512(const int32_t* src, size_t size, int32_t val, bool* res);
|
||||
template void
|
||||
LessValAVX512(const int64_t* src, size_t size, int64_t val, bool* res);
|
||||
template void
|
||||
LessValAVX512(const float* src, size_t size, float val, bool* res);
|
||||
template void
|
||||
LessValAVX512(const double* src, size_t size, double val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GreaterValAVX512(const T* src, size_t size, T val, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareValAVX512Impl<T, CompareType::GT>::Compare(src, size, val, res);
|
||||
};
|
||||
template void
|
||||
GreaterValAVX512(const int8_t* src, size_t size, int8_t val, bool* res);
|
||||
template void
|
||||
GreaterValAVX512(const int16_t* src, size_t size, int16_t val, bool* res);
|
||||
template void
|
||||
GreaterValAVX512(const int32_t* src, size_t size, int32_t val, bool* res);
|
||||
template void
|
||||
GreaterValAVX512(const int64_t* src, size_t size, int64_t val, bool* res);
|
||||
template void
|
||||
GreaterValAVX512(const float* src, size_t size, float val, bool* res);
|
||||
template void
|
||||
GreaterValAVX512(const double* src, size_t size, double val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
NotEqualValAVX512(const T* src, size_t size, T val, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareValAVX512Impl<T, CompareType::NEQ>::Compare(src, size, val, res);
|
||||
};
|
||||
template void
|
||||
NotEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res);
|
||||
template void
|
||||
NotEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res);
|
||||
template void
|
||||
NotEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res);
|
||||
template void
|
||||
NotEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res);
|
||||
template void
|
||||
NotEqualValAVX512(const float* src, size_t size, float val, bool* res);
|
||||
template void
|
||||
NotEqualValAVX512(const double* src, size_t size, double val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessEqualValAVX512(const T* src, size_t size, T val, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareValAVX512Impl<T, CompareType::LE>::Compare(src, size, val, res);
|
||||
};
|
||||
template void
|
||||
LessEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res);
|
||||
template void
|
||||
LessEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res);
|
||||
template void
|
||||
LessEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res);
|
||||
template void
|
||||
LessEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res);
|
||||
template void
|
||||
LessEqualValAVX512(const float* src, size_t size, float val, bool* res);
|
||||
template void
|
||||
LessEqualValAVX512(const double* src, size_t size, double val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GreaterEqualValAVX512(const T* src, size_t size, T val, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareValAVX512Impl<T, CompareType::GE>::Compare(src, size, val, res);
|
||||
};
|
||||
template void
|
||||
GreaterEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res);
|
||||
template void
|
||||
GreaterEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res);
|
||||
template void
|
||||
GreaterEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res);
|
||||
template void
|
||||
GreaterEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res);
|
||||
template void
|
||||
GreaterEqualValAVX512(const float* src, size_t size, float val, bool* res);
|
||||
template void
|
||||
GreaterEqualValAVX512(const double* src, size_t size, double val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
CompareColumnAVX512(const T* left, const T* right, size_t size, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
}
|
||||
|
||||
template <typename T, CompareType type>
|
||||
struct CompareColumnAVX512Impl {
|
||||
static void
|
||||
Compare(const T* left, const T* right, size_t size, bool* res) {
|
||||
static_assert(std::is_integral_v<T>, "T must be integral type");
|
||||
|
||||
int batch_size = 512 / (sizeof(T) * 8);
|
||||
int middle = size / batch_size * batch_size;
|
||||
|
||||
for (size_t i = 0; i < middle; i += batch_size) {
|
||||
__m512i left_reg =
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(left + i));
|
||||
__m512i right_reg =
|
||||
_mm512_loadu_si512(reinterpret_cast<const __m512i*>(right + i));
|
||||
|
||||
if constexpr (std::is_same_v<T, int8_t>) {
|
||||
__mmask64 cmp_res_mask = _mm512_cmp_epi8_mask(
|
||||
left_reg,
|
||||
right_reg,
|
||||
(CompareOperator<T, type>::ComparePredicate));
|
||||
|
||||
__m512i cmp_res = _mm512_maskz_set1_epi8(cmp_res_mask, 0x01);
|
||||
_mm512_storeu_si512(res + i, cmp_res);
|
||||
} else if constexpr (std::is_same_v<T, int16_t>) {
|
||||
__mmask32 cmp_res_mask = _mm512_cmp_epi16_mask(
|
||||
left_reg,
|
||||
right_reg,
|
||||
(CompareOperator<T, type>::ComparePredicate));
|
||||
|
||||
__m256i cmp_res = _mm256_maskz_set1_epi8(cmp_res_mask, 0x01);
|
||||
_mm256_storeu_si256((__m256i*)(res + i), cmp_res);
|
||||
} else if constexpr (std::is_same_v<T, int32_t>) {
|
||||
__mmask16 cmp_res_mask = _mm512_cmp_epi32_mask(
|
||||
left_reg,
|
||||
right_reg,
|
||||
(CompareOperator<T, type>::ComparePredicate));
|
||||
|
||||
__m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01);
|
||||
_mm_storeu_si128((__m128i*)(res + i), cmp_res);
|
||||
} else if constexpr (std::is_same_v<T, int64_t>) {
|
||||
__mmask8 mask = _mm512_cmp_epi64_mask(
|
||||
left_reg,
|
||||
right_reg,
|
||||
(CompareOperator<T, type>::ComparePredicate));
|
||||
|
||||
__m128i cmp_res = _mm_maskz_set1_epi8(mask, 0x01);
|
||||
_mm_storeu_si64((__m128i*)(res + i), cmp_res);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = middle; i < size; ++i) {
|
||||
res[i] = CompareOperator<T, type>::Op(left[i], right[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <CompareType type>
|
||||
struct CompareColumnAVX512Impl<float, type> {
|
||||
static void
|
||||
Compare(const float* left, const float* right, size_t size, bool* res) {
|
||||
int batch_size = 512 / (sizeof(float) * 8);
|
||||
int middle = size / batch_size * batch_size;
|
||||
|
||||
for (size_t i = 0; i < middle; i += batch_size) {
|
||||
__m512 left_reg =
|
||||
_mm512_loadu_ps(reinterpret_cast<const __m512*>(left + i));
|
||||
__m512 right_reg =
|
||||
_mm512_loadu_ps(reinterpret_cast<const __m512*>(right + i));
|
||||
|
||||
__mmask16 cmp_res_mask = _mm512_cmp_ps_mask(
|
||||
left_reg,
|
||||
right_reg,
|
||||
(CompareOperator<float, type>::ComparePredicate));
|
||||
|
||||
__m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01);
|
||||
_mm_storeu_si128((__m128i*)(res + i), cmp_res);
|
||||
}
|
||||
|
||||
for (size_t i = middle; i < size; ++i) {
|
||||
res[i] = CompareOperator<float, type>::Op(left[i], right[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <CompareType type>
|
||||
struct CompareColumnAVX512Impl<double, type> {
|
||||
static void
|
||||
Compare(const double* left, const double* right, size_t size, bool* res) {
|
||||
int batch_size = 512 / (sizeof(double) * 8);
|
||||
int middle = size / batch_size * batch_size;
|
||||
|
||||
for (size_t i = 0; i < middle; i += batch_size) {
|
||||
__m512d left_reg =
|
||||
_mm512_loadu_pd(reinterpret_cast<const __m512d*>(left + i));
|
||||
__m512d right_reg =
|
||||
_mm512_loadu_pd(reinterpret_cast<const __m512d*>(right + i));
|
||||
|
||||
__mmask8 cmp_res_mask = _mm512_cmp_pd_mask(
|
||||
left_reg,
|
||||
right_reg,
|
||||
(CompareOperator<double, type>::ComparePredicate));
|
||||
|
||||
__m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01);
|
||||
_mm_storeu_si64((res + i), cmp_res);
|
||||
}
|
||||
|
||||
for (size_t i = middle; i < size; ++i) {
|
||||
res[i] = CompareOperator<double, type>::Op(left[i], right[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
EqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareColumnAVX512Impl<T, CompareType::EQ>::Compare(
|
||||
left, right, size, res);
|
||||
};
|
||||
|
||||
template void
|
||||
EqualColumnAVX512(const int8_t* left,
|
||||
const int8_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
EqualColumnAVX512(const int16_t* left,
|
||||
const int16_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
EqualColumnAVX512(const int32_t* left,
|
||||
const int32_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
EqualColumnAVX512(const int64_t* left,
|
||||
const int64_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
EqualColumnAVX512(const float* left,
|
||||
const float* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
EqualColumnAVX512(const double* left,
|
||||
const double* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessColumnAVX512(const T* left, const T* right, size_t size, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareColumnAVX512Impl<T, CompareType::LT>::Compare(
|
||||
left, right, size, res);
|
||||
};
|
||||
template void
|
||||
LessColumnAVX512(const int8_t* left,
|
||||
const int8_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
LessColumnAVX512(const int16_t* left,
|
||||
const int16_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
LessColumnAVX512(const int32_t* left,
|
||||
const int32_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
LessColumnAVX512(const int64_t* left,
|
||||
const int64_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
LessColumnAVX512(const float* left, const float* right, size_t size, bool* res);
|
||||
template void
|
||||
LessColumnAVX512(const double* left,
|
||||
const double* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GreaterColumnAVX512(const T* left, const T* right, size_t size, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareColumnAVX512Impl<T, CompareType::GT>::Compare(
|
||||
left, right, size, res);
|
||||
};
|
||||
template void
|
||||
GreaterColumnAVX512(const int8_t* left,
|
||||
const int8_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
GreaterColumnAVX512(const int16_t* left,
|
||||
const int16_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
GreaterColumnAVX512(const int32_t* left,
|
||||
const int32_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
GreaterColumnAVX512(const int64_t* left,
|
||||
const int64_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
GreaterColumnAVX512(const float* left,
|
||||
const float* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
GreaterColumnAVX512(const double* left,
|
||||
const double* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareColumnAVX512Impl<T, CompareType::LE>::Compare(
|
||||
left, right, size, res);
|
||||
};
|
||||
template void
|
||||
LessEqualColumnAVX512(const int8_t* left,
|
||||
const int8_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
LessEqualColumnAVX512(const int16_t* left,
|
||||
const int16_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
LessEqualColumnAVX512(const int32_t* left,
|
||||
const int32_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
LessEqualColumnAVX512(const int64_t* left,
|
||||
const int64_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
LessEqualColumnAVX512(const float* left,
|
||||
const float* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
LessEqualColumnAVX512(const double* left,
|
||||
const double* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GreaterEqualColumnAVX512(const T* left,
|
||||
const T* right,
|
||||
size_t size,
|
||||
bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareColumnAVX512Impl<T, CompareType::GE>::Compare(
|
||||
left, right, size, res);
|
||||
};
|
||||
template void
|
||||
GreaterEqualColumnAVX512(const int8_t* left,
|
||||
const int8_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
GreaterEqualColumnAVX512(const int16_t* left,
|
||||
const int16_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
GreaterEqualColumnAVX512(const int32_t* left,
|
||||
const int32_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
GreaterEqualColumnAVX512(const int64_t* left,
|
||||
const int64_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
GreaterEqualColumnAVX512(const float* left,
|
||||
const float* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
GreaterEqualColumnAVX512(const double* left,
|
||||
const double* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
NotEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) {
|
||||
static_assert(std::is_integral_v<T> || std::is_floating_point_v<T>,
|
||||
"T must be integral or float/double type");
|
||||
CompareColumnAVX512Impl<T, CompareType::NEQ>::Compare(
|
||||
left, right, size, res);
|
||||
};
|
||||
|
||||
template void
|
||||
NotEqualColumnAVX512(const int8_t* left,
|
||||
const int8_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
NotEqualColumnAVX512(const int16_t* left,
|
||||
const int16_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
NotEqualColumnAVX512(const int32_t* left,
|
||||
const int32_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
NotEqualColumnAVX512(const int64_t* left,
|
||||
const int64_t* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
NotEqualColumnAVX512(const float* left,
|
||||
const float* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
template void
|
||||
NotEqualColumnAVX512(const double* left,
|
||||
const double* right,
|
||||
size_t size,
|
||||
bool* res);
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
#endif
|
||||
|
||||
@ -61,5 +61,53 @@ AndBoolAVX512(bool* left, bool* right, int64_t size);
|
||||
void
|
||||
OrBoolAVX512(bool* left, bool* right, int64_t size);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
EqualValAVX512(const T* src, size_t size, T val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessValAVX512(const T* src, size_t size, T val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GreaterValAVX512(const T* src, size_t size, T val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
NotEqualValAVX512(const T* src, size_t size, T val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessEqualValAVX512(const T* src, size_t size, T val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GreaterEqualValAVX512(const T* src, size_t size, T val, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
EqualColumnAVX512(const T* left, const T* right, size_t size, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessColumnAVX512(const T* left, const T* right, size_t size, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GreaterColumnAVX512(const T* left, const T* right, size_t size, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GreaterEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res);
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
NotEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res);
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
|
||||
@ -40,5 +40,14 @@ const int TERM_EXPR_IN_SIZE_THREAD = 50;
|
||||
std::is_same<T, float>::value || std::is_same<T, double>::value, \
|
||||
Message);
|
||||
|
||||
enum class CompareType {
|
||||
GT = 1,
|
||||
GE = 2,
|
||||
LT = 3,
|
||||
LE = 4,
|
||||
EQ = 5,
|
||||
NEQ = 6,
|
||||
};
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
|
||||
@ -32,19 +32,6 @@
|
||||
namespace milvus {
|
||||
namespace simd {
|
||||
|
||||
#if defined(__x86_64__)
|
||||
bool use_avx512 = true;
|
||||
bool use_avx2 = true;
|
||||
bool use_sse4_2 = true;
|
||||
bool use_sse2 = true;
|
||||
|
||||
bool use_bitset_sse2;
|
||||
bool use_find_term_sse2;
|
||||
bool use_find_term_sse4_2;
|
||||
bool use_find_term_avx2;
|
||||
bool use_find_term_avx512;
|
||||
#endif
|
||||
|
||||
decltype(get_bitset_block) get_bitset_block = GetBitsetBlockRef;
|
||||
decltype(all_false) all_false = AllFalseRef;
|
||||
decltype(all_true) all_true = AllTrueRef;
|
||||
@ -52,20 +39,124 @@ decltype(invert_bool) invert_bool = InvertBoolRef;
|
||||
decltype(and_bool) and_bool = AndBoolRef;
|
||||
decltype(or_bool) or_bool = OrBoolRef;
|
||||
|
||||
FindTermPtr<bool> find_term_bool = FindTermRef<bool>;
|
||||
FindTermPtr<int8_t> find_term_int8 = FindTermRef<int8_t>;
|
||||
FindTermPtr<int16_t> find_term_int16 = FindTermRef<int16_t>;
|
||||
FindTermPtr<int32_t> find_term_int32 = FindTermRef<int32_t>;
|
||||
FindTermPtr<int64_t> find_term_int64 = FindTermRef<int64_t>;
|
||||
FindTermPtr<float> find_term_float = FindTermRef<float>;
|
||||
FindTermPtr<double> find_term_double = FindTermRef<double>;
|
||||
#define DECLARE_FIND_TERM_PTR(type) \
|
||||
FindTermPtr<type> find_term_##type = FindTermRef<type>;
|
||||
DECLARE_FIND_TERM_PTR(bool)
|
||||
DECLARE_FIND_TERM_PTR(int8_t)
|
||||
DECLARE_FIND_TERM_PTR(int16_t)
|
||||
DECLARE_FIND_TERM_PTR(int32_t)
|
||||
DECLARE_FIND_TERM_PTR(int64_t)
|
||||
DECLARE_FIND_TERM_PTR(float)
|
||||
DECLARE_FIND_TERM_PTR(double)
|
||||
|
||||
#define DECLARE_COMPARE_VAL_PTR(prefix, RefFunc, type) \
|
||||
CompareValPtr<type> prefix##_##type = RefFunc<type>;
|
||||
|
||||
DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, bool)
|
||||
DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int8_t)
|
||||
DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int16_t)
|
||||
DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int32_t)
|
||||
DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int64_t)
|
||||
DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, float)
|
||||
DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, double)
|
||||
|
||||
DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, bool)
|
||||
DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int8_t)
|
||||
DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int16_t)
|
||||
DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int32_t)
|
||||
DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int64_t)
|
||||
DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, float)
|
||||
DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, double)
|
||||
|
||||
DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, bool)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int8_t)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int16_t)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int32_t)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int64_t)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, float)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, double)
|
||||
|
||||
DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, bool)
|
||||
DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int8_t)
|
||||
DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int16_t)
|
||||
DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int32_t)
|
||||
DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int64_t)
|
||||
DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, float)
|
||||
DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, double)
|
||||
|
||||
DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, bool)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int8_t)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int16_t)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int32_t)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int64_t)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, float)
|
||||
DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, double)
|
||||
|
||||
DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, bool)
|
||||
DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int8_t)
|
||||
DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int16_t)
|
||||
DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int32_t)
|
||||
DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int64_t)
|
||||
DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, float)
|
||||
DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, double)
|
||||
|
||||
#define DECLARE_COMPARE_COL_PTR(prefix, RefFunc, type) \
|
||||
CompareColPtr<type> prefix##_##type = RefFunc<type>;
|
||||
|
||||
DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, bool)
|
||||
DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int8_t)
|
||||
DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int16_t)
|
||||
DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int32_t)
|
||||
DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int64_t)
|
||||
DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, float)
|
||||
DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, double)
|
||||
|
||||
DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, bool)
|
||||
DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int8_t)
|
||||
DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int16_t)
|
||||
DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int32_t)
|
||||
DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int64_t)
|
||||
DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, float)
|
||||
DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, double)
|
||||
|
||||
DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, bool)
|
||||
DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int8_t)
|
||||
DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int16_t)
|
||||
DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int32_t)
|
||||
DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int64_t)
|
||||
DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, float)
|
||||
DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, double)
|
||||
|
||||
DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, bool)
|
||||
DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int8_t)
|
||||
DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int16_t)
|
||||
DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int32_t)
|
||||
DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int64_t)
|
||||
DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, float)
|
||||
DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, double)
|
||||
|
||||
DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, bool)
|
||||
DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int8_t)
|
||||
DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int16_t)
|
||||
DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int32_t)
|
||||
DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int64_t)
|
||||
DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, float)
|
||||
DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, double)
|
||||
|
||||
DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, bool)
|
||||
DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int8_t)
|
||||
DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int16_t)
|
||||
DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int32_t)
|
||||
DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int64_t)
|
||||
DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, float)
|
||||
DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, double)
|
||||
|
||||
#if defined(__x86_64__)
|
||||
bool
|
||||
cpu_support_avx512() {
|
||||
InstructionSet& instruction_set_inst = InstructionSet::GetInstance();
|
||||
return (instruction_set_inst.AVX512F() && instruction_set_inst.AVX512DQ() &&
|
||||
instruction_set_inst.AVX512BW());
|
||||
instruction_set_inst.AVX512BW() && instruction_set_inst.AVX512VL());
|
||||
}
|
||||
|
||||
bool
|
||||
@ -87,95 +178,77 @@ cpu_support_sse2() {
|
||||
}
|
||||
#endif
|
||||
|
||||
void
|
||||
static void
|
||||
bitset_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
if (use_avx512 && cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
// For now, sse2 has best performance
|
||||
get_bitset_block = GetBitsetBlockSSE2;
|
||||
use_bitset_sse2 = true;
|
||||
} else if (use_avx2 && cpu_support_avx2()) {
|
||||
simd_type = "AVX2";
|
||||
// For now, sse2 has best performance
|
||||
get_bitset_block = GetBitsetBlockSSE2;
|
||||
use_bitset_sse2 = true;
|
||||
} else if (use_sse4_2 && cpu_support_sse4_2()) {
|
||||
simd_type = "SSE4";
|
||||
get_bitset_block = GetBitsetBlockSSE2;
|
||||
use_bitset_sse2 = true;
|
||||
} else if (use_sse2 && cpu_support_sse2()) {
|
||||
// SSE2 have best performance in test.
|
||||
if (cpu_support_sse2()) {
|
||||
simd_type = "SSE2";
|
||||
get_bitset_block = GetBitsetBlockSSE2;
|
||||
use_bitset_sse2 = true;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("bitset hook simd type: {}", simd_type);
|
||||
}
|
||||
|
||||
void
|
||||
static void
|
||||
find_term_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
if (use_avx512 && cpu_support_avx512()) {
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
find_term_bool = FindTermAVX512<bool>;
|
||||
find_term_int8 = FindTermAVX512<int8_t>;
|
||||
find_term_int16 = FindTermAVX512<int16_t>;
|
||||
find_term_int32 = FindTermAVX512<int32_t>;
|
||||
find_term_int64 = FindTermAVX512<int64_t>;
|
||||
find_term_int8_t = FindTermAVX512<int8_t>;
|
||||
find_term_int16_t = FindTermAVX512<int16_t>;
|
||||
find_term_int32_t = FindTermAVX512<int32_t>;
|
||||
find_term_int64_t = FindTermAVX512<int64_t>;
|
||||
find_term_float = FindTermAVX512<float>;
|
||||
find_term_double = FindTermAVX512<double>;
|
||||
use_find_term_avx512 = true;
|
||||
} else if (use_avx2 && cpu_support_avx2()) {
|
||||
} else if (cpu_support_avx2()) {
|
||||
simd_type = "AVX2";
|
||||
find_term_bool = FindTermAVX2<bool>;
|
||||
find_term_int8 = FindTermAVX2<int8_t>;
|
||||
find_term_int16 = FindTermAVX2<int16_t>;
|
||||
find_term_int32 = FindTermAVX2<int32_t>;
|
||||
find_term_int64 = FindTermAVX2<int64_t>;
|
||||
find_term_int8_t = FindTermAVX2<int8_t>;
|
||||
find_term_int16_t = FindTermAVX2<int16_t>;
|
||||
find_term_int32_t = FindTermAVX2<int32_t>;
|
||||
find_term_int64_t = FindTermAVX2<int64_t>;
|
||||
find_term_float = FindTermAVX2<float>;
|
||||
find_term_double = FindTermAVX2<double>;
|
||||
use_find_term_avx2 = true;
|
||||
} else if (use_sse4_2 && cpu_support_sse4_2()) {
|
||||
} else if (cpu_support_sse4_2()) {
|
||||
simd_type = "SSE4";
|
||||
find_term_bool = FindTermSSE4<bool>;
|
||||
find_term_int8 = FindTermSSE4<int8_t>;
|
||||
find_term_int16 = FindTermSSE4<int16_t>;
|
||||
find_term_int32 = FindTermSSE4<int32_t>;
|
||||
find_term_int64 = FindTermSSE4<int64_t>;
|
||||
find_term_int8_t = FindTermSSE4<int8_t>;
|
||||
find_term_int16_t = FindTermSSE4<int16_t>;
|
||||
find_term_int32_t = FindTermSSE4<int32_t>;
|
||||
find_term_int64_t = FindTermSSE4<int64_t>;
|
||||
find_term_float = FindTermSSE4<float>;
|
||||
find_term_double = FindTermSSE4<double>;
|
||||
use_find_term_sse4_2 = true;
|
||||
} else if (use_sse2 && cpu_support_sse2()) {
|
||||
} else if (cpu_support_sse2()) {
|
||||
simd_type = "SSE2";
|
||||
find_term_bool = FindTermSSE2<bool>;
|
||||
find_term_int8 = FindTermSSE2<int8_t>;
|
||||
find_term_int16 = FindTermSSE2<int16_t>;
|
||||
find_term_int32 = FindTermSSE2<int32_t>;
|
||||
find_term_int64 = FindTermSSE2<int64_t>;
|
||||
find_term_int8_t = FindTermSSE2<int8_t>;
|
||||
find_term_int16_t = FindTermSSE2<int16_t>;
|
||||
find_term_int32_t = FindTermSSE2<int32_t>;
|
||||
find_term_int64_t = FindTermSSE2<int64_t>;
|
||||
find_term_float = FindTermSSE2<float>;
|
||||
find_term_double = FindTermSSE2<double>;
|
||||
use_find_term_sse2 = true;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("find term hook simd type: {}", simd_type);
|
||||
}
|
||||
|
||||
void
|
||||
static void
|
||||
all_boolean_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
if (use_sse2 && cpu_support_sse2()) {
|
||||
if (cpu_support_sse2()) {
|
||||
simd_type = "SSE2";
|
||||
all_false = AllFalseSSE2;
|
||||
all_true = AllTrueSSE2;
|
||||
@ -189,13 +262,13 @@ all_boolean_hook() {
|
||||
LOG_INFO("AllFalse/AllTrue hook simd type: {}", simd_type);
|
||||
}
|
||||
|
||||
void
|
||||
static void
|
||||
invert_boolean_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
if (use_sse2 && cpu_support_sse2()) {
|
||||
if (cpu_support_sse2()) {
|
||||
simd_type = "SSE2";
|
||||
invert_bool = InvertBoolSSE2;
|
||||
}
|
||||
@ -207,21 +280,21 @@ invert_boolean_hook() {
|
||||
LOG_INFO("InvertBoolean hook simd type: {}", simd_type);
|
||||
}
|
||||
|
||||
void
|
||||
static void
|
||||
logical_boolean_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
if (use_avx512 && cpu_support_avx512()) {
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
and_bool = AndBoolAVX512;
|
||||
or_bool = OrBoolAVX512;
|
||||
} else if (use_avx2 && cpu_support_avx2()) {
|
||||
} else if (cpu_support_avx2()) {
|
||||
simd_type = "AVX2";
|
||||
and_bool = AndBoolAVX2;
|
||||
or_bool = OrBoolAVX2;
|
||||
} else if (use_sse2 && cpu_support_sse2()) {
|
||||
} else if (cpu_support_sse2()) {
|
||||
simd_type = "SSE2";
|
||||
and_bool = AndBoolSSE2;
|
||||
or_bool = OrBoolSSE2;
|
||||
@ -234,17 +307,287 @@ logical_boolean_hook() {
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("InvertBoolean hook simd type: {}", simd_type);
|
||||
}
|
||||
void
|
||||
|
||||
static void
|
||||
boolean_hook() {
|
||||
all_boolean_hook();
|
||||
invert_boolean_hook();
|
||||
logical_boolean_hook();
|
||||
}
|
||||
|
||||
static void
|
||||
equal_val_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
equal_val_int8_t = EqualValAVX512<int8_t>;
|
||||
equal_val_int16_t = EqualValAVX512<int16_t>;
|
||||
equal_val_int32_t = EqualValAVX512<int32_t>;
|
||||
equal_val_int64_t = EqualValAVX512<int64_t>;
|
||||
equal_val_float = EqualValAVX512<float>;
|
||||
equal_val_double = EqualValAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("equal val hook simd type: {} ", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
less_val_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
less_val_int8_t = LessValAVX512<int8_t>;
|
||||
less_val_int16_t = LessValAVX512<int16_t>;
|
||||
less_val_int32_t = LessValAVX512<int32_t>;
|
||||
less_val_int64_t = LessValAVX512<int64_t>;
|
||||
less_val_float = LessValAVX512<float>;
|
||||
less_val_double = LessValAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("less than val hook simd type:{} ", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
greater_val_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
greater_val_int8_t = GreaterValAVX512<int8_t>;
|
||||
greater_val_int16_t = GreaterValAVX512<int16_t>;
|
||||
greater_val_int32_t = GreaterValAVX512<int32_t>;
|
||||
greater_val_int64_t = GreaterValAVX512<int64_t>;
|
||||
greater_val_float = GreaterValAVX512<float>;
|
||||
greater_val_double = GreaterValAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("greater than val hook simd type: {} ", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
less_equal_val_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
less_equal_val_int8_t = LessEqualValAVX512<int8_t>;
|
||||
less_equal_val_int16_t = LessEqualValAVX512<int16_t>;
|
||||
less_equal_val_int32_t = LessEqualValAVX512<int32_t>;
|
||||
less_equal_val_int64_t = LessEqualValAVX512<int64_t>;
|
||||
less_equal_val_float = LessEqualValAVX512<float>;
|
||||
less_equal_val_double = LessEqualValAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("less equal than val hook simd type: {} ", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
greater_equal_val_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
greater_equal_val_int8_t = GreaterEqualValAVX512<int8_t>;
|
||||
greater_equal_val_int16_t = GreaterEqualValAVX512<int16_t>;
|
||||
greater_equal_val_int32_t = GreaterEqualValAVX512<int32_t>;
|
||||
greater_equal_val_int64_t = GreaterEqualValAVX512<int64_t>;
|
||||
greater_equal_val_float = GreaterEqualValAVX512<float>;
|
||||
greater_equal_val_double = GreaterEqualValAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("greater equal than val hook simd type: {} ", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
not_equal_val_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
not_equal_val_int8_t = NotEqualValAVX512<int8_t>;
|
||||
not_equal_val_int16_t = NotEqualValAVX512<int16_t>;
|
||||
not_equal_val_int32_t = NotEqualValAVX512<int32_t>;
|
||||
not_equal_val_int64_t = NotEqualValAVX512<int64_t>;
|
||||
not_equal_val_float = NotEqualValAVX512<float>;
|
||||
not_equal_val_double = NotEqualValAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("not equal val hook simd type: {}", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
equal_col_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
equal_col_int8_t = EqualColumnAVX512<int8_t>;
|
||||
equal_col_int16_t = EqualColumnAVX512<int16_t>;
|
||||
equal_col_int32_t = EqualColumnAVX512<int32_t>;
|
||||
equal_col_int64_t = EqualColumnAVX512<int64_t>;
|
||||
equal_col_float = EqualColumnAVX512<float>;
|
||||
equal_col_double = EqualColumnAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("equal column hook simd type:{} ", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
less_col_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
less_col_int8_t = LessColumnAVX512<int8_t>;
|
||||
less_col_int16_t = LessColumnAVX512<int16_t>;
|
||||
less_col_int32_t = LessColumnAVX512<int32_t>;
|
||||
less_col_int64_t = LessColumnAVX512<int64_t>;
|
||||
less_col_float = LessColumnAVX512<float>;
|
||||
less_col_double = LessColumnAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("less than column hook simd type:{} ", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
greater_col_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
greater_col_int8_t = GreaterColumnAVX512<int8_t>;
|
||||
greater_col_int16_t = GreaterColumnAVX512<int16_t>;
|
||||
greater_col_int32_t = GreaterColumnAVX512<int32_t>;
|
||||
greater_col_int64_t = GreaterColumnAVX512<int64_t>;
|
||||
greater_col_float = GreaterColumnAVX512<float>;
|
||||
greater_col_double = GreaterColumnAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("greater than column hook simd type:{} ", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
less_equal_col_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
less_equal_col_int8_t = LessEqualColumnAVX512<int8_t>;
|
||||
less_equal_col_int16_t = LessEqualColumnAVX512<int16_t>;
|
||||
less_equal_col_int32_t = LessEqualColumnAVX512<int32_t>;
|
||||
less_equal_col_int64_t = LessEqualColumnAVX512<int64_t>;
|
||||
less_equal_col_float = LessEqualColumnAVX512<float>;
|
||||
less_equal_col_double = LessEqualColumnAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("less equal than column hook simd type: {}", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
greater_equal_col_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
greater_equal_col_int8_t = GreaterEqualColumnAVX512<int8_t>;
|
||||
greater_equal_col_int16_t = GreaterEqualColumnAVX512<int16_t>;
|
||||
greater_equal_col_int32_t = GreaterEqualColumnAVX512<int32_t>;
|
||||
greater_equal_col_int64_t = GreaterEqualColumnAVX512<int64_t>;
|
||||
greater_equal_col_float = GreaterEqualColumnAVX512<float>;
|
||||
greater_equal_col_double = GreaterEqualColumnAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("greater equal than column hook simd type:{} ", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
not_equal_col_hook() {
|
||||
static std::mutex hook_mutex;
|
||||
std::lock_guard<std::mutex> lock(hook_mutex);
|
||||
std::string simd_type = "REF";
|
||||
#if defined(__x86_64__)
|
||||
// Only support avx512 for now
|
||||
if (cpu_support_avx512()) {
|
||||
simd_type = "AVX512";
|
||||
not_equal_col_int8_t = NotEqualColumnAVX512<int8_t>;
|
||||
not_equal_col_int16_t = NotEqualColumnAVX512<int16_t>;
|
||||
not_equal_col_int32_t = NotEqualColumnAVX512<int32_t>;
|
||||
not_equal_col_int64_t = NotEqualColumnAVX512<int64_t>;
|
||||
not_equal_col_float = NotEqualColumnAVX512<float>;
|
||||
not_equal_col_double = NotEqualColumnAVX512<double>;
|
||||
}
|
||||
#endif
|
||||
// TODO: support arm cpu
|
||||
LOG_INFO("not equal column hook simd type: {}", simd_type);
|
||||
}
|
||||
|
||||
static void
|
||||
compare_hook() {
|
||||
equal_val_hook();
|
||||
less_val_hook();
|
||||
greater_val_hook();
|
||||
less_equal_val_hook();
|
||||
greater_equal_val_hook();
|
||||
not_equal_val_hook();
|
||||
equal_col_hook();
|
||||
less_col_hook();
|
||||
greater_col_hook();
|
||||
less_equal_col_hook();
|
||||
greater_equal_col_hook();
|
||||
not_equal_col_hook();
|
||||
}
|
||||
|
||||
static int init_hook_ = []() {
|
||||
bitset_hook();
|
||||
find_term_hook();
|
||||
boolean_hook();
|
||||
find_term_hook();
|
||||
compare_hook();
|
||||
return 0;
|
||||
}();
|
||||
|
||||
|
||||
@ -18,41 +18,6 @@
|
||||
namespace milvus {
|
||||
namespace simd {
|
||||
|
||||
extern BitsetBlockType (*get_bitset_block)(const bool* src);
|
||||
extern bool (*all_false)(const bool* src, int64_t size);
|
||||
extern bool (*all_true)(const bool* src, int64_t size);
|
||||
extern void (*invert_bool)(bool* src, int64_t size);
|
||||
extern void (*and_bool)(bool* left, bool* right, int64_t size);
|
||||
extern void (*or_bool)(bool* left, bool* right, int64_t size);
|
||||
|
||||
template <typename T>
|
||||
using FindTermPtr = bool (*)(const T* src, size_t size, T val);
|
||||
|
||||
extern FindTermPtr<bool> find_term_bool;
|
||||
extern FindTermPtr<int8_t> find_term_int8;
|
||||
extern FindTermPtr<int16_t> find_term_int16;
|
||||
extern FindTermPtr<int32_t> find_term_int32;
|
||||
extern FindTermPtr<int64_t> find_term_int64;
|
||||
extern FindTermPtr<float> find_term_float;
|
||||
extern FindTermPtr<double> find_term_double;
|
||||
|
||||
#if defined(__x86_64__)
|
||||
// Flags that indicate whether runtime can choose
|
||||
// these simd type or not when hook starts.
|
||||
extern bool use_avx512;
|
||||
extern bool use_avx2;
|
||||
extern bool use_sse4_2;
|
||||
extern bool use_sse2;
|
||||
|
||||
// Flags that indicate which kind of simd for
|
||||
// different function when hook ends.
|
||||
extern bool use_bitset_sse2;
|
||||
extern bool use_find_term_sse2;
|
||||
extern bool use_find_term_sse4_2;
|
||||
extern bool use_find_term_avx2;
|
||||
extern bool use_find_term_avx512;
|
||||
#endif
|
||||
|
||||
#if defined(__x86_64__)
|
||||
bool
|
||||
cpu_support_avx512();
|
||||
@ -62,53 +27,135 @@ bool
|
||||
cpu_support_sse4_2();
|
||||
#endif
|
||||
|
||||
void
|
||||
bitset_hook();
|
||||
|
||||
void
|
||||
find_term_hook();
|
||||
|
||||
void
|
||||
boolean_hook();
|
||||
|
||||
void
|
||||
all_boolean_hook();
|
||||
|
||||
void
|
||||
invert_boolean_hook();
|
||||
|
||||
void
|
||||
logical_boolean_hook();
|
||||
extern BitsetBlockType (*get_bitset_block)(const bool* src);
|
||||
extern bool (*all_false)(const bool* src, int64_t size);
|
||||
extern bool (*all_true)(const bool* src, int64_t size);
|
||||
extern void (*invert_bool)(bool* src, int64_t size);
|
||||
extern void (*and_bool)(bool* left, bool* right, int64_t size);
|
||||
extern void (*or_bool)(bool* left, bool* right, int64_t size);
|
||||
|
||||
template <typename T>
|
||||
bool
|
||||
find_term_func(const T* data, size_t size, T val) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
using FindTermPtr = bool (*)(const T* src, size_t size, T val);
|
||||
#define EXTERN_FIND_TERM_PTR(type) extern FindTermPtr<type> find_term_##type;
|
||||
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
return milvus::simd::find_term_bool(data, size, val);
|
||||
}
|
||||
if constexpr (std::is_same_v<T, int8_t>) {
|
||||
return milvus::simd::find_term_int8(data, size, val);
|
||||
}
|
||||
if constexpr (std::is_same_v<T, int16_t>) {
|
||||
return milvus::simd::find_term_int16(data, size, val);
|
||||
}
|
||||
if constexpr (std::is_same_v<T, int32_t>) {
|
||||
return milvus::simd::find_term_int32(data, size, val);
|
||||
}
|
||||
if constexpr (std::is_same_v<T, int64_t>) {
|
||||
return milvus::simd::find_term_int64(data, size, val);
|
||||
}
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
return milvus::simd::find_term_float(data, size, val);
|
||||
}
|
||||
if constexpr (std::is_same_v<T, double>) {
|
||||
return milvus::simd::find_term_double(data, size, val);
|
||||
}
|
||||
}
|
||||
EXTERN_FIND_TERM_PTR(bool)
|
||||
EXTERN_FIND_TERM_PTR(int8_t)
|
||||
EXTERN_FIND_TERM_PTR(int16_t)
|
||||
EXTERN_FIND_TERM_PTR(int32_t)
|
||||
EXTERN_FIND_TERM_PTR(int64_t)
|
||||
EXTERN_FIND_TERM_PTR(float)
|
||||
EXTERN_FIND_TERM_PTR(double)
|
||||
|
||||
// Compare val function register
|
||||
// Such as A == 10, A < 10...
|
||||
template <typename T>
|
||||
using CompareValPtr = void (*)(const T* src, size_t size, T val, bool* res);
|
||||
#define EXTERN_COMPARE_VAL_PTR(prefix, type) \
|
||||
extern CompareValPtr<type> prefix##_##type;
|
||||
|
||||
// Compare column function register
|
||||
// Such as A == B, A < B...
|
||||
template <typename T>
|
||||
using CompareColPtr =
|
||||
void (*)(const T* left, const T* right, size_t size, bool* res);
|
||||
#define EXTERN_COMPARE_COL_PTR(prefix, type) \
|
||||
extern CompareColPtr<type> prefix##_##type;
|
||||
|
||||
EXTERN_COMPARE_VAL_PTR(equal_val, bool)
|
||||
EXTERN_COMPARE_VAL_PTR(equal_val, int8_t)
|
||||
EXTERN_COMPARE_VAL_PTR(equal_val, int16_t)
|
||||
EXTERN_COMPARE_VAL_PTR(equal_val, int32_t)
|
||||
EXTERN_COMPARE_VAL_PTR(equal_val, int64_t)
|
||||
EXTERN_COMPARE_VAL_PTR(equal_val, float)
|
||||
EXTERN_COMPARE_VAL_PTR(equal_val, double)
|
||||
|
||||
EXTERN_COMPARE_VAL_PTR(less_val, bool)
|
||||
EXTERN_COMPARE_VAL_PTR(less_val, int8_t)
|
||||
EXTERN_COMPARE_VAL_PTR(less_val, int16_t)
|
||||
EXTERN_COMPARE_VAL_PTR(less_val, int32_t)
|
||||
EXTERN_COMPARE_VAL_PTR(less_val, int64_t)
|
||||
EXTERN_COMPARE_VAL_PTR(less_val, float)
|
||||
EXTERN_COMPARE_VAL_PTR(less_val, double)
|
||||
|
||||
EXTERN_COMPARE_VAL_PTR(greater_val, bool)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_val, int8_t)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_val, int16_t)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_val, int32_t)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_val, int64_t)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_val, float)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_val, double)
|
||||
|
||||
EXTERN_COMPARE_VAL_PTR(less_equal_val, bool)
|
||||
EXTERN_COMPARE_VAL_PTR(less_equal_val, int8_t)
|
||||
EXTERN_COMPARE_VAL_PTR(less_equal_val, int16_t)
|
||||
EXTERN_COMPARE_VAL_PTR(less_equal_val, int32_t)
|
||||
EXTERN_COMPARE_VAL_PTR(less_equal_val, int64_t)
|
||||
EXTERN_COMPARE_VAL_PTR(less_equal_val, float)
|
||||
EXTERN_COMPARE_VAL_PTR(less_equal_val, double)
|
||||
|
||||
EXTERN_COMPARE_VAL_PTR(greater_equal_val, bool)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_equal_val, int8_t)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_equal_val, int16_t)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_equal_val, int32_t)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_equal_val, int64_t)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_equal_val, float)
|
||||
EXTERN_COMPARE_VAL_PTR(greater_equal_val, double)
|
||||
|
||||
EXTERN_COMPARE_VAL_PTR(not_equal_val, bool)
|
||||
EXTERN_COMPARE_VAL_PTR(not_equal_val, int8_t)
|
||||
EXTERN_COMPARE_VAL_PTR(not_equal_val, int16_t)
|
||||
EXTERN_COMPARE_VAL_PTR(not_equal_val, int32_t)
|
||||
EXTERN_COMPARE_VAL_PTR(not_equal_val, int64_t)
|
||||
EXTERN_COMPARE_VAL_PTR(not_equal_val, float)
|
||||
EXTERN_COMPARE_VAL_PTR(not_equal_val, double)
|
||||
|
||||
EXTERN_COMPARE_COL_PTR(equal_col, bool)
|
||||
EXTERN_COMPARE_COL_PTR(equal_col, int8_t)
|
||||
EXTERN_COMPARE_COL_PTR(equal_col, int16_t)
|
||||
EXTERN_COMPARE_COL_PTR(equal_col, int32_t)
|
||||
EXTERN_COMPARE_COL_PTR(equal_col, int64_t)
|
||||
EXTERN_COMPARE_COL_PTR(equal_col, float)
|
||||
EXTERN_COMPARE_COL_PTR(equal_col, double)
|
||||
|
||||
EXTERN_COMPARE_COL_PTR(less_col, bool)
|
||||
EXTERN_COMPARE_COL_PTR(less_col, int8_t)
|
||||
EXTERN_COMPARE_COL_PTR(less_col, int16_t)
|
||||
EXTERN_COMPARE_COL_PTR(less_col, int32_t)
|
||||
EXTERN_COMPARE_COL_PTR(less_col, int64_t)
|
||||
EXTERN_COMPARE_COL_PTR(less_col, float)
|
||||
EXTERN_COMPARE_COL_PTR(less_col, double)
|
||||
|
||||
EXTERN_COMPARE_COL_PTR(greater_col, bool)
|
||||
EXTERN_COMPARE_COL_PTR(greater_col, int8_t)
|
||||
EXTERN_COMPARE_COL_PTR(greater_col, int16_t)
|
||||
EXTERN_COMPARE_COL_PTR(greater_col, int32_t)
|
||||
EXTERN_COMPARE_COL_PTR(greater_col, int64_t)
|
||||
EXTERN_COMPARE_COL_PTR(greater_col, float)
|
||||
EXTERN_COMPARE_COL_PTR(greater_col, double)
|
||||
|
||||
EXTERN_COMPARE_COL_PTR(less_equal_col, bool)
|
||||
EXTERN_COMPARE_COL_PTR(less_equal_col, int8_t)
|
||||
EXTERN_COMPARE_COL_PTR(less_equal_col, int16_t)
|
||||
EXTERN_COMPARE_COL_PTR(less_equal_col, int32_t)
|
||||
EXTERN_COMPARE_COL_PTR(less_equal_col, int64_t)
|
||||
EXTERN_COMPARE_COL_PTR(less_equal_col, float)
|
||||
EXTERN_COMPARE_COL_PTR(less_equal_col, double)
|
||||
|
||||
EXTERN_COMPARE_COL_PTR(greater_equal_col, bool)
|
||||
EXTERN_COMPARE_COL_PTR(greater_equal_col, int8_t)
|
||||
EXTERN_COMPARE_COL_PTR(greater_equal_col, int16_t)
|
||||
EXTERN_COMPARE_COL_PTR(greater_equal_col, int32_t)
|
||||
EXTERN_COMPARE_COL_PTR(greater_equal_col, int64_t)
|
||||
EXTERN_COMPARE_COL_PTR(greater_equal_col, float)
|
||||
EXTERN_COMPARE_COL_PTR(greater_equal_col, double)
|
||||
|
||||
EXTERN_COMPARE_COL_PTR(not_equal_col, bool)
|
||||
EXTERN_COMPARE_COL_PTR(not_equal_col, int8_t)
|
||||
EXTERN_COMPARE_COL_PTR(not_equal_col, int16_t)
|
||||
EXTERN_COMPARE_COL_PTR(not_equal_col, int32_t)
|
||||
EXTERN_COMPARE_COL_PTR(not_equal_col, int64_t)
|
||||
EXTERN_COMPARE_COL_PTR(not_equal_col, float)
|
||||
EXTERN_COMPARE_COL_PTR(not_equal_col, double)
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
|
||||
264
internal/core/src/simd/interface.h
Normal file
264
internal/core/src/simd/interface.h
Normal file
@ -0,0 +1,264 @@
|
||||
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "hook.h"
|
||||
namespace milvus {
|
||||
namespace simd {
|
||||
|
||||
#define DISPATCH_FIND_TERM_SIMD_FUNC(type) \
|
||||
if constexpr (std::is_same_v<T, type>) { \
|
||||
return milvus::simd::find_term_##type(data, size, val); \
|
||||
}
|
||||
|
||||
#define DISPATCH_COMPARE_VAL_SIMD_FUNC(prefix, type) \
|
||||
if constexpr (std::is_same_v<T, type>) { \
|
||||
return milvus::simd::prefix##_##type(data, size, val, res); \
|
||||
}
|
||||
|
||||
#define DISPATCH_COMPARE_COL_SIMD_FUNC(prefix, type) \
|
||||
if constexpr (std::is_same_v<T, type>) { \
|
||||
return milvus::simd::prefix##_##type(left, right, size, res); \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool
|
||||
find_term_func(const T* data, size_t size, T val) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_FIND_TERM_SIMD_FUNC(bool)
|
||||
DISPATCH_FIND_TERM_SIMD_FUNC(int8_t)
|
||||
DISPATCH_FIND_TERM_SIMD_FUNC(int16_t)
|
||||
DISPATCH_FIND_TERM_SIMD_FUNC(int32_t)
|
||||
DISPATCH_FIND_TERM_SIMD_FUNC(int64_t)
|
||||
DISPATCH_FIND_TERM_SIMD_FUNC(float)
|
||||
DISPATCH_FIND_TERM_SIMD_FUNC(double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
equal_val_func(const T* data, int64_t size, T val, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, bool)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int8_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int16_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int32_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int64_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, float)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
less_val_func(const T* data, int64_t size, T val, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, bool)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int8_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int16_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int32_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int64_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, float)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
greater_val_func(const T* data, int64_t size, T val, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, bool)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int8_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int16_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int32_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int64_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, float)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
less_equal_val_func(const T* data, int64_t size, T val, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, bool)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int8_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int16_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int32_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int64_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, float)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
greater_equal_val_func(const T* data, int64_t size, T val, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, bool)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int8_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int16_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int32_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int64_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, float)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
not_equal_val_func(const T* data, int64_t size, T val, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, bool)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int8_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int16_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int32_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int64_t)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, float)
|
||||
DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
equal_col_func(const T* left, const T* right, int64_t size, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, bool)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int8_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int16_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int32_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int64_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, float)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
less_col_func(const T* left, const T* right, int64_t size, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, bool)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int8_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int16_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int32_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int64_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, float)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
greater_col_func(const T* left, const T* right, int64_t size, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, bool)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int8_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int16_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int32_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int64_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, float)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
less_equal_col_func(const T* left, const T* right, int64_t size, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, bool)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int8_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int16_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int32_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int64_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, float)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
greater_equal_col_func(const T* left, const T* right, int64_t size, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, bool)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int8_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int16_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int32_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int64_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, float)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
not_equal_col_func(const T* left, const T* right, int64_t size, bool* res) {
|
||||
static_assert(
|
||||
std::is_integral<T>::value || std::is_floating_point<T>::value,
|
||||
"T must be integral or float/double type");
|
||||
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, bool)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int8_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int16_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int32_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int64_t)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, float)
|
||||
DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, double)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
compare_col_func(CompareType cmp_type,
|
||||
const T* left,
|
||||
const T* right,
|
||||
int64_t size,
|
||||
bool* res) {
|
||||
if (cmp_type == CompareType::EQ) {
|
||||
equal_col_func(left, right, size, res);
|
||||
} else if (cmp_type == CompareType::NEQ) {
|
||||
not_equal_col_func(left, right, size, res);
|
||||
} else if (cmp_type == CompareType::GE) {
|
||||
greater_equal_col_func(left, right, size, res);
|
||||
} else if (cmp_type == CompareType::GT) {
|
||||
greater_col_func(left, right, size, res);
|
||||
} else if (cmp_type == CompareType::LE) {
|
||||
less_equal_col_func(left, right, size, res);
|
||||
} else if (cmp_type == CompareType::LT) {
|
||||
less_col_func(left, right, size, res);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
@ -45,5 +45,99 @@ FindTermRef(const T* src, size_t size, T val) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
EqualValRef(const T* src, size_t size, T val, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = src[i] == val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessValRef(const T* src, size_t size, T val, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = src[i] < val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GreaterValRef(const T* src, size_t size, T val, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = src[i] > val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessEqualValRef(const T* src, size_t size, T val, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = src[i] <= val;
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void
|
||||
GreaterEqualValRef(const T* src, size_t size, T val, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = src[i] >= val;
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void
|
||||
NotEqualValRef(const T* src, size_t size, T val, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = src[i] != val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
EqualColumnRef(const T* left, const T* right, size_t size, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = left[i] == right[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessColumnRef(const T* left, const T* right, size_t size, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = left[i] < right[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
LessEqualColumnRef(const T* left, const T* right, size_t size, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = left[i] <= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GreaterColumnRef(const T* left, const T* right, size_t size, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = left[i] > right[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
GreaterEqualColumnRef(const T* left, const T* right, size_t size, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = left[i] >= right[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
NotEqualColumnRef(const T* left, const T* right, size_t size, bool* res) {
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
res[i] = left[i] != right[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace simd
|
||||
} // namespace milvus
|
||||
|
||||
@ -61,9 +61,8 @@ FindTermSSE2(const bool* src, size_t vec_size, bool val) {
|
||||
__m128i xmm_target = _mm_set1_epi8(val);
|
||||
__m128i xmm_data;
|
||||
size_t num_chunks = vec_size / 16;
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
xmm_data =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(src + 16 * i));
|
||||
for (size_t i = 0; i < num_chunks * 16; i += 16) {
|
||||
xmm_data = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
|
||||
__m128i xmm_match = _mm_cmpeq_epi8(xmm_data, xmm_target);
|
||||
int mask = _mm_movemask_epi8(xmm_match);
|
||||
if (mask != 0) {
|
||||
@ -71,7 +70,7 @@ FindTermSSE2(const bool* src, size_t vec_size, bool val) {
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 16 * num_chunks; i < vec_size; ++i) {
|
||||
for (size_t i = num_chunks * 16; i < vec_size; ++i) {
|
||||
if (src[i] == val) {
|
||||
return true;
|
||||
}
|
||||
@ -86,9 +85,8 @@ FindTermSSE2(const int8_t* src, size_t vec_size, int8_t val) {
|
||||
__m128i xmm_target = _mm_set1_epi8(val);
|
||||
__m128i xmm_data;
|
||||
size_t num_chunks = vec_size / 16;
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
xmm_data =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(src + 16 * i));
|
||||
for (size_t i = 0; i < num_chunks * 16; i += 16) {
|
||||
xmm_data = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
|
||||
__m128i xmm_match = _mm_cmpeq_epi8(xmm_data, xmm_target);
|
||||
int mask = _mm_movemask_epi8(xmm_match);
|
||||
if (mask != 0) {
|
||||
@ -96,7 +94,7 @@ FindTermSSE2(const int8_t* src, size_t vec_size, int8_t val) {
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 16 * num_chunks; i < vec_size; ++i) {
|
||||
for (size_t i = num_chunks * 16; i < vec_size; ++i) {
|
||||
if (src[i] == val) {
|
||||
return true;
|
||||
}
|
||||
@ -111,9 +109,8 @@ FindTermSSE2(const int16_t* src, size_t vec_size, int16_t val) {
|
||||
__m128i xmm_target = _mm_set1_epi16(val);
|
||||
__m128i xmm_data;
|
||||
size_t num_chunks = vec_size / 8;
|
||||
for (size_t i = 0; i < num_chunks; i++) {
|
||||
xmm_data =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i * 8));
|
||||
for (size_t i = 0; i < num_chunks * 8; i += 8) {
|
||||
xmm_data = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
|
||||
__m128i xmm_match = _mm_cmpeq_epi16(xmm_data, xmm_target);
|
||||
int mask = _mm_movemask_epi8(xmm_match);
|
||||
if (mask != 0) {
|
||||
@ -121,7 +118,7 @@ FindTermSSE2(const int16_t* src, size_t vec_size, int16_t val) {
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 8 * num_chunks; i < vec_size; ++i) {
|
||||
for (size_t i = num_chunks * 8; i < vec_size; ++i) {
|
||||
if (src[i] == val) {
|
||||
return true;
|
||||
}
|
||||
@ -136,9 +133,9 @@ FindTermSSE2(const int32_t* src, size_t vec_size, int32_t val) {
|
||||
size_t remaining_size = vec_size % 4;
|
||||
|
||||
__m128i xmm_target = _mm_set1_epi32(val);
|
||||
for (size_t i = 0; i < num_chunk; ++i) {
|
||||
for (size_t i = 0; i < num_chunk * 4; i += 4) {
|
||||
__m128i xmm_data =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i * 4));
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
|
||||
__m128i xmm_match = _mm_cmpeq_epi32(xmm_data, xmm_target);
|
||||
int mask = _mm_movemask_epi8(xmm_match);
|
||||
if (mask != 0) {
|
||||
@ -180,9 +177,9 @@ FindTermSSE2(const int64_t* src, size_t vec_size, int64_t val) {
|
||||
size_t num_chunk = vec_size / 2;
|
||||
size_t remaining_size = vec_size % 2;
|
||||
|
||||
for (int64_t i = 0; i < num_chunk; i++) {
|
||||
for (int64_t i = 0; i < num_chunk * 2; i += 2) {
|
||||
__m128i xmm_vec =
|
||||
_mm_load_si128(reinterpret_cast<const __m128i*>(src + i * 2));
|
||||
_mm_load_si128(reinterpret_cast<const __m128i*>(src + i));
|
||||
|
||||
__m128i xmm_low = _mm_set1_epi32(low);
|
||||
__m128i xmm_high = _mm_set1_epi32(high);
|
||||
@ -203,13 +200,6 @@ FindTermSSE2(const int64_t* src, size_t vec_size, int64_t val) {
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
||||
// for (size_t i = 0; i < vec_size; ++i) {
|
||||
// if (src[i] == val) {
|
||||
// return true;
|
||||
// }
|
||||
// }
|
||||
// return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -217,8 +207,8 @@ bool
|
||||
FindTermSSE2(const float* src, size_t vec_size, float val) {
|
||||
size_t num_chunks = vec_size / 4;
|
||||
__m128 xmm_target = _mm_set1_ps(val);
|
||||
for (int i = 0; i < num_chunks; ++i) {
|
||||
__m128 xmm_data = _mm_loadu_ps(src + 4 * i);
|
||||
for (int i = 0; i < 4 * num_chunks; i += 4) {
|
||||
__m128 xmm_data = _mm_loadu_ps(src + i);
|
||||
__m128 xmm_match = _mm_cmpeq_ps(xmm_data, xmm_target);
|
||||
int mask = _mm_movemask_ps(xmm_match);
|
||||
if (mask != 0) {
|
||||
@ -239,8 +229,8 @@ bool
|
||||
FindTermSSE2(const double* src, size_t vec_size, double val) {
|
||||
size_t num_chunks = vec_size / 2;
|
||||
__m128d xmm_target = _mm_set1_pd(val);
|
||||
for (int i = 0; i < num_chunks; ++i) {
|
||||
__m128d xmm_data = _mm_loadu_pd(src + 2 * i);
|
||||
for (int i = 0; i < 2 * num_chunks; i += 2) {
|
||||
__m128d xmm_data = _mm_loadu_pd(src + i);
|
||||
__m128d xmm_match = _mm_cmpeq_pd(xmm_data, xmm_target);
|
||||
int mask = _mm_movemask_pd(xmm_match);
|
||||
if (mask != 0) {
|
||||
|
||||
@ -32,9 +32,9 @@ FindTermSSE4(const int64_t* src, size_t vec_size, int64_t val) {
|
||||
size_t remaining_size = vec_size % 2;
|
||||
|
||||
__m128i xmm_target = _mm_set1_epi64x(val);
|
||||
for (size_t i = 0; i < num_chunk; ++i) {
|
||||
for (size_t i = 0; i < num_chunk * 2; i += 2) {
|
||||
__m128i xmm_data =
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i * 2));
|
||||
_mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
|
||||
__m128i xmm_match = _mm_cmpeq_epi64(xmm_data, xmm_target);
|
||||
int mask = _mm_movemask_epi8(xmm_match);
|
||||
if (mask != 0) {
|
||||
|
||||
@ -38,6 +38,7 @@ using FixedVector = boost::container::vector<Type>;
|
||||
#include "simd/sse4.h"
|
||||
#include "simd/avx2.h"
|
||||
#include "simd/avx512.h"
|
||||
#include "simd/ref.h"
|
||||
|
||||
using namespace milvus::simd;
|
||||
TEST(GetBitSetBlock, base_test_sse) {
|
||||
@ -107,6 +108,30 @@ TEST(GetBitSetBlock, base_test_sse) {
|
||||
ASSERT_EQ(res, 0x1084210842108421);
|
||||
}
|
||||
|
||||
TEST(GetBitsetBlockPerf, bitset) {
|
||||
FixedVector<bool> srcs;
|
||||
for (size_t i = 0; i < 100000000; ++i) {
|
||||
srcs.push_back(i % 2 == 0);
|
||||
}
|
||||
std::cout << "start test" << std::endl;
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
for (int i = 0; i < 10000000; ++i)
|
||||
auto result = GetBitsetBlockSSE2(srcs.data() + i);
|
||||
std::cout << "cost: "
|
||||
<< std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::steady_clock::now() - start)
|
||||
.count()
|
||||
<< "us" << std::endl;
|
||||
start = std::chrono::steady_clock::now();
|
||||
for (int i = 0; i < 10000000; ++i)
|
||||
auto result = GetBitsetBlockAVX2(srcs.data() + i);
|
||||
std::cout << "cost: "
|
||||
<< std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::steady_clock::now() - start)
|
||||
.count()
|
||||
<< "us" << std::endl;
|
||||
}
|
||||
|
||||
TEST(GetBitSetBlock, base_test_avx2) {
|
||||
FixedVector<bool> src;
|
||||
for (int i = 0; i < 64; ++i) {
|
||||
@ -1214,10 +1239,298 @@ TEST(AllBooleanNeon, performance) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(EqualVal, perf_int8) {
|
||||
if (!cpu_support_avx512()) {
|
||||
PRINT_SKPI_TEST
|
||||
return;
|
||||
}
|
||||
std::vector<int8_t> srcs(1000000);
|
||||
for (int i = 0; i < 1000000; ++i) {
|
||||
srcs[i] = i % 128;
|
||||
}
|
||||
FixedVector<bool> res(1000000);
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
EqualValRef(srcs.data(), 1000000, (int8_t)10, res.data());
|
||||
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::steady_clock::now() - start)
|
||||
.count()
|
||||
<< std::endl;
|
||||
start = std::chrono::steady_clock::now();
|
||||
EqualValAVX512(srcs.data(), 1000000, (int8_t)10, res.data());
|
||||
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::steady_clock::now() - start)
|
||||
.count()
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
TestCompareValAVX512Perf() {
|
||||
if (!cpu_support_avx512()) {
|
||||
PRINT_SKPI_TEST
|
||||
return;
|
||||
}
|
||||
std::vector<T> srcs(1000000);
|
||||
for (int i = 0; i < 1000000; ++i) {
|
||||
srcs[i] = i;
|
||||
}
|
||||
FixedVector<bool> res(1000000);
|
||||
T target = 10;
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
EqualValRef(srcs.data(), 1000000, target, res.data());
|
||||
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::steady_clock::now() - start)
|
||||
.count()
|
||||
<< std::endl;
|
||||
start = std::chrono::steady_clock::now();
|
||||
EqualValAVX512(srcs.data(), 1000000, target, res.data());
|
||||
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::steady_clock::now() - start)
|
||||
.count()
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
TEST(EqualVal, perf_int16) {
|
||||
TestCompareValAVX512Perf<int16_t>();
|
||||
}
|
||||
|
||||
TEST(EqualVal, pref_int32) {
|
||||
TestCompareValAVX512Perf<int32_t>();
|
||||
}
|
||||
|
||||
TEST(EqualVal, perf_int64) {
|
||||
TestCompareValAVX512Perf<int64_t>();
|
||||
}
|
||||
|
||||
TEST(EqualVal, perf_float) {
|
||||
TestCompareValAVX512Perf<float>();
|
||||
}
|
||||
|
||||
TEST(EqualVal, perf_double) {
|
||||
TestCompareValAVX512Perf<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
TestCompareValAVX512(int size, T target) {
|
||||
if (!cpu_support_avx512()) {
|
||||
PRINT_SKPI_TEST
|
||||
return;
|
||||
}
|
||||
std::vector<T> vecs;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if constexpr (std::is_same_v<T, int8_t>) {
|
||||
vecs.push_back(i % 127);
|
||||
} else if constexpr (std::is_floating_point_v<T>) {
|
||||
vecs.push_back(i + 0.01);
|
||||
} else {
|
||||
vecs.push_back(i);
|
||||
}
|
||||
}
|
||||
FixedVector<bool> res(size);
|
||||
|
||||
EqualValAVX512(vecs.data(), size, target, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], vecs[i] == target) << i;
|
||||
}
|
||||
LessValAVX512(vecs.data(), size, target, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], vecs[i] < target) << i;
|
||||
}
|
||||
LessEqualValAVX512(vecs.data(), size, target, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], vecs[i] <= target) << i;
|
||||
}
|
||||
GreaterEqualValAVX512(vecs.data(), size, target, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], vecs[i] >= target) << i;
|
||||
}
|
||||
GreaterValAVX512(vecs.data(), size, target, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], vecs[i] > target) << i;
|
||||
}
|
||||
NotEqualValAVX512(vecs.data(), size, target, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], vecs[i] != target) << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompareVal, avx512_int8) {
|
||||
TestCompareValAVX512<int8_t>(1000, 9);
|
||||
TestCompareValAVX512<int8_t>(1000, 99);
|
||||
TestCompareValAVX512<int8_t>(1001, 127);
|
||||
}
|
||||
|
||||
TEST(CompareVal, avx512_int16) {
|
||||
TestCompareValAVX512<int16_t>(1000, 99);
|
||||
TestCompareValAVX512<int16_t>(1000, 999);
|
||||
TestCompareValAVX512<int16_t>(1001, 1000);
|
||||
}
|
||||
|
||||
TEST(CompareVal, avx512_int32) {
|
||||
TestCompareValAVX512<int32_t>(1000, 99);
|
||||
TestCompareValAVX512<int32_t>(1000, 999);
|
||||
TestCompareValAVX512<int32_t>(1001, 1000);
|
||||
}
|
||||
|
||||
TEST(CompareVal, avx512_int64) {
|
||||
TestCompareValAVX512<int64_t>(1000, 99);
|
||||
TestCompareValAVX512<int64_t>(1000, 999);
|
||||
TestCompareValAVX512<int64_t>(1001, 1000);
|
||||
}
|
||||
|
||||
TEST(CompareVal, avx512_float) {
|
||||
TestCompareValAVX512<float>(1000, 99.01);
|
||||
TestCompareValAVX512<float>(1000, 999.01);
|
||||
TestCompareValAVX512<float>(1001, 1000.01);
|
||||
}
|
||||
|
||||
TEST(CompareVal, avx512_double) {
|
||||
TestCompareValAVX512<double>(1000, 99.01);
|
||||
TestCompareValAVX512<double>(1000, 999.01);
|
||||
TestCompareValAVX512<double>(1001, 1000.01);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
TestCompareColumnAVX512Perf() {
|
||||
if (!cpu_support_avx512()) {
|
||||
PRINT_SKPI_TEST
|
||||
return;
|
||||
}
|
||||
std::vector<T> lefts(1000000);
|
||||
for (int i = 0; i < 1000000; ++i) {
|
||||
lefts[i] = i;
|
||||
}
|
||||
std::vector<T> rights(1000000);
|
||||
for (int i = 0; i < 1000000; ++i) {
|
||||
rights[i] = i;
|
||||
}
|
||||
FixedVector<bool> res(1000000);
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
LessColumnRef(lefts.data(), rights.data(), 1000000, res.data());
|
||||
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::steady_clock::now() - start)
|
||||
.count()
|
||||
<< std::endl;
|
||||
start = std::chrono::steady_clock::now();
|
||||
LessColumnAVX512(lefts.data(), rights.data(), 1000000, res.data());
|
||||
std::cout << std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::steady_clock::now() - start)
|
||||
.count()
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
TEST(LessColumn, pref_int32) {
|
||||
TestCompareColumnAVX512Perf<int32_t>();
|
||||
}
|
||||
|
||||
TEST(LessColumn, perf_int64) {
|
||||
TestCompareColumnAVX512Perf<int64_t>();
|
||||
}
|
||||
|
||||
TEST(LessColumn, perf_float) {
|
||||
TestCompareColumnAVX512Perf<float>();
|
||||
}
|
||||
|
||||
TEST(LessColumn, perf_double) {
|
||||
TestCompareColumnAVX512Perf<double>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
TestCompareColumnAVX512(int size, T min_val, T max_val) {
|
||||
if (!cpu_support_avx512()) {
|
||||
PRINT_SKPI_TEST
|
||||
return;
|
||||
}
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
|
||||
std::vector<T> left;
|
||||
std::vector<T> right;
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
std::uniform_real_distribution<float> dis(min_val, max_val);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
left.push_back(dis(gen));
|
||||
right.push_back(dis(gen));
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
std::uniform_real_distribution<double> dis(min_val, max_val);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
left.push_back(dis(gen));
|
||||
right.push_back(dis(gen));
|
||||
}
|
||||
} else {
|
||||
std::uniform_int_distribution<> dis(min_val, max_val);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
left.push_back(dis(gen));
|
||||
right.push_back(dis(gen));
|
||||
}
|
||||
}
|
||||
|
||||
FixedVector<bool> res(size);
|
||||
|
||||
EqualColumnAVX512(left.data(), right.data(), size, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], left[i] == right[i]) << i;
|
||||
}
|
||||
LessColumnAVX512(left.data(), right.data(), size, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], left[i] < right[i]) << i;
|
||||
}
|
||||
GreaterColumnAVX512(left.data(), right.data(), size, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], left[i] > right[i]) << i;
|
||||
}
|
||||
LessEqualColumnAVX512(left.data(), right.data(), size, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], left[i] <= right[i]) << i;
|
||||
}
|
||||
GreaterEqualColumnAVX512(left.data(), right.data(), size, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], left[i] >= right[i]) << i;
|
||||
}
|
||||
NotEqualColumnAVX512(left.data(), right.data(), size, res.data());
|
||||
for (int i = 0; i < size; i++) {
|
||||
ASSERT_EQ(res[i], left[i] != right[i]) << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CompareColumn, avx512_int8) {
|
||||
TestCompareColumnAVX512<int8_t>(1000, -128, 127);
|
||||
TestCompareColumnAVX512<int8_t>(1001, -128, 127);
|
||||
}
|
||||
|
||||
TEST(CompareColumn, avx512_int16) {
|
||||
TestCompareColumnAVX512<int16_t>(1000, -1000, 1000);
|
||||
TestCompareColumnAVX512<int16_t>(1001, -1000, 1000);
|
||||
}
|
||||
|
||||
TEST(CompareColumn, avx512_int32) {
|
||||
TestCompareColumnAVX512<int32_t>(1000, -1000, 1000);
|
||||
TestCompareColumnAVX512<int32_t>(1001, -1000, 1000);
|
||||
}
|
||||
|
||||
TEST(CompareColumn, avx512_int64) {
|
||||
TestCompareColumnAVX512<int64_t>(1000, -1000, 1000);
|
||||
TestCompareColumnAVX512<int64_t>(1001, -1000, 1000);
|
||||
}
|
||||
|
||||
TEST(CompareColumn, avx512_float) {
|
||||
TestCompareColumnAVX512<float>(1000, -1.0, 1.0);
|
||||
TestCompareColumnAVX512<float>(1001, -1.0, 1.0);
|
||||
}
|
||||
|
||||
TEST(CompareColumn, avx512_double) {
|
||||
TestCompareColumnAVX512<double>(1000, -1.0, 1.0);
|
||||
TestCompareColumnAVX512<double>(1001, -1.0, 1.0);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
int
|
||||
main(int argc, char* argv[]) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user