diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go index 0ac7137ce6..6fc8626ce0 100644 --- a/internal/util/grpcclient/client.go +++ b/internal/util/grpcclient/client.go @@ -464,29 +464,27 @@ func (c *ClientBase[T]) call(ctx context.Context, caller func(client T) (any, er ctx, cancel := context.WithCancel(ctx) defer cancel() - err := retry.Do(ctx, func() error { + err := retry.Handle(ctx, func() (bool, error) { if wrapper == nil { if ok := c.checkNodeSessionExist(ctx); !ok { // if session doesn't exist, no need to reset connection for datanode/indexnode/querynode - return retry.Unrecoverable(merr.ErrNodeNotFound) + return false, merr.ErrNodeNotFound } err := errors.Wrap(clientErr, "empty grpc client") log.Warn("grpc client is nil, maybe fail to get client in the retry state", zap.Error(err)) resetClientFunc() - return err + return true, err } + wrapper.Pin() var err error ret, err = caller(wrapper.client) wrapper.Unpin() + if err != nil { var needRetry, needReset bool needRetry, needReset, err = c.checkGrpcErr(ctx, err) - if !needRetry { - // stop retry - err = retry.Unrecoverable(err) - } if needReset { log.Warn("start to reset connection because of specific reasons", zap.Error(err)) resetClientFunc() @@ -498,7 +496,7 @@ func (c *ClientBase[T]) call(ctx context.Context, caller func(client T) (any, er resetClientFunc() } } - return err + return needRetry, err } // reset counter c.ctxCounter.Store(0) @@ -512,19 +510,19 @@ func (c *ClientBase[T]) call(ctx context.Context, caller func(client T) (any, er default: // it will directly return the result log.Warn("unknown return type", zap.Any("return", ret)) - return nil + return false, nil } if status == nil { log.Warn("status is nil, please fix it", zap.Stack("stack")) - return nil + return false, nil } err = merr.Error(status) if err != nil && merr.IsRetryableErr(err) { - return err + return true, err } - return nil + return false, nil }, retry.Attempts(uint(c.MaxAttempts)), // Because the previous InitialBackoff and MaxBackoff were float, and the unit was s. // For compatibility, this is multiplied by 1000. diff --git a/pkg/util/retry/retry.go b/pkg/util/retry/retry.go index 2597d62d97..05d92d1d4c 100644 --- a/pkg/util/retry/retry.go +++ b/pkg/util/retry/retry.go @@ -80,6 +80,63 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error { return el } +// Do will run function with retry mechanism. +// fn is the func to run, return err and shouldRetry flag. +// Option can control the retry times and timeout. +func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error { + if !funcutil.CheckCtxValid(ctx) { + return ctx.Err() + } + + log := log.Ctx(ctx) + c := newDefaultConfig() + + for _, opt := range opts { + opt(c) + } + + var lastErr error + for i := uint(0); i < c.attempts; i++ { + if shouldRetry, err := fn(); err != nil { + if i%4 == 0 { + log.Warn("retry func failed", zap.Uint("retried", i), zap.Error(err)) + } + + if !shouldRetry { + if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil { + return lastErr + } + return err + } + + deadline, ok := ctx.Deadline() + if ok && time.Until(deadline) < c.sleep { + // to avoid sleep until ctx done + if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil { + return lastErr + } + return err + } + + lastErr = err + + select { + case <-time.After(c.sleep): + case <-ctx.Done(): + return lastErr + } + + c.sleep *= 2 + if c.sleep > c.maxSleepTime { + c.sleep = c.maxSleepTime + } + } else { + return nil + } + } + return lastErr +} + // errUnrecoverable is error instance for unrecoverable. var errUnrecoverable = errors.New("unrecoverable error") diff --git a/pkg/util/retry/retry_test.go b/pkg/util/retry/retry_test.go index abdef0af51..70c059c456 100644 --- a/pkg/util/retry/retry_test.go +++ b/pkg/util/retry/retry_test.go @@ -181,3 +181,50 @@ func TestRetryErrorParam(t *testing.T) { assert.Equal(t, 3, runTimes) } } + +func TestHandle(t *testing.T) { + // test context done + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := Handle(ctx, func() (bool, error) { + return false, nil + }, Attempts(5)) + assert.ErrorIs(t, err, context.Canceled) + + fakeErr := errors.New("mock retry error") + // test return error and retry + counter := 0 + err = Handle(context.Background(), func() (bool, error) { + counter++ + if counter < 3 { + return true, fakeErr + } + return false, nil + }, Attempts(10)) + assert.NoError(t, err) + + // test ctx done before return retry success + counter = 0 + ctx1, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + err = Handle(ctx1, func() (bool, error) { + counter++ + if counter < 5 { + return true, fakeErr + } + return false, nil + }, Attempts(10)) + assert.ErrorIs(t, err, fakeErr) + + // test return error and not retry + err = Handle(context.Background(), func() (bool, error) { + return false, fakeErr + }, Attempts(10)) + assert.ErrorIs(t, err, fakeErr) + + // test return nil + err = Handle(context.Background(), func() (bool, error) { + return false, nil + }, Attempts(10)) + assert.NoError(t, err) +}