diff --git a/internal/querynodev2/optimizers/query_hook.go b/internal/querynodev2/optimizers/query_hook.go index faaf990d1d..1aabb4adcc 100644 --- a/internal/querynodev2/optimizers/query_hook.go +++ b/internal/querynodev2/optimizers/query_hook.go @@ -12,6 +12,7 @@ import ( "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) // QueryHook is the interface for search/query parameter optimizer. @@ -23,8 +24,8 @@ type QueryHook interface { } func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, queryHook QueryHook, numSegments int) (*querypb.SearchRequest, error) { - // no hook applied, just return - if queryHook == nil { + // no hook applied or disabled, just return + if queryHook == nil || !paramtable.Get().AutoIndexConfig.Enable.GetAsBool() { return req, nil } diff --git a/internal/querynodev2/optimizers/query_hook_test.go b/internal/querynodev2/optimizers/query_hook_test.go index 132619b5e3..df97557f88 100644 --- a/internal/querynodev2/optimizers/query_hook_test.go +++ b/internal/querynodev2/optimizers/query_hook_test.go @@ -13,6 +13,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" ) type QueryHookSuite struct { @@ -30,15 +31,20 @@ func (suite *QueryHookSuite) TearDownTest() { func (suite *QueryHookSuite) TestOptimizeSearchParam() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + paramtable.Init() suite.Run("normal_run", func() { + paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") mockHook := NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` }).Return(nil) suite.queryHook = mockHook - defer func() { suite.queryHook = nil }() + defer func() { + paramtable.Get().Reset(paramtable.Get().AutoIndexConfig.Enable.Key) + suite.queryHook = nil + }() plan := &planpb.PlanNode{ Node: &planpb.PlanNode_VectorAnns{ @@ -63,7 +69,37 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { suite.verifyQueryInfo(req, 50, `{"param": 2}`) }) + suite.Run("disable optimization", func() { + mockHook := NewMockQueryHook(suite.T()) + suite.queryHook = mockHook + defer func() { suite.queryHook = nil }() + + plan := &planpb.PlanNode{ + Node: &planpb.PlanNode_VectorAnns{ + VectorAnns: &planpb.VectorANNS{ + QueryInfo: &planpb.QueryInfo{ + Topk: 100, + SearchParams: `{"param": 1}`, + }, + }, + }, + } + bs, err := proto.Marshal(plan) + suite.Require().NoError(err) + + req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + SerializedExprPlan: bs, + }, + TotalChannelNum: 2, + }, suite.queryHook, 2) + suite.NoError(err) + suite.verifyQueryInfo(req, 100, `{"param": 1}`) + }) + suite.Run("no_hook", func() { + paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") + defer paramtable.Get().Reset(paramtable.Get().AutoIndexConfig.Enable.Key) suite.queryHook = nil plan := &planpb.PlanNode{ Node: &planpb.PlanNode_VectorAnns{ @@ -89,13 +125,17 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { }) suite.Run("other_plannode", func() { + paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") mockHook := NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` }).Return(nil).Maybe() suite.queryHook = mockHook - defer func() { suite.queryHook = nil }() + defer func() { + paramtable.Get().Reset(paramtable.Get().AutoIndexConfig.Enable.Key) + suite.queryHook = nil + }() plan := &planpb.PlanNode{ Node: &planpb.PlanNode_Query{}, @@ -114,6 +154,8 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { }) suite.Run("no_serialized_plan", func() { + paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") + defer paramtable.Get().Reset(paramtable.Get().AutoIndexConfig.Enable.Key) mockHook := NewMockQueryHook(suite.T()) suite.queryHook = mockHook defer func() { suite.queryHook = nil }() @@ -126,13 +168,17 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { }) suite.Run("hook_run_error", func() { + paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") mockHook := NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` }).Return(merr.WrapErrServiceInternal("mocked")) suite.queryHook = mockHook - defer func() { suite.queryHook = nil }() + defer func() { + paramtable.Get().Reset(paramtable.Get().AutoIndexConfig.Enable.Key) + suite.queryHook = nil + }() plan := &planpb.PlanNode{ Node: &planpb.PlanNode_VectorAnns{