diff --git a/internal/proxy/lb_policy.go b/internal/proxy/lb_policy.go index 417a2221fd..68617372f8 100644 --- a/internal/proxy/lb_policy.go +++ b/internal/proxy/lb_policy.go @@ -126,7 +126,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload log.Warn("no available shard delegator found", zap.Int64s("nodes", nodes), zap.Int64s("excluded", excludeNodes.Collect())) - return -1, merr.WrapErrServiceUnavailable("no available shard delegator found") + return -1, merr.WrapErrChannelNotAvailable("no available shard delegator found") } targetNode, err = lb.balancer.SelectNode(ctx, availableNodes, workload.nq) @@ -210,7 +210,7 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID }) retryOnReplica := Params.ProxyCfg.RetryTimesOnReplica.GetAsInt() wg.Go(func() error { - err := lb.ExecuteWithRetry(ctx, ChannelWorkload{ + return lb.ExecuteWithRetry(ctx, ChannelWorkload{ db: workload.db, collectionName: workload.collectionName, collectionID: workload.collectionID, @@ -220,12 +220,10 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad exec: workload.exec, retryTimes: uint(len(nodes) * retryOnReplica), }) - return err }) } - err = wg.Wait() - return err + return wg.Wait() } func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) { diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index 19da3d6b0d..b3a89ef5f5 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -213,7 +213,7 @@ func (s *LBPolicySuite) TestSelectNode() { shardLeaders: s.nodes, nq: 1, }, typeutil.NewUniqueSet(s.nodes...)) - s.ErrorIs(err, merr.ErrServiceUnavailable) + s.ErrorIs(err, merr.ErrChannelNotAvailable) s.Equal(int64(-1), targetNode) // test get shard leaders failed, retry to select node failed diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 45f15ac72b..c45a1ea2b3 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -494,7 +494,7 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query } if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("QueryNode query result error", zap.Any("errorCode", result.GetStatus().GetErrorCode()), zap.String("reason", result.GetStatus().GetReason())) - return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason()) + return errors.Wrapf(merr.Error(result.GetStatus()), "fail to Query on QueryNode %d", nodeID) } log.Debug("get query result") diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 7cef1d2a4e..f3e56080c1 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -18,7 +18,6 @@ package proxy import ( "context" "fmt" - "strings" "testing" "time" @@ -211,12 +210,10 @@ func TestQueryTask_all(t *testing.T) { qn.ExpectedCalls = nil qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_NotShardLeader, - }, + Status: merr.Status(merr.ErrChannelNotAvailable), }, nil) err = task.Execute(ctx) - assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error())) + assert.ErrorIs(t, err, merr.ErrChannelNotAvailable) qn.ExpectedCalls = nil qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 5aaf7ff644..0ea69a3f21 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -539,7 +539,7 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("QueryNode search result error", zap.String("reason", result.GetStatus().GetReason())) - return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason()) + return errors.Wrapf(merr.Error(result.GetStatus()), "fail to search on QueryNode %d", nodeID) } t.resultBuf.Insert(result) t.lb.UpdateCostMetrics(nodeID, result.CostAggregation) diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index f744a4a777..4725205d30 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -19,7 +19,6 @@ import ( "context" "fmt" "strconv" - "strings" "testing" "time" @@ -1719,12 +1718,10 @@ func TestSearchTask_ErrExecute(t *testing.T) { qn.ExpectedCalls = nil qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_NotShardLeader, - }, + Status: merr.Status(merr.ErrChannelNotAvailable), }, nil) err = task.Execute(ctx) - assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error())) + assert.ErrorIs(t, err, merr.ErrChannelNotAvailable) qn.ExpectedCalls = nil qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() diff --git a/internal/proxy/task_statistic.go b/internal/proxy/task_statistic.go index 9a676f692e..2b4104c082 100644 --- a/internal/proxy/task_statistic.go +++ b/internal/proxy/task_statistic.go @@ -302,7 +302,7 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64 zap.Int64("nodeID", nodeID), zap.String("reason", result.GetStatus().GetReason())) globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName) - return fmt.Errorf("fail to get statistic, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason()) + return errors.Wrapf(merr.Error(result.GetStatus()), "fail to get statistic on QueryNode ID=%d", nodeID) } g.resultBuf.Insert(result) diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 281fc6abbe..4da7c02e51 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -773,21 +773,21 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( }, nil } - failRet := &internalpb.SearchResults{ + resp := &internalpb.SearchResults{ Status: merr.Success(), } collection := node.manager.Collection.Get(req.GetReq().GetCollectionID()) if collection == nil { - failRet.Status = merr.Status(merr.WrapErrCollectionNotFound(req.GetReq().GetCollectionID())) - return failRet, nil + resp.Status = merr.Status(merr.WrapErrCollectionNotFound(req.GetReq().GetCollectionID())) + return resp, nil } // Check if the metric type specified in search params matches the metric type in the index info. if !req.GetFromShardLeader() && req.GetReq().GetMetricType() != "" { if req.GetReq().GetMetricType() != collection.GetMetricType() { - failRet.Status = merr.Status(merr.WrapErrParameterInvalid(collection.GetMetricType(), req.GetReq().GetMetricType(), + resp.Status = merr.Status(merr.WrapErrParameterInvalid(collection.GetMetricType(), req.GetReq().GetMetricType(), fmt.Sprintf("collection:%d, metric type not match", collection.ID()))) - return failRet, nil + return resp, nil } } @@ -796,10 +796,9 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( req.Req.MetricType = collection.GetMetricType() } - var toReduceResults []*internalpb.SearchResults - var mu sync.Mutex + toReduceResults := make([]*internalpb.SearchResults, len(req.GetDmlChannels())) runningGp, runningCtx := errgroup.WithContext(ctx) - for _, ch := range req.GetDmlChannels() { + for i, ch := range req.GetDmlChannels() { ch := ch req := &querypb.SearchRequest{ Req: req.Req, @@ -810,31 +809,30 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) ( TotalChannelNum: req.TotalChannelNum, } + i := i runningGp.Go(func() error { ret, err := node.searchChannel(runningCtx, req, ch) - mu.Lock() - defer mu.Unlock() if err != nil { - failRet.Status = merr.Status(err) return err } - if ret.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { - return merr.Error(failRet.GetStatus()) + if err := merr.Error(ret.GetStatus()); err != nil { + return err } - toReduceResults = append(toReduceResults, ret) + toReduceResults[i] = ret return nil }) } if err := runningGp.Wait(); err != nil { - return failRet, nil + resp.Status = merr.Status(err) + return resp, nil } tr.RecordSpan() result, err := segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) if err != nil { log.Warn("failed to reduce search results", zap.Error(err)) - failRet.Status = merr.Status(err) - return failRet, nil + resp.Status = merr.Status(err) + return resp, nil } reduceLatency := tr.RecordSpan() metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards). diff --git a/pkg/util/retry/retry.go b/pkg/util/retry/retry.go index cb1a8a2092..afb01ab31b 100644 --- a/pkg/util/retry/retry.go +++ b/pkg/util/retry/retry.go @@ -38,31 +38,36 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error { opt(c) } - var el error + var lastErr error for i := uint(0); i < c.attempts; i++ { if err := fn(); err != nil { if i%4 == 0 { - log.Error("retry func failed", zap.Uint("retry time", i), zap.Error(err)) + log.Warn("retry func failed", zap.Uint("retried", i), zap.Error(err)) } - err = errors.Wrapf(err, "attempt #%d", i) - el = merr.Combine(el, err) - if !IsRecoverable(err) { - return el + if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil { + return lastErr + } + return err } deadline, ok := ctx.Deadline() if ok && time.Until(deadline) < c.sleep { // to avoid sleep until ctx done - return el + if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil { + return lastErr + } + return err } + lastErr = err + select { case <-time.After(c.sleep): case <-ctx.Done(): - return merr.Combine(el, ctx.Err()) + return lastErr } c.sleep *= 2 @@ -73,7 +78,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error { return nil } } - return el + return lastErr } // errUnrecoverable is error instance for unrecoverable. diff --git a/pkg/util/retry/retry_test.go b/pkg/util/retry/retry_test.go index afc0861838..d0a2c501e4 100644 --- a/pkg/util/retry/retry_test.go +++ b/pkg/util/retry/retry_test.go @@ -127,8 +127,9 @@ func TestContextDeadline(t *testing.T) { func TestContextCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) + mockErr := errors.New("mock error") testFn := func() error { - return fmt.Errorf("some error") + return mockErr } go func() { @@ -138,7 +139,7 @@ func TestContextCancel(t *testing.T) { err := Do(ctx, testFn) assert.Error(t, err) - assert.True(t, merr.IsCanceledOrTimeout(err)) + assert.ErrorIs(t, err, mockErr) t.Log(err) }