diff --git a/client/index/scann.go b/client/index/scann.go index fe69b6cb0d..c897593b13 100644 --- a/client/index/scann.go +++ b/client/index/scann.go @@ -45,7 +45,7 @@ func NewSCANNIndex(metricType MetricType, nlist int, withRawData bool) Index { metricType: metricType, indexType: SCANN, }, - nlist: nlist, + nlist: nlist, withRawData: withRawData, } } diff --git a/client/index/sparse.go b/client/index/sparse.go index 0dec11200c..e835c68bfb 100644 --- a/client/index/sparse.go +++ b/client/index/sparse.go @@ -5,7 +5,7 @@ import ( ) const ( - dropRatio = `drop_ratio_build` + dropRatio = `drop_ratio_build` ) var _ Index = sparseInvertedIndex{} @@ -13,19 +13,19 @@ var _ Index = sparseInvertedIndex{} // IndexSparseInverted index type for SPARSE_INVERTED_INDEX type sparseInvertedIndex struct { baseIndex - dropRatio float64 + dropRatio float64 } func (idx sparseInvertedIndex) Params() map[string]string { return map[string]string{ - MetricTypeKey: string(idx.metricType), - IndexTypeKey: string(SparseInverted), - dropRatio: fmt.Sprintf("%v", idx.dropRatio), + MetricTypeKey: string(idx.metricType), + IndexTypeKey: string(SparseInverted), + dropRatio: fmt.Sprintf("%v", idx.dropRatio), } } func NewSparseInvertedIndex(metricType MetricType, dropRatio float64) Index { - return sparseInvertedIndex { + return sparseInvertedIndex{ baseIndex: baseIndex{ metricType: metricType, indexType: SparseInverted, @@ -36,22 +36,23 @@ func NewSparseInvertedIndex(metricType MetricType, dropRatio float64) Index { } var _ Index = sparseWANDIndex{} + type sparseWANDIndex struct { baseIndex - dropRatio float64 + dropRatio float64 } func (idx sparseWANDIndex) Params() map[string]string { return map[string]string{ - MetricTypeKey: string(idx.metricType), - IndexTypeKey: string(SparseWAND), - dropRatio: fmt.Sprintf("%v", idx.dropRatio), + MetricTypeKey: string(idx.metricType), + IndexTypeKey: string(SparseWAND), + dropRatio: fmt.Sprintf("%v", idx.dropRatio), } } // IndexSparseWAND index type for SPARSE_WAND, weak-and func NewSparseWANDIndex(metricType MetricType, dropRatio float64) Index { - return sparseWANDIndex { + return sparseWANDIndex{ baseIndex: baseIndex{ metricType: metricType, indexType: SparseWAND, @@ -60,4 +61,3 @@ func NewSparseWANDIndex(metricType MetricType, dropRatio float64) Index { dropRatio: dropRatio, } } - diff --git a/internal/mocks/proto/mock_streamingpb/mock_StreamingNodeHandlerService_ConsumeServer.go b/internal/mocks/proto/mock_streamingpb/mock_StreamingNodeHandlerService_ConsumeServer.go new file mode 100644 index 0000000000..151bb30156 --- /dev/null +++ b/internal/mocks/proto/mock_streamingpb/mock_StreamingNodeHandlerService_ConsumeServer.go @@ -0,0 +1,378 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_streamingpb + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + metadata "google.golang.org/grpc/metadata" + + streamingpb "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +// MockStreamingNodeHandlerService_ConsumeServer is an autogenerated mock type for the StreamingNodeHandlerService_ConsumeServer type +type MockStreamingNodeHandlerService_ConsumeServer struct { + mock.Mock +} + +type MockStreamingNodeHandlerService_ConsumeServer_Expecter struct { + mock *mock.Mock +} + +func (_m *MockStreamingNodeHandlerService_ConsumeServer) EXPECT() *MockStreamingNodeHandlerService_ConsumeServer_Expecter { + return &MockStreamingNodeHandlerService_ConsumeServer_Expecter{mock: &_m.Mock} +} + +// Context provides a mock function with given fields: +func (_m *MockStreamingNodeHandlerService_ConsumeServer) Context() context.Context { + ret := _m.Called() + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +// MockStreamingNodeHandlerService_ConsumeServer_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context' +type MockStreamingNodeHandlerService_ConsumeServer_Context_Call struct { + *mock.Call +} + +// Context is a helper method to define mock.On call +func (_e *MockStreamingNodeHandlerService_ConsumeServer_Expecter) Context() *MockStreamingNodeHandlerService_ConsumeServer_Context_Call { + return &MockStreamingNodeHandlerService_ConsumeServer_Context_Call{Call: _e.mock.On("Context")} +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_Context_Call) Run(run func()) *MockStreamingNodeHandlerService_ConsumeServer_Context_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_Context_Call) Return(_a0 context.Context) *MockStreamingNodeHandlerService_ConsumeServer_Context_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_Context_Call) RunAndReturn(run func() context.Context) *MockStreamingNodeHandlerService_ConsumeServer_Context_Call { + _c.Call.Return(run) + return _c +} + +// Recv provides a mock function with given fields: +func (_m *MockStreamingNodeHandlerService_ConsumeServer) Recv() (*streamingpb.ConsumeRequest, error) { + ret := _m.Called() + + var r0 *streamingpb.ConsumeRequest + var r1 error + if rf, ok := ret.Get(0).(func() (*streamingpb.ConsumeRequest, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *streamingpb.ConsumeRequest); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*streamingpb.ConsumeRequest) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockStreamingNodeHandlerService_ConsumeServer_Recv_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Recv' +type MockStreamingNodeHandlerService_ConsumeServer_Recv_Call struct { + *mock.Call +} + +// Recv is a helper method to define mock.On call +func (_e *MockStreamingNodeHandlerService_ConsumeServer_Expecter) Recv() *MockStreamingNodeHandlerService_ConsumeServer_Recv_Call { + return &MockStreamingNodeHandlerService_ConsumeServer_Recv_Call{Call: _e.mock.On("Recv")} +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_Recv_Call) Run(run func()) *MockStreamingNodeHandlerService_ConsumeServer_Recv_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_Recv_Call) Return(_a0 *streamingpb.ConsumeRequest, _a1 error) *MockStreamingNodeHandlerService_ConsumeServer_Recv_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_Recv_Call) RunAndReturn(run func() (*streamingpb.ConsumeRequest, error)) *MockStreamingNodeHandlerService_ConsumeServer_Recv_Call { + _c.Call.Return(run) + return _c +} + +// RecvMsg provides a mock function with given fields: m +func (_m *MockStreamingNodeHandlerService_ConsumeServer) RecvMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingNodeHandlerService_ConsumeServer_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg' +type MockStreamingNodeHandlerService_ConsumeServer_RecvMsg_Call struct { + *mock.Call +} + +// RecvMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockStreamingNodeHandlerService_ConsumeServer_Expecter) RecvMsg(m interface{}) *MockStreamingNodeHandlerService_ConsumeServer_RecvMsg_Call { + return &MockStreamingNodeHandlerService_ConsumeServer_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)} +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_RecvMsg_Call) Run(run func(m interface{})) *MockStreamingNodeHandlerService_ConsumeServer_RecvMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_RecvMsg_Call) Return(_a0 error) *MockStreamingNodeHandlerService_ConsumeServer_RecvMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_RecvMsg_Call) RunAndReturn(run func(interface{}) error) *MockStreamingNodeHandlerService_ConsumeServer_RecvMsg_Call { + _c.Call.Return(run) + return _c +} + +// Send provides a mock function with given fields: _a0 +func (_m *MockStreamingNodeHandlerService_ConsumeServer) Send(_a0 *streamingpb.ConsumeResponse) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(*streamingpb.ConsumeResponse) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingNodeHandlerService_ConsumeServer_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send' +type MockStreamingNodeHandlerService_ConsumeServer_Send_Call struct { + *mock.Call +} + +// Send is a helper method to define mock.On call +// - _a0 *streamingpb.ConsumeResponse +func (_e *MockStreamingNodeHandlerService_ConsumeServer_Expecter) Send(_a0 interface{}) *MockStreamingNodeHandlerService_ConsumeServer_Send_Call { + return &MockStreamingNodeHandlerService_ConsumeServer_Send_Call{Call: _e.mock.On("Send", _a0)} +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_Send_Call) Run(run func(_a0 *streamingpb.ConsumeResponse)) *MockStreamingNodeHandlerService_ConsumeServer_Send_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*streamingpb.ConsumeResponse)) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_Send_Call) Return(_a0 error) *MockStreamingNodeHandlerService_ConsumeServer_Send_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_Send_Call) RunAndReturn(run func(*streamingpb.ConsumeResponse) error) *MockStreamingNodeHandlerService_ConsumeServer_Send_Call { + _c.Call.Return(run) + return _c +} + +// SendHeader provides a mock function with given fields: _a0 +func (_m *MockStreamingNodeHandlerService_ConsumeServer) SendHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingNodeHandlerService_ConsumeServer_SendHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendHeader' +type MockStreamingNodeHandlerService_ConsumeServer_SendHeader_Call struct { + *mock.Call +} + +// SendHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockStreamingNodeHandlerService_ConsumeServer_Expecter) SendHeader(_a0 interface{}) *MockStreamingNodeHandlerService_ConsumeServer_SendHeader_Call { + return &MockStreamingNodeHandlerService_ConsumeServer_SendHeader_Call{Call: _e.mock.On("SendHeader", _a0)} +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SendHeader_Call) Run(run func(_a0 metadata.MD)) *MockStreamingNodeHandlerService_ConsumeServer_SendHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SendHeader_Call) Return(_a0 error) *MockStreamingNodeHandlerService_ConsumeServer_SendHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SendHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockStreamingNodeHandlerService_ConsumeServer_SendHeader_Call { + _c.Call.Return(run) + return _c +} + +// SendMsg provides a mock function with given fields: m +func (_m *MockStreamingNodeHandlerService_ConsumeServer) SendMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingNodeHandlerService_ConsumeServer_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg' +type MockStreamingNodeHandlerService_ConsumeServer_SendMsg_Call struct { + *mock.Call +} + +// SendMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockStreamingNodeHandlerService_ConsumeServer_Expecter) SendMsg(m interface{}) *MockStreamingNodeHandlerService_ConsumeServer_SendMsg_Call { + return &MockStreamingNodeHandlerService_ConsumeServer_SendMsg_Call{Call: _e.mock.On("SendMsg", m)} +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SendMsg_Call) Run(run func(m interface{})) *MockStreamingNodeHandlerService_ConsumeServer_SendMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SendMsg_Call) Return(_a0 error) *MockStreamingNodeHandlerService_ConsumeServer_SendMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SendMsg_Call) RunAndReturn(run func(interface{}) error) *MockStreamingNodeHandlerService_ConsumeServer_SendMsg_Call { + _c.Call.Return(run) + return _c +} + +// SetHeader provides a mock function with given fields: _a0 +func (_m *MockStreamingNodeHandlerService_ConsumeServer) SetHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingNodeHandlerService_ConsumeServer_SetHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHeader' +type MockStreamingNodeHandlerService_ConsumeServer_SetHeader_Call struct { + *mock.Call +} + +// SetHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockStreamingNodeHandlerService_ConsumeServer_Expecter) SetHeader(_a0 interface{}) *MockStreamingNodeHandlerService_ConsumeServer_SetHeader_Call { + return &MockStreamingNodeHandlerService_ConsumeServer_SetHeader_Call{Call: _e.mock.On("SetHeader", _a0)} +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SetHeader_Call) Run(run func(_a0 metadata.MD)) *MockStreamingNodeHandlerService_ConsumeServer_SetHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SetHeader_Call) Return(_a0 error) *MockStreamingNodeHandlerService_ConsumeServer_SetHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SetHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockStreamingNodeHandlerService_ConsumeServer_SetHeader_Call { + _c.Call.Return(run) + return _c +} + +// SetTrailer provides a mock function with given fields: _a0 +func (_m *MockStreamingNodeHandlerService_ConsumeServer) SetTrailer(_a0 metadata.MD) { + _m.Called(_a0) +} + +// MockStreamingNodeHandlerService_ConsumeServer_SetTrailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTrailer' +type MockStreamingNodeHandlerService_ConsumeServer_SetTrailer_Call struct { + *mock.Call +} + +// SetTrailer is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockStreamingNodeHandlerService_ConsumeServer_Expecter) SetTrailer(_a0 interface{}) *MockStreamingNodeHandlerService_ConsumeServer_SetTrailer_Call { + return &MockStreamingNodeHandlerService_ConsumeServer_SetTrailer_Call{Call: _e.mock.On("SetTrailer", _a0)} +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SetTrailer_Call) Run(run func(_a0 metadata.MD)) *MockStreamingNodeHandlerService_ConsumeServer_SetTrailer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SetTrailer_Call) Return() *MockStreamingNodeHandlerService_ConsumeServer_SetTrailer_Call { + _c.Call.Return() + return _c +} + +func (_c *MockStreamingNodeHandlerService_ConsumeServer_SetTrailer_Call) RunAndReturn(run func(metadata.MD)) *MockStreamingNodeHandlerService_ConsumeServer_SetTrailer_Call { + _c.Call.Return(run) + return _c +} + +// NewMockStreamingNodeHandlerService_ConsumeServer creates a new instance of MockStreamingNodeHandlerService_ConsumeServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockStreamingNodeHandlerService_ConsumeServer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockStreamingNodeHandlerService_ConsumeServer { + mock := &MockStreamingNodeHandlerService_ConsumeServer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/proto/mock_streamingpb/mock_StreamingNodeHandlerService_ProduceServer.go b/internal/mocks/proto/mock_streamingpb/mock_StreamingNodeHandlerService_ProduceServer.go new file mode 100644 index 0000000000..d4397f07fb --- /dev/null +++ b/internal/mocks/proto/mock_streamingpb/mock_StreamingNodeHandlerService_ProduceServer.go @@ -0,0 +1,378 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_streamingpb + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + metadata "google.golang.org/grpc/metadata" + + streamingpb "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +// MockStreamingNodeHandlerService_ProduceServer is an autogenerated mock type for the StreamingNodeHandlerService_ProduceServer type +type MockStreamingNodeHandlerService_ProduceServer struct { + mock.Mock +} + +type MockStreamingNodeHandlerService_ProduceServer_Expecter struct { + mock *mock.Mock +} + +func (_m *MockStreamingNodeHandlerService_ProduceServer) EXPECT() *MockStreamingNodeHandlerService_ProduceServer_Expecter { + return &MockStreamingNodeHandlerService_ProduceServer_Expecter{mock: &_m.Mock} +} + +// Context provides a mock function with given fields: +func (_m *MockStreamingNodeHandlerService_ProduceServer) Context() context.Context { + ret := _m.Called() + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +// MockStreamingNodeHandlerService_ProduceServer_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context' +type MockStreamingNodeHandlerService_ProduceServer_Context_Call struct { + *mock.Call +} + +// Context is a helper method to define mock.On call +func (_e *MockStreamingNodeHandlerService_ProduceServer_Expecter) Context() *MockStreamingNodeHandlerService_ProduceServer_Context_Call { + return &MockStreamingNodeHandlerService_ProduceServer_Context_Call{Call: _e.mock.On("Context")} +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_Context_Call) Run(run func()) *MockStreamingNodeHandlerService_ProduceServer_Context_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_Context_Call) Return(_a0 context.Context) *MockStreamingNodeHandlerService_ProduceServer_Context_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_Context_Call) RunAndReturn(run func() context.Context) *MockStreamingNodeHandlerService_ProduceServer_Context_Call { + _c.Call.Return(run) + return _c +} + +// Recv provides a mock function with given fields: +func (_m *MockStreamingNodeHandlerService_ProduceServer) Recv() (*streamingpb.ProduceRequest, error) { + ret := _m.Called() + + var r0 *streamingpb.ProduceRequest + var r1 error + if rf, ok := ret.Get(0).(func() (*streamingpb.ProduceRequest, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *streamingpb.ProduceRequest); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*streamingpb.ProduceRequest) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockStreamingNodeHandlerService_ProduceServer_Recv_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Recv' +type MockStreamingNodeHandlerService_ProduceServer_Recv_Call struct { + *mock.Call +} + +// Recv is a helper method to define mock.On call +func (_e *MockStreamingNodeHandlerService_ProduceServer_Expecter) Recv() *MockStreamingNodeHandlerService_ProduceServer_Recv_Call { + return &MockStreamingNodeHandlerService_ProduceServer_Recv_Call{Call: _e.mock.On("Recv")} +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_Recv_Call) Run(run func()) *MockStreamingNodeHandlerService_ProduceServer_Recv_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_Recv_Call) Return(_a0 *streamingpb.ProduceRequest, _a1 error) *MockStreamingNodeHandlerService_ProduceServer_Recv_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_Recv_Call) RunAndReturn(run func() (*streamingpb.ProduceRequest, error)) *MockStreamingNodeHandlerService_ProduceServer_Recv_Call { + _c.Call.Return(run) + return _c +} + +// RecvMsg provides a mock function with given fields: m +func (_m *MockStreamingNodeHandlerService_ProduceServer) RecvMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingNodeHandlerService_ProduceServer_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg' +type MockStreamingNodeHandlerService_ProduceServer_RecvMsg_Call struct { + *mock.Call +} + +// RecvMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockStreamingNodeHandlerService_ProduceServer_Expecter) RecvMsg(m interface{}) *MockStreamingNodeHandlerService_ProduceServer_RecvMsg_Call { + return &MockStreamingNodeHandlerService_ProduceServer_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)} +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_RecvMsg_Call) Run(run func(m interface{})) *MockStreamingNodeHandlerService_ProduceServer_RecvMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_RecvMsg_Call) Return(_a0 error) *MockStreamingNodeHandlerService_ProduceServer_RecvMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_RecvMsg_Call) RunAndReturn(run func(interface{}) error) *MockStreamingNodeHandlerService_ProduceServer_RecvMsg_Call { + _c.Call.Return(run) + return _c +} + +// Send provides a mock function with given fields: _a0 +func (_m *MockStreamingNodeHandlerService_ProduceServer) Send(_a0 *streamingpb.ProduceResponse) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(*streamingpb.ProduceResponse) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingNodeHandlerService_ProduceServer_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send' +type MockStreamingNodeHandlerService_ProduceServer_Send_Call struct { + *mock.Call +} + +// Send is a helper method to define mock.On call +// - _a0 *streamingpb.ProduceResponse +func (_e *MockStreamingNodeHandlerService_ProduceServer_Expecter) Send(_a0 interface{}) *MockStreamingNodeHandlerService_ProduceServer_Send_Call { + return &MockStreamingNodeHandlerService_ProduceServer_Send_Call{Call: _e.mock.On("Send", _a0)} +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_Send_Call) Run(run func(_a0 *streamingpb.ProduceResponse)) *MockStreamingNodeHandlerService_ProduceServer_Send_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*streamingpb.ProduceResponse)) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_Send_Call) Return(_a0 error) *MockStreamingNodeHandlerService_ProduceServer_Send_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_Send_Call) RunAndReturn(run func(*streamingpb.ProduceResponse) error) *MockStreamingNodeHandlerService_ProduceServer_Send_Call { + _c.Call.Return(run) + return _c +} + +// SendHeader provides a mock function with given fields: _a0 +func (_m *MockStreamingNodeHandlerService_ProduceServer) SendHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingNodeHandlerService_ProduceServer_SendHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendHeader' +type MockStreamingNodeHandlerService_ProduceServer_SendHeader_Call struct { + *mock.Call +} + +// SendHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockStreamingNodeHandlerService_ProduceServer_Expecter) SendHeader(_a0 interface{}) *MockStreamingNodeHandlerService_ProduceServer_SendHeader_Call { + return &MockStreamingNodeHandlerService_ProduceServer_SendHeader_Call{Call: _e.mock.On("SendHeader", _a0)} +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SendHeader_Call) Run(run func(_a0 metadata.MD)) *MockStreamingNodeHandlerService_ProduceServer_SendHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SendHeader_Call) Return(_a0 error) *MockStreamingNodeHandlerService_ProduceServer_SendHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SendHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockStreamingNodeHandlerService_ProduceServer_SendHeader_Call { + _c.Call.Return(run) + return _c +} + +// SendMsg provides a mock function with given fields: m +func (_m *MockStreamingNodeHandlerService_ProduceServer) SendMsg(m interface{}) error { + ret := _m.Called(m) + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingNodeHandlerService_ProduceServer_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg' +type MockStreamingNodeHandlerService_ProduceServer_SendMsg_Call struct { + *mock.Call +} + +// SendMsg is a helper method to define mock.On call +// - m interface{} +func (_e *MockStreamingNodeHandlerService_ProduceServer_Expecter) SendMsg(m interface{}) *MockStreamingNodeHandlerService_ProduceServer_SendMsg_Call { + return &MockStreamingNodeHandlerService_ProduceServer_SendMsg_Call{Call: _e.mock.On("SendMsg", m)} +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SendMsg_Call) Run(run func(m interface{})) *MockStreamingNodeHandlerService_ProduceServer_SendMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SendMsg_Call) Return(_a0 error) *MockStreamingNodeHandlerService_ProduceServer_SendMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SendMsg_Call) RunAndReturn(run func(interface{}) error) *MockStreamingNodeHandlerService_ProduceServer_SendMsg_Call { + _c.Call.Return(run) + return _c +} + +// SetHeader provides a mock function with given fields: _a0 +func (_m *MockStreamingNodeHandlerService_ProduceServer) SetHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingNodeHandlerService_ProduceServer_SetHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHeader' +type MockStreamingNodeHandlerService_ProduceServer_SetHeader_Call struct { + *mock.Call +} + +// SetHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockStreamingNodeHandlerService_ProduceServer_Expecter) SetHeader(_a0 interface{}) *MockStreamingNodeHandlerService_ProduceServer_SetHeader_Call { + return &MockStreamingNodeHandlerService_ProduceServer_SetHeader_Call{Call: _e.mock.On("SetHeader", _a0)} +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SetHeader_Call) Run(run func(_a0 metadata.MD)) *MockStreamingNodeHandlerService_ProduceServer_SetHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SetHeader_Call) Return(_a0 error) *MockStreamingNodeHandlerService_ProduceServer_SetHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SetHeader_Call) RunAndReturn(run func(metadata.MD) error) *MockStreamingNodeHandlerService_ProduceServer_SetHeader_Call { + _c.Call.Return(run) + return _c +} + +// SetTrailer provides a mock function with given fields: _a0 +func (_m *MockStreamingNodeHandlerService_ProduceServer) SetTrailer(_a0 metadata.MD) { + _m.Called(_a0) +} + +// MockStreamingNodeHandlerService_ProduceServer_SetTrailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTrailer' +type MockStreamingNodeHandlerService_ProduceServer_SetTrailer_Call struct { + *mock.Call +} + +// SetTrailer is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *MockStreamingNodeHandlerService_ProduceServer_Expecter) SetTrailer(_a0 interface{}) *MockStreamingNodeHandlerService_ProduceServer_SetTrailer_Call { + return &MockStreamingNodeHandlerService_ProduceServer_SetTrailer_Call{Call: _e.mock.On("SetTrailer", _a0)} +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SetTrailer_Call) Run(run func(_a0 metadata.MD)) *MockStreamingNodeHandlerService_ProduceServer_SetTrailer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SetTrailer_Call) Return() *MockStreamingNodeHandlerService_ProduceServer_SetTrailer_Call { + _c.Call.Return() + return _c +} + +func (_c *MockStreamingNodeHandlerService_ProduceServer_SetTrailer_Call) RunAndReturn(run func(metadata.MD)) *MockStreamingNodeHandlerService_ProduceServer_SetTrailer_Call { + _c.Call.Return(run) + return _c +} + +// NewMockStreamingNodeHandlerService_ProduceServer creates a new instance of MockStreamingNodeHandlerService_ProduceServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockStreamingNodeHandlerService_ProduceServer(t interface { + mock.TestingT + Cleanup(func()) +}) *MockStreamingNodeHandlerService_ProduceServer { + mock := &MockStreamingNodeHandlerService_ProduceServer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/streamingnode/server/mock_wal/mock_Scanner.go b/internal/mocks/streamingnode/server/mock_wal/mock_Scanner.go index be26d70e76..25fd0a2f9f 100644 --- a/internal/mocks/streamingnode/server/mock_wal/mock_Scanner.go +++ b/internal/mocks/streamingnode/server/mock_wal/mock_Scanner.go @@ -5,6 +5,8 @@ package mock_wal import ( message "github.com/milvus-io/milvus/pkg/streaming/util/message" mock "github.com/stretchr/testify/mock" + + types "github.com/milvus-io/milvus/pkg/streaming/util/types" ) // MockScanner is an autogenerated mock type for the Scanner type @@ -63,6 +65,47 @@ func (_c *MockScanner_Chan_Call) RunAndReturn(run func() <-chan message.Immutabl return _c } +// Channel provides a mock function with given fields: +func (_m *MockScanner) Channel() types.PChannelInfo { + ret := _m.Called() + + var r0 types.PChannelInfo + if rf, ok := ret.Get(0).(func() types.PChannelInfo); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.PChannelInfo) + } + + return r0 +} + +// MockScanner_Channel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Channel' +type MockScanner_Channel_Call struct { + *mock.Call +} + +// Channel is a helper method to define mock.On call +func (_e *MockScanner_Expecter) Channel() *MockScanner_Channel_Call { + return &MockScanner_Channel_Call{Call: _e.mock.On("Channel")} +} + +func (_c *MockScanner_Channel_Call) Run(run func()) *MockScanner_Channel_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockScanner_Channel_Call) Return(_a0 types.PChannelInfo) *MockScanner_Channel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockScanner_Channel_Call) RunAndReturn(run func() types.PChannelInfo) *MockScanner_Channel_Call { + _c.Call.Return(run) + return _c +} + // Close provides a mock function with given fields: func (_m *MockScanner) Close() error { ret := _m.Called() diff --git a/internal/mocks/streamingnode/server/mock_wal/mock_WAL.go b/internal/mocks/streamingnode/server/mock_wal/mock_WAL.go index 0cf318fe5e..4721914fde 100644 --- a/internal/mocks/streamingnode/server/mock_wal/mock_WAL.go +++ b/internal/mocks/streamingnode/server/mock_wal/mock_WAL.go @@ -244,6 +244,47 @@ func (_c *MockWAL_Read_Call) RunAndReturn(run func(context.Context, wal.ReadOpti return _c } +// WALName provides a mock function with given fields: +func (_m *MockWAL) WALName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockWAL_WALName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WALName' +type MockWAL_WALName_Call struct { + *mock.Call +} + +// WALName is a helper method to define mock.On call +func (_e *MockWAL_Expecter) WALName() *MockWAL_WALName_Call { + return &MockWAL_WALName_Call{Call: _e.mock.On("WALName")} +} + +func (_c *MockWAL_WALName_Call) Run(run func()) *MockWAL_WALName_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockWAL_WALName_Call) Return(_a0 string) *MockWAL_WALName_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWAL_WALName_Call) RunAndReturn(run func() string) *MockWAL_WALName_Call { + _c.Call.Return(run) + return _c +} + // NewMockWAL creates a new instance of MockWAL. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockWAL(t interface { diff --git a/internal/mocks/streamingnode/server/mock_walmanager/mock_Manager.go b/internal/mocks/streamingnode/server/mock_walmanager/mock_Manager.go new file mode 100644 index 0000000000..4c12954e6c --- /dev/null +++ b/internal/mocks/streamingnode/server/mock_walmanager/mock_Manager.go @@ -0,0 +1,264 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_walmanager + +import ( + context "context" + + types "github.com/milvus-io/milvus/pkg/streaming/util/types" + mock "github.com/stretchr/testify/mock" + + wal "github.com/milvus-io/milvus/internal/streamingnode/server/wal" +) + +// MockManager is an autogenerated mock type for the Manager type +type MockManager struct { + mock.Mock +} + +type MockManager_Expecter struct { + mock *mock.Mock +} + +func (_m *MockManager) EXPECT() *MockManager_Expecter { + return &MockManager_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *MockManager) Close() { + _m.Called() +} + +// MockManager_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockManager_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockManager_Expecter) Close() *MockManager_Close_Call { + return &MockManager_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockManager_Close_Call) Run(run func()) *MockManager_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockManager_Close_Call) Return() *MockManager_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockManager_Close_Call) RunAndReturn(run func()) *MockManager_Close_Call { + _c.Call.Return(run) + return _c +} + +// GetAllAvailableChannels provides a mock function with given fields: +func (_m *MockManager) GetAllAvailableChannels() ([]types.PChannelInfo, error) { + ret := _m.Called() + + var r0 []types.PChannelInfo + var r1 error + if rf, ok := ret.Get(0).(func() ([]types.PChannelInfo, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() []types.PChannelInfo); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.PChannelInfo) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockManager_GetAllAvailableChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAllAvailableChannels' +type MockManager_GetAllAvailableChannels_Call struct { + *mock.Call +} + +// GetAllAvailableChannels is a helper method to define mock.On call +func (_e *MockManager_Expecter) GetAllAvailableChannels() *MockManager_GetAllAvailableChannels_Call { + return &MockManager_GetAllAvailableChannels_Call{Call: _e.mock.On("GetAllAvailableChannels")} +} + +func (_c *MockManager_GetAllAvailableChannels_Call) Run(run func()) *MockManager_GetAllAvailableChannels_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockManager_GetAllAvailableChannels_Call) Return(_a0 []types.PChannelInfo, _a1 error) *MockManager_GetAllAvailableChannels_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockManager_GetAllAvailableChannels_Call) RunAndReturn(run func() ([]types.PChannelInfo, error)) *MockManager_GetAllAvailableChannels_Call { + _c.Call.Return(run) + return _c +} + +// GetAvailableWAL provides a mock function with given fields: channel +func (_m *MockManager) GetAvailableWAL(channel types.PChannelInfo) (wal.WAL, error) { + ret := _m.Called(channel) + + var r0 wal.WAL + var r1 error + if rf, ok := ret.Get(0).(func(types.PChannelInfo) (wal.WAL, error)); ok { + return rf(channel) + } + if rf, ok := ret.Get(0).(func(types.PChannelInfo) wal.WAL); ok { + r0 = rf(channel) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(wal.WAL) + } + } + + if rf, ok := ret.Get(1).(func(types.PChannelInfo) error); ok { + r1 = rf(channel) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockManager_GetAvailableWAL_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAvailableWAL' +type MockManager_GetAvailableWAL_Call struct { + *mock.Call +} + +// GetAvailableWAL is a helper method to define mock.On call +// - channel types.PChannelInfo +func (_e *MockManager_Expecter) GetAvailableWAL(channel interface{}) *MockManager_GetAvailableWAL_Call { + return &MockManager_GetAvailableWAL_Call{Call: _e.mock.On("GetAvailableWAL", channel)} +} + +func (_c *MockManager_GetAvailableWAL_Call) Run(run func(channel types.PChannelInfo)) *MockManager_GetAvailableWAL_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(types.PChannelInfo)) + }) + return _c +} + +func (_c *MockManager_GetAvailableWAL_Call) Return(_a0 wal.WAL, _a1 error) *MockManager_GetAvailableWAL_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockManager_GetAvailableWAL_Call) RunAndReturn(run func(types.PChannelInfo) (wal.WAL, error)) *MockManager_GetAvailableWAL_Call { + _c.Call.Return(run) + return _c +} + +// Open provides a mock function with given fields: ctx, channel +func (_m *MockManager) Open(ctx context.Context, channel types.PChannelInfo) error { + ret := _m.Called(ctx, channel) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.PChannelInfo) error); ok { + r0 = rf(ctx, channel) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockManager_Open_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Open' +type MockManager_Open_Call struct { + *mock.Call +} + +// Open is a helper method to define mock.On call +// - ctx context.Context +// - channel types.PChannelInfo +func (_e *MockManager_Expecter) Open(ctx interface{}, channel interface{}) *MockManager_Open_Call { + return &MockManager_Open_Call{Call: _e.mock.On("Open", ctx, channel)} +} + +func (_c *MockManager_Open_Call) Run(run func(ctx context.Context, channel types.PChannelInfo)) *MockManager_Open_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.PChannelInfo)) + }) + return _c +} + +func (_c *MockManager_Open_Call) Return(_a0 error) *MockManager_Open_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockManager_Open_Call) RunAndReturn(run func(context.Context, types.PChannelInfo) error) *MockManager_Open_Call { + _c.Call.Return(run) + return _c +} + +// Remove provides a mock function with given fields: ctx, channel +func (_m *MockManager) Remove(ctx context.Context, channel types.PChannelInfo) error { + ret := _m.Called(ctx, channel) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, types.PChannelInfo) error); ok { + r0 = rf(ctx, channel) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockManager_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove' +type MockManager_Remove_Call struct { + *mock.Call +} + +// Remove is a helper method to define mock.On call +// - ctx context.Context +// - channel types.PChannelInfo +func (_e *MockManager_Expecter) Remove(ctx interface{}, channel interface{}) *MockManager_Remove_Call { + return &MockManager_Remove_Call{Call: _e.mock.On("Remove", ctx, channel)} +} + +func (_c *MockManager_Remove_Call) Run(run func(ctx context.Context, channel types.PChannelInfo)) *MockManager_Remove_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.PChannelInfo)) + }) + return _c +} + +func (_c *MockManager_Remove_Call) Return(_a0 error) *MockManager_Remove_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockManager_Remove_Call) RunAndReturn(run func(context.Context, types.PChannelInfo) error) *MockManager_Remove_Call { + _c.Call.Return(run) + return _c +} + +// NewMockManager creates a new instance of MockManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockManager(t interface { + mock.TestingT + Cleanup(func()) +}) *MockManager { + mock := &MockManager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/proto/streaming.proto b/internal/proto/streaming.proto index 4c5b951229..c6d2d7c4fe 100644 --- a/internal/proto/streaming.proto +++ b/internal/proto/streaming.proto @@ -26,12 +26,6 @@ message Message { message PChannelInfo { string name = 1; // channel name int64 term = 2; // A monotonic increasing term, every time the channel is recovered or moved to another streamingnode, the term will increase by meta server. - int64 server_id = 3; // The log node id address of the channel. -} - -// VChannelInfo is the information of a vchannel info. -message VChannelInfo { - string name = 1; } // DeliverPolicy is the policy to deliver message. @@ -80,6 +74,7 @@ enum StreamingCode { STREAMING_CODE_IGNORED_OPERATION = 7; // ignored operation STREAMING_CODE_INNER = 8; // underlying service failure. STREAMING_CODE_EOF = 9; // end of stream, generated by grpc status. + STREAMING_CODE_INVAILD_ARGUMENT = 10; // invalid argument STREAMING_CODE_UNKNOWN = 999; // unknown error } @@ -87,4 +82,184 @@ enum StreamingCode { message StreamingError { StreamingCode code = 1; string cause = 2; -} \ No newline at end of file +} + + +// +// StreamingNodeHandlerService +// + +// StreamingNodeHandlerService is the service to handle log messages. +// All handler operation will be blocked until the channel is ready read or write on that log node. +// Server: all log node. Running on every log node. +// Client: all log produce or consuming node. +service StreamingNodeHandlerService { + // Produce is a bi-directional streaming RPC to send messages to a channel. + // All messages sent to a channel will be assigned a unique messageID. + // The messageID is used to identify the message in the channel. + // The messageID isn't promised to be monotonous increasing with the sequence of responsing. + // Error: + // If channel isn't assign to this log node, the RPC will return error CHANNEL_NOT_EXIST. + // If channel is moving away to other log node, the RPC will return error CHANNEL_FENCED. + rpc Produce(stream ProduceRequest) returns (stream ProduceResponse) {}; + + // Consume is a server streaming RPC to receive messages from a channel. + // All message after given startMessageID and excluding will be sent to the client by stream. + // If no more message in the channel, the stream will be blocked until new message coming. + // Error: + // If channel isn't assign to this log node, the RPC will return error CHANNEL_NOT_EXIST. + // If channel is moving away to other log node, the RPC will return error CHANNEL_FENCED. + rpc Consume(stream ConsumeRequest) returns (stream ConsumeResponse) {}; +} + +// ProduceRequest is the request of the Produce RPC. +// Channel name will be passthrough in the header of stream bu not in the request body. +message ProduceRequest { + oneof request { + ProduceMessageRequest produce = 2; + CloseProducerRequest close = 3; + } +} + +// CreateProducerRequest is the request of the CreateProducer RPC. +// CreateProducerRequest is passed in the header of stream. +message CreateProducerRequest { + PChannelInfo pchannel = 1; +} + +// ProduceMessageRequest is the request of the Produce RPC. +message ProduceMessageRequest { + int64 request_id = 1; // request id for reply. + Message message = 2; // message to be sent. +} + +// CloseProducerRequest is the request of the CloseProducer RPC. +// After CloseProducerRequest is requested, no more ProduceRequest can be sent. +message CloseProducerRequest { +} + +// ProduceResponse is the response of the Produce RPC. +message ProduceResponse { + oneof response { + CreateProducerResponse create = 1; + ProduceMessageResponse produce = 2; + CloseProducerResponse close = 3; + } +} + +// CreateProducerResponse is the result of the CreateProducer RPC. +message CreateProducerResponse { + int64 producer_id = 1; // A unique producer id on streamingnode for this producer in streamingnode lifetime. + // Is used to identify the producer in streamingnode for other unary grpc call at producer level. +} + +message ProduceMessageResponse { + int64 request_id = 1; + oneof response { + ProduceMessageResponseResult result = 2; + StreamingError error = 3; + } +} + +// ProduceMessageResponseResult is the result of the produce message streaming RPC. +message ProduceMessageResponseResult { + MessageID id = 1; // the offset of the message in the channel +} + +// CloseProducerResponse is the result of the CloseProducer RPC. +message CloseProducerResponse { +} + +// ConsumeRequest is the request of the Consume RPC. +// Add more control block in future. +message ConsumeRequest { + oneof request { + CloseConsumerRequest close = 1; + } +} + +// CloseConsumerRequest is the request of the CloseConsumer RPC. +// After CloseConsumerRequest is requested, no more ConsumeRequest can be sent. +message CloseConsumerRequest { +} + +// CreateConsumerRequest is the request of the CreateConsumer RPC. +// CreateConsumerRequest is passed in the header of stream. +message CreateConsumerRequest { + PChannelInfo pchannel = 1; + DeliverPolicy deliver_policy = 2; // deliver policy. + repeated DeliverFilter deliver_filters = 3; // deliver filter. +} + +// ConsumeResponse is the reponse of the Consume RPC. +message ConsumeResponse { + oneof response { + CreateConsumerResponse create = 1; + ConsumeMessageReponse consume = 2; + CloseConsumerResponse close = 3; + } +} + +message CreateConsumerResponse { +} + +message ConsumeMessageReponse { + MessageID id = 1; // message id of message. + Message message = 2; // message to be consumed. +} + +message CloseConsumerResponse { +} + +// +// StreamingNodeManagerService +// + +// StreamingNodeManagerService is the log manage operation on log node. +// Server: all log node. Running on every log node. +// Client: log coord. There should be only one client globally to call this service on all streamingnode. +service StreamingNodeManagerService { + // Assign is a unary RPC to assign a channel on a log node. + // Block until the channel assignd is ready to read or write on the log node. + // Error: + // If the channel already exists, return error with code CHANNEL_EXIST. + rpc Assign(StreamingNodeManagerAssignRequest) returns (StreamingNodeManagerAssignResponse) {}; + + // Remove is unary RPC to remove a channel on a log node. + // Data of the channel on flying would be sent or flused as much as possible. + // Block until the resource of channel is released on the log node. + // New incoming request of handler of this channel will be rejected with special error. + // Error: + // If the channel does not exist, return error with code CHANNEL_NOT_EXIST. + rpc Remove(StreamingNodeManagerRemoveRequest) returns (StreamingNodeManagerRemoveResponse) {}; + + // rpc CollectStatus() ... + // CollectStatus is unary RPC to collect all avaliable channel info and load balance info on a log node. + // Used to recover channel info on log coord, collect balance info and health check. + rpc CollectStatus(StreamingNodeManagerCollectStatusRequest) returns (StreamingNodeManagerCollectStatusResponse) {}; +} + +// StreamingManagerAssignRequest is the request message of Assign RPC. +message StreamingNodeManagerAssignRequest { + PChannelInfo pchannel = 1; +} + +message StreamingNodeManagerAssignResponse { +} + +message StreamingNodeManagerRemoveRequest { + PChannelInfo pchannel = 1; +} + +message StreamingNodeManagerRemoveResponse {} + +message StreamingNodeManagerCollectStatusRequest { +} + +message StreamingNodeBalanceAttributes { + // TODO: traffic of pchannel or other things. +} + +message StreamingNodeManagerCollectStatusResponse { + StreamingNodeBalanceAttributes balance_attributes = 1; +} diff --git a/internal/streamingnode/server/service/handler.go b/internal/streamingnode/server/service/handler.go new file mode 100644 index 0000000000..0251f45797 --- /dev/null +++ b/internal/streamingnode/server/service/handler.go @@ -0,0 +1,55 @@ +package service + +import ( + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/service/handler/consumer" + "github.com/milvus-io/milvus/internal/streamingnode/server/service/handler/producer" + "github.com/milvus-io/milvus/internal/streamingnode/server/walmanager" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +var _ HandlerService = (*handlerServiceImpl)(nil) + +// NewHandlerService creates a new handler service. +func NewHandlerService(walManager walmanager.Manager) HandlerService { + return &handlerServiceImpl{ + walManager: walManager, + } +} + +type HandlerService = streamingpb.StreamingNodeHandlerServiceServer + +// handlerServiceImpl implements HandlerService. +// handlerServiceImpl is just a rpc level to handle incoming grpc. +// It should not handle any wal related logic, just +// 1. recv request and transfer param into wal +// 2. wait wal handling result and transform it into grpc response (convert error into grpc error) +// 3. send response to client. +type handlerServiceImpl struct { + walManager walmanager.Manager +} + +// Produce creates a new producer for the channel on this log node. +func (hs *handlerServiceImpl) Produce(streamServer streamingpb.StreamingNodeHandlerService_ProduceServer) error { + metrics.StreamingNodeProducerTotal.WithLabelValues(paramtable.GetStringNodeID()).Inc() + defer metrics.StreamingNodeProducerTotal.WithLabelValues(paramtable.GetStringNodeID()).Dec() + + p, err := producer.CreateProduceServer(hs.walManager, streamServer) + if err != nil { + return err + } + return p.Execute() +} + +// Consume creates a new consumer for the channel on this log node. +func (hs *handlerServiceImpl) Consume(streamServer streamingpb.StreamingNodeHandlerService_ConsumeServer) error { + metrics.StreamingNodeConsumerTotal.WithLabelValues(paramtable.GetStringNodeID()).Inc() + defer metrics.StreamingNodeConsumerTotal.WithLabelValues(paramtable.GetStringNodeID()).Dec() + + c, err := consumer.CreateConsumeServer(hs.walManager, streamServer) + if err != nil { + return err + } + return c.Execute() +} diff --git a/internal/streamingnode/server/service/handler/consumer/consume_grpc_server_helper.go b/internal/streamingnode/server/service/handler/consumer/consume_grpc_server_helper.go new file mode 100644 index 0000000000..444ec8295c --- /dev/null +++ b/internal/streamingnode/server/service/handler/consumer/consume_grpc_server_helper.go @@ -0,0 +1,37 @@ +package consumer + +import "github.com/milvus-io/milvus/internal/proto/streamingpb" + +// consumeGrpcServerHelper is a wrapped consumer server of log messages. +type consumeGrpcServerHelper struct { + streamingpb.StreamingNodeHandlerService_ConsumeServer +} + +// SendConsumeMessage sends the consume result to client. +func (p *consumeGrpcServerHelper) SendConsumeMessage(resp *streamingpb.ConsumeMessageReponse) error { + return p.Send(&streamingpb.ConsumeResponse{ + Response: &streamingpb.ConsumeResponse_Consume{ + Consume: resp, + }, + }) +} + +// SendCreated sends the create response to client. +func (p *consumeGrpcServerHelper) SendCreated(resp *streamingpb.CreateConsumerResponse) error { + return p.Send(&streamingpb.ConsumeResponse{ + Response: &streamingpb.ConsumeResponse_Create{ + Create: resp, + }, + }) +} + +// SendClosed sends the close response to client. +// no more message should be sent after sending close response. +func (p *consumeGrpcServerHelper) SendClosed() error { + // wait for all consume messages are processed. + return p.Send(&streamingpb.ConsumeResponse{ + Response: &streamingpb.ConsumeResponse_Close{ + Close: &streamingpb.CloseConsumerResponse{}, + }, + }) +} diff --git a/internal/streamingnode/server/service/handler/consumer/consume_server.go b/internal/streamingnode/server/service/handler/consumer/consume_server.go new file mode 100644 index 0000000000..6340965cf4 --- /dev/null +++ b/internal/streamingnode/server/service/handler/consumer/consume_server.go @@ -0,0 +1,191 @@ +package consumer + +import ( + "io" + "strconv" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/streamingnode/server/walmanager" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/internal/util/streamingutil/typeconverter" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +// CreateConsumeServer create a new consumer. +// Expected message sequence: +// CreateConsumeServer: +// -> ConsumeResponse 1 +// -> ConsumeResponse 2 +// -> ConsumeResponse 3 +// CloseConsumer: +func CreateConsumeServer(walManager walmanager.Manager, streamServer streamingpb.StreamingNodeHandlerService_ConsumeServer) (*ConsumeServer, error) { + createReq, err := contextutil.GetCreateConsumer(streamServer.Context()) + if err != nil { + return nil, status.NewInvaildArgument("create consumer request is required") + } + + pchanelInfo := typeconverter.NewPChannelInfoFromProto(createReq.Pchannel) + l, err := walManager.GetAvailableWAL(pchanelInfo) + if err != nil { + return nil, err + } + + deliverPolicy, err := typeconverter.NewDeliverPolicyFromProto(l.WALName(), createReq.GetDeliverPolicy()) + if err != nil { + return nil, status.NewInvaildArgument("at convert deliver policy, err: %s", err.Error()) + } + deliverFilters, err := newMessageFilter(createReq.DeliverFilters) + if err != nil { + return nil, status.NewInvaildArgument("at convert deliver filters, err: %s", err.Error()) + } + scanner, err := l.Read(streamServer.Context(), wal.ReadOption{ + DeliverPolicy: deliverPolicy, + MessageFilter: deliverFilters, + }) + if err != nil { + return nil, err + } + consumeServer := &consumeGrpcServerHelper{ + StreamingNodeHandlerService_ConsumeServer: streamServer, + } + if err := consumeServer.SendCreated(&streamingpb.CreateConsumerResponse{}); err != nil { + // release the scanner to avoid resource leak. + if err := scanner.Close(); err != nil { + log.Warn("close scanner failed at create consume server", zap.Error(err)) + } + return nil, errors.Wrap(err, "at send created") + } + return &ConsumeServer{ + scanner: scanner, + consumeServer: consumeServer, + logger: log.With(zap.String("channel", l.Channel().Name), zap.Int64("term", l.Channel().Term)), // Add trace info for all log. + closeCh: make(chan struct{}), + }, nil +} + +// ConsumeServer is a ConsumeServer of log messages. +type ConsumeServer struct { + scanner wal.Scanner + consumeServer *consumeGrpcServerHelper + logger *log.MLogger + closeCh chan struct{} +} + +// Execute executes the consumer. +func (c *ConsumeServer) Execute() error { + // recv loop will be blocked until the stream is closed. + // 1. close by client. + // 2. close by server context cancel by return of outside Execute. + go c.recvLoop() + + // Start a send loop on current goroutine. + // the loop will be blocked until: + // 1. the stream is broken. + // 2. recv arm recv close signal. + // 3. scanner is quit with expected error. + return c.sendLoop() +} + +// sendLoop sends the message to client. +func (c *ConsumeServer) sendLoop() (err error) { + defer func() { + if err := c.scanner.Close(); err != nil { + c.logger.Warn("close scanner failed", zap.Error(err)) + } + if err != nil { + c.logger.Warn("send arm of stream closed by unexpected error", zap.Error(err)) + return + } + c.logger.Info("send arm of stream closed") + }() + // Read ahead buffer is implemented by scanner. + // Do not add buffer here. + for { + select { + case msg, ok := <-c.scanner.Chan(): + if !ok { + return status.NewInner("scanner error: %s", c.scanner.Error()) + } + // Send Consumed message to client and do metrics. + messageSize := msg.EstimateSize() + if err := c.consumeServer.SendConsumeMessage(&streamingpb.ConsumeMessageReponse{ + Id: &streamingpb.MessageID{ + Id: msg.MessageID().Marshal(), + }, + Message: &streamingpb.Message{ + Payload: msg.Payload(), + Properties: msg.Properties().ToRawMap(), + }, + }); err != nil { + return status.NewInner("send consume message failed: %s", err.Error()) + } + metrics.StreamingNodeConsumeBytes.WithLabelValues( + paramtable.GetStringNodeID(), + c.scanner.Channel().Name, + strconv.FormatInt(c.scanner.Channel().Term, 10), + ).Observe(float64(messageSize)) + case <-c.closeCh: + c.logger.Info("close channel notified") + if err := c.consumeServer.SendClosed(); err != nil { + c.logger.Warn("send close failed", zap.Error(err)) + return status.NewInner("close send server failed: %s", err.Error()) + } + return nil + case <-c.consumeServer.Context().Done(): + return c.consumeServer.Context().Err() + } + } +} + +// recvLoop receives messages from client. +func (c *ConsumeServer) recvLoop() (err error) { + defer func() { + close(c.closeCh) + if err != nil { + c.logger.Warn("recv arm of stream closed by unexpected error", zap.Error(err)) + return + } + c.logger.Info("recv arm of stream closed") + }() + + for { + req, err := c.consumeServer.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + switch req := req.Request.(type) { + case *streamingpb.ConsumeRequest_Close: + c.logger.Info("close request received") + // we will receive io.EOF soon, just do nothing here. + default: + // skip unknown message here, to keep the forward compatibility. + c.logger.Warn("unknown request type", zap.Any("request", req)) + } + } +} + +func newMessageFilter(filters []*streamingpb.DeliverFilter) (wal.MessageFilter, error) { + fs, err := typeconverter.NewDeliverFiltersFromProtos(filters) + if err != nil { + return nil, err + } + return func(msg message.ImmutableMessage) bool { + for _, f := range fs { + if !f.Filter(msg) { + return false + } + } + return true + }, nil +} diff --git a/internal/streamingnode/server/service/handler/consumer/consume_server_test.go b/internal/streamingnode/server/service/handler/consumer/consume_server_test.go new file mode 100644 index 0000000000..446b023fae --- /dev/null +++ b/internal/streamingnode/server/service/handler/consumer/consume_server_test.go @@ -0,0 +1,272 @@ +package consumer + +import ( + "context" + "io" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/metadata" + + "github.com/milvus-io/milvus/internal/mocks/proto/mock_streamingpb" + "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal" + "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_walmanager" + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/walmanager" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestMain(m *testing.M) { + paramtable.Init() + m.Run() +} + +func TestNewMessageFilter(t *testing.T) { + filters := []*streamingpb.DeliverFilter{ + { + Filter: &streamingpb.DeliverFilter_TimeTickGt{ + TimeTickGt: &streamingpb.DeliverFilterTimeTickGT{ + TimeTick: 1, + }, + }, + }, + { + Filter: &streamingpb.DeliverFilter_Vchannel{ + Vchannel: &streamingpb.DeliverFilterVChannel{ + Vchannel: "test", + }, + }, + }, + } + filterFunc, err := newMessageFilter(filters) + assert.NoError(t, err) + + msg := mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(2).Maybe() + msg.EXPECT().VChannel().Return("test2").Maybe() + assert.False(t, filterFunc(msg)) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(1).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + assert.False(t, filterFunc(msg)) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(2).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + assert.True(t, filterFunc(msg)) + + filters = []*streamingpb.DeliverFilter{ + { + Filter: &streamingpb.DeliverFilter_TimeTickGte{ + TimeTickGte: &streamingpb.DeliverFilterTimeTickGTE{ + TimeTick: 1, + }, + }, + }, + { + Filter: &streamingpb.DeliverFilter_Vchannel{ + Vchannel: &streamingpb.DeliverFilterVChannel{ + Vchannel: "test", + }, + }, + }, + } + filterFunc, err = newMessageFilter(filters) + assert.NoError(t, err) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(1).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + assert.True(t, filterFunc(msg)) +} + +func TestCreateConsumeServer(t *testing.T) { + manager := mock_walmanager.NewMockManager(t) + grpcConsumeServer := mock_streamingpb.NewMockStreamingNodeHandlerService_ConsumeServer(t) + + // No metadata in context should report error + grpcConsumeServer.EXPECT().Context().Return(context.Background()) + assertCreateConsumeServerFail(t, manager, grpcConsumeServer) + + // wal not exist should report error. + meta, _ := metadata.FromOutgoingContext(contextutil.WithCreateConsumer(context.Background(), &streamingpb.CreateConsumerRequest{ + Pchannel: &streamingpb.PChannelInfo{ + Name: "test", + Term: 1, + }, + DeliverPolicy: &streamingpb.DeliverPolicy{ + Policy: &streamingpb.DeliverPolicy_All{}, + }, + })) + ctx := metadata.NewIncomingContext(context.Background(), meta) + grpcConsumeServer.ExpectedCalls = nil + grpcConsumeServer.EXPECT().Context().Return(ctx) + manager.EXPECT().GetAvailableWAL(types.PChannelInfo{Name: "test", Term: int64(1)}).Return(nil, errors.New("wal not exist")) + assertCreateConsumeServerFail(t, manager, grpcConsumeServer) + + // Return error if create scanner failed. + l := mock_wal.NewMockWAL(t) + l.EXPECT().Read(mock.Anything, mock.Anything).Return(nil, errors.New("create scanner failed")) + l.EXPECT().WALName().Return("test") + manager.ExpectedCalls = nil + manager.EXPECT().GetAvailableWAL(types.PChannelInfo{Name: "test", Term: int64(1)}).Return(l, nil) + assertCreateConsumeServerFail(t, manager, grpcConsumeServer) + + // Return error if send created failed. + grpcConsumeServer.EXPECT().Send(mock.Anything).Return(errors.New("send created failed")) + l.EXPECT().Read(mock.Anything, mock.Anything).Unset() + s := mock_wal.NewMockScanner(t) + s.EXPECT().Close().Return(nil) + l.EXPECT().Read(mock.Anything, mock.Anything).Return(s, nil) + assertCreateConsumeServerFail(t, manager, grpcConsumeServer) + + // Passed. + grpcConsumeServer.EXPECT().Send(mock.Anything).Unset() + grpcConsumeServer.EXPECT().Send(mock.Anything).Return(nil) + + l.EXPECT().Channel().Return(types.PChannelInfo{ + Name: "test", + Term: 1, + }) + server, err := CreateConsumeServer(manager, grpcConsumeServer) + assert.NoError(t, err) + assert.NotNil(t, server) +} + +func TestConsumeServerRecvArm(t *testing.T) { + grpcConsumerServer := mock_streamingpb.NewMockStreamingNodeHandlerService_ConsumeServer(t) + server := &ConsumeServer{ + consumeServer: &consumeGrpcServerHelper{ + StreamingNodeHandlerService_ConsumeServer: grpcConsumerServer, + }, + logger: log.With(), + closeCh: make(chan struct{}), + } + recvCh := make(chan *streamingpb.ConsumeRequest) + grpcConsumerServer.EXPECT().Recv().RunAndReturn(func() (*streamingpb.ConsumeRequest, error) { + req, ok := <-recvCh + if ok { + return req, nil + } + return nil, io.EOF + }) + + // Test recv arm + ch := make(chan error) + go func() { + ch <- server.recvLoop() + }() + + // should be blocked. + testChannelShouldBeBlocked(t, ch, 500*time.Millisecond) + testChannelShouldBeBlocked(t, server.closeCh, 500*time.Millisecond) + + // cancelConsumerCh should be closed after receiving close request. + recvCh <- &streamingpb.ConsumeRequest{ + Request: &streamingpb.ConsumeRequest_Close{}, + } + close(recvCh) + <-server.closeCh + assert.NoError(t, <-ch) + + // Test unexpected recv error. + grpcConsumerServer.EXPECT().Recv().Unset() + grpcConsumerServer.EXPECT().Recv().Return(nil, io.ErrUnexpectedEOF) + server.closeCh = make(chan struct{}) + assert.ErrorIs(t, server.recvLoop(), io.ErrUnexpectedEOF) +} + +func TestConsumerServeSendArm(t *testing.T) { + grpcConsumerServer := mock_streamingpb.NewMockStreamingNodeHandlerService_ConsumeServer(t) + scanner := mock_wal.NewMockScanner(t) + s := &ConsumeServer{ + consumeServer: &consumeGrpcServerHelper{ + StreamingNodeHandlerService_ConsumeServer: grpcConsumerServer, + }, + logger: log.With(), + scanner: scanner, + closeCh: make(chan struct{}), + } + ctx, cancel := context.WithCancel(context.Background()) + grpcConsumerServer.EXPECT().Context().Return(ctx) + grpcConsumerServer.EXPECT().Send(mock.Anything).RunAndReturn(func(cr *streamingpb.ConsumeResponse) error { return nil }).Times(2) + + scanCh := make(chan message.ImmutableMessage, 1) + scanner.EXPECT().Channel().Return(types.PChannelInfo{}) + scanner.EXPECT().Chan().Return(scanCh) + scanner.EXPECT().Close().Return(nil).Times(3) + + // Test send arm + ch := make(chan error) + go func() { + ch <- s.sendLoop() + }() + + // should be blocked. + testChannelShouldBeBlocked(t, ch, 500*time.Millisecond) + + // test send. + msg := mock_message.NewMockImmutableMessage(t) + msg.EXPECT().MessageID().Return(walimplstest.NewTestMessageID(1)) + msg.EXPECT().EstimateSize().Return(0) + msg.EXPECT().Payload().Return([]byte{}) + properties := mock_message.NewMockRProperties(t) + properties.EXPECT().ToRawMap().Return(map[string]string{}) + msg.EXPECT().Properties().Return(properties) + scanCh <- msg + + // test scanner broken. + scanner.EXPECT().Error().Return(io.EOF) + close(scanCh) + err := <-ch + sErr := status.AsStreamingError(err) + assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_INNER, sErr.Code) + + // test cancel by client. + scanner.EXPECT().Chan().Unset() + scanner.EXPECT().Chan().Return(make(<-chan message.ImmutableMessage)) + go func() { + ch <- s.sendLoop() + }() + // should be blocked. + testChannelShouldBeBlocked(t, ch, 500*time.Millisecond) + close(s.closeCh) + assert.NoError(t, <-ch) + + // test cancel by server context. + s.closeCh = make(chan struct{}) + go func() { + ch <- s.sendLoop() + }() + testChannelShouldBeBlocked(t, ch, 500*time.Millisecond) + cancel() + assert.ErrorIs(t, <-ch, context.Canceled) +} + +func assertCreateConsumeServerFail(t *testing.T, manager walmanager.Manager, grpcConsumeServer streamingpb.StreamingNodeHandlerService_ConsumeServer) { + server, err := CreateConsumeServer(manager, grpcConsumeServer) + assert.Nil(t, server) + assert.Error(t, err) +} + +func testChannelShouldBeBlocked[T any](t *testing.T, ch <-chan T, d time.Duration) { + // should be blocked. + ctx, cancel := context.WithTimeout(context.Background(), d) + defer cancel() + select { + case <-ch: + t.Errorf("should be block") + case <-ctx.Done(): + } +} diff --git a/internal/streamingnode/server/service/handler/producer/produce_grpc_server_helper.go b/internal/streamingnode/server/service/handler/producer/produce_grpc_server_helper.go new file mode 100644 index 0000000000..44a8b13a37 --- /dev/null +++ b/internal/streamingnode/server/service/handler/producer/produce_grpc_server_helper.go @@ -0,0 +1,39 @@ +package producer + +import ( + "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +// produceGrpcServerHelper is a wrapped producer server of log messages. +type produceGrpcServerHelper struct { + streamingpb.StreamingNodeHandlerService_ProduceServer +} + +// SendProduceMessage sends the produce result to client. +func (p *produceGrpcServerHelper) SendProduceMessage(resp *streamingpb.ProduceMessageResponse) error { + return p.Send(&streamingpb.ProduceResponse{ + Response: &streamingpb.ProduceResponse_Produce{ + Produce: resp, + }, + }) +} + +// SendCreated sends the create response to client. +func (p *produceGrpcServerHelper) SendCreated() error { + return p.Send(&streamingpb.ProduceResponse{ + Response: &streamingpb.ProduceResponse_Create{ + Create: &streamingpb.CreateProducerResponse{}, + }, + }) +} + +// SendClosed sends the close response to client. +// no more message should be sent after sending close response. +func (p *produceGrpcServerHelper) SendClosed() error { + // wait for all produce messages are processed. + return p.Send(&streamingpb.ProduceResponse{ + Response: &streamingpb.ProduceResponse_Close{ + Close: &streamingpb.CloseProducerResponse{}, + }, + }) +} diff --git a/internal/streamingnode/server/service/handler/producer/produce_server.go b/internal/streamingnode/server/service/handler/producer/produce_server.go new file mode 100644 index 0000000000..954fc3a9b7 --- /dev/null +++ b/internal/streamingnode/server/service/handler/producer/produce_server.go @@ -0,0 +1,232 @@ +package producer + +import ( + "io" + "strconv" + "sync" + "time" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/streamingnode/server/walmanager" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/internal/util/streamingutil/typeconverter" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +// CreateProduceServer create a new producer. +// Expected message sequence: +// CreateProducer (Header) +// ProduceRequest 1 -> ProduceResponse Or Error 1 +// ProduceRequest 2 -> ProduceResponse Or Error 2 +// ProduceRequest 3 -> ProduceResponse Or Error 3 +// CloseProducer +func CreateProduceServer(walManager walmanager.Manager, streamServer streamingpb.StreamingNodeHandlerService_ProduceServer) (*ProduceServer, error) { + createReq, err := contextutil.GetCreateProducer(streamServer.Context()) + if err != nil { + return nil, status.NewInvaildArgument("create producer request is required") + } + l, err := walManager.GetAvailableWAL(typeconverter.NewPChannelInfoFromProto(createReq.Pchannel)) + if err != nil { + return nil, err + } + + produceServer := &produceGrpcServerHelper{ + StreamingNodeHandlerService_ProduceServer: streamServer, + } + if err := produceServer.SendCreated(); err != nil { + return nil, errors.Wrap(err, "at send created") + } + return &ProduceServer{ + wal: l, + produceServer: produceServer, + logger: log.With(zap.String("channel", l.Channel().Name), zap.Int64("term", l.Channel().Term)), + produceMessageCh: make(chan *streamingpb.ProduceMessageResponse), + appendWG: sync.WaitGroup{}, + }, nil +} + +// ProduceServer is a ProduceServer of log messages. +type ProduceServer struct { + wal wal.WAL + produceServer *produceGrpcServerHelper + logger *log.MLogger + produceMessageCh chan *streamingpb.ProduceMessageResponse // All processing messages result should sent from theses channel. + appendWG sync.WaitGroup +} + +// Execute starts the producer. +func (p *ProduceServer) Execute() error { + // Start a recv arm to handle the control message from client. + go func() { + // recv loop will be blocked until the stream is closed. + // 1. close by client. + // 2. close by server context cancel by return of outside Execute. + _ = p.recvLoop() + }() + + // Start a send loop on current main goroutine. + // the loop will be blocked until: + // 1. the stream is broken. + // 2. recv arm recv closed and all response is sent. + return p.sendLoop() +} + +// sendLoop sends the message to client. +func (p *ProduceServer) sendLoop() (err error) { + defer func() { + if err != nil { + p.logger.Warn("send arm of stream closed by unexpected error", zap.Error(err)) + return + } + p.logger.Info("send arm of stream closed") + }() + for { + select { + case resp, ok := <-p.produceMessageCh: + if !ok { + // all message has been sent, sent close response. + p.produceServer.SendClosed() + return nil + } + if err := p.produceServer.SendProduceMessage(resp); err != nil { + return err + } + case <-p.produceServer.Context().Done(): + return errors.Wrap(p.produceServer.Context().Err(), "cancel send loop by stream server") + } + } +} + +// recvLoop receives the message from client. +func (p *ProduceServer) recvLoop() (err error) { + defer func() { + p.appendWG.Wait() + close(p.produceMessageCh) + if err != nil { + p.logger.Warn("recv arm of stream closed by unexpected error", zap.Error(err)) + return + } + p.logger.Info("recv arm of stream closed") + }() + + for { + req, err := p.produceServer.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + switch req := req.Request.(type) { + case *streamingpb.ProduceRequest_Produce: + p.handleProduce(req.Produce) + case *streamingpb.ProduceRequest_Close: + p.logger.Info("recv arm of stream start to close, waiting for all append request finished...") + // we will receive io.EOF after that. + default: + // skip message here, to keep the forward compatibility. + p.logger.Warn("unknown request type", zap.Any("request", req)) + } + } +} + +// handleProduce handles the produce message request. +func (p *ProduceServer) handleProduce(req *streamingpb.ProduceMessageRequest) { + p.logger.Debug("recv produce message from client", zap.Int64("requestID", req.RequestId)) + msg := message.NewMutableMessageBuilder(). + WithPayload(req.GetMessage().GetPayload()). + WithProperties(req.GetMessage().GetProperties()). + BuildMutable() + + if err := p.validateMessage(msg); err != nil { + p.logger.Warn("produce message validation failed", zap.Int64("requestID", req.RequestId), zap.Error(err)) + p.sendProduceResult(req.RequestId, nil, err) + return + } + + // Append message to wal. + // Concurrent append request can be executed concurrently. + messageSize := msg.EstimateSize() + now := time.Now() + p.appendWG.Add(1) + p.wal.AppendAsync(p.produceServer.Context(), msg, func(id message.MessageID, err error) { + defer func() { + p.appendWG.Done() + p.updateMetrics(messageSize, time.Since(now).Seconds(), err) + }() + p.sendProduceResult(req.RequestId, id, err) + }) +} + +// validateMessage validates the message. +func (p *ProduceServer) validateMessage(msg message.MutableMessage) error { + // validate the msg. + if !msg.Version().GT(message.VersionOld) { + return status.NewInner("unsupported message version") + } + if !msg.MessageType().Valid() { + return status.NewInner("unsupported message type") + } + if msg.Payload() == nil { + return status.NewInner("empty payload for message") + } + return nil +} + +// sendProduceResult sends the produce result to client. +func (p *ProduceServer) sendProduceResult(reqID int64, id message.MessageID, err error) { + resp := &streamingpb.ProduceMessageResponse{ + RequestId: reqID, + } + if err != nil { + p.logger.Warn("append message to wal failed", zap.Int64("requestID", reqID), zap.Error(err)) + resp.Response = &streamingpb.ProduceMessageResponse_Error{ + Error: status.AsStreamingError(err).AsPBError(), + } + } else { + resp.Response = &streamingpb.ProduceMessageResponse_Result{ + Result: &streamingpb.ProduceMessageResponseResult{ + Id: &streamingpb.MessageID{ + Id: id.Marshal(), + }, + }, + } + } + + // If server context is canceled, it means the stream has been closed. + // all pending response message should be dropped, client side will handle it. + select { + case p.produceMessageCh <- resp: + p.logger.Debug("send produce message response to client", zap.Int64("requestID", reqID), zap.Any("messageID", id), zap.Error(err)) + case <-p.produceServer.Context().Done(): + p.logger.Warn("stream closed before produce message response sent", zap.Int64("requestID", reqID), zap.Any("messageID", id)) + return + } +} + +// updateMetrics updates the metrics. +func (p *ProduceServer) updateMetrics(messageSize int, cost float64, err error) { + name := p.wal.Channel().Name + term := strconv.FormatInt(p.wal.Channel().Term, 10) + metrics.StreamingNodeProduceBytes.WithLabelValues(paramtable.GetStringNodeID(), name, term, getStatusLabel(err)).Observe(float64(messageSize)) + metrics.StreamingNodeProduceDurationSeconds.WithLabelValues(paramtable.GetStringNodeID(), name, term, getStatusLabel(err)).Observe(cost) +} + +// getStatusLabel returns the status label of error. +func getStatusLabel(err error) string { + if status.IsCanceled(err) { + return metrics.CancelLabel + } + if err != nil { + return metrics.FailLabel + } + return metrics.SuccessLabel +} diff --git a/internal/streamingnode/server/service/handler/producer/produce_server_test.go b/internal/streamingnode/server/service/handler/producer/produce_server_test.go new file mode 100644 index 0000000000..7e76b2b6bf --- /dev/null +++ b/internal/streamingnode/server/service/handler/producer/produce_server_test.go @@ -0,0 +1,287 @@ +package producer + +import ( + "context" + "io" + "sync" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.uber.org/atomic" + "google.golang.org/grpc/metadata" + + "github.com/milvus-io/milvus/internal/mocks/proto/mock_streamingpb" + "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal" + "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_walmanager" + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/walmanager" + "github.com/milvus-io/milvus/internal/util/streamingutil/service/contextutil" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func TestMain(m *testing.M) { + paramtable.Init() + m.Run() +} + +func TestCreateProduceServer(t *testing.T) { + manager := mock_walmanager.NewMockManager(t) + grpcProduceServer := mock_streamingpb.NewMockStreamingNodeHandlerService_ProduceServer(t) + + // No metadata in context should report error + grpcProduceServer.EXPECT().Context().Return(context.Background()) + assertCreateProduceServerFail(t, manager, grpcProduceServer) + + // wal not exist should report error. + meta, _ := metadata.FromOutgoingContext(contextutil.WithCreateProducer(context.Background(), &streamingpb.CreateProducerRequest{ + Pchannel: &streamingpb.PChannelInfo{ + Name: "test", + Term: 1, + }, + })) + ctx := metadata.NewIncomingContext(context.Background(), meta) + grpcProduceServer.ExpectedCalls = nil + grpcProduceServer.EXPECT().Context().Return(ctx) + manager.EXPECT().GetAvailableWAL(types.PChannelInfo{Name: "test", Term: 1}).Return(nil, errors.New("wal not exist")) + assertCreateProduceServerFail(t, manager, grpcProduceServer) + + // Return error if create scanner failed. + l := mock_wal.NewMockWAL(t) + manager.ExpectedCalls = nil + manager.EXPECT().GetAvailableWAL(types.PChannelInfo{Name: "test", Term: 1}).Return(l, nil) + grpcProduceServer.EXPECT().Send(mock.Anything).Return(errors.New("send created failed")) + assertCreateProduceServerFail(t, manager, grpcProduceServer) + + // Passed. + grpcProduceServer.EXPECT().Send(mock.Anything).Unset() + grpcProduceServer.EXPECT().Send(mock.Anything).Return(nil) + + l.EXPECT().Channel().Return(types.PChannelInfo{ + Name: "test", + Term: 1, + }) + server, err := CreateProduceServer(manager, grpcProduceServer) + assert.NoError(t, err) + assert.NotNil(t, server) +} + +func TestProduceSendArm(t *testing.T) { + grpcProduceServer := mock_streamingpb.NewMockStreamingNodeHandlerService_ProduceServer(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + grpcProduceServer.EXPECT().Context().Return(ctx) + + success := atomic.NewInt32(0) + produceFailure := atomic.NewBool(false) + grpcProduceServer.EXPECT().Send(mock.Anything).RunAndReturn(func(pr *streamingpb.ProduceResponse) error { + if !produceFailure.Load() { + success.Inc() + return nil + } + return errors.New("send failure") + }) + + p := &ProduceServer{ + produceServer: &produceGrpcServerHelper{ + StreamingNodeHandlerService_ProduceServer: grpcProduceServer, + }, + logger: log.With(), + produceMessageCh: make(chan *streamingpb.ProduceMessageResponse, 10), + appendWG: sync.WaitGroup{}, + } + + // test send arm success. + ch := make(chan error) + go func() { + ch <- p.sendLoop() + }() + + p.produceMessageCh <- &streamingpb.ProduceMessageResponse{ + RequestId: 1, + Response: &streamingpb.ProduceMessageResponse_Result{ + Result: &streamingpb.ProduceMessageResponseResult{ + Id: &streamingpb.MessageID{ + Id: walimplstest.NewTestMessageID(1).Marshal(), + }, + }, + }, + } + close(p.produceMessageCh) + assert.Nil(t, <-ch) + assert.Equal(t, int32(2), success.Load()) + + // test send arm failure + p = &ProduceServer{ + produceServer: &produceGrpcServerHelper{ + StreamingNodeHandlerService_ProduceServer: grpcProduceServer, + }, + logger: log.With(), + produceMessageCh: make(chan *streamingpb.ProduceMessageResponse, 10), + appendWG: sync.WaitGroup{}, + } + + ch = make(chan error) + go func() { + ch <- p.sendLoop() + }() + + success.Store(0) + produceFailure.Store(true) + + p.produceMessageCh <- &streamingpb.ProduceMessageResponse{ + RequestId: 1, + Response: &streamingpb.ProduceMessageResponse_Result{ + Result: &streamingpb.ProduceMessageResponseResult{ + Id: &streamingpb.MessageID{ + Id: walimplstest.NewTestMessageID(1).Marshal(), + }, + }, + }, + } + assert.Error(t, <-ch) + + // test send arm failure + p = &ProduceServer{ + produceServer: &produceGrpcServerHelper{ + StreamingNodeHandlerService_ProduceServer: grpcProduceServer, + }, + logger: log.With(), + produceMessageCh: make(chan *streamingpb.ProduceMessageResponse, 10), + appendWG: sync.WaitGroup{}, + } + + ch = make(chan error) + go func() { + ch <- p.sendLoop() + }() + cancel() + assert.Error(t, <-ch) +} + +func TestProduceServerRecvArm(t *testing.T) { + grpcProduceServer := mock_streamingpb.NewMockStreamingNodeHandlerService_ProduceServer(t) + recvCh := make(chan *streamingpb.ProduceRequest) + grpcProduceServer.EXPECT().Recv().RunAndReturn(func() (*streamingpb.ProduceRequest, error) { + req, ok := <-recvCh + if ok { + return req, nil + } + return nil, io.EOF + }) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + grpcProduceServer.EXPECT().Context().Return(ctx) + + l := mock_wal.NewMockWAL(t) + l.EXPECT().Channel().Return(types.PChannelInfo{ + Name: "test", + Term: 1, + }) + l.EXPECT().AppendAsync(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, mm message.MutableMessage, f func(message.MessageID, error)) { + msgID := walimplstest.NewTestMessageID(1) + f(msgID, nil) + }) + + p := &ProduceServer{ + wal: l, + produceServer: &produceGrpcServerHelper{ + StreamingNodeHandlerService_ProduceServer: grpcProduceServer, + }, + logger: log.With(), + produceMessageCh: make(chan *streamingpb.ProduceMessageResponse, 10), + appendWG: sync.WaitGroup{}, + } + + // Test send arm + ch := make(chan error) + go func() { + ch <- p.recvLoop() + }() + + req := &streamingpb.ProduceRequest{ + Request: &streamingpb.ProduceRequest_Produce{ + Produce: &streamingpb.ProduceMessageRequest{ + RequestId: 1, + Message: &streamingpb.Message{ + Payload: []byte("test"), + Properties: map[string]string{ + "_v": "1", + "_t": "1", + }, + }, + }, + }, + } + recvCh <- req + + msg := <-p.produceMessageCh + assert.Equal(t, int64(1), msg.RequestId) + assert.NotNil(t, msg.Response.(*streamingpb.ProduceMessageResponse_Result).Result.Id) + + // Test send error. + l.EXPECT().AppendAsync(mock.Anything, mock.Anything, mock.Anything).Unset() + l.EXPECT().AppendAsync(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, mm message.MutableMessage, f func(message.MessageID, error)) { + f(nil, errors.New("append error")) + }) + + req.Request.(*streamingpb.ProduceRequest_Produce).Produce.RequestId = 2 + recvCh <- req + msg = <-p.produceMessageCh + assert.Equal(t, int64(2), msg.RequestId) + assert.NotNil(t, msg.Response.(*streamingpb.ProduceMessageResponse_Error).Error) + + // Test send close and EOF. + recvCh <- &streamingpb.ProduceRequest{ + Request: &streamingpb.ProduceRequest_Close{}, + } + p.appendWG.Wait() + + close(recvCh) + // produceMessageCh should be closed. + <-p.produceMessageCh + // recvLoop should closed. + err := <-ch + assert.NoError(t, err) + + p = &ProduceServer{ + wal: l, + produceServer: &produceGrpcServerHelper{ + StreamingNodeHandlerService_ProduceServer: grpcProduceServer, + }, + logger: log.With(), + produceMessageCh: make(chan *streamingpb.ProduceMessageResponse), + appendWG: sync.WaitGroup{}, + } + + // Test recv failure. + grpcProduceServer.EXPECT().Recv().Unset() + grpcProduceServer.EXPECT().Recv().RunAndReturn(func() (*streamingpb.ProduceRequest, error) { + return nil, io.ErrUnexpectedEOF + }) + + assert.ErrorIs(t, p.recvLoop(), io.ErrUnexpectedEOF) +} + +func assertCreateProduceServerFail(t *testing.T, manager walmanager.Manager, grpcProduceServer streamingpb.StreamingNodeHandlerService_ProduceServer) { + server, err := CreateProduceServer(manager, grpcProduceServer) + assert.Nil(t, server) + assert.Error(t, err) +} + +func testChannelShouldBeBlocked[T any](t *testing.T, ch <-chan T, d time.Duration) { + // should be blocked. + ctx, cancel := context.WithTimeout(context.Background(), d) + defer cancel() + select { + case <-ch: + t.Errorf("should be block") + case <-ctx.Done(): + } +} diff --git a/internal/streamingnode/server/service/manager.go b/internal/streamingnode/server/service/manager.go new file mode 100644 index 0000000000..f3ad42d2b5 --- /dev/null +++ b/internal/streamingnode/server/service/manager.go @@ -0,0 +1,57 @@ +package service + +import ( + "context" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/walmanager" + "github.com/milvus-io/milvus/internal/util/streamingutil/typeconverter" +) + +var _ ManagerService = (*managerServiceImpl)(nil) + +// NewManagerService create a streamingnode manager service. +func NewManagerService(m walmanager.Manager) ManagerService { + return &managerServiceImpl{ + m, + } +} + +type ManagerService interface { + streamingpb.StreamingNodeManagerServiceServer +} + +// managerServiceImpl implements ManagerService. +// managerServiceImpl is just a rpc level to handle incoming grpc. +// all manager logic should be done in wal.Manager. +type managerServiceImpl struct { + walManager walmanager.Manager +} + +// Assign assigns a wal instance for the channel on this Manager. +// After assign returns, the wal instance is ready to use. +func (ms *managerServiceImpl) Assign(ctx context.Context, req *streamingpb.StreamingNodeManagerAssignRequest) (*streamingpb.StreamingNodeManagerAssignResponse, error) { + pchannelInfo := typeconverter.NewPChannelInfoFromProto(req.GetPchannel()) + if err := ms.walManager.Open(ctx, pchannelInfo); err != nil { + return nil, err + } + return &streamingpb.StreamingNodeManagerAssignResponse{}, nil +} + +// Remove removes the wal instance for the channel. +// After remove returns, the wal instance is removed and all underlying read write operation should be rejected. +func (ms *managerServiceImpl) Remove(ctx context.Context, req *streamingpb.StreamingNodeManagerRemoveRequest) (*streamingpb.StreamingNodeManagerRemoveResponse, error) { + pchannelInfo := typeconverter.NewPChannelInfoFromProto(req.GetPchannel()) + if err := ms.walManager.Remove(ctx, pchannelInfo); err != nil { + return nil, err + } + return &streamingpb.StreamingNodeManagerRemoveResponse{}, nil +} + +// CollectStatus collects the status of all wal instances in these streamingnode. +func (ms *managerServiceImpl) CollectStatus(ctx context.Context, req *streamingpb.StreamingNodeManagerCollectStatusRequest) (*streamingpb.StreamingNodeManagerCollectStatusResponse, error) { + // TODO: collect traffic metric for load balance. + return &streamingpb.StreamingNodeManagerCollectStatusResponse{ + BalanceAttributes: &streamingpb.StreamingNodeBalanceAttributes{}, + }, nil +} diff --git a/internal/streamingnode/server/wal/adaptor/opener_test.go b/internal/streamingnode/server/wal/adaptor/opener_test.go index b525aeda7e..f2b28cf104 100644 --- a/internal/streamingnode/server/wal/adaptor/opener_test.go +++ b/internal/streamingnode/server/wal/adaptor/opener_test.go @@ -66,9 +66,8 @@ func TestOpenerAdaptor(t *testing.T) { defer wg.Done() wal, err := opener.Open(context.Background(), &wal.OpenOption{ Channel: types.PChannelInfo{ - Name: fmt.Sprintf("test_%d", i), - Term: int64(i), - ServerID: 1, + Name: fmt.Sprintf("test_%d", i), + Term: int64(i), }, }) if err != nil { @@ -108,9 +107,8 @@ func TestOpenerAdaptor(t *testing.T) { // open a wal after opener closed should return shutdown error. _, err := opener.Open(context.Background(), &wal.OpenOption{ Channel: types.PChannelInfo{ - Name: "test_after_close", - Term: int64(1), - ServerID: 1, + Name: "test_after_close", + Term: int64(1), }, }) assertShutdownError(t, err) diff --git a/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go b/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go index 90a718a485..9861fb680b 100644 --- a/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go @@ -1,12 +1,13 @@ package adaptor import ( - "github.com/milvus-io/milvus/pkg/log" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls" "github.com/milvus-io/milvus/pkg/streaming/walimpls/helper" ) @@ -46,6 +47,11 @@ type scannerAdaptorImpl struct { cleanup func() } +// Channel returns the channel assignment info of the wal. +func (s *scannerAdaptorImpl) Channel() types.PChannelInfo { + return s.innerWAL.Channel() +} + // Chan returns the channel of message. func (s *scannerAdaptorImpl) Chan() <-chan message.ImmutableMessage { return s.sendingCh diff --git a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go index 978865fdc2..e2a0d24136 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go @@ -63,6 +63,10 @@ type walAdaptorImpl struct { cleanup func() } +func (w *walAdaptorImpl) WALName() string { + return w.inner.WALName() +} + // Channel returns the channel info of wal. func (w *walAdaptorImpl) Channel() types.PChannelInfo { return w.inner.Channel() diff --git a/internal/streamingnode/server/wal/adaptor/wal_test.go b/internal/streamingnode/server/wal/adaptor/wal_test.go index c99b4bc2d9..a48ae83d63 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_test.go +++ b/internal/streamingnode/server/wal/adaptor/wal_test.go @@ -86,9 +86,8 @@ func (f *testOneWALFramework) Run() { ctx := context.Background() for ; f.term <= 3; f.term++ { pChannel := types.PChannelInfo{ - Name: f.pchannel, - Term: int64(f.term), - ServerID: 1, + Name: f.pchannel, + Term: int64(f.term), } w, err := f.opener.Open(ctx, &wal.OpenOption{ Channel: pChannel, @@ -96,7 +95,6 @@ func (f *testOneWALFramework) Run() { assert.NoError(f.t, err) assert.NotNil(f.t, w) assert.Equal(f.t, pChannel.Name, w.Channel().Name) - assert.Equal(f.t, pChannel.ServerID, w.Channel().ServerID) f.testReadAndWrite(ctx, w) // close the wal diff --git a/internal/streamingnode/server/wal/scanner.go b/internal/streamingnode/server/wal/scanner.go index 27b71604f2..f9ea7a65a2 100644 --- a/internal/streamingnode/server/wal/scanner.go +++ b/internal/streamingnode/server/wal/scanner.go @@ -3,6 +3,7 @@ package wal import ( "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/options" + "github.com/milvus-io/milvus/pkg/streaming/util/types" ) type MessageFilter = func(message.ImmutableMessage) bool @@ -18,6 +19,9 @@ type Scanner interface { // Chan returns the channel of message. Chan() <-chan message.ImmutableMessage + // Channel returns the channel assignment info of the wal. + Channel() types.PChannelInfo + // Error returns the error of scanner failed. // Will block until scanner is closed or Chan is dry out. Error() error diff --git a/internal/streamingnode/server/wal/wal.go b/internal/streamingnode/server/wal/wal.go index b9f624c877..3cc3a847e9 100644 --- a/internal/streamingnode/server/wal/wal.go +++ b/internal/streamingnode/server/wal/wal.go @@ -10,6 +10,8 @@ import ( // WAL is the WAL framework interface. // !!! Don't implement it directly, implement walimpls.WAL instead. type WAL interface { + WALName() string + // Channel returns the channel assignment info of the wal. Channel() types.PChannelInfo diff --git a/internal/streamingnode/server/walmanager/manager.go b/internal/streamingnode/server/walmanager/manager.go index 33892fc3ad..811ae42f15 100644 --- a/internal/streamingnode/server/walmanager/manager.go +++ b/internal/streamingnode/server/walmanager/manager.go @@ -16,7 +16,7 @@ type Manager interface { // GetAvailableWAL returns a available wal instance for the channel. // Return nil if the wal instance is not found. - GetAvailableWAL(channelName string, term int64) (wal.WAL, error) + GetAvailableWAL(channel types.PChannelInfo) (wal.WAL, error) // GetAllAvailableWALInfo returns all available channel info. GetAllAvailableChannels() ([]types.PChannelInfo, error) diff --git a/internal/streamingnode/server/walmanager/manager_impl.go b/internal/streamingnode/server/walmanager/manager_impl.go index d6aee61f20..70f4ed26b5 100644 --- a/internal/streamingnode/server/walmanager/manager_impl.go +++ b/internal/streamingnode/server/walmanager/manager_impl.go @@ -80,21 +80,21 @@ func (m *managerImpl) Remove(ctx context.Context, channel types.PChannelInfo) (e // GetAvailableWAL returns a available wal instance for the channel. // Return nil if the wal instance is not found. -func (m *managerImpl) GetAvailableWAL(channelName string, term int64) (wal.WAL, error) { +func (m *managerImpl) GetAvailableWAL(channel types.PChannelInfo) (wal.WAL, error) { // reject operation if manager is closing. if m.lifetime.Add(lifetime.IsWorking) != nil { return nil, status.NewOnShutdownError("wal manager is closed") } defer m.lifetime.Done() - l := m.getWALLifetime(channelName).GetWAL() + l := m.getWALLifetime(channel.Name).GetWAL() if l == nil { - return nil, status.NewChannelNotExist(channelName) + return nil, status.NewChannelNotExist(channel.Name) } - channelTerm := l.Channel().Term - if channelTerm != term { - return nil, status.NewUnmatchedChannelTerm(channelName, term, channelTerm) + currentTerm := l.Channel().Term + if currentTerm != channel.Term { + return nil, status.NewUnmatchedChannelTerm(channel.Name, channel.Term, currentTerm) } return l, nil } diff --git a/internal/streamingnode/server/walmanager/manager_impl_test.go b/internal/streamingnode/server/walmanager/manager_impl_test.go index 93c90c6b28..dbeb8ee026 100644 --- a/internal/streamingnode/server/walmanager/manager_impl_test.go +++ b/internal/streamingnode/server/walmanager/manager_impl_test.go @@ -34,7 +34,7 @@ func TestManager(t *testing.T) { m := newManager(opener) channelName := "ch1" - l, err := m.GetAvailableWAL(channelName, 1) + l, err := m.GetAvailableWAL(types.PChannelInfo{Name: channelName, Term: 1}) assertErrorChannelNotExist(t, err) assert.Nil(t, l) @@ -45,7 +45,7 @@ func TestManager(t *testing.T) { err = m.Remove(context.Background(), types.PChannelInfo{Name: channelName, Term: 1}) assert.NoError(t, err) - l, err = m.GetAvailableWAL(channelName, 1) + l, err = m.GetAvailableWAL(types.PChannelInfo{Name: channelName, Term: 1}) assertErrorChannelNotExist(t, err) assert.Nil(t, l) @@ -53,7 +53,7 @@ func TestManager(t *testing.T) { Name: channelName, Term: 1, }) - assertErrorTermExpired(t, err) + assertErrorOperationIgnored(t, err) err = m.Open(context.Background(), types.PChannelInfo{ Name: channelName, @@ -62,13 +62,13 @@ func TestManager(t *testing.T) { assert.NoError(t, err) err = m.Remove(context.Background(), types.PChannelInfo{Name: channelName, Term: 1}) - assertErrorTermExpired(t, err) + assertErrorOperationIgnored(t, err) - l, err = m.GetAvailableWAL(channelName, 1) + l, err = m.GetAvailableWAL(types.PChannelInfo{Name: channelName, Term: 1}) assertErrorTermExpired(t, err) assert.Nil(t, l) - l, err = m.GetAvailableWAL(channelName, 2) + l, err = m.GetAvailableWAL(types.PChannelInfo{Name: channelName, Term: 2}) assert.NoError(t, err) assert.NotNil(t, l) @@ -101,7 +101,7 @@ func TestManager(t *testing.T) { err = m.Remove(context.Background(), types.PChannelInfo{Name: channelName, Term: 2}) assertShutdownError(t, err) - l, err = m.GetAvailableWAL(channelName, 2) + l, err = m.GetAvailableWAL(types.PChannelInfo{Name: channelName, Term: 2}) assertShutdownError(t, err) assert.Nil(t, l) } diff --git a/internal/streamingnode/server/walmanager/wal_lifetime.go b/internal/streamingnode/server/walmanager/wal_lifetime.go index d00ddc37d7..ee1a6f145c 100644 --- a/internal/streamingnode/server/walmanager/wal_lifetime.go +++ b/internal/streamingnode/server/walmanager/wal_lifetime.go @@ -53,7 +53,7 @@ func (w *walLifetime) Open(ctx context.Context, channel types.PChannelInfo) erro // Set expected WAL state to available at given term. expected := newAvailableExpectedState(ctx, channel) if !w.statePair.SetExpectedState(expected) { - return status.NewUnmatchedChannelTerm("expired term, cannot change expected state for open") + return status.NewIgnoreOperation("channel %s with expired term %d, cannot change expected state for open", channel.Name, channel.Term) } // Wait until the WAL state is ready or term expired or error occurs. @@ -65,7 +65,7 @@ func (w *walLifetime) Remove(ctx context.Context, term int64) error { // Set expected WAL state to unavailable at given term. expected := newUnavailableExpectedState(term) if !w.statePair.SetExpectedState(expected) { - return status.NewUnmatchedChannelTerm("expired term, cannot change expected state for remove") + return status.NewIgnoreOperation("expired term %d, cannot change expected state for remove", term) } // Wait until the WAL state is ready or term expired or error occurs. diff --git a/internal/streamingnode/server/walmanager/wal_lifetime_test.go b/internal/streamingnode/server/walmanager/wal_lifetime_test.go index 11feed0789..8d8187f316 100644 --- a/internal/streamingnode/server/walmanager/wal_lifetime_test.go +++ b/internal/streamingnode/server/walmanager/wal_lifetime_test.go @@ -38,7 +38,7 @@ func TestWALLifetime(t *testing.T) { // Test expired term remove. err = wlt.Remove(context.Background(), 1) - assertErrorTermExpired(t, err) + assertErrorOperationIgnored(t, err) assert.NotNil(t, wlt.GetWAL()) assert.Equal(t, channel, wlt.GetWAL().Channel().Name) assert.Equal(t, int64(2), wlt.GetWAL().Channel().Term) @@ -53,7 +53,7 @@ func TestWALLifetime(t *testing.T) { Name: channel, Term: 1, }) - assertErrorTermExpired(t, err) + assertErrorOperationIgnored(t, err) assert.Nil(t, wlt.GetWAL()) // Test open after close. @@ -92,7 +92,7 @@ func TestWALLifetime(t *testing.T) { Name: channel, Term: 11, }) - assertErrorTermExpired(t, err) + assertErrorOperationIgnored(t, err) wlt.Open(context.Background(), types.PChannelInfo{ Name: channel, diff --git a/internal/streamingnode/server/walmanager/wal_state_pair_test.go b/internal/streamingnode/server/walmanager/wal_state_pair_test.go index d23290456d..226456a5f1 100644 --- a/internal/streamingnode/server/walmanager/wal_state_pair_test.go +++ b/internal/streamingnode/server/walmanager/wal_state_pair_test.go @@ -70,6 +70,12 @@ func TestStatePair(t *testing.T) { } } +func assertErrorOperationIgnored(t *testing.T, err error) { + assert.Error(t, err) + logErr := status.AsStreamingError(err) + assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_IGNORED_OPERATION, logErr.Code) +} + func assertErrorTermExpired(t *testing.T, err error) { assert.Error(t, err) logErr := status.AsStreamingError(err) diff --git a/internal/streamingservice/.mockery.yaml b/internal/streamingservice/.mockery.yaml index 55423e9140..9c592d044b 100644 --- a/internal/streamingservice/.mockery.yaml +++ b/internal/streamingservice/.mockery.yaml @@ -19,3 +19,10 @@ packages: google.golang.org/grpc: interfaces: ClientStream: + github.com/milvus-io/milvus/internal/proto/streamingpb: + interfaces: + StreamingNodeHandlerService_ConsumeServer: + StreamingNodeHandlerService_ProduceServer: + github.com/milvus-io/milvus/internal/streamingnode/server/walmanager: + interfaces: + Manager: diff --git a/internal/util/streamingutil/service/contextutil/create_consumer.go b/internal/util/streamingutil/service/contextutil/create_consumer.go new file mode 100644 index 0000000000..ffb8e16bd0 --- /dev/null +++ b/internal/util/streamingutil/service/contextutil/create_consumer.go @@ -0,0 +1,51 @@ +package contextutil + +import ( + "context" + "encoding/base64" + "fmt" + + "github.com/cockroachdb/errors" + "github.com/golang/protobuf/proto" + "google.golang.org/grpc/metadata" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +const ( + createConsumerKey = "create-consumer" +) + +// WithCreateConsumer attaches create consumer request to context. +func WithCreateConsumer(ctx context.Context, req *streamingpb.CreateConsumerRequest) context.Context { + bytes, err := proto.Marshal(req) + if err != nil { + panic(fmt.Sprintf("unreachable: marshal create consumer request should never failed, %+v", req)) + } + // use base64 encoding to transfer binary to text. + msg := base64.StdEncoding.EncodeToString(bytes) + return metadata.AppendToOutgoingContext(ctx, createConsumerKey, msg) +} + +// GetCreateConsumer gets create consumer request from context. +func GetCreateConsumer(ctx context.Context) (*streamingpb.CreateConsumerRequest, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, errors.New("create consumer metadata not found from incoming context") + } + msg := md.Get(createConsumerKey) + if len(msg) == 0 { + return nil, errors.New("create consumer metadata not found") + } + + bytes, err := base64.StdEncoding.DecodeString(msg[0]) + if err != nil { + return nil, errors.Wrap(err, "decode create consumer metadata failed") + } + + req := &streamingpb.CreateConsumerRequest{} + if err := proto.Unmarshal(bytes, req); err != nil { + return nil, errors.Wrap(err, "unmarshal create consumer request failed") + } + return req, nil +} diff --git a/internal/util/streamingutil/service/contextutil/create_consumer_test.go b/internal/util/streamingutil/service/contextutil/create_consumer_test.go new file mode 100644 index 0000000000..8991070808 --- /dev/null +++ b/internal/util/streamingutil/service/contextutil/create_consumer_test.go @@ -0,0 +1,70 @@ +package contextutil + +import ( + "context" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +func TestWithCreateConsumer(t *testing.T) { + req := &streamingpb.CreateConsumerRequest{ + Pchannel: &streamingpb.PChannelInfo{ + Name: "test", + Term: 1, + }, + DeliverPolicy: &streamingpb.DeliverPolicy{ + Policy: &streamingpb.DeliverPolicy_All{}, + }, + } + ctx := WithCreateConsumer(context.Background(), req) + + md, ok := metadata.FromOutgoingContext(ctx) + assert.True(t, ok) + assert.NotNil(t, md) + + ctx = metadata.NewIncomingContext(context.Background(), md) + req2, err := GetCreateConsumer(ctx) + assert.Nil(t, err) + assert.Equal(t, req.Pchannel.Name, req2.Pchannel.Name) + assert.Equal(t, req.Pchannel.Term, req2.Pchannel.Term) + assert.Equal(t, req.DeliverPolicy.String(), req2.DeliverPolicy.String()) + + // panic case. + assert.Panics(t, func() { WithCreateConsumer(context.Background(), nil) }) +} + +func TestGetCreateConsumer(t *testing.T) { + // empty context. + req, err := GetCreateConsumer(context.Background()) + assert.Error(t, err) + assert.Nil(t, req) + + // key not exist. + md := metadata.New(map[string]string{}) + req, err = GetCreateConsumer(metadata.NewIncomingContext(context.Background(), md)) + assert.Error(t, err) + assert.Nil(t, req) + + // invalid value. + md = metadata.New(map[string]string{ + createConsumerKey: "invalid", + }) + req, err = GetCreateConsumer(metadata.NewIncomingContext(context.Background(), md)) + assert.Error(t, err) + assert.Nil(t, req) + + // unmarshal error. + md = metadata.New(map[string]string{ + createConsumerKey: base64.StdEncoding.EncodeToString([]byte("invalid")), + }) + req, err = GetCreateConsumer(metadata.NewIncomingContext(context.Background(), md)) + assert.Error(t, err) + assert.Nil(t, req) + + // normal case is tested on TestWithCreateConsumer. +} diff --git a/internal/util/streamingutil/service/contextutil/create_producer.go b/internal/util/streamingutil/service/contextutil/create_producer.go new file mode 100644 index 0000000000..e8e4aa8d26 --- /dev/null +++ b/internal/util/streamingutil/service/contextutil/create_producer.go @@ -0,0 +1,51 @@ +package contextutil + +import ( + "context" + "encoding/base64" + "fmt" + + "github.com/cockroachdb/errors" + "github.com/golang/protobuf/proto" + "google.golang.org/grpc/metadata" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +const ( + createProducerKey = "create-producer" +) + +// WithCreateProducer attaches create producer request to context. +func WithCreateProducer(ctx context.Context, req *streamingpb.CreateProducerRequest) context.Context { + bytes, err := proto.Marshal(req) + if err != nil { + panic(fmt.Sprintf("unreachable: marshal create producer request failed, %+v", err)) + } + // use base64 encoding to transfer binary to text. + msg := base64.StdEncoding.EncodeToString(bytes) + return metadata.AppendToOutgoingContext(ctx, createProducerKey, msg) +} + +// GetCreateProducer gets create producer request from context. +func GetCreateProducer(ctx context.Context) (*streamingpb.CreateProducerRequest, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, errors.New("create producer metadata not found from incoming context") + } + msg := md.Get(createProducerKey) + if len(msg) == 0 { + return nil, errors.New("create consumer metadata not found") + } + + bytes, err := base64.StdEncoding.DecodeString(msg[0]) + if err != nil { + return nil, errors.Wrap(err, "decode create consumer metadata failed") + } + + req := &streamingpb.CreateProducerRequest{} + if err := proto.Unmarshal(bytes, req); err != nil { + return nil, errors.Wrap(err, "unmarshal create producer request failed") + } + return req, nil +} diff --git a/internal/util/streamingutil/service/contextutil/create_producer_test.go b/internal/util/streamingutil/service/contextutil/create_producer_test.go new file mode 100644 index 0000000000..aac67e6104 --- /dev/null +++ b/internal/util/streamingutil/service/contextutil/create_producer_test.go @@ -0,0 +1,66 @@ +package contextutil + +import ( + "context" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +func TestWithCreateProducer(t *testing.T) { + req := &streamingpb.CreateProducerRequest{ + Pchannel: &streamingpb.PChannelInfo{ + Name: "test", + Term: 1, + }, + } + ctx := WithCreateProducer(context.Background(), req) + + md, ok := metadata.FromOutgoingContext(ctx) + assert.True(t, ok) + assert.NotNil(t, md) + + ctx = metadata.NewIncomingContext(context.Background(), md) + req2, err := GetCreateProducer(ctx) + assert.Nil(t, err) + assert.Equal(t, req.Pchannel.Name, req2.Pchannel.Name) + assert.Equal(t, req.Pchannel.Term, req2.Pchannel.Term) + + // panic case. + assert.Panics(t, func() { WithCreateProducer(context.Background(), nil) }) +} + +func TestGetCreateProducer(t *testing.T) { + // empty context. + req, err := GetCreateProducer(context.Background()) + assert.Error(t, err) + assert.Nil(t, req) + + // key not exist. + md := metadata.New(map[string]string{}) + req, err = GetCreateProducer(metadata.NewIncomingContext(context.Background(), md)) + assert.Error(t, err) + assert.Nil(t, req) + + // invalid value. + md = metadata.New(map[string]string{ + createProducerKey: "invalid", + }) + req, err = GetCreateProducer(metadata.NewIncomingContext(context.Background(), md)) + assert.Error(t, err) + assert.Nil(t, req) + + // unmarshal error. + md = metadata.New(map[string]string{ + createProducerKey: base64.StdEncoding.EncodeToString([]byte("invalid")), + }) + req, err = GetCreateProducer(metadata.NewIncomingContext(context.Background(), md)) + assert.Error(t, err) + assert.Nil(t, req) + + // normal case is tested on TestWithCreateProducer. +} diff --git a/internal/util/streamingutil/status/rpc_error.go b/internal/util/streamingutil/status/rpc_error.go index c2ce2c1128..d204e0a96f 100644 --- a/internal/util/streamingutil/status/rpc_error.go +++ b/internal/util/streamingutil/status/rpc_error.go @@ -22,6 +22,7 @@ var streamingErrorToGRPCStatus = map[streamingpb.StreamingCode]codes.Code{ streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM: codes.FailedPrecondition, streamingpb.StreamingCode_STREAMING_CODE_IGNORED_OPERATION: codes.FailedPrecondition, streamingpb.StreamingCode_STREAMING_CODE_INNER: codes.Unavailable, + streamingpb.StreamingCode_STREAMING_CODE_INVAILD_ARGUMENT: codes.InvalidArgument, streamingpb.StreamingCode_STREAMING_CODE_UNKNOWN: codes.Unknown, } diff --git a/internal/util/streamingutil/status/streaming_error.go b/internal/util/streamingutil/status/streaming_error.go index a27800d35a..28a705fc9a 100644 --- a/internal/util/streamingutil/status/streaming_error.go +++ b/internal/util/streamingutil/status/streaming_error.go @@ -57,13 +57,13 @@ func NewChannelExist(format string, args ...interface{}) *StreamingError { } // NewChannelNotExist creates a new StreamingError with code STREAMING_CODE_CHANNEL_NOT_EXIST. -func NewChannelNotExist(format string, args ...interface{}) *StreamingError { - return New(streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST, format, args...) +func NewChannelNotExist(channel string) *StreamingError { + return New(streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST, "%s not exist", channel) } // NewUnmatchedChannelTerm creates a new StreamingError with code StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM. -func NewUnmatchedChannelTerm(format string, args ...interface{}) *StreamingError { - return New(streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM, format, args...) +func NewUnmatchedChannelTerm(channel string, expectedTerm int64, currentTerm int64) *StreamingError { + return New(streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM, "channel %s at term %d is expected, but current term is %d", channel, expectedTerm, currentTerm) } // NewIgnoreOperation creates a new StreamingError with code STREAMING_CODE_IGNORED_OPERATION. @@ -76,6 +76,11 @@ func NewInner(format string, args ...interface{}) *StreamingError { return New(streamingpb.StreamingCode_STREAMING_CODE_INNER, format, args...) } +// NewInvaildArgument creates a new StreamingError with code STREAMING_CODE_INVAILD_ARGUMENT. +func NewInvaildArgument(format string, args ...interface{}) *StreamingError { + return New(streamingpb.StreamingCode_STREAMING_CODE_INVAILD_ARGUMENT, format, args...) +} + // New creates a new StreamingError with the given code and cause. func New(code streamingpb.StreamingCode, format string, args ...interface{}) *StreamingError { if len(args) == 0 { diff --git a/internal/util/streamingutil/status/streaming_error_test.go b/internal/util/streamingutil/status/streaming_error_test.go index 9c7c8dce2f..9becfcf0fd 100644 --- a/internal/util/streamingutil/status/streaming_error_test.go +++ b/internal/util/streamingutil/status/streaming_error_test.go @@ -39,8 +39,8 @@ func TestStreamingError(t *testing.T) { pbErr = streamingErr.AsPBError() assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST, pbErr.Code) - streamingErr = NewUnmatchedChannelTerm("test") - assert.Contains(t, streamingErr.Error(), "code: STREAMING_CODE_UNMATCHED_CHANNEL_TERM, cause: test") + streamingErr = NewUnmatchedChannelTerm("test", 1, 2) + assert.Contains(t, streamingErr.Error(), "code: STREAMING_CODE_UNMATCHED_CHANNEL_TERM, cause: channel test") assert.True(t, streamingErr.IsWrongStreamingNode()) pbErr = streamingErr.AsPBError() assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM, pbErr.Code) diff --git a/internal/util/streamingutil/typeconverter/deliver.go b/internal/util/streamingutil/typeconverter/deliver.go new file mode 100644 index 0000000000..7c4f33bf61 --- /dev/null +++ b/internal/util/streamingutil/typeconverter/deliver.go @@ -0,0 +1,137 @@ +package typeconverter + +import ( + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/options" +) + +// NewDeliverPolicyFromProto converts protobuf DeliverPolicy to DeliverPolicy +func NewDeliverPolicyFromProto(name string, policy *streamingpb.DeliverPolicy) (options.DeliverPolicy, error) { + switch policy := policy.GetPolicy().(type) { + case *streamingpb.DeliverPolicy_All: + return options.DeliverPolicyAll(), nil + case *streamingpb.DeliverPolicy_Latest: + return options.DeliverPolicyLatest(), nil + case *streamingpb.DeliverPolicy_StartFrom: + msgID, err := message.UnmarshalMessageID(name, policy.StartFrom.GetId()) + if err != nil { + return nil, err + } + return options.DeliverPolicyStartFrom(msgID), nil + case *streamingpb.DeliverPolicy_StartAfter: + msgID, err := message.UnmarshalMessageID(name, policy.StartAfter.GetId()) + if err != nil { + return nil, err + } + return options.DeliverPolicyStartAfter(msgID), nil + default: + return nil, errors.New("unknown deliver policy") + } +} + +// NewProtoFromDeliverPolicy converts DeliverPolicy to protobuf DeliverPolicy +func NewProtoFromDeliverPolicy(policy options.DeliverPolicy) (*streamingpb.DeliverPolicy, error) { + switch policy.Policy() { + case options.DeliverPolicyTypeAll: + return &streamingpb.DeliverPolicy{ + Policy: &streamingpb.DeliverPolicy_All{}, + }, nil + case options.DeliverPolicyTypeLatest: + return &streamingpb.DeliverPolicy{ + Policy: &streamingpb.DeliverPolicy_Latest{}, + }, nil + case options.DeliverPolicyTypeStartFrom: + return &streamingpb.DeliverPolicy{ + Policy: &streamingpb.DeliverPolicy_StartFrom{ + StartFrom: &streamingpb.MessageID{ + Id: policy.MessageID().Marshal(), + }, + }, + }, nil + case options.DeliverPolicyTypeStartAfter: + return &streamingpb.DeliverPolicy{ + Policy: &streamingpb.DeliverPolicy_StartAfter{ + StartAfter: &streamingpb.MessageID{ + Id: policy.MessageID().Marshal(), + }, + }, + }, nil + default: + return nil, errors.New("unknown deliver policy") + } +} + +// NewProtosFromDeliverFilters converts DeliverFilter to protobuf DeliverFilter +func NewProtosFromDeliverFilters(filter []options.DeliverFilter) ([]*streamingpb.DeliverFilter, error) { + protos := make([]*streamingpb.DeliverFilter, 0, len(filter)) + for _, f := range filter { + proto, err := NewProtoFromDeliverFilter(f) + if err != nil { + return nil, err + } + protos = append(protos, proto) + } + return protos, nil +} + +// NewProtoFromDeliverFilter converts DeliverFilter to protobuf DeliverFilter +func NewProtoFromDeliverFilter(filter options.DeliverFilter) (*streamingpb.DeliverFilter, error) { + switch filter.Type() { + case options.DeliverFilterTypeTimeTickGT: + return &streamingpb.DeliverFilter{ + Filter: &streamingpb.DeliverFilter_TimeTickGt{ + TimeTickGt: &streamingpb.DeliverFilterTimeTickGT{ + TimeTick: filter.(interface{ TimeTick() uint64 }).TimeTick(), + }, + }, + }, nil + case options.DeliverFilterTypeTimeTickGTE: + return &streamingpb.DeliverFilter{ + Filter: &streamingpb.DeliverFilter_TimeTickGte{ + TimeTickGte: &streamingpb.DeliverFilterTimeTickGTE{ + TimeTick: filter.(interface{ TimeTick() uint64 }).TimeTick(), + }, + }, + }, nil + case options.DeliverFilterTypeVChannel: + return &streamingpb.DeliverFilter{ + Filter: &streamingpb.DeliverFilter_Vchannel{ + Vchannel: &streamingpb.DeliverFilterVChannel{ + Vchannel: filter.(interface{ VChannel() string }).VChannel(), + }, + }, + }, nil + default: + return nil, errors.New("unknown deliver filter") + } +} + +// NewDeliverFiltersFromProtos converts protobuf DeliverFilter to DeliverFilter +func NewDeliverFiltersFromProtos(protos []*streamingpb.DeliverFilter) ([]options.DeliverFilter, error) { + filters := make([]options.DeliverFilter, 0, len(protos)) + for _, p := range protos { + f, err := NewDeliverFilterFromProto(p) + if err != nil { + return nil, err + } + filters = append(filters, f) + } + return filters, nil +} + +// NewDeliverFilterFromProto converts protobuf DeliverFilter to DeliverFilter +func NewDeliverFilterFromProto(proto *streamingpb.DeliverFilter) (options.DeliverFilter, error) { + switch proto.Filter.(type) { + case *streamingpb.DeliverFilter_TimeTickGt: + return options.DeliverFilterTimeTickGT(proto.GetTimeTickGt().GetTimeTick()), nil + case *streamingpb.DeliverFilter_TimeTickGte: + return options.DeliverFilterTimeTickGTE(proto.GetTimeTickGte().GetTimeTick()), nil + case *streamingpb.DeliverFilter_Vchannel: + return options.DeliverFilterVChannel(proto.GetVchannel().GetVchannel()), nil + default: + return nil, errors.New("unknown deliver filter") + } +} diff --git a/internal/util/streamingutil/typeconverter/deliver_test.go b/internal/util/streamingutil/typeconverter/deliver_test.go new file mode 100644 index 0000000000..77ca100631 --- /dev/null +++ b/internal/util/streamingutil/typeconverter/deliver_test.go @@ -0,0 +1,73 @@ +package typeconverter + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/options" +) + +func TestDeliverFilter(t *testing.T) { + filters := []options.DeliverFilter{ + options.DeliverFilterTimeTickGT(1), + options.DeliverFilterTimeTickGTE(2), + options.DeliverFilterVChannel("vchannel"), + } + pbFilters, err := NewProtosFromDeliverFilters(filters) + assert.NoError(t, err) + assert.Equal(t, len(filters), len(pbFilters)) + filters2, err := NewDeliverFiltersFromProtos(pbFilters) + assert.NoError(t, err) + assert.Equal(t, len(filters), len(filters2)) + for idx, filter := range filters { + filter2 := filters2[idx] + assert.Equal(t, filter.Type(), filter2.Type()) + switch filter.Type() { + case options.DeliverFilterTypeTimeTickGT: + assert.Equal(t, filter.(interface{ TimeTick() uint64 }).TimeTick(), filter2.(interface{ TimeTick() uint64 }).TimeTick()) + case options.DeliverFilterTypeTimeTickGTE: + assert.Equal(t, filter.(interface{ TimeTick() uint64 }).TimeTick(), filter2.(interface{ TimeTick() uint64 }).TimeTick()) + case options.DeliverFilterTypeVChannel: + assert.Equal(t, filter.(interface{ VChannel() string }).VChannel(), filter2.(interface{ VChannel() string }).VChannel()) + } + } +} + +func TestDeliverPolicy(t *testing.T) { + policy := options.DeliverPolicyAll() + pbPolicy, err := NewProtoFromDeliverPolicy(policy) + assert.NoError(t, err) + policy2, err := NewDeliverPolicyFromProto("mock", pbPolicy) + assert.NoError(t, err) + assert.Equal(t, policy.Policy(), policy2.Policy()) + + policy = options.DeliverPolicyLatest() + pbPolicy, err = NewProtoFromDeliverPolicy(policy) + assert.NoError(t, err) + policy2, err = NewDeliverPolicyFromProto("mock", pbPolicy) + assert.NoError(t, err) + assert.Equal(t, policy.Policy(), policy2.Policy()) + + msgID := mock_message.NewMockMessageID(t) + msgID.EXPECT().Marshal().Return([]byte("mock")) + message.RegisterMessageIDUnmsarshaler("mock", func(b []byte) (message.MessageID, error) { + return msgID, nil + }) + + policy = options.DeliverPolicyStartFrom(msgID) + pbPolicy, err = NewProtoFromDeliverPolicy(policy) + assert.NoError(t, err) + policy2, err = NewDeliverPolicyFromProto("mock", pbPolicy) + assert.NoError(t, err) + assert.Equal(t, policy.Policy(), policy2.Policy()) + + policy = options.DeliverPolicyStartAfter(msgID) + pbPolicy, err = NewProtoFromDeliverPolicy(policy) + assert.NoError(t, err) + policy2, err = NewDeliverPolicyFromProto("mock", pbPolicy) + assert.NoError(t, err) + assert.Equal(t, policy.Policy(), policy2.Policy()) +} diff --git a/internal/util/streamingutil/typeconverter/pchannel_info.go b/internal/util/streamingutil/typeconverter/pchannel_info.go new file mode 100644 index 0000000000..267b467180 --- /dev/null +++ b/internal/util/streamingutil/typeconverter/pchannel_info.go @@ -0,0 +1,34 @@ +package typeconverter + +import ( + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +// NewPChannelInfoFromProto converts protobuf PChannelInfo to PChannelInfo +func NewPChannelInfoFromProto(pchannel *streamingpb.PChannelInfo) types.PChannelInfo { + if pchannel.GetName() == "" { + panic("pchannel name is empty") + } + if pchannel.GetTerm() <= 0 { + panic("pchannel term is empty or negetive") + } + return types.PChannelInfo{ + Name: pchannel.GetName(), + Term: pchannel.GetTerm(), + } +} + +// NewProtoFromPChannelInfo converts PChannelInfo to protobuf PChannelInfo +func NewProtoFromPChannelInfo(pchannel types.PChannelInfo) *streamingpb.PChannelInfo { + if pchannel.Name == "" { + panic("pchannel name is empty") + } + if pchannel.Term <= 0 { + panic("pchannel term is empty or negetive") + } + return &streamingpb.PChannelInfo{ + Name: pchannel.Name, + Term: pchannel.Term, + } +} diff --git a/internal/util/streamingutil/typeconverter/pchannel_info_test.go b/internal/util/streamingutil/typeconverter/pchannel_info_test.go new file mode 100644 index 0000000000..7aeeeb441e --- /dev/null +++ b/internal/util/streamingutil/typeconverter/pchannel_info_test.go @@ -0,0 +1,34 @@ +package typeconverter + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +func TestPChannelInfo(t *testing.T) { + info := types.PChannelInfo{Name: "pchannel", Term: 1} + pbInfo := NewProtoFromPChannelInfo(info) + + info2 := NewPChannelInfoFromProto(pbInfo) + assert.Equal(t, info.Name, info2.Name) + assert.Equal(t, info.Term, info2.Term) + + assert.Panics(t, func() { + NewProtoFromPChannelInfo(types.PChannelInfo{Name: "", Term: 1}) + }) + assert.Panics(t, func() { + NewProtoFromPChannelInfo(types.PChannelInfo{Name: "c", Term: -1}) + }) + + assert.Panics(t, func() { + NewPChannelInfoFromProto(&streamingpb.PChannelInfo{Name: "", Term: 1}) + }) + + assert.Panics(t, func() { + NewPChannelInfoFromProto(&streamingpb.PChannelInfo{Name: "c", Term: -1}) + }) +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 91304fbb1e..0183c446da 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -85,6 +85,7 @@ const ( collectionIDLabelName = "collection_id" partitionIDLabelName = "partition_id" channelNameLabelName = "channel_name" + channelTermLabelName = "channel_term" functionLabelName = "function_name" queryTypeLabelName = "query_type" collectionName = "collection_name" diff --git a/pkg/metrics/streaming_service_metrics.go b/pkg/metrics/streaming_service_metrics.go index b929a08def..9e2899d7f8 100644 --- a/pkg/metrics/streaming_service_metrics.go +++ b/pkg/metrics/streaming_service_metrics.go @@ -60,11 +60,6 @@ var ( Help: "Total of pchannels", }) - // StreamingCoordVChannelTotal = newStreamingCoordGaugeVec(prometheus.GaugeOpts{ - // Name: "vchannel_total", - // Help: "Total of vchannels", - // }) - StreamingCoordAssignmentListenerTotal = newStreamingCoordGaugeVec(prometheus.GaugeOpts{ Name: "assignment_listener_total", Help: "Total of assignment listener", @@ -95,19 +90,19 @@ var ( Name: "produce_bytes", Help: "Bytes of produced message", Buckets: bytesBuckets, - }) + }, channelNameLabelName, channelTermLabelName, statusLabelName) StreamingNodeConsumeBytes = newStreamingNodeHistogramVec(prometheus.HistogramOpts{ Name: "consume_bytes", Help: "Bytes of consumed message", Buckets: bytesBuckets, - }) + }, channelNameLabelName, channelTermLabelName) StreamingNodeProduceDurationSeconds = newStreamingNodeHistogramVec(prometheus.HistogramOpts{ Name: "produce_duration_seconds", Help: "Duration of producing message", Buckets: secondsBuckets, - }, statusLabelName) + }, channelNameLabelName, channelTermLabelName, statusLabelName) ) func RegisterStreamingServiceClient(registry *prometheus.Registry) { diff --git a/pkg/mocks/streaming/mock_walimpls/mock_WALImpls.go b/pkg/mocks/streaming/mock_walimpls/mock_WALImpls.go index ab08675849..f85f320cb8 100644 --- a/pkg/mocks/streaming/mock_walimpls/mock_WALImpls.go +++ b/pkg/mocks/streaming/mock_walimpls/mock_WALImpls.go @@ -209,6 +209,47 @@ func (_c *MockWALImpls_Read_Call) RunAndReturn(run func(context.Context, walimpl return _c } +// WALName provides a mock function with given fields: +func (_m *MockWALImpls) WALName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockWALImpls_WALName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WALName' +type MockWALImpls_WALName_Call struct { + *mock.Call +} + +// WALName is a helper method to define mock.On call +func (_e *MockWALImpls_Expecter) WALName() *MockWALImpls_WALName_Call { + return &MockWALImpls_WALName_Call{Call: _e.mock.On("WALName")} +} + +func (_c *MockWALImpls_WALName_Call) Run(run func()) *MockWALImpls_WALName_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockWALImpls_WALName_Call) Return(_a0 string) *MockWALImpls_WALName_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWALImpls_WALName_Call) RunAndReturn(run func() string) *MockWALImpls_WALName_Call { + _c.Call.Return(run) + return _c +} + // NewMockWALImpls creates a new instance of MockWALImpls. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockWALImpls(t interface { diff --git a/pkg/streaming/util/options/deliver.go b/pkg/streaming/util/options/deliver.go index ebdd7554cf..71e1416629 100644 --- a/pkg/streaming/util/options/deliver.go +++ b/pkg/streaming/util/options/deliver.go @@ -5,52 +5,28 @@ import ( ) const ( - deliverOrderTimetick DeliverOrder = 1 + DeliverPolicyTypeAll deliverPolicyType = 1 + DeliverPolicyTypeLatest deliverPolicyType = 2 + DeliverPolicyTypeStartFrom deliverPolicyType = 3 + DeliverPolicyTypeStartAfter deliverPolicyType = 4 - DeliverPolicyTypeAll DeliverPolicyType = 1 - DeliverPolicyTypeLatest DeliverPolicyType = 2 - DeliverPolicyTypeStartFrom DeliverPolicyType = 3 - DeliverPolicyTypeStartAfter DeliverPolicyType = 4 + DeliverFilterTypeTimeTickGT deliverFilterType = 1 + DeliverFilterTypeTimeTickGTE deliverFilterType = 2 + DeliverFilterTypeVChannel deliverFilterType = 3 ) -// DeliverOrder is the order of delivering messages. type ( - DeliverOrder int - DeliverPolicyType int + deliverPolicyType int + deliverFilterType int ) // DeliverPolicy is the policy of delivering messages. type DeliverPolicy interface { - Policy() DeliverPolicyType + Policy() deliverPolicyType MessageID() message.MessageID } -type deliverPolicyWithoutMessageID struct { - policy DeliverPolicyType -} - -func (d *deliverPolicyWithoutMessageID) Policy() DeliverPolicyType { - return d.policy -} - -func (d *deliverPolicyWithoutMessageID) MessageID() message.MessageID { - panic("not implemented") -} - -type deliverPolicyWithMessageID struct { - policy DeliverPolicyType - messageID message.MessageID -} - -func (d *deliverPolicyWithMessageID) Policy() DeliverPolicyType { - return d.policy -} - -func (d *deliverPolicyWithMessageID) MessageID() message.MessageID { - return d.messageID -} - // DeliverPolicyAll delivers all messages. func DeliverPolicyAll() DeliverPolicy { return &deliverPolicyWithoutMessageID{ @@ -81,7 +57,34 @@ func DeliverPolicyStartAfter(messageID message.MessageID) DeliverPolicy { } } -// DeliverOrderTimeTick delivers messages by time tick. -func DeliverOrderTimeTick() DeliverOrder { - return deliverOrderTimetick +// DeliverFilter is the filter of delivering messages. +type DeliverFilter interface { + Type() deliverFilterType + + Filter(message.ImmutableMessage) bool +} + +// +// DeliverFilters +// + +// DeliverFilterTimeTickGT delivers messages by time tick greater than the specified time tick. +func DeliverFilterTimeTickGT(timeTick uint64) DeliverFilter { + return &deliverFilterTimeTickGT{ + timeTick: timeTick, + } +} + +// DeliverFilterTimeTickGTE delivers messages by time tick greater than or equal to the specified time tick. +func DeliverFilterTimeTickGTE(timeTick uint64) DeliverFilter { + return &deliverFilterTimeTickGTE{ + timeTick: timeTick, + } +} + +// DeliverFilterVChannel delivers messages filtered by vchannel. +func DeliverFilterVChannel(vchannel string) DeliverFilter { + return &deliverFilterVChannel{ + vchannel: vchannel, + } } diff --git a/pkg/streaming/util/options/deliver_impl.go b/pkg/streaming/util/options/deliver_impl.go new file mode 100644 index 0000000000..e6e99abc1c --- /dev/null +++ b/pkg/streaming/util/options/deliver_impl.go @@ -0,0 +1,81 @@ +package options + +import "github.com/milvus-io/milvus/pkg/streaming/util/message" + +// deliverPolicyWithoutMessageID is the policy of delivering messages without messageID. +type deliverPolicyWithoutMessageID struct { + policy deliverPolicyType +} + +func (d *deliverPolicyWithoutMessageID) Policy() deliverPolicyType { + return d.policy +} + +func (d *deliverPolicyWithoutMessageID) MessageID() message.MessageID { + panic("not implemented") +} + +// deliverPolicyWithMessageID is the policy of delivering messages with messageID. +type deliverPolicyWithMessageID struct { + policy deliverPolicyType + messageID message.MessageID +} + +func (d *deliverPolicyWithMessageID) Policy() deliverPolicyType { + return d.policy +} + +func (d *deliverPolicyWithMessageID) MessageID() message.MessageID { + return d.messageID +} + +// deliverFilterTimeTickGT delivers messages by time tick greater than the specified time tick. +type deliverFilterTimeTickGT struct { + timeTick uint64 +} + +func (f *deliverFilterTimeTickGT) Type() deliverFilterType { + return DeliverFilterTypeTimeTickGT +} + +func (f *deliverFilterTimeTickGT) TimeTick() uint64 { + return f.timeTick +} + +func (f *deliverFilterTimeTickGT) Filter(msg message.ImmutableMessage) bool { + return msg.TimeTick() > f.timeTick +} + +// deliverFilterTimeTickGTE delivers messages by time tick greater than or equal to the specified time tick. +type deliverFilterTimeTickGTE struct { + timeTick uint64 +} + +func (f *deliverFilterTimeTickGTE) Type() deliverFilterType { + return DeliverFilterTypeTimeTickGTE +} + +func (f *deliverFilterTimeTickGTE) TimeTick() uint64 { + return f.timeTick +} + +func (f *deliverFilterTimeTickGTE) Filter(msg message.ImmutableMessage) bool { + return msg.TimeTick() >= f.timeTick +} + +// deliverFilterVChannel delivers messages by vchannel. +type deliverFilterVChannel struct { + vchannel string +} + +func (f *deliverFilterVChannel) Type() deliverFilterType { + return DeliverFilterTypeVChannel +} + +func (f *deliverFilterVChannel) VChannel() string { + return f.vchannel +} + +func (f *deliverFilterVChannel) Filter(msg message.ImmutableMessage) bool { + return msg.VChannel() == f.vchannel +} diff --git a/pkg/streaming/util/options/deliver_test.go b/pkg/streaming/util/options/deliver_test.go index 721e0cdc31..bf72ab5f29 100644 --- a/pkg/streaming/util/options/deliver_test.go +++ b/pkg/streaming/util/options/deliver_test.go @@ -8,7 +8,7 @@ import ( "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" ) -func TestDeliver(t *testing.T) { +func TestDeliverPolicy(t *testing.T) { policy := DeliverPolicyAll() assert.Equal(t, DeliverPolicyTypeAll, policy.Policy()) assert.Panics(t, func() { @@ -30,3 +30,35 @@ func TestDeliver(t *testing.T) { assert.Equal(t, DeliverPolicyTypeStartAfter, policy.Policy()) assert.Equal(t, messageID, policy.MessageID()) } + +func TestDeliverFilter(t *testing.T) { + filter := DeliverFilterTimeTickGT(1) + assert.Equal(t, uint64(1), filter.(interface{ TimeTick() uint64 }).TimeTick()) + assert.Equal(t, DeliverFilterTypeTimeTickGT, filter.Type()) + msg := mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(uint64(1)) + assert.False(t, filter.Filter(msg)) + msg.EXPECT().TimeTick().Unset() + msg.EXPECT().TimeTick().Return(uint64(2)) + assert.True(t, filter.Filter(msg)) + + filter = DeliverFilterTimeTickGTE(2) + assert.Equal(t, uint64(2), filter.(interface{ TimeTick() uint64 }).TimeTick()) + assert.Equal(t, DeliverFilterTypeTimeTickGTE, filter.Type()) + msg.EXPECT().TimeTick().Unset() + msg.EXPECT().TimeTick().Return(uint64(1)) + assert.False(t, filter.Filter(msg)) + msg.EXPECT().TimeTick().Unset() + msg.EXPECT().TimeTick().Return(uint64(2)) + assert.True(t, filter.Filter(msg)) + + filter = DeliverFilterVChannel("vchannel") + assert.Equal(t, "vchannel", filter.(interface{ VChannel() string }).VChannel()) + assert.Equal(t, DeliverFilterTypeVChannel, filter.Type()) + msg.EXPECT().VChannel().Unset() + msg.EXPECT().VChannel().Return("vchannel2") + assert.False(t, filter.Filter(msg)) + msg.EXPECT().VChannel().Unset() + msg.EXPECT().VChannel().Return("vchannel") + assert.True(t, filter.Filter(msg)) +} diff --git a/pkg/streaming/util/types/pchannel_info.go b/pkg/streaming/util/types/pchannel_info.go index 66656450c6..1da295d639 100644 --- a/pkg/streaming/util/types/pchannel_info.go +++ b/pkg/streaming/util/types/pchannel_info.go @@ -6,7 +6,6 @@ const ( // PChannelInfo is the struct for pchannel info. type PChannelInfo struct { - Name string // name of pchannel. - Term int64 // term of pchannel. - ServerID int64 // assigned streaming node server id of pchannel. + Name string // name of pchannel. + Term int64 // term of pchannel. } diff --git a/pkg/streaming/walimpls/helper/wal_helper_test.go b/pkg/streaming/walimpls/helper/wal_helper_test.go index 7917a7adde..e3b9c1b79b 100644 --- a/pkg/streaming/walimpls/helper/wal_helper_test.go +++ b/pkg/streaming/walimpls/helper/wal_helper_test.go @@ -12,9 +12,8 @@ import ( func TestWALHelper(t *testing.T) { h := NewWALHelper(&walimpls.OpenOption{ Channel: types.PChannelInfo{ - Name: "test", - Term: 1, - ServerID: 1, + Name: "test", + Term: 1, }, }) assert.NotNil(t, h.Channel()) diff --git a/pkg/streaming/walimpls/impls/pulsar/wal.go b/pkg/streaming/walimpls/impls/pulsar/wal.go index 54e073f8f9..60a753d63e 100644 --- a/pkg/streaming/walimpls/impls/pulsar/wal.go +++ b/pkg/streaming/walimpls/impls/pulsar/wal.go @@ -20,6 +20,10 @@ type walImpl struct { p pulsar.Producer } +func (w *walImpl) WALName() string { + return walName +} + func (w *walImpl) Append(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) { id, err := w.p.Send(ctx, &pulsar.ProducerMessage{ Payload: msg.Payload(), diff --git a/pkg/streaming/walimpls/impls/rmq/wal.go b/pkg/streaming/walimpls/impls/rmq/wal.go index a00b9ef043..16a9cee0e3 100644 --- a/pkg/streaming/walimpls/impls/rmq/wal.go +++ b/pkg/streaming/walimpls/impls/rmq/wal.go @@ -24,6 +24,10 @@ type walImpl struct { c client.Client } +func (w *walImpl) WALName() string { + return walName +} + // Append appends a message to the wal. func (w *walImpl) Append(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) { id, err := w.p.Send(&common.ProducerMessage{ diff --git a/pkg/streaming/walimpls/impls/walimplstest/wal.go b/pkg/streaming/walimpls/impls/walimplstest/wal.go index 595ad8aec3..0dd3448685 100644 --- a/pkg/streaming/walimpls/impls/walimplstest/wal.go +++ b/pkg/streaming/walimpls/impls/walimplstest/wal.go @@ -19,6 +19,10 @@ type walImpls struct { datas *messageLog } +func (w *walImpls) WALName() string { + return WALName +} + func (w *walImpls) Append(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) { return w.datas.Append(ctx, msg) } diff --git a/pkg/streaming/walimpls/test_framework.go b/pkg/streaming/walimpls/test_framework.go index 89f6131d45..7b345e94e3 100644 --- a/pkg/streaming/walimpls/test_framework.go +++ b/pkg/streaming/walimpls/test_framework.go @@ -98,9 +98,8 @@ func (f *testOneWALImplsFramework) Run() { // test a read write loop for ; f.term <= 3; f.term++ { pChannel := types.PChannelInfo{ - Name: f.pchannel, - Term: int64(f.term), - ServerID: 1, + Name: f.pchannel, + Term: int64(f.term), } // create a wal. w, err := f.opener.Open(ctx, &OpenOption{ @@ -109,7 +108,6 @@ func (f *testOneWALImplsFramework) Run() { assert.NoError(f.t, err) assert.NotNil(f.t, w) assert.Equal(f.t, pChannel.Name, w.Channel().Name) - assert.Equal(f.t, pChannel.ServerID, w.Channel().ServerID) assert.Equal(f.t, pChannel.Term, w.Channel().Term) f.testReadAndWrite(ctx, w) diff --git a/pkg/streaming/walimpls/wal.go b/pkg/streaming/walimpls/wal.go index 6d65fd1606..64c87f7d2c 100644 --- a/pkg/streaming/walimpls/wal.go +++ b/pkg/streaming/walimpls/wal.go @@ -8,6 +8,9 @@ import ( ) type WALImpls interface { + // WALName returns the name of the wal. + WALName() string + // Channel returns the channel assignment info of the wal. // Should be read-only. Channel() types.PChannelInfo