mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Merge FloatSearch() and BinarySearch() into SearchOnGrowing() (#18498)
Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
4edc8d3f81
commit
7cd37fc6dd
@ -16,42 +16,35 @@
|
||||
|
||||
namespace milvus::query {
|
||||
|
||||
Status
|
||||
FloatSearch(const segcore::SegmentGrowingImpl& segment,
|
||||
const query::SearchInfo& info,
|
||||
const float* query_data,
|
||||
int64_t num_queries,
|
||||
int64_t ins_barrier,
|
||||
const BitsetView& bitset,
|
||||
SearchResult& results) {
|
||||
// TODO: small index is disabled, however 3 unittests still call this API, consider to remove this API
|
||||
// - Query::ExecWithPredicateLoader
|
||||
// - Query::ExecWithPredicate
|
||||
// - Query::ExecWithoutPredicate
|
||||
int32_t
|
||||
FloatIndexSearch(const segcore::SegmentGrowingImpl& segment,
|
||||
const query::SearchInfo& info,
|
||||
const void* query_data,
|
||||
int64_t num_queries,
|
||||
int64_t ins_barrier,
|
||||
const BitsetView& bitset,
|
||||
SubSearchResult& results) {
|
||||
auto& schema = segment.get_schema();
|
||||
auto& indexing_record = segment.get_indexing_record();
|
||||
auto& record = segment.get_insert_record();
|
||||
|
||||
// step 1.1: get meta
|
||||
// step 1.2: get which vector field to search
|
||||
auto vecfield_id = info.field_id_;
|
||||
auto& field = schema[vecfield_id];
|
||||
|
||||
AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT, "[FloatSearch]Field data type isn't VECTOR_FLOAT");
|
||||
auto dim = field.get_dim();
|
||||
auto topk = info.topk_;
|
||||
auto total_count = topk * num_queries;
|
||||
auto metric_type = info.metric_type_;
|
||||
auto round_decimal = info.round_decimal_;
|
||||
// step 2: small indexing search
|
||||
// std::vector<int64_t> final_uids(total_count, -1);
|
||||
// std::vector<float> final_dis(total_count, std::numeric_limits<float>::max());
|
||||
SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal);
|
||||
dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data};
|
||||
dataset::SearchDataset search_dataset{info.metric_type_, num_queries, info.topk_,
|
||||
info.round_decimal_, field.get_dim(), query_data};
|
||||
auto vec_ptr = record.get_field_data<FloatVector>(vecfield_id);
|
||||
|
||||
int current_chunk_id = 0;
|
||||
|
||||
if (indexing_record.is_in(vecfield_id)) {
|
||||
auto max_indexed_id = indexing_record.get_finished_ack();
|
||||
const auto& field_indexing = indexing_record.get_vec_field_indexing(vecfield_id);
|
||||
auto search_conf = field_indexing.get_search_params(topk);
|
||||
auto search_conf = field_indexing.get_search_params(info.topk_);
|
||||
AssertInfo(vec_ptr->get_size_per_chunk() == field_indexing.get_size_per_chunk(),
|
||||
"[FloatSearch]Chunk size of vector not equal to chunk size of field index");
|
||||
|
||||
@ -72,24 +65,60 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
|
||||
}
|
||||
}
|
||||
|
||||
final_qr.merge(sub_qr);
|
||||
results.merge(sub_qr);
|
||||
current_chunk_id++;
|
||||
}
|
||||
}
|
||||
return current_chunk_id;
|
||||
}
|
||||
|
||||
void
|
||||
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
int64_t ins_barrier,
|
||||
const query::SearchInfo& info,
|
||||
const void* query_data,
|
||||
int64_t num_queries,
|
||||
const BitsetView& bitset,
|
||||
SearchResult& results) {
|
||||
auto& schema = segment.get_schema();
|
||||
auto& indexing_record = segment.get_indexing_record();
|
||||
auto& record = segment.get_insert_record();
|
||||
|
||||
// step 1.1: get meta
|
||||
// step 1.2: get which vector field to search
|
||||
auto vecfield_id = info.field_id_;
|
||||
auto& field = schema[vecfield_id];
|
||||
auto data_type = field.get_data_type();
|
||||
AssertInfo(datatype_is_vector(data_type), "[SearchOnGrowing]Data type isn't vector type");
|
||||
|
||||
auto dim = field.get_dim();
|
||||
auto topk = info.topk_;
|
||||
auto metric_type = info.metric_type_;
|
||||
auto round_decimal = info.round_decimal_;
|
||||
|
||||
// step 2: small indexing search
|
||||
SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal);
|
||||
dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data};
|
||||
|
||||
int32_t current_chunk_id = 0;
|
||||
if (field.get_data_type() == DataType::VECTOR_FLOAT) {
|
||||
current_chunk_id = FloatIndexSearch(segment, info, query_data, num_queries, ins_barrier, bitset, final_qr);
|
||||
}
|
||||
|
||||
// step 3: brute force search where small indexing is unavailable
|
||||
auto vec_ptr = record.get_field_data_base(vecfield_id);
|
||||
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
|
||||
auto max_chunk = upper_div(ins_barrier, vec_size_per_chunk);
|
||||
|
||||
for (int chunk_id = current_chunk_id; chunk_id < max_chunk; ++chunk_id) {
|
||||
auto& chunk = vec_ptr->get_chunk(chunk_id);
|
||||
auto chunk_data = vec_ptr->get_chunk_data(chunk_id);
|
||||
|
||||
auto element_begin = chunk_id * vec_size_per_chunk;
|
||||
auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_size_per_chunk);
|
||||
auto size_per_chunk = element_end - element_begin;
|
||||
|
||||
auto sub_view = bitset.subview(element_begin, size_per_chunk);
|
||||
auto sub_qr = BruteForceSearch(search_dataset, chunk.data(), size_per_chunk, sub_view);
|
||||
auto sub_qr = BruteForceSearch(search_dataset, chunk_data, size_per_chunk, sub_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : sub_qr.mutable_seg_offsets()) {
|
||||
@ -103,92 +132,6 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment,
|
||||
results.seg_offsets_ = std::move(final_qr.mutable_seg_offsets());
|
||||
results.unity_topK_ = topk;
|
||||
results.total_nq_ = num_queries;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
BinarySearch(const segcore::SegmentGrowingImpl& segment,
|
||||
const query::SearchInfo& info,
|
||||
const uint8_t* query_data,
|
||||
int64_t num_queries,
|
||||
int64_t ins_barrier,
|
||||
const BitsetView& bitset,
|
||||
SearchResult& results) {
|
||||
auto& schema = segment.get_schema();
|
||||
auto& indexing_record = segment.get_indexing_record();
|
||||
auto& record = segment.get_insert_record();
|
||||
// step 1: binary search to find the barrier of the snapshot
|
||||
// auto ins_barrier = get_barrier(record, timestamp);
|
||||
auto metric_type = info.metric_type_;
|
||||
// auto del_barrier = get_barrier(deleted_record_, timestamp);
|
||||
|
||||
// step 2.1: get meta
|
||||
// step 2.2: get which vector field to search
|
||||
auto vecfield_id = info.field_id_;
|
||||
auto& field = schema[vecfield_id];
|
||||
|
||||
AssertInfo(field.get_data_type() == DataType::VECTOR_BINARY, "[BinarySearch]Field data type isn't VECTOR_BINARY");
|
||||
auto dim = field.get_dim();
|
||||
auto topk = info.topk_;
|
||||
auto total_count = topk * num_queries;
|
||||
auto round_decimal = info.round_decimal_;
|
||||
// step 3: small indexing search
|
||||
query::dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data};
|
||||
|
||||
auto vec_ptr = record.get_field_data<BinaryVector>(vecfield_id);
|
||||
auto max_indexed_id = 0;
|
||||
|
||||
// step 4: brute force search where small indexing is unavailable
|
||||
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
|
||||
auto max_chunk = upper_div(ins_barrier, vec_size_per_chunk);
|
||||
SubSearchResult final_result(num_queries, topk, metric_type, round_decimal);
|
||||
for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) {
|
||||
auto& chunk = vec_ptr->get_chunk(chunk_id);
|
||||
auto element_begin = chunk_id * vec_size_per_chunk;
|
||||
auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_size_per_chunk);
|
||||
auto nsize = element_end - element_begin;
|
||||
|
||||
auto sub_view = bitset.subview(element_begin, nsize);
|
||||
auto sub_result = BruteForceSearch(search_dataset, chunk.data(), nsize, sub_view);
|
||||
|
||||
// convert chunk uid to segment uid
|
||||
for (auto& x : sub_result.mutable_seg_offsets()) {
|
||||
if (x != -1) {
|
||||
x += chunk_id * vec_size_per_chunk;
|
||||
}
|
||||
}
|
||||
final_result.merge(sub_result);
|
||||
}
|
||||
|
||||
final_result.round_values();
|
||||
results.distances_ = std::move(final_result.mutable_distances());
|
||||
results.seg_offsets_ = std::move(final_result.mutable_seg_offsets());
|
||||
results.unity_topK_ = topk;
|
||||
results.total_nq_ = num_queries;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO: refactor and merge this into one
|
||||
void
|
||||
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
int64_t ins_barrier,
|
||||
const query::SearchInfo& info,
|
||||
const void* query_data,
|
||||
int64_t num_queries,
|
||||
const BitsetView& bitset,
|
||||
SearchResult& results) {
|
||||
// TODO: add data_type to info
|
||||
auto data_type = segment.get_schema()[info.field_id_].get_data_type();
|
||||
AssertInfo(datatype_is_vector(data_type), "[SearchOnGrowing]Data type isn't vector type");
|
||||
if (data_type == DataType::VECTOR_FLOAT) {
|
||||
auto typed_data = reinterpret_cast<const float*>(query_data);
|
||||
FloatSearch(segment, info, typed_data, num_queries, ins_barrier, bitset, results);
|
||||
} else {
|
||||
auto typed_data = reinterpret_cast<const uint8_t*>(query_data);
|
||||
BinarySearch(segment, info, typed_data, num_queries, ins_barrier, bitset, results);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
||||
@ -370,15 +370,8 @@ SegmentSealedImpl::vector_search(int64_t vec_count,
|
||||
PanicInfo("Field Data is not loaded");
|
||||
}
|
||||
|
||||
query::dataset::SearchDataset dataset;
|
||||
dataset.query_data = query_data;
|
||||
dataset.num_queries = query_count;
|
||||
// if(field_meta.is)
|
||||
dataset.metric_type = search_info.metric_type_;
|
||||
dataset.topk = search_info.topk_;
|
||||
dataset.dim = field_meta.get_dim();
|
||||
dataset.round_decimal = search_info.round_decimal_;
|
||||
|
||||
query::dataset::SearchDataset dataset{search_info.metric_type_, query_count, search_info.topk_,
|
||||
search_info.round_decimal_, field_meta.get_dim(), query_data};
|
||||
AssertInfo(get_bit(field_data_ready_bitset_, field_id),
|
||||
"Can't get bitset element at " + std::to_string(field_id.get()));
|
||||
AssertInfo(row_count_opt_.has_value(), "Can't get row count value");
|
||||
@ -388,13 +381,10 @@ SegmentSealedImpl::vector_search(int64_t vec_count,
|
||||
auto chunk_data = vec_data->get_chunk_data(0);
|
||||
auto sub_qr = query::BruteForceSearch(dataset, chunk_data, row_count, bitset);
|
||||
|
||||
SearchResult results;
|
||||
results.distances_ = std::move(sub_qr.mutable_distances());
|
||||
results.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
|
||||
results.unity_topK_ = dataset.topk;
|
||||
results.total_nq_ = dataset.num_queries;
|
||||
|
||||
output = std::move(results);
|
||||
output.distances_ = std::move(sub_qr.mutable_distances());
|
||||
output.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets());
|
||||
output.unity_topK_ = dataset.topk;
|
||||
output.total_nq_ = dataset.num_queries;
|
||||
}
|
||||
|
||||
void
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user