diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 446246ce9d..cb53937c85 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -12,105 +12,26 @@ #include #include -#include -#include - #include "SearchBruteForce.h" -#include "SubSearchResult.h" -#include "common/Types.h" -#include "segcore/Utils.h" +#include "knowhere/archive/BruteForce.h" namespace milvus::query { -// copy from faiss/IndexBinaryFlat.cpp::IndexBinaryFlat::search() -// disable lint to make further migration easier -static void -binary_search(const knowhere::MetricType& metric_type, - const uint8_t* xb, - int64_t ntotal, - int code_size, - idx_t n, // num_queries - const uint8_t* x, - idx_t k, // topk - float* D, - idx_t* labels, - const BitsetView bitset) { - using namespace faiss; // NOLINT - if (metric_type == knowhere::metric::JACCARD || metric_type == knowhere::metric::TANIMOTO) { - float_maxheap_array_t res = {size_t(n), size_t(k), labels, D}; - binary_distance_knn_hc(METRIC_Jaccard, &res, x, xb, ntotal, code_size, bitset); - - if (metric_type == knowhere::metric::TANIMOTO) { - for (int i = 0; i < k * n; i++) { - D[i] = Jaccard_2_Tanimoto(D[i]); - } - } - } else if (metric_type == knowhere::metric::HAMMING) { - std::vector int_distances(n * k); - int_maxheap_array_t res = {size_t(n), size_t(k), labels, int_distances.data()}; - binary_distance_knn_hc(METRIC_Hamming, &res, x, xb, ntotal, code_size, bitset); - for (int i = 0; i < n * k; ++i) { - D[i] = int_distances[i]; - } - } else if (metric_type == knowhere::metric::SUBSTRUCTURE || metric_type == knowhere::metric::SUPERSTRUCTURE) { - // only matched ids will be chosen, not to use heap - auto faiss_metric_type = knowhere::GetFaissMetricType(metric_type); - binary_distance_knn_mc(faiss_metric_type, x, xb, n, ntotal, k, code_size, D, labels, bitset); - } else { - std::string msg = "binary search not support metric type: " + metric_type; - PanicInfo(msg); - } -} - SubSearchResult -BinarySearchBruteForce(const dataset::SearchDataset& dataset, - const void* chunk_data_raw, - int64_t size_per_chunk, - const BitsetView& bitset) { - // TODO: refactor the internal function - auto metric_type = dataset.metric_type; - auto num_queries = dataset.num_queries; - auto topk = dataset.topk; - auto dim = dataset.dim; - auto round_decimal = dataset.round_decimal; - SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal); - auto query_data = reinterpret_cast(dataset.query_data); - auto chunk_data = reinterpret_cast(chunk_data_raw); - - int64_t code_size = dim / 8; - binary_search(metric_type, chunk_data, size_per_chunk, code_size, num_queries, query_data, topk, - sub_result.get_distances(), sub_result.get_seg_offsets(), bitset); +BruteForceSearch(const dataset::SearchDataset& dataset, + const void* chunk_data_raw, + int64_t chunk_rows, + const BitsetView& bitset) { + SubSearchResult sub_result(dataset.num_queries, dataset.topk, dataset.metric_type, dataset.round_decimal); + try { + knowhere::BruteForceSearch(dataset.metric_type, chunk_data_raw, dataset.query_data, dataset.dim, chunk_rows, + dataset.num_queries, dataset.topk, sub_result.get_seg_offsets(), + sub_result.get_distances(), bitset); + } catch (std::exception& e) { + PanicInfo(e.what()); + } sub_result.round_values(); return sub_result; } -SubSearchResult -FloatSearchBruteForce(const dataset::SearchDataset& dataset, - const void* chunk_data_raw, - int64_t size_per_chunk, - const BitsetView& bitset) { - auto metric_type = dataset.metric_type; - auto num_queries = dataset.num_queries; - auto topk = dataset.topk; - auto dim = dataset.dim; - auto round_decimal = dataset.round_decimal; - SubSearchResult sub_qr(num_queries, topk, metric_type, round_decimal); - auto query_data = reinterpret_cast(dataset.query_data); - auto chunk_data = reinterpret_cast(chunk_data_raw); - if (metric_type == knowhere::metric::L2) { - faiss::float_maxheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_seg_offsets(), - sub_qr.get_distances()}; - faiss::knn_L2sqr(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, nullptr, bitset); - } else if (metric_type == knowhere::metric::IP) { - faiss::float_minheap_array_t buf{(size_t)num_queries, (size_t)topk, sub_qr.get_seg_offsets(), - sub_qr.get_distances()}; - faiss::knn_inner_product(query_data, chunk_data, dim, num_queries, size_per_chunk, &buf, bitset); - } else { - std::string msg = "search not support metric type: " + metric_type; - PanicInfo(msg); - } - sub_qr.round_values(); - return sub_qr; -} - } // namespace milvus::query diff --git a/internal/core/src/query/SearchBruteForce.h b/internal/core/src/query/SearchBruteForce.h index e2e31dd1ce..d4cbaaccfc 100644 --- a/internal/core/src/query/SearchBruteForce.h +++ b/internal/core/src/query/SearchBruteForce.h @@ -11,24 +11,16 @@ #pragma once -#include "common/Schema.h" #include "common/BitsetView.h" #include "query/SubSearchResult.h" #include "query/helper.h" -#include "segcore/ConcurrentVector.h" namespace milvus::query { SubSearchResult -BinarySearchBruteForce(const dataset::SearchDataset& dataset, - const void* chunk_data_raw, - int64_t size_per_chunk, - const BitsetView& bitset); - -SubSearchResult -FloatSearchBruteForce(const dataset::SearchDataset& dataset, - const void* chunk_data_raw, - int64_t size_per_chunk, - const BitsetView& bitset); +BruteForceSearch(const dataset::SearchDataset& dataset, + const void* chunk_data_raw, + int64_t chunk_rows, + const BitsetView& bitset); } // namespace milvus::query diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index b8530f6e7d..79f425b2d0 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -89,7 +89,7 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment, auto size_per_chunk = element_end - element_begin; auto sub_view = bitset.subview(element_begin, size_per_chunk); - auto sub_qr = FloatSearchBruteForce(search_dataset, chunk.data(), size_per_chunk, sub_view); + auto sub_qr = BruteForceSearch(search_dataset, chunk.data(), size_per_chunk, sub_view); // convert chunk uid to segment uid for (auto& x : sub_qr.mutable_seg_offsets()) { @@ -150,7 +150,7 @@ BinarySearch(const segcore::SegmentGrowingImpl& segment, auto nsize = element_end - element_begin; auto sub_view = bitset.subview(element_begin, nsize); - auto sub_result = BinarySearchBruteForce(search_dataset, chunk.data(), nsize, sub_view); + auto sub_result = BruteForceSearch(search_dataset, chunk.data(), nsize, sub_view); // convert chunk uid to segment uid for (auto& x : sub_result.mutable_seg_offsets()) { diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 4d94330e25..2d09cb8c34 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -381,14 +381,7 @@ SegmentSealedImpl::vector_search(int64_t vec_count, auto vec_data = insert_record_.get_field_data_base(field_id); AssertInfo(vec_data->num_chunk() == 1, "num chunk not equal to 1 for sealed segment"); auto chunk_data = vec_data->get_chunk_data(0); - - auto sub_qr = [&] { - if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) { - return query::FloatSearchBruteForce(dataset, chunk_data, row_count, bitset); - } else { - return query::BinarySearchBruteForce(dataset, chunk_data, row_count, bitset); - } - }(); + auto sub_qr = query::BruteForceSearch(dataset, chunk_data, row_count, bitset); SearchResult results; results.distances_ = std::move(sub_qr.mutable_distances()); diff --git a/internal/core/thirdparty/knowhere/CMakeLists.txt b/internal/core/thirdparty/knowhere/CMakeLists.txt index 5d47257992..8f49df5ccc 100644 --- a/internal/core/thirdparty/knowhere/CMakeLists.txt +++ b/internal/core/thirdparty/knowhere/CMakeLists.txt @@ -11,8 +11,8 @@ # or implied. See the License for the specific language governing permissions and limitations under the License. #------------------------------------------------------------------------------- -set( KNOWHERE_VERSION v1.1.13 ) -set( KNOWHERE_SOURCE_MD5 "5ea7ce8ae71b4aa496ee3c66ccf56d5a") +set( KNOWHERE_VERSION v1.1.14 ) +set( KNOWHERE_SOURCE_MD5 "de9303c3f667662aa92f3676a1f6ef96") if ( DEFINED ENV{MILVUS_KNOWHERE_URL} ) set( KNOWHERE_SOURCE_URL "$ENV{MILVUS_KNOWHERE_URL}" ) diff --git a/internal/core/unittest/test_bf.cpp b/internal/core/unittest/test_bf.cpp index d9a5c419b3..181fd00fdc 100644 --- a/internal/core/unittest/test_bf.cpp +++ b/internal/core/unittest/test_bf.cpp @@ -114,10 +114,10 @@ class TestFloatSearchBruteForce : public ::testing::Test { dataset::SearchDataset dataset{metric_type, nq, topk, -1, dim, query.data()}; if (!is_supported_float_metric(metric_type)) { - ASSERT_ANY_THROW(FloatSearchBruteForce(dataset, base.data(), nb, bitset_view)); + ASSERT_ANY_THROW(BruteForceSearch(dataset, base.data(), nb, bitset_view)); return; } - auto result = FloatSearchBruteForce(dataset, base.data(), nb, bitset_view); + auto result = BruteForceSearch(dataset, base.data(), nb, bitset_view); for (int i = 0; i < nq; i++) { auto ref = Ref(base.data(), query.data() + i * dim, nb, dim, topk, metric_type); auto ans = result.get_seg_offsets() + i * topk; diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index 370bcf39d8..1e73c392a5 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -324,7 +324,7 @@ TEST(Indexing, BinaryBruteForce) { query_data // }; - auto sub_result = query::BinarySearchBruteForce(search_dataset, bin_vec.data(), N, nullptr); + auto sub_result = query::BruteForceSearch(search_dataset, bin_vec.data(), N, nullptr); SearchResult sr; sr.total_nq_ = num_queries; diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index d02ebce3a5..a398f1dd78 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -538,7 +538,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) { dim, // query_ptr // }; - auto sub_result = FloatSearchBruteForce(search_dataset, vec_col.data(), N, nullptr); + auto sub_result = BruteForceSearch(search_dataset, vec_col.data(), N, nullptr); auto sr = segment->Search(plan.get(), ph_group.get(), time); segment->FillPrimaryKeys(plan.get(), *sr);