From 7f6b7998db48c1b65377f90fd774b0955a7fe22a Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Thu, 30 Dec 2021 12:09:46 +0800 Subject: [PATCH] Let FAISS support all CPU SIMD (#14587) Signed-off-by: yudong.cai --- .../knowhere/archive/KnowhereConfig.cpp | 9 ++--- .../src/index/thirdparty/faiss/FaissHook.cpp | 36 ++++++++++--------- .../src/index/thirdparty/faiss/FaissHook.h | 8 ++--- .../thirdparty/faiss/utils/BinaryDistance.cpp | 8 ++--- 4 files changed, 30 insertions(+), 31 deletions(-) diff --git a/internal/core/src/index/knowhere/knowhere/archive/KnowhereConfig.cpp b/internal/core/src/index/knowhere/knowhere/archive/KnowhereConfig.cpp index 1fbbacb1a8..a696e2ce1c 100644 --- a/internal/core/src/index/knowhere/knowhere/archive/KnowhereConfig.cpp +++ b/internal/core/src/index/knowhere/knowhere/archive/KnowhereConfig.cpp @@ -55,12 +55,9 @@ KnowhereConfig::SetSimdType(const SimdType simd_type) { } std::string cpu_flag; - if (faiss::hook_init(cpu_flag)) { - LOG_KNOWHERE_DEBUG_ << "FAISS hook " << cpu_flag; - return cpu_flag; - } - - KNOWHERE_THROW_MSG("FAISS hook fail, CPU not supported!"); + faiss::hook_init(cpu_flag); + LOG_KNOWHERE_DEBUG_ << "FAISS hook " << cpu_flag; + return cpu_flag; } void diff --git a/internal/core/src/index/thirdparty/faiss/FaissHook.cpp b/internal/core/src/index/thirdparty/faiss/FaissHook.cpp index 1f68ffb805..363559269d 100644 --- a/internal/core/src/index/thirdparty/faiss/FaissHook.cpp +++ b/internal/core/src/index/thirdparty/faiss/FaissHook.cpp @@ -32,34 +32,28 @@ sq_sel_inv_list_scanner_func_ptr sq_sel_inv_list_scanner = sq_select_inverted_li /*****************************************************************************/ -bool support_avx512() { - if (!faiss_use_avx512) return false; - +bool cpu_support_avx512() { InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); return (instruction_set_inst.AVX512F() && instruction_set_inst.AVX512DQ() && instruction_set_inst.AVX512BW()); } -bool support_avx2() { - if (!faiss_use_avx2) return false; - +bool cpu_support_avx2() { InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); return (instruction_set_inst.AVX2()); } -bool support_sse4_2() { - if (!faiss_use_sse4_2) return false; - +bool cpu_support_sse4_2() { InstructionSet& instruction_set_inst = InstructionSet::GetInstance(); return (instruction_set_inst.SSE42()); } -bool hook_init(std::string& cpu_flag) { +void hook_init(std::string& cpu_flag) { static std::mutex hook_mutex; std::lock_guard lock(hook_mutex); - if (support_avx512()) { + if (faiss_use_avx512 && cpu_support_avx512()) { /* for IVFFLAT */ fvec_inner_product = fvec_inner_product_avx512; fvec_L2sqr = fvec_L2sqr_avx512; @@ -72,7 +66,7 @@ bool hook_init(std::string& cpu_flag) { sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_avx512; cpu_flag = "AVX512"; - } else if (support_avx2()) { + } else if (faiss_use_avx2 && cpu_support_avx2()) { /* for IVFFLAT */ fvec_inner_product = fvec_inner_product_avx; fvec_L2sqr = fvec_L2sqr_avx; @@ -85,7 +79,7 @@ bool hook_init(std::string& cpu_flag) { sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_avx; cpu_flag = "AVX2"; - } else if (support_sse4_2()) { + } else if (faiss_use_sse4_2 && cpu_support_sse4_2()) { /* for IVFFLAT */ fvec_inner_product = fvec_inner_product_sse; fvec_L2sqr = fvec_L2sqr_sse; @@ -99,11 +93,19 @@ bool hook_init(std::string& cpu_flag) { cpu_flag = "SSE4_2"; } else { - cpu_flag = "UNSUPPORTED"; - return false; - } + /* for IVFFLAT */ + fvec_inner_product = fvec_inner_product_ref; + fvec_L2sqr = fvec_L2sqr_ref; + fvec_L1 = fvec_L1_ref; + fvec_Linf = fvec_Linf_ref; - return true; + /* for IVFSQ */ + sq_get_distance_computer = sq_get_distance_computer_ref; + sq_sel_quantizer = sq_select_quantizer_ref; + sq_sel_inv_list_scanner = sq_select_inverted_list_scanner_ref; + + cpu_flag = "REF"; + } } } // namespace faiss diff --git a/internal/core/src/index/thirdparty/faiss/FaissHook.h b/internal/core/src/index/thirdparty/faiss/FaissHook.h index 4f1513e99b..940923d28c 100644 --- a/internal/core/src/index/thirdparty/faiss/FaissHook.h +++ b/internal/core/src/index/thirdparty/faiss/FaissHook.h @@ -31,10 +31,10 @@ extern sq_get_distance_computer_func_ptr sq_get_distance_computer; extern sq_sel_quantizer_func_ptr sq_sel_quantizer; extern sq_sel_inv_list_scanner_func_ptr sq_sel_inv_list_scanner; -bool support_avx512(); -bool support_avx2(); -bool support_sse4_2(); +bool cpu_support_avx512(); +bool cpu_support_avx2(); +bool cpu_support_sse4_2(); -bool hook_init(std::string& cpu_flag); +void hook_init(std::string& cpu_flag); } // namespace faiss diff --git a/internal/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp b/internal/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp index 4269ba9426..e6f69dbd43 100644 --- a/internal/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp +++ b/internal/core/src/index/thirdparty/faiss/utils/BinaryDistance.cpp @@ -421,7 +421,7 @@ void binary_distance_knn_hc ( { switch (metric_type) { case METRIC_Jaccard: { - if (support_avx2() && ncodes > 64) { + if (cpu_support_avx2() && ncodes > 64) { binary_distance_knn_hc (ncodes, ha, a, b, nb, bitset); } else { @@ -449,7 +449,7 @@ void binary_distance_knn_hc ( } case METRIC_Hamming: { - if (support_avx2() && ncodes > 64) { + if (cpu_support_avx2() && ncodes > 64) { binary_distance_knn_hc (ncodes, ha, a, b, nb, bitset); } else { @@ -554,7 +554,7 @@ void binary_range_search( case METRIC_Tanimoto: radius = Tanimoto_2_Jaccard(radius); case METRIC_Jaccard: { - if (support_avx2() && ncodes > 64) { + if (cpu_support_avx2() && ncodes > 64) { binary_range_search (a, b, na, nb, ncodes, radius, result, buffer_size, bitset); } else { @@ -592,7 +592,7 @@ void binary_range_search( } case METRIC_Hamming: { - if (support_avx2() && ncodes > 64) { + if (cpu_support_avx2() && ncodes > 64) { binary_range_search (a, b, na, nb, ncodes, radius, result, buffer_size, bitset); } else {