diff --git a/internal/kv/etcd/embed_etcd_kv.go b/internal/kv/etcd/embed_etcd_kv.go index ccb6682b31..754fb81e05 100644 --- a/internal/kv/etcd/embed_etcd_kv.go +++ b/internal/kv/etcd/embed_etcd_kv.go @@ -38,6 +38,11 @@ import ( // implementation assertion var _ kv.MetaKv = (*EmbedEtcdKV)(nil) +const ( + defaultRetryCount = 3 + defaultRetryInterval = 1 * time.Second +) + // EmbedEtcdKV use embedded Etcd instance as a KV storage type EmbedEtcdKV struct { client *clientv3.Client @@ -48,9 +53,26 @@ type EmbedEtcdKV struct { requestTimeout time.Duration } +func retry(attempts int, sleep time.Duration, fn func() error) error { + for i := 0; ; i++ { + err := fn() + if err == nil || i >= (attempts-1) { + return err + } + time.Sleep(sleep) + } +} + // NewEmbededEtcdKV creates a new etcd kv. func NewEmbededEtcdKV(cfg *embed.Config, rootPath string, options ...Option) (*EmbedEtcdKV, error) { - e, err := embed.StartEtcd(cfg) + var e *embed.Etcd + var err error + + err = retry(defaultRetryCount, defaultRetryInterval, func() error { + e, err = embed.StartEtcd(cfg) + return err + }) + if err != nil { return nil, err } @@ -69,15 +91,22 @@ func NewEmbededEtcdKV(cfg *embed.Config, rootPath string, options ...Option) (*E requestTimeout: opt.requestTimeout, } + // wait until embed etcd is ready with retry mechanism + err = retry(defaultRetryCount, defaultRetryInterval, func() error { + select { + case <-e.Server.ReadyNotify(): + log.Info("Embedded etcd is ready!") + return nil + case <-time.After(60 * time.Second): + e.Server.Stop() // trigger a shutdown + return errors.New("Embedded etcd took too long to start") + } + }) - // wait until embed etcd is ready - select { - case <-e.Server.ReadyNotify(): - log.Info("Embedded etcd is ready!") - case <-time.After(60 * time.Second): - e.Server.Stop() // trigger a shutdown - return nil, errors.New("Embedded etcd took too long to start") + if err != nil { + return nil, err } + return kv, nil } diff --git a/internal/kv/etcd/etcd_kv_test.go b/internal/kv/etcd/etcd_kv_test.go index d4ec49d062..0baa1e6c9d 100644 --- a/internal/kv/etcd/etcd_kv_test.go +++ b/internal/kv/etcd/etcd_kv_test.go @@ -906,3 +906,64 @@ func TestHasPrefix(t *testing.T) { assert.NoError(t, err) assert.False(t, has) } + +func TestRetrySuccess(t *testing.T) { + // Test case where the function succeeds on the first attempt + err := retry(defaultRetryCount, defaultRetryInterval, func() error { + return nil + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestRetryFailure(t *testing.T) { + // Test case where the function fails all attempts + expectedErr := errors.New("always fail") + err := retry(defaultRetryCount, defaultRetryInterval, func() error { + return expectedErr + }) + if err == nil { + t.Fatalf("expected error, got nil") + } + if err != expectedErr { + t.Fatalf("expected %v, got %v", expectedErr, err) + } +} + +func TestRetryEventuallySucceeds(t *testing.T) { + // Test case where the function fails the first two attempts and succeeds on the third + attempts := 0 + err := retry(defaultRetryCount, defaultRetryInterval, func() error { + attempts++ + if attempts < 3 { + return errors.New("temporary failure") + } + return nil + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if attempts != 3 { + t.Fatalf("expected 3 attempts, got %d", attempts) + } +} + +func TestRetryInterval(t *testing.T) { + // Test case to check if retry respects the interval + startTime := time.Now() + err := retry(defaultRetryCount, defaultRetryInterval, func() error { + return errors.New("fail") + }) + elapsed := time.Since(startTime) + // expected (defaultRetryCount - 1) intervals of defaultRetryInterval + expectedMin := defaultRetryInterval * time.Duration(defaultRetryCount-1) + expectedMax := expectedMin + (50 * time.Millisecond) // Allow 50ms margin for timing precision + + if err == nil { + t.Fatalf("expected error, got nil") + } + if elapsed < expectedMin || elapsed > expectedMax { + t.Fatalf("expected elapsed time around %v, got %v", expectedMin, elapsed) + } +}