fix: [2.3] correct autoindex segment num (#28429)

issue: #28386 
pr: #28387

Signed-off-by: chasingegg <chao.gao@zilliz.com>
This commit is contained in:
Gao 2023-11-28 19:24:26 +08:00 committed by GitHub
parent fda452ea4d
commit ccca932cc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 277 additions and 262 deletions

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
"github.com/milvus-io/milvus/internal/querynodev2/delegator/deletebuffer"
"github.com/milvus-io/milvus/internal/querynodev2/optimizers"
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tsafe"
@ -106,6 +107,8 @@ type shardDelegator struct {
loader segments.Loader
tsCond *sync.Cond
latestTsafe *atomic.Uint64
// queryHook
queryHook optimizers.QueryHook
}
// getLogger returns the zap logger with pre-defined shard attributes.
@ -226,6 +229,13 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
zap.Int("sealedNum", sealedNum),
zap.Int("growingNum", len(growing)),
)
req, err = optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum)
if err != nil {
log.Warn("failed to optimize search params", zap.Error(err))
return nil, err
}
tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
if err != nil {
log.Warn("Search organizeSubTask failed", zap.Error(err))
@ -636,7 +646,7 @@ func (sd *shardDelegator) Close() {
// NewShardDelegator creates a new ShardDelegator instance with all fields initialized.
func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64,
workerManager cluster.Manager, manager *segments.Manager, tsafeManager tsafe.Manager, loader segments.Loader,
factory msgstream.Factory, startTs uint64,
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook,
) (ShardDelegator, error) {
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID),
zap.Int64("replicaID", replicaID),
@ -669,6 +679,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
latestTsafe: atomic.NewUint64(startTs),
loader: loader,
factory: factory,
queryHook: queryHook,
}
m := sync.Mutex{}
sd.tsCond = sync.NewCond(&m)

View File

@ -135,7 +135,7 @@ func (s *DelegatorDataSuite) SetupTest() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000)
}, 10000, nil)
s.Require().NoError(err)
}

View File

@ -159,7 +159,7 @@ func (s *DelegatorSuite) SetupTest() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil
},
}, 10000)
}, 10000, nil)
s.Require().NoError(err)
}

View File

@ -21,20 +21,17 @@ import (
"fmt"
"strconv"
"github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/funcutil"
@ -292,72 +289,6 @@ func (node *QueryNode) queryStreamSegments(ctx context.Context, req *querypb.Que
return nil
}
func (node *QueryNode) optimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, deleg delegator.ShardDelegator) (*querypb.SearchRequest, error) {
// no hook applied, just return
if node.queryHook == nil {
return req, nil
}
log := log.Ctx(ctx).With(zap.Int64("collection", req.GetReq().GetCollectionID()))
serializedPlan := req.GetReq().GetSerializedExprPlan()
// plan not found
if serializedPlan == nil {
log.Warn("serialized plan not found")
return req, merr.WrapErrParameterInvalid("serialized search plan", "nil")
}
channelNum := req.GetTotalChannelNum()
// not set, change to conservative channel num 1
if channelNum <= 0 {
channelNum = 1
}
plan := planpb.PlanNode{}
err := proto.Unmarshal(serializedPlan, &plan)
if err != nil {
log.Warn("failed to unmarshal plan", zap.Error(err))
return nil, merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
}
switch plan.GetNode().(type) {
case *planpb.PlanNode_VectorAnns:
// ignore growing ones for now since they will always be brute force
sealed, _ := deleg.GetSegmentInfo(true)
sealedNum := lo.Reduce(sealed, func(sum int, item delegator.SnapshotItem, _ int) int {
return sum + len(item.Segments)
}, 0)
// use shardNum * segments num in shard to estimate total segment number
estSegmentNum := sealedNum * int(channelNum)
withFilter := (plan.GetVectorAnns().GetPredicates() != nil)
queryInfo := plan.GetVectorAnns().GetQueryInfo()
params := map[string]any{
common.TopKKey: queryInfo.GetTopk(),
common.SearchParamKey: queryInfo.GetSearchParams(),
common.SegmentNumKey: estSegmentNum,
common.WithFilterKey: withFilter,
common.CollectionKey: req.GetReq().GetCollectionID(),
}
err := node.queryHook.Run(params)
if err != nil {
log.Warn("failed to execute queryHook", zap.Error(err))
return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed")
}
queryInfo.Topk = params[common.TopKKey].(int64)
queryInfo.SearchParams = params[common.SearchParamKey].(string)
serializedExprPlan, err := proto.Marshal(&plan)
if err != nil {
log.Warn("failed to marshal optimized plan", zap.Error(err))
return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
}
req.Req.SerializedExprPlan = serializedExprPlan
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
default:
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
}
return req, nil
}
func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchRequest, channel string) (*internalpb.SearchResults, error) {
log := log.Ctx(ctx).With(
zap.Int64("msgID", req.GetReq().GetBase().GetMsgID()),
@ -397,11 +328,6 @@ func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchReq
log.Warn("Query failed, failed to get shard delegator for search", zap.Error(err))
return nil, err
}
req, err = node.optimizeSearchParams(ctx, req, sd)
if err != nil {
log.Warn("failed to optimize search params", zap.Error(err))
return nil, err
}
// do search
results, err := sd.Search(searchCtx, req)
if err != nil {

View File

@ -21,22 +21,16 @@ import (
"os"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/optimizers"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
@ -143,180 +137,3 @@ func (suite *HandlersSuite) TestLoadGrowingSegments() {
func TestHandlersSuite(t *testing.T) {
suite.Run(t, new(HandlersSuite))
}
type OptimizeSearchParamSuite struct {
suite.Suite
// Data
collectionID int64
collectionName string
segmentID int64
channel string
node *QueryNode
delegator *delegator.MockShardDelegator
// Mock
factory *dependency.MockFactory
}
func (suite *OptimizeSearchParamSuite) SetupSuite() {
suite.collectionID = 111
suite.collectionName = "test-collection"
suite.segmentID = 1
suite.channel = "test-channel"
suite.delegator = &delegator.MockShardDelegator{}
suite.delegator.EXPECT().GetSegmentInfo(mock.Anything).Return([]delegator.SnapshotItem{{NodeID: 1, Segments: []delegator.SegmentEntry{{SegmentID: 100}}}}, []delegator.SegmentEntry{})
}
func (suite *OptimizeSearchParamSuite) SetupTest() {
suite.factory = dependency.NewMockFactory(suite.T())
suite.node = NewQueryNode(context.Background(), suite.factory)
}
func (suite *OptimizeSearchParamSuite) TearDownTest() {
}
func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
suite.Run("normal_run", func() {
mockHook := optimizers.NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(nil)
suite.node.queryHook = mockHook
defer func() { suite.node.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.delegator)
suite.NoError(err)
suite.verifyQueryInfo(req, 50, `{"param": 2}`)
})
suite.Run("no_hook", func() {
suite.node.queryHook = nil
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.delegator)
suite.NoError(err)
suite.verifyQueryInfo(req, 100, `{"param": 1}`)
})
suite.Run("other_plannode", func() {
mockHook := optimizers.NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(nil).Maybe()
suite.node.queryHook = mockHook
defer func() { suite.node.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_Query{},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.delegator)
suite.NoError(err)
suite.Equal(bs, req.GetReq().GetSerializedExprPlan())
})
suite.Run("no_serialized_plan", func() {
mockHook := optimizers.NewMockQueryHook(suite.T())
suite.node.queryHook = mockHook
defer func() { suite.node.queryHook = nil }()
_, err := suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{},
TotalChannelNum: 2,
}, suite.delegator)
suite.Error(err)
})
suite.Run("hook_run_error", func() {
mockHook := optimizers.NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(merr.WrapErrServiceInternal("mocked"))
suite.node.queryHook = mockHook
defer func() { suite.node.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
_, err = suite.node.optimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
}, suite.delegator)
suite.Error(err)
})
}
func (suite *OptimizeSearchParamSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, param string) {
planBytes := req.GetReq().GetSerializedExprPlan()
plan := planpb.PlanNode{}
err := proto.Unmarshal(planBytes, &plan)
suite.Require().NoError(err)
queryInfo := plan.GetVectorAnns().GetQueryInfo()
suite.Equal(topK, queryInfo.GetTopk())
suite.Equal(param, queryInfo.GetSearchParams())
}
func TestOptimizeSearchParam(t *testing.T) {
suite.Run(t, new(OptimizeSearchParamSuite))
}

View File

@ -1,5 +1,19 @@
package optimizers
import (
"context"
"fmt"
"github.com/golang/protobuf/proto"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
)
// QueryHook is the interface for search/query parameter optimizer.
type QueryHook interface {
Run(map[string]any) error
@ -7,3 +21,64 @@ type QueryHook interface {
InitTuningConfig(map[string]string) error
DeleteTuningConfig(string) error
}
func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, queryHook QueryHook, numSegments int) (*querypb.SearchRequest, error) {
// no hook applied, just return
if queryHook == nil {
return req, nil
}
log := log.Ctx(ctx).With(zap.Int64("collection", req.GetReq().GetCollectionID()))
serializedPlan := req.GetReq().GetSerializedExprPlan()
// plan not found
if serializedPlan == nil {
log.Warn("serialized plan not found")
return req, merr.WrapErrParameterInvalid("serialized search plan", "nil")
}
channelNum := req.GetTotalChannelNum()
// not set, change to conservative channel num 1
if channelNum <= 0 {
channelNum = 1
}
plan := planpb.PlanNode{}
err := proto.Unmarshal(serializedPlan, &plan)
if err != nil {
log.Warn("failed to unmarshal plan", zap.Error(err))
return nil, merr.WrapErrParameterInvalid("valid serialized search plan", "no unmarshalable one", err.Error())
}
switch plan.GetNode().(type) {
case *planpb.PlanNode_VectorAnns:
// use shardNum * segments num in shard to estimate total segment number
estSegmentNum := numSegments * int(channelNum)
withFilter := (plan.GetVectorAnns().GetPredicates() != nil)
queryInfo := plan.GetVectorAnns().GetQueryInfo()
params := map[string]any{
common.TopKKey: queryInfo.GetTopk(),
common.SearchParamKey: queryInfo.GetSearchParams(),
common.SegmentNumKey: estSegmentNum,
common.WithFilterKey: withFilter,
common.CollectionKey: req.GetReq().GetCollectionID(),
}
err := queryHook.Run(params)
if err != nil {
log.Warn("failed to execute queryHook", zap.Error(err))
return nil, merr.WrapErrServiceUnavailable(err.Error(), "queryHook execution failed")
}
queryInfo.Topk = params[common.TopKKey].(int64)
queryInfo.SearchParams = params[common.SearchParamKey].(string)
serializedExprPlan, err := proto.Marshal(&plan)
if err != nil {
log.Warn("failed to marshal optimized plan", zap.Error(err))
return nil, merr.WrapErrParameterInvalid("marshalable search plan", "plan with marshal error", err.Error())
}
req.Req.SerializedExprPlan = serializedExprPlan
log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo))
default:
log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode())))
}
return req, nil
}

View File

@ -0,0 +1,173 @@
package optimizers
import (
"context"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type QueryHookSuite struct {
suite.Suite
queryHook QueryHook
}
func (suite *QueryHookSuite) SetupTest() {
}
func (suite *QueryHookSuite) TearDownTest() {
suite.queryHook = nil
}
func (suite *QueryHookSuite) TestOptimizeSearchParam() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
suite.Run("normal_run", func() {
mockHook := NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(nil)
suite.queryHook = mockHook
defer func() { suite.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.queryHook, 2)
suite.NoError(err)
suite.verifyQueryInfo(req, 50, `{"param": 2}`)
})
suite.Run("no_hook", func() {
suite.queryHook = nil
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.queryHook, 2)
suite.NoError(err)
suite.verifyQueryInfo(req, 100, `{"param": 1}`)
})
suite.Run("other_plannode", func() {
mockHook := NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(nil).Maybe()
suite.queryHook = mockHook
defer func() { suite.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_Query{},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
TotalChannelNum: 2,
}, suite.queryHook, 2)
suite.NoError(err)
suite.Equal(bs, req.GetReq().GetSerializedExprPlan())
})
suite.Run("no_serialized_plan", func() {
mockHook := NewMockQueryHook(suite.T())
suite.queryHook = mockHook
defer func() { suite.queryHook = nil }()
_, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{},
TotalChannelNum: 2,
}, suite.queryHook, 2)
suite.Error(err)
})
suite.Run("hook_run_error", func() {
mockHook := NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}`
}).Return(merr.WrapErrServiceInternal("mocked"))
suite.queryHook = mockHook
defer func() { suite.queryHook = nil }()
plan := &planpb.PlanNode{
Node: &planpb.PlanNode_VectorAnns{
VectorAnns: &planpb.VectorANNS{
QueryInfo: &planpb.QueryInfo{
Topk: 100,
SearchParams: `{"param": 1}`,
},
},
},
}
bs, err := proto.Marshal(plan)
suite.Require().NoError(err)
_, err = OptimizeSearchParams(ctx, &querypb.SearchRequest{
Req: &internalpb.SearchRequest{
SerializedExprPlan: bs,
},
}, suite.queryHook, 2)
suite.Error(err)
})
}
func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, param string) {
planBytes := req.GetReq().GetSerializedExprPlan()
plan := planpb.PlanNode{}
err := proto.Unmarshal(planBytes, &plan)
suite.Require().NoError(err)
queryInfo := plan.GetVectorAnns().GetQueryInfo()
suite.Equal(topK, queryInfo.GetTopk())
suite.Equal(param, queryInfo.GetSearchParams())
}
func TestOptimizeSearchParam(t *testing.T) {
suite.Run(t, new(QueryHookSuite))
}

View File

@ -254,8 +254,21 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
node.composeIndexMeta(req.GetIndexInfoList(), req.Schema), req.GetLoadMeta())
collection := node.manager.Collection.Get(req.GetCollectionID())
collection.SetMetricType(req.GetLoadMeta().GetMetricType())
delegator, err := delegator.NewShardDelegator(ctx, req.GetCollectionID(), req.GetReplicaID(), channel.GetChannelName(), req.GetVersion(),
node.clusterManager, node.manager, node.tSafeManager, node.loader, node.factory, channel.GetSeekPosition().GetTimestamp())
delegator, err := delegator.NewShardDelegator(
ctx,
req.GetCollectionID(),
req.GetReplicaID(),
channel.GetChannelName(),
req.GetVersion(),
node.clusterManager,
node.manager,
node.tSafeManager,
node.loader,
node.factory,
channel.GetSeekPosition().GetTimestamp(),
node.queryHook,
)
if err != nil {
log.Warn("failed to create shard delegator", zap.Error(err))
return merr.Status(err), nil