mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 01:58:34 +08:00
enhance: make search groupby stop when reaching topk groups (#35814)
related: #33544 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>
This commit is contained in:
parent
57422cb2ed
commit
4641fd9195
@ -27,6 +27,7 @@ namespace milvus {
|
|||||||
struct SearchInfo {
|
struct SearchInfo {
|
||||||
int64_t topk_{0};
|
int64_t topk_{0};
|
||||||
int64_t group_size_{1};
|
int64_t group_size_{1};
|
||||||
|
bool group_strict_size_{false};
|
||||||
int64_t round_decimal_{0};
|
int64_t round_decimal_{0};
|
||||||
FieldId field_id_;
|
FieldId field_id_;
|
||||||
MetricType metric_type_;
|
MetricType metric_type_;
|
||||||
|
|||||||
@ -212,6 +212,7 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|||||||
search_info.group_size_ = query_info_proto.group_size() > 0
|
search_info.group_size_ = query_info_proto.group_size() > 0
|
||||||
? query_info_proto.group_size()
|
? query_info_proto.group_size()
|
||||||
: 1;
|
: 1;
|
||||||
|
search_info.group_strict_size_ = query_info_proto.group_strict_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
|
auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
|
||||||
|
|||||||
@ -44,6 +44,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<int8_t>(iterators,
|
GroupIteratorsByType<int8_t>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
|
search_info.group_strict_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -58,6 +59,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<int16_t>(iterators,
|
GroupIteratorsByType<int16_t>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
|
search_info.group_strict_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -72,6 +74,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<int32_t>(iterators,
|
GroupIteratorsByType<int32_t>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
|
search_info.group_strict_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -86,6 +89,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<int64_t>(iterators,
|
GroupIteratorsByType<int64_t>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
|
search_info.group_strict_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -99,6 +103,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<bool>(iterators,
|
GroupIteratorsByType<bool>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
|
search_info.group_strict_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -113,6 +118,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<std::string>(iterators,
|
GroupIteratorsByType<std::string>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
|
search_info.group_strict_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -136,6 +142,7 @@ GroupIteratorsByType(
|
|||||||
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
||||||
int64_t topK,
|
int64_t topK,
|
||||||
int64_t group_size,
|
int64_t group_size,
|
||||||
|
bool group_strict_size,
|
||||||
const DataGetter<T>& data_getter,
|
const DataGetter<T>& data_getter,
|
||||||
std::vector<GroupByValueType>& group_by_values,
|
std::vector<GroupByValueType>& group_by_values,
|
||||||
std::vector<int64_t>& seg_offsets,
|
std::vector<int64_t>& seg_offsets,
|
||||||
@ -147,6 +154,7 @@ GroupIteratorsByType(
|
|||||||
GroupIteratorResult<T>(iterator,
|
GroupIteratorResult<T>(iterator,
|
||||||
topK,
|
topK,
|
||||||
group_size,
|
group_size,
|
||||||
|
group_strict_size,
|
||||||
data_getter,
|
data_getter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -161,13 +169,14 @@ void
|
|||||||
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
||||||
int64_t topK,
|
int64_t topK,
|
||||||
int64_t group_size,
|
int64_t group_size,
|
||||||
|
bool group_strict_size,
|
||||||
const DataGetter<T>& data_getter,
|
const DataGetter<T>& data_getter,
|
||||||
std::vector<GroupByValueType>& group_by_values,
|
std::vector<GroupByValueType>& group_by_values,
|
||||||
std::vector<int64_t>& offsets,
|
std::vector<int64_t>& offsets,
|
||||||
std::vector<float>& distances,
|
std::vector<float>& distances,
|
||||||
const knowhere::MetricType& metrics_type) {
|
const knowhere::MetricType& metrics_type) {
|
||||||
//1.
|
//1.
|
||||||
GroupByMap<T> groupMap(topK, group_size);
|
GroupByMap<T> groupMap(topK, group_size, group_strict_size);
|
||||||
|
|
||||||
//2. do iteration until fill the whole map or run out of all data
|
//2. do iteration until fill the whole map or run out of all data
|
||||||
//note it may enumerate all data inside a segment and can block following
|
//note it may enumerate all data inside a segment and can block following
|
||||||
@ -195,8 +204,8 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
|||||||
|
|
||||||
//4. save groupBy results
|
//4. save groupBy results
|
||||||
for (auto iter = res.cbegin(); iter != res.cend(); iter++) {
|
for (auto iter = res.cbegin(); iter != res.cend(); iter++) {
|
||||||
offsets.push_back(std::get<0>(*iter));
|
offsets.emplace_back(std::get<0>(*iter));
|
||||||
distances.push_back(std::get<1>(*iter));
|
distances.emplace_back(std::get<1>(*iter));
|
||||||
group_by_values.emplace_back(std::move(std::get<2>(*iter)));
|
group_by_values.emplace_back(std::move(std::get<2>(*iter)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -182,6 +182,7 @@ GroupIteratorsByType(
|
|||||||
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
||||||
int64_t topK,
|
int64_t topK,
|
||||||
int64_t group_size,
|
int64_t group_size,
|
||||||
|
bool group_strict_size,
|
||||||
const DataGetter<T>& data_getter,
|
const DataGetter<T>& data_getter,
|
||||||
std::vector<GroupByValueType>& group_by_values,
|
std::vector<GroupByValueType>& group_by_values,
|
||||||
std::vector<int64_t>& seg_offsets,
|
std::vector<int64_t>& seg_offsets,
|
||||||
@ -195,19 +196,31 @@ struct GroupByMap {
|
|||||||
std::unordered_map<T, int> group_map_{};
|
std::unordered_map<T, int> group_map_{};
|
||||||
int group_capacity_{0};
|
int group_capacity_{0};
|
||||||
int group_size_{0};
|
int group_size_{0};
|
||||||
int enough_group_count{0};
|
int enough_group_count_{0};
|
||||||
|
bool strict_group_size_{false};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
GroupByMap(int group_capacity, int group_size)
|
GroupByMap(int group_capacity,
|
||||||
: group_capacity_(group_capacity), group_size_(group_size){};
|
int group_size,
|
||||||
|
bool strict_group_size = false)
|
||||||
|
: group_capacity_(group_capacity),
|
||||||
|
group_size_(group_size),
|
||||||
|
strict_group_size_(strict_group_size){};
|
||||||
bool
|
bool
|
||||||
IsGroupResEnough() {
|
IsGroupResEnough() {
|
||||||
return group_map_.size() == group_capacity_ &&
|
bool enough = false;
|
||||||
enough_group_count == group_capacity_;
|
if (strict_group_size_) {
|
||||||
|
enough = group_map_.size() == group_capacity_ &&
|
||||||
|
enough_group_count_ == group_capacity_;
|
||||||
|
} else {
|
||||||
|
enough = group_map_.size() == group_capacity_;
|
||||||
|
}
|
||||||
|
return enough;
|
||||||
}
|
}
|
||||||
bool
|
bool
|
||||||
Push(const T& t) {
|
Push(const T& t) {
|
||||||
if (group_map_.size() >= group_capacity_ && group_map_[t] == 0) {
|
if (group_map_.size() >= group_capacity_ &&
|
||||||
|
group_map_.find(t) == group_map_.end()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (group_map_[t] >= group_size_) {
|
if (group_map_[t] >= group_size_) {
|
||||||
@ -218,7 +231,7 @@ struct GroupByMap {
|
|||||||
}
|
}
|
||||||
group_map_[t] += 1;
|
group_map_[t] += 1;
|
||||||
if (group_map_[t] >= group_size_) {
|
if (group_map_[t] >= group_size_) {
|
||||||
enough_group_count += 1;
|
enough_group_count_ += 1;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -229,6 +242,7 @@ void
|
|||||||
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
||||||
int64_t topK,
|
int64_t topK,
|
||||||
int64_t group_size,
|
int64_t group_size,
|
||||||
|
bool group_strict_size,
|
||||||
const DataGetter<T>& data_getter,
|
const DataGetter<T>& data_getter,
|
||||||
std::vector<GroupByValueType>& group_by_values,
|
std::vector<GroupByValueType>& group_by_values,
|
||||||
std::vector<int64_t>& offsets,
|
std::vector<int64_t>& offsets,
|
||||||
|
|||||||
@ -474,6 +474,7 @@ TEST(GroupBY, SealedData) {
|
|||||||
search_params: "{\"ef\": 10}"
|
search_params: "{\"ef\": 10}"
|
||||||
group_by_field_id: 101,
|
group_by_field_id: 101,
|
||||||
group_size: 5,
|
group_size: 5,
|
||||||
|
group_strict_size: true,
|
||||||
>
|
>
|
||||||
placeholder_tag: "$0"
|
placeholder_tag: "$0"
|
||||||
|
|
||||||
@ -796,6 +797,7 @@ TEST(GroupBY, GrowingIndex) {
|
|||||||
search_params: "{\"ef\": 10}"
|
search_params: "{\"ef\": 10}"
|
||||||
group_by_field_id: 101
|
group_by_field_id: 101
|
||||||
group_size: 3
|
group_size: 3
|
||||||
|
group_strict_size: true
|
||||||
>
|
>
|
||||||
placeholder_tag: "$0"
|
placeholder_tag: "$0"
|
||||||
|
|
||||||
|
|||||||
@ -62,6 +62,7 @@ message QueryInfo {
|
|||||||
int64 group_by_field_id = 6;
|
int64 group_by_field_id = 6;
|
||||||
bool materialized_view_involved = 7;
|
bool materialized_view_involved = 7;
|
||||||
int64 group_size = 8;
|
int64 group_size = 8;
|
||||||
|
bool group_strict_size = 9;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ColumnInfo {
|
message ColumnInfo {
|
||||||
|
|||||||
@ -129,6 +129,17 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var groupStrictSize bool
|
||||||
|
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
|
||||||
|
if err != nil {
|
||||||
|
groupStrictSize = false
|
||||||
|
} else {
|
||||||
|
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
|
||||||
|
if err != nil {
|
||||||
|
groupStrictSize = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||||
if isIterator == "True" && groupByFieldId > 0 {
|
if isIterator == "True" && groupByFieldId > 0 {
|
||||||
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
return nil, 0, merr.WrapErrParameterInvalid("", "",
|
||||||
@ -146,6 +157,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||||||
RoundDecimal: roundDecimal,
|
RoundDecimal: roundDecimal,
|
||||||
GroupByFieldId: groupByFieldId,
|
GroupByFieldId: groupByFieldId,
|
||||||
GroupSize: groupSize,
|
GroupSize: groupSize,
|
||||||
|
GroupStrictSize: groupStrictSize,
|
||||||
}, offset, nil
|
}, offset, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -48,6 +48,7 @@ const (
|
|||||||
IteratorField = "iterator"
|
IteratorField = "iterator"
|
||||||
GroupByFieldKey = "group_by_field"
|
GroupByFieldKey = "group_by_field"
|
||||||
GroupSizeKey = "group_size"
|
GroupSizeKey = "group_size"
|
||||||
|
GroupStrictSize = "group_strict_size"
|
||||||
AnnsFieldKey = "anns_field"
|
AnnsFieldKey = "anns_field"
|
||||||
TopKKey = "topk"
|
TopKKey = "topk"
|
||||||
NQKey = "nq"
|
NQKey = "nq"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user