diff --git a/internal/proxy/task_policies.go b/internal/proxy/task_policies.go index e8948edec5..f2886323a3 100644 --- a/internal/proxy/task_policies.go +++ b/internal/proxy/task_policies.go @@ -51,7 +51,9 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy for err != nil && current < replicaNum { currentID := leaders.GetNodeIds()[current] if err != errBegin { - log.Warn("retry with another QueryNode", zap.String("leader", leaders.GetChannelName()), zap.Int64("nodeID", currentID)) + log.Warn("retry with another QueryNode", + zap.Int("retries numbers", current), + zap.String("leader", leaders.GetChannelName()), zap.Int64("nodeID", currentID)) } qn, err = getQueryNodePolicy(ctx, leaders.GetNodeAddrs()[current]) diff --git a/internal/proxy/task_policies_test.go b/internal/proxy/task_policies_test.go new file mode 100644 index 0000000000..e697d927f4 --- /dev/null +++ b/internal/proxy/task_policies_test.go @@ -0,0 +1,109 @@ +package proxy + +import ( + "context" + "fmt" + "testing" + + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/types" + "github.com/stretchr/testify/require" +) + +func TestRoundRobinPolicy(t *testing.T) { + var ( + getQueryNodePolicy = mockGetQueryNodePolicy + ctx = context.TODO() + ) + + t.Run("All fails", func(t *testing.T) { + allFailTests := []struct { + leaderIDs []UniqueID + + description string + }{ + {[]UniqueID{1}, "one invalid shard leader"}, + {[]UniqueID{1, 2}, "two invalid shard leaders"}, + {[]UniqueID{1, 1}, "two invalid same shard leaders"}, + } + + for _, test := range allFailTests { + t.Run(test.description, func(t *testing.T) { + query := (&mockQuery{isvalid: false}).query + + leaders := &querypb.ShardLeadersList{ + ChannelName: t.Name(), + NodeIds: test.leaderIDs, + NodeAddrs: make([]string, len(test.leaderIDs)), + } + err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders) + require.Error(t, err) + }) + } + }) + + t.Run("Pass at the first try", func(t *testing.T) { + allPassTests := []struct { + leaderIDs []UniqueID + + description string + }{ + {[]UniqueID{1}, "one valid shard leader"}, + {[]UniqueID{1, 2}, "two valid shard leaders"}, + {[]UniqueID{1, 1}, "two valid same shard leaders"}, + } + + for _, test := range allPassTests { + query := (&mockQuery{isvalid: true}).query + leaders := &querypb.ShardLeadersList{ + ChannelName: t.Name(), + NodeIds: test.leaderIDs, + NodeAddrs: make([]string, len(test.leaderIDs)), + } + err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders) + require.NoError(t, err) + } + }) + + t.Run("Pass at the second try", func(t *testing.T) { + passAtLast := []struct { + leaderIDs []UniqueID + + description string + }{ + {[]UniqueID{-1, 2}, "invalid vs valid shard leaders"}, + {[]UniqueID{-1, -1, 3}, "invalid, invalid, and valid shard leaders"}, + } + + for _, test := range passAtLast { + query := (&mockQuery{isvalid: true}).query + leaders := &querypb.ShardLeadersList{ + ChannelName: t.Name(), + NodeIds: test.leaderIDs, + NodeAddrs: make([]string, len(test.leaderIDs)), + } + err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders) + require.NoError(t, err) + } + }) +} + +func mockGetQueryNodePolicy(ctx context.Context, address string) (types.QueryNode, error) { + return &QueryNodeMock{address: address}, nil +} + +type mockQuery struct { + isvalid bool +} + +func (m *mockQuery) query(nodeID UniqueID, qn types.QueryNode) error { + if nodeID == -1 { + return fmt.Errorf("error at condition") + } + + if m.isvalid { + return nil + } + + return fmt.Errorf("mock error in query, NodeID=%d", nodeID) +}