From 0b4a17c22b60330d16e0c2fdf4e870ff0c745254 Mon Sep 17 00:00:00 2001 From: wei liu Date: Tue, 17 Jun 2025 08:15:24 +0800 Subject: [PATCH] fix: Fix exclude nodes clearing logic position in load balancer retry (#42577) issue: #42561 Move the exclude nodes clearing logic from ExecuteWithRetry to selectNode after shard leader cache refresh to ensure proper retry behavior: - Remove premature exclude clearing in ExecuteWithRetry that happened before shard leader cache update - Add exclude clearing logic in selectNode after refreshing shard leader cache when all replicas are excluded - Ensure multiple retries can properly update shard leader cache and clear exclude list when needed - Add comprehensive tests for edge cases including empty shard leaders and mixed serviceable node scenarios --------- Signed-off-by: Wei Liu --- internal/proxy/lb_policy.go | 167 ++++++++-------- internal/proxy/lb_policy_test.go | 303 +++++++++++++++++++++++++---- internal/proxy/meta_cache.go | 115 +++++++---- internal/proxy/meta_cache_test.go | 117 +++++++++-- internal/proxy/mock_cache.go | 105 ++++++++-- internal/proxy/task_search_test.go | 4 +- 6 files changed, 605 insertions(+), 206 deletions(-) diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index 9a85add535..3fa20e2c26 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -41,10 +41,8 @@ type ChannelWorkload struct { collectionName string collectionID int64 channel string - shardLeaders []nodeInfo nq int64 exec executeFunc - retryTimes uint } type CollectionWorkLoad struct { @@ -105,51 +103,72 @@ func (lb *LBPolicyImpl) Start(ctx context.Context) { } } -// GetShardLeaders should always retry until ctx done, except the collection is not loaded. -func (lb *LBPolicyImpl) GetShardLeaders(ctx context.Context, dbName string, collName string, collectionID int64, withCache bool) (map[string][]nodeInfo, error) { - var shardLeaders map[string][]nodeInfo - // use retry to handle query coord service not ready +// GetShard will retry until ctx done, except the collection is not loaded. +// return all replicas of shard from cache if withCache is true, otherwise return shard leaders from coord. +func (lb *LBPolicyImpl) GetShard(ctx context.Context, dbName string, collName string, collectionID int64, channel string, withCache bool) ([]nodeInfo, error) { + var shardLeaders []nodeInfo err := retry.Handle(ctx, func() (bool, error) { var err error - shardLeaders, err = globalMetaCache.GetShards(ctx, withCache, dbName, collName, collectionID) - if err != nil { - return !errors.Is(err, merr.ErrCollectionNotLoaded), err - } - return false, nil + shardLeaders, err = globalMetaCache.GetShard(ctx, withCache, dbName, collName, collectionID, channel) + return !errors.Is(err, merr.ErrCollectionNotLoaded), err }) - return shardLeaders, err } +// GetShardLeaderList will retry until ctx done, except the collection is not loaded. +// return all shard(channel) from cache if withCache is true, otherwise return shard leaders from coord. +func (lb *LBPolicyImpl) GetShardLeaderList(ctx context.Context, dbName string, collName string, collectionID int64, withCache bool) ([]string, error) { + var ret []string + err := retry.Handle(ctx, func() (bool, error) { + var err error + ret, err = globalMetaCache.GetShardLeaderList(ctx, dbName, collName, collectionID, withCache) + return !errors.Is(err, merr.ErrCollectionNotLoaded), err + }) + return ret, err +} + // try to select the best node from the available nodes -func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload *ChannelWorkload, excludeNodes typeutil.UniqueSet) (nodeInfo, error) { - log := log.Ctx(ctx) +func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes *typeutil.UniqueSet) (nodeInfo, error) { + log := log.Ctx(ctx).With( + zap.Int64("collectionID", workload.collectionID), + zap.String("channelName", workload.channel), + ) // Select node using specified nodes - trySelectNode := func(nodes []nodeInfo) (nodeInfo, error) { - candidateNodes := make(map[int64]nodeInfo) - serviceableNodes := make(map[int64]nodeInfo) - // Filter nodes based on excludeNodes - for _, node := range nodes { - if !excludeNodes.Contain(node.nodeID) { - if node.serviceable { - serviceableNodes[node.nodeID] = node + trySelectNode := func(withCache bool) (nodeInfo, error) { + shardLeaders, err := lb.GetShard(ctx, workload.db, workload.collectionName, workload.collectionID, workload.channel, withCache) + if err != nil { + log.Warn("failed to get shard delegator", + zap.Error(err)) + return nodeInfo{}, err + } + + // if all available delegator has been excluded even after refresh shard leader cache + // we should clear excludeNodes and try to select node again instead of failing the request at selectNode + if len(shardLeaders) > 0 && len(shardLeaders) == excludeNodes.Len() { + allReplicaExcluded := true + for _, node := range shardLeaders { + if !excludeNodes.Contain(node.nodeID) { + allReplicaExcluded = false + break } - candidateNodes[node.nodeID] = node + } + if allReplicaExcluded { + log.Warn("all replicas are excluded after refresh shard leader cache, clear it and try to select node") + excludeNodes.Clear() } } - var err error + candidateNodes := make(map[int64]nodeInfo) + serviceableNodes := make(map[int64]nodeInfo) defer func() { if err != nil { - candidatesInStr := lo.Map(nodes, func(node nodeInfo, _ int) string { + candidatesInStr := lo.Map(shardLeaders, func(node nodeInfo, _ int) string { return node.String() }) serviceableNodesInStr := lo.Map(lo.Values(serviceableNodes), func(node nodeInfo, _ int) string { return node.String() }) log.Warn("failed to select shard", - zap.Int64("collectionID", workload.collectionID), - zap.String("channelName", workload.channel), zap.Int64s("excluded", excludeNodes.Collect()), zap.String("candidates", strings.Join(candidatesInStr, ", ")), zap.String("serviceableNodes", strings.Join(serviceableNodesInStr, ", ")), @@ -157,8 +176,17 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor } }() + // Filter nodes based on excludeNodes + for _, node := range shardLeaders { + if !excludeNodes.Contain(node.nodeID) { + if node.serviceable { + serviceableNodes[node.nodeID] = node + } + candidateNodes[node.nodeID] = node + } + } if len(candidateNodes) == 0 { - err = merr.WrapErrChannelNotAvailable(workload.channel) + err = merr.WrapErrChannelNotAvailable(workload.channel, "no available shard leaders") return nodeInfo{}, err } @@ -182,23 +210,13 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor return candidateNodes[targetNodeID], nil } - // First attempt with current shard leaders - targetNode, err := trySelectNode(workload.shardLeaders) - // If failed, refresh cache and retry + // First attempt with current shard leaders cache + withShardLeaderCache := true + targetNode, err := trySelectNode(withShardLeaderCache) if err != nil { - globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName) - shardLeaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, false) - if err != nil { - log.Warn("failed to get shard delegator", - zap.Int64("collectionID", workload.collectionID), - zap.String("channelName", workload.channel), - zap.Error(err)) - return nodeInfo{}, err - } - - workload.shardLeaders = shardLeaders[workload.channel] // Second attempt with fresh shard leaders - targetNode, err = trySelectNode(workload.shardLeaders) + withShardLeaderCache = false + targetNode, err = trySelectNode(withShardLeaderCache) if err != nil { return nodeInfo{}, err } @@ -209,20 +227,17 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor // ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes. func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error { + log := log.Ctx(ctx).With( + zap.Int64("collectionID", workload.collectionID), + zap.String("channelName", workload.channel), + ) var lastErr error excludeNodes := typeutil.NewUniqueSet() tryExecute := func() (bool, error) { - // if keeping retry after all nodes are excluded, try to clean excludeNodes - if excludeNodes.Len() == len(workload.shardLeaders) { - excludeNodes.Clear() - } - balancer := lb.getBalancer() - targetNode, err := lb.selectNode(ctx, balancer, &workload, excludeNodes) + targetNode, err := lb.selectNode(ctx, balancer, workload, &excludeNodes) if err != nil { log.Warn("failed to select node for shard", - zap.Int64("collectionID", workload.collectionID), - zap.String("channelName", workload.channel), zap.Int64("nodeID", targetNode.nodeID), zap.Int64s("excluded", excludeNodes.Collect()), zap.Error(err), @@ -238,8 +253,6 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo client, err := lb.clientMgr.GetClient(ctx, targetNode) if err != nil { log.Warn("search/query channel failed, node not available", - zap.Int64("collectionID", workload.collectionID), - zap.String("channelName", workload.channel), zap.Int64("nodeID", targetNode.nodeID), zap.Error(err)) excludeNodes.Insert(targetNode.nodeID) @@ -251,8 +264,6 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo err = workload.exec(ctx, targetNode.nodeID, client, workload.channel) if err != nil { log.Warn("search/query channel failed", - zap.Int64("collectionID", workload.collectionID), - zap.String("channelName", workload.channel), zap.Int64("nodeID", targetNode.nodeID), zap.Error(err)) excludeNodes.Insert(targetNode.nodeID) @@ -263,10 +274,15 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo return true, nil } - // if failed, try to execute with partial result - err := retry.Handle(ctx, tryExecute, retry.Attempts(workload.retryTimes)) + shardLeaders, err := lb.GetShard(ctx, workload.db, workload.collectionName, workload.collectionID, workload.channel, true) if err != nil { - log.Ctx(ctx).Warn("failed to execute with partial result", + log.Warn("failed to get shard leaders", zap.Error(err)) + return err + } + retryTimes := max(lb.retryOnReplica, len(shardLeaders)) + err = retry.Handle(ctx, tryExecute, retry.Attempts(uint(retryTimes))) + if err != nil { + log.Warn("failed to execute", zap.String("channel", workload.channel), zap.Error(err)) } @@ -276,69 +292,54 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo // Execute will execute collection workload in parallel func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error { - dml2leaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, true) + log := log.Ctx(ctx).With( + zap.Int64("collectionID", workload.collectionID), + ) + channelList, err := lb.GetShardLeaderList(ctx, workload.db, workload.collectionName, workload.collectionID, true) if err != nil { - log.Ctx(ctx).Warn("failed to get shards", zap.Error(err)) + log.Warn("failed to get shards", zap.Error(err)) return err } - totalChannels := len(dml2leaders) - if totalChannels == 0 { - log.Ctx(ctx).Info("no shard leaders found", zap.Int64("collectionID", workload.collectionID)) + if len(channelList) == 0 { + log.Info("no shard leaders found", zap.Int64("collectionID", workload.collectionID)) return merr.WrapErrCollectionNotLoaded(workload.collectionID) } wg, _ := errgroup.WithContext(ctx) // Launch a goroutine for each channel - for k, v := range dml2leaders { - channel := k - nodes := v - channelRetryTimes := lb.retryOnReplica - if len(nodes) > 0 { - channelRetryTimes *= len(nodes) - } + for _, channel := range channelList { wg.Go(func() error { return lb.ExecuteWithRetry(ctx, ChannelWorkload{ db: workload.db, collectionName: workload.collectionName, collectionID: workload.collectionID, channel: channel, - shardLeaders: nodes, nq: workload.nq, exec: workload.exec, - retryTimes: uint(channelRetryTimes), }) }) } - return wg.Wait() } // Execute will execute any one channel in collection workload func (lb *LBPolicyImpl) ExecuteOneChannel(ctx context.Context, workload CollectionWorkLoad) error { - dml2leaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, true) + channelList, err := lb.GetShardLeaderList(ctx, workload.db, workload.collectionName, workload.collectionID, true) if err != nil { log.Ctx(ctx).Warn("failed to get shards", zap.Error(err)) return err } // let every request could retry at least twice, which could retry after update shard leader cache - for k, v := range dml2leaders { - channel := k - nodes := v - channelRetryTimes := lb.retryOnReplica - if len(nodes) > 0 { - channelRetryTimes *= len(nodes) - } + for _, channel := range channelList { return lb.ExecuteWithRetry(ctx, ChannelWorkload{ db: workload.db, collectionName: workload.collectionName, collectionID: workload.collectionID, channel: channel, - shardLeaders: nodes, nq: workload.nq, exec: workload.exec, - retryTimes: uint(channelRetryTimes), }) } return fmt.Errorf("no acitvate sheard leader exist for collection: %s", workload.collectionName) diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index c0ce2a598b..20b59a74fc 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -178,14 +178,14 @@ func (s *LBPolicySuite) TestSelectNode() { ctx := context.Background() s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil) - targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{ + targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: s.nodes, - nq: 1, - }, typeutil.NewUniqueSet()) + // shardLeaders: s.nodes, + nq: 1, + }, &typeutil.UniqueSet{}) s.NoError(err) s.Equal(int64(5), targetNode.nodeID) @@ -194,14 +194,14 @@ func (s *LBPolicySuite) TestSelectNode() { s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil) - targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{ + targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: s.nodes, - nq: 1, - }, typeutil.NewUniqueSet()) + // shardLeaders: s.nodes, + nq: 1, + }, &typeutil.UniqueSet{}) s.NoError(err) s.Equal(int64(3), targetNode.nodeID) @@ -209,29 +209,30 @@ func (s *LBPolicySuite) TestSelectNode() { s.lbBalancer.ExpectedCalls = nil s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) - targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{ + targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: []nodeInfo{}, - nq: 1, - }, typeutil.NewUniqueSet()) + // shardLeaders: []nodeInfo{}, + nq: 1, + }, &typeutil.UniqueSet{}) s.ErrorIs(err, merr.ErrNodeNotAvailable) - // test all nodes has been excluded, expected failure + // test all nodes has been excluded, expected clear excludeNodes and try to select node again + excludeNodes := typeutil.NewUniqueSet(s.nodeIDs...) s.lbBalancer.ExpectedCalls = nil s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) - targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{ + targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: s.nodes, - nq: 1, - }, typeutil.NewUniqueSet(s.nodeIDs...)) - s.ErrorIs(err, merr.ErrChannelNotAvailable) + // shardLeaders: s.nodes, + nq: 1, + }, &excludeNodes) + s.ErrorIs(err, merr.ErrNodeNotAvailable) // test get shard leaders failed, retry to select node failed s.lbBalancer.ExpectedCalls = nil @@ -240,14 +241,14 @@ func (s *LBPolicySuite) TestSelectNode() { s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { return nil, merr.ErrServiceUnavailable } - targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{ + targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ db: dbName, collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: s.nodes, - nq: 1, - }, typeutil.NewUniqueSet()) + // shardLeaders: s.nodes, + nq: 1, + }, &typeutil.UniqueSet{}) s.ErrorIs(err, merr.ErrServiceUnavailable) } @@ -265,12 +266,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: s.nodes, nq: 1, exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { return nil }, - retryTimes: 1, }) s.NoError(err) @@ -283,12 +282,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: s.nodes, nq: 1, exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { return nil }, - retryTimes: 1, }) s.ErrorIs(err, merr.ErrNodeNotAvailable) @@ -304,12 +301,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: s.nodes, nq: 1, exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { return nil }, - retryTimes: 1, }) s.Error(err) @@ -327,12 +322,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: s.nodes, nq: 1, exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { return nil }, - retryTimes: 2, }) s.NoError(err) @@ -351,7 +344,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: s.nodes, nq: 1, exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { counter++ @@ -360,7 +352,6 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { } return nil }, - retryTimes: 2, }) s.NoError(err) @@ -376,13 +367,11 @@ func (s *LBPolicySuite) TestExecuteWithRetry() { collectionName: s.collectionName, collectionID: s.collectionID, channel: s.channels[0], - shardLeaders: s.nodes, nq: 1, exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error { _, err := qn.Search(ctx, nil) return err }, - retryTimes: 2, }) s.True(merr.IsCanceledOrTimeout(err)) } @@ -461,7 +450,7 @@ func (s *LBPolicySuite) TestExecute() { }, }) s.Error(err) - s.Equal(int64(26), counter.Load()) + s.Equal(int64(6), counter.Load()) // test get shard leader failed globalMetaCache.DeprecateShardCache(dbName, s.collectionName) @@ -501,7 +490,7 @@ func (s *LBPolicySuite) TestNewLBPolicy() { policy.Close() } -func (s *LBPolicySuite) TestGetShardLeaders() { +func (s *LBPolicySuite) TestGetShard() { ctx := context.Background() // ErrCollectionNotFullyLoaded is retriable, expected to retry until ctx done or success @@ -518,7 +507,7 @@ func (s *LBPolicySuite) TestGetShardLeaders() { return nil, nil } - _, err := s.lbPolicy.GetShardLeaders(ctx, dbName, s.collectionName, s.collectionID, true) + _, err := s.lbPolicy.GetShard(ctx, dbName, s.collectionName, s.collectionID, s.channels[0], true) s.NoError(err) s.Equal(int64(0), counter.Load()) @@ -529,12 +518,250 @@ func (s *LBPolicySuite) TestGetShardLeaders() { counter.Inc() return nil, merr.ErrCollectionNotLoaded } - _, err = s.lbPolicy.GetShardLeaders(ctx, dbName, s.collectionName, s.collectionID, true) + _, err = s.lbPolicy.GetShard(ctx, dbName, s.collectionName, s.collectionID, s.channels[0], true) log.Info("check err", zap.Error(err)) s.Error(err) s.Equal(int64(1), counter.Load()) } +func (s *LBPolicySuite) TestSelectNodeEdgeCases() { + ctx := context.Background() + + // Test case 1: Empty shard leaders after refresh, should fail gracefully + s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable).Times(1) + + // Setup mock to return empty shard leaders + successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} + s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + return &querypb.GetShardLeadersResponse{ + Status: &successStatus, + Shards: []*querypb.ShardLeadersList{ + { + ChannelName: s.channels[0], + NodeIds: []int64{}, // Empty node list + NodeAddrs: []string{}, + Serviceable: []bool{}, + }, + }, + }, nil + } + + excludeNodes := typeutil.NewUniqueSet(s.nodeIDs...) + _, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ + db: dbName, + collectionName: s.collectionName, + collectionID: s.collectionID, + channel: s.channels[0], + nq: 1, + }, &excludeNodes) + s.Error(err) + + log.Info("test case 1") + globalMetaCache.DeprecateShardCache(dbName, s.collectionName) + // Test case 2: Single replica scenario - exclude it, refresh shows same single replica, should clear and succeed + s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Times(1) + + singleNodeList := []int64{1} + s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + return &querypb.GetShardLeadersResponse{ + Status: &successStatus, + Shards: []*querypb.ShardLeadersList{ + { + ChannelName: s.channels[0], + NodeIds: singleNodeList, + NodeAddrs: []string{"localhost:9000"}, + Serviceable: []bool{true}, + }, + }, + }, nil + } + + excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude the single node + targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ + db: dbName, + collectionName: s.collectionName, + collectionID: s.collectionID, + channel: s.channels[0], + nq: 1, + }, &excludeNodes) + s.NoError(err) + s.Equal(int64(1), targetNode.nodeID) + s.Equal(0, excludeNodes.Len()) // Should be cleared + + globalMetaCache.DeprecateShardCache(dbName, s.collectionName) + mixedNodeIDs := []int64{1, 2, 3} + + s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable).Times(1) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil).Times(1) + + s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + return &querypb.GetShardLeadersResponse{ + Status: &successStatus, + Shards: []*querypb.ShardLeadersList{ + { + ChannelName: s.channels[0], + NodeIds: mixedNodeIDs, + NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, + Serviceable: []bool{true, false, true}, + }, + }, + }, nil + } + + excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude node 1, node 3 should be available + targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ + db: dbName, + collectionName: s.collectionName, + collectionID: s.collectionID, + channel: s.channels[0], + nq: 1, + }, &excludeNodes) + s.NoError(err) + s.Equal(int64(3), targetNode.nodeID) + s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded +} + +func (s *LBPolicySuite) TestGetShardLeaderList() { + ctx := context.Background() + + // Test normal scenario with cache + channelList, err := s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, true) + s.NoError(err) + s.Equal(len(s.channels), len(channelList)) + s.Contains(channelList, s.channels[0]) + s.Contains(channelList, s.channels[1]) + + // Test without cache - should refresh from coordinator + globalMetaCache.DeprecateShardCache(dbName, s.collectionName) + channelList, err = s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, false) + s.NoError(err) + s.Equal(len(s.channels), len(channelList)) + + // Test error case - collection not loaded + counter := atomic.NewInt64(0) + globalMetaCache.DeprecateShardCache(dbName, s.collectionName) + s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + counter.Inc() + return nil, merr.ErrCollectionNotLoaded + } + _, err = s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, true) + s.Error(err) + s.ErrorIs(err, merr.ErrCollectionNotLoaded) + s.Equal(int64(1), counter.Load()) +} + +func (s *LBPolicySuite) TestSelectNodeWithExcludeClearing() { + ctx := context.Background() + + // Test exclude nodes clearing when all replicas are excluded after cache refresh + s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) + // First attempt fails due to no candidates + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable).Times(1) + // Second attempt succeeds after exclude nodes are cleared + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Times(1) + + // Setup mock to return only excluded nodes first, then same nodes for retry + successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} + s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + return &querypb.GetShardLeadersResponse{ + Status: &successStatus, + Shards: []*querypb.ShardLeadersList{ + { + ChannelName: s.channels[0], + NodeIds: []int64{1, 2}, // All these will be excluded + NodeAddrs: []string{"localhost:9000", "localhost:9001"}, + Serviceable: []bool{true, true}, + }, + }, + }, nil + } + + globalMetaCache.DeprecateShardCache(dbName, s.collectionName) + excludeNodes := typeutil.NewUniqueSet(int64(1), int64(2)) // Exclude all available nodes + targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ + db: dbName, + collectionName: s.collectionName, + collectionID: s.collectionID, + channel: s.channels[0], + nq: 1, + }, &excludeNodes) + + s.NoError(err) + s.Equal(int64(1), targetNode.nodeID) + s.Equal(0, excludeNodes.Len()) // Should be cleared when all replicas were excluded + + // Test exclude nodes NOT cleared when only partial replicas are excluded + s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) + s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(2, nil).Times(1) + + globalMetaCache.DeprecateShardCache(dbName, s.collectionName) + s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + return &querypb.GetShardLeadersResponse{ + Status: &successStatus, + Shards: []*querypb.ShardLeadersList{ + { + ChannelName: s.channels[0], + NodeIds: []int64{1, 2, 3}, // Node 2 and 3 are still available + NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, + Serviceable: []bool{true, true, true}, + }, + }, + }, nil + } + + excludeNodes = typeutil.NewUniqueSet(int64(1)) // Only exclude node 1 + targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ + db: dbName, + collectionName: s.collectionName, + collectionID: s.collectionID, + channel: s.channels[0], + nq: 1, + }, &excludeNodes) + + s.NoError(err) + s.Equal(int64(2), targetNode.nodeID) + s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded + + // Test empty shard leaders scenario + s.lbBalancer.ExpectedCalls = nil + s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) + + globalMetaCache.DeprecateShardCache(dbName, s.collectionName) + s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + return &querypb.GetShardLeadersResponse{ + Status: &successStatus, + Shards: []*querypb.ShardLeadersList{ + { + ChannelName: s.channels[0], + NodeIds: []int64{}, // Empty shard leaders + NodeAddrs: []string{}, + Serviceable: []bool{}, + }, + }, + }, nil + } + + excludeNodes = typeutil.NewUniqueSet(int64(1)) + _, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ + db: dbName, + collectionName: s.collectionName, + collectionID: s.collectionID, + channel: s.channels[0], + nq: 1, + }, &excludeNodes) + + s.Error(err) + s.Equal(1, excludeNodes.Len()) // Should NOT be cleared for empty shard leaders +} + func TestLBPolicySuite(t *testing.T) { suite.Run(t, new(LBPolicySuite)) } diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index bb5f72ec95..80e0afc692 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -70,7 +70,8 @@ type Cache interface { GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error) // GetCollectionSchema get collection's schema. GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error) - GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) + GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) + GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error) DeprecateShardCache(database, collectionName string) InvalidateShardLeaderCache(collections []int64) ListShardLocation() map[int64]nodeInfo @@ -283,6 +284,14 @@ type shardLeaders struct { shardLeaders map[string][]nodeInfo } +func (sl *shardLeaders) Get(channel string) []nodeInfo { + return sl.shardLeaders[channel] +} + +func (sl *shardLeaders) GetShardLeaderList() []string { + return lo.Keys(sl.shardLeaders) +} + type shardLeadersReader struct { leaders *shardLeaders idx int64 @@ -944,15 +953,39 @@ func (m *MetaCache) UpdateCredential(credInfo *internalpb.CredentialInfo) { m.credMap[username].Sha256Password = credInfo.Sha256Password } -// GetShards update cache if withCache == false -func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, collectionName string, collectionID int64) (map[string][]nodeInfo, error) { - method := "GetShards" - log := log.Ctx(ctx).With( - zap.String("db", database), - zap.String("collectionName", collectionName), - zap.Int64("collectionID", collectionID)) - +func (m *MetaCache) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) { + method := "GetShard" // check cache first + cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method) + if cacheShardLeaders == nil || !withCache { + // refresh shard leader cache + newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID) + if err != nil { + return nil, err + } + cacheShardLeaders = newShardLeaders + } + + return cacheShardLeaders.Get(channel), nil +} + +func (m *MetaCache) GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error) { + method := "GetShardLeaderList" + // check cache first + cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method) + if cacheShardLeaders == nil || !withCache { + // refresh shard leader cache + newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID) + if err != nil { + return nil, err + } + cacheShardLeaders = newShardLeaders + } + + return cacheShardLeaders.GetShardLeaderList(), nil +} + +func (m *MetaCache) getCachedShardLeaders(database, collectionName, caller string) *shardLeaders { m.leaderMut.RLock() var cacheShardLeaders *shardLeaders db, ok := m.collLeader[database] @@ -962,45 +995,44 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col cacheShardLeaders = db[collectionName] } m.leaderMut.RUnlock() - if withCache { - if cacheShardLeaders != nil { - metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheHitLabel).Inc() - iterator := cacheShardLeaders.GetReader() - return iterator.Shuffle(), nil - } - metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() + if cacheShardLeaders != nil { + metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheHitLabel).Inc() + } else { + metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheMissLabel).Inc() } - info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID) - if err != nil { - return nil, err - } + return cacheShardLeaders +} + +func (m *MetaCache) updateShardLocationCache(ctx context.Context, database, collectionName string, collectionID int64) (*shardLeaders, error) { + log := log.Ctx(ctx).With( + zap.String("db", database), + zap.String("collectionName", collectionName), + zap.Int64("collectionID", collectionID)) + + method := "updateShardLocationCache" + tr := timerecord.NewTimeRecorder(method) + defer metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method). + Observe(float64(tr.ElapseSpan().Milliseconds())) req := &querypb.GetShardLeadersRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders), commonpbutil.WithSourceID(paramtable.GetNodeID()), ), - CollectionID: info.collID, + CollectionID: collectionID, WithUnserviceableShards: true, } - - tr := timerecord.NewTimeRecorder("UpdateShardCache") resp, err := m.mixCoord.GetShardLeaders(ctx, req) - if err != nil { - return nil, err - } - if err = merr.Error(resp.GetStatus()); err != nil { + if err := merr.CheckRPCCall(resp.GetStatus(), err); err != nil { + log.Error("failed to get shard locations", + zap.Int64("collectionID", collectionID), + zap.Error(err)) return nil, err } shards := parseShardLeaderList2QueryNode(resp.GetShards()) - newShardLeaders := &shardLeaders{ - collectionID: info.collID, - shardLeaders: shards, - idx: atomic.NewInt64(0), - } // convert shards map to string for logging if log.Logger.Level() == zap.DebugLevel { @@ -1015,23 +1047,20 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col log.Debug("update shard leader cache", zap.String("newShardLeaders", strings.Join(shardStr, ", "))) } + newShardLeaders := &shardLeaders{ + collectionID: collectionID, + shardLeaders: shards, + idx: atomic.NewInt64(0), + } + m.leaderMut.Lock() if _, ok := m.collLeader[database]; !ok { m.collLeader[database] = make(map[string]*shardLeaders) } m.collLeader[database][collectionName] = newShardLeaders - iterator := newShardLeaders.GetReader() - ret := iterator.Shuffle() m.leaderMut.Unlock() - nodeInfos := make([]string, 0) - for ch, shardLeader := range newShardLeaders.shardLeaders { - for _, nodeInfo := range shardLeader { - nodeInfos = append(nodeInfos, fmt.Sprintf("channel %s, nodeID: %d, nodeAddr: %s", ch, nodeInfo.nodeID, nodeInfo.address)) - } - } - log.Debug("fill new collection shard leader", zap.Strings("nodeInfos", nodeInfos)) - metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).Observe(float64(tr.ElapseSpan().Milliseconds())) - return ret, nil + + return newShardLeaders, nil } func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]nodeInfo { diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 2921f14e39..f10ef65af6 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -1362,7 +1362,7 @@ func TestMetaCache_GetPartitionError(t *testing.T) { assert.Equal(t, id, typeutil.UniqueID(0)) } -func TestMetaCache_GetShards(t *testing.T) { +func TestMetaCache_GetShard(t *testing.T) { var ( ctx = context.Background() collectionName = "collection1" @@ -1375,13 +1375,13 @@ func TestMetaCache_GetShards(t *testing.T) { require.Nil(t, err) t.Run("No collection in meta cache", func(t *testing.T) { - shards, err := globalMetaCache.GetShards(ctx, true, dbName, "non-exists", 0) + shards, err := globalMetaCache.GetShard(ctx, true, dbName, "non-exists", 0, "channel-1") assert.Error(t, err) assert.Empty(t, shards) }) t.Run("without shardLeaders in collection info invalid shardLeaders", func(t *testing.T) { - shards, err := globalMetaCache.GetShards(ctx, false, dbName, collectionName, collectionID) + shards, err := globalMetaCache.GetShard(ctx, false, dbName, collectionName, collectionID, "channel-1") assert.Error(t, err) assert.Empty(t, shards) }) @@ -1401,14 +1401,12 @@ func TestMetaCache_GetShards(t *testing.T) { }, nil } - shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID) + shards, err := globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1") assert.NoError(t, err) assert.NotEmpty(t, shards) - assert.Equal(t, 1, len(shards)) - assert.Equal(t, 3, len(shards["channel-1"])) + assert.Equal(t, 3, len(shards)) // get from cache - rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) { return &querypb.GetShardLeadersResponse{ Status: &commonpb.Status{ @@ -1418,12 +1416,11 @@ func TestMetaCache_GetShards(t *testing.T) { }, nil } - shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID) + shards, err = globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1") assert.NoError(t, err) assert.NotEmpty(t, shards) - assert.Equal(t, 1, len(shards)) - assert.Equal(t, 3, len(shards["channel-1"])) + assert.Equal(t, 3, len(shards)) }) } @@ -1462,11 +1459,10 @@ func TestMetaCache_ClearShards(t *testing.T) { }, nil } - shards, err := globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID) + shards, err := globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1") require.NoError(t, err) require.NotEmpty(t, shards) - require.Equal(t, 1, len(shards)) - require.Equal(t, 3, len(shards["channel-1"])) + require.Equal(t, 3, len(shards)) globalMetaCache.DeprecateShardCache(dbName, collectionName) @@ -1479,7 +1475,7 @@ func TestMetaCache_ClearShards(t *testing.T) { }, nil } - shards, err = globalMetaCache.GetShards(ctx, true, dbName, collectionName, collectionID) + shards, err = globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1") assert.Error(t, err) assert.Empty(t, shards) }) @@ -1843,18 +1839,18 @@ func TestMetaCache_InvalidateShardLeaderCache(t *testing.T) { }, }, nil } - nodeInfos, err := globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1) + nodeInfos, err := globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1") assert.NoError(t, err) - assert.Len(t, nodeInfos["channel-1"], 3) + assert.Len(t, nodeInfos, 3) assert.Equal(t, called.Load(), int32(1)) - globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1) + globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1") assert.Equal(t, called.Load(), int32(1)) globalMetaCache.InvalidateShardLeaderCache([]int64{1}) - nodeInfos, err = globalMetaCache.GetShards(ctx, true, dbName, "collection1", 1) + nodeInfos, err = globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1") assert.NoError(t, err) - assert.Len(t, nodeInfos["channel-1"], 3) + assert.Len(t, nodeInfos, 3) assert.Equal(t, called.Load(), int32(2)) } @@ -2155,3 +2151,86 @@ func TestMetaCache_Parallel(t *testing.T) { _, ok = cache.collInfo[dbName]["collection1"] assert.True(t, ok) } + +func TestMetaCache_GetShardLeaderList(t *testing.T) { + var ( + ctx = context.Background() + collectionName = "collection1" + collectionID = int64(1) + ) + + rootCoord := &MockMixCoordClientInterface{} + mgr := newShardClientMgr() + err := InitMetaCache(ctx, rootCoord, mgr) + require.Nil(t, err) + + t.Run("No collection in meta cache", func(t *testing.T) { + channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, "non-exists", 0, true) + assert.Error(t, err) + assert.Empty(t, channels) + }) + + t.Run("Get channel list without cache", func(t *testing.T) { + rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) { + return &querypb.GetShardLeadersResponse{ + Status: merr.Success(), + Shards: []*querypb.ShardLeadersList{ + { + ChannelName: "channel-1", + NodeIds: []int64{1, 2, 3}, + NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, + Serviceable: []bool{true, true, true}, + }, + { + ChannelName: "channel-2", + NodeIds: []int64{4, 5, 6}, + NodeAddrs: []string{"localhost:9003", "localhost:9004", "localhost:9005"}, + Serviceable: []bool{true, true, true}, + }, + }, + }, nil + } + + channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, false) + assert.NoError(t, err) + assert.Equal(t, 2, len(channels)) + assert.Contains(t, channels, "channel-1") + assert.Contains(t, channels, "channel-2") + }) + + t.Run("Get channel list with cache", func(t *testing.T) { + // First call should populate cache + channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true) + assert.NoError(t, err) + assert.Equal(t, 2, len(channels)) + + // Mock should return error but cache should be used + rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) { + return &querypb.GetShardLeadersResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "should not be called when using cache", + }, + }, nil + } + + channels, err = globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true) + assert.NoError(t, err) + assert.Equal(t, 2, len(channels)) + assert.Contains(t, channels, "channel-1") + assert.Contains(t, channels, "channel-2") + }) + + t.Run("Error from coordinator", func(t *testing.T) { + // Deprecate cache first + globalMetaCache.DeprecateShardCache(dbName, collectionName) + + rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) { + return nil, errors.New("coordinator error") + } + + channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true) + assert.Error(t, err) + assert.Empty(t, channels) + }) +} diff --git a/internal/proxy/mock_cache.go b/internal/proxy/mock_cache.go index 1bd04c3185..0aede95362 100644 --- a/internal/proxy/mock_cache.go +++ b/internal/proxy/mock_cache.go @@ -757,29 +757,29 @@ func (_c *MockCache_GetPrivilegeInfo_Call) RunAndReturn(run func(context.Context return _c } -// GetShards provides a mock function with given fields: ctx, withCache, database, collectionName, collectionID -func (_m *MockCache) GetShards(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64) (map[string][]nodeInfo, error) { - ret := _m.Called(ctx, withCache, database, collectionName, collectionID) +// GetShard provides a mock function with given fields: ctx, withCache, database, collectionName, collectionID, channel +func (_m *MockCache) GetShard(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) { + ret := _m.Called(ctx, withCache, database, collectionName, collectionID, channel) if len(ret) == 0 { - panic("no return value specified for GetShards") + panic("no return value specified for GetShard") } - var r0 map[string][]nodeInfo + var r0 []nodeInfo var r1 error - if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64) (map[string][]nodeInfo, error)); ok { - return rf(ctx, withCache, database, collectionName, collectionID) + if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) ([]nodeInfo, error)); ok { + return rf(ctx, withCache, database, collectionName, collectionID, channel) } - if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64) map[string][]nodeInfo); ok { - r0 = rf(ctx, withCache, database, collectionName, collectionID) + if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) []nodeInfo); ok { + r0 = rf(ctx, withCache, database, collectionName, collectionID, channel) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(map[string][]nodeInfo) + r0 = ret.Get(0).([]nodeInfo) } } - if rf, ok := ret.Get(1).(func(context.Context, bool, string, string, int64) error); ok { - r1 = rf(ctx, withCache, database, collectionName, collectionID) + if rf, ok := ret.Get(1).(func(context.Context, bool, string, string, int64, string) error); ok { + r1 = rf(ctx, withCache, database, collectionName, collectionID, channel) } else { r1 = ret.Error(1) } @@ -787,34 +787,97 @@ func (_m *MockCache) GetShards(ctx context.Context, withCache bool, database str return r0, r1 } -// MockCache_GetShards_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShards' -type MockCache_GetShards_Call struct { +// MockCache_GetShard_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShard' +type MockCache_GetShard_Call struct { *mock.Call } -// GetShards is a helper method to define mock.On call +// GetShard is a helper method to define mock.On call // - ctx context.Context // - withCache bool // - database string // - collectionName string // - collectionID int64 -func (_e *MockCache_Expecter) GetShards(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}, collectionID interface{}) *MockCache_GetShards_Call { - return &MockCache_GetShards_Call{Call: _e.mock.On("GetShards", ctx, withCache, database, collectionName, collectionID)} +// - channel string +func (_e *MockCache_Expecter) GetShard(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}, collectionID interface{}, channel interface{}) *MockCache_GetShard_Call { + return &MockCache_GetShard_Call{Call: _e.mock.On("GetShard", ctx, withCache, database, collectionName, collectionID, channel)} } -func (_c *MockCache_GetShards_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64)) *MockCache_GetShards_Call { +func (_c *MockCache_GetShard_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string)) *MockCache_GetShard_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string), args[4].(int64)) + run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string), args[4].(int64), args[5].(string)) }) return _c } -func (_c *MockCache_GetShards_Call) Return(_a0 map[string][]nodeInfo, _a1 error) *MockCache_GetShards_Call { +func (_c *MockCache_GetShard_Call) Return(_a0 []nodeInfo, _a1 error) *MockCache_GetShard_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockCache_GetShards_Call) RunAndReturn(run func(context.Context, bool, string, string, int64) (map[string][]nodeInfo, error)) *MockCache_GetShards_Call { +func (_c *MockCache_GetShard_Call) RunAndReturn(run func(context.Context, bool, string, string, int64, string) ([]nodeInfo, error)) *MockCache_GetShard_Call { + _c.Call.Return(run) + return _c +} + +// GetShardLeaderList provides a mock function with given fields: ctx, database, collectionName, collectionID, withCache +func (_m *MockCache) GetShardLeaderList(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool) ([]string, error) { + ret := _m.Called(ctx, database, collectionName, collectionID, withCache) + + if len(ret) == 0 { + panic("no return value specified for GetShardLeaderList") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) ([]string, error)); ok { + return rf(ctx, database, collectionName, collectionID, withCache) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) []string); ok { + r0 = rf(ctx, database, collectionName, collectionID, withCache) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, int64, bool) error); ok { + r1 = rf(ctx, database, collectionName, collectionID, withCache) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCache_GetShardLeaderList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShardLeaderList' +type MockCache_GetShardLeaderList_Call struct { + *mock.Call +} + +// GetShardLeaderList is a helper method to define mock.On call +// - ctx context.Context +// - database string +// - collectionName string +// - collectionID int64 +// - withCache bool +func (_e *MockCache_Expecter) GetShardLeaderList(ctx interface{}, database interface{}, collectionName interface{}, collectionID interface{}, withCache interface{}) *MockCache_GetShardLeaderList_Call { + return &MockCache_GetShardLeaderList_Call{Call: _e.mock.On("GetShardLeaderList", ctx, database, collectionName, collectionID, withCache)} +} + +func (_c *MockCache_GetShardLeaderList_Call) Run(run func(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool)) *MockCache_GetShardLeaderList_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(int64), args[4].(bool)) + }) + return _c +} + +func (_c *MockCache_GetShardLeaderList_Call) Return(_a0 []string, _a1 error) *MockCache_GetShardLeaderList_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCache_GetShardLeaderList_Call) RunAndReturn(run func(context.Context, string, string, int64, bool) ([]string, error)) *MockCache_GetShardLeaderList_Call { _c.Call.Return(run) return _c } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index a6b24084ff..1b7a2690c9 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1152,7 +1152,7 @@ func TestSearchTask_WithFunctions(t *testing.T) { cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(info, nil).Maybe() cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe() cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe() - cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe() + cache.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]nodeInfo{}, nil).Maybe() cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe() globalMetaCache = cache @@ -3652,7 +3652,7 @@ func TestSearchTask_Requery(t *testing.T) { cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(schema, nil).Maybe() cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe() cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe() - cache.EXPECT().GetShards(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(map[string][]nodeInfo{}, nil).Maybe() + cache.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]nodeInfo{}, nil).Maybe() cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe() globalMetaCache = cache