mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
fix: support group by with nullable grouping keys (#41797)
See #36264 In this PR: - Enhanced error handling in parse of grouping field. - Fixed null handling in reduce tasks in proxy nodes. - Updated tests to reflect changes in error handling and data processing logic. --------- Signed-off-by: Ted Xu <ted.xu@zilliz.com>
This commit is contained in:
parent
b8d7045539
commit
ae32203d3a
@ -25,7 +25,7 @@ func reduceSearchResult(ctx context.Context, subSearchResultData []*schemapb.Sea
|
|||||||
if reduceInfo.GetIsAdvance() {
|
if reduceInfo.GetIsAdvance() {
|
||||||
// for hybrid search group by, we cannot reduce result for results from one single search path,
|
// for hybrid search group by, we cannot reduce result for results from one single search path,
|
||||||
// because the final score has not been accumulated, also, offset cannot be applied
|
// because the final score has not been accumulated, also, offset cannot be applied
|
||||||
return reduceAdvanceGroupBY(ctx,
|
return reduceAdvanceGroupBy(ctx,
|
||||||
subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetPkType(), reduceInfo.GetMetricType())
|
subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetPkType(), reduceInfo.GetMetricType())
|
||||||
}
|
}
|
||||||
return reduceSearchResultDataWithGroupBy(ctx,
|
return reduceSearchResultDataWithGroupBy(ctx,
|
||||||
@ -69,7 +69,7 @@ func checkResultDatas(ctx context.Context, subSearchResultData []*schemapb.Searc
|
|||||||
return allSearchCount, hitNum, nil
|
return allSearchCount, hitNum, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func reduceAdvanceGroupBY(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
func reduceAdvanceGroupBy(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
||||||
nq int64, topK int64, pkType schemapb.DataType, metricType string,
|
nq int64, topK int64, pkType schemapb.DataType, metricType string,
|
||||||
) (*milvuspb.SearchResults, error) {
|
) (*milvuspb.SearchResults, error) {
|
||||||
log.Ctx(ctx).Debug("reduceAdvanceGroupBY", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq))
|
log.Ctx(ctx).Debug("reduceAdvanceGroupBY", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq))
|
||||||
@ -117,6 +117,11 @@ func reduceAdvanceGroupBY(ctx context.Context, subSearchResultData []*schemapb.S
|
|||||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(subSearchResultData[0].GetGroupByFieldValue().GetType(), true, int(limit))
|
||||||
|
if err != nil {
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
// reducing nq * topk results
|
// reducing nq * topk results
|
||||||
for nqIdx := int64(0); nqIdx < nq; nqIdx++ {
|
for nqIdx := int64(0); nqIdx < nq; nqIdx++ {
|
||||||
dataCount := int64(0)
|
dataCount := int64(0)
|
||||||
@ -127,23 +132,23 @@ func reduceAdvanceGroupBY(ctx context.Context, subSearchResultData []*schemapb.S
|
|||||||
subGroupByVals := subData.GetGroupByFieldValue()
|
subGroupByVals := subData.GetGroupByFieldValue()
|
||||||
|
|
||||||
nqTopK := subData.Topks[nqIdx]
|
nqTopK := subData.Topks[nqIdx]
|
||||||
|
groupByValIterator := typeutil.GetDataIterator(subGroupByVals)
|
||||||
|
|
||||||
for i := int64(0); i < nqTopK; i++ {
|
for i := int64(0); i < nqTopK; i++ {
|
||||||
innerIdx := subSearchNqOffset[subIdx][nqIdx] + i
|
innerIdx := subSearchNqOffset[subIdx][nqIdx] + i
|
||||||
pk := typeutil.GetPK(subPks, innerIdx)
|
pk := typeutil.GetPK(subPks, innerIdx)
|
||||||
score := subScores[innerIdx]
|
score := subScores[innerIdx]
|
||||||
groupByVal := typeutil.GetData(subData.GetGroupByFieldValue(), int(innerIdx))
|
groupByVal := groupByValIterator(int(innerIdx))
|
||||||
|
gpFieldBuilder.Add(groupByVal)
|
||||||
typeutil.AppendPKs(ret.Results.Ids, pk)
|
typeutil.AppendPKs(ret.Results.Ids, pk)
|
||||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||||
if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subGroupByVals.GetType()); err != nil {
|
|
||||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
|
||||||
return ret, err
|
|
||||||
}
|
|
||||||
dataCount += 1
|
dataCount += 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ret.Results.Topks = append(ret.Results.Topks, dataCount)
|
ret.Results.Topks = append(ret.Results.Topks, dataCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ret.Results.GroupByFieldValue = gpFieldBuilder.Build()
|
||||||
ret.Results.TopK = topK // realTopK is the topK of the nq-th query
|
ret.Results.TopK = topK // realTopK is the topK of the nq-th query
|
||||||
if !metric.PositivelyRelated(metricType) {
|
if !metric.PositivelyRelated(metricType) {
|
||||||
for k := range ret.Results.Scores {
|
for k := range ret.Results.Scores {
|
||||||
@ -194,7 +199,7 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||||||
}
|
}
|
||||||
groupBound := groupSize * limit
|
groupBound := groupSize * limit
|
||||||
if err := setupIdListForSearchResult(ret, pkType, groupBound); err != nil {
|
if err := setupIdListForSearchResult(ret, pkType, groupBound); err != nil {
|
||||||
return ret, nil
|
return ret, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
|
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
|
||||||
@ -209,6 +214,7 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||||||
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
||||||
subSearchNqOffset = make([][]int64, subSearchNum)
|
subSearchNqOffset = make([][]int64, subSearchNum)
|
||||||
totalResCount int64 = 0
|
totalResCount int64 = 0
|
||||||
|
subSearchGroupByValIterator = make([]func(int) any, subSearchNum)
|
||||||
)
|
)
|
||||||
for i := 0; i < subSearchNum; i++ {
|
for i := 0; i < subSearchNum; i++ {
|
||||||
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
||||||
@ -216,6 +222,11 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||||||
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||||
}
|
}
|
||||||
totalResCount += subSearchNqOffset[i][nq-1]
|
totalResCount += subSearchNqOffset[i][nq-1]
|
||||||
|
subSearchGroupByValIterator[i] = typeutil.GetDataIterator(subSearchResultData[i].GetGroupByFieldValue())
|
||||||
|
}
|
||||||
|
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(subSearchResultData[0].GetGroupByFieldValue().GetType(), true, int(limit))
|
||||||
|
if err != nil {
|
||||||
|
return ret, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var realTopK int64 = -1
|
var realTopK int64 = -1
|
||||||
@ -245,11 +256,7 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||||||
|
|
||||||
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
id := typeutil.GetPK(subSearchRes.GetIds(), resultDataIdx)
|
||||||
score := subSearchRes.GetScores()[resultDataIdx]
|
score := subSearchRes.GetScores()[resultDataIdx]
|
||||||
groupByVal := typeutil.GetData(subSearchRes.GetGroupByFieldValue(), int(resultDataIdx))
|
groupByVal := subSearchGroupByValIterator[subSearchIdx](int(resultDataIdx))
|
||||||
if groupByVal == nil {
|
|
||||||
return nil, errors.New("get nil groupByVal from subSearchRes, wrong states, as milvus doesn't support nil value," +
|
|
||||||
"there must be sth wrong on queryNode side")
|
|
||||||
}
|
|
||||||
|
|
||||||
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
|
if int64(len(skipOffsetMap)) < offset || skipOffsetMap[groupByVal] {
|
||||||
skipOffsetMap[groupByVal] = true
|
skipOffsetMap[groupByVal] = true
|
||||||
@ -276,18 +283,13 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||||||
// assemble all eligible values in group
|
// assemble all eligible values in group
|
||||||
// values in groupByValList is sorted by the highest score in each group
|
// values in groupByValList is sorted by the highest score in each group
|
||||||
for _, groupVal := range groupByValList {
|
for _, groupVal := range groupByValList {
|
||||||
if groupVal != nil {
|
|
||||||
groupEntities := groupByValMap[groupVal]
|
groupEntities := groupByValMap[groupVal]
|
||||||
for _, groupEntity := range groupEntities {
|
for _, groupEntity := range groupEntities {
|
||||||
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
subResData := subSearchResultData[groupEntity.subSearchIdx]
|
||||||
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
|
||||||
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
|
||||||
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
|
||||||
if err := typeutil.AppendGroupByValue(ret.Results, groupVal, subResData.GetGroupByFieldValue().GetType()); err != nil {
|
gpFieldBuilder.Add(groupVal)
|
||||||
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
|
||||||
return ret, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,6 +298,7 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||||||
}
|
}
|
||||||
realTopK = j
|
realTopK = j
|
||||||
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
ret.Results.Topks = append(ret.Results.Topks, realTopK)
|
||||||
|
ret.Results.GroupByFieldValue = gpFieldBuilder.Build()
|
||||||
|
|
||||||
// limit search result to avoid oom
|
// limit search result to avoid oom
|
||||||
if retSize > maxOutputSize {
|
if retSize > maxOutputSize {
|
||||||
@ -535,11 +538,12 @@ func rankSearchResultDataByGroup(ctx context.Context,
|
|||||||
start := 0
|
start := 0
|
||||||
// milvus has limits for the value range of nq and limit
|
// milvus has limits for the value range of nq and limit
|
||||||
// no matter on 32-bit and 64-bit platform, converting nq and topK into int is safe
|
// no matter on 32-bit and 64-bit platform, converting nq and topK into int is safe
|
||||||
|
groupByValIterator := typeutil.GetDataIterator(result.GetResults().GetGroupByFieldValue())
|
||||||
for i := 0; i < int(nq); i++ {
|
for i := 0; i < int(nq); i++ {
|
||||||
realTopK := int(result.GetResults().Topks[i])
|
realTopK := int(result.GetResults().Topks[i])
|
||||||
for j := start; j < start+realTopK; j++ {
|
for j := start; j < start+realTopK; j++ {
|
||||||
id := typeutil.GetPK(result.GetResults().GetIds(), int64(j))
|
id := typeutil.GetPK(result.GetResults().GetIds(), int64(j))
|
||||||
groupByVal := typeutil.GetData(result.GetResults().GetGroupByFieldValue(), j)
|
groupByVal := groupByValIterator(j)
|
||||||
if accumulatedScores[i][id] != nil {
|
if accumulatedScores[i][id] != nil {
|
||||||
accumulatedScores[i][id].accumulatedScore += scores[j]
|
accumulatedScores[i][id].accumulatedScore += scores[j]
|
||||||
} else {
|
} else {
|
||||||
@ -550,6 +554,10 @@ func rankSearchResultDataByGroup(ctx context.Context,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(groupByDataType, true, int(limit))
|
||||||
|
if err != nil {
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
for i := int64(0); i < nq; i++ {
|
for i := int64(0); i < nq; i++ {
|
||||||
idSet := accumulatedScores[i]
|
idSet := accumulatedScores[i]
|
||||||
keys := make([]interface{}, 0)
|
keys := make([]interface{}, 0)
|
||||||
@ -631,13 +639,14 @@ func rankSearchResultDataByGroup(ctx context.Context,
|
|||||||
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||||
}
|
}
|
||||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||||
typeutil.AppendGroupByValue(ret.Results, group.groupVal, groupByDataType)
|
gpFieldBuilder.Add(group.groupVal)
|
||||||
}
|
}
|
||||||
returnedRowNum += len(group.idList)
|
returnedRowNum += len(group.idList)
|
||||||
}
|
}
|
||||||
ret.Results.Topks = append(ret.Results.Topks, int64(returnedRowNum))
|
ret.Results.Topks = append(ret.Results.Topks, int64(returnedRowNum))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ret.Results.GroupByFieldValue = gpFieldBuilder.Build()
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -79,7 +79,6 @@ func (r *rankParams) String() string {
|
|||||||
type SearchInfo struct {
|
type SearchInfo struct {
|
||||||
planInfo *planpb.QueryInfo
|
planInfo *planpb.QueryInfo
|
||||||
offset int64
|
offset int64
|
||||||
parseError error
|
|
||||||
isIterator bool
|
isIterator bool
|
||||||
collectionID int64
|
collectionID int64
|
||||||
}
|
}
|
||||||
@ -160,21 +159,21 @@ func parseSearchIteratorV2Info(searchParamsPair []*commonpb.KeyValuePair, groupB
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parseSearchInfo returns QueryInfo and offset
|
// parseSearchInfo returns QueryInfo and offset
|
||||||
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) *SearchInfo {
|
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*SearchInfo, error) {
|
||||||
var topK int64
|
var topK int64
|
||||||
isAdvanced := rankParams != nil
|
isAdvanced := rankParams != nil
|
||||||
externalLimit := rankParams.GetLimit() + rankParams.GetOffset()
|
externalLimit := rankParams.GetLimit() + rankParams.GetOffset()
|
||||||
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if externalLimit <= 0 {
|
if externalLimit <= 0 {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s is required", TopKKey)}
|
return nil, fmt.Errorf("%s is required", TopKKey)
|
||||||
}
|
}
|
||||||
topK = externalLimit
|
topK = externalLimit
|
||||||
} else {
|
} else {
|
||||||
topKInParam, err := strconv.ParseInt(topKStr, 0, 64)
|
topKInParam, err := strconv.ParseInt(topKStr, 0, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if externalLimit <= 0 {
|
if externalLimit <= 0 {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)}
|
return nil, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
||||||
}
|
}
|
||||||
topK = externalLimit
|
topK = externalLimit
|
||||||
} else {
|
} else {
|
||||||
@ -194,7 +193,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||||||
// 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here
|
// 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here
|
||||||
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
|
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
|
||||||
} else {
|
} else {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)}
|
return nil, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -205,12 +204,12 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)}
|
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if offset != 0 {
|
if offset != 0 {
|
||||||
if err := validateLimit(offset); err != nil {
|
if err := validateLimit(offset); err != nil {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)}
|
return nil, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -218,7 +217,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||||||
|
|
||||||
queryTopK := topK + offset
|
queryTopK := topK + offset
|
||||||
if err := validateLimit(queryTopK); err != nil {
|
if err := validateLimit(queryTopK); err != nil {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)}
|
return nil, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. parse metrics type
|
// 2. parse metrics type
|
||||||
@ -240,11 +239,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||||||
|
|
||||||
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
|
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)}
|
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)}
|
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. parse search param str
|
// 4. parse search param str
|
||||||
@ -259,26 +258,26 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||||||
if isAdvanced {
|
if isAdvanced {
|
||||||
groupByFieldId, groupSize, strictGroupSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetStrictGroupSize()
|
groupByFieldId, groupSize, strictGroupSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetStrictGroupSize()
|
||||||
} else {
|
} else {
|
||||||
groupByInfo := parseGroupByInfo(searchParamsPair, schema)
|
groupByInfo, err := parseGroupByInfo(searchParamsPair, schema)
|
||||||
if groupByInfo.err != nil {
|
if err != nil {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: groupByInfo.err}
|
return nil, err
|
||||||
}
|
}
|
||||||
groupByFieldId, groupSize, strictGroupSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetStrictGroupSize()
|
groupByFieldId, groupSize, strictGroupSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetStrictGroupSize()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
||||||
if isIterator && groupByFieldId > 0 {
|
if isIterator && groupByFieldId > 0 {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "",
|
return nil, merr.WrapErrParameterInvalid("", "",
|
||||||
"Not allowed to do groupBy when doing iteration")}
|
"Not allowed to do groupBy when doing iteration")
|
||||||
}
|
}
|
||||||
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
|
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "",
|
return nil, merr.WrapErrParameterInvalid("", "",
|
||||||
"Not allowed to do range-search when doing search-group-by")}
|
"Not allowed to do range-search when doing search-group-by")
|
||||||
}
|
}
|
||||||
|
|
||||||
planSearchIteratorV2Info, err := parseSearchIteratorV2Info(searchParamsPair, groupByFieldId, isIterator, offset, &queryTopK)
|
planSearchIteratorV2Info, err := parseSearchIteratorV2Info(searchParamsPair, groupByFieldId, isIterator, offset, &queryTopK)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("parse iterator v2 info failed: %w", err)}
|
return nil, fmt.Errorf("parse iterator v2 info failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &SearchInfo{
|
return &SearchInfo{
|
||||||
@ -295,9 +294,8 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||||||
},
|
},
|
||||||
offset: offset,
|
offset: offset,
|
||||||
isIterator: isIterator,
|
isIterator: isIterator,
|
||||||
parseError: nil,
|
|
||||||
collectionID: collectionId,
|
collectionID: collectionId,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
|
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
|
||||||
@ -395,7 +393,6 @@ type groupByInfo struct {
|
|||||||
groupByFieldId int64
|
groupByFieldId int64
|
||||||
groupSize int64
|
groupSize int64
|
||||||
strictGroupSize bool
|
strictGroupSize bool
|
||||||
err error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *groupByInfo) GetGroupByFieldId() int64 {
|
func (g *groupByInfo) GetGroupByFieldId() int64 {
|
||||||
@ -419,14 +416,7 @@ func (g *groupByInfo) GetStrictGroupSize() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *groupByInfo) GetError() error {
|
func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*groupByInfo, error) {
|
||||||
if g != nil {
|
|
||||||
return g.err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) *groupByInfo {
|
|
||||||
ret := &groupByInfo{}
|
ret := &groupByInfo{}
|
||||||
|
|
||||||
// 1. parse group_by_field
|
// 1. parse group_by_field
|
||||||
@ -444,8 +434,7 @@ func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemap
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if groupByFieldId == -1 {
|
if groupByFieldId == -1 {
|
||||||
ret.err = merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
|
return nil, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
|
||||||
return ret
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ret.groupByFieldId = groupByFieldId
|
ret.groupByFieldId = groupByFieldId
|
||||||
@ -458,20 +447,17 @@ func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemap
|
|||||||
} else {
|
} else {
|
||||||
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
|
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ret.err = merr.WrapErrParameterInvalidMsg(
|
return nil, merr.WrapErrParameterInvalidMsg(
|
||||||
fmt.Sprintf("failed to parse input group size:%s", groupSizeStr))
|
fmt.Sprintf("failed to parse input group size:%s", groupSizeStr))
|
||||||
return ret
|
|
||||||
}
|
}
|
||||||
if groupSize <= 0 {
|
if groupSize <= 0 {
|
||||||
ret.err = merr.WrapErrParameterInvalidMsg(
|
return nil, merr.WrapErrParameterInvalidMsg(
|
||||||
fmt.Sprintf("input group size:%d is negative, failed to do search_groupby", groupSize))
|
fmt.Sprintf("input group size:%d is negative, failed to do search_groupby", groupSize))
|
||||||
return ret
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
|
if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
|
||||||
ret.err = merr.WrapErrParameterInvalidMsg(
|
return nil, merr.WrapErrParameterInvalidMsg(
|
||||||
fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
|
fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
|
||||||
return ret
|
|
||||||
}
|
}
|
||||||
ret.groupSize = groupSize
|
ret.groupSize = groupSize
|
||||||
|
|
||||||
@ -487,7 +473,7 @@ func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemap
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
ret.strictGroupSize = strictGroupSize
|
ret.strictGroupSize = strictGroupSize
|
||||||
return ret
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseRankParams get limit and offset from rankParams, both are optional.
|
// parseRankParams get limit and offset from rankParams, both are optional.
|
||||||
@ -536,9 +522,9 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parse group_by parameters from main request body for hybrid search
|
// parse group_by parameters from main request body for hybrid search
|
||||||
groupByInfo := parseGroupByInfo(rankParamsPair, schema)
|
groupByInfo, err := parseGroupByInfo(rankParamsPair, schema)
|
||||||
if groupByInfo.err != nil {
|
if err != nil {
|
||||||
return nil, groupByInfo.err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &rankParams{
|
return &rankParams{
|
||||||
|
|||||||
@ -865,9 +865,9 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string
|
|||||||
}
|
}
|
||||||
annsFieldName = vecFields[0].Name
|
annsFieldName = vecFields[0].Name
|
||||||
}
|
}
|
||||||
searchInfo := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
|
searchInfo, err := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
|
||||||
if searchInfo.parseError != nil {
|
if err != nil {
|
||||||
return nil, nil, 0, false, searchInfo.parseError
|
return nil, nil, 0, false, err
|
||||||
}
|
}
|
||||||
if searchInfo.collectionID > 0 && searchInfo.collectionID != t.GetCollectionID() {
|
if searchInfo.collectionID > 0 && searchInfo.collectionID != t.GetCollectionID() {
|
||||||
return nil, nil, 0, false, merr.WrapErrParameterInvalidMsg("collection id:%d in the request is not consistent to that in the search context,"+
|
return nil, nil, 0, false, merr.WrapErrParameterInvalidMsg("collection id:%d in the request is not consistent to that in the search context,"+
|
||||||
@ -1254,8 +1254,9 @@ func (t *searchTask) reorganizeRequeryResults(ctx context.Context, fields []*sch
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
offsets := make(map[any]int)
|
offsets := make(map[any]int)
|
||||||
|
pkItr := typeutil.GetDataIterator(pkFieldData)
|
||||||
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
|
for i := 0; i < typeutil.GetPKSize(pkFieldData); i++ {
|
||||||
pk := typeutil.GetData(pkFieldData, i)
|
pk := pkItr(i)
|
||||||
offsets[pk] = i
|
offsets[pk] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2505,36 +2505,10 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
|
|||||||
{9, 7, 5, 3, 1, 9, 7, 5, 3, 1},
|
{9, 7, 5, 3, 1, 9, 7, 5, 3, 1},
|
||||||
}
|
}
|
||||||
|
|
||||||
groupByValuesArr := [][][]int64{
|
makePartialResult := func(ids []int64, scores []float32, groupByValues []int64, valids []bool) *schemapb.SearchResultData {
|
||||||
{
|
|
||||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
|
||||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
|
||||||
}, // result2 has completely same group_by values, no result from result2 can be selected
|
|
||||||
{
|
|
||||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
|
||||||
{6, 8, 3, 4, 5, 6, 8, 3, 4, 5},
|
|
||||||
}, // result2 will contribute group_by values 6 and 8
|
|
||||||
}
|
|
||||||
expectedIDs := [][]int64{
|
|
||||||
{1, 3, 5, 7, 9, 1, 3, 5, 7, 9},
|
|
||||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
|
||||||
}
|
|
||||||
expectedScores := [][]float32{
|
|
||||||
{-10, -8, -6, -4, -2, -10, -8, -6, -4, -2},
|
|
||||||
{-10, -9, -8, -7, -6, -10, -9, -8, -7, -6},
|
|
||||||
}
|
|
||||||
expectedGroupByValues := [][]int64{
|
|
||||||
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
|
||||||
{1, 6, 2, 8, 3, 1, 6, 2, 8, 3},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, groupByValues := range groupByValuesArr {
|
|
||||||
t.Run("Group By correctness", func(t *testing.T) {
|
|
||||||
var results []*schemapb.SearchResultData
|
|
||||||
for j := range ids {
|
|
||||||
result := getSearchResultData(nq, topK)
|
result := getSearchResultData(nq, topK)
|
||||||
result.Ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: ids[j]}}
|
result.Ids.IdField = &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: ids}}
|
||||||
result.Scores = scores[j]
|
result.Scores = scores
|
||||||
result.Topks = []int64{topK, topK}
|
result.Topks = []int64{topK, topK}
|
||||||
result.GroupByFieldValue = &schemapb.FieldData{
|
result.GroupByFieldValue = &schemapb.FieldData{
|
||||||
Type: schemapb.DataType_Int64,
|
Type: schemapb.DataType_Int64,
|
||||||
@ -2542,26 +2516,97 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
|
|||||||
Scalars: &schemapb.ScalarField{
|
Scalars: &schemapb.ScalarField{
|
||||||
Data: &schemapb.ScalarField_LongData{
|
Data: &schemapb.ScalarField_LongData{
|
||||||
LongData: &schemapb.LongArray{
|
LongData: &schemapb.LongArray{
|
||||||
Data: groupByValues[j],
|
Data: groupByValues,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
ValidData: valids,
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inputs []*schemapb.SearchResultData
|
||||||
|
expectedIDs []int64
|
||||||
|
expectedScores []float32
|
||||||
|
expectedGroupByValues *schemapb.FieldData
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "same group_by values",
|
||||||
|
inputs: []*schemapb.SearchResultData{
|
||||||
|
makePartialResult(ids[0], scores[0], []int64{1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, nil),
|
||||||
|
makePartialResult(ids[1], scores[1], []int64{1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, nil),
|
||||||
|
},
|
||||||
|
expectedIDs: []int64{1, 3, 5, 7, 9, 1, 3, 5, 7, 9},
|
||||||
|
expectedScores: []float32{-10, -8, -6, -4, -2, -10, -8, -6, -4, -2},
|
||||||
|
expectedGroupByValues: &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3, 4, 5, 1, 2, 3, 4, 5}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different group_by values",
|
||||||
|
inputs: []*schemapb.SearchResultData{
|
||||||
|
makePartialResult(ids[0], scores[0], []int64{1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, nil),
|
||||||
|
makePartialResult(ids[1], scores[1], []int64{6, 8, 3, 4, 5, 6, 8, 3, 4, 5}, nil),
|
||||||
|
},
|
||||||
|
expectedIDs: []int64{1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
|
||||||
|
expectedScores: []float32{-10, -9, -8, -7, -6, -10, -9, -8, -7, -6},
|
||||||
|
expectedGroupByValues: &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 6, 2, 8, 3, 1, 6, 2, 8, 3}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nullable group_by values",
|
||||||
|
inputs: []*schemapb.SearchResultData{
|
||||||
|
makePartialResult(ids[0], scores[0], []int64{1, 2, 3, 4, 1, 2, 3, 4}, []bool{true, true, true, true, false, true, true, true, true, false}),
|
||||||
|
makePartialResult(ids[1], scores[1], []int64{1, 2, 3, 4, 1, 2, 3, 4}, []bool{true, true, true, true, false, true, true, true, true, false}),
|
||||||
|
},
|
||||||
|
expectedIDs: []int64{1, 3, 5, 7, 9, 1, 3, 5, 7, 9},
|
||||||
|
expectedScores: []float32{-10, -8, -6, -4, -2, -10, -8, -6, -4, -2},
|
||||||
|
expectedGroupByValues: &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_LongData{
|
||||||
|
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3, 4, 0, 1, 2, 3, 4, 0}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ValidData: []bool{true, true, true, true, false, true, true, true, true, false},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
results = append(results, result)
|
|
||||||
}
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
queryInfo := &planpb.QueryInfo{
|
queryInfo := &planpb.QueryInfo{
|
||||||
GroupByFieldId: 1,
|
GroupByFieldId: 1,
|
||||||
GroupSize: 1,
|
GroupSize: 1,
|
||||||
}
|
}
|
||||||
reduced, err := reduceSearchResult(context.TODO(), results,
|
reduced, err := reduceSearchResult(context.TODO(), tt.inputs,
|
||||||
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
reduce.NewReduceSearchResultInfo(nq, topK).
|
||||||
|
WithMetricType(metric.L2).
|
||||||
|
WithPkType(schemapb.DataType_Int64).
|
||||||
|
WithGroupByField(queryInfo.GetGroupByFieldId()).
|
||||||
|
WithGroupSize(queryInfo.GetGroupSize()))
|
||||||
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
||||||
resultScores := reduced.GetResults().GetScores()
|
resultScores := reduced.GetResults().GetScores()
|
||||||
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue()
|
||||||
assert.EqualValues(t, expectedIDs[i], resultIDs)
|
assert.EqualValues(t, tt.expectedIDs, resultIDs)
|
||||||
assert.EqualValues(t, expectedScores[i], resultScores)
|
assert.EqualValues(t, tt.expectedScores, resultScores)
|
||||||
assert.EqualValues(t, expectedGroupByValues[i], resultGroupByValues)
|
assert.EqualValues(t, tt.expectedGroupByValues, resultGroupByValues)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -3078,8 +3123,8 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
searchInfo := parseSearchInfo(test.validParams, nil, nil)
|
searchInfo, err := parseSearchInfo(test.validParams, nil, nil)
|
||||||
assert.NoError(t, searchInfo.parseError)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, searchInfo.planInfo)
|
assert.NotNil(t, searchInfo.planInfo)
|
||||||
if test.description == "offsetParam" {
|
if test.description == "offsetParam" {
|
||||||
assert.Equal(t, targetOffset, searchInfo.offset)
|
assert.Equal(t, targetOffset, searchInfo.offset)
|
||||||
@ -3099,8 +3144,8 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
limit: externalLimit,
|
limit: externalLimit,
|
||||||
}
|
}
|
||||||
|
|
||||||
searchInfo := parseSearchInfo(offsetParam, nil, rank)
|
searchInfo, err := parseSearchInfo(offsetParam, nil, rank)
|
||||||
assert.NoError(t, searchInfo.parseError)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, searchInfo.planInfo)
|
assert.NotNil(t, searchInfo.planInfo)
|
||||||
assert.Equal(t, int64(10), searchInfo.planInfo.GetTopk())
|
assert.Equal(t, int64(10), searchInfo.planInfo.GetTopk())
|
||||||
assert.Equal(t, int64(0), searchInfo.offset)
|
assert.Equal(t, int64(0), searchInfo.offset)
|
||||||
@ -3152,8 +3197,8 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
Value: "true",
|
Value: "true",
|
||||||
})
|
})
|
||||||
|
|
||||||
searchInfo := parseSearchInfo(params, schema, testRankParams)
|
searchInfo, err := parseSearchInfo(params, schema, testRankParams)
|
||||||
assert.NoError(t, searchInfo.parseError)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, searchInfo.planInfo)
|
assert.NotNil(t, searchInfo.planInfo)
|
||||||
|
|
||||||
// all group_by related parameters should be aligned to parameters
|
// all group_by related parameters should be aligned to parameters
|
||||||
@ -3242,12 +3287,11 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
searchInfo := parseSearchInfo(test.invalidParams, nil, nil)
|
searchInfo, err := parseSearchInfo(test.invalidParams, nil, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, searchInfo.planInfo)
|
assert.Nil(t, searchInfo)
|
||||||
assert.Zero(t, searchInfo.offset)
|
|
||||||
|
|
||||||
t.Logf("err=%s", searchInfo.parseError)
|
t.Logf("err=%s", err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -3269,9 +3313,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
schema := &schemapb.CollectionSchema{
|
schema := &schemapb.CollectionSchema{
|
||||||
Fields: fields,
|
Fields: fields,
|
||||||
}
|
}
|
||||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
searchInfo, err := parseSearchInfo(normalParam, schema, nil)
|
||||||
assert.Nil(t, searchInfo.planInfo)
|
assert.Nil(t, searchInfo)
|
||||||
assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid)
|
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||||
})
|
})
|
||||||
t.Run("check range-search and groupBy", func(t *testing.T) {
|
t.Run("check range-search and groupBy", func(t *testing.T) {
|
||||||
normalParam := getValidSearchParams()
|
normalParam := getValidSearchParams()
|
||||||
@ -3288,9 +3332,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
schema := &schemapb.CollectionSchema{
|
schema := &schemapb.CollectionSchema{
|
||||||
Fields: fields,
|
Fields: fields,
|
||||||
}
|
}
|
||||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
searchInfo, err := parseSearchInfo(normalParam, schema, nil)
|
||||||
assert.Nil(t, searchInfo.planInfo)
|
assert.Nil(t, searchInfo)
|
||||||
assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid)
|
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||||
})
|
})
|
||||||
t.Run("check nullable and groupBy", func(t *testing.T) {
|
t.Run("check nullable and groupBy", func(t *testing.T) {
|
||||||
normalParam := getValidSearchParams()
|
normalParam := getValidSearchParams()
|
||||||
@ -3307,9 +3351,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
schema := &schemapb.CollectionSchema{
|
schema := &schemapb.CollectionSchema{
|
||||||
Fields: fields,
|
Fields: fields,
|
||||||
}
|
}
|
||||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
searchInfo, err := parseSearchInfo(normalParam, schema, nil)
|
||||||
assert.NotNil(t, searchInfo.planInfo)
|
assert.NotNil(t, searchInfo)
|
||||||
assert.NoError(t, searchInfo.parseError)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
t.Run("check iterator and topK", func(t *testing.T) {
|
t.Run("check iterator and topK", func(t *testing.T) {
|
||||||
normalParam := getValidSearchParams()
|
normalParam := getValidSearchParams()
|
||||||
@ -3326,9 +3370,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
schema := &schemapb.CollectionSchema{
|
schema := &schemapb.CollectionSchema{
|
||||||
Fields: fields,
|
Fields: fields,
|
||||||
}
|
}
|
||||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
searchInfo, err := parseSearchInfo(normalParam, schema, nil)
|
||||||
assert.NotNil(t, searchInfo.planInfo)
|
assert.NotNil(t, searchInfo)
|
||||||
assert.NoError(t, searchInfo.parseError)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), searchInfo.planInfo.GetTopk())
|
assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), searchInfo.planInfo.GetTopk())
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -3346,26 +3390,26 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
schema := &schemapb.CollectionSchema{
|
schema := &schemapb.CollectionSchema{
|
||||||
Fields: fields,
|
Fields: fields,
|
||||||
}
|
}
|
||||||
searchInfo := parseSearchInfo(normalParam, schema, nil)
|
_, err := parseSearchInfo(normalParam, schema, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "exceeds configured max group size"))
|
assert.True(t, strings.Contains(err.Error(), "exceeds configured max group size"))
|
||||||
{
|
{
|
||||||
resetSearchParamsValue(normalParam, GroupSizeKey, `10`)
|
resetSearchParamsValue(normalParam, GroupSizeKey, `10`)
|
||||||
searchInfo = parseSearchInfo(normalParam, schema, nil)
|
searchInfo, err := parseSearchInfo(normalParam, schema, nil)
|
||||||
assert.NoError(t, searchInfo.parseError)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, int64(10), searchInfo.planInfo.GroupSize)
|
assert.Equal(t, int64(10), searchInfo.planInfo.GroupSize)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
resetSearchParamsValue(normalParam, GroupSizeKey, `-1`)
|
resetSearchParamsValue(normalParam, GroupSizeKey, `-1`)
|
||||||
searchInfo = parseSearchInfo(normalParam, schema, nil)
|
_, err := parseSearchInfo(normalParam, schema, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "is negative"))
|
assert.True(t, strings.Contains(err.Error(), "is negative"))
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
resetSearchParamsValue(normalParam, GroupSizeKey, `xxx`)
|
resetSearchParamsValue(normalParam, GroupSizeKey, `xxx`)
|
||||||
searchInfo = parseSearchInfo(normalParam, schema, nil)
|
_, err := parseSearchInfo(normalParam, schema, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.True(t, strings.Contains(searchInfo.parseError.Error(), "failed to parse input group size"))
|
assert.True(t, strings.Contains(err.Error(), "failed to parse input group size"))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -3391,8 +3435,8 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("iteratorV2 normal", func(t *testing.T) {
|
t.Run("iteratorV2 normal", func(t *testing.T) {
|
||||||
param := generateValidParamsForSearchIteratorV2()
|
param := generateValidParamsForSearchIteratorV2()
|
||||||
searchInfo := parseSearchInfo(param, nil, nil)
|
searchInfo, err := parseSearchInfo(param, nil, nil)
|
||||||
assert.NoError(t, searchInfo.parseError)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, searchInfo.planInfo)
|
assert.NotNil(t, searchInfo.planInfo)
|
||||||
assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token)
|
assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token)
|
||||||
assert.Equal(t, kBatchSize, searchInfo.planInfo.SearchIteratorV2Info.BatchSize)
|
assert.Equal(t, kBatchSize, searchInfo.planInfo.SearchIteratorV2Info.BatchSize)
|
||||||
@ -3403,9 +3447,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
t.Run("iteratorV2 without isIterator", func(t *testing.T) {
|
t.Run("iteratorV2 without isIterator", func(t *testing.T) {
|
||||||
param := generateValidParamsForSearchIteratorV2()
|
param := generateValidParamsForSearchIteratorV2()
|
||||||
resetSearchParamsValue(param, IteratorField, "False")
|
resetSearchParamsValue(param, IteratorField, "False")
|
||||||
searchInfo := parseSearchInfo(param, nil, nil)
|
_, err := parseSearchInfo(param, nil, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.ErrorContains(t, searchInfo.parseError, "both")
|
assert.ErrorContains(t, err, "both")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("iteratorV2 with groupBy", func(t *testing.T) {
|
t.Run("iteratorV2 with groupBy", func(t *testing.T) {
|
||||||
@ -3422,9 +3466,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
schema := &schemapb.CollectionSchema{
|
schema := &schemapb.CollectionSchema{
|
||||||
Fields: fields,
|
Fields: fields,
|
||||||
}
|
}
|
||||||
searchInfo := parseSearchInfo(param, schema, nil)
|
_, err := parseSearchInfo(param, schema, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.ErrorContains(t, searchInfo.parseError, "roupBy")
|
assert.ErrorContains(t, err, "roupBy")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("iteratorV2 with offset", func(t *testing.T) {
|
t.Run("iteratorV2 with offset", func(t *testing.T) {
|
||||||
@ -3433,9 +3477,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
Key: OffsetKey,
|
Key: OffsetKey,
|
||||||
Value: "10",
|
Value: "10",
|
||||||
})
|
})
|
||||||
searchInfo := parseSearchInfo(param, nil, nil)
|
_, err := parseSearchInfo(param, nil, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.ErrorContains(t, searchInfo.parseError, "offset")
|
assert.ErrorContains(t, err, "offset")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("iteratorV2 invalid token", func(t *testing.T) {
|
t.Run("iteratorV2 invalid token", func(t *testing.T) {
|
||||||
@ -3444,9 +3488,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
Key: SearchIterIdKey,
|
Key: SearchIterIdKey,
|
||||||
Value: "invalid_token",
|
Value: "invalid_token",
|
||||||
})
|
})
|
||||||
searchInfo := parseSearchInfo(param, nil, nil)
|
_, err := parseSearchInfo(param, nil, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.ErrorContains(t, searchInfo.parseError, "invalid token format")
|
assert.ErrorContains(t, err, "invalid token format")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("iteratorV2 passed token must be same", func(t *testing.T) {
|
t.Run("iteratorV2 passed token must be same", func(t *testing.T) {
|
||||||
@ -3457,8 +3501,8 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
Key: SearchIterIdKey,
|
Key: SearchIterIdKey,
|
||||||
Value: token.String(),
|
Value: token.String(),
|
||||||
})
|
})
|
||||||
searchInfo := parseSearchInfo(param, nil, nil)
|
searchInfo, err := parseSearchInfo(param, nil, nil)
|
||||||
assert.NoError(t, searchInfo.parseError)
|
assert.NoError(t, err)
|
||||||
assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token)
|
assert.NotEmpty(t, searchInfo.planInfo.SearchIteratorV2Info.Token)
|
||||||
assert.Equal(t, token.String(), searchInfo.planInfo.SearchIteratorV2Info.Token)
|
assert.Equal(t, token.String(), searchInfo.planInfo.SearchIteratorV2Info.Token)
|
||||||
})
|
})
|
||||||
@ -3466,33 +3510,33 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
t.Run("iteratorV2 batch size", func(t *testing.T) {
|
t.Run("iteratorV2 batch size", func(t *testing.T) {
|
||||||
param := generateValidParamsForSearchIteratorV2()
|
param := generateValidParamsForSearchIteratorV2()
|
||||||
resetSearchParamsValue(param, SearchIterBatchSizeKey, "1.123")
|
resetSearchParamsValue(param, SearchIterBatchSizeKey, "1.123")
|
||||||
searchInfo := parseSearchInfo(param, nil, nil)
|
_, err := parseSearchInfo(param, nil, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
|
assert.ErrorContains(t, err, "batch size is invalid")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("iteratorV2 batch size", func(t *testing.T) {
|
t.Run("iteratorV2 batch size", func(t *testing.T) {
|
||||||
param := generateValidParamsForSearchIteratorV2()
|
param := generateValidParamsForSearchIteratorV2()
|
||||||
resetSearchParamsValue(param, SearchIterBatchSizeKey, "")
|
resetSearchParamsValue(param, SearchIterBatchSizeKey, "")
|
||||||
searchInfo := parseSearchInfo(param, nil, nil)
|
_, err := parseSearchInfo(param, nil, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.ErrorContains(t, searchInfo.parseError, "batch size is required")
|
assert.ErrorContains(t, err, "batch size is required")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("iteratorV2 batch size negative", func(t *testing.T) {
|
t.Run("iteratorV2 batch size negative", func(t *testing.T) {
|
||||||
param := generateValidParamsForSearchIteratorV2()
|
param := generateValidParamsForSearchIteratorV2()
|
||||||
resetSearchParamsValue(param, SearchIterBatchSizeKey, "-1")
|
resetSearchParamsValue(param, SearchIterBatchSizeKey, "-1")
|
||||||
searchInfo := parseSearchInfo(param, nil, nil)
|
_, err := parseSearchInfo(param, nil, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
|
assert.ErrorContains(t, err, "batch size is invalid")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("iteratorV2 batch size too large", func(t *testing.T) {
|
t.Run("iteratorV2 batch size too large", func(t *testing.T) {
|
||||||
param := generateValidParamsForSearchIteratorV2()
|
param := generateValidParamsForSearchIteratorV2()
|
||||||
resetSearchParamsValue(param, SearchIterBatchSizeKey, fmt.Sprintf("%d", Params.QuotaConfig.TopKLimit.GetAsInt64()+1))
|
resetSearchParamsValue(param, SearchIterBatchSizeKey, fmt.Sprintf("%d", Params.QuotaConfig.TopKLimit.GetAsInt64()+1))
|
||||||
searchInfo := parseSearchInfo(param, nil, nil)
|
_, err := parseSearchInfo(param, nil, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.ErrorContains(t, searchInfo.parseError, "batch size is invalid")
|
assert.ErrorContains(t, err, "batch size is invalid")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("iteratorV2 last bound", func(t *testing.T) {
|
t.Run("iteratorV2 last bound", func(t *testing.T) {
|
||||||
@ -3502,8 +3546,8 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
Key: SearchIterLastBoundKey,
|
Key: SearchIterLastBoundKey,
|
||||||
Value: fmt.Sprintf("%f", kLastBound),
|
Value: fmt.Sprintf("%f", kLastBound),
|
||||||
})
|
})
|
||||||
searchInfo := parseSearchInfo(param, nil, nil)
|
searchInfo, err := parseSearchInfo(param, nil, nil)
|
||||||
assert.NoError(t, searchInfo.parseError)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, searchInfo.planInfo)
|
assert.NotNil(t, searchInfo.planInfo)
|
||||||
assert.Equal(t, kLastBound, *searchInfo.planInfo.SearchIteratorV2Info.LastBound)
|
assert.Equal(t, kLastBound, *searchInfo.planInfo.SearchIteratorV2Info.LastBound)
|
||||||
})
|
})
|
||||||
@ -3514,9 +3558,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
|||||||
Key: SearchIterLastBoundKey,
|
Key: SearchIterLastBoundKey,
|
||||||
Value: "xxx",
|
Value: "xxx",
|
||||||
})
|
})
|
||||||
searchInfo := parseSearchInfo(param, nil, nil)
|
_, err := parseSearchInfo(param, nil, nil)
|
||||||
assert.Error(t, searchInfo.parseError)
|
assert.Error(t, err)
|
||||||
assert.ErrorContains(t, searchInfo.parseError, "failed to parse input last bound")
|
assert.ErrorContains(t, err, "failed to parse input last bound")
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||||
)
|
)
|
||||||
@ -127,12 +126,18 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||||||
}
|
}
|
||||||
|
|
||||||
resultOffsets := make([][]int64, len(searchResultData))
|
resultOffsets := make([][]int64, len(searchResultData))
|
||||||
for i := 0; i < len(searchResultData); i++ {
|
groupByValIterator := make([]func(int) any, len(searchResultData))
|
||||||
|
for i := range searchResultData {
|
||||||
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
||||||
for j := int64(1); j < info.GetNq(); j++ {
|
for j := int64(1); j < info.GetNq(); j++ {
|
||||||
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
||||||
}
|
}
|
||||||
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
||||||
|
groupByValIterator[i] = typeutil.GetDataIterator(searchResultData[i].GetGroupByFieldValue())
|
||||||
|
}
|
||||||
|
gpFieldBuilder, err := typeutil.NewFieldDataBuilder(searchResultData[0].GetGroupByFieldValue().GetType(), true, int(info.GetTopK()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var filteredCount int64
|
var filteredCount int64
|
||||||
@ -159,13 +164,9 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||||||
idx := resultOffsets[sel][i] + offsets[sel]
|
idx := resultOffsets[sel][i] + offsets[sel]
|
||||||
|
|
||||||
id := typeutil.GetPK(searchResultData[sel].GetIds(), idx)
|
id := typeutil.GetPK(searchResultData[sel].GetIds(), idx)
|
||||||
groupByVal := typeutil.GetData(searchResultData[sel].GetGroupByFieldValue(), int(idx))
|
groupByVal := groupByValIterator[sel](int(idx))
|
||||||
score := searchResultData[sel].Scores[idx]
|
score := searchResultData[sel].Scores[idx]
|
||||||
if _, ok := idSet[id]; !ok {
|
if _, ok := idSet[id]; !ok {
|
||||||
if groupByVal == nil {
|
|
||||||
return ret, merr.WrapErrParameterMissing("GroupByVal returned from segment cannot be null")
|
|
||||||
}
|
|
||||||
|
|
||||||
groupCount := groupByValueMap[groupByVal]
|
groupCount := groupByValueMap[groupByVal]
|
||||||
if groupCount == 0 && int64(len(groupByValueMap)) >= info.GetTopK() {
|
if groupCount == 0 && int64(len(groupByValueMap)) >= info.GetTopK() {
|
||||||
// exceed the limit for group count, filter this entity
|
// exceed the limit for group count, filter this entity
|
||||||
@ -177,10 +178,7 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||||||
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
|
||||||
typeutil.AppendPKs(ret.Ids, id)
|
typeutil.AppendPKs(ret.Ids, id)
|
||||||
ret.Scores = append(ret.Scores, score)
|
ret.Scores = append(ret.Scores, score)
|
||||||
if err := typeutil.AppendGroupByValue(ret, groupByVal, searchResultData[sel].GetGroupByFieldValue().GetType()); err != nil {
|
gpFieldBuilder.Add(groupByVal)
|
||||||
log.Error("Failed to append groupByValues", zap.Error(err))
|
|
||||||
return ret, err
|
|
||||||
}
|
|
||||||
groupByValueMap[groupByVal] += 1
|
groupByValueMap[groupByVal] += 1
|
||||||
idSet[id] = struct{}{}
|
idSet[id] = struct{}{}
|
||||||
j++
|
j++
|
||||||
@ -198,6 +196,7 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||||||
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
return nil, fmt.Errorf("search results exceed the maxOutputSize Limit %d", maxOutputSize)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ret.GroupByFieldValue = gpFieldBuilder.Build()
|
||||||
if float64(filteredCount) >= 0.3*float64(groupBound) {
|
if float64(filteredCount) >= 0.3*float64(groupBound) {
|
||||||
log.Warn("GroupBy reduce filtered too many results, "+
|
log.Warn("GroupBy reduce filtered too many results, "+
|
||||||
"this may influence the final result seriously",
|
"this may influence the final result seriously",
|
||||||
|
|||||||
138
pkg/util/typeutil/field_data.go
Normal file
138
pkg/util/typeutil/field_data.go
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
package typeutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FieldDataBuilder struct {
|
||||||
|
dt schemapb.DataType
|
||||||
|
data []any
|
||||||
|
valid []bool
|
||||||
|
hasInvalid bool
|
||||||
|
|
||||||
|
fillZero bool // if true, fill zero value in returned field data for invalid rows
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFieldDataBuilder(dt schemapb.DataType, fillZero bool, capacity int) (*FieldDataBuilder, error) {
|
||||||
|
switch dt {
|
||||||
|
case schemapb.DataType_Bool,
|
||||||
|
schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64,
|
||||||
|
schemapb.DataType_VarChar:
|
||||||
|
return &FieldDataBuilder{
|
||||||
|
dt: dt,
|
||||||
|
data: make([]any, 0, capacity),
|
||||||
|
valid: make([]bool, 0, capacity),
|
||||||
|
fillZero: fillZero,
|
||||||
|
}, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("not supported field type: %s", dt.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *FieldDataBuilder) Add(data any) *FieldDataBuilder {
|
||||||
|
if data == nil {
|
||||||
|
b.hasInvalid = true
|
||||||
|
b.valid = append(b.valid, false)
|
||||||
|
} else {
|
||||||
|
b.data = append(b.data, data)
|
||||||
|
b.valid = append(b.valid, true)
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *FieldDataBuilder) Build() *schemapb.FieldData {
|
||||||
|
field := &schemapb.FieldData{
|
||||||
|
Type: b.dt,
|
||||||
|
}
|
||||||
|
if b.hasInvalid {
|
||||||
|
field.ValidData = b.valid
|
||||||
|
}
|
||||||
|
|
||||||
|
switch b.dt {
|
||||||
|
case schemapb.DataType_Bool:
|
||||||
|
val := make([]bool, 0, len(b.valid))
|
||||||
|
validIdx := 0
|
||||||
|
for _, v := range b.valid {
|
||||||
|
if v {
|
||||||
|
val = append(val, b.data[validIdx].(bool))
|
||||||
|
validIdx++
|
||||||
|
} else if b.fillZero {
|
||||||
|
val = append(val, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
field.Field = &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_BoolData{
|
||||||
|
BoolData: &schemapb.BoolArray{
|
||||||
|
Data: val,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||||
|
val := make([]int32, 0, len(b.valid))
|
||||||
|
validIdx := 0
|
||||||
|
for _, v := range b.valid {
|
||||||
|
if v {
|
||||||
|
val = append(val, b.data[validIdx].(int32))
|
||||||
|
validIdx++
|
||||||
|
} else if b.fillZero {
|
||||||
|
val = append(val, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
field.Field = &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_IntData{
|
||||||
|
IntData: &schemapb.IntArray{
|
||||||
|
Data: val,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case schemapb.DataType_Int64:
|
||||||
|
val := make([]int64, 0, len(b.valid))
|
||||||
|
validIdx := 0
|
||||||
|
for _, v := range b.valid {
|
||||||
|
if v {
|
||||||
|
val = append(val, b.data[validIdx].(int64))
|
||||||
|
validIdx++
|
||||||
|
} else if b.fillZero {
|
||||||
|
val = append(val, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
field.Field = &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_LongData{
|
||||||
|
LongData: &schemapb.LongArray{
|
||||||
|
Data: val,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
case schemapb.DataType_VarChar:
|
||||||
|
val := make([]string, 0, len(b.valid))
|
||||||
|
validIdx := 0
|
||||||
|
for _, v := range b.valid {
|
||||||
|
if v {
|
||||||
|
val = append(val, b.data[validIdx].(string))
|
||||||
|
validIdx++
|
||||||
|
} else if b.fillZero {
|
||||||
|
val = append(val, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
field.Field = &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_StringData{
|
||||||
|
StringData: &schemapb.StringArray{
|
||||||
|
Data: val,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return field
|
||||||
|
}
|
||||||
211
pkg/util/typeutil/field_data_test.go
Normal file
211
pkg/util/typeutil/field_data_test.go
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
package typeutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewFieldDataBuilder(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dt schemapb.DataType
|
||||||
|
fillZero bool
|
||||||
|
capacity int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid bool type",
|
||||||
|
dt: schemapb.DataType_Bool,
|
||||||
|
fillZero: true,
|
||||||
|
capacity: 10,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid int32 type",
|
||||||
|
dt: schemapb.DataType_Int32,
|
||||||
|
fillZero: false,
|
||||||
|
capacity: 5,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid varchar type",
|
||||||
|
dt: schemapb.DataType_VarChar,
|
||||||
|
fillZero: true,
|
||||||
|
capacity: 3,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid type",
|
||||||
|
dt: schemapb.DataType_FloatVector,
|
||||||
|
fillZero: true,
|
||||||
|
capacity: 10,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
builder, err := NewFieldDataBuilder(tt.dt, tt.fillZero, tt.capacity)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, builder)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, builder)
|
||||||
|
assert.Equal(t, tt.dt, builder.dt)
|
||||||
|
assert.Equal(t, tt.fillZero, builder.fillZero)
|
||||||
|
assert.Equal(t, 0, len(builder.data))
|
||||||
|
assert.Equal(t, 0, len(builder.valid))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldDataBuilder_Add(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dt schemapb.DataType
|
||||||
|
fillZero bool
|
||||||
|
inputs []any
|
||||||
|
want *FieldDataBuilder
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "add bool values",
|
||||||
|
dt: schemapb.DataType_Bool,
|
||||||
|
fillZero: true,
|
||||||
|
inputs: []any{true, nil, false},
|
||||||
|
want: &FieldDataBuilder{
|
||||||
|
dt: schemapb.DataType_Bool,
|
||||||
|
data: []any{true, false},
|
||||||
|
valid: []bool{true, false, true},
|
||||||
|
hasInvalid: true,
|
||||||
|
fillZero: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add int32 values",
|
||||||
|
dt: schemapb.DataType_Int32,
|
||||||
|
fillZero: false,
|
||||||
|
inputs: []any{int32(1), int32(2), nil},
|
||||||
|
want: &FieldDataBuilder{
|
||||||
|
dt: schemapb.DataType_Int32,
|
||||||
|
data: []any{int32(1), int32(2)},
|
||||||
|
valid: []bool{true, true, false},
|
||||||
|
hasInvalid: true,
|
||||||
|
fillZero: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
builder, err := NewFieldDataBuilder(tt.dt, tt.fillZero, len(tt.inputs))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
for _, input := range tt.inputs {
|
||||||
|
builder = builder.Add(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.want.dt, builder.dt)
|
||||||
|
assert.Equal(t, tt.want.data, builder.data)
|
||||||
|
assert.Equal(t, tt.want.valid, builder.valid)
|
||||||
|
assert.Equal(t, tt.want.hasInvalid, builder.hasInvalid)
|
||||||
|
assert.Equal(t, tt.want.fillZero, builder.fillZero)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFieldDataBuilder_Build(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dt schemapb.DataType
|
||||||
|
fillZero bool
|
||||||
|
inputs []any
|
||||||
|
want *schemapb.FieldData
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "build bool field with fillZero",
|
||||||
|
dt: schemapb.DataType_Bool,
|
||||||
|
fillZero: true,
|
||||||
|
inputs: []any{true, nil, false},
|
||||||
|
want: &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Bool,
|
||||||
|
ValidData: []bool{true, false, true},
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_BoolData{
|
||||||
|
BoolData: &schemapb.BoolArray{
|
||||||
|
Data: []bool{true, false, false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "build int32 field without fillZero",
|
||||||
|
dt: schemapb.DataType_Int32,
|
||||||
|
fillZero: false,
|
||||||
|
inputs: []any{int32(1), int32(2), nil},
|
||||||
|
want: &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Int32,
|
||||||
|
ValidData: []bool{true, true, false},
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_IntData{
|
||||||
|
IntData: &schemapb.IntArray{
|
||||||
|
Data: []int32{1, 2},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "build varchar field with fillZero",
|
||||||
|
dt: schemapb.DataType_VarChar,
|
||||||
|
fillZero: true,
|
||||||
|
inputs: []any{"hello", nil, "world"},
|
||||||
|
want: &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_VarChar,
|
||||||
|
ValidData: []bool{true, false, true},
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_StringData{
|
||||||
|
StringData: &schemapb.StringArray{
|
||||||
|
Data: []string{"hello", "", "world"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
builder, err := NewFieldDataBuilder(tt.dt, tt.fillZero, len(tt.inputs))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
for _, input := range tt.inputs {
|
||||||
|
builder = builder.Add(input)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := builder.Build()
|
||||||
|
assert.Equal(t, tt.want.Type, got.Type)
|
||||||
|
assert.Equal(t, tt.want.ValidData, got.ValidData)
|
||||||
|
|
||||||
|
switch tt.dt {
|
||||||
|
case schemapb.DataType_Bool:
|
||||||
|
assert.Equal(t, tt.want.GetScalars().GetBoolData().GetData(), got.GetScalars().GetBoolData().GetData())
|
||||||
|
case schemapb.DataType_Int32:
|
||||||
|
assert.Equal(t, tt.want.GetScalars().GetIntData().GetData(), got.GetScalars().GetIntData().GetData())
|
||||||
|
case schemapb.DataType_VarChar:
|
||||||
|
assert.Equal(t, tt.want.GetScalars().GetStringData().GetData(), got.GetScalars().GetStringData().GetData())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -724,21 +724,28 @@ func PrepareResultFieldData(sample []*schemapb.FieldData, topK int64) []*schemap
|
|||||||
// AppendFieldData appends fields data of specified index from src to dst
|
// AppendFieldData appends fields data of specified index from src to dst
|
||||||
func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int64) {
|
func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int64) {
|
||||||
for i, fieldData := range src {
|
for i, fieldData := range src {
|
||||||
switch fieldType := fieldData.Field.(type) {
|
if dst[i] == nil {
|
||||||
case *schemapb.FieldData_Scalars:
|
|
||||||
if dst[i] == nil || dst[i].GetScalars() == nil {
|
|
||||||
dst[i] = &schemapb.FieldData{
|
dst[i] = &schemapb.FieldData{
|
||||||
Type: fieldData.Type,
|
Type: fieldData.Type,
|
||||||
FieldName: fieldData.FieldName,
|
FieldName: fieldData.FieldName,
|
||||||
FieldId: fieldData.FieldId,
|
FieldId: fieldData.FieldId,
|
||||||
IsDynamic: fieldData.IsDynamic,
|
IsDynamic: fieldData.IsDynamic,
|
||||||
Field: &schemapb.FieldData_Scalars{
|
|
||||||
Scalars: &schemapb.ScalarField{},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// assign null data
|
||||||
if len(fieldData.GetValidData()) != 0 {
|
if len(fieldData.GetValidData()) != 0 {
|
||||||
dst[i].ValidData = append(dst[i].ValidData, fieldData.ValidData[idx])
|
if dst[i].ValidData == nil {
|
||||||
|
dst[i].ValidData = make([]bool, 0)
|
||||||
|
}
|
||||||
|
valid := fieldData.ValidData[idx]
|
||||||
|
dst[i].ValidData = append(dst[i].ValidData, valid)
|
||||||
|
}
|
||||||
|
switch fieldType := fieldData.Field.(type) {
|
||||||
|
case *schemapb.FieldData_Scalars:
|
||||||
|
if dst[i] == nil || dst[i].GetScalars() == nil {
|
||||||
|
dst[i].Field = &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
dstScalar := dst[i].GetScalars()
|
dstScalar := dst[i].GetScalars()
|
||||||
switch srcScalar := fieldType.Scalars.Data.(type) {
|
switch srcScalar := fieldType.Scalars.Data.(type) {
|
||||||
@ -1471,7 +1478,32 @@ func GetPK(data *schemapb.IDs, idx int64) interface{} {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetData(field *schemapb.FieldData, idx int) interface{} {
|
func GetDataIterator(field *schemapb.FieldData) func(int) any {
|
||||||
|
if field.ValidData != nil {
|
||||||
|
// unpack valid data
|
||||||
|
idxs := make([]int, len(field.ValidData))
|
||||||
|
validCnt := 0
|
||||||
|
for i, valid := range field.ValidData {
|
||||||
|
if valid {
|
||||||
|
idxs[i] = validCnt
|
||||||
|
validCnt++
|
||||||
|
} else {
|
||||||
|
idxs[i] = -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return func(idx int) any {
|
||||||
|
if idxs[idx] == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return getData(field, idxs[idx])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return func(idx int) any {
|
||||||
|
return getData(field, idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getData(field *schemapb.FieldData, idx int) any {
|
||||||
switch field.GetType() {
|
switch field.GetType() {
|
||||||
case schemapb.DataType_Bool:
|
case schemapb.DataType_Bool:
|
||||||
return field.GetScalars().GetBoolData().GetData()[idx]
|
return field.GetScalars().GetBoolData().GetData()[idx]
|
||||||
@ -1663,64 +1695,6 @@ func SelectMinPKWithTimestamp[T interface {
|
|||||||
return sel, drainResult
|
return sel, drainResult
|
||||||
}
|
}
|
||||||
|
|
||||||
func AppendGroupByValue(dstResData *schemapb.SearchResultData,
|
|
||||||
groupByVal interface{}, srcDataType schemapb.DataType,
|
|
||||||
) error {
|
|
||||||
if dstResData.GroupByFieldValue == nil {
|
|
||||||
dstResData.GroupByFieldValue = &schemapb.FieldData{
|
|
||||||
Type: srcDataType,
|
|
||||||
Field: &schemapb.FieldData_Scalars{
|
|
||||||
Scalars: &schemapb.ScalarField{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dstScalarField := dstResData.GroupByFieldValue.GetScalars()
|
|
||||||
switch srcDataType {
|
|
||||||
case schemapb.DataType_Bool:
|
|
||||||
if dstScalarField.GetBoolData() == nil {
|
|
||||||
dstScalarField.Data = &schemapb.ScalarField_BoolData{
|
|
||||||
BoolData: &schemapb.BoolArray{
|
|
||||||
Data: []bool{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dstScalarField.GetBoolData().Data = append(dstScalarField.GetBoolData().Data, groupByVal.(bool))
|
|
||||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
|
||||||
if dstScalarField.GetIntData() == nil {
|
|
||||||
dstScalarField.Data = &schemapb.ScalarField_IntData{
|
|
||||||
IntData: &schemapb.IntArray{
|
|
||||||
Data: []int32{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dstScalarField.GetIntData().Data = append(dstScalarField.GetIntData().Data, groupByVal.(int32))
|
|
||||||
case schemapb.DataType_Int64:
|
|
||||||
if dstScalarField.GetLongData() == nil {
|
|
||||||
dstScalarField.Data = &schemapb.ScalarField_LongData{
|
|
||||||
LongData: &schemapb.LongArray{
|
|
||||||
Data: []int64{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dstScalarField.GetLongData().Data = append(dstScalarField.GetLongData().Data, groupByVal.(int64))
|
|
||||||
case schemapb.DataType_VarChar:
|
|
||||||
if dstScalarField.GetStringData() == nil {
|
|
||||||
dstScalarField.Data = &schemapb.ScalarField_StringData{
|
|
||||||
StringData: &schemapb.StringArray{
|
|
||||||
Data: []string{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dstScalarField.GetStringData().Data = append(dstScalarField.GetStringData().Data, groupByVal.(string))
|
|
||||||
default:
|
|
||||||
log.Error("Not supported field type from group_by value field", zap.String("field type",
|
|
||||||
srcDataType.String()))
|
|
||||||
return fmt.Errorf("not supported field type from group_by value field: %s",
|
|
||||||
srcDataType.String())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendSparseFloatArray(dst, src *schemapb.SparseFloatArray) {
|
func appendSparseFloatArray(dst, src *schemapb.SparseFloatArray) {
|
||||||
if len(src.Contents) == 0 {
|
if len(src.Contents) == 0 {
|
||||||
return
|
return
|
||||||
|
|||||||
@ -1724,21 +1724,21 @@ func TestGetDataAndGetDataSize(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("test GetData", func(t *testing.T) {
|
t.Run("test GetData", func(t *testing.T) {
|
||||||
boolDataRes := GetData(boolData, 0)
|
boolDataRes := getData(boolData, 0)
|
||||||
int8DataRes := GetData(int8Data, 0)
|
int8DataRes := getData(int8Data, 0)
|
||||||
int16DataRes := GetData(int16Data, 0)
|
int16DataRes := getData(int16Data, 0)
|
||||||
int32DataRes := GetData(int32Data, 0)
|
int32DataRes := getData(int32Data, 0)
|
||||||
int64DataRes := GetData(int64Data, 0)
|
int64DataRes := getData(int64Data, 0)
|
||||||
floatDataRes := GetData(floatData, 0)
|
floatDataRes := getData(floatData, 0)
|
||||||
doubleDataRes := GetData(doubleData, 0)
|
doubleDataRes := getData(doubleData, 0)
|
||||||
varCharDataRes := GetData(varCharData, 0)
|
varCharDataRes := getData(varCharData, 0)
|
||||||
binVecDataRes := GetData(binVecData, 0)
|
binVecDataRes := getData(binVecData, 0)
|
||||||
floatVecDataRes := GetData(floatVecData, 0)
|
floatVecDataRes := getData(floatVecData, 0)
|
||||||
float16VecDataRes := GetData(float16VecData, 0)
|
float16VecDataRes := getData(float16VecData, 0)
|
||||||
bfloat16VecDataRes := GetData(bfloat16VecData, 0)
|
bfloat16VecDataRes := getData(bfloat16VecData, 0)
|
||||||
sparseFloatDataRes := GetData(sparseFloatData, 0)
|
sparseFloatDataRes := getData(sparseFloatData, 0)
|
||||||
int8VecDataRes := GetData(int8VecData, 0)
|
int8VecDataRes := getData(int8VecData, 0)
|
||||||
invalidDataRes := GetData(invalidData, 0)
|
invalidDataRes := getData(invalidData, 0)
|
||||||
|
|
||||||
assert.Equal(t, BoolArray[0], boolDataRes)
|
assert.Equal(t, BoolArray[0], boolDataRes)
|
||||||
assert.Equal(t, int32(Int8Array[0]), int8DataRes)
|
assert.Equal(t, int32(Int8Array[0]), int8DataRes)
|
||||||
@ -2872,3 +2872,68 @@ func TestSparsePlaceholderGroupSize(t *testing.T) {
|
|||||||
// no more than 2% cases have large error ratio.
|
// no more than 2% cases have large error ratio.
|
||||||
assert.Less(t, largeErrorRatio, 2.0)
|
assert.Less(t, largeErrorRatio, 2.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetDataIterator(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
field *schemapb.FieldData
|
||||||
|
want []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty field",
|
||||||
|
field: &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_LongData{
|
||||||
|
LongData: &schemapb.LongArray{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ints",
|
||||||
|
field: &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_LongData{
|
||||||
|
LongData: &schemapb.LongArray{
|
||||||
|
Data: []int64{1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: []any{int64(1), int64(2), int64(3)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ints with nulls",
|
||||||
|
field: &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_Int64,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_LongData{
|
||||||
|
LongData: &schemapb.LongArray{
|
||||||
|
Data: []int64{1, 2, 3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ValidData: []bool{true, false, true, true},
|
||||||
|
},
|
||||||
|
want: []any{int64(1), nil, int64(2), int64(3)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
itr := GetDataIterator(tt.field)
|
||||||
|
for i, want := range tt.want {
|
||||||
|
got := itr(i)
|
||||||
|
assert.Equal(t, want, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
321
tests/integration/search/search_test.go
Normal file
321
tests/integration/search/search_test.go
Normal file
@ -0,0 +1,321 @@
|
|||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package search
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||||
|
"github.com/milvus-io/milvus/tests/integration"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SearchSuite struct {
|
||||||
|
integration.MiniClusterSuite
|
||||||
|
|
||||||
|
indexType string
|
||||||
|
metricType string
|
||||||
|
vecType schemapb.DataType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SearchSuite) run() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
c := s.Cluster
|
||||||
|
|
||||||
|
const (
|
||||||
|
dim = 128
|
||||||
|
dbName = ""
|
||||||
|
rowNum = 3000
|
||||||
|
)
|
||||||
|
|
||||||
|
collectionName := "TestSearch" + funcutil.GenRandomStr()
|
||||||
|
groupByField := integration.VarCharField
|
||||||
|
|
||||||
|
schema := integration.ConstructSchema(collectionName, dim, true,
|
||||||
|
&schemapb.FieldSchema{
|
||||||
|
FieldID: 100,
|
||||||
|
Name: integration.Int64Field,
|
||||||
|
IsPrimaryKey: true,
|
||||||
|
Description: "",
|
||||||
|
DataType: schemapb.DataType_Int64,
|
||||||
|
TypeParams: nil,
|
||||||
|
IndexParams: nil,
|
||||||
|
AutoID: true,
|
||||||
|
},
|
||||||
|
&schemapb.FieldSchema{
|
||||||
|
FieldID: 101,
|
||||||
|
Name: groupByField,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
Description: "",
|
||||||
|
DataType: schemapb.DataType_VarChar,
|
||||||
|
TypeParams: []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: common.MaxLengthKey,
|
||||||
|
Value: fmt.Sprintf("%d", 256),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IndexParams: nil,
|
||||||
|
Nullable: true,
|
||||||
|
},
|
||||||
|
&schemapb.FieldSchema{
|
||||||
|
FieldID: 102,
|
||||||
|
Name: integration.FloatVecField,
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
Description: "",
|
||||||
|
DataType: schemapb.DataType_FloatVector,
|
||||||
|
TypeParams: []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: common.DimKey,
|
||||||
|
Value: fmt.Sprintf("%d", dim),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IndexParams: nil,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
marshaledSchema, err := proto.Marshal(schema)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{
|
||||||
|
DbName: dbName,
|
||||||
|
CollectionName: collectionName,
|
||||||
|
Schema: marshaledSchema,
|
||||||
|
ShardsNum: common.DefaultShardsNum,
|
||||||
|
})
|
||||||
|
s.NoError(err)
|
||||||
|
if createCollectionStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||||
|
log.Warn("createCollectionStatus fail reason", zap.String("reason", createCollectionStatus.GetReason()))
|
||||||
|
}
|
||||||
|
s.Require().Equal(createCollectionStatus.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||||
|
|
||||||
|
log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus))
|
||||||
|
showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{})
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(showCollectionsResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||||
|
log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp))
|
||||||
|
|
||||||
|
var fVecColumn *schemapb.FieldData
|
||||||
|
if s.vecType == schemapb.DataType_SparseFloatVector {
|
||||||
|
fVecColumn = integration.NewSparseFloatVectorFieldData(integration.SparseFloatVecField, rowNum)
|
||||||
|
} else {
|
||||||
|
fVecColumn = integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
|
||||||
|
}
|
||||||
|
fVarCharColumn := integration.NewVarCharFieldData(integration.VarCharField, rowNum, true)
|
||||||
|
hashKeys := integration.GenerateHashKeys(rowNum)
|
||||||
|
insertCheckReport := func() {
|
||||||
|
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeoutCtx.Done():
|
||||||
|
s.Fail("insert check timeout")
|
||||||
|
case report := <-c.Extension.GetReportChan():
|
||||||
|
reportInfo := report.(map[string]any)
|
||||||
|
log.Info("insert report info", zap.Any("reportInfo", reportInfo))
|
||||||
|
s.Equal(hookutil.OpTypeInsert, reportInfo[hookutil.OpTypeKey])
|
||||||
|
s.NotEqualValues(0, reportInfo[hookutil.RequestDataSizeKey])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go insertCheckReport()
|
||||||
|
insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
|
||||||
|
DbName: dbName,
|
||||||
|
CollectionName: collectionName,
|
||||||
|
FieldsData: []*schemapb.FieldData{fVarCharColumn, fVecColumn},
|
||||||
|
HashKeys: hashKeys,
|
||||||
|
NumRows: uint32(rowNum),
|
||||||
|
})
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)
|
||||||
|
|
||||||
|
// flush
|
||||||
|
flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
|
||||||
|
DbName: dbName,
|
||||||
|
CollectionNames: []string{collectionName},
|
||||||
|
})
|
||||||
|
s.NoError(err)
|
||||||
|
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||||
|
ids := segmentIDs.GetData()
|
||||||
|
s.Require().NotEmpty(segmentIDs)
|
||||||
|
s.Require().True(has)
|
||||||
|
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
|
||||||
|
s.True(has)
|
||||||
|
|
||||||
|
segments, err := c.MetaWatcher.ShowSegments()
|
||||||
|
s.NoError(err)
|
||||||
|
s.NotEmpty(segments)
|
||||||
|
for _, segment := range segments {
|
||||||
|
log.Info("ShowSegments result", zap.String("segment", segment.String()))
|
||||||
|
}
|
||||||
|
s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)
|
||||||
|
|
||||||
|
// create index
|
||||||
|
createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
|
||||||
|
CollectionName: collectionName,
|
||||||
|
FieldName: fVecColumn.FieldName,
|
||||||
|
IndexName: "_default",
|
||||||
|
ExtraParams: integration.ConstructIndexParam(dim, s.indexType, s.metricType),
|
||||||
|
})
|
||||||
|
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||||
|
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
|
||||||
|
}
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())
|
||||||
|
|
||||||
|
s.WaitForIndexBuilt(ctx, collectionName, fVecColumn.FieldName)
|
||||||
|
|
||||||
|
// load
|
||||||
|
loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
|
||||||
|
DbName: dbName,
|
||||||
|
CollectionName: collectionName,
|
||||||
|
})
|
||||||
|
s.NoError(err)
|
||||||
|
if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success {
|
||||||
|
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
|
||||||
|
}
|
||||||
|
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
|
||||||
|
s.WaitForLoad(ctx, collectionName)
|
||||||
|
|
||||||
|
// search
|
||||||
|
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
|
||||||
|
nq := 10
|
||||||
|
topk := 10
|
||||||
|
roundDecimal := -1
|
||||||
|
|
||||||
|
params := integration.GetSearchParams(s.indexType, s.metricType)
|
||||||
|
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||||
|
fVecColumn.FieldName, s.vecType, nil, s.metricType, params, nq, dim, topk, roundDecimal)
|
||||||
|
searchReq.SearchParams = append(searchReq.SearchParams, &commonpb.KeyValuePair{
|
||||||
|
Key: proxy.GroupByFieldKey,
|
||||||
|
Value: groupByField,
|
||||||
|
})
|
||||||
|
|
||||||
|
searchCheckReport := func() {
|
||||||
|
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeoutCtx.Done():
|
||||||
|
s.Fail("search check timeout")
|
||||||
|
case report := <-c.Extension.GetReportChan():
|
||||||
|
reportInfo := report.(map[string]any)
|
||||||
|
log.Info("search report info", zap.Any("reportInfo", reportInfo))
|
||||||
|
s.Equal(hookutil.OpTypeSearch, reportInfo[hookutil.OpTypeKey])
|
||||||
|
s.NotEqualValues(0, reportInfo[hookutil.ResultDataSizeKey])
|
||||||
|
s.NotEqualValues(0, reportInfo[hookutil.RelatedDataSizeKey])
|
||||||
|
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go searchCheckReport()
|
||||||
|
searchResult, err := c.Proxy.Search(ctx, searchReq)
|
||||||
|
err = merr.CheckRPCCall(searchResult, err)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
results := searchResult.GetResults()
|
||||||
|
offset := 0
|
||||||
|
// verify group by field corresponds to fVarCharColumn
|
||||||
|
for i := range results.NumQueries {
|
||||||
|
k := int(results.Topks[i])
|
||||||
|
itr := typeutil.GetDataIterator(results.GroupByFieldValue)
|
||||||
|
m := make(map[any]any, k) // test if the group by field values are unique
|
||||||
|
for j := 0; j < k; j++ {
|
||||||
|
gpbVal := itr(offset + j)
|
||||||
|
s.NotContains(m, gpbVal)
|
||||||
|
m[gpbVal] = struct{}{}
|
||||||
|
}
|
||||||
|
offset += k
|
||||||
|
s.Equal(len(m), k)
|
||||||
|
}
|
||||||
|
|
||||||
|
queryCheckReport := func() {
|
||||||
|
timeoutCtx, cancelFunc := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancelFunc()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeoutCtx.Done():
|
||||||
|
s.Fail("query check timeout")
|
||||||
|
case report := <-c.Extension.GetReportChan():
|
||||||
|
reportInfo := report.(map[string]any)
|
||||||
|
log.Info("query report info", zap.Any("reportInfo", reportInfo))
|
||||||
|
s.Equal(hookutil.OpTypeQuery, reportInfo[hookutil.OpTypeKey])
|
||||||
|
s.NotEqualValues(0, reportInfo[hookutil.ResultDataSizeKey])
|
||||||
|
s.NotEqualValues(0, reportInfo[hookutil.RelatedDataSizeKey])
|
||||||
|
s.EqualValues(rowNum, reportInfo[hookutil.RelatedCntKey])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go queryCheckReport()
|
||||||
|
queryResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{
|
||||||
|
DbName: dbName,
|
||||||
|
CollectionName: collectionName,
|
||||||
|
Expr: "",
|
||||||
|
OutputFields: []string{"count(*)"},
|
||||||
|
})
|
||||||
|
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||||
|
log.Warn("searchResult fail reason", zap.String("reason", queryResult.GetStatus().GetReason()))
|
||||||
|
}
|
||||||
|
s.NoError(err)
|
||||||
|
s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode())
|
||||||
|
|
||||||
|
status, err := c.Proxy.ReleaseCollection(ctx, &milvuspb.ReleaseCollectionRequest{
|
||||||
|
CollectionName: collectionName,
|
||||||
|
})
|
||||||
|
err = merr.CheckRPCCall(status, err)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
status, err = c.Proxy.DropCollection(ctx, &milvuspb.DropCollectionRequest{
|
||||||
|
CollectionName: collectionName,
|
||||||
|
})
|
||||||
|
err = merr.CheckRPCCall(status, err)
|
||||||
|
s.NoError(err)
|
||||||
|
|
||||||
|
log.Info("TestSearch succeed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SearchSuite) TestSearch() {
|
||||||
|
s.indexType = integration.IndexFaissIvfFlat
|
||||||
|
s.metricType = metric.L2
|
||||||
|
s.vecType = schemapb.DataType_FloatVector
|
||||||
|
s.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearch(t *testing.T) {
|
||||||
|
suite.Run(t, new(SearchSuite))
|
||||||
|
}
|
||||||
@ -131,6 +131,28 @@ func NewVarCharSameFieldData(fieldName string, numRows int, value string) *schem
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewVarCharFieldData(fieldName string, numRows int, nullable bool) *schemapb.FieldData {
|
||||||
|
numValid := numRows
|
||||||
|
if nullable {
|
||||||
|
numValid = numRows / 2
|
||||||
|
}
|
||||||
|
return &schemapb.FieldData{
|
||||||
|
Type: schemapb.DataType_String,
|
||||||
|
FieldName: fieldName,
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_StringData{
|
||||||
|
StringData: &schemapb.StringArray{
|
||||||
|
Data: testutils.GenerateStringArray(numValid),
|
||||||
|
// Data: testutils.GenerateStringArray(numRows),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ValidData: testutils.GenerateBoolArray(numRows),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewStringFieldData(fieldName string, numRows int) *schemapb.FieldData {
|
func NewStringFieldData(fieldName string, numRows int) *schemapb.FieldData {
|
||||||
return testutils.NewStringFieldData(fieldName, numRows)
|
return testutils.NewStringFieldData(fieldName, numRows)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user