mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
feat: limit search result entries (#42522)
See: #42521 Signed-off-by: Ted Xu <ted.xu@zilliz.com>
This commit is contained in:
parent
c262f987db
commit
35c17523de
@ -304,6 +304,10 @@ proxy:
|
||||
ddlConcurrency: 16 # The concurrent execution number of DDL at proxy.
|
||||
dclConcurrency: 16 # The concurrent execution number of DCL at proxy.
|
||||
mustUsePartitionKey: false # switch for whether proxy must use partition key for the collection
|
||||
# maximum number of result entries, typically Nq * TopK * GroupSize.
|
||||
# It costs additional memory and time to process a large number of result entries.
|
||||
# If the number of result entries exceeds this limit, the search will be rejected.
|
||||
maxResultEntries: 1000000
|
||||
accessLog:
|
||||
enable: false # Whether to enable the access log feature.
|
||||
minioEnable: false # Whether to upload local access log files to MinIO. This parameter can be specified when proxy.accessLog.filename is not empty.
|
||||
|
||||
@ -278,6 +278,10 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
|
||||
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]()
|
||||
|
||||
if err = ValidateTask(t); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug("search PreExecute done.",
|
||||
zap.Uint64("guarantee_ts", guaranteeTs),
|
||||
zap.Bool("use_default_consistency", useDefaultConsistency),
|
||||
|
||||
@ -776,6 +776,34 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("reject large num of result entries", func(t *testing.T) {
|
||||
collName := "test_large_num_of_result_entries" + funcutil.GenRandomStr()
|
||||
createColl(t, collName, qc)
|
||||
|
||||
task := getSearchTask(t, collName)
|
||||
task.SearchRequest.Nq = 1000
|
||||
task.SearchRequest.Topk = 1001
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
task.SearchRequest.Nq = 100
|
||||
task.SearchRequest.Topk = 100
|
||||
task.SearchRequest.GroupSize = 200
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
|
||||
task.SearchRequest.IsAdvanced = true
|
||||
task.SearchRequest.SubReqs = []*internalpb.SubSearchRequest{
|
||||
{
|
||||
Topk: 100,
|
||||
Nq: 100,
|
||||
GroupSize: 200,
|
||||
},
|
||||
}
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("collection not exist", func(t *testing.T) {
|
||||
collName := "test_collection_not_exist" + funcutil.GenRandomStr()
|
||||
task := getSearchTask(t, collName)
|
||||
|
||||
69
internal/proxy/task_validator.go
Normal file
69
internal/proxy/task_validator.go
Normal file
@ -0,0 +1,69 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
)
|
||||
|
||||
// validator is a generic interface for validating tasks
|
||||
type validator[T any] interface {
|
||||
validate(request T) error
|
||||
}
|
||||
|
||||
// searchTaskValidator validates search tasks
|
||||
type searchTaskValidator struct{}
|
||||
|
||||
var searchTaskValidatorInstance validator[*searchTask] = &searchTaskValidator{}
|
||||
|
||||
func (v *searchTaskValidator) validateSubSearch(subReq *internalpb.SubSearchRequest) error {
|
||||
// check if number of result entries is too large
|
||||
nEntries := subReq.GetNq() * subReq.GetTopk()
|
||||
// if there is group size, multiply it
|
||||
if subReq.GetGroupSize() > 0 {
|
||||
nEntries *= subReq.GroupSize
|
||||
}
|
||||
if nEntries > paramtable.Get().ProxyCfg.MaxResultEntries.GetAsInt64() {
|
||||
return fmt.Errorf("number of result entries is too large")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *searchTaskValidator) validateSearch(search *searchTask) error {
|
||||
// check if number of result entries is too large
|
||||
nEntries := search.GetNq() * search.GetTopk()
|
||||
// if there is group size, multiply it
|
||||
if search.GetGroupSize() > 0 {
|
||||
nEntries *= search.GroupSize
|
||||
}
|
||||
if nEntries > paramtable.Get().ProxyCfg.MaxResultEntries.GetAsInt64() {
|
||||
return fmt.Errorf("number of result entries is too large")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *searchTaskValidator) validate(search *searchTask) error {
|
||||
// if it is a hybrid search, check all sub-searches
|
||||
if search.SearchRequest.GetIsAdvanced() {
|
||||
for _, subReq := range search.SearchRequest.GetSubReqs() {
|
||||
if err := v.validateSubSearch(subReq); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := v.validateSearch(search); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateTask(task any) error {
|
||||
switch t := task.(type) {
|
||||
case *searchTask:
|
||||
return searchTaskValidatorInstance.validate(t)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
99
internal/proxy/task_validator_test.go
Normal file
99
internal/proxy/task_validator_test.go
Normal file
@ -0,0 +1,99 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
)
|
||||
|
||||
func TestValidateTask(t *testing.T) {
|
||||
// Initialize paramtable for testing
|
||||
paramtable.Init()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
task any
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid search task",
|
||||
task: &searchTask{
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Nq: 10,
|
||||
Topk: 100,
|
||||
IsAdvanced: false,
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid search task",
|
||||
task: &searchTask{
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Nq: 1000,
|
||||
Topk: 2000,
|
||||
IsAdvanced: false,
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "number of result entries is too large",
|
||||
},
|
||||
{
|
||||
name: "invalid search task with group size",
|
||||
task: &searchTask{
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Nq: 100,
|
||||
Topk: 200,
|
||||
GroupSize: 100,
|
||||
IsAdvanced: false,
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "number of result entries is too large",
|
||||
},
|
||||
{
|
||||
name: "invalid search task with sub-request",
|
||||
task: &searchTask{
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
IsAdvanced: true,
|
||||
SubReqs: []*internalpb.SubSearchRequest{
|
||||
{
|
||||
Nq: 100,
|
||||
Topk: 200,
|
||||
GroupSize: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "number of result entries is too large",
|
||||
},
|
||||
{
|
||||
name: "non-search task should return nil",
|
||||
task: "not a search task",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "nil task should return nil",
|
||||
task: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateTask(tt.task)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1522,6 +1522,7 @@ type proxyConfig struct {
|
||||
SkipPartitionKeyCheck ParamItem `refreshable:"true"`
|
||||
MaxVarCharLength ParamItem `refreshable:"false"`
|
||||
MaxTextLength ParamItem `refreshable:"false"`
|
||||
MaxResultEntries ParamItem `refreshable:"true"`
|
||||
|
||||
AccessLog AccessLogConfig
|
||||
|
||||
@ -1943,6 +1944,17 @@ please adjust in embedded Milvus: false`,
|
||||
}
|
||||
p.MaxTextLength.Init(base.mgr)
|
||||
|
||||
p.MaxResultEntries = ParamItem{
|
||||
Key: "proxy.maxResultEntries",
|
||||
Version: "2.6.0",
|
||||
DefaultValue: strconv.Itoa(1000000),
|
||||
Doc: `maximum number of result entries, typically Nq * TopK * GroupSize.
|
||||
It costs additional memory and time to process a large number of result entries.
|
||||
If the number of result entries exceeds this limit, the search will be rejected.`,
|
||||
Export: true,
|
||||
}
|
||||
p.MaxResultEntries.Init(base.mgr)
|
||||
|
||||
p.GracefulStopTimeout = ParamItem{
|
||||
Key: "proxy.gracefulStopTimeout",
|
||||
Version: "2.3.7",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user