mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 01:58:34 +08:00
enhance: add unify vector index config management (#36846)
issue: #34298 Signed-off-by: xianliang.li <xianliang.li@zilliz.com>
This commit is contained in:
parent
cd2655c861
commit
3224e58c5b
@ -1066,3 +1066,15 @@ streaming:
|
|||||||
backoffMultiplier: 2 # The multiplier of balance task trigger backoff, 2 by default
|
backoffMultiplier: 2 # The multiplier of balance task trigger backoff, 2 by default
|
||||||
txn:
|
txn:
|
||||||
defaultKeepaliveTimeout: 10s # The default keepalive timeout for wal txn, 10s by default
|
defaultKeepaliveTimeout: 10s # The default keepalive timeout for wal txn, 10s by default
|
||||||
|
|
||||||
|
# Any configuration related to the knowhere vector search engine
|
||||||
|
knowhere:
|
||||||
|
enable: true # When enable this configuration, the index parameters defined following will be automatically populated as index parameters, without requiring user input.
|
||||||
|
DISKANN: # Index parameters for diskann
|
||||||
|
build: # Diskann build params
|
||||||
|
max_degree: 56 # Maximum degree of the Vamana graph
|
||||||
|
search_list_size: 100 # Size of the candidate list during building graph
|
||||||
|
pq_code_budget_gb_ratio: 0.125 # Size limit on the PQ code (compared with raw data)
|
||||||
|
search_cache_budget_gb_ratio: 0.1 # Ratio of cached node numbers to raw data
|
||||||
|
search: # Diskann search params
|
||||||
|
beam_width_ratio: 4.0 # Ratio between the maximum number of IO requests per search iteration and CPU number.
|
||||||
@ -29,10 +29,12 @@ import (
|
|||||||
"github.com/milvus-io/milvus/internal/proto/workerpb"
|
"github.com/milvus-io/milvus/internal/proto/workerpb"
|
||||||
"github.com/milvus-io/milvus/internal/storage"
|
"github.com/milvus-io/milvus/internal/storage"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -159,6 +161,19 @@ func (it *indexBuildTask) PreCheck(ctx context.Context, dependency *taskSchedule
|
|||||||
|
|
||||||
fieldID := dependency.meta.indexMeta.GetFieldIDByIndexID(segIndex.CollectionID, segIndex.IndexID)
|
fieldID := dependency.meta.indexMeta.GetFieldIDByIndexID(segIndex.CollectionID, segIndex.IndexID)
|
||||||
binlogIDs := getBinLogIDs(segment, fieldID)
|
binlogIDs := getBinLogIDs(segment, fieldID)
|
||||||
|
|
||||||
|
// When new index parameters are added, these parameters need to be updated to ensure they are included during the index-building process.
|
||||||
|
if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() {
|
||||||
|
var ret error
|
||||||
|
indexParams, ret = Params.KnowhereConfig.UpdateIndexParams(GetIndexType(indexParams), paramtable.BuildStage, indexParams)
|
||||||
|
|
||||||
|
if ret != nil {
|
||||||
|
log.Ctx(ctx).Warn("failed to update index build params defined in yaml", zap.Int64("taskID", it.taskID), zap.Error(ret))
|
||||||
|
it.SetState(indexpb.JobState_JobStateInit, ret.Error())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if isDiskANNIndex(GetIndexType(indexParams)) {
|
if isDiskANNIndex(GetIndexType(indexParams)) {
|
||||||
var err error
|
var err error
|
||||||
indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams)
|
indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams)
|
||||||
|
|||||||
@ -210,6 +210,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
|
|||||||
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
|
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
|
||||||
|
|
||||||
indexType := it.newIndexParams[common.IndexTypeKey]
|
indexType := it.newIndexParams[common.IndexTypeKey]
|
||||||
|
var fieldDataSize uint64
|
||||||
if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) {
|
if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) {
|
||||||
// check index node support disk index
|
// check index node support disk index
|
||||||
if !Params.IndexNodeCfg.EnableDisk.GetAsBool() {
|
if !Params.IndexNodeCfg.EnableDisk.GetAsBool() {
|
||||||
@ -225,7 +226,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
|
|||||||
log.Warn("IndexNode get local used size failed")
|
log.Warn("IndexNode get local used size failed")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
|
fieldDataSize, err = estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("IndexNode get local used size failed")
|
log.Warn("IndexNode get local used size failed")
|
||||||
return err
|
return err
|
||||||
@ -247,6 +248,11 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// system resource-related parameters, such as memory limits, CPU limits, and disk limits, are appended here to the parameter list
|
||||||
|
if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() {
|
||||||
|
it.newIndexParams, _ = Params.KnowhereConfig.MergeResourceParams(fieldDataSize, paramtable.BuildStage, it.newIndexParams)
|
||||||
|
}
|
||||||
|
|
||||||
storageConfig := &indexcgopb.StorageConfig{
|
storageConfig := &indexcgopb.StorageConfig{
|
||||||
Address: it.req.GetStorageConfig().GetAddress(),
|
Address: it.req.GetStorageConfig().GetAddress(),
|
||||||
AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(),
|
AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(),
|
||||||
|
|||||||
@ -341,6 +341,14 @@ func (cit *createIndexTask) parseIndexParams() error {
|
|||||||
if !exist {
|
if !exist {
|
||||||
return fmt.Errorf("IndexType not specified")
|
return fmt.Errorf("IndexType not specified")
|
||||||
}
|
}
|
||||||
|
// index parameters defined in the YAML file are merged with the user-provided parameters during create stage
|
||||||
|
if Params.KnowhereConfig.Enable.GetAsBool() {
|
||||||
|
var err error
|
||||||
|
indexParamsMap, err = Params.KnowhereConfig.MergeIndexParams(indexType, paramtable.BuildStage, indexParamsMap)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) {
|
if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) {
|
||||||
err := indexparams.FillDiskIndexParams(Params, indexParamsMap)
|
err := indexparams.FillDiskIndexParams(Params, indexParamsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -39,6 +39,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/common"
|
"github.com/milvus-io/milvus/pkg/common"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
@ -1053,6 +1054,43 @@ func Test_parseIndexParams(t *testing.T) {
|
|||||||
err := cit.parseIndexParams()
|
err := cit.parseIndexParams()
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("verify merge params with yaml", func(t *testing.T) {
|
||||||
|
paramtable.Init()
|
||||||
|
Params.Save("knowhere.HNSW.build.M", "3000")
|
||||||
|
Params.Save("knowhere.HNSW.build.efConstruction", "120")
|
||||||
|
defer Params.Reset("knowhere.HNSW.build.M")
|
||||||
|
defer Params.Reset("knowhere.HNSW.build.efConstruction")
|
||||||
|
|
||||||
|
cit := &createIndexTask{
|
||||||
|
Condition: nil,
|
||||||
|
req: &milvuspb.CreateIndexRequest{
|
||||||
|
ExtraParams: []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: common.IndexTypeKey,
|
||||||
|
Value: "HNSW",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: common.MetricTypeKey,
|
||||||
|
Value: metric.L2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IndexName: "",
|
||||||
|
},
|
||||||
|
fieldSchema: &schemapb.FieldSchema{
|
||||||
|
FieldID: 101,
|
||||||
|
Name: "FieldVector",
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
DataType: schemapb.DataType_FloatVector,
|
||||||
|
TypeParams: []*commonpb.KeyValuePair{
|
||||||
|
{Key: common.DimKey, Value: "768"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := cit.parseIndexParams()
|
||||||
|
// Out of range in json: param 'M' (3000) should be in range [2, 2048]
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_wrapUserIndexParams(t *testing.T) {
|
func Test_wrapUserIndexParams(t *testing.T) {
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||||
)
|
)
|
||||||
@ -30,6 +31,10 @@ var (
|
|||||||
ErrKeyNotFound = errors.New("key not found")
|
ErrKeyNotFound = errors.New("key not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NotFormatPrefix = "knowhere."
|
||||||
|
)
|
||||||
|
|
||||||
func Init(opts ...Option) (*Manager, error) {
|
func Init(opts ...Option) (*Manager, error) {
|
||||||
o := &Options{}
|
o := &Options{}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
@ -55,7 +60,17 @@ func Init(opts ...Option) (*Manager, error) {
|
|||||||
|
|
||||||
var formattedKeys = typeutil.NewConcurrentMap[string, string]()
|
var formattedKeys = typeutil.NewConcurrentMap[string, string]()
|
||||||
|
|
||||||
|
func lowerKey(key string) string {
|
||||||
|
if strings.HasPrefix(key, NotFormatPrefix) {
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
return strings.ToLower(key)
|
||||||
|
}
|
||||||
|
|
||||||
func formatKey(key string) string {
|
func formatKey(key string) string {
|
||||||
|
if strings.HasPrefix(key, NotFormatPrefix) {
|
||||||
|
return key
|
||||||
|
}
|
||||||
cached, ok := formattedKeys.Get(key)
|
cached, ok := formattedKeys.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
return cached
|
return cached
|
||||||
@ -64,3 +79,43 @@ func formatKey(key string) string {
|
|||||||
formattedKeys.Insert(key, result)
|
formattedKeys.Insert(key, result)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func flattenNode(node *yaml.Node, parentKey string, result map[string]string) {
|
||||||
|
// The content of the node should contain key-value pairs in a MappingNode
|
||||||
|
if node.Kind == yaml.MappingNode {
|
||||||
|
for i := 0; i < len(node.Content); i += 2 {
|
||||||
|
keyNode := node.Content[i]
|
||||||
|
valueNode := node.Content[i+1]
|
||||||
|
|
||||||
|
key := keyNode.Value
|
||||||
|
// Construct the full key with parent hierarchy
|
||||||
|
fullKey := key
|
||||||
|
if parentKey != "" {
|
||||||
|
fullKey = parentKey + "." + key
|
||||||
|
}
|
||||||
|
|
||||||
|
switch valueNode.Kind {
|
||||||
|
case yaml.ScalarNode:
|
||||||
|
// Scalar value, store it as a string
|
||||||
|
result[lowerKey(fullKey)] = valueNode.Value
|
||||||
|
result[formatKey(fullKey)] = valueNode.Value
|
||||||
|
case yaml.MappingNode:
|
||||||
|
// Nested map, process recursively
|
||||||
|
flattenNode(valueNode, fullKey, result)
|
||||||
|
case yaml.SequenceNode:
|
||||||
|
// List (sequence), process elements
|
||||||
|
var listStr string
|
||||||
|
for j, item := range valueNode.Content {
|
||||||
|
if j > 0 {
|
||||||
|
listStr += ","
|
||||||
|
}
|
||||||
|
if item.Kind == yaml.ScalarNode {
|
||||||
|
listStr += item.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result[lowerKey(fullKey)] = listStr
|
||||||
|
result[formatKey(fullKey)] = listStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -17,14 +17,16 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"github.com/spf13/cast"
|
|
||||||
"github.com/spf13/viper"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/pkg/log"
|
"github.com/milvus-io/milvus/pkg/log"
|
||||||
)
|
)
|
||||||
@ -115,7 +117,6 @@ func (fs *FileSource) UpdateOptions(opts Options) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (fs *FileSource) loadFromFile() error {
|
func (fs *FileSource) loadFromFile() error {
|
||||||
yamlReader := viper.New()
|
|
||||||
newConfig := make(map[string]string)
|
newConfig := make(map[string]string)
|
||||||
var configFiles []string
|
var configFiles []string
|
||||||
|
|
||||||
@ -128,37 +129,35 @@ func (fs *FileSource) loadFromFile() error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
yamlReader.SetConfigFile(configFile)
|
ext := filepath.Ext(configFile)
|
||||||
if err := yamlReader.ReadInConfig(); err != nil {
|
if len(ext) == 0 || ext[1:] != "yaml" {
|
||||||
|
return fmt.Errorf("Unsupported Config Type: " + ext)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configFile)
|
||||||
|
if err != nil {
|
||||||
return errors.Wrap(err, "Read config failed: "+configFile)
|
return errors.Wrap(err, "Read config failed: "+configFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range yamlReader.AllKeys() {
|
// handle empty file
|
||||||
val := yamlReader.Get(key)
|
if len(data) == 0 {
|
||||||
str, err := cast.ToStringE(val)
|
|
||||||
if err != nil {
|
|
||||||
switch val := val.(type) {
|
|
||||||
case []any:
|
|
||||||
str = str[:0]
|
|
||||||
for _, v := range val {
|
|
||||||
ss, err := cast.ToStringE(v)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("cast to string failed", zap.Any("value", v))
|
|
||||||
}
|
|
||||||
if str == "" {
|
|
||||||
str = ss
|
|
||||||
} else {
|
|
||||||
str = str + "," + ss
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
log.Warn("val is not a slice", zap.Any("value", val))
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var node yaml.Node
|
||||||
|
decoder := yaml.NewDecoder(bytes.NewReader(data))
|
||||||
|
if err := decoder.Decode(&node); err != nil {
|
||||||
|
return errors.Wrap(err, "YAML unmarshal failed: "+configFile)
|
||||||
}
|
}
|
||||||
newConfig[key] = str
|
|
||||||
newConfig[formatKey(key)] = str
|
if node.Kind == yaml.DocumentNode && len(node.Content) > 0 {
|
||||||
|
// Get the content of the Document Node
|
||||||
|
contentNode := node.Content[0]
|
||||||
|
|
||||||
|
// Recursively process the content of the Document Node
|
||||||
|
flattenNode(contentNode, "", newConfig)
|
||||||
|
} else if node.Kind == yaml.MappingNode {
|
||||||
|
flattenNode(&node, "", newConfig)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -76,6 +76,7 @@ type ComponentParam struct {
|
|||||||
DataCoordCfg dataCoordConfig
|
DataCoordCfg dataCoordConfig
|
||||||
DataNodeCfg dataNodeConfig
|
DataNodeCfg dataNodeConfig
|
||||||
IndexNodeCfg indexNodeConfig
|
IndexNodeCfg indexNodeConfig
|
||||||
|
KnowhereConfig knowhereConfig
|
||||||
HTTPCfg httpConfig
|
HTTPCfg httpConfig
|
||||||
LogCfg logConfig
|
LogCfg logConfig
|
||||||
RoleCfg roleConfig
|
RoleCfg roleConfig
|
||||||
@ -134,6 +135,7 @@ func (p *ComponentParam) init(bt *BaseTable) {
|
|||||||
p.LogCfg.init(bt)
|
p.LogCfg.init(bt)
|
||||||
p.RoleCfg.init(bt)
|
p.RoleCfg.init(bt)
|
||||||
p.GpuConfig.init(bt)
|
p.GpuConfig.init(bt)
|
||||||
|
p.KnowhereConfig.init(bt)
|
||||||
|
|
||||||
p.RootCoordGrpcServerCfg.Init("rootCoord", bt)
|
p.RootCoordGrpcServerCfg.Init("rootCoord", bt)
|
||||||
p.ProxyGrpcServerCfg.Init("proxy", bt)
|
p.ProxyGrpcServerCfg.Init("proxy", bt)
|
||||||
|
|||||||
118
pkg/util/paramtable/knowhere_param.go
Normal file
118
pkg/util/paramtable/knowhere_param.go
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
package paramtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/util/hardware"
|
||||||
|
)
|
||||||
|
|
||||||
|
type knowhereConfig struct {
|
||||||
|
Enable ParamItem `refreshable:"true"`
|
||||||
|
IndexParam ParamGroup `refreshable:"true"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
BuildStage = "build"
|
||||||
|
LoadStage = "load"
|
||||||
|
SearchStage = "search"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
BuildDramBudgetKey = "build_dram_budget_gb"
|
||||||
|
NumBuildThreadKey = "num_build_thread"
|
||||||
|
VecFieldSizeKey = "vec_field_size_gb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *knowhereConfig) init(base *BaseTable) {
|
||||||
|
p.IndexParam = ParamGroup{
|
||||||
|
KeyPrefix: "knowhere.",
|
||||||
|
Version: "2.5.0",
|
||||||
|
}
|
||||||
|
p.IndexParam.Init(base.mgr)
|
||||||
|
|
||||||
|
p.Enable = ParamItem{
|
||||||
|
Key: "knowhere.enable",
|
||||||
|
Version: "2.5.0",
|
||||||
|
DefaultValue: "true",
|
||||||
|
}
|
||||||
|
p.Enable.Init(base.mgr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *knowhereConfig) getIndexParam(indexType string, stage string) map[string]string {
|
||||||
|
matchedParam := make(map[string]string)
|
||||||
|
|
||||||
|
params := p.IndexParam.GetValue()
|
||||||
|
prefix := indexType + "." + stage + "."
|
||||||
|
|
||||||
|
for k, v := range params {
|
||||||
|
if strings.HasPrefix(k, prefix) {
|
||||||
|
matchedParam[strings.TrimPrefix(k, prefix)] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return matchedParam
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetKeyFromSlice(indexParams []*commonpb.KeyValuePair, key string) string {
|
||||||
|
for _, param := range indexParams {
|
||||||
|
if param.Key == key {
|
||||||
|
return param.Value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *knowhereConfig) GetRuntimeParameter(stage string) (map[string]string, error) {
|
||||||
|
params := make(map[string]string)
|
||||||
|
|
||||||
|
if stage == BuildStage {
|
||||||
|
params[BuildDramBudgetKey] = fmt.Sprintf("%f", float32(hardware.GetFreeMemoryCount())/(1<<30))
|
||||||
|
params[NumBuildThreadKey] = strconv.Itoa(int(float32(hardware.GetCPUNum())))
|
||||||
|
}
|
||||||
|
|
||||||
|
return params, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *knowhereConfig) UpdateIndexParams(indexType string, stage string, indexParams []*commonpb.KeyValuePair) ([]*commonpb.KeyValuePair, error) {
|
||||||
|
defaultParams := p.getIndexParam(indexType, stage)
|
||||||
|
|
||||||
|
for key, val := range defaultParams {
|
||||||
|
if GetKeyFromSlice(indexParams, key) == "" {
|
||||||
|
indexParams = append(indexParams,
|
||||||
|
&commonpb.KeyValuePair{
|
||||||
|
Key: key,
|
||||||
|
Value: val,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return indexParams, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *knowhereConfig) MergeIndexParams(indexType string, stage string, indexParam map[string]string) (map[string]string, error) {
|
||||||
|
defaultParams := p.getIndexParam(indexType, stage)
|
||||||
|
|
||||||
|
for key, val := range defaultParams {
|
||||||
|
_, existed := indexParam[key]
|
||||||
|
if !existed {
|
||||||
|
indexParam[key] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return indexParam, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *knowhereConfig) MergeResourceParams(vecFieldSize uint64, stage string, indexParam map[string]string) (map[string]string, error) {
|
||||||
|
param, _ := p.GetRuntimeParameter(stage)
|
||||||
|
|
||||||
|
for key, val := range param {
|
||||||
|
indexParam[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
indexParam[VecFieldSizeKey] = fmt.Sprintf("%f", float32(vecFieldSize)/(1<<30))
|
||||||
|
|
||||||
|
return indexParam, nil
|
||||||
|
}
|
||||||
243
pkg/util/paramtable/knowhere_param_test.go
Normal file
243
pkg/util/paramtable/knowhere_param_test.go
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
package paramtable
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKnowhereConfig_GetIndexParam(t *testing.T) {
|
||||||
|
bt := NewBaseTable(SkipRemote(true))
|
||||||
|
cfg := &knowhereConfig{}
|
||||||
|
cfg.init(bt)
|
||||||
|
|
||||||
|
// Set some initial config
|
||||||
|
indexParams := map[string]interface{}{
|
||||||
|
"knowhere.IVF_FLAT.build.nlist": 1024,
|
||||||
|
"knowhere.HNSW.build.efConstruction": 360,
|
||||||
|
"knowhere.DISKANN.search.search_list": 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, val := range indexParams {
|
||||||
|
valStr, _ := json.Marshal(val)
|
||||||
|
bt.Save(key, string(valStr))
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
indexType string
|
||||||
|
stage string
|
||||||
|
expectedKey string
|
||||||
|
expectedValue string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IVF_FLAT Build",
|
||||||
|
indexType: "IVF_FLAT",
|
||||||
|
stage: BuildStage,
|
||||||
|
expectedKey: "nlist",
|
||||||
|
expectedValue: "1024",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HNSW Build",
|
||||||
|
indexType: "HNSW",
|
||||||
|
stage: BuildStage,
|
||||||
|
expectedKey: "efConstruction",
|
||||||
|
expectedValue: "360",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DISKANN Search",
|
||||||
|
indexType: "DISKANN",
|
||||||
|
stage: SearchStage,
|
||||||
|
expectedKey: "search_list",
|
||||||
|
expectedValue: "100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-existent",
|
||||||
|
indexType: "NON_EXISTENT",
|
||||||
|
stage: BuildStage,
|
||||||
|
expectedKey: "",
|
||||||
|
expectedValue: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := cfg.getIndexParam(tt.indexType, tt.stage)
|
||||||
|
if tt.expectedKey != "" {
|
||||||
|
assert.Contains(t, result, tt.expectedKey, "The result should contain the expected key")
|
||||||
|
assert.Equal(t, tt.expectedValue, result[tt.expectedKey], "The value for the key should match the expected value")
|
||||||
|
} else {
|
||||||
|
assert.Empty(t, result, "The result should be empty for non-existent index type")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKnowhereConfig_GetRuntimeParameter(t *testing.T) {
|
||||||
|
cfg := &knowhereConfig{}
|
||||||
|
|
||||||
|
params, err := cfg.GetRuntimeParameter(BuildStage)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, params, BuildDramBudgetKey)
|
||||||
|
assert.Contains(t, params, NumBuildThreadKey)
|
||||||
|
|
||||||
|
params, err = cfg.GetRuntimeParameter(SearchStage)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKnowhereConfig_UpdateParameter(t *testing.T) {
|
||||||
|
bt := NewBaseTable(SkipRemote(true))
|
||||||
|
cfg := &knowhereConfig{}
|
||||||
|
cfg.init(bt)
|
||||||
|
|
||||||
|
// Set some initial config
|
||||||
|
indexParams := map[string]interface{}{
|
||||||
|
"knowhere.IVF_FLAT.build.nlist": 1024,
|
||||||
|
"knowhere.IVF_FLAT.build.num_build_thread": 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, val := range indexParams {
|
||||||
|
valStr, _ := json.Marshal(val)
|
||||||
|
bt.Save(key, string(valStr))
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
indexType string
|
||||||
|
stage string
|
||||||
|
inputParams []*commonpb.KeyValuePair
|
||||||
|
expectedParams map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IVF_FLAT Build",
|
||||||
|
indexType: "IVF_FLAT",
|
||||||
|
stage: BuildStage,
|
||||||
|
inputParams: []*commonpb.KeyValuePair{
|
||||||
|
{Key: "nlist", Value: "128"},
|
||||||
|
{Key: "existing_key", Value: "existing_value"},
|
||||||
|
},
|
||||||
|
expectedParams: map[string]string{
|
||||||
|
"existing_key": "existing_value",
|
||||||
|
"nlist": "128",
|
||||||
|
"num_build_thread": "12",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := cfg.UpdateIndexParams(tt.indexType, tt.stage, tt.inputParams)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
for key, expectedValue := range tt.expectedParams {
|
||||||
|
assert.Equal(t, expectedValue, GetKeyFromSlice(result, key), "The value for key %s should match the expected value", key)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKnowhereConfig_MergeParameter(t *testing.T) {
|
||||||
|
bt := NewBaseTable(SkipRemote(true))
|
||||||
|
cfg := &knowhereConfig{}
|
||||||
|
cfg.init(bt)
|
||||||
|
|
||||||
|
indexParams := map[string]interface{}{
|
||||||
|
"knowhere.IVF_FLAT.build.nlist": 1024,
|
||||||
|
"knowhere.IVF_FLAT.build.num_build_thread": 12,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, val := range indexParams {
|
||||||
|
valStr, _ := json.Marshal(val)
|
||||||
|
bt.Save(key, string(valStr))
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
indexType string
|
||||||
|
stage string
|
||||||
|
inputParams map[string]string
|
||||||
|
expectedParams map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IVF_FLAT Build",
|
||||||
|
indexType: "IVF_FLAT",
|
||||||
|
stage: BuildStage,
|
||||||
|
inputParams: map[string]string{
|
||||||
|
"nlist": "128",
|
||||||
|
"existing_key": "existing_value",
|
||||||
|
},
|
||||||
|
expectedParams: map[string]string{
|
||||||
|
"existing_key": "existing_value",
|
||||||
|
"nlist": "128",
|
||||||
|
"num_build_thread": "12",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := cfg.MergeIndexParams(tt.indexType, tt.stage, tt.inputParams)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
for key, expectedValue := range tt.expectedParams {
|
||||||
|
assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKnowhereConfig_MergeWithResource(t *testing.T) {
|
||||||
|
cfg := &knowhereConfig{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
vecFieldSize uint64
|
||||||
|
inputParams map[string]string
|
||||||
|
expectedParams map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Merge with resource",
|
||||||
|
vecFieldSize: 1024 * 1024 * 1024,
|
||||||
|
inputParams: map[string]string{
|
||||||
|
"existing_key": "existing_value",
|
||||||
|
},
|
||||||
|
expectedParams: map[string]string{
|
||||||
|
"existing_key": "existing_value",
|
||||||
|
BuildDramBudgetKey: "", // We can't predict the exact value, but it should exist
|
||||||
|
NumBuildThreadKey: "", // We can't predict the exact value, but it should exist
|
||||||
|
VecFieldSizeKey: "1.000000",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := cfg.MergeResourceParams(tt.vecFieldSize, BuildStage, tt.inputParams)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
for key, expectedValue := range tt.expectedParams {
|
||||||
|
if expectedValue != "" {
|
||||||
|
assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key)
|
||||||
|
} else {
|
||||||
|
assert.Contains(t, result, key, "The result should contain the key %s", key)
|
||||||
|
assert.NotEmpty(t, result[key], "The value for key %s should not be empty", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetKeyFromSlice(t *testing.T) {
|
||||||
|
indexParams := []*commonpb.KeyValuePair{
|
||||||
|
{Key: "key1", Value: "value1"},
|
||||||
|
{Key: "key2", Value: "value2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, "value1", GetKeyFromSlice(indexParams, "key1"))
|
||||||
|
assert.Equal(t, "value2", GetKeyFromSlice(indexParams, "key2"))
|
||||||
|
assert.Equal(t, "", GetKeyFromSlice(indexParams, "non_existent_key"))
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user