fix: lost database in restful v2 (#46171)

issue: #45812

---------

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-12-09 10:59:13 +08:00 committed by GitHub
parent 459425ac84
commit b8086cb62b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 76 additions and 17 deletions

View File

@ -34,7 +34,6 @@ import (
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -50,7 +49,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -386,13 +384,10 @@ func wrapperProxyWithLimit(ctx context.Context, ginCtx *gin.Context, req any, ch
forwardHandler := func(reqCtx context.Context, req any) (any, error) {
interceptor := streaming.ForwardLegacyProxyUnaryServerInterceptor()
newCtx := reqCtx
if token, ok := ginCtx.Get(ContextToken); ok {
newCtx = metadata.NewIncomingContext(reqCtx, metadata.MD{
util.HeaderAuthorize: []string{token.(string)},
})
interceptor = streaming.ForwardLegacyProxyUnaryServerInterceptor(streaming.OptForwardAuth(token.(string)))
}
return interceptor(newCtx, req, &grpc.UnaryServerInfo{FullMethod: fullMethod}, func(ctx context.Context, req any) (interface{}, error) {
return interceptor(reqCtx, req, &grpc.UnaryServerInfo{FullMethod: fullMethod}, func(ctx context.Context, req any) (interface{}, error) {
return handler(ctx, req)
})
}

View File

@ -67,7 +67,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/proxypb"
"github.com/milvus-io/milvus/pkg/v2/tracer"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/interceptor"
"github.com/milvus-io/milvus/pkg/v2/util/logutil"
@ -131,7 +130,7 @@ func authenticate(c *gin.Context) {
if proxy.PasswordVerify(c, username, password) {
log.Ctx(context.TODO()).Debug("auth successful", zap.String("username", username))
c.Set(httpserver.ContextUsername, username)
c.Set(httpserver.ContextToken, crypto.Base64Encode(fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, password)))
c.Set(httpserver.ContextToken, fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, password))
return
}
}

View File

@ -40,6 +40,9 @@ import (
"github.com/milvus-io/milvus/internal/util/streamingutil/service/lazygrpc"
"github.com/milvus-io/milvus/internal/util/streamingutil/service/resolver"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -60,7 +63,22 @@ func newForwardService(streamingCoordClient client.Client) *forwardServiceImpl {
}
type ForwardService interface {
ForwardLegacyProxy(ctx context.Context, request any) (any, error)
ForwardLegacyProxy(ctx context.Context, request any, forwardAuth ...ForwardOption) (any, error)
}
// OptForwardAuth is the option to set the auth token for the forward service.
func OptForwardAuth(authToken string) ForwardOption {
return func(fs *forwardOption) {
fs.authToken = authToken
}
}
// ForwardOption is the option for the forward service.
type ForwardOption func(*forwardOption)
// forwardOption is the option for the forward service.
type forwardOption struct {
authToken string
}
// forwardServiceImpl is the implementation of FallbackService.
@ -75,12 +93,12 @@ type forwardServiceImpl struct {
}
// ForwardLegacyProxy forwards the request to the legacy proxy.
func (fs *forwardServiceImpl) ForwardLegacyProxy(ctx context.Context, request any) (any, error) {
func (fs *forwardServiceImpl) ForwardLegacyProxy(ctx context.Context, request any, opts ...ForwardOption) (any, error) {
if err := fs.checkIfForwardDisabledWithLock(ctx); err != nil {
return nil, err
}
return fs.forwardLegacyProxy(ctx, request)
return fs.forwardLegacyProxy(ctx, request, opts...)
}
// checkIfForwardDisabledWithLock checks if the forward is disabled with lock.
@ -92,12 +110,20 @@ func (fs *forwardServiceImpl) checkIfForwardDisabledWithLock(ctx context.Context
}
// forwardLegacyProxy forwards the request to the legacy proxy.
func (fs *forwardServiceImpl) forwardLegacyProxy(ctx context.Context, request any) (any, error) {
func (fs *forwardServiceImpl) forwardLegacyProxy(ctx context.Context, request any, opts ...ForwardOption) (any, error) {
s, err := fs.getLegacyProxyService(ctx)
if err != nil {
return nil, err
}
var optForwardOption forwardOption
for _, opt := range opts {
opt(&optForwardOption)
}
if optForwardOption.authToken != "" {
ctx = contextutil.SetToIncomingContext(ctx, util.HeaderAuthorize, crypto.Base64Encode(optForwardOption.authToken))
}
var result proto.Message
switch req := request.(type) {
case *milvuspb.InsertRequest:
@ -247,7 +273,7 @@ func (fs *forwardServiceImpl) markForwardDisabled() {
// the dml cannot be executed at new 2.6.x proxy until all 2.5.x proxies are down.
//
// so we need to forward the request to the 2.5.x proxy.
func ForwardLegacyProxyUnaryServerInterceptor() grpc.UnaryServerInterceptor {
func ForwardLegacyProxyUnaryServerInterceptor(opts ...ForwardOption) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if info.FullMethod != milvuspb.MilvusService_Insert_FullMethodName &&
info.FullMethod != milvuspb.MilvusService_Delete_FullMethodName &&
@ -259,7 +285,7 @@ func ForwardLegacyProxyUnaryServerInterceptor() grpc.UnaryServerInterceptor {
}
// try to forward the request to the legacy proxy.
resp, err := WAL().ForwardService().ForwardLegacyProxy(ctx, req)
resp, err := WAL().ForwardService().ForwardLegacyProxy(ctx, req, opts...)
if err == nil {
return resp, nil
}

View File

@ -244,7 +244,7 @@ func (n *noopWALAccesser) ForwardService() ForwardService {
type noopForwardService struct{}
func (n *noopForwardService) ForwardLegacyProxy(ctx context.Context, request any) (any, error) {
func (n *noopForwardService) ForwardLegacyProxy(ctx context.Context, request any, opts ...ForwardOption) (any, error) {
return nil, ErrForwardDisabled
}

View File

@ -51,7 +51,7 @@ func TenantID(ctx context.Context) string {
func AppendToIncomingContext(ctx context.Context, kv ...string) context.Context {
if len(kv)%2 == 1 {
panic(fmt.Sprintf("metadata: AppendToOutgoingContext got an odd number of input pairs for metadata: %d", len(kv)))
panic(fmt.Sprintf("metadata: AppendToIncomingContext got an odd number of input pairs for metadata: %d", len(kv)))
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
@ -65,6 +65,23 @@ func AppendToIncomingContext(ctx context.Context, kv ...string) context.Context
return metadata.NewIncomingContext(ctx, md)
}
// SetToIncomingContext sets the metadata to the incoming context.
func SetToIncomingContext(ctx context.Context, kv ...string) context.Context {
if len(kv)%2 == 1 {
panic(fmt.Sprintf("metadata: SetToIncomingContext got an odd number of input pairs for metadata: %d", len(kv)))
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
md = metadata.New(make(map[string]string, len(kv)/2))
}
for i, s := range kv {
if i%2 == 0 {
md.Set(s, kv[i+1])
}
}
return metadata.NewIncomingContext(ctx, md)
}
func GetCurUserFromContext(ctx context.Context) (string, error) {
username, _, err := GetAuthInfoFromContext(ctx)
return username, err

View File

@ -48,6 +48,28 @@ func TestAppendToIncomingContext(t *testing.T) {
})
}
func TestSetToIncomingContext(t *testing.T) {
t.Run("invalid kvs", func(t *testing.T) {
assert.Panics(t, func() {
// nolint
SetToIncomingContext(context.Background(), "foo")
})
})
t.Run("valid kvs", func(t *testing.T) {
ctx := context.Background()
ctx = SetToIncomingContext(ctx, "foo", "bar1")
md, ok := metadata.FromIncomingContext(ctx)
assert.True(t, ok)
assert.Equal(t, "bar1", md.Get("foo")[0])
ctx = SetToIncomingContext(ctx, "foo", "bar2")
md, ok = metadata.FromIncomingContext(ctx)
assert.True(t, ok)
assert.Equal(t, "bar2", md.Get("foo")[0])
})
}
func TestGetCurUserFromContext(t *testing.T) {
_, err := GetCurUserFromContext(context.Background())
assert.Error(t, err)