diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index 59e746d367..e0971ea400 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -106,9 +106,8 @@ func (lb *LBPolicyImpl) GetShardLeaders(ctx context.Context, dbName string, coll var err error shardLeaders, err = globalMetaCache.GetShards(ctx, withCache, dbName, collName, collectionID) if err != nil { - return !errors.Is(err, merr.ErrCollectionLoaded), err + return !errors.Is(err, merr.ErrCollectionNotLoaded), err } - return false, nil }) diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index abaf3c215e..d427dd0b8e 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -21,9 +21,12 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/pingcap/log" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "go.uber.org/atomic" + "go.uber.org/zap" + "google.golang.org/grpc" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -456,6 +459,42 @@ func (s *LBPolicySuite) TestNewLBPolicy() { policy.Close() } +func (s *LBPolicySuite) TestGetShardLeaders() { + ctx := context.Background() + + // ErrCollectionNotFullyLoaded is retriable, expected to retry until ctx done or success + counter := atomic.NewInt64(0) + globalMetaCache.DeprecateShardCache(dbName, s.collectionName) + s.qc.ExpectedCalls = nil + s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + counter.Inc() + return nil, merr.ErrCollectionNotFullyLoaded + }).Times(5) + s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + log.Info("return rpc success") + return nil, nil + }).Times(5) + _, err := s.lbPolicy.GetShardLeaders(ctx, dbName, s.collectionName, s.collectionID, true) + s.NoError(err) + s.Equal(int64(5), counter.Load()) + + // ErrServiceUnavailable is not retriable, expected to fail fast + counter.Store(0) + globalMetaCache.DeprecateShardCache(dbName, s.collectionName) + s.qc.ExpectedCalls = nil + s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { + counter.Inc() + return nil, merr.ErrCollectionNotLoaded + }) + _, err = s.lbPolicy.GetShardLeaders(ctx, dbName, s.collectionName, s.collectionID, true) + log.Info("check err", zap.Error(err)) + s.Error(err) + s.Equal(int64(1), counter.Load()) +} + func TestLBPolicySuite(t *testing.T) { suite.Run(t, new(LBPolicySuite)) }