diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 45c6233706..81970e6566 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/cluster" "github.com/milvus-io/milvus/internal/querynodev2/delegator/deletebuffer" + "github.com/milvus-io/milvus/internal/querynodev2/optimizers" "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" @@ -106,6 +107,8 @@ type shardDelegator struct { loader segments.Loader tsCond *sync.Cond latestTsafe *atomic.Uint64 + // queryHook + queryHook optimizers.QueryHook } // getLogger returns the zap logger with pre-defined shard attributes. @@ -226,6 +229,13 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest zap.Int("sealedNum", sealedNum), zap.Int("growingNum", len(growing)), ) + + req, err = optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum) + if err != nil { + log.Warn("failed to optimize search params", zap.Error(err)) + return nil, err + } + tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest) if err != nil { log.Warn("Search organizeSubTask failed", zap.Error(err)) @@ -636,7 +646,7 @@ func (sd *shardDelegator) Close() { // NewShardDelegator creates a new ShardDelegator instance with all fields initialized. func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64, workerManager cluster.Manager, manager *segments.Manager, tsafeManager tsafe.Manager, loader segments.Loader, - factory msgstream.Factory, startTs uint64, + factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, ) (ShardDelegator, error) { log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID), zap.Int64("replicaID", replicaID), @@ -669,6 +679,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni latestTsafe: atomic.NewUint64(startTs), loader: loader, factory: factory, + queryHook: queryHook, } m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) diff --git a/internal/querynodev2/delegator/delegator_data_test.go b/internal/querynodev2/delegator/delegator_data_test.go index 2d93cac83b..4f7eac04b4 100644 --- a/internal/querynodev2/delegator/delegator_data_test.go +++ b/internal/querynodev2/delegator/delegator_data_test.go @@ -135,7 +135,7 @@ func (s *DelegatorDataSuite) SetupTest() { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { return s.mq, nil }, - }, 10000) + }, 10000, nil) s.Require().NoError(err) } diff --git a/internal/querynodev2/delegator/delegator_test.go b/internal/querynodev2/delegator/delegator_test.go index bedc51e44d..a8f6e4bae7 100644 --- a/internal/querynodev2/delegator/delegator_test.go +++ b/internal/querynodev2/delegator/delegator_test.go @@ -159,7 +159,7 @@ func (s *DelegatorSuite) SetupTest() { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { return s.mq, nil }, - }, 10000) + }, 10000, nil) s.Require().NoError(err) } diff --git a/internal/querynodev2/handlers.go b/internal/querynodev2/handlers.go index 8f734acea8..0e5bfe754b 100644 --- a/internal/querynodev2/handlers.go +++ b/internal/querynodev2/handlers.go @@ -21,20 +21,17 @@ import ( "fmt" "strconv" - "github.com/golang/protobuf/proto" "github.com/samber/lo" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tasks" "github.com/milvus-io/milvus/internal/util/streamrpc" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -292,72 +289,6 @@ func (node *QueryNode) queryStreamSegments(ctx context.Context, req *querypb.Que return nil } -func (node *QueryNode) optimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, deleg delegator.ShardDelegator) (*querypb.SearchRequest, error) { - // no hook applied, just return - if node.queryHook == nil { - return req, nil - } - - log := log.Ctx(ctx).With(zap.Int64("collection", req.GetReq().GetCollectionID())) - - serializedPlan := req.GetReq().GetSerializedExprPlan() - // plan not found - if serializedPlan == nil { - log.Warn("serialized plan not found") - return req, merr.WrapErrParameterInvalid("serialized search plan", "nil") - } - - channelNum := req.GetTotalChannelNum() - // not set, change to conservative channel num 1 - if channelNum <= 0 { - channelNum = 1 - } - - plan := planpb.PlanNode{} - err := proto.Unmarshal(serializedPlan, &plan) - if err != nil { - log.Warn("failed to unmarshal plan", zap.Error(err)) - return nil, merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error()) - } - - switch plan.GetNode().(type) { - case *planpb.PlanNode_VectorAnns: - // ignore growing ones for now since they will always be brute force - sealed, _ := deleg.GetSegmentInfo(true) - sealedNum := lo.Reduce(sealed, func(sum int, item delegator.SnapshotItem, _ int) int { - return sum + len(item.Segments) - }, 0) - // use shardNum * segments num in shard to estimate total segment number - estSegmentNum := sealedNum * int(channelNum) - withFilter := (plan.GetVectorAnns().GetPredicates() != nil) - queryInfo := plan.GetVectorAnns().GetQueryInfo() - params := map[string]any{ - common.TopKKey: queryInfo.GetTopk(), - common.SearchParamKey: queryInfo.GetSearchParams(), - common.SegmentNumKey: estSegmentNum, - common.WithFilterKey: withFilter, - common.CollectionKey: req.GetReq().GetCollectionID(), - } - err := node.queryHook.Run(params) - if err != nil { - log.Warn("failed to execute queryHook", zap.Error(err)) - return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed") - } - queryInfo.Topk = params[common.TopKKey].(int64) - queryInfo.SearchParams = params[common.SearchParamKey].(string) - serializedExprPlan, err := proto.Marshal(&plan) - if err != nil { - log.Warn("failed to marshal optimized plan", zap.Error(err)) - return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error()) - } - req.Req.SerializedExprPlan = serializedExprPlan - log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo)) - default: - log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode()))) - } - return req, nil -} - func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchRequest, channel string) (*internalpb.SearchResults, error) { log := log.Ctx(ctx).With( zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()), @@ -397,11 +328,6 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err)) return nil, err } - req, err = node.optimizeSearchParams(ctx, req, sd) - if err != nil { - log.Warn("failed to optimize search params", zap.Error(err)) - return nil, err - } // do search results, err := sd.Search(searchCtx, req) if err != nil { diff --git a/internal/querynodev2/handlers_test.go b/internal/querynodev2/handlers_test.go index 0088d6f531..a9b2ed0f1d 100644 --- a/internal/querynodev2/handlers_test.go +++ b/internal/querynodev2/handlers_test.go @@ -21,22 +21,16 @@ import ( "os" "testing" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" - "github.com/milvus-io/milvus/internal/querynodev2/optimizers" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/etcd" - "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -143,180 +137,3 @@ func (suite *HandlersSuite) TestLoadGrowingSegments() { func TestHandlersSuite(t *testing.T) { suite.Run(t, new(HandlersSuite)) } - -type OptimizeSearchParamSuite struct { - suite.Suite - // Data - collectionID int64 - collectionName string - segmentID int64 - channel string - - node *QueryNode - delegator *delegator.MockShardDelegator - // Mock - factory *dependency.MockFactory -} - -func (suite *OptimizeSearchParamSuite) SetupSuite() { - suite.collectionID = 111 - suite.collectionName = "test-collection" - suite.segmentID = 1 - suite.channel = "test-channel" - - suite.delegator = &delegator.MockShardDelegator{} - suite.delegator.EXPECT().GetSegmentInfo(mock.Anything).Return([]delegator.SnapshotItem{{NodeID: 1, Segments: []delegator.SegmentEntry{{SegmentID: 100}}}}, []delegator.SegmentEntry{}) -} - -func (suite *OptimizeSearchParamSuite) SetupTest() { - suite.factory = dependency.NewMockFactory(suite.T()) - suite.node = NewQueryNode(context.Background(), suite.factory) -} - -func (suite *OptimizeSearchParamSuite) TearDownTest() { -} - -func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - suite.Run("normal_run", func() { - mockHook := optimizers.NewMockQueryHook(suite.T()) - mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { - params[common.TopKKey] = int64(50) - params[common.SearchParamKey] = `{"param": 2}` - }).Return(nil) - suite.node.queryHook = mockHook - defer func() { suite.node.queryHook = nil }() - - plan := &planpb.PlanNode{ - Node: &planpb.PlanNode_VectorAnns{ - VectorAnns: &planpb.VectorANNS{ - QueryInfo: &planpb.QueryInfo{ - Topk: 100, - SearchParams: `{"param": 1}`, - }, - }, - }, - } - bs, err := proto.Marshal(plan) - suite.Require().NoError(err) - - req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{ - Req: &internalpb.SearchRequest{ - SerializedExprPlan: bs, - }, - TotalChannelNum: 2, - }, suite.delegator) - suite.NoError(err) - suite.verifyQueryInfo(req, 50, `{"param": 2}`) - }) - - suite.Run("no_hook", func() { - suite.node.queryHook = nil - plan := &planpb.PlanNode{ - Node: &planpb.PlanNode_VectorAnns{ - VectorAnns: &planpb.VectorANNS{ - QueryInfo: &planpb.QueryInfo{ - Topk: 100, - SearchParams: `{"param": 1}`, - }, - }, - }, - } - bs, err := proto.Marshal(plan) - suite.Require().NoError(err) - - req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{ - Req: &internalpb.SearchRequest{ - SerializedExprPlan: bs, - }, - TotalChannelNum: 2, - }, suite.delegator) - suite.NoError(err) - suite.verifyQueryInfo(req, 100, `{"param": 1}`) - }) - - suite.Run("other_plannode", func() { - mockHook := optimizers.NewMockQueryHook(suite.T()) - mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { - params[common.TopKKey] = int64(50) - params[common.SearchParamKey] = `{"param": 2}` - }).Return(nil).Maybe() - suite.node.queryHook = mockHook - defer func() { suite.node.queryHook = nil }() - - plan := &planpb.PlanNode{ - Node: &planpb.PlanNode_Query{}, - } - bs, err := proto.Marshal(plan) - suite.Require().NoError(err) - - req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{ - Req: &internalpb.SearchRequest{ - SerializedExprPlan: bs, - }, - TotalChannelNum: 2, - }, suite.delegator) - suite.NoError(err) - suite.Equal(bs, req.GetReq().GetSerializedExprPlan()) - }) - - suite.Run("no_serialized_plan", func() { - mockHook := optimizers.NewMockQueryHook(suite.T()) - suite.node.queryHook = mockHook - defer func() { suite.node.queryHook = nil }() - - _, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{ - Req: &internalpb.SearchRequest{}, - TotalChannelNum: 2, - }, suite.delegator) - suite.Error(err) - }) - - suite.Run("hook_run_error", func() { - mockHook := optimizers.NewMockQueryHook(suite.T()) - mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { - params[common.TopKKey] = int64(50) - params[common.SearchParamKey] = `{"param": 2}` - }).Return(merr.WrapErrServiceInternal("mocked")) - suite.node.queryHook = mockHook - defer func() { suite.node.queryHook = nil }() - - plan := &planpb.PlanNode{ - Node: &planpb.PlanNode_VectorAnns{ - VectorAnns: &planpb.VectorANNS{ - QueryInfo: &planpb.QueryInfo{ - Topk: 100, - SearchParams: `{"param": 1}`, - }, - }, - }, - } - bs, err := proto.Marshal(plan) - suite.Require().NoError(err) - - _, err = suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{ - Req: &internalpb.SearchRequest{ - SerializedExprPlan: bs, - }, - }, suite.delegator) - suite.Error(err) - }) -} - -func (suite *OptimizeSearchParamSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, param string) { - planBytes := req.GetReq().GetSerializedExprPlan() - - plan := planpb.PlanNode{} - err := proto.Unmarshal(planBytes, &plan) - suite.Require().NoError(err) - - queryInfo := plan.GetVectorAnns().GetQueryInfo() - suite.Equal(topK, queryInfo.GetTopk()) - suite.Equal(param, queryInfo.GetSearchParams()) -} - -func TestOptimizeSearchParam(t *testing.T) { - suite.Run(t, new(OptimizeSearchParamSuite)) -} diff --git a/internal/querynodev2/optimizers/query_hook.go b/internal/querynodev2/optimizers/query_hook.go index c3703feba1..faaf990d1d 100644 --- a/internal/querynodev2/optimizers/query_hook.go +++ b/internal/querynodev2/optimizers/query_hook.go @@ -1,5 +1,19 @@ package optimizers +import ( + "context" + "fmt" + + "github.com/golang/protobuf/proto" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" +) + // QueryHook is the interface for search/query parameter optimizer. type QueryHook interface { Run(map[string]any) error @@ -7,3 +21,64 @@ type QueryHook interface { InitTuningConfig(map[string]string) error DeleteTuningConfig(string) error } + +func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, queryHook QueryHook, numSegments int) (*querypb.SearchRequest, error) { + // no hook applied, just return + if queryHook == nil { + return req, nil + } + + log := log.Ctx(ctx).With(zap.Int64("collection", req.GetReq().GetCollectionID())) + + serializedPlan := req.GetReq().GetSerializedExprPlan() + // plan not found + if serializedPlan == nil { + log.Warn("serialized plan not found") + return req, merr.WrapErrParameterInvalid("serialized search plan", "nil") + } + + channelNum := req.GetTotalChannelNum() + // not set, change to conservative channel num 1 + if channelNum <= 0 { + channelNum = 1 + } + + plan := planpb.PlanNode{} + err := proto.Unmarshal(serializedPlan, &plan) + if err != nil { + log.Warn("failed to unmarshal plan", zap.Error(err)) + return nil, merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error()) + } + + switch plan.GetNode().(type) { + case *planpb.PlanNode_VectorAnns: + // use shardNum * segments num in shard to estimate total segment number + estSegmentNum := numSegments * int(channelNum) + withFilter := (plan.GetVectorAnns().GetPredicates() != nil) + queryInfo := plan.GetVectorAnns().GetQueryInfo() + params := map[string]any{ + common.TopKKey: queryInfo.GetTopk(), + common.SearchParamKey: queryInfo.GetSearchParams(), + common.SegmentNumKey: estSegmentNum, + common.WithFilterKey: withFilter, + common.CollectionKey: req.GetReq().GetCollectionID(), + } + err := queryHook.Run(params) + if err != nil { + log.Warn("failed to execute queryHook", zap.Error(err)) + return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed") + } + queryInfo.Topk = params[common.TopKKey].(int64) + queryInfo.SearchParams = params[common.SearchParamKey].(string) + serializedExprPlan, err := proto.Marshal(&plan) + if err != nil { + log.Warn("failed to marshal optimized plan", zap.Error(err)) + return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error()) + } + req.Req.SerializedExprPlan = serializedExprPlan + log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo)) + default: + log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode()))) + } + return req, nil +} diff --git a/internal/querynodev2/optimizers/query_hook_test.go b/internal/querynodev2/optimizers/query_hook_test.go new file mode 100644 index 0000000000..132619b5e3 --- /dev/null +++ b/internal/querynodev2/optimizers/query_hook_test.go @@ -0,0 +1,173 @@ +package optimizers + +import ( + "context" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type QueryHookSuite struct { + suite.Suite + queryHook QueryHook +} + +func (suite *QueryHookSuite) SetupTest() { +} + +func (suite *QueryHookSuite) TearDownTest() { + suite.queryHook = nil +} + +func (suite *QueryHookSuite) TestOptimizeSearchParam() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + suite.Run("normal_run", func() { + mockHook := NewMockQueryHook(suite.T()) + mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { + params[common.TopKKey] = int64(50) + params[common.SearchParamKey] = `{"param": 2}` + }).Return(nil) + suite.queryHook = mockHook + defer func() { suite.queryHook = nil }() + + plan := &planpb.PlanNode{ + Node: &planpb.PlanNode_VectorAnns{ + VectorAnns: &planpb.VectorANNS{ + QueryInfo: &planpb.QueryInfo{ + Topk: 100, + SearchParams: `{"param": 1}`, + }, + }, + }, + } + bs, err := proto.Marshal(plan) + suite.Require().NoError(err) + + req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + SerializedExprPlan: bs, + }, + TotalChannelNum: 2, + }, suite.queryHook, 2) + suite.NoError(err) + suite.verifyQueryInfo(req, 50, `{"param": 2}`) + }) + + suite.Run("no_hook", func() { + suite.queryHook = nil + plan := &planpb.PlanNode{ + Node: &planpb.PlanNode_VectorAnns{ + VectorAnns: &planpb.VectorANNS{ + QueryInfo: &planpb.QueryInfo{ + Topk: 100, + SearchParams: `{"param": 1}`, + }, + }, + }, + } + bs, err := proto.Marshal(plan) + suite.Require().NoError(err) + + req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + SerializedExprPlan: bs, + }, + TotalChannelNum: 2, + }, suite.queryHook, 2) + suite.NoError(err) + suite.verifyQueryInfo(req, 100, `{"param": 1}`) + }) + + suite.Run("other_plannode", func() { + mockHook := NewMockQueryHook(suite.T()) + mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { + params[common.TopKKey] = int64(50) + params[common.SearchParamKey] = `{"param": 2}` + }).Return(nil).Maybe() + suite.queryHook = mockHook + defer func() { suite.queryHook = nil }() + + plan := &planpb.PlanNode{ + Node: &planpb.PlanNode_Query{}, + } + bs, err := proto.Marshal(plan) + suite.Require().NoError(err) + + req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + SerializedExprPlan: bs, + }, + TotalChannelNum: 2, + }, suite.queryHook, 2) + suite.NoError(err) + suite.Equal(bs, req.GetReq().GetSerializedExprPlan()) + }) + + suite.Run("no_serialized_plan", func() { + mockHook := NewMockQueryHook(suite.T()) + suite.queryHook = mockHook + defer func() { suite.queryHook = nil }() + + _, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{}, + TotalChannelNum: 2, + }, suite.queryHook, 2) + suite.Error(err) + }) + + suite.Run("hook_run_error", func() { + mockHook := NewMockQueryHook(suite.T()) + mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { + params[common.TopKKey] = int64(50) + params[common.SearchParamKey] = `{"param": 2}` + }).Return(merr.WrapErrServiceInternal("mocked")) + suite.queryHook = mockHook + defer func() { suite.queryHook = nil }() + + plan := &planpb.PlanNode{ + Node: &planpb.PlanNode_VectorAnns{ + VectorAnns: &planpb.VectorANNS{ + QueryInfo: &planpb.QueryInfo{ + Topk: 100, + SearchParams: `{"param": 1}`, + }, + }, + }, + } + bs, err := proto.Marshal(plan) + suite.Require().NoError(err) + + _, err = OptimizeSearchParams(ctx, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + SerializedExprPlan: bs, + }, + }, suite.queryHook, 2) + suite.Error(err) + }) +} + +func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, param string) { + planBytes := req.GetReq().GetSerializedExprPlan() + + plan := planpb.PlanNode{} + err := proto.Unmarshal(planBytes, &plan) + suite.Require().NoError(err) + + queryInfo := plan.GetVectorAnns().GetQueryInfo() + suite.Equal(topK, queryInfo.GetTopk()) + suite.Equal(param, queryInfo.GetSearchParams()) +} + +func TestOptimizeSearchParam(t *testing.T) { + suite.Run(t, new(QueryHookSuite)) +} diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index ef5d483813..6209806f1a 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -254,8 +254,21 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm node.composeIndexMeta(req.GetIndexInfoList(), req.Schema), req.GetLoadMeta()) collection := node.manager.Collection.Get(req.GetCollectionID()) collection.SetMetricType(req.GetLoadMeta().GetMetricType()) - delegator, err := delegator.NewShardDelegator(ctx, req.GetCollectionID(), req.GetReplicaID(), channel.GetChannelName(), req.GetVersion(), - node.clusterManager, node.manager, node.tSafeManager, node.loader, node.factory, channel.GetSeekPosition().GetTimestamp()) + + delegator, err := delegator.NewShardDelegator( + ctx, + req.GetCollectionID(), + req.GetReplicaID(), + channel.GetChannelName(), + req.GetVersion(), + node.clusterManager, + node.manager, + node.tSafeManager, + node.loader, + node.factory, + channel.GetSeekPosition().GetTimestamp(), + node.queryHook, + ) if err != nil { log.Warn("failed to create shard delegator", zap.Error(err)) return merr.Status(err), nil