From 2fe8677cbf3b6561bc62daec4e13d059824da6bb Mon Sep 17 00:00:00 2001 From: Jiquan Long Date: Fri, 29 Apr 2022 18:01:49 +0800 Subject: [PATCH] Enable dimension check in Proxy when create index request received (#16718) Signed-off-by: dragondriver --- internal/proxy/mock_cache_test.go | 43 +++ internal/proxy/task.go | 123 +++++--- internal/proxy/task_test.go | 265 ++++++++++++++++++ internal/util/indexparamcheck/conf_adapter.go | 25 +- .../util/indexparamcheck/conf_adapter_test.go | 21 +- 5 files changed, 426 insertions(+), 51 deletions(-) create mode 100644 internal/proxy/mock_cache_test.go diff --git a/internal/proxy/mock_cache_test.go b/internal/proxy/mock_cache_test.go new file mode 100644 index 0000000000..b5e9eb49a5 --- /dev/null +++ b/internal/proxy/mock_cache_test.go @@ -0,0 +1,43 @@ +package proxy + +import ( + "context" + + "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/util/typeutil" +) + +type getCollectionIDFunc func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) +type getCollectionSchemaFunc func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) + +type mockCache struct { + Cache + getIDFunc getCollectionIDFunc + getSchemaFunc getCollectionSchemaFunc +} + +func (m *mockCache) GetCollectionID(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { + if m.getIDFunc != nil { + return m.getIDFunc(ctx, collectionName) + } + return 0, nil +} + +func (m *mockCache) GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { + if m.getSchemaFunc != nil { + return m.getSchemaFunc(ctx, collectionName) + } + return nil, nil +} + +func (m *mockCache) setGetIDFunc(f getCollectionIDFunc) { + m.getIDFunc = f +} + +func (m *mockCache) setGetSchemaFunc(f getCollectionSchemaFunc) { + m.getSchemaFunc = f +} + +func newMockCache() *mockCache { + return &mockCache{} +} diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 5e622aaed3..162b215f36 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1795,36 +1795,13 @@ func (cit *createIndexTask) OnEnqueue() error { return nil } -func (cit *createIndexTask) PreExecute(ctx context.Context) error { - cit.Base.MsgType = commonpb.MsgType_CreateIndex - cit.Base.SourceID = Params.ProxyCfg.GetNodeID() - - collName, fieldName := cit.CollectionName, cit.FieldName - - col, err := globalMetaCache.GetCollectionInfo(ctx, collName) - if err != nil { - return err - } - cit.collectionID = col.collID - - if err := validateCollectionName(collName); err != nil { - return err - } - - if err := validateFieldName(fieldName); err != nil { - return err - } - - // check index param, not accurate, only some static rules +func parseIndexParams(m []*commonpb.KeyValuePair) (map[string]string, error) { indexParams := make(map[string]string) - for _, kv := range cit.CreateIndexRequest.ExtraParams { + for _, kv := range m { if kv.Key == "params" { // TODO(dragondriver): change `params` to const variable params, err := funcutil.ParseIndexParamsMap(kv.Value) if err != nil { - log.Warn("Failed to parse index params", - zap.String("params", kv.Value), - zap.Error(err)) - continue + return nil, err } for k, v := range params { indexParams[k] = v @@ -1833,23 +1810,68 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error { indexParams[kv.Key] = kv.Value } } - indexType, exist := indexParams["index_type"] // TODO(dragondriver): change `index_type` to const variable + _, exist := indexParams["index_type"] // TODO(dragondriver): change `index_type` to const variable if !exist { - indexType = indexparamcheck.IndexFaissIvfPQ // IVF_PQ is the default index type + indexParams["index_type"] = indexparamcheck.IndexFaissIvfPQ // IVF_PQ is the default index type } + return indexParams, nil +} - //TODO:: add default index type for VarChar type field +func (cit *createIndexTask) getIndexedField(ctx context.Context) (*schemapb.FieldSchema, error) { + schema, err := globalMetaCache.GetCollectionSchema(ctx, cit.GetCollectionName()) + if err != nil { + log.Error("failed to get collection schema", zap.Error(err)) + return nil, fmt.Errorf("failed to get collection schema: %s", err) + } + schemaHelper, err := typeutil.CreateSchemaHelper(schema) + if err != nil { + log.Error("failed to parse collection schema", zap.Error(err)) + return nil, fmt.Errorf("failed to parse collection schema: %s", err) + } + field, err := schemaHelper.GetFieldFromName(cit.GetFieldName()) + if err != nil { + log.Error("create index on non-exist field", zap.Error(err)) + return nil, fmt.Errorf("cannot create index on non-exist field: %s", cit.GetFieldName()) + } + return field, nil +} + +func fillDimension(field *schemapb.FieldSchema, indexParams map[string]string) error { + vecDataTypes := []schemapb.DataType{ + schemapb.DataType_FloatVector, + schemapb.DataType_BinaryVector, + } + if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) { + return nil + } + params := make([]*commonpb.KeyValuePair, 0, len(field.GetTypeParams())+len(field.GetIndexParams())) + params = append(params, field.GetTypeParams()...) + params = append(params, field.GetIndexParams()...) + dimensionInSchema, err := funcutil.GetAttrByKeyFromRepeatedKV("dim", params) + if err != nil { + return fmt.Errorf("dimension not found in schema") + } + dimension, exist := indexParams["dim"] + if exist { + if dimensionInSchema != dimension { + return fmt.Errorf("dimension mismatch, dimension in schema: %s, dimension: %s", dimensionInSchema, dimension) + } + } else { + indexParams["dim"] = dimensionInSchema + } + return nil +} + +func checkTrain(field *schemapb.FieldSchema, indexParams map[string]string) error { + indexType := indexParams["index_type"] // skip params check of non-vector field. vecDataTypes := []schemapb.DataType{ schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, } - - for _, f := range col.schema.GetFields() { - if f.GetName() == fieldName && !funcutil.SliceContain(vecDataTypes, f.GetDataType()) { - return indexparamcheck.CheckIndexValid(f.GetDataType(), indexType, indexParams) - } + if !funcutil.SliceContain(vecDataTypes, field.GetDataType()) { + return indexparamcheck.CheckIndexValid(field.GetDataType(), indexType, indexParams) } adapter, err := indexparamcheck.GetConfAdapterMgrInstance().GetAdapter(indexType) @@ -1858,15 +1880,46 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error { return fmt.Errorf("invalid index type: %s", indexType) } + if err := fillDimension(field, indexParams); err != nil { + return err + } + ok := adapter.CheckTrain(indexParams) if !ok { log.Warn("Create index with invalid params", zap.Any("index_params", indexParams)) - return fmt.Errorf("invalid index params: %v", cit.CreateIndexRequest.ExtraParams) + return fmt.Errorf("invalid index params: %v", indexParams) } return nil } +func (cit *createIndexTask) PreExecute(ctx context.Context) error { + cit.Base.MsgType = commonpb.MsgType_CreateIndex + cit.Base.SourceID = Params.ProxyCfg.GetNodeID() + + collName := cit.CollectionName + + collID, err := globalMetaCache.GetCollectionID(ctx, collName) + if err != nil { + return err + } + cit.collectionID = collID + + field, err := cit.getIndexedField(ctx) + if err != nil { + return err + } + + // check index param, not accurate, only some static rules + indexParams, err := parseIndexParams(cit.GetExtraParams()) + if err != nil { + log.Error("failed to parse index params", zap.Error(err)) + return fmt.Errorf("failed to parse index params: %s", err) + } + + return checkTrain(field, indexParams) +} + func (cit *createIndexTask) Execute(ctx context.Context) error { var err error cit.result, err = cit.rootCoord.CreateIndex(ctx, cit.CreateIndexRequest) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index ea97b6adde..fd7ca0f423 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -21,11 +21,14 @@ import ( "context" "encoding/binary" "encoding/json" + "errors" "math/rand" "strconv" "testing" "time" + "github.com/milvus-io/milvus/internal/util/typeutil" + "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" @@ -2169,3 +2172,265 @@ func TestAlterAlias_all(t *testing.T) { assert.NoError(t, task.Execute(ctx)) assert.NoError(t, task.PostExecute(ctx)) } + +func Test_createIndexTask_getIndexedField(t *testing.T) { + collectionName := "test" + fieldName := "test" + + cit := &createIndexTask{ + CreateIndexRequest: &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: fieldName, + }, + } + + t.Run("normal", func(t *testing.T) { + cache := newMockCache() + cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { + return &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: fieldName, + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: nil, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + AutoID: false, + }, + }, + }, nil + }) + globalMetaCache = cache + field, err := cit.getIndexedField(context.Background()) + assert.NoError(t, err) + assert.Equal(t, fieldName, field.GetName()) + }) + + t.Run("schema not found", func(t *testing.T) { + cache := newMockCache() + cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { + return nil, errors.New("mock") + }) + globalMetaCache = cache + _, err := cit.getIndexedField(context.Background()) + assert.Error(t, err) + }) + + t.Run("invalid schema", func(t *testing.T) { + cache := newMockCache() + cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { + return &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: fieldName, + }, + { + Name: fieldName, // duplicate + }, + }, + }, nil + }) + globalMetaCache = cache + _, err := cit.getIndexedField(context.Background()) + assert.Error(t, err) + }) + + t.Run("field not found", func(t *testing.T) { + cache := newMockCache() + cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { + return &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: fieldName + fieldName, + }, + }, + }, nil + }) + globalMetaCache = cache + _, err := cit.getIndexedField(context.Background()) + assert.Error(t, err) + }) +} + +func Test_fillDimension(t *testing.T) { + t.Run("scalar", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_Int64, + } + assert.NoError(t, fillDimension(f, nil)) + }) + + t.Run("no dim in schema", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_FloatVector, + } + assert.Error(t, fillDimension(f, nil)) + }) + + t.Run("dimension mismatch", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_FloatVector, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + } + assert.Error(t, fillDimension(f, map[string]string{"dim": "8"})) + }) + + t.Run("normal", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_FloatVector, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + } + m := map[string]string{} + assert.NoError(t, fillDimension(f, m)) + assert.Equal(t, "128", m["dim"]) + }) +} + +func Test_checkTrain(t *testing.T) { + t.Run("normal", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_FloatVector, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + } + m := map[string]string{ + "index_type": "IVF_FLAT", + "nlist": "1024", + "metric_type": "L2", + } + assert.NoError(t, checkTrain(f, m)) + }) + + t.Run("scalar", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_Int64, + } + m := map[string]string{ + "index_type": "scalar", + } + assert.NoError(t, checkTrain(f, m)) + }) + + t.Run("dimension mismatch", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_FloatVector, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + } + m := map[string]string{ + "index_type": "IVF_FLAT", + "nlist": "1024", + "metric_type": "L2", + "dim": "8", + } + assert.Error(t, checkTrain(f, m)) + }) + + t.Run("invalid params", func(t *testing.T) { + f := &schemapb.FieldSchema{ + DataType: schemapb.DataType_FloatVector, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + } + m := map[string]string{ + "index_type": "IVF_FLAT", + "metric_type": "L2", + } + assert.Error(t, checkTrain(f, m)) + }) +} + +func Test_createIndexTask_PreExecute(t *testing.T) { + collectionName := "test" + fieldName := "test" + + cit := &createIndexTask{ + CreateIndexRequest: &milvuspb.CreateIndexRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_CreateIndex, + }, + CollectionName: collectionName, + FieldName: fieldName, + }, + } + + t.Run("normal", func(t *testing.T) { + cache := newMockCache() + cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { + return 100, nil + }) + cache.setGetSchemaFunc(func(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) { + return &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: fieldName, + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: nil, + IndexParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + AutoID: false, + }, + }, + }, nil + }) + globalMetaCache = cache + cit.CreateIndexRequest.ExtraParams = []*commonpb.KeyValuePair{ + { + Key: "index_type", + Value: "IVF_FLAT", + }, + { + Key: "nlist", + Value: "1024", + }, + { + Key: "metric_type", + Value: "L2", + }, + } + assert.NoError(t, cit.PreExecute(context.Background())) + }) + + t.Run("collection not found", func(t *testing.T) { + cache := newMockCache() + cache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { + return 0, errors.New("mock") + }) + globalMetaCache = cache + assert.Error(t, cit.PreExecute(context.Background())) + }) +} diff --git a/internal/util/indexparamcheck/conf_adapter.go b/internal/util/indexparamcheck/conf_adapter.go index 13c4e6976b..c687cd7c17 100644 --- a/internal/util/indexparamcheck/conf_adapter.go +++ b/internal/util/indexparamcheck/conf_adapter.go @@ -131,10 +131,9 @@ type BaseConfAdapter struct { // CheckTrain check whether the params contains supported metrics types func (adapter *BaseConfAdapter) CheckTrain(params map[string]string) bool { - // dimension is specified when create collection - //if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) { - // return false - //} + if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) { + return false + } return CheckStrByValues(params, Metric, METRICS) } @@ -179,8 +178,8 @@ func (adapter *IVFPQConfAdapter) CheckTrain(params map[string]string) bool { func (adapter *IVFPQConfAdapter) checkPQParams(params map[string]string) bool { dimStr, dimensionExist := params[DIM] - if !dimensionExist { // dimension is specified when creating collection - return true + if !dimensionExist { + return false } dimension, err := strconv.Atoi(dimStr) @@ -260,10 +259,9 @@ type BinIDMAPConfAdapter struct { // CheckTrain checks if a binary flat index can be built with the specific parameters. func (adapter *BinIDMAPConfAdapter) CheckTrain(params map[string]string) bool { - // dimension is specified when create collection - //if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) { - // return false - //} + if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) { + return false + } return CheckStrByValues(params, Metric, BinIDMapMetrics) } @@ -278,10 +276,9 @@ type BinIVFConfAdapter struct { // CheckTrain checks if a binary ivf index can be built with specific parameters. func (adapter *BinIVFConfAdapter) CheckTrain(params map[string]string) bool { - // dimension is specified when create collection - //if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) { - // return false - //} + if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) { + return false + } if !CheckIntByRange(params, NLIST, MinNList, MaxNList) { return false diff --git a/internal/util/indexparamcheck/conf_adapter_test.go b/internal/util/indexparamcheck/conf_adapter_test.go index 9ad513585c..dd351077e5 100644 --- a/internal/util/indexparamcheck/conf_adapter_test.go +++ b/internal/util/indexparamcheck/conf_adapter_test.go @@ -12,6 +12,7 @@ package indexparamcheck import ( + "fmt" "strconv" "testing" ) @@ -50,11 +51,15 @@ func TestBaseConfAdapter_CheckTrain(t *testing.T) { DIM: strconv.Itoa(128), Metric: L2, } + paramsWithoutDim := map[string]string{ + Metric: L2, + } cases := []struct { params map[string]string want bool }{ {validParams, true}, + {paramsWithoutDim, false}, } adapter := newBaseConfAdapter() @@ -141,7 +146,7 @@ func TestIVFPQConfAdapter_CheckTrain(t *testing.T) { {validParamsWithoutNbits, true}, {invalidIVFParamsMin(), false}, {invalidIVFParamsMax(), false}, - {validParamsWithoutDim, true}, + {validParamsWithoutDim, false}, {invalidParamsDim, false}, {invalidParamsNbits, false}, {invalidParamsWithoutIVF, false}, @@ -150,8 +155,9 @@ func TestIVFPQConfAdapter_CheckTrain(t *testing.T) { } adapter := newIVFPQConfAdapter() - for _, test := range cases { + for i, test := range cases { if got := adapter.CheckTrain(test.params); got != test.want { + fmt.Printf("i: %d, params: %v\n", i, test.params) t.Errorf("IVFPQConfAdapter.CheckTrain(%v) = %v", test.params, test.want) } } @@ -187,11 +193,15 @@ func TestBinIDMAPConfAdapter_CheckTrain(t *testing.T) { DIM: strconv.Itoa(128), Metric: JACCARD, } + paramsWithoutDim := map[string]string{ + Metric: JACCARD, + } cases := []struct { params map[string]string want bool }{ {validParams, true}, + {paramsWithoutDim, false}, } adapter := newBinIDMAPConfAdapter() @@ -211,6 +221,12 @@ func TestBinIVFConfAdapter_CheckTrain(t *testing.T) { NBITS: strconv.Itoa(8), Metric: JACCARD, } + paramsWithoutDim := map[string]string{ + NLIST: strconv.Itoa(100), + IVFM: strconv.Itoa(4), + NBITS: strconv.Itoa(8), + Metric: JACCARD, + } invalidParams := copyParams(validParams) invalidParams[Metric] = L2 @@ -220,6 +236,7 @@ func TestBinIVFConfAdapter_CheckTrain(t *testing.T) { want bool }{ {validParams, true}, + {paramsWithoutDim, false}, {invalidIVFParamsMin(), false}, {invalidIVFParamsMax(), false}, {invalidParams, false},