enhance:refine group_strict_size parameter(#37482) (#37483)

related: #37482

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
This commit is contained in:
Chun Han 2024-11-12 09:56:28 +08:00 committed by GitHub
parent 21b68029a0
commit 2d29dcd30c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 60 additions and 60 deletions

View File

@ -27,7 +27,7 @@ namespace milvus {
struct SearchInfo { struct SearchInfo {
int64_t topk_{0}; int64_t topk_{0};
int64_t group_size_{1}; int64_t group_size_{1};
bool group_strict_size_{false}; bool strict_group_size_{false};
int64_t round_decimal_{0}; int64_t round_decimal_{0};
FieldId field_id_; FieldId field_id_;
MetricType metric_type_; MetricType metric_type_;

View File

@ -44,7 +44,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int8_t>(iterators, GroupIteratorsByType<int8_t>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_, search_info.group_size_,
search_info.group_strict_size_, search_info.strict_group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
@ -59,7 +59,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int16_t>(iterators, GroupIteratorsByType<int16_t>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_, search_info.group_size_,
search_info.group_strict_size_, search_info.strict_group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
@ -74,7 +74,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int32_t>(iterators, GroupIteratorsByType<int32_t>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_, search_info.group_size_,
search_info.group_strict_size_, search_info.strict_group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
@ -89,7 +89,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<int64_t>(iterators, GroupIteratorsByType<int64_t>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_, search_info.group_size_,
search_info.group_strict_size_, search_info.strict_group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
@ -103,7 +103,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<bool>(iterators, GroupIteratorsByType<bool>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_, search_info.group_size_,
search_info.group_strict_size_, search_info.strict_group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
@ -118,7 +118,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GroupIteratorsByType<std::string>(iterators, GroupIteratorsByType<std::string>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_, search_info.group_size_,
search_info.group_strict_size_, search_info.strict_group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
@ -142,7 +142,7 @@ GroupIteratorsByType(
const std::vector<std::shared_ptr<VectorIterator>>& iterators, const std::vector<std::shared_ptr<VectorIterator>>& iterators,
int64_t topK, int64_t topK,
int64_t group_size, int64_t group_size,
bool group_strict_size, bool strict_group_size,
const DataGetter<T>& data_getter, const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values, std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets, std::vector<int64_t>& seg_offsets,
@ -154,7 +154,7 @@ GroupIteratorsByType(
GroupIteratorResult<T>(iterator, GroupIteratorResult<T>(iterator,
topK, topK,
group_size, group_size,
group_strict_size, strict_group_size,
data_getter, data_getter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
@ -169,14 +169,14 @@ void
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator, GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
int64_t topK, int64_t topK,
int64_t group_size, int64_t group_size,
bool group_strict_size, bool strict_group_size,
const DataGetter<T>& data_getter, const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values, std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets, std::vector<int64_t>& offsets,
std::vector<float>& distances, std::vector<float>& distances,
const knowhere::MetricType& metrics_type) { const knowhere::MetricType& metrics_type) {
//1. //1.
GroupByMap<T> groupMap(topK, group_size, group_strict_size); GroupByMap<T> groupMap(topK, group_size, strict_group_size);
//2. do iteration until fill the whole map or run out of all data //2. do iteration until fill the whole map or run out of all data
//note it may enumerate all data inside a segment and can block following //note it may enumerate all data inside a segment and can block following

View File

@ -183,7 +183,7 @@ GroupIteratorsByType(
const std::vector<std::shared_ptr<VectorIterator>>& iterators, const std::vector<std::shared_ptr<VectorIterator>>& iterators,
int64_t topK, int64_t topK,
int64_t group_size, int64_t group_size,
bool group_strict_size, bool strict_group_size,
const DataGetter<T>& data_getter, const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values, std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets, std::vector<int64_t>& seg_offsets,
@ -243,7 +243,7 @@ void
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator, GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
int64_t topK, int64_t topK,
int64_t group_size, int64_t group_size,
bool group_strict_size, bool strict_group_size,
const DataGetter<T>& data_getter, const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values, std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets, std::vector<int64_t>& offsets,

View File

@ -66,8 +66,8 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
search_info.group_size_ = query_info_proto.group_size() > 0 search_info.group_size_ = query_info_proto.group_size() > 0
? query_info_proto.group_size() ? query_info_proto.group_size()
: 1; : 1;
search_info.group_strict_size_ = search_info.strict_group_size_ =
query_info_proto.group_strict_size(); query_info_proto.strict_group_size();
} }
return search_info; return search_info;
}; };

View File

@ -474,7 +474,7 @@ TEST(GroupBY, SealedData) {
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 101, group_by_field_id: 101,
group_size: 5, group_size: 5,
group_strict_size: true, strict_group_size: true,
> >
placeholder_tag: "$0" placeholder_tag: "$0"
@ -797,7 +797,7 @@ TEST(GroupBY, GrowingIndex) {
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 101 group_by_field_id: 101
group_size: 3 group_size: 3
group_strict_size: true strict_group_size: true
> >
placeholder_tag: "$0" placeholder_tag: "$0"

View File

@ -154,6 +154,6 @@ const (
ParamRangeFilter = "range_filter" ParamRangeFilter = "range_filter"
ParamGroupByField = "group_by_field" ParamGroupByField = "group_by_field"
ParamGroupSize = "group_size" ParamGroupSize = "group_size"
ParamGroupStrictSize = "group_strict_size" ParamStrictGroupSize = "strict_group_size"
BoundedTimestamp = 2 BoundedTimestamp = 2
) )

View File

@ -1002,7 +1002,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
} }
if httpReq.GroupByField != "" && httpReq.GroupSize > 0 { if httpReq.GroupByField != "" && httpReq.GroupSize > 0 {
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)})
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamStrictGroupSize, Value: strconv.FormatBool(httpReq.StrictGroupSize)})
} }
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField}) searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField})
body, _ := c.Get(gin.BodyBytesKey) body, _ := c.Get(gin.BodyBytesKey)
@ -1107,7 +1107,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
} }
if httpReq.GroupByField != "" && httpReq.GroupSize > 0 { if httpReq.GroupByField != "" && httpReq.GroupSize > 0 {
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)}) req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)})
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)}) req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamStrictGroupSize, Value: strconv.FormatBool(httpReq.StrictGroupSize)})
} }
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest)) return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest))

View File

@ -157,7 +157,7 @@ type SearchReqV2 struct {
Filter string `json:"filter"` Filter string `json:"filter"`
GroupByField string `json:"groupingField"` GroupByField string `json:"groupingField"`
GroupSize int32 `json:"groupSize"` GroupSize int32 `json:"groupSize"`
GroupStrictSize bool `json:"groupStrictSize"` StrictGroupSize bool `json:"strictGroupSize"`
Limit int32 `json:"limit"` Limit int32 `json:"limit"`
Offset int32 `json:"offset"` Offset int32 `json:"offset"`
OutputFields []string `json:"outputFields"` OutputFields []string `json:"outputFields"`
@ -194,7 +194,7 @@ type HybridSearchReq struct {
Limit int32 `json:"limit"` Limit int32 `json:"limit"`
GroupByField string `json:"groupingField"` GroupByField string `json:"groupingField"`
GroupSize int32 `json:"groupSize"` GroupSize int32 `json:"groupSize"`
GroupStrictSize bool `json:"groupStrictSize"` StrictGroupSize bool `json:"strictGroupSize"`
OutputFields []string `json:"outputFields"` OutputFields []string `json:"outputFields"`
ConsistencyLevel string `json:"consistencyLevel"` ConsistencyLevel string `json:"consistencyLevel"`
} }

View File

@ -63,7 +63,7 @@ message QueryInfo {
int64 group_by_field_id = 6; int64 group_by_field_id = 6;
bool materialized_view_involved = 7; bool materialized_view_involved = 7;
int64 group_size = 8; int64 group_size = 8;
bool group_strict_size = 9; bool strict_group_size = 9;
double bm25_avgdl = 10; double bm25_avgdl = 10;
int64 query_field_id =11; int64 query_field_id =11;
} }

View File

@ -26,7 +26,7 @@ type rankParams struct {
roundDecimal int64 roundDecimal int64
groupByFieldId int64 groupByFieldId int64
groupSize int64 groupSize int64
groupStrictSize bool strictGroupSize bool
} }
func (r *rankParams) GetLimit() int64 { func (r *rankParams) GetLimit() int64 {
@ -64,9 +64,9 @@ func (r *rankParams) GetGroupSize() int64 {
return 1 return 1
} }
func (r *rankParams) GetGroupStrictSize() bool { func (r *rankParams) GetStrictGroupSize() bool {
if r != nil { if r != nil {
return r.groupStrictSize return r.strictGroupSize
} }
return false return false
} }
@ -170,15 +170,15 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
// 5. parse group by field and group by size // 5. parse group by field and group by size
var groupByFieldId, groupSize int64 var groupByFieldId, groupSize int64
var groupStrictSize bool var strictGroupSize bool
if isAdvanced { if isAdvanced {
groupByFieldId, groupSize, groupStrictSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetGroupStrictSize() groupByFieldId, groupSize, strictGroupSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetStrictGroupSize()
} else { } else {
groupByInfo := parseGroupByInfo(searchParamsPair, schema) groupByInfo := parseGroupByInfo(searchParamsPair, schema)
if groupByInfo.err != nil { if groupByInfo.err != nil {
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: groupByInfo.err} return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: groupByInfo.err}
} }
groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize() groupByFieldId, groupSize, strictGroupSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetStrictGroupSize()
} }
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search // 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
@ -199,7 +199,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
RoundDecimal: roundDecimal, RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId, GroupByFieldId: groupByFieldId,
GroupSize: groupSize, GroupSize: groupSize,
GroupStrictSize: groupStrictSize, StrictGroupSize: strictGroupSize,
}, },
offset: offset, offset: offset,
isIterator: isIterator, isIterator: isIterator,
@ -303,7 +303,7 @@ func getPartitionIDs(ctx context.Context, dbName string, collectionName string,
type groupByInfo struct { type groupByInfo struct {
groupByFieldId int64 groupByFieldId int64
groupSize int64 groupSize int64
groupStrictSize bool strictGroupSize bool
err error err error
} }
@ -321,9 +321,9 @@ func (g *groupByInfo) GetGroupSize() int64 {
return 0 return 0
} }
func (g *groupByInfo) GetGroupStrictSize() bool { func (g *groupByInfo) GetStrictGroupSize() bool {
if g != nil { if g != nil {
return g.groupStrictSize return g.strictGroupSize
} }
return false return false
} }
@ -389,17 +389,17 @@ func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemap
ret.groupSize = groupSize ret.groupSize = groupSize
// 3. parse group strict size // 3. parse group strict size
var groupStrictSize bool var strictGroupSize bool
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair) strictGroupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(StrictGroupSize, searchParamsPair)
if err != nil { if err != nil {
groupStrictSize = false strictGroupSize = false
} else { } else {
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr) strictGroupSize, err = strconv.ParseBool(strictGroupSizeStr)
if err != nil { if err != nil {
groupStrictSize = false strictGroupSize = false
} }
} }
ret.groupStrictSize = groupStrictSize ret.strictGroupSize = strictGroupSize
return ret return ret
} }
@ -460,7 +460,7 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.C
roundDecimal: roundDecimal, roundDecimal: roundDecimal,
groupByFieldId: groupByInfo.GetGroupByFieldId(), groupByFieldId: groupByInfo.GetGroupByFieldId(),
groupSize: groupByInfo.GetGroupSize(), groupSize: groupByInfo.GetGroupSize(),
groupStrictSize: groupByInfo.GetGroupStrictSize(), strictGroupSize: groupByInfo.GetStrictGroupSize(),
}, nil }, nil
} }

View File

@ -55,7 +55,7 @@ const (
IteratorField = "iterator" IteratorField = "iterator"
GroupByFieldKey = "group_by_field" GroupByFieldKey = "group_by_field"
GroupSizeKey = "group_size" GroupSizeKey = "group_size"
GroupStrictSize = "group_strict_size" StrictGroupSize = "strict_group_size"
RankGroupScorer = "rank_group_scorer" RankGroupScorer = "rank_group_scorer"
AnnsFieldKey = "anns_field" AnnsFieldKey = "anns_field"
TopKKey = "topk" TopKKey = "topk"

View File

@ -2325,7 +2325,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
}, },
} }
// 1. first parse rank params // 1. first parse rank params
// outer params require to group by field 101 and groupSize=3 and groupStrictSize=false // outer params require to group by field 101 and groupSize=3 and strictGroupSize=false
testRankParamsPairs := getValidSearchParams() testRankParamsPairs := getValidSearchParams()
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{ testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
Key: GroupByFieldKey, Key: GroupByFieldKey,
@ -2336,7 +2336,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
Value: strconv.FormatInt(3, 10), Value: strconv.FormatInt(3, 10),
}) })
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{ testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
Key: GroupStrictSize, Key: StrictGroupSize,
Value: "false", Value: "false",
}) })
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{ testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
@ -2348,7 +2348,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
// 2. parse search params for sub request in hybridsearch // 2. parse search params for sub request in hybridsearch
params := getValidSearchParams() params := getValidSearchParams()
// inner params require to group by field 103 and groupSize=10 and groupStrictSize=true // inner params require to group by field 103 and groupSize=10 and strictGroupSize=true
params = append(params, &commonpb.KeyValuePair{ params = append(params, &commonpb.KeyValuePair{
Key: GroupByFieldKey, Key: GroupByFieldKey,
Value: "c3", Value: "c3",
@ -2358,7 +2358,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
Value: strconv.FormatInt(10, 10), Value: strconv.FormatInt(10, 10),
}) })
params = append(params, &commonpb.KeyValuePair{ params = append(params, &commonpb.KeyValuePair{
Key: GroupStrictSize, Key: StrictGroupSize,
Value: "true", Value: "true",
}) })
@ -2370,7 +2370,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
// set by main request rather than inner sub request // set by main request rather than inner sub request
assert.Equal(t, int64(101), searchInfo.planInfo.GetGroupByFieldId()) assert.Equal(t, int64(101), searchInfo.planInfo.GetGroupByFieldId())
assert.Equal(t, int64(3), searchInfo.planInfo.GetGroupSize()) assert.Equal(t, int64(3), searchInfo.planInfo.GetGroupSize())
assert.False(t, searchInfo.planInfo.GetGroupStrictSize()) assert.False(t, searchInfo.planInfo.GetStrictGroupSize())
}) })
t.Run("parseSearchInfo error", func(t *testing.T) { t.Run("parseSearchInfo error", func(t *testing.T) {

View File

@ -27,8 +27,8 @@ pytest-parallel
pytest-random-order pytest-random-order
# pymilvus # pymilvus
pymilvus==2.5.0rc108 pymilvus==2.5.0rc111
pymilvus[bulk_writer]==2.5.0rc108 pymilvus[bulk_writer]==2.5.0rc111
# for customize config test # for customize config test
python-benedict==0.24.3 python-benedict==0.24.3

View File

@ -2203,8 +2203,8 @@ class TestGroupSearch(TestCaseClassBase):
""" """
target: target:
1. search on 4 different float vector fields with group by varchar field with group size 1. search on 4 different float vector fields with group by varchar field with group size
verify results entity = limit * group_size and group size is full if group_strict_size is True verify results entity = limit * group_size and group size is full if strict_group_size is True
verify results group counts = limit if group_strict_size is False verify results group counts = limit if strict_group_size is False
""" """
nq = 2 nq = 2
limit = 50 limit = 50
@ -2212,11 +2212,11 @@ class TestGroupSearch(TestCaseClassBase):
for j in range(len(self.vector_fields)): for j in range(len(self.vector_fields)):
search_vectors = cf.gen_vectors(nq, dim=self.dims[j], vector_data_type=self.vector_fields[j]) search_vectors = cf.gen_vectors(nq, dim=self.dims[j], vector_data_type=self.vector_fields[j])
search_params = {"params": cf.get_search_params_params(self.index_types[j])} search_params = {"params": cf.get_search_params_params(self.index_types[j])}
# when group_strict_size=true, it shall return results with entities = limit * group_size # when strict_group_size=true, it shall return results with entities = limit * group_size
res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j], res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j],
param=search_params, limit=limit, param=search_params, limit=limit,
group_by_field=group_by_field, group_by_field=group_by_field,
group_size=group_size, group_strict_size=True, group_size=group_size, strict_group_size=True,
output_fields=[group_by_field])[0] output_fields=[group_by_field])[0]
for i in range(nq): for i in range(nq):
assert len(res1[i]) == limit * group_size assert len(res1[i]) == limit * group_size
@ -2226,11 +2226,11 @@ class TestGroupSearch(TestCaseClassBase):
group_values.append(res1[i][l*group_size+k].fields.get(group_by_field)) group_values.append(res1[i][l*group_size+k].fields.get(group_by_field))
assert len(set(group_values)) == 1 assert len(set(group_values)) == 1
# when group_strict_size=false, it shall return results with group counts = limit # when strict_group_size=false, it shall return results with group counts = limit
res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j], res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j],
param=search_params, limit=limit, param=search_params, limit=limit,
group_by_field=group_by_field, group_by_field=group_by_field,
group_size=group_size, group_strict_size=False, group_size=group_size, strict_group_size=False,
output_fields=[group_by_field])[0] output_fields=[group_by_field])[0]
for i in range(nq): for i in range(nq):
group_values = [] group_values = []
@ -2438,7 +2438,7 @@ class TestGroupSearch(TestCaseClassBase):
param=search_param, limit=limit, offset=limit * r, param=search_param, limit=limit, offset=limit * r,
expr=default_search_exp, expr=default_search_exp,
group_by_field=grpby_field, group_size=group_size, group_by_field=grpby_field, group_size=group_size,
group_strict_size=True, strict_group_size=True,
output_fields=[grpby_field], output_fields=[grpby_field],
check_task=CheckTasks.check_search_results, check_task=CheckTasks.check_search_results,
check_items={"nq": 1, "limit": res_count}, check_items={"nq": 1, "limit": res_count},
@ -2459,7 +2459,7 @@ class TestGroupSearch(TestCaseClassBase):
param=search_param, limit=limit * page_rounds, param=search_param, limit=limit * page_rounds,
expr=default_search_exp, expr=default_search_exp,
group_by_field=grpby_field, group_size=group_size, group_by_field=grpby_field, group_size=group_size,
group_strict_size=True, strict_group_size=True,
output_fields=[grpby_field], output_fields=[grpby_field],
check_task=CheckTasks.check_search_results, check_task=CheckTasks.check_search_results,
check_items={"nq": 1, "limit": total_count} check_items={"nq": 1, "limit": total_count}
@ -2488,7 +2488,7 @@ class TestGroupSearch(TestCaseClassBase):
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field, self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
param=search_params, limit=limit, param=search_params, limit=limit,
group_by_field=group_by_field, group_by_field=group_by_field,
group_size=max_group_size, group_strict_size=True, group_size=max_group_size, strict_group_size=True,
output_fields=[group_by_field]) output_fields=[group_by_field])
exceed_max_group_size = max_group_size + 1 exceed_max_group_size = max_group_size + 1
error = {ct.err_code: 999, error = {ct.err_code: 999,
@ -2497,7 +2497,7 @@ class TestGroupSearch(TestCaseClassBase):
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field, self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
param=search_params, limit=limit, param=search_params, limit=limit,
group_by_field=group_by_field, group_by_field=group_by_field,
group_size=exceed_max_group_size, group_strict_size=True, group_size=exceed_max_group_size, strict_group_size=True,
output_fields=[group_by_field], output_fields=[group_by_field],
check_task=CheckTasks.err_res, check_items=error) check_task=CheckTasks.err_res, check_items=error)
@ -2505,7 +2505,7 @@ class TestGroupSearch(TestCaseClassBase):
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field, self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
param=search_params, limit=limit, param=search_params, limit=limit,
group_by_field=group_by_field, group_by_field=group_by_field,
group_size=max_group_size, group_strict_size=True, group_size=max_group_size, strict_group_size=True,
output_fields=[group_by_field]) output_fields=[group_by_field])
below_min_group_size = min_group_size - 1 below_min_group_size = min_group_size - 1
error = {ct.err_code: 999, error = {ct.err_code: 999,
@ -2513,6 +2513,6 @@ class TestGroupSearch(TestCaseClassBase):
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field, self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
param=search_params, limit=limit, param=search_params, limit=limit,
group_by_field=group_by_field, group_by_field=group_by_field,
group_size=below_min_group_size, group_strict_size=True, group_size=below_min_group_size, strict_group_size=True,
output_fields=[group_by_field], output_fields=[group_by_field],
check_task=CheckTasks.err_res, check_items=error) check_task=CheckTasks.err_res, check_items=error)