From bd5fab1e539958ed90c0ad7e05bc52b0705842e3 Mon Sep 17 00:00:00 2001 From: aoiasd <45024769+aoiasd@users.noreply.github.com> Date: Fri, 31 Mar 2023 14:36:22 +0800 Subject: [PATCH] Remove merge policy of proxy RoundRobin policy (#23021) Signed-off-by: aoiasd --- internal/proxy/impl.go | 10 +- internal/proxy/task_policies.go | 176 ++++++-------------------- internal/proxy/task_policies_test.go | 114 +---------------- internal/proxy/task_query.go | 4 +- internal/proxy/task_query_test.go | 4 +- internal/proxy/task_search.go | 8 +- internal/proxy/task_search_test.go | 4 +- internal/proxy/task_statistic.go | 4 +- internal/proxy/task_statistic_test.go | 12 +- 9 files changed, 70 insertions(+), 266 deletions(-) diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index b033ea9d51..72737c0353 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -2590,10 +2590,9 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* ), ReqID: paramtable.GetNodeID(), }, - request: request, - qc: node.queryCoord, - queryShardPolicy: mergeRoundRobinPolicy, - shardMgr: node.shardMgr, + request: request, + qc: node.queryCoord, + shardMgr: node.shardMgr, } method := "Query" @@ -2924,8 +2923,7 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista qc: node.queryCoord, ids: ids.IdArray, - queryShardPolicy: mergeRoundRobinPolicy, - shardMgr: node.shardMgr, + shardMgr: node.shardMgr, } log := log.Ctx(ctx).With( diff --git a/internal/proxy/task_policies.go b/internal/proxy/task_policies.go index f7c8568bc8..af69f5d340 100644 --- a/internal/proxy/task_policies.go +++ b/internal/proxy/task_policies.go @@ -2,161 +2,69 @@ package proxy import ( "context" - "fmt" - "strings" - "sync" "github.com/cockroachdb/errors" + "golang.org/x/sync/errgroup" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/merr" "go.uber.org/zap" ) // type pickShardPolicy func(ctx context.Context, mgr *shardClientMgr, query func(UniqueID, types.QueryNode) error, leaders []nodeInfo) error -type pickShardPolicy func(context.Context, *shardClientMgr, func(context.Context, UniqueID, types.QueryNode, []string, int) error, map[string][]nodeInfo) error +type queryFunc func(context.Context, UniqueID, types.QueryNode, ...string) error +type pickShardPolicy func(context.Context, *shardClientMgr, queryFunc, map[string][]nodeInfo) error var ( - errBegin = errors.New("begin error") errInvalidShardLeaders = errors.New("Invalid shard leader") ) -func updateShardsWithRoundRobin(shardsLeaders map[string][]nodeInfo) { - for channelID, leaders := range shardsLeaders { - if len(leaders) <= 1 { - continue - } - - shardsLeaders[channelID] = append(leaders[1:], leaders[0]) - } -} - -// mergeErrSet merges all errors in ErrSet -func mergeErrSet(errSet map[string]error) error { - var builder strings.Builder - for channel, err := range errSet { - if err == nil { - continue - } - - builder.WriteString(fmt.Sprintf("Channel: %s returns err: %s", channel, err.Error())) - } - return errors.New(builder.String()) -} - -// group dml shard leader with same nodeID -func groupShardleadersWithSameQueryNode( - ctx context.Context, - shard2leaders map[string][]nodeInfo, - nexts map[string]int, errSet map[string]error, - mgr *shardClientMgr) (map[int64][]string, map[int64]types.QueryNode, error) { - // check if all leaders were checked - for dml, idx := range nexts { - if idx >= len(shard2leaders[dml]) { - log.Ctx(ctx).Warn("no shard leaders were available", - zap.String("channel", dml), - zap.String("leaders", fmt.Sprintf("%v", shard2leaders[dml]))) - if err, ok := errSet[dml]; ok { - return nil, nil, err - } - return nil, nil, fmt.Errorf("no available shard leader") - } - } - qnSet := make(map[int64]types.QueryNode) - node2dmls := make(map[int64][]string) - updates := make(map[string]int) - - for dml, idx := range nexts { - updates[dml] = idx + 1 - nodeInfo := shard2leaders[dml][idx] - if _, ok := qnSet[nodeInfo.nodeID]; !ok { - qn, err := mgr.GetClient(ctx, nodeInfo.nodeID) - if err != nil { - log.Ctx(ctx).Warn("failed to get shard leader", zap.Int64("nodeID", nodeInfo.nodeID), zap.Error(err)) - // if get client failed, just record error and wait for next round to get client and do query - errSet[dml] = err - continue - } - qnSet[nodeInfo.nodeID] = qn - } - if _, ok := node2dmls[nodeInfo.nodeID]; !ok { - node2dmls[nodeInfo.nodeID] = make([]string, 0) - } - node2dmls[nodeInfo.nodeID] = append(node2dmls[nodeInfo.nodeID], dml) - } - // update idxes - for dml, idx := range updates { - nexts[dml] = idx - } - return node2dmls, qnSet, nil -} - -// mergeRoundRobinPolicy first group shard leaders with same querynode, then do the query with multiple dml channels -// if request failed, it finds shard leader for failed dml channels, and again groups shard leaders and do the query +// RoundRobinPolicy do the query with multiple dml channels +// if request failed, it finds shard leader for failed dml channels // -// Suppose qn0 is the shard leader for dml-channel0 and dml-channel1, if search for dml-channel0 succeeded, but -// failed for dml-channel1. In this case, an error returned from qn0, and next shard leaders for dml-channel0 and dml-channel1 will be -// retrieved and dml-channel0 therefore will again be searched. -// -// TODO: In this senario, qn0 should return a partial success results for dml-channel0, and only retrys for dml-channel1 -func mergeRoundRobinPolicy( +func RoundRobinPolicy( ctx context.Context, mgr *shardClientMgr, - query func(context.Context, UniqueID, types.QueryNode, []string, int) error, + query queryFunc, dml2leaders map[string][]nodeInfo) error { - nexts := make(map[string]int) - errSet := make(map[string]error) // record err for dml channels - totalChannelNum := len(dml2leaders) - for dml := range dml2leaders { - nexts[dml] = 0 - } - for len(nexts) > 0 { - node2dmls, nodeset, err := groupShardleadersWithSameQueryNode(ctx, dml2leaders, nexts, errSet, mgr) - if err != nil { - log.Ctx(ctx).Warn("failed to search/query with round-robin policy", zap.Error(mergeErrSet(errSet))) - return err - } - wg := &sync.WaitGroup{} - mu := &sync.Mutex{} - wg.Add(len(node2dmls)) - for nodeID, channels := range node2dmls { - nodeID := nodeID - channels := channels - qn := nodeset[nodeID] - go func() { - defer wg.Done() - if err := query(ctx, nodeID, qn, channels, totalChannelNum); err != nil { - log.Ctx(ctx).Warn("failed to do query with node", zap.Int64("nodeID", nodeID), - zap.Strings("dmlChannels", channels), zap.Error(err)) - mu.Lock() - defer mu.Unlock() - for _, ch := range channels { - errSet[ch] = err - } - return - } - mu.Lock() - defer mu.Unlock() - for _, channel := range channels { - delete(nexts, channel) - delete(errSet, channel) - } - }() - } - wg.Wait() - if len(nexts) > 0 { - nextSet := make(map[string]int64) - for dml, idx := range nexts { - if idx >= len(dml2leaders[dml]) { - nextSet[dml] = -1 - } else { - nextSet[dml] = dml2leaders[dml][idx].nodeID - } + + queryChannel := func(ctx context.Context, channel string) error { + var combineErr error + leaders := dml2leaders[channel] + + for _, target := range leaders { + qn, err := mgr.GetClient(ctx, target.nodeID) + if err != nil { + log.Warn("query channel failed, node not available", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err)) + combineErr = merr.Combine(combineErr, err) + continue } - log.Ctx(ctx).Warn("retry another query node with round robin", zap.Any("Nexts", nextSet)) + err = query(ctx, target.nodeID, qn, channel) + if err != nil { + log.Warn("query channel failed", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err)) + combineErr = merr.Combine(combineErr, err) + continue + } + return nil } + + log.Ctx(ctx).Error("failed to do query on all shard leader", + zap.String("channel", channel), zap.Error(combineErr)) + return combineErr } - return nil + + wg, ctx := errgroup.WithContext(ctx) + for channel := range dml2leaders { + channel := channel + wg.Go(func() error { + err := queryChannel(ctx, channel) + return err + }) + } + + err := wg.Wait() + return err } diff --git a/internal/proxy/task_policies_test.go b/internal/proxy/task_policies_test.go index 4497058078..5c83e8612e 100644 --- a/internal/proxy/task_policies_test.go +++ b/internal/proxy/task_policies_test.go @@ -8,116 +8,12 @@ import ( "sync" "testing" - "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/types" "github.com/stretchr/testify/assert" - - "go.uber.org/zap" ) -func TestUpdateShardsWithRoundRobin(t *testing.T) { - list := map[string][]nodeInfo{ - "channel-1": { - {1, "addr1"}, - {2, "addr2"}, - }, - "channel-2": { - {20, "addr20"}, - {21, "addr21"}, - }, - } - - updateShardsWithRoundRobin(list) - - assert.Equal(t, int64(2), list["channel-1"][0].nodeID) - assert.Equal(t, "addr2", list["channel-1"][0].address) - assert.Equal(t, int64(21), list["channel-2"][0].nodeID) - assert.Equal(t, "addr21", list["channel-2"][0].address) - - t.Run("check print", func(t *testing.T) { - qns := []nodeInfo{ - {1, "addr1"}, - {2, "addr2"}, - {20, "addr20"}, - {21, "addr21"}, - } - - res := fmt.Sprintf("list: %v", qns) - - log.Debug("Check String func", - zap.Any("Any", qns), - zap.Any("ok", qns[0]), - zap.String("ok2", res), - ) - - }) -} - -func TestGroupShardLeadersWithSameQueryNode(t *testing.T) { - var err error - - var ( - ctx = context.TODO() - ) - - mgr := newShardClientMgr() - - shard2leaders := map[string][]nodeInfo{ - "c0": {{nodeID: 0, address: "fake"}, {nodeID: 1, address: "fake"}, {nodeID: 2, address: "fake"}}, - "c1": {{nodeID: 1, address: "fake"}, {nodeID: 2, address: "fake"}, {nodeID: 3, address: "fake"}}, - "c2": {{nodeID: 0, address: "fake"}, {nodeID: 2, address: "fake"}, {nodeID: 3, address: "fake"}}, - "c3": {{nodeID: 1, address: "fake"}, {nodeID: 3, address: "fake"}, {nodeID: 4, address: "fake"}}, - } - mgr.UpdateShardLeaders(nil, shard2leaders) - nexts := map[string]int{ - "c0": 0, - "c1": 0, - "c2": 0, - "c3": 0, - } - errSet := map[string]error{} - node2dmls, qnSet, err := groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr) - assert.Nil(t, err) - for nodeID := range node2dmls { - sort.Slice(node2dmls[nodeID], func(i, j int) bool { return node2dmls[nodeID][i] < node2dmls[nodeID][j] }) - } - - cli0, err := mgr.GetClient(ctx, 0) - assert.Nil(t, err) - cli1, err := mgr.GetClient(ctx, 1) - assert.Nil(t, err) - cli2, err := mgr.GetClient(ctx, 2) - assert.Nil(t, err) - cli3, err := mgr.GetClient(ctx, 3) - assert.Nil(t, err) - - assert.Equal(t, node2dmls, map[int64][]string{0: {"c0", "c2"}, 1: {"c1", "c3"}}) - assert.Equal(t, qnSet, map[int64]types.QueryNode{0: cli0, 1: cli1}) - assert.Equal(t, nexts, map[string]int{"c0": 1, "c1": 1, "c2": 1, "c3": 1}) - // delete client1 in client mgr - delete(mgr.clients.data, 1) - node2dmls, qnSet, err = groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr) - assert.Nil(t, err) - for nodeID := range node2dmls { - sort.Slice(node2dmls[nodeID], func(i, j int) bool { return node2dmls[nodeID][i] < node2dmls[nodeID][j] }) - } - assert.Equal(t, node2dmls, map[int64][]string{2: {"c1", "c2"}, 3: {"c3"}}) - assert.Equal(t, qnSet, map[int64]types.QueryNode{2: cli2, 3: cli3}) - assert.Equal(t, nexts, map[string]int{"c0": 2, "c1": 2, "c2": 2, "c3": 2}) - assert.NotNil(t, errSet["c0"]) - - nexts["c0"] = 3 - _, _, err = groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr) - assert.True(t, strings.Contains(err.Error(), errSet["c0"].Error())) - - nexts["c0"] = 2 - nexts["c1"] = 3 - _, _, err = groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr) - assert.Equal(t, err, fmt.Errorf("no available shard leader")) -} - -func TestMergeRoundRobinPolicy(t *testing.T) { +func TestRoundRobinPolicy(t *testing.T) { var err error var ( @@ -137,7 +33,7 @@ func TestMergeRoundRobinPolicy(t *testing.T) { querier := &mockQuery{} querier.init() - err = mergeRoundRobinPolicy(ctx, mgr, querier.query, shard2leaders) + err = RoundRobinPolicy(ctx, mgr, querier.query, shard2leaders) assert.Nil(t, err) assert.Equal(t, querier.records(), map[UniqueID][]string{0: {"c0", "c2"}, 1: {"c1", "c3"}}) @@ -145,7 +41,7 @@ func TestMergeRoundRobinPolicy(t *testing.T) { querier.init() querier.failset[0] = mockerr - err = mergeRoundRobinPolicy(ctx, mgr, querier.query, shard2leaders) + err = RoundRobinPolicy(ctx, mgr, querier.query, shard2leaders) assert.Nil(t, err) assert.Equal(t, querier.records(), map[int64][]string{1: {"c0", "c1", "c3"}, 2: {"c2"}}) @@ -153,7 +49,7 @@ func TestMergeRoundRobinPolicy(t *testing.T) { querier.failset[0] = mockerr querier.failset[2] = mockerr querier.failset[3] = mockerr - err = mergeRoundRobinPolicy(ctx, mgr, querier.query, shard2leaders) + err = RoundRobinPolicy(ctx, mgr, querier.query, shard2leaders) assert.True(t, strings.Contains(err.Error(), mockerr.Error())) } @@ -167,7 +63,7 @@ type mockQuery struct { failset map[UniqueID]error } -func (m *mockQuery) query(_ context.Context, nodeID UniqueID, qn types.QueryNode, chs []string, _ int) error { +func (m *mockQuery) query(_ context.Context, nodeID UniqueID, qn types.QueryNode, chs ...string) error { m.mu.Lock() defer m.mu.Unlock() if err, ok := m.failset[nodeID]; ok { diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 5ae0c5c483..bcb1307051 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -243,7 +243,7 @@ func (t *queryTask) createPlan(ctx context.Context) error { func (t *queryTask) PreExecute(ctx context.Context) error { if t.queryShardPolicy == nil { - t.queryShardPolicy = mergeRoundRobinPolicy + t.queryShardPolicy = RoundRobinPolicy } t.Base.MsgType = commonpb.MsgType_Retrieve @@ -454,7 +454,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error { return nil } -func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string, channelNum int) error { +func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs ...string) error { retrieveReq := typeutil.Clone(t.RetrieveRequest) retrieveReq.GetBase().TargetID = nodeID req := &querypb.QueryRequest{ diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 44a1716d61..8570705920 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -44,7 +44,7 @@ func TestQueryTask_all(t *testing.T) { expr = fmt.Sprintf("%s > 0", testInt64Field) hitNum = 10 - errPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string, int) error, map[string][]nodeInfo) error { + errPolicy = func(context.Context, *shardClientMgr, queryFunc, map[string][]nodeInfo) error { return fmt.Errorf("fake error") } ) @@ -181,7 +181,7 @@ func TestQueryTask_all(t *testing.T) { task.queryShardPolicy = errPolicy assert.Error(t, task.Execute(ctx)) - task.queryShardPolicy = mergeRoundRobinPolicy + task.queryShardPolicy = RoundRobinPolicy result1 := &internalpb.RetrieveResults{ Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult}, Status: &commonpb.Status{ diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 0a79ac8442..9f828c59a9 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -52,6 +52,7 @@ type searchTask struct { qc types.QueryCoord tr *timerecord.TimeRecorder collectionName string + channelNum int32 schema *schemapb.CollectionSchema offset int64 @@ -207,7 +208,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { defer sp.End() if t.searchShardPolicy == nil { - t.searchShardPolicy = mergeRoundRobinPolicy + t.searchShardPolicy = RoundRobinPolicy } t.Base.MsgType = commonpb.MsgType_Search @@ -358,6 +359,7 @@ func (t *searchTask) Execute(ctx context.Context) error { } t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders)) t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders)) + t.channelNum = int32(len(shard2Leaders)) if err := t.searchShardPolicy(ctx, t.shardMgr, t.searchShard, shard2Leaders); err != nil { log.Warn("failed to do search", zap.Error(err), zap.String("Shards", fmt.Sprintf("%v", shard2Leaders))) return err @@ -439,14 +441,14 @@ func (t *searchTask) PostExecute(ctx context.Context) error { return nil } -func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string, channelNum int) error { +func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs ...string) error { searchReq := typeutil.Clone(t.SearchRequest) searchReq.GetBase().TargetID = nodeID req := &querypb.SearchRequest{ Req: searchReq, DmlChannels: channelIDs, Scope: querypb.DataScope_All, - TotalChannelNum: int32(channelNum), + TotalChannelNum: t.channelNum, } queryNode := querynode.GetQueryNode() diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index ada257c869..01d389a2b5 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1698,7 +1698,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { shardsNum = int32(2) collectionName = t.Name() + funcutil.GenRandomStr() - errPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string, int) error, map[string][]nodeInfo) error { + errPolicy = func(context.Context, *shardClientMgr, queryFunc, map[string][]nodeInfo) error { return fmt.Errorf("fake error") } ) @@ -1820,7 +1820,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { task.searchShardPolicy = errPolicy assert.Error(t, task.Execute(ctx)) - task.searchShardPolicy = mergeRoundRobinPolicy + task.searchShardPolicy = RoundRobinPolicy qn.searchError = fmt.Errorf("mock error") assert.Error(t, task.Execute(ctx)) diff --git a/internal/proxy/task_statistic.go b/internal/proxy/task_statistic.go index 731ddea178..6a6d5651d8 100644 --- a/internal/proxy/task_statistic.go +++ b/internal/proxy/task_statistic.go @@ -109,7 +109,7 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error { defer sp.End() if g.statisticShardPolicy == nil { - g.statisticShardPolicy = mergeRoundRobinPolicy + g.statisticShardPolicy = RoundRobinPolicy } // TODO: Maybe we should create a new MsgType: GetStatistics? @@ -299,7 +299,7 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro return nil } -func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string, channelNum int) error { +func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs ...string) error { nodeReq := proto.Clone(g.GetStatisticsRequest).(*internalpb.GetStatisticsRequest) nodeReq.Base.TargetID = nodeID req := &querypb.GetStatisticsRequest{ diff --git a/internal/proxy/task_statistic_test.go b/internal/proxy/task_statistic_test.go index 74cc3a5953..826eea088d 100644 --- a/internal/proxy/task_statistic_test.go +++ b/internal/proxy/task_statistic_test.go @@ -155,27 +155,27 @@ func TestStatisticTask_all(t *testing.T) { assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp) task.ctx = ctx - task.statisticShardPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string, int) error, map[string][]nodeInfo) error { + task.statisticShardPolicy = func(context.Context, *shardClientMgr, queryFunc, map[string][]nodeInfo) error { return fmt.Errorf("fake error") } task.fromQueryNode = true assert.Error(t, task.Execute(ctx)) assert.NoError(t, task.PostExecute(ctx)) - task.statisticShardPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string, int) error, map[string][]nodeInfo) error { + task.statisticShardPolicy = func(context.Context, *shardClientMgr, queryFunc, map[string][]nodeInfo) error { return errInvalidShardLeaders } task.fromQueryNode = true assert.Error(t, task.Execute(ctx)) assert.NoError(t, task.PostExecute(ctx)) - task.statisticShardPolicy = mergeRoundRobinPolicy + task.statisticShardPolicy = RoundRobinPolicy task.fromQueryNode = true qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, errors.New("GetStatistics failed")).Times(3) assert.Error(t, task.Execute(ctx)) assert.NoError(t, task.PostExecute(ctx)) - task.statisticShardPolicy = mergeRoundRobinPolicy + task.statisticShardPolicy = RoundRobinPolicy task.fromQueryNode = true qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(&internalpb.GetStatisticsResponse{ Status: &commonpb.Status{ @@ -186,7 +186,7 @@ func TestStatisticTask_all(t *testing.T) { assert.Error(t, task.Execute(ctx)) assert.NoError(t, task.PostExecute(ctx)) - task.statisticShardPolicy = mergeRoundRobinPolicy + task.statisticShardPolicy = RoundRobinPolicy task.fromQueryNode = true qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(&internalpb.GetStatisticsResponse{ Status: &commonpb.Status{ @@ -197,7 +197,7 @@ func TestStatisticTask_all(t *testing.T) { assert.Error(t, task.Execute(ctx)) assert.NoError(t, task.PostExecute(ctx)) - task.statisticShardPolicy = mergeRoundRobinPolicy + task.statisticShardPolicy = RoundRobinPolicy task.fromQueryNode = true qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, nil).Once() assert.NoError(t, task.Execute(ctx))