feat: limit search result entries (#42522)

See: #42521

Signed-off-by: Ted Xu <ted.xu@zilliz.com>
This commit is contained in:
Ted Xu 2025-06-05 12:08:33 +08:00 committed by GitHub
parent c262f987db
commit 35c17523de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 216 additions and 0 deletions

View File

@ -304,6 +304,10 @@ proxy:
ddlConcurrency: 16 # The concurrent execution number of DDL at proxy.
dclConcurrency: 16 # The concurrent execution number of DCL at proxy.
mustUsePartitionKey: false # switch for whether proxy must use partition key for the collection
# maximum number of result entries, typically Nq * TopK * GroupSize.
# It costs additional memory and time to process a large number of result entries.
# If the number of result entries exceeds this limit, the search will be rejected.
maxResultEntries: 1000000
accessLog:
enable: false # Whether to enable the access log feature.
minioEnable: false # Whether to upload local access log files to MinIO. This parameter can be specified when proxy.accessLog.filename is not empty.

View File

@ -278,6 +278,10 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
if err = ValidateTask(t); err != nil {
return err
}
log.Debug("search PreExecute done.",
zap.Uint64("guarantee_ts", guaranteeTs),
zap.Bool("use_default_consistency", useDefaultConsistency),

View File

@ -776,6 +776,34 @@ func TestSearchTask_PreExecute(t *testing.T) {
assert.Error(t, err)
})
t.Run("reject large num of result entries", func(t *testing.T) {
collName := "test_large_num_of_result_entries" + funcutil.GenRandomStr()
createColl(t, collName, qc)
task := getSearchTask(t, collName)
task.SearchRequest.Nq = 1000
task.SearchRequest.Topk = 1001
err = task.PreExecute(ctx)
assert.Error(t, err)
task.SearchRequest.Nq = 100
task.SearchRequest.Topk = 100
task.SearchRequest.GroupSize = 200
err = task.PreExecute(ctx)
assert.Error(t, err)
task.SearchRequest.IsAdvanced = true
task.SearchRequest.SubReqs = []*internalpb.SubSearchRequest{
{
Topk: 100,
Nq: 100,
GroupSize: 200,
},
}
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("collection not exist", func(t *testing.T) {
collName := "test_collection_not_exist" + funcutil.GenRandomStr()
task := getSearchTask(t, collName)

View File

@ -0,0 +1,69 @@
package proxy
import (
"fmt"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
// validator is a generic interface for validating tasks
type validator[T any] interface {
validate(request T) error
}
// searchTaskValidator validates search tasks
type searchTaskValidator struct{}
var searchTaskValidatorInstance validator[*searchTask] = &searchTaskValidator{}
func (v *searchTaskValidator) validateSubSearch(subReq *internalpb.SubSearchRequest) error {
// check if number of result entries is too large
nEntries := subReq.GetNq() * subReq.GetTopk()
// if there is group size, multiply it
if subReq.GetGroupSize() > 0 {
nEntries *= subReq.GroupSize
}
if nEntries > paramtable.Get().ProxyCfg.MaxResultEntries.GetAsInt64() {
return fmt.Errorf("number of result entries is too large")
}
return nil
}
func (v *searchTaskValidator) validateSearch(search *searchTask) error {
// check if number of result entries is too large
nEntries := search.GetNq() * search.GetTopk()
// if there is group size, multiply it
if search.GetGroupSize() > 0 {
nEntries *= search.GroupSize
}
if nEntries > paramtable.Get().ProxyCfg.MaxResultEntries.GetAsInt64() {
return fmt.Errorf("number of result entries is too large")
}
return nil
}
func (v *searchTaskValidator) validate(search *searchTask) error {
// if it is a hybrid search, check all sub-searches
if search.SearchRequest.GetIsAdvanced() {
for _, subReq := range search.SearchRequest.GetSubReqs() {
if err := v.validateSubSearch(subReq); err != nil {
return err
}
}
} else {
if err := v.validateSearch(search); err != nil {
return err
}
}
return nil
}
func ValidateTask(task any) error {
switch t := task.(type) {
case *searchTask:
return searchTaskValidatorInstance.validate(t)
default:
return nil
}
}

View File

@ -0,0 +1,99 @@
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
func TestValidateTask(t *testing.T) {
// Initialize paramtable for testing
paramtable.Init()
tests := []struct {
name string
task any
expectError bool
errorMsg string
}{
{
name: "valid search task",
task: &searchTask{
SearchRequest: &internalpb.SearchRequest{
Nq: 10,
Topk: 100,
IsAdvanced: false,
},
},
expectError: false,
},
{
name: "invalid search task",
task: &searchTask{
SearchRequest: &internalpb.SearchRequest{
Nq: 1000,
Topk: 2000,
IsAdvanced: false,
},
},
expectError: true,
errorMsg: "number of result entries is too large",
},
{
name: "invalid search task with group size",
task: &searchTask{
SearchRequest: &internalpb.SearchRequest{
Nq: 100,
Topk: 200,
GroupSize: 100,
IsAdvanced: false,
},
},
expectError: true,
errorMsg: "number of result entries is too large",
},
{
name: "invalid search task with sub-request",
task: &searchTask{
SearchRequest: &internalpb.SearchRequest{
IsAdvanced: true,
SubReqs: []*internalpb.SubSearchRequest{
{
Nq: 100,
Topk: 200,
GroupSize: 100,
},
},
},
},
expectError: true,
errorMsg: "number of result entries is too large",
},
{
name: "non-search task should return nil",
task: "not a search task",
expectError: false,
},
{
name: "nil task should return nil",
task: nil,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateTask(tt.task)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@ -1522,6 +1522,7 @@ type proxyConfig struct {
SkipPartitionKeyCheck ParamItem `refreshable:"true"`
MaxVarCharLength ParamItem `refreshable:"false"`
MaxTextLength ParamItem `refreshable:"false"`
MaxResultEntries ParamItem `refreshable:"true"`
AccessLog AccessLogConfig
@ -1943,6 +1944,17 @@ please adjust in embedded Milvus: false`,
}
p.MaxTextLength.Init(base.mgr)
p.MaxResultEntries = ParamItem{
Key: "proxy.maxResultEntries",
Version: "2.6.0",
DefaultValue: strconv.Itoa(1000000),
Doc: `maximum number of result entries, typically Nq * TopK * GroupSize.
It costs additional memory and time to process a large number of result entries.
If the number of result entries exceeds this limit, the search will be rejected.`,
Export: true,
}
p.MaxResultEntries.Init(base.mgr)
p.GracefulStopTimeout = ParamItem{
Key: "proxy.gracefulStopTimeout",
Version: "2.3.7",