mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
enhance: support recall estimation (#38017)
issue: #37899 Only `search` api will be supported --------- Signed-off-by: chasingegg <chao.gao@zilliz.com>
This commit is contained in:
parent
dc85d8e968
commit
8977454311
@ -128,6 +128,7 @@ message SearchRequest {
|
||||
int64 group_size = 24;
|
||||
int64 field_id = 25;
|
||||
bool is_topk_reduce = 26;
|
||||
bool is_recall_evaluation = 27;
|
||||
}
|
||||
|
||||
message SubSearchResults {
|
||||
@ -164,6 +165,7 @@ message SearchResults {
|
||||
bool is_advanced = 16;
|
||||
int64 all_search_count = 17;
|
||||
bool is_topk_reduce = 18;
|
||||
bool is_recall_evaluation = 19;
|
||||
}
|
||||
|
||||
message CostAggregation {
|
||||
|
||||
@ -3000,12 +3000,13 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
||||
optimizedSearch := true
|
||||
resultSizeInsufficient := false
|
||||
isTopkReduce := false
|
||||
isRecallEvaluation := false
|
||||
err2 := retry.Handle(ctx, func() (bool, error) {
|
||||
rsp, resultSizeInsufficient, isTopkReduce, err = node.search(ctx, request, optimizedSearch)
|
||||
rsp, resultSizeInsufficient, isTopkReduce, isRecallEvaluation, err = node.search(ctx, request, optimizedSearch, false)
|
||||
if merr.Ok(rsp.GetStatus()) && optimizedSearch && resultSizeInsufficient && isTopkReduce && paramtable.Get().AutoIndexConfig.EnableResultLimitCheck.GetAsBool() {
|
||||
// without optimize search
|
||||
optimizedSearch = false
|
||||
rsp, resultSizeInsufficient, isTopkReduce, err = node.search(ctx, request, optimizedSearch)
|
||||
rsp, resultSizeInsufficient, isTopkReduce, isRecallEvaluation, err = node.search(ctx, request, optimizedSearch, false)
|
||||
metrics.ProxyRetrySearchCount.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel,
|
||||
@ -3023,6 +3024,23 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
||||
if errors.Is(merr.Error(rsp.GetStatus()), merr.ErrInconsistentRequery) {
|
||||
return true, merr.Error(rsp.GetStatus())
|
||||
}
|
||||
// search for ground truth and compute recall
|
||||
if isRecallEvaluation && merr.Ok(rsp.GetStatus()) {
|
||||
var rspGT *milvuspb.SearchResults
|
||||
rspGT, _, _, _, err = node.search(ctx, request, false, true)
|
||||
metrics.ProxyRecallSearchCount.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel,
|
||||
request.GetCollectionName(),
|
||||
).Inc()
|
||||
if merr.Ok(rspGT.GetStatus()) {
|
||||
return false, computeRecall(rsp.GetResults(), rspGT.GetResults())
|
||||
}
|
||||
if errors.Is(merr.Error(rspGT.GetStatus()), merr.ErrInconsistentRequery) {
|
||||
return true, merr.Error(rspGT.GetStatus())
|
||||
}
|
||||
return false, merr.Error(rspGT.GetStatus())
|
||||
}
|
||||
return false, nil
|
||||
})
|
||||
if err2 != nil {
|
||||
@ -3031,13 +3049,11 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
|
||||
return rsp, err
|
||||
}
|
||||
|
||||
func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, optimizedSearch bool) (*milvuspb.SearchResults, bool, bool, error) {
|
||||
receiveSize := proto.Size(request)
|
||||
metrics.ProxyReceiveBytes.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.SearchLabel,
|
||||
request.GetCollectionName(),
|
||||
).Add(float64(receiveSize))
|
||||
func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, optimizedSearch bool, isRecallEvaluation bool) (*milvuspb.SearchResults, bool, bool, bool, error) {
|
||||
metrics.GetStats(ctx).
|
||||
SetNodeID(paramtable.GetNodeID()).
|
||||
SetInboundLabel(metrics.SearchLabel).
|
||||
SetCollectionName(request.GetCollectionName())
|
||||
|
||||
metrics.ProxyReceivedNQ.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
@ -3048,7 +3064,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
|
||||
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
}, false, false, nil
|
||||
}, false, false, false, nil
|
||||
}
|
||||
|
||||
method := "Search"
|
||||
@ -3069,7 +3085,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
|
||||
if err != nil {
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
}, false, false, nil
|
||||
}, false, false, false, nil
|
||||
}
|
||||
|
||||
request.PlaceholderGroup = placeholderGroupBytes
|
||||
@ -3083,8 +3099,9 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_Search),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
IsTopkReduce: optimizedSearch,
|
||||
ReqID: paramtable.GetNodeID(),
|
||||
IsTopkReduce: optimizedSearch,
|
||||
IsRecallEvaluation: isRecallEvaluation,
|
||||
},
|
||||
request: request,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
@ -3146,7 +3163,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
|
||||
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
}, false, false, nil
|
||||
}, false, false, false, nil
|
||||
}
|
||||
tr.CtxRecord(ctx, "search request enqueue")
|
||||
|
||||
@ -3172,7 +3189,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
|
||||
|
||||
return &milvuspb.SearchResults{
|
||||
Status: merr.Status(err),
|
||||
}, false, false, nil
|
||||
}, false, false, false, nil
|
||||
}
|
||||
|
||||
span := tr.CtxRecord(ctx, "wait search result")
|
||||
@ -3229,7 +3246,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
|
||||
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeSearch, dbName, username).Add(float64(v))
|
||||
}
|
||||
}
|
||||
return qt.result, qt.resultSizeInsufficient, qt.isTopkReduce, nil
|
||||
return qt.result, qt.resultSizeInsufficient, qt.isTopkReduce, qt.isRecallEvaluation, nil
|
||||
}
|
||||
|
||||
func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
|
||||
@ -3272,12 +3289,10 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
||||
}
|
||||
|
||||
func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest, optimizedSearch bool) (*milvuspb.SearchResults, bool, bool, error) {
|
||||
receiveSize := proto.Size(request)
|
||||
metrics.ProxyReceiveBytes.WithLabelValues(
|
||||
strconv.FormatInt(paramtable.GetNodeID(), 10),
|
||||
metrics.HybridSearchLabel,
|
||||
request.GetCollectionName(),
|
||||
).Add(float64(receiveSize))
|
||||
metrics.GetStats(ctx).
|
||||
SetNodeID(paramtable.GetNodeID()).
|
||||
SetInboundLabel(metrics.HybridSearchLabel).
|
||||
SetCollectionName(request.GetCollectionName())
|
||||
|
||||
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
|
||||
return &milvuspb.SearchResults{
|
||||
|
||||
@ -65,6 +65,7 @@ type searchTask struct {
|
||||
mustUsePartitionKey bool
|
||||
resultSizeInsufficient bool
|
||||
isTopkReduce bool
|
||||
isRecallEvaluation bool
|
||||
|
||||
userOutputFields []string
|
||||
userDynamicFields []string
|
||||
@ -647,10 +648,14 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
||||
t.queryChannelsTs = make(map[string]uint64)
|
||||
t.relatedDataSize = 0
|
||||
isTopkReduce := false
|
||||
isRecallEvaluation := false
|
||||
for _, r := range toReduceResults {
|
||||
if r.GetIsTopkReduce() {
|
||||
isTopkReduce = true
|
||||
}
|
||||
if r.GetIsRecallEvaluation() {
|
||||
isRecallEvaluation = true
|
||||
}
|
||||
t.relatedDataSize += r.GetCostAggregation().GetTotalRelatedDataSize()
|
||||
for ch, ts := range r.GetChannelsMvcc() {
|
||||
t.queryChannelsTs[ch] = ts
|
||||
@ -731,6 +736,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
||||
}
|
||||
t.resultSizeInsufficient = resultSizeInsufficient
|
||||
t.isTopkReduce = isTopkReduce
|
||||
t.isRecallEvaluation = isRecallEvaluation
|
||||
t.result.CollectionName = t.collectionName
|
||||
t.fillInFieldInfo()
|
||||
|
||||
|
||||
@ -1253,6 +1253,75 @@ func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int
|
||||
return pkNames, fieldIDs
|
||||
}
|
||||
|
||||
func recallCal[T string | int64](results []T, gts []T) float32 {
|
||||
hit := 0
|
||||
total := 0
|
||||
for _, r := range results {
|
||||
total++
|
||||
for _, gt := range gts {
|
||||
if r == gt {
|
||||
hit++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return float32(hit) / float32(total)
|
||||
}
|
||||
|
||||
func computeRecall(results *schemapb.SearchResultData, gts *schemapb.SearchResultData) error {
|
||||
if results.GetNumQueries() != gts.GetNumQueries() {
|
||||
return fmt.Errorf("num of queries is inconsistent between search results(%d) and ground truth(%d)", results.GetNumQueries(), gts.GetNumQueries())
|
||||
}
|
||||
|
||||
switch results.GetIds().GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
switch gts.GetIds().GetIdField().(type) {
|
||||
case *schemapb.IDs_IntId:
|
||||
currentResultIndex := int64(0)
|
||||
currentGTIndex := int64(0)
|
||||
recalls := make([]float32, 0, results.GetNumQueries())
|
||||
for i := 0; i < int(results.GetNumQueries()); i++ {
|
||||
currentResultTopk := results.GetTopks()[i]
|
||||
currentGTTopk := gts.GetTopks()[i]
|
||||
recalls = append(recalls, recallCal(results.GetIds().GetIntId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk],
|
||||
gts.GetIds().GetIntId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk]))
|
||||
currentResultIndex += currentResultTopk
|
||||
currentGTIndex += currentGTTopk
|
||||
}
|
||||
results.Recalls = recalls
|
||||
return nil
|
||||
case *schemapb.IDs_StrId:
|
||||
return fmt.Errorf("pk type is inconsistent between search results(int64) and ground truth(string)")
|
||||
default:
|
||||
return fmt.Errorf("unsupported pk type")
|
||||
}
|
||||
|
||||
case *schemapb.IDs_StrId:
|
||||
switch gts.GetIds().GetIdField().(type) {
|
||||
case *schemapb.IDs_StrId:
|
||||
currentResultIndex := int64(0)
|
||||
currentGTIndex := int64(0)
|
||||
recalls := make([]float32, 0, results.GetNumQueries())
|
||||
for i := 0; i < int(results.GetNumQueries()); i++ {
|
||||
currentResultTopk := results.GetTopks()[i]
|
||||
currentGTTopk := gts.GetTopks()[i]
|
||||
recalls = append(recalls, recallCal(results.GetIds().GetStrId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk],
|
||||
gts.GetIds().GetStrId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk]))
|
||||
currentResultIndex += currentResultTopk
|
||||
currentGTIndex += currentGTTopk
|
||||
}
|
||||
results.Recalls = recalls
|
||||
return nil
|
||||
case *schemapb.IDs_IntId:
|
||||
return fmt.Errorf("pk type is inconsistent between search results(string) and ground truth(int64)")
|
||||
default:
|
||||
return fmt.Errorf("unsupported pk type")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported pk type")
|
||||
}
|
||||
}
|
||||
|
||||
// Support wildcard in output fields:
|
||||
//
|
||||
// "*" - all fields
|
||||
|
||||
@ -3079,3 +3079,165 @@ func TestValidateFunctionBasicParams(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestComputeRecall(t *testing.T) {
|
||||
t.Run("normal case1", func(t *testing.T) {
|
||||
result1 := &schemapb.SearchResultData{
|
||||
NumQueries: 3,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"11", "9", "8", "5", "3", "1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.1},
|
||||
Topks: []int64{2, 2, 2},
|
||||
}
|
||||
|
||||
gt := &schemapb.SearchResultData{
|
||||
NumQueries: 3,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"11", "10", "8", "5", "3", "1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.98, 0.8, 0.5, 0.3, 0.1},
|
||||
Topks: []int64{2, 2, 2},
|
||||
}
|
||||
|
||||
err := computeRecall(result1, gt)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result1.Recalls[0], float32(0.5))
|
||||
assert.Equal(t, result1.Recalls[1], float32(1.0))
|
||||
assert.Equal(t, result1.Recalls[2], float32(1.0))
|
||||
})
|
||||
|
||||
t.Run("normal case2", func(t *testing.T) {
|
||||
result1 := &schemapb.SearchResultData{
|
||||
NumQueries: 2,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
|
||||
Topks: []int64{5, 5},
|
||||
}
|
||||
|
||||
gt := &schemapb.SearchResultData{
|
||||
NumQueries: 2,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 6, 5, 4, 1, 34, 23, 22, 20},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
|
||||
Topks: []int64{5, 5},
|
||||
}
|
||||
|
||||
err := computeRecall(result1, gt)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, result1.Recalls[0], float32(0.6))
|
||||
assert.Equal(t, result1.Recalls[1], float32(0.8))
|
||||
})
|
||||
|
||||
t.Run("not match size", func(t *testing.T) {
|
||||
result1 := &schemapb.SearchResultData{
|
||||
NumQueries: 2,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
|
||||
Topks: []int64{5, 5},
|
||||
}
|
||||
|
||||
gt := &schemapb.SearchResultData{
|
||||
NumQueries: 1,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 6, 5, 4},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3},
|
||||
Topks: []int64{5},
|
||||
}
|
||||
|
||||
err := computeRecall(result1, gt)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not match type1", func(t *testing.T) {
|
||||
result1 := &schemapb.SearchResultData{
|
||||
NumQueries: 2,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
|
||||
Topks: []int64{5, 5},
|
||||
}
|
||||
|
||||
gt := &schemapb.SearchResultData{
|
||||
NumQueries: 2,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"11", "10", "8", "5", "3", "1", "23", "22", "21", "20"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
|
||||
Topks: []int64{5, 5},
|
||||
}
|
||||
|
||||
err := computeRecall(result1, gt)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("not match type2", func(t *testing.T) {
|
||||
result1 := &schemapb.SearchResultData{
|
||||
NumQueries: 2,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_StrId{
|
||||
StrId: &schemapb.StringArray{
|
||||
Data: []string{"11", "10", "8", "5", "3", "1", "23", "22", "21", "20"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
|
||||
Topks: []int64{5, 5},
|
||||
}
|
||||
|
||||
gt := &schemapb.SearchResultData{
|
||||
NumQueries: 2,
|
||||
Ids: &schemapb.IDs{
|
||||
IdField: &schemapb.IDs_IntId{
|
||||
IntId: &schemapb.LongArray{
|
||||
Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21},
|
||||
},
|
||||
},
|
||||
},
|
||||
Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4},
|
||||
Topks: []int64{5, 5},
|
||||
}
|
||||
|
||||
err := computeRecall(result1, gt)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@ -65,6 +65,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
||||
|
||||
channelsMvcc := make(map[string]uint64)
|
||||
isTopkReduce := false
|
||||
isRecallEvaluation := false
|
||||
for _, r := range results {
|
||||
for ch, ts := range r.GetChannelsMvcc() {
|
||||
channelsMvcc[ch] = ts
|
||||
@ -72,6 +73,9 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
||||
if r.GetIsTopkReduce() {
|
||||
isTopkReduce = true
|
||||
}
|
||||
if r.GetIsRecallEvaluation() {
|
||||
isRecallEvaluation = true
|
||||
}
|
||||
// shouldn't let new SearchResults.MetricType to be empty, though the req.MetricType is empty
|
||||
if info.GetMetricType() == "" {
|
||||
info.SetMetricType(r.MetricType)
|
||||
@ -126,6 +130,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
|
||||
searchResults.CostAggregation.TotalRelatedDataSize = relatedDataSize
|
||||
searchResults.ChannelsMvcc = channelsMvcc
|
||||
searchResults.IsTopkReduce = isTopkReduce
|
||||
searchResults.IsRecallEvaluation = isRecallEvaluation
|
||||
return searchResults, nil
|
||||
}
|
||||
|
||||
|
||||
@ -733,6 +733,7 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe
|
||||
if req.GetReq().GetIsTopkReduce() {
|
||||
resp.IsTopkReduce = true
|
||||
}
|
||||
resp.IsRecallEvaluation = req.GetReq().GetIsRecallEvaluation()
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
|
||||
@ -1177,7 +1177,7 @@ func (suite *ServiceSuite) syncDistribution(ctx context.Context) {
|
||||
}
|
||||
|
||||
// Test Search
|
||||
func (suite *ServiceSuite) genCSearchRequest(nq int64, dataType schemapb.DataType, fieldID int64, metricType string, isTopkReduce bool) (*internalpb.SearchRequest, error) {
|
||||
func (suite *ServiceSuite) genCSearchRequest(nq int64, dataType schemapb.DataType, fieldID int64, metricType string, isTopkReduce bool, isRecallEvaluation bool) (*internalpb.SearchRequest, error) {
|
||||
placeHolder, err := genPlaceHolderGroup(nq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -1202,6 +1202,7 @@ func (suite *ServiceSuite) genCSearchRequest(nq int64, dataType schemapb.DataTyp
|
||||
Nq: nq,
|
||||
MvccTimestamp: typeutil.MaxTimestamp,
|
||||
IsTopkReduce: isTopkReduce,
|
||||
IsRecallEvaluation: isRecallEvaluation,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -1212,7 +1213,7 @@ func (suite *ServiceSuite) TestSearch_Normal() {
|
||||
suite.TestLoadSegments_Int64()
|
||||
suite.syncDistribution(ctx)
|
||||
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false)
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false, false)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
|
||||
@ -1237,7 +1238,7 @@ func (suite *ServiceSuite) TestSearch_Concurrent() {
|
||||
futures := make([]*conc.Future[*internalpb.SearchResults], 0, concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
future := conc.Go(func() (*internalpb.SearchResults, error) {
|
||||
creq, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType, false)
|
||||
creq, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType, false, false)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
|
||||
@ -1263,7 +1264,7 @@ func (suite *ServiceSuite) TestSearch_Failed() {
|
||||
|
||||
// data
|
||||
schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, "invalidMetricType", false)
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, "invalidMetricType", false, false)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
|
||||
@ -1388,7 +1389,7 @@ func (suite *ServiceSuite) TestSearchSegments_Normal() {
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
suite.TestLoadSegments_Int64()
|
||||
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false)
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false, false)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
|
||||
@ -1400,13 +1401,15 @@ func (suite *ServiceSuite) TestSearchSegments_Normal() {
|
||||
rsp, err := suite.node.SearchSegments(ctx, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(rsp.GetIsTopkReduce(), false)
|
||||
suite.Equal(rsp.GetIsRecallEvaluation(), false)
|
||||
suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode())
|
||||
|
||||
req.Req, err = suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, true)
|
||||
req.Req, err = suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, true, true)
|
||||
suite.NoError(err)
|
||||
rsp, err = suite.node.SearchSegments(ctx, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(rsp.GetIsTopkReduce(), true)
|
||||
suite.Equal(rsp.GetIsRecallEvaluation(), true)
|
||||
suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
@ -1416,7 +1419,7 @@ func (suite *ServiceSuite) TestStreamingSearch() {
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
suite.TestLoadSegments_Int64()
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.UseStreamComputing.Key, "true")
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false)
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false, true)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
FromShardLeader: true,
|
||||
@ -1430,6 +1433,7 @@ func (suite *ServiceSuite) TestStreamingSearch() {
|
||||
rsp, err := suite.node.SearchSegments(ctx, req)
|
||||
suite.NoError(err)
|
||||
suite.Equal(false, rsp.GetIsTopkReduce())
|
||||
suite.Equal(true, rsp.GetIsRecallEvaluation())
|
||||
suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode())
|
||||
}
|
||||
|
||||
@ -1438,7 +1442,7 @@ func (suite *ServiceSuite) TestStreamingSearchGrowing() {
|
||||
// pre
|
||||
suite.TestWatchDmChannelsInt64()
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.UseStreamComputing.Key, "true")
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false)
|
||||
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false, false)
|
||||
req := &querypb.SearchRequest{
|
||||
Req: creq,
|
||||
FromShardLeader: true,
|
||||
|
||||
@ -28,6 +28,7 @@ func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, query
|
||||
// no hook applied or disabled, just return
|
||||
if queryHook == nil || !paramtable.Get().AutoIndexConfig.Enable.GetAsBool() {
|
||||
req.Req.IsTopkReduce = false
|
||||
req.Req.IsRecallEvaluation = false
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@ -68,8 +69,9 @@ func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, query
|
||||
common.SegmentNumKey: estSegmentNum,
|
||||
common.WithFilterKey: withFilter,
|
||||
common.DataTypeKey: int32(plan.GetVectorAnns().GetVectorType()),
|
||||
common.WithOptimizeKey: paramtable.Get().AutoIndexConfig.EnableOptimize.GetAsBool() && req.GetReq().GetIsTopkReduce(),
|
||||
common.WithOptimizeKey: paramtable.Get().AutoIndexConfig.EnableOptimize.GetAsBool() && req.GetReq().GetIsTopkReduce() && queryInfo.GetGroupByFieldId() < 0,
|
||||
common.CollectionKey: req.GetReq().GetCollectionID(),
|
||||
common.RecallEvalKey: req.GetReq().GetIsRecallEvaluation(),
|
||||
}
|
||||
if withFilter && channelNum > 1 {
|
||||
params[common.ChannelNumKey] = channelNum
|
||||
@ -90,6 +92,11 @@ func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, query
|
||||
}
|
||||
req.Req.SerializedExprPlan = serializedExprPlan
|
||||
req.Req.IsTopkReduce = isTopkReduce
|
||||
if isRecallEvaluation, ok := params[common.RecallEvalKey]; ok {
|
||||
req.Req.IsRecallEvaluation = isRecallEvaluation.(bool) && queryInfo.GetGroupByFieldId() < 0
|
||||
} else {
|
||||
req.Req.IsRecallEvaluation = false
|
||||
}
|
||||
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
|
||||
default:
|
||||
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
|
||||
|
||||
@ -41,6 +41,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
params[common.RecallEvalKey] = true
|
||||
}).Return(nil)
|
||||
suite.queryHook = mockHook
|
||||
defer func() {
|
||||
@ -48,20 +49,21 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
|
||||
suite.queryHook = nil
|
||||
}()
|
||||
|
||||
getPlan := func(topk int64) *planpb.PlanNode {
|
||||
getPlan := func(topk int64, groupByField int64) *planpb.PlanNode {
|
||||
return &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: topk,
|
||||
SearchParams: `{"param": 1}`,
|
||||
Topk: topk,
|
||||
SearchParams: `{"param": 1}`,
|
||||
GroupByFieldId: groupByField,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
bs, err := proto.Marshal(getPlan(100))
|
||||
bs, err := proto.Marshal(getPlan(100, 101))
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
@ -72,9 +74,9 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 50, true, `{"param": 2}`)
|
||||
suite.verifyQueryInfo(req, 50, true, false, `{"param": 2}`)
|
||||
|
||||
bs, err = proto.Marshal(getPlan(50))
|
||||
bs, err = proto.Marshal(getPlan(50, -1))
|
||||
suite.Require().NoError(err)
|
||||
req, err = OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
@ -84,7 +86,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 50, false, `{"param": 2}`)
|
||||
suite.verifyQueryInfo(req, 50, false, true, `{"param": 2}`)
|
||||
})
|
||||
|
||||
suite.Run("disable optimization", func() {
|
||||
@ -112,7 +114,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 100, false, `{"param": 1}`)
|
||||
suite.verifyQueryInfo(req, 100, false, false, `{"param": 1}`)
|
||||
})
|
||||
|
||||
suite.Run("no_hook", func() {
|
||||
@ -140,7 +142,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 100, false, `{"param": 1}`)
|
||||
suite.verifyQueryInfo(req, 100, false, false, `{"param": 1}`)
|
||||
})
|
||||
|
||||
suite.Run("other_plannode", func() {
|
||||
@ -221,7 +223,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, isTopkReduce bool, param string) {
|
||||
func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, isTopkReduce bool, isRecallEvaluation bool, param string) {
|
||||
planBytes := req.GetReq().GetSerializedExprPlan()
|
||||
|
||||
plan := planpb.PlanNode{}
|
||||
@ -232,6 +234,7 @@ func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK in
|
||||
suite.Equal(topK, queryInfo.GetTopk())
|
||||
suite.Equal(param, queryInfo.GetSearchParams())
|
||||
suite.Equal(isTopkReduce, req.GetReq().GetIsTopkReduce())
|
||||
suite.Equal(isRecallEvaluation, req.GetReq().GetIsRecallEvaluation())
|
||||
}
|
||||
|
||||
func TestOptimizeSearchParam(t *testing.T) {
|
||||
|
||||
@ -123,6 +123,7 @@ const (
|
||||
ChannelNumKey = "channel_num"
|
||||
WithOptimizeKey = "with_optimize"
|
||||
CollectionKey = "collection"
|
||||
RecallEvalKey = "recall_eval"
|
||||
|
||||
IndexParamsKey = "params"
|
||||
IndexTypeKey = "index_type"
|
||||
|
||||
@ -408,6 +408,15 @@ var (
|
||||
Name: "retry_search_result_insufficient_cnt",
|
||||
Help: "counter of retry search which does not have enough results",
|
||||
}, []string{nodeIDLabelName, queryTypeLabelName, collectionName})
|
||||
|
||||
// ProxyRecallSearchCount records the counter that users issue recall evaluation requests, which are cpu-intensive
|
||||
ProxyRecallSearchCount = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: milvusNamespace,
|
||||
Subsystem: typeutil.ProxyRole,
|
||||
Name: "recall_search_cnt",
|
||||
Help: "counter of recall search",
|
||||
}, []string{nodeIDLabelName, queryTypeLabelName, collectionName})
|
||||
)
|
||||
|
||||
// RegisterProxy registers Proxy metrics
|
||||
@ -468,6 +477,7 @@ func RegisterProxy(registry *prometheus.Registry) {
|
||||
registry.MustRegister(MaxInsertRate)
|
||||
registry.MustRegister(ProxyRetrySearchCount)
|
||||
registry.MustRegister(ProxyRetrySearchResultInsufficientCount)
|
||||
registry.MustRegister(ProxyRecallSearchCount)
|
||||
|
||||
RegisterStreamingServiceClient(registry)
|
||||
}
|
||||
@ -593,4 +603,9 @@ func CleanupProxyCollectionMetrics(nodeID int64, collection string) {
|
||||
queryTypeLabelName: HybridSearchLabel,
|
||||
collectionName: collection,
|
||||
})
|
||||
ProxyRecallSearchCount.Delete(prometheus.Labels{
|
||||
nodeIDLabelName: strconv.FormatInt(nodeID, 10),
|
||||
queryTypeLabelName: SearchLabel,
|
||||
collectionName: collection,
|
||||
})
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user