From 024beddfe61746e9a498c8039802684a529c4bed Mon Sep 17 00:00:00 2001 From: huanghaoyuanhhy Date: Tue, 14 Mar 2023 18:09:54 +0800 Subject: [PATCH] Make GCS OAuth token thread-safe (#22714) Signed-off-by: huanghaoyuan --- go.mod | 2 +- go.sum | 2 ++ internal/storage/gcp/gcp.go | 15 ++++++++++----- internal/storage/gcp/gcp_test.go | 6 +++--- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index eb1fb30bdd..38cf11f91d 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( go.etcd.io/etcd/api/v3 v3.5.5 go.etcd.io/etcd/client/v3 v3.5.5 go.etcd.io/etcd/server/v3 v3.5.5 - go.uber.org/atomic v1.7.0 + go.uber.org/atomic v1.10.0 go.uber.org/automaxprocs v1.4.0 go.uber.org/zap v1.17.0 golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 diff --git a/go.sum b/go.sum index 7013731d6d..d030f10b89 100644 --- a/go.sum +++ b/go.sum @@ -884,6 +884,8 @@ go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0= go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q= go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= diff --git a/internal/storage/gcp/gcp.go b/internal/storage/gcp/gcp.go index 033ceda3ec..28cb964575 100644 --- a/internal/storage/gcp/gcp.go +++ b/internal/storage/gcp/gcp.go @@ -7,6 +7,7 @@ import ( "github.com/cockroachdb/errors" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" + "go.uber.org/atomic" "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) @@ -15,7 +16,7 @@ import ( type WrapHTTPTransport struct { tokenSrc oauth2.TokenSource backend transport - currentToken *oauth2.Token + currentToken atomic.Pointer[oauth2.Token] } // transport abstracts http.Transport to simplify test @@ -41,14 +42,18 @@ func NewWrapHTTPTransport(secure bool) (*WrapHTTPTransport, error) { func (t *WrapHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { // here Valid() means the token won't be expired in 10 sec // so the http client timeout shouldn't be longer, or we need to change the default `expiryDelta` time - if !t.currentToken.Valid() { - var err error - t.currentToken, err = t.tokenSrc.Token() + currentToken := t.currentToken.Load() + if currentToken.Valid() { + req.Header.Set("Authorization", "Bearer "+currentToken.AccessToken) + } else { + newToken, err := t.tokenSrc.Token() if err != nil { return nil, errors.Wrap(err, "failed to acquire token") } + t.currentToken.Store(newToken) + req.Header.Set("Authorization", "Bearer "+newToken.AccessToken) } - req.Header.Set("Authorization", "Bearer "+t.currentToken.AccessToken) + return t.backend.RoundTrip(req) } diff --git a/internal/storage/gcp/gcp_test.go b/internal/storage/gcp/gcp_test.go index a7e7c64289..9695316a3d 100644 --- a/internal/storage/gcp/gcp_test.go +++ b/internal/storage/gcp/gcp_test.go @@ -75,7 +75,7 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) { }) t.Run("invalid token, refresh failed", func(t *testing.T) { - ts.currentToken = nil + ts.currentToken.Store(nil) ts.tokenSrc = &mockTokenSource{err: errors.New("mock error")} req, err := http.NewRequest("GET", "http://example.com", nil) assert.NoError(t, err) @@ -84,7 +84,7 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) { }) t.Run("invalid token, refresh ok", func(t *testing.T) { - ts.currentToken = nil + ts.currentToken.Store(nil) ts.tokenSrc = &mockTokenSource{err: nil} req, err := http.NewRequest("GET", "http://example.com", nil) assert.NoError(t, err) @@ -92,7 +92,7 @@ func TestGCPWrappedHTTPTransport_RoundTrip(t *testing.T) { assert.NoError(t, err) }) - ts.currentToken = &oauth2.Token{} + ts.currentToken.Store(&oauth2.Token{}) t.Run("valid token, call failed", func(t *testing.T) { ts.backend = &mockTransport{err: errors.New("mock error")} req, err := http.NewRequest("GET", "http://example.com", nil)