From f00c529aeae39d68716cbd0f85c31b30a80c70c7 Mon Sep 17 00:00:00 2001 From: Chun Han <116052805+MrPresent-Han@users.noreply.github.com> Date: Fri, 12 Jul 2024 10:17:36 +0800 Subject: [PATCH] feat: support group_size for search_group_by(#33544) (#33720) related: #33544 mainly changes in three aspects: 1. enable setting group_size for group by function 2. separate normal reduce and group by reduce 3. eleminate uncessary padding in search result for reducing Signed-off-by: MrPresent-Han Co-authored-by: MrPresent-Han --- internal/core/src/common/QueryInfo.h | 5 +- internal/core/src/common/QueryResult.h | 3 +- internal/core/src/query/CMakeLists.txt | 2 +- internal/core/src/query/Plan.cpp | 6 + internal/core/src/query/Plan.h | 4 + internal/core/src/query/PlanProto.cpp | 4 + internal/core/src/query/SearchOnIndex.cpp | 2 +- internal/core/src/query/SearchOnSealed.cpp | 2 +- internal/core/src/query/Utils.h | 7 + .../SearchGroupByOperator.cpp} | 98 +++--- .../SearchGroupByOperator.h} | 54 +++- .../query/visitors/ExecPlanNodeVisitor.cpp | 20 +- internal/core/src/segcore/CMakeLists.txt | 7 +- .../core/src/segcore/reduce/GroupReduce.cpp | 193 ++++++++++++ .../core/src/segcore/reduce/GroupReduce.h | 58 ++++ .../core/src/segcore/{ => reduce}/Reduce.cpp | 194 ++++-------- .../core/src/segcore/{ => reduce}/Reduce.h | 39 +-- .../src/segcore/{ => reduce}/StreamReduce.cpp | 4 +- .../src/segcore/{ => reduce}/StreamReduce.h | 0 internal/core/src/segcore/reduce_c.cpp | 36 ++- internal/core/unittest/test_c_api.cpp | 2 +- internal/core/unittest/test_float16.cpp | 2 +- internal/core/unittest/test_group_by.cpp | 290 ++++++++++-------- internal/core/unittest/test_indexing.cpp | 2 +- internal/core/unittest/test_utils.cpp | 10 + internal/core/unittest/test_utils/DataGen.h | 34 +- .../unittest/test_utils/c_api_test_utils.h | 14 +- internal/proto/plan.proto | 1 + 28 files changed, 708 insertions(+), 385 deletions(-) rename internal/core/src/query/{GroupByOperator.cpp => groupby/SearchGroupByOperator.cpp} (71%) rename internal/core/src/query/{GroupByOperator.h => groupby/SearchGroupByOperator.h} (81%) create mode 100644 internal/core/src/segcore/reduce/GroupReduce.cpp create mode 100644 internal/core/src/segcore/reduce/GroupReduce.h rename internal/core/src/segcore/{ => reduce}/Reduce.cpp (71%) rename internal/core/src/segcore/{ => reduce}/Reduce.h (84%) rename internal/core/src/segcore/{ => reduce}/StreamReduce.cpp (99%) rename internal/core/src/segcore/{ => reduce}/StreamReduce.h (100%) diff --git a/internal/core/src/common/QueryInfo.h b/internal/core/src/common/QueryInfo.h index 03d86f9965..31785ea365 100644 --- a/internal/core/src/common/QueryInfo.h +++ b/internal/core/src/common/QueryInfo.h @@ -25,8 +25,9 @@ namespace milvus { struct SearchInfo { - int64_t topk_; - int64_t round_decimal_; + int64_t topk_{0}; + int64_t group_size_{1}; + int64_t round_decimal_{0}; FieldId field_id_; MetricType metric_type_; knowhere::Json search_params_; diff --git a/internal/core/src/common/QueryResult.h b/internal/core/src/common/QueryResult.h index e8aa3a2fdf..97bc418d47 100644 --- a/internal/core/src/common/QueryResult.h +++ b/internal/core/src/common/QueryResult.h @@ -195,6 +195,7 @@ struct SearchResult { std::vector distances_; std::vector seg_offsets_; std::optional> group_by_values_; + std::optional group_size_; // first fill data during fillPrimaryKey, and then update data after reducing search results std::vector primary_keys_; @@ -209,7 +210,7 @@ struct SearchResult { std::map> output_fields_data_; // used for reduce, filter invalid pk, get real topks count - std::vector topk_per_nq_prefix_sum_; + std::vector topk_per_nq_prefix_sum_{}; //Vector iterators, used for group by std::optional>> diff --git a/internal/core/src/query/CMakeLists.txt b/internal/core/src/query/CMakeLists.txt index 6e9c2e4b93..51ae991deb 100644 --- a/internal/core/src/query/CMakeLists.txt +++ b/internal/core/src/query/CMakeLists.txt @@ -26,7 +26,7 @@ set(MILVUS_QUERY_SRCS SearchOnIndex.cpp SearchBruteForce.cpp SubSearchResult.cpp - GroupByOperator.cpp + groupby/SearchGroupByOperator.cpp PlanProto.cpp ) add_library(milvus_query ${MILVUS_QUERY_SRCS}) diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index a91c1bb47d..f12043b31f 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -90,6 +90,12 @@ CreateSearchPlanByExpr(const Schema& schema, return ProtoParser(schema).CreatePlan(plan_node); } +std::unique_ptr +CreateSearchPlanFromPlanNode(const Schema& schema, + const proto::plan::PlanNode& plan_node) { + return ProtoParser(schema).CreatePlan(plan_node); +} + std::unique_ptr CreateRetrievePlanByExpr(const Schema& schema, const void* serialized_expr_plan, diff --git a/internal/core/src/query/Plan.h b/internal/core/src/query/Plan.h index 6b908fd5f7..88f10ceb8b 100644 --- a/internal/core/src/query/Plan.h +++ b/internal/core/src/query/Plan.h @@ -32,6 +32,10 @@ CreateSearchPlanByExpr(const Schema& schema, const void* serialized_expr_plan, const int64_t size); +std::unique_ptr +CreateSearchPlanFromPlanNode(const Schema& schema, + const proto::plan::PlanNode& plan_node); + std::unique_ptr ParsePlaceholderGroup(const Plan* plan, const uint8_t* blob, diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index 7e719f47dd..1b9c011515 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -209,7 +209,11 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { if (query_info_proto.group_by_field_id() > 0) { auto group_by_field_id = FieldId(query_info_proto.group_by_field_id()); search_info.group_by_field_id_ = group_by_field_id; + search_info.group_size_ = query_info_proto.group_size() > 0 + ? query_info_proto.group_size() + : 1; } + auto plan_node = [&]() -> std::unique_ptr { if (anns_proto.vector_type() == milvus::proto::plan::VectorType::BinaryVector) { diff --git a/internal/core/src/query/SearchOnIndex.cpp b/internal/core/src/query/SearchOnIndex.cpp index 45de711f6c..2eb7cf9f3a 100644 --- a/internal/core/src/query/SearchOnIndex.cpp +++ b/internal/core/src/query/SearchOnIndex.cpp @@ -10,7 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include "SearchOnIndex.h" -#include "query/GroupByOperator.h" +#include "query/groupby/SearchGroupByOperator.h" namespace milvus::query { void diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index 8bc806062a..db524c6a98 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -17,7 +17,7 @@ #include "query/SearchBruteForce.h" #include "query/SearchOnSealed.h" #include "query/helper.h" -#include "query/GroupByOperator.h" +#include "query/groupby/SearchGroupByOperator.h" namespace milvus::query { diff --git a/internal/core/src/query/Utils.h b/internal/core/src/query/Utils.h index 2a42e894c7..830744da99 100644 --- a/internal/core/src/query/Utils.h +++ b/internal/core/src/query/Utils.h @@ -71,4 +71,11 @@ out_of_range(int64_t t) { return gt_ub(t) || lt_lb(t); } +inline bool +dis_closer(float dis1, float dis2, const MetricType& metric_type) { + if (PositivelyRelated(metric_type)) + return dis1 > dis2; + return dis1 < dis2; +} + } // namespace milvus::query diff --git a/internal/core/src/query/GroupByOperator.cpp b/internal/core/src/query/groupby/SearchGroupByOperator.cpp similarity index 71% rename from internal/core/src/query/GroupByOperator.cpp rename to internal/core/src/query/groupby/SearchGroupByOperator.cpp index 7ad7b5c08c..7b04f9cd2f 100644 --- a/internal/core/src/query/GroupByOperator.cpp +++ b/internal/core/src/query/groupby/SearchGroupByOperator.cpp @@ -13,35 +13,43 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "GroupByOperator.h" +#include "SearchGroupByOperator.h" #include "common/Consts.h" #include "segcore/SegmentSealedImpl.h" -#include "Utils.h" +#include "query/Utils.h" namespace milvus { namespace query { void -GroupBy(const std::vector>& iterators, - const SearchInfo& search_info, - std::vector& group_by_values, - const segcore::SegmentInternalInterface& segment, - std::vector& seg_offsets, - std::vector& distances) { +SearchGroupBy(const std::vector>& iterators, + const SearchInfo& search_info, + std::vector& group_by_values, + const segcore::SegmentInternalInterface& segment, + std::vector& seg_offsets, + std::vector& distances, + std::vector& topk_per_nq_prefix_sum) { //1. get search meta FieldId group_by_field_id = search_info.group_by_field_id_.value(); auto data_type = segment.GetFieldDataType(group_by_field_id); - + int max_total_size = + search_info.topk_ * search_info.group_size_ * iterators.size(); + seg_offsets.reserve(max_total_size); + distances.reserve(max_total_size); + group_by_values.reserve(max_total_size); + topk_per_nq_prefix_sum.reserve(iterators.size() + 1); switch (data_type) { case DataType::INT8: { auto dataGetter = GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, + search_info.group_size_, *dataGetter, group_by_values, seg_offsets, distances, - search_info.metric_type_); + search_info.metric_type_, + topk_per_nq_prefix_sum); break; } case DataType::INT16: { @@ -49,11 +57,13 @@ GroupBy(const std::vector>& iterators, GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, + search_info.group_size_, *dataGetter, group_by_values, seg_offsets, distances, - search_info.metric_type_); + search_info.metric_type_, + topk_per_nq_prefix_sum); break; } case DataType::INT32: { @@ -61,11 +71,13 @@ GroupBy(const std::vector>& iterators, GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, + search_info.group_size_, *dataGetter, group_by_values, seg_offsets, distances, - search_info.metric_type_); + search_info.metric_type_, + topk_per_nq_prefix_sum); break; } case DataType::INT64: { @@ -73,22 +85,26 @@ GroupBy(const std::vector>& iterators, GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, + search_info.group_size_, *dataGetter, group_by_values, seg_offsets, distances, - search_info.metric_type_); + search_info.metric_type_, + topk_per_nq_prefix_sum); break; } case DataType::BOOL: { auto dataGetter = GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, + search_info.group_size_, *dataGetter, group_by_values, seg_offsets, distances, - search_info.metric_type_); + search_info.metric_type_, + topk_per_nq_prefix_sum); break; } case DataType::VARCHAR: { @@ -96,11 +112,13 @@ GroupBy(const std::vector>& iterators, GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, + search_info.group_size_, *dataGetter, group_by_values, seg_offsets, distances, - search_info.metric_type_); + search_info.metric_type_, + topk_per_nq_prefix_sum); break; } default: { @@ -117,19 +135,24 @@ void GroupIteratorsByType( const std::vector>& iterators, int64_t topK, + int64_t group_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& seg_offsets, std::vector& distances, - const knowhere::MetricType& metrics_type) { + const knowhere::MetricType& metrics_type, + std::vector& topk_per_nq_prefix_sum) { + topk_per_nq_prefix_sum.push_back(0); for (auto& iterator : iterators) { GroupIteratorResult(iterator, topK, + group_size, data_getter, group_by_values, seg_offsets, distances, metrics_type); + topk_per_nq_prefix_sum.push_back(seg_offsets.size()); } } @@ -137,23 +160,20 @@ template void GroupIteratorResult(const std::shared_ptr& iterator, int64_t topK, + int64_t group_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& offsets, std::vector& distances, const knowhere::MetricType& metrics_type) { //1. - std::unordered_map> groupMap; + GroupByMap groupMap(topK, group_size); //2. do iteration until fill the whole map or run out of all data //note it may enumerate all data inside a segment and can block following //query and search possibly - auto dis_closer = [&](float l, float r) { - if (PositivelyRelated(metrics_type)) - return l > r; - return l < r; - }; - while (iterator->HasNext() && groupMap.size() < topK) { + std::vector> res; + while (iterator->HasNext() && !groupMap.IsGroupResEnough()) { auto offset_dis_pair = iterator->Next(); AssertInfo( offset_dis_pair.has_value(), @@ -162,38 +182,22 @@ GroupIteratorResult(const std::shared_ptr& iterator, auto offset = offset_dis_pair.value().first; auto dis = offset_dis_pair.value().second; T row_data = data_getter.Get(offset); - auto it = groupMap.find(row_data); - if (it == groupMap.end()) { - groupMap.emplace(row_data, std::make_pair(offset, dis)); - } else if (dis_closer(dis, it->second.second)) { - it->second = {offset, dis}; + if (groupMap.Push(row_data)) { + res.emplace_back(offset, dis, row_data); } } //3. sorted based on distances and metrics - std::vector>> sortedGroupVals( - groupMap.begin(), groupMap.end()); auto customComparator = [&](const auto& lhs, const auto& rhs) { - return dis_closer(lhs.second.second, rhs.second.second); + return dis_closer(std::get<1>(lhs), std::get<1>(rhs), metrics_type); }; - std::sort(sortedGroupVals.begin(), sortedGroupVals.end(), customComparator); + std::sort(res.begin(), res.end(), customComparator); //4. save groupBy results - group_by_values.reserve(sortedGroupVals.size()); - offsets.reserve(sortedGroupVals.size()); - distances.reserve(sortedGroupVals.size()); - for (auto iter = sortedGroupVals.cbegin(); iter != sortedGroupVals.cend(); - iter++) { - group_by_values.emplace_back(iter->first); - offsets.push_back(iter->second.first); - distances.push_back(iter->second.second); - } - - //5. padding topK results, extra memory consumed will be removed when reducing - for (std::size_t idx = groupMap.size(); idx < topK; idx++) { - offsets.push_back(INVALID_SEG_OFFSET); - distances.push_back(0.0); - group_by_values.emplace_back(std::monostate{}); + for (auto iter = res.cbegin(); iter != res.cend(); iter++) { + offsets.push_back(std::get<0>(*iter)); + distances.push_back(std::get<1>(*iter)); + group_by_values.emplace_back(std::move(std::get<2>(*iter))); } } diff --git a/internal/core/src/query/GroupByOperator.h b/internal/core/src/query/groupby/SearchGroupByOperator.h similarity index 81% rename from internal/core/src/query/GroupByOperator.h rename to internal/core/src/query/groupby/SearchGroupByOperator.h index f34316f0e1..41e3d2299d 100644 --- a/internal/core/src/query/GroupByOperator.h +++ b/internal/core/src/query/groupby/SearchGroupByOperator.h @@ -23,6 +23,7 @@ #include "segcore/SegmentSealedImpl.h" #include "segcore/ConcurrentVector.h" #include "common/Span.h" +#include "query/Utils.h" namespace milvus { namespace query { @@ -167,28 +168,67 @@ PrepareVectorIteratorsFromIndex(const SearchInfo& search_info, } void -GroupBy(const std::vector>& iterators, - const SearchInfo& searchInfo, - std::vector& group_by_values, - const segcore::SegmentInternalInterface& segment, - std::vector& seg_offsets, - std::vector& distances); +SearchGroupBy(const std::vector>& iterators, + const SearchInfo& searchInfo, + std::vector& group_by_values, + const segcore::SegmentInternalInterface& segment, + std::vector& seg_offsets, + std::vector& distances, + std::vector& topk_per_nq_prefix_sum); template void GroupIteratorsByType( const std::vector>& iterators, int64_t topK, + int64_t group_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& seg_offsets, std::vector& distances, - const knowhere::MetricType& metrics_type); + const knowhere::MetricType& metrics_type, + std::vector& topk_per_nq_prefix_sum); + +template +struct GroupByMap { + private: + std::unordered_map group_map_{}; + int group_capacity_{0}; + int group_size_{0}; + int enough_group_count{0}; + + public: + GroupByMap(int group_capacity, int group_size) + : group_capacity_(group_capacity), group_size_(group_size){}; + bool + IsGroupResEnough() { + return group_map_.size() == group_capacity_ && + enough_group_count == group_capacity_; + } + bool + Push(const T& t) { + if (group_map_.size() >= group_capacity_ && group_map_[t] == 0){ + return false; + } + if (group_map_[t] >= group_size_) { + //we ignore following input no matter the distance as knowhere::iterator doesn't guarantee + //strictly increase/decreasing distance output + //but this should not be a very serious influence to overall recall rate + return false; + } + group_map_[t] += 1; + if (group_map_[t] >= group_size_) { + enough_group_count += 1; + } + return true; + } +}; template void GroupIteratorResult(const std::shared_ptr& iterator, int64_t topK, + int64_t group_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& offsets, diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp index d9e8a6c125..0892f7b385 100644 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp @@ -25,7 +25,7 @@ #include "plan/PlanNode.h" #include "exec/Task.h" #include "segcore/SegmentInterface.h" -#include "query/GroupByOperator.h" +#include "query/groupby/SearchGroupByOperator.h" namespace milvus::query { namespace impl { @@ -193,14 +193,20 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) { search_result); search_result.total_data_cnt_ = final_view.size(); if (search_result.vector_iterators_.has_value()) { + AssertInfo(search_result.vector_iterators_.value().size() == + search_result.total_nq_, + "Vector Iterators' count must be equal to total_nq_, Check " + "your code"); std::vector group_by_values; - GroupBy(search_result.vector_iterators_.value(), - node.search_info_, - group_by_values, - *segment, - search_result.seg_offsets_, - search_result.distances_); + SearchGroupBy(search_result.vector_iterators_.value(), + node.search_info_, + group_by_values, + *segment, + search_result.seg_offsets_, + search_result.distances_, + search_result.topk_per_nq_prefix_sum_); search_result.group_by_values_ = std::move(group_by_values); + search_result.group_size_ = node.search_info_.group_size_; AssertInfo(search_result.seg_offsets_.size() == search_result.group_by_values_.value().size(), "Wrong state! search_result group_by_values_ size:{} is not " diff --git a/internal/core/src/segcore/CMakeLists.txt b/internal/core/src/segcore/CMakeLists.txt index 7d7944eda9..b783afb361 100644 --- a/internal/core/src/segcore/CMakeLists.txt +++ b/internal/core/src/segcore/CMakeLists.txt @@ -24,8 +24,6 @@ set(SEGCORE_FILES SegmentGrowingImpl.cpp SegmentSealedImpl.cpp FieldIndexing.cpp - Reduce.cpp - StreamReduce.cpp metrics_c.cpp plan_c.cpp reduce_c.cpp @@ -39,7 +37,10 @@ set(SEGCORE_FILES Utils.cpp ConcurrentVector.cpp ReduceUtils.cpp - check_vec_index_c.cpp) + check_vec_index_c.cpp + reduce/Reduce.cpp + reduce/StreamReduce.cpp + reduce/GroupReduce.cpp) add_library(milvus_segcore SHARED ${SEGCORE_FILES}) target_link_libraries(milvus_segcore milvus_query milvus_bitset milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage milvus_futures) diff --git a/internal/core/src/segcore/reduce/GroupReduce.cpp b/internal/core/src/segcore/reduce/GroupReduce.cpp new file mode 100644 index 0000000000..005956dba3 --- /dev/null +++ b/internal/core/src/segcore/reduce/GroupReduce.cpp @@ -0,0 +1,193 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#include "GroupReduce.h" +#include "log/Log.h" +#include "segcore/SegmentInterface.h" +#include "segcore/ReduceUtils.h" + +namespace milvus::segcore { + +void +GroupReduceHelper::FillOtherData( + int result_count, + int64_t nq_begin, + int64_t nq_end, + std::unique_ptr& search_res_data) { + std::vector group_by_values; + group_by_values.resize(result_count); + for (auto qi = nq_begin; qi < nq_end; qi++) { + for (auto search_result : search_results_) { + AssertInfo(search_result != nullptr, + "null search result when reorganize"); + if (search_result->result_offsets_.size() == 0) { + continue; + } + + auto topk_start = search_result->topk_per_nq_prefix_sum_[qi]; + auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1]; + for (auto ki = topk_start; ki < topk_end; ki++) { + auto loc = search_result->result_offsets_[ki]; + group_by_values[loc] = + search_result->group_by_values_.value()[ki]; + } + } + } + AssembleGroupByValues(search_res_data, group_by_values, plan_); +} + +void +GroupReduceHelper::RefreshSingleSearchResult(SearchResult* search_result, + int seg_res_idx, + std::vector& real_topks) { + AssertInfo(search_result->group_by_values_.has_value(), + "no group by values for search result, group reducer should not " + "be called, wrong code"); + AssertInfo(search_result->primary_keys_.size() == + search_result->group_by_values_.value().size(), + "Wrong size for group_by_values size before refresh:{}, " + "not equal to " + "primary_keys_.size:{}", + search_result->group_by_values_.value().size(), + search_result->primary_keys_.size()); + + uint32_t size = 0; + for (int j = 0; j < total_nq_; j++) { + size += final_search_records_[seg_res_idx][j].size(); + } + std::vector primary_keys(size); + std::vector distances(size); + std::vector seg_offsets(size); + std::vector group_by_values(size); + + uint32_t index = 0; + for (int j = 0; j < total_nq_; j++) { + for (auto offset : final_search_records_[seg_res_idx][j]) { + primary_keys[index] = search_result->primary_keys_[offset]; + distances[index] = search_result->distances_[offset]; + seg_offsets[index] = search_result->seg_offsets_[offset]; + group_by_values[index] = + search_result->group_by_values_.value()[offset]; + index++; + real_topks[j]++; + } + } + search_result->primary_keys_.swap(primary_keys); + search_result->distances_.swap(distances); + search_result->seg_offsets_.swap(seg_offsets); + search_result->group_by_values_.value().swap(group_by_values); + AssertInfo(search_result->primary_keys_.size() == + search_result->group_by_values_.value().size(), + "Wrong size for group_by_values size after refresh:{}, " + "not equal to " + "primary_keys_.size:{}", + search_result->group_by_values_.value().size(), + search_result->primary_keys_.size()); +} + +void +GroupReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) { + //do nothing, for group-by reduce, as we calculate prefix_sum for nq when doing group by and no padding invalid results + //so there's no need to filter search_result +} + +int64_t +GroupReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, + int64_t topk, + int64_t& offset) { + std::priority_queue, + SearchResultPairComparator> + heap; + pk_set_.clear(); + pairs_.clear(); + pairs_.reserve(num_segments_); + for (int i = 0; i < num_segments_; i++) { + auto search_result = search_results_[i]; + auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi]; + auto offset_end = search_result->topk_per_nq_prefix_sum_[qi + 1]; + if (offset_beg == offset_end) { + continue; + } + auto primary_key = search_result->primary_keys_[offset_beg]; + auto distance = search_result->distances_[offset_beg]; + AssertInfo(search_result->group_by_values_.has_value(), + "Wrong state, search_result has no group_by_vales for " + "group_by_reduce, must be sth wrong!"); + AssertInfo(search_result->group_by_values_.value().size() == + search_result->primary_keys_.size(), + "Wrong state, search_result's group_by_values's length is " + "not equal to pks' size!"); + auto group_by_val = search_result->group_by_values_.value()[offset_beg]; + pairs_.emplace_back(primary_key, + distance, + search_result, + i, + offset_beg, + offset_end, + std::move(group_by_val)); + heap.push(&pairs_.back()); + } + + // nq has no results for all segments + if (heap.size() == 0) { + return 0; + } + + int64_t group_size = search_results_[0]->group_size_.value(); + int64_t group_by_total_size = group_size * topk; + int64_t filtered_count = 0; + auto start = offset; + std::unordered_map group_by_map; + + auto should_filtered = [&](const PkType& pk, + const GroupByValueType& group_by_val) { + if (pk_set_.count(pk) != 0) + return true; + if (group_by_map.size() >= topk && + group_by_map.count(group_by_val) == 0) + return true; + if (group_by_map[group_by_val] >= group_size) + return true; + return false; + }; + + while (offset - start < group_by_total_size && !heap.empty()) { + //fetch value + auto pilot = heap.top(); + heap.pop(); + auto index = pilot->segment_index_; + auto pk = pilot->primary_key_; + AssertInfo(pk != INVALID_PK, + "Wrong, search results should have been filtered and " + "invalid_pk should not be existed"); + auto group_by_val = pilot->group_by_value_.value(); + + //judge filter + if (!should_filtered(pk, group_by_val)) { + pilot->search_result_->result_offsets_.push_back(offset++); + final_search_records_[index][qi].push_back(pilot->offset_); + pk_set_.insert(pk); + group_by_map[group_by_val] += 1; + } else { + filtered_count++; + } + + //move pilot forward + pilot->advance(); + if (pilot->primary_key_ != INVALID_PK) { + heap.push(pilot); + } + } + return filtered_count; +} + +} // namespace milvus::segcore diff --git a/internal/core/src/segcore/reduce/GroupReduce.h b/internal/core/src/segcore/reduce/GroupReduce.h new file mode 100644 index 0000000000..35e3780607 --- /dev/null +++ b/internal/core/src/segcore/reduce/GroupReduce.h @@ -0,0 +1,58 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License +#pragma once +#include "Reduce.h" +#include "common/QueryResult.h" +#include "query/PlanImpl.h" + +namespace milvus::segcore { +class GroupReduceHelper : public ReduceHelper { + public: + explicit GroupReduceHelper(std::vector& search_results, + milvus::query::Plan* plan, + int64_t* slice_nqs, + int64_t* slice_topKs, + int64_t slice_num, + tracer::TraceContext* trace_ctx) + : ReduceHelper(search_results, + plan, + slice_nqs, + slice_topKs, + slice_num, + trace_ctx) { + } + + protected: + void + FilterInvalidSearchResult(SearchResult* search_result) override; + + int64_t + ReduceSearchResultForOneNQ(int64_t qi, + int64_t topk, + int64_t& result_offset) override; + + void + RefreshSingleSearchResult(SearchResult* search_result, + int seg_res_idx, + std::vector& real_topks) override; + + void + FillOtherData(int result_count, + int64_t nq_begin, + int64_t nq_end, + std::unique_ptr& + search_res_data) override; + + private: + std::unordered_set group_by_val_set_{}; +}; + +} // namespace milvus::segcore diff --git a/internal/core/src/segcore/Reduce.cpp b/internal/core/src/segcore/reduce/Reduce.cpp similarity index 71% rename from internal/core/src/segcore/Reduce.cpp rename to internal/core/src/segcore/reduce/Reduce.cpp index 1d52a460b8..b0086d6901 100644 --- a/internal/core/src/segcore/Reduce.cpp +++ b/internal/core/src/segcore/reduce/Reduce.cpp @@ -11,15 +11,15 @@ #include "Reduce.h" -#include +#include "log/Log.h" #include #include -#include "SegmentInterface.h" -#include "Utils.h" +#include "segcore/SegmentInterface.h" +#include "segcore/Utils.h" #include "common/EasyAssert.h" -#include "pkVisitor.h" -#include "ReduceUtils.h" +#include "segcore/pkVisitor.h" +#include "segcore/ReduceUtils.h" namespace milvus::segcore { @@ -56,7 +56,7 @@ void ReduceHelper::Reduce() { FillPrimaryKey(); ReduceResultData(); - RefreshSearchResult(); + RefreshSearchResults(); FillEntryData(); } @@ -90,13 +90,6 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) { auto segment = static_cast(search_result->segment_); auto& offsets = search_result->seg_offsets_; auto& distances = search_result->distances_; - if (search_result->group_by_values_.has_value()) { - AssertInfo(search_result->distances_.size() == - search_result->group_by_values_.value().size(), - "wrong group_by_values size, size:{}, expected size:{} ", - search_result->group_by_values_.value().size(), - search_result->distances_.size()); - } for (auto i = 0; i < nq; ++i) { for (auto j = 0; j < topK; ++j) { @@ -112,18 +105,12 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) { real_topks[i]++; offsets[valid_index] = offsets[index]; distances[valid_index] = distances[index]; - if (search_result->group_by_values_.has_value()) - search_result->group_by_values_.value()[valid_index] = - search_result->group_by_values_.value()[index]; valid_index++; } } } offsets.resize(valid_index); distances.resize(valid_index); - if (search_result->group_by_values_.has_value()) - search_result->group_by_values_.value().resize(valid_index); - search_result->topk_per_nq_prefix_sum_.resize(nq + 1); std::partial_sum(real_topks.begin(), real_topks.end(), @@ -154,59 +141,14 @@ ReduceHelper::FillPrimaryKey() { } void -ReduceHelper::RefreshSearchResult() { +ReduceHelper::RefreshSearchResults() { tracer::AutoSpan span( - "ReduceHelper::RefreshSearchResult", trace_ctx_, false); + "ReduceHelper::RefreshSearchResults", trace_ctx_, false); for (int i = 0; i < num_segments_; i++) { std::vector real_topks(total_nq_, 0); auto search_result = search_results_[i]; - if (search_result->group_by_values_.has_value()) { - AssertInfo(search_result->primary_keys_.size() == - search_result->group_by_values_.value().size(), - "Wrong size for group_by_values size before refresh:{}, " - "not equal to " - "primary_keys_.size:{}", - search_result->group_by_values_.value().size(), - search_result->primary_keys_.size()); - } if (search_result->result_offsets_.size() != 0) { - uint32_t size = 0; - for (int j = 0; j < total_nq_; j++) { - size += final_search_records_[i][j].size(); - } - std::vector primary_keys(size); - std::vector distances(size); - std::vector seg_offsets(size); - std::vector group_by_values(size); - - uint32_t index = 0; - for (int j = 0; j < total_nq_; j++) { - for (auto offset : final_search_records_[i][j]) { - primary_keys[index] = search_result->primary_keys_[offset]; - distances[index] = search_result->distances_[offset]; - seg_offsets[index] = search_result->seg_offsets_[offset]; - if (search_result->group_by_values_.has_value()) - group_by_values[index] = - search_result->group_by_values_.value()[offset]; - index++; - real_topks[j]++; - } - } - search_result->primary_keys_.swap(primary_keys); - search_result->distances_.swap(distances); - search_result->seg_offsets_.swap(seg_offsets); - if (search_result->group_by_values_.has_value()) { - search_result->group_by_values_.value().swap(group_by_values); - } - } - if (search_result->group_by_values_.has_value()) { - AssertInfo(search_result->primary_keys_.size() == - search_result->group_by_values_.value().size(), - "Wrong size for group_by_values size after refresh:{}, " - "not equal to " - "primary_keys_.size:{}", - search_result->group_by_values_.value().size(), - search_result->primary_keys_.size()); + RefreshSingleSearchResult(search_result, i, real_topks); } std::partial_sum(real_topks.begin(), real_topks.end(), @@ -214,6 +156,33 @@ ReduceHelper::RefreshSearchResult() { } } +void +ReduceHelper::RefreshSingleSearchResult(SearchResult* search_result, + int seg_res_idx, + std::vector& real_topks) { + uint32_t size = 0; + for (int j = 0; j < total_nq_; j++) { + size += final_search_records_[seg_res_idx][j].size(); + } + std::vector primary_keys(size); + std::vector distances(size); + std::vector seg_offsets(size); + + uint32_t index = 0; + for (int j = 0; j < total_nq_; j++) { + for (auto offset : final_search_records_[seg_res_idx][j]) { + primary_keys[index] = search_result->primary_keys_[offset]; + distances[index] = search_result->distances_[offset]; + seg_offsets[index] = search_result->seg_offsets_[offset]; + index++; + real_topks[j]++; + } + } + search_result->primary_keys_.swap(primary_keys); + search_result->distances_.swap(distances); + search_result->seg_offsets_.swap(seg_offsets); +} + void ReduceHelper::FillEntryData() { tracer::AutoSpan span("ReduceHelper::FillEntryData", trace_ctx_, false); @@ -228,12 +197,12 @@ int64_t ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, int64_t topk, int64_t& offset) { - while (!heap_.empty()) { - heap_.pop(); - } + std::priority_queue, + SearchResultPairComparator> + heap; pk_set_.clear(); pairs_.clear(); - group_by_val_set_.clear(); pairs_.reserve(num_segments_); for (int i = 0; i < num_segments_; i++) { @@ -245,41 +214,21 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, } auto primary_key = search_result->primary_keys_[offset_beg]; auto distance = search_result->distances_[offset_beg]; - if (search_result->group_by_values_.has_value()) { - AssertInfo( - search_result->group_by_values_.value().size() > offset_beg, - "Wrong size for group_by_values size to " - "ReduceSearchResultForOneNQ:{}, not enough for" - "required offset_beg:{}", - search_result->group_by_values_.value().size(), - offset_beg); - } - pairs_.emplace_back( - primary_key, - distance, - search_result, - i, - offset_beg, - offset_end, - search_result->group_by_values_.has_value() && - search_result->group_by_values_.value().size() > offset_beg - ? std::make_optional( - search_result->group_by_values_.value().at(offset_beg)) - : std::nullopt); - heap_.push(&pairs_.back()); + primary_key, distance, search_result, i, offset_beg, offset_end); + heap.push(&pairs_.back()); } // nq has no results for all segments - if (heap_.size() == 0) { + if (heap.size() == 0) { return 0; } int64_t dup_cnt = 0; auto start = offset; - while (offset - start < topk && !heap_.empty()) { - auto pilot = heap_.top(); - heap_.pop(); + while (offset - start < topk && !heap.empty()) { + auto pilot = heap.top(); + heap.pop(); auto index = pilot->segment_index_; auto pk = pilot->primary_key_; @@ -289,27 +238,16 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, } // remove duplicates if (pk_set_.count(pk) == 0) { - bool skip_for_group_by = false; - if (pilot->group_by_value_.has_value()) { - if (group_by_val_set_.count(pilot->group_by_value_.value()) > - 0) { - skip_for_group_by = true; - } - } - if (!skip_for_group_by) { - pilot->search_result_->result_offsets_.push_back(offset++); - final_search_records_[index][qi].push_back(pilot->offset_); - pk_set_.insert(pk); - if (pilot->group_by_value_.has_value()) - group_by_val_set_.insert(pilot->group_by_value_.value()); - } + pilot->search_result_->result_offsets_.push_back(offset++); + final_search_records_[index][qi].push_back(pilot->offset_); + pk_set_.insert(pk); } else { // skip entity with same primary key dup_cnt++; } pilot->advance(); if (pilot->primary_key_ != INVALID_PK) { - heap_.push(pilot); + heap.push(pilot); } } return dup_cnt; @@ -331,7 +269,7 @@ ReduceHelper::ReduceResultData() { "incorrect search result primary key size"); } - int64_t skip_dup_cnt = 0; + int64_t filtered_count = 0; for (int64_t slice_index = 0; slice_index < num_slices_; slice_index++) { auto nq_begin = slice_nqs_prefix_sum_[slice_index]; auto nq_end = slice_nqs_prefix_sum_[slice_index + 1]; @@ -339,15 +277,24 @@ ReduceHelper::ReduceResultData() { // reduce search results int64_t offset = 0; for (int64_t qi = nq_begin; qi < nq_end; qi++) { - skip_dup_cnt += ReduceSearchResultForOneNQ( + filtered_count += ReduceSearchResultForOneNQ( qi, slice_topKs_[slice_index], offset); } } - if (skip_dup_cnt > 0) { - LOG_DEBUG("skip duplicated search result, count = {}", skip_dup_cnt); + if (filtered_count > 0) { + LOG_DEBUG("skip duplicated search result, count = {}", filtered_count); } } +void +ReduceHelper::FillOtherData( + int result_count, + int64_t nq_begin, + int64_t nq_end, + std::unique_ptr& search_res_data) { + //simple batch reduce do nothing for other data +} + std::vector ReduceHelper::GetSearchResultDataSlice(int slice_index) { auto nq_begin = slice_nqs_prefix_sum_[slice_index]; @@ -370,7 +317,6 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) { search_result_data->set_top_k(slice_topKs_[slice_index]); search_result_data->set_num_queries(nq_end - nq_begin); search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0); - search_result_data->set_all_search_count(all_search_count); // `result_pairs` contains the SearchResult and result_offset info, used for filling output fields @@ -407,12 +353,6 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) { // reserve space for distances search_result_data->mutable_scores()->Resize(result_count, 0); - //reserve space for group_by_values - std::vector group_by_values; - if (plan_->plan_node_->search_info_.group_by_field_id_.has_value()) { - group_by_values.resize(result_count); - } - // fill pks and distances for (auto qi = nq_begin; qi < nq_end; qi++) { int64_t topk_count = 0; @@ -461,11 +401,6 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) { search_result_data->mutable_scores()->Set( loc, search_result->distances_[ki]); - // set group by values - if (search_result->group_by_values_.has_value() && - ki < search_result->group_by_values_.value().size()) - group_by_values[loc] = - search_result->group_by_values_.value()[ki]; // set result offset to fill output fields data result_pairs[loc] = {&search_result->output_fields_data_, ki}; } @@ -474,12 +409,13 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) { // update result topKs search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count); } - AssembleGroupByValues(search_result_data, group_by_values, plan_); AssertInfo(search_result_data->scores_size() == result_count, "wrong scores size, size = " + std::to_string(search_result_data->scores_size()) + ", expected size = " + std::to_string(result_count)); + // fill other wanted data + FillOtherData(result_count, nq_begin, nq_end, search_result_data); // set output fields for (auto field_id : plan_->target_entries_) { diff --git a/internal/core/src/segcore/Reduce.h b/internal/core/src/segcore/reduce/Reduce.h similarity index 84% rename from internal/core/src/segcore/Reduce.h rename to internal/core/src/segcore/reduce/Reduce.h index 8e60b1b9f3..0aa16222f6 100644 --- a/internal/core/src/segcore/Reduce.h +++ b/internal/core/src/segcore/reduce/Reduce.h @@ -21,9 +21,9 @@ #include "common/type_c.h" #include "common/QueryResult.h" #include "query/PlanImpl.h" -#include "ReduceStructure.h" +#include "segcore/ReduceStructure.h" #include "common/Tracer.h" -#include "segment_c.h" +#include "segcore/segment_c.h" namespace milvus::segcore { @@ -60,11 +60,16 @@ class ReduceHelper { } protected: - void + virtual void FilterInvalidSearchResult(SearchResult* search_result); void - RefreshSearchResult(); + RefreshSearchResults(); + + virtual void + RefreshSingleSearchResult(SearchResult* search_result, + int seg_res_idx, + std::vector& real_topks); void FillPrimaryKey(); @@ -72,6 +77,18 @@ class ReduceHelper { void ReduceResultData(); + virtual int64_t + ReduceSearchResultForOneNQ(int64_t qi, + int64_t topk, + int64_t& result_offset); + + virtual void + FillOtherData(int result_count, + int64_t nq_begin, + int64_t nq_end, + std::unique_ptr& + search_res_data); + private: void Initialize(); @@ -79,11 +96,6 @@ class ReduceHelper { void FillEntryData(); - int64_t - ReduceSearchResultForOneNQ(int64_t qi, - int64_t topk, - int64_t& result_offset); - std::vector GetSearchResultDataSlice(int slice_index_); @@ -94,25 +106,16 @@ class ReduceHelper { std::vector slice_nqs_prefix_sum_; int64_t num_segments_; std::vector slice_topKs_; - std::priority_queue, - SearchResultPairComparator> - heap_; // Used for merge results, // define these here to avoid allocating them for each query std::vector pairs_; std::unordered_set pk_set_; - std::unordered_set group_by_val_set_; // dim0: num_segments_; dim1: total_nq_; dim2: offset std::vector>> final_search_records_; - - private: std::vector slice_nqs_; int64_t total_nq_; - // output std::unique_ptr search_result_data_blobs_; - tracer::TraceContext* trace_ctx_; }; diff --git a/internal/core/src/segcore/StreamReduce.cpp b/internal/core/src/segcore/reduce/StreamReduce.cpp similarity index 99% rename from internal/core/src/segcore/StreamReduce.cpp rename to internal/core/src/segcore/reduce/StreamReduce.cpp index 2ca14379a3..d7fdf22035 100644 --- a/internal/core/src/segcore/StreamReduce.cpp +++ b/internal/core/src/segcore/reduce/StreamReduce.cpp @@ -10,9 +10,9 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include "StreamReduce.h" -#include "SegmentInterface.h" +#include "segcore/SegmentInterface.h" #include "segcore/Utils.h" -#include "Reduce.h" +#include "segcore/reduce/Reduce.h" #include "segcore/pkVisitor.h" #include "segcore/ReduceUtils.h" diff --git a/internal/core/src/segcore/StreamReduce.h b/internal/core/src/segcore/reduce/StreamReduce.h similarity index 100% rename from internal/core/src/segcore/StreamReduce.h rename to internal/core/src/segcore/reduce/StreamReduce.h diff --git a/internal/core/src/segcore/reduce_c.cpp b/internal/core/src/segcore/reduce_c.cpp index 628f8ffcdf..973fa95a67 100644 --- a/internal/core/src/segcore/reduce_c.cpp +++ b/internal/core/src/segcore/reduce_c.cpp @@ -10,12 +10,13 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include -#include "Reduce.h" +#include "segcore/reduce/Reduce.h" +#include "segcore/reduce/GroupReduce.h" #include "common/QueryResult.h" #include "common/EasyAssert.h" #include "query/Plan.h" #include "segcore/reduce_c.h" -#include "segcore/StreamReduce.h" +#include "segcore/reduce/StreamReduce.h" #include "segcore/Utils.h" using SearchResult = milvus::SearchResult; @@ -95,17 +96,30 @@ ReduceSearchResultsAndFillData(CTraceContext c_trace, search_results[i] = static_cast(c_search_results[i]); } - auto reduce_helper = milvus::segcore::ReduceHelper(search_results, - plan, - slice_nqs, - slice_topKs, - num_slices, - &trace_ctx); - reduce_helper.Reduce(); - reduce_helper.Marshal(); + std::shared_ptr reduce_helper; + if (plan->plan_node_->search_info_.group_by_field_id_.has_value()) { + reduce_helper = + std::make_shared( + search_results, + plan, + slice_nqs, + slice_topKs, + num_slices, + &trace_ctx); + } else { + reduce_helper = + std::make_shared(search_results, + plan, + slice_nqs, + slice_topKs, + num_slices, + &trace_ctx); + } + reduce_helper->Reduce(); + reduce_helper->Marshal(); // set final result ptr - *cSearchResultDataBlobs = reduce_helper.GetSearchResultDataBlobs(); + *cSearchResultDataBlobs = reduce_helper->GetSearchResultDataBlobs(); return milvus::SuccessCStatus(); } catch (std::exception& e) { return milvus::FailureCStatus(&e); diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index eec24d8670..8c177c0787 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -31,7 +31,7 @@ #include "pb/plan.pb.h" #include "query/ExprImpl.h" #include "segcore/Collection.h" -#include "segcore/Reduce.h" +#include "segcore/reduce/Reduce.h" #include "segcore/reduce_c.h" #include "segcore/segment_c.h" #include "futures/Future.h" diff --git a/internal/core/unittest/test_float16.cpp b/internal/core/unittest/test_float16.cpp index bf172a5d47..38da5af555 100644 --- a/internal/core/unittest/test_float16.cpp +++ b/internal/core/unittest/test_float16.cpp @@ -16,7 +16,7 @@ #include "index/IndexFactory.h" #include "knowhere/comp/index_param.h" #include "query/ExprImpl.h" -#include "segcore/Reduce.h" +#include "segcore/reduce/Reduce.h" #include "segcore/reduce_c.h" #include "test_utils/DataGen.h" #include "test_utils/PbHelper.h" diff --git a/internal/core/unittest/test_group_by.cpp b/internal/core/unittest/test_group_by.cpp index 1295230f25..1f7fe70a31 100644 --- a/internal/core/unittest/test_group_by.cpp +++ b/internal/core/unittest/test_group_by.cpp @@ -64,20 +64,13 @@ CheckGroupBySearchResult(const SearchResult& search_result, int topK, int nq, bool strict) { - int total = topK * nq; - ASSERT_EQ(search_result.group_by_values_.value().size(), total); - ASSERT_EQ(search_result.seg_offsets_.size(), total); - ASSERT_EQ(search_result.distances_.size(), total); + int size = search_result.group_by_values_.value().size(); + ASSERT_EQ(search_result.seg_offsets_.size(), size); + ASSERT_EQ(search_result.distances_.size(), size); ASSERT_TRUE(search_result.seg_offsets_[0] != INVALID_SEG_OFFSET); - int res_bound = GetSearchResultBound(search_result); - ASSERT_TRUE(res_bound > 0); - if (strict) { - ASSERT_TRUE(res_bound == total - 1); - } else { - ASSERT_TRUE(res_bound == total - 1 || - search_result.seg_offsets_[res_bound + 1] == - INVALID_SEG_OFFSET); - } + ASSERT_TRUE(search_result.seg_offsets_[size - 1] != INVALID_SEG_OFFSET); + ASSERT_EQ(search_result.topk_per_nq_prefix_sum_.size(), nq + 1); + ASSERT_EQ(size, search_result.topk_per_nq_prefix_sum_[nq]); } TEST(GroupBY, SealedIndex) { @@ -98,10 +91,10 @@ TEST(GroupBY, SealedIndex) { auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); schema->set_primary_field_id(str_fid); auto segment = CreateSealedSegment(schema); - size_t N = 100; + size_t N = 50; //2. load raw data - auto raw_data = DataGen(schema, N); + auto raw_data = DataGen(schema, N, 42, 0, 8, 10, false, false); auto fields = schema->get_fields(); for (auto field_data : raw_data.raw_->fields_data()) { int64_t field_id = field_data.field_id(); @@ -125,24 +118,27 @@ TEST(GroupBY, SealedIndex) { load_index_info.index = std::move(indexing); load_index_info.index_params[METRICS_TYPE] = knowhere::metric::L2; segment->LoadIndex(load_index_info); - int topK = 100; + int topK = 15; + int group_size = 3; //4. search group by int8 { const char* raw_plan = R"(vector_anns: < field_id: 100 query_info: < - topk: 100 + topk: 15 metric_type: "L2" search_params: "{\"ef\": 10}" group_by_field_id: 101 + group_size: 3 > placeholder_tag: "$0" >)"; - auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + proto::plan::PlanNode plan_node; + auto ok = + google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node); + auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); @@ -153,29 +149,31 @@ TEST(GroupBY, SealedIndex) { CheckGroupBySearchResult(*search_result, topK, num_queries, false); auto& group_by_values = search_result->group_by_values_.value(); + ASSERT_EQ(20, group_by_values.size()); + //as the total data is 0,0,....6,6, so there will be 7 buckets with [3,3,3,3,3,3,2] items respectively + //so there will be 20 items returned + int size = group_by_values.size(); - std::unordered_set i8_set; + std::unordered_map i8_map; float lastDistance = 0.0; for (size_t i = 0; i < size; i++) { if (std::holds_alternative(group_by_values[i])) { int8_t g_val = std::get(group_by_values[i]); - ASSERT_FALSE(i8_set.count(g_val) > - 0); //no repetition on groupBy field - i8_set.insert(g_val); + i8_map[g_val] += 1; + ASSERT_TRUE(i8_map[g_val] <= group_size); + //for every group, the number of hits should not exceed group_size auto distance = search_result->distances_.at(i); ASSERT_TRUE( lastDistance <= distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; - } else { - //check padding - ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET); - ASSERT_EQ(search_result->distances_[i], 0.0); } } + ASSERT_TRUE(i8_map.size() <= topK); + ASSERT_TRUE(i8_map.size() == 7); } - //4. search group by int16 + //5. search group by int16 { const char* raw_plan = R"(vector_anns: < field_id: 100 @@ -184,14 +182,16 @@ TEST(GroupBY, SealedIndex) { metric_type: "L2" search_params: "{\"ef\": 10}" group_by_field_id: 102 + group_size: 3 > placeholder_tag: "$0" >)"; - auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + proto::plan::PlanNode plan_node; + auto ok = + google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node); + auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); @@ -203,28 +203,27 @@ TEST(GroupBY, SealedIndex) { auto& group_by_values = search_result->group_by_values_.value(); int size = group_by_values.size(); - std::unordered_set i16_set; + ASSERT_EQ(20, size); + //as the total data is 0,0,....6,6, so there will be 7 buckets with [3,3,3,3,3,3,2] items respectively + //so there will be 20 items returned + + std::unordered_map i16_map; float lastDistance = 0.0; for (size_t i = 0; i < size; i++) { if (std::holds_alternative(group_by_values[i])) { int16_t g_val = std::get(group_by_values[i]); - ASSERT_FALSE(i16_set.count(g_val) > - 0); //no repetition on groupBy field - i16_set.insert(g_val); + i16_map[g_val] += 1; + ASSERT_TRUE(i16_map[g_val] <= group_size); auto distance = search_result->distances_.at(i); ASSERT_TRUE( lastDistance <= distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; - } else { - //check padding - ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET); - ASSERT_EQ(search_result->distances_[i], 0.0); } } + ASSERT_TRUE(i16_map.size() == 7); } - - //4. search group by int32 + //6. search group by int32 { const char* raw_plan = R"(vector_anns: < field_id: 100 @@ -233,14 +232,16 @@ TEST(GroupBY, SealedIndex) { metric_type: "L2" search_params: "{\"ef\": 10}" group_by_field_id: 103 + group_size: 3 > placeholder_tag: "$0" >)"; - auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + proto::plan::PlanNode plan_node; + auto ok = + google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node); + auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); @@ -249,31 +250,31 @@ TEST(GroupBY, SealedIndex) { auto search_result = segment->Search(plan.get(), ph_group.get(), 1L << 63); CheckGroupBySearchResult(*search_result, topK, num_queries, false); + auto& group_by_values = search_result->group_by_values_.value(); int size = group_by_values.size(); + ASSERT_EQ(20, size); + //as the total data is 0,0,....6,6, so there will be 7 buckets with [3,3,3,3,3,3,2] items respectively + //so there will be 20 items returned - std::unordered_set i32_set; + std::unordered_map i32_map; float lastDistance = 0.0; for (size_t i = 0; i < size; i++) { if (std::holds_alternative(group_by_values[i])) { int16_t g_val = std::get(group_by_values[i]); - ASSERT_FALSE(i32_set.count(g_val) > - 0); //no repetition on groupBy field - i32_set.insert(g_val); + i32_map[g_val] += 1; + ASSERT_TRUE(i32_map[g_val] <= group_size); auto distance = search_result->distances_.at(i); ASSERT_TRUE( lastDistance <= distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; - } else { - //check padding - ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET); - ASSERT_EQ(search_result->distances_[i], 0.0); } } + ASSERT_TRUE(i32_map.size() == 7); } - //4. search group by int64 + //7. search group by int64 { const char* raw_plan = R"(vector_anns: < field_id: 100 @@ -282,14 +283,16 @@ TEST(GroupBY, SealedIndex) { metric_type: "L2" search_params: "{\"ef\": 10}" group_by_field_id: 104 + group_size: 3 > placeholder_tag: "$0" >)"; - auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + proto::plan::PlanNode plan_node; + auto ok = + google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node); + auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); @@ -299,30 +302,29 @@ TEST(GroupBY, SealedIndex) { segment->Search(plan.get(), ph_group.get(), 1L << 63); CheckGroupBySearchResult(*search_result, topK, num_queries, false); auto& group_by_values = search_result->group_by_values_.value(); - int size = group_by_values.size(); - std::unordered_set i64_set; + ASSERT_EQ(20, size); + //as the total data is 0,0,....6,6, so there will be 7 buckets with [3,3,3,3,3,3,2] items respectively + //so there will be 20 items returned + + std::unordered_map i64_map; float lastDistance = 0.0; for (size_t i = 0; i < size; i++) { if (std::holds_alternative(group_by_values[i])) { int16_t g_val = std::get(group_by_values[i]); - ASSERT_FALSE(i64_set.count(g_val) > - 0); //no repetition on groupBy field - i64_set.insert(g_val); + i64_map[g_val] += 1; + ASSERT_TRUE(i64_map[g_val] <= group_size); auto distance = search_result->distances_.at(i); ASSERT_TRUE( lastDistance <= distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; - } else { - //check padding - ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET); - ASSERT_EQ(search_result->distances_[i], 0.0); } } + ASSERT_TRUE(i64_map.size() == 7); } - //4. search group by string + //8. search group by string { const char* raw_plan = R"(vector_anns: < field_id: 100 @@ -331,14 +333,16 @@ TEST(GroupBY, SealedIndex) { metric_type: "L2" search_params: "{\"ef\": 10}" group_by_field_id: 105 + group_size: 3 > placeholder_tag: "$0" >)"; - auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + proto::plan::PlanNode plan_node; + auto ok = + google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node); + auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); @@ -348,30 +352,28 @@ TEST(GroupBY, SealedIndex) { segment->Search(plan.get(), ph_group.get(), 1L << 63); CheckGroupBySearchResult(*search_result, topK, num_queries, false); auto& group_by_values = search_result->group_by_values_.value(); + ASSERT_EQ(20, group_by_values.size()); int size = group_by_values.size(); - std::unordered_set strs_set; + + std::unordered_map strs_map; float lastDistance = 0.0; for (size_t i = 0; i < size; i++) { if (std::holds_alternative(group_by_values[i])) { std::string g_val = std::move(std::get(group_by_values[i])); - ASSERT_FALSE(strs_set.count(g_val) > - 0); //no repetition on groupBy field - strs_set.insert(g_val); + strs_map[g_val] += 1; + ASSERT_TRUE(strs_map[g_val] <= group_size); auto distance = search_result->distances_.at(i); ASSERT_TRUE( lastDistance <= distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; - } else { - //check padding - ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET); - ASSERT_EQ(search_result->distances_[i], 0.0); } } + ASSERT_TRUE(strs_map.size() == 7); } - //4. search group by bool + //9. search group by bool { const char* raw_plan = R"(vector_anns: < field_id: 100 @@ -380,14 +382,16 @@ TEST(GroupBY, SealedIndex) { metric_type: "L2" search_params: "{\"ef\": 10}" group_by_field_id: 106 + group_size: 3 > placeholder_tag: "$0" >)"; - auto plan_str = translate_text_plan_to_binary_plan(raw_plan); - auto plan = - CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); + proto::plan::PlanNode plan_node; + auto ok = + google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node); + auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node); auto num_queries = 1; auto seed = 1024; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); @@ -396,30 +400,28 @@ TEST(GroupBY, SealedIndex) { auto search_result = segment->Search(plan.get(), ph_group.get(), 1L << 63); CheckGroupBySearchResult(*search_result, topK, num_queries, false); + auto& group_by_values = search_result->group_by_values_.value(); int size = group_by_values.size(); - std::unordered_set bools_set; - int boolValCount = 0; + ASSERT_EQ(size, 6); + //as there are only two possible values: true, false + //for each group, there are at most 3 items, so the final size of group_by_vals is 3 * 2 = 6 + + std::unordered_map bools_map; float lastDistance = 0.0; for (size_t i = 0; i < size; i++) { if (std::holds_alternative(group_by_values[i])) { bool g_val = std::get(group_by_values[i]); - ASSERT_FALSE(bools_set.count(g_val) > - 0); //no repetition on groupBy field - bools_set.insert(g_val); - boolValCount += 1; + bools_map[g_val] += 1; + ASSERT_TRUE(bools_map[g_val] <= group_size); auto distance = search_result->distances_.at(i); ASSERT_TRUE( lastDistance <= distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; - } else { - //check padding - ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET); - ASSERT_EQ(search_result->distances_[i], 0.0); } - ASSERT_TRUE(boolValCount <= 2); //bool values cannot exceed two } + ASSERT_TRUE(bools_map.size() == 2); //bool values cannot exceed two } } @@ -444,7 +446,7 @@ TEST(GroupBY, SealedData) { size_t N = 100; //2. load raw data - auto raw_data = DataGen(schema, N); + auto raw_data = DataGen(schema, N, 42, 0, 8, 10, false, false); auto fields = schema->get_fields(); for (auto field_data : raw_data.raw_->fields_data()) { int64_t field_id = field_data.field_id(); @@ -459,16 +461,18 @@ TEST(GroupBY, SealedData) { } prepareSegmentSystemFieldData(segment, N, raw_data); - int topK = 100; + int topK = 10; + int group_size = 5; //3. search group by int8 { const char* raw_plan = R"(vector_anns: < field_id: 100 query_info: < - topk: 100 + topk: 10 metric_type: "L2" search_params: "{\"ef\": 10}" - group_by_field_id: 101 + group_by_field_id: 101, + group_size: 5, > placeholder_tag: "$0" @@ -487,25 +491,27 @@ TEST(GroupBY, SealedData) { auto& group_by_values = search_result->group_by_values_.value(); int size = group_by_values.size(); - std::unordered_set i8_set; + //as the repeated is 8, so there will be 13 groups and enough 10 * 5 = 50 results + ASSERT_EQ(50, size); + + std::unordered_map i8_map; float lastDistance = 0.0; for (size_t i = 0; i < size; i++) { if (std::holds_alternative(group_by_values[i])) { int8_t g_val = std::get(group_by_values[i]); - ASSERT_FALSE(i8_set.count(g_val) > - 0); //no repetition on groupBy field - i8_set.insert(g_val); + i8_map[g_val] += 1; + ASSERT_TRUE(i8_map[g_val] <= group_size); auto distance = search_result->distances_.at(i); ASSERT_TRUE( lastDistance <= distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; - } else { - //check padding - ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET); - ASSERT_EQ(search_result->distances_[i], 0.0); } } + ASSERT_TRUE(i8_map.size() == topK); + for (const auto& it : i8_map) { + ASSERT_TRUE(it.second == group_size); + } } } @@ -534,8 +540,10 @@ TEST(GroupBY, Reduce) { uint64_t ts_offset = 0; int repeat_count_1 = 2; int repeat_count_2 = 5; - auto raw_data1 = DataGen(schema, N, seed, ts_offset, repeat_count_1); - auto raw_data2 = DataGen(schema, N, seed, ts_offset, repeat_count_2); + auto raw_data1 = + DataGen(schema, N, seed, ts_offset, repeat_count_1, false, false); + auto raw_data2 = + DataGen(schema, N, seed, ts_offset, repeat_count_2, false, false); auto fields = schema->get_fields(); //load segment1 raw data @@ -582,13 +590,17 @@ TEST(GroupBY, Reduce) { segment2->LoadIndex(load_index_info_2); //4. search group by respectively + auto num_queries = 10; + auto topK = 10; + int group_size = 3; const char* raw_plan = R"(vector_anns: < field_id: 100 query_info: < - topk: 100 + topk: 10 metric_type: "L2" search_params: "{\"ef\": 10}" group_by_field_id: 101 + group_size: 3 > placeholder_tag: "$0" @@ -596,8 +608,6 @@ TEST(GroupBY, Reduce) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto num_queries = 10; - auto topK = 100; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); @@ -629,7 +639,7 @@ TEST(GroupBY, Reduce) { slice_nqs.data(), slice_topKs.data(), slice_nqs.size()); - CheckSearchResultDuplicate(results); + CheckSearchResultDuplicate(results, group_size); DeleteSearchResult(c_search_res_1); DeleteSearchResult(c_search_res_2); DeleteSearchResultDataBlobs(cSearchResultData); @@ -664,7 +674,8 @@ TEST(GroupBY, GrowingRawData) { int64_t rows_per_batch = 512; int n_batch = 3; for (int i = 0; i < n_batch; i++) { - auto data_set = DataGen(schema, rows_per_batch); + auto data_set = + DataGen(schema, rows_per_batch, 42, 0, 8, 10, false, false); auto offset = segment_growing_impl->PreInsert(rows_per_batch); segment_growing_impl->Insert(offset, rows_per_batch, @@ -674,6 +685,9 @@ TEST(GroupBY, GrowingRawData) { } //2. Search group by + auto num_queries = 10; + auto topK = 100; + int group_size = 1; const char* raw_plan = R"(vector_anns: < field_id: 102 query_info: < @@ -681,6 +695,7 @@ TEST(GroupBY, GrowingRawData) { metric_type: "L2" search_params: "{\"ef\": 10}" group_by_field_id: 101 + group_size: 1 > placeholder_tag: "$0" @@ -688,8 +703,6 @@ TEST(GroupBY, GrowingRawData) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto num_queries = 10; - auto topK = 100; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); @@ -698,25 +711,25 @@ TEST(GroupBY, GrowingRawData) { CheckGroupBySearchResult(*search_result, topK, num_queries, true); auto& group_by_values = search_result->group_by_values_.value(); + int size = group_by_values.size(); + ASSERT_EQ(size, 640); + //as the number of data is 512 and repeated count is 8, the group number is 64 for every query + //and the total group number should be 640 + int expected_group_count = 64; int idx = 0; for (int i = 0; i < num_queries; i++) { std::unordered_set i32_set; float lastDistance = 0.0; - for (int j = 0; j < topK; j++) { + for (int j = 0; j < expected_group_count; j++) { if (std::holds_alternative(group_by_values[idx])) { int32_t g_val = std::get(group_by_values[idx]); - ASSERT_FALSE(i32_set.count(g_val) > - 0); //no repetition on groupBy field + ASSERT_FALSE( + i32_set.count(g_val) > + 0); //as the group_size is 1, there should not be any duplication for group_by value i32_set.insert(g_val); auto distance = search_result->distances_.at(idx); - ASSERT_TRUE( - lastDistance <= - distance); //distance should be decreased as metrics_type is L2 + ASSERT_TRUE(lastDistance <= distance); lastDistance = distance; - } else { - //check padding - ASSERT_EQ(search_result->seg_offsets_[idx], INVALID_SEG_OFFSET); - ASSERT_EQ(search_result->distances_[idx], 0.0); } idx++; } @@ -760,7 +773,8 @@ TEST(GroupBY, GrowingIndex) { int64_t rows_per_batch = 1024; int n_batch = 10; for (int i = 0; i < n_batch; i++) { - auto data_set = DataGen(schema, rows_per_batch); + auto data_set = + DataGen(schema, rows_per_batch, 42, 0, 8, 10, false, false); auto offset = segment_growing_impl->PreInsert(rows_per_batch); segment_growing_impl->Insert(offset, rows_per_batch, @@ -770,6 +784,9 @@ TEST(GroupBY, GrowingIndex) { } //2. Search group by int32 + auto num_queries = 10; + auto topK = 100; + int group_size = 3; const char* raw_plan = R"(vector_anns: < field_id: 102 query_info: < @@ -777,6 +794,7 @@ TEST(GroupBY, GrowingIndex) { metric_type: "L2" search_params: "{\"ef\": 10}" group_by_field_id: 101 + group_size: 3 > placeholder_tag: "$0" @@ -784,8 +802,6 @@ TEST(GroupBY, GrowingIndex) { auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto plan = CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); - auto num_queries = 10; - auto topK = 100; auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); @@ -794,27 +810,29 @@ TEST(GroupBY, GrowingIndex) { CheckGroupBySearchResult(*search_result, topK, num_queries, true); auto& group_by_values = search_result->group_by_values_.value(); + auto size = group_by_values.size(); + int expected_group_count = 100; + ASSERT_EQ(size, expected_group_count * group_size * num_queries); int idx = 0; for (int i = 0; i < num_queries; i++) { - std::unordered_set i32_set; + std::unordered_map i32_map; float lastDistance = 0.0; - for (int j = 0; j < topK; j++) { + for (int j = 0; j < expected_group_count * group_size; j++) { if (std::holds_alternative(group_by_values[idx])) { int32_t g_val = std::get(group_by_values[idx]); - ASSERT_FALSE(i32_set.count(g_val) > - 0); //no repetition on groupBy field - i32_set.insert(g_val); + i32_map[g_val] += 1; + ASSERT_TRUE(i32_map[g_val] <= group_size); auto distance = search_result->distances_.at(idx); ASSERT_TRUE( lastDistance <= distance); //distance should be decreased as metrics_type is L2 lastDistance = distance; - } else { - //check padding - ASSERT_EQ(search_result->seg_offsets_[idx], INVALID_SEG_OFFSET); - ASSERT_EQ(search_result->distances_[idx], 0.0); } idx++; } + ASSERT_EQ(i32_map.size(), expected_group_count); + for (const auto& map_pair : i32_map) { + ASSERT_EQ(group_size, map_pair.second); + } } } \ No newline at end of file diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index 0b1c9f6e4c..9d4afc53ae 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -28,7 +28,7 @@ #include "knowhere/comp/index_param.h" #include "nlohmann/json.hpp" #include "query/SearchBruteForce.h" -#include "segcore/Reduce.h" +#include "segcore/reduce/Reduce.h" #include "index/IndexFactory.h" #include "common/QueryResult.h" #include "segcore/Types.h" diff --git a/internal/core/unittest/test_utils.cpp b/internal/core/unittest/test_utils.cpp index 2c8e30877b..f8d3cc59e8 100644 --- a/internal/core/unittest/test_utils.cpp +++ b/internal/core/unittest/test_utils.cpp @@ -212,3 +212,13 @@ TEST(Util, get_common_prefix) { common_prefix = milvus::GetCommonPrefix(str1, str2); EXPECT_STREQ(common_prefix.c_str(), ""); } + +TEST(Util, dis_closer){ + EXPECT_TRUE(milvus::query::dis_closer(0.1, 0.2, "L2")); + EXPECT_FALSE(milvus::query::dis_closer(0.2, 0.1, "L2")); + EXPECT_FALSE(milvus::query::dis_closer(0.1, 0.1, "L2")); + + EXPECT_TRUE(milvus::query::dis_closer(0.2, 0.1, "IP")); + EXPECT_FALSE(milvus::query::dis_closer(0.1, 0.2, "IP")); + EXPECT_FALSE(milvus::query::dis_closer(0.1, 0.1, "IP")); +} \ No newline at end of file diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 441d7e1914..cf48ecdc57 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -236,7 +236,8 @@ struct GeneratedData { uint64_t ts_offset, int repeat_count, int array_len, - bool random_pk); + bool random_pk, + bool random_val); friend GeneratedData DataGenForJsonArray(SchemaPtr schema, int64_t N, @@ -307,7 +308,8 @@ inline GeneratedData DataGen(SchemaPtr schema, uint64_t ts_offset = 0, int repeat_count = 1, int array_len = 10, - bool random_pk = false) { + bool random_pk = false, + bool random_val = true) { using std::vector; std::default_random_engine random(seed); std::normal_distribution<> distr(0, 1); @@ -414,24 +416,39 @@ inline GeneratedData DataGen(SchemaPtr schema, } case DataType::INT32: { vector data(N); - for (auto& x : data) { - x = random() % (2 * N); + for (int i = 0; i < N; i++) { + int x = 0; + if (random_val) + x = random() % (2 * N); + else + x = i / repeat_count; + data[i] = x; } insert_cols(data, N, field_meta); break; } case DataType::INT16: { vector data(N); - for (auto& x : data) { - x = random() % (2 * N); + for (int i = 0; i < N; i++) { + int16_t x = 0; + if (random_val) + x = random() % (2 * N); + else + x = i / repeat_count; + data[i] = x; } insert_cols(data, N, field_meta); break; } case DataType::INT8: { vector data(N); - for (auto& x : data) { - x = random() % (2 * N); + for (int i = 0; i < N; i++) { + int8_t x = 0; + if (random_val) + x = random() % (2 * N); + else + x = i / repeat_count; + data[i] = x; } insert_cols(data, N, field_meta); break; @@ -1175,7 +1192,6 @@ translate_text_plan_to_binary_plan(const char* text_plan) { auto ok = google::protobuf::TextFormat::ParseFromString(text_plan, &plan_node); AssertInfo(ok, "Failed to parse"); - std::string binary_plan; plan_node.SerializeToString(&binary_plan); diff --git a/internal/core/unittest/test_utils/c_api_test_utils.h b/internal/core/unittest/test_utils/c_api_test_utils.h index cabf6ec432..cf5eb02eb8 100644 --- a/internal/core/unittest/test_utils/c_api_test_utils.h +++ b/internal/core/unittest/test_utils/c_api_test_utils.h @@ -25,7 +25,7 @@ #include "common/type_c.h" #include "pb/plan.pb.h" #include "segcore/Collection.h" -#include "segcore/Reduce.h" +#include "segcore/reduce/Reduce.h" #include "segcore/reduce_c.h" #include "segcore/segment_c.h" #include "futures/Future.h" @@ -86,14 +86,14 @@ generate_query_data(int nq) { return blob; } void -CheckSearchResultDuplicate(const std::vector& results) { +CheckSearchResultDuplicate(const std::vector& results, + int group_size = 1) { auto nq = ((SearchResult*)results[0])->total_nq_; - std::unordered_set pk_set; - std::unordered_set group_by_val_set; + std::unordered_map group_by_map; for (int qi = 0; qi < nq; qi++) { pk_set.clear(); - group_by_val_set.clear(); + group_by_map.clear(); for (size_t i = 0; i < results.size(); i++) { auto search_result = (SearchResult*)results[i]; ASSERT_EQ(nq, search_result->total_nq_); @@ -108,8 +108,8 @@ CheckSearchResultDuplicate(const std::vector& results) { search_result->group_by_values_.value().size() > ki) { auto group_by_val = search_result->group_by_values_.value()[ki]; - ASSERT_TRUE(group_by_val_set.count(group_by_val) == 0); - group_by_val_set.insert(group_by_val); + group_by_map[group_by_val] += 1; + ASSERT_TRUE(group_by_map[group_by_val] <= group_size); } } } diff --git a/internal/proto/plan.proto b/internal/proto/plan.proto index 62ccae20b5..7bc830172f 100644 --- a/internal/proto/plan.proto +++ b/internal/proto/plan.proto @@ -61,6 +61,7 @@ message QueryInfo { int64 round_decimal = 5; int64 group_by_field_id = 6; bool materialized_view_involved = 7; + int64 group_size = 8; } message ColumnInfo {