diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index ea18f99f5e..c22c47b72b 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -40,6 +40,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/hook" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/proxy" "github.com/milvus-io/milvus/internal/types" @@ -381,7 +382,13 @@ func wrapperProxyWithLimit(ctx context.Context, ginCtx *gin.Context, req any, ch username = "" } - response, err := proxy.HookInterceptor(context.WithValue(ctx, hook.GinParamsKey, ginCtx.Keys), req, username.(string), fullMethod, handler) + forwardHandler := func(reqCtx context.Context, req any) (any, error) { + interceptor := streaming.ForwardDMLToLegacyProxyUnaryServerInterceptor() + return interceptor(reqCtx, req, &grpc.UnaryServerInfo{FullMethod: fullMethod}, func(ctx context.Context, req any) (interface{}, error) { + return handler(ctx, req) + }) + } + response, err := proxy.HookInterceptor(context.WithValue(ctx, hook.GinParamsKey, ginCtx.Keys), req, username.(string), fullMethod, forwardHandler) if err == nil { status, ok := requestutil.GetStatusFromResponse(response) if ok { diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index eacc0319ef..934354bbea 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/distributed/streaming" mhttp "github.com/milvus-io/milvus/internal/http" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/mocks" @@ -74,6 +75,7 @@ func (req *DefaultReq) GetDbName() string { return req.DbName } func init() { paramtable.Init() + streaming.SetupNoopWALForTest() } func sendReqAndVerify(t *testing.T, testEngine *gin.Engine, testName, method string, testcase requestBodyTestCase) { diff --git a/internal/distributed/streaming/forward.go b/internal/distributed/streaming/forward.go index 17cdefd915..26007d1094 100644 --- a/internal/distributed/streaming/forward.go +++ b/internal/distributed/streaming/forward.go @@ -59,6 +59,10 @@ func newForwardService(streamingCoordClient client.Client) *forwardServiceImpl { return fs } +type ForwardService interface { + ForwardDMLToLegacyProxy(ctx context.Context, request any) (any, error) +} + // forwardServiceImpl is the implementation of FallbackService. type forwardServiceImpl struct { log.Binder @@ -246,7 +250,7 @@ func ForwardDMLToLegacyProxyUnaryServerInterceptor() grpc.UnaryServerInterceptor } // try to forward the request to the legacy proxy. - resp, err := WAL().(*walAccesserImpl).forwardService.ForwardDMLToLegacyProxy(ctx, req) + resp, err := WAL().ForwardService().ForwardDMLToLegacyProxy(ctx, req) if err == nil { return resp, nil } diff --git a/internal/distributed/streaming/streaming.go b/internal/distributed/streaming/streaming.go index 9a12ef582e..10f9cb9123 100644 --- a/internal/distributed/streaming/streaming.go +++ b/internal/distributed/streaming/streaming.go @@ -134,6 +134,9 @@ type Balancer interface { // WALAccesser is the interfaces to interact with the milvus write ahead log. type WALAccesser interface { + // ForwardService returns the forward service of the wal. + ForwardService() ForwardService + // Replicate returns the replicate service of the wal. Replicate() ReplicateService diff --git a/internal/distributed/streaming/test_streaming.go b/internal/distributed/streaming/test_streaming.go index b4295bfc09..71c8db85d0 100644 --- a/internal/distributed/streaming/test_streaming.go +++ b/internal/distributed/streaming/test_streaming.go @@ -238,6 +238,16 @@ func (n *noopWALAccesser) UpdateReplicateConfiguration(ctx context.Context, conf return nil } +func (n *noopWALAccesser) ForwardService() ForwardService { + return &noopForwardService{} +} + +type noopForwardService struct{} + +func (n *noopForwardService) ForwardDMLToLegacyProxy(ctx context.Context, request any) (any, error) { + return nil, ErrForwardDisabled +} + type noopScanner struct{} func (n *noopScanner) Done() <-chan struct{} { diff --git a/internal/distributed/streaming/wal.go b/internal/distributed/streaming/wal.go index 0cb9f7d9c3..4b2170871a 100644 --- a/internal/distributed/streaming/wal.go +++ b/internal/distributed/streaming/wal.go @@ -65,6 +65,10 @@ type walAccesserImpl struct { forwardService *forwardServiceImpl } +func (w *walAccesserImpl) ForwardService() ForwardService { + return w.forwardService +} + func (w *walAccesserImpl) Replicate() ReplicateService { return replicateService{w} } diff --git a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go index 7f61e40848..c385568259 100644 --- a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go +++ b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go @@ -288,6 +288,53 @@ func (_c *MockWALAccesser_ControlChannel_Call) RunAndReturn(run func() string) * return _c } +// ForwardService provides a mock function with no fields +func (_m *MockWALAccesser) ForwardService() streaming.ForwardService { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ForwardService") + } + + var r0 streaming.ForwardService + if rf, ok := ret.Get(0).(func() streaming.ForwardService); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(streaming.ForwardService) + } + } + + return r0 +} + +// MockWALAccesser_ForwardService_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ForwardService' +type MockWALAccesser_ForwardService_Call struct { + *mock.Call +} + +// ForwardService is a helper method to define mock.On call +func (_e *MockWALAccesser_Expecter) ForwardService() *MockWALAccesser_ForwardService_Call { + return &MockWALAccesser_ForwardService_Call{Call: _e.mock.On("ForwardService")} +} + +func (_c *MockWALAccesser_ForwardService_Call) Run(run func()) *MockWALAccesser_ForwardService_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockWALAccesser_ForwardService_Call) Return(_a0 streaming.ForwardService) *MockWALAccesser_ForwardService_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWALAccesser_ForwardService_Call) RunAndReturn(run func() streaming.ForwardService) *MockWALAccesser_ForwardService_Call { + _c.Call.Return(run) + return _c +} + // Local provides a mock function with no fields func (_m *MockWALAccesser) Local() streaming.Local { ret := _m.Called()