From 7cd37fc6dd76d9b574f9e054f8c13a1d64e48fea Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Wed, 3 Aug 2022 11:56:36 +0800 Subject: [PATCH] Merge FloatSearch() and BinarySearch() into SearchOnGrowing() (#18498) Signed-off-by: yudong.cai --- internal/core/src/query/SearchOnGrowing.cpp | 165 ++++++------------ .../core/src/segcore/SegmentSealedImpl.cpp | 22 +-- 2 files changed, 60 insertions(+), 127 deletions(-) diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index 79f425b2d0..4a9fa0364d 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -16,42 +16,35 @@ namespace milvus::query { -Status -FloatSearch(const segcore::SegmentGrowingImpl& segment, - const query::SearchInfo& info, - const float* query_data, - int64_t num_queries, - int64_t ins_barrier, - const BitsetView& bitset, - SearchResult& results) { +// TODO: small index is disabled, however 3 unittests still call this API, consider to remove this API +// - Query::ExecWithPredicateLoader +// - Query::ExecWithPredicate +// - Query::ExecWithoutPredicate +int32_t +FloatIndexSearch(const segcore::SegmentGrowingImpl& segment, + const query::SearchInfo& info, + const void* query_data, + int64_t num_queries, + int64_t ins_barrier, + const BitsetView& bitset, + SubSearchResult& results) { auto& schema = segment.get_schema(); auto& indexing_record = segment.get_indexing_record(); auto& record = segment.get_insert_record(); - // step 1.1: get meta - // step 1.2: get which vector field to search auto vecfield_id = info.field_id_; auto& field = schema[vecfield_id]; AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT, "[FloatSearch]Field data type isn't VECTOR_FLOAT"); - auto dim = field.get_dim(); - auto topk = info.topk_; - auto total_count = topk * num_queries; - auto metric_type = info.metric_type_; - auto round_decimal = info.round_decimal_; - // step 2: small indexing search - // std::vector final_uids(total_count, -1); - // std::vector final_dis(total_count, std::numeric_limits::max()); - SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal); - dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data}; + dataset::SearchDataset search_dataset{info.metric_type_, num_queries, info.topk_, + info.round_decimal_, field.get_dim(), query_data}; auto vec_ptr = record.get_field_data(vecfield_id); int current_chunk_id = 0; - if (indexing_record.is_in(vecfield_id)) { auto max_indexed_id = indexing_record.get_finished_ack(); const auto& field_indexing = indexing_record.get_vec_field_indexing(vecfield_id); - auto search_conf = field_indexing.get_search_params(topk); + auto search_conf = field_indexing.get_search_params(info.topk_); AssertInfo(vec_ptr->get_size_per_chunk() == field_indexing.get_size_per_chunk(), "[FloatSearch]Chunk size of vector not equal to chunk size of field index"); @@ -72,24 +65,60 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment, } } - final_qr.merge(sub_qr); + results.merge(sub_qr); current_chunk_id++; } } + return current_chunk_id; +} + +void +SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, + int64_t ins_barrier, + const query::SearchInfo& info, + const void* query_data, + int64_t num_queries, + const BitsetView& bitset, + SearchResult& results) { + auto& schema = segment.get_schema(); + auto& indexing_record = segment.get_indexing_record(); + auto& record = segment.get_insert_record(); + + // step 1.1: get meta + // step 1.2: get which vector field to search + auto vecfield_id = info.field_id_; + auto& field = schema[vecfield_id]; + auto data_type = field.get_data_type(); + AssertInfo(datatype_is_vector(data_type), "[SearchOnGrowing]Data type isn't vector type"); + + auto dim = field.get_dim(); + auto topk = info.topk_; + auto metric_type = info.metric_type_; + auto round_decimal = info.round_decimal_; + + // step 2: small indexing search + SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal); + dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data}; + + int32_t current_chunk_id = 0; + if (field.get_data_type() == DataType::VECTOR_FLOAT) { + current_chunk_id = FloatIndexSearch(segment, info, query_data, num_queries, ins_barrier, bitset, final_qr); + } // step 3: brute force search where small indexing is unavailable + auto vec_ptr = record.get_field_data_base(vecfield_id); auto vec_size_per_chunk = vec_ptr->get_size_per_chunk(); auto max_chunk = upper_div(ins_barrier, vec_size_per_chunk); for (int chunk_id = current_chunk_id; chunk_id < max_chunk; ++chunk_id) { - auto& chunk = vec_ptr->get_chunk(chunk_id); + auto chunk_data = vec_ptr->get_chunk_data(chunk_id); auto element_begin = chunk_id * vec_size_per_chunk; auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_size_per_chunk); auto size_per_chunk = element_end - element_begin; auto sub_view = bitset.subview(element_begin, size_per_chunk); - auto sub_qr = BruteForceSearch(search_dataset, chunk.data(), size_per_chunk, sub_view); + auto sub_qr = BruteForceSearch(search_dataset, chunk_data, size_per_chunk, sub_view); // convert chunk uid to segment uid for (auto& x : sub_qr.mutable_seg_offsets()) { @@ -103,92 +132,6 @@ FloatSearch(const segcore::SegmentGrowingImpl& segment, results.seg_offsets_ = std::move(final_qr.mutable_seg_offsets()); results.unity_topK_ = topk; results.total_nq_ = num_queries; - - return Status::OK(); -} - -Status -BinarySearch(const segcore::SegmentGrowingImpl& segment, - const query::SearchInfo& info, - const uint8_t* query_data, - int64_t num_queries, - int64_t ins_barrier, - const BitsetView& bitset, - SearchResult& results) { - auto& schema = segment.get_schema(); - auto& indexing_record = segment.get_indexing_record(); - auto& record = segment.get_insert_record(); - // step 1: binary search to find the barrier of the snapshot - // auto ins_barrier = get_barrier(record, timestamp); - auto metric_type = info.metric_type_; - // auto del_barrier = get_barrier(deleted_record_, timestamp); - - // step 2.1: get meta - // step 2.2: get which vector field to search - auto vecfield_id = info.field_id_; - auto& field = schema[vecfield_id]; - - AssertInfo(field.get_data_type() == DataType::VECTOR_BINARY, "[BinarySearch]Field data type isn't VECTOR_BINARY"); - auto dim = field.get_dim(); - auto topk = info.topk_; - auto total_count = topk * num_queries; - auto round_decimal = info.round_decimal_; - // step 3: small indexing search - query::dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data}; - - auto vec_ptr = record.get_field_data(vecfield_id); - auto max_indexed_id = 0; - - // step 4: brute force search where small indexing is unavailable - auto vec_size_per_chunk = vec_ptr->get_size_per_chunk(); - auto max_chunk = upper_div(ins_barrier, vec_size_per_chunk); - SubSearchResult final_result(num_queries, topk, metric_type, round_decimal); - for (int chunk_id = max_indexed_id; chunk_id < max_chunk; ++chunk_id) { - auto& chunk = vec_ptr->get_chunk(chunk_id); - auto element_begin = chunk_id * vec_size_per_chunk; - auto element_end = std::min(ins_barrier, (chunk_id + 1) * vec_size_per_chunk); - auto nsize = element_end - element_begin; - - auto sub_view = bitset.subview(element_begin, nsize); - auto sub_result = BruteForceSearch(search_dataset, chunk.data(), nsize, sub_view); - - // convert chunk uid to segment uid - for (auto& x : sub_result.mutable_seg_offsets()) { - if (x != -1) { - x += chunk_id * vec_size_per_chunk; - } - } - final_result.merge(sub_result); - } - - final_result.round_values(); - results.distances_ = std::move(final_result.mutable_distances()); - results.seg_offsets_ = std::move(final_result.mutable_seg_offsets()); - results.unity_topK_ = topk; - results.total_nq_ = num_queries; - - return Status::OK(); -} - -// TODO: refactor and merge this into one -void -SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, - int64_t ins_barrier, - const query::SearchInfo& info, - const void* query_data, - int64_t num_queries, - const BitsetView& bitset, - SearchResult& results) { - // TODO: add data_type to info - auto data_type = segment.get_schema()[info.field_id_].get_data_type(); - AssertInfo(datatype_is_vector(data_type), "[SearchOnGrowing]Data type isn't vector type"); - if (data_type == DataType::VECTOR_FLOAT) { - auto typed_data = reinterpret_cast(query_data); - FloatSearch(segment, info, typed_data, num_queries, ins_barrier, bitset, results); - } else { - auto typed_data = reinterpret_cast(query_data); - BinarySearch(segment, info, typed_data, num_queries, ins_barrier, bitset, results); - } } } // namespace milvus::query diff --git a/internal/core/src/segcore/SegmentSealedImpl.cpp b/internal/core/src/segcore/SegmentSealedImpl.cpp index 226924aa17..c98e39eee8 100644 --- a/internal/core/src/segcore/SegmentSealedImpl.cpp +++ b/internal/core/src/segcore/SegmentSealedImpl.cpp @@ -370,15 +370,8 @@ SegmentSealedImpl::vector_search(int64_t vec_count, PanicInfo("Field Data is not loaded"); } - query::dataset::SearchDataset dataset; - dataset.query_data = query_data; - dataset.num_queries = query_count; - // if(field_meta.is) - dataset.metric_type = search_info.metric_type_; - dataset.topk = search_info.topk_; - dataset.dim = field_meta.get_dim(); - dataset.round_decimal = search_info.round_decimal_; - + query::dataset::SearchDataset dataset{search_info.metric_type_, query_count, search_info.topk_, + search_info.round_decimal_, field_meta.get_dim(), query_data}; AssertInfo(get_bit(field_data_ready_bitset_, field_id), "Can't get bitset element at " + std::to_string(field_id.get())); AssertInfo(row_count_opt_.has_value(), "Can't get row count value"); @@ -388,13 +381,10 @@ SegmentSealedImpl::vector_search(int64_t vec_count, auto chunk_data = vec_data->get_chunk_data(0); auto sub_qr = query::BruteForceSearch(dataset, chunk_data, row_count, bitset); - SearchResult results; - results.distances_ = std::move(sub_qr.mutable_distances()); - results.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets()); - results.unity_topK_ = dataset.topk; - results.total_nq_ = dataset.num_queries; - - output = std::move(results); + output.distances_ = std::move(sub_qr.mutable_distances()); + output.seg_offsets_ = std::move(sub_qr.mutable_seg_offsets()); + output.unity_topK_ = dataset.topk; + output.total_nq_ = dataset.num_queries; } void