enhance: add unify vector index config management (#36846)

issue: #34298

Signed-off-by: xianliang.li <xianliang.li@zilliz.com>
This commit is contained in:
foxspy 2024-11-01 06:18:21 +08:00 committed by GitHub
parent cd2655c861
commit 3224e58c5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 538 additions and 42 deletions

View File

@ -1066,3 +1066,15 @@ streaming:
backoffMultiplier: 2 # The multiplier of balance task trigger backoff, 2 by default
txn:
defaultKeepaliveTimeout: 10s # The default keepalive timeout for wal txn, 10s by default
# Any configuration related to the knowhere vector search engine
knowhere:
enable: true # When enable this configuration, the index parameters defined following will be automatically populated as index parameters, without requiring user input.
DISKANN: # Index parameters for diskann
build: # Diskann build params
max_degree: 56 # Maximum degree of the Vamana graph
search_list_size: 100 # Size of the candidate list during building graph
pq_code_budget_gb_ratio: 0.125 # Size limit on the PQ code (compared with raw data)
search_cache_budget_gb_ratio: 0.1 # Ratio of cached node numbers to raw data
search: # Diskann search params
beam_width_ratio: 4.0 # Ratio between the maximum number of IO requests per search iteration and CPU number.

View File

@ -29,10 +29,12 @@ import (
"github.com/milvus-io/milvus/internal/proto/workerpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/indexparams"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -159,6 +161,19 @@ func (it *indexBuildTask) PreCheck(ctx context.Context, dependency *taskSchedule
fieldID := dependency.meta.indexMeta.GetFieldIDByIndexID(segIndex.CollectionID, segIndex.IndexID)
binlogIDs := getBinLogIDs(segment, fieldID)
// When new index parameters are added, these parameters need to be updated to ensure they are included during the index-building process.
if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() {
var ret error
indexParams, ret = Params.KnowhereConfig.UpdateIndexParams(GetIndexType(indexParams), paramtable.BuildStage, indexParams)
if ret != nil {
log.Ctx(ctx).Warn("failed to update index build params defined in yaml", zap.Int64("taskID", it.taskID), zap.Error(ret))
it.SetState(indexpb.JobState_JobStateInit, ret.Error())
return false
}
}
if isDiskANNIndex(GetIndexType(indexParams)) {
var err error
indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams)

View File

@ -210,6 +210,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
indexType := it.newIndexParams[common.IndexTypeKey]
var fieldDataSize uint64
if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) {
// check index node support disk index
if !Params.IndexNodeCfg.EnableDisk.GetAsBool() {
@ -225,7 +226,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
log.Warn("IndexNode get local used size failed")
return err
}
fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
fieldDataSize, err = estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
if err != nil {
log.Warn("IndexNode get local used size failed")
return err
@ -247,6 +248,11 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
}
}
// system resource-related parameters, such as memory limits, CPU limits, and disk limits, are appended here to the parameter list
if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() {
it.newIndexParams, _ = Params.KnowhereConfig.MergeResourceParams(fieldDataSize, paramtable.BuildStage, it.newIndexParams)
}
storageConfig := &indexcgopb.StorageConfig{
Address: it.req.GetStorageConfig().GetAddress(),
AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(),

View File

@ -341,6 +341,14 @@ func (cit *createIndexTask) parseIndexParams() error {
if !exist {
return fmt.Errorf("IndexType not specified")
}
// index parameters defined in the YAML file are merged with the user-provided parameters during create stage
if Params.KnowhereConfig.Enable.GetAsBool() {
var err error
indexParamsMap, err = Params.KnowhereConfig.MergeIndexParams(indexType, paramtable.BuildStage, indexParamsMap)
if err != nil {
return err
}
}
if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) {
err := indexparams.FillDiskIndexParams(Params, indexParamsMap)
if err != nil {

View File

@ -39,6 +39,7 @@ import (
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -1053,6 +1054,43 @@ func Test_parseIndexParams(t *testing.T) {
err := cit.parseIndexParams()
assert.Error(t, err)
})
t.Run("verify merge params with yaml", func(t *testing.T) {
paramtable.Init()
Params.Save("knowhere.HNSW.build.M", "3000")
Params.Save("knowhere.HNSW.build.efConstruction", "120")
defer Params.Reset("knowhere.HNSW.build.M")
defer Params.Reset("knowhere.HNSW.build.efConstruction")
cit := &createIndexTask{
Condition: nil,
req: &milvuspb.CreateIndexRequest{
ExtraParams: []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "HNSW",
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
},
IndexName: "",
},
fieldSchema: &schemapb.FieldSchema{
FieldID: 101,
Name: "FieldVector",
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "768"},
},
},
}
err := cit.parseIndexParams()
// Out of range in json: param 'M' (3000) should be in range [2, 2048]
assert.Error(t, err)
})
}
func Test_wrapUserIndexParams(t *testing.T) {

View File

@ -20,6 +20,7 @@ import (
"strings"
"github.com/cockroachdb/errors"
"gopkg.in/yaml.v3"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
@ -30,6 +31,10 @@ var (
ErrKeyNotFound = errors.New("key not found")
)
const (
NotFormatPrefix = "knowhere."
)
func Init(opts ...Option) (*Manager, error) {
o := &Options{}
for _, opt := range opts {
@ -55,7 +60,17 @@ func Init(opts ...Option) (*Manager, error) {
var formattedKeys = typeutil.NewConcurrentMap[string, string]()
func lowerKey(key string) string {
if strings.HasPrefix(key, NotFormatPrefix) {
return key
}
return strings.ToLower(key)
}
func formatKey(key string) string {
if strings.HasPrefix(key, NotFormatPrefix) {
return key
}
cached, ok := formattedKeys.Get(key)
if ok {
return cached
@ -64,3 +79,43 @@ func formatKey(key string) string {
formattedKeys.Insert(key, result)
return result
}
func flattenNode(node *yaml.Node, parentKey string, result map[string]string) {
// The content of the node should contain key-value pairs in a MappingNode
if node.Kind == yaml.MappingNode {
for i := 0; i < len(node.Content); i += 2 {
keyNode := node.Content[i]
valueNode := node.Content[i+1]
key := keyNode.Value
// Construct the full key with parent hierarchy
fullKey := key
if parentKey != "" {
fullKey = parentKey + "." + key
}
switch valueNode.Kind {
case yaml.ScalarNode:
// Scalar value, store it as a string
result[lowerKey(fullKey)] = valueNode.Value
result[formatKey(fullKey)] = valueNode.Value
case yaml.MappingNode:
// Nested map, process recursively
flattenNode(valueNode, fullKey, result)
case yaml.SequenceNode:
// List (sequence), process elements
var listStr string
for j, item := range valueNode.Content {
if j > 0 {
listStr += ","
}
if item.Kind == yaml.ScalarNode {
listStr += item.Value
}
}
result[lowerKey(fullKey)] = listStr
result[formatKey(fullKey)] = listStr
}
}
}
}

View File

@ -17,14 +17,16 @@
package config
import (
"bytes"
"fmt"
"os"
"path/filepath"
"sync"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/spf13/cast"
"github.com/spf13/viper"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
"github.com/milvus-io/milvus/pkg/log"
)
@ -115,7 +117,6 @@ func (fs *FileSource) UpdateOptions(opts Options) {
}
func (fs *FileSource) loadFromFile() error {
yamlReader := viper.New()
newConfig := make(map[string]string)
var configFiles []string
@ -128,37 +129,35 @@ func (fs *FileSource) loadFromFile() error {
continue
}
yamlReader.SetConfigFile(configFile)
if err := yamlReader.ReadInConfig(); err != nil {
ext := filepath.Ext(configFile)
if len(ext) == 0 || ext[1:] != "yaml" {
return fmt.Errorf("Unsupported Config Type: " + ext)
}
data, err := os.ReadFile(configFile)
if err != nil {
return errors.Wrap(err, "Read config failed: "+configFile)
}
for _, key := range yamlReader.AllKeys() {
val := yamlReader.Get(key)
str, err := cast.ToStringE(val)
if err != nil {
switch val := val.(type) {
case []any:
str = str[:0]
for _, v := range val {
ss, err := cast.ToStringE(v)
if err != nil {
log.Warn("cast to string failed", zap.Any("value", v))
}
if str == "" {
str = ss
} else {
str = str + "," + ss
}
}
default:
log.Warn("val is not a slice", zap.Any("value", val))
// handle empty file
if len(data) == 0 {
continue
}
var node yaml.Node
decoder := yaml.NewDecoder(bytes.NewReader(data))
if err := decoder.Decode(&node); err != nil {
return errors.Wrap(err, "YAML unmarshal failed: "+configFile)
}
newConfig[key] = str
newConfig[formatKey(key)] = str
if node.Kind == yaml.DocumentNode && len(node.Content) > 0 {
// Get the content of the Document Node
contentNode := node.Content[0]
// Recursively process the content of the Document Node
flattenNode(contentNode, "", newConfig)
} else if node.Kind == yaml.MappingNode {
flattenNode(&node, "", newConfig)
}
}

View File

@ -76,6 +76,7 @@ type ComponentParam struct {
DataCoordCfg dataCoordConfig
DataNodeCfg dataNodeConfig
IndexNodeCfg indexNodeConfig
KnowhereConfig knowhereConfig
HTTPCfg httpConfig
LogCfg logConfig
RoleCfg roleConfig
@ -134,6 +135,7 @@ func (p *ComponentParam) init(bt *BaseTable) {
p.LogCfg.init(bt)
p.RoleCfg.init(bt)
p.GpuConfig.init(bt)
p.KnowhereConfig.init(bt)
p.RootCoordGrpcServerCfg.Init("rootCoord", bt)
p.ProxyGrpcServerCfg.Init("proxy", bt)

View File

@ -0,0 +1,118 @@
package paramtable
import (
"fmt"
"strconv"
"strings"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/pkg/util/hardware"
)
type knowhereConfig struct {
Enable ParamItem `refreshable:"true"`
IndexParam ParamGroup `refreshable:"true"`
}
const (
BuildStage = "build"
LoadStage = "load"
SearchStage = "search"
)
const (
BuildDramBudgetKey = "build_dram_budget_gb"
NumBuildThreadKey = "num_build_thread"
VecFieldSizeKey = "vec_field_size_gb"
)
func (p *knowhereConfig) init(base *BaseTable) {
p.IndexParam = ParamGroup{
KeyPrefix: "knowhere.",
Version: "2.5.0",
}
p.IndexParam.Init(base.mgr)
p.Enable = ParamItem{
Key: "knowhere.enable",
Version: "2.5.0",
DefaultValue: "true",
}
p.Enable.Init(base.mgr)
}
func (p *knowhereConfig) getIndexParam(indexType string, stage string) map[string]string {
matchedParam := make(map[string]string)
params := p.IndexParam.GetValue()
prefix := indexType + "." + stage + "."
for k, v := range params {
if strings.HasPrefix(k, prefix) {
matchedParam[strings.TrimPrefix(k, prefix)] = v
}
}
return matchedParam
}
func GetKeyFromSlice(indexParams []*commonpb.KeyValuePair, key string) string {
for _, param := range indexParams {
if param.Key == key {
return param.Value
}
}
return ""
}
func (p *knowhereConfig) GetRuntimeParameter(stage string) (map[string]string, error) {
params := make(map[string]string)
if stage == BuildStage {
params[BuildDramBudgetKey] = fmt.Sprintf("%f", float32(hardware.GetFreeMemoryCount())/(1<<30))
params[NumBuildThreadKey] = strconv.Itoa(int(float32(hardware.GetCPUNum())))
}
return params, nil
}
func (p *knowhereConfig) UpdateIndexParams(indexType string, stage string, indexParams []*commonpb.KeyValuePair) ([]*commonpb.KeyValuePair, error) {
defaultParams := p.getIndexParam(indexType, stage)
for key, val := range defaultParams {
if GetKeyFromSlice(indexParams, key) == "" {
indexParams = append(indexParams,
&commonpb.KeyValuePair{
Key: key,
Value: val,
})
}
}
return indexParams, nil
}
func (p *knowhereConfig) MergeIndexParams(indexType string, stage string, indexParam map[string]string) (map[string]string, error) {
defaultParams := p.getIndexParam(indexType, stage)
for key, val := range defaultParams {
_, existed := indexParam[key]
if !existed {
indexParam[key] = val
}
}
return indexParam, nil
}
func (p *knowhereConfig) MergeResourceParams(vecFieldSize uint64, stage string, indexParam map[string]string) (map[string]string, error) {
param, _ := p.GetRuntimeParameter(stage)
for key, val := range param {
indexParam[key] = val
}
indexParam[VecFieldSizeKey] = fmt.Sprintf("%f", float32(vecFieldSize)/(1<<30))
return indexParam, nil
}

View File

@ -0,0 +1,243 @@
package paramtable
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)
func TestKnowhereConfig_GetIndexParam(t *testing.T) {
bt := NewBaseTable(SkipRemote(true))
cfg := &knowhereConfig{}
cfg.init(bt)
// Set some initial config
indexParams := map[string]interface{}{
"knowhere.IVF_FLAT.build.nlist": 1024,
"knowhere.HNSW.build.efConstruction": 360,
"knowhere.DISKANN.search.search_list": 100,
}
for key, val := range indexParams {
valStr, _ := json.Marshal(val)
bt.Save(key, string(valStr))
}
tests := []struct {
name string
indexType string
stage string
expectedKey string
expectedValue string
}{
{
name: "IVF_FLAT Build",
indexType: "IVF_FLAT",
stage: BuildStage,
expectedKey: "nlist",
expectedValue: "1024",
},
{
name: "HNSW Build",
indexType: "HNSW",
stage: BuildStage,
expectedKey: "efConstruction",
expectedValue: "360",
},
{
name: "DISKANN Search",
indexType: "DISKANN",
stage: SearchStage,
expectedKey: "search_list",
expectedValue: "100",
},
{
name: "Non-existent",
indexType: "NON_EXISTENT",
stage: BuildStage,
expectedKey: "",
expectedValue: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := cfg.getIndexParam(tt.indexType, tt.stage)
if tt.expectedKey != "" {
assert.Contains(t, result, tt.expectedKey, "The result should contain the expected key")
assert.Equal(t, tt.expectedValue, result[tt.expectedKey], "The value for the key should match the expected value")
} else {
assert.Empty(t, result, "The result should be empty for non-existent index type")
}
})
}
}
func TestKnowhereConfig_GetRuntimeParameter(t *testing.T) {
cfg := &knowhereConfig{}
params, err := cfg.GetRuntimeParameter(BuildStage)
assert.NoError(t, err)
assert.Contains(t, params, BuildDramBudgetKey)
assert.Contains(t, params, NumBuildThreadKey)
params, err = cfg.GetRuntimeParameter(SearchStage)
assert.NoError(t, err)
assert.Empty(t, params)
}
func TestKnowhereConfig_UpdateParameter(t *testing.T) {
bt := NewBaseTable(SkipRemote(true))
cfg := &knowhereConfig{}
cfg.init(bt)
// Set some initial config
indexParams := map[string]interface{}{
"knowhere.IVF_FLAT.build.nlist": 1024,
"knowhere.IVF_FLAT.build.num_build_thread": 12,
}
for key, val := range indexParams {
valStr, _ := json.Marshal(val)
bt.Save(key, string(valStr))
}
tests := []struct {
name string
indexType string
stage string
inputParams []*commonpb.KeyValuePair
expectedParams map[string]string
}{
{
name: "IVF_FLAT Build",
indexType: "IVF_FLAT",
stage: BuildStage,
inputParams: []*commonpb.KeyValuePair{
{Key: "nlist", Value: "128"},
{Key: "existing_key", Value: "existing_value"},
},
expectedParams: map[string]string{
"existing_key": "existing_value",
"nlist": "128",
"num_build_thread": "12",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := cfg.UpdateIndexParams(tt.indexType, tt.stage, tt.inputParams)
assert.NoError(t, err)
for key, expectedValue := range tt.expectedParams {
assert.Equal(t, expectedValue, GetKeyFromSlice(result, key), "The value for key %s should match the expected value", key)
}
})
}
}
func TestKnowhereConfig_MergeParameter(t *testing.T) {
bt := NewBaseTable(SkipRemote(true))
cfg := &knowhereConfig{}
cfg.init(bt)
indexParams := map[string]interface{}{
"knowhere.IVF_FLAT.build.nlist": 1024,
"knowhere.IVF_FLAT.build.num_build_thread": 12,
}
for key, val := range indexParams {
valStr, _ := json.Marshal(val)
bt.Save(key, string(valStr))
}
tests := []struct {
name string
indexType string
stage string
inputParams map[string]string
expectedParams map[string]string
}{
{
name: "IVF_FLAT Build",
indexType: "IVF_FLAT",
stage: BuildStage,
inputParams: map[string]string{
"nlist": "128",
"existing_key": "existing_value",
},
expectedParams: map[string]string{
"existing_key": "existing_value",
"nlist": "128",
"num_build_thread": "12",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := cfg.MergeIndexParams(tt.indexType, tt.stage, tt.inputParams)
assert.NoError(t, err)
for key, expectedValue := range tt.expectedParams {
assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key)
}
})
}
}
func TestKnowhereConfig_MergeWithResource(t *testing.T) {
cfg := &knowhereConfig{}
tests := []struct {
name string
vecFieldSize uint64
inputParams map[string]string
expectedParams map[string]string
}{
{
name: "Merge with resource",
vecFieldSize: 1024 * 1024 * 1024,
inputParams: map[string]string{
"existing_key": "existing_value",
},
expectedParams: map[string]string{
"existing_key": "existing_value",
BuildDramBudgetKey: "", // We can't predict the exact value, but it should exist
NumBuildThreadKey: "", // We can't predict the exact value, but it should exist
VecFieldSizeKey: "1.000000",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := cfg.MergeResourceParams(tt.vecFieldSize, BuildStage, tt.inputParams)
assert.NoError(t, err)
for key, expectedValue := range tt.expectedParams {
if expectedValue != "" {
assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key)
} else {
assert.Contains(t, result, key, "The result should contain the key %s", key)
assert.NotEmpty(t, result[key], "The value for key %s should not be empty", key)
}
}
})
}
}
func TestGetKeyFromSlice(t *testing.T) {
indexParams := []*commonpb.KeyValuePair{
{Key: "key1", Value: "value1"},
{Key: "key2", Value: "value2"},
}
assert.Equal(t, "value1", GetKeyFromSlice(indexParams, "key1"))
assert.Equal(t, "value2", GetKeyFromSlice(indexParams, "key2"))
assert.Equal(t, "", GetKeyFromSlice(indexParams, "non_existent_key"))
}