diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 840e758d10..71537756de 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -553,3 +553,7 @@ trace: sampleFraction: 0 jaeger: url: # when exporter is jaeger should set the jaeger's URL + +autoIndex: + params: + build: '{"M": 30,"efConstruction": 360,"index_type": "HNSW", "metric_type": "IP"}' diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 9e6e3283fc..69610606e6 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -61,6 +61,7 @@ type createIndexTask struct { isAutoIndex bool newIndexParams []*commonpb.KeyValuePair newTypeParams []*commonpb.KeyValuePair + newExtraParams []*commonpb.KeyValuePair collectionID UniqueID fieldSchema *schemapb.FieldSchema @@ -103,7 +104,22 @@ func (cit *createIndexTask) OnEnqueue() error { return nil } +func wrapUserIndexParams(metricType string) []*commonpb.KeyValuePair { + return []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: AutoIndexName, + }, + { + Key: common.MetricTypeKey, + Value: metricType, + }, + } +} + func (cit *createIndexTask) parseIndexParams() error { + cit.newExtraParams = cit.req.GetExtraParams() + isVecIndex := typeutil.IsVectorType(cit.fieldSchema.DataType) indexParamsMap := make(map[string]string) if !isVecIndex { @@ -133,17 +149,73 @@ func (cit *createIndexTask) parseIndexParams() error { if isVecIndex { specifyIndexType, exist := indexParamsMap[common.IndexTypeKey] - if Params.AutoIndexConfig.Enable.GetAsBool() { + if Params.AutoIndexConfig.Enable.GetAsBool() { // `enable` only for cloud instance. log.Info("create index trigger AutoIndex", zap.String("original type", specifyIndexType), zap.String("final type", Params.AutoIndexConfig.AutoIndexTypeName.GetValue())) + + metricType, metricTypeExist := indexParamsMap[common.MetricTypeKey] + // override params by autoindex for k, v := range Params.AutoIndexConfig.IndexParams.GetAsJSONMap() { indexParamsMap[k] = v } - } else { + + if metricTypeExist { + // make the users' metric type first class citizen. + indexParamsMap[common.MetricTypeKey] = metricType + } + } else { // behavior change after 2.2.9, adapt autoindex logic here. + autoIndexConfig := Params.AutoIndexConfig.IndexParams.GetAsJSONMap() + + useAutoIndex := func() { + fields := make([]zap.Field, 0, len(autoIndexConfig)) + for k, v := range autoIndexConfig { + indexParamsMap[k] = v + fields = append(fields, zap.String(k, v)) + } + log.Ctx(cit.ctx).Info("AutoIndex triggered", fields...) + } + + handle := func(numberParams int) error { + // empty case. + if len(indexParamsMap) == numberParams { + // though we already know there must be metric type, how to make this safer to avoid crash? + metricType := autoIndexConfig[common.MetricTypeKey] + cit.newExtraParams = wrapUserIndexParams(metricType) + useAutoIndex() + return nil + } + + metricType, metricTypeExist := indexParamsMap[common.MetricTypeKey] + + if len(indexParamsMap) > numberParams+1 { + return fmt.Errorf("only metric type can be passed when use AutoIndex") + } + + if len(indexParamsMap) == numberParams+1 { + if !metricTypeExist { + return fmt.Errorf("only metric type can be passed when use AutoIndex") + } + + // only metric type is passed. + cit.newExtraParams = wrapUserIndexParams(metricType) + useAutoIndex() + // make the users' metric type first class citizen. + indexParamsMap[common.MetricTypeKey] = metricType + } + + return nil + } + if !exist { - return fmt.Errorf("IndexType not specified") + if err := handle(0); err != nil { + return err + } + } else if specifyIndexType == AutoIndexName { + if err := handle(1); err != nil { + return err + } } } @@ -298,9 +370,11 @@ func (cit *createIndexTask) PreExecute(ctx context.Context) error { } func (cit *createIndexTask) Execute(ctx context.Context) error { - log.Debug("proxy create index", zap.Int64("collID", cit.collectionID), zap.Int64("fieldID", cit.fieldSchema.GetFieldID()), + log.Ctx(ctx).Info("proxy create index", zap.Int64("collID", cit.collectionID), zap.Int64("fieldID", cit.fieldSchema.GetFieldID()), zap.String("indexName", cit.req.GetIndexName()), zap.Any("typeParams", cit.fieldSchema.GetTypeParams()), - zap.Any("indexParams", cit.req.GetExtraParams())) + zap.Any("indexParams", cit.req.GetExtraParams()), + zap.Any("newExtraParams", cit.newExtraParams), + ) if cit.req.GetIndexName() == "" { cit.req.IndexName = Params.CommonCfg.DefaultIndexName.GetValue() + "_" + strconv.FormatInt(cit.fieldSchema.GetFieldID(), 10) @@ -313,7 +387,7 @@ func (cit *createIndexTask) Execute(ctx context.Context) error { TypeParams: cit.newTypeParams, IndexParams: cit.newIndexParams, IsAutoIndex: cit.isAutoIndex, - UserIndexParams: cit.req.GetExtraParams(), + UserIndexParams: cit.newExtraParams, Timestamp: cit.BeginTs(), } cit.result, err = cit.datacoord.CreateIndex(ctx, req) diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 4625cab58c..c40fd1fe4c 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -22,6 +22,8 @@ import ( "os" "testing" + "github.com/milvus-io/milvus/pkg/config" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -492,3 +494,120 @@ func Test_parseIndexParams(t *testing.T) { assert.Error(t, err) }) } + +func Test_wrapUserIndexParams(t *testing.T) { + params := wrapUserIndexParams("L2") + assert.Equal(t, 2, len(params)) + assert.Equal(t, "index_type", params[0].Key) + assert.Equal(t, AutoIndexName, params[0].Value) + assert.Equal(t, "metric_type", params[1].Key) + assert.Equal(t, "L2", params[1].Value) +} + +func Test_parseIndexParams_AutoIndex(t *testing.T) { + Params.Init() + mgr := config.NewManager() + mgr.SetConfig("autoIndex.enable", "false") + mgr.SetConfig("autoIndex.params.build", `{"M": 30,"efConstruction": 360,"index_type": "HNSW", "metric_type": "IP"}`) + Params.AutoIndexConfig.Enable.Init(mgr) + Params.AutoIndexConfig.IndexParams.Init(mgr) + autoIndexConfig := Params.AutoIndexConfig.IndexParams.GetAsJSONMap() + fieldSchema := &schemapb.FieldSchema{ + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "8"}, + }, + } + + t.Run("case 1, empty parameters", func(t *testing.T) { + task := &createIndexTask{ + fieldSchema: fieldSchema, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: make([]*commonpb.KeyValuePair, 0), + }, + } + err := task.parseIndexParams() + assert.NoError(t, err) + assert.ElementsMatch(t, []*commonpb.KeyValuePair{ + {Key: common.IndexTypeKey, Value: AutoIndexName}, + {Key: common.MetricTypeKey, Value: autoIndexConfig[common.MetricTypeKey]}, + }, task.newExtraParams) + }) + + t.Run("case 2, only metric type passed", func(t *testing.T) { + task := &createIndexTask{ + fieldSchema: fieldSchema, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + {Key: common.MetricTypeKey, Value: "L2"}, + }, + }, + } + err := task.parseIndexParams() + assert.NoError(t, err) + assert.ElementsMatch(t, []*commonpb.KeyValuePair{ + {Key: common.IndexTypeKey, Value: AutoIndexName}, + {Key: common.MetricTypeKey, Value: "L2"}, + }, task.newExtraParams) + }) + + t.Run("case 3, AutoIndex & metric_type passed", func(t *testing.T) { + task := &createIndexTask{ + fieldSchema: fieldSchema, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + {Key: common.MetricTypeKey, Value: "L2"}, + {Key: common.IndexTypeKey, Value: AutoIndexName}, + }, + }, + } + err := task.parseIndexParams() + assert.NoError(t, err) + assert.ElementsMatch(t, []*commonpb.KeyValuePair{ + {Key: common.IndexTypeKey, Value: AutoIndexName}, + {Key: common.MetricTypeKey, Value: "L2"}, + }, task.newExtraParams) + }) + + t.Run("case 4, duplicate and useless parameters passed", func(t *testing.T) { + task := &createIndexTask{ + fieldSchema: fieldSchema, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + {Key: "not important", Value: "L2"}, + }, + }, + } + err := task.parseIndexParams() + assert.Error(t, err) + }) + + t.Run("case 5, duplicate and useless parameters passed", func(t *testing.T) { + task := &createIndexTask{ + fieldSchema: fieldSchema, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + {Key: common.MetricTypeKey, Value: "L2"}, + {Key: "not important", Value: "L2"}, + }, + }, + } + err := task.parseIndexParams() + assert.Error(t, err) + }) + + t.Run("case 6, autoindex & duplicate", func(t *testing.T) { + task := &createIndexTask{ + fieldSchema: fieldSchema, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + {Key: common.IndexTypeKey, Value: AutoIndexName}, + {Key: common.MetricTypeKey, Value: "L2"}, + {Key: "not important", Value: "L2"}, + }, + }, + } + err := task.parseIndexParams() + assert.Error(t, err) + }) +} diff --git a/pkg/util/funcutil/func.go b/pkg/util/funcutil/func.go index aa8e93a41a..5881e2ef44 100644 --- a/pkg/util/funcutil/func.go +++ b/pkg/util/funcutil/func.go @@ -76,6 +76,12 @@ func JSONToMap(mStr string) (map[string]string, error) { return ret, nil } +func MapToJSON(m map[string]string) []byte { + // error won't happen here. + bs, _ := json.Marshal(m) + return bs +} + const ( // PulsarMaxMessageSizeKey is the key of config item PulsarMaxMessageSizeKey = "maxMessageSize" diff --git a/pkg/util/funcutil/func_test.go b/pkg/util/funcutil/func_test.go index bc4279109e..aa2037071c 100644 --- a/pkg/util/funcutil/func_test.go +++ b/pkg/util/funcutil/func_test.go @@ -21,6 +21,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "reflect" "strconv" "testing" "time" @@ -384,3 +385,13 @@ func TestUserRoleCache(t *testing.T) { _, _, err = DecodeUserRoleCache("foo") assert.Error(t, err) } + +func TestMapToJSON(t *testing.T) { + s := `{"M": 30,"efConstruction": 360,"index_type": "HNSW", "metric_type": "IP"}` + m, err := JSONToMap(s) + assert.NoError(t, err) + j := MapToJSON(m) + got, err := JSONToMap(string(j)) + assert.NoError(t, err) + assert.True(t, reflect.DeepEqual(m, got)) +} diff --git a/pkg/util/indexparamcheck/base_checker.go b/pkg/util/indexparamcheck/base_checker.go index 88935e820c..7d8d8af18a 100644 --- a/pkg/util/indexparamcheck/base_checker.go +++ b/pkg/util/indexparamcheck/base_checker.go @@ -1,13 +1,14 @@ package indexparamcheck import ( + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/schemapb" ) type baseChecker struct { } -func (c *baseChecker) CheckTrain(params map[string]string) error { +func (c baseChecker) CheckTrain(params map[string]string) error { if !CheckIntByRange(params, DIM, DefaultMinDim, DefaultMaxDim) { return errOutOfRange(DIM, DefaultMinDim, DefaultMaxDim) } @@ -16,10 +17,16 @@ func (c *baseChecker) CheckTrain(params map[string]string) error { } // CheckValidDataType check whether the field data type is supported for the index type -func (c *baseChecker) CheckValidDataType(dType schemapb.DataType) error { +func (c baseChecker) CheckValidDataType(dType schemapb.DataType) error { return nil } +func (c baseChecker) SetDefaultMetricTypeIfNotExist(m map[string]string) {} + +func (c baseChecker) StaticCheck(params map[string]string) error { + return errors.New("unsupported index type") +} + func newBaseChecker() IndexChecker { return &baseChecker{} } diff --git a/pkg/util/indexparamcheck/base_checker_test.go b/pkg/util/indexparamcheck/base_checker_test.go index a1fe26521a..6e55107cbf 100644 --- a/pkg/util/indexparamcheck/base_checker_test.go +++ b/pkg/util/indexparamcheck/base_checker_test.go @@ -105,3 +105,8 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) { } } } + +func Test_baseChecker_StaticCheck(t *testing.T) { + // TODO + assert.Error(t, newBaseChecker().StaticCheck(nil)) +} diff --git a/pkg/util/indexparamcheck/bin_flat_checker.go b/pkg/util/indexparamcheck/bin_flat_checker.go index bd8bf886e4..2e0b813c38 100644 --- a/pkg/util/indexparamcheck/bin_flat_checker.go +++ b/pkg/util/indexparamcheck/bin_flat_checker.go @@ -4,11 +4,14 @@ type binFlatChecker struct { binaryVectorBaseChecker } -// CheckTrain checks if a binary flat index can be built with the specific parameters. -func (c *binFlatChecker) CheckTrain(params map[string]string) error { +func (c binFlatChecker) CheckTrain(params map[string]string) error { return c.binaryVectorBaseChecker.CheckTrain(params) } +func (c binFlatChecker) StaticCheck(params map[string]string) error { + return c.staticCheck(params) +} + func newBinFlatChecker() IndexChecker { return &binFlatChecker{} } diff --git a/pkg/util/indexparamcheck/bin_ivf_flat_checker.go b/pkg/util/indexparamcheck/bin_ivf_flat_checker.go index 2a804c1df4..dfcbc316a6 100644 --- a/pkg/util/indexparamcheck/bin_ivf_flat_checker.go +++ b/pkg/util/indexparamcheck/bin_ivf_flat_checker.go @@ -8,11 +8,7 @@ type binIVFFlatChecker struct { binaryVectorBaseChecker } -func (c *binIVFFlatChecker) CheckTrain(params map[string]string) error { - if err := c.binaryVectorBaseChecker.CheckTrain(params); err != nil { - return err - } - +func (c binIVFFlatChecker) StaticCheck(params map[string]string) error { if !CheckStrByValues(params, Metric, BinIvfMetrics) { return fmt.Errorf("metric type not found or not supported, supported: %v", BinIvfMetrics) } @@ -24,6 +20,14 @@ func (c *binIVFFlatChecker) CheckTrain(params map[string]string) error { return nil } +func (c binIVFFlatChecker) CheckTrain(params map[string]string) error { + if err := c.binaryVectorBaseChecker.CheckTrain(params); err != nil { + return err + } + + return c.StaticCheck(params) +} + func newBinIVFFlatChecker() IndexChecker { return &binIVFFlatChecker{} } diff --git a/pkg/util/indexparamcheck/binary_vector_base_checker.go b/pkg/util/indexparamcheck/binary_vector_base_checker.go index 74f84cb9af..89b64541c8 100644 --- a/pkg/util/indexparamcheck/binary_vector_base_checker.go +++ b/pkg/util/indexparamcheck/binary_vector_base_checker.go @@ -3,6 +3,8 @@ package indexparamcheck import ( "fmt" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus-proto/go-api/schemapb" ) @@ -10,11 +12,7 @@ type binaryVectorBaseChecker struct { baseChecker } -func (c *binaryVectorBaseChecker) CheckTrain(params map[string]string) error { - if err := c.baseChecker.CheckTrain(params); err != nil { - return err - } - +func (c binaryVectorBaseChecker) staticCheck(params map[string]string) error { if !CheckStrByValues(params, Metric, BinIDMapMetrics) { return fmt.Errorf("metric type not found or not supported, supported: %v", BinIDMapMetrics) } @@ -22,13 +20,25 @@ func (c *binaryVectorBaseChecker) CheckTrain(params map[string]string) error { return nil } -func (c *binaryVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error { +func (c binaryVectorBaseChecker) CheckTrain(params map[string]string) error { + if err := c.baseChecker.CheckTrain(params); err != nil { + return err + } + + return c.staticCheck(params) +} + +func (c binaryVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error { if dType != schemapb.DataType_BinaryVector { return fmt.Errorf("binary vector is only supported") } return nil } +func (c binaryVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string) { + setDefaultIfNotExist(params, common.MetricTypeKey, BinaryVectorDefaultMetricType) +} + func newBinaryVectorBaseChecker() IndexChecker { return &binaryVectorBaseChecker{} } diff --git a/pkg/util/indexparamcheck/constraints.go b/pkg/util/indexparamcheck/constraints.go index 0876d72d15..75be6d06e1 100644 --- a/pkg/util/indexparamcheck/constraints.go +++ b/pkg/util/indexparamcheck/constraints.go @@ -72,3 +72,8 @@ var BinIDMapMetrics = []string{HAMMING, JACCARD, TANIMOTO, SUBSTRUCTURE, SUPERST var BinIvfMetrics = []string{HAMMING, JACCARD, TANIMOTO} // const var supportDimPerSubQuantizer = []int{32, 28, 24, 20, 16, 12, 10, 8, 6, 4, 3, 2, 1} // const var supportSubQuantizer = []int{96, 64, 56, 48, 40, 32, 28, 24, 20, 16, 12, 8, 4, 3, 2, 1} // const + +const ( + FloatVectorDefaultMetricType = IP + BinaryVectorDefaultMetricType = JACCARD +) diff --git a/pkg/util/indexparamcheck/diskann_checker.go b/pkg/util/indexparamcheck/diskann_checker.go index 129754419f..05d581ef5a 100644 --- a/pkg/util/indexparamcheck/diskann_checker.go +++ b/pkg/util/indexparamcheck/diskann_checker.go @@ -5,11 +5,15 @@ type diskannChecker struct { floatVectorBaseChecker } -func (c *diskannChecker) CheckTrain(params map[string]string) error { +func (c diskannChecker) StaticCheck(params map[string]string) error { + return c.staticCheck(params) +} + +func (c diskannChecker) CheckTrain(params map[string]string) error { if !CheckIntByRange(params, DIM, DiskAnnMinDim, DefaultMaxDim) { return errOutOfRange(DIM, DiskAnnMinDim, DefaultMaxDim) } - return c.floatVectorBaseChecker.CheckTrain(params) + return c.StaticCheck(params) } func newDiskannChecker() IndexChecker { diff --git a/pkg/util/indexparamcheck/float_vector_base_checker.go b/pkg/util/indexparamcheck/float_vector_base_checker.go index f9d119decf..b442b85f81 100644 --- a/pkg/util/indexparamcheck/float_vector_base_checker.go +++ b/pkg/util/indexparamcheck/float_vector_base_checker.go @@ -3,6 +3,8 @@ package indexparamcheck import ( "fmt" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus-proto/go-api/schemapb" ) @@ -10,11 +12,7 @@ type floatVectorBaseChecker struct { baseChecker } -func (c *floatVectorBaseChecker) CheckTrain(params map[string]string) error { - if err := c.baseChecker.CheckTrain(params); err != nil { - return err - } - +func (c floatVectorBaseChecker) staticCheck(params map[string]string) error { if !CheckStrByValues(params, Metric, METRICS) { return fmt.Errorf("metric type not found or not supported, supported: %v", METRICS) } @@ -22,13 +20,25 @@ func (c *floatVectorBaseChecker) CheckTrain(params map[string]string) error { return nil } -func (c *floatVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error { +func (c floatVectorBaseChecker) CheckTrain(params map[string]string) error { + if err := c.baseChecker.CheckTrain(params); err != nil { + return err + } + + return c.staticCheck(params) +} + +func (c floatVectorBaseChecker) CheckValidDataType(dType schemapb.DataType) error { if dType != schemapb.DataType_FloatVector { return fmt.Errorf("float vector is only supported") } return nil } +func (c floatVectorBaseChecker) SetDefaultMetricTypeIfNotExist(params map[string]string) { + setDefaultIfNotExist(params, common.MetricTypeKey, FloatVectorDefaultMetricType) +} + func newFloatVectorBaseChecker() IndexChecker { return &floatVectorBaseChecker{} } diff --git a/pkg/util/indexparamcheck/hnsw_checker.go b/pkg/util/indexparamcheck/hnsw_checker.go index 2ee3abc565..a47d744db0 100644 --- a/pkg/util/indexparamcheck/hnsw_checker.go +++ b/pkg/util/indexparamcheck/hnsw_checker.go @@ -4,7 +4,7 @@ type hnswChecker struct { floatVectorBaseChecker } -func (c *hnswChecker) CheckTrain(params map[string]string) error { +func (c hnswChecker) StaticCheck(params map[string]string) error { if !CheckIntByRange(params, EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) { return errOutOfRange(EFConstruction, HNSWMinEfConstruction, HNSWMaxEfConstruction) } @@ -13,6 +13,13 @@ func (c *hnswChecker) CheckTrain(params map[string]string) error { return errOutOfRange(HNSWM, HNSWMinM, HNSWMaxM) } + return c.floatVectorBaseChecker.staticCheck(params) +} + +func (c hnswChecker) CheckTrain(params map[string]string) error { + if err := c.StaticCheck(params); err != nil { + return err + } return c.floatVectorBaseChecker.CheckTrain(params) } diff --git a/pkg/util/indexparamcheck/index_checker.go b/pkg/util/indexparamcheck/index_checker.go index 7111dfcb2e..abaa08cf93 100644 --- a/pkg/util/indexparamcheck/index_checker.go +++ b/pkg/util/indexparamcheck/index_checker.go @@ -23,4 +23,6 @@ import ( type IndexChecker interface { CheckTrain(map[string]string) error CheckValidDataType(dType schemapb.DataType) error + SetDefaultMetricTypeIfNotExist(map[string]string) + StaticCheck(map[string]string) error } diff --git a/pkg/util/indexparamcheck/ivf_base_checker.go b/pkg/util/indexparamcheck/ivf_base_checker.go index 244de4761f..9b8a3e2e04 100644 --- a/pkg/util/indexparamcheck/ivf_base_checker.go +++ b/pkg/util/indexparamcheck/ivf_base_checker.go @@ -4,13 +4,20 @@ type ivfBaseChecker struct { floatVectorBaseChecker } -func (c *ivfBaseChecker) CheckTrain(params map[string]string) error { +func (c ivfBaseChecker) StaticCheck(params map[string]string) error { if !CheckIntByRange(params, NLIST, MinNList, MaxNList) { return errOutOfRange(NLIST, MinNList, MaxNList) } // skip check number of rows + return c.floatVectorBaseChecker.staticCheck(params) +} + +func (c ivfBaseChecker) CheckTrain(params map[string]string) error { + if err := c.StaticCheck(params); err != nil { + return err + } return c.floatVectorBaseChecker.CheckTrain(params) } diff --git a/pkg/util/indexparamcheck/utils.go b/pkg/util/indexparamcheck/utils.go index 1dad7fcf65..adca93aeb3 100644 --- a/pkg/util/indexparamcheck/utils.go +++ b/pkg/util/indexparamcheck/utils.go @@ -62,3 +62,10 @@ func CheckStrByValues(params map[string]string, key string, container []string) func errOutOfRange(x interface{}, lb interface{}, ub interface{}) error { return fmt.Errorf("%v out of range: [%v, %v]", x, lb, ub) } + +func setDefaultIfNotExist(params map[string]string, key string, defaultValue string) { + _, exist := params[key] + if !exist { + params[key] = defaultValue + } +} diff --git a/pkg/util/paramtable/autoindex_param.go b/pkg/util/paramtable/autoindex_param.go index 00a1c99514..e540edafc0 100644 --- a/pkg/util/paramtable/autoindex_param.go +++ b/pkg/util/paramtable/autoindex_param.go @@ -17,7 +17,13 @@ package paramtable import ( + "fmt" + + "github.com/milvus-io/milvus/pkg/config" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" ) // ///////////////////////////////////////////////////////////////////////////// @@ -41,8 +47,9 @@ func (p *autoIndexConfig) init(base *BaseTable) { p.Enable.Init(base.mgr) p.IndexParams = ParamItem{ - Key: "autoIndex.params.build", - Version: "2.2.0", + Key: "autoIndex.params.build", + Version: "2.2.0", + DefaultValue: `{"M": 30,"efConstruction": 360,"index_type": "HNSW", "metric_type": "IP"}`, } p.IndexParams.Init(base.mgr) @@ -69,4 +76,36 @@ func (p *autoIndexConfig) init(base *BaseTable) { Version: "2.2.0", } p.AutoIndexTypeName.Init(base.mgr) + + p.panicIfNotValidAndSetDefaultMetricType(base.mgr) +} + +func (p *autoIndexConfig) panicIfNotValidAndSetDefaultMetricType(mgr *config.Manager) { + m := p.IndexParams.GetAsJSONMap() + if m == nil { + panic("autoIndex.build not invalid, should be json format") + } + + indexType, ok := m[common.IndexTypeKey] + if !ok { + panic("autoIndex.build not invalid, index type not found") + } + + checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(indexType) + if err != nil { + panic(fmt.Sprintf("autoIndex.build not invalid, unsupported index type: %s", indexType)) + } + + checker.SetDefaultMetricTypeIfNotExist(m) + + if err := checker.StaticCheck(m); err != nil { + panic(fmt.Sprintf("autoIndex.build not invalid, parameters not invalid, error: %s", err.Error())) + } + + p.reset(m, mgr) +} + +func (p *autoIndexConfig) reset(m map[string]string, mgr *config.Manager) { + j := funcutil.MapToJSON(m) + mgr.SetConfig("autoIndex.params.build", string(j)) } diff --git a/pkg/util/paramtable/autoindex_param_test.go b/pkg/util/paramtable/autoindex_param_test.go index 7a4d26e7c1..a4fa5387d0 100644 --- a/pkg/util/paramtable/autoindex_param_test.go +++ b/pkg/util/paramtable/autoindex_param_test.go @@ -21,6 +21,10 @@ import ( "strconv" "testing" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" + + "github.com/milvus-io/milvus/pkg/config" + "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/pkg/common" @@ -90,3 +94,149 @@ func TestAutoIndexParams_build(t *testing.T) { // CParams.Save(CParams.AutoIndexConfig.IndexParams.Key, string(jsonStrBytes)) // }) } + +func Test_autoIndexConfig_panicIfNotValid(t *testing.T) { + t.Run("not in json format", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", "not in json format") + p := &autoIndexConfig{ + IndexParams: ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + assert.Panics(t, func() { + p.panicIfNotValidAndSetDefaultMetricType(mgr) + }) + }) + + t.Run("index type not found", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"M": 30}`) + p := &autoIndexConfig{ + IndexParams: ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + assert.Panics(t, func() { + p.panicIfNotValidAndSetDefaultMetricType(mgr) + }) + }) + + t.Run("unsupported index type", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"index_type": "not supported"}`) + p := &autoIndexConfig{ + IndexParams: ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + assert.Panics(t, func() { + p.panicIfNotValidAndSetDefaultMetricType(mgr) + }) + }) + + t.Run("normal case, hnsw", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"M": 30,"efConstruction": 360,"index_type": "HNSW"}`) + p := &autoIndexConfig{ + IndexParams: ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + assert.NotPanics(t, func() { + p.panicIfNotValidAndSetDefaultMetricType(mgr) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) + }) + + t.Run("normal case, ivf flat", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`) + p := &autoIndexConfig{ + IndexParams: ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + assert.NotPanics(t, func() { + p.panicIfNotValidAndSetDefaultMetricType(mgr) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) + }) + + t.Run("normal case, ivf flat", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "IVF_FLAT"}`) + p := &autoIndexConfig{ + IndexParams: ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + assert.NotPanics(t, func() { + p.panicIfNotValidAndSetDefaultMetricType(mgr) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) + }) + + t.Run("normal case, diskann", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"index_type": "DISKANN"}`) + p := &autoIndexConfig{ + IndexParams: ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + assert.NotPanics(t, func() { + p.panicIfNotValidAndSetDefaultMetricType(mgr) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, indexparamcheck.FloatVectorDefaultMetricType, metricType) + }) + + t.Run("normal case, bin flat", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"index_type": "BIN_FLAT"}`) + p := &autoIndexConfig{ + IndexParams: ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + assert.NotPanics(t, func() { + p.panicIfNotValidAndSetDefaultMetricType(mgr) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType) + }) + + t.Run("normal case, bin ivf flat", func(t *testing.T) { + mgr := config.NewManager() + mgr.SetConfig("autoIndex.params.build", `{"nlist": 30, "index_type": "BIN_IVF_FLAT"}`) + p := &autoIndexConfig{ + IndexParams: ParamItem{ + Key: "autoIndex.params.build", + }, + } + p.IndexParams.Init(mgr) + assert.NotPanics(t, func() { + p.panicIfNotValidAndSetDefaultMetricType(mgr) + }) + metricType, exist := p.IndexParams.GetAsJSONMap()[common.MetricTypeKey] + assert.True(t, exist) + assert.Equal(t, indexparamcheck.BinaryVectorDefaultMetricType, metricType) + }) +}