diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 1d936d847c..ac9eb9c60a 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -2801,6 +2801,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest) node: node, lb: node.lbPolicy, enableMaterializedView: node.enableMaterializedView, + mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(), } guaranteeTs := request.GuaranteeTimestamp @@ -2997,11 +2998,12 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea ), ReqID: paramtable.GetNodeID(), }, - request: newSearchReq, - tr: timerecord.NewTimeRecorder(method), - qc: node.queryCoord, - node: node, - lb: node.lbPolicy, + request: newSearchReq, + tr: timerecord.NewTimeRecorder(method), + qc: node.queryCoord, + node: node, + lb: node.lbPolicy, + mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(), } guaranteeTs := request.GuaranteeTimestamp @@ -3411,9 +3413,10 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* ), ReqID: paramtable.GetNodeID(), }, - request: request, - qc: node.queryCoord, - lb: node.lbPolicy, + request: request, + qc: node.queryCoord, + lb: node.lbPolicy, + mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(), } res, err := node.query(ctx, qt) if merr.Ok(res.Status) && err == nil { diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 679c86b658..ebbfd7e2d5 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -1590,6 +1590,14 @@ func TestProxy(t *testing.T) { resp, err := proxy.Search(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + + { + Params.Save(Params.ProxyCfg.MustUsePartitionKey.Key, "true") + resp, err := proxy.Search(ctx, req) + assert.NoError(t, err) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + Params.Reset(Params.ProxyCfg.MustUsePartitionKey.Key) + } }) constructAdvancedSearchRequest := func() *milvuspb.SearchRequest { diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 54a9898e9e..e418c77ce8 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -212,6 +212,12 @@ func (t *createCollectionTask) validatePartitionKey() error { } } + mustPartitionKey := Params.ProxyCfg.MustUsePartitionKey.GetAsBool() + if mustPartitionKey && idx == -1 { + return merr.WrapErrParameterInvalidMsg("partition key must be set when creating the collection" + + " because the mustUsePartitionKey config is true") + } + if idx == -1 { if t.GetNumPartitions() != 0 { return fmt.Errorf("num_partitions should only be specified with partition key field enabled") diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 3f8d3cbc82..77536e874f 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -112,6 +112,13 @@ func (it *insertTask) PreExecute(ctx context.Context) error { return err } + maxInsertSize := Params.QuotaConfig.MaxInsertSize.GetAsInt() + if maxInsertSize != -1 && it.insertMsg.Size() > maxInsertSize { + log.Warn("insert request size exceeds maxInsertSize", + zap.Int("request size", it.insertMsg.Size()), zap.Int("maxInsertSize", maxInsertSize)) + return merr.WrapErrParameterTooLarge("insert request size exceeds maxInsertSize") + } + schema, err := globalMetaCache.GetCollectionSchema(ctx, it.insertMsg.GetDbName(), collectionName) if err != nil { log.Warn("get collection schema from global meta cache failed", zap.String("collectionName", collectionName), zap.Error(err)) diff --git a/internal/proxy/task_insert_test.go b/internal/proxy/task_insert_test.go index 4f7ef10af2..083398ea98 100644 --- a/internal/proxy/task_insert_test.go +++ b/internal/proxy/task_insert_test.go @@ -11,6 +11,8 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) func TestInsertTask_CheckAligned(t *testing.T) { @@ -285,3 +287,23 @@ func TestInsertTask(t *testing.T) { assert.ElementsMatch(t, channels, it.pChannels) }) } + +func TestMaxInsertSize(t *testing.T) { + t.Run("test MaxInsertSize", func(t *testing.T) { + paramtable.Init() + Params.Save(Params.QuotaConfig.MaxInsertSize.Key, "1") + defer Params.Reset(Params.QuotaConfig.MaxInsertSize.Key) + it := insertTask{ + ctx: context.Background(), + insertMsg: &msgstream.InsertMsg{ + InsertRequest: msgpb.InsertRequest{ + DbName: "hooooooo", + CollectionName: "fooooo", + }, + }, + } + err := it.PreExecute(context.Background()) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterTooLarge) + }) +} diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index ff7f886df6..618805a4f9 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -68,6 +68,7 @@ type queryTask struct { reQuery bool allQueryCnt int64 totalRelatedDataSize int64 + mustUsePartitionKey bool } type queryParams struct { @@ -303,6 +304,10 @@ func (t *queryTask) PreExecute(ctx context.Context) error { if t.partitionKeyMode && len(t.request.GetPartitionNames()) != 0 { return errors.New("not support manually specifying the partition names if partition key mode is used") } + if t.mustUsePartitionKey && !t.partitionKeyMode { + return merr.WrapErrParameterInvalidMsg("must use partition key in the query request " + + "because the mustUsePartitionKey config is true") + } for _, tag := range t.request.PartitionNames { if err := validatePartitionTag(tag, false); err != nil { diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 99d6bae869..fddf9e42c4 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -170,6 +170,15 @@ func TestQueryTask_all(t *testing.T) { assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp) task.ctx = ctx1 assert.NoError(t, task.PreExecute(ctx)) + + { + task.mustUsePartitionKey = true + err := task.PreExecute(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + task.mustUsePartitionKey = false + } + // after preExecute assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp) diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 019e0b7ae8..2863c19d88 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -61,6 +61,7 @@ type searchTask struct { requery bool partitionKeyMode bool enableMaterializedView bool + mustUsePartitionKey bool userOutputFields []string @@ -135,6 +136,10 @@ func (t *searchTask) PreExecute(ctx context.Context) error { if t.partitionKeyMode && len(t.request.GetPartitionNames()) != 0 { return errors.New("not support manually specifying the partition names if partition key mode is used") } + if t.mustUsePartitionKey && !t.partitionKeyMode { + return merr.WrapErrParameterInvalidMsg("must use partition key in the search request " + + "because the mustUsePartitionKey config is true") + } if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 { // translate partition name to partition ids. Use regex-pattern to match partition name. diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 55c860781a..042d4d4ef7 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -328,6 +328,14 @@ func TestSearchTask_PreExecute(t *testing.T) { assert.NoError(t, task.PreExecute(ctx)) assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp) + { + task.mustUsePartitionKey = true + err = task.PreExecute(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + task.mustUsePartitionKey = false + } + // field not exist task.ctx = context.TODO() task.request.OutputFields = []string{testInt64Field + funcutil.GenRandomStr()} diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index 70dc12ddaa..01250dc0ac 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -696,6 +696,12 @@ func TestCreateCollectionTask(t *testing.T) { err = task.PreExecute(ctx) assert.NoError(t, err) + Params.Save(Params.ProxyCfg.MustUsePartitionKey.Key, "true") + err = task.PreExecute(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + Params.Reset(Params.ProxyCfg.MustUsePartitionKey.Key) + task.Schema = []byte{0x1, 0x2, 0x3, 0x4} err = task.PreExecute(ctx) assert.Error(t, err) diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index 6790ccd6b7..630f4d02be 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -109,8 +109,9 @@ var ( ErrIoUnexpectEOF = newMilvusError("unexpected EOF", 1002, true) // Parameter related - ErrParameterInvalid = newMilvusError("invalid parameter", 1100, false) - ErrParameterMissing = newMilvusError("missing parameter", 1101, false) + ErrParameterInvalid = newMilvusError("invalid parameter", 1100, false) + ErrParameterMissing = newMilvusError("missing parameter", 1101, false) + ErrParameterTooLarge = newMilvusError("parameter too large", 1102, false) // Metrics related ErrMetricNotFound = newMilvusError("metric not found", 1200, false) diff --git a/pkg/util/merr/errors_test.go b/pkg/util/merr/errors_test.go index ff55a70b12..51bc6d854b 100644 --- a/pkg/util/merr/errors_test.go +++ b/pkg/util/merr/errors_test.go @@ -136,6 +136,7 @@ func (s *ErrSuite) TestWrap() { s.ErrorIs(WrapErrParameterInvalid(8, 1, "failed to create"), ErrParameterInvalid) s.ErrorIs(WrapErrParameterInvalidRange(1, 1<<16, 0, "topk should be in range"), ErrParameterInvalid) s.ErrorIs(WrapErrParameterMissing("alias_name", "no alias parameter"), ErrParameterMissing) + s.ErrorIs(WrapErrParameterTooLarge("unit test"), ErrParameterTooLarge) // Metrics related s.ErrorIs(WrapErrMetricNotFound("unknown", "failed to get metric"), ErrMetricNotFound) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 7846d503ce..397aa9cbab 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -863,6 +863,14 @@ func WrapErrParameterMissing[T any](param T, msg ...string) error { return err } +func WrapErrParameterTooLarge(name string, msg ...string) error { + err := wrapFields(ErrParameterTooLarge, value("message", name)) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "->")) + } + return err +} + // Metrics related func WrapErrMetricNotFound(name string, msg ...string) error { err := wrapFields(ErrMetricNotFound, value("metric", name)) diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 6694a600c8..0e2b62129b 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1011,6 +1011,7 @@ type proxyConfig struct { RetryTimesOnReplica ParamItem `refreshable:"true"` RetryTimesOnHealthCheck ParamItem `refreshable:"true"` PartitionNameRegexp ParamItem `refreshable:"true"` + MustUsePartitionKey ParamItem `refreshable:"true"` AccessLog AccessLogConfig @@ -1338,6 +1339,15 @@ please adjust in embedded Milvus: false`, } p.PartitionNameRegexp.Init(base.mgr) + p.MustUsePartitionKey = ParamItem{ + Key: "proxy.mustUsePartitionKey", + Version: "2.4.1", + DefaultValue: "false", + Doc: "switch for whether proxy must use partition key for the collection", + Export: true, + } + p.MustUsePartitionKey.Init(base.mgr) + p.GracefulStopTimeout = ParamItem{ Key: "proxy.gracefulStopTimeout", Version: "2.3.7", diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index 3f4b90277d..5627e6696c 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -170,6 +170,10 @@ func TestComponentParam(t *testing.T) { params.Save("proxy.gracefulStopTimeout", "100") assert.Equal(t, 100*time.Second, Params.GracefulStopTimeout.GetAsDuration(time.Second)) + + assert.False(t, Params.MustUsePartitionKey.GetAsBool()) + params.Save("proxy.mustUsePartitionKey", "true") + assert.True(t, Params.MustUsePartitionKey.GetAsBool()) }) // t.Run("test proxyConfig panic", func(t *testing.T) { diff --git a/pkg/util/paramtable/quota_param.go b/pkg/util/paramtable/quota_param.go index 8c108f1ac5..33996a7b77 100644 --- a/pkg/util/paramtable/quota_param.go +++ b/pkg/util/paramtable/quota_param.go @@ -128,6 +128,7 @@ type quotaConfig struct { NQLimit ParamItem `refreshable:"true"` MaxQueryResultWindow ParamItem `refreshable:"true"` MaxOutputSize ParamItem `refreshable:"true"` + MaxInsertSize ParamItem `refreshable:"true"` MaxResourceGroupNumOfQueryNode ParamItem `refreshable:"true"` // limit writing @@ -1537,6 +1538,15 @@ Check https://milvus.io/docs/limitations.md for more details.`, } p.MaxOutputSize.Init(base.mgr) + p.MaxInsertSize = ParamItem{ + Key: "quotaAndLimits.limits.maxInsertSize", + Version: "2.4.1", + DefaultValue: "-1", // -1 means no limit, the unit is byte + Doc: `maximum size of a single insert request, in bytes, -1 means no limit`, + Export: true, + } + p.MaxInsertSize.Init(base.mgr) + p.MaxResourceGroupNumOfQueryNode = ParamItem{ Key: "quotaAndLimits.limits.maxResourceGroupNumOfQueryNode", Version: "2.4.1", diff --git a/pkg/util/paramtable/quota_param_test.go b/pkg/util/paramtable/quota_param_test.go index 0ca8f7006c..14f589a348 100644 --- a/pkg/util/paramtable/quota_param_test.go +++ b/pkg/util/paramtable/quota_param_test.go @@ -175,11 +175,16 @@ func TestQuotaParam(t *testing.T) { }) t.Run("test limits", func(t *testing.T) { + params.Init(NewBaseTable(SkipRemote(true))) assert.Equal(t, 65536, qc.MaxCollectionNum.GetAsInt()) assert.Equal(t, 65536, qc.MaxCollectionNumPerDB.GetAsInt()) assert.Equal(t, 1024, params.QuotaConfig.MaxResourceGroupNumOfQueryNode.GetAsInt()) params.Save(params.QuotaConfig.MaxResourceGroupNumOfQueryNode.Key, "512") assert.Equal(t, 512, params.QuotaConfig.MaxResourceGroupNumOfQueryNode.GetAsInt()) + + assert.Equal(t, -1, qc.MaxInsertSize.GetAsInt()) + baseParams.Save(params.QuotaConfig.MaxInsertSize.Key, "1024") + assert.Equal(t, 1024, qc.MaxInsertSize.GetAsInt()) }) t.Run("test limit writing", func(t *testing.T) {