enhance: add unify vector index config management (#36846)

issue: #34298

Signed-off-by: xianliang.li <xianliang.li@zilliz.com>
This commit is contained in:
foxspy 2024-11-01 06:18:21 +08:00 committed by GitHub
parent cd2655c861
commit 3224e58c5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 538 additions and 42 deletions

View File

@ -1066,3 +1066,15 @@ streaming:
backoffMultiplier: 2 # The multiplier of balance task trigger backoff, 2 by default
txn:
defaultKeepaliveTimeout: 10s # The default keepalive timeout for wal txn, 10s by default
# Any configuration related to the knowhere vector search engine
knowhere:
enable: true # When enable this configuration, the index parameters defined following will be automatically populated as index parameters, without requiring user input.
DISKANN: # Index parameters for diskann
build: # Diskann build params
max_degree: 56 # Maximum degree of the Vamana graph
search_list_size: 100 # Size of the candidate list during building graph
pq_code_budget_gb_ratio: 0.125 # Size limit on the PQ code (compared with raw data)
search_cache_budget_gb_ratio: 0.1 # Ratio of cached node numbers to raw data
search: # Diskann search params
beam_width_ratio: 4.0 # Ratio between the maximum number of IO requests per search iteration and CPU number.

View File

@ -29,10 +29,12 @@ import (
"github.com/milvus-io/milvus/internal/proto/workerpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -159,6 +161,19 @@ func (it *indexBuildTask) PreCheck(ctx context.Context, dependency *taskSchedule
fieldID := dependency.meta.indexMeta.GetFieldIDByIndexID(segIndex.CollectionID, segIndex.IndexID)
binlogIDs := getBinLogIDs(segment, fieldID)
// When new index parameters are added, these parameters need to be updated to ensure they are included during the index-building process.
if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() {
var ret error
indexParams, ret = Params.KnowhereConfig.UpdateIndexParams(GetIndexType(indexParams), paramtable.BuildStage, indexParams)
if ret != nil {
log.Ctx(ctx).Warn("failed to update index build params defined in yaml", zap.Int64("taskID", it.taskID), zap.Error(ret))
it.SetState(indexpb.JobState_JobStateInit, ret.Error())
return false
}
}
if isDiskANNIndex(GetIndexType(indexParams)) {
var err error
indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams)

View File

@ -210,6 +210,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
indexType := it.newIndexParams[common.IndexTypeKey]
var fieldDataSize uint64
if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) {
// check index node support disk index
if !Params.IndexNodeCfg.EnableDisk.GetAsBool() {
@ -225,7 +226,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
log.Warn("IndexNode get local used size failed")
return err
}
fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
fieldDataSize, err = estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
if err != nil {
log.Warn("IndexNode get local used size failed")
return err
@ -247,6 +248,11 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
}
}
// system resource-related parameters, such as memory limits, CPU limits, and disk limits, are appended here to the parameter list
if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() {
it.newIndexParams, _ = Params.KnowhereConfig.MergeResourceParams(fieldDataSize, paramtable.BuildStage, it.newIndexParams)
}
storageConfig := &indexcgopb.StorageConfig{
Address: it.req.GetStorageConfig().GetAddress(),
AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(),

View File

@ -341,6 +341,14 @@ func (cit *createIndexTask) parseIndexParams() error {
if !exist {
return fmt.Errorf("IndexType not specified")
}
// index parameters defined in the YAML file are merged with the user-provided parameters during create stage
if Params.KnowhereConfig.Enable.GetAsBool() {
var err error
indexParamsMap, err = Params.KnowhereConfig.MergeIndexParams(indexType, paramtable.BuildStage, indexParamsMap)
if err != nil {
return err
}
}
if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) {
err := indexparams.FillDiskIndexParams(Params, indexParamsMap)
if err != nil {

View File

@ -39,6 +39,7 @@ import (
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -1053,6 +1054,43 @@ func Test_parseIndexParams(t *testing.T) {
err := cit.parseIndexParams()
assert.Error(t, err)
})
t.Run("verify merge params with yaml", func(t *testing.T) {
paramtable.Init()
Params.Save("knowhere.HNSW.build.M", "3000")
Params.Save("knowhere.HNSW.build.efConstruction", "120")
defer Params.Reset("knowhere.HNSW.build.M")
defer Params.Reset("knowhere.HNSW.build.efConstruction")
cit := &createIndexTask{
Condition: nil,
req: &milvuspb.CreateIndexRequest{
ExtraParams: []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "HNSW",
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
},
IndexName: "",
},
fieldSchema: &schemapb.FieldSchema{
FieldID: 101,
Name: "FieldVector",
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "768"},
},
},
}
err := cit.parseIndexParams()
// Out of range in json: param 'M' (3000) should be in range [2, 2048]
assert.Error(t, err)
})
}
func Test_wrapUserIndexParams(t *testing.T) {

View File

@ -20,6 +20,7 @@ import (
"strings"
"github.com/cockroachdb/errors"
"gopkg.in/yaml.v3"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -30,6 +31,10 @@ var (
ErrKeyNotFound = errors.New("key not found")
)
const (
NotFormatPrefix = "knowhere."
)
func Init(opts ...Option) (*Manager, error) {
o := &Options{}
for _, opt := range opts {
@ -55,7 +60,17 @@ func Init(opts ...Option) (*Manager, error) {
var formattedKeys = typeutil.NewConcurrentMap[string, string]()
func lowerKey(key string) string {
if strings.HasPrefix(key, NotFormatPrefix) {
return key
}
return strings.ToLower(key)
}
func formatKey(key string) string {
if strings.HasPrefix(key, NotFormatPrefix) {
return key
}
cached, ok := formattedKeys.Get(key)
if ok {
return cached
@ -64,3 +79,43 @@ func formatKey(key string) string {
formattedKeys.Insert(key, result)
return result
}
func flattenNode(node *yaml.Node, parentKey string, result map[string]string) {
// The content of the node should contain key-value pairs in a MappingNode
if node.Kind == yaml.MappingNode {
for i := 0; i < len(node.Content); i += 2 {
keyNode := node.Content[i]
valueNode := node.Content[i+1]
key := keyNode.Value
// Construct the full key with parent hierarchy
fullKey := key
if parentKey != "" {
fullKey = parentKey + "." + key
}
switch valueNode.Kind {
case yaml.ScalarNode:
// Scalar value, store it as a string
result[lowerKey(fullKey)] = valueNode.Value
result[formatKey(fullKey)] = valueNode.Value
case yaml.MappingNode:
// Nested map, process recursively
flattenNode(valueNode, fullKey, result)
case yaml.SequenceNode:
// List (sequence), process elements
var listStr string
for j, item := range valueNode.Content {
if j > 0 {
listStr += ","
}
if item.Kind == yaml.ScalarNode {
listStr += item.Value
}
}
result[lowerKey(fullKey)] = listStr
result[formatKey(fullKey)] = listStr
}
}
}
}

View File

@ -17,14 +17,16 @@
package config
import (
"bytes"
"fmt"
"os"
"path/filepath"
"sync"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/spf13/cast"
"github.com/spf13/viper"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
"github.com/milvus-io/milvus/pkg/log"
)
@ -115,7 +117,6 @@ func (fs *FileSource) UpdateOptions(opts Options) {
}
func (fs *FileSource) loadFromFile() error {
yamlReader := viper.New()
newConfig := make(map[string]string)
var configFiles []string
@ -128,37 +129,35 @@ func (fs *FileSource) loadFromFile() error {
continue
}
yamlReader.SetConfigFile(configFile)
if err := yamlReader.ReadInConfig(); err != nil {
ext := filepath.Ext(configFile)
if len(ext) == 0 || ext[1:] != "yaml" {
return fmt.Errorf("Unsupported Config Type: " + ext)
}
data, err := os.ReadFile(configFile)
if err != nil {
return errors.Wrap(err, "Read config failed: "+configFile)
}
for _, key := range yamlReader.AllKeys() {
val := yamlReader.Get(key)
str, err := cast.ToStringE(val)
if err != nil {
switch val := val.(type) {
case []any:
str = str[:0]
for _, v := range val {
ss, err := cast.ToStringE(v)
if err != nil {
log.Warn("cast to string failed", zap.Any("value", v))
}
if str == "" {
str = ss
} else {
str = str + "," + ss
}
}
// handle empty file
if len(data) == 0 {
continue
}
default:
log.Warn("val is not a slice", zap.Any("value", val))
continue
}
}
newConfig[key] = str
newConfig[formatKey(key)] = str
var node yaml.Node
decoder := yaml.NewDecoder(bytes.NewReader(data))
if err := decoder.Decode(&node); err != nil {
return errors.Wrap(err, "YAML unmarshal failed: "+configFile)
}
if node.Kind == yaml.DocumentNode && len(node.Content) > 0 {
// Get the content of the Document Node
contentNode := node.Content[0]
// Recursively process the content of the Document Node
flattenNode(contentNode, "", newConfig)
} else if node.Kind == yaml.MappingNode {
flattenNode(&node, "", newConfig)
}
}

View File

@ -69,17 +69,18 @@ type ComponentParam struct {
GpuConfig gpuConfig
TraceCfg traceConfig
RootCoordCfg rootCoordConfig
ProxyCfg proxyConfig
QueryCoordCfg queryCoordConfig
QueryNodeCfg queryNodeConfig
DataCoordCfg dataCoordConfig
DataNodeCfg dataNodeConfig
IndexNodeCfg indexNodeConfig
HTTPCfg httpConfig
LogCfg logConfig
RoleCfg roleConfig
StreamingCfg streamingConfig
RootCoordCfg rootCoordConfig
ProxyCfg proxyConfig
QueryCoordCfg queryCoordConfig
QueryNodeCfg queryNodeConfig
DataCoordCfg dataCoordConfig
DataNodeCfg dataNodeConfig
IndexNodeCfg indexNodeConfig
KnowhereConfig knowhereConfig
HTTPCfg httpConfig
LogCfg logConfig
RoleCfg roleConfig
StreamingCfg streamingConfig
RootCoordGrpcServerCfg GrpcServerConfig
ProxyGrpcServerCfg GrpcServerConfig
@ -134,6 +135,7 @@ func (p *ComponentParam) init(bt *BaseTable) {
p.LogCfg.init(bt)
p.RoleCfg.init(bt)
p.GpuConfig.init(bt)
p.KnowhereConfig.init(bt)
p.RootCoordGrpcServerCfg.Init("rootCoord", bt)
p.ProxyGrpcServerCfg.Init("proxy", bt)

View File

@ -0,0 +1,118 @@
package paramtable
import (
"fmt"
"strconv"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/util/hardware"
)
type knowhereConfig struct {
Enable ParamItem `refreshable:"true"`
IndexParam ParamGroup `refreshable:"true"`
}
const (
BuildStage = "build"
LoadStage = "load"
SearchStage = "search"
)
const (
BuildDramBudgetKey = "build_dram_budget_gb"
NumBuildThreadKey = "num_build_thread"
VecFieldSizeKey = "vec_field_size_gb"
)
func (p *knowhereConfig) init(base *BaseTable) {
p.IndexParam = ParamGroup{
KeyPrefix: "knowhere.",
Version: "2.5.0",
}
p.IndexParam.Init(base.mgr)
p.Enable = ParamItem{
Key: "knowhere.enable",
Version: "2.5.0",
DefaultValue: "true",
}
p.Enable.Init(base.mgr)
}
func (p *knowhereConfig) getIndexParam(indexType string, stage string) map[string]string {
matchedParam := make(map[string]string)
params := p.IndexParam.GetValue()
prefix := indexType + "." + stage + "."
for k, v := range params {
if strings.HasPrefix(k, prefix) {
matchedParam[strings.TrimPrefix(k, prefix)] = v
}
}
return matchedParam
}
func GetKeyFromSlice(indexParams []*commonpb.KeyValuePair, key string) string {
for _, param := range indexParams {
if param.Key == key {
return param.Value
}
}
return ""
}
func (p *knowhereConfig) GetRuntimeParameter(stage string) (map[string]string, error) {
params := make(map[string]string)
if stage == BuildStage {
params[BuildDramBudgetKey] = fmt.Sprintf("%f", float32(hardware.GetFreeMemoryCount())/(1<<30))
params[NumBuildThreadKey] = strconv.Itoa(int(float32(hardware.GetCPUNum())))
}
return params, nil
}
func (p *knowhereConfig) UpdateIndexParams(indexType string, stage string, indexParams []*commonpb.KeyValuePair) ([]*commonpb.KeyValuePair, error) {
defaultParams := p.getIndexParam(indexType, stage)
for key, val := range defaultParams {
if GetKeyFromSlice(indexParams, key) == "" {
indexParams = append(indexParams,
&commonpb.KeyValuePair{
Key: key,
Value: val,
})
}
}
return indexParams, nil
}
func (p *knowhereConfig) MergeIndexParams(indexType string, stage string, indexParam map[string]string) (map[string]string, error) {
defaultParams := p.getIndexParam(indexType, stage)
for key, val := range defaultParams {
_, existed := indexParam[key]
if !existed {
indexParam[key] = val
}
}
return indexParam, nil
}
func (p *knowhereConfig) MergeResourceParams(vecFieldSize uint64, stage string, indexParam map[string]string) (map[string]string, error) {
param, _ := p.GetRuntimeParameter(stage)
for key, val := range param {
indexParam[key] = val
}
indexParam[VecFieldSizeKey] = fmt.Sprintf("%f", float32(vecFieldSize)/(1<<30))
return indexParam, nil
}

View File

@ -0,0 +1,243 @@
package paramtable
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)
func TestKnowhereConfig_GetIndexParam(t *testing.T) {
bt := NewBaseTable(SkipRemote(true))
cfg := &knowhereConfig{}
cfg.init(bt)
// Set some initial config
indexParams := map[string]interface{}{
"knowhere.IVF_FLAT.build.nlist": 1024,
"knowhere.HNSW.build.efConstruction": 360,
"knowhere.DISKANN.search.search_list": 100,
}
for key, val := range indexParams {
valStr, _ := json.Marshal(val)
bt.Save(key, string(valStr))
}
tests := []struct {
name string
indexType string
stage string
expectedKey string
expectedValue string
}{
{
name: "IVF_FLAT Build",
indexType: "IVF_FLAT",
stage: BuildStage,
expectedKey: "nlist",
expectedValue: "1024",
},
{
name: "HNSW Build",
indexType: "HNSW",
stage: BuildStage,
expectedKey: "efConstruction",
expectedValue: "360",
},
{
name: "DISKANN Search",
indexType: "DISKANN",
stage: SearchStage,
expectedKey: "search_list",
expectedValue: "100",
},
{
name: "Non-existent",
indexType: "NON_EXISTENT",
stage: BuildStage,
expectedKey: "",
expectedValue: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := cfg.getIndexParam(tt.indexType, tt.stage)
if tt.expectedKey != "" {
assert.Contains(t, result, tt.expectedKey, "The result should contain the expected key")
assert.Equal(t, tt.expectedValue, result[tt.expectedKey], "The value for the key should match the expected value")
} else {
assert.Empty(t, result, "The result should be empty for non-existent index type")
}
})
}
}
func TestKnowhereConfig_GetRuntimeParameter(t *testing.T) {
cfg := &knowhereConfig{}
params, err := cfg.GetRuntimeParameter(BuildStage)
assert.NoError(t, err)
assert.Contains(t, params, BuildDramBudgetKey)
assert.Contains(t, params, NumBuildThreadKey)
params, err = cfg.GetRuntimeParameter(SearchStage)
assert.NoError(t, err)
assert.Empty(t, params)
}
func TestKnowhereConfig_UpdateParameter(t *testing.T) {
bt := NewBaseTable(SkipRemote(true))
cfg := &knowhereConfig{}
cfg.init(bt)
// Set some initial config
indexParams := map[string]interface{}{
"knowhere.IVF_FLAT.build.nlist": 1024,
"knowhere.IVF_FLAT.build.num_build_thread": 12,
}
for key, val := range indexParams {
valStr, _ := json.Marshal(val)
bt.Save(key, string(valStr))
}
tests := []struct {
name string
indexType string
stage string
inputParams []*commonpb.KeyValuePair
expectedParams map[string]string
}{
{
name: "IVF_FLAT Build",
indexType: "IVF_FLAT",
stage: BuildStage,
inputParams: []*commonpb.KeyValuePair{
{Key: "nlist", Value: "128"},
{Key: "existing_key", Value: "existing_value"},
},
expectedParams: map[string]string{
"existing_key": "existing_value",
"nlist": "128",
"num_build_thread": "12",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := cfg.UpdateIndexParams(tt.indexType, tt.stage, tt.inputParams)
assert.NoError(t, err)
for key, expectedValue := range tt.expectedParams {
assert.Equal(t, expectedValue, GetKeyFromSlice(result, key), "The value for key %s should match the expected value", key)
}
})
}
}
func TestKnowhereConfig_MergeParameter(t *testing.T) {
bt := NewBaseTable(SkipRemote(true))
cfg := &knowhereConfig{}
cfg.init(bt)
indexParams := map[string]interface{}{
"knowhere.IVF_FLAT.build.nlist": 1024,
"knowhere.IVF_FLAT.build.num_build_thread": 12,
}
for key, val := range indexParams {
valStr, _ := json.Marshal(val)
bt.Save(key, string(valStr))
}
tests := []struct {
name string
indexType string
stage string
inputParams map[string]string
expectedParams map[string]string
}{
{
name: "IVF_FLAT Build",
indexType: "IVF_FLAT",
stage: BuildStage,
inputParams: map[string]string{
"nlist": "128",
"existing_key": "existing_value",
},
expectedParams: map[string]string{
"existing_key": "existing_value",
"nlist": "128",
"num_build_thread": "12",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := cfg.MergeIndexParams(tt.indexType, tt.stage, tt.inputParams)
assert.NoError(t, err)
for key, expectedValue := range tt.expectedParams {
assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key)
}
})
}
}
func TestKnowhereConfig_MergeWithResource(t *testing.T) {
cfg := &knowhereConfig{}
tests := []struct {
name string
vecFieldSize uint64
inputParams map[string]string
expectedParams map[string]string
}{
{
name: "Merge with resource",
vecFieldSize: 1024 * 1024 * 1024,
inputParams: map[string]string{
"existing_key": "existing_value",
},
expectedParams: map[string]string{
"existing_key": "existing_value",
BuildDramBudgetKey: "", // We can't predict the exact value, but it should exist
NumBuildThreadKey: "", // We can't predict the exact value, but it should exist
VecFieldSizeKey: "1.000000",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := cfg.MergeResourceParams(tt.vecFieldSize, BuildStage, tt.inputParams)
assert.NoError(t, err)
for key, expectedValue := range tt.expectedParams {
if expectedValue != "" {
assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key)
} else {
assert.Contains(t, result, key, "The result should contain the key %s", key)
assert.NotEmpty(t, result[key], "The value for key %s should not be empty", key)
}
}
})
}
}
func TestGetKeyFromSlice(t *testing.T) {
indexParams := []*commonpb.KeyValuePair{
{Key: "key1", Value: "value1"},
{Key: "key2", Value: "value2"},
}
assert.Equal(t, "value1", GetKeyFromSlice(indexParams, "key1"))
assert.Equal(t, "value2", GetKeyFromSlice(indexParams, "key2"))
assert.Equal(t, "", GetKeyFromSlice(indexParams, "non_existent_key"))
}