enhance: refine the retry error (#28573)

return the last error but not combining all errors, to improve
readability and erorr handling

resolve: #28572

---------

Signed-off-by: yah01 <yah2er0ne@outlook.com>
This commit is contained in:
yah01 2023-11-30 18:34:32 +08:00 committed by GitHub
parent b4353ca4ce
commit bf633bb5d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 43 additions and 47 deletions

View File

@ -126,7 +126,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
log.Warn("no available shard delegator found", log.Warn("no available shard delegator found",
zap.Int64s("nodes", nodes), zap.Int64s("nodes", nodes),
zap.Int64s("excluded", excludeNodes.Collect())) zap.Int64s("excluded", excludeNodes.Collect()))
return -1, merr.WrapErrServiceUnavailable("no available shard delegator found") return -1, merr.WrapErrChannelNotAvailable("no available shard delegator found")
} }
targetNode, err = lb.balancer.SelectNode(ctx, availableNodes, workload.nq) targetNode, err = lb.balancer.SelectNode(ctx, availableNodes, workload.nq)
@ -210,7 +210,7 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID }) nodes := lo.Map(nodes, func(node nodeInfo, _ int) int64 { return node.nodeID })
retryOnReplica := Params.ProxyCfg.RetryTimesOnReplica.GetAsInt() retryOnReplica := Params.ProxyCfg.RetryTimesOnReplica.GetAsInt()
wg.Go(func() error { wg.Go(func() error {
err := lb.ExecuteWithRetry(ctx, ChannelWorkload{ return lb.ExecuteWithRetry(ctx, ChannelWorkload{
db: workload.db, db: workload.db,
collectionName: workload.collectionName, collectionName: workload.collectionName,
collectionID: workload.collectionID, collectionID: workload.collectionID,
@ -220,12 +220,10 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
exec: workload.exec, exec: workload.exec,
retryTimes: uint(len(nodes) * retryOnReplica), retryTimes: uint(len(nodes) * retryOnReplica),
}) })
return err
}) })
} }
err = wg.Wait() return wg.Wait()
return err
} }
func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) { func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {

View File

@ -213,7 +213,7 @@ func (s *LBPolicySuite) TestSelectNode() {
shardLeaders: s.nodes, shardLeaders: s.nodes,
nq: 1, nq: 1,
}, typeutil.NewUniqueSet(s.nodes...)) }, typeutil.NewUniqueSet(s.nodes...))
s.ErrorIs(err, merr.ErrServiceUnavailable) s.ErrorIs(err, merr.ErrChannelNotAvailable)
s.Equal(int64(-1), targetNode) s.Equal(int64(-1), targetNode)
// test get shard leaders failed, retry to select node failed // test get shard leaders failed, retry to select node failed

View File

@ -494,7 +494,7 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query
} }
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode query result error", zap.Any("errorCode", result.GetStatus().GetErrorCode()), zap.String("reason", result.GetStatus().GetReason())) log.Warn("QueryNode query result error", zap.Any("errorCode", result.GetStatus().GetErrorCode()), zap.String("reason", result.GetStatus().GetReason()))
return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason()) return errors.Wrapf(merr.Error(result.GetStatus()), "fail to Query on QueryNode %d", nodeID)
} }
log.Debug("get query result") log.Debug("get query result")

View File

@ -18,7 +18,6 @@ package proxy
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"testing" "testing"
"time" "time"
@ -211,12 +210,10 @@ func TestQueryTask_all(t *testing.T) {
qn.ExpectedCalls = nil qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{ qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
Status: &commonpb.Status{ Status: merr.Status(merr.ErrChannelNotAvailable),
ErrorCode: commonpb.ErrorCode_NotShardLeader,
},
}, nil) }, nil)
err = task.Execute(ctx) err = task.Execute(ctx)
assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error())) assert.ErrorIs(t, err, merr.ErrChannelNotAvailable)
qn.ExpectedCalls = nil qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()

View File

@ -539,7 +539,7 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode search result error", log.Warn("QueryNode search result error",
zap.String("reason", result.GetStatus().GetReason())) zap.String("reason", result.GetStatus().GetReason()))
return fmt.Errorf("fail to Search, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason()) return errors.Wrapf(merr.Error(result.GetStatus()), "fail to search on QueryNode %d", nodeID)
} }
t.resultBuf.Insert(result) t.resultBuf.Insert(result)
t.lb.UpdateCostMetrics(nodeID, result.CostAggregation) t.lb.UpdateCostMetrics(nodeID, result.CostAggregation)

View File

@ -19,7 +19,6 @@ import (
"context" "context"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"testing" "testing"
"time" "time"
@ -1719,12 +1718,10 @@ func TestSearchTask_ErrExecute(t *testing.T) {
qn.ExpectedCalls = nil qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{ qn.EXPECT().Search(mock.Anything, mock.Anything).Return(&internalpb.SearchResults{
Status: &commonpb.Status{ Status: merr.Status(merr.ErrChannelNotAvailable),
ErrorCode: commonpb.ErrorCode_NotShardLeader,
},
}, nil) }, nil)
err = task.Execute(ctx) err = task.Execute(ctx)
assert.True(t, strings.Contains(err.Error(), errInvalidShardLeaders.Error())) assert.ErrorIs(t, err, merr.ErrChannelNotAvailable)
qn.ExpectedCalls = nil qn.ExpectedCalls = nil
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()

View File

@ -302,7 +302,7 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64
zap.Int64("nodeID", nodeID), zap.Int64("nodeID", nodeID),
zap.String("reason", result.GetStatus().GetReason())) zap.String("reason", result.GetStatus().GetReason()))
globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName) globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
return fmt.Errorf("fail to get statistic, QueryNode ID=%d, reason=%s", nodeID, result.GetStatus().GetReason()) return errors.Wrapf(merr.Error(result.GetStatus()), "fail to get statistic on QueryNode ID=%d", nodeID)
} }
g.resultBuf.Insert(result) g.resultBuf.Insert(result)

View File

@ -773,21 +773,21 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
}, nil }, nil
} }
failRet := &internalpb.SearchResults{ resp := &internalpb.SearchResults{
Status: merr.Success(), Status: merr.Success(),
} }
collection := node.manager.Collection.Get(req.GetReq().GetCollectionID()) collection := node.manager.Collection.Get(req.GetReq().GetCollectionID())
if collection == nil { if collection == nil {
failRet.Status = merr.Status(merr.WrapErrCollectionNotFound(req.GetReq().GetCollectionID())) resp.Status = merr.Status(merr.WrapErrCollectionNotFound(req.GetReq().GetCollectionID()))
return failRet, nil return resp, nil
} }
// Check if the metric type specified in search params matches the metric type in the index info. // Check if the metric type specified in search params matches the metric type in the index info.
if !req.GetFromShardLeader() && req.GetReq().GetMetricType() != "" { if !req.GetFromShardLeader() && req.GetReq().GetMetricType() != "" {
if req.GetReq().GetMetricType() != collection.GetMetricType() { if req.GetReq().GetMetricType() != collection.GetMetricType() {
failRet.Status = merr.Status(merr.WrapErrParameterInvalid(collection.GetMetricType(), req.GetReq().GetMetricType(), resp.Status = merr.Status(merr.WrapErrParameterInvalid(collection.GetMetricType(), req.GetReq().GetMetricType(),
fmt.Sprintf("collection:%d, metric type not match", collection.ID()))) fmt.Sprintf("collection:%d, metric type not match", collection.ID())))
return failRet, nil return resp, nil
} }
} }
@ -796,10 +796,9 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
req.Req.MetricType = collection.GetMetricType() req.Req.MetricType = collection.GetMetricType()
} }
var toReduceResults []*internalpb.SearchResults toReduceResults := make([]*internalpb.SearchResults, len(req.GetDmlChannels()))
var mu sync.Mutex
runningGp, runningCtx := errgroup.WithContext(ctx) runningGp, runningCtx := errgroup.WithContext(ctx)
for _, ch := range req.GetDmlChannels() { for i, ch := range req.GetDmlChannels() {
ch := ch ch := ch
req := &querypb.SearchRequest{ req := &querypb.SearchRequest{
Req: req.Req, Req: req.Req,
@ -810,31 +809,30 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
TotalChannelNum: req.TotalChannelNum, TotalChannelNum: req.TotalChannelNum,
} }
i := i
runningGp.Go(func() error { runningGp.Go(func() error {
ret, err := node.searchChannel(runningCtx, req, ch) ret, err := node.searchChannel(runningCtx, req, ch)
mu.Lock()
defer mu.Unlock()
if err != nil { if err != nil {
failRet.Status = merr.Status(err)
return err return err
} }
if ret.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { if err := merr.Error(ret.GetStatus()); err != nil {
return merr.Error(failRet.GetStatus()) return err
} }
toReduceResults = append(toReduceResults, ret) toReduceResults[i] = ret
return nil return nil
}) })
} }
if err := runningGp.Wait(); err != nil { if err := runningGp.Wait(); err != nil {
return failRet, nil resp.Status = merr.Status(err)
return resp, nil
} }
tr.RecordSpan() tr.RecordSpan()
result, err := segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType()) result, err := segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
if err != nil { if err != nil {
log.Warn("failed to reduce search results", zap.Error(err)) log.Warn("failed to reduce search results", zap.Error(err))
failRet.Status = merr.Status(err) resp.Status = merr.Status(err)
return failRet, nil return resp, nil
} }
reduceLatency := tr.RecordSpan() reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards). metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards).

View File

@ -38,31 +38,36 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
opt(c) opt(c)
} }
var el error var lastErr error
for i := uint(0); i < c.attempts; i++ { for i := uint(0); i < c.attempts; i++ {
if err := fn(); err != nil { if err := fn(); err != nil {
if i%4 == 0 { if i%4 == 0 {
log.Error("retry func failed", zap.Uint("retry time", i), zap.Error(err)) log.Warn("retry func failed", zap.Uint("retried", i), zap.Error(err))
} }
err = errors.Wrapf(err, "attempt #%d", i)
el = merr.Combine(el, err)
if !IsRecoverable(err) { if !IsRecoverable(err) {
return el if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil {
return lastErr
}
return err
} }
deadline, ok := ctx.Deadline() deadline, ok := ctx.Deadline()
if ok && time.Until(deadline) < c.sleep { if ok && time.Until(deadline) < c.sleep {
// to avoid sleep until ctx done // to avoid sleep until ctx done
return el if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil {
return lastErr
}
return err
} }
lastErr = err
select { select {
case <-time.After(c.sleep): case <-time.After(c.sleep):
case <-ctx.Done(): case <-ctx.Done():
return merr.Combine(el, ctx.Err()) return lastErr
} }
c.sleep *= 2 c.sleep *= 2
@ -73,7 +78,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
return nil return nil
} }
} }
return el return lastErr
} }
// errUnrecoverable is error instance for unrecoverable. // errUnrecoverable is error instance for unrecoverable.

View File

@ -127,8 +127,9 @@ func TestContextDeadline(t *testing.T) {
func TestContextCancel(t *testing.T) { func TestContextCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
mockErr := errors.New("mock error")
testFn := func() error { testFn := func() error {
return fmt.Errorf("some error") return mockErr
} }
go func() { go func() {
@ -138,7 +139,7 @@ func TestContextCancel(t *testing.T) {
err := Do(ctx, testFn) err := Do(ctx, testFn)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, merr.IsCanceledOrTimeout(err)) assert.ErrorIs(t, err, mockErr)
t.Log(err) t.Log(err)
} }