From b8086cb62b7690e6e4fe04073ef2a4968384e24d Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Tue, 9 Dec 2025 10:59:13 +0800 Subject: [PATCH] fix: lost database in restful v2 (#46171) issue: #45812 --------- Signed-off-by: chyezh --- .../proxy/httpserver/handler_v2.go | 9 +---- internal/distributed/proxy/service.go | 3 +- internal/distributed/streaming/forward.go | 38 ++++++++++++++++--- .../distributed/streaming/test_streaming.go | 2 +- pkg/util/contextutil/context_util.go | 19 +++++++++- pkg/util/contextutil/context_util_test.go | 22 +++++++++++ 6 files changed, 76 insertions(+), 17 deletions(-) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 4c931c62ae..d428d4bd60 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -34,7 +34,6 @@ import ( "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -50,7 +49,6 @@ import ( "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" - "github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util/crypto" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -386,13 +384,10 @@ func wrapperProxyWithLimit(ctx context.Context, ginCtx *gin.Context, req any, ch forwardHandler := func(reqCtx context.Context, req any) (any, error) { interceptor := streaming.ForwardLegacyProxyUnaryServerInterceptor() - newCtx := reqCtx if token, ok := ginCtx.Get(ContextToken); ok { - newCtx = metadata.NewIncomingContext(reqCtx, metadata.MD{ - util.HeaderAuthorize: []string{token.(string)}, - }) + interceptor = streaming.ForwardLegacyProxyUnaryServerInterceptor(streaming.OptForwardAuth(token.(string))) } - return interceptor(newCtx, req, &grpc.UnaryServerInfo{FullMethod: fullMethod}, func(ctx context.Context, req any) (interface{}, error) { + return interceptor(reqCtx, req, &grpc.UnaryServerInfo{FullMethod: fullMethod}, func(ctx context.Context, req any) (interface{}, error) { return handler(ctx, req) }) } diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index c61c1142af..4c29602c61 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -67,7 +67,6 @@ import ( "github.com/milvus-io/milvus/pkg/v2/proto/proxypb" "github.com/milvus-io/milvus/pkg/v2/tracer" "github.com/milvus-io/milvus/pkg/v2/util" - "github.com/milvus-io/milvus/pkg/v2/util/crypto" "github.com/milvus-io/milvus/pkg/v2/util/etcd" "github.com/milvus-io/milvus/pkg/v2/util/interceptor" "github.com/milvus-io/milvus/pkg/v2/util/logutil" @@ -131,7 +130,7 @@ func authenticate(c *gin.Context) { if proxy.PasswordVerify(c, username, password) { log.Ctx(context.TODO()).Debug("auth successful", zap.String("username", username)) c.Set(httpserver.ContextUsername, username) - c.Set(httpserver.ContextToken, crypto.Base64Encode(fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, password))) + c.Set(httpserver.ContextToken, fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, password)) return } } diff --git a/internal/distributed/streaming/forward.go b/internal/distributed/streaming/forward.go index a9f7b91563..fcb3fe3d9d 100644 --- a/internal/distributed/streaming/forward.go +++ b/internal/distributed/streaming/forward.go @@ -40,6 +40,9 @@ import ( "github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc" "github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver" "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/util" + "github.com/milvus-io/milvus/pkg/v2/util/contextutil" + "github.com/milvus-io/milvus/pkg/v2/util/crypto" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -60,7 +63,22 @@ func newForwardService(streamingCoordClient client.Client) *forwardServiceImpl { } type ForwardService interface { - ForwardLegacyProxy(ctx context.Context, request any) (any, error) + ForwardLegacyProxy(ctx context.Context, request any, forwardAuth ...ForwardOption) (any, error) +} + +// OptForwardAuth is the option to set the auth token for the forward service. +func OptForwardAuth(authToken string) ForwardOption { + return func(fs *forwardOption) { + fs.authToken = authToken + } +} + +// ForwardOption is the option for the forward service. +type ForwardOption func(*forwardOption) + +// forwardOption is the option for the forward service. +type forwardOption struct { + authToken string } // forwardServiceImpl is the implementation of FallbackService. @@ -75,12 +93,12 @@ type forwardServiceImpl struct { } // ForwardLegacyProxy forwards the request to the legacy proxy. -func (fs *forwardServiceImpl) ForwardLegacyProxy(ctx context.Context, request any) (any, error) { +func (fs *forwardServiceImpl) ForwardLegacyProxy(ctx context.Context, request any, opts ...ForwardOption) (any, error) { if err := fs.checkIfForwardDisabledWithLock(ctx); err != nil { return nil, err } - return fs.forwardLegacyProxy(ctx, request) + return fs.forwardLegacyProxy(ctx, request, opts...) } // checkIfForwardDisabledWithLock checks if the forward is disabled with lock. @@ -92,12 +110,20 @@ func (fs *forwardServiceImpl) checkIfForwardDisabledWithLock(ctx context.Context } // forwardLegacyProxy forwards the request to the legacy proxy. -func (fs *forwardServiceImpl) forwardLegacyProxy(ctx context.Context, request any) (any, error) { +func (fs *forwardServiceImpl) forwardLegacyProxy(ctx context.Context, request any, opts ...ForwardOption) (any, error) { s, err := fs.getLegacyProxyService(ctx) if err != nil { return nil, err } + var optForwardOption forwardOption + for _, opt := range opts { + opt(&optForwardOption) + } + if optForwardOption.authToken != "" { + ctx = contextutil.SetToIncomingContext(ctx, util.HeaderAuthorize, crypto.Base64Encode(optForwardOption.authToken)) + } + var result proto.Message switch req := request.(type) { case *milvuspb.InsertRequest: @@ -247,7 +273,7 @@ func (fs *forwardServiceImpl) markForwardDisabled() { // the dml cannot be executed at new 2.6.x proxy until all 2.5.x proxies are down. // // so we need to forward the request to the 2.5.x proxy. -func ForwardLegacyProxyUnaryServerInterceptor() grpc.UnaryServerInterceptor { +func ForwardLegacyProxyUnaryServerInterceptor(opts ...ForwardOption) grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if info.FullMethod != milvuspb.MilvusService_Insert_FullMethodName && info.FullMethod != milvuspb.MilvusService_Delete_FullMethodName && @@ -259,7 +285,7 @@ func ForwardLegacyProxyUnaryServerInterceptor() grpc.UnaryServerInterceptor { } // try to forward the request to the legacy proxy. - resp, err := WAL().ForwardService().ForwardLegacyProxy(ctx, req) + resp, err := WAL().ForwardService().ForwardLegacyProxy(ctx, req, opts...) if err == nil { return resp, nil } diff --git a/internal/distributed/streaming/test_streaming.go b/internal/distributed/streaming/test_streaming.go index 9ba76a0527..b7d68316ca 100644 --- a/internal/distributed/streaming/test_streaming.go +++ b/internal/distributed/streaming/test_streaming.go @@ -244,7 +244,7 @@ func (n *noopWALAccesser) ForwardService() ForwardService { type noopForwardService struct{} -func (n *noopForwardService) ForwardLegacyProxy(ctx context.Context, request any) (any, error) { +func (n *noopForwardService) ForwardLegacyProxy(ctx context.Context, request any, opts ...ForwardOption) (any, error) { return nil, ErrForwardDisabled } diff --git a/pkg/util/contextutil/context_util.go b/pkg/util/contextutil/context_util.go index 3cebc6f8f5..c224e570ab 100644 --- a/pkg/util/contextutil/context_util.go +++ b/pkg/util/contextutil/context_util.go @@ -51,7 +51,7 @@ func TenantID(ctx context.Context) string { func AppendToIncomingContext(ctx context.Context, kv ...string) context.Context { if len(kv)%2 == 1 { - panic(fmt.Sprintf("metadata: AppendToOutgoingContext got an odd number of input pairs for metadata: %d", len(kv))) + panic(fmt.Sprintf("metadata: AppendToIncomingContext got an odd number of input pairs for metadata: %d", len(kv))) } md, ok := metadata.FromIncomingContext(ctx) if !ok { @@ -65,6 +65,23 @@ func AppendToIncomingContext(ctx context.Context, kv ...string) context.Context return metadata.NewIncomingContext(ctx, md) } +// SetToIncomingContext sets the metadata to the incoming context. +func SetToIncomingContext(ctx context.Context, kv ...string) context.Context { + if len(kv)%2 == 1 { + panic(fmt.Sprintf("metadata: SetToIncomingContext got an odd number of input pairs for metadata: %d", len(kv))) + } + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + md = metadata.New(make(map[string]string, len(kv)/2)) + } + for i, s := range kv { + if i%2 == 0 { + md.Set(s, kv[i+1]) + } + } + return metadata.NewIncomingContext(ctx, md) +} + func GetCurUserFromContext(ctx context.Context) (string, error) { username, _, err := GetAuthInfoFromContext(ctx) return username, err diff --git a/pkg/util/contextutil/context_util_test.go b/pkg/util/contextutil/context_util_test.go index ea4eb8cc47..40f6f91b9c 100644 --- a/pkg/util/contextutil/context_util_test.go +++ b/pkg/util/contextutil/context_util_test.go @@ -48,6 +48,28 @@ func TestAppendToIncomingContext(t *testing.T) { }) } +func TestSetToIncomingContext(t *testing.T) { + t.Run("invalid kvs", func(t *testing.T) { + assert.Panics(t, func() { + // nolint + SetToIncomingContext(context.Background(), "foo") + }) + }) + + t.Run("valid kvs", func(t *testing.T) { + ctx := context.Background() + ctx = SetToIncomingContext(ctx, "foo", "bar1") + md, ok := metadata.FromIncomingContext(ctx) + assert.True(t, ok) + assert.Equal(t, "bar1", md.Get("foo")[0]) + + ctx = SetToIncomingContext(ctx, "foo", "bar2") + md, ok = metadata.FromIncomingContext(ctx) + assert.True(t, ok) + assert.Equal(t, "bar2", md.Get("foo")[0]) + }) +} + func TestGetCurUserFromContext(t *testing.T) { _, err := GetCurUserFromContext(context.Background()) assert.Error(t, err)