diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 2725fcf98a..eca42e9a6f 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -42,7 +42,8 @@ CheckBruteForceSearchParam(const FieldMeta& field, } knowhere::Json -PrepareBFSearchParams(const SearchInfo& search_info) { +PrepareBFSearchParams(const SearchInfo& search_info, + const std::map& index_info) { knowhere::Json search_cfg = search_info.search_params_; search_cfg[knowhere::meta::METRIC_TYPE] = search_info.metric_type_; @@ -62,6 +63,10 @@ PrepareBFSearchParams(const SearchInfo& search_info) { if (search_info.metric_type_ == knowhere::metric::BM25) { search_cfg[knowhere::meta::BM25_AVGDL] = search_info.search_params_[knowhere::meta::BM25_AVGDL]; + search_cfg[knowhere::meta::BM25_K1] = + std::stof(index_info.at(knowhere::meta::BM25_K1)); + search_cfg[knowhere::meta::BM25_B] = + std::stof(index_info.at(knowhere::meta::BM25_B)); } return search_cfg; } @@ -71,6 +76,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset, const void* chunk_data_raw, int64_t chunk_rows, const SearchInfo& search_info, + const std::map& index_info, const BitsetView& bitset, DataType data_type) { SubSearchResult sub_result(dataset.num_queries, @@ -87,12 +93,11 @@ BruteForceSearch(const dataset::SearchDataset& dataset, base_dataset->SetIsSparse(true); query_dataset->SetIsSparse(true); } - auto search_cfg = PrepareBFSearchParams(search_info); + auto search_cfg = PrepareBFSearchParams(search_info, index_info); // `range_search_k` is only used as one of the conditions for iterator early termination. // not gurantee to return exactly `range_search_k` results, which may be more or less. // set it to -1 will return all results in the range. search_cfg[knowhere::meta::RANGE_SEARCH_K] = topk; - sub_result.mutable_seg_offsets().resize(nq * topk); sub_result.mutable_distances().resize(nq * topk); @@ -201,6 +206,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset, const void* chunk_data_raw, int64_t chunk_rows, const SearchInfo& search_info, + const std::map& index_info, const BitsetView& bitset, DataType data_type) { auto nq = dataset.num_queries; @@ -211,7 +217,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset, base_dataset->SetIsSparse(true); query_dataset->SetIsSparse(true); } - auto search_cfg = PrepareBFSearchParams(search_info); + auto search_cfg = PrepareBFSearchParams(search_info, index_info); knowhere::expected> iterators_val; diff --git a/internal/core/src/query/SearchBruteForce.h b/internal/core/src/query/SearchBruteForce.h index b7cad461b1..3cf6863b91 100644 --- a/internal/core/src/query/SearchBruteForce.h +++ b/internal/core/src/query/SearchBruteForce.h @@ -28,6 +28,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset, const void* chunk_data_raw, int64_t chunk_rows, const SearchInfo& search_info, + const std::map& index_info, const BitsetView& bitset, DataType data_type); @@ -36,6 +37,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset, const void* chunk_data_raw, int64_t chunk_rows, const SearchInfo& search_info, + const std::map& index_info, const BitsetView& bitset, DataType data_type); diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index f71efb7562..7e6606261f 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -14,6 +14,10 @@ #include "common/Tracer.h" #include "common/Types.h" #include "SearchOnGrowing.h" +#include +#include "knowhere/comp/index_param.h" +#include "knowhere/config.h" +#include "log/Log.h" #include "query/SearchBruteForce.h" #include "query/SearchOnIndex.h" @@ -109,6 +113,15 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, dataset::SearchDataset search_dataset{ metric_type, num_queries, topk, round_decimal, dim, query_data}; int32_t current_chunk_id = 0; + + // get K1 and B from index for bm25 brute force + std::map index_info; + if (metric_type == knowhere::metric::BM25) { + index_info = segment.get_indexing_record() + .get_field_index_meta(vecfield_id) + .GetIndexParams(); + } + // step 3: brute force search where small indexing is unavailable auto vec_ptr = record.get_data_base(vecfield_id); auto vec_size_per_chunk = vec_ptr->get_size_per_chunk(); @@ -129,6 +142,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, chunk_data, size_per_chunk, info, + index_info, sub_view, data_type); final_qr.merge(sub_qr); @@ -137,6 +151,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, chunk_data, size_per_chunk, info, + index_info, sub_view, data_type); diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index 2bd7e8edb8..2a0dc5f7b0 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -81,6 +81,7 @@ void SearchOnSealed(const Schema& schema, std::shared_ptr column, const SearchInfo& search_info, + const std::map& index_info, const void* query_data, int64_t num_queries, int64_t row_count, @@ -137,6 +138,7 @@ SearchOnSealed(const Schema& schema, vec_data, chunk_size, search_info, + index_info, bitset_view, data_type); final_qr.merge(sub_qr); @@ -145,6 +147,7 @@ SearchOnSealed(const Schema& schema, vec_data, chunk_size, search_info, + index_info, bitset_view, data_type); for (auto& o : sub_qr.mutable_seg_offsets()) { @@ -177,6 +180,7 @@ void SearchOnSealed(const Schema& schema, const void* vec_data, const SearchInfo& search_info, + const std::map& index_info, const void* query_data, int64_t num_queries, int64_t row_count, @@ -200,13 +204,23 @@ SearchOnSealed(const Schema& schema, auto data_type = field.get_data_type(); CheckBruteForceSearchParam(field, search_info); if (search_info.group_by_field_id_.has_value()) { - auto sub_qr = BruteForceSearchIterators( - dataset, vec_data, row_count, search_info, bitset, data_type); + auto sub_qr = BruteForceSearchIterators(dataset, + vec_data, + row_count, + search_info, + index_info, + bitset, + data_type); result.AssembleChunkVectorIterators( num_queries, 1, {0}, sub_qr.chunk_iterators()); } else { - auto sub_qr = BruteForceSearch( - dataset, vec_data, row_count, search_info, bitset, data_type); + auto sub_qr = BruteForceSearch(dataset, + vec_data, + row_count, + search_info, + index_info, + bitset, + data_type); result.distances_ = std::move(sub_qr.mutable_distances()); result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets()); } diff --git a/internal/core/src/query/SearchOnSealed.h b/internal/core/src/query/SearchOnSealed.h index a9261c793f..b3254b1c14 100644 --- a/internal/core/src/query/SearchOnSealed.h +++ b/internal/core/src/query/SearchOnSealed.h @@ -31,6 +31,7 @@ void SearchOnSealed(const Schema& schema, std::shared_ptr column, const SearchInfo& search_info, + const std::map& index_info, const void* query_data, int64_t num_queries, int64_t row_count, @@ -41,6 +42,7 @@ void SearchOnSealed(const Schema& schema, const void* vec_data, const SearchInfo& search_info, + const std::map& index_info, const void* query_data, int64_t num_queries, int64_t row_count, diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp index 3a6f660844..817f98caf3 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp @@ -940,9 +940,18 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, AssertInfo(num_rows_.has_value(), "Can't get row count value"); auto row_count = num_rows_.value(); auto vec_data = fields_.at(field_id); + + // get index params for bm25 brute force + std::map index_info; + if (search_info.metric_type_ == knowhere::metric::BM25) { + auto index_info = + col_index_meta_->GetFieldIndexMeta(field_id).GetIndexParams(); + } + query::SearchOnSealed(*schema_, vec_data, search_info, + index_info, query_data, query_count, row_count, diff --git a/internal/core/src/segcore/FieldIndexing.h b/internal/core/src/segcore/FieldIndexing.h index 339ed5612b..8752e63e2c 100644 --- a/internal/core/src/segcore/FieldIndexing.h +++ b/internal/core/src/segcore/FieldIndexing.h @@ -22,9 +22,11 @@ #include "AckResponder.h" #include "InsertRecord.h" +#include "common/FieldMeta.h" #include "common/Schema.h" #include "common/IndexMeta.h" #include "IndexConfigGenerator.h" +#include "knowhere/config.h" #include "log/Log.h" #include "segcore/SegcoreConfig.h" #include "index/VectorIndex.h" @@ -429,6 +431,11 @@ class IndexingRecord { return *ptr; } + const FieldIndexMeta& + get_field_index_meta(FieldId fieldId) const { + return index_meta_->GetFieldIndexMeta(fieldId); + } + bool is_in(FieldId field_id) const { return field_indexings_.count(field_id); diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index ad5c77e26a..bfd847df1f 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -985,9 +985,18 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info, AssertInfo(num_rows_.has_value(), "Can't get row count value"); auto row_count = num_rows_.value(); auto vec_data = fields_.at(field_id); + + // get index params for bm25 brute force + std::map index_info; + if (search_info.metric_type_ == knowhere::metric::BM25) { + auto index_info = + col_index_meta_->GetFieldIndexMeta(field_id).GetIndexParams(); + } + query::SearchOnSealed(*schema_, vec_data->Data(), search_info, + index_info, query_data, query_count, row_count, diff --git a/internal/core/unittest/test_bf.cpp b/internal/core/unittest/test_bf.cpp index 94db431d53..75c4566145 100644 --- a/internal/core/unittest/test_bf.cpp +++ b/internal/core/unittest/test_bf.cpp @@ -122,6 +122,7 @@ class TestFloatSearchBruteForce : public ::testing::Test { auto base = GenFloatVecs(dim, nb, metric_type); auto query = GenFloatVecs(dim, nq, metric_type); + auto index_info = std::map{}; dataset::SearchDataset dataset{ metric_type, nq, topk, -1, dim, query.data()}; @@ -137,6 +138,7 @@ class TestFloatSearchBruteForce : public ::testing::Test { base.data(), nb, search_info, + index_info, bitset_view, DataType::VECTOR_FLOAT); for (int i = 0; i < nq; i++) { diff --git a/internal/core/unittest/test_bf_sparse.cpp b/internal/core/unittest/test_bf_sparse.cpp index 7c9e466208..0c5ce6d1a6 100644 --- a/internal/core/unittest/test_bf_sparse.cpp +++ b/internal/core/unittest/test_bf_sparse.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include +#include #include #include "common/Utils.h" @@ -98,6 +99,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { auto base = milvus::segcore::GenerateRandomSparseFloatVector(nb); auto query = milvus::segcore::GenerateRandomSparseFloatVector(nq); + auto index_info = std::map{}; SearchInfo search_info; search_info.topk_ = topk; search_info.metric_type_ = metric_type; @@ -108,6 +110,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { base.get(), nb, search_info, + index_info, bitset_view, DataType::VECTOR_SPARSE_FLOAT)); return; @@ -116,6 +119,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { base.get(), nb, search_info, + index_info, bitset_view, DataType::VECTOR_SPARSE_FLOAT); for (int i = 0; i < nq; i++) { @@ -130,6 +134,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { base.get(), nb, search_info, + index_info, bitset_view, DataType::VECTOR_SPARSE_FLOAT); for (int i = 0; i < nq; i++) { @@ -143,6 +148,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { base.get(), nb, search_info, + index_info, bitset_view, DataType::VECTOR_SPARSE_FLOAT); auto iterators = result3.chunk_iterators(); diff --git a/internal/core/unittest/test_chunked_segment.cpp b/internal/core/unittest/test_chunked_segment.cpp index 0995dfb842..e27deddf59 100644 --- a/internal/core/unittest/test_chunked_segment.cpp +++ b/internal/core/unittest/test_chunked_segment.cpp @@ -106,11 +106,13 @@ TEST(test_chunk_segment, TestSearchOnSealed) { auto query_ds = segcore::DataGen(schema, 1); auto col_query_data = query_ds.get_col(fakevec_id); auto query_data = col_query_data.data(); + auto index_info = std::map{}; SearchResult search_result; query::SearchOnSealed(*schema, column, search_info, + index_info, query_data, 1, total_row_count, @@ -135,6 +137,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) { query::SearchOnSealed(*schema, column, search_info, + index_info, query_data, 1, total_row_count, diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index 6f228f7f58..32ead43755 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -173,6 +173,7 @@ TEST(Indexing, BinaryBruteForce) { }; SearchInfo search_info; + auto index_info = std::map{}; search_info.topk_ = topk; search_info.round_decimal_ = round_decimal; search_info.metric_type_ = metric_type; @@ -180,6 +181,7 @@ TEST(Indexing, BinaryBruteForce) { bin_vec.data(), N, search_info, + index_info, nullptr, DataType::VECTOR_BINARY); diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index cb4ccf4131..b8e1e4c090 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -1236,6 +1236,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) { CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + auto index_info = std::map{}; std::vector ph_group_arr = {ph_group.get()}; @@ -1257,6 +1258,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) { vec_col.data(), N, search_info, + index_info, nullptr, DataType::VECTOR_FLOAT);