From d07197ab1a065ac57b2f659da9e913629b040fd8 Mon Sep 17 00:00:00 2001 From: zhagnlu <1542303831@qq.com> Date: Sun, 7 Jan 2024 20:20:57 +0800 Subject: [PATCH] enhance: add compare simd function (#29432) #26137 Signed-off-by: luzhang Co-authored-by: luzhang --- .../core/src/exec/expression/CompareExpr.h | 21 +- .../src/query/visitors/ExecExprVisitor.cpp | 21 - internal/core/src/simd/CMakeLists.txt | 5 +- internal/core/src/simd/avx2.cpp | 41 +- internal/core/src/simd/avx512.cpp | 725 +++++++++++++++++- internal/core/src/simd/avx512.h | 48 ++ internal/core/src/simd/common.h | 9 + internal/core/src/simd/hook.cpp | 489 ++++++++++-- internal/core/src/simd/hook.h | 205 +++-- internal/core/src/simd/interface.h | 264 +++++++ internal/core/src/simd/ref.h | 94 +++ internal/core/src/simd/sse2.cpp | 44 +- internal/core/src/simd/sse4.cpp | 4 +- internal/core/unittest/test_simd.cpp | 315 +++++++- 14 files changed, 2044 insertions(+), 241 deletions(-) create mode 100644 internal/core/src/simd/interface.h diff --git a/internal/core/src/exec/expression/CompareExpr.h b/internal/core/src/exec/expression/CompareExpr.h index 5b0497e0b8..c05974eb54 100644 --- a/internal/core/src/exec/expression/CompareExpr.h +++ b/internal/core/src/exec/expression/CompareExpr.h @@ -24,6 +24,7 @@ #include "common/Vector.h" #include "exec/expression/Expr.h" #include "segcore/SegmentInterface.h" +#include "simd/interface.h" namespace milvus { namespace exec { @@ -41,7 +42,7 @@ using ChunkDataAccessor = std::function; template struct CompareElementFunc { void - operator()(const T* left, const U* right, size_t size, bool* res) { + operator_base(const T* left, const U* right, size_t size, bool* res) { for (int i = 0; i < size; ++i) { if constexpr (op == proto::plan::OpType::Equal) { res[i] = left[i] == right[i]; @@ -63,6 +64,24 @@ struct CompareElementFunc { } } } + + void + operator()(const T* left, const U* right, size_t size, bool* res) { +#if defined(USE_DYNAMIC_SIMD) + if constexpr (std::is_same_v) { + milvus::simd::compare_col_func( + static_cast(op), + left, + right, + size, + res); + } else { + operator_base(left, right, size, res); + } +#else + operator_base(left, right, size, res); +#endif + } }; class PhyCompareFilterExpr : public Expr { diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index 808f1758ab..e6a8ef901c 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -2632,30 +2632,9 @@ ExecExprVisitor::ExecTermVisitorImplTemplate(TermExpr& expr_raw) -> BitsetType { return index->In(n, terms.data()); }; -#if defined(USE_DYNAMIC_SIMD) - std::function x)> elem_func; - if (n <= milvus::simd::TERM_EXPR_IN_SIZE_THREAD) { - elem_func = [&terms, &term_set, n](MayConstRef x) { - if constexpr (std::is_integral::value || - std::is_floating_point::value) { - return milvus::simd::find_term_func(terms.data(), n, x); - } else { - // For string type, simd performance not better than set mode - static_assert(std::is_same::value || - std::is_same::value); - return term_set.find(x) != term_set.end(); - } - }; - } else { - elem_func = [&term_set, n](MayConstRef x) { - return term_set.find(x) != term_set.end(); - }; - } -#else auto elem_func = [&term_set](MayConstRef x) { return term_set.find(x) != term_set.end(); }; -#endif auto default_skip_index_func = [&](const SkipIndex& skipIndex, FieldId fieldId, diff --git a/internal/core/src/simd/CMakeLists.txt b/internal/core/src/simd/CMakeLists.txt index ced8277197..632373da08 100644 --- a/internal/core/src/simd/CMakeLists.txt +++ b/internal/core/src/simd/CMakeLists.txt @@ -25,7 +25,8 @@ if (${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") ) set_source_files_properties(sse4.cpp PROPERTIES COMPILE_FLAGS "-msse4.2") set_source_files_properties(avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2") - set_source_files_properties(avx512.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx512bw") + set_source_files_properties(avx512.cpp PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512vl -mavx512dq -mavx512bw") + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm*") # TODO: add arm cpu simd message ("simd using arm mode") @@ -37,4 +38,4 @@ endif() add_library(milvus_simd ${MILVUS_SIMD_SRCS}) # Link the milvus_simd library with other libraries as needed -target_link_libraries(milvus_simd milvus_log) \ No newline at end of file +target_link_libraries(milvus_simd milvus_log) diff --git a/internal/core/src/simd/avx2.cpp b/internal/core/src/simd/avx2.cpp index 08c6a2636d..1ea51facca 100644 --- a/internal/core/src/simd/avx2.cpp +++ b/internal/core/src/simd/avx2.cpp @@ -29,11 +29,12 @@ GetBitsetBlockAVX2(const bool* src) { // BitsetBlockType has 64 bits __m256i highbit = _mm256_set1_epi8(0x7F); uint32_t tmp[8]; - for (size_t i = 0; i < 2; i += 1) { - __m256i boolvec = _mm256_loadu_si256((__m256i*)&src[i * 32]); - __m256i highbits = _mm256_add_epi8(boolvec, highbit); - tmp[i] = _mm256_movemask_epi8(highbits); - } + __m256i boolvec = _mm256_loadu_si256((__m256i*)(src)); + __m256i highbits = _mm256_add_epi8(boolvec, highbit); + tmp[0] = _mm256_movemask_epi8(highbits); + boolvec = _mm256_loadu_si256((__m256i*)(src + 32)); + highbits = _mm256_add_epi8(boolvec, highbit); + tmp[1] = _mm256_movemask_epi8(highbits); __m256i tmpvec = _mm256_loadu_si256((__m256i*)tmp); BitsetBlockType res[4]; @@ -65,9 +66,9 @@ FindTermAVX2(const bool* src, size_t vec_size, bool val) { __m256i ymm_data; size_t num_chunks = vec_size / 32; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 32 * num_chunks; i += 32) { ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + 32 * i)); + _mm256_loadu_si256(reinterpret_cast(src + i)); __m256i ymm_match = _mm256_cmpeq_epi8(ymm_data, ymm_target); int mask = _mm256_movemask_epi8(ymm_match); if (mask != 0) { @@ -90,9 +91,9 @@ FindTermAVX2(const int8_t* src, size_t vec_size, int8_t val) { __m256i ymm_data; size_t num_chunks = vec_size / 32; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 32 * num_chunks; i += 32) { ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + 32 * i)); + _mm256_loadu_si256(reinterpret_cast(src + i)); __m256i ymm_match = _mm256_cmpeq_epi8(ymm_data, ymm_target); int mask = _mm256_movemask_epi8(ymm_match); if (mask != 0) { @@ -114,10 +115,9 @@ FindTermAVX2(const int16_t* src, size_t vec_size, int16_t val) { __m256i ymm_target = _mm256_set1_epi16(val); __m256i ymm_data; size_t num_chunks = vec_size / 16; - size_t remaining_size = vec_size % 16; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 16 * num_chunks; i += 16) { ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + 16 * i)); + _mm256_loadu_si256(reinterpret_cast(src + i)); __m256i ymm_match = _mm256_cmpeq_epi16(ymm_data, ymm_target); int mask = _mm256_movemask_epi8(ymm_match); if (mask != 0) { @@ -141,9 +141,9 @@ FindTermAVX2(const int32_t* src, size_t vec_size, int32_t val) { size_t num_chunks = vec_size / 8; size_t remaining_size = vec_size % 8; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 8 * num_chunks; i += 8) { ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + 8 * i)); + _mm256_loadu_si256(reinterpret_cast(src + i)); __m256i ymm_match = _mm256_cmpeq_epi32(ymm_data, ymm_target); int mask = _mm256_movemask_epi8(ymm_match); if (mask != 0) { @@ -163,11 +163,10 @@ FindTermAVX2(const int64_t* src, size_t vec_size, int64_t val) { __m256i ymm_target = _mm256_set1_epi64x(val); __m256i ymm_data; size_t num_chunks = vec_size / 4; - size_t remaining_size = vec_size % 4; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 4 * num_chunks; i += 4) { ymm_data = - _mm256_loadu_si256(reinterpret_cast(src + 4 * i)); + _mm256_loadu_si256(reinterpret_cast(src + i)); __m256i ymm_match = _mm256_cmpeq_epi64(ymm_data, ymm_target); int mask = _mm256_movemask_epi8(ymm_match); if (mask != 0) { @@ -190,8 +189,8 @@ FindTermAVX2(const float* src, size_t vec_size, float val) { __m256 ymm_data; size_t num_chunks = vec_size / 8; - for (size_t i = 0; i < num_chunks; i++) { - ymm_data = _mm256_loadu_ps(src + 8 * i); + for (size_t i = 0; i < 8 * num_chunks; i += 8) { + ymm_data = _mm256_loadu_ps(src + i); __m256 ymm_match = _mm256_cmp_ps(ymm_data, ymm_target, _CMP_EQ_OQ); int mask = _mm256_movemask_ps(ymm_match); if (mask != 0) { @@ -214,8 +213,8 @@ FindTermAVX2(const double* src, size_t vec_size, double val) { __m256d ymm_data; size_t num_chunks = vec_size / 4; - for (size_t i = 0; i < num_chunks; i++) { - ymm_data = _mm256_loadu_pd(src + 8 * i); + for (size_t i = 0; i < 4 * num_chunks; i += 4) { + ymm_data = _mm256_loadu_pd(src + i); __m256d ymm_match = _mm256_cmp_pd(ymm_data, ymm_target, _CMP_EQ_OQ); int mask = _mm256_movemask_pd(ymm_match); if (mask != 0) { diff --git a/internal/core/src/simd/avx512.cpp b/internal/core/src/simd/avx512.cpp index 3df38319fd..e1bc4da3ff 100644 --- a/internal/core/src/simd/avx512.cpp +++ b/internal/core/src/simd/avx512.cpp @@ -25,9 +25,9 @@ FindTermAVX512(const bool* src, size_t vec_size, bool val) { __m512i zmm_data; size_t num_chunks = vec_size / 64; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 64 * num_chunks; i += 64) { zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + 64 * i)); + _mm512_loadu_si512(reinterpret_cast(src + i)); __mmask64 mask = _mm512_cmpeq_epi8_mask(zmm_data, zmm_target); if (mask != 0) { return true; @@ -49,9 +49,9 @@ FindTermAVX512(const int8_t* src, size_t vec_size, int8_t val) { __m512i zmm_data; size_t num_chunks = vec_size / 64; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 64 * num_chunks; i += 64) { zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + 64 * i)); + _mm512_loadu_si512(reinterpret_cast(src + i)); __mmask64 mask = _mm512_cmpeq_epi8_mask(zmm_data, zmm_target); if (mask != 0) { return true; @@ -73,9 +73,9 @@ FindTermAVX512(const int16_t* src, size_t vec_size, int16_t val) { __m512i zmm_data; size_t num_chunks = vec_size / 32; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 32 * num_chunks; i += 32) { zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + 32 * i)); + _mm512_loadu_si512(reinterpret_cast(src + i)); __mmask32 mask = _mm512_cmpeq_epi16_mask(zmm_data, zmm_target); if (mask != 0) { return true; @@ -97,9 +97,9 @@ FindTermAVX512(const int32_t* src, size_t vec_size, int32_t val) { __m512i zmm_data; size_t num_chunks = vec_size / 16; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 16 * num_chunks; i += 16) { zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + 16 * i)); + _mm512_loadu_si512(reinterpret_cast(src + i)); __mmask16 mask = _mm512_cmpeq_epi32_mask(zmm_data, zmm_target); if (mask != 0) { return true; @@ -121,9 +121,9 @@ FindTermAVX512(const int64_t* src, size_t vec_size, int64_t val) { __m512i zmm_data; size_t num_chunks = vec_size / 8; - for (size_t i = 0; i < num_chunks; i++) { + for (size_t i = 0; i < 8 * num_chunks; i += 8) { zmm_data = - _mm512_loadu_si512(reinterpret_cast(src + 8 * i)); + _mm512_loadu_si512(reinterpret_cast(src + i)); __mmask8 mask = _mm512_cmpeq_epi64_mask(zmm_data, zmm_target); if (mask != 0) { return true; @@ -145,8 +145,8 @@ FindTermAVX512(const float* src, size_t vec_size, float val) { __m512 zmm_data; size_t num_chunks = vec_size / 16; - for (size_t i = 0; i < num_chunks; i++) { - zmm_data = _mm512_loadu_ps(src + 16 * i); + for (size_t i = 0; i < 16 * num_chunks; i += 16) { + zmm_data = _mm512_loadu_ps(src + i); __mmask16 mask = _mm512_cmp_ps_mask(zmm_data, zmm_target, _CMP_EQ_OQ); if (mask != 0) { return true; @@ -168,8 +168,8 @@ FindTermAVX512(const double* src, size_t vec_size, double val) { __m512d zmm_data; size_t num_chunks = vec_size / 8; - for (size_t i = 0; i < num_chunks; i++) { - zmm_data = _mm512_loadu_pd(src + 8 * i); + for (size_t i = 0; i < 8 * num_chunks; i += 8) { + zmm_data = _mm512_loadu_pd(src + i); __mmask8 mask = _mm512_cmp_pd_mask(zmm_data, zmm_target, _CMP_EQ_OQ); if (mask != 0) { return true; @@ -216,6 +216,703 @@ OrBoolAVX512(bool* left, bool* right, int64_t size) { } } +template +struct CompareOperator; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_EQ_OQ : _MM_CMPINT_EQ; + static constexpr bool + Op(T a, T b) { + return a == b; + } +}; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_NEQ_OQ : _MM_CMPINT_NE; + static constexpr bool + Op(T a, T b) { + return a != b; + } +}; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_LT_OQ : _MM_CMPINT_LT; + static constexpr bool + Op(T a, T b) { + return a < b; + } +}; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_LE_OQ : _MM_CMPINT_LE; + static constexpr bool + Op(T a, T b) { + return a <= b; + } +}; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_GT_OQ : _MM_CMPINT_NLE; + static constexpr bool + Op(T a, T b) { + return a > b; + } +}; + +template +struct CompareOperator { + static constexpr int ComparePredicate = + std::is_floating_point_v ? _CMP_GE_OQ : _MM_CMPINT_NLT; + static constexpr bool + Op(T a, T b) { + return a >= b; + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const int8_t* src, size_t size, int8_t val, bool* res) { + __m512i target = _mm512_set1_epi8(val); + + int middle = size / 64 * 64; + + for (size_t i = 0; i < middle; i += 64) { + __m512i data = + _mm512_loadu_si512(reinterpret_cast(src + i)); + + __mmask64 cmp_res_mask = _mm512_cmp_epi8_mask( + data, + target, + (CompareOperator::ComparePredicate)); + __m512i cmp_res = _mm512_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm512_storeu_si512(res + i, cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const int16_t* src, size_t size, int16_t val, bool* res) { + __m512i target = _mm512_set1_epi16(val); + + int middle = size / 32 * 32; + + for (size_t i = 0; i < middle; i += 32) { + __m512i data = + _mm512_loadu_si512(reinterpret_cast(src + i)); + + __mmask32 cmp_res_mask = _mm512_cmp_epi16_mask( + data, + target, + (CompareOperator::ComparePredicate)); + __m256i cmp_res = _mm256_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm256_storeu_si256((__m256i*)(res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const int32_t* src, size_t size, int32_t val, bool* res) { + __m512i target = _mm512_set1_epi32(val); + + int middle = size / 16 * 16; + + for (size_t i = 0; i < middle; i += 16) { + __m512i data = + _mm512_loadu_si512(reinterpret_cast(src + i)); + + __mmask16 cmp_res_mask = _mm512_cmp_epi32_mask( + data, + target, + (CompareOperator::ComparePredicate)); + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si128((__m128i*)(res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const int64_t* src, size_t size, int64_t val, bool* res) { + __m512i target = _mm512_set1_epi64(val); + int middle = size / 8 * 8; + int index = 0; + for (size_t i = 0; i < middle; i += 8) { + __m512i data = + _mm512_loadu_si512(reinterpret_cast(src + i)); + __mmask8 mask = _mm512_cmp_epi64_mask( + data, + target, + (CompareOperator::ComparePredicate)); + __m128i cmp_res = _mm_maskz_set1_epi8(mask, 0x01); + _mm_storeu_si64((__m128i*)(res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const float* src, size_t size, float val, bool* res) { + __m512 target = _mm512_set1_ps(val); + + int middle = size / 16 * 16; + + for (size_t i = 0; i < middle; i += 16) { + __m512 data = _mm512_loadu_ps(src + i); + + __mmask16 cmp_res_mask = _mm512_cmp_ps_mask( + data, target, (CompareOperator::ComparePredicate)); + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si128((__m128i*)(res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +struct CompareValAVX512Impl { + static void + Compare(const double* src, size_t size, double val, bool* res) { + __m512d target = _mm512_set1_pd(val); + + int middle = size / 8 * 8; + + for (size_t i = 0; i < middle; i += 8) { + __m512d data = _mm512_loadu_pd(src + i); + + __mmask8 cmp_res_mask = _mm512_cmp_pd_mask( + data, + target, + (CompareOperator::ComparePredicate)); + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si64((res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(src[i], val); + } + } +}; + +template +void +EqualValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +EqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +EqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +EqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +EqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +EqualValAVX512(const float* src, size_t size, float val, bool* res); +template void +EqualValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +LessValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +LessValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +LessValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +LessValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +LessValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +LessValAVX512(const float* src, size_t size, float val, bool* res); +template void +LessValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +GreaterValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +GreaterValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +GreaterValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +GreaterValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +GreaterValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +GreaterValAVX512(const float* src, size_t size, float val, bool* res); +template void +GreaterValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +NotEqualValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +NotEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +NotEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +NotEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +NotEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +NotEqualValAVX512(const float* src, size_t size, float val, bool* res); +template void +NotEqualValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +LessEqualValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +LessEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +LessEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +LessEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +LessEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +LessEqualValAVX512(const float* src, size_t size, float val, bool* res); +template void +LessEqualValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +GreaterEqualValAVX512(const T* src, size_t size, T val, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareValAVX512Impl::Compare(src, size, val, res); +}; +template void +GreaterEqualValAVX512(const int8_t* src, size_t size, int8_t val, bool* res); +template void +GreaterEqualValAVX512(const int16_t* src, size_t size, int16_t val, bool* res); +template void +GreaterEqualValAVX512(const int32_t* src, size_t size, int32_t val, bool* res); +template void +GreaterEqualValAVX512(const int64_t* src, size_t size, int64_t val, bool* res); +template void +GreaterEqualValAVX512(const float* src, size_t size, float val, bool* res); +template void +GreaterEqualValAVX512(const double* src, size_t size, double val, bool* res); + +template +void +CompareColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); +} + +template +struct CompareColumnAVX512Impl { + static void + Compare(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v, "T must be integral type"); + + int batch_size = 512 / (sizeof(T) * 8); + int middle = size / batch_size * batch_size; + + for (size_t i = 0; i < middle; i += batch_size) { + __m512i left_reg = + _mm512_loadu_si512(reinterpret_cast(left + i)); + __m512i right_reg = + _mm512_loadu_si512(reinterpret_cast(right + i)); + + if constexpr (std::is_same_v) { + __mmask64 cmp_res_mask = _mm512_cmp_epi8_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m512i cmp_res = _mm512_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm512_storeu_si512(res + i, cmp_res); + } else if constexpr (std::is_same_v) { + __mmask32 cmp_res_mask = _mm512_cmp_epi16_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m256i cmp_res = _mm256_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm256_storeu_si256((__m256i*)(res + i), cmp_res); + } else if constexpr (std::is_same_v) { + __mmask16 cmp_res_mask = _mm512_cmp_epi32_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si128((__m128i*)(res + i), cmp_res); + } else if constexpr (std::is_same_v) { + __mmask8 mask = _mm512_cmp_epi64_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m128i cmp_res = _mm_maskz_set1_epi8(mask, 0x01); + _mm_storeu_si64((__m128i*)(res + i), cmp_res); + } + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(left[i], right[i]); + } + } +}; + +template +struct CompareColumnAVX512Impl { + static void + Compare(const float* left, const float* right, size_t size, bool* res) { + int batch_size = 512 / (sizeof(float) * 8); + int middle = size / batch_size * batch_size; + + for (size_t i = 0; i < middle; i += batch_size) { + __m512 left_reg = + _mm512_loadu_ps(reinterpret_cast(left + i)); + __m512 right_reg = + _mm512_loadu_ps(reinterpret_cast(right + i)); + + __mmask16 cmp_res_mask = _mm512_cmp_ps_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si128((__m128i*)(res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(left[i], right[i]); + } + } +}; + +template +struct CompareColumnAVX512Impl { + static void + Compare(const double* left, const double* right, size_t size, bool* res) { + int batch_size = 512 / (sizeof(double) * 8); + int middle = size / batch_size * batch_size; + + for (size_t i = 0; i < middle; i += batch_size) { + __m512d left_reg = + _mm512_loadu_pd(reinterpret_cast(left + i)); + __m512d right_reg = + _mm512_loadu_pd(reinterpret_cast(right + i)); + + __mmask8 cmp_res_mask = _mm512_cmp_pd_mask( + left_reg, + right_reg, + (CompareOperator::ComparePredicate)); + + __m128i cmp_res = _mm_maskz_set1_epi8(cmp_res_mask, 0x01); + _mm_storeu_si64((res + i), cmp_res); + } + + for (size_t i = middle; i < size; ++i) { + res[i] = CompareOperator::Op(left[i], right[i]); + } + } +}; + +template +void +EqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; + +template void +EqualColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +EqualColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +EqualColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +EqualColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +EqualColumnAVX512(const float* left, + const float* right, + size_t size, + bool* res); +template void +EqualColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + +template +void +LessColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; +template void +LessColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +LessColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +LessColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +LessColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +LessColumnAVX512(const float* left, const float* right, size_t size, bool* res); +template void +LessColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + +template +void +GreaterColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; +template void +GreaterColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +GreaterColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +GreaterColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +GreaterColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +GreaterColumnAVX512(const float* left, + const float* right, + size_t size, + bool* res); +template void +GreaterColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + +template +void +LessEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; +template void +LessEqualColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +LessEqualColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +LessEqualColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +LessEqualColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +LessEqualColumnAVX512(const float* left, + const float* right, + size_t size, + bool* res); +template void +LessEqualColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + +template +void +GreaterEqualColumnAVX512(const T* left, + const T* right, + size_t size, + bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; +template void +GreaterEqualColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +GreaterEqualColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +GreaterEqualColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +GreaterEqualColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +GreaterEqualColumnAVX512(const float* left, + const float* right, + size_t size, + bool* res); +template void +GreaterEqualColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + +template +void +NotEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res) { + static_assert(std::is_integral_v || std::is_floating_point_v, + "T must be integral or float/double type"); + CompareColumnAVX512Impl::Compare( + left, right, size, res); +}; + +template void +NotEqualColumnAVX512(const int8_t* left, + const int8_t* right, + size_t size, + bool* res); +template void +NotEqualColumnAVX512(const int16_t* left, + const int16_t* right, + size_t size, + bool* res); +template void +NotEqualColumnAVX512(const int32_t* left, + const int32_t* right, + size_t size, + bool* res); +template void +NotEqualColumnAVX512(const int64_t* left, + const int64_t* right, + size_t size, + bool* res); +template void +NotEqualColumnAVX512(const float* left, + const float* right, + size_t size, + bool* res); +template void +NotEqualColumnAVX512(const double* left, + const double* right, + size_t size, + bool* res); + } // namespace simd } // namespace milvus #endif diff --git a/internal/core/src/simd/avx512.h b/internal/core/src/simd/avx512.h index fe24b00bb6..9b5c549d3d 100644 --- a/internal/core/src/simd/avx512.h +++ b/internal/core/src/simd/avx512.h @@ -61,5 +61,53 @@ AndBoolAVX512(bool* left, bool* right, int64_t size); void OrBoolAVX512(bool* left, bool* right, int64_t size); +template +void +EqualValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +LessValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +GreaterValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +NotEqualValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +LessEqualValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +GreaterEqualValAVX512(const T* src, size_t size, T val, bool* res); + +template +void +EqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); + +template +void +LessColumnAVX512(const T* left, const T* right, size_t size, bool* res); + +template +void +LessEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); + +template +void +GreaterColumnAVX512(const T* left, const T* right, size_t size, bool* res); + +template +void +GreaterEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); + +template +void +NotEqualColumnAVX512(const T* left, const T* right, size_t size, bool* res); + } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/common.h b/internal/core/src/simd/common.h index 3cbe9c6e3e..f6e0c9e3c6 100644 --- a/internal/core/src/simd/common.h +++ b/internal/core/src/simd/common.h @@ -40,5 +40,14 @@ const int TERM_EXPR_IN_SIZE_THREAD = 50; std::is_same::value || std::is_same::value, \ Message); +enum class CompareType { + GT = 1, + GE = 2, + LT = 3, + LE = 4, + EQ = 5, + NEQ = 6, +}; + } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/hook.cpp b/internal/core/src/simd/hook.cpp index 2fe688b857..89b5b30067 100644 --- a/internal/core/src/simd/hook.cpp +++ b/internal/core/src/simd/hook.cpp @@ -32,19 +32,6 @@ namespace milvus { namespace simd { -#if defined(__x86_64__) -bool use_avx512 = true; -bool use_avx2 = true; -bool use_sse4_2 = true; -bool use_sse2 = true; - -bool use_bitset_sse2; -bool use_find_term_sse2; -bool use_find_term_sse4_2; -bool use_find_term_avx2; -bool use_find_term_avx512; -#endif - decltype(get_bitset_block) get_bitset_block = GetBitsetBlockRef; decltype(all_false) all_false = AllFalseRef; decltype(all_true) all_true = AllTrueRef; @@ -52,20 +39,124 @@ decltype(invert_bool) invert_bool = InvertBoolRef; decltype(and_bool) and_bool = AndBoolRef; decltype(or_bool) or_bool = OrBoolRef; -FindTermPtr find_term_bool = FindTermRef; -FindTermPtr find_term_int8 = FindTermRef; -FindTermPtr find_term_int16 = FindTermRef; -FindTermPtr find_term_int32 = FindTermRef; -FindTermPtr find_term_int64 = FindTermRef; -FindTermPtr find_term_float = FindTermRef; -FindTermPtr find_term_double = FindTermRef; +#define DECLARE_FIND_TERM_PTR(type) \ + FindTermPtr find_term_##type = FindTermRef; +DECLARE_FIND_TERM_PTR(bool) +DECLARE_FIND_TERM_PTR(int8_t) +DECLARE_FIND_TERM_PTR(int16_t) +DECLARE_FIND_TERM_PTR(int32_t) +DECLARE_FIND_TERM_PTR(int64_t) +DECLARE_FIND_TERM_PTR(float) +DECLARE_FIND_TERM_PTR(double) + +#define DECLARE_COMPARE_VAL_PTR(prefix, RefFunc, type) \ + CompareValPtr prefix##_##type = RefFunc; + +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, bool) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, float) +DECLARE_COMPARE_VAL_PTR(equal_val, EqualValRef, double) + +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, bool) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, float) +DECLARE_COMPARE_VAL_PTR(less_val, LessValRef, double) + +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, bool) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, float) +DECLARE_COMPARE_VAL_PTR(greater_val, GreaterValRef, double) + +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, bool) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, float) +DECLARE_COMPARE_VAL_PTR(less_equal_val, LessEqualValRef, double) + +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, bool) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, float) +DECLARE_COMPARE_VAL_PTR(greater_equal_val, GreaterEqualValRef, double) + +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, bool) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int8_t) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int16_t) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int32_t) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, int64_t) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, float) +DECLARE_COMPARE_VAL_PTR(not_equal_val, NotEqualValRef, double) + +#define DECLARE_COMPARE_COL_PTR(prefix, RefFunc, type) \ + CompareColPtr prefix##_##type = RefFunc; + +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, bool) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, float) +DECLARE_COMPARE_COL_PTR(equal_col, EqualColumnRef, double) + +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, bool) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, float) +DECLARE_COMPARE_COL_PTR(less_col, LessColumnRef, double) + +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, bool) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, float) +DECLARE_COMPARE_COL_PTR(greater_col, GreaterColumnRef, double) + +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, bool) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, float) +DECLARE_COMPARE_COL_PTR(less_equal_col, LessEqualColumnRef, double) + +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, bool) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, float) +DECLARE_COMPARE_COL_PTR(greater_equal_col, GreaterEqualColumnRef, double) + +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, bool) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int8_t) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int16_t) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int32_t) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, int64_t) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, float) +DECLARE_COMPARE_COL_PTR(not_equal_col, NotEqualColumnRef, double) #if defined(__x86_64__) bool cpu_support_avx512() { InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); return (instruction_set_inst.AVX512F() && instruction_set_inst.AVX512DQ() && - instruction_set_inst.AVX512BW()); + instruction_set_inst.AVX512BW() && instruction_set_inst.AVX512VL()); } bool @@ -87,95 +178,77 @@ cpu_support_sse2() { } #endif -void +static void bitset_hook() { static std::mutex hook_mutex; std::lock_guard lock(hook_mutex); std::string simd_type = "REF"; #if defined(__x86_64__) - if (use_avx512 && cpu_support_avx512()) { - simd_type = "AVX512"; - // For now, sse2 has best performance - get_bitset_block = GetBitsetBlockSSE2; - use_bitset_sse2 = true; - } else if (use_avx2 && cpu_support_avx2()) { - simd_type = "AVX2"; - // For now, sse2 has best performance - get_bitset_block = GetBitsetBlockSSE2; - use_bitset_sse2 = true; - } else if (use_sse4_2 && cpu_support_sse4_2()) { - simd_type = "SSE4"; - get_bitset_block = GetBitsetBlockSSE2; - use_bitset_sse2 = true; - } else if (use_sse2 && cpu_support_sse2()) { + // SSE2 have best performance in test. + if (cpu_support_sse2()) { simd_type = "SSE2"; get_bitset_block = GetBitsetBlockSSE2; - use_bitset_sse2 = true; } #endif // TODO: support arm cpu LOG_INFO("bitset hook simd type: {}", simd_type); } -void +static void find_term_hook() { static std::mutex hook_mutex; std::lock_guard lock(hook_mutex); std::string simd_type = "REF"; #if defined(__x86_64__) - if (use_avx512 && cpu_support_avx512()) { + if (cpu_support_avx512()) { simd_type = "AVX512"; find_term_bool = FindTermAVX512; - find_term_int8 = FindTermAVX512; - find_term_int16 = FindTermAVX512; - find_term_int32 = FindTermAVX512; - find_term_int64 = FindTermAVX512; + find_term_int8_t = FindTermAVX512; + find_term_int16_t = FindTermAVX512; + find_term_int32_t = FindTermAVX512; + find_term_int64_t = FindTermAVX512; find_term_float = FindTermAVX512; find_term_double = FindTermAVX512; - use_find_term_avx512 = true; - } else if (use_avx2 && cpu_support_avx2()) { + } else if (cpu_support_avx2()) { simd_type = "AVX2"; find_term_bool = FindTermAVX2; - find_term_int8 = FindTermAVX2; - find_term_int16 = FindTermAVX2; - find_term_int32 = FindTermAVX2; - find_term_int64 = FindTermAVX2; + find_term_int8_t = FindTermAVX2; + find_term_int16_t = FindTermAVX2; + find_term_int32_t = FindTermAVX2; + find_term_int64_t = FindTermAVX2; find_term_float = FindTermAVX2; find_term_double = FindTermAVX2; - use_find_term_avx2 = true; - } else if (use_sse4_2 && cpu_support_sse4_2()) { + } else if (cpu_support_sse4_2()) { simd_type = "SSE4"; find_term_bool = FindTermSSE4; - find_term_int8 = FindTermSSE4; - find_term_int16 = FindTermSSE4; - find_term_int32 = FindTermSSE4; - find_term_int64 = FindTermSSE4; + find_term_int8_t = FindTermSSE4; + find_term_int16_t = FindTermSSE4; + find_term_int32_t = FindTermSSE4; + find_term_int64_t = FindTermSSE4; find_term_float = FindTermSSE4; find_term_double = FindTermSSE4; - use_find_term_sse4_2 = true; - } else if (use_sse2 && cpu_support_sse2()) { + } else if (cpu_support_sse2()) { simd_type = "SSE2"; find_term_bool = FindTermSSE2; - find_term_int8 = FindTermSSE2; - find_term_int16 = FindTermSSE2; - find_term_int32 = FindTermSSE2; - find_term_int64 = FindTermSSE2; + find_term_int8_t = FindTermSSE2; + find_term_int16_t = FindTermSSE2; + find_term_int32_t = FindTermSSE2; + find_term_int64_t = FindTermSSE2; find_term_float = FindTermSSE2; find_term_double = FindTermSSE2; - use_find_term_sse2 = true; } #endif // TODO: support arm cpu LOG_INFO("find term hook simd type: {}", simd_type); } -void +static void all_boolean_hook() { static std::mutex hook_mutex; std::lock_guard lock(hook_mutex); std::string simd_type = "REF"; #if defined(__x86_64__) - if (use_sse2 && cpu_support_sse2()) { + if (cpu_support_sse2()) { simd_type = "SSE2"; all_false = AllFalseSSE2; all_true = AllTrueSSE2; @@ -189,13 +262,13 @@ all_boolean_hook() { LOG_INFO("AllFalse/AllTrue hook simd type: {}", simd_type); } -void +static void invert_boolean_hook() { static std::mutex hook_mutex; std::lock_guard lock(hook_mutex); std::string simd_type = "REF"; #if defined(__x86_64__) - if (use_sse2 && cpu_support_sse2()) { + if (cpu_support_sse2()) { simd_type = "SSE2"; invert_bool = InvertBoolSSE2; } @@ -207,21 +280,21 @@ invert_boolean_hook() { LOG_INFO("InvertBoolean hook simd type: {}", simd_type); } -void +static void logical_boolean_hook() { static std::mutex hook_mutex; std::lock_guard lock(hook_mutex); std::string simd_type = "REF"; #if defined(__x86_64__) - if (use_avx512 && cpu_support_avx512()) { + if (cpu_support_avx512()) { simd_type = "AVX512"; and_bool = AndBoolAVX512; or_bool = OrBoolAVX512; - } else if (use_avx2 && cpu_support_avx2()) { + } else if (cpu_support_avx2()) { simd_type = "AVX2"; and_bool = AndBoolAVX2; or_bool = OrBoolAVX2; - } else if (use_sse2 && cpu_support_sse2()) { + } else if (cpu_support_sse2()) { simd_type = "SSE2"; and_bool = AndBoolSSE2; or_bool = OrBoolSSE2; @@ -234,17 +307,287 @@ logical_boolean_hook() { // TODO: support arm cpu LOG_INFO("InvertBoolean hook simd type: {}", simd_type); } -void + +static void boolean_hook() { all_boolean_hook(); invert_boolean_hook(); logical_boolean_hook(); } +static void +equal_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + equal_val_int8_t = EqualValAVX512; + equal_val_int16_t = EqualValAVX512; + equal_val_int32_t = EqualValAVX512; + equal_val_int64_t = EqualValAVX512; + equal_val_float = EqualValAVX512; + equal_val_double = EqualValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("equal val hook simd type: {} ", simd_type); +} + +static void +less_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + less_val_int8_t = LessValAVX512; + less_val_int16_t = LessValAVX512; + less_val_int32_t = LessValAVX512; + less_val_int64_t = LessValAVX512; + less_val_float = LessValAVX512; + less_val_double = LessValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("less than val hook simd type:{} ", simd_type); +} + +static void +greater_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + greater_val_int8_t = GreaterValAVX512; + greater_val_int16_t = GreaterValAVX512; + greater_val_int32_t = GreaterValAVX512; + greater_val_int64_t = GreaterValAVX512; + greater_val_float = GreaterValAVX512; + greater_val_double = GreaterValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("greater than val hook simd type: {} ", simd_type); +} + +static void +less_equal_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + less_equal_val_int8_t = LessEqualValAVX512; + less_equal_val_int16_t = LessEqualValAVX512; + less_equal_val_int32_t = LessEqualValAVX512; + less_equal_val_int64_t = LessEqualValAVX512; + less_equal_val_float = LessEqualValAVX512; + less_equal_val_double = LessEqualValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("less equal than val hook simd type: {} ", simd_type); +} + +static void +greater_equal_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + greater_equal_val_int8_t = GreaterEqualValAVX512; + greater_equal_val_int16_t = GreaterEqualValAVX512; + greater_equal_val_int32_t = GreaterEqualValAVX512; + greater_equal_val_int64_t = GreaterEqualValAVX512; + greater_equal_val_float = GreaterEqualValAVX512; + greater_equal_val_double = GreaterEqualValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("greater equal than val hook simd type: {} ", simd_type); +} + +static void +not_equal_val_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + not_equal_val_int8_t = NotEqualValAVX512; + not_equal_val_int16_t = NotEqualValAVX512; + not_equal_val_int32_t = NotEqualValAVX512; + not_equal_val_int64_t = NotEqualValAVX512; + not_equal_val_float = NotEqualValAVX512; + not_equal_val_double = NotEqualValAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("not equal val hook simd type: {}", simd_type); +} + +static void +equal_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + equal_col_int8_t = EqualColumnAVX512; + equal_col_int16_t = EqualColumnAVX512; + equal_col_int32_t = EqualColumnAVX512; + equal_col_int64_t = EqualColumnAVX512; + equal_col_float = EqualColumnAVX512; + equal_col_double = EqualColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("equal column hook simd type:{} ", simd_type); +} + +static void +less_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + less_col_int8_t = LessColumnAVX512; + less_col_int16_t = LessColumnAVX512; + less_col_int32_t = LessColumnAVX512; + less_col_int64_t = LessColumnAVX512; + less_col_float = LessColumnAVX512; + less_col_double = LessColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("less than column hook simd type:{} ", simd_type); +} + +static void +greater_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + greater_col_int8_t = GreaterColumnAVX512; + greater_col_int16_t = GreaterColumnAVX512; + greater_col_int32_t = GreaterColumnAVX512; + greater_col_int64_t = GreaterColumnAVX512; + greater_col_float = GreaterColumnAVX512; + greater_col_double = GreaterColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("greater than column hook simd type:{} ", simd_type); +} + +static void +less_equal_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + less_equal_col_int8_t = LessEqualColumnAVX512; + less_equal_col_int16_t = LessEqualColumnAVX512; + less_equal_col_int32_t = LessEqualColumnAVX512; + less_equal_col_int64_t = LessEqualColumnAVX512; + less_equal_col_float = LessEqualColumnAVX512; + less_equal_col_double = LessEqualColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("less equal than column hook simd type: {}", simd_type); +} + +static void +greater_equal_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + greater_equal_col_int8_t = GreaterEqualColumnAVX512; + greater_equal_col_int16_t = GreaterEqualColumnAVX512; + greater_equal_col_int32_t = GreaterEqualColumnAVX512; + greater_equal_col_int64_t = GreaterEqualColumnAVX512; + greater_equal_col_float = GreaterEqualColumnAVX512; + greater_equal_col_double = GreaterEqualColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("greater equal than column hook simd type:{} ", simd_type); +} + +static void +not_equal_col_hook() { + static std::mutex hook_mutex; + std::lock_guard lock(hook_mutex); + std::string simd_type = "REF"; +#if defined(__x86_64__) + // Only support avx512 for now + if (cpu_support_avx512()) { + simd_type = "AVX512"; + not_equal_col_int8_t = NotEqualColumnAVX512; + not_equal_col_int16_t = NotEqualColumnAVX512; + not_equal_col_int32_t = NotEqualColumnAVX512; + not_equal_col_int64_t = NotEqualColumnAVX512; + not_equal_col_float = NotEqualColumnAVX512; + not_equal_col_double = NotEqualColumnAVX512; + } +#endif + // TODO: support arm cpu + LOG_INFO("not equal column hook simd type: {}", simd_type); +} + +static void +compare_hook() { + equal_val_hook(); + less_val_hook(); + greater_val_hook(); + less_equal_val_hook(); + greater_equal_val_hook(); + not_equal_val_hook(); + equal_col_hook(); + less_col_hook(); + greater_col_hook(); + less_equal_col_hook(); + greater_equal_col_hook(); + not_equal_col_hook(); +} + static int init_hook_ = []() { bitset_hook(); - find_term_hook(); boolean_hook(); + find_term_hook(); + compare_hook(); return 0; }(); diff --git a/internal/core/src/simd/hook.h b/internal/core/src/simd/hook.h index 98e82853ae..2ffbbd8144 100644 --- a/internal/core/src/simd/hook.h +++ b/internal/core/src/simd/hook.h @@ -18,41 +18,6 @@ namespace milvus { namespace simd { -extern BitsetBlockType (*get_bitset_block)(const bool* src); -extern bool (*all_false)(const bool* src, int64_t size); -extern bool (*all_true)(const bool* src, int64_t size); -extern void (*invert_bool)(bool* src, int64_t size); -extern void (*and_bool)(bool* left, bool* right, int64_t size); -extern void (*or_bool)(bool* left, bool* right, int64_t size); - -template -using FindTermPtr = bool (*)(const T* src, size_t size, T val); - -extern FindTermPtr find_term_bool; -extern FindTermPtr find_term_int8; -extern FindTermPtr find_term_int16; -extern FindTermPtr find_term_int32; -extern FindTermPtr find_term_int64; -extern FindTermPtr find_term_float; -extern FindTermPtr find_term_double; - -#if defined(__x86_64__) -// Flags that indicate whether runtime can choose -// these simd type or not when hook starts. -extern bool use_avx512; -extern bool use_avx2; -extern bool use_sse4_2; -extern bool use_sse2; - -// Flags that indicate which kind of simd for -// different function when hook ends. -extern bool use_bitset_sse2; -extern bool use_find_term_sse2; -extern bool use_find_term_sse4_2; -extern bool use_find_term_avx2; -extern bool use_find_term_avx512; -#endif - #if defined(__x86_64__) bool cpu_support_avx512(); @@ -62,53 +27,135 @@ bool cpu_support_sse4_2(); #endif -void -bitset_hook(); - -void -find_term_hook(); - -void -boolean_hook(); - -void -all_boolean_hook(); - -void -invert_boolean_hook(); - -void -logical_boolean_hook(); +extern BitsetBlockType (*get_bitset_block)(const bool* src); +extern bool (*all_false)(const bool* src, int64_t size); +extern bool (*all_true)(const bool* src, int64_t size); +extern void (*invert_bool)(bool* src, int64_t size); +extern void (*and_bool)(bool* left, bool* right, int64_t size); +extern void (*or_bool)(bool* left, bool* right, int64_t size); template -bool -find_term_func(const T* data, size_t size, T val) { - static_assert( - std::is_integral::value || std::is_floating_point::value, - "T must be integral or float/double type"); +using FindTermPtr = bool (*)(const T* src, size_t size, T val); +#define EXTERN_FIND_TERM_PTR(type) extern FindTermPtr find_term_##type; - if constexpr (std::is_same_v) { - return milvus::simd::find_term_bool(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_int8(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_int16(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_int32(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_int64(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_float(data, size, val); - } - if constexpr (std::is_same_v) { - return milvus::simd::find_term_double(data, size, val); - } -} +EXTERN_FIND_TERM_PTR(bool) +EXTERN_FIND_TERM_PTR(int8_t) +EXTERN_FIND_TERM_PTR(int16_t) +EXTERN_FIND_TERM_PTR(int32_t) +EXTERN_FIND_TERM_PTR(int64_t) +EXTERN_FIND_TERM_PTR(float) +EXTERN_FIND_TERM_PTR(double) + +// Compare val function register +// Such as A == 10, A < 10... +template +using CompareValPtr = void (*)(const T* src, size_t size, T val, bool* res); +#define EXTERN_COMPARE_VAL_PTR(prefix, type) \ + extern CompareValPtr prefix##_##type; + +// Compare column function register +// Such as A == B, A < B... +template +using CompareColPtr = + void (*)(const T* left, const T* right, size_t size, bool* res); +#define EXTERN_COMPARE_COL_PTR(prefix, type) \ + extern CompareColPtr prefix##_##type; + +EXTERN_COMPARE_VAL_PTR(equal_val, bool) +EXTERN_COMPARE_VAL_PTR(equal_val, int8_t) +EXTERN_COMPARE_VAL_PTR(equal_val, int16_t) +EXTERN_COMPARE_VAL_PTR(equal_val, int32_t) +EXTERN_COMPARE_VAL_PTR(equal_val, int64_t) +EXTERN_COMPARE_VAL_PTR(equal_val, float) +EXTERN_COMPARE_VAL_PTR(equal_val, double) + +EXTERN_COMPARE_VAL_PTR(less_val, bool) +EXTERN_COMPARE_VAL_PTR(less_val, int8_t) +EXTERN_COMPARE_VAL_PTR(less_val, int16_t) +EXTERN_COMPARE_VAL_PTR(less_val, int32_t) +EXTERN_COMPARE_VAL_PTR(less_val, int64_t) +EXTERN_COMPARE_VAL_PTR(less_val, float) +EXTERN_COMPARE_VAL_PTR(less_val, double) + +EXTERN_COMPARE_VAL_PTR(greater_val, bool) +EXTERN_COMPARE_VAL_PTR(greater_val, int8_t) +EXTERN_COMPARE_VAL_PTR(greater_val, int16_t) +EXTERN_COMPARE_VAL_PTR(greater_val, int32_t) +EXTERN_COMPARE_VAL_PTR(greater_val, int64_t) +EXTERN_COMPARE_VAL_PTR(greater_val, float) +EXTERN_COMPARE_VAL_PTR(greater_val, double) + +EXTERN_COMPARE_VAL_PTR(less_equal_val, bool) +EXTERN_COMPARE_VAL_PTR(less_equal_val, int8_t) +EXTERN_COMPARE_VAL_PTR(less_equal_val, int16_t) +EXTERN_COMPARE_VAL_PTR(less_equal_val, int32_t) +EXTERN_COMPARE_VAL_PTR(less_equal_val, int64_t) +EXTERN_COMPARE_VAL_PTR(less_equal_val, float) +EXTERN_COMPARE_VAL_PTR(less_equal_val, double) + +EXTERN_COMPARE_VAL_PTR(greater_equal_val, bool) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, int8_t) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, int16_t) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, int32_t) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, int64_t) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, float) +EXTERN_COMPARE_VAL_PTR(greater_equal_val, double) + +EXTERN_COMPARE_VAL_PTR(not_equal_val, bool) +EXTERN_COMPARE_VAL_PTR(not_equal_val, int8_t) +EXTERN_COMPARE_VAL_PTR(not_equal_val, int16_t) +EXTERN_COMPARE_VAL_PTR(not_equal_val, int32_t) +EXTERN_COMPARE_VAL_PTR(not_equal_val, int64_t) +EXTERN_COMPARE_VAL_PTR(not_equal_val, float) +EXTERN_COMPARE_VAL_PTR(not_equal_val, double) + +EXTERN_COMPARE_COL_PTR(equal_col, bool) +EXTERN_COMPARE_COL_PTR(equal_col, int8_t) +EXTERN_COMPARE_COL_PTR(equal_col, int16_t) +EXTERN_COMPARE_COL_PTR(equal_col, int32_t) +EXTERN_COMPARE_COL_PTR(equal_col, int64_t) +EXTERN_COMPARE_COL_PTR(equal_col, float) +EXTERN_COMPARE_COL_PTR(equal_col, double) + +EXTERN_COMPARE_COL_PTR(less_col, bool) +EXTERN_COMPARE_COL_PTR(less_col, int8_t) +EXTERN_COMPARE_COL_PTR(less_col, int16_t) +EXTERN_COMPARE_COL_PTR(less_col, int32_t) +EXTERN_COMPARE_COL_PTR(less_col, int64_t) +EXTERN_COMPARE_COL_PTR(less_col, float) +EXTERN_COMPARE_COL_PTR(less_col, double) + +EXTERN_COMPARE_COL_PTR(greater_col, bool) +EXTERN_COMPARE_COL_PTR(greater_col, int8_t) +EXTERN_COMPARE_COL_PTR(greater_col, int16_t) +EXTERN_COMPARE_COL_PTR(greater_col, int32_t) +EXTERN_COMPARE_COL_PTR(greater_col, int64_t) +EXTERN_COMPARE_COL_PTR(greater_col, float) +EXTERN_COMPARE_COL_PTR(greater_col, double) + +EXTERN_COMPARE_COL_PTR(less_equal_col, bool) +EXTERN_COMPARE_COL_PTR(less_equal_col, int8_t) +EXTERN_COMPARE_COL_PTR(less_equal_col, int16_t) +EXTERN_COMPARE_COL_PTR(less_equal_col, int32_t) +EXTERN_COMPARE_COL_PTR(less_equal_col, int64_t) +EXTERN_COMPARE_COL_PTR(less_equal_col, float) +EXTERN_COMPARE_COL_PTR(less_equal_col, double) + +EXTERN_COMPARE_COL_PTR(greater_equal_col, bool) +EXTERN_COMPARE_COL_PTR(greater_equal_col, int8_t) +EXTERN_COMPARE_COL_PTR(greater_equal_col, int16_t) +EXTERN_COMPARE_COL_PTR(greater_equal_col, int32_t) +EXTERN_COMPARE_COL_PTR(greater_equal_col, int64_t) +EXTERN_COMPARE_COL_PTR(greater_equal_col, float) +EXTERN_COMPARE_COL_PTR(greater_equal_col, double) + +EXTERN_COMPARE_COL_PTR(not_equal_col, bool) +EXTERN_COMPARE_COL_PTR(not_equal_col, int8_t) +EXTERN_COMPARE_COL_PTR(not_equal_col, int16_t) +EXTERN_COMPARE_COL_PTR(not_equal_col, int32_t) +EXTERN_COMPARE_COL_PTR(not_equal_col, int64_t) +EXTERN_COMPARE_COL_PTR(not_equal_col, float) +EXTERN_COMPARE_COL_PTR(not_equal_col, double) } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/interface.h b/internal/core/src/simd/interface.h new file mode 100644 index 0000000000..e93a5c31dc --- /dev/null +++ b/internal/core/src/simd/interface.h @@ -0,0 +1,264 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include "hook.h" +namespace milvus { +namespace simd { + +#define DISPATCH_FIND_TERM_SIMD_FUNC(type) \ + if constexpr (std::is_same_v) { \ + return milvus::simd::find_term_##type(data, size, val); \ + } + +#define DISPATCH_COMPARE_VAL_SIMD_FUNC(prefix, type) \ + if constexpr (std::is_same_v) { \ + return milvus::simd::prefix##_##type(data, size, val, res); \ + } + +#define DISPATCH_COMPARE_COL_SIMD_FUNC(prefix, type) \ + if constexpr (std::is_same_v) { \ + return milvus::simd::prefix##_##type(left, right, size, res); \ + } + +template +bool +find_term_func(const T* data, size_t size, T val) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_FIND_TERM_SIMD_FUNC(bool) + DISPATCH_FIND_TERM_SIMD_FUNC(int8_t) + DISPATCH_FIND_TERM_SIMD_FUNC(int16_t) + DISPATCH_FIND_TERM_SIMD_FUNC(int32_t) + DISPATCH_FIND_TERM_SIMD_FUNC(int64_t) + DISPATCH_FIND_TERM_SIMD_FUNC(float) + DISPATCH_FIND_TERM_SIMD_FUNC(double) +} + +template +void +equal_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(equal_val, double) +} + +template +void +less_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_val, double) +} + +template +void +greater_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_val, double) +} + +template +void +less_equal_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(less_equal_val, double) +} + +template +void +greater_equal_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(greater_equal_val, double) +} + +template +void +not_equal_val_func(const T* data, int64_t size, T val, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, bool) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int8_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int16_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int32_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, int64_t) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, float) + DISPATCH_COMPARE_VAL_SIMD_FUNC(not_equal_val, double) +} + +template +void +equal_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(equal_col, double) +} + +template +void +less_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_col, double) +} + +template +void +greater_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_col, double) +} + +template +void +less_equal_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(less_equal_col, double) +} + +template +void +greater_equal_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(greater_equal_col, double) +} + +template +void +not_equal_col_func(const T* left, const T* right, int64_t size, bool* res) { + static_assert( + std::is_integral::value || std::is_floating_point::value, + "T must be integral or float/double type"); + + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, bool) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int8_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int16_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int32_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, int64_t) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, float) + DISPATCH_COMPARE_COL_SIMD_FUNC(not_equal_col, double) +} + +template +void +compare_col_func(CompareType cmp_type, + const T* left, + const T* right, + int64_t size, + bool* res) { + if (cmp_type == CompareType::EQ) { + equal_col_func(left, right, size, res); + } else if (cmp_type == CompareType::NEQ) { + not_equal_col_func(left, right, size, res); + } else if (cmp_type == CompareType::GE) { + greater_equal_col_func(left, right, size, res); + } else if (cmp_type == CompareType::GT) { + greater_col_func(left, right, size, res); + } else if (cmp_type == CompareType::LE) { + less_equal_col_func(left, right, size, res); + } else if (cmp_type == CompareType::LT) { + less_col_func(left, right, size, res); + } +} + +} // namespace simd +} // namespace milvus diff --git a/internal/core/src/simd/ref.h b/internal/core/src/simd/ref.h index 6e90c7215a..f3b7af1a0c 100644 --- a/internal/core/src/simd/ref.h +++ b/internal/core/src/simd/ref.h @@ -45,5 +45,99 @@ FindTermRef(const T* src, size_t size, T val) { return false; } +template +void +EqualValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] == val; + } +} + +template +void +LessValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] < val; + } +} + +template +void +GreaterValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] > val; + } +} + +template +void +LessEqualValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] <= val; + } +} +template +void +GreaterEqualValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] >= val; + } +} +template +void +NotEqualValRef(const T* src, size_t size, T val, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = src[i] != val; + } +} + +template +void +EqualColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] == right[i]; + } +} + +template +void +LessColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] < right[i]; + } +} + +template +void +LessEqualColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] <= right[i]; + } +} + +template +void +GreaterColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] > right[i]; + } +} + +template +void +GreaterEqualColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] >= right[i]; + } +} + +template +void +NotEqualColumnRef(const T* left, const T* right, size_t size, bool* res) { + for (size_t i = 0; i < size; ++i) { + res[i] = left[i] != right[i]; + } +} + } // namespace simd } // namespace milvus diff --git a/internal/core/src/simd/sse2.cpp b/internal/core/src/simd/sse2.cpp index 40542bf22b..9726aec946 100644 --- a/internal/core/src/simd/sse2.cpp +++ b/internal/core/src/simd/sse2.cpp @@ -61,9 +61,8 @@ FindTermSSE2(const bool* src, size_t vec_size, bool val) { __m128i xmm_target = _mm_set1_epi8(val); __m128i xmm_data; size_t num_chunks = vec_size / 16; - for (size_t i = 0; i < num_chunks; i++) { - xmm_data = - _mm_loadu_si128(reinterpret_cast(src + 16 * i)); + for (size_t i = 0; i < num_chunks * 16; i += 16) { + xmm_data = _mm_loadu_si128(reinterpret_cast(src + i)); __m128i xmm_match = _mm_cmpeq_epi8(xmm_data, xmm_target); int mask = _mm_movemask_epi8(xmm_match); if (mask != 0) { @@ -71,7 +70,7 @@ FindTermSSE2(const bool* src, size_t vec_size, bool val) { } } - for (size_t i = 16 * num_chunks; i < vec_size; ++i) { + for (size_t i = num_chunks * 16; i < vec_size; ++i) { if (src[i] == val) { return true; } @@ -86,9 +85,8 @@ FindTermSSE2(const int8_t* src, size_t vec_size, int8_t val) { __m128i xmm_target = _mm_set1_epi8(val); __m128i xmm_data; size_t num_chunks = vec_size / 16; - for (size_t i = 0; i < num_chunks; i++) { - xmm_data = - _mm_loadu_si128(reinterpret_cast(src + 16 * i)); + for (size_t i = 0; i < num_chunks * 16; i += 16) { + xmm_data = _mm_loadu_si128(reinterpret_cast(src + i)); __m128i xmm_match = _mm_cmpeq_epi8(xmm_data, xmm_target); int mask = _mm_movemask_epi8(xmm_match); if (mask != 0) { @@ -96,7 +94,7 @@ FindTermSSE2(const int8_t* src, size_t vec_size, int8_t val) { } } - for (size_t i = 16 * num_chunks; i < vec_size; ++i) { + for (size_t i = num_chunks * 16; i < vec_size; ++i) { if (src[i] == val) { return true; } @@ -111,9 +109,8 @@ FindTermSSE2(const int16_t* src, size_t vec_size, int16_t val) { __m128i xmm_target = _mm_set1_epi16(val); __m128i xmm_data; size_t num_chunks = vec_size / 8; - for (size_t i = 0; i < num_chunks; i++) { - xmm_data = - _mm_loadu_si128(reinterpret_cast(src + i * 8)); + for (size_t i = 0; i < num_chunks * 8; i += 8) { + xmm_data = _mm_loadu_si128(reinterpret_cast(src + i)); __m128i xmm_match = _mm_cmpeq_epi16(xmm_data, xmm_target); int mask = _mm_movemask_epi8(xmm_match); if (mask != 0) { @@ -121,7 +118,7 @@ FindTermSSE2(const int16_t* src, size_t vec_size, int16_t val) { } } - for (size_t i = 8 * num_chunks; i < vec_size; ++i) { + for (size_t i = num_chunks * 8; i < vec_size; ++i) { if (src[i] == val) { return true; } @@ -136,9 +133,9 @@ FindTermSSE2(const int32_t* src, size_t vec_size, int32_t val) { size_t remaining_size = vec_size % 4; __m128i xmm_target = _mm_set1_epi32(val); - for (size_t i = 0; i < num_chunk; ++i) { + for (size_t i = 0; i < num_chunk * 4; i += 4) { __m128i xmm_data = - _mm_loadu_si128(reinterpret_cast(src + i * 4)); + _mm_loadu_si128(reinterpret_cast(src + i)); __m128i xmm_match = _mm_cmpeq_epi32(xmm_data, xmm_target); int mask = _mm_movemask_epi8(xmm_match); if (mask != 0) { @@ -180,9 +177,9 @@ FindTermSSE2(const int64_t* src, size_t vec_size, int64_t val) { size_t num_chunk = vec_size / 2; size_t remaining_size = vec_size % 2; - for (int64_t i = 0; i < num_chunk; i++) { + for (int64_t i = 0; i < num_chunk * 2; i += 2) { __m128i xmm_vec = - _mm_load_si128(reinterpret_cast(src + i * 2)); + _mm_load_si128(reinterpret_cast(src + i)); __m128i xmm_low = _mm_set1_epi32(low); __m128i xmm_high = _mm_set1_epi32(high); @@ -203,13 +200,6 @@ FindTermSSE2(const int64_t* src, size_t vec_size, int64_t val) { } } return false; - - // for (size_t i = 0; i < vec_size; ++i) { - // if (src[i] == val) { - // return true; - // } - // } - // return false; } template <> @@ -217,8 +207,8 @@ bool FindTermSSE2(const float* src, size_t vec_size, float val) { size_t num_chunks = vec_size / 4; __m128 xmm_target = _mm_set1_ps(val); - for (int i = 0; i < num_chunks; ++i) { - __m128 xmm_data = _mm_loadu_ps(src + 4 * i); + for (int i = 0; i < 4 * num_chunks; i += 4) { + __m128 xmm_data = _mm_loadu_ps(src + i); __m128 xmm_match = _mm_cmpeq_ps(xmm_data, xmm_target); int mask = _mm_movemask_ps(xmm_match); if (mask != 0) { @@ -239,8 +229,8 @@ bool FindTermSSE2(const double* src, size_t vec_size, double val) { size_t num_chunks = vec_size / 2; __m128d xmm_target = _mm_set1_pd(val); - for (int i = 0; i < num_chunks; ++i) { - __m128d xmm_data = _mm_loadu_pd(src + 2 * i); + for (int i = 0; i < 2 * num_chunks; i += 2) { + __m128d xmm_data = _mm_loadu_pd(src + i); __m128d xmm_match = _mm_cmpeq_pd(xmm_data, xmm_target); int mask = _mm_movemask_pd(xmm_match); if (mask != 0) { diff --git a/internal/core/src/simd/sse4.cpp b/internal/core/src/simd/sse4.cpp index 8585f9c648..bf3d08c76b 100644 --- a/internal/core/src/simd/sse4.cpp +++ b/internal/core/src/simd/sse4.cpp @@ -32,9 +32,9 @@ FindTermSSE4(const int64_t* src, size_t vec_size, int64_t val) { size_t remaining_size = vec_size % 2; __m128i xmm_target = _mm_set1_epi64x(val); - for (size_t i = 0; i < num_chunk; ++i) { + for (size_t i = 0; i < num_chunk * 2; i += 2) { __m128i xmm_data = - _mm_loadu_si128(reinterpret_cast(src + i * 2)); + _mm_loadu_si128(reinterpret_cast(src + i)); __m128i xmm_match = _mm_cmpeq_epi64(xmm_data, xmm_target); int mask = _mm_movemask_epi8(xmm_match); if (mask != 0) { diff --git a/internal/core/unittest/test_simd.cpp b/internal/core/unittest/test_simd.cpp index edfc410c23..cb157436c5 100644 --- a/internal/core/unittest/test_simd.cpp +++ b/internal/core/unittest/test_simd.cpp @@ -38,6 +38,7 @@ using FixedVector = boost::container::vector; #include "simd/sse4.h" #include "simd/avx2.h" #include "simd/avx512.h" +#include "simd/ref.h" using namespace milvus::simd; TEST(GetBitSetBlock, base_test_sse) { @@ -107,6 +108,30 @@ TEST(GetBitSetBlock, base_test_sse) { ASSERT_EQ(res, 0x1084210842108421); } +TEST(GetBitsetBlockPerf, bitset) { + FixedVector srcs; + for (size_t i = 0; i < 100000000; ++i) { + srcs.push_back(i % 2 == 0); + } + std::cout << "start test" << std::endl; + auto start = std::chrono::steady_clock::now(); + for (int i = 0; i < 10000000; ++i) + auto result = GetBitsetBlockSSE2(srcs.data() + i); + std::cout << "cost: " + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << "us" << std::endl; + start = std::chrono::steady_clock::now(); + for (int i = 0; i < 10000000; ++i) + auto result = GetBitsetBlockAVX2(srcs.data() + i); + std::cout << "cost: " + << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << "us" << std::endl; +} + TEST(GetBitSetBlock, base_test_avx2) { FixedVector src; for (int i = 0; i < 64; ++i) { @@ -1214,10 +1239,298 @@ TEST(AllBooleanNeon, performance) { } } +TEST(EqualVal, perf_int8) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector srcs(1000000); + for (int i = 0; i < 1000000; ++i) { + srcs[i] = i % 128; + } + FixedVector res(1000000); + auto start = std::chrono::steady_clock::now(); + EqualValRef(srcs.data(), 1000000, (int8_t)10, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + start = std::chrono::steady_clock::now(); + EqualValAVX512(srcs.data(), 1000000, (int8_t)10, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; +} + +template +void +TestCompareValAVX512Perf() { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector srcs(1000000); + for (int i = 0; i < 1000000; ++i) { + srcs[i] = i; + } + FixedVector res(1000000); + T target = 10; + auto start = std::chrono::steady_clock::now(); + EqualValRef(srcs.data(), 1000000, target, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + start = std::chrono::steady_clock::now(); + EqualValAVX512(srcs.data(), 1000000, target, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; +} + +TEST(EqualVal, perf_int16) { + TestCompareValAVX512Perf(); +} + +TEST(EqualVal, pref_int32) { + TestCompareValAVX512Perf(); +} + +TEST(EqualVal, perf_int64) { + TestCompareValAVX512Perf(); +} + +TEST(EqualVal, perf_float) { + TestCompareValAVX512Perf(); +} + +TEST(EqualVal, perf_double) { + TestCompareValAVX512Perf(); +} + +template +void +TestCompareValAVX512(int size, T target) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector vecs; + for (int i = 0; i < size; ++i) { + if constexpr (std::is_same_v) { + vecs.push_back(i % 127); + } else if constexpr (std::is_floating_point_v) { + vecs.push_back(i + 0.01); + } else { + vecs.push_back(i); + } + } + FixedVector res(size); + + EqualValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] == target) << i; + } + LessValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] < target) << i; + } + LessEqualValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] <= target) << i; + } + GreaterEqualValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] >= target) << i; + } + GreaterValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] > target) << i; + } + NotEqualValAVX512(vecs.data(), size, target, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], vecs[i] != target) << i; + } +} + +TEST(CompareVal, avx512_int8) { + TestCompareValAVX512(1000, 9); + TestCompareValAVX512(1000, 99); + TestCompareValAVX512(1001, 127); +} + +TEST(CompareVal, avx512_int16) { + TestCompareValAVX512(1000, 99); + TestCompareValAVX512(1000, 999); + TestCompareValAVX512(1001, 1000); +} + +TEST(CompareVal, avx512_int32) { + TestCompareValAVX512(1000, 99); + TestCompareValAVX512(1000, 999); + TestCompareValAVX512(1001, 1000); +} + +TEST(CompareVal, avx512_int64) { + TestCompareValAVX512(1000, 99); + TestCompareValAVX512(1000, 999); + TestCompareValAVX512(1001, 1000); +} + +TEST(CompareVal, avx512_float) { + TestCompareValAVX512(1000, 99.01); + TestCompareValAVX512(1000, 999.01); + TestCompareValAVX512(1001, 1000.01); +} + +TEST(CompareVal, avx512_double) { + TestCompareValAVX512(1000, 99.01); + TestCompareValAVX512(1000, 999.01); + TestCompareValAVX512(1001, 1000.01); +} + +template +void +TestCompareColumnAVX512Perf() { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::vector lefts(1000000); + for (int i = 0; i < 1000000; ++i) { + lefts[i] = i; + } + std::vector rights(1000000); + for (int i = 0; i < 1000000; ++i) { + rights[i] = i; + } + FixedVector res(1000000); + auto start = std::chrono::steady_clock::now(); + LessColumnRef(lefts.data(), rights.data(), 1000000, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; + start = std::chrono::steady_clock::now(); + LessColumnAVX512(lefts.data(), rights.data(), 1000000, res.data()); + std::cout << std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count() + << std::endl; +} + +TEST(LessColumn, pref_int32) { + TestCompareColumnAVX512Perf(); +} + +TEST(LessColumn, perf_int64) { + TestCompareColumnAVX512Perf(); +} + +TEST(LessColumn, perf_float) { + TestCompareColumnAVX512Perf(); +} + +TEST(LessColumn, perf_double) { + TestCompareColumnAVX512Perf(); +} + +template +void +TestCompareColumnAVX512(int size, T min_val, T max_val) { + if (!cpu_support_avx512()) { + PRINT_SKPI_TEST + return; + } + std::random_device rd; + std::mt19937 gen(rd()); + + std::vector left; + std::vector right; + if constexpr (std::is_same_v) { + std::uniform_real_distribution dis(min_val, max_val); + for (int i = 0; i < size; ++i) { + left.push_back(dis(gen)); + right.push_back(dis(gen)); + } + } else if constexpr (std::is_same_v) { + std::uniform_real_distribution dis(min_val, max_val); + for (int i = 0; i < size; ++i) { + left.push_back(dis(gen)); + right.push_back(dis(gen)); + } + } else { + std::uniform_int_distribution<> dis(min_val, max_val); + for (int i = 0; i < size; ++i) { + left.push_back(dis(gen)); + right.push_back(dis(gen)); + } + } + + FixedVector res(size); + + EqualColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] == right[i]) << i; + } + LessColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] < right[i]) << i; + } + GreaterColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] > right[i]) << i; + } + LessEqualColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] <= right[i]) << i; + } + GreaterEqualColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] >= right[i]) << i; + } + NotEqualColumnAVX512(left.data(), right.data(), size, res.data()); + for (int i = 0; i < size; i++) { + ASSERT_EQ(res[i], left[i] != right[i]) << i; + } +} + +TEST(CompareColumn, avx512_int8) { + TestCompareColumnAVX512(1000, -128, 127); + TestCompareColumnAVX512(1001, -128, 127); +} + +TEST(CompareColumn, avx512_int16) { + TestCompareColumnAVX512(1000, -1000, 1000); + TestCompareColumnAVX512(1001, -1000, 1000); +} + +TEST(CompareColumn, avx512_int32) { + TestCompareColumnAVX512(1000, -1000, 1000); + TestCompareColumnAVX512(1001, -1000, 1000); +} + +TEST(CompareColumn, avx512_int64) { + TestCompareColumnAVX512(1000, -1000, 1000); + TestCompareColumnAVX512(1001, -1000, 1000); +} + +TEST(CompareColumn, avx512_float) { + TestCompareColumnAVX512(1000, -1.0, 1.0); + TestCompareColumnAVX512(1001, -1.0, 1.0); +} + +TEST(CompareColumn, avx512_double) { + TestCompareColumnAVX512(1000, -1.0, 1.0); + TestCompareColumnAVX512(1001, -1.0, 1.0); +} + #endif int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} \ No newline at end of file +}