diff --git a/internal/core/src/common/QueryInfo.h b/internal/core/src/common/QueryInfo.h index 31785ea365..440194d33c 100644 --- a/internal/core/src/common/QueryInfo.h +++ b/internal/core/src/common/QueryInfo.h @@ -27,6 +27,7 @@ namespace milvus { struct SearchInfo { int64_t topk_{0}; int64_t group_size_{1}; + bool group_strict_size_{false}; int64_t round_decimal_{0}; FieldId field_id_; MetricType metric_type_; diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index 170b0d120c..7964c8df9c 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -212,6 +212,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { search_info.group_size_ = query_info_proto.group_size() > 0 ? query_info_proto.group_size() : 1; + search_info.group_strict_size_ = query_info_proto.group_strict_size(); } auto plan_node = [&]() -> std::unique_ptr { diff --git a/internal/core/src/query/groupby/SearchGroupByOperator.cpp b/internal/core/src/query/groupby/SearchGroupByOperator.cpp index 7b04f9cd2f..1650e55a8e 100644 --- a/internal/core/src/query/groupby/SearchGroupByOperator.cpp +++ b/internal/core/src/query/groupby/SearchGroupByOperator.cpp @@ -44,6 +44,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -58,6 +59,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -72,6 +74,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -86,6 +89,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -99,6 +103,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -113,6 +118,7 @@ SearchGroupBy(const std::vector>& iterators, GroupIteratorsByType(iterators, search_info.topk_, search_info.group_size_, + search_info.group_strict_size_, *dataGetter, group_by_values, seg_offsets, @@ -136,6 +142,7 @@ GroupIteratorsByType( const std::vector>& iterators, int64_t topK, int64_t group_size, + bool group_strict_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& seg_offsets, @@ -147,6 +154,7 @@ GroupIteratorsByType( GroupIteratorResult(iterator, topK, group_size, + group_strict_size, data_getter, group_by_values, seg_offsets, @@ -161,13 +169,14 @@ void GroupIteratorResult(const std::shared_ptr& iterator, int64_t topK, int64_t group_size, + bool group_strict_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& offsets, std::vector& distances, const knowhere::MetricType& metrics_type) { //1. - GroupByMap groupMap(topK, group_size); + GroupByMap groupMap(topK, group_size, group_strict_size); //2. do iteration until fill the whole map or run out of all data //note it may enumerate all data inside a segment and can block following @@ -195,8 +204,8 @@ GroupIteratorResult(const std::shared_ptr& iterator, //4. save groupBy results for (auto iter = res.cbegin(); iter != res.cend(); iter++) { - offsets.push_back(std::get<0>(*iter)); - distances.push_back(std::get<1>(*iter)); + offsets.emplace_back(std::get<0>(*iter)); + distances.emplace_back(std::get<1>(*iter)); group_by_values.emplace_back(std::move(std::get<2>(*iter))); } } diff --git a/internal/core/src/query/groupby/SearchGroupByOperator.h b/internal/core/src/query/groupby/SearchGroupByOperator.h index dfc51d318e..f3513ab882 100644 --- a/internal/core/src/query/groupby/SearchGroupByOperator.h +++ b/internal/core/src/query/groupby/SearchGroupByOperator.h @@ -182,6 +182,7 @@ GroupIteratorsByType( const std::vector>& iterators, int64_t topK, int64_t group_size, + bool group_strict_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& seg_offsets, @@ -195,19 +196,31 @@ struct GroupByMap { std::unordered_map group_map_{}; int group_capacity_{0}; int group_size_{0}; - int enough_group_count{0}; + int enough_group_count_{0}; + bool strict_group_size_{false}; public: - GroupByMap(int group_capacity, int group_size) - : group_capacity_(group_capacity), group_size_(group_size){}; + GroupByMap(int group_capacity, + int group_size, + bool strict_group_size = false) + : group_capacity_(group_capacity), + group_size_(group_size), + strict_group_size_(strict_group_size){}; bool IsGroupResEnough() { - return group_map_.size() == group_capacity_ && - enough_group_count == group_capacity_; + bool enough = false; + if (strict_group_size_) { + enough = group_map_.size() == group_capacity_ && + enough_group_count_ == group_capacity_; + } else { + enough = group_map_.size() == group_capacity_; + } + return enough; } bool Push(const T& t) { - if (group_map_.size() >= group_capacity_ && group_map_[t] == 0) { + if (group_map_.size() >= group_capacity_ && + group_map_.find(t) == group_map_.end()) { return false; } if (group_map_[t] >= group_size_) { @@ -218,7 +231,7 @@ struct GroupByMap { } group_map_[t] += 1; if (group_map_[t] >= group_size_) { - enough_group_count += 1; + enough_group_count_ += 1; } return true; } @@ -229,6 +242,7 @@ void GroupIteratorResult(const std::shared_ptr& iterator, int64_t topK, int64_t group_size, + bool group_strict_size, const DataGetter& data_getter, std::vector& group_by_values, std::vector& offsets, diff --git a/internal/core/unittest/test_group_by.cpp b/internal/core/unittest/test_group_by.cpp index c06ff6e558..0d334cbe51 100644 --- a/internal/core/unittest/test_group_by.cpp +++ b/internal/core/unittest/test_group_by.cpp @@ -474,6 +474,7 @@ TEST(GroupBY, SealedData) { search_params: "{\"ef\": 10}" group_by_field_id: 101, group_size: 5, + group_strict_size: true, > placeholder_tag: "$0" @@ -796,6 +797,7 @@ TEST(GroupBY, GrowingIndex) { search_params: "{\"ef\": 10}" group_by_field_id: 101 group_size: 3 + group_strict_size: true > placeholder_tag: "$0" diff --git a/internal/proto/plan.proto b/internal/proto/plan.proto index e551a242b5..e9d19f0193 100644 --- a/internal/proto/plan.proto +++ b/internal/proto/plan.proto @@ -62,6 +62,7 @@ message QueryInfo { int64 group_by_field_id = 6; bool materialized_view_involved = 7; int64 group_size = 8; + bool group_strict_size = 9; } message ColumnInfo { diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 06f2ff4a0a..ad6a90dd08 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -129,6 +129,17 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } } + var groupStrictSize bool + groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair) + if err != nil { + groupStrictSize = false + } else { + groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr) + if err != nil { + groupStrictSize = false + } + } + // 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search if isIterator == "True" && groupByFieldId > 0 { return nil, 0, merr.WrapErrParameterInvalid("", "", @@ -140,12 +151,13 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } return &planpb.QueryInfo{ - Topk: queryTopK, - MetricType: metricType, - SearchParams: searchParamStr, - RoundDecimal: roundDecimal, - GroupByFieldId: groupByFieldId, - GroupSize: groupSize, + Topk: queryTopK, + MetricType: metricType, + SearchParams: searchParamStr, + RoundDecimal: roundDecimal, + GroupByFieldId: groupByFieldId, + GroupSize: groupSize, + GroupStrictSize: groupStrictSize, }, offset, nil } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 71bb162789..2b4fad685a 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -48,6 +48,7 @@ const ( IteratorField = "iterator" GroupByFieldKey = "group_by_field" GroupSizeKey = "group_size" + GroupStrictSize = "group_strict_size" AnnsFieldKey = "anns_field" TopKKey = "topk" NQKey = "nq"