mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
fix: [2.3] correct autoindex segment num (#28429)
issue: #28386 pr: #28387 Signed-off-by: chasingegg <chao.gao@zilliz.com>
This commit is contained in:
parent
fda452ea4d
commit
ccca932cc6
@ -34,6 +34,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator/deletebuffer"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/optimizers"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
|
||||
@ -106,6 +107,8 @@ type shardDelegator struct {
|
||||
loader segments.Loader
|
||||
tsCond *sync.Cond
|
||||
latestTsafe *atomic.Uint64
|
||||
// queryHook
|
||||
queryHook optimizers.QueryHook
|
||||
}
|
||||
|
||||
// getLogger returns the zap logger with pre-defined shard attributes.
|
||||
@ -226,6 +229,13 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
||||
zap.Int("sealedNum", sealedNum),
|
||||
zap.Int("growingNum", len(growing)),
|
||||
)
|
||||
|
||||
req, err = optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum)
|
||||
if err != nil {
|
||||
log.Warn("failed to optimize search params", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
|
||||
if err != nil {
|
||||
log.Warn("Search organizeSubTask failed", zap.Error(err))
|
||||
@ -636,7 +646,7 @@ func (sd *shardDelegator) Close() {
|
||||
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
|
||||
func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64,
|
||||
workerManager cluster.Manager, manager *segments.Manager, tsafeManager tsafe.Manager, loader segments.Loader,
|
||||
factory msgstream.Factory, startTs uint64,
|
||||
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook,
|
||||
) (ShardDelegator, error) {
|
||||
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID),
|
||||
zap.Int64("replicaID", replicaID),
|
||||
@ -669,6 +679,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
||||
latestTsafe: atomic.NewUint64(startTs),
|
||||
loader: loader,
|
||||
factory: factory,
|
||||
queryHook: queryHook,
|
||||
}
|
||||
m := sync.Mutex{}
|
||||
sd.tsCond = sync.NewCond(&m)
|
||||
|
||||
@ -135,7 +135,7 @@ func (s *DelegatorDataSuite) SetupTest() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000)
|
||||
}, 10000, nil)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
|
||||
@ -159,7 +159,7 @@ func (s *DelegatorSuite) SetupTest() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000)
|
||||
}, 10000, nil)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
|
||||
@ -21,20 +21,17 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/samber/lo"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/segments"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
@ -292,72 +289,6 @@ func (node *QueryNode) queryStreamSegments(ctx context.Context, req *querypb.Que
|
||||
return nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) optimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, deleg delegator.ShardDelegator) (*querypb.SearchRequest, error) {
|
||||
// no hook applied, just return
|
||||
if node.queryHook == nil {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collection", req.GetReq().GetCollectionID()))
|
||||
|
||||
serializedPlan := req.GetReq().GetSerializedExprPlan()
|
||||
// plan not found
|
||||
if serializedPlan == nil {
|
||||
log.Warn("serialized plan not found")
|
||||
return req, merr.WrapErrParameterInvalid("serialized search plan", "nil")
|
||||
}
|
||||
|
||||
channelNum := req.GetTotalChannelNum()
|
||||
// not set, change to conservative channel num 1
|
||||
if channelNum <= 0 {
|
||||
channelNum = 1
|
||||
}
|
||||
|
||||
plan := planpb.PlanNode{}
|
||||
err := proto.Unmarshal(serializedPlan, &plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to unmarshal plan", zap.Error(err))
|
||||
return nil, merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
|
||||
}
|
||||
|
||||
switch plan.GetNode().(type) {
|
||||
case *planpb.PlanNode_VectorAnns:
|
||||
// ignore growing ones for now since they will always be brute force
|
||||
sealed, _ := deleg.GetSegmentInfo(true)
|
||||
sealedNum := lo.Reduce(sealed, func(sum int, item delegator.SnapshotItem, _ int) int {
|
||||
return sum + len(item.Segments)
|
||||
}, 0)
|
||||
// use shardNum * segments num in shard to estimate total segment number
|
||||
estSegmentNum := sealedNum * int(channelNum)
|
||||
withFilter := (plan.GetVectorAnns().GetPredicates() != nil)
|
||||
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
||||
params := map[string]any{
|
||||
common.TopKKey: queryInfo.GetTopk(),
|
||||
common.SearchParamKey: queryInfo.GetSearchParams(),
|
||||
common.SegmentNumKey: estSegmentNum,
|
||||
common.WithFilterKey: withFilter,
|
||||
common.CollectionKey: req.GetReq().GetCollectionID(),
|
||||
}
|
||||
err := node.queryHook.Run(params)
|
||||
if err != nil {
|
||||
log.Warn("failed to execute queryHook", zap.Error(err))
|
||||
return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed")
|
||||
}
|
||||
queryInfo.Topk = params[common.TopKKey].(int64)
|
||||
queryInfo.SearchParams = params[common.SearchParamKey].(string)
|
||||
serializedExprPlan, err := proto.Marshal(&plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to marshal optimized plan", zap.Error(err))
|
||||
return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
|
||||
}
|
||||
req.Req.SerializedExprPlan = serializedExprPlan
|
||||
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
|
||||
default:
|
||||
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchRequest, channel string) (*internalpb.SearchResults, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
|
||||
@ -397,11 +328,6 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
|
||||
log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
req, err = node.optimizeSearchParams(ctx, req, sd)
|
||||
if err != nil {
|
||||
log.Warn("failed to optimize search params", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
// do search
|
||||
results, err := sd.Search(searchCtx, req)
|
||||
if err != nil {
|
||||
|
||||
@ -21,22 +21,16 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/datapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/optimizers"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/etcd"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
)
|
||||
|
||||
@ -143,180 +137,3 @@ func (suite *HandlersSuite) TestLoadGrowingSegments() {
|
||||
func TestHandlersSuite(t *testing.T) {
|
||||
suite.Run(t, new(HandlersSuite))
|
||||
}
|
||||
|
||||
type OptimizeSearchParamSuite struct {
|
||||
suite.Suite
|
||||
// Data
|
||||
collectionID int64
|
||||
collectionName string
|
||||
segmentID int64
|
||||
channel string
|
||||
|
||||
node *QueryNode
|
||||
delegator *delegator.MockShardDelegator
|
||||
// Mock
|
||||
factory *dependency.MockFactory
|
||||
}
|
||||
|
||||
func (suite *OptimizeSearchParamSuite) SetupSuite() {
|
||||
suite.collectionID = 111
|
||||
suite.collectionName = "test-collection"
|
||||
suite.segmentID = 1
|
||||
suite.channel = "test-channel"
|
||||
|
||||
suite.delegator = &delegator.MockShardDelegator{}
|
||||
suite.delegator.EXPECT().GetSegmentInfo(mock.Anything).Return([]delegator.SnapshotItem{{NodeID: 1, Segments: []delegator.SegmentEntry{{SegmentID: 100}}}}, []delegator.SegmentEntry{})
|
||||
}
|
||||
|
||||
func (suite *OptimizeSearchParamSuite) SetupTest() {
|
||||
suite.factory = dependency.NewMockFactory(suite.T())
|
||||
suite.node = NewQueryNode(context.Background(), suite.factory)
|
||||
}
|
||||
|
||||
func (suite *OptimizeSearchParamSuite) TearDownTest() {
|
||||
}
|
||||
|
||||
func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
suite.Run("normal_run", func() {
|
||||
mockHook := optimizers.NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(nil)
|
||||
suite.node.queryHook = mockHook
|
||||
defer func() { suite.node.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.delegator)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 50, `{"param": 2}`)
|
||||
})
|
||||
|
||||
suite.Run("no_hook", func() {
|
||||
suite.node.queryHook = nil
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.delegator)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 100, `{"param": 1}`)
|
||||
})
|
||||
|
||||
suite.Run("other_plannode", func() {
|
||||
mockHook := optimizers.NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(nil).Maybe()
|
||||
suite.node.queryHook = mockHook
|
||||
defer func() { suite.node.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_Query{},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.delegator)
|
||||
suite.NoError(err)
|
||||
suite.Equal(bs, req.GetReq().GetSerializedExprPlan())
|
||||
})
|
||||
|
||||
suite.Run("no_serialized_plan", func() {
|
||||
mockHook := optimizers.NewMockQueryHook(suite.T())
|
||||
suite.node.queryHook = mockHook
|
||||
defer func() { suite.node.queryHook = nil }()
|
||||
|
||||
_, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.delegator)
|
||||
suite.Error(err)
|
||||
})
|
||||
|
||||
suite.Run("hook_run_error", func() {
|
||||
mockHook := optimizers.NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(merr.WrapErrServiceInternal("mocked"))
|
||||
suite.node.queryHook = mockHook
|
||||
defer func() { suite.node.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
_, err = suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
}, suite.delegator)
|
||||
suite.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *OptimizeSearchParamSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, param string) {
|
||||
planBytes := req.GetReq().GetSerializedExprPlan()
|
||||
|
||||
plan := planpb.PlanNode{}
|
||||
err := proto.Unmarshal(planBytes, &plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
||||
suite.Equal(topK, queryInfo.GetTopk())
|
||||
suite.Equal(param, queryInfo.GetSearchParams())
|
||||
}
|
||||
|
||||
func TestOptimizeSearchParam(t *testing.T) {
|
||||
suite.Run(t, new(OptimizeSearchParamSuite))
|
||||
}
|
||||
|
||||
@ -1,5 +1,19 @@
|
||||
package optimizers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/log"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
// QueryHook is the interface for search/query parameter optimizer.
|
||||
type QueryHook interface {
|
||||
Run(map[string]any) error
|
||||
@ -7,3 +21,64 @@ type QueryHook interface {
|
||||
InitTuningConfig(map[string]string) error
|
||||
DeleteTuningConfig(string) error
|
||||
}
|
||||
|
||||
func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, queryHook QueryHook, numSegments int) (*querypb.SearchRequest, error) {
|
||||
// no hook applied, just return
|
||||
if queryHook == nil {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
log := log.Ctx(ctx).With(zap.Int64("collection", req.GetReq().GetCollectionID()))
|
||||
|
||||
serializedPlan := req.GetReq().GetSerializedExprPlan()
|
||||
// plan not found
|
||||
if serializedPlan == nil {
|
||||
log.Warn("serialized plan not found")
|
||||
return req, merr.WrapErrParameterInvalid("serialized search plan", "nil")
|
||||
}
|
||||
|
||||
channelNum := req.GetTotalChannelNum()
|
||||
// not set, change to conservative channel num 1
|
||||
if channelNum <= 0 {
|
||||
channelNum = 1
|
||||
}
|
||||
|
||||
plan := planpb.PlanNode{}
|
||||
err := proto.Unmarshal(serializedPlan, &plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to unmarshal plan", zap.Error(err))
|
||||
return nil, merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
|
||||
}
|
||||
|
||||
switch plan.GetNode().(type) {
|
||||
case *planpb.PlanNode_VectorAnns:
|
||||
// use shardNum * segments num in shard to estimate total segment number
|
||||
estSegmentNum := numSegments * int(channelNum)
|
||||
withFilter := (plan.GetVectorAnns().GetPredicates() != nil)
|
||||
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
||||
params := map[string]any{
|
||||
common.TopKKey: queryInfo.GetTopk(),
|
||||
common.SearchParamKey: queryInfo.GetSearchParams(),
|
||||
common.SegmentNumKey: estSegmentNum,
|
||||
common.WithFilterKey: withFilter,
|
||||
common.CollectionKey: req.GetReq().GetCollectionID(),
|
||||
}
|
||||
err := queryHook.Run(params)
|
||||
if err != nil {
|
||||
log.Warn("failed to execute queryHook", zap.Error(err))
|
||||
return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed")
|
||||
}
|
||||
queryInfo.Topk = params[common.TopKKey].(int64)
|
||||
queryInfo.SearchParams = params[common.SearchParamKey].(string)
|
||||
serializedExprPlan, err := proto.Marshal(&plan)
|
||||
if err != nil {
|
||||
log.Warn("failed to marshal optimized plan", zap.Error(err))
|
||||
return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
|
||||
}
|
||||
req.Req.SerializedExprPlan = serializedExprPlan
|
||||
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
|
||||
default:
|
||||
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
173
internal/querynodev2/optimizers/query_hook_test.go
Normal file
173
internal/querynodev2/optimizers/query_hook_test.go
Normal file
@ -0,0 +1,173 @@
|
||||
package optimizers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
)
|
||||
|
||||
type QueryHookSuite struct {
|
||||
suite.Suite
|
||||
queryHook QueryHook
|
||||
}
|
||||
|
||||
func (suite *QueryHookSuite) SetupTest() {
|
||||
}
|
||||
|
||||
func (suite *QueryHookSuite) TearDownTest() {
|
||||
suite.queryHook = nil
|
||||
}
|
||||
|
||||
func (suite *QueryHookSuite) TestOptimizeSearchParam() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
suite.Run("normal_run", func() {
|
||||
mockHook := NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(nil)
|
||||
suite.queryHook = mockHook
|
||||
defer func() { suite.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 50, `{"param": 2}`)
|
||||
})
|
||||
|
||||
suite.Run("no_hook", func() {
|
||||
suite.queryHook = nil
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.NoError(err)
|
||||
suite.verifyQueryInfo(req, 100, `{"param": 1}`)
|
||||
})
|
||||
|
||||
suite.Run("other_plannode", func() {
|
||||
mockHook := NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(nil).Maybe()
|
||||
suite.queryHook = mockHook
|
||||
defer func() { suite.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_Query{},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.NoError(err)
|
||||
suite.Equal(bs, req.GetReq().GetSerializedExprPlan())
|
||||
})
|
||||
|
||||
suite.Run("no_serialized_plan", func() {
|
||||
mockHook := NewMockQueryHook(suite.T())
|
||||
suite.queryHook = mockHook
|
||||
defer func() { suite.queryHook = nil }()
|
||||
|
||||
_, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{},
|
||||
TotalChannelNum: 2,
|
||||
}, suite.queryHook, 2)
|
||||
suite.Error(err)
|
||||
})
|
||||
|
||||
suite.Run("hook_run_error", func() {
|
||||
mockHook := NewMockQueryHook(suite.T())
|
||||
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
|
||||
params[common.TopKKey] = int64(50)
|
||||
params[common.SearchParamKey] = `{"param": 2}`
|
||||
}).Return(merr.WrapErrServiceInternal("mocked"))
|
||||
suite.queryHook = mockHook
|
||||
defer func() { suite.queryHook = nil }()
|
||||
|
||||
plan := &planpb.PlanNode{
|
||||
Node: &planpb.PlanNode_VectorAnns{
|
||||
VectorAnns: &planpb.VectorANNS{
|
||||
QueryInfo: &planpb.QueryInfo{
|
||||
Topk: 100,
|
||||
SearchParams: `{"param": 1}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
bs, err := proto.Marshal(plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
_, err = OptimizeSearchParams(ctx, &querypb.SearchRequest{
|
||||
Req: &internalpb.SearchRequest{
|
||||
SerializedExprPlan: bs,
|
||||
},
|
||||
}, suite.queryHook, 2)
|
||||
suite.Error(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, param string) {
|
||||
planBytes := req.GetReq().GetSerializedExprPlan()
|
||||
|
||||
plan := planpb.PlanNode{}
|
||||
err := proto.Unmarshal(planBytes, &plan)
|
||||
suite.Require().NoError(err)
|
||||
|
||||
queryInfo := plan.GetVectorAnns().GetQueryInfo()
|
||||
suite.Equal(topK, queryInfo.GetTopk())
|
||||
suite.Equal(param, queryInfo.GetSearchParams())
|
||||
}
|
||||
|
||||
func TestOptimizeSearchParam(t *testing.T) {
|
||||
suite.Run(t, new(QueryHookSuite))
|
||||
}
|
||||
@ -254,8 +254,21 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
|
||||
node.composeIndexMeta(req.GetIndexInfoList(), req.Schema), req.GetLoadMeta())
|
||||
collection := node.manager.Collection.Get(req.GetCollectionID())
|
||||
collection.SetMetricType(req.GetLoadMeta().GetMetricType())
|
||||
delegator, err := delegator.NewShardDelegator(ctx, req.GetCollectionID(), req.GetReplicaID(), channel.GetChannelName(), req.GetVersion(),
|
||||
node.clusterManager, node.manager, node.tSafeManager, node.loader, node.factory, channel.GetSeekPosition().GetTimestamp())
|
||||
|
||||
delegator, err := delegator.NewShardDelegator(
|
||||
ctx,
|
||||
req.GetCollectionID(),
|
||||
req.GetReplicaID(),
|
||||
channel.GetChannelName(),
|
||||
req.GetVersion(),
|
||||
node.clusterManager,
|
||||
node.manager,
|
||||
node.tSafeManager,
|
||||
node.loader,
|
||||
node.factory,
|
||||
channel.GetSeekPosition().GetTimestamp(),
|
||||
node.queryHook,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn("failed to create shard delegator", zap.Error(err))
|
||||
return merr.Status(err), nil
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user