mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 01:58:34 +08:00
feat: supporing hybrid search group_by (#35982)
related: #35096 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>
This commit is contained in:
parent
62f4a6a112
commit
e480b103bd
@ -170,7 +170,8 @@ func TestMeta_ScalarAutoIndex(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Key: common.IndexTypeKey,
|
Key: common.IndexTypeKey,
|
||||||
Value: "HYBRID",
|
Value: "HYBRID",
|
||||||
}},
|
},
|
||||||
|
},
|
||||||
Timestamp: 0,
|
Timestamp: 0,
|
||||||
IsAutoIndex: true,
|
IsAutoIndex: true,
|
||||||
UserIndexParams: userIndexParams,
|
UserIndexParams: userIndexParams,
|
||||||
@ -205,7 +206,6 @@ func TestMeta_ScalarAutoIndex(t *testing.T) {
|
|||||||
assert.Equal(t, newIndexParams[0].Key, common.IndexTypeKey)
|
assert.Equal(t, newIndexParams[0].Key, common.IndexTypeKey)
|
||||||
assert.Equal(t, newIndexParams[0].Value, "INVERTED")
|
assert.Equal(t, newIndexParams[0].Value, "INVERTED")
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMeta_CanCreateIndex(t *testing.T) {
|
func TestMeta_CanCreateIndex(t *testing.T) {
|
||||||
|
|||||||
@ -769,7 +769,7 @@ func (s *taskSchedulerSuite) scheduler(handler Handler) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil)
|
catalog.EXPECT().AlterSegmentIndexes(mock.Anything, mock.Anything).Return(nil)
|
||||||
//catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil)
|
// catalog.EXPECT().SaveStatsTask(mock.Anything, mock.Anything).Return(nil)
|
||||||
|
|
||||||
in := mocks.NewMockIndexNodeClient(s.T())
|
in := mocks.NewMockIndexNodeClient(s.T())
|
||||||
in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
in.EXPECT().CreateJobV2(mock.Anything, mock.Anything).Return(merr.Success(), nil)
|
||||||
|
|||||||
@ -63,7 +63,7 @@ func mergeSortMultipleSegments(ctx context.Context,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
//SegmentDeserializeReaderTest(binlogPaths, t.binlogIO, writer.GetPkID())
|
// SegmentDeserializeReaderTest(binlogPaths, t.binlogIO, writer.GetPkID())
|
||||||
segmentReaders := make([]*SegmentDeserializeReader, len(binlogs))
|
segmentReaders := make([]*SegmentDeserializeReader, len(binlogs))
|
||||||
for i, s := range binlogs {
|
for i, s := range binlogs {
|
||||||
var binlogBatchCount int
|
var binlogBatchCount int
|
||||||
|
|||||||
@ -119,7 +119,6 @@ func (s *PriorityQueueSuite) PriorityQueueMergeSort() {
|
|||||||
heap.Push(&pq, next)
|
heap.Push(&pq, next)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewPriorityQueueSuite(t *testing.T) {
|
func TestNewPriorityQueueSuite(t *testing.T) {
|
||||||
|
|||||||
@ -94,11 +94,8 @@ message SubSearchRequest {
|
|||||||
int64 topk = 7;
|
int64 topk = 7;
|
||||||
int64 offset = 8;
|
int64 offset = 8;
|
||||||
string metricType = 9;
|
string metricType = 9;
|
||||||
}
|
int64 group_by_field_id = 10;
|
||||||
|
int64 group_size = 11;
|
||||||
message ExtraSearchParam {
|
|
||||||
int64 group_by_field_id = 1;
|
|
||||||
int64 group_size = 2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message SearchRequest {
|
message SearchRequest {
|
||||||
@ -125,7 +122,8 @@ message SearchRequest {
|
|||||||
bool is_advanced = 20;
|
bool is_advanced = 20;
|
||||||
int64 offset = 21;
|
int64 offset = 21;
|
||||||
common.ConsistencyLevel consistency_level = 22;
|
common.ConsistencyLevel consistency_level = 22;
|
||||||
ExtraSearchParam extra_search_param = 23;
|
int64 group_by_field_id = 23;
|
||||||
|
int64 group_size = 24;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SubSearchResults {
|
message SubSearchResults {
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
@ -20,54 +20,137 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type reduceSearchResultInfo struct {
|
func reduceSearchResult(ctx context.Context, subSearchResultData []*schemapb.SearchResultData, reduceInfo *reduce.ResultInfo) (*milvuspb.SearchResults, error) {
|
||||||
subSearchResultData []*schemapb.SearchResultData
|
if reduceInfo.GetGroupByFieldId() > 0 {
|
||||||
nq int64
|
if reduceInfo.GetIsAdvance() {
|
||||||
topK int64
|
// for hybrid search group by, we cannot reduce result for results from one single search path,
|
||||||
metricType string
|
// because the final score has not been accumulated, also, offset cannot be applied
|
||||||
pkType schemapb.DataType
|
return reduceAdvanceGroupBY(ctx,
|
||||||
offset int64
|
subSearchResultData, reduceInfo.GetNq(), reduceInfo.GetTopK(), reduceInfo.GetPkType(), reduceInfo.GetMetricType())
|
||||||
queryInfo *planpb.QueryInfo
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func NewReduceSearchResultInfo(
|
|
||||||
subSearchResultData []*schemapb.SearchResultData,
|
|
||||||
nq int64,
|
|
||||||
topK int64,
|
|
||||||
metricType string,
|
|
||||||
pkType schemapb.DataType,
|
|
||||||
offset int64,
|
|
||||||
queryInfo *planpb.QueryInfo,
|
|
||||||
) *reduceSearchResultInfo {
|
|
||||||
return &reduceSearchResultInfo{
|
|
||||||
subSearchResultData: subSearchResultData,
|
|
||||||
nq: nq,
|
|
||||||
topK: topK,
|
|
||||||
metricType: metricType,
|
|
||||||
pkType: pkType,
|
|
||||||
offset: offset,
|
|
||||||
queryInfo: queryInfo,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func reduceSearchResult(ctx context.Context, reduceInfo *reduceSearchResultInfo) (*milvuspb.SearchResults, error) {
|
|
||||||
if reduceInfo.queryInfo.GroupByFieldId > 0 {
|
|
||||||
return reduceSearchResultDataWithGroupBy(ctx,
|
return reduceSearchResultDataWithGroupBy(ctx,
|
||||||
reduceInfo.subSearchResultData,
|
subSearchResultData,
|
||||||
reduceInfo.nq,
|
reduceInfo.GetNq(),
|
||||||
reduceInfo.topK,
|
reduceInfo.GetTopK(),
|
||||||
reduceInfo.metricType,
|
reduceInfo.GetMetricType(),
|
||||||
reduceInfo.pkType,
|
reduceInfo.GetPkType(),
|
||||||
reduceInfo.offset,
|
reduceInfo.GetOffset(),
|
||||||
reduceInfo.queryInfo.GroupSize)
|
reduceInfo.GetGroupSize())
|
||||||
}
|
}
|
||||||
return reduceSearchResultDataNoGroupBy(ctx,
|
return reduceSearchResultDataNoGroupBy(ctx,
|
||||||
reduceInfo.subSearchResultData,
|
subSearchResultData,
|
||||||
reduceInfo.nq,
|
reduceInfo.GetNq(),
|
||||||
reduceInfo.topK,
|
reduceInfo.GetTopK(),
|
||||||
reduceInfo.metricType,
|
reduceInfo.GetMetricType(),
|
||||||
reduceInfo.pkType,
|
reduceInfo.GetPkType(),
|
||||||
reduceInfo.offset)
|
reduceInfo.GetOffset())
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkResultDatas(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
||||||
|
nq int64, topK int64,
|
||||||
|
) (int64, int, error) {
|
||||||
|
var allSearchCount int64
|
||||||
|
var hitNum int
|
||||||
|
for i, sData := range subSearchResultData {
|
||||||
|
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
||||||
|
log.Ctx(ctx).Debug("subSearchResultData",
|
||||||
|
zap.Int("result No.", i),
|
||||||
|
zap.Int64("nq", sData.NumQueries),
|
||||||
|
zap.Int64("topk", sData.TopK),
|
||||||
|
zap.Int("length of pks", pkLength),
|
||||||
|
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
||||||
|
allSearchCount += sData.GetAllSearchCount()
|
||||||
|
hitNum += pkLength
|
||||||
|
if err := checkSearchResultData(sData, nq, topK, pkLength); err != nil {
|
||||||
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||||
|
return allSearchCount, hitNum, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return allSearchCount, hitNum, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func reduceAdvanceGroupBY(ctx context.Context, subSearchResultData []*schemapb.SearchResultData,
|
||||||
|
nq int64, topK int64, pkType schemapb.DataType, metricType string,
|
||||||
|
) (*milvuspb.SearchResults, error) {
|
||||||
|
log.Ctx(ctx).Debug("reduceAdvanceGroupBY", zap.Int("len(subSearchResultData)", len(subSearchResultData)), zap.Int64("nq", nq))
|
||||||
|
// for advance group by, offset is not applied, so just return when there's only one channel
|
||||||
|
if len(subSearchResultData) == 1 {
|
||||||
|
return &milvuspb.SearchResults{
|
||||||
|
Status: merr.Success(),
|
||||||
|
Results: subSearchResultData[0],
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := &milvuspb.SearchResults{
|
||||||
|
Status: merr.Success(),
|
||||||
|
Results: &schemapb.SearchResultData{
|
||||||
|
NumQueries: nq,
|
||||||
|
TopK: topK,
|
||||||
|
Scores: []float32{},
|
||||||
|
Ids: &schemapb.IDs{},
|
||||||
|
Topks: []int64{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var limit int64
|
||||||
|
if allSearchCount, hitNum, err := checkResultDatas(ctx, subSearchResultData, nq, topK); err != nil {
|
||||||
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||||
|
return ret, err
|
||||||
|
} else {
|
||||||
|
ret.GetResults().AllSearchCount = allSearchCount
|
||||||
|
limit = int64(hitNum)
|
||||||
|
ret.GetResults().FieldsData = typeutil.PrepareResultFieldData(subSearchResultData[0].GetFieldsData(), limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
subSearchNum = len(subSearchResultData)
|
||||||
|
// for results of each subSearchResultData, storing the start offset of each query of nq queries
|
||||||
|
subSearchNqOffset = make([][]int64, subSearchNum)
|
||||||
|
)
|
||||||
|
for i := 0; i < subSearchNum; i++ {
|
||||||
|
subSearchNqOffset[i] = make([]int64, subSearchResultData[i].GetNumQueries())
|
||||||
|
for j := int64(1); j < nq; j++ {
|
||||||
|
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// reducing nq * topk results
|
||||||
|
for nqIdx := int64(0); nqIdx < nq; nqIdx++ {
|
||||||
|
dataCount := int64(0)
|
||||||
|
for subIdx := 0; subIdx < subSearchNum; subIdx += 1 {
|
||||||
|
subData := subSearchResultData[subIdx]
|
||||||
|
subPks := subData.GetIds()
|
||||||
|
subScores := subData.GetScores()
|
||||||
|
subGroupByVals := subData.GetGroupByFieldValue()
|
||||||
|
|
||||||
|
nqTopK := subData.Topks[nqIdx]
|
||||||
|
for i := int64(0); i < nqTopK; i++ {
|
||||||
|
innerIdx := subSearchNqOffset[subIdx][nqIdx] + i
|
||||||
|
pk := typeutil.GetPK(subPks, innerIdx)
|
||||||
|
score := subScores[innerIdx]
|
||||||
|
groupByVal := typeutil.GetData(subData.GetGroupByFieldValue(), int(innerIdx))
|
||||||
|
typeutil.AppendPKs(ret.Results.Ids, pk)
|
||||||
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||||
|
if err := typeutil.AppendGroupByValue(ret.Results, groupByVal, subGroupByVals.GetType()); err != nil {
|
||||||
|
log.Ctx(ctx).Error("failed to append groupByValues", zap.Error(err))
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
|
dataCount += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret.Results.Topks = append(ret.Results.Topks, dataCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
ret.Results.TopK = topK // realTopK is the topK of the nq-th query
|
||||||
|
if !metric.PositivelyRelated(metricType) {
|
||||||
|
for k := range ret.Results.Scores {
|
||||||
|
ret.Results.Scores[k] *= -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MilvusPKType interface{}
|
type MilvusPKType interface{}
|
||||||
@ -109,37 +192,16 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||||||
Topks: []int64{},
|
Topks: []int64{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
groupBound := groupSize * limit
|
||||||
switch pkType {
|
if err := setupIdListForSearchResult(ret, pkType, groupBound); err != nil {
|
||||||
case schemapb.DataType_Int64:
|
return ret, nil
|
||||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
|
||||||
IntId: &schemapb.LongArray{
|
|
||||||
Data: make([]int64, 0, limit),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
case schemapb.DataType_VarChar:
|
|
||||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
|
||||||
StrId: &schemapb.StringArray{
|
|
||||||
Data: make([]string, 0, limit),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, errors.New("unsupported pk type")
|
|
||||||
}
|
}
|
||||||
for i, sData := range subSearchResultData {
|
|
||||||
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
|
||||||
log.Ctx(ctx).Debug("subSearchResultData",
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||||
zap.Int("result No.", i),
|
return ret, err
|
||||||
zap.Int64("nq", sData.NumQueries),
|
} else {
|
||||||
zap.Int64("topk", sData.TopK),
|
ret.GetResults().AllSearchCount = allSearchCount
|
||||||
zap.Int("length of pks", pkLength),
|
|
||||||
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
|
||||||
ret.Results.AllSearchCount += sData.GetAllSearchCount()
|
|
||||||
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
|
||||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
|
||||||
return ret, err
|
|
||||||
}
|
|
||||||
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -163,7 +225,6 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
|
|||||||
|
|
||||||
var retSize int64
|
var retSize int64
|
||||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||||
groupBound := groupSize * limit
|
|
||||||
|
|
||||||
// reducing nq * topk results
|
// reducing nq * topk results
|
||||||
for i := int64(0); i < nq; i++ {
|
for i := int64(0); i < nq; i++ {
|
||||||
@ -298,36 +359,15 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
switch pkType {
|
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
||||||
case schemapb.DataType_Int64:
|
return ret, nil
|
||||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
|
||||||
IntId: &schemapb.LongArray{
|
|
||||||
Data: make([]int64, 0, limit),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
case schemapb.DataType_VarChar:
|
|
||||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
|
||||||
StrId: &schemapb.StringArray{
|
|
||||||
Data: make([]string, 0, limit),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return nil, errors.New("unsupported pk type")
|
|
||||||
}
|
}
|
||||||
for i, sData := range subSearchResultData {
|
|
||||||
pkLength := typeutil.GetSizeOfIDs(sData.GetIds())
|
if allSearchCount, _, err := checkResultDatas(ctx, subSearchResultData, nq, topk); err != nil {
|
||||||
log.Ctx(ctx).Debug("subSearchResultData",
|
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||||
zap.Int("result No.", i),
|
return ret, err
|
||||||
zap.Int64("nq", sData.NumQueries),
|
} else {
|
||||||
zap.Int64("topk", sData.TopK),
|
ret.GetResults().AllSearchCount = allSearchCount
|
||||||
zap.Int("length of pks", pkLength),
|
|
||||||
zap.Int("length of FieldsData", len(sData.FieldsData)))
|
|
||||||
ret.Results.AllSearchCount += sData.GetAllSearchCount()
|
|
||||||
if err := checkSearchResultData(sData, nq, topk); err != nil {
|
|
||||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
|
||||||
return ret, err
|
|
||||||
}
|
|
||||||
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -428,23 +468,215 @@ func rankSearchResultData(ctx context.Context,
|
|||||||
params *rankParams,
|
params *rankParams,
|
||||||
pkType schemapb.DataType,
|
pkType schemapb.DataType,
|
||||||
searchResults []*milvuspb.SearchResults,
|
searchResults []*milvuspb.SearchResults,
|
||||||
|
groupByFieldID int64,
|
||||||
|
groupSize int64,
|
||||||
|
groupScorer func(group *Group) error,
|
||||||
) (*milvuspb.SearchResults, error) {
|
) (*milvuspb.SearchResults, error) {
|
||||||
tr := timerecord.NewTimeRecorder("rankSearchResultData")
|
if groupByFieldID > 0 {
|
||||||
|
return rankSearchResultDataByGroup(ctx, nq, params, pkType, searchResults, groupScorer, groupSize)
|
||||||
|
}
|
||||||
|
return rankSearchResultDataByPk(ctx, nq, params, pkType, searchResults)
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareKey(keyI interface{}, keyJ interface{}) bool {
|
||||||
|
switch keyI.(type) {
|
||||||
|
case int64:
|
||||||
|
return keyI.(int64) < keyJ.(int64)
|
||||||
|
case string:
|
||||||
|
return keyI.(string) < keyJ.(string)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetGroupScorer(scorerType string) (func(group *Group) error, error) {
|
||||||
|
switch scorerType {
|
||||||
|
case MaxScorer:
|
||||||
|
return func(group *Group) error {
|
||||||
|
group.finalScore = group.maxScore
|
||||||
|
return nil
|
||||||
|
}, nil
|
||||||
|
case SumScorer:
|
||||||
|
return func(group *Group) error {
|
||||||
|
group.finalScore = group.sumScore
|
||||||
|
return nil
|
||||||
|
}, nil
|
||||||
|
case AvgScorer:
|
||||||
|
return func(group *Group) error {
|
||||||
|
if len(group.idList) == 0 {
|
||||||
|
return merr.WrapErrParameterInvalid(1, len(group.idList),
|
||||||
|
"input group for score must have at least one id, must be sth wrong within code")
|
||||||
|
}
|
||||||
|
group.finalScore = group.sumScore / float32(len(group.idList))
|
||||||
|
return nil
|
||||||
|
}, nil
|
||||||
|
default:
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("input group scorer type: %s is not supported!", scorerType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Group struct {
|
||||||
|
idList []interface{}
|
||||||
|
scoreList []float32
|
||||||
|
groupVal interface{}
|
||||||
|
maxScore float32
|
||||||
|
sumScore float32
|
||||||
|
finalScore float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func rankSearchResultDataByGroup(ctx context.Context,
|
||||||
|
nq int64,
|
||||||
|
params *rankParams,
|
||||||
|
pkType schemapb.DataType,
|
||||||
|
searchResults []*milvuspb.SearchResults,
|
||||||
|
groupScorer func(group *Group) error,
|
||||||
|
groupSize int64,
|
||||||
|
) (*milvuspb.SearchResults, error) {
|
||||||
|
tr := timerecord.NewTimeRecorder("rankSearchResultDataByGroup")
|
||||||
defer func() {
|
defer func() {
|
||||||
tr.CtxElapse(ctx, "done")
|
tr.CtxElapse(ctx, "done")
|
||||||
}()
|
}()
|
||||||
|
offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal
|
||||||
offset := params.offset
|
// in the context of group by, the meaning for offset/limit/top refers to related numbers of group
|
||||||
limit := params.limit
|
groupTopK := limit + offset
|
||||||
topk := limit + offset
|
log.Ctx(ctx).Debug("rankSearchResultDataByGroup",
|
||||||
roundDecimal := params.roundDecimal
|
|
||||||
log.Ctx(ctx).Debug("rankSearchResultData",
|
|
||||||
zap.Int("len(searchResults)", len(searchResults)),
|
zap.Int("len(searchResults)", len(searchResults)),
|
||||||
zap.Int64("nq", nq),
|
zap.Int64("nq", nq),
|
||||||
zap.Int64("offset", offset),
|
zap.Int64("offset", offset),
|
||||||
zap.Int64("limit", limit))
|
zap.Int64("limit", limit))
|
||||||
|
|
||||||
ret := &milvuspb.SearchResults{
|
var ret *milvuspb.SearchResults
|
||||||
|
if ret = initSearchResults(nq, limit); len(searchResults) == 0 {
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
totalCount := limit * groupSize
|
||||||
|
if err := setupIdListForSearchResult(ret, pkType, totalCount); err != nil {
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type accumulateIDGroupVal struct {
|
||||||
|
accumulatedScore float32
|
||||||
|
groupVal interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
accumulatedScores := make([]map[interface{}]*accumulateIDGroupVal, nq)
|
||||||
|
for i := int64(0); i < nq; i++ {
|
||||||
|
accumulatedScores[i] = make(map[interface{}]*accumulateIDGroupVal)
|
||||||
|
}
|
||||||
|
groupByDataType := searchResults[0].GetResults().GetGroupByFieldValue().GetType()
|
||||||
|
for _, result := range searchResults {
|
||||||
|
scores := result.GetResults().GetScores()
|
||||||
|
start := 0
|
||||||
|
// milvus has limits for the value range of nq and limit
|
||||||
|
// no matter on 32-bit and 64-bit platform, converting nq and topK into int is safe
|
||||||
|
for i := 0; i < int(nq); i++ {
|
||||||
|
realTopK := int(result.GetResults().Topks[i])
|
||||||
|
for j := start; j < start+realTopK; j++ {
|
||||||
|
id := typeutil.GetPK(result.GetResults().GetIds(), int64(j))
|
||||||
|
groupByVal := typeutil.GetData(result.GetResults().GetGroupByFieldValue(), j)
|
||||||
|
if accumulatedScores[i][id] != nil {
|
||||||
|
accumulatedScores[i][id].accumulatedScore += scores[j]
|
||||||
|
} else {
|
||||||
|
accumulatedScores[i][id] = &accumulateIDGroupVal{accumulatedScore: scores[j], groupVal: groupByVal}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start += realTopK
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := int64(0); i < nq; i++ {
|
||||||
|
idSet := accumulatedScores[i]
|
||||||
|
keys := make([]interface{}, 0)
|
||||||
|
for key := range idSet {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort id by score
|
||||||
|
big := func(i, j int) bool {
|
||||||
|
scoreItemI := idSet[keys[i]]
|
||||||
|
scoreItemJ := idSet[keys[j]]
|
||||||
|
if scoreItemI.accumulatedScore == scoreItemJ.accumulatedScore {
|
||||||
|
return compareKey(keys[i], keys[j])
|
||||||
|
}
|
||||||
|
return scoreItemI.accumulatedScore > scoreItemJ.accumulatedScore
|
||||||
|
}
|
||||||
|
sort.Slice(keys, big)
|
||||||
|
|
||||||
|
// separate keys into buckets according to groupVal
|
||||||
|
buckets := make(map[interface{}]*Group)
|
||||||
|
for _, key := range keys {
|
||||||
|
scoreItem := idSet[key]
|
||||||
|
groupVal := scoreItem.groupVal
|
||||||
|
if buckets[groupVal] == nil {
|
||||||
|
buckets[groupVal] = &Group{
|
||||||
|
idList: make([]interface{}, 0),
|
||||||
|
scoreList: make([]float32, 0),
|
||||||
|
groupVal: groupVal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if int64(len(buckets[groupVal].idList)) >= groupSize {
|
||||||
|
// only consider group size results in each group
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
buckets[groupVal].idList = append(buckets[groupVal].idList, key)
|
||||||
|
buckets[groupVal].scoreList = append(buckets[groupVal].scoreList, scoreItem.accumulatedScore)
|
||||||
|
if scoreItem.accumulatedScore > buckets[groupVal].maxScore {
|
||||||
|
buckets[groupVal].maxScore = scoreItem.accumulatedScore
|
||||||
|
}
|
||||||
|
buckets[groupVal].sumScore += scoreItem.accumulatedScore
|
||||||
|
}
|
||||||
|
if int64(len(buckets)) <= offset {
|
||||||
|
ret.Results.Topks = append(ret.Results.Topks, 0)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
groupList := make([]*Group, len(buckets))
|
||||||
|
idx := 0
|
||||||
|
for _, group := range buckets {
|
||||||
|
groupScorer(group)
|
||||||
|
groupList[idx] = group
|
||||||
|
idx += 1
|
||||||
|
}
|
||||||
|
sort.Slice(groupList, func(i, j int) bool {
|
||||||
|
if groupList[i].finalScore == groupList[j].finalScore {
|
||||||
|
if len(groupList[i].idList) == len(groupList[j].idList) {
|
||||||
|
// if final score and size of group are both equal
|
||||||
|
// choose the group with smaller first key
|
||||||
|
// here, it's guaranteed all group having at least one id in the idList
|
||||||
|
return compareKey(groupList[i].idList[0], groupList[j].idList[0])
|
||||||
|
}
|
||||||
|
// choose the larger group when scores are equal
|
||||||
|
return len(groupList[i].idList) > len(groupList[j].idList)
|
||||||
|
}
|
||||||
|
return groupList[i].finalScore > groupList[j].finalScore
|
||||||
|
})
|
||||||
|
|
||||||
|
if int64(len(groupList)) > groupTopK {
|
||||||
|
groupList = groupList[:groupTopK]
|
||||||
|
}
|
||||||
|
returnedRowNum := 0
|
||||||
|
for index := int(offset); index < len(groupList); index++ {
|
||||||
|
group := groupList[index]
|
||||||
|
for i, score := range group.scoreList {
|
||||||
|
// idList and scoreList must have same length
|
||||||
|
typeutil.AppendPKs(ret.Results.Ids, group.idList[i])
|
||||||
|
if roundDecimal != -1 {
|
||||||
|
multiplier := math.Pow(10.0, float64(roundDecimal))
|
||||||
|
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||||
|
}
|
||||||
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||||
|
typeutil.AppendGroupByValue(ret.Results, group.groupVal, groupByDataType)
|
||||||
|
}
|
||||||
|
returnedRowNum += len(group.idList)
|
||||||
|
}
|
||||||
|
ret.Results.Topks = append(ret.Results.Topks, int64(returnedRowNum))
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func initSearchResults(nq int64, limit int64) *milvuspb.SearchResults {
|
||||||
|
return &milvuspb.SearchResults{
|
||||||
Status: merr.Success(),
|
Status: merr.Success(),
|
||||||
Results: &schemapb.SearchResultData{
|
Results: &schemapb.SearchResultData{
|
||||||
NumQueries: nq,
|
NumQueries: nq,
|
||||||
@ -455,22 +687,54 @@ func rankSearchResultData(ctx context.Context,
|
|||||||
Topks: []int64{},
|
Topks: []int64{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupIdListForSearchResult(searchResult *milvuspb.SearchResults, pkType schemapb.DataType, capacity int64) error {
|
||||||
switch pkType {
|
switch pkType {
|
||||||
case schemapb.DataType_Int64:
|
case schemapb.DataType_Int64:
|
||||||
ret.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
searchResult.GetResults().Ids.IdField = &schemapb.IDs_IntId{
|
||||||
IntId: &schemapb.LongArray{
|
IntId: &schemapb.LongArray{
|
||||||
Data: make([]int64, 0),
|
Data: make([]int64, 0, capacity),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
case schemapb.DataType_VarChar:
|
case schemapb.DataType_VarChar:
|
||||||
ret.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
searchResult.GetResults().Ids.IdField = &schemapb.IDs_StrId{
|
||||||
StrId: &schemapb.StringArray{
|
StrId: &schemapb.StringArray{
|
||||||
Data: make([]string, 0),
|
Data: make([]string, 0, capacity),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return nil, errors.New("unsupported pk type")
|
return errors.New("unsupported pk type")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func rankSearchResultDataByPk(ctx context.Context,
|
||||||
|
nq int64,
|
||||||
|
params *rankParams,
|
||||||
|
pkType schemapb.DataType,
|
||||||
|
searchResults []*milvuspb.SearchResults,
|
||||||
|
) (*milvuspb.SearchResults, error) {
|
||||||
|
tr := timerecord.NewTimeRecorder("rankSearchResultDataByPk")
|
||||||
|
defer func() {
|
||||||
|
tr.CtxElapse(ctx, "done")
|
||||||
|
}()
|
||||||
|
|
||||||
|
offset, limit, roundDecimal := params.offset, params.limit, params.roundDecimal
|
||||||
|
topk := limit + offset
|
||||||
|
log.Ctx(ctx).Debug("rankSearchResultDataByPk",
|
||||||
|
zap.Int("len(searchResults)", len(searchResults)),
|
||||||
|
zap.Int64("nq", nq),
|
||||||
|
zap.Int64("offset", offset),
|
||||||
|
zap.Int64("limit", limit))
|
||||||
|
|
||||||
|
var ret *milvuspb.SearchResults
|
||||||
|
if ret = initSearchResults(nq, limit); len(searchResults) == 0 {
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
||||||
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// []map[id]score
|
// []map[id]score
|
||||||
@ -503,20 +767,10 @@ func rankSearchResultData(ctx context.Context,
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
compareKeys := func(keyI, keyJ interface{}) bool {
|
|
||||||
switch keyI.(type) {
|
|
||||||
case int64:
|
|
||||||
return keyI.(int64) < keyJ.(int64)
|
|
||||||
case string:
|
|
||||||
return keyI.(string) < keyJ.(string)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort id by score
|
// sort id by score
|
||||||
big := func(i, j int) bool {
|
big := func(i, j int) bool {
|
||||||
if idSet[keys[i]] == idSet[keys[j]] {
|
if idSet[keys[i]] == idSet[keys[j]] {
|
||||||
return compareKeys(keys[i], keys[j])
|
return compareKey(keys[i], keys[j])
|
||||||
}
|
}
|
||||||
return idSet[keys[i]] > idSet[keys[j]]
|
return idSet[keys[i]] > idSet[keys[j]]
|
||||||
}
|
}
|
||||||
|
|||||||
133
internal/proxy/search_reduce_util_test.go
Normal file
133
internal/proxy/search_reduce_util_test.go
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SearchReduceUtilTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
}
|
||||||
|
|
||||||
|
func (struts *SearchReduceUtilTestSuite) TestRankByGroup() {
|
||||||
|
var searchResultData1 *schemapb.SearchResultData
|
||||||
|
var searchResultData2 *schemapb.SearchResultData
|
||||||
|
|
||||||
|
{
|
||||||
|
groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}
|
||||||
|
searchResultData1 = &schemapb.SearchResultData{
|
||||||
|
Ids: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_StrId{
|
||||||
|
StrId: &schemapb.StringArray{
|
||||||
|
Data: []string{"7", "5", "4", "2", "3", "6", "1", "9", "8"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Topks: []int64{9},
|
||||||
|
Scores: []float32{0.6, 0.53, 0.52, 0.43, 0.41, 0.33, 0.30, 0.27, 0.22},
|
||||||
|
GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
groupFieldValue := []string{"www", "aaa", "ccc", "www", "www", "ccc", "aaa", "ccc", "aaa"}
|
||||||
|
searchResultData2 = &schemapb.SearchResultData{
|
||||||
|
Ids: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_StrId{
|
||||||
|
StrId: &schemapb.StringArray{
|
||||||
|
Data: []string{"17", "15", "14", "12", "13", "16", "11", "19", "18"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Topks: []int64{9},
|
||||||
|
Scores: []float32{0.7, 0.43, 0.32, 0.32, 0.31, 0.31, 0.30, 0.30, 0.30},
|
||||||
|
GroupByFieldValue: getFieldData("string", int64(101), schemapb.DataType_VarChar, groupFieldValue, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
searchResults := []*milvuspb.SearchResults{
|
||||||
|
{Results: searchResultData1},
|
||||||
|
{Results: searchResultData2},
|
||||||
|
}
|
||||||
|
|
||||||
|
nq := int64(1)
|
||||||
|
limit := int64(3)
|
||||||
|
offset := int64(0)
|
||||||
|
roundDecimal := int64(1)
|
||||||
|
groupSize := int64(3)
|
||||||
|
groupByFieldId := int64(101)
|
||||||
|
rankParams := &rankParams{limit: limit, offset: offset, roundDecimal: roundDecimal}
|
||||||
|
|
||||||
|
{
|
||||||
|
// test for sum group scorer
|
||||||
|
scorerType := "sum"
|
||||||
|
groupScorer, _ := GetGroupScorer(scorerType)
|
||||||
|
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||||
|
struts.NoError(err)
|
||||||
|
struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||||
|
struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores())
|
||||||
|
struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// test for max group scorer
|
||||||
|
scorerType := "max"
|
||||||
|
groupScorer, _ := GetGroupScorer(scorerType)
|
||||||
|
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||||
|
struts.NoError(err)
|
||||||
|
struts.Equal([]string{"17", "12", "13", "7", "15", "1", "5", "2", "3"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||||
|
struts.Equal([]float32{0.7, 0.3, 0.3, 0.6, 0.4, 0.3, 0.5, 0.4, 0.4}, rankedRes.GetResults().GetScores())
|
||||||
|
struts.Equal([]string{"www", "www", "www", "aaa", "aaa", "aaa", "bbb", "bbb", "bbb"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// test for avg group scorer
|
||||||
|
scorerType := "avg"
|
||||||
|
groupScorer, _ := GetGroupScorer(scorerType)
|
||||||
|
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||||
|
struts.NoError(err)
|
||||||
|
struts.Equal([]string{"5", "2", "3", "17", "12", "13", "7", "15", "1"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||||
|
struts.Equal([]float32{0.5, 0.4, 0.4, 0.7, 0.3, 0.3, 0.6, 0.4, 0.3}, rankedRes.GetResults().GetScores())
|
||||||
|
struts.Equal([]string{"bbb", "bbb", "bbb", "www", "www", "www", "aaa", "aaa", "aaa"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// test for offset for ranking group
|
||||||
|
scorerType := "avg"
|
||||||
|
groupScorer, _ := GetGroupScorer(scorerType)
|
||||||
|
rankParams.offset = 2
|
||||||
|
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||||
|
struts.NoError(err)
|
||||||
|
struts.Equal([]string{"7", "15", "1", "4", "6", "14"}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||||
|
struts.Equal([]float32{0.6, 0.4, 0.3, 0.5, 0.3, 0.3}, rankedRes.GetResults().GetScores())
|
||||||
|
struts.Equal([]string{"aaa", "aaa", "aaa", "ccc", "ccc", "ccc"}, rankedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// test for offset exceeding the count of final groups
|
||||||
|
scorerType := "avg"
|
||||||
|
groupScorer, _ := GetGroupScorer(scorerType)
|
||||||
|
rankParams.offset = 4
|
||||||
|
rankedRes, err := rankSearchResultData(context.Background(), nq, rankParams, schemapb.DataType_VarChar, searchResults, groupByFieldId, groupSize, groupScorer)
|
||||||
|
struts.NoError(err)
|
||||||
|
struts.Equal([]string{}, rankedRes.GetResults().GetIds().GetStrId().Data)
|
||||||
|
struts.Equal([]float32{}, rankedRes.GetResults().GetScores())
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// test for invalid group scorer
|
||||||
|
scorerType := "xxx"
|
||||||
|
groupScorer, err := GetGroupScorer(scorerType)
|
||||||
|
struts.Error(err)
|
||||||
|
struts.Nil(groupScorer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchReduceUtilTestSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(SearchReduceUtilTestSuite))
|
||||||
|
}
|
||||||
@ -310,13 +310,15 @@ func parseRankParams(rankParamsPair []*commonpb.KeyValuePair) (*rankParams, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
|
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
|
||||||
|
searchParams := make([]*commonpb.KeyValuePair, len(req.GetRankParams()))
|
||||||
|
copy(searchParams, req.GetRankParams())
|
||||||
ret := &milvuspb.SearchRequest{
|
ret := &milvuspb.SearchRequest{
|
||||||
Base: req.GetBase(),
|
Base: req.GetBase(),
|
||||||
DbName: req.GetDbName(),
|
DbName: req.GetDbName(),
|
||||||
CollectionName: req.GetCollectionName(),
|
CollectionName: req.GetCollectionName(),
|
||||||
PartitionNames: req.GetPartitionNames(),
|
PartitionNames: req.GetPartitionNames(),
|
||||||
OutputFields: req.GetOutputFields(),
|
OutputFields: req.GetOutputFields(),
|
||||||
SearchParams: req.GetRankParams(),
|
SearchParams: searchParams,
|
||||||
TravelTimestamp: req.GetTravelTimestamp(),
|
TravelTimestamp: req.GetTravelTimestamp(),
|
||||||
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
|
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
|
||||||
Nq: 0,
|
Nq: 0,
|
||||||
|
|||||||
@ -42,6 +42,12 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
SumScorer string = "sum"
|
||||||
|
MaxScorer string = "max"
|
||||||
|
AvgScorer string = "avg"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
IgnoreGrowingKey = "ignore_growing"
|
IgnoreGrowingKey = "ignore_growing"
|
||||||
ReduceStopForBestKey = "reduce_stop_for_best"
|
ReduceStopForBestKey = "reduce_stop_for_best"
|
||||||
@ -49,6 +55,7 @@ const (
|
|||||||
GroupByFieldKey = "group_by_field"
|
GroupByFieldKey = "group_by_field"
|
||||||
GroupSizeKey = "group_size"
|
GroupSizeKey = "group_size"
|
||||||
GroupStrictSize = "group_strict_size"
|
GroupStrictSize = "group_strict_size"
|
||||||
|
RankGroupScorer = "rank_group_scorer"
|
||||||
AnnsFieldKey = "anns_field"
|
AnnsFieldKey = "anns_field"
|
||||||
TopKKey = "topk"
|
TopKKey = "topk"
|
||||||
NQKey = "nq"
|
NQKey = "nq"
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/metrics"
|
"github.com/milvus-io/milvus/pkg/metrics"
|
||||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||||
@ -76,8 +77,9 @@ type searchTask struct {
|
|||||||
queryInfos []*planpb.QueryInfo
|
queryInfos []*planpb.QueryInfo
|
||||||
relatedDataSize int64
|
relatedDataSize int64
|
||||||
|
|
||||||
reScorers []reScorer
|
reScorers []reScorer
|
||||||
rankParams *rankParams
|
rankParams *rankParams
|
||||||
|
groupScorer func(group *Group) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *searchTask) CanSkipAllocTimestamp() bool {
|
func (t *searchTask) CanSkipAllocTimestamp() bool {
|
||||||
@ -339,10 +341,9 @@ func setQueryInfoIfMvEnable(queryInfo *planpb.QueryInfo, t *searchTask, plan *pl
|
|||||||
func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
||||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init advanced search request")
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init advanced search request")
|
||||||
defer sp.End()
|
defer sp.End()
|
||||||
|
|
||||||
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
|
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
|
||||||
|
|
||||||
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName))
|
||||||
|
|
||||||
// fetch search_growing from search param
|
// fetch search_growing from search param
|
||||||
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
|
t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs()))
|
||||||
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
|
t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs()))
|
||||||
@ -351,9 +352,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if queryInfo.GetGroupByFieldId() != -1 {
|
|
||||||
return errors.New("not support search_group_by operation in the hybrid search")
|
|
||||||
}
|
|
||||||
internalSubReq := &internalpb.SubSearchRequest{
|
internalSubReq := &internalpb.SubSearchRequest{
|
||||||
Dsl: subReq.GetDsl(),
|
Dsl: subReq.GetDsl(),
|
||||||
PlaceholderGroup: subReq.GetPlaceholderGroup(),
|
PlaceholderGroup: subReq.GetPlaceholderGroup(),
|
||||||
@ -364,6 +363,8 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||||||
Topk: queryInfo.GetTopk(),
|
Topk: queryInfo.GetTopk(),
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
MetricType: queryInfo.GetMetricType(),
|
MetricType: queryInfo.GetMetricType(),
|
||||||
|
GroupByFieldId: queryInfo.GetGroupByFieldId(),
|
||||||
|
GroupSize: queryInfo.GetGroupSize(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// set PartitionIDs for sub search
|
// set PartitionIDs for sub search
|
||||||
@ -403,6 +404,11 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||||
}
|
}
|
||||||
|
if len(t.queryInfos) > 0 {
|
||||||
|
t.SearchRequest.GroupByFieldId = t.queryInfos[0].GetGroupByFieldId()
|
||||||
|
t.SearchRequest.GroupSize = t.queryInfos[0].GetGroupSize()
|
||||||
|
}
|
||||||
|
|
||||||
// used for requery
|
// used for requery
|
||||||
if t.partitionKeyMode {
|
if t.partitionKeyMode {
|
||||||
t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect()
|
t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect()
|
||||||
@ -413,6 +419,18 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||||||
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
|
log.Info("generate reScorer failed", zap.Any("params", t.request.GetSearchParams()), zap.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set up groupScorer for hybridsearch+groupBy
|
||||||
|
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, t.request.GetSearchParams())
|
||||||
|
if err != nil {
|
||||||
|
groupScorerStr = MaxScorer
|
||||||
|
}
|
||||||
|
groupScorer, err := GetGroupScorer(groupScorerStr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.groupScorer = groupScorer
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -461,7 +479,8 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
|||||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||||
t.queryInfos = append(t.queryInfos, queryInfo)
|
t.queryInfos = append(t.queryInfos, queryInfo)
|
||||||
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
t.SearchRequest.DslType = commonpb.DslType_BoolExprV1
|
||||||
t.SearchRequest.ExtraSearchParam = &internalpb.ExtraSearchParam{GroupByFieldId: queryInfo.GroupByFieldId, GroupSize: queryInfo.GroupSize}
|
t.SearchRequest.GroupByFieldId = queryInfo.GroupByFieldId
|
||||||
|
t.SearchRequest.GroupSize = queryInfo.GroupSize
|
||||||
log.Debug("proxy init search request",
|
log.Debug("proxy init search request",
|
||||||
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()),
|
||||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||||
@ -554,7 +573,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo) (*milvuspb.SearchResults, error) {
|
func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*internalpb.SearchResults, nq, topK int64, offset int64, queryInfo *planpb.QueryInfo, isAdvance bool) (*milvuspb.SearchResults, error) {
|
||||||
metricType := ""
|
metricType := ""
|
||||||
if len(toReduceResults) >= 1 {
|
if len(toReduceResults) >= 1 {
|
||||||
metricType = toReduceResults[0].GetMetricType()
|
metricType = toReduceResults[0].GetMetricType()
|
||||||
@ -585,8 +604,8 @@ func (t *searchTask) reduceResults(ctx context.Context, toReduceResults []*inter
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var result *milvuspb.SearchResults
|
var result *milvuspb.SearchResults
|
||||||
result, err = reduceSearchResult(ctx, NewReduceSearchResultInfo(validSearchResults, nq, topK,
|
result, err = reduceSearchResult(ctx, validSearchResults, reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metricType).WithPkType(primaryFieldSchema.GetDataType()).
|
||||||
metricType, primaryFieldSchema.DataType, offset, queryInfo))
|
WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()).WithAdvance(isAdvance))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("failed to reduce search results", zap.Error(err))
|
log.Warn("failed to reduce search results", zap.Error(err))
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -647,7 +666,6 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||||||
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
|
multipleInternalResults[reqIndex] = append(multipleInternalResults[reqIndex], internalResults)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
multipleMilvusResults := make([]*milvuspb.SearchResults, len(t.SearchRequest.GetSubReqs()))
|
||||||
for index, internalResults := range multipleInternalResults {
|
for index, internalResults := range multipleInternalResults {
|
||||||
subReq := t.SearchRequest.GetSubReqs()[index]
|
subReq := t.SearchRequest.GetSubReqs()[index]
|
||||||
@ -656,7 +674,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||||||
if len(internalResults) >= 1 {
|
if len(internalResults) >= 1 {
|
||||||
metricType = internalResults[0].GetMetricType()
|
metricType = internalResults[0].GetMetricType()
|
||||||
}
|
}
|
||||||
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index])
|
result, err := t.reduceResults(t.ctx, internalResults, subReq.GetNq(), subReq.GetTopk(), subReq.GetOffset(), t.queryInfos[index], true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -667,13 +685,16 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||||||
t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
|
t.result, err = rankSearchResultData(ctx, t.SearchRequest.GetNq(),
|
||||||
t.rankParams,
|
t.rankParams,
|
||||||
primaryFieldSchema.GetDataType(),
|
primaryFieldSchema.GetDataType(),
|
||||||
multipleMilvusResults)
|
multipleMilvusResults,
|
||||||
|
t.SearchRequest.GetGroupByFieldId(),
|
||||||
|
t.SearchRequest.GetGroupSize(),
|
||||||
|
t.groupScorer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("rank search result failed", zap.Error(err))
|
log.Warn("rank search result failed", zap.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0])
|
t.result, err = t.reduceResults(t.ctx, toReduceResults, t.SearchRequest.GetNq(), t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), t.queryInfos[0], false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -914,7 +935,7 @@ func decodeSearchResults(ctx context.Context, searchResults []*internalpb.Search
|
|||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64) error {
|
func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64, pkHitNum int) error {
|
||||||
if data.NumQueries != nq {
|
if data.NumQueries != nq {
|
||||||
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
return fmt.Errorf("search result's nq(%d) mis-match with %d", data.NumQueries, nq)
|
||||||
}
|
}
|
||||||
@ -922,7 +943,6 @@ func checkSearchResultData(data *schemapb.SearchResultData, nq int64, topk int64
|
|||||||
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
|
return fmt.Errorf("search result's topk(%d) mis-match with %d", data.TopK, topk)
|
||||||
}
|
}
|
||||||
|
|
||||||
pkHitNum := typeutil.GetSizeOfIDs(data.GetIds())
|
|
||||||
if len(data.Scores) != pkHitNum {
|
if len(data.Scores) != pkHitNum {
|
||||||
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
|
return fmt.Errorf("search result's score length invalid, score length=%d, expectedLength=%d",
|
||||||
len(data.Scores), pkHitNum)
|
len(data.Scores), pkHitNum)
|
||||||
|
|||||||
@ -39,6 +39,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
@ -1247,7 +1248,8 @@ func Test_checkSearchResultData(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
err := checkSearchResultData(test.args.data, test.args.nq, test.args.topk)
|
pkLength := typeutil.GetSizeOfIDs(test.args.data.GetIds())
|
||||||
|
err := checkSearchResultData(test.args.data, test.args.nq, test.args.topk, pkLength)
|
||||||
|
|
||||||
if test.wantErr {
|
if test.wantErr {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@ -1522,8 +1524,9 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
reduced, err := reduceSearchResult(context.TODO(),
|
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||||
NewReduceSearchResultInfo(results, nq, topk, metric.L2, schemapb.DataType_Int64, test.offset, queryInfo))
|
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).
|
||||||
|
WithOffset(test.offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||||
assert.Equal(t, []int64{test.limit, test.limit}, reduced.GetResults().GetTopks())
|
assert.Equal(t, []int64{test.limit, test.limit}, reduced.GetResults().GetTopks())
|
||||||
@ -1574,8 +1577,9 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, test := range lessThanLimitTests {
|
for _, test := range lessThanLimitTests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topk,
|
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||||
metric.L2, schemapb.DataType_Int64, test.offset, queryInfo))
|
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithOffset(test.offset).
|
||||||
|
WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
assert.Equal(t, test.outData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||||
assert.Equal(t, []int64{test.outLimit, test.outLimit}, reduced.GetResults().GetTopks())
|
assert.Equal(t, []int64{test.outLimit, test.outLimit}, reduced.GetResults().GetTopks())
|
||||||
@ -1603,9 +1607,8 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||||||
GroupByFieldId: -1,
|
GroupByFieldId: -1,
|
||||||
}
|
}
|
||||||
|
|
||||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(
|
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||||
results, nq, topk, metric.L2, schemapb.DataType_Int64, 0, queryInfo))
|
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetIntId().GetData())
|
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetIntId().GetData())
|
||||||
assert.Equal(t, []int64{5, 5}, reduced.GetResults().GetTopks())
|
assert.Equal(t, []int64{5, 5}, reduced.GetResults().GetTopks())
|
||||||
@ -1633,9 +1636,8 @@ func TestTaskSearch_reduceSearchResultData(t *testing.T) {
|
|||||||
queryInfo := &planpb.QueryInfo{
|
queryInfo := &planpb.QueryInfo{
|
||||||
GroupByFieldId: -1,
|
GroupByFieldId: -1,
|
||||||
}
|
}
|
||||||
|
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results,
|
reduce.NewReduceSearchResultInfo(nq, topk).WithMetricType(metric.L2).WithPkType(schemapb.DataType_VarChar).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||||
nq, topk, metric.L2, schemapb.DataType_VarChar, 0, queryInfo))
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetStrId().GetData())
|
assert.Equal(t, resultData, reduced.GetResults().GetIds().GetStrId().GetData())
|
||||||
@ -1708,8 +1710,8 @@ func TestTaskSearch_reduceGroupBySearchResultData(t *testing.T) {
|
|||||||
GroupByFieldId: 1,
|
GroupByFieldId: 1,
|
||||||
GroupSize: 1,
|
GroupSize: 1,
|
||||||
}
|
}
|
||||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2,
|
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||||
schemapb.DataType_Int64, 0, queryInfo))
|
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||||
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
||||||
resultScores := reduced.GetResults().GetScores()
|
resultScores := reduced.GetResults().GetScores()
|
||||||
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
||||||
@ -1768,8 +1770,8 @@ func TestTaskSearch_reduceGroupBySearchResultDataWithOffset(t *testing.T) {
|
|||||||
GroupByFieldId: 1,
|
GroupByFieldId: 1,
|
||||||
GroupSize: 1,
|
GroupSize: 1,
|
||||||
}
|
}
|
||||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, limit+offset, metric.L2,
|
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||||
schemapb.DataType_Int64, offset, queryInfo))
|
reduce.NewReduceSearchResultInfo(nq, limit+offset).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithOffset(offset).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||||
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
||||||
resultScores := reduced.GetResults().GetScores()
|
resultScores := reduced.GetResults().GetScores()
|
||||||
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
||||||
@ -1842,8 +1844,9 @@ func TestTaskSearch_reduceGroupBySearchWithGroupSizeMoreThanOne(t *testing.T) {
|
|||||||
GroupByFieldId: 1,
|
GroupByFieldId: 1,
|
||||||
GroupSize: 2,
|
GroupSize: 2,
|
||||||
}
|
}
|
||||||
reduced, err := reduceSearchResult(context.TODO(), NewReduceSearchResultInfo(results, nq, topK, metric.L2,
|
reduced, err := reduceSearchResult(context.TODO(), results,
|
||||||
schemapb.DataType_Int64, 0, queryInfo))
|
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(queryInfo.GetGroupByFieldId()).WithGroupSize(queryInfo.GetGroupSize()))
|
||||||
|
|
||||||
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
resultIDs := reduced.GetResults().GetIds().GetIntId().Data
|
||||||
resultScores := reduced.GetResults().GetScores()
|
resultScores := reduced.GetResults().GetScores()
|
||||||
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
resultGroupByValues := reduced.GetResults().GetGroupByFieldValue().GetScalars().GetLongData().GetData()
|
||||||
@ -1855,6 +1858,188 @@ func TestTaskSearch_reduceGroupBySearchWithGroupSizeMoreThanOne(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTaskSearch_reduceAdvanceSearchGroupBy(t *testing.T) {
|
||||||
|
groupByField := int64(101)
|
||||||
|
nq := int64(1)
|
||||||
|
subSearchResultData := make([]*schemapb.SearchResultData, 0)
|
||||||
|
topK := int64(3)
|
||||||
|
{
|
||||||
|
scores := []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43}
|
||||||
|
ids := []int64{7, 5, 6, 11, 22, 14, 31, 23, 37}
|
||||||
|
tops := []int64{9}
|
||||||
|
groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}
|
||||||
|
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
|
||||||
|
result1 := &schemapb.SearchResultData{
|
||||||
|
Scores: scores,
|
||||||
|
TopK: topK,
|
||||||
|
Ids: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: ids,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NumQueries: nq,
|
||||||
|
Topks: tops,
|
||||||
|
GroupByFieldValue: groupByVals,
|
||||||
|
}
|
||||||
|
subSearchResultData = append(subSearchResultData, result1)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
scores := []float32{0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51, 0.48}
|
||||||
|
ids := []int64{17, 15, 16, 21, 32, 24, 41, 33, 27}
|
||||||
|
tops := []int64{9}
|
||||||
|
groupFieldValue := []string{"xxx", "bbb", "ddd", "bbb", "bbb", "ddd", "xxx", "ddd", "xxx"}
|
||||||
|
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
|
||||||
|
result2 := &schemapb.SearchResultData{
|
||||||
|
TopK: topK,
|
||||||
|
Scores: scores,
|
||||||
|
Ids: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: ids,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Topks: tops,
|
||||||
|
NumQueries: nq,
|
||||||
|
GroupByFieldValue: groupByVals,
|
||||||
|
}
|
||||||
|
subSearchResultData = append(subSearchResultData, result2)
|
||||||
|
}
|
||||||
|
groupSize := int64(3)
|
||||||
|
|
||||||
|
reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData,
|
||||||
|
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.IP).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// reduce_advance_groupby will only merge results from different delegator without reducing any result
|
||||||
|
assert.Equal(t, 18, len(reducedRes.GetResults().Ids.GetIntId().Data))
|
||||||
|
assert.Equal(t, 18, len(reducedRes.GetResults().GetScores()))
|
||||||
|
assert.Equal(t, 18, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data))
|
||||||
|
assert.Equal(t, topK, reducedRes.GetResults().GetTopK())
|
||||||
|
assert.Equal(t, []int64{18}, reducedRes.GetResults().GetTopks())
|
||||||
|
|
||||||
|
assert.Equal(t, []int64{7, 5, 6, 11, 22, 14, 31, 23, 37, 17, 15, 16, 21, 32, 24, 41, 33, 27}, reducedRes.GetResults().Ids.GetIntId().Data)
|
||||||
|
assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43, 0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51, 0.48}, reducedRes.GetResults().GetScores())
|
||||||
|
assert.Equal(t, []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa", "xxx", "bbb", "ddd", "bbb", "bbb", "ddd", "xxx", "ddd", "xxx"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskSearch_reduceAdvanceSearchGroupByShortCut(t *testing.T) {
|
||||||
|
groupByField := int64(101)
|
||||||
|
nq := int64(1)
|
||||||
|
subSearchResultData := make([]*schemapb.SearchResultData, 0)
|
||||||
|
topK := int64(3)
|
||||||
|
{
|
||||||
|
scores := []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43}
|
||||||
|
ids := []int64{7, 5, 6, 11, 22, 14, 31, 23, 37}
|
||||||
|
tops := []int64{9}
|
||||||
|
groupFieldValue := []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}
|
||||||
|
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
|
||||||
|
result1 := &schemapb.SearchResultData{
|
||||||
|
Scores: scores,
|
||||||
|
TopK: topK,
|
||||||
|
Ids: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: ids,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NumQueries: nq,
|
||||||
|
Topks: tops,
|
||||||
|
GroupByFieldValue: groupByVals,
|
||||||
|
}
|
||||||
|
subSearchResultData = append(subSearchResultData, result1)
|
||||||
|
}
|
||||||
|
groupSize := int64(3)
|
||||||
|
|
||||||
|
reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData,
|
||||||
|
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.L2).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true))
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// reduce_advance_groupby will only merge results from different delegator without reducing any result
|
||||||
|
assert.Equal(t, 9, len(reducedRes.GetResults().Ids.GetIntId().Data))
|
||||||
|
assert.Equal(t, 9, len(reducedRes.GetResults().GetScores()))
|
||||||
|
assert.Equal(t, 9, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data))
|
||||||
|
assert.Equal(t, topK, reducedRes.GetResults().GetTopK())
|
||||||
|
assert.Equal(t, []int64{9}, reducedRes.GetResults().GetTopks())
|
||||||
|
|
||||||
|
assert.Equal(t, []int64{7, 5, 6, 11, 22, 14, 31, 23, 37}, reducedRes.GetResults().Ids.GetIntId().Data)
|
||||||
|
assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.52, 0.51, 0.5, 0.45, 0.43}, reducedRes.GetResults().GetScores())
|
||||||
|
assert.Equal(t, []string{"aaa", "bbb", "ccc", "bbb", "bbb", "ccc", "aaa", "ccc", "aaa"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTaskSearch_reduceAdvanceSearchGroupByMultipleNq(t *testing.T) {
|
||||||
|
groupByField := int64(101)
|
||||||
|
nq := int64(2)
|
||||||
|
subSearchResultData := make([]*schemapb.SearchResultData, 0)
|
||||||
|
topK := int64(2)
|
||||||
|
groupSize := int64(2)
|
||||||
|
{
|
||||||
|
scores := []float32{0.9, 0.7, 0.65, 0.55, 0.51, 0.5, 0.45, 0.43}
|
||||||
|
ids := []int64{7, 5, 6, 11, 14, 31, 23, 37}
|
||||||
|
tops := []int64{4, 4}
|
||||||
|
groupFieldValue := []string{"ccc", "bbb", "ccc", "bbb", "aaa", "xxx", "xxx", "aaa"}
|
||||||
|
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
|
||||||
|
result1 := &schemapb.SearchResultData{
|
||||||
|
Scores: scores,
|
||||||
|
TopK: topK,
|
||||||
|
Ids: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: ids,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NumQueries: nq,
|
||||||
|
Topks: tops,
|
||||||
|
GroupByFieldValue: groupByVals,
|
||||||
|
}
|
||||||
|
subSearchResultData = append(subSearchResultData, result1)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
scores := []float32{0.83, 0.72, 0.72, 0.65, 0.63, 0.55, 0.52, 0.51}
|
||||||
|
ids := []int64{17, 15, 16, 21, 32, 24, 41, 33}
|
||||||
|
tops := []int64{4, 4}
|
||||||
|
groupFieldValue := []string{"ddd", "bbb", "ddd", "bbb", "rrr", "sss", "rrr", "sss"}
|
||||||
|
groupByVals := getFieldData("string", groupByField, schemapb.DataType_VarChar, groupFieldValue, 1)
|
||||||
|
result2 := &schemapb.SearchResultData{
|
||||||
|
TopK: topK,
|
||||||
|
Scores: scores,
|
||||||
|
Ids: &schemapb.IDs{
|
||||||
|
IdField: &schemapb.IDs_IntId{
|
||||||
|
IntId: &schemapb.LongArray{
|
||||||
|
Data: ids,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Topks: tops,
|
||||||
|
NumQueries: nq,
|
||||||
|
GroupByFieldValue: groupByVals,
|
||||||
|
}
|
||||||
|
subSearchResultData = append(subSearchResultData, result2)
|
||||||
|
}
|
||||||
|
|
||||||
|
reducedRes, err := reduceSearchResult(context.Background(), subSearchResultData,
|
||||||
|
reduce.NewReduceSearchResultInfo(nq, topK).WithMetricType(metric.IP).WithPkType(schemapb.DataType_Int64).WithGroupByField(groupByField).WithGroupSize(groupSize).WithAdvance(true))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// reduce_advance_groupby will only merge results from different delegator without reducing any result
|
||||||
|
assert.Equal(t, 16, len(reducedRes.GetResults().Ids.GetIntId().Data))
|
||||||
|
assert.Equal(t, 16, len(reducedRes.GetResults().GetScores()))
|
||||||
|
assert.Equal(t, 16, len(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data))
|
||||||
|
|
||||||
|
assert.Equal(t, topK, reducedRes.GetResults().GetTopK())
|
||||||
|
assert.Equal(t, []int64{8, 8}, reducedRes.GetResults().GetTopks())
|
||||||
|
|
||||||
|
assert.Equal(t, []int64{7, 5, 6, 11, 17, 15, 16, 21, 14, 31, 23, 37, 32, 24, 41, 33}, reducedRes.GetResults().Ids.GetIntId().Data)
|
||||||
|
assert.Equal(t, []float32{0.9, 0.7, 0.65, 0.55, 0.83, 0.72, 0.72, 0.65, 0.51, 0.5, 0.45, 0.43, 0.63, 0.55, 0.52, 0.51}, reducedRes.GetResults().GetScores())
|
||||||
|
assert.Equal(t, []string{"ccc", "bbb", "ccc", "bbb", "ddd", "bbb", "ddd", "bbb", "aaa", "xxx", "xxx", "aaa", "rrr", "sss", "rrr", "sss"}, reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||||
|
|
||||||
|
fmt.Println(reducedRes.GetResults().Ids.GetIntId().Data)
|
||||||
|
fmt.Println(reducedRes.GetResults().GetScores())
|
||||||
|
fmt.Println(reducedRes.GetResults().GetGroupByFieldValue().GetScalars().GetStringData().Data)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSearchTask_ErrExecute(t *testing.T) {
|
func TestSearchTask_ErrExecute(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
|
|||||||
@ -43,6 +43,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||||
"github.com/milvus-io/milvus/internal/storage"
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
@ -332,6 +333,8 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||||||
IgnoreGrowing: req.GetReq().GetIgnoreGrowing(),
|
IgnoreGrowing: req.GetReq().GetIgnoreGrowing(),
|
||||||
Username: req.GetReq().GetUsername(),
|
Username: req.GetReq().GetUsername(),
|
||||||
IsAdvanced: false,
|
IsAdvanced: false,
|
||||||
|
GroupByFieldId: subReq.GetGroupByFieldId(),
|
||||||
|
GroupSize: subReq.GetGroupSize(),
|
||||||
}
|
}
|
||||||
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
||||||
searchReq := &querypb.SearchRequest{
|
searchReq := &querypb.SearchRequest{
|
||||||
@ -350,14 +353,12 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return segments.ReduceSearchResults(ctx,
|
return segments.ReduceSearchOnQueryNode(ctx,
|
||||||
results,
|
results,
|
||||||
segments.NewReduceInfo(searchReq.Req.GetNq(),
|
reduce.NewReduceSearchResultInfo(searchReq.GetReq().GetNq(),
|
||||||
searchReq.Req.GetTopk(),
|
searchReq.GetReq().GetTopk()).WithMetricType(searchReq.GetReq().GetMetricType()).
|
||||||
searchReq.Req.GetExtraSearchParam().GetGroupByFieldId(),
|
WithGroupByField(searchReq.GetReq().GetGroupByFieldId()).
|
||||||
searchReq.Req.GetExtraSearchParam().GetGroupSize(),
|
WithGroupSize(searchReq.GetReq().GetGroupSize()))
|
||||||
searchReq.Req.GetMetricType()),
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
futures[index] = future
|
futures[index] = future
|
||||||
}
|
}
|
||||||
@ -376,12 +377,7 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
|||||||
}
|
}
|
||||||
results[i] = result
|
results[i] = result
|
||||||
}
|
}
|
||||||
var ret *internalpb.SearchResults
|
return results, nil
|
||||||
ret, err = segments.MergeToAdvancedResults(ctx, results)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return []*internalpb.SearchResults{ret}, nil
|
|
||||||
}
|
}
|
||||||
return sd.search(ctx, req, sealed, growing)
|
return sd.search(ctx, req, sealed, growing)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,6 +32,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||||
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
|
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/metrics"
|
"github.com/milvus-io/milvus/pkg/metrics"
|
||||||
@ -384,16 +385,10 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
|
|||||||
req.GetSegmentIDs(),
|
req.GetSegmentIDs(),
|
||||||
))
|
))
|
||||||
|
|
||||||
var resp *internalpb.SearchResults
|
resp, err := segments.ReduceSearchOnQueryNode(ctx, results,
|
||||||
if req.GetReq().GetIsAdvanced() {
|
reduce.NewReduceSearchResultInfo(req.GetReq().GetNq(),
|
||||||
resp, err = segments.ReduceAdvancedSearchResults(ctx, results, req.Req.GetNq())
|
req.GetReq().GetTopk()).WithMetricType(req.GetReq().GetMetricType()).WithGroupByField(req.GetReq().GetGroupByFieldId()).
|
||||||
} else {
|
WithGroupSize(req.GetReq().GetGroupByFieldId()).WithAdvance(req.GetReq().GetIsAdvanced()))
|
||||||
resp, err = segments.ReduceSearchResults(ctx, results, segments.NewReduceInfo(req.Req.GetNq(),
|
|
||||||
req.Req.GetTopk(),
|
|
||||||
req.Req.GetExtraSearchParam().GetGroupByFieldId(),
|
|
||||||
req.Req.GetExtraSearchParam().GetGroupSize(),
|
|
||||||
req.Req.GetMetricType()))
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -29,6 +29,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||||
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
@ -42,7 +43,14 @@ var _ typeutil.ResultWithID = &internalpb.RetrieveResults{}
|
|||||||
|
|
||||||
var _ typeutil.ResultWithID = &segcorepb.RetrieveResults{}
|
var _ typeutil.ResultWithID = &segcorepb.RetrieveResults{}
|
||||||
|
|
||||||
func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, info *ReduceInfo) (*internalpb.SearchResults, error) {
|
func ReduceSearchOnQueryNode(ctx context.Context, results []*internalpb.SearchResults, info *reduce.ResultInfo) (*internalpb.SearchResults, error) {
|
||||||
|
if info.GetIsAdvance() {
|
||||||
|
return ReduceAdvancedSearchResults(ctx, results)
|
||||||
|
}
|
||||||
|
return ReduceSearchResults(ctx, results, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResults, info *reduce.ResultInfo) (*internalpb.SearchResults, error) {
|
||||||
results = lo.Filter(results, func(result *internalpb.SearchResults, _ int) bool {
|
results = lo.Filter(results, func(result *internalpb.SearchResults, _ int) bool {
|
||||||
return result != nil && result.GetSlicedBlob() != nil
|
return result != nil && result.GetSlicedBlob() != nil
|
||||||
})
|
})
|
||||||
@ -60,8 +68,8 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
|||||||
channelsMvcc[ch] = ts
|
channelsMvcc[ch] = ts
|
||||||
}
|
}
|
||||||
// shouldn't let new SearchResults.MetricType to be empty, though the req.MetricType is empty
|
// shouldn't let new SearchResults.MetricType to be empty, though the req.MetricType is empty
|
||||||
if info.metricType == "" {
|
if info.GetMetricType() == "" {
|
||||||
info.metricType = r.MetricType
|
info.SetMetricType(r.MetricType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log := log.Ctx(ctx)
|
log := log.Ctx(ctx)
|
||||||
@ -86,7 +94,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
|||||||
log.Warn("shard leader reduce errors", zap.Error(err))
|
log.Warn("shard leader reduce errors", zap.Error(err))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
searchResults, err := EncodeSearchResultData(ctx, reducedResultData, info.nq, info.topK, info.metricType)
|
searchResults, err := EncodeSearchResultData(ctx, reducedResultData, info.GetNq(), info.GetTopK(), info.GetMetricType())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("shard leader encode search result errors", zap.Error(err))
|
log.Warn("shard leader encode search result errors", zap.Error(err))
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -115,7 +123,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
|||||||
return searchResults, nil
|
return searchResults, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.SearchResults, nq int64) (*internalpb.SearchResults, error) {
|
func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.SearchResults) (*internalpb.SearchResults, error) {
|
||||||
_, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceAdvancedSearchResults")
|
_, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceAdvancedSearchResults")
|
||||||
defer sp.End()
|
defer sp.End()
|
||||||
|
|
||||||
@ -129,53 +137,14 @@ func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.Sear
|
|||||||
IsAdvanced: true,
|
IsAdvanced: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, result := range results {
|
|
||||||
relatedDataSize += result.GetCostAggregation().GetTotalRelatedDataSize()
|
|
||||||
for ch, ts := range result.GetChannelsMvcc() {
|
|
||||||
channelsMvcc[ch] = ts
|
|
||||||
}
|
|
||||||
if !result.GetIsAdvanced() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// we just append here, no need to split subResult and reduce
|
|
||||||
// defer this reduce to proxy
|
|
||||||
searchResults.SubResults = append(searchResults.SubResults, result.GetSubResults()...)
|
|
||||||
searchResults.NumQueries = result.GetNumQueries()
|
|
||||||
}
|
|
||||||
searchResults.ChannelsMvcc = channelsMvcc
|
|
||||||
requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) {
|
|
||||||
if paramtable.Get().QueryNodeCfg.EnableWorkerSQCostMetrics.GetAsBool() {
|
|
||||||
return result.GetCostAggregation(), true
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.GetBase().GetSourceID() == paramtable.GetNodeID() {
|
|
||||||
return result.GetCostAggregation(), true
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, false
|
|
||||||
})
|
|
||||||
searchResults.CostAggregation = mergeRequestCost(requestCosts)
|
|
||||||
if searchResults.CostAggregation == nil {
|
|
||||||
searchResults.CostAggregation = &internalpb.CostAggregation{}
|
|
||||||
}
|
|
||||||
searchResults.CostAggregation.TotalRelatedDataSize = relatedDataSize
|
|
||||||
return searchResults, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchResults) (*internalpb.SearchResults, error) {
|
|
||||||
searchResults := &internalpb.SearchResults{
|
|
||||||
IsAdvanced: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
channelsMvcc := make(map[string]uint64)
|
|
||||||
relatedDataSize := int64(0)
|
|
||||||
for index, result := range results {
|
for index, result := range results {
|
||||||
relatedDataSize += result.GetCostAggregation().GetTotalRelatedDataSize()
|
relatedDataSize += result.GetCostAggregation().GetTotalRelatedDataSize()
|
||||||
for ch, ts := range result.GetChannelsMvcc() {
|
for ch, ts := range result.GetChannelsMvcc() {
|
||||||
channelsMvcc[ch] = ts
|
channelsMvcc[ch] = ts
|
||||||
}
|
}
|
||||||
|
searchResults.NumQueries = result.GetNumQueries()
|
||||||
// we just append here, no need to split subResult and reduce
|
// we just append here, no need to split subResult and reduce
|
||||||
// defer this reduce to proxy
|
// defer this reduction to proxy
|
||||||
subResult := &internalpb.SubSearchResults{
|
subResult := &internalpb.SubSearchResults{
|
||||||
MetricType: result.GetMetricType(),
|
MetricType: result.GetMetricType(),
|
||||||
NumQueries: result.GetNumQueries(),
|
NumQueries: result.GetNumQueries(),
|
||||||
@ -185,7 +154,6 @@ func MergeToAdvancedResults(ctx context.Context, results []*internalpb.SearchRes
|
|||||||
SlicedOffset: result.GetSlicedOffset(),
|
SlicedOffset: result.GetSlicedOffset(),
|
||||||
ReqIndex: int64(index),
|
ReqIndex: int64(index),
|
||||||
}
|
}
|
||||||
searchResults.NumQueries = result.GetNumQueries()
|
|
||||||
searchResults.SubResults = append(searchResults.SubResults, subResult)
|
searchResults.SubResults = append(searchResults.SubResults, subResult)
|
||||||
}
|
}
|
||||||
searchResults.ChannelsMvcc = channelsMvcc
|
searchResults.ChannelsMvcc = channelsMvcc
|
||||||
|
|||||||
@ -29,7 +29,9 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
"github.com/milvus-io/milvus/internal/proto/segcorepb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
@ -886,6 +888,42 @@ func (suite *ResultSuite) TestSort() {
|
|||||||
}, result.FieldsData[9].GetScalars().GetArrayData().GetData())
|
}, result.FieldsData[9].GetScalars().GetArrayData().GetData())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (suite *ResultSuite) TestReduceSearchOnQueryNode() {
|
||||||
|
results := make([]*internalpb.SearchResults, 0)
|
||||||
|
metricType := metric.IP
|
||||||
|
nq := int64(1)
|
||||||
|
topK := int64(1)
|
||||||
|
mockBlob := []byte{65, 66, 67, 65, 66, 67}
|
||||||
|
{
|
||||||
|
subRes1 := &internalpb.SearchResults{
|
||||||
|
MetricType: metricType,
|
||||||
|
NumQueries: nq,
|
||||||
|
TopK: topK,
|
||||||
|
SlicedBlob: mockBlob,
|
||||||
|
}
|
||||||
|
results = append(results, subRes1)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
subRes2 := &internalpb.SearchResults{
|
||||||
|
MetricType: metricType,
|
||||||
|
NumQueries: nq,
|
||||||
|
TopK: topK,
|
||||||
|
SlicedBlob: mockBlob,
|
||||||
|
}
|
||||||
|
results = append(results, subRes2)
|
||||||
|
}
|
||||||
|
reducedRes, err := ReduceSearchOnQueryNode(context.Background(), results, reduce.NewReduceSearchResultInfo(nq, topK).
|
||||||
|
WithMetricType(metricType).WithPkType(schemapb.DataType_Int8).WithAdvance(true))
|
||||||
|
suite.NoError(err)
|
||||||
|
suite.Equal(2, len(reducedRes.GetSubResults()))
|
||||||
|
|
||||||
|
subRes1 := reducedRes.GetSubResults()[0]
|
||||||
|
suite.Equal(metricType, subRes1.GetMetricType())
|
||||||
|
suite.Equal(nq, subRes1.GetNumQueries())
|
||||||
|
suite.Equal(topK, subRes1.GetTopK())
|
||||||
|
suite.Equal(mockBlob, subRes1.GetSlicedBlob())
|
||||||
|
}
|
||||||
|
|
||||||
func TestResult_MergeRequestCost(t *testing.T) {
|
func TestResult_MergeRequestCost(t *testing.T) {
|
||||||
costs := []*internalpb.CostAggregation{
|
costs := []*internalpb.CostAggregation{
|
||||||
{
|
{
|
||||||
|
|||||||
@ -8,39 +8,28 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ReduceInfo struct {
|
|
||||||
nq int64
|
|
||||||
topK int64
|
|
||||||
groupByFieldID int64
|
|
||||||
groupSize int64
|
|
||||||
metricType string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewReduceInfo(nq int64, topK int64, groupByFieldID int64, groupSize int64, metric string) *ReduceInfo {
|
|
||||||
return &ReduceInfo{nq, topK, groupByFieldID, groupSize, metric}
|
|
||||||
}
|
|
||||||
|
|
||||||
type SearchReduce interface {
|
type SearchReduce interface {
|
||||||
ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error)
|
ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type SearchCommonReduce struct{}
|
type SearchCommonReduce struct{}
|
||||||
|
|
||||||
func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) {
|
func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error) {
|
||||||
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
|
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
|
||||||
defer sp.End()
|
defer sp.End()
|
||||||
log := log.Ctx(ctx)
|
log := log.Ctx(ctx)
|
||||||
|
|
||||||
if len(searchResultData) == 0 {
|
if len(searchResultData) == 0 {
|
||||||
return &schemapb.SearchResultData{
|
return &schemapb.SearchResultData{
|
||||||
NumQueries: info.nq,
|
NumQueries: info.GetNq(),
|
||||||
TopK: info.topK,
|
TopK: info.GetTopK(),
|
||||||
FieldsData: make([]*schemapb.FieldData, 0),
|
FieldsData: make([]*schemapb.FieldData, 0),
|
||||||
Scores: make([]float32, 0),
|
Scores: make([]float32, 0),
|
||||||
Ids: &schemapb.IDs{},
|
Ids: &schemapb.IDs{},
|
||||||
@ -48,8 +37,8 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
ret := &schemapb.SearchResultData{
|
ret := &schemapb.SearchResultData{
|
||||||
NumQueries: info.nq,
|
NumQueries: info.GetNq(),
|
||||||
TopK: info.topK,
|
TopK: info.GetTopK(),
|
||||||
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
||||||
Scores: make([]float32, 0),
|
Scores: make([]float32, 0),
|
||||||
Ids: &schemapb.IDs{},
|
Ids: &schemapb.IDs{},
|
||||||
@ -59,7 +48,7 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
|||||||
resultOffsets := make([][]int64, len(searchResultData))
|
resultOffsets := make([][]int64, len(searchResultData))
|
||||||
for i := 0; i < len(searchResultData); i++ {
|
for i := 0; i < len(searchResultData); i++ {
|
||||||
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
||||||
for j := int64(1); j < info.nq; j++ {
|
for j := int64(1); j < info.GetNq(); j++ {
|
||||||
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
||||||
}
|
}
|
||||||
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
||||||
@ -68,11 +57,11 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
|||||||
var skipDupCnt int64
|
var skipDupCnt int64
|
||||||
var retSize int64
|
var retSize int64
|
||||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||||
for i := int64(0); i < info.nq; i++ {
|
for i := int64(0); i < info.GetNq(); i++ {
|
||||||
offsets := make([]int64, len(searchResultData))
|
offsets := make([]int64, len(searchResultData))
|
||||||
idSet := make(map[interface{}]struct{})
|
idSet := make(map[interface{}]struct{})
|
||||||
var j int64
|
var j int64
|
||||||
for j = 0; j < info.topK; {
|
for j = 0; j < info.GetTopK(); {
|
||||||
sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i)
|
sel := SelectSearchResultData(searchResultData, resultOffsets, offsets, i)
|
||||||
if sel == -1 {
|
if sel == -1 {
|
||||||
break
|
break
|
||||||
@ -113,15 +102,15 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
|
|||||||
|
|
||||||
type SearchGroupByReduce struct{}
|
type SearchGroupByReduce struct{}
|
||||||
|
|
||||||
func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *ReduceInfo) (*schemapb.SearchResultData, error) {
|
func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, searchResultData []*schemapb.SearchResultData, info *reduce.ResultInfo) (*schemapb.SearchResultData, error) {
|
||||||
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
|
ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceSearchResultData")
|
||||||
defer sp.End()
|
defer sp.End()
|
||||||
log := log.Ctx(ctx)
|
log := log.Ctx(ctx)
|
||||||
|
|
||||||
if len(searchResultData) == 0 {
|
if len(searchResultData) == 0 {
|
||||||
return &schemapb.SearchResultData{
|
return &schemapb.SearchResultData{
|
||||||
NumQueries: info.nq,
|
NumQueries: info.GetNq(),
|
||||||
TopK: info.topK,
|
TopK: info.GetTopK(),
|
||||||
FieldsData: make([]*schemapb.FieldData, 0),
|
FieldsData: make([]*schemapb.FieldData, 0),
|
||||||
Scores: make([]float32, 0),
|
Scores: make([]float32, 0),
|
||||||
Ids: &schemapb.IDs{},
|
Ids: &schemapb.IDs{},
|
||||||
@ -129,8 +118,8 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
ret := &schemapb.SearchResultData{
|
ret := &schemapb.SearchResultData{
|
||||||
NumQueries: info.nq,
|
NumQueries: info.GetNq(),
|
||||||
TopK: info.topK,
|
TopK: info.GetTopK(),
|
||||||
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
FieldsData: make([]*schemapb.FieldData, len(searchResultData[0].FieldsData)),
|
||||||
Scores: make([]float32, 0),
|
Scores: make([]float32, 0),
|
||||||
Ids: &schemapb.IDs{},
|
Ids: &schemapb.IDs{},
|
||||||
@ -140,7 +129,7 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||||||
resultOffsets := make([][]int64, len(searchResultData))
|
resultOffsets := make([][]int64, len(searchResultData))
|
||||||
for i := 0; i < len(searchResultData); i++ {
|
for i := 0; i < len(searchResultData); i++ {
|
||||||
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
resultOffsets[i] = make([]int64, len(searchResultData[i].Topks))
|
||||||
for j := int64(1); j < info.nq; j++ {
|
for j := int64(1); j < info.GetNq(); j++ {
|
||||||
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
resultOffsets[i][j] = resultOffsets[i][j-1] + searchResultData[i].Topks[j-1]
|
||||||
}
|
}
|
||||||
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
|
||||||
@ -149,13 +138,13 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||||||
var filteredCount int64
|
var filteredCount int64
|
||||||
var retSize int64
|
var retSize int64
|
||||||
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
|
||||||
groupSize := info.groupSize
|
groupSize := info.GetGroupSize()
|
||||||
if groupSize <= 0 {
|
if groupSize <= 0 {
|
||||||
groupSize = 1
|
groupSize = 1
|
||||||
}
|
}
|
||||||
groupBound := info.topK * groupSize
|
groupBound := info.GetTopK() * groupSize
|
||||||
|
|
||||||
for i := int64(0); i < info.nq; i++ {
|
for i := int64(0); i < info.GetNq(); i++ {
|
||||||
offsets := make([]int64, len(searchResultData))
|
offsets := make([]int64, len(searchResultData))
|
||||||
|
|
||||||
idSet := make(map[interface{}]struct{})
|
idSet := make(map[interface{}]struct{})
|
||||||
@ -178,7 +167,7 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||||||
}
|
}
|
||||||
|
|
||||||
groupCount := groupByValueMap[groupByVal]
|
groupCount := groupByValueMap[groupByVal]
|
||||||
if groupCount == 0 && int64(len(groupByValueMap)) >= info.topK {
|
if groupCount == 0 && int64(len(groupByValueMap)) >= info.GetTopK() {
|
||||||
// exceed the limit for group count, filter this entity
|
// exceed the limit for group count, filter this entity
|
||||||
filteredCount++
|
filteredCount++
|
||||||
} else if groupCount >= groupSize {
|
} else if groupCount >= groupSize {
|
||||||
@ -219,8 +208,8 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
|
|||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitSearchReducer(info *ReduceInfo) SearchReduce {
|
func InitSearchReducer(info *reduce.ResultInfo) SearchReduce {
|
||||||
if info.groupByFieldID > 0 {
|
if info.GetGroupByFieldId() > 0 {
|
||||||
return &SearchGroupByReduce{}
|
return &SearchGroupByReduce{}
|
||||||
}
|
}
|
||||||
return &SearchCommonReduce{}
|
return &SearchCommonReduce{}
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,7 +29,7 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() {
|
|||||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||||
dataArray = append(dataArray, data1)
|
dataArray = append(dataArray, data1)
|
||||||
dataArray = append(dataArray, data2)
|
dataArray = append(dataArray, data2)
|
||||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk}
|
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1)
|
||||||
searchReduce := InitSearchReducer(reduceInfo)
|
searchReduce := InitSearchReducer(reduceInfo)
|
||||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||||
suite.Nil(err)
|
suite.Nil(err)
|
||||||
@ -47,7 +48,7 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() {
|
|||||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||||
dataArray = append(dataArray, data1)
|
dataArray = append(dataArray, data1)
|
||||||
dataArray = append(dataArray, data2)
|
dataArray = append(dataArray, data2)
|
||||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk}
|
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1)
|
||||||
searchReduce := InitSearchReducer(reduceInfo)
|
searchReduce := InitSearchReducer(reduceInfo)
|
||||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||||
suite.Nil(err)
|
suite.Nil(err)
|
||||||
@ -96,7 +97,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
|
|||||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||||
dataArray = append(dataArray, data1)
|
dataArray = append(dataArray, data1)
|
||||||
dataArray = append(dataArray, data2)
|
dataArray = append(dataArray, data2)
|
||||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
|
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101)
|
||||||
searchReduce := InitSearchReducer(reduceInfo)
|
searchReduce := InitSearchReducer(reduceInfo)
|
||||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||||
suite.Nil(err)
|
suite.Nil(err)
|
||||||
@ -140,7 +141,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
|
|||||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||||
dataArray = append(dataArray, data1)
|
dataArray = append(dataArray, data1)
|
||||||
dataArray = append(dataArray, data2)
|
dataArray = append(dataArray, data2)
|
||||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
|
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101)
|
||||||
searchReduce := InitSearchReducer(reduceInfo)
|
searchReduce := InitSearchReducer(reduceInfo)
|
||||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||||
suite.Nil(err)
|
suite.Nil(err)
|
||||||
@ -184,7 +185,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
|
|||||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||||
dataArray = append(dataArray, data1)
|
dataArray = append(dataArray, data1)
|
||||||
dataArray = append(dataArray, data2)
|
dataArray = append(dataArray, data2)
|
||||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101}
|
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(1).WithGroupByField(101)
|
||||||
searchReduce := InitSearchReducer(reduceInfo)
|
searchReduce := InitSearchReducer(reduceInfo)
|
||||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||||
suite.Nil(err)
|
suite.Nil(err)
|
||||||
@ -228,7 +229,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
|
|||||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||||
dataArray = append(dataArray, data1)
|
dataArray = append(dataArray, data1)
|
||||||
dataArray = append(dataArray, data2)
|
dataArray = append(dataArray, data2)
|
||||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3}
|
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(3).WithGroupByField(101)
|
||||||
searchReduce := InitSearchReducer(reduceInfo)
|
searchReduce := InitSearchReducer(reduceInfo)
|
||||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||||
suite.Nil(err)
|
suite.Nil(err)
|
||||||
@ -239,7 +240,7 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
|
|||||||
|
|
||||||
suite.Run("reduce_group_by_empty_input", func() {
|
suite.Run("reduce_group_by_empty_input", func() {
|
||||||
dataArray := make([]*schemapb.SearchResultData, 0)
|
dataArray := make([]*schemapb.SearchResultData, 0)
|
||||||
reduceInfo := &ReduceInfo{nq: nq, topK: topk, groupByFieldID: 101, groupSize: 3}
|
reduceInfo := reduce.NewReduceSearchResultInfo(nq, topk).WithGroupSize(3).WithGroupByField(101)
|
||||||
searchReduce := InitSearchReducer(reduceInfo)
|
searchReduce := InitSearchReducer(reduceInfo)
|
||||||
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
res, err := searchReduce.ReduceSearchResultData(context.TODO(), dataArray, reduceInfo)
|
||||||
suite.Nil(err)
|
suite.Nil(err)
|
||||||
|
|||||||
@ -753,69 +753,41 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
toReduceResults := make([]*internalpb.SearchResults, len(req.GetDmlChannels()))
|
if len(req.GetDmlChannels()) != 1 {
|
||||||
runningGp, runningCtx := errgroup.WithContext(ctx)
|
err := merr.WrapErrParameterInvalid(1, len(req.GetDmlChannels()), "count of channel to be searched should only be 1, wrong code")
|
||||||
|
resp.Status = merr.Status(err)
|
||||||
for i, ch := range req.GetDmlChannels() {
|
log.Warn("got wrong number of channels to be searched", zap.Error(err))
|
||||||
ch := ch
|
return resp, nil
|
||||||
req := &querypb.SearchRequest{
|
|
||||||
Req: req.Req,
|
|
||||||
DmlChannels: []string{ch},
|
|
||||||
SegmentIDs: req.SegmentIDs,
|
|
||||||
Scope: req.Scope,
|
|
||||||
TotalChannelNum: req.TotalChannelNum,
|
|
||||||
}
|
|
||||||
|
|
||||||
i := i
|
|
||||||
runningGp.Go(func() error {
|
|
||||||
ret, err := node.searchChannel(runningCtx, req, ch)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := merr.Error(ret.GetStatus()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
toReduceResults[i] = ret
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
if err := runningGp.Wait(); err != nil {
|
|
||||||
|
ch := req.GetDmlChannels()[0]
|
||||||
|
channelReq := &querypb.SearchRequest{
|
||||||
|
Req: req.Req,
|
||||||
|
DmlChannels: []string{ch},
|
||||||
|
SegmentIDs: req.SegmentIDs,
|
||||||
|
Scope: req.Scope,
|
||||||
|
TotalChannelNum: req.TotalChannelNum,
|
||||||
|
}
|
||||||
|
ret, err := node.searchChannel(ctx, channelReq, ch)
|
||||||
|
if err != nil {
|
||||||
resp.Status = merr.Status(err)
|
resp.Status = merr.Status(err)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
tr.RecordSpan()
|
tr.RecordSpan()
|
||||||
var result *internalpb.SearchResults
|
ret.Status = merr.Success()
|
||||||
var err2 error
|
|
||||||
if req.GetReq().GetIsAdvanced() {
|
|
||||||
result, err2 = segments.ReduceAdvancedSearchResults(ctx, toReduceResults, req.Req.GetNq())
|
|
||||||
} else {
|
|
||||||
result, err2 = segments.ReduceSearchResults(ctx, toReduceResults, segments.NewReduceInfo(req.Req.GetNq(),
|
|
||||||
req.Req.GetTopk(),
|
|
||||||
req.Req.GetExtraSearchParam().GetGroupByFieldId(),
|
|
||||||
req.Req.GetExtraSearchParam().GetGroupSize(),
|
|
||||||
req.Req.GetMetricType()))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err2 != nil {
|
|
||||||
log.Warn("failed to reduce search results", zap.Error(err2))
|
|
||||||
resp.Status = merr.Status(err2)
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
result.Status = merr.Success()
|
|
||||||
|
|
||||||
reduceLatency := tr.RecordSpan()
|
reduceLatency := tr.RecordSpan()
|
||||||
metrics.QueryNodeReduceLatency.
|
metrics.QueryNodeReduceLatency.
|
||||||
WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards, metrics.BatchReduce).
|
WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards, metrics.BatchReduce).
|
||||||
Observe(float64(reduceLatency.Milliseconds()))
|
Observe(float64(reduceLatency.Milliseconds()))
|
||||||
|
|
||||||
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.SearchLabel).
|
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(node.GetNodeID(), 10), metrics.SearchLabel).
|
||||||
Add(float64(proto.Size(req)))
|
Add(float64(proto.Size(req)))
|
||||||
|
|
||||||
if result.GetCostAggregation() != nil {
|
if ret.GetCostAggregation() != nil {
|
||||||
result.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
|
ret.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
|
||||||
}
|
}
|
||||||
return result, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// only used for delegator query segments from worker
|
// only used for delegator query segments from worker
|
||||||
|
|||||||
92
internal/util/reduce/reduce_info.go
Normal file
92
internal/util/reduce/reduce_info.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
package reduce
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ResultInfo struct {
|
||||||
|
nq int64
|
||||||
|
topK int64
|
||||||
|
metricType string
|
||||||
|
pkType schemapb.DataType
|
||||||
|
offset int64
|
||||||
|
groupByFieldId int64
|
||||||
|
groupSize int64
|
||||||
|
isAdvance bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewReduceSearchResultInfo(
|
||||||
|
nq int64,
|
||||||
|
topK int64,
|
||||||
|
) *ResultInfo {
|
||||||
|
return &ResultInfo{
|
||||||
|
nq: nq,
|
||||||
|
topK: topK,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) WithMetricType(metricType string) *ResultInfo {
|
||||||
|
r.metricType = metricType
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) WithPkType(pkType schemapb.DataType) *ResultInfo {
|
||||||
|
r.pkType = pkType
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) WithOffset(offset int64) *ResultInfo {
|
||||||
|
r.offset = offset
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) WithGroupByField(groupByField int64) *ResultInfo {
|
||||||
|
r.groupByFieldId = groupByField
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) WithGroupSize(groupSize int64) *ResultInfo {
|
||||||
|
r.groupSize = groupSize
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) WithAdvance(advance bool) *ResultInfo {
|
||||||
|
r.isAdvance = advance
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) GetNq() int64 {
|
||||||
|
return r.nq
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) GetTopK() int64 {
|
||||||
|
return r.topK
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) GetMetricType() string {
|
||||||
|
return r.metricType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) GetPkType() schemapb.DataType {
|
||||||
|
return r.pkType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) GetOffset() int64 {
|
||||||
|
return r.offset
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) GetGroupByFieldId() int64 {
|
||||||
|
return r.groupByFieldId
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) GetGroupSize() int64 {
|
||||||
|
return r.groupSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) GetIsAdvance() bool {
|
||||||
|
return r.isAdvance
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ResultInfo) SetMetricType(metricType string) {
|
||||||
|
r.metricType = metricType
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user