From be62f82da443d173e8ea9d656dc02fe01a760ab7 Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Wed, 31 May 2023 10:17:29 +0800 Subject: [PATCH] Add consistency_level in search/query request (#24541) Signed-off-by: zhenshan.cao --- go.mod | 2 +- go.sum | 2 + internal/proxy/meta_cache.go | 3 + internal/proxy/task_search.go | 40 ++++++++++++-- internal/proxy/util.go | 13 +++++ internal/proxy/util_test.go | 32 +++++++++++ internal/util/typeutil/schema.go | 10 ++++ internal/util/typeutil/schema_test.go | 58 ++++++++++++++++++++ tests/python_client/requirements.txt | 2 +- tests/python_client/testcases/test_search.py | 36 ++++++++++++ 10 files changed, 190 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index 2f59961b8e..7c684c424e 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( github.com/klauspost/compress v1.14.2 github.com/lingdor/stackerror v0.0.0-20191119040541-976d8885ed76 github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d - github.com/milvus-io/milvus-proto/go-api v0.0.0-20230526035721-38841a224dac + github.com/milvus-io/milvus-proto/go-api v0.0.0-20230529035323-00da05d318d1 github.com/minio/minio-go/v7 v7.0.17 github.com/opentracing/opentracing-go v1.2.0 github.com/panjf2000/ants/v2 v2.4.8 diff --git a/go.sum b/go.sum index 194c21dcc7..ebfea23519 100644 --- a/go.sum +++ b/go.sum @@ -512,6 +512,8 @@ github.com/milvus-io/milvus-proto/go-api v0.0.0-20230522080721-2975bfe7a190 h1:Z github.com/milvus-io/milvus-proto/go-api v0.0.0-20230522080721-2975bfe7a190/go.mod h1:148qnlmZ0Fdm1Fq+Mj/OW2uDoEP25g3mjh0vMGtkgmk= github.com/milvus-io/milvus-proto/go-api v0.0.0-20230526035721-38841a224dac h1:MC4X/pkkGvKEXhIiO52ZA0SX0c2MMNhqVoxOLIe8q/M= github.com/milvus-io/milvus-proto/go-api v0.0.0-20230526035721-38841a224dac/go.mod h1:148qnlmZ0Fdm1Fq+Mj/OW2uDoEP25g3mjh0vMGtkgmk= +github.com/milvus-io/milvus-proto/go-api v0.0.0-20230529035323-00da05d318d1 h1:ushb9LriQIuX6ephDhS3SdMqEFq7OlgMBE5ruTwUEhE= +github.com/milvus-io/milvus-proto/go-api v0.0.0-20230529035323-00da05d318d1/go.mod h1:148qnlmZ0Fdm1Fq+Mj/OW2uDoEP25g3mjh0vMGtkgmk= github.com/milvus-io/pulsar-client-go v0.6.8 h1:fZdZH73aPRszu2fazyeeahQEz34tyn1Pt9EkqJmV100= github.com/milvus-io/pulsar-client-go v0.6.8/go.mod h1:oFIlYIk23tamkSLttw849qphmMIpHY8ztEBWDWJW+sc= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 278293a8a4..06885bb1ef 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -94,6 +94,7 @@ type collectionInfo struct { createdUtcTimestamp uint64 isLoaded bool database string + consistencyLevel commonpb.ConsistencyLevel } func (info *collectionInfo) isCollectionCached() bool { @@ -401,6 +402,7 @@ func (m *MetaCache) updateCollection(coll *milvuspb.DescribeCollectionResponse, m.collInfo[database][collectionName].collID = coll.CollectionID m.collInfo[database][collectionName].createdTimestamp = coll.CreatedTimestamp m.collInfo[database][collectionName].createdUtcTimestamp = coll.CreatedUtcTimestamp + m.collInfo[database][collectionName].consistencyLevel = coll.ConsistencyLevel } func (m *MetaCache) GetPartitionID(ctx context.Context, database, collectionName string, partitionName string) (typeutil.UniqueID, error) { @@ -560,6 +562,7 @@ func (m *MetaCache) describeCollection(ctx context.Context, database, collection CreatedTimestamp: coll.CreatedTimestamp, CreatedUtcTimestamp: coll.CreatedUtcTimestamp, DbName: coll.GetDbName(), + ConsistencyLevel: coll.ConsistencyLevel, } for _, field := range coll.Schema.Fields { if field.FieldID >= common.StartOfUserFieldID { diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 8ad6f85842..9e3aa7f69d 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -274,10 +274,16 @@ func (t *searchTask) PreExecute(ctx context.Context) error { partitionNames := t.request.GetPartitionNames() if t.request.GetDslType() == commonpb.DslType_BoolExprV1 { annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams()) - if err != nil { - return errors.New(AnnsFieldKey + " not found in search_params") + if err != nil || len(annsField) == 0 { + if enableMultipleVectorFields { + return errors.New(AnnsFieldKey + " not found in search_params") + } + vecFieldSchema, err2 := typeutil.GetVectorFieldSchema(t.schema) + if err2 != nil { + return errors.New(AnnsFieldKey + " not found in schema") + } + annsField = vecFieldSchema.Name } - queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams()) if err != nil { return err @@ -343,12 +349,32 @@ func (t *searchTask) PreExecute(ctx context.Context) error { if err != nil { return err } - t.SearchRequest.TravelTimestamp = travelTimestamp + collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName) + if err2 != nil { + log.Ctx(ctx).Debug("Proxy::searchTask::PreExecute failed to GetCollectionInfo from cache", + zap.Any("collectionName", collectionName), zap.Error(err2)) + return err2 + } guaranteeTs := t.request.GetGuaranteeTimestamp() - guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs()) - t.SearchRequest.GuaranteeTimestamp = guaranteeTs + var consistencyLevel commonpb.ConsistencyLevel + useDefaultConsistency := t.request.GetUseDefaultConsistency() + if useDefaultConsistency { + consistencyLevel = collectionInfo.consistencyLevel + guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel) + } else { + consistencyLevel = t.request.GetConsistencyLevel() + //Compatibility logic, parse guarantee timestamp + if consistencyLevel == 0 && guaranteeTs > 0 { + guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs()) + } else { + // parse from guarantee timestamp and user input consistency level + guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel) + } + } + t.SearchRequest.GuaranteeTimestamp = guaranteeTs + t.SearchRequest.TravelTimestamp = travelTimestamp if deadline, ok := t.TraceCtx().Deadline(); ok { t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0) } @@ -369,6 +395,8 @@ func (t *searchTask) PreExecute(ctx context.Context) error { log.Ctx(ctx).Debug("search PreExecute done.", zap.Int64("msgID", t.ID()), zap.Uint64("travel_ts", travelTimestamp), zap.Uint64("guarantee_ts", guaranteeTs), + zap.Bool("use_default_consistency", useDefaultConsistency), + zap.Any("consistency level", consistencyLevel), zap.Uint64("timeout_ts", t.SearchRequest.GetTimeoutTimestamp())) return nil diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 1e047b5af1..ac4d67641b 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -699,6 +699,19 @@ func ReplaceID2Name(oldStr string, id int64, name string) string { return strings.ReplaceAll(oldStr, strconv.FormatInt(id, 10), name) } +func parseGuaranteeTsFromConsistency(ts, tMax typeutil.Timestamp, consistency commonpb.ConsistencyLevel) typeutil.Timestamp { + switch consistency { + case commonpb.ConsistencyLevel_Strong: + ts = tMax + case commonpb.ConsistencyLevel_Bounded: + ratio := time.Duration(-Params.CommonCfg.GracefulTime) + ts = tsoutil.AddPhysicalDurationOnTs(tMax, ratio*time.Millisecond) + case commonpb.ConsistencyLevel_Eventually: + ts = 1 + } + return ts +} + func parseGuaranteeTs(ts, tMax typeutil.Timestamp) typeutil.Timestamp { switch ts { case strongTS: diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 7361624f8d..1d46dbf0b9 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -1010,3 +1010,35 @@ func Test_isPartitionIsLoaded(t *testing.T) { assert.False(t, loaded) }) } + +func Test_ParseGuaranteeTs(t *testing.T) { + strongTs := typeutil.Timestamp(0) + boundedTs := typeutil.Timestamp(2) + tsNow := tsoutil.GetCurrentTime() + tsMax := tsoutil.GetCurrentTime() + + assert.Equal(t, tsMax, parseGuaranteeTs(strongTs, tsMax)) + ratio := time.Duration(-Params.CommonCfg.GracefulTime) + assert.Equal(t, tsoutil.AddPhysicalDurationOnTs(tsMax, ratio*time.Millisecond), parseGuaranteeTs(boundedTs, tsMax)) + assert.Equal(t, tsNow, parseGuaranteeTs(tsNow, tsMax)) +} + +func Test_ParseGuaranteeTsFromConsistency(t *testing.T) { + strong := commonpb.ConsistencyLevel_Strong + bounded := commonpb.ConsistencyLevel_Bounded + eventually := commonpb.ConsistencyLevel_Eventually + session := commonpb.ConsistencyLevel_Session + customized := commonpb.ConsistencyLevel_Customized + + tsDefault := typeutil.Timestamp(0) + tsEventually := typeutil.Timestamp(1) + tsNow := tsoutil.GetCurrentTime() + tsMax := tsoutil.GetCurrentTime() + + assert.Equal(t, tsMax, parseGuaranteeTsFromConsistency(tsDefault, tsMax, strong)) + ratio := time.Duration(-Params.CommonCfg.GracefulTime) + assert.Equal(t, tsoutil.AddPhysicalDurationOnTs(tsMax, ratio*time.Millisecond), parseGuaranteeTsFromConsistency(tsDefault, tsMax, bounded)) + assert.Equal(t, tsNow, parseGuaranteeTsFromConsistency(tsNow, tsMax, session)) + assert.Equal(t, tsNow, parseGuaranteeTsFromConsistency(tsNow, tsMax, customized)) + assert.Equal(t, tsEventually, parseGuaranteeTsFromConsistency(tsDefault, tsMax, eventually)) +} diff --git a/internal/util/typeutil/schema.go b/internal/util/typeutil/schema.go index 9069a9e78c..ff68640e59 100644 --- a/internal/util/typeutil/schema.go +++ b/internal/util/typeutil/schema.go @@ -229,6 +229,16 @@ func CreateSchemaHelper(schema *schemapb.CollectionSchema) (*SchemaHelper, error return &schemaHelper, nil } +// GetVectorFieldSchema get vector field schema from collection schema. +func GetVectorFieldSchema(schema *schemapb.CollectionSchema) (*schemapb.FieldSchema, error) { + for _, fieldSchema := range schema.Fields { + if IsVectorType(fieldSchema.DataType) { + return fieldSchema, nil + } + } + return nil, errors.New("vector field is not found") +} + // GetPrimaryKeyField returns the schema of the primary key func (helper *SchemaHelper) GetPrimaryKeyField() (*schemapb.FieldSchema, error) { if helper.primaryKeyOffset == -1 { diff --git a/internal/util/typeutil/schema_test.go b/internal/util/typeutil/schema_test.go index 5d58f55c62..6be4c4e1db 100644 --- a/internal/util/typeutil/schema_test.go +++ b/internal/util/typeutil/schema_test.go @@ -210,6 +210,64 @@ func TestSchema(t *testing.T) { }) } +func TestSchema_GetVectorFieldSchema(t *testing.T) { + + schemaNormal := &schemapb.CollectionSchema{ + Name: "testColl", + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "field_int64", + IsPrimaryKey: true, + Description: "", + DataType: 5, + }, + { + FieldID: 107, + Name: "field_float_vector", + IsPrimaryKey: false, + Description: "", + DataType: 101, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "128", + }, + }, + }, + }, + } + + t.Run("GetVectorFieldSchema", func(t *testing.T) { + fieldSchema, err := GetVectorFieldSchema(schemaNormal) + assert.Equal(t, "field_float_vector", fieldSchema.Name) + assert.Nil(t, err) + }) + + schemaInvalid := &schemapb.CollectionSchema{ + Name: "testColl", + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "field_int64", + IsPrimaryKey: true, + Description: "", + DataType: 5, + }, + }, + } + + t.Run("GetVectorFieldSchemaInvalid", func(t *testing.T) { + _, err := GetVectorFieldSchema(schemaInvalid) + assert.Error(t, err) + }) + +} + func TestSchema_invalid(t *testing.T) { t.Run("Duplicate field name", func(t *testing.T) { schema := &schemapb.CollectionSchema{ diff --git a/tests/python_client/requirements.txt b/tests/python_client/requirements.txt index fb7d6a43b5..0d515f1d76 100644 --- a/tests/python_client/requirements.txt +++ b/tests/python_client/requirements.txt @@ -12,7 +12,7 @@ allure-pytest==2.7.0 pytest-print==0.2.1 pytest-level==0.1.1 pytest-xdist==2.5.0 -pymilvus==2.2.9.dev16 +pymilvus==2.2.9.dev28 pytest-rerunfailures==9.1.1 git+https://github.com/Projectplace/pytest-tags ndg-httpsclient diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index c95c3df891..dd421d6e8b 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -991,6 +991,42 @@ class TestCollectionSearch(TestcaseBase): "ids": insert_ids, "limit": default_limit}) + + @pytest.mark.tags(CaseLabel.L0) + def test_search_normal_without_specify_anns_field(self): + """ + target: test search normal case + method: create connection, collection, insert and search + expected: 1. search returned with 0 before travel timestamp + 2. search successfully with limit(topK) after travel timestamp + """ + nq = 2 + dim = 32 + auto_id = True + # 1. initialize with data + collection_w, _, _, insert_ids, time_stamp = \ + self.init_collection_general(prefix, True, auto_id=auto_id, dim=dim, is_flush=True)[0:5] + # 2. search before insert time_stamp + log.info("test_search_normal: searching collection %s" % collection_w.name) + vectors = [[random.random() for _ in range(dim)] for _ in range(nq)] + collection_w.search(vectors[:nq], "", + default_search_params, default_limit, + default_search_exp, + travel_timestamp=time_stamp - 1, + check_task=CheckTasks.err_res, + check_items={"err_code": 1, + "err_msg": f"only support to travel back to 0s so far"}) + # 3. search after insert time_stamp + collection_w.search(vectors[:nq], "", + default_search_params, default_limit, + default_search_exp, + travel_timestamp=0, + guarantee_timestamp=0, + check_task=CheckTasks.check_search_results, + check_items={"nq": nq, + "ids": insert_ids, + "limit": default_limit}) + @pytest.mark.tags(CaseLabel.L0) def test_search_with_hit_vectors(self, nq, dim, auto_id): """