From 903450f5c62af39b6eb3c85ece637429998d4618 Mon Sep 17 00:00:00 2001 From: Chun Han <116052805+MrPresent-Han@users.noreply.github.com> Date: Wed, 16 Oct 2024 18:51:23 +0800 Subject: [PATCH] enhance: add ts support for iterator(#22718) (#36572) related: #22718 Signed-off-by: MrPresent-Han Co-authored-by: MrPresent-Han --- go.mod | 2 +- go.sum | 4 +- internal/proxy/search_util.go | 65 +++-- internal/proxy/task_query.go | 13 +- internal/proxy/task_query_test.go | 303 +++++++++++++------- internal/proxy/task_search.go | 43 +-- internal/proxy/task_search_test.go | 129 ++++++--- internal/querynodev2/delegator/delegator.go | 4 +- 8 files changed, 372 insertions(+), 191 deletions(-) diff --git a/go.mod b/go.mod index 6586f2754f..33079cb112 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/klauspost/compress v1.17.7 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497 + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240930043709-0c23514e4c34 github.com/minio/minio-go/v7 v7.0.61 github.com/pingcap/log v1.1.1-0.20221015072633-39906604fb81 github.com/prometheus/client_golang v1.14.0 diff --git a/go.sum b/go.sum index 2f49de7f13..30b1ff14fe 100644 --- a/go.sum +++ b/go.sum @@ -625,8 +625,8 @@ github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119 h1:9VXijWu github.com/milvus-io/cgosymbolizer v0.0.0-20240722103217-b7dee0e50119/go.mod h1:DvXTE/K/RtHehxU8/GtDs4vFtfw64jJ3PaCnFri8CRg= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b h1:TfeY0NxYxZzUfIfYe5qYDBzt4ZYRqzUjTR6CvUzjat8= github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b/go.mod h1:iwW+9cWfIzzDseEBCCeDSN5SD16Tidvy8cwQ7ZY8Qj4= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497 h1:t4sQMbSy05p8qgMGvEGyLYYLoZ9fD1dushS1bj5X6+0= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240923125106-ef9b8fd69497/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240930043709-0c23514e4c34 h1:Fwxpg98128gfWRbQ1A3PMP9o2IfYZk7RSEy8rcoCWDA= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240930043709-0c23514e4c34/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= github.com/milvus-io/pulsar-client-go v0.12.1 h1:O2JZp1tsYiO7C0MQ4hrUY/aJXnn2Gry6hpm7UodghmE= github.com/milvus-io/pulsar-client-go v0.12.1/go.mod h1:dkutuH4oS2pXiGm+Ti7fQZ4MRjrMPZ8IJeEGAWMeckk= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index d499872b19..116d8b656a 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -75,22 +75,29 @@ func (r *rankParams) String() string { return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal()) } +type SearchInfo struct { + planInfo *planpb.QueryInfo + offset int64 + parseError error + isIterator bool +} + // parseSearchInfo returns QueryInfo and offset -func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, error) { +func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) *SearchInfo { var topK int64 isAdvanced := rankParams != nil externalLimit := rankParams.GetLimit() + rankParams.GetOffset() topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair) if err != nil { if externalLimit <= 0 { - return nil, 0, fmt.Errorf("%s is required", TopKKey) + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s is required", TopKKey)} } topK = externalLimit } else { topKInParam, err := strconv.ParseInt(topKStr, 0, 64) if err != nil { if externalLimit <= 0 { - return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr) + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)} } topK = externalLimit } else { @@ -98,15 +105,16 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } } - isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair) + isIteratorStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair) + isIterator := (isIteratorStr == "True") || (isIteratorStr == "true") if err := validateLimit(topK); err != nil { - if isIterator == "True" { + if isIterator { // 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem // 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here topK = Params.QuotaConfig.TopKLimit.GetAsInt64() } else { - return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err) + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)} } } @@ -117,12 +125,12 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb if err == nil { offset, err = strconv.ParseInt(offsetStr, 0, 64) if err != nil { - return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr) + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)} } if offset != 0 { if err := validateLimit(offset); err != nil { - return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err) + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)} } } } @@ -130,7 +138,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb queryTopK := topK + offset if err := validateLimit(queryTopK); err != nil { - return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err) + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)} } // 2. parse metrics type @@ -147,11 +155,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64) if err != nil { - return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)} } if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) { - return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)} } // 4. parse search param str @@ -168,30 +176,35 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } else { groupByInfo := parseGroupByInfo(searchParamsPair, schema) if groupByInfo.err != nil { - return nil, 0, groupByInfo.err + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: groupByInfo.err} } groupByFieldId, groupSize, groupStrictSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetGroupStrictSize() } // 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search - if isIterator == "True" && groupByFieldId > 0 { - return nil, 0, merr.WrapErrParameterInvalid("", "", - "Not allowed to do groupBy when doing iteration") + if isIterator && groupByFieldId > 0 { + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "", + "Not allowed to do groupBy when doing iteration")} } if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 { - return nil, 0, merr.WrapErrParameterInvalid("", "", - "Not allowed to do range-search when doing search-group-by") + return &SearchInfo{planInfo: nil, offset: 0, isIterator: false, parseError: merr.WrapErrParameterInvalid("", "", + "Not allowed to do range-search when doing search-group-by")} } - return &planpb.QueryInfo{ - Topk: queryTopK, - MetricType: metricType, - SearchParams: searchParamStr, - RoundDecimal: roundDecimal, - GroupByFieldId: groupByFieldId, - GroupSize: groupSize, - GroupStrictSize: groupStrictSize, - }, offset, nil + return &SearchInfo{ + planInfo: &planpb.QueryInfo{ + Topk: queryTopK, + MetricType: metricType, + SearchParams: searchParamStr, + RoundDecimal: roundDecimal, + GroupByFieldId: groupByFieldId, + GroupSize: groupSize, + GroupStrictSize: groupStrictSize, + }, + offset: offset, + isIterator: isIterator, + parseError: nil, + } } func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) { diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index a79af7e7d7..8188274259 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -79,6 +79,7 @@ type queryParams struct { limit int64 offset int64 reduceType reduce.IReduceType + isIterator bool } // translateToOutputFieldIDs translates output fields name to output fields id. @@ -178,7 +179,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair) // if limit is not provided if err != nil { - return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType}, nil + return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType, isIterator: isIterator}, nil } limit, err = strconv.ParseInt(limitStr, 0, 64) if err != nil { @@ -203,6 +204,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e limit: limit, offset: offset, reduceType: reduceType, + isIterator: isIterator, }, nil } @@ -461,6 +463,11 @@ func (t *queryTask) PreExecute(ctx context.Context) error { } } t.GuaranteeTimestamp = guaranteeTs + // need modify mvccTs and guaranteeTs for iterator specially + if t.queryParams.isIterator && t.request.GetGuaranteeTimestamp() > 0 { + t.MvccTimestamp = t.request.GetGuaranteeTimestamp() + t.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp() + } deadline, ok := t.TraceCtx().Deadline() if ok { @@ -542,6 +549,10 @@ func (t *queryTask) PostExecute(ctx context.Context) error { t.result.OutputFields = t.userOutputFields metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Observe(float64(tr.RecordSpan().Milliseconds())) + if t.queryParams.isIterator && t.request.GetGuaranteeTimestamp() == 0 { + // first page for iteration, need to set up sessionTs for iterator + t.result.SessionTs = t.BeginTs() + } log.Debug("Query PostExecute done") return nil } diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 9f2ec742ef..e2296dc9eb 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -132,126 +132,223 @@ func TestQueryTask_all(t *testing.T) { require.NoError(t, err) require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) - // test begins - task := &queryTask{ - Condition: NewTaskCondition(ctx), - RetrieveRequest: &internalpb.RetrieveRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Retrieve, - SourceID: paramtable.GetNodeID(), + t.Run("test query task parameters", func(t *testing.T) { + task := &queryTask{ + Condition: NewTaskCondition(ctx), + RetrieveRequest: &internalpb.RetrieveRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + SourceID: paramtable.GetNodeID(), + }, + CollectionID: collectionID, + OutputFieldsId: make([]int64, len(fieldName2Types)), }, - CollectionID: collectionID, - OutputFieldsId: make([]int64, len(fieldName2Types)), - }, - ctx: ctx, - result: &milvuspb.QueryResults{ - Status: merr.Success(), - FieldsData: []*schemapb.FieldData{}, - }, - request: &milvuspb.QueryRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Retrieve, - SourceID: paramtable.GetNodeID(), + ctx: ctx, + result: &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{}, }, - CollectionName: collectionName, - Expr: expr, - QueryParams: []*commonpb.KeyValuePair{ - { - Key: IgnoreGrowingKey, - Value: "false", + request: &milvuspb.QueryRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + SourceID: paramtable.GetNodeID(), + }, + CollectionName: collectionName, + Expr: expr, + QueryParams: []*commonpb.KeyValuePair{ + { + Key: IgnoreGrowingKey, + Value: "false", + }, }, }, - }, - qc: qc, - lb: lb, - } + qc: qc, + lb: lb, + } - assert.NoError(t, task.OnEnqueue()) + assert.NoError(t, task.OnEnqueue()) - // test query task with timeout - ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second) - defer cancel1() - // before preExecute - assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp) - task.ctx = ctx1 - assert.NoError(t, task.PreExecute(ctx)) + // test query task with timeout + ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second) + defer cancel1() + // before preExecute + assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp) + task.ctx = ctx1 + assert.NoError(t, task.PreExecute(ctx)) - { - task.mustUsePartitionKey = true - err := task.PreExecute(ctx) - assert.Error(t, err) - assert.ErrorIs(t, err, merr.ErrParameterInvalid) - task.mustUsePartitionKey = false - } + { + task.mustUsePartitionKey = true + err := task.PreExecute(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + task.mustUsePartitionKey = false + } - // after preExecute - assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp) + // after preExecute + assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp) - // check reduce_stop_for_best - assert.Equal(t, false, task.RetrieveRequest.GetReduceStopForBest()) - task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{ - Key: ReduceStopForBestKey, - Value: "trxxxx", - }) - assert.Error(t, task.PreExecute(ctx)) + // check reduce_stop_for_best + assert.Equal(t, false, task.RetrieveRequest.GetReduceStopForBest()) + task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{ + Key: ReduceStopForBestKey, + Value: "trxxxx", + }) + assert.Error(t, task.PreExecute(ctx)) - result1 := &internalpb.RetrieveResults{ - Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult}, - Status: merr.Success(), - Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{Data: testutils.GenerateInt64Array(hitNum)}, + result1 := &internalpb.RetrieveResults{ + Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult}, + Status: merr.Success(), + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: testutils.GenerateInt64Array(hitNum)}, + }, }, - }, - } + } - outputFieldIDs := make([]UniqueID, 0, len(fieldName2Types)) - for i := 0; i < len(fieldName2Types); i++ { - outputFieldIDs = append(outputFieldIDs, int64(common.StartOfUserFieldID+i)) - } - task.RetrieveRequest.OutputFieldsId = outputFieldIDs - for fieldName, dataType := range fieldName2Types { - result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, hitNum)) - } - result1.FieldsData = append(result1.FieldsData, generateFieldData(schemapb.DataType_Int64, common.TimeStampFieldName, hitNum)) - task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField) - task.ctx = ctx - qn.ExpectedCalls = nil - qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() - qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) - assert.Error(t, task.Execute(ctx)) + outputFieldIDs := make([]UniqueID, 0, len(fieldName2Types)) + for i := 0; i < len(fieldName2Types); i++ { + outputFieldIDs = append(outputFieldIDs, int64(common.StartOfUserFieldID+i)) + } + task.RetrieveRequest.OutputFieldsId = outputFieldIDs + for fieldName, dataType := range fieldName2Types { + result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, hitNum)) + } + result1.FieldsData = append(result1.FieldsData, generateFieldData(schemapb.DataType_Int64, common.TimeStampFieldName, hitNum)) + task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField) + task.ctx = ctx + qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")) + assert.Error(t, task.Execute(ctx)) - qn.ExpectedCalls = nil - qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() - qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ - Status: merr.Status(merr.ErrChannelNotAvailable), - }, nil) - err = task.Execute(ctx) - assert.ErrorIs(t, err, merr.ErrChannelNotAvailable) + qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ + Status: merr.Status(merr.ErrChannelNotAvailable), + }, nil) + err = task.Execute(ctx) + assert.ErrorIs(t, err, merr.ErrChannelNotAvailable) - qn.ExpectedCalls = nil - qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() - qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - }, nil) - assert.Error(t, task.Execute(ctx)) + qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, + }, nil) + assert.Error(t, task.Execute(ctx)) - qn.ExpectedCalls = nil - qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() - qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil) - assert.NoError(t, task.Execute(ctx)) + qn.ExpectedCalls = nil + qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() + qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil) + assert.NoError(t, task.Execute(ctx)) - task.queryParams = &queryParams{ - limit: 100, - offset: 100, - } - assert.NoError(t, task.PostExecute(ctx)) + task.queryParams = &queryParams{ + limit: 100, + offset: 100, + } + assert.NoError(t, task.PostExecute(ctx)) - for i := 0; i < len(task.result.FieldsData); i++ { - assert.NotEqual(t, task.result.FieldsData[i].FieldId, common.TimeStampField) - } + for i := 0; i < len(task.result.FieldsData); i++ { + assert.NotEqual(t, task.result.FieldsData[i].FieldId, common.TimeStampField) + } + }) + + t.Run("test query for iterator", func(t *testing.T) { + qt := &queryTask{ + Condition: NewTaskCondition(ctx), + RetrieveRequest: &internalpb.RetrieveRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + SourceID: paramtable.GetNodeID(), + }, + CollectionID: collectionID, + OutputFieldsId: make([]int64, len(fieldName2Types)), + }, + ctx: ctx, + result: &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{}, + }, + request: &milvuspb.QueryRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + SourceID: paramtable.GetNodeID(), + }, + CollectionName: collectionName, + Expr: expr, + QueryParams: []*commonpb.KeyValuePair{ + { + Key: IgnoreGrowingKey, + Value: "false", + }, + { + Key: IteratorField, + Value: "True", + }, + }, + }, + qc: qc, + lb: lb, + resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{}, + } + // simulate scheduler enqueue task + enqueTs := uint64(10000) + qt.SetTs(enqueTs) + qtErr := qt.PreExecute(context.TODO()) + assert.Nil(t, qtErr) + assert.True(t, qt.queryParams.isIterator) + qt.resultBuf.Insert(&internalpb.RetrieveResults{}) + qtErr = qt.PostExecute(context.TODO()) + assert.Nil(t, qtErr) + // after first page, sessionTs is set + assert.True(t, qt.result.GetSessionTs() > 0) + + // next page query task + qt = &queryTask{ + Condition: NewTaskCondition(ctx), + RetrieveRequest: &internalpb.RetrieveRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + SourceID: paramtable.GetNodeID(), + }, + CollectionID: collectionID, + OutputFieldsId: make([]int64, len(fieldName2Types)), + }, + ctx: ctx, + result: &milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{}, + }, + request: &milvuspb.QueryRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + SourceID: paramtable.GetNodeID(), + }, + CollectionName: collectionName, + Expr: expr, + QueryParams: []*commonpb.KeyValuePair{ + { + Key: IgnoreGrowingKey, + Value: "false", + }, + { + Key: IteratorField, + Value: "True", + }, + }, + GuaranteeTimestamp: enqueTs, + }, + qc: qc, + lb: lb, + resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{}, + } + qtErr = qt.PreExecute(context.TODO()) + assert.Nil(t, qtErr) + assert.True(t, qt.queryParams.isIterator) + // from the second page, the mvccTs is set to the sessionTs init in the first page + assert.Equal(t, enqueTs, qt.GetMvccTimestamp()) + }) } func Test_translateToOutputFieldIDs(t *testing.T) { diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index f9728aa9a3..4e14c7b938 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -80,6 +80,8 @@ type searchTask struct { reScorers []reScorer rankParams *rankParams groupScorer func(group *Group) error + + isIterator bool } func (t *searchTask) CanSkipAllocTimestamp() bool { @@ -249,6 +251,10 @@ func (t *searchTask) PreExecute(ctx context.Context) error { } t.SearchRequest.GuaranteeTimestamp = guaranteeTs t.SearchRequest.ConsistencyLevel = consistencyLevel + if t.isIterator && t.request.GetGuaranteeTimestamp() > 0 { + t.MvccTimestamp = t.request.GetGuaranteeTimestamp() + t.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp() + } if deadline, ok := t.TraceCtx().Deadline(); ok { t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0) @@ -351,7 +357,7 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs())) t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs())) for index, subReq := range t.request.GetSubReqs() { - plan, queryInfo, offset, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl()) + plan, queryInfo, offset, _, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl()) if err != nil { return err } @@ -444,11 +450,12 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName)) // fetch search_growing from search param - plan, queryInfo, offset, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl()) + plan, queryInfo, offset, isIterator, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl()) if err != nil { return err } + t.isIterator = isIterator t.SearchRequest.Offset = offset t.SearchRequest.FieldId = queryInfo.GetQueryFieldId() @@ -492,40 +499,40 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { return nil } -func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string) (*planpb.PlanNode, *planpb.QueryInfo, int64, error) { +func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string) (*planpb.PlanNode, *planpb.QueryInfo, int64, bool, error) { annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params) if err != nil || len(annsFieldName) == 0 { vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema) if len(vecFields) == 0 { - return nil, nil, 0, errors.New(AnnsFieldKey + " not found in schema") + return nil, nil, 0, false, errors.New(AnnsFieldKey + " not found in schema") } if enableMultipleVectorFields && len(vecFields) > 1 { - return nil, nil, 0, errors.New("multiple anns_fields exist, please specify a anns_field in search_params") + return nil, nil, 0, false, errors.New("multiple anns_fields exist, please specify a anns_field in search_params") } annsFieldName = vecFields[0].Name } - queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams) - if parseErr != nil { - return nil, nil, 0, parseErr + searchInfo := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams) + if searchInfo.parseError != nil { + return nil, nil, 0, false, searchInfo.parseError } annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName) - if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector { - return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column") + if searchInfo.planInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector { + return nil, nil, 0, false, errors.New("not support search_group_by operation based on binary vector column") } - queryInfo.QueryFieldId = annField.GetFieldID() - plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, queryInfo) + searchInfo.planInfo.QueryFieldId = annField.GetFieldID() + plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, searchInfo.planInfo) if planErr != nil { log.Warn("failed to create query plan", zap.Error(planErr), zap.String("dsl", dsl), // may be very large if large term passed. - zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo)) - return nil, nil, 0, merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", planErr) + zap.String("anns field", annsFieldName), zap.Any("query info", searchInfo.planInfo)) + return nil, nil, 0, false, merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", planErr) } log.Debug("create query plan", zap.String("dsl", t.request.Dsl), // may be very large if large term passed. - zap.String("anns field", annsFieldName), zap.Any("query info", queryInfo)) - return plan, queryInfo, offset, nil + zap.String("anns field", annsFieldName), zap.Any("query info", searchInfo.planInfo)) + return plan, searchInfo.planInfo, searchInfo.offset, searchInfo.isIterator, nil } func (t *searchTask) tryParsePartitionIDsFromPlan(plan *planpb.PlanNode) ([]int64, error) { @@ -718,6 +725,10 @@ func (t *searchTask) PostExecute(ctx context.Context) error { } t.result.Results.OutputFields = t.userOutputFields t.result.CollectionName = t.request.GetCollectionName() + if t.isIterator && t.request.GetGuaranteeTimestamp() == 0 { + // first page for iteration, need to set up sessionTs for iterator + t.result.SessionTs = t.BeginTs() + } metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds())) diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 70799c9fde..b0b0e77d95 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -301,6 +301,55 @@ func TestSearchTask_PreExecute(t *testing.T) { task.request.OutputFields = []string{testFloatVecField} assert.NoError(t, task.PreExecute(ctx)) }) + + t.Run("search consistent iterator pre_ts", func(t *testing.T) { + collName := "search_with_timeout" + funcutil.GenRandomStr() + createColl(t, collName, rc) + + st := getSearchTask(t, collName) + st.request.SearchParams = getValidSearchParams() + st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }) + st.request.GuaranteeTimestamp = 1000 + st.request.DslType = commonpb.DslType_BoolExprV1 + + ctxTimeout, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp) + + st.ctx = ctxTimeout + assert.NoError(t, st.PreExecute(ctx)) + assert.True(t, st.isIterator) + assert.True(t, st.GetMvccTimestamp() > 0) + assert.Equal(t, uint64(1000), st.GetGuaranteeTimestamp()) + }) + + t.Run("search consistent iterator post_ts", func(t *testing.T) { + collName := "search_with_timeout" + funcutil.GenRandomStr() + createColl(t, collName, rc) + + st := getSearchTask(t, collName) + st.request.SearchParams = getValidSearchParams() + st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }) + st.request.DslType = commonpb.DslType_BoolExprV1 + + _, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp) + enqueueTs := uint64(100000) + st.SetTs(enqueueTs) + assert.NoError(t, st.PreExecute(ctx)) + assert.True(t, st.isIterator) + assert.True(t, st.GetMvccTimestamp() == 0) + st.resultBuf.Insert(&internalpb.SearchResults{}) + st.PostExecute(context.TODO()) + assert.Equal(t, st.result.GetSessionTs(), enqueueTs) + }) } func getQueryCoord() *mocks.MockQueryCoord { @@ -2235,11 +2284,11 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - info, offset, err := parseSearchInfo(test.validParams, nil, nil) - assert.NoError(t, err) - assert.NotNil(t, info) + searchInfo := parseSearchInfo(test.validParams, nil, nil) + assert.NoError(t, searchInfo.parseError) + assert.NotNil(t, searchInfo.planInfo) if test.description == "offsetParam" { - assert.Equal(t, targetOffset, offset) + assert.Equal(t, targetOffset, searchInfo.offset) } }) } @@ -2256,11 +2305,11 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { limit: externalLimit, } - info, offset, err := parseSearchInfo(offsetParam, nil, rank) - assert.NoError(t, err) - assert.NotNil(t, info) - assert.Equal(t, int64(10), info.GetTopk()) - assert.Equal(t, int64(0), offset) + searchInfo := parseSearchInfo(offsetParam, nil, rank) + assert.NoError(t, searchInfo.parseError) + assert.NotNil(t, searchInfo.planInfo) + assert.Equal(t, int64(10), searchInfo.planInfo.GetTopk()) + assert.Equal(t, int64(0), searchInfo.offset) }) t.Run("parseSearchInfo groupBy info for hybrid search", func(t *testing.T) { @@ -2309,15 +2358,15 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { Value: "true", }) - info, _, err := parseSearchInfo(params, schema, testRankParams) - assert.NoError(t, err) - assert.NotNil(t, info) + searchInfo := parseSearchInfo(params, schema, testRankParams) + assert.NoError(t, searchInfo.parseError) + assert.NotNil(t, searchInfo.planInfo) // all group_by related parameters should be aligned to parameters // set by main request rather than inner sub request - assert.Equal(t, int64(101), info.GetGroupByFieldId()) - assert.Equal(t, int64(3), info.GetGroupSize()) - assert.False(t, info.GetGroupStrictSize()) + assert.Equal(t, int64(101), searchInfo.planInfo.GetGroupByFieldId()) + assert.Equal(t, int64(3), searchInfo.planInfo.GetGroupSize()) + assert.False(t, searchInfo.planInfo.GetGroupStrictSize()) }) t.Run("parseSearchInfo error", func(t *testing.T) { @@ -2399,12 +2448,12 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - info, offset, err := parseSearchInfo(test.invalidParams, nil, nil) - assert.Error(t, err) - assert.Nil(t, info) - assert.Zero(t, offset) + searchInfo := parseSearchInfo(test.invalidParams, nil, nil) + assert.Error(t, searchInfo.parseError) + assert.Nil(t, searchInfo.planInfo) + assert.Zero(t, searchInfo.offset) - t.Logf("err=%s", err.Error()) + t.Logf("err=%s", searchInfo.parseError) }) } }) @@ -2426,9 +2475,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: fields, } - info, _, err := parseSearchInfo(normalParam, schema, nil) - assert.Nil(t, info) - assert.ErrorIs(t, err, merr.ErrParameterInvalid) + searchInfo := parseSearchInfo(normalParam, schema, nil) + assert.Nil(t, searchInfo.planInfo) + assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid) }) t.Run("check range-search and groupBy", func(t *testing.T) { normalParam := getValidSearchParams() @@ -2445,9 +2494,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: fields, } - info, _, err := parseSearchInfo(normalParam, schema, nil) - assert.Nil(t, info) - assert.ErrorIs(t, err, merr.ErrParameterInvalid) + searchInfo := parseSearchInfo(normalParam, schema, nil) + assert.Nil(t, searchInfo.planInfo) + assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid) }) t.Run("check nullable and groupBy", func(t *testing.T) { normalParam := getValidSearchParams() @@ -2464,9 +2513,9 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: fields, } - info, _, err := parseSearchInfo(normalParam, schema, nil) - assert.Nil(t, info) - assert.ErrorIs(t, err, merr.ErrParameterInvalid) + searchInfo := parseSearchInfo(normalParam, schema, nil) + assert.Nil(t, searchInfo.planInfo) + assert.ErrorIs(t, searchInfo.parseError, merr.ErrParameterInvalid) }) t.Run("check iterator and topK", func(t *testing.T) { normalParam := getValidSearchParams() @@ -2483,10 +2532,10 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: fields, } - info, _, err := parseSearchInfo(normalParam, schema, nil) - assert.NotNil(t, info) - assert.NoError(t, err) - assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), info.Topk) + searchInfo := parseSearchInfo(normalParam, schema, nil) + assert.NotNil(t, searchInfo.planInfo) + assert.NoError(t, searchInfo.parseError) + assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), searchInfo.planInfo.GetTopk()) }) t.Run("check max group size", func(t *testing.T) { @@ -2503,15 +2552,15 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: fields, } - info, _, err := parseSearchInfo(normalParam, schema, nil) - assert.Nil(t, info) - assert.Error(t, err) - assert.True(t, strings.Contains(err.Error(), "exceeds configured max group size")) + searchInfo := parseSearchInfo(normalParam, schema, nil) + assert.Nil(t, searchInfo.planInfo) + assert.Error(t, searchInfo.parseError) + assert.True(t, strings.Contains(searchInfo.parseError.Error(), "exceeds configured max group size")) resetSearchParamsValue(normalParam, GroupSizeKey, `10`) - info, _, err = parseSearchInfo(normalParam, schema, nil) - assert.NotNil(t, info) - assert.NoError(t, err) + searchInfo = parseSearchInfo(normalParam, schema, nil) + assert.NotNil(t, searchInfo.planInfo) + assert.NoError(t, searchInfo.parseError) }) } diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index a4c90dd14a..8fd8447914 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -439,7 +439,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq // wait tsafe waitTr := timerecord.NewTimeRecorder("wait tSafe") - tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) + tSafe, err := sd.waitTSafe(ctx, req.Req.GetGuaranteeTimestamp()) if err != nil { log.Warn("delegator query failed to wait tsafe", zap.Error(err)) return err @@ -512,7 +512,7 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) // wait tsafe waitTr := timerecord.NewTimeRecorder("wait tSafe") - tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) + tSafe, err := sd.waitTSafe(ctx, req.Req.GetGuaranteeTimestamp()) if err != nil { log.Warn("delegator query failed to wait tsafe", zap.Error(err)) return nil, err