diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index d7cc948f70..f6e9e03fb7 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -83,7 +83,7 @@ func ParseUsernamePassword(c *gin.Context) (string, string, bool) { username, password, ok := c.Request.BasicAuth() if !ok { token := GetAuthorization(c) - i := strings.IndexAny(token, util.CredentialSeperator) + i := strings.IndexAny(token, util.CredentialSeparator) if i != -1 { username = token[:i] password = token[i+1:] diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index c11d3cb89d..8dddaa4398 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "math" + "net/http/httptest" "strconv" "strings" "testing" @@ -2734,3 +2735,42 @@ func TestGenFunctionScore(t *testing.T) { assert.NoError(t, err) } } + +func TestParseUsernamePassword(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("token with credential separator", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Request.Header.Set("Authorization", "Bearer testuser:testpass") + + username, password, ok := ParseUsernamePassword(c) + assert.True(t, ok) + assert.Equal(t, "testuser", username) + assert.Equal(t, "testpass", password) + }) + + t.Run("token without credential separator", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Request.Header.Set("Authorization", "Bearer tokenonly") + + username, password, ok := ParseUsernamePassword(c) + assert.False(t, ok) + assert.Equal(t, "", username) + assert.Equal(t, "", password) + }) + + t.Run("empty authorization header", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/", nil) + + username, password, ok := ParseUsernamePassword(c) + assert.False(t, ok) + assert.Equal(t, "", username) + assert.Equal(t, "", password) + }) +} diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index b8ac7e0524..d7d2d4cf5c 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -130,12 +130,12 @@ func authenticate(c *gin.Context) { if proxy.PasswordVerify(c, username, password) { log.Ctx(context.TODO()).Debug("auth successful", zap.String("username", username)) c.Set(httpserver.ContextUsername, username) - c.Set(httpserver.ContextToken, fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, password)) + c.Set(httpserver.ContextToken, fmt.Sprintf("%s%s%s", username, util.CredentialSeparator, password)) return } } rawToken := httpserver.GetAuthorization(c) - if rawToken != "" && !strings.Contains(rawToken, util.CredentialSeperator) { + if rawToken != "" && !strings.Contains(rawToken, util.CredentialSeparator) { user, err := proxy.VerifyAPIKey(rawToken) if err == nil { c.Set(httpserver.ContextUsername, user) diff --git a/internal/proxy/accesslog/info/util.go b/internal/proxy/accesslog/info/util.go index a985929ca0..27a5cea50e 100644 --- a/internal/proxy/accesslog/info/util.go +++ b/internal/proxy/accesslog/info/util.go @@ -49,7 +49,7 @@ func getCurUserFromContext(ctx context.Context) (string, error) { if err != nil { return "", fmt.Errorf("fail to decode the token, token: %s", token) } - secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2) + secrets := strings.SplitN(rawToken, util.CredentialSeparator, 2) if len(secrets) < 2 { return "", fmt.Errorf("fail to get user info from the raw token, raw token: %s", rawToken) } diff --git a/internal/proxy/accesslog/info/util_test.go b/internal/proxy/accesslog/info/util_test.go index 47f633c697..22e329bba8 100644 --- a/internal/proxy/accesslog/info/util_test.go +++ b/internal/proxy/accesslog/info/util_test.go @@ -17,11 +17,16 @@ package info import ( + "context" + "strings" "testing" "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/util" + "github.com/milvus-io/milvus/pkg/v2/util/crypto" ) func TestGetSdkTypeByUserAgent(t *testing.T) { @@ -149,3 +154,56 @@ func TestGetLengthFromTemplateValue(t *testing.T) { assert.Equal(t, 1, getLengthFromTemplateValue(tv)) }) } + +func TestGetCurUserFromContext(t *testing.T) { + t.Run("valid context with user info", func(t *testing.T) { + ctx := context.Background() + token := crypto.Base64Encode("testuser:testpassword") + md := metadata.New(map[string]string{ + strings.ToLower(util.HeaderAuthorize): token, + }) + ctx = metadata.NewIncomingContext(ctx, md) + + username, err := getCurUserFromContext(ctx) + assert.NoError(t, err) + assert.Equal(t, "testuser", username) + }) + + t.Run("no metadata in context", func(t *testing.T) { + ctx := context.Background() + _, err := getCurUserFromContext(ctx) + assert.Error(t, err) + }) + + t.Run("no authorization in metadata", func(t *testing.T) { + ctx := context.Background() + md := metadata.New(map[string]string{}) + ctx = metadata.NewIncomingContext(ctx, md) + + _, err := getCurUserFromContext(ctx) + assert.Error(t, err) + }) + + t.Run("invalid token format", func(t *testing.T) { + ctx := context.Background() + md := metadata.New(map[string]string{ + strings.ToLower(util.HeaderAuthorize): "invalid_base64!@#", + }) + ctx = metadata.NewIncomingContext(ctx, md) + + _, err := getCurUserFromContext(ctx) + assert.Error(t, err) + }) + + t.Run("token without separator", func(t *testing.T) { + ctx := context.Background() + token := crypto.Base64Encode("tokenwithoutseparator") + md := metadata.New(map[string]string{ + strings.ToLower(util.HeaderAuthorize): token, + }) + ctx = metadata.NewIncomingContext(ctx, md) + + _, err := getCurUserFromContext(ctx) + assert.Error(t, err) + }) +} diff --git a/internal/proxy/authentication_interceptor.go b/internal/proxy/authentication_interceptor.go index 7badadddd3..dd1579a6ff 100644 --- a/internal/proxy/authentication_interceptor.go +++ b/internal/proxy/authentication_interceptor.go @@ -23,7 +23,7 @@ import ( ) func parseMD(rawToken string) (username, password string) { - secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2) + secrets := strings.SplitN(rawToken, util.CredentialSeparator, 2) if len(secrets) < 2 { log.Warn("invalid token format, length of secrets less than 2") return @@ -81,14 +81,14 @@ func AuthenticationInterceptor(ctx context.Context) (context.Context, error) { return nil, status.Error(codes.Unauthenticated, "invalid token format") } - if !strings.Contains(rawToken, util.CredentialSeperator) { + if !strings.Contains(rawToken, util.CredentialSeparator) { user, err := VerifyAPIKey(rawToken) if err != nil { log.Warn("fail to verify apikey", zap.Error(err)) return nil, status.Error(codes.Unauthenticated, "auth check failure, please check api key is correct") } metrics.UserRPCCounter.WithLabelValues(user).Inc() - userToken := fmt.Sprintf("%s%s%s", user, util.CredentialSeperator, util.PasswordHolder) + userToken := fmt.Sprintf("%s%s%s", user, util.CredentialSeparator, util.PasswordHolder) md[strings.ToLower(util.HeaderAuthorize)] = []string{crypto.Base64Encode(userToken)} md[util.HeaderToken] = []string{rawToken} ctx = metadata.NewIncomingContext(ctx, md) diff --git a/internal/proxy/trace_log_interceptor_test.go b/internal/proxy/trace_log_interceptor_test.go index 41f9cf5c25..d7c0fc5212 100644 --- a/internal/proxy/trace_log_interceptor_test.go +++ b/internal/proxy/trace_log_interceptor_test.go @@ -48,7 +48,7 @@ func TestTraceLogInterceptor(t *testing.T) { _, _ = TraceLogInterceptor(context.Background(), &milvuspb.ShowCollectionsRequest{}, &grpc.UnaryServerInfo{}, handler) // simple mode - ctx := GetContext(context.Background(), fmt.Sprintf("%s%s%s", "foo", util.CredentialSeperator, "FOO123456")) + ctx := GetContext(context.Background(), fmt.Sprintf("%s%s%s", "foo", util.CredentialSeparator, "FOO123456")) _ = paramtable.Get().Save(paramtable.Get().CommonCfg.TraceLogMode.Key, "1") { _, _ = TraceLogInterceptor(ctx, &milvuspb.CreateCollectionRequest{ diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 55d7fc84d4..a410024d31 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1409,7 +1409,7 @@ func NewContextWithMetadata(ctx context.Context, username string, dbName string) ctx = contextutil.AppendToIncomingContext(ctx, dbKey, dbName) } if username != "" { - originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username) + originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeparator, username) authKey := strings.ToLower(util.HeaderAuthorize) authValue := crypto.Base64Encode(originValue) ctx = contextutil.AppendToIncomingContext(ctx, authKey, authValue) @@ -1420,7 +1420,7 @@ func NewContextWithMetadata(ctx context.Context, username string, dbName string) func AppendUserInfoForRPC(ctx context.Context) context.Context { curUser, _ := GetCurUserFromContext(ctx) if curUser != "" { - originValue := fmt.Sprintf("%s%s%s", curUser, util.CredentialSeperator, curUser) + originValue := fmt.Sprintf("%s%s%s", curUser, util.CredentialSeparator, curUser) authKey := strings.ToLower(util.HeaderAuthorize) authValue := crypto.Base64Encode(originValue) ctx = metadata.AppendToOutgoingContext(ctx, authKey, authValue) diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 1282d3538e..7ec71b6abf 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -802,7 +802,7 @@ func TestGetCurUserFromContext(t *testing.T) { root := "root" password := "123456" - username, err := GetCurUserFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password))) + username, err := GetCurUserFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeparator, password))) assert.NoError(t, err) assert.Equal(t, "root", username) } @@ -2373,6 +2373,63 @@ func TestAppendUserInfoForRPC(t *testing.T) { assert.Equal(t, expectAuth, authorization[0]) } +func TestNewContextWithMetadata(t *testing.T) { + t.Run("with username and dbName", func(t *testing.T) { + ctx := context.Background() + ctx = NewContextWithMetadata(ctx, "testuser", "testdb") + + md, ok := metadata.FromIncomingContext(ctx) + assert.True(t, ok) + + // Check dbName + dbNameKey := strings.ToLower(util.HeaderDBName) + dbNameVal, ok := md[dbNameKey] + assert.True(t, ok) + assert.Equal(t, "testdb", dbNameVal[0]) + + // Check authorization + authKey := strings.ToLower(util.HeaderAuthorize) + authVal, ok := md[authKey] + assert.True(t, ok) + expectedAuth := crypto.Base64Encode("testuser:testuser") + assert.Equal(t, expectedAuth, authVal[0]) + }) + + t.Run("with empty username", func(t *testing.T) { + ctx := context.Background() + ctx = NewContextWithMetadata(ctx, "", "testdb") + + md, ok := metadata.FromIncomingContext(ctx) + assert.True(t, ok) + + // Check dbName is set + dbNameKey := strings.ToLower(util.HeaderDBName) + dbNameVal, ok := md[dbNameKey] + assert.True(t, ok) + assert.Equal(t, "testdb", dbNameVal[0]) + + // Check authorization is not set + authKey := strings.ToLower(util.HeaderAuthorize) + _, ok = md[authKey] + assert.False(t, ok) + }) + + t.Run("with empty dbName", func(t *testing.T) { + ctx := context.Background() + ctx = NewContextWithMetadata(ctx, "testuser", "") + + md, ok := metadata.FromIncomingContext(ctx) + assert.True(t, ok) + + // Check authorization is set + authKey := strings.ToLower(util.HeaderAuthorize) + authVal, ok := md[authKey] + assert.True(t, ok) + expectedAuth := crypto.Base64Encode("testuser:testuser") + assert.Equal(t, expectedAuth, authVal[0]) + }) +} + func TestGetCostValue(t *testing.T) { t.Run("empty status", func(t *testing.T) { { diff --git a/pkg/util/constant.go b/pkg/util/constant.go index 7b71763f7e..85180efc1b 100644 --- a/pkg/util/constant.go +++ b/pkg/util/constant.go @@ -46,7 +46,7 @@ const ( HeaderAuthorize = "authorization" HeaderToken = "token" - CredentialSeperator = ":" + CredentialSeparator = ":" UserRoot = "root" PasswordHolder = "___" DefaultTenant = "" diff --git a/pkg/util/contextutil/context_util.go b/pkg/util/contextutil/context_util.go index c224e570ab..5a77f2b792 100644 --- a/pkg/util/contextutil/context_util.go +++ b/pkg/util/contextutil/context_util.go @@ -101,7 +101,7 @@ func GetAuthInfoFromContext(ctx context.Context) (string, string, error) { if err != nil { return "", "", fmt.Errorf("fail to decode the token, token: %s", token) } - secrets := strings.SplitN(rawToken, util.CredentialSeperator, 2) + secrets := strings.SplitN(rawToken, util.CredentialSeparator, 2) if len(secrets) < 2 { return "", "", fmt.Errorf("fail to get user info from the raw token, raw token: %s", rawToken) } diff --git a/pkg/util/contextutil/context_util_test.go b/pkg/util/contextutil/context_util_test.go index 40f6f91b9c..7f7963c6cd 100644 --- a/pkg/util/contextutil/context_util_test.go +++ b/pkg/util/contextutil/context_util_test.go @@ -82,12 +82,12 @@ func TestGetCurUserFromContext(t *testing.T) { root := "root" password := "123456" - username, err := GetCurUserFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password))) + username, err := GetCurUserFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeparator, password))) assert.NoError(t, err) assert.Equal(t, root, username) { - u, p, e := GetAuthInfoFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password))) + u, p, e := GetAuthInfoFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeparator, password))) assert.NoError(t, e) assert.Equal(t, "root", u) assert.Equal(t, password, p)