diff --git a/configs/milvus.yaml b/configs/milvus.yaml index a2a5bcfe62..bf1e67d820 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -304,6 +304,10 @@ proxy: ddlConcurrency: 16 # The concurrent execution number of DDL at proxy. dclConcurrency: 16 # The concurrent execution number of DCL at proxy. mustUsePartitionKey: false # switch for whether proxy must use partition key for the collection + # maximum number of result entries, typically Nq * TopK * GroupSize. + # It costs additional memory and time to process a large number of result entries. + # If the number of result entries exceeds this limit, the search will be rejected. + maxResultEntries: 1000000 accessLog: enable: false # Whether to enable the access log feature. minioEnable: false # Whether to upload local access log files to MinIO. This parameter can be specified when proxy.accessLog.filename is not empty. diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index e3564b6301..759804f431 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -278,6 +278,10 @@ func (t *searchTask) PreExecute(ctx context.Context) error { t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]() + if err = ValidateTask(t); err != nil { + return err + } + log.Debug("search PreExecute done.", zap.Uint64("guarantee_ts", guaranteeTs), zap.Bool("use_default_consistency", useDefaultConsistency), diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index e5211f986f..9b4c6dc5fa 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -776,6 +776,34 @@ func TestSearchTask_PreExecute(t *testing.T) { assert.Error(t, err) }) + t.Run("reject large num of result entries", func(t *testing.T) { + collName := "test_large_num_of_result_entries" + funcutil.GenRandomStr() + createColl(t, collName, qc) + + task := getSearchTask(t, collName) + task.SearchRequest.Nq = 1000 + task.SearchRequest.Topk = 1001 + err = task.PreExecute(ctx) + assert.Error(t, err) + + task.SearchRequest.Nq = 100 + task.SearchRequest.Topk = 100 + task.SearchRequest.GroupSize = 200 + err = task.PreExecute(ctx) + assert.Error(t, err) + + task.SearchRequest.IsAdvanced = true + task.SearchRequest.SubReqs = []*internalpb.SubSearchRequest{ + { + Topk: 100, + Nq: 100, + GroupSize: 200, + }, + } + err = task.PreExecute(ctx) + assert.Error(t, err) + }) + t.Run("collection not exist", func(t *testing.T) { collName := "test_collection_not_exist" + funcutil.GenRandomStr() task := getSearchTask(t, collName) diff --git a/internal/proxy/task_validator.go b/internal/proxy/task_validator.go new file mode 100644 index 0000000000..6cb85f955c --- /dev/null +++ b/internal/proxy/task_validator.go @@ -0,0 +1,69 @@ +package proxy + +import ( + "fmt" + + "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" +) + +// validator is a generic interface for validating tasks +type validator[T any] interface { + validate(request T) error +} + +// searchTaskValidator validates search tasks +type searchTaskValidator struct{} + +var searchTaskValidatorInstance validator[*searchTask] = &searchTaskValidator{} + +func (v *searchTaskValidator) validateSubSearch(subReq *internalpb.SubSearchRequest) error { + // check if number of result entries is too large + nEntries := subReq.GetNq() * subReq.GetTopk() + // if there is group size, multiply it + if subReq.GetGroupSize() > 0 { + nEntries *= subReq.GroupSize + } + if nEntries > paramtable.Get().ProxyCfg.MaxResultEntries.GetAsInt64() { + return fmt.Errorf("number of result entries is too large") + } + return nil +} + +func (v *searchTaskValidator) validateSearch(search *searchTask) error { + // check if number of result entries is too large + nEntries := search.GetNq() * search.GetTopk() + // if there is group size, multiply it + if search.GetGroupSize() > 0 { + nEntries *= search.GroupSize + } + if nEntries > paramtable.Get().ProxyCfg.MaxResultEntries.GetAsInt64() { + return fmt.Errorf("number of result entries is too large") + } + return nil +} + +func (v *searchTaskValidator) validate(search *searchTask) error { + // if it is a hybrid search, check all sub-searches + if search.SearchRequest.GetIsAdvanced() { + for _, subReq := range search.SearchRequest.GetSubReqs() { + if err := v.validateSubSearch(subReq); err != nil { + return err + } + } + } else { + if err := v.validateSearch(search); err != nil { + return err + } + } + return nil +} + +func ValidateTask(task any) error { + switch t := task.(type) { + case *searchTask: + return searchTaskValidatorInstance.validate(t) + default: + return nil + } +} diff --git a/internal/proxy/task_validator_test.go b/internal/proxy/task_validator_test.go new file mode 100644 index 0000000000..48c2f3536b --- /dev/null +++ b/internal/proxy/task_validator_test.go @@ -0,0 +1,99 @@ +package proxy + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" +) + +func TestValidateTask(t *testing.T) { + // Initialize paramtable for testing + paramtable.Init() + + tests := []struct { + name string + task any + expectError bool + errorMsg string + }{ + { + name: "valid search task", + task: &searchTask{ + SearchRequest: &internalpb.SearchRequest{ + Nq: 10, + Topk: 100, + IsAdvanced: false, + }, + }, + expectError: false, + }, + { + name: "invalid search task", + task: &searchTask{ + SearchRequest: &internalpb.SearchRequest{ + Nq: 1000, + Topk: 2000, + IsAdvanced: false, + }, + }, + expectError: true, + errorMsg: "number of result entries is too large", + }, + { + name: "invalid search task with group size", + task: &searchTask{ + SearchRequest: &internalpb.SearchRequest{ + Nq: 100, + Topk: 200, + GroupSize: 100, + IsAdvanced: false, + }, + }, + expectError: true, + errorMsg: "number of result entries is too large", + }, + { + name: "invalid search task with sub-request", + task: &searchTask{ + SearchRequest: &internalpb.SearchRequest{ + IsAdvanced: true, + SubReqs: []*internalpb.SubSearchRequest{ + { + Nq: 100, + Topk: 200, + GroupSize: 100, + }, + }, + }, + }, + expectError: true, + errorMsg: "number of result entries is too large", + }, + { + name: "non-search task should return nil", + task: "not a search task", + expectError: false, + }, + { + name: "nil task should return nil", + task: nil, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateTask(tt.task) + + if tt.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index e3388010ff..052206916a 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1522,6 +1522,7 @@ type proxyConfig struct { SkipPartitionKeyCheck ParamItem `refreshable:"true"` MaxVarCharLength ParamItem `refreshable:"false"` MaxTextLength ParamItem `refreshable:"false"` + MaxResultEntries ParamItem `refreshable:"true"` AccessLog AccessLogConfig @@ -1943,6 +1944,17 @@ please adjust in embedded Milvus: false`, } p.MaxTextLength.Init(base.mgr) + p.MaxResultEntries = ParamItem{ + Key: "proxy.maxResultEntries", + Version: "2.6.0", + DefaultValue: strconv.Itoa(1000000), + Doc: `maximum number of result entries, typically Nq * TopK * GroupSize. +It costs additional memory and time to process a large number of result entries. +If the number of result entries exceeds this limit, the search will be rejected.`, + Export: true, + } + p.MaxResultEntries.Init(base.mgr) + p.GracefulStopTimeout = ParamItem{ Key: "proxy.gracefulStopTimeout", Version: "2.3.7",