From 4c91e05a5d75fd940dac088a5ab58e2fcb88c6ea Mon Sep 17 00:00:00 2001 From: Chun Han <116052805+MrPresent-Han@users.noreply.github.com> Date: Wed, 15 Jan 2025 10:36:59 +0800 Subject: [PATCH] enhance: fix inconsistenty of alias and db for query iterator(#39045) (#39248) related: #39045 pr: https://github.com/milvus-io/milvus/pull/39216 Signed-off-by: MrPresent-Han Co-authored-by: MrPresent-Han --- internal/proxy/search_util.go | 19 +++++++++++------- internal/proxy/task.go | 1 + internal/proxy/task_query.go | 32 ++++++++++++++++++++++-------- internal/proxy/task_query_test.go | 23 +++++++++++++++++++++ internal/proxy/task_search.go | 5 +++++ internal/proxy/task_search_test.go | 24 ++++++++++++++++++++++ 6 files changed, 89 insertions(+), 15 deletions(-) diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index c3a8254827..bb4b254d24 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -77,10 +77,11 @@ func (r *rankParams) String() string { } type SearchInfo struct { - planInfo *planpb.QueryInfo - offset int64 - parseError error - isIterator bool + planInfo *planpb.QueryInfo + offset int64 + parseError error + isIterator bool + collectionID int64 } func parseSearchIteratorV2Info(searchParamsPair []*commonpb.KeyValuePair, groupByFieldId int64, isIterator bool, offset int64, queryTopK *int64) (*planpb.SearchIteratorV2Info, error) { @@ -184,6 +185,9 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb isIteratorStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair) isIterator := (isIteratorStr == "True") || (isIteratorStr == "true") + collectionIDStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, searchParamsPair) + collectionId, _ := strconv.ParseInt(collectionIDStr, 0, 64) + if err := validateLimit(topK); err != nil { if isIterator { // 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem @@ -289,9 +293,10 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb Hints: hints, SearchIteratorV2Info: planSearchIteratorV2Info, }, - offset: offset, - isIterator: isIterator, - parseError: nil, + offset: offset, + isIterator: isIterator, + parseError: nil, + collectionID: collectionId, } } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 99ca815496..b10835d075 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -55,6 +55,7 @@ const ( IgnoreGrowingKey = "ignore_growing" ReduceStopForBestKey = "reduce_stop_for_best" IteratorField = "iterator" + CollectionID = "collection_id" GroupByFieldKey = "group_by_field" GroupSizeKey = "group_size" StrictGroupSize = "strict_group_size" diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 85febb525d..1729e56408 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -76,10 +76,11 @@ type queryTask struct { } type queryParams struct { - limit int64 - offset int64 - reduceType reduce.IReduceType - isIterator bool + limit int64 + offset int64 + reduceType reduce.IReduceType + isIterator bool + collectionID int64 } // translateToOutputFieldIDs translates output fields name to output fields id. @@ -146,6 +147,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e reduceStopForBest bool isIterator bool err error + collectionID int64 ) reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair) // if reduce_stop_for_best is provided @@ -167,6 +169,15 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e } } + collectionIdStr, err := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, queryParamsPair) + if err == nil { + collectionID, err = strconv.ParseInt(collectionIdStr, 0, 64) + if err != nil { + return nil, merr.WrapErrParameterInvalid("int value for collection_id", CollectionID, + "value for collection id is invalid") + } + } + reduceType := reduce.IReduceNoOrder if isIterator { if reduceStopForBest { @@ -201,10 +212,11 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e } return &queryParams{ - limit: limit, - offset: offset, - reduceType: reduceType, - isIterator: isIterator, + limit: limit, + offset: offset, + reduceType: reduceType, + isIterator: isIterator, + collectionID: collectionID, }, nil } @@ -364,6 +376,10 @@ func (t *queryTask) PreExecute(ctx context.Context) error { if err != nil { return err } + if queryParams.collectionID > 0 && queryParams.collectionID != t.GetCollectionID() { + return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("Input collection id is not consistent to collectionID in the context," + + "alias or database may have changed")) + } if queryParams.reduceType == reduce.IReduceInOrderForBest { t.RetrieveRequest.ReduceStopForBest = true } diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 244d091075..43ffb2a096 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -193,6 +193,29 @@ func TestQueryTask_all(t *testing.T) { Value: "trxxxx", }) assert.Error(t, task.PreExecute(ctx)) + task.request.QueryParams = task.request.QueryParams[0 : len(task.request.QueryParams)-1] + + // check parse collection id + task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{ + Key: CollectionID, + Value: "trxxxx", + }) + err := task.PreExecute(ctx) + assert.Error(t, err) + task.request.QueryParams = task.request.QueryParams[0 : len(task.request.QueryParams)-1] + + // check collection id consistency + task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{ + Key: LimitKey, + Value: "11", + }) + task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{ + Key: CollectionID, + Value: "8080", + }) + err = task.PreExecute(ctx) + assert.Error(t, err) + task.request.QueryParams = make([]*commonpb.KeyValuePair, 0) result1 := &internalpb.RetrieveResults{ Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult}, diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index ab1b7a5592..98e073baa7 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -521,6 +521,11 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string if searchInfo.parseError != nil { return nil, nil, 0, false, searchInfo.parseError } + if searchInfo.collectionID > 0 && searchInfo.collectionID != t.GetCollectionID() { + return nil, nil, 0, false, merr.WrapErrParameterInvalidMsg("collection id:%d in the request is not consistent to that in the search context,"+ + "alias or database may have been changed: %d", searchInfo.collectionID, t.GetCollectionID()) + } + annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName) if searchInfo.planInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector { return nil, nil, 0, false, errors.New("not support search_group_by operation based on binary vector column") diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 04f5bec997..d369472435 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -473,6 +473,30 @@ func TestSearchTask_PreExecute(t *testing.T) { st.PostExecute(context.TODO()) assert.Equal(t, st.result.GetSessionTs(), enqueueTs) }) + + t.Run("search inconsistent collection_id", func(t *testing.T) { + collName := "search_inconsistent_collection" + funcutil.GenRandomStr() + createColl(t, collName, rc) + + st := getSearchTask(t, collName) + st.request.SearchParams = getValidSearchParams() + st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }) + st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{ + Key: CollectionID, + Value: "8080", + }) + st.request.DslType = commonpb.DslType_BoolExprV1 + + _, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp) + enqueueTs := uint64(100000) + st.SetTs(enqueueTs) + assert.Error(t, st.PreExecute(ctx)) + }) } func getQueryCoord() *mocks.MockQueryCoord {