diff --git a/internal/datacoord/index_service.go b/internal/datacoord/index_service.go index d849f99061..a62e2c5f14 100644 --- a/internal/datacoord/index_service.go +++ b/internal/datacoord/index_service.go @@ -19,6 +19,7 @@ package datacoord import ( "context" "fmt" + "strconv" "time" "github.com/samber/lo" @@ -27,8 +28,10 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -237,6 +240,44 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques return merr.Success(), nil } +func ValidateIndexParams(index *model.Index, key, value string) error { + switch key { + case common.MmapEnabledKey: + indexType := getIndexType(index.IndexParams) + if !indexparamcheck.IsMmapSupported(indexType) { + return merr.WrapErrParameterInvalidMsg("index type %s does not support mmap", indexType) + } + + if _, err := strconv.ParseBool(value); err != nil { + return merr.WrapErrParameterInvalidMsg("invalid %s value: %s, expected: true, false", key, value) + } + } + return nil +} + +func UpdateParams(index *model.Index, from []*commonpb.KeyValuePair, updates []*commonpb.KeyValuePair) ([]*commonpb.KeyValuePair, error) { + params := make(map[string]string) + for _, param := range from { + params[param.GetKey()] = param.GetValue() + } + + // update the params + for _, param := range updates { + if err := ValidateIndexParams(index, param.GetKey(), param.GetValue()); err != nil { + log.Warn("failed to alter index params", zap.Error(err)) + return nil, err + } + params[param.GetKey()] = param.GetValue() + } + + return lo.MapToSlice(params, func(k string, v string) *commonpb.KeyValuePair { + return &commonpb.KeyValuePair{ + Key: k, + Value: v, + } + }), nil +} + func (s *Server) AlterIndex(ctx context.Context, req *indexpb.AlterIndexRequest) (*commonpb.Status, error) { log := log.Ctx(ctx).With( zap.Int64("collectionID", req.GetCollectionID()), @@ -250,27 +291,28 @@ func (s *Server) AlterIndex(ctx context.Context, req *indexpb.AlterIndexRequest) } indexes := s.meta.GetIndexesForCollection(req.GetCollectionID(), req.GetIndexName()) - params := make(map[string]string) for _, index := range indexes { - for _, param := range index.UserIndexParams { - params[param.GetKey()] = param.GetValue() + // update user index params + newUserIndexParams, err := UpdateParams(index, index.UserIndexParams, req.GetParams()) + if err != nil { + return merr.Status(err), nil } - - // update the index params - for _, param := range req.GetParams() { - params[param.GetKey()] = param.GetValue() - } - - log.Info("prepare to alter index", + log.Info("alter index user index params", zap.String("indexName", index.IndexName), - zap.Any("params", params), + zap.Any("params", newUserIndexParams), ) - index.UserIndexParams = lo.MapToSlice(params, func(k string, v string) *commonpb.KeyValuePair { - return &commonpb.KeyValuePair{ - Key: k, - Value: v, - } - }) + index.UserIndexParams = newUserIndexParams + + // update index params + newIndexParams, err := UpdateParams(index, index.IndexParams, req.GetParams()) + if err != nil { + return merr.Status(err), nil + } + log.Info("alter index user index params", + zap.String("indexName", index.IndexName), + zap.Any("params", newIndexParams), + ) + index.IndexParams = newIndexParams } err := s.meta.AlterIndex(ctx, indexes...) diff --git a/internal/datacoord/index_service_test.go b/internal/datacoord/index_service_test.go index 59d0fef4c8..de9121a5e7 100644 --- a/internal/datacoord/index_service_test.go +++ b/internal/datacoord/index_service_test.go @@ -41,6 +41,7 @@ import ( "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -579,6 +580,24 @@ func TestServer_AlterIndex(t *testing.T) { s.stateCode.Store(commonpb.StateCode_Healthy) + t.Run("mmap_unsupported", func(t *testing.T) { + indexParams[0].Value = indexparamcheck.IndexRaftCagra + + resp, err := s.AlterIndex(ctx, req) + assert.NoError(t, err) + assert.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid) + + indexParams[0].Value = indexparamcheck.IndexFaissIvfFlat + }) + + t.Run("param_value_invalied", func(t *testing.T) { + req.Params[0].Value = "abc" + resp, err := s.AlterIndex(ctx, req) + assert.ErrorIs(t, merr.CheckRPCCall(resp, err), merr.ErrParameterInvalid) + + req.Params[0].Value = "true" + }) + t.Run("success", func(t *testing.T) { resp, err := s.AlterIndex(ctx, req) assert.NoError(t, merr.CheckRPCCall(resp, err)) diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 43add9b112..bf768d533c 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -504,6 +504,10 @@ func (t *alterIndexTask) PreExecute(ctx context.Context) error { } t.collectionID = collection + if len(t.req.GetIndexName()) == 0 { + return merr.WrapErrParameterInvalidMsg("index name is empty") + } + if err = validateIndexName(t.req.GetIndexName()); err != nil { return err } diff --git a/pkg/util/indexparamcheck/index_type.go b/pkg/util/indexparamcheck/index_type.go index 4b8291ed9d..b6fb43049e 100644 --- a/pkg/util/indexparamcheck/index_type.go +++ b/pkg/util/indexparamcheck/index_type.go @@ -38,3 +38,13 @@ func IsGpuIndex(indexType IndexType) bool { indexType == IndexRaftIvfPQ || indexType == IndexRaftCagra } + +func IsMmapSupported(indexType IndexType) bool { + return indexType == IndexFaissIDMap || + indexType == IndexFaissIvfFlat || + indexType == IndexFaissIvfPQ || + indexType == IndexFaissIvfSQ8 || + indexType == IndexFaissBinIDMap || + indexType == IndexFaissBinIvfFlat || + indexType == IndexHNSW +}