mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
related: #37482 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>
This commit is contained in:
parent
21b68029a0
commit
2d29dcd30c
@ -27,7 +27,7 @@ namespace milvus {
|
|||||||
struct SearchInfo {
|
struct SearchInfo {
|
||||||
int64_t topk_{0};
|
int64_t topk_{0};
|
||||||
int64_t group_size_{1};
|
int64_t group_size_{1};
|
||||||
bool group_strict_size_{false};
|
bool strict_group_size_{false};
|
||||||
int64_t round_decimal_{0};
|
int64_t round_decimal_{0};
|
||||||
FieldId field_id_;
|
FieldId field_id_;
|
||||||
MetricType metric_type_;
|
MetricType metric_type_;
|
||||||
|
|||||||
@ -44,7 +44,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<int8_t>(iterators,
|
GroupIteratorsByType<int8_t>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
search_info.group_strict_size_,
|
search_info.strict_group_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -59,7 +59,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<int16_t>(iterators,
|
GroupIteratorsByType<int16_t>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
search_info.group_strict_size_,
|
search_info.strict_group_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -74,7 +74,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<int32_t>(iterators,
|
GroupIteratorsByType<int32_t>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
search_info.group_strict_size_,
|
search_info.strict_group_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -89,7 +89,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<int64_t>(iterators,
|
GroupIteratorsByType<int64_t>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
search_info.group_strict_size_,
|
search_info.strict_group_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -103,7 +103,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<bool>(iterators,
|
GroupIteratorsByType<bool>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
search_info.group_strict_size_,
|
search_info.strict_group_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -118,7 +118,7 @@ SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
|||||||
GroupIteratorsByType<std::string>(iterators,
|
GroupIteratorsByType<std::string>(iterators,
|
||||||
search_info.topk_,
|
search_info.topk_,
|
||||||
search_info.group_size_,
|
search_info.group_size_,
|
||||||
search_info.group_strict_size_,
|
search_info.strict_group_size_,
|
||||||
*dataGetter,
|
*dataGetter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -142,7 +142,7 @@ GroupIteratorsByType(
|
|||||||
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
||||||
int64_t topK,
|
int64_t topK,
|
||||||
int64_t group_size,
|
int64_t group_size,
|
||||||
bool group_strict_size,
|
bool strict_group_size,
|
||||||
const DataGetter<T>& data_getter,
|
const DataGetter<T>& data_getter,
|
||||||
std::vector<GroupByValueType>& group_by_values,
|
std::vector<GroupByValueType>& group_by_values,
|
||||||
std::vector<int64_t>& seg_offsets,
|
std::vector<int64_t>& seg_offsets,
|
||||||
@ -154,7 +154,7 @@ GroupIteratorsByType(
|
|||||||
GroupIteratorResult<T>(iterator,
|
GroupIteratorResult<T>(iterator,
|
||||||
topK,
|
topK,
|
||||||
group_size,
|
group_size,
|
||||||
group_strict_size,
|
strict_group_size,
|
||||||
data_getter,
|
data_getter,
|
||||||
group_by_values,
|
group_by_values,
|
||||||
seg_offsets,
|
seg_offsets,
|
||||||
@ -169,14 +169,14 @@ void
|
|||||||
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
||||||
int64_t topK,
|
int64_t topK,
|
||||||
int64_t group_size,
|
int64_t group_size,
|
||||||
bool group_strict_size,
|
bool strict_group_size,
|
||||||
const DataGetter<T>& data_getter,
|
const DataGetter<T>& data_getter,
|
||||||
std::vector<GroupByValueType>& group_by_values,
|
std::vector<GroupByValueType>& group_by_values,
|
||||||
std::vector<int64_t>& offsets,
|
std::vector<int64_t>& offsets,
|
||||||
std::vector<float>& distances,
|
std::vector<float>& distances,
|
||||||
const knowhere::MetricType& metrics_type) {
|
const knowhere::MetricType& metrics_type) {
|
||||||
//1.
|
//1.
|
||||||
GroupByMap<T> groupMap(topK, group_size, group_strict_size);
|
GroupByMap<T> groupMap(topK, group_size, strict_group_size);
|
||||||
|
|
||||||
//2. do iteration until fill the whole map or run out of all data
|
//2. do iteration until fill the whole map or run out of all data
|
||||||
//note it may enumerate all data inside a segment and can block following
|
//note it may enumerate all data inside a segment and can block following
|
||||||
|
|||||||
@ -183,7 +183,7 @@ GroupIteratorsByType(
|
|||||||
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
const std::vector<std::shared_ptr<VectorIterator>>& iterators,
|
||||||
int64_t topK,
|
int64_t topK,
|
||||||
int64_t group_size,
|
int64_t group_size,
|
||||||
bool group_strict_size,
|
bool strict_group_size,
|
||||||
const DataGetter<T>& data_getter,
|
const DataGetter<T>& data_getter,
|
||||||
std::vector<GroupByValueType>& group_by_values,
|
std::vector<GroupByValueType>& group_by_values,
|
||||||
std::vector<int64_t>& seg_offsets,
|
std::vector<int64_t>& seg_offsets,
|
||||||
@ -243,7 +243,7 @@ void
|
|||||||
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
|
||||||
int64_t topK,
|
int64_t topK,
|
||||||
int64_t group_size,
|
int64_t group_size,
|
||||||
bool group_strict_size,
|
bool strict_group_size,
|
||||||
const DataGetter<T>& data_getter,
|
const DataGetter<T>& data_getter,
|
||||||
std::vector<GroupByValueType>& group_by_values,
|
std::vector<GroupByValueType>& group_by_values,
|
||||||
std::vector<int64_t>& offsets,
|
std::vector<int64_t>& offsets,
|
||||||
|
|||||||
@ -66,8 +66,8 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
|||||||
search_info.group_size_ = query_info_proto.group_size() > 0
|
search_info.group_size_ = query_info_proto.group_size() > 0
|
||||||
? query_info_proto.group_size()
|
? query_info_proto.group_size()
|
||||||
: 1;
|
: 1;
|
||||||
search_info.group_strict_size_ =
|
search_info.strict_group_size_ =
|
||||||
query_info_proto.group_strict_size();
|
query_info_proto.strict_group_size();
|
||||||
}
|
}
|
||||||
return search_info;
|
return search_info;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -474,7 +474,7 @@ TEST(GroupBY, SealedData) {
|
|||||||
search_params: "{\"ef\": 10}"
|
search_params: "{\"ef\": 10}"
|
||||||
group_by_field_id: 101,
|
group_by_field_id: 101,
|
||||||
group_size: 5,
|
group_size: 5,
|
||||||
group_strict_size: true,
|
strict_group_size: true,
|
||||||
>
|
>
|
||||||
placeholder_tag: "$0"
|
placeholder_tag: "$0"
|
||||||
|
|
||||||
@ -797,7 +797,7 @@ TEST(GroupBY, GrowingIndex) {
|
|||||||
search_params: "{\"ef\": 10}"
|
search_params: "{\"ef\": 10}"
|
||||||
group_by_field_id: 101
|
group_by_field_id: 101
|
||||||
group_size: 3
|
group_size: 3
|
||||||
group_strict_size: true
|
strict_group_size: true
|
||||||
>
|
>
|
||||||
placeholder_tag: "$0"
|
placeholder_tag: "$0"
|
||||||
|
|
||||||
|
|||||||
@ -154,6 +154,6 @@ const (
|
|||||||
ParamRangeFilter = "range_filter"
|
ParamRangeFilter = "range_filter"
|
||||||
ParamGroupByField = "group_by_field"
|
ParamGroupByField = "group_by_field"
|
||||||
ParamGroupSize = "group_size"
|
ParamGroupSize = "group_size"
|
||||||
ParamGroupStrictSize = "group_strict_size"
|
ParamStrictGroupSize = "strict_group_size"
|
||||||
BoundedTimestamp = 2
|
BoundedTimestamp = 2
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1002,7 +1002,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN
|
|||||||
}
|
}
|
||||||
if httpReq.GroupByField != "" && httpReq.GroupSize > 0 {
|
if httpReq.GroupByField != "" && httpReq.GroupSize > 0 {
|
||||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)})
|
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)})
|
||||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)})
|
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: ParamStrictGroupSize, Value: strconv.FormatBool(httpReq.StrictGroupSize)})
|
||||||
}
|
}
|
||||||
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField})
|
searchParams = append(searchParams, &commonpb.KeyValuePair{Key: proxy.AnnsFieldKey, Value: httpReq.AnnsField})
|
||||||
body, _ := c.Get(gin.BodyBytesKey)
|
body, _ := c.Get(gin.BodyBytesKey)
|
||||||
@ -1107,7 +1107,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq
|
|||||||
}
|
}
|
||||||
if httpReq.GroupByField != "" && httpReq.GroupSize > 0 {
|
if httpReq.GroupByField != "" && httpReq.GroupSize > 0 {
|
||||||
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)})
|
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupSize, Value: strconv.FormatInt(int64(httpReq.GroupSize), 10)})
|
||||||
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamGroupStrictSize, Value: strconv.FormatBool(httpReq.GroupStrictSize)})
|
req.RankParams = append(req.RankParams, &commonpb.KeyValuePair{Key: ParamStrictGroupSize, Value: strconv.FormatBool(httpReq.StrictGroupSize)})
|
||||||
}
|
}
|
||||||
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/HybridSearch", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) {
|
||||||
return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest))
|
return h.proxy.HybridSearch(reqCtx, req.(*milvuspb.HybridSearchRequest))
|
||||||
|
|||||||
@ -157,7 +157,7 @@ type SearchReqV2 struct {
|
|||||||
Filter string `json:"filter"`
|
Filter string `json:"filter"`
|
||||||
GroupByField string `json:"groupingField"`
|
GroupByField string `json:"groupingField"`
|
||||||
GroupSize int32 `json:"groupSize"`
|
GroupSize int32 `json:"groupSize"`
|
||||||
GroupStrictSize bool `json:"groupStrictSize"`
|
StrictGroupSize bool `json:"strictGroupSize"`
|
||||||
Limit int32 `json:"limit"`
|
Limit int32 `json:"limit"`
|
||||||
Offset int32 `json:"offset"`
|
Offset int32 `json:"offset"`
|
||||||
OutputFields []string `json:"outputFields"`
|
OutputFields []string `json:"outputFields"`
|
||||||
@ -194,7 +194,7 @@ type HybridSearchReq struct {
|
|||||||
Limit int32 `json:"limit"`
|
Limit int32 `json:"limit"`
|
||||||
GroupByField string `json:"groupingField"`
|
GroupByField string `json:"groupingField"`
|
||||||
GroupSize int32 `json:"groupSize"`
|
GroupSize int32 `json:"groupSize"`
|
||||||
GroupStrictSize bool `json:"groupStrictSize"`
|
StrictGroupSize bool `json:"strictGroupSize"`
|
||||||
OutputFields []string `json:"outputFields"`
|
OutputFields []string `json:"outputFields"`
|
||||||
ConsistencyLevel string `json:"consistencyLevel"`
|
ConsistencyLevel string `json:"consistencyLevel"`
|
||||||
}
|
}
|
||||||
|
|||||||
@ -63,7 +63,7 @@ message QueryInfo {
|
|||||||
int64 group_by_field_id = 6;
|
int64 group_by_field_id = 6;
|
||||||
bool materialized_view_involved = 7;
|
bool materialized_view_involved = 7;
|
||||||
int64 group_size = 8;
|
int64 group_size = 8;
|
||||||
bool group_strict_size = 9;
|
bool strict_group_size = 9;
|
||||||
double bm25_avgdl = 10;
|
double bm25_avgdl = 10;
|
||||||
int64 query_field_id =11;
|
int64 query_field_id =11;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -26,7 +26,7 @@ type rankParams struct {
|
|||||||
roundDecimal int64
|
roundDecimal int64
|
||||||
groupByFieldId int64
|
groupByFieldId int64
|
||||||
groupSize int64
|
groupSize int64
|
||||||
groupStrictSize bool
|
strictGroupSize bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rankParams) GetLimit() int64 {
|
func (r *rankParams) GetLimit() int64 {
|
||||||
@ -64,9 +64,9 @@ func (r *rankParams) GetGroupSize() int64 {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rankParams) GetGroupStrictSize() bool {
|
func (r *rankParams) GetStrictGroupSize() bool {
|
||||||
if r != nil {
|
if r != nil {
|
||||||
return r.groupStrictSize
|
return r.strictGroupSize
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -170,15 +170,15 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||||||
|
|
||||||
// 5. parse group by field and group by size
|
// 5. parse group by field and group by size
|
||||||
var groupByFieldId, groupSize int64
|
var groupByFieldId, groupSize int64
|
||||||
var groupStrictSize bool
|
var strictGroupSize bool
|
||||||
if isAdvanced {
|
if isAdvanced {
|
||||||
groupByFieldId, groupSize, groupStrictSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetGroupStrictSize()
|
groupByFieldId, groupSize, strictGroupSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetStrictGroupSize()
|
||||||
} else {
|
} else {
|
||||||
groupByInfo := parseGroupByInfo(searchParamsPair, schema)
|
groupByInfo := parseGroupByInfo(searchParamsPair, schema)
|
||||||
if groupByInfo.err != nil {
|
if groupByInfo.err != nil {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: groupByInfo.err}
|
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: groupByInfo.err}
|
||||||
}
|
}
|
||||||
groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize()
|
groupByFieldId, groupSize, strictGroupSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetStrictGroupSize()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||||
@ -199,7 +199,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||||||
RoundDecimal: roundDecimal,
|
RoundDecimal: roundDecimal,
|
||||||
GroupByFieldId: groupByFieldId,
|
GroupByFieldId: groupByFieldId,
|
||||||
GroupSize: groupSize,
|
GroupSize: groupSize,
|
||||||
GroupStrictSize: groupStrictSize,
|
StrictGroupSize: strictGroupSize,
|
||||||
},
|
},
|
||||||
offset: offset,
|
offset: offset,
|
||||||
isIterator: isIterator,
|
isIterator: isIterator,
|
||||||
@ -303,7 +303,7 @@ func getPartitionIDs(ctx context.Context, dbName string, collectionName string,
|
|||||||
type groupByInfo struct {
|
type groupByInfo struct {
|
||||||
groupByFieldId int64
|
groupByFieldId int64
|
||||||
groupSize int64
|
groupSize int64
|
||||||
groupStrictSize bool
|
strictGroupSize bool
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -321,9 +321,9 @@ func (g *groupByInfo) GetGroupSize() int64 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *groupByInfo) GetGroupStrictSize() bool {
|
func (g *groupByInfo) GetStrictGroupSize() bool {
|
||||||
if g != nil {
|
if g != nil {
|
||||||
return g.groupStrictSize
|
return g.strictGroupSize
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -389,17 +389,17 @@ func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemap
|
|||||||
ret.groupSize = groupSize
|
ret.groupSize = groupSize
|
||||||
|
|
||||||
// 3. parse group strict size
|
// 3. parse group strict size
|
||||||
var groupStrictSize bool
|
var strictGroupSize bool
|
||||||
groupStrictSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupStrictSize, searchParamsPair)
|
strictGroupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(StrictGroupSize, searchParamsPair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
groupStrictSize = false
|
strictGroupSize = false
|
||||||
} else {
|
} else {
|
||||||
groupStrictSize, err = strconv.ParseBool(groupStrictSizeStr)
|
strictGroupSize, err = strconv.ParseBool(strictGroupSizeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
groupStrictSize = false
|
strictGroupSize = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ret.groupStrictSize = groupStrictSize
|
ret.strictGroupSize = strictGroupSize
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -460,7 +460,7 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.C
|
|||||||
roundDecimal: roundDecimal,
|
roundDecimal: roundDecimal,
|
||||||
groupByFieldId: groupByInfo.GetGroupByFieldId(),
|
groupByFieldId: groupByInfo.GetGroupByFieldId(),
|
||||||
groupSize: groupByInfo.GetGroupSize(),
|
groupSize: groupByInfo.GetGroupSize(),
|
||||||
groupStrictSize: groupByInfo.GetGroupStrictSize(),
|
strictGroupSize: groupByInfo.GetStrictGroupSize(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -55,7 +55,7 @@ const (
|
|||||||
IteratorField = "iterator"
|
IteratorField = "iterator"
|
||||||
GroupByFieldKey = "group_by_field"
|
GroupByFieldKey = "group_by_field"
|
||||||
GroupSizeKey = "group_size"
|
GroupSizeKey = "group_size"
|
||||||
GroupStrictSize = "group_strict_size"
|
StrictGroupSize = "strict_group_size"
|
||||||
RankGroupScorer = "rank_group_scorer"
|
RankGroupScorer = "rank_group_scorer"
|
||||||
AnnsFieldKey = "anns_field"
|
AnnsFieldKey = "anns_field"
|
||||||
TopKKey = "topk"
|
TopKKey = "topk"
|
||||||
|
|||||||
@ -2325,7 +2325,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
// 1. first parse rank params
|
// 1. first parse rank params
|
||||||
// outer params require to group by field 101 and groupSize=3 and groupStrictSize=false
|
// outer params require to group by field 101 and groupSize=3 and strictGroupSize=false
|
||||||
testRankParamsPairs := getValidSearchParams()
|
testRankParamsPairs := getValidSearchParams()
|
||||||
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
||||||
Key: GroupByFieldKey,
|
Key: GroupByFieldKey,
|
||||||
@ -2336,7 +2336,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
Value: strconv.FormatInt(3, 10),
|
Value: strconv.FormatInt(3, 10),
|
||||||
})
|
})
|
||||||
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
||||||
Key: GroupStrictSize,
|
Key: StrictGroupSize,
|
||||||
Value: "false",
|
Value: "false",
|
||||||
})
|
})
|
||||||
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
testRankParamsPairs = append(testRankParamsPairs, &commonpb.KeyValuePair{
|
||||||
@ -2348,7 +2348,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
|
|
||||||
// 2. parse search params for sub request in hybridsearch
|
// 2. parse search params for sub request in hybridsearch
|
||||||
params := getValidSearchParams()
|
params := getValidSearchParams()
|
||||||
// inner params require to group by field 103 and groupSize=10 and groupStrictSize=true
|
// inner params require to group by field 103 and groupSize=10 and strictGroupSize=true
|
||||||
params = append(params, &commonpb.KeyValuePair{
|
params = append(params, &commonpb.KeyValuePair{
|
||||||
Key: GroupByFieldKey,
|
Key: GroupByFieldKey,
|
||||||
Value: "c3",
|
Value: "c3",
|
||||||
@ -2358,7 +2358,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
Value: strconv.FormatInt(10, 10),
|
Value: strconv.FormatInt(10, 10),
|
||||||
})
|
})
|
||||||
params = append(params, &commonpb.KeyValuePair{
|
params = append(params, &commonpb.KeyValuePair{
|
||||||
Key: GroupStrictSize,
|
Key: StrictGroupSize,
|
||||||
Value: "true",
|
Value: "true",
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -2370,7 +2370,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
// set by main request rather than inner sub request
|
// set by main request rather than inner sub request
|
||||||
assert.Equal(t, int64(101), searchInfo.planInfo.GetGroupByFieldId())
|
assert.Equal(t, int64(101), searchInfo.planInfo.GetGroupByFieldId())
|
||||||
assert.Equal(t, int64(3), searchInfo.planInfo.GetGroupSize())
|
assert.Equal(t, int64(3), searchInfo.planInfo.GetGroupSize())
|
||||||
assert.False(t, searchInfo.planInfo.GetGroupStrictSize())
|
assert.False(t, searchInfo.planInfo.GetStrictGroupSize())
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("parseSearchInfo error", func(t *testing.T) {
|
t.Run("parseSearchInfo error", func(t *testing.T) {
|
||||||
|
|||||||
@ -27,8 +27,8 @@ pytest-parallel
|
|||||||
pytest-random-order
|
pytest-random-order
|
||||||
|
|
||||||
# pymilvus
|
# pymilvus
|
||||||
pymilvus==2.5.0rc108
|
pymilvus==2.5.0rc111
|
||||||
pymilvus[bulk_writer]==2.5.0rc108
|
pymilvus[bulk_writer]==2.5.0rc111
|
||||||
|
|
||||||
# for customize config test
|
# for customize config test
|
||||||
python-benedict==0.24.3
|
python-benedict==0.24.3
|
||||||
|
|||||||
@ -2203,8 +2203,8 @@ class TestGroupSearch(TestCaseClassBase):
|
|||||||
"""
|
"""
|
||||||
target:
|
target:
|
||||||
1. search on 4 different float vector fields with group by varchar field with group size
|
1. search on 4 different float vector fields with group by varchar field with group size
|
||||||
verify results entity = limit * group_size and group size is full if group_strict_size is True
|
verify results entity = limit * group_size and group size is full if strict_group_size is True
|
||||||
verify results group counts = limit if group_strict_size is False
|
verify results group counts = limit if strict_group_size is False
|
||||||
"""
|
"""
|
||||||
nq = 2
|
nq = 2
|
||||||
limit = 50
|
limit = 50
|
||||||
@ -2212,11 +2212,11 @@ class TestGroupSearch(TestCaseClassBase):
|
|||||||
for j in range(len(self.vector_fields)):
|
for j in range(len(self.vector_fields)):
|
||||||
search_vectors = cf.gen_vectors(nq, dim=self.dims[j], vector_data_type=self.vector_fields[j])
|
search_vectors = cf.gen_vectors(nq, dim=self.dims[j], vector_data_type=self.vector_fields[j])
|
||||||
search_params = {"params": cf.get_search_params_params(self.index_types[j])}
|
search_params = {"params": cf.get_search_params_params(self.index_types[j])}
|
||||||
# when group_strict_size=true, it shall return results with entities = limit * group_size
|
# when strict_group_size=true, it shall return results with entities = limit * group_size
|
||||||
res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j],
|
res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j],
|
||||||
param=search_params, limit=limit,
|
param=search_params, limit=limit,
|
||||||
group_by_field=group_by_field,
|
group_by_field=group_by_field,
|
||||||
group_size=group_size, group_strict_size=True,
|
group_size=group_size, strict_group_size=True,
|
||||||
output_fields=[group_by_field])[0]
|
output_fields=[group_by_field])[0]
|
||||||
for i in range(nq):
|
for i in range(nq):
|
||||||
assert len(res1[i]) == limit * group_size
|
assert len(res1[i]) == limit * group_size
|
||||||
@ -2226,11 +2226,11 @@ class TestGroupSearch(TestCaseClassBase):
|
|||||||
group_values.append(res1[i][l*group_size+k].fields.get(group_by_field))
|
group_values.append(res1[i][l*group_size+k].fields.get(group_by_field))
|
||||||
assert len(set(group_values)) == 1
|
assert len(set(group_values)) == 1
|
||||||
|
|
||||||
# when group_strict_size=false, it shall return results with group counts = limit
|
# when strict_group_size=false, it shall return results with group counts = limit
|
||||||
res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j],
|
res1 = self.collection_wrap.search(data=search_vectors, anns_field=self.vector_fields[j],
|
||||||
param=search_params, limit=limit,
|
param=search_params, limit=limit,
|
||||||
group_by_field=group_by_field,
|
group_by_field=group_by_field,
|
||||||
group_size=group_size, group_strict_size=False,
|
group_size=group_size, strict_group_size=False,
|
||||||
output_fields=[group_by_field])[0]
|
output_fields=[group_by_field])[0]
|
||||||
for i in range(nq):
|
for i in range(nq):
|
||||||
group_values = []
|
group_values = []
|
||||||
@ -2438,7 +2438,7 @@ class TestGroupSearch(TestCaseClassBase):
|
|||||||
param=search_param, limit=limit, offset=limit * r,
|
param=search_param, limit=limit, offset=limit * r,
|
||||||
expr=default_search_exp,
|
expr=default_search_exp,
|
||||||
group_by_field=grpby_field, group_size=group_size,
|
group_by_field=grpby_field, group_size=group_size,
|
||||||
group_strict_size=True,
|
strict_group_size=True,
|
||||||
output_fields=[grpby_field],
|
output_fields=[grpby_field],
|
||||||
check_task=CheckTasks.check_search_results,
|
check_task=CheckTasks.check_search_results,
|
||||||
check_items={"nq": 1, "limit": res_count},
|
check_items={"nq": 1, "limit": res_count},
|
||||||
@ -2459,7 +2459,7 @@ class TestGroupSearch(TestCaseClassBase):
|
|||||||
param=search_param, limit=limit * page_rounds,
|
param=search_param, limit=limit * page_rounds,
|
||||||
expr=default_search_exp,
|
expr=default_search_exp,
|
||||||
group_by_field=grpby_field, group_size=group_size,
|
group_by_field=grpby_field, group_size=group_size,
|
||||||
group_strict_size=True,
|
strict_group_size=True,
|
||||||
output_fields=[grpby_field],
|
output_fields=[grpby_field],
|
||||||
check_task=CheckTasks.check_search_results,
|
check_task=CheckTasks.check_search_results,
|
||||||
check_items={"nq": 1, "limit": total_count}
|
check_items={"nq": 1, "limit": total_count}
|
||||||
@ -2488,7 +2488,7 @@ class TestGroupSearch(TestCaseClassBase):
|
|||||||
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
||||||
param=search_params, limit=limit,
|
param=search_params, limit=limit,
|
||||||
group_by_field=group_by_field,
|
group_by_field=group_by_field,
|
||||||
group_size=max_group_size, group_strict_size=True,
|
group_size=max_group_size, strict_group_size=True,
|
||||||
output_fields=[group_by_field])
|
output_fields=[group_by_field])
|
||||||
exceed_max_group_size = max_group_size + 1
|
exceed_max_group_size = max_group_size + 1
|
||||||
error = {ct.err_code: 999,
|
error = {ct.err_code: 999,
|
||||||
@ -2497,7 +2497,7 @@ class TestGroupSearch(TestCaseClassBase):
|
|||||||
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
||||||
param=search_params, limit=limit,
|
param=search_params, limit=limit,
|
||||||
group_by_field=group_by_field,
|
group_by_field=group_by_field,
|
||||||
group_size=exceed_max_group_size, group_strict_size=True,
|
group_size=exceed_max_group_size, strict_group_size=True,
|
||||||
output_fields=[group_by_field],
|
output_fields=[group_by_field],
|
||||||
check_task=CheckTasks.err_res, check_items=error)
|
check_task=CheckTasks.err_res, check_items=error)
|
||||||
|
|
||||||
@ -2505,7 +2505,7 @@ class TestGroupSearch(TestCaseClassBase):
|
|||||||
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
||||||
param=search_params, limit=limit,
|
param=search_params, limit=limit,
|
||||||
group_by_field=group_by_field,
|
group_by_field=group_by_field,
|
||||||
group_size=max_group_size, group_strict_size=True,
|
group_size=max_group_size, strict_group_size=True,
|
||||||
output_fields=[group_by_field])
|
output_fields=[group_by_field])
|
||||||
below_min_group_size = min_group_size - 1
|
below_min_group_size = min_group_size - 1
|
||||||
error = {ct.err_code: 999,
|
error = {ct.err_code: 999,
|
||||||
@ -2513,6 +2513,6 @@ class TestGroupSearch(TestCaseClassBase):
|
|||||||
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
self.collection_wrap.search(data=search_vectors, anns_field=default_search_field,
|
||||||
param=search_params, limit=limit,
|
param=search_params, limit=limit,
|
||||||
group_by_field=group_by_field,
|
group_by_field=group_by_field,
|
||||||
group_size=below_min_group_size, group_strict_size=True,
|
group_size=below_min_group_size, strict_group_size=True,
|
||||||
output_fields=[group_by_field],
|
output_fields=[group_by_field],
|
||||||
check_task=CheckTasks.err_res, check_items=error)
|
check_task=CheckTasks.err_res, check_items=error)
|
||||||
Loading…
x
Reference in New Issue
Block a user