fix: Fix exclude nodes clearing logic position in load balancer retry (#42577)

issue: #42561
Move the exclude nodes clearing logic from ExecuteWithRetry to
selectNode after shard leader cache refresh to ensure proper retry
behavior:
- Remove premature exclude clearing in ExecuteWithRetry that happened
before shard leader cache update
- Add exclude clearing logic in selectNode after refreshing shard leader
cache when all replicas are excluded
- Ensure multiple retries can properly update shard leader cache and
clear exclude list when needed
- Add comprehensive tests for edge cases including empty shard leaders
and mixed serviceable node scenarios

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2025-06-17 08:15:24 +08:00 committed by GitHub
parent 5539636b5b
commit 0b4a17c22b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 605 additions and 206 deletions

View File

@ -41,10 +41,8 @@ type ChannelWorkload struct {
collectionName string
collectionID int64
channel string
shardLeaders []nodeInfo
nq int64
exec executeFunc
retryTimes uint
}
type CollectionWorkLoad struct {
@ -105,51 +103,72 @@ func (lb *LBPolicyImpl) Start(ctx context.Context) {
}
}
// GetShardLeaders should always retry until ctx done, except the collection is not loaded.
func (lb *LBPolicyImpl) GetShardLeaders(ctx context.Context, dbName string, collName string, collectionID int64, withCache bool) (map[string][]nodeInfo, error) {
var shardLeaders map[string][]nodeInfo
// use retry to handle query coord service not ready
// GetShard will retry until ctx done, except the collection is not loaded.
// return all replicas of shard from cache if withCache is true, otherwise return shard leaders from coord.
func (lb *LBPolicyImpl) GetShard(ctx context.Context, dbName string, collName string, collectionID int64, channel string, withCache bool) ([]nodeInfo, error) {
var shardLeaders []nodeInfo
err := retry.Handle(ctx, func() (bool, error) {
var err error
shardLeaders, err = globalMetaCache.GetShards(ctx, withCache, dbName, collName, collectionID)
if err != nil {
return !errors.Is(err, merr.ErrCollectionNotLoaded), err
}
return false, nil
shardLeaders, err = globalMetaCache.GetShard(ctx, withCache, dbName, collName, collectionID, channel)
return !errors.Is(err, merr.ErrCollectionNotLoaded), err
})
return shardLeaders, err
}
// GetShardLeaderList will retry until ctx done, except the collection is not loaded.
// return all shard(channel) from cache if withCache is true, otherwise return shard leaders from coord.
func (lb *LBPolicyImpl) GetShardLeaderList(ctx context.Context, dbName string, collName string, collectionID int64, withCache bool) ([]string, error) {
var ret []string
err := retry.Handle(ctx, func() (bool, error) {
var err error
ret, err = globalMetaCache.GetShardLeaderList(ctx, dbName, collName, collectionID, withCache)
return !errors.Is(err, merr.ErrCollectionNotLoaded), err
})
return ret, err
}
// try to select the best node from the available nodes
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload *ChannelWorkload, excludeNodes typeutil.UniqueSet) (nodeInfo, error) {
log := log.Ctx(ctx)
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes *typeutil.UniqueSet) (nodeInfo, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
)
// Select node using specified nodes
trySelectNode := func(nodes []nodeInfo) (nodeInfo, error) {
candidateNodes := make(map[int64]nodeInfo)
serviceableNodes := make(map[int64]nodeInfo)
// Filter nodes based on excludeNodes
for _, node := range nodes {
if !excludeNodes.Contain(node.nodeID) {
if node.serviceable {
serviceableNodes[node.nodeID] = node
trySelectNode := func(withCache bool) (nodeInfo, error) {
shardLeaders, err := lb.GetShard(ctx, workload.db, workload.collectionName, workload.collectionID, workload.channel, withCache)
if err != nil {
log.Warn("failed to get shard delegator",
zap.Error(err))
return nodeInfo{}, err
}
// if all available delegator has been excluded even after refresh shard leader cache
// we should clear excludeNodes and try to select node again instead of failing the request at selectNode
if len(shardLeaders) > 0 && len(shardLeaders) == excludeNodes.Len() {
allReplicaExcluded := true
for _, node := range shardLeaders {
if !excludeNodes.Contain(node.nodeID) {
allReplicaExcluded = false
break
}
candidateNodes[node.nodeID] = node
}
if allReplicaExcluded {
log.Warn("all replicas are excluded after refresh shard leader cache, clear it and try to select node")
excludeNodes.Clear()
}
}
var err error
candidateNodes := make(map[int64]nodeInfo)
serviceableNodes := make(map[int64]nodeInfo)
defer func() {
if err != nil {
candidatesInStr := lo.Map(nodes, func(node nodeInfo, _ int) string {
candidatesInStr := lo.Map(shardLeaders, func(node nodeInfo, _ int) string {
return node.String()
})
serviceableNodesInStr := lo.Map(lo.Values(serviceableNodes), func(node nodeInfo, _ int) string {
return node.String()
})
log.Warn("failed to select shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("excluded", excludeNodes.Collect()),
zap.String("candidates", strings.Join(candidatesInStr, ", ")),
zap.String("serviceableNodes", strings.Join(serviceableNodesInStr, ", ")),
@ -157,8 +176,17 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
}
}()
// Filter nodes based on excludeNodes
for _, node := range shardLeaders {
if !excludeNodes.Contain(node.nodeID) {
if node.serviceable {
serviceableNodes[node.nodeID] = node
}
candidateNodes[node.nodeID] = node
}
}
if len(candidateNodes) == 0 {
err = merr.WrapErrChannelNotAvailable(workload.channel)
err = merr.WrapErrChannelNotAvailable(workload.channel, "no available shard leaders")
return nodeInfo{}, err
}
@ -182,23 +210,13 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
return candidateNodes[targetNodeID], nil
}
// First attempt with current shard leaders
targetNode, err := trySelectNode(workload.shardLeaders)
// If failed, refresh cache and retry
// First attempt with current shard leaders cache
withShardLeaderCache := true
targetNode, err := trySelectNode(withShardLeaderCache)
if err != nil {
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
shardLeaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, false)
if err != nil {
log.Warn("failed to get shard delegator",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Error(err))
return nodeInfo{}, err
}
workload.shardLeaders = shardLeaders[workload.channel]
// Second attempt with fresh shard leaders
targetNode, err = trySelectNode(workload.shardLeaders)
withShardLeaderCache = false
targetNode, err = trySelectNode(withShardLeaderCache)
if err != nil {
return nodeInfo{}, err
}
@ -209,20 +227,17 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
// ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes.
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
)
var lastErr error
excludeNodes := typeutil.NewUniqueSet()
tryExecute := func() (bool, error) {
// if keeping retry after all nodes are excluded, try to clean excludeNodes
if excludeNodes.Len() == len(workload.shardLeaders) {
excludeNodes.Clear()
}
balancer := lb.getBalancer()
targetNode, err := lb.selectNode(ctx, balancer, &workload, excludeNodes)
targetNode, err := lb.selectNode(ctx, balancer, workload, &excludeNodes)
if err != nil {
log.Warn("failed to select node for shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode.nodeID),
zap.Int64s("excluded", excludeNodes.Collect()),
zap.Error(err),
@ -238,8 +253,6 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
client, err := lb.clientMgr.GetClient(ctx, targetNode)
if err != nil {
log.Warn("search/query channel failed, node not available",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode.nodeID),
zap.Error(err))
excludeNodes.Insert(targetNode.nodeID)
@ -251,8 +264,6 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
err = workload.exec(ctx, targetNode.nodeID, client, workload.channel)
if err != nil {
log.Warn("search/query channel failed",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode.nodeID),
zap.Error(err))
excludeNodes.Insert(targetNode.nodeID)
@ -263,10 +274,15 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
return true, nil
}
// if failed, try to execute with partial result
err := retry.Handle(ctx, tryExecute, retry.Attempts(workload.retryTimes))
shardLeaders, err := lb.GetShard(ctx, workload.db, workload.collectionName, workload.collectionID, workload.channel, true)
if err != nil {
log.Ctx(ctx).Warn("failed to execute with partial result",
log.Warn("failed to get shard leaders", zap.Error(err))
return err
}
retryTimes := max(lb.retryOnReplica, len(shardLeaders))
err = retry.Handle(ctx, tryExecute, retry.Attempts(uint(retryTimes)))
if err != nil {
log.Warn("failed to execute",
zap.String("channel", workload.channel),
zap.Error(err))
}
@ -276,69 +292,54 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
// Execute will execute collection workload in parallel
func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error {
dml2leaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, true)
log := log.Ctx(ctx).With(
zap.Int64("collectionID", workload.collectionID),
)
channelList, err := lb.GetShardLeaderList(ctx, workload.db, workload.collectionName, workload.collectionID, true)
if err != nil {
log.Ctx(ctx).Warn("failed to get shards", zap.Error(err))
log.Warn("failed to get shards", zap.Error(err))
return err
}
totalChannels := len(dml2leaders)
if totalChannels == 0 {
log.Ctx(ctx).Info("no shard leaders found", zap.Int64("collectionID", workload.collectionID))
if len(channelList) == 0 {
log.Info("no shard leaders found", zap.Int64("collectionID", workload.collectionID))
return merr.WrapErrCollectionNotLoaded(workload.collectionID)
}
wg, _ := errgroup.WithContext(ctx)
// Launch a goroutine for each channel
for k, v := range dml2leaders {
channel := k
nodes := v
channelRetryTimes := lb.retryOnReplica
if len(nodes) > 0 {
channelRetryTimes *= len(nodes)
}
for _, channel := range channelList {
wg.Go(func() error {
return lb.ExecuteWithRetry(ctx, ChannelWorkload{
db: workload.db,
collectionName: workload.collectionName,
collectionID: workload.collectionID,
channel: channel,
shardLeaders: nodes,
nq: workload.nq,
exec: workload.exec,
retryTimes: uint(channelRetryTimes),
})
})
}
return wg.Wait()
}
// Execute will execute any one channel in collection workload
func (lb *LBPolicyImpl) ExecuteOneChannel(ctx context.Context, workload CollectionWorkLoad) error {
dml2leaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, true)
channelList, err := lb.GetShardLeaderList(ctx, workload.db, workload.collectionName, workload.collectionID, true)
if err != nil {
log.Ctx(ctx).Warn("failed to get shards", zap.Error(err))
return err
}
// let every request could retry at least twice, which could retry after update shard leader cache
for k, v := range dml2leaders {
channel := k
nodes := v
channelRetryTimes := lb.retryOnReplica
if len(nodes) > 0 {
channelRetryTimes *= len(nodes)
}
for _, channel := range channelList {
return lb.ExecuteWithRetry(ctx, ChannelWorkload{
db: workload.db,
collectionName: workload.collectionName,
collectionID: workload.collectionID,
channel: channel,
shardLeaders: nodes,
nq: workload.nq,
exec: workload.exec,
retryTimes: uint(channelRetryTimes),
})
}
return fmt.Errorf("no acitvate sheard leader exist for collection: %s", workload.collectionName)

View File

@ -178,14 +178,14 @@ func (s *LBPolicySuite) TestSelectNode() {
ctx := context.Background()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet())
// shardLeaders: s.nodes,
nq: 1,
}, &typeutil.UniqueSet{})
s.NoError(err)
s.Equal(int64(5), targetNode.nodeID)
@ -194,14 +194,14 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet())
// shardLeaders: s.nodes,
nq: 1,
}, &typeutil.UniqueSet{})
s.NoError(err)
s.Equal(int64(3), targetNode.nodeID)
@ -209,29 +209,30 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: []nodeInfo{},
nq: 1,
}, typeutil.NewUniqueSet())
// shardLeaders: []nodeInfo{},
nq: 1,
}, &typeutil.UniqueSet{})
s.ErrorIs(err, merr.ErrNodeNotAvailable)
// test all nodes has been excluded, expected failure
// test all nodes has been excluded, expected clear excludeNodes and try to select node again
excludeNodes := typeutil.NewUniqueSet(s.nodeIDs...)
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet(s.nodeIDs...))
s.ErrorIs(err, merr.ErrChannelNotAvailable)
// shardLeaders: s.nodes,
nq: 1,
}, &excludeNodes)
s.ErrorIs(err, merr.ErrNodeNotAvailable)
// test get shard leaders failed, retry to select node failed
s.lbBalancer.ExpectedCalls = nil
@ -240,14 +241,14 @@ func (s *LBPolicySuite) TestSelectNode() {
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return nil, merr.ErrServiceUnavailable
}
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
}, typeutil.NewUniqueSet())
// shardLeaders: s.nodes,
nq: 1,
}, &typeutil.UniqueSet{})
s.ErrorIs(err, merr.ErrServiceUnavailable)
}
@ -265,12 +266,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
retryTimes: 1,
})
s.NoError(err)
@ -283,12 +282,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
retryTimes: 1,
})
s.ErrorIs(err, merr.ErrNodeNotAvailable)
@ -304,12 +301,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
retryTimes: 1,
})
s.Error(err)
@ -327,12 +322,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
retryTimes: 2,
})
s.NoError(err)
@ -351,7 +344,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
counter++
@ -360,7 +352,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
}
return nil
},
retryTimes: 2,
})
s.NoError(err)
@ -376,13 +367,11 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
shardLeaders: s.nodes,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
_, err := qn.Search(ctx, nil)
return err
},
retryTimes: 2,
})
s.True(merr.IsCanceledOrTimeout(err))
}
@ -461,7 +450,7 @@ func (s *LBPolicySuite) TestExecute() {
},
})
s.Error(err)
s.Equal(int64(26), counter.Load())
s.Equal(int64(6), counter.Load())
// test get shard leader failed
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
@ -501,7 +490,7 @@ func (s *LBPolicySuite) TestNewLBPolicy() {
policy.Close()
}
func (s *LBPolicySuite) TestGetShardLeaders() {
func (s *LBPolicySuite) TestGetShard() {
ctx := context.Background()
// ErrCollectionNotFullyLoaded is retriable, expected to retry until ctx done or success
@ -518,7 +507,7 @@ func (s *LBPolicySuite) TestGetShardLeaders() {
return nil, nil
}
_, err := s.lbPolicy.GetShardLeaders(ctx, dbName, s.collectionName, s.collectionID, true)
_, err := s.lbPolicy.GetShard(ctx, dbName, s.collectionName, s.collectionID, s.channels[0], true)
s.NoError(err)
s.Equal(int64(0), counter.Load())
@ -529,12 +518,250 @@ func (s *LBPolicySuite) TestGetShardLeaders() {
counter.Inc()
return nil, merr.ErrCollectionNotLoaded
}
_, err = s.lbPolicy.GetShardLeaders(ctx, dbName, s.collectionName, s.collectionID, true)
_, err = s.lbPolicy.GetShard(ctx, dbName, s.collectionName, s.collectionID, s.channels[0], true)
log.Info("check err", zap.Error(err))
s.Error(err)
s.Equal(int64(1), counter.Load())
}
func (s *LBPolicySuite) TestSelectNodeEdgeCases() {
ctx := context.Background()
// Test case 1: Empty shard leaders after refresh, should fail gracefully
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable).Times(1)
// Setup mock to return empty shard leaders
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: []int64{}, // Empty node list
NodeAddrs: []string{},
Serviceable: []bool{},
},
},
}, nil
}
excludeNodes := typeutil.NewUniqueSet(s.nodeIDs...)
_, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.Error(err)
log.Info("test case 1")
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
// Test case 2: Single replica scenario - exclude it, refresh shows same single replica, should clear and succeed
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Times(1)
singleNodeList := []int64{1}
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: singleNodeList,
NodeAddrs: []string{"localhost:9000"},
Serviceable: []bool{true},
},
},
}, nil
}
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude the single node
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(1), targetNode.nodeID)
s.Equal(0, excludeNodes.Len()) // Should be cleared
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
mixedNodeIDs := []int64{1, 2, 3}
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil).Times(1)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: mixedNodeIDs,
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, false, true},
},
},
}, nil
}
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude node 1, node 3 should be available
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(3), targetNode.nodeID)
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded
}
func (s *LBPolicySuite) TestGetShardLeaderList() {
ctx := context.Background()
// Test normal scenario with cache
channelList, err := s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, true)
s.NoError(err)
s.Equal(len(s.channels), len(channelList))
s.Contains(channelList, s.channels[0])
s.Contains(channelList, s.channels[1])
// Test without cache - should refresh from coordinator
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
channelList, err = s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, false)
s.NoError(err)
s.Equal(len(s.channels), len(channelList))
// Test error case - collection not loaded
counter := atomic.NewInt64(0)
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
counter.Inc()
return nil, merr.ErrCollectionNotLoaded
}
_, err = s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, true)
s.Error(err)
s.ErrorIs(err, merr.ErrCollectionNotLoaded)
s.Equal(int64(1), counter.Load())
}
func (s *LBPolicySuite) TestSelectNodeWithExcludeClearing() {
ctx := context.Background()
// Test exclude nodes clearing when all replicas are excluded after cache refresh
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
// First attempt fails due to no candidates
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable).Times(1)
// Second attempt succeeds after exclude nodes are cleared
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Times(1)
// Setup mock to return only excluded nodes first, then same nodes for retry
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: []int64{1, 2}, // All these will be excluded
NodeAddrs: []string{"localhost:9000", "localhost:9001"},
Serviceable: []bool{true, true},
},
},
}, nil
}
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
excludeNodes := typeutil.NewUniqueSet(int64(1), int64(2)) // Exclude all available nodes
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(1), targetNode.nodeID)
s.Equal(0, excludeNodes.Len()) // Should be cleared when all replicas were excluded
// Test exclude nodes NOT cleared when only partial replicas are excluded
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(2, nil).Times(1)
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: []int64{1, 2, 3}, // Node 2 and 3 are still available
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil
}
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Only exclude node 1
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(2), targetNode.nodeID)
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded
// Test empty shard leaders scenario
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: []int64{}, // Empty shard leaders
NodeAddrs: []string{},
Serviceable: []bool{},
},
},
}, nil
}
excludeNodes = typeutil.NewUniqueSet(int64(1))
_, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.Error(err)
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared for empty shard leaders
}
func TestLBPolicySuite(t *testing.T) {
suite.Run(t, new(LBPolicySuite))
}

View File

@ -70,7 +70,8 @@ type Cache interface {
GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error)
// GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error)
GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error)
GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error)
GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error)
DeprecateShardCache(database, collectionName string)
InvalidateShardLeaderCache(collections []int64)
ListShardLocation() map[int64]nodeInfo
@ -283,6 +284,14 @@ type shardLeaders struct {
shardLeaders map[string][]nodeInfo
}
func (sl *shardLeaders) Get(channel string) []nodeInfo {
return sl.shardLeaders[channel]
}
func (sl *shardLeaders) GetShardLeaderList() []string {
return lo.Keys(sl.shardLeaders)
}
type shardLeadersReader struct {
leaders *shardLeaders
idx int64
@ -944,15 +953,39 @@ func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
m.credMap[username].Sha256Password = credInfo.Sha256Password
}
// GetShards update cache if withCache == false
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) {
method := "GetShards"
log := log.Ctx(ctx).With(
zap.String("db", database),
zap.String("collectionName", collectionName),
zap.Int64("collectionID", collectionID))
func (m *MetaCache) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) {
method := "GetShard"
// check cache first
cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method)
if cacheShardLeaders == nil || !withCache {
// refresh shard leader cache
newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
cacheShardLeaders = newShardLeaders
}
return cacheShardLeaders.Get(channel), nil
}
func (m *MetaCache) GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error) {
method := "GetShardLeaderList"
// check cache first
cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method)
if cacheShardLeaders == nil || !withCache {
// refresh shard leader cache
newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
cacheShardLeaders = newShardLeaders
}
return cacheShardLeaders.GetShardLeaderList(), nil
}
func (m *MetaCache) getCachedShardLeaders(database, collectionName, caller string) *shardLeaders {
m.leaderMut.RLock()
var cacheShardLeaders *shardLeaders
db, ok := m.collLeader[database]
@ -962,45 +995,44 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
cacheShardLeaders = db[collectionName]
}
m.leaderMut.RUnlock()
if withCache {
if cacheShardLeaders != nil {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc()
iterator := cacheShardLeaders.GetReader()
return iterator.Shuffle(), nil
}
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
if cacheShardLeaders != nil {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheHitLabel).Inc()
} else {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheMissLabel).Inc()
}
info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
return cacheShardLeaders
}
func (m *MetaCache) updateShardLocationCache(ctx context.Context, database, collectionName string, collectionID int64) (*shardLeaders, error) {
log := log.Ctx(ctx).With(
zap.String("db", database),
zap.String("collectionName", collectionName),
zap.Int64("collectionID", collectionID))
method := "updateShardLocationCache"
tr := timerecord.NewTimeRecorder(method)
defer metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).
Observe(float64(tr.ElapseSpan().Milliseconds()))
req := &querypb.GetShardLeadersRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
CollectionID: info.collID,
CollectionID: collectionID,
WithUnserviceableShards: true,
}
tr := timerecord.NewTimeRecorder("UpdateShardCache")
resp, err := m.mixCoord.GetShardLeaders(ctx, req)
if err != nil {
return nil, err
}
if err = merr.Error(resp.GetStatus()); err != nil {
if err := merr.CheckRPCCall(resp.GetStatus(), err); err != nil {
log.Error("failed to get shard locations",
zap.Int64("collectionID", collectionID),
zap.Error(err))
return nil, err
}
shards := parseShardLeaderList2QueryNode(resp.GetShards())
newShardLeaders := &shardLeaders{
collectionID: info.collID,
shardLeaders: shards,
idx: atomic.NewInt64(0),
}
// convert shards map to string for logging
if log.Logger.Level() == zap.DebugLevel {
@ -1015,23 +1047,20 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
log.Debug("update shard leader cache", zap.String("newShardLeaders", strings.Join(shardStr, ", ")))
}
newShardLeaders := &shardLeaders{
collectionID: collectionID,
shardLeaders: shards,
idx: atomic.NewInt64(0),
}
m.leaderMut.Lock()
if _, ok := m.collLeader[database]; !ok {
m.collLeader[database] = make(map[string]*shardLeaders)
}
m.collLeader[database][collectionName] = newShardLeaders
iterator := newShardLeaders.GetReader()
ret := iterator.Shuffle()
m.leaderMut.Unlock()
nodeInfos := make([]string, 0)
for ch, shardLeader := range newShardLeaders.shardLeaders {
for _, nodeInfo := range shardLeader {
nodeInfos = append(nodeInfos, fmt.Sprintf("channel %s, nodeID: %d, nodeAddr: %s", ch, nodeInfo.nodeID, nodeInfo.address))
}
}
log.Debug("fill new collection shard leader", zap.Strings("nodeInfos", nodeInfos))
metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
return ret, nil
return newShardLeaders, nil
}
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]nodeInfo {

View File

@ -1362,7 +1362,7 @@ func TestMetaCache_GetPartitionError(t *testing.T) {
assert.Equal(t, id, typeutil.UniqueID(0))
}
func TestMetaCache_GetShards(t *testing.T) {
func TestMetaCache_GetShard(t *testing.T) {
var (
ctx = context.Background()
collectionName = "collection1"
@ -1375,13 +1375,13 @@ func TestMetaCache_GetShards(t *testing.T) {
require.Nil(t, err)
t.Run("No collection in meta cache", func(t *testing.T) {
shards, err := globalMetaCache.GetShards(ctx, true, dbName, "non-exists", 0)
shards, err := globalMetaCache.GetShard(ctx, true, dbName, "non-exists", 0, "channel-1")
assert.Error(t, err)
assert.Empty(t, shards)
})
t.Run("without shardLeaders in collection info invalid shardLeaders", func(t *testing.T) {
shards, err := globalMetaCache.GetShards(ctx, false, dbName, collectionName, collectionID)
shards, err := globalMetaCache.GetShard(ctx, false, dbName, collectionName, collectionID, "channel-1")
assert.Error(t, err)
assert.Empty(t, shards)
})
@ -1401,14 +1401,12 @@ func TestMetaCache_GetShards(t *testing.T) {
}, nil
}
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
shards, err := globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards["channel-1"]))
assert.Equal(t, 3, len(shards))
// get from cache
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
@ -1418,12 +1416,11 @@ func TestMetaCache_GetShards(t *testing.T) {
}, nil
}
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
shards, err = globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards["channel-1"]))
assert.Equal(t, 3, len(shards))
})
}
@ -1462,11 +1459,10 @@ func TestMetaCache_ClearShards(t *testing.T) {
}, nil
}
shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
shards, err := globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
require.NoError(t, err)
require.NotEmpty(t, shards)
require.Equal(t, 1, len(shards))
require.Equal(t, 3, len(shards["channel-1"]))
require.Equal(t, 3, len(shards))
globalMetaCache.DeprecateShardCache(dbName, collectionName)
@ -1479,7 +1475,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
}, nil
}
shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID)
shards, err = globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
assert.Error(t, err)
assert.Empty(t, shards)
})
@ -1843,18 +1839,18 @@ func TestMetaCache_InvalidateShardLeaderCache(t *testing.T) {
},
}, nil
}
nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
nodeInfos, err := globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1")
assert.NoError(t, err)
assert.Len(t, nodeInfos["channel-1"], 3)
assert.Len(t, nodeInfos, 3)
assert.Equal(t, called.Load(), int32(1))
globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1")
assert.Equal(t, called.Load(), int32(1))
globalMetaCache.InvalidateShardLeaderCache([]int64{1})
nodeInfos, err = globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1)
nodeInfos, err = globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1")
assert.NoError(t, err)
assert.Len(t, nodeInfos["channel-1"], 3)
assert.Len(t, nodeInfos, 3)
assert.Equal(t, called.Load(), int32(2))
}
@ -2155,3 +2151,86 @@ func TestMetaCache_Parallel(t *testing.T) {
_, ok = cache.collInfo[dbName]["collection1"]
assert.True(t, ok)
}
func TestMetaCache_GetShardLeaderList(t *testing.T) {
var (
ctx = context.Background()
collectionName = "collection1"
collectionID = int64(1)
)
rootCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
require.Nil(t, err)
t.Run("No collection in meta cache", func(t *testing.T) {
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, "non-exists", 0, true)
assert.Error(t, err)
assert.Empty(t, channels)
})
t.Run("Get channel list without cache", func(t *testing.T) {
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: merr.Success(),
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
{
ChannelName: "channel-2",
NodeIds: []int64{4, 5, 6},
NodeAddrs: []string{"localhost:9003", "localhost:9004", "localhost:9005"},
Serviceable: []bool{true, true, true},
},
},
}, nil
}
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, false)
assert.NoError(t, err)
assert.Equal(t, 2, len(channels))
assert.Contains(t, channels, "channel-1")
assert.Contains(t, channels, "channel-2")
})
t.Run("Get channel list with cache", func(t *testing.T) {
// First call should populate cache
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true)
assert.NoError(t, err)
assert.Equal(t, 2, len(channels))
// Mock should return error but cache should be used
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "should not be called when using cache",
},
}, nil
}
channels, err = globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true)
assert.NoError(t, err)
assert.Equal(t, 2, len(channels))
assert.Contains(t, channels, "channel-1")
assert.Contains(t, channels, "channel-2")
})
t.Run("Error from coordinator", func(t *testing.T) {
// Deprecate cache first
globalMetaCache.DeprecateShardCache(dbName, collectionName)
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return nil, errors.New("coordinator error")
}
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true)
assert.Error(t, err)
assert.Empty(t, channels)
})
}

View File

@ -757,29 +757,29 @@ func (_c *MockCache_GetPrivilegeInfo_Call) RunAndReturn(run func(context.Context
return _c
}
// GetShards provides a mock function with given fields: ctx, withCache, database, collectionName, collectionID
func (_m *MockCache) GetShards(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64) (map[string][]nodeInfo, error) {
ret := _m.Called(ctx, withCache, database, collectionName, collectionID)
// GetShard provides a mock function with given fields: ctx, withCache, database, collectionName, collectionID, channel
func (_m *MockCache) GetShard(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) {
ret := _m.Called(ctx, withCache, database, collectionName, collectionID, channel)
if len(ret) == 0 {
panic("no return value specified for GetShards")
panic("no return value specified for GetShard")
}
var r0 map[string][]nodeInfo
var r0 []nodeInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64) (map[string][]nodeInfo, error)); ok {
return rf(ctx, withCache, database, collectionName, collectionID)
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) ([]nodeInfo, error)); ok {
return rf(ctx, withCache, database, collectionName, collectionID, channel)
}
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64) map[string][]nodeInfo); ok {
r0 = rf(ctx, withCache, database, collectionName, collectionID)
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) []nodeInfo); ok {
r0 = rf(ctx, withCache, database, collectionName, collectionID, channel)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string][]nodeInfo)
r0 = ret.Get(0).([]nodeInfo)
}
}
if rf, ok := ret.Get(1).(func(context.Context, bool, string, string, int64) error); ok {
r1 = rf(ctx, withCache, database, collectionName, collectionID)
if rf, ok := ret.Get(1).(func(context.Context, bool, string, string, int64, string) error); ok {
r1 = rf(ctx, withCache, database, collectionName, collectionID, channel)
} else {
r1 = ret.Error(1)
}
@ -787,34 +787,97 @@ func (_m *MockCache) GetShards(ctx context.Context, withCache bool, database str
return r0, r1
}
// MockCache_GetShards_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShards'
type MockCache_GetShards_Call struct {
// MockCache_GetShard_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShard'
type MockCache_GetShard_Call struct {
*mock.Call
}
// GetShards is a helper method to define mock.On call
// GetShard is a helper method to define mock.On call
// - ctx context.Context
// - withCache bool
// - database string
// - collectionName string
// - collectionID int64
func (_e *MockCache_Expecter) GetShards(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}, collectionID interface{}) *MockCache_GetShards_Call {
return &MockCache_GetShards_Call{Call: _e.mock.On("GetShards", ctx, withCache, database, collectionName, collectionID)}
// - channel string
func (_e *MockCache_Expecter) GetShard(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}, collectionID interface{}, channel interface{}) *MockCache_GetShard_Call {
return &MockCache_GetShard_Call{Call: _e.mock.On("GetShard", ctx, withCache, database, collectionName, collectionID, channel)}
}
func (_c *MockCache_GetShards_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64)) *MockCache_GetShards_Call {
func (_c *MockCache_GetShard_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string)) *MockCache_GetShard_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string), args[4].(int64))
run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string), args[4].(int64), args[5].(string))
})
return _c
}
func (_c *MockCache_GetShards_Call) Return(_a0 map[string][]nodeInfo, _a1 error) *MockCache_GetShards_Call {
func (_c *MockCache_GetShard_Call) Return(_a0 []nodeInfo, _a1 error) *MockCache_GetShard_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCache_GetShards_Call) RunAndReturn(run func(context.Context, bool, string, string, int64) (map[string][]nodeInfo, error)) *MockCache_GetShards_Call {
func (_c *MockCache_GetShard_Call) RunAndReturn(run func(context.Context, bool, string, string, int64, string) ([]nodeInfo, error)) *MockCache_GetShard_Call {
_c.Call.Return(run)
return _c
}
// GetShardLeaderList provides a mock function with given fields: ctx, database, collectionName, collectionID, withCache
func (_m *MockCache) GetShardLeaderList(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool) ([]string, error) {
ret := _m.Called(ctx, database, collectionName, collectionID, withCache)
if len(ret) == 0 {
panic("no return value specified for GetShardLeaderList")
}
var r0 []string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) ([]string, error)); ok {
return rf(ctx, database, collectionName, collectionID, withCache)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) []string); ok {
r0 = rf(ctx, database, collectionName, collectionID, withCache)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, int64, bool) error); ok {
r1 = rf(ctx, database, collectionName, collectionID, withCache)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetShardLeaderList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShardLeaderList'
type MockCache_GetShardLeaderList_Call struct {
*mock.Call
}
// GetShardLeaderList is a helper method to define mock.On call
// - ctx context.Context
// - database string
// - collectionName string
// - collectionID int64
// - withCache bool
func (_e *MockCache_Expecter) GetShardLeaderList(ctx interface{}, database interface{}, collectionName interface{}, collectionID interface{}, withCache interface{}) *MockCache_GetShardLeaderList_Call {
return &MockCache_GetShardLeaderList_Call{Call: _e.mock.On("GetShardLeaderList", ctx, database, collectionName, collectionID, withCache)}
}
func (_c *MockCache_GetShardLeaderList_Call) Run(run func(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool)) *MockCache_GetShardLeaderList_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(int64), args[4].(bool))
})
return _c
}
func (_c *MockCache_GetShardLeaderList_Call) Return(_a0 []string, _a1 error) *MockCache_GetShardLeaderList_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCache_GetShardLeaderList_Call) RunAndReturn(run func(context.Context, string, string, int64, bool) ([]string, error)) *MockCache_GetShardLeaderList_Call {
_c.Call.Return(run)
return _c
}

View File

@ -1152,7 +1152,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(info, nil).Maybe()
cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe()
cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe()
cache.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]nodeInfo{}, nil).Maybe()
cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
globalMetaCache = cache
@ -3652,7 +3652,7 @@ func TestSearchTask_Requery(t *testing.T) {
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(schema, nil).Maybe()
cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe()
cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe()
cache.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]nodeInfo{}, nil).Maybe()
cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
globalMetaCache = cache