From 3224e58c5b2c9a462788ce3a009de38ccdd7bc2b Mon Sep 17 00:00:00 2001 From: foxspy Date: Fri, 1 Nov 2024 06:18:21 +0800 Subject: [PATCH] enhance: add unify vector index config management (#36846) issue: #34298 Signed-off-by: xianliang.li --- configs/milvus.yaml | 12 + internal/datacoord/task_index.go | 15 ++ internal/indexnode/task_index.go | 8 +- internal/proxy/task_index.go | 8 + internal/proxy/task_index_test.go | 38 ++++ pkg/config/config.go | 55 +++++ pkg/config/file_source.go | 59 +++-- pkg/util/paramtable/component_param.go | 24 +- pkg/util/paramtable/knowhere_param.go | 118 ++++++++++ pkg/util/paramtable/knowhere_param_test.go | 243 +++++++++++++++++++++ 10 files changed, 538 insertions(+), 42 deletions(-) create mode 100644 pkg/util/paramtable/knowhere_param.go create mode 100644 pkg/util/paramtable/knowhere_param_test.go diff --git a/configs/milvus.yaml b/configs/milvus.yaml index ff35477d9b..033708a2c6 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -1066,3 +1066,15 @@ streaming: backoffMultiplier: 2 # The multiplier of balance task trigger backoff, 2 by default txn: defaultKeepaliveTimeout: 10s # The default keepalive timeout for wal txn, 10s by default + +# Any configuration related to the knowhere vector search engine +knowhere: + enable: true # When enable this configuration, the index parameters defined following will be automatically populated as index parameters, without requiring user input. + DISKANN: # Index parameters for diskann + build: # Diskann build params + max_degree: 56 # Maximum degree of the Vamana graph + search_list_size: 100 # Size of the candidate list during building graph + pq_code_budget_gb_ratio: 0.125 # Size limit on the PQ code (compared with raw data) + search_cache_budget_gb_ratio: 0.1 # Ratio of cached node numbers to raw data + search: # Diskann search params + beam_width_ratio: 4.0 # Ratio between the maximum number of IO requests per search iteration and CPU number. \ No newline at end of file diff --git a/internal/datacoord/task_index.go b/internal/datacoord/task_index.go index a72cd0019e..efba15570b 100644 --- a/internal/datacoord/task_index.go +++ b/internal/datacoord/task_index.go @@ -29,10 +29,12 @@ import ( "github.com/milvus-io/milvus/internal/proto/workerpb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/indexparams" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -159,6 +161,19 @@ func (it *indexBuildTask) PreCheck(ctx context.Context, dependency *taskSchedule fieldID := dependency.meta.indexMeta.GetFieldIDByIndexID(segIndex.CollectionID, segIndex.IndexID) binlogIDs := getBinLogIDs(segment, fieldID) + + // When new index parameters are added, these parameters need to be updated to ensure they are included during the index-building process. + if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() { + var ret error + indexParams, ret = Params.KnowhereConfig.UpdateIndexParams(GetIndexType(indexParams), paramtable.BuildStage, indexParams) + + if ret != nil { + log.Ctx(ctx).Warn("failed to update index build params defined in yaml", zap.Int64("taskID", it.taskID), zap.Error(ret)) + it.SetState(indexpb.JobState_JobStateInit, ret.Error()) + return false + } + } + if isDiskANNIndex(GetIndexType(indexParams)) { var err error indexParams, err = indexparams.UpdateDiskIndexBuildParams(Params, indexParams) diff --git a/internal/indexnode/task_index.go b/internal/indexnode/task_index.go index 7808ac43f0..87e0f44be8 100644 --- a/internal/indexnode/task_index.go +++ b/internal/indexnode/task_index.go @@ -210,6 +210,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { zap.Int32("currentIndexVersion", it.req.GetCurrentIndexVersion())) indexType := it.newIndexParams[common.IndexTypeKey] + var fieldDataSize uint64 if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) { // check index node support disk index if !Params.IndexNodeCfg.EnableDisk.GetAsBool() { @@ -225,7 +226,7 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { log.Warn("IndexNode get local used size failed") return err } - fieldDataSize, err := estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) + fieldDataSize, err = estimateFieldDataSize(it.req.GetDim(), it.req.GetNumRows(), it.req.GetField().GetDataType()) if err != nil { log.Warn("IndexNode get local used size failed") return err @@ -247,6 +248,11 @@ func (it *indexBuildTask) Execute(ctx context.Context) error { } } + // system resource-related parameters, such as memory limits, CPU limits, and disk limits, are appended here to the parameter list + if vecindexmgr.GetVecIndexMgrInstance().IsVecIndex(indexType) && Params.KnowhereConfig.Enable.GetAsBool() { + it.newIndexParams, _ = Params.KnowhereConfig.MergeResourceParams(fieldDataSize, paramtable.BuildStage, it.newIndexParams) + } + storageConfig := &indexcgopb.StorageConfig{ Address: it.req.GetStorageConfig().GetAddress(), AccessKeyID: it.req.GetStorageConfig().GetAccessKeyID(), diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index f904a708de..3b69f59d4a 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -341,6 +341,14 @@ func (cit *createIndexTask) parseIndexParams() error { if !exist { return fmt.Errorf("IndexType not specified") } + // index parameters defined in the YAML file are merged with the user-provided parameters during create stage + if Params.KnowhereConfig.Enable.GetAsBool() { + var err error + indexParamsMap, err = Params.KnowhereConfig.MergeIndexParams(indexType, paramtable.BuildStage, indexParamsMap) + if err != nil { + return err + } + } if vecindexmgr.GetVecIndexMgrInstance().IsDiskANN(indexType) { err := indexparams.FillDiskIndexParams(Params, indexParamsMap) if err != nil { diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 8d056fadb0..b5aebd59d5 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -39,6 +39,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -1053,6 +1054,43 @@ func Test_parseIndexParams(t *testing.T) { err := cit.parseIndexParams() assert.Error(t, err) }) + + t.Run("verify merge params with yaml", func(t *testing.T) { + paramtable.Init() + Params.Save("knowhere.HNSW.build.M", "3000") + Params.Save("knowhere.HNSW.build.efConstruction", "120") + defer Params.Reset("knowhere.HNSW.build.M") + defer Params.Reset("knowhere.HNSW.build.efConstruction") + + cit := &createIndexTask{ + Condition: nil, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "HNSW", + }, + { + Key: common.MetricTypeKey, + Value: metric.L2, + }, + }, + IndexName: "", + }, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "FieldVector", + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "768"}, + }, + }, + } + err := cit.parseIndexParams() + // Out of range in json: param 'M' (3000) should be in range [2, 2048] + assert.Error(t, err) + }) } func Test_wrapUserIndexParams(t *testing.T) { diff --git a/pkg/config/config.go b/pkg/config/config.go index fc93c086f7..8b9f3cfe10 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -20,6 +20,7 @@ import ( "strings" "github.com/cockroachdb/errors" + "gopkg.in/yaml.v3" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -30,6 +31,10 @@ var ( ErrKeyNotFound = errors.New("key not found") ) +const ( + NotFormatPrefix = "knowhere." +) + func Init(opts ...Option) (*Manager, error) { o := &Options{} for _, opt := range opts { @@ -55,7 +60,17 @@ func Init(opts ...Option) (*Manager, error) { var formattedKeys = typeutil.NewConcurrentMap[string, string]() +func lowerKey(key string) string { + if strings.HasPrefix(key, NotFormatPrefix) { + return key + } + return strings.ToLower(key) +} + func formatKey(key string) string { + if strings.HasPrefix(key, NotFormatPrefix) { + return key + } cached, ok := formattedKeys.Get(key) if ok { return cached @@ -64,3 +79,43 @@ func formatKey(key string) string { formattedKeys.Insert(key, result) return result } + +func flattenNode(node *yaml.Node, parentKey string, result map[string]string) { + // The content of the node should contain key-value pairs in a MappingNode + if node.Kind == yaml.MappingNode { + for i := 0; i < len(node.Content); i += 2 { + keyNode := node.Content[i] + valueNode := node.Content[i+1] + + key := keyNode.Value + // Construct the full key with parent hierarchy + fullKey := key + if parentKey != "" { + fullKey = parentKey + "." + key + } + + switch valueNode.Kind { + case yaml.ScalarNode: + // Scalar value, store it as a string + result[lowerKey(fullKey)] = valueNode.Value + result[formatKey(fullKey)] = valueNode.Value + case yaml.MappingNode: + // Nested map, process recursively + flattenNode(valueNode, fullKey, result) + case yaml.SequenceNode: + // List (sequence), process elements + var listStr string + for j, item := range valueNode.Content { + if j > 0 { + listStr += "," + } + if item.Kind == yaml.ScalarNode { + listStr += item.Value + } + } + result[lowerKey(fullKey)] = listStr + result[formatKey(fullKey)] = listStr + } + } + } +} diff --git a/pkg/config/file_source.go b/pkg/config/file_source.go index e8402efe6b..4eace878d8 100644 --- a/pkg/config/file_source.go +++ b/pkg/config/file_source.go @@ -17,14 +17,16 @@ package config import ( + "bytes" + "fmt" "os" + "path/filepath" "sync" "github.com/cockroachdb/errors" "github.com/samber/lo" - "github.com/spf13/cast" - "github.com/spf13/viper" "go.uber.org/zap" + "gopkg.in/yaml.v3" "github.com/milvus-io/milvus/pkg/log" ) @@ -115,7 +117,6 @@ func (fs *FileSource) UpdateOptions(opts Options) { } func (fs *FileSource) loadFromFile() error { - yamlReader := viper.New() newConfig := make(map[string]string) var configFiles []string @@ -128,37 +129,35 @@ func (fs *FileSource) loadFromFile() error { continue } - yamlReader.SetConfigFile(configFile) - if err := yamlReader.ReadInConfig(); err != nil { + ext := filepath.Ext(configFile) + if len(ext) == 0 || ext[1:] != "yaml" { + return fmt.Errorf("Unsupported Config Type: " + ext) + } + + data, err := os.ReadFile(configFile) + if err != nil { return errors.Wrap(err, "Read config failed: "+configFile) } - for _, key := range yamlReader.AllKeys() { - val := yamlReader.Get(key) - str, err := cast.ToStringE(val) - if err != nil { - switch val := val.(type) { - case []any: - str = str[:0] - for _, v := range val { - ss, err := cast.ToStringE(v) - if err != nil { - log.Warn("cast to string failed", zap.Any("value", v)) - } - if str == "" { - str = ss - } else { - str = str + "," + ss - } - } + // handle empty file + if len(data) == 0 { + continue + } - default: - log.Warn("val is not a slice", zap.Any("value", val)) - continue - } - } - newConfig[key] = str - newConfig[formatKey(key)] = str + var node yaml.Node + decoder := yaml.NewDecoder(bytes.NewReader(data)) + if err := decoder.Decode(&node); err != nil { + return errors.Wrap(err, "YAML unmarshal failed: "+configFile) + } + + if node.Kind == yaml.DocumentNode && len(node.Content) > 0 { + // Get the content of the Document Node + contentNode := node.Content[0] + + // Recursively process the content of the Document Node + flattenNode(contentNode, "", newConfig) + } else if node.Kind == yaml.MappingNode { + flattenNode(&node, "", newConfig) } } diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 40a3212430..9e0ff8bdc3 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -69,17 +69,18 @@ type ComponentParam struct { GpuConfig gpuConfig TraceCfg traceConfig - RootCoordCfg rootCoordConfig - ProxyCfg proxyConfig - QueryCoordCfg queryCoordConfig - QueryNodeCfg queryNodeConfig - DataCoordCfg dataCoordConfig - DataNodeCfg dataNodeConfig - IndexNodeCfg indexNodeConfig - HTTPCfg httpConfig - LogCfg logConfig - RoleCfg roleConfig - StreamingCfg streamingConfig + RootCoordCfg rootCoordConfig + ProxyCfg proxyConfig + QueryCoordCfg queryCoordConfig + QueryNodeCfg queryNodeConfig + DataCoordCfg dataCoordConfig + DataNodeCfg dataNodeConfig + IndexNodeCfg indexNodeConfig + KnowhereConfig knowhereConfig + HTTPCfg httpConfig + LogCfg logConfig + RoleCfg roleConfig + StreamingCfg streamingConfig RootCoordGrpcServerCfg GrpcServerConfig ProxyGrpcServerCfg GrpcServerConfig @@ -134,6 +135,7 @@ func (p *ComponentParam) init(bt *BaseTable) { p.LogCfg.init(bt) p.RoleCfg.init(bt) p.GpuConfig.init(bt) + p.KnowhereConfig.init(bt) p.RootCoordGrpcServerCfg.Init("rootCoord", bt) p.ProxyGrpcServerCfg.Init("proxy", bt) diff --git a/pkg/util/paramtable/knowhere_param.go b/pkg/util/paramtable/knowhere_param.go new file mode 100644 index 0000000000..035631fac5 --- /dev/null +++ b/pkg/util/paramtable/knowhere_param.go @@ -0,0 +1,118 @@ +package paramtable + +import ( + "fmt" + "strconv" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/pkg/util/hardware" +) + +type knowhereConfig struct { + Enable ParamItem `refreshable:"true"` + IndexParam ParamGroup `refreshable:"true"` +} + +const ( + BuildStage = "build" + LoadStage = "load" + SearchStage = "search" +) + +const ( + BuildDramBudgetKey = "build_dram_budget_gb" + NumBuildThreadKey = "num_build_thread" + VecFieldSizeKey = "vec_field_size_gb" +) + +func (p *knowhereConfig) init(base *BaseTable) { + p.IndexParam = ParamGroup{ + KeyPrefix: "knowhere.", + Version: "2.5.0", + } + p.IndexParam.Init(base.mgr) + + p.Enable = ParamItem{ + Key: "knowhere.enable", + Version: "2.5.0", + DefaultValue: "true", + } + p.Enable.Init(base.mgr) +} + +func (p *knowhereConfig) getIndexParam(indexType string, stage string) map[string]string { + matchedParam := make(map[string]string) + + params := p.IndexParam.GetValue() + prefix := indexType + "." + stage + "." + + for k, v := range params { + if strings.HasPrefix(k, prefix) { + matchedParam[strings.TrimPrefix(k, prefix)] = v + } + } + + return matchedParam +} + +func GetKeyFromSlice(indexParams []*commonpb.KeyValuePair, key string) string { + for _, param := range indexParams { + if param.Key == key { + return param.Value + } + } + return "" +} + +func (p *knowhereConfig) GetRuntimeParameter(stage string) (map[string]string, error) { + params := make(map[string]string) + + if stage == BuildStage { + params[BuildDramBudgetKey] = fmt.Sprintf("%f", float32(hardware.GetFreeMemoryCount())/(1<<30)) + params[NumBuildThreadKey] = strconv.Itoa(int(float32(hardware.GetCPUNum()))) + } + + return params, nil +} + +func (p *knowhereConfig) UpdateIndexParams(indexType string, stage string, indexParams []*commonpb.KeyValuePair) ([]*commonpb.KeyValuePair, error) { + defaultParams := p.getIndexParam(indexType, stage) + + for key, val := range defaultParams { + if GetKeyFromSlice(indexParams, key) == "" { + indexParams = append(indexParams, + &commonpb.KeyValuePair{ + Key: key, + Value: val, + }) + } + } + + return indexParams, nil +} + +func (p *knowhereConfig) MergeIndexParams(indexType string, stage string, indexParam map[string]string) (map[string]string, error) { + defaultParams := p.getIndexParam(indexType, stage) + + for key, val := range defaultParams { + _, existed := indexParam[key] + if !existed { + indexParam[key] = val + } + } + + return indexParam, nil +} + +func (p *knowhereConfig) MergeResourceParams(vecFieldSize uint64, stage string, indexParam map[string]string) (map[string]string, error) { + param, _ := p.GetRuntimeParameter(stage) + + for key, val := range param { + indexParam[key] = val + } + + indexParam[VecFieldSizeKey] = fmt.Sprintf("%f", float32(vecFieldSize)/(1<<30)) + + return indexParam, nil +} diff --git a/pkg/util/paramtable/knowhere_param_test.go b/pkg/util/paramtable/knowhere_param_test.go new file mode 100644 index 0000000000..87cf771b20 --- /dev/null +++ b/pkg/util/paramtable/knowhere_param_test.go @@ -0,0 +1,243 @@ +package paramtable + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" +) + +func TestKnowhereConfig_GetIndexParam(t *testing.T) { + bt := NewBaseTable(SkipRemote(true)) + cfg := &knowhereConfig{} + cfg.init(bt) + + // Set some initial config + indexParams := map[string]interface{}{ + "knowhere.IVF_FLAT.build.nlist": 1024, + "knowhere.HNSW.build.efConstruction": 360, + "knowhere.DISKANN.search.search_list": 100, + } + + for key, val := range indexParams { + valStr, _ := json.Marshal(val) + bt.Save(key, string(valStr)) + } + + tests := []struct { + name string + indexType string + stage string + expectedKey string + expectedValue string + }{ + { + name: "IVF_FLAT Build", + indexType: "IVF_FLAT", + stage: BuildStage, + expectedKey: "nlist", + expectedValue: "1024", + }, + { + name: "HNSW Build", + indexType: "HNSW", + stage: BuildStage, + expectedKey: "efConstruction", + expectedValue: "360", + }, + { + name: "DISKANN Search", + indexType: "DISKANN", + stage: SearchStage, + expectedKey: "search_list", + expectedValue: "100", + }, + { + name: "Non-existent", + indexType: "NON_EXISTENT", + stage: BuildStage, + expectedKey: "", + expectedValue: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := cfg.getIndexParam(tt.indexType, tt.stage) + if tt.expectedKey != "" { + assert.Contains(t, result, tt.expectedKey, "The result should contain the expected key") + assert.Equal(t, tt.expectedValue, result[tt.expectedKey], "The value for the key should match the expected value") + } else { + assert.Empty(t, result, "The result should be empty for non-existent index type") + } + }) + } +} + +func TestKnowhereConfig_GetRuntimeParameter(t *testing.T) { + cfg := &knowhereConfig{} + + params, err := cfg.GetRuntimeParameter(BuildStage) + assert.NoError(t, err) + assert.Contains(t, params, BuildDramBudgetKey) + assert.Contains(t, params, NumBuildThreadKey) + + params, err = cfg.GetRuntimeParameter(SearchStage) + assert.NoError(t, err) + assert.Empty(t, params) +} + +func TestKnowhereConfig_UpdateParameter(t *testing.T) { + bt := NewBaseTable(SkipRemote(true)) + cfg := &knowhereConfig{} + cfg.init(bt) + + // Set some initial config + indexParams := map[string]interface{}{ + "knowhere.IVF_FLAT.build.nlist": 1024, + "knowhere.IVF_FLAT.build.num_build_thread": 12, + } + + for key, val := range indexParams { + valStr, _ := json.Marshal(val) + bt.Save(key, string(valStr)) + } + + tests := []struct { + name string + indexType string + stage string + inputParams []*commonpb.KeyValuePair + expectedParams map[string]string + }{ + { + name: "IVF_FLAT Build", + indexType: "IVF_FLAT", + stage: BuildStage, + inputParams: []*commonpb.KeyValuePair{ + {Key: "nlist", Value: "128"}, + {Key: "existing_key", Value: "existing_value"}, + }, + expectedParams: map[string]string{ + "existing_key": "existing_value", + "nlist": "128", + "num_build_thread": "12", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := cfg.UpdateIndexParams(tt.indexType, tt.stage, tt.inputParams) + assert.NoError(t, err) + + for key, expectedValue := range tt.expectedParams { + assert.Equal(t, expectedValue, GetKeyFromSlice(result, key), "The value for key %s should match the expected value", key) + } + }) + } +} + +func TestKnowhereConfig_MergeParameter(t *testing.T) { + bt := NewBaseTable(SkipRemote(true)) + cfg := &knowhereConfig{} + cfg.init(bt) + + indexParams := map[string]interface{}{ + "knowhere.IVF_FLAT.build.nlist": 1024, + "knowhere.IVF_FLAT.build.num_build_thread": 12, + } + + for key, val := range indexParams { + valStr, _ := json.Marshal(val) + bt.Save(key, string(valStr)) + } + + tests := []struct { + name string + indexType string + stage string + inputParams map[string]string + expectedParams map[string]string + }{ + { + name: "IVF_FLAT Build", + indexType: "IVF_FLAT", + stage: BuildStage, + inputParams: map[string]string{ + "nlist": "128", + "existing_key": "existing_value", + }, + expectedParams: map[string]string{ + "existing_key": "existing_value", + "nlist": "128", + "num_build_thread": "12", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := cfg.MergeIndexParams(tt.indexType, tt.stage, tt.inputParams) + assert.NoError(t, err) + + for key, expectedValue := range tt.expectedParams { + assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key) + } + }) + } +} + +func TestKnowhereConfig_MergeWithResource(t *testing.T) { + cfg := &knowhereConfig{} + + tests := []struct { + name string + vecFieldSize uint64 + inputParams map[string]string + expectedParams map[string]string + }{ + { + name: "Merge with resource", + vecFieldSize: 1024 * 1024 * 1024, + inputParams: map[string]string{ + "existing_key": "existing_value", + }, + expectedParams: map[string]string{ + "existing_key": "existing_value", + BuildDramBudgetKey: "", // We can't predict the exact value, but it should exist + NumBuildThreadKey: "", // We can't predict the exact value, but it should exist + VecFieldSizeKey: "1.000000", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := cfg.MergeResourceParams(tt.vecFieldSize, BuildStage, tt.inputParams) + assert.NoError(t, err) + + for key, expectedValue := range tt.expectedParams { + if expectedValue != "" { + assert.Equal(t, expectedValue, result[key], "The value for key %s should match the expected value", key) + } else { + assert.Contains(t, result, key, "The result should contain the key %s", key) + assert.NotEmpty(t, result[key], "The value for key %s should not be empty", key) + } + } + }) + } +} + +func TestGetKeyFromSlice(t *testing.T) { + indexParams := []*commonpb.KeyValuePair{ + {Key: "key1", Value: "value1"}, + {Key: "key2", Value: "value2"}, + } + + assert.Equal(t, "value1", GetKeyFromSlice(indexParams, "key1")) + assert.Equal(t, "value2", GetKeyFromSlice(indexParams, "key2")) + assert.Equal(t, "", GetKeyFromSlice(indexParams, "non_existent_key")) +}