mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
enhance: add unify vector index config management (#36846)
issue: #34298 Signed-off-by: xianliang.li <xianliang.li@zilliz.com>
This commit is contained in:
parent
cd2655c861
commit
3224e58c5b
@ -1066,3 +1066,15 @@ streaming:
|
||||
backoffMultiplier: 2 # The multiplier of balance task trigger backoff, 2 by default
|
||||
txn:
|
||||
defaultKeepaliveTimeout: 10s # The default keepalive timeout for wal txn, 10s by default
|
||||
|
||||
# Any configuration related to the knowhere vector search engine
|
||||
knowhere:
|
||||
enable: true # When enable this configuration, the index parameters defined following will be automatically populated as index parameters, without requiring user input.
|
||||
DISKANN: # Index parameters for diskann
|
||||
build: # Diskann build params
|
||||
max_degree: 56 # Maximum degree of the Vamana graph
|
||||
search_list_size: 100 # Size of the candidate list during building graph
|
||||
pq_code_budget_gb_ratio: 0.125 # Size limit on the PQ code (compared with raw data)
|
||||
search_cache_budget_gb_ratio: 0.1 # Ratio of cached node numbers to raw data
|
||||
search: # Diskann search params
|
||||
beam_width_ratio: 4.0 # Ratio between the maximum number of IO requests per search iteration and CPU number.
|
||||
@ -29,10 +29,12 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proto/workerpb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/indexparams"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
@ -159,6 +161,19 @@ func (it *indexBuildTask) PreCheck(ctx context.Context, dependency *taskSchedule
|
||||
|
||||
fieldID := dependency.meta.indexMeta.GetFieldIDByIndexID(segIndex.CollectionID, segIndex.IndexID)
|
||||
binlogIDs := getBinLogIDs(segment, fieldID)
|
||||
|
||||
// When new index parameters are added, these parameters need to be updated to ensure they are included during the index-building process.
|
||||
if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() {
|
||||
var ret error
|
||||
indexParams, ret = Params.KnowhereConfig.UpdateIndexParams(GetIndexType(indexParams), paramtable.BuildStage, indexParams)
|
||||
|
||||
if ret != nil {
|
||||
log.Ctx(ctx).Warn("failed to update index build params defined in yaml", zap.Int64("taskID", it.taskID), zap.Error(ret))
|
||||
it.SetState(indexpb.JobState_JobStateInit, ret.Error())
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if isDiskANNIndex(GetIndexType(indexParams)) {
|
||||
var err error
|
||||
indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams)
|
||||
|
||||
@ -210,6 +210,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
|
||||
zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion()))
|
||||
|
||||
indexType := it.newIndexParams[common.IndexTypeKey]
|
||||
var fieldDataSize uint64
|
||||
if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) {
|
||||
// check index node support disk index
|
||||
if !Params.IndexNodeCfg.EnableDisk.GetAsBool() {
|
||||
@ -225,7 +226,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
|
||||
log.Warn("IndexNode get local used size failed")
|
||||
return err
|
||||
}
|
||||
fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
|
||||
fieldDataSize, err = estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType())
|
||||
if err != nil {
|
||||
log.Warn("IndexNode get local used size failed")
|
||||
return err
|
||||
@ -247,6 +248,11 @@ func (it *indexBuildTask) Execute(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// system resource-related parameters, such as memory limits, CPU limits, and disk limits, are appended here to the parameter list
|
||||
if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() {
|
||||
it.newIndexParams, _ = Params.KnowhereConfig.MergeResourceParams(fieldDataSize, paramtable.BuildStage, it.newIndexParams)
|
||||
}
|
||||
|
||||
storageConfig := &indexcgopb.StorageConfig{
|
||||
Address: it.req.GetStorageConfig().GetAddress(),
|
||||
AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(),
|
||||
|
||||
@ -341,6 +341,14 @@ func (cit *createIndexTask) parseIndexParams() error {
|
||||
if !exist {
|
||||
return fmt.Errorf("IndexType not specified")
|
||||
}
|
||||
// index parameters defined in the YAML file are merged with the user-provided parameters during create stage
|
||||
if Params.KnowhereConfig.Enable.GetAsBool() {
|
||||
var err error
|
||||
indexParamsMap, err = Params.KnowhereConfig.MergeIndexParams(indexType, paramtable.BuildStage, indexParamsMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) {
|
||||
err := indexparams.FillDiskIndexParams(Params, indexParamsMap)
|
||||
if err != nil {
|
||||
|
||||
@ -39,6 +39,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
@ -1053,6 +1054,43 @@ func Test_parseIndexParams(t *testing.T) {
|
||||
err := cit.parseIndexParams()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("verify merge params with yaml", func(t *testing.T) {
|
||||
paramtable.Init()
|
||||
Params.Save("knowhere.HNSW.build.M", "3000")
|
||||
Params.Save("knowhere.HNSW.build.efConstruction", "120")
|
||||
defer Params.Reset("knowhere.HNSW.build.M")
|
||||
defer Params.Reset("knowhere.HNSW.build.efConstruction")
|
||||
|
||||
cit := &createIndexTask{
|
||||
Condition: nil,
|
||||
req: &milvuspb.CreateIndexRequest{
|
||||
ExtraParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.IndexTypeKey,
|
||||
Value: "HNSW",
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metric.L2,
|
||||
},
|
||||
},
|
||||
IndexName: "",
|
||||
},
|
||||
fieldSchema: &schemapb.FieldSchema{
|
||||
FieldID: 101,
|
||||
Name: "FieldVector",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "768"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := cit.parseIndexParams()
|
||||
// Out of range in json: param 'M' (3000) should be in range [2, 2048]
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_wrapUserIndexParams(t *testing.T) {
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
@ -30,6 +31,10 @@ var (
|
||||
ErrKeyNotFound = errors.New("key not found")
|
||||
)
|
||||
|
||||
const (
|
||||
NotFormatPrefix = "knowhere."
|
||||
)
|
||||
|
||||
func Init(opts ...Option) (*Manager, error) {
|
||||
o := &Options{}
|
||||
for _, opt := range opts {
|
||||
@ -55,7 +60,17 @@ func Init(opts ...Option) (*Manager, error) {
|
||||
|
||||
var formattedKeys = typeutil.NewConcurrentMap[string, string]()
|
||||
|
||||
func lowerKey(key string) string {
|
||||
if strings.HasPrefix(key, NotFormatPrefix) {
|
||||
return key
|
||||
}
|
||||
return strings.ToLower(key)
|
||||
}
|
||||
|
||||
func formatKey(key string) string {
|
||||
if strings.HasPrefix(key, NotFormatPrefix) {
|
||||
return key
|
||||
}
|
||||
cached, ok := formattedKeys.Get(key)
|
||||
if ok {
|
||||
return cached
|
||||
@ -64,3 +79,43 @@ func formatKey(key string) string {
|
||||
formattedKeys.Insert(key, result)
|
||||
return result
|
||||
}
|
||||
|
||||
func flattenNode(node *yaml.Node, parentKey string, result map[string]string) {
|
||||
// The content of the node should contain key-value pairs in a MappingNode
|
||||
if node.Kind == yaml.MappingNode {
|
||||
for i := 0; i < len(node.Content); i += 2 {
|
||||
keyNode := node.Content[i]
|
||||
valueNode := node.Content[i+1]
|
||||
|
||||
key := keyNode.Value
|
||||
// Construct the full key with parent hierarchy
|
||||
fullKey := key
|
||||
if parentKey != "" {
|
||||
fullKey = parentKey + "." + key
|
||||
}
|
||||
|
||||
switch valueNode.Kind {
|
||||
case yaml.ScalarNode:
|
||||
// Scalar value, store it as a string
|
||||
result[lowerKey(fullKey)] = valueNode.Value
|
||||
result[formatKey(fullKey)] = valueNode.Value
|
||||
case yaml.MappingNode:
|
||||
// Nested map, process recursively
|
||||
flattenNode(valueNode, fullKey, result)
|
||||
case yaml.SequenceNode:
|
||||
// List (sequence), process elements
|
||||
var listStr string
|
||||
for j, item := range valueNode.Content {
|
||||
if j > 0 {
|
||||
listStr += ","
|
||||
}
|
||||
if item.Kind == yaml.ScalarNode {
|
||||
listStr += item.Value
|
||||
}
|
||||
}
|
||||
result[lowerKey(fullKey)] = listStr
|
||||
result[formatKey(fullKey)] = listStr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,14 +17,16 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
"github.com/spf13/cast"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
)
|
||||
@ -115,7 +117,6 @@ func (fs *FileSource) UpdateOptions(opts Options) {
|
||||
}
|
||||
|
||||
func (fs *FileSource) loadFromFile() error {
|
||||
yamlReader := viper.New()
|
||||
newConfig := make(map[string]string)
|
||||
var configFiles []string
|
||||
|
||||
@ -128,37 +129,35 @@ func (fs *FileSource) loadFromFile() error {
|
||||
continue
|
||||
}
|
||||
|
||||
yamlReader.SetConfigFile(configFile)
|
||||
if err := yamlReader.ReadInConfig(); err != nil {
|
||||
ext := filepath.Ext(configFile)
|
||||
if len(ext) == 0 || ext[1:] != "yaml" {
|
||||
return fmt.Errorf("Unsupported Config Type: " + ext)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configFile)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Read config failed: "+configFile)
|
||||
}
|
||||
|
||||
for _, key := range yamlReader.AllKeys() {
|
||||
val := yamlReader.Get(key)
|
||||
str, err := cast.ToStringE(val)
|
||||
if err != nil {
|
||||
switch val := val.(type) {
|
||||
case []any:
|
||||
str = str[:0]
|
||||
for _, v := range val {
|
||||
ss, err := cast.ToStringE(v)
|
||||
if err != nil {
|
||||
log.Warn("cast to string failed", zap.Any("value", v))
|
||||
}
|
||||
if str == "" {
|
||||
str = ss
|
||||
} else {
|
||||
str = str + "," + ss
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
log.Warn("val is not a slice", zap.Any("value", val))
|
||||
// handle empty file
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var node yaml.Node
|
||||
decoder := yaml.NewDecoder(bytes.NewReader(data))
|
||||
if err := decoder.Decode(&node); err != nil {
|
||||
return errors.Wrap(err, "YAML unmarshal failed: "+configFile)
|
||||
}
|
||||
newConfig[key] = str
|
||||
newConfig[formatKey(key)] = str
|
||||
|
||||
if node.Kind == yaml.DocumentNode && len(node.Content) > 0 {
|
||||
// Get the content of the Document Node
|
||||
contentNode := node.Content[0]
|
||||
|
||||
// Recursively process the content of the Document Node
|
||||
flattenNode(contentNode, "", newConfig)
|
||||
} else if node.Kind == yaml.MappingNode {
|
||||
flattenNode(&node, "", newConfig)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -76,6 +76,7 @@ type ComponentParam struct {
|
||||
DataCoordCfg dataCoordConfig
|
||||
DataNodeCfg dataNodeConfig
|
||||
IndexNodeCfg indexNodeConfig
|
||||
KnowhereConfig knowhereConfig
|
||||
HTTPCfg httpConfig
|
||||
LogCfg logConfig
|
||||
RoleCfg roleConfig
|
||||
@ -134,6 +135,7 @@ func (p *ComponentParam) init(bt *BaseTable) {
|
||||
p.LogCfg.init(bt)
|
||||
p.RoleCfg.init(bt)
|
||||
p.GpuConfig.init(bt)
|
||||
p.KnowhereConfig.init(bt)
|
||||
|
||||
p.RootCoordGrpcServerCfg.Init("rootCoord", bt)
|
||||
p.ProxyGrpcServerCfg.Init("proxy", bt)
|
||||
|
||||
118
pkg/util/paramtable/knowhere_param.go
Normal file
118
pkg/util/paramtable/knowhere_param.go
Normal file
@ -0,0 +1,118 @@
|
||||
package paramtable
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/pkg/util/hardware"
|
||||
)
|
||||
|
||||
type knowhereConfig struct {
|
||||
Enable ParamItem `refreshable:"true"`
|
||||
IndexParam ParamGroup `refreshable:"true"`
|
||||
}
|
||||
|
||||
const (
|
||||
BuildStage = "build"
|
||||
LoadStage = "load"
|
||||
SearchStage = "search"
|
||||
)
|
||||
|
||||
const (
|
||||
BuildDramBudgetKey = "build_dram_budget_gb"
|
||||
NumBuildThreadKey = "num_build_thread"
|
||||
VecFieldSizeKey = "vec_field_size_gb"
|
||||
)
|
||||
|
||||
func (p *knowhereConfig) init(base *BaseTable) {
|
||||
p.IndexParam = ParamGroup{
|
||||
KeyPrefix: "knowhere.",
|
||||
Version: "2.5.0",
|
||||
}
|
||||
p.IndexParam.Init(base.mgr)
|
||||
|
||||
p.Enable = ParamItem{
|
||||
Key: "knowhere.enable",
|
||||
Version: "2.5.0",
|
||||
DefaultValue: "true",
|
||||
}
|
||||
p.Enable.Init(base.mgr)
|
||||
}
|
||||
|
||||
func (p *knowhereConfig) getIndexParam(indexType string, stage string) map[string]string {
|
||||
matchedParam := make(map[string]string)
|
||||
|
||||
params := p.IndexParam.GetValue()
|
||||
prefix := indexType + "." + stage + "."
|
||||
|
||||
for k, v := range params {
|
||||
if strings.HasPrefix(k, prefix) {
|
||||
matchedParam[strings.TrimPrefix(k, prefix)] = v
|
||||
}
|
||||
}
|
||||
|
||||
return matchedParam
|
||||
}
|
||||
|
||||
func GetKeyFromSlice(indexParams []*commonpb.KeyValuePair, key string) string {
|
||||
for _, param := range indexParams {
|
||||
if param.Key == key {
|
||||
return param.Value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *knowhereConfig) GetRuntimeParameter(stage string) (map[string]string, error) {
|
||||
params := make(map[string]string)
|
||||
|
||||
if stage == BuildStage {
|
||||
params[BuildDramBudgetKey] = fmt.Sprintf("%f", float32(hardware.GetFreeMemoryCount())/(1<<30))
|
||||
params[NumBuildThreadKey] = strconv.Itoa(int(float32(hardware.GetCPUNum())))
|
||||
}
|
||||
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func (p *knowhereConfig) UpdateIndexParams(indexType string, stage string, indexParams []*commonpb.KeyValuePair) ([]*commonpb.KeyValuePair, error) {
|
||||
defaultParams := p.getIndexParam(indexType, stage)
|
||||
|
||||
for key, val := range defaultParams {
|
||||
if GetKeyFromSlice(indexParams, key) == "" {
|
||||
indexParams = append(indexParams,
|
||||
&commonpb.KeyValuePair{
|
||||
Key: key,
|
||||
Value: val,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return indexParams, nil
|
||||
}
|
||||
|
||||
func (p *knowhereConfig) MergeIndexParams(indexType string, stage string, indexParam map[string]string) (map[string]string, error) {
|
||||
defaultParams := p.getIndexParam(indexType, stage)
|
||||
|
||||
for key, val := range defaultParams {
|
||||
_, existed := indexParam[key]
|
||||
if !existed {
|
||||
indexParam[key] = val
|
||||
}
|
||||
}
|
||||
|
||||
return indexParam, nil
|
||||
}
|
||||
|
||||
func (p *knowhereConfig) MergeResourceParams(vecFieldSize uint64, stage string, indexParam map[string]string) (map[string]string, error) {
|
||||
param, _ := p.GetRuntimeParameter(stage)
|
||||
|
||||
for key, val := range param {
|
||||
indexParam[key] = val
|
||||
}
|
||||
|
||||
indexParam[VecFieldSizeKey] = fmt.Sprintf("%f", float32(vecFieldSize)/(1<<30))
|
||||
|
||||
return indexParam, nil
|
||||
}
|
||||
243
pkg/util/paramtable/knowhere_param_test.go
Normal file
243
pkg/util/paramtable/knowhere_param_test.go
Normal file
@ -0,0 +1,243 @@
|
||||
package paramtable
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
)
|
||||
|
||||
func TestKnowhereConfig_GetIndexParam(t *testing.T) {
|
||||
bt := NewBaseTable(SkipRemote(true))
|
||||
cfg := &knowhereConfig{}
|
||||
cfg.init(bt)
|
||||
|
||||
// Set some initial config
|
||||
indexParams := map[string]interface{}{
|
||||
"knowhere.IVF_FLAT.build.nlist": 1024,
|
||||
"knowhere.HNSW.build.efConstruction": 360,
|
||||
"knowhere.DISKANN.search.search_list": 100,
|
||||
}
|
||||
|
||||
for key, val := range indexParams {
|
||||
valStr, _ := json.Marshal(val)
|
||||
bt.Save(key, string(valStr))
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
indexType string
|
||||
stage string
|
||||
expectedKey string
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "IVF_FLAT Build",
|
||||
indexType: "IVF_FLAT",
|
||||
stage: BuildStage,
|
||||
expectedKey: "nlist",
|
||||
expectedValue: "1024",
|
||||
},
|
||||
{
|
||||
name: "HNSW Build",
|
||||
indexType: "HNSW",
|
||||
stage: BuildStage,
|
||||
expectedKey: "efConstruction",
|
||||
expectedValue: "360",
|
||||
},
|
||||
{
|
||||
name: "DISKANN Search",
|
||||
indexType: "DISKANN",
|
||||
stage: SearchStage,
|
||||
expectedKey: "search_list",
|
||||
expectedValue: "100",
|
||||
},
|
||||
{
|
||||
name: "Non-existent",
|
||||
indexType: "NON_EXISTENT",
|
||||
stage: BuildStage,
|
||||
expectedKey: "",
|
||||
expectedValue: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := cfg.getIndexParam(tt.indexType, tt.stage)
|
||||
if tt.expectedKey != "" {
|
||||
assert.Contains(t, result, tt.expectedKey, "The result should contain the expected key")
|
||||
assert.Equal(t, tt.expectedValue, result[tt.expectedKey], "The value for the key should match the expected value")
|
||||
} else {
|
||||
assert.Empty(t, result, "The result should be empty for non-existent index type")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKnowhereConfig_GetRuntimeParameter(t *testing.T) {
|
||||
cfg := &knowhereConfig{}
|
||||
|
||||
params, err := cfg.GetRuntimeParameter(BuildStage)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, params, BuildDramBudgetKey)
|
||||
assert.Contains(t, params, NumBuildThreadKey)
|
||||
|
||||
params, err = cfg.GetRuntimeParameter(SearchStage)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, params)
|
||||
}
|
||||
|
||||
func TestKnowhereConfig_UpdateParameter(t *testing.T) {
|
||||
bt := NewBaseTable(SkipRemote(true))
|
||||
cfg := &knowhereConfig{}
|
||||
cfg.init(bt)
|
||||
|
||||
// Set some initial config
|
||||
indexParams := map[string]interface{}{
|
||||
"knowhere.IVF_FLAT.build.nlist": 1024,
|
||||
"knowhere.IVF_FLAT.build.num_build_thread": 12,
|
||||
}
|
||||
|
||||
for key, val := range indexParams {
|
||||
valStr, _ := json.Marshal(val)
|
||||
bt.Save(key, string(valStr))
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
indexType string
|
||||
stage string
|
||||
inputParams []*commonpb.KeyValuePair
|
||||
expectedParams map[string]string
|
||||
}{
|
||||
{
|
||||
name: "IVF_FLAT Build",
|
||||
indexType: "IVF_FLAT",
|
||||
stage: BuildStage,
|
||||
inputParams: []*commonpb.KeyValuePair{
|
||||
{Key: "nlist", Value: "128"},
|
||||
{Key: "existing_key", Value: "existing_value"},
|
||||
},
|
||||
expectedParams: map[string]string{
|
||||
"existing_key": "existing_value",
|
||||
"nlist": "128",
|
||||
"num_build_thread": "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := cfg.UpdateIndexParams(tt.indexType, tt.stage, tt.inputParams)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for key, expectedValue := range tt.expectedParams {
|
||||
assert.Equal(t, expectedValue, GetKeyFromSlice(result, key), "The value for key %s should match the expected value", key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKnowhereConfig_MergeParameter(t *testing.T) {
|
||||
bt := NewBaseTable(SkipRemote(true))
|
||||
cfg := &knowhereConfig{}
|
||||
cfg.init(bt)
|
||||
|
||||
indexParams := map[string]interface{}{
|
||||
"knowhere.IVF_FLAT.build.nlist": 1024,
|
||||
"knowhere.IVF_FLAT.build.num_build_thread": 12,
|
||||
}
|
||||
|
||||
for key, val := range indexParams {
|
||||
valStr, _ := json.Marshal(val)
|
||||
bt.Save(key, string(valStr))
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
indexType string
|
||||
stage string
|
||||
inputParams map[string]string
|
||||
expectedParams map[string]string
|
||||
}{
|
||||
{
|
||||
name: "IVF_FLAT Build",
|
||||
indexType: "IVF_FLAT",
|
||||
stage: BuildStage,
|
||||
inputParams: map[string]string{
|
||||
"nlist": "128",
|
||||
"existing_key": "existing_value",
|
||||
},
|
||||
expectedParams: map[string]string{
|
||||
"existing_key": "existing_value",
|
||||
"nlist": "128",
|
||||
"num_build_thread": "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := cfg.MergeIndexParams(tt.indexType, tt.stage, tt.inputParams)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for key, expectedValue := range tt.expectedParams {
|
||||
assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKnowhereConfig_MergeWithResource(t *testing.T) {
|
||||
cfg := &knowhereConfig{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
vecFieldSize uint64
|
||||
inputParams map[string]string
|
||||
expectedParams map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Merge with resource",
|
||||
vecFieldSize: 1024 * 1024 * 1024,
|
||||
inputParams: map[string]string{
|
||||
"existing_key": "existing_value",
|
||||
},
|
||||
expectedParams: map[string]string{
|
||||
"existing_key": "existing_value",
|
||||
BuildDramBudgetKey: "", // We can't predict the exact value, but it should exist
|
||||
NumBuildThreadKey: "", // We can't predict the exact value, but it should exist
|
||||
VecFieldSizeKey: "1.000000",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := cfg.MergeResourceParams(tt.vecFieldSize, BuildStage, tt.inputParams)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for key, expectedValue := range tt.expectedParams {
|
||||
if expectedValue != "" {
|
||||
assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key)
|
||||
} else {
|
||||
assert.Contains(t, result, key, "The result should contain the key %s", key)
|
||||
assert.NotEmpty(t, result[key], "The value for key %s should not be empty", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetKeyFromSlice(t *testing.T) {
|
||||
indexParams := []*commonpb.KeyValuePair{
|
||||
{Key: "key1", Value: "value1"},
|
||||
{Key: "key2", Value: "value2"},
|
||||
}
|
||||
|
||||
assert.Equal(t, "value1", GetKeyFromSlice(indexParams, "key1"))
|
||||
assert.Equal(t, "value2", GetKeyFromSlice(indexParams, "key2"))
|
||||
assert.Equal(t, "", GetKeyFromSlice(indexParams, "non_existent_key"))
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user