From fe20366b5c1fee108ca0834ee2f297a2ff2415f0 Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 15 Aug 2024 19:18:53 +0800 Subject: [PATCH] enhance: Remove duplicated schema helper creation in proxy (#35489) Related to PRs of #35415 Signed-off-by: Congqi Xia --- internal/proxy/task_index.go | 20 +------ internal/proxy/task_insert.go | 2 +- internal/proxy/task_test.go | 73 +++++++++++------------ internal/proxy/task_upsert.go | 2 +- internal/proxy/util.go | 6 +- internal/proxy/validate_util.go | 10 ++-- internal/proxy/validate_util_test.go | 88 ++++++++++++++++++++-------- 7 files changed, 110 insertions(+), 91 deletions(-) diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index dfa70e3374..607fb1b89c 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -350,12 +350,7 @@ func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.Fiel log.Error("failed to get collection schema", zap.Error(err)) return nil, fmt.Errorf("failed to get collection schema: %s", err) } - schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema) - if err != nil { - log.Error("failed to parse collection schema", zap.Error(err)) - return nil, fmt.Errorf("failed to parse collection schema: %s", err) - } - field, err := schemaHelper.GetFieldFromName(cit.req.GetFieldName()) + field, err := schema.schemaHelper.GetFieldFromName(cit.req.GetFieldName()) if err != nil { log.Error("create index on non-exist field", zap.Error(err)) return nil, fmt.Errorf("cannot create index on non-exist field: %s", cit.req.GetFieldName()) @@ -678,11 +673,6 @@ func (dit *describeIndexTask) Execute(ctx context.Context) error { log.Error("failed to get collection schema", zap.Error(err)) return fmt.Errorf("failed to get collection schema: %s", err) } - schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema) - if err != nil { - log.Error("failed to parse collection schema", zap.Error(err)) - return fmt.Errorf("failed to parse collection schema: %s", err) - } resp, err := dit.datacoord.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{CollectionID: dit.collectionID, IndexName: dit.IndexName, Timestamp: dit.Timestamp}) if err != nil { @@ -700,7 +690,7 @@ func (dit *describeIndexTask) Execute(ctx context.Context) error { return err } for _, indexInfo := range resp.IndexInfos { - field, err := schemaHelper.GetFieldFromID(indexInfo.FieldID) + field, err := schema.schemaHelper.GetFieldFromID(indexInfo.FieldID) if err != nil { log.Error("failed to get collection field", zap.Error(err)) return fmt.Errorf("failed to get collection field: %d", indexInfo.FieldID) @@ -802,11 +792,7 @@ func (dit *getIndexStatisticsTask) Execute(ctx context.Context) error { log.Error("failed to get collection schema", zap.String("collection_name", dit.GetCollectionName()), zap.Error(err)) return fmt.Errorf("failed to get collection schema: %s", dit.GetCollectionName()) } - schemaHelper, err := typeutil.CreateSchemaHelper(schema.CollectionSchema) - if err != nil { - log.Error("failed to parse collection schema", zap.String("collection_name", schema.GetName()), zap.Error(err)) - return fmt.Errorf("failed to parse collection schema: %s", dit.GetCollectionName()) - } + schemaHelper := schema.schemaHelper resp, err := dit.datacoord.GetIndexStatistics(ctx, &indexpb.GetIndexStatisticsRequest{ CollectionID: dit.collectionID, IndexName: dit.IndexName, diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 26c961676e..be46569b66 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -213,7 +213,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { } if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck(), withMaxCapCheck()). - Validate(it.insertMsg.GetFieldsData(), schema.CollectionSchema, it.insertMsg.NRows()); err != nil { + Validate(it.insertMsg.GetFieldsData(), schema.schemaHelper, it.insertMsg.NRows()); err != nil { return merr.WrapErrAsInputError(err) } diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 06f80220ef..a85813c633 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -2088,6 +2088,35 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { }, } + idField := &schemapb.FieldSchema{ + FieldID: 100, + Name: "id", + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: nil, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + AutoID: false, + } + vectorField := &schemapb.FieldSchema{ + FieldID: 101, + Name: fieldName, + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: nil, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + AutoID: false, + } + t.Run("normal", func(t *testing.T) { cache := NewMockCache(t) cache.On("GetCollectionSchema", @@ -2096,20 +2125,8 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { mock.AnythingOfType("string"), ).Return(newSchemaInfo(&schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ - { - FieldID: 100, - Name: fieldName, - IsPrimaryKey: false, - DataType: schemapb.DataType_FloatVector, - TypeParams: nil, - IndexParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "128", - }, - }, - AutoID: false, - }, + idField, + vectorField, }, }), nil) @@ -2131,28 +2148,9 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { assert.Error(t, err) }) - t.Run("invalid schema", func(t *testing.T) { - cache := NewMockCache(t) - cache.On("GetCollectionSchema", - mock.Anything, // context.Context - mock.AnythingOfType("string"), - mock.AnythingOfType("string"), - ).Return(newSchemaInfo(&schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ - { - Name: fieldName, - }, - { - Name: fieldName, // duplicate - }, - }, - }), nil) - globalMetaCache = cache - _, err := cit.getIndexedField(context.Background()) - assert.Error(t, err) - }) - t.Run("field not found", func(t *testing.T) { + otherField := typeutil.Clone(vectorField) + otherField.Name = otherField.Name + "_other" cache := NewMockCache(t) cache.On("GetCollectionSchema", mock.Anything, // context.Context @@ -2160,9 +2158,8 @@ func Test_createIndexTask_getIndexedField(t *testing.T) { mock.AnythingOfType("string"), ).Return(newSchemaInfo(&schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ - { - Name: fieldName + fieldName, - }, + idField, + otherField, }, }), nil) globalMetaCache = cache diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index f9054bf54f..ed82be44bb 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -223,7 +223,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { } if err := newValidateUtil(withNANCheck(), withOverflowCheck(), withMaxLenCheck()). - Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.CollectionSchema, it.upsertMsg.InsertMsg.NRows()); err != nil { + Validate(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.schemaHelper, it.upsertMsg.InsertMsg.NRows()); err != nil { return err } diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 2a2ba490c2..99438abb4e 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1012,11 +1012,7 @@ func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary userOutputFieldsMap[outputFieldName] = true } else { if schema.EnableDynamicField { - schemaH, err := typeutil.CreateSchemaHelper(schema.CollectionSchema) - if err != nil { - return nil, nil, err - } - err = planparserv2.ParseIdentifier(schemaH, outputFieldName, func(expr *planpb.Expr) error { + err := planparserv2.ParseIdentifier(schema.schemaHelper, outputFieldName, func(expr *planpb.Expr) error { if len(expr.GetColumnExpr().GetInfo().GetNestedPath()) == 1 && expr.GetColumnExpr().GetInfo().GetNestedPath()[0] == outputFieldName { return nil diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index 2c879c70e2..ff5be7e7f7 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -56,12 +56,10 @@ func (v *validateUtil) apply(opts ...validateOption) { } } -func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.CollectionSchema, numRows uint64) error { - helper, err := typeutil.CreateSchemaHelper(schema) - if err != nil { - return err +func (v *validateUtil) Validate(data []*schemapb.FieldData, helper *typeutil.SchemaHelper, numRows uint64) error { + if helper == nil { + return merr.WrapErrServiceInternal("nil schema helper provided for Validation") } - for _, field := range data { fieldSchema, err := helper.GetFieldFromName(field.GetFieldName()) if err != nil { @@ -122,7 +120,7 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, schema *schemapb.Col } } - err = v.fillWithValue(data, helper, int(numRows)) + err := v.fillWithValue(data, helper, int(numRows)) if err != nil { return err } diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/validate_util_test.go index 5c4079dbe1..3e56fa1f45 100644 --- a/internal/proxy/validate_util_test.go +++ b/internal/proxy/validate_util_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -1430,8 +1431,10 @@ func Test_validateUtil_Validate(t *testing.T) { } v := newValidateUtil() + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err := v.Validate(data, schema, 100) + err = v.Validate(data, helper, 100) assert.Error(t, err) }) @@ -1560,8 +1563,10 @@ func Test_validateUtil_Validate(t *testing.T) { } v := newValidateUtil(withNANCheck(), withMaxLenCheck()) + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err := v.Validate(data, schema, 2) + err = v.Validate(data, helper, 2) assert.Error(t, err) }) @@ -1690,7 +1695,9 @@ func Test_validateUtil_Validate(t *testing.T) { } v := newValidateUtil(withNANCheck(), withMaxLenCheck()) - err := v.Validate(data, schema, 2) + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + err = v.Validate(data, helper, 2) assert.Error(t, err) // Validate JSON length @@ -1719,7 +1726,9 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, } - err = v.Validate(data, schema, 2) + helper, err = typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + err = v.Validate(data, helper, 2) assert.Error(t, err) }) @@ -1751,8 +1760,9 @@ func Test_validateUtil_Validate(t *testing.T) { } v := newValidateUtil(withOverflowCheck()) - - err := v.Validate(data, schema, 2) + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + err = v.Validate(data, helper, 2) assert.Error(t, err) }) @@ -1788,8 +1798,10 @@ func Test_validateUtil_Validate(t *testing.T) { } v := newValidateUtil() + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err := v.Validate(data, schema, 100) + err = v.Validate(data, helper, 100) assert.Error(t, err) }) @@ -1836,8 +1848,10 @@ func Test_validateUtil_Validate(t *testing.T) { } v := newValidateUtil(withMaxCapCheck()) + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err := v.Validate(data, schema, 1) + err = v.Validate(data, helper, 1) assert.Error(t, err) }) @@ -1889,8 +1903,10 @@ func Test_validateUtil_Validate(t *testing.T) { } v := newValidateUtil(withMaxCapCheck(), withMaxLenCheck()) + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err := v.Validate(data, schema, 1) + err = v.Validate(data, helper, 1) assert.Error(t, err) }) @@ -1932,8 +1948,10 @@ func Test_validateUtil_Validate(t *testing.T) { } v := newValidateUtil(withMaxCapCheck()) + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err := v.Validate(data, schema, 1) + err = v.Validate(data, helper, 1) assert.Error(t, err) }) @@ -1980,8 +1998,10 @@ func Test_validateUtil_Validate(t *testing.T) { } v := newValidateUtil(withMaxCapCheck()) + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err := v.Validate(data, schema, 1) + err = v.Validate(data, helper, 1) assert.Error(t, err) }) @@ -2035,7 +2055,9 @@ func Test_validateUtil_Validate(t *testing.T) { } v := newValidateUtil(withMaxCapCheck()) - err := v.Validate(data, schema, 1) + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + err = v.Validate(data, helper, 1) assert.Error(t, err) data = []*schemapb.FieldData{ @@ -2084,8 +2106,10 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, } + helper, err = typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + err = newValidateUtil(withMaxCapCheck()).Validate(data, helper, 1) assert.Error(t, err) schema = &schemapb.CollectionSchema{ @@ -2103,8 +2127,10 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, } + helper, err = typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + err = newValidateUtil(withMaxCapCheck()).Validate(data, helper, 1) assert.Error(t, err) schema = &schemapb.CollectionSchema{ @@ -2123,7 +2149,10 @@ func Test_validateUtil_Validate(t *testing.T) { }, } - err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + helper, err = typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + + err = newValidateUtil(withMaxCapCheck()).Validate(data, helper, 1) assert.Error(t, err) data = []*schemapb.FieldData{ @@ -2172,7 +2201,9 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, } - err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + helper, err = typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + err = newValidateUtil(withMaxCapCheck()).Validate(data, helper, 1) assert.Error(t, err) data = []*schemapb.FieldData{ @@ -2221,8 +2252,10 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, } + helper, err = typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + err = newValidateUtil(withMaxCapCheck()).Validate(data, helper, 1) assert.Error(t, err) data = []*schemapb.FieldData{ @@ -2271,8 +2304,10 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, } + helper, err = typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + err = newValidateUtil(withMaxCapCheck()).Validate(data, helper, 1) assert.Error(t, err) data = []*schemapb.FieldData{ @@ -2321,8 +2356,10 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, } + helper, err = typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + err = newValidateUtil(withMaxCapCheck()).Validate(data, helper, 1) assert.Error(t, err) }) @@ -2366,8 +2403,9 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, } - - err := newValidateUtil(withMaxCapCheck(), withOverflowCheck()).Validate(data, schema, 1) + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + err = newValidateUtil(withMaxCapCheck(), withOverflowCheck()).Validate(data, helper, 1) assert.Error(t, err) data = []*schemapb.FieldData{ @@ -2409,8 +2447,10 @@ func Test_validateUtil_Validate(t *testing.T) { }, }, } + helper, err = typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err = newValidateUtil(withMaxCapCheck(), withOverflowCheck()).Validate(data, schema, 1) + err = newValidateUtil(withMaxCapCheck(), withOverflowCheck()).Validate(data, helper, 1) assert.Error(t, err) }) @@ -2830,8 +2870,10 @@ func Test_validateUtil_Validate(t *testing.T) { } v := newValidateUtil(withNANCheck(), withMaxLenCheck(), withOverflowCheck(), withMaxCapCheck()) + helper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) - err := v.Validate(data, schema, 2) + err = v.Validate(data, helper, 2) assert.NoError(t, err) })