mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-02-02 01:06:41 +08:00
fix: correct autoindex segment num (#28387)
Fix #28386 Current code snippet ``` // get delegator sd, ok := node.delegators.Get(channel) if !ok { err := merr.WrapErrChannelNotFound(channel) log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err)) return nil, err } req, err = node.optimizeSearchParams(ctx, req, sd) if err != nil { log.Warn("failed to optimize search params", zap.Error(err)) return nil, err } // do search results, err := sd.Search(searchCtx, req) ``` We could move these into `ShardDelegator`, and directly use sealed segment num in `Search` methods, also segment num got outside could be wrong when we specify partitions. Signed-off-by: chasingegg <chao.gao@zilliz.com>
This commit is contained in:
parent
29249c4bd3
commit
3e77365de5
@ -34,6 +34,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator/deletebuffer"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/optimizers"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
@ -109,6 +110,8 @@ type shardDelegator struct {
|
||||
loader segments.Loader
|
||||
tsCond *sync.Cond
|
||||
latestTsafe *atomic.Uint64
|
||||
// queryHook
|
||||
queryHook optimizers.QueryHook
|
||||
}
|
||||
|
||||
// getLogger returns the zap logger with pre-defined shard attributes.
|
||||
@ -229,6 +232,13 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
||||
zap.Int("sealedNum", sealedNum),
|
||||
zap.Int("growingNum", len(growing)),
|
||||
)
|
||||
|
||||
req, err = optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum)
|
||||
if err != nil {
|
||||
log.Warn("failed to optimize search params", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
|
||||
if err != nil {
|
||||
log.Warn("Search organizeSubTask failed", zap.Error(err))
|
||||
@ -639,7 +649,7 @@ func (sd *shardDelegator) Close() {
|
||||
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
|
||||
func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64,
|
||||
workerManager cluster.Manager, manager *segments.Manager, tsafeManager tsafe.Manager, loader segments.Loader,
|
||||
factory msgstream.Factory, startTs uint64,
|
||||
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook,
|
||||
) (ShardDelegator, error) {
|
||||
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID),
|
||||
zap.Int64("replicaID", replicaID),
|
||||
@ -673,6 +683,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
||||
latestTsafe: atomic.NewUint64(startTs),
|
||||
loader: loader,
|
||||
factory: factory,
|
||||
queryHook: queryHook,
|
||||
}
|
||||
m := sync.Mutex{}
|
||||
sd.tsCond = sync.NewCond(&m)
|
||||
|
||||
@ -136,7 +136,7 @@ func (s *DelegatorDataSuite) SetupTest() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000)
|
||||
}, 10000, nil)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
@ -522,7 +522,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000)
|
||||
}, 10000, nil)
|
||||
s.NoError(err)
|
||||
|
||||
growing0 := segments.NewMockSegment(s.T())
|
||||
|
||||
@ -160,7 +160,7 @@ func (s *DelegatorSuite) SetupTest() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000)
|
||||
}, 10000, nil)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
|
||||
@ -21,7 +21,6 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/samber/lo"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
@ -29,13 +28,11 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
@ -332,72 +329,6 @@ func (node *QueryNode) queryStreamSegments(ctx context.Context, req *querypb.Que
|
||||
return nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) optimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, deleg delegator.ShardDelegator) (*querypb.SearchRequest, error) {
|
||||
// no hook applied, just return
|
||||
if node.queryHook == nil {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collection", req.GetReq().GetCollectionID()))
|
||||
|
||||
serializedPlan := req.GetReq().GetSerializedExprPlan()
|
||||
// plan not found
|
||||
if serializedPlan == nil {
|
||||
log.Warn("serialized plan not found")
|
||||
return req, merr.WrapErrParameterInvalid("serialized search plan", "nil")
|
||||
}
|
||||
|
||||
channelNum := req.GetTotalChannelNum()
|
||||
// not set, change to conservative channel num 1
|
||||
if channelNum <= 0 {
|
||||
channelNum = 1
|
||||
}
|
||||
|
||||
plan := planpb.PlanNode{}
|
||||
err := proto.Unmarshal(serializedPlan, &plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to unmarshal plan", zap.Error(err))
|
||||
return nil, merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
|
||||
}
|
||||
|
||||
switch plan.GetNode().(type) {
|
||||
case *planpb.PlanNode_VectorAnns:
|
||||
// ignore growing ones for now since they will always be brute force
|
||||
sealed, _ := deleg.GetSegmentInfo(true)
|
||||
sealedNum := lo.Reduce(sealed, func(sum int, item delegator.SnapshotItem, _ int) int {
|
||||
return sum + len(item.Segments)
|
||||
}, 0)
|
||||
// use shardNum * segments num in shard to estimate total segment number
|
||||
estSegmentNum := sealedNum * int(channelNum)
|
||||
withFilter := (plan.GetVectorAnns().GetPredicates() != nil)
|
||||
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
||||
params := map[string]any{
|
||||
common.TopKKey: queryInfo.GetTopk(),
|
||||
common.SearchParamKey: queryInfo.GetSearchParams(),
|
||||
common.SegmentNumKey: estSegmentNum,
|
||||
common.WithFilterKey: withFilter,
|
||||
common.CollectionKey: req.GetReq().GetCollectionID(),
|
||||
}
|
||||
err := node.queryHook.Run(params)
|
||||
if err != nil {
|
||||
log.Warn("failed to execute queryHook", zap.Error(err))
|
||||
return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed")
|
||||
}
|
||||
queryInfo.Topk = params[common.TopKKey].(int64)
|
||||
queryInfo.SearchParams = params[common.SearchParamKey].(string)
|
||||
serializedExprPlan, err := proto.Marshal(&plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to marshal optimized plan", zap.Error(err))
|
||||
return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
|
||||
}
|
||||
req.Req.SerializedExprPlan = serializedExprPlan
|
||||
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
|
||||
default:
|
||||
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchRequest, channel string) (*internalpb.SearchResults, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
|
||||
@ -436,11 +367,6 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
|
||||
log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
req, err = node.optimizeSearchParams(ctx, req, sd)
|
||||
if err != nil {
|
||||
log.Warn("failed to optimize search params", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
// do search
|
||||
results, err := sd.Search(searchCtx, req)
|
||||
if err != nil {
|
||||
|
||||
@ -21,22 +21,16 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/optimizers"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
@ -143,180 +137,3 @@ func (suite *HandlersSuite) TestLoadGrowingSegments() {
|
||||
func TestHandlersSuite(t *testing.T) {
|
||||
suite.Run(t, new(HandlersSuite))
|
||||
}
|
||||
|
||||
type OptimizeSearchParamSuite struct {
|
||||
suite.Suite
|
||||
// Data
|
||||
collectionID int64
|
||||
collectionName string
|
||||
segmentID int64
|
||||
channel string
|
||||
|
||||
node *QueryNode
|
||||
delegator *delegator.MockShardDelegator
|
||||
// Mock
|
||||
factory *dependency.MockFactory
|
||||
}
|
||||
|
||||
func (suite *OptimizeSearchParamSuite) SetupSuite() {
|
||||
suite.collectionID = 111
|
||||
suite.collectionName = "test-collection"
|
||||
suite.segmentID = 1
|
||||
suite.channel = "test-channel"
|
||||
|
||||
suite.delegator = &delegator.MockShardDelegator{}
|
||||
suite.delegator.EXPECT().GetSegmentInfo(mock.Anything).Return([]delegator.SnapshotItem{{NodeID: 1, Segments: []delegator.SegmentEntry{{SegmentID: 100}}}}, []delegator.SegmentEntry{})
|
||||
}
|
||||
|
||||
func (suite *OptimizeSearchParamSuite) SetupTest() {
|
||||
suite.factory = dependency.NewMockFactory(suite.T())
|
||||
suite.node = NewQueryNode(context.Background(), suite.factory)
|
||||
}
|
||||
|
||||
func (suite *OptimizeSearchParamSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
suite.Run("normal_run", func() {
|
||||
mockHook := optimizers.NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(nil)
|
||||
suite.node.queryHook = mockHook
|
||||
defer func() { suite.node.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.delegator)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 50, `{"param": 2}`)
|
||||
})
|
||||
|
||||
suite.Run("no_hook", func() {
|
||||
suite.node.queryHook = nil
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.delegator)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 100, `{"param": 1}`)
|
||||
})
|
||||
|
||||
suite.Run("other_plannode", func() {
|
||||
mockHook := optimizers.NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(nil).Maybe()
|
||||
suite.node.queryHook = mockHook
|
||||
defer func() { suite.node.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_Query{},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.delegator)
|
||||
suite.NoError(err)
|
||||
suite.Equal(bs, req.GetReq().GetSerializedExprPlan())
|
||||
})
|
||||
|
||||
suite.Run("no_serialized_plan", func() {
|
||||
mockHook := optimizers.NewMockQueryHook(suite.T())
|
||||
suite.node.queryHook = mockHook
|
||||
defer func() { suite.node.queryHook = nil }()
|
||||
|
||||
_, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.delegator)
|
||||
suite.Error(err)
|
||||
})
|
||||
|
||||
suite.Run("hook_run_error", func() {
|
||||
mockHook := optimizers.NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(merr.WrapErrServiceInternal("mocked"))
|
||||
suite.node.queryHook = mockHook
|
||||
defer func() { suite.node.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
_, err = suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
}, suite.delegator)
|
||||
suite.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *OptimizeSearchParamSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, param string) {
|
||||
planBytes := req.GetReq().GetSerializedExprPlan()
|
||||
|
||||
plan := planpb.PlanNode{}
|
||||
err := proto.Unmarshal(planBytes, &plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
||||
suite.Equal(topK, queryInfo.GetTopk())
|
||||
suite.Equal(param, queryInfo.GetSearchParams())
|
||||
}
|
||||
|
||||
func TestOptimizeSearchParam(t *testing.T) {
|
||||
suite.Run(t, new(OptimizeSearchParamSuite))
|
||||
}
|
||||
|
||||
@ -1,5 +1,19 @@
|
||||
package optimizers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
// QueryHook is the interface for search/query parameter optimizer.
|
||||
type QueryHook interface {
|
||||
Run(map[string]any) error
|
||||
@ -7,3 +21,64 @@ type QueryHook interface {
|
||||
InitTuningConfig(map[string]string) error
|
||||
DeleteTuningConfig(string) error
|
||||
}
|
||||
|
||||
func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, queryHook QueryHook, numSegments int) (*querypb.SearchRequest, error) {
|
||||
// no hook applied, just return
|
||||
if queryHook == nil {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collection", req.GetReq().GetCollectionID()))
|
||||
|
||||
serializedPlan := req.GetReq().GetSerializedExprPlan()
|
||||
// plan not found
|
||||
if serializedPlan == nil {
|
||||
log.Warn("serialized plan not found")
|
||||
return req, merr.WrapErrParameterInvalid("serialized search plan", "nil")
|
||||
}
|
||||
|
||||
channelNum := req.GetTotalChannelNum()
|
||||
// not set, change to conservative channel num 1
|
||||
if channelNum <= 0 {
|
||||
channelNum = 1
|
||||
}
|
||||
|
||||
plan := planpb.PlanNode{}
|
||||
err := proto.Unmarshal(serializedPlan, &plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to unmarshal plan", zap.Error(err))
|
||||
return nil, merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
|
||||
}
|
||||
|
||||
switch plan.GetNode().(type) {
|
||||
case *planpb.PlanNode_VectorAnns:
|
||||
// use shardNum * segments num in shard to estimate total segment number
|
||||
estSegmentNum := numSegments * int(channelNum)
|
||||
withFilter := (plan.GetVectorAnns().GetPredicates() != nil)
|
||||
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
||||
params := map[string]any{
|
||||
common.TopKKey: queryInfo.GetTopk(),
|
||||
common.SearchParamKey: queryInfo.GetSearchParams(),
|
||||
common.SegmentNumKey: estSegmentNum,
|
||||
common.WithFilterKey: withFilter,
|
||||
common.CollectionKey: req.GetReq().GetCollectionID(),
|
||||
}
|
||||
err := queryHook.Run(params)
|
||||
if err != nil {
|
||||
log.Warn("failed to execute queryHook", zap.Error(err))
|
||||
return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed")
|
||||
}
|
||||
queryInfo.Topk = params[common.TopKKey].(int64)
|
||||
queryInfo.SearchParams = params[common.SearchParamKey].(string)
|
||||
serializedExprPlan, err := proto.Marshal(&plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to marshal optimized plan", zap.Error(err))
|
||||
return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
|
||||
}
|
||||
req.Req.SerializedExprPlan = serializedExprPlan
|
||||
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
|
||||
default:
|
||||
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
173
internal/querynodev2/optimizers/query_hook_test.go
Normal file
173
internal/querynodev2/optimizers/query_hook_test.go
Normal file
@ -0,0 +1,173 @@
|
||||
package optimizers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type QueryHookSuite struct {
|
||||
suite.Suite
|
||||
queryHook QueryHook
|
||||
}
|
||||
|
||||
func (suite *QueryHookSuite) SetupTest() {
|
||||
}
|
||||
|
||||
func (suite *QueryHookSuite) TearDownTest() {
|
||||
suite.queryHook = nil
|
||||
}
|
||||
|
||||
func (suite *QueryHookSuite) TestOptimizeSearchParam() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
suite.Run("normal_run", func() {
|
||||
mockHook := NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(nil)
|
||||
suite.queryHook = mockHook
|
||||
defer func() { suite.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 50, `{"param": 2}`)
|
||||
})
|
||||
|
||||
suite.Run("no_hook", func() {
|
||||
suite.queryHook = nil
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 100, `{"param": 1}`)
|
||||
})
|
||||
|
||||
suite.Run("other_plannode", func() {
|
||||
mockHook := NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(nil).Maybe()
|
||||
suite.queryHook = mockHook
|
||||
defer func() { suite.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_Query{},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.NoError(err)
|
||||
suite.Equal(bs, req.GetReq().GetSerializedExprPlan())
|
||||
})
|
||||
|
||||
suite.Run("no_serialized_plan", func() {
|
||||
mockHook := NewMockQueryHook(suite.T())
|
||||
suite.queryHook = mockHook
|
||||
defer func() { suite.queryHook = nil }()
|
||||
|
||||
_, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.Error(err)
|
||||
})
|
||||
|
||||
suite.Run("hook_run_error", func() {
|
||||
mockHook := NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(merr.WrapErrServiceInternal("mocked"))
|
||||
suite.queryHook = mockHook
|
||||
defer func() { suite.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
_, err = OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
}, suite.queryHook, 2)
|
||||
suite.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, param string) {
|
||||
planBytes := req.GetReq().GetSerializedExprPlan()
|
||||
|
||||
plan := planpb.PlanNode{}
|
||||
err := proto.Unmarshal(planBytes, &plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
||||
suite.Equal(topK, queryInfo.GetTopk())
|
||||
suite.Equal(param, queryInfo.GetSearchParams())
|
||||
}
|
||||
|
||||
func TestOptimizeSearchParam(t *testing.T) {
|
||||
suite.Run(t, new(QueryHookSuite))
|
||||
}
|
||||
@ -268,6 +268,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
|
||||
node.loader,
|
||||
node.factory,
|
||||
channel.GetSeekPosition().GetTimestamp(),
|
||||
node.queryHook,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn("failed to create shard delegator", zap.Error(err))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user