fix: bm25 brute force search need index params k1 and b (#37721)

relate: https://github.com/milvus-io/milvus/issues/35853

---------

Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
aoiasd 2024-11-18 15:44:31 +08:00 committed by GitHub
parent 12ed40e125
commit e9391acf80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 87 additions and 8 deletions

View File

@ -42,7 +42,8 @@ CheckBruteForceSearchParam(const FieldMeta& field,
}
knowhere::Json
PrepareBFSearchParams(const SearchInfo& search_info) {
PrepareBFSearchParams(const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info) {
knowhere::Json search_cfg = search_info.search_params_;
search_cfg[knowhere::meta::METRIC_TYPE] = search_info.metric_type_;
@ -62,6 +63,10 @@ PrepareBFSearchParams(const SearchInfo& search_info) {
if (search_info.metric_type_ == knowhere::metric::BM25) {
search_cfg[knowhere::meta::BM25_AVGDL] =
search_info.search_params_[knowhere::meta::BM25_AVGDL];
search_cfg[knowhere::meta::BM25_K1] =
std::stof(index_info.at(knowhere::meta::BM25_K1));
search_cfg[knowhere::meta::BM25_B] =
std::stof(index_info.at(knowhere::meta::BM25_B));
}
return search_cfg;
}
@ -71,6 +76,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
SubSearchResult sub_result(dataset.num_queries,
@ -87,12 +93,11 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
}
auto search_cfg = PrepareBFSearchParams(search_info);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
// `range_search_k` is only used as one of the conditions for iterator early termination.
// not gurantee to return exactly `range_search_k` results, which may be more or less.
// set it to -1 will return all results in the range.
search_cfg[knowhere::meta::RANGE_SEARCH_K] = topk;
sub_result.mutable_seg_offsets().resize(nq * topk);
sub_result.mutable_distances().resize(nq * topk);
@ -201,6 +206,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type) {
auto nq = dataset.num_queries;
@ -211,7 +217,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true);
}
auto search_cfg = PrepareBFSearchParams(search_info);
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
iterators_val;

View File

@ -28,6 +28,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);
@ -36,6 +37,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
const void* chunk_data_raw,
int64_t chunk_rows,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const BitsetView& bitset,
DataType data_type);

View File

@ -14,6 +14,10 @@
#include "common/Tracer.h"
#include "common/Types.h"
#include "SearchOnGrowing.h"
#include <cstddef>
#include "knowhere/comp/index_param.h"
#include "knowhere/config.h"
#include "log/Log.h"
#include "query/SearchBruteForce.h"
#include "query/SearchOnIndex.h"
@ -109,6 +113,15 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
dataset::SearchDataset search_dataset{
metric_type, num_queries, topk, round_decimal, dim, query_data};
int32_t current_chunk_id = 0;
// get K1 and B from index for bm25 brute force
std::map<std::string, std::string> index_info;
if (metric_type == knowhere::metric::BM25) {
index_info = segment.get_indexing_record()
.get_field_index_meta(vecfield_id)
.GetIndexParams();
}
// step 3: brute force search where small indexing is unavailable
auto vec_ptr = record.get_data_base(vecfield_id);
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
@ -129,6 +142,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
chunk_data,
size_per_chunk,
info,
index_info,
sub_view,
data_type);
final_qr.merge(sub_qr);
@ -137,6 +151,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
chunk_data,
size_per_chunk,
info,
index_info,
sub_view,
data_type);

View File

@ -81,6 +81,7 @@ void
SearchOnSealed(const Schema& schema,
std::shared_ptr<ChunkedColumnBase> column,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,
@ -137,6 +138,7 @@ SearchOnSealed(const Schema& schema,
vec_data,
chunk_size,
search_info,
index_info,
bitset_view,
data_type);
final_qr.merge(sub_qr);
@ -145,6 +147,7 @@ SearchOnSealed(const Schema& schema,
vec_data,
chunk_size,
search_info,
index_info,
bitset_view,
data_type);
for (auto& o : sub_qr.mutable_seg_offsets()) {
@ -177,6 +180,7 @@ void
SearchOnSealed(const Schema& schema,
const void* vec_data,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,
@ -200,13 +204,23 @@ SearchOnSealed(const Schema& schema,
auto data_type = field.get_data_type();
CheckBruteForceSearchParam(field, search_info);
if (search_info.group_by_field_id_.has_value()) {
auto sub_qr = BruteForceSearchIterators(
dataset, vec_data, row_count, search_info, bitset, data_type);
auto sub_qr = BruteForceSearchIterators(dataset,
vec_data,
row_count,
search_info,
index_info,
bitset,
data_type);
result.AssembleChunkVectorIterators(
num_queries, 1, {0}, sub_qr.chunk_iterators());
} else {
auto sub_qr = BruteForceSearch(
dataset, vec_data, row_count, search_info, bitset, data_type);
auto sub_qr = BruteForceSearch(dataset,
vec_data,
row_count,
search_info,
index_info,
bitset,
data_type);
result.distances_ = std::move(sub_qr.mutable_distances());
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
}

View File

@ -31,6 +31,7 @@ void
SearchOnSealed(const Schema& schema,
std::shared_ptr<ChunkedColumnBase> column,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,
@ -41,6 +42,7 @@ void
SearchOnSealed(const Schema& schema,
const void* vec_data,
const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info,
const void* query_data,
int64_t num_queries,
int64_t row_count,

View File

@ -940,9 +940,18 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
AssertInfo(num_rows_.has_value(), "Can't get row count value");
auto row_count = num_rows_.value();
auto vec_data = fields_.at(field_id);
// get index params for bm25 brute force
std::map<std::string, std::string> index_info;
if (search_info.metric_type_ == knowhere::metric::BM25) {
auto index_info =
col_index_meta_->GetFieldIndexMeta(field_id).GetIndexParams();
}
query::SearchOnSealed(*schema_,
vec_data,
search_info,
index_info,
query_data,
query_count,
row_count,

View File

@ -22,9 +22,11 @@
#include "AckResponder.h"
#include "InsertRecord.h"
#include "common/FieldMeta.h"
#include "common/Schema.h"
#include "common/IndexMeta.h"
#include "IndexConfigGenerator.h"
#include "knowhere/config.h"
#include "log/Log.h"
#include "segcore/SegcoreConfig.h"
#include "index/VectorIndex.h"
@ -429,6 +431,11 @@ class IndexingRecord {
return *ptr;
}
const FieldIndexMeta&
get_field_index_meta(FieldId fieldId) const {
return index_meta_->GetFieldIndexMeta(fieldId);
}
bool
is_in(FieldId field_id) const {
return field_indexings_.count(field_id);

View File

@ -985,9 +985,18 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info,
AssertInfo(num_rows_.has_value(), "Can't get row count value");
auto row_count = num_rows_.value();
auto vec_data = fields_.at(field_id);
// get index params for bm25 brute force
std::map<std::string, std::string> index_info;
if (search_info.metric_type_ == knowhere::metric::BM25) {
auto index_info =
col_index_meta_->GetFieldIndexMeta(field_id).GetIndexParams();
}
query::SearchOnSealed(*schema_,
vec_data->Data(),
search_info,
index_info,
query_data,
query_count,
row_count,

View File

@ -122,6 +122,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
auto base = GenFloatVecs(dim, nb, metric_type);
auto query = GenFloatVecs(dim, nq, metric_type);
auto index_info = std::map<std::string, std::string>{};
dataset::SearchDataset dataset{
metric_type, nq, topk, -1, dim, query.data()};
@ -137,6 +138,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
base.data(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_FLOAT);
for (int i = 0; i < nq; i++) {

View File

@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <gtest/gtest.h>
#include <map>
#include <random>
#include "common/Utils.h"
@ -98,6 +99,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
auto base = milvus::segcore::GenerateRandomSparseFloatVector(nb);
auto query = milvus::segcore::GenerateRandomSparseFloatVector(nq);
auto index_info = std::map<std::string, std::string>{};
SearchInfo search_info;
search_info.topk_ = topk;
search_info.metric_type_ = metric_type;
@ -108,6 +110,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT));
return;
@ -116,6 +119,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
for (int i = 0; i < nq; i++) {
@ -130,6 +134,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
for (int i = 0; i < nq; i++) {
@ -143,6 +148,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
base.get(),
nb,
search_info,
index_info,
bitset_view,
DataType::VECTOR_SPARSE_FLOAT);
auto iterators = result3.chunk_iterators();

View File

@ -106,11 +106,13 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
auto query_ds = segcore::DataGen(schema, 1);
auto col_query_data = query_ds.get_col<float>(fakevec_id);
auto query_data = col_query_data.data();
auto index_info = std::map<std::string, std::string>{};
SearchResult search_result;
query::SearchOnSealed(*schema,
column,
search_info,
index_info,
query_data,
1,
total_row_count,
@ -135,6 +137,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
query::SearchOnSealed(*schema,
column,
search_info,
index_info,
query_data,
1,
total_row_count,

View File

@ -173,6 +173,7 @@ TEST(Indexing, BinaryBruteForce) {
};
SearchInfo search_info;
auto index_info = std::map<std::string, std::string>{};
search_info.topk_ = topk;
search_info.round_decimal_ = round_decimal;
search_info.metric_type_ = metric_type;
@ -180,6 +181,7 @@ TEST(Indexing, BinaryBruteForce) {
bin_vec.data(),
N,
search_info,
index_info,
nullptr,
DataType::VECTOR_BINARY);

View File

@ -1236,6 +1236,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
auto index_info = std::map<std::string, std::string>{};
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
@ -1257,6 +1258,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
vec_col.data(),
N,
search_info,
index_info,
nullptr,
DataType::VECTOR_FLOAT);