mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
fix: bm25 brute force search need index params k1 and b (#37721)
relate: https://github.com/milvus-io/milvus/issues/35853 --------- Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
parent
12ed40e125
commit
e9391acf80
@ -42,7 +42,8 @@ CheckBruteForceSearchParam(const FieldMeta& field,
|
||||
}
|
||||
|
||||
knowhere::Json
|
||||
PrepareBFSearchParams(const SearchInfo& search_info) {
|
||||
PrepareBFSearchParams(const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info) {
|
||||
knowhere::Json search_cfg = search_info.search_params_;
|
||||
|
||||
search_cfg[knowhere::meta::METRIC_TYPE] = search_info.metric_type_;
|
||||
@ -62,6 +63,10 @@ PrepareBFSearchParams(const SearchInfo& search_info) {
|
||||
if (search_info.metric_type_ == knowhere::metric::BM25) {
|
||||
search_cfg[knowhere::meta::BM25_AVGDL] =
|
||||
search_info.search_params_[knowhere::meta::BM25_AVGDL];
|
||||
search_cfg[knowhere::meta::BM25_K1] =
|
||||
std::stof(index_info.at(knowhere::meta::BM25_K1));
|
||||
search_cfg[knowhere::meta::BM25_B] =
|
||||
std::stof(index_info.at(knowhere::meta::BM25_B));
|
||||
}
|
||||
return search_cfg;
|
||||
}
|
||||
@ -71,6 +76,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
|
||||
const void* chunk_data_raw,
|
||||
int64_t chunk_rows,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type) {
|
||||
SubSearchResult sub_result(dataset.num_queries,
|
||||
@ -87,12 +93,11 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
|
||||
base_dataset->SetIsSparse(true);
|
||||
query_dataset->SetIsSparse(true);
|
||||
}
|
||||
auto search_cfg = PrepareBFSearchParams(search_info);
|
||||
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
|
||||
// `range_search_k` is only used as one of the conditions for iterator early termination.
|
||||
// not gurantee to return exactly `range_search_k` results, which may be more or less.
|
||||
// set it to -1 will return all results in the range.
|
||||
search_cfg[knowhere::meta::RANGE_SEARCH_K] = topk;
|
||||
|
||||
sub_result.mutable_seg_offsets().resize(nq * topk);
|
||||
sub_result.mutable_distances().resize(nq * topk);
|
||||
|
||||
@ -201,6 +206,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
|
||||
const void* chunk_data_raw,
|
||||
int64_t chunk_rows,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type) {
|
||||
auto nq = dataset.num_queries;
|
||||
@ -211,7 +217,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
|
||||
base_dataset->SetIsSparse(true);
|
||||
query_dataset->SetIsSparse(true);
|
||||
}
|
||||
auto search_cfg = PrepareBFSearchParams(search_info);
|
||||
auto search_cfg = PrepareBFSearchParams(search_info, index_info);
|
||||
|
||||
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
|
||||
iterators_val;
|
||||
|
||||
@ -28,6 +28,7 @@ BruteForceSearch(const dataset::SearchDataset& dataset,
|
||||
const void* chunk_data_raw,
|
||||
int64_t chunk_rows,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type);
|
||||
|
||||
@ -36,6 +37,7 @@ BruteForceSearchIterators(const dataset::SearchDataset& dataset,
|
||||
const void* chunk_data_raw,
|
||||
int64_t chunk_rows,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type);
|
||||
|
||||
|
||||
@ -14,6 +14,10 @@
|
||||
#include "common/Tracer.h"
|
||||
#include "common/Types.h"
|
||||
#include "SearchOnGrowing.h"
|
||||
#include <cstddef>
|
||||
#include "knowhere/comp/index_param.h"
|
||||
#include "knowhere/config.h"
|
||||
#include "log/Log.h"
|
||||
#include "query/SearchBruteForce.h"
|
||||
#include "query/SearchOnIndex.h"
|
||||
|
||||
@ -109,6 +113,15 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
dataset::SearchDataset search_dataset{
|
||||
metric_type, num_queries, topk, round_decimal, dim, query_data};
|
||||
int32_t current_chunk_id = 0;
|
||||
|
||||
// get K1 and B from index for bm25 brute force
|
||||
std::map<std::string, std::string> index_info;
|
||||
if (metric_type == knowhere::metric::BM25) {
|
||||
index_info = segment.get_indexing_record()
|
||||
.get_field_index_meta(vecfield_id)
|
||||
.GetIndexParams();
|
||||
}
|
||||
|
||||
// step 3: brute force search where small indexing is unavailable
|
||||
auto vec_ptr = record.get_data_base(vecfield_id);
|
||||
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
|
||||
@ -129,6 +142,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
chunk_data,
|
||||
size_per_chunk,
|
||||
info,
|
||||
index_info,
|
||||
sub_view,
|
||||
data_type);
|
||||
final_qr.merge(sub_qr);
|
||||
@ -137,6 +151,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
chunk_data,
|
||||
size_per_chunk,
|
||||
info,
|
||||
index_info,
|
||||
sub_view,
|
||||
data_type);
|
||||
|
||||
|
||||
@ -81,6 +81,7 @@ void
|
||||
SearchOnSealed(const Schema& schema,
|
||||
std::shared_ptr<ChunkedColumnBase> column,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const void* query_data,
|
||||
int64_t num_queries,
|
||||
int64_t row_count,
|
||||
@ -137,6 +138,7 @@ SearchOnSealed(const Schema& schema,
|
||||
vec_data,
|
||||
chunk_size,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
data_type);
|
||||
final_qr.merge(sub_qr);
|
||||
@ -145,6 +147,7 @@ SearchOnSealed(const Schema& schema,
|
||||
vec_data,
|
||||
chunk_size,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
data_type);
|
||||
for (auto& o : sub_qr.mutable_seg_offsets()) {
|
||||
@ -177,6 +180,7 @@ void
|
||||
SearchOnSealed(const Schema& schema,
|
||||
const void* vec_data,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const void* query_data,
|
||||
int64_t num_queries,
|
||||
int64_t row_count,
|
||||
@ -200,13 +204,23 @@ SearchOnSealed(const Schema& schema,
|
||||
auto data_type = field.get_data_type();
|
||||
CheckBruteForceSearchParam(field, search_info);
|
||||
if (search_info.group_by_field_id_.has_value()) {
|
||||
auto sub_qr = BruteForceSearchIterators(
|
||||
dataset, vec_data, row_count, search_info, bitset, data_type);
|
||||
auto sub_qr = BruteForceSearchIterators(dataset,
|
||||
vec_data,
|
||||
row_count,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type);
|
||||
result.AssembleChunkVectorIterators(
|
||||
num_queries, 1, {0}, sub_qr.chunk_iterators());
|
||||
} else {
|
||||
auto sub_qr = BruteForceSearch(
|
||||
dataset, vec_data, row_count, search_info, bitset, data_type);
|
||||
auto sub_qr = BruteForceSearch(dataset,
|
||||
vec_data,
|
||||
row_count,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type);
|
||||
result.distances_ = std::move(sub_qr.mutable_distances());
|
||||
result.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
|
||||
}
|
||||
|
||||
@ -31,6 +31,7 @@ void
|
||||
SearchOnSealed(const Schema& schema,
|
||||
std::shared_ptr<ChunkedColumnBase> column,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const void* query_data,
|
||||
int64_t num_queries,
|
||||
int64_t row_count,
|
||||
@ -41,6 +42,7 @@ void
|
||||
SearchOnSealed(const Schema& schema,
|
||||
const void* vec_data,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const void* query_data,
|
||||
int64_t num_queries,
|
||||
int64_t row_count,
|
||||
|
||||
@ -940,9 +940,18 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
|
||||
AssertInfo(num_rows_.has_value(), "Can't get row count value");
|
||||
auto row_count = num_rows_.value();
|
||||
auto vec_data = fields_.at(field_id);
|
||||
|
||||
// get index params for bm25 brute force
|
||||
std::map<std::string, std::string> index_info;
|
||||
if (search_info.metric_type_ == knowhere::metric::BM25) {
|
||||
auto index_info =
|
||||
col_index_meta_->GetFieldIndexMeta(field_id).GetIndexParams();
|
||||
}
|
||||
|
||||
query::SearchOnSealed(*schema_,
|
||||
vec_data,
|
||||
search_info,
|
||||
index_info,
|
||||
query_data,
|
||||
query_count,
|
||||
row_count,
|
||||
|
||||
@ -22,9 +22,11 @@
|
||||
|
||||
#include "AckResponder.h"
|
||||
#include "InsertRecord.h"
|
||||
#include "common/FieldMeta.h"
|
||||
#include "common/Schema.h"
|
||||
#include "common/IndexMeta.h"
|
||||
#include "IndexConfigGenerator.h"
|
||||
#include "knowhere/config.h"
|
||||
#include "log/Log.h"
|
||||
#include "segcore/SegcoreConfig.h"
|
||||
#include "index/VectorIndex.h"
|
||||
@ -429,6 +431,11 @@ class IndexingRecord {
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
const FieldIndexMeta&
|
||||
get_field_index_meta(FieldId fieldId) const {
|
||||
return index_meta_->GetFieldIndexMeta(fieldId);
|
||||
}
|
||||
|
||||
bool
|
||||
is_in(FieldId field_id) const {
|
||||
return field_indexings_.count(field_id);
|
||||
|
||||
@ -985,9 +985,18 @@ SegmentSealedImpl::vector_search(SearchInfo& search_info,
|
||||
AssertInfo(num_rows_.has_value(), "Can't get row count value");
|
||||
auto row_count = num_rows_.value();
|
||||
auto vec_data = fields_.at(field_id);
|
||||
|
||||
// get index params for bm25 brute force
|
||||
std::map<std::string, std::string> index_info;
|
||||
if (search_info.metric_type_ == knowhere::metric::BM25) {
|
||||
auto index_info =
|
||||
col_index_meta_->GetFieldIndexMeta(field_id).GetIndexParams();
|
||||
}
|
||||
|
||||
query::SearchOnSealed(*schema_,
|
||||
vec_data->Data(),
|
||||
search_info,
|
||||
index_info,
|
||||
query_data,
|
||||
query_count,
|
||||
row_count,
|
||||
|
||||
@ -122,6 +122,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
|
||||
|
||||
auto base = GenFloatVecs(dim, nb, metric_type);
|
||||
auto query = GenFloatVecs(dim, nq, metric_type);
|
||||
auto index_info = std::map<std::string, std::string>{};
|
||||
|
||||
dataset::SearchDataset dataset{
|
||||
metric_type, nq, topk, -1, dim, query.data()};
|
||||
@ -137,6 +138,7 @@ class TestFloatSearchBruteForce : public ::testing::Test {
|
||||
base.data(),
|
||||
nb,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_FLOAT);
|
||||
for (int i = 0; i < nq; i++) {
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <map>
|
||||
#include <random>
|
||||
|
||||
#include "common/Utils.h"
|
||||
@ -98,6 +99,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
||||
|
||||
auto base = milvus::segcore::GenerateRandomSparseFloatVector(nb);
|
||||
auto query = milvus::segcore::GenerateRandomSparseFloatVector(nq);
|
||||
auto index_info = std::map<std::string, std::string>{};
|
||||
SearchInfo search_info;
|
||||
search_info.topk_ = topk;
|
||||
search_info.metric_type_ = metric_type;
|
||||
@ -108,6 +110,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
||||
base.get(),
|
||||
nb,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT));
|
||||
return;
|
||||
@ -116,6 +119,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
||||
base.get(),
|
||||
nb,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT);
|
||||
for (int i = 0; i < nq; i++) {
|
||||
@ -130,6 +134,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
||||
base.get(),
|
||||
nb,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT);
|
||||
for (int i = 0; i < nq; i++) {
|
||||
@ -143,6 +148,7 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
||||
base.get(),
|
||||
nb,
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT);
|
||||
auto iterators = result3.chunk_iterators();
|
||||
|
||||
@ -106,11 +106,13 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
|
||||
auto query_ds = segcore::DataGen(schema, 1);
|
||||
auto col_query_data = query_ds.get_col<float>(fakevec_id);
|
||||
auto query_data = col_query_data.data();
|
||||
auto index_info = std::map<std::string, std::string>{};
|
||||
SearchResult search_result;
|
||||
|
||||
query::SearchOnSealed(*schema,
|
||||
column,
|
||||
search_info,
|
||||
index_info,
|
||||
query_data,
|
||||
1,
|
||||
total_row_count,
|
||||
@ -135,6 +137,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
|
||||
query::SearchOnSealed(*schema,
|
||||
column,
|
||||
search_info,
|
||||
index_info,
|
||||
query_data,
|
||||
1,
|
||||
total_row_count,
|
||||
|
||||
@ -173,6 +173,7 @@ TEST(Indexing, BinaryBruteForce) {
|
||||
};
|
||||
|
||||
SearchInfo search_info;
|
||||
auto index_info = std::map<std::string, std::string>{};
|
||||
search_info.topk_ = topk;
|
||||
search_info.round_decimal_ = round_decimal;
|
||||
search_info.metric_type_ = metric_type;
|
||||
@ -180,6 +181,7 @@ TEST(Indexing, BinaryBruteForce) {
|
||||
bin_vec.data(),
|
||||
N,
|
||||
search_info,
|
||||
index_info,
|
||||
nullptr,
|
||||
DataType::VECTOR_BINARY);
|
||||
|
||||
|
||||
@ -1236,6 +1236,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
|
||||
CreatePlaceholderGroupFromBlob(num_queries, 16, query_ptr);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
auto index_info = std::map<std::string, std::string>{};
|
||||
|
||||
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
|
||||
|
||||
@ -1257,6 +1258,7 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
|
||||
vec_col.data(),
|
||||
N,
|
||||
search_info,
|
||||
index_info,
|
||||
nullptr,
|
||||
DataType::VECTOR_FLOAT);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user