diff --git a/internal/.mockery.yaml b/internal/.mockery.yaml index d50d113192..646dbb2b53 100644 --- a/internal/.mockery.yaml +++ b/internal/.mockery.yaml @@ -34,6 +34,9 @@ packages: Interceptor: InterceptorWithReady: InterceptorBuilder: + github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector: + interfaces: + SealOperator: google.golang.org/grpc: interfaces: ClientStream: @@ -54,6 +57,7 @@ packages: github.com/milvus-io/milvus/internal/metastore: interfaces: StreamingCoordCataLog: + StreamingNodeCataLog: github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer: interfaces: Discoverer: @@ -71,3 +75,4 @@ packages: google.golang.org/grpc/balancer: interfaces: SubConn: + diff --git a/internal/datacoord/services_test.go b/internal/datacoord/services_test.go index 3aa64919b9..6a6081b5b3 100644 --- a/internal/datacoord/services_test.go +++ b/internal/datacoord/services_test.go @@ -475,6 +475,10 @@ func (s *ServerSuite) TestFlush_NormalCase() { expireTs := allocations[0].ExpireTime segID := allocations[0].SegmentID + info, err := s.testServer.segmentManager.AllocNewGrowingSegment(context.TODO(), 0, 1, 1, "channel-1") + s.NoError(err) + s.NotNil(info) + resp, err := s.testServer.Flush(context.TODO(), req) s.NoError(err) s.EqualValues(commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) diff --git a/internal/distributed/streamingnode/service.go b/internal/distributed/streamingnode/service.go index 101af807f9..27ee50a77f 100644 --- a/internal/distributed/streamingnode/service.go +++ b/internal/distributed/streamingnode/service.go @@ -27,6 +27,7 @@ import ( "github.com/cockroachdb/errors" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" + "github.com/tikv/client-go/v2/txnkv" clientv3 "go.etcd.io/etcd/client/v3" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.uber.org/zap" @@ -36,19 +37,24 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" dcc "github.com/milvus-io/milvus/internal/distributed/datacoord/client" rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + tikvkv "github.com/milvus-io/milvus/internal/kv/tikv" streamingnodeserver "github.com/milvus-io/milvus/internal/streamingnode/server" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/componentutil" kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv" "github.com/milvus-io/milvus/internal/util/sessionutil" streamingserviceinterceptor "github.com/milvus-io/milvus/internal/util/streamingutil/service/interceptor" + "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/tracer" + "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/interceptor" "github.com/milvus-io/milvus/pkg/util/logutil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/retry" + "github.com/milvus-io/milvus/pkg/util/tikv" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -59,6 +65,7 @@ type Server struct { // session of current server. session *sessionutil.Session + metaKV kv.MetaKv // server streamingnode *streamingnodeserver.Server @@ -69,6 +76,7 @@ type Server struct { // component client etcdCli *clientv3.Client + tikvCli *txnkv.Client rootCoord types.RootCoordClient dataCoord types.DataCoordClient } @@ -112,14 +120,14 @@ func (s *Server) stop() { log.Warn("streamingnode unregister session failed", zap.Error(err)) } - // Stop grpc server. - log.Info("streamingnode stop grpc server...") - s.grpcServer.GracefulStop() - // Stop StreamingNode service. log.Info("streamingnode stop service...") s.streamingnode.Stop() + // Stop grpc server. + log.Info("streamingnode stop grpc server...") + s.grpcServer.GracefulStop() + // Stop all session log.Info("streamingnode stop session...") s.session.Stop() @@ -130,6 +138,13 @@ func (s *Server) stop() { log.Warn("streamingnode stop rootCoord client failed", zap.Error(err)) } + // Stop tikv + if s.tikvCli != nil { + if err := s.tikvCli.Close(); err != nil { + log.Warn("streamingnode stop tikv client failed", zap.Error(err)) + } + } + // Wait for grpc server to stop. log.Info("wait for grpc server stop...") <-s.grpcServerChan @@ -153,6 +168,9 @@ func (s *Server) init() (err error) { // Create etcd client. s.etcdCli, _ = kvfactory.GetEtcdAndPath() + if err := s.initMeta(); err != nil { + return err + } if err := s.allocateAddress(); err != nil { return err } @@ -174,6 +192,7 @@ func (s *Server) init() (err error) { WithRootCoordClient(s.rootCoord). WithDataCoordClient(s.dataCoord). WithSession(s.session). + WithMetaKV(s.metaKV). Build() if err := s.streamingnode.Init(context.Background()); err != nil { return errors.Wrap(err, "StreamingNode service init failed") @@ -218,6 +237,29 @@ func (s *Server) initSession() error { return nil } +func (s *Server) initMeta() error { + params := paramtable.Get() + metaType := params.MetaStoreCfg.MetaStoreType.GetValue() + log.Info("data coordinator connecting to metadata store", zap.String("metaType", metaType)) + metaRootPath := "" + if metaType == util.MetaStoreTypeTiKV { + var err error + s.tikvCli, err = tikv.GetTiKVClient(¶mtable.Get().TiKVCfg) + if err != nil { + log.Warn("Streamingnode init tikv client failed", zap.Error(err)) + return err + } + metaRootPath = params.TiKVCfg.MetaRootPath.GetValue() + s.metaKV = tikvkv.NewTiKV(s.tikvCli, metaRootPath, + tikvkv.WithRequestTimeout(paramtable.Get().ServiceParam.TiKVCfg.RequestTimeout.GetAsDuration(time.Millisecond))) + } else if metaType == util.MetaStoreTypeEtcd { + metaRootPath = params.EtcdCfg.MetaRootPath.GetValue() + s.metaKV = etcdkv.NewEtcdKV(s.etcdCli, metaRootPath, + etcdkv.WithRequestTimeout(paramtable.Get().ServiceParam.EtcdCfg.RequestTimeout.GetAsDuration(time.Millisecond))) + } + return nil +} + func (s *Server) initRootCoord() (err error) { log.Info("StreamingNode connect to rootCoord...") s.rootCoord, err = rcc.NewClient(context.Background()) diff --git a/internal/metastore/catalog.go b/internal/metastore/catalog.go index 1e3e1cf5c7..bb2ed9316d 100644 --- a/internal/metastore/catalog.go +++ b/internal/metastore/catalog.go @@ -198,3 +198,10 @@ type StreamingCoordCataLog interface { // SavePChannel save a pchannel info to metastore. SavePChannels(ctx context.Context, info []*streamingpb.PChannelMeta) error } + +// StreamingNodeCataLog is the interface for streamingnode catalog +type StreamingNodeCataLog interface { + ListSegmentAssignment(ctx context.Context, pChannelName string) ([]*streamingpb.SegmentAssignmentMeta, error) + + SaveSegmentAssignments(ctx context.Context, pChannelName string, infos []*streamingpb.SegmentAssignmentMeta) error +} diff --git a/internal/metastore/kv/streamingnode/constant.go b/internal/metastore/kv/streamingnode/constant.go new file mode 100644 index 0000000000..111b0ef7fd --- /dev/null +++ b/internal/metastore/kv/streamingnode/constant.go @@ -0,0 +1,6 @@ +package streamingnode + +const ( + MetaPrefix = "streamingnode-meta" + SegmentAssignMeta = MetaPrefix + "/segment-assign" +) diff --git a/internal/metastore/kv/streamingnode/kv_catalog.go b/internal/metastore/kv/streamingnode/kv_catalog.go new file mode 100644 index 0000000000..37b2250113 --- /dev/null +++ b/internal/metastore/kv/streamingnode/kv_catalog.go @@ -0,0 +1,92 @@ +package streamingnode + +import ( + "context" + "path" + "strconv" + + "github.com/cockroachdb/errors" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus/internal/metastore" + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/kv" + "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/pkg/util/etcd" +) + +// NewCataLog creates a new catalog instance +func NewCataLog(metaKV kv.MetaKv) metastore.StreamingNodeCataLog { + return &catalog{ + metaKV: metaKV, + } +} + +// catalog is a kv based catalog. +type catalog struct { + metaKV kv.MetaKv +} + +func (c *catalog) ListSegmentAssignment(ctx context.Context, pChannelName string) ([]*streamingpb.SegmentAssignmentMeta, error) { + prefix := buildSegmentAssignmentMetaPath(pChannelName) + keys, values, err := c.metaKV.LoadWithPrefix(prefix) + if err != nil { + return nil, err + } + + infos := make([]*streamingpb.SegmentAssignmentMeta, 0, len(values)) + for k, value := range values { + info := &streamingpb.SegmentAssignmentMeta{} + if err = proto.Unmarshal([]byte(value), info); err != nil { + return nil, errors.Wrapf(err, "unmarshal pchannel %s failed", keys[k]) + } + infos = append(infos, info) + } + return infos, nil +} + +// SaveSegmentAssignments saves the segment assignment info to meta storage. +func (c *catalog) SaveSegmentAssignments(ctx context.Context, pChannelName string, infos []*streamingpb.SegmentAssignmentMeta) error { + kvs := make(map[string]string, len(infos)) + removes := make([]string, 0) + for _, info := range infos { + key := buildSegmentAssignmentMetaPathOfSegment(pChannelName, info.GetSegmentId()) + if info.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED { + // Flushed segment should be removed from meta + removes = append(removes, key) + continue + } + + data, err := proto.Marshal(info) + if err != nil { + return errors.Wrapf(err, "marshal segment %d at pchannel %s failed", info.GetSegmentId(), pChannelName) + } + kvs[key] = string(data) + } + + if len(removes) > 0 { + if err := etcd.RemoveByBatchWithLimit(removes, util.MaxEtcdTxnNum, func(partialRemoves []string) error { + return c.metaKV.MultiRemove(partialRemoves) + }); err != nil { + return err + } + } + + if len(kvs) > 0 { + return etcd.SaveByBatchWithLimit(kvs, util.MaxEtcdTxnNum, func(partialKvs map[string]string) error { + return c.metaKV.MultiSave(partialKvs) + }) + } + return nil +} + +// buildSegmentAssignmentMetaPath builds the path for segment assignment +// streamingnode-meta/segment-assign/${pChannelName} +func buildSegmentAssignmentMetaPath(pChannelName string) string { + return path.Join(SegmentAssignMeta, pChannelName) +} + +// buildSegmentAssignmentMetaPathOfSegment builds the path for segment assignment +func buildSegmentAssignmentMetaPathOfSegment(pChannelName string, segmentID int64) string { + return path.Join(SegmentAssignMeta, pChannelName, strconv.FormatInt(segmentID, 10)) +} diff --git a/internal/metastore/kv/streamingnode/kv_catalog_test.go b/internal/metastore/kv/streamingnode/kv_catalog_test.go new file mode 100644 index 0000000000..b1685bf4e9 --- /dev/null +++ b/internal/metastore/kv/streamingnode/kv_catalog_test.go @@ -0,0 +1,43 @@ +package streamingnode + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus/internal/kv/mocks" + "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +func TestCatalog(t *testing.T) { + kv := mocks.NewMetaKv(t) + k := "p1" + v := streamingpb.SegmentAssignmentMeta{} + vs, err := proto.Marshal(&v) + assert.NoError(t, err) + + kv.EXPECT().LoadWithPrefix(mock.Anything).Return([]string{k}, []string{string(vs)}, nil) + catalog := NewCataLog(kv) + ctx := context.Background() + metas, err := catalog.ListSegmentAssignment(ctx, "p1") + assert.Len(t, metas, 1) + assert.NoError(t, err) + + kv.EXPECT().MultiRemove(mock.Anything).Return(nil) + kv.EXPECT().MultiSave(mock.Anything).Return(nil) + + err = catalog.SaveSegmentAssignments(ctx, "p1", []*streamingpb.SegmentAssignmentMeta{ + { + SegmentId: 1, + State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED, + }, + { + SegmentId: 2, + State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_PENDING, + }, + }) + assert.NoError(t, err) +} diff --git a/internal/mocks/mock_metastore/mock_StreamingNodeCataLog.go b/internal/mocks/mock_metastore/mock_StreamingNodeCataLog.go new file mode 100644 index 0000000000..ce000248ba --- /dev/null +++ b/internal/mocks/mock_metastore/mock_StreamingNodeCataLog.go @@ -0,0 +1,137 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_metastore + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + streamingpb "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +// MockStreamingNodeCataLog is an autogenerated mock type for the StreamingNodeCataLog type +type MockStreamingNodeCataLog struct { + mock.Mock +} + +type MockStreamingNodeCataLog_Expecter struct { + mock *mock.Mock +} + +func (_m *MockStreamingNodeCataLog) EXPECT() *MockStreamingNodeCataLog_Expecter { + return &MockStreamingNodeCataLog_Expecter{mock: &_m.Mock} +} + +// ListSegmentAssignment provides a mock function with given fields: ctx, pChannelName +func (_m *MockStreamingNodeCataLog) ListSegmentAssignment(ctx context.Context, pChannelName string) ([]*streamingpb.SegmentAssignmentMeta, error) { + ret := _m.Called(ctx, pChannelName) + + var r0 []*streamingpb.SegmentAssignmentMeta + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]*streamingpb.SegmentAssignmentMeta, error)); ok { + return rf(ctx, pChannelName) + } + if rf, ok := ret.Get(0).(func(context.Context, string) []*streamingpb.SegmentAssignmentMeta); ok { + r0 = rf(ctx, pChannelName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*streamingpb.SegmentAssignmentMeta) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, pChannelName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockStreamingNodeCataLog_ListSegmentAssignment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListSegmentAssignment' +type MockStreamingNodeCataLog_ListSegmentAssignment_Call struct { + *mock.Call +} + +// ListSegmentAssignment is a helper method to define mock.On call +// - ctx context.Context +// - pChannelName string +func (_e *MockStreamingNodeCataLog_Expecter) ListSegmentAssignment(ctx interface{}, pChannelName interface{}) *MockStreamingNodeCataLog_ListSegmentAssignment_Call { + return &MockStreamingNodeCataLog_ListSegmentAssignment_Call{Call: _e.mock.On("ListSegmentAssignment", ctx, pChannelName)} +} + +func (_c *MockStreamingNodeCataLog_ListSegmentAssignment_Call) Run(run func(ctx context.Context, pChannelName string)) *MockStreamingNodeCataLog_ListSegmentAssignment_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockStreamingNodeCataLog_ListSegmentAssignment_Call) Return(_a0 []*streamingpb.SegmentAssignmentMeta, _a1 error) *MockStreamingNodeCataLog_ListSegmentAssignment_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockStreamingNodeCataLog_ListSegmentAssignment_Call) RunAndReturn(run func(context.Context, string) ([]*streamingpb.SegmentAssignmentMeta, error)) *MockStreamingNodeCataLog_ListSegmentAssignment_Call { + _c.Call.Return(run) + return _c +} + +// SaveSegmentAssignments provides a mock function with given fields: ctx, pChannelName, infos +func (_m *MockStreamingNodeCataLog) SaveSegmentAssignments(ctx context.Context, pChannelName string, infos []*streamingpb.SegmentAssignmentMeta) error { + ret := _m.Called(ctx, pChannelName, infos) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, []*streamingpb.SegmentAssignmentMeta) error); ok { + r0 = rf(ctx, pChannelName, infos) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockStreamingNodeCataLog_SaveSegmentAssignments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveSegmentAssignments' +type MockStreamingNodeCataLog_SaveSegmentAssignments_Call struct { + *mock.Call +} + +// SaveSegmentAssignments is a helper method to define mock.On call +// - ctx context.Context +// - pChannelName string +// - infos []*streamingpb.SegmentAssignmentMeta +func (_e *MockStreamingNodeCataLog_Expecter) SaveSegmentAssignments(ctx interface{}, pChannelName interface{}, infos interface{}) *MockStreamingNodeCataLog_SaveSegmentAssignments_Call { + return &MockStreamingNodeCataLog_SaveSegmentAssignments_Call{Call: _e.mock.On("SaveSegmentAssignments", ctx, pChannelName, infos)} +} + +func (_c *MockStreamingNodeCataLog_SaveSegmentAssignments_Call) Run(run func(ctx context.Context, pChannelName string, infos []*streamingpb.SegmentAssignmentMeta)) *MockStreamingNodeCataLog_SaveSegmentAssignments_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].([]*streamingpb.SegmentAssignmentMeta)) + }) + return _c +} + +func (_c *MockStreamingNodeCataLog_SaveSegmentAssignments_Call) Return(_a0 error) *MockStreamingNodeCataLog_SaveSegmentAssignments_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockStreamingNodeCataLog_SaveSegmentAssignments_Call) RunAndReturn(run func(context.Context, string, []*streamingpb.SegmentAssignmentMeta) error) *MockStreamingNodeCataLog_SaveSegmentAssignments_Call { + _c.Call.Return(run) + return _c +} + +// NewMockStreamingNodeCataLog creates a new instance of MockStreamingNodeCataLog. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockStreamingNodeCataLog(t interface { + mock.TestingT + Cleanup(func()) +}) *MockStreamingNodeCataLog { + mock := &MockStreamingNodeCataLog{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/streamingnode/server/mock_walmanager/mock_Manager.go b/internal/mocks/streamingnode/server/mock_walmanager/mock_Manager.go index 16b3bbe6f1..4c12954e6c 100644 --- a/internal/mocks/streamingnode/server/mock_walmanager/mock_Manager.go +++ b/internal/mocks/streamingnode/server/mock_walmanager/mock_Manager.go @@ -109,17 +109,17 @@ func (_c *MockManager_GetAllAvailableChannels_Call) RunAndReturn(run func() ([]t return _c } -// GetAvailableWAL provides a mock function with given fields: _a0 -func (_m *MockManager) GetAvailableWAL(_a0 types.PChannelInfo) (wal.WAL, error) { - ret := _m.Called(_a0) +// GetAvailableWAL provides a mock function with given fields: channel +func (_m *MockManager) GetAvailableWAL(channel types.PChannelInfo) (wal.WAL, error) { + ret := _m.Called(channel) var r0 wal.WAL var r1 error if rf, ok := ret.Get(0).(func(types.PChannelInfo) (wal.WAL, error)); ok { - return rf(_a0) + return rf(channel) } if rf, ok := ret.Get(0).(func(types.PChannelInfo) wal.WAL); ok { - r0 = rf(_a0) + r0 = rf(channel) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(wal.WAL) @@ -127,7 +127,7 @@ func (_m *MockManager) GetAvailableWAL(_a0 types.PChannelInfo) (wal.WAL, error) } if rf, ok := ret.Get(1).(func(types.PChannelInfo) error); ok { - r1 = rf(_a0) + r1 = rf(channel) } else { r1 = ret.Error(1) } @@ -141,12 +141,12 @@ type MockManager_GetAvailableWAL_Call struct { } // GetAvailableWAL is a helper method to define mock.On call -// - _a0 types.PChannelInfo -func (_e *MockManager_Expecter) GetAvailableWAL(_a0 interface{}) *MockManager_GetAvailableWAL_Call { - return &MockManager_GetAvailableWAL_Call{Call: _e.mock.On("GetAvailableWAL", _a0)} +// - channel types.PChannelInfo +func (_e *MockManager_Expecter) GetAvailableWAL(channel interface{}) *MockManager_GetAvailableWAL_Call { + return &MockManager_GetAvailableWAL_Call{Call: _e.mock.On("GetAvailableWAL", channel)} } -func (_c *MockManager_GetAvailableWAL_Call) Run(run func(_a0 types.PChannelInfo)) *MockManager_GetAvailableWAL_Call { +func (_c *MockManager_GetAvailableWAL_Call) Run(run func(channel types.PChannelInfo)) *MockManager_GetAvailableWAL_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(types.PChannelInfo)) }) diff --git a/internal/mocks/streamingnode/server/wal/interceptors/segment/mock_inspector/mock_SealOperator.go b/internal/mocks/streamingnode/server/wal/interceptors/segment/mock_inspector/mock_SealOperator.go new file mode 100644 index 0000000000..697da341fe --- /dev/null +++ b/internal/mocks/streamingnode/server/wal/interceptors/segment/mock_inspector/mock_SealOperator.go @@ -0,0 +1,203 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_inspector + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + stats "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + + types "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +// MockSealOperator is an autogenerated mock type for the SealOperator type +type MockSealOperator struct { + mock.Mock +} + +type MockSealOperator_Expecter struct { + mock *mock.Mock +} + +func (_m *MockSealOperator) EXPECT() *MockSealOperator_Expecter { + return &MockSealOperator_Expecter{mock: &_m.Mock} +} + +// Channel provides a mock function with given fields: +func (_m *MockSealOperator) Channel() types.PChannelInfo { + ret := _m.Called() + + var r0 types.PChannelInfo + if rf, ok := ret.Get(0).(func() types.PChannelInfo); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(types.PChannelInfo) + } + + return r0 +} + +// MockSealOperator_Channel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Channel' +type MockSealOperator_Channel_Call struct { + *mock.Call +} + +// Channel is a helper method to define mock.On call +func (_e *MockSealOperator_Expecter) Channel() *MockSealOperator_Channel_Call { + return &MockSealOperator_Channel_Call{Call: _e.mock.On("Channel")} +} + +func (_c *MockSealOperator_Channel_Call) Run(run func()) *MockSealOperator_Channel_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSealOperator_Channel_Call) Return(_a0 types.PChannelInfo) *MockSealOperator_Channel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSealOperator_Channel_Call) RunAndReturn(run func() types.PChannelInfo) *MockSealOperator_Channel_Call { + _c.Call.Return(run) + return _c +} + +// IsNoWaitSeal provides a mock function with given fields: +func (_m *MockSealOperator) IsNoWaitSeal() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockSealOperator_IsNoWaitSeal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsNoWaitSeal' +type MockSealOperator_IsNoWaitSeal_Call struct { + *mock.Call +} + +// IsNoWaitSeal is a helper method to define mock.On call +func (_e *MockSealOperator_Expecter) IsNoWaitSeal() *MockSealOperator_IsNoWaitSeal_Call { + return &MockSealOperator_IsNoWaitSeal_Call{Call: _e.mock.On("IsNoWaitSeal")} +} + +func (_c *MockSealOperator_IsNoWaitSeal_Call) Run(run func()) *MockSealOperator_IsNoWaitSeal_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSealOperator_IsNoWaitSeal_Call) Return(_a0 bool) *MockSealOperator_IsNoWaitSeal_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSealOperator_IsNoWaitSeal_Call) RunAndReturn(run func() bool) *MockSealOperator_IsNoWaitSeal_Call { + _c.Call.Return(run) + return _c +} + +// TryToSealSegments provides a mock function with given fields: ctx, infos +func (_m *MockSealOperator) TryToSealSegments(ctx context.Context, infos ...stats.SegmentBelongs) { + _va := make([]interface{}, len(infos)) + for _i := range infos { + _va[_i] = infos[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + _m.Called(_ca...) +} + +// MockSealOperator_TryToSealSegments_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TryToSealSegments' +type MockSealOperator_TryToSealSegments_Call struct { + *mock.Call +} + +// TryToSealSegments is a helper method to define mock.On call +// - ctx context.Context +// - infos ...stats.SegmentBelongs +func (_e *MockSealOperator_Expecter) TryToSealSegments(ctx interface{}, infos ...interface{}) *MockSealOperator_TryToSealSegments_Call { + return &MockSealOperator_TryToSealSegments_Call{Call: _e.mock.On("TryToSealSegments", + append([]interface{}{ctx}, infos...)...)} +} + +func (_c *MockSealOperator_TryToSealSegments_Call) Run(run func(ctx context.Context, infos ...stats.SegmentBelongs)) *MockSealOperator_TryToSealSegments_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]stats.SegmentBelongs, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(stats.SegmentBelongs) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockSealOperator_TryToSealSegments_Call) Return() *MockSealOperator_TryToSealSegments_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSealOperator_TryToSealSegments_Call) RunAndReturn(run func(context.Context, ...stats.SegmentBelongs)) *MockSealOperator_TryToSealSegments_Call { + _c.Call.Return(run) + return _c +} + +// TryToSealWaitedSegment provides a mock function with given fields: ctx +func (_m *MockSealOperator) TryToSealWaitedSegment(ctx context.Context) { + _m.Called(ctx) +} + +// MockSealOperator_TryToSealWaitedSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TryToSealWaitedSegment' +type MockSealOperator_TryToSealWaitedSegment_Call struct { + *mock.Call +} + +// TryToSealWaitedSegment is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockSealOperator_Expecter) TryToSealWaitedSegment(ctx interface{}) *MockSealOperator_TryToSealWaitedSegment_Call { + return &MockSealOperator_TryToSealWaitedSegment_Call{Call: _e.mock.On("TryToSealWaitedSegment", ctx)} +} + +func (_c *MockSealOperator_TryToSealWaitedSegment_Call) Run(run func(ctx context.Context)) *MockSealOperator_TryToSealWaitedSegment_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockSealOperator_TryToSealWaitedSegment_Call) Return() *MockSealOperator_TryToSealWaitedSegment_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSealOperator_TryToSealWaitedSegment_Call) RunAndReturn(run func(context.Context)) *MockSealOperator_TryToSealWaitedSegment_Call { + _c.Call.Return(run) + return _c +} + +// NewMockSealOperator creates a new instance of MockSealOperator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSealOperator(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSealOperator { + mock := &MockSealOperator{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/proto/streaming.proto b/internal/proto/streaming.proto index 7b6740e843..2b32ad3d78 100644 --- a/internal/proto/streaming.proto +++ b/internal/proto/streaming.proto @@ -391,3 +391,42 @@ message StreamingNodeBalanceAttributes { message StreamingNodeManagerCollectStatusResponse { StreamingNodeBalanceAttributes balance_attributes = 1; } + +/// +/// SegmentAssignment +/// +// SegmentAssignmentMeta is the stat of segment assignment. +// These meta is only used to recover status at streaming node segment +// assignment, don't use it outside. +// Used to storage the segment assignment stat +// at meta-store. The WAL use it to determine when to make the segment sealed. +message SegmentAssignmentMeta { + int64 collection_id = 1; + int64 partition_id = 2; + int64 segment_id = 3; + string vchannel = 4; + SegmentAssignmentState state = 5; + SegmentAssignmentStat stat = 6; +} + +// SegmentAssignmentState is the state of segment assignment. +// The state machine can be described as following: +// 1. PENDING -> GROWING -> SEALED -> FLUSHED +enum SegmentAssignmentState { + SEGMENT_ASSIGNMENT_STATE_UNKNOWN = 0; // should never used. + SEGMENT_ASSIGNMENT_STATE_PENDING = 1; + SEGMENT_ASSIGNMENT_STATE_GROWING = 2; + SEGMENT_ASSIGNMENT_STATE_SEALED = 3; + SEGMENT_ASSIGNMENT_STATE_FLUSHED = 4; // can never be seen, because it's + // removed physically when enter FLUSHED. +} + +// SegmentAssignmentStat is the stat of segment assignment. +message SegmentAssignmentStat { + uint64 max_binary_size = 1; + uint64 inserted_rows = 2; + uint64 inserted_binary_size = 3; + int64 create_timestamp_nanoseconds = 4; + int64 last_modified_timestamp_nanoseconds = 5; + uint64 binlog_counter = 6; +} diff --git a/internal/streamingcoord/server/balancer/balance_timer.go b/internal/streamingcoord/server/balancer/balance_timer.go deleted file mode 100644 index 53443930a1..0000000000 --- a/internal/streamingcoord/server/balancer/balance_timer.go +++ /dev/null @@ -1,55 +0,0 @@ -package balancer - -import ( - "time" - - "github.com/cenkalti/backoff/v4" - - "github.com/milvus-io/milvus/pkg/util/paramtable" -) - -// newBalanceTimer creates a new balanceTimer -func newBalanceTimer() *balanceTimer { - return &balanceTimer{ - backoff: backoff.NewExponentialBackOff(), - newIncomingBackOff: false, - } -} - -// balanceTimer is a timer for balance operation -type balanceTimer struct { - backoff *backoff.ExponentialBackOff - newIncomingBackOff bool - enableBackoff bool -} - -// EnableBackoffOrNot enables or disables backoff -func (t *balanceTimer) EnableBackoff() { - if !t.enableBackoff { - t.enableBackoff = true - t.newIncomingBackOff = true - } -} - -// DisableBackoff disables backoff -func (t *balanceTimer) DisableBackoff() { - t.enableBackoff = false -} - -// NextTimer returns the next timer and the duration of the timer -func (t *balanceTimer) NextTimer() (<-chan time.Time, time.Duration) { - if !t.enableBackoff { - balanceInterval := paramtable.Get().StreamingCoordCfg.AutoBalanceTriggerInterval.GetAsDurationByParse() - return time.After(balanceInterval), balanceInterval - } - if t.newIncomingBackOff { - t.newIncomingBackOff = false - // reconfig backoff - t.backoff.InitialInterval = paramtable.Get().StreamingCoordCfg.AutoBalanceBackoffInitialInterval.GetAsDurationByParse() - t.backoff.Multiplier = paramtable.Get().StreamingCoordCfg.AutoBalanceBackoffMultiplier.GetAsFloat() - t.backoff.MaxInterval = paramtable.Get().StreamingCoordCfg.AutoBalanceTriggerInterval.GetAsDurationByParse() - t.backoff.Reset() - } - nextBackoff := t.backoff.NextBackOff() - return time.After(nextBackoff), nextBackoff -} diff --git a/internal/streamingcoord/server/balancer/balancer_impl.go b/internal/streamingcoord/server/balancer/balancer_impl.go index 49a9bbc15e..36e2cd8a4b 100644 --- a/internal/streamingcoord/server/balancer/balancer_impl.go +++ b/internal/streamingcoord/server/balancer/balancer_impl.go @@ -2,6 +2,7 @@ package balancer import ( "context" + "time" "github.com/cockroachdb/errors" "go.uber.org/zap" @@ -13,6 +14,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/lifetime" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -106,7 +108,7 @@ func (b *balancerImpl) execute() { b.logger.Info("balancer execute finished") }() - balanceTimer := newBalanceTimer() + balanceTimer := typeutil.NewBackoffTimer(&backoffConfigFetcher{}) nodeChanged, err := resource.Resource().StreamingNodeManagerClient().WatchNodeChanged(b.backgroundTaskNotifier.Context()) if err != nil { b.logger.Error("fail to watch node changed", zap.Error(err)) @@ -284,3 +286,17 @@ func generateCurrentLayout(channelsInMeta map[string]*channel.PChannelMeta, allN AllNodesInfo: allNodesInfo, } } + +type backoffConfigFetcher struct{} + +func (f *backoffConfigFetcher) BackoffConfig() typeutil.BackoffConfig { + return typeutil.BackoffConfig{ + InitialInterval: paramtable.Get().StreamingCoordCfg.AutoBalanceBackoffInitialInterval.GetAsDurationByParse(), + Multiplier: paramtable.Get().StreamingCoordCfg.AutoBalanceBackoffMultiplier.GetAsFloat(), + MaxInterval: paramtable.Get().StreamingCoordCfg.AutoBalanceTriggerInterval.GetAsDurationByParse(), + } +} + +func (f *backoffConfigFetcher) DefaultInterval() time.Duration { + return paramtable.Get().StreamingCoordCfg.AutoBalanceTriggerInterval.GetAsDurationByParse() +} diff --git a/internal/streamingnode/server/builder.go b/internal/streamingnode/server/builder.go index eeb5237fa9..f235fe4703 100644 --- a/internal/streamingnode/server/builder.go +++ b/internal/streamingnode/server/builder.go @@ -4,10 +4,12 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/metastore/kv/streamingnode" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/componentutil" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -19,6 +21,7 @@ type ServerBuilder struct { rc types.RootCoordClient dc types.DataCoordClient session *sessionutil.Session + kv kv.MetaKv } // NewServerBuilder creates a new server builder. @@ -56,12 +59,19 @@ func (b *ServerBuilder) WithSession(session *sessionutil.Session) *ServerBuilder return b } +// WithMetaKV sets meta kv to the server builder. +func (b *ServerBuilder) WithMetaKV(kv kv.MetaKv) *ServerBuilder { + b.kv = kv + return b +} + // Build builds a streaming node server. func (s *ServerBuilder) Build() *Server { resource.Init( resource.OptETCD(s.etcdClient), resource.OptRootCoordClient(s.rc), resource.OptDataCoordClient(s.dc), + resource.OptStreamingNodeCatalog(streamingnode.NewCataLog(s.kv)), ) return &Server{ session: s.session, diff --git a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go index 54170315cb..aa0f11aa4b 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go @@ -96,7 +96,8 @@ func (s *FlusherSuite) SetupSuite() { wbMgr.EXPECT().Start().Return() wbMgr.EXPECT().Stop().Return() - resource.Init( + resource.InitForTest( + s.T(), resource.OptSyncManager(syncMgr), resource.OptBufferManager(wbMgr), resource.OptRootCoordClient(rootcoord), diff --git a/internal/streamingnode/server/resource/resource.go b/internal/streamingnode/server/resource/resource.go index 409c1cc7cd..5cff55f217 100644 --- a/internal/streamingnode/server/resource/resource.go +++ b/internal/streamingnode/server/resource/resource.go @@ -7,9 +7,12 @@ import ( "github.com/milvus-io/milvus/internal/flushcommon/syncmgr" "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" + "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/internal/types" ) @@ -67,6 +70,13 @@ func OptDataCoordClient(dataCoordClient types.DataCoordClient) optResourceInit { } } +// OptStreamingNodeCatalog provides the streaming node catalog to the resource. +func OptStreamingNodeCatalog(catalog metastore.StreamingNodeCataLog) optResourceInit { + return func(r *resourceImpl) { + r.streamingNodeCatalog = catalog + } +} + // Init initializes the singleton of resources. // Should be call when streaming node startup. func Init(opts ...optResourceInit) { @@ -76,10 +86,15 @@ func Init(opts ...optResourceInit) { } r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) + r.segmentAssignStatsManager = stats.NewStatsManager() + r.segmentSealedInspector = inspector.NewSealedInspector(r.segmentAssignStatsManager.SealNotifier()) assertNotNil(r.TSOAllocator()) assertNotNil(r.RootCoordClient()) assertNotNil(r.DataCoordClient()) + assertNotNil(r.StreamingNodeCatalog()) + assertNotNil(r.SegmentAssignStatsManager()) + assertNotNil(r.SegmentSealedInspector()) } // Resource access the underlying singleton of resources. @@ -94,12 +109,15 @@ type resourceImpl struct { syncMgr syncmgr.SyncManager wbMgr writebuffer.BufferManager - timestampAllocator idalloc.Allocator - idAllocator idalloc.Allocator - etcdClient *clientv3.Client - chunkManager storage.ChunkManager - rootCoordClient types.RootCoordClient - dataCoordClient types.DataCoordClient + timestampAllocator idalloc.Allocator + idAllocator idalloc.Allocator + etcdClient *clientv3.Client + chunkManager storage.ChunkManager + rootCoordClient types.RootCoordClient + dataCoordClient types.DataCoordClient + streamingNodeCatalog metastore.StreamingNodeCataLog + segmentAssignStatsManager *stats.StatsManager + segmentSealedInspector inspector.SealOperationInspector } // Flusher returns the flusher. @@ -147,6 +165,21 @@ func (r *resourceImpl) DataCoordClient() types.DataCoordClient { return r.dataCoordClient } +// StreamingNodeCataLog returns the streaming node catalog. +func (r *resourceImpl) StreamingNodeCatalog() metastore.StreamingNodeCataLog { + return r.streamingNodeCatalog +} + +// SegmentAssignStatManager returns the segment assign stats manager. +func (r *resourceImpl) SegmentAssignStatsManager() *stats.StatsManager { + return r.segmentAssignStatsManager +} + +// SegmentSealedInspector returns the segment sealed inspector. +func (r *resourceImpl) SegmentSealedInspector() inspector.SealOperationInspector { + return r.segmentSealedInspector +} + // assertNotNil panics if the resource is nil. func assertNotNil(v interface{}) { iv := reflect.ValueOf(v) diff --git a/internal/streamingnode/server/resource/resource_test.go b/internal/streamingnode/server/resource/resource_test.go index 6aaed02e34..27dde00197 100644 --- a/internal/streamingnode/server/resource/resource_test.go +++ b/internal/streamingnode/server/resource/resource_test.go @@ -7,6 +7,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/mocks/mock_metastore" ) func TestInit(t *testing.T) { @@ -23,6 +24,7 @@ func TestInit(t *testing.T) { OptETCD(&clientv3.Client{}), OptRootCoordClient(mocks.NewMockRootCoordClient(t)), OptDataCoordClient(mocks.NewMockDataCoordClient(t)), + OptStreamingNodeCatalog(mock_metastore.NewMockStreamingNodeCataLog(t)), ) assert.NotNil(t, Resource().TSOAllocator()) @@ -31,5 +33,5 @@ func TestInit(t *testing.T) { } func TestInitForTest(t *testing.T) { - InitForTest() + InitForTest(t) } diff --git a/internal/streamingnode/server/resource/test_utility.go b/internal/streamingnode/server/resource/test_utility.go index 1bb2bd3a8a..547db9595e 100644 --- a/internal/streamingnode/server/resource/test_utility.go +++ b/internal/streamingnode/server/resource/test_utility.go @@ -3,15 +3,28 @@ package resource -import "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" +import ( + "testing" + + "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" +) // InitForTest initializes the singleton of resources for test. -func InitForTest(opts ...optResourceInit) { +func InitForTest(t *testing.T, opts ...optResourceInit) { r = &resourceImpl{} for _, opt := range opts { opt(r) } if r.rootCoordClient != nil { r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) + r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) + } else { + r.rootCoordClient = idalloc.NewMockRootCoordClient(t) + r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) + r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) } + r.segmentAssignStatsManager = stats.NewStatsManager() + r.segmentSealedInspector = inspector.NewSealedInspector(r.segmentAssignStatsManager.SealNotifier()) } diff --git a/internal/streamingnode/server/wal/adaptor/builder.go b/internal/streamingnode/server/wal/adaptor/builder.go index 0e8084a4e7..6190fca490 100644 --- a/internal/streamingnode/server/wal/adaptor/builder.go +++ b/internal/streamingnode/server/wal/adaptor/builder.go @@ -4,6 +4,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/ddl" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick" "github.com/milvus-io/milvus/pkg/streaming/walimpls" ) @@ -32,6 +33,7 @@ func (b builderAdaptorImpl) Build() (wal.Opener, error) { // Add all interceptor here. return adaptImplsToOpener(o, []interceptors.InterceptorBuilder{ timetick.NewInterceptorBuilder(), + segment.NewInterceptorBuilder(), ddl.NewInterceptorBuilder(), }), nil } diff --git a/internal/streamingnode/server/wal/adaptor/wal_test.go b/internal/streamingnode/server/wal/adaptor/wal_test.go index def37b7c32..83b405e3f0 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_test.go +++ b/internal/streamingnode/server/wal/adaptor/wal_test.go @@ -12,7 +12,16 @@ import ( "github.com/remeh/sizedwaitgroup" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/flushcommon/syncmgr" + "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/mocks/mock_metastore" + "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_flusher" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" @@ -30,9 +39,7 @@ type walTestFramework struct { } func TestWAL(t *testing.T) { - rc := idalloc.NewMockRootCoordClient(t) - resource.InitForTest(resource.OptRootCoordClient(rc)) - + initResourceForTest(t) b := registry.MustGetBuilder(walimplstest.WALName) f := &walTestFramework{ b: b, @@ -42,6 +49,36 @@ func TestWAL(t *testing.T) { f.Run() } +func initResourceForTest(t *testing.T) { + rc := idalloc.NewMockRootCoordClient(t) + rc.EXPECT().GetPChannelInfo(mock.Anything, mock.Anything).Return(&rootcoordpb.GetPChannelInfoResponse{}, nil) + + dc := mocks.NewMockDataCoordClient(t) + dc.EXPECT().AllocSegment(mock.Anything, mock.Anything).Return(&datapb.AllocSegmentResponse{}, nil) + catalog := mock_metastore.NewMockStreamingNodeCataLog(t) + catalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).Return(nil, nil) + catalog.EXPECT().SaveSegmentAssignments(mock.Anything, mock.Anything, mock.Anything).Return(nil) + + syncMgr := syncmgr.NewMockSyncManager(t) + wbMgr := writebuffer.NewMockBufferManager(t) + + flusher := mock_flusher.NewMockFlusher(t) + flusher.EXPECT().RegisterPChannel(mock.Anything, mock.Anything).Return(nil).Maybe() + flusher.EXPECT().UnregisterPChannel(mock.Anything).Return().Maybe() + flusher.EXPECT().RegisterVChannel(mock.Anything, mock.Anything).Return() + flusher.EXPECT().UnregisterVChannel(mock.Anything).Return() + + resource.InitForTest( + t, + resource.OptSyncManager(syncMgr), + resource.OptBufferManager(wbMgr), + resource.OptRootCoordClient(rc), + resource.OptDataCoordClient(dc), + resource.OptFlusher(flusher), + resource.OptStreamingNodeCatalog(catalog), + ) +} + func (f *walTestFramework) Run() { wg := sync.WaitGroup{} loopCnt := 3 @@ -82,6 +119,7 @@ type testOneWALFramework struct { func (f *testOneWALFramework) Run() { ctx := context.Background() + for ; f.term <= 3; f.term++ { pChannel := types.PChannelInfo{ Name: f.pchannel, @@ -101,6 +139,9 @@ func (f *testOneWALFramework) Run() { } func (f *testOneWALFramework) testReadAndWrite(ctx context.Context, w wal.WAL) { + f.testSendCreateCollection(ctx, w) + defer f.testSendDropCollection(ctx, w) + // Test read and write. wg := sync.WaitGroup{} wg.Add(3) @@ -142,6 +183,35 @@ func (f *testOneWALFramework) testReadAndWrite(ctx context.Context, w wal.WAL) { f.testReadWithOption(ctx, w) } +func (f *testOneWALFramework) testSendCreateCollection(ctx context.Context, w wal.WAL) { + // create collection before start test + createMsg, err := message.NewCreateCollectionMessageBuilderV1(). + WithHeader(&message.CreateCollectionMessageHeader{ + CollectionId: 1, + PartitionIds: []int64{2}, + }). + WithBody(&msgpb.CreateCollectionRequest{}).BuildMutable() + assert.NoError(f.t, err) + + msgID, err := w.Append(ctx, createMsg.WithVChannel("v1")) + assert.NoError(f.t, err) + assert.NotNil(f.t, msgID) +} + +func (f *testOneWALFramework) testSendDropCollection(ctx context.Context, w wal.WAL) { + // drop collection after test + dropMsg, err := message.NewDropCollectionMessageBuilderV1(). + WithHeader(&message.DropCollectionMessageHeader{ + CollectionId: 1, + }). + WithBody(&msgpb.DropCollectionRequest{}).BuildMutable() + assert.NoError(f.t, err) + + msgID, err := w.Append(ctx, dropMsg.WithVChannel("v1")) + assert.NoError(f.t, err) + assert.NotNil(f.t, msgID) +} + func (f *testOneWALFramework) testAppend(ctx context.Context, w wal.WAL) ([]message.ImmutableMessage, error) { messages := make([]message.ImmutableMessage, f.messageCount) swg := sizedwaitgroup.New(10) @@ -178,6 +248,9 @@ func (f *testOneWALFramework) testAppend(ctx context.Context, w wal.WAL) ([]mess func (f *testOneWALFramework) testRead(ctx context.Context, w wal.WAL) ([]message.ImmutableMessage, error) { s, err := w.Read(ctx, wal.ReadOption{ DeliverPolicy: options.DeliverPolicyAll(), + MessageFilter: func(im message.ImmutableMessage) bool { + return im.MessageType() == message.MessageTypeInsert + }, }) assert.NoError(f.t, err) defer s.Close() @@ -218,7 +291,7 @@ func (f *testOneWALFramework) testReadWithOption(ctx context.Context, w wal.WAL) s, err := w.Read(ctx, wal.ReadOption{ DeliverPolicy: options.DeliverPolicyStartFrom(readFromMsg.LastConfirmedMessageID()), MessageFilter: func(im message.ImmutableMessage) bool { - return im.TimeTick() >= readFromMsg.TimeTick() + return im.TimeTick() >= readFromMsg.TimeTick() && im.MessageType() == message.MessageTypeInsert }, }) assert.NoError(f.t, err) diff --git a/internal/streamingnode/server/wal/interceptors/segment/builder.go b/internal/streamingnode/server/wal/interceptors/segment/builder.go new file mode 100644 index 0000000000..3b4f62dc2a --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/builder.go @@ -0,0 +1,31 @@ +package segment + +import ( + "context" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/manager" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +func NewInterceptorBuilder() interceptors.InterceptorBuilder { + return &interceptorBuilder{} +} + +type interceptorBuilder struct{} + +func (b *interceptorBuilder) Build(param interceptors.InterceptorBuildParam) interceptors.BasicInterceptor { + assignManager := syncutil.NewFuture[*manager.PChannelSegmentAllocManager]() + ctx, cancel := context.WithCancel(context.Background()) + segmentInterceptor := &segmentInterceptor{ + ctx: ctx, + cancel: cancel, + logger: log.With(zap.Any("pchannel", param.WALImpls.Channel())), + assignManager: assignManager, + } + go segmentInterceptor.recoverPChannelManager(param) + return segmentInterceptor +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/inspector/impls.go b/internal/streamingnode/server/wal/interceptors/segment/inspector/impls.go new file mode 100644 index 0000000000..91619334e1 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/inspector/impls.go @@ -0,0 +1,142 @@ +package inspector + +import ( + "context" + "time" + + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +const ( + defaultSealAllInterval = 10 * time.Second +) + +// NewSealedInspector creates a new seal inspector. +func NewSealedInspector(n *stats.SealSignalNotifier) SealOperationInspector { + s := &sealOperationInspectorImpl{ + taskNotifier: syncutil.NewAsyncTaskNotifier[struct{}](), + managers: typeutil.NewConcurrentMap[string, SealOperator](), + notifier: n, + backOffTimer: typeutil.NewBackoffTimer(typeutil.BackoffTimerConfig{ + Default: 1 * time.Second, + Backoff: typeutil.BackoffConfig{ + InitialInterval: 20 * time.Millisecond, + Multiplier: 2.0, + MaxInterval: 1 * time.Second, + }, + }), + triggerCh: make(chan string), + } + go s.background() + return s +} + +// sealOperationInspectorImpl is the implementation of SealInspector. +type sealOperationInspectorImpl struct { + taskNotifier *syncutil.AsyncTaskNotifier[struct{}] + + managers *typeutil.ConcurrentMap[string, SealOperator] + notifier *stats.SealSignalNotifier + backOffTimer *typeutil.BackoffTimer + triggerCh chan string +} + +// TriggerSealWaited implements SealInspector.TriggerSealWaited. +func (s *sealOperationInspectorImpl) TriggerSealWaited(ctx context.Context, pchannel string) error { + select { + case <-ctx.Done(): + return ctx.Err() + case s.triggerCh <- pchannel: + return nil + } +} + +// RegsiterPChannelManager implements SealInspector.RegsiterPChannelManager. +func (s *sealOperationInspectorImpl) RegsiterPChannelManager(m SealOperator) { + _, loaded := s.managers.GetOrInsert(m.Channel().Name, m) + if loaded { + panic("pchannel manager already exists, critical bug in code") + } +} + +// UnregisterPChannelManager implements SealInspector.UnregisterPChannelManager. +func (s *sealOperationInspectorImpl) UnregisterPChannelManager(m SealOperator) { + _, loaded := s.managers.GetAndRemove(m.Channel().Name) + if !loaded { + panic("pchannel manager not found, critical bug in code") + } +} + +// Close implements SealInspector.Close. +func (s *sealOperationInspectorImpl) Close() { + s.taskNotifier.Cancel() + s.taskNotifier.BlockUntilFinish() +} + +// background is the background task to inspect if a segment should be sealed or not. +func (s *sealOperationInspectorImpl) background() { + defer s.taskNotifier.Finish(struct{}{}) + + sealAllTicker := time.NewTicker(defaultSealAllInterval) + defer sealAllTicker.Stop() + + var backoffCh <-chan time.Time + for { + if s.shouldEnableBackoff() { + // start a backoff if there's some pchannel wait for seal. + s.backOffTimer.EnableBackoff() + backoffCh, _ = s.backOffTimer.NextTimer() + } else { + s.backOffTimer.DisableBackoff() + } + + select { + case <-s.taskNotifier.Context().Done(): + return + case pchannel := <-s.triggerCh: + if manager, ok := s.managers.Get(pchannel); ok { + manager.TryToSealWaitedSegment(s.taskNotifier.Context()) + } + case <-s.notifier.WaitChan(): + s.tryToSealPartition(s.notifier.Get()) + case <-backoffCh: + // only seal waited segment for backoff. + s.managers.Range(func(_ string, pm SealOperator) bool { + pm.TryToSealWaitedSegment(s.taskNotifier.Context()) + return true + }) + case <-sealAllTicker.C: + s.managers.Range(func(_ string, pm SealOperator) bool { + pm.TryToSealSegments(s.taskNotifier.Context()) + return true + }) + } + } +} + +// shouldEnableBackoff checks if the backoff should be enabled. +// if there's any pchannel has a segment wait for seal, enable backoff. +func (s *sealOperationInspectorImpl) shouldEnableBackoff() bool { + enableBackoff := false + s.managers.Range(func(_ string, pm SealOperator) bool { + if !pm.IsNoWaitSeal() { + enableBackoff = true + return false + } + return true + }) + return enableBackoff +} + +// tryToSealPartition tries to seal the segment with the specified policies. +func (s *sealOperationInspectorImpl) tryToSealPartition(infos typeutil.Set[stats.SegmentBelongs]) { + for info := range infos { + pm, ok := s.managers.Get(info.PChannel) + if !ok { + continue + } + pm.TryToSealSegments(s.taskNotifier.Context(), info) + } +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector.go b/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector.go new file mode 100644 index 0000000000..3fef273441 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector.go @@ -0,0 +1,41 @@ +package inspector + +import ( + "context" + + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +// SealOperationInspector is the inspector to check if a segment should be sealed or not. +type SealOperationInspector interface { + // TriggerSealWaited triggers the seal waited segment. + TriggerSealWaited(ctx context.Context, pchannel string) error + + // RegisterPChannelManager registers a pchannel manager. + RegsiterPChannelManager(m SealOperator) + + // UnregisterPChannelManager unregisters a pchannel manager. + UnregisterPChannelManager(m SealOperator) + + // Close closes the inspector. + Close() +} + +// SealOperator is a segment seal operator. +type SealOperator interface { + // Channel returns the pchannel info. + Channel() types.PChannelInfo + + // TryToSealSegments tries to seal the segment, if info is given, seal operation is only applied to related partitions and waiting seal segments, + // Otherwise, seal operation is applied to all partitions. + // Return false if there's some segment wait for seal but not sealed. + TryToSealSegments(ctx context.Context, infos ...stats.SegmentBelongs) + + // TryToSealWaitedSegment tries to seal the wait for sealing segment. + // Return false if there's some segment wait for seal but not sealed. + TryToSealWaitedSegment(ctx context.Context) + + // IsNoWaitSeal returns whether there's no segment wait for seal. + IsNoWaitSeal() bool +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector_test.go b/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector_test.go new file mode 100644 index 0000000000..5e2894a216 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/inspector/inspector_test.go @@ -0,0 +1,63 @@ +package inspector + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/mock" + "go.uber.org/atomic" + + "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/wal/interceptors/segment/mock_inspector" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +func TestSealedInspector(t *testing.T) { + notifier := stats.NewSealSignalNotifier() + inspector := NewSealedInspector(notifier) + + o := mock_inspector.NewMockSealOperator(t) + ops := atomic.NewInt32(0) + + o.EXPECT().Channel().Return(types.PChannelInfo{Name: "v1"}) + o.EXPECT().TryToSealSegments(mock.Anything, mock.Anything). + RunAndReturn(func(ctx context.Context, sb ...stats.SegmentBelongs) { + ops.Add(1) + }) + o.EXPECT().TryToSealWaitedSegment(mock.Anything). + RunAndReturn(func(ctx context.Context) { + ops.Add(1) + }) + o.EXPECT().IsNoWaitSeal().RunAndReturn(func() bool { + return ops.Load()%2 == 0 + }) + + inspector.RegsiterPChannelManager(o) + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + for i := 0; i < 5; i++ { + inspector.TriggerSealWaited(context.Background(), "v1") + ops.Add(1) + } + }() + go func() { + defer wg.Done() + for i := 0; i < 5; i++ { + notifier.AddAndNotify(stats.SegmentBelongs{ + PChannel: "v1", + VChannel: "vv1", + CollectionID: 12, + PartitionID: 1, + }) + time.Sleep(5 * time.Millisecond) + } + time.Sleep(500 * time.Millisecond) + }() + wg.Wait() + inspector.UnregisterPChannelManager(o) + inspector.Close() +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/params.go b/internal/streamingnode/server/wal/interceptors/segment/manager/params.go new file mode 100644 index 0000000000..086e570ac8 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/params.go @@ -0,0 +1,27 @@ +package manager + +import ( + "go.uber.org/atomic" + + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" +) + +// AssignSegmentRequest is a request to allocate segment. +type AssignSegmentRequest struct { + CollectionID int64 + PartitionID int64 + InsertMetrics stats.InsertMetrics +} + +// AssignSegmentResult is a result of segment allocation. +// The sum of Results.Row is equal to InserMetrics.NumRows. +type AssignSegmentResult struct { + SegmentID int64 + Acknowledge *atomic.Int32 // used to ack the segment assign result has been consumed +} + +// Ack acks the segment assign result has been consumed. +// Must be only call once after the segment assign result has been consumed. +func (r *AssignSegmentResult) Ack() { + r.Acknowledge.Dec() +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go new file mode 100644 index 0000000000..a711c4b9e6 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go @@ -0,0 +1,236 @@ +package manager + +import ( + "context" + "sync" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/policy" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +// newPartitionSegmentManager creates a new partition segment assign manager. +func newPartitionSegmentManager( + pchannel types.PChannelInfo, + vchannel string, + collectionID int64, + paritionID int64, + segments []*segmentAllocManager, +) *partitionSegmentManager { + return &partitionSegmentManager{ + mu: sync.Mutex{}, + logger: log.With( + zap.Any("pchannel", pchannel), + zap.String("vchannel", vchannel), + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", paritionID)), + pchannel: pchannel, + vchannel: vchannel, + collectionID: collectionID, + paritionID: paritionID, + segments: segments, + } +} + +// partitionSegmentManager is a assign manager of determined partition on determined vchannel. +type partitionSegmentManager struct { + mu sync.Mutex + logger *log.MLogger + pchannel types.PChannelInfo + vchannel string + collectionID int64 + paritionID int64 + segments []*segmentAllocManager // there will be very few segments in this list. +} + +func (m *partitionSegmentManager) CollectionID() int64 { + return m.collectionID +} + +// AssignSegment assigns a segment for a assign segment request. +func (m *partitionSegmentManager) AssignSegment(ctx context.Context, insert stats.InsertMetrics) (*AssignSegmentResult, error) { + m.mu.Lock() + defer m.mu.Unlock() + + return m.assignSegment(ctx, insert) +} + +// CollectShouldBeSealed try to collect all segments that should be sealed. +func (m *partitionSegmentManager) CollectShouldBeSealed() []*segmentAllocManager { + m.mu.Lock() + defer m.mu.Unlock() + + shouldBeSealedSegments := make([]*segmentAllocManager, 0, len(m.segments)) + segments := make([]*segmentAllocManager, 0, len(m.segments)) + for _, segment := range m.segments { + // A already sealed segment may be came from recovery. + if segment.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_SEALED { + shouldBeSealedSegments = append(shouldBeSealedSegments, segment) + m.logger.Info("segment has been sealed, remove it from assignment", + zap.Int64("segmentID", segment.GetSegmentID()), + zap.String("state", segment.GetState().String()), + zap.Any("stat", segment.GetStat()), + ) + continue + } + // policy hitted segment should be removed from assignment manager. + if m.hitSealPolicy(segment) { + shouldBeSealedSegments = append(shouldBeSealedSegments, segment) + continue + } + segments = append(segments, segment) + } + m.segments = segments + return shouldBeSealedSegments +} + +// CollectDirtySegmentsAndClear collects all segments in the manager and clear the maanger. +func (m *partitionSegmentManager) CollectDirtySegmentsAndClear() []*segmentAllocManager { + m.mu.Lock() + defer m.mu.Unlock() + dirtySegments := make([]*segmentAllocManager, 0, len(m.segments)) + for _, segment := range m.segments { + if segment.IsDirtyEnough() { + dirtySegments = append(dirtySegments, segment) + } + } + m.segments = make([]*segmentAllocManager, 0) + return dirtySegments +} + +// CollectAllCanBeSealedAndClear collects all segments that can be sealed and clear the manager. +func (m *partitionSegmentManager) CollectAllCanBeSealedAndClear() []*segmentAllocManager { + m.mu.Lock() + defer m.mu.Unlock() + canBeSealed := make([]*segmentAllocManager, 0, len(m.segments)) + for _, segment := range m.segments { + if segment.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING || + segment.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_SEALED { + canBeSealed = append(canBeSealed, segment) + } + } + m.segments = make([]*segmentAllocManager, 0) + return canBeSealed +} + +// hitSealPolicy checks if the segment should be sealed by policy. +func (m *partitionSegmentManager) hitSealPolicy(segmentMeta *segmentAllocManager) bool { + stat := segmentMeta.GetStat() + for _, p := range policy.GetSegmentAsyncSealPolicy() { + if result := p.ShouldBeSealed(stat); result.ShouldBeSealed { + m.logger.Info("segment should be sealed by policy", + zap.Int64("segmentID", segmentMeta.GetSegmentID()), + zap.String("policy", result.PolicyName), + zap.Any("stat", stat), + zap.Any("extraInfo", result.ExtraInfo), + ) + return true + } + } + return false +} + +// allocNewGrowingSegment allocates a new growing segment. +// After this operation, the growing segment can be seen at datacoord. +func (m *partitionSegmentManager) allocNewGrowingSegment(ctx context.Context) (*segmentAllocManager, error) { + // A pending segment may be already created when failure or recovery. + pendingSegment := m.findPendingSegmentInMeta() + if pendingSegment == nil { + // if there's no pending segment, create a new pending segment. + var err error + if pendingSegment, err = m.createNewPendingSegment(ctx); err != nil { + return nil, err + } + } + + // Transfer the pending segment into growing state. + // Alloc the growing segment at datacoord first. + resp, err := resource.Resource().DataCoordClient().AllocSegment(ctx, &datapb.AllocSegmentRequest{ + CollectionId: pendingSegment.GetCollectionID(), + PartitionId: pendingSegment.GetPartitionID(), + SegmentId: pendingSegment.GetSegmentID(), + Vchannel: pendingSegment.GetVChannel(), + }) + if err := merr.CheckRPCCall(resp, err); err != nil { + return nil, errors.Wrap(err, "failed to alloc growing segment at datacoord") + } + + // Getnerate growing segment limitation. + limitation := policy.GetSegmentLimitationPolicy().GenerateLimitation() + + // Commit it into streaming node meta. + // growing segment can be assigned now. + tx := pendingSegment.BeginModification() + tx.IntoGrowing(&limitation) + if err := tx.Commit(ctx); err != nil { + return nil, errors.Wrapf(err, "failed to commit modification of segment assignment into growing, segmentID: %d", pendingSegment.GetSegmentID()) + } + m.logger.Info( + "generate new growing segment", + zap.Int64("segmentID", pendingSegment.GetSegmentID()), + zap.String("limitationPolicy", limitation.PolicyName), + zap.Uint64("segmentBinarySize", limitation.SegmentSize), + zap.Any("extraInfo", limitation.ExtraInfo), + ) + return pendingSegment, nil +} + +// findPendingSegmentInMeta finds a pending segment in the meta list. +func (m *partitionSegmentManager) findPendingSegmentInMeta() *segmentAllocManager { + // Found if there's already a pending segment. + for _, segment := range m.segments { + if segment.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_PENDING { + return segment + } + } + return nil +} + +// createNewPendingSegment creates a new pending segment. +// pending segment only have a segment id, it's not a real segment, +// and will be transfer into growing state until registering to datacoord. +// The segment id is always allocated from rootcoord to avoid repeated. +// Pending state is used to avoid growing segment leak at datacoord. +func (m *partitionSegmentManager) createNewPendingSegment(ctx context.Context) (*segmentAllocManager, error) { + // Allocate new segment id and create ts from remote. + segmentID, err := resource.Resource().IDAllocator().Allocate(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to allocate segment id") + } + meta := newSegmentAllocManager(m.pchannel, m.collectionID, m.paritionID, int64(segmentID), m.vchannel) + tx := meta.BeginModification() + if err := tx.Commit(ctx); err != nil { + return nil, errors.Wrap(err, "failed to commit segment assignment modification") + } + m.segments = append(m.segments, meta) + return meta, nil +} + +// assignSegment assigns a segment for a assign segment request and return should trigger a seal operation. +func (m *partitionSegmentManager) assignSegment(ctx context.Context, insert stats.InsertMetrics) (*AssignSegmentResult, error) { + // Alloc segment for insert at previous segments. + for _, segment := range m.segments { + inserted, ack := segment.AllocRows(ctx, insert) + if inserted { + return &AssignSegmentResult{SegmentID: segment.GetSegmentID(), Acknowledge: ack}, nil + } + } + + // If not inserted, ask a new growing segment to insert. + newGrowingSegment, err := m.allocNewGrowingSegment(ctx) + if err != nil { + return nil, err + } + if inserted, ack := newGrowingSegment.AllocRows(ctx, insert); inserted { + return &AssignSegmentResult{SegmentID: newGrowingSegment.GetSegmentID(), Acknowledge: ack}, nil + } + return nil, errors.Errorf("too large insert message, cannot hold in empty growing segment, stats: %+v", insert) +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_managers.go b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_managers.go new file mode 100644 index 0000000000..094676cbe7 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_managers.go @@ -0,0 +1,232 @@ +package manager + +import ( + "sync" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// buildNewPartitionManagers builds new partition managers. +func buildNewPartitionManagers( + pchannel types.PChannelInfo, + rawMetas []*streamingpb.SegmentAssignmentMeta, + collectionInfos []*rootcoordpb.CollectionInfoOnPChannel, +) (*partitionSegmentManagers, []*segmentAllocManager) { + // create a map to check if the partition exists. + partitionExist := make(map[int64]struct{}, len(collectionInfos)) + // collectionMap is a map from collectionID to collectionInfo. + collectionInfoMap := make(map[int64]*rootcoordpb.CollectionInfoOnPChannel, len(collectionInfos)) + for _, collectionInfo := range collectionInfos { + for _, partition := range collectionInfo.GetPartitions() { + partitionExist[partition.GetPartitionId()] = struct{}{} + } + collectionInfoMap[collectionInfo.GetCollectionId()] = collectionInfo + } + + // recover the segment infos from the streaming node segment assignment meta storage + waitForSealed := make([]*segmentAllocManager, 0) + metaMaps := make(map[int64][]*segmentAllocManager) + for _, rawMeta := range rawMetas { + m := newSegmentAllocManagerFromProto(pchannel, rawMeta) + if _, ok := partitionExist[rawMeta.GetPartitionId()]; !ok { + // related collection or partition is not exist. + // should be sealed right now. + waitForSealed = append(waitForSealed, m) + continue + } + if _, ok := metaMaps[rawMeta.GetPartitionId()]; !ok { + metaMaps[rawMeta.GetPartitionId()] = make([]*segmentAllocManager, 0, 2) + } + metaMaps[rawMeta.GetPartitionId()] = append(metaMaps[rawMeta.GetPartitionId()], m) + } + + // create managers list. + managers := typeutil.NewConcurrentMap[int64, *partitionSegmentManager]() + for collectionID, collectionInfo := range collectionInfoMap { + for _, partition := range collectionInfo.GetPartitions() { + segmentManagers := make([]*segmentAllocManager, 0) + // recovery meta is recovered , use it. + if managers, ok := metaMaps[partition.GetPartitionId()]; ok { + segmentManagers = managers + } + // otherwise, just create a new manager. + _, ok := managers.GetOrInsert(partition.GetPartitionId(), newPartitionSegmentManager( + pchannel, + collectionInfo.GetVchannel(), + collectionID, + partition.GetPartitionId(), + segmentManagers, + )) + if ok { + panic("partition manager already exists when buildNewPartitionManagers in segment assignment service, there's a bug in system") + } + } + } + return &partitionSegmentManagers{ + mu: sync.Mutex{}, + logger: log.With(zap.Any("pchannel", pchannel)), + pchannel: pchannel, + managers: managers, + collectionInfos: collectionInfoMap, + }, waitForSealed +} + +// partitionSegmentManagers is a collection of partition managers. +type partitionSegmentManagers struct { + mu sync.Mutex + + logger *log.MLogger + pchannel types.PChannelInfo + managers *typeutil.ConcurrentMap[int64, *partitionSegmentManager] // map partitionID to partition manager + collectionInfos map[int64]*rootcoordpb.CollectionInfoOnPChannel // map collectionID to collectionInfo +} + +// NewCollection creates a new partition manager. +func (m *partitionSegmentManagers) NewCollection(collectionID int64, vchannel string, partitionID []int64) { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.collectionInfos[collectionID]; ok { + m.logger.Warn("collection already exists when NewCollection in segment assignment service", + zap.Int64("collectionID", collectionID), + ) + return + } + + m.collectionInfos[collectionID] = newCollectionInfo(collectionID, vchannel, partitionID) + for _, partitionID := range partitionID { + if _, loaded := m.managers.GetOrInsert(partitionID, newPartitionSegmentManager( + m.pchannel, + vchannel, + collectionID, + partitionID, + make([]*segmentAllocManager, 0), + )); loaded { + m.logger.Warn("partition already exists when NewCollection in segment assignment service, it's may be a bug in system", + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID), + ) + } + } +} + +// NewPartition creates a new partition manager. +func (m *partitionSegmentManagers) NewPartition(collectionID int64, partitionID int64) { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.collectionInfos[collectionID]; !ok { + m.logger.Warn("collection not exists when NewPartition in segment assignment service, it's may be a bug in system", + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID), + ) + return + } + m.collectionInfos[collectionID].Partitions = append(m.collectionInfos[collectionID].Partitions, &rootcoordpb.PartitionInfoOnPChannel{ + PartitionId: partitionID, + }) + + if _, loaded := m.managers.GetOrInsert(partitionID, newPartitionSegmentManager( + m.pchannel, + m.collectionInfos[collectionID].Vchannel, + collectionID, + partitionID, + make([]*segmentAllocManager, 0), + )); loaded { + m.logger.Warn( + "partition already exists when NewPartition in segment assignment service, it's may be a bug in system", + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID)) + } +} + +// Get gets a partition manager from the partition managers. +func (m *partitionSegmentManagers) Get(collectionID int64, partitionID int64) (*partitionSegmentManager, error) { + pm, ok := m.managers.Get(partitionID) + if !ok { + return nil, errors.Errorf("partition %d in collection %d not found in segment assignment service", partitionID, collectionID) + } + return pm, nil +} + +// RemoveCollection removes a collection manager from the partition managers. +// Return the segments that need to be sealed. +func (m *partitionSegmentManagers) RemoveCollection(collectionID int64) []*segmentAllocManager { + m.mu.Lock() + defer m.mu.Unlock() + + collectionInfo, ok := m.collectionInfos[collectionID] + if !ok { + m.logger.Warn("collection not exists when RemoveCollection in segment assignment service", zap.Int64("collectionID", collectionID)) + return nil + } + delete(m.collectionInfos, collectionID) + + needSealed := make([]*segmentAllocManager, 0) + for _, partition := range collectionInfo.Partitions { + pm, ok := m.managers.Get(partition.PartitionId) + if ok { + needSealed = append(needSealed, pm.CollectAllCanBeSealedAndClear()...) + } + m.managers.Remove(partition.PartitionId) + } + return needSealed +} + +// RemovePartition removes a partition manager from the partition managers. +func (m *partitionSegmentManagers) RemovePartition(collectionID int64, partitionID int64) []*segmentAllocManager { + m.mu.Lock() + defer m.mu.Unlock() + + collectionInfo, ok := m.collectionInfos[collectionID] + if !ok { + m.logger.Warn("collection not exists when RemovePartition in segment assignment service", zap.Int64("collectionID", collectionID)) + return nil + } + partitions := make([]*rootcoordpb.PartitionInfoOnPChannel, 0, len(collectionInfo.Partitions)-1) + for _, partition := range collectionInfo.Partitions { + if partition.PartitionId != partitionID { + partitions = append(partitions, partition) + } + } + collectionInfo.Partitions = partitions + + pm, loaded := m.managers.GetAndRemove(partitionID) + if !loaded { + m.logger.Warn("partition not exists when RemovePartition in segment assignment service", + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partitionID)) + return nil + } + return pm.CollectAllCanBeSealedAndClear() +} + +// Range ranges the partition managers. +func (m *partitionSegmentManagers) Range(f func(pm *partitionSegmentManager)) { + m.managers.Range(func(_ int64, pm *partitionSegmentManager) bool { + f(pm) + return true + }) +} + +// newCollectionInfo creates a new collection info. +func newCollectionInfo(collectionID int64, vchannel string, partitionIDs []int64) *rootcoordpb.CollectionInfoOnPChannel { + info := &rootcoordpb.CollectionInfoOnPChannel{ + CollectionId: collectionID, + Vchannel: vchannel, + Partitions: make([]*rootcoordpb.PartitionInfoOnPChannel, 0, len(partitionIDs)), + } + for _, partitionID := range partitionIDs { + info.Partitions = append(info.Partitions, &rootcoordpb.PartitionInfoOnPChannel{ + PartitionId: partitionID, + }) + } + return info +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go new file mode 100644 index 0000000000..a72b289272 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go @@ -0,0 +1,231 @@ +package manager + +import ( + "context" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/lifetime" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +// RecoverPChannelSegmentAllocManager recovers the segment assignment manager at the specified pchannel. +func RecoverPChannelSegmentAllocManager( + ctx context.Context, + pchannel types.PChannelInfo, + wal *syncutil.Future[wal.WAL], +) (*PChannelSegmentAllocManager, error) { + // recover streaming node growing segment metas. + rawMetas, err := resource.Resource().StreamingNodeCatalog().ListSegmentAssignment(ctx, pchannel.Name) + if err != nil { + return nil, errors.Wrap(err, "failed to list segment assignment from catalog") + } + // get collection and parition info from rootcoord. + resp, err := resource.Resource().RootCoordClient().GetPChannelInfo(ctx, &rootcoordpb.GetPChannelInfoRequest{ + Pchannel: pchannel.Name, + }) + if err := merr.CheckRPCCall(resp, err); err != nil { + return nil, errors.Wrap(err, "failed to get pchannel info from rootcoord") + } + managers, waitForSealed := buildNewPartitionManagers(pchannel, rawMetas, resp.GetCollections()) + + // PChannelSegmentAllocManager is the segment assign manager of determined pchannel. + logger := log.With(zap.Any("pchannel", pchannel)) + + return &PChannelSegmentAllocManager{ + lifetime: lifetime.NewLifetime(lifetime.Working), + logger: logger, + pchannel: pchannel, + managers: managers, + helper: newSealQueue(logger, wal, waitForSealed), + }, nil +} + +// PChannelSegmentAllocManager is a segment assign manager of determined pchannel. +type PChannelSegmentAllocManager struct { + lifetime lifetime.Lifetime[lifetime.State] + + logger *log.MLogger + pchannel types.PChannelInfo + managers *partitionSegmentManagers + // There should always + helper *sealQueue +} + +// Channel returns the pchannel info. +func (m *PChannelSegmentAllocManager) Channel() types.PChannelInfo { + return m.pchannel +} + +// NewPartitions creates a new partition with the specified partitionIDs. +func (m *PChannelSegmentAllocManager) NewCollection(collectionID int64, vchannel string, partitionIDs []int64) error { + if err := m.checkLifetime(); err != nil { + return err + } + defer m.lifetime.Done() + + m.managers.NewCollection(collectionID, vchannel, partitionIDs) + return nil +} + +// NewPartition creates a new partition with the specified partitionID. +func (m *PChannelSegmentAllocManager) NewPartition(collectionID int64, partitionID int64) error { + if err := m.checkLifetime(); err != nil { + return err + } + defer m.lifetime.Done() + + m.managers.NewPartition(collectionID, partitionID) + return nil +} + +// AssignSegment assigns a segment for a assign segment request. +func (m *PChannelSegmentAllocManager) AssignSegment(ctx context.Context, req *AssignSegmentRequest) (*AssignSegmentResult, error) { + if err := m.checkLifetime(); err != nil { + return nil, err + } + defer m.lifetime.Done() + + manager, err := m.managers.Get(req.CollectionID, req.PartitionID) + if err != nil { + return nil, err + } + return manager.AssignSegment(ctx, req.InsertMetrics) +} + +// RemoveCollection removes the specified collection. +func (m *PChannelSegmentAllocManager) RemoveCollection(ctx context.Context, collectionID int64) error { + if err := m.checkLifetime(); err != nil { + return err + } + defer m.lifetime.Done() + + waitForSealed := m.managers.RemoveCollection(collectionID) + m.helper.AsyncSeal(waitForSealed...) + + // trigger a seal operation in background rightnow. + resource.Resource().SegmentSealedInspector().TriggerSealWaited(ctx, m.pchannel.Name) + + // wait for all segment has been flushed. + return m.helper.WaitUntilNoWaitSeal(ctx) +} + +// RemovePartition removes the specified partitions. +func (m *PChannelSegmentAllocManager) RemovePartition(ctx context.Context, collectionID int64, partitionID int64) error { + if err := m.checkLifetime(); err != nil { + return err + } + defer m.lifetime.Done() + + // Remove the given partition from the partition managers. + // And seal all segments that should be sealed. + waitForSealed := m.managers.RemovePartition(collectionID, partitionID) + m.helper.AsyncSeal(waitForSealed...) + + // trigger a seal operation in background rightnow. + resource.Resource().SegmentSealedInspector().TriggerSealWaited(ctx, m.pchannel.Name) + + // wait for all segment has been flushed. + return m.helper.WaitUntilNoWaitSeal(ctx) +} + +// TryToSealSegments tries to seal the specified segments. +func (m *PChannelSegmentAllocManager) TryToSealSegments(ctx context.Context, infos ...stats.SegmentBelongs) { + if err := m.lifetime.Add(lifetime.IsWorking); err != nil { + return + } + defer m.lifetime.Done() + + if len(infos) == 0 { + // if no segment info specified, try to seal all segments. + m.managers.Range(func(pm *partitionSegmentManager) { + m.helper.AsyncSeal(pm.CollectShouldBeSealed()...) + }) + } else { + // if some segment info specified, try to seal the specified partition. + for _, info := range infos { + if pm, err := m.managers.Get(info.CollectionID, info.PartitionID); err == nil { + m.helper.AsyncSeal(pm.CollectShouldBeSealed()...) + } + } + } + m.helper.SealAllWait(ctx) +} + +// TryToSealWaitedSegment tries to seal the wait for sealing segment. +func (m *PChannelSegmentAllocManager) TryToSealWaitedSegment(ctx context.Context) { + if err := m.lifetime.Add(lifetime.IsWorking); err != nil { + return + } + defer m.lifetime.Done() + + m.helper.SealAllWait(ctx) +} + +// IsNoWaitSeal returns whether the segment manager is no segment wait for seal. +func (m *PChannelSegmentAllocManager) IsNoWaitSeal() bool { + return m.helper.IsEmpty() +} + +// WaitUntilNoWaitSeal waits until no segment wait for seal. +func (m *PChannelSegmentAllocManager) WaitUntilNoWaitSeal(ctx context.Context) error { + if err := m.lifetime.Add(lifetime.IsWorking); err != nil { + return err + } + defer m.lifetime.Done() + + return m.helper.WaitUntilNoWaitSeal(ctx) +} + +// checkLifetime checks the lifetime of the segment manager. +func (m *PChannelSegmentAllocManager) checkLifetime() error { + if err := m.lifetime.Add(lifetime.IsWorking); err != nil { + m.logger.Warn("unreachable: segment assignment manager is not working, so the wal is on closing", zap.Error(err)) + return errors.New("segment assignment manager is not working") + } + return nil +} + +// Close try to persist all stats and invalid the manager. +func (m *PChannelSegmentAllocManager) Close(ctx context.Context) { + m.logger.Info("segment assignment manager start to close") + m.lifetime.SetState(lifetime.Stopped) + m.lifetime.Wait() + + // Try to seal all wait + m.helper.SealAllWait(ctx) + m.logger.Info("seal all waited segments done", zap.Int("waitCounter", m.helper.WaitCounter())) + + segments := make([]*segmentAllocManager, 0) + m.managers.Range(func(pm *partitionSegmentManager) { + segments = append(segments, pm.CollectDirtySegmentsAndClear()...) + }) + + // commitAllSegmentsOnSamePChannel commits all segments on the same pchannel. + protoSegments := make([]*streamingpb.SegmentAssignmentMeta, 0, len(segments)) + for _, segment := range segments { + protoSegments = append(protoSegments, segment.Snapshot()) + } + + m.logger.Info("segment assignment manager save all dirty segment assignments info", zap.Int("segmentCount", len(protoSegments))) + if err := resource.Resource().StreamingNodeCatalog().SaveSegmentAssignments(ctx, m.pchannel.Name, protoSegments); err != nil { + m.logger.Warn("commit segment assignment at pchannel failed", zap.Error(err)) + } + + // remove the stats from stats manager. + m.logger.Info("segment assignment manager remove all segment stats from stats manager") + for _, segment := range segments { + if segment.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { + resource.Resource().SegmentAssignStatsManager().UnregisterSealedSegment(segment.GetSegmentID()) + } + } +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go new file mode 100644 index 0000000000..4c24043c4d --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go @@ -0,0 +1,293 @@ +package manager + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/mocks/mock_metastore" + "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_wal" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +func TestSegmentAllocManager(t *testing.T) { + initializeTestState(t) + + w := mock_wal.NewMockWAL(t) + w.EXPECT().Append(mock.Anything, mock.Anything).Return(nil, nil) + f := syncutil.NewFuture[wal.WAL]() + f.Set(w) + + m, err := RecoverPChannelSegmentAllocManager(context.Background(), types.PChannelInfo{Name: "v1"}, f) + assert.NoError(t, err) + assert.NotNil(t, m) + + ctx := context.Background() + + // Ask for allocate segment + result, err := m.AssignSegment(ctx, &AssignSegmentRequest{ + CollectionID: 1, + PartitionID: 1, + InsertMetrics: stats.InsertMetrics{ + Rows: 100, + BinarySize: 100, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, result) + + // Ask for allocate more segment, will generated new growing segment. + result2, err := m.AssignSegment(ctx, &AssignSegmentRequest{ + CollectionID: 1, + PartitionID: 1, + InsertMetrics: stats.InsertMetrics{ + Rows: 1024 * 1024, + BinarySize: 1024 * 1024, // 1MB setting at paramtable. + }, + }) + assert.NoError(t, err) + assert.NotNil(t, result2) + + // Ask for seal segment. + // Here already have a sealed segment, and a growing segment wait for seal, but the result is not acked. + m.TryToSealSegments(ctx) + assert.False(t, m.IsNoWaitSeal()) + + // The following segment assign will trigger a reach limit, so new seal segment will be created. + result3, err := m.AssignSegment(ctx, &AssignSegmentRequest{ + CollectionID: 1, + PartitionID: 1, + InsertMetrics: stats.InsertMetrics{ + Rows: 1, + BinarySize: 1, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, result3) + m.TryToSealSegments(ctx) + assert.False(t, m.IsNoWaitSeal()) // result2 is not acked, so new seal segment will not be sealed right away. + + result.Ack() + result2.Ack() + result3.Ack() + m.TryToSealWaitedSegment(ctx) + assert.True(t, m.IsNoWaitSeal()) // result2 is acked, so new seal segment will be sealed right away. + + // Try to seal a partition. + m.TryToSealSegments(ctx, stats.SegmentBelongs{ + CollectionID: 1, + VChannel: "v1", + PartitionID: 2, + PChannel: "v1", + }) + assert.True(t, m.IsNoWaitSeal()) + + // Try to seal with a policy + resource.Resource().SegmentAssignStatsManager().UpdateOnFlush(6000, stats.FlushOperationMetrics{ + BinLogCounter: 100, + }) + // ask a unacknowledgement seal for partition 3 to avoid seal operation. + result, err = m.AssignSegment(ctx, &AssignSegmentRequest{ + CollectionID: 1, + PartitionID: 3, + InsertMetrics: stats.InsertMetrics{ + Rows: 100, + BinarySize: 100, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, result) + + // Should be collected but not sealed. + m.TryToSealSegments(ctx) + assert.False(t, m.IsNoWaitSeal()) + result.Ack() + // Should be sealed. + m.TryToSealSegments(ctx) + assert.True(t, m.IsNoWaitSeal()) + + m.Close(ctx) +} + +func TestCreateAndDropCollection(t *testing.T) { + initializeTestState(t) + + w := mock_wal.NewMockWAL(t) + w.EXPECT().Append(mock.Anything, mock.Anything).Return(nil, nil) + f := syncutil.NewFuture[wal.WAL]() + f.Set(w) + + m, err := RecoverPChannelSegmentAllocManager(context.Background(), types.PChannelInfo{Name: "v1"}, f) + assert.NoError(t, err) + assert.NotNil(t, m) + resource.Resource().SegmentSealedInspector().RegsiterPChannelManager(m) + + ctx := context.Background() + + testRequest := &AssignSegmentRequest{ + CollectionID: 100, + PartitionID: 101, + InsertMetrics: stats.InsertMetrics{ + Rows: 100, + BinarySize: 200, + }, + } + + resp, err := m.AssignSegment(ctx, testRequest) + assert.Error(t, err) + assert.Nil(t, resp) + + m.NewCollection(100, "v1", []int64{101, 102, 103}) + resp, err = m.AssignSegment(ctx, testRequest) + assert.NoError(t, err) + assert.NotNil(t, resp) + resp.Ack() + + testRequest.PartitionID = 104 + resp, err = m.AssignSegment(ctx, testRequest) + assert.Error(t, err) + assert.Nil(t, resp) + + m.NewPartition(100, 104) + resp, err = m.AssignSegment(ctx, testRequest) + assert.NoError(t, err) + assert.NotNil(t, resp) + resp.Ack() + + m.RemovePartition(ctx, 100, 104) + assert.True(t, m.IsNoWaitSeal()) + resp, err = m.AssignSegment(ctx, testRequest) + assert.Error(t, err) + assert.Nil(t, resp) + + m.RemoveCollection(ctx, 100) + resp, err = m.AssignSegment(ctx, testRequest) + assert.True(t, m.IsNoWaitSeal()) + assert.Error(t, err) + assert.Nil(t, resp) +} + +func newStat(insertedBinarySize uint64, maxBinarySize uint64) *streamingpb.SegmentAssignmentStat { + return &streamingpb.SegmentAssignmentStat{ + MaxBinarySize: maxBinarySize, + InsertedRows: insertedBinarySize, + InsertedBinarySize: insertedBinarySize, + CreateTimestampNanoseconds: time.Now().UnixNano(), + LastModifiedTimestampNanoseconds: time.Now().UnixNano(), + } +} + +// initializeTestState is a helper function to initialize the status for testing. +func initializeTestState(t *testing.T) { + // c 1 + // p 1 + // s 1000p + // p 2 + // s 2000g, 3000g, 4000s, 5000g + // p 3 + // s 6000g + + paramtable.Init() + paramtable.Get().DataCoordCfg.SegmentSealProportionJitter.SwapTempValue("0.0") + paramtable.Get().DataCoordCfg.SegmentMaxSize.SwapTempValue("1") + + streamingNodeCatalog := mock_metastore.NewMockStreamingNodeCataLog(t) + dataCoordClient := mocks.NewMockDataCoordClient(t) + dataCoordClient.EXPECT().AllocSegment(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, asr *datapb.AllocSegmentRequest, co ...grpc.CallOption) (*datapb.AllocSegmentResponse, error) { + return &datapb.AllocSegmentResponse{ + SegmentInfo: &datapb.SegmentInfo{ + ID: asr.GetSegmentId(), + CollectionID: asr.GetCollectionId(), + PartitionID: asr.GetPartitionId(), + }, + Status: merr.Success(), + }, nil + }) + + rootCoordClient := idalloc.NewMockRootCoordClient(t) + rootCoordClient.EXPECT().GetPChannelInfo(mock.Anything, mock.Anything).Return(&rootcoordpb.GetPChannelInfoResponse{ + Collections: []*rootcoordpb.CollectionInfoOnPChannel{ + { + CollectionId: 1, + Partitions: []*rootcoordpb.PartitionInfoOnPChannel{ + {PartitionId: 1}, + {PartitionId: 2}, + {PartitionId: 3}, + }, + }, + }, + }, nil) + + resource.InitForTest(t, + resource.OptStreamingNodeCatalog(streamingNodeCatalog), + resource.OptDataCoordClient(dataCoordClient), + resource.OptRootCoordClient(rootCoordClient), + ) + streamingNodeCatalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).Return( + []*streamingpb.SegmentAssignmentMeta{ + { + CollectionId: 1, + PartitionId: 1, + SegmentId: 1000, + Vchannel: "v1", + State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_PENDING, + Stat: nil, + }, + { + CollectionId: 1, + PartitionId: 2, + SegmentId: 2000, + Vchannel: "v1", + State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING, + Stat: newStat(1000, 1000), + }, + { + CollectionId: 1, + PartitionId: 2, + SegmentId: 3000, + Vchannel: "v1", + State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING, + Stat: newStat(100, 1000), + }, + { + CollectionId: 1, + PartitionId: 2, + SegmentId: 4000, + Vchannel: "v1", + State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_SEALED, + Stat: newStat(900, 1000), + }, + { + CollectionId: 1, + PartitionId: 2, + SegmentId: 5000, + Vchannel: "v1", + State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING, + Stat: newStat(900, 1000), + }, + { + CollectionId: 1, + PartitionId: 3, + SegmentId: 6000, + Vchannel: "v1", + State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING, + Stat: newStat(100, 1000), + }, + }, nil) + streamingNodeCatalog.EXPECT().SaveSegmentAssignments(mock.Anything, mock.Anything, mock.Anything).Return(nil) +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/seal_queue.go b/internal/streamingnode/server/wal/interceptors/segment/manager/seal_queue.go new file mode 100644 index 0000000000..c31fd9f2d6 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/seal_queue.go @@ -0,0 +1,195 @@ +package manager + +import ( + "context" + "sync" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/syncutil" +) + +// newSealQueue creates a new seal helper queue. +func newSealQueue(logger *log.MLogger, wal *syncutil.Future[wal.WAL], waitForSealed []*segmentAllocManager) *sealQueue { + return &sealQueue{ + cond: syncutil.NewContextCond(&sync.Mutex{}), + logger: logger, + wal: wal, + waitForSealed: waitForSealed, + waitCounter: len(waitForSealed), + } +} + +// sealQueue is a helper to seal segments. +type sealQueue struct { + cond *syncutil.ContextCond + logger *log.MLogger + wal *syncutil.Future[wal.WAL] + waitForSealed []*segmentAllocManager + waitCounter int // wait counter count the real wait segment count, it is not equal to waitForSealed length. + // some segments may be in sealing process. +} + +// AsyncSeal adds a segment into the queue, and will be sealed at next time. +func (q *sealQueue) AsyncSeal(manager ...*segmentAllocManager) { + q.cond.LockAndBroadcast() + defer q.cond.L.Unlock() + + q.waitForSealed = append(q.waitForSealed, manager...) + q.waitCounter += len(manager) +} + +// SealAllWait seals all segments in the queue. +// If the operation is failure, the segments will be collected and will be retried at next time. +// Return true if all segments are sealed, otherwise return false. +func (q *sealQueue) SealAllWait(ctx context.Context) { + q.cond.L.Lock() + segments := q.waitForSealed + q.waitForSealed = make([]*segmentAllocManager, 0) + q.cond.L.Unlock() + + q.tryToSealSegments(ctx, segments...) +} + +// IsEmpty returns whether the queue is empty. +func (q *sealQueue) IsEmpty() bool { + q.cond.L.Lock() + defer q.cond.L.Unlock() + + return q.waitCounter == 0 +} + +// WaitCounter returns the wait counter. +func (q *sealQueue) WaitCounter() int { + q.cond.L.Lock() + defer q.cond.L.Unlock() + + return q.waitCounter +} + +// WaitUntilNoWaitSeal waits until no segment in the queue. +func (q *sealQueue) WaitUntilNoWaitSeal(ctx context.Context) error { + // wait until the wait counter becomes 0. + q.cond.L.Lock() + for q.waitCounter > 0 { + if err := q.cond.Wait(ctx); err != nil { + return err + } + } + q.cond.L.Unlock() + return nil +} + +// tryToSealSegments tries to seal segments, return the undone segments. +func (q *sealQueue) tryToSealSegments(ctx context.Context, segments ...*segmentAllocManager) { + if len(segments) == 0 { + return + } + undone, sealedSegments := q.transferSegmentStateIntoSealed(ctx, segments...) + + // send flush message into wal. + for collectionID, vchannelSegments := range sealedSegments { + for vchannel, segments := range vchannelSegments { + if err := q.sendFlushMessageIntoWAL(ctx, collectionID, vchannel, segments); err != nil { + q.logger.Warn("fail to send flush message into wal", zap.String("vchannel", vchannel), zap.Int64("collectionID", collectionID), zap.Error(err)) + undone = append(undone, segments...) + continue + } + for _, segment := range segments { + tx := segment.BeginModification() + tx.IntoFlushed() + if err := tx.Commit(ctx); err != nil { + q.logger.Warn("flushed segment failed at commit, maybe sent repeated flush message into wal", zap.Int64("segmentID", segment.GetSegmentID()), zap.Error(err)) + undone = append(undone, segment) + } + } + } + } + + q.cond.LockAndBroadcast() + q.waitForSealed = append(q.waitForSealed, undone...) + // the undone one should be retried at next time, so the counter should not decrease. + q.waitCounter -= (len(segments) - len(undone)) + q.cond.L.Unlock() +} + +// transferSegmentStateIntoSealed transfers the segment state into sealed. +func (q *sealQueue) transferSegmentStateIntoSealed(ctx context.Context, segments ...*segmentAllocManager) ([]*segmentAllocManager, map[int64]map[string][]*segmentAllocManager) { + // undone sealed segment should be done at next time. + undone := make([]*segmentAllocManager, 0) + sealedSegments := make(map[int64]map[string][]*segmentAllocManager) + for _, segment := range segments { + if segment.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { + tx := segment.BeginModification() + tx.IntoSealed() + if err := tx.Commit(ctx); err != nil { + q.logger.Warn("seal segment failed at commit", zap.Int64("segmentID", segment.GetSegmentID()), zap.Error(err)) + undone = append(undone, segment) + continue + } + } + // assert here. + if segment.GetState() != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_SEALED { + panic("unreachable code: segment should be sealed here") + } + + // if there'are flying acks, wait them acked, delay the sealed at next retry. + ackSem := segment.AckSem() + if ackSem > 0 { + undone = append(undone, segment) + q.logger.Info("segment has been sealed, but there are flying acks, delay it", zap.Int64("segmentID", segment.GetSegmentID()), zap.Int32("ackSem", ackSem)) + continue + } + + // collect all sealed segments and no flying ack segment. + if _, ok := sealedSegments[segment.GetCollectionID()]; !ok { + sealedSegments[segment.GetCollectionID()] = make(map[string][]*segmentAllocManager) + } + if _, ok := sealedSegments[segment.GetCollectionID()][segment.GetVChannel()]; !ok { + sealedSegments[segment.GetCollectionID()][segment.GetVChannel()] = make([]*segmentAllocManager, 0) + } + sealedSegments[segment.GetCollectionID()][segment.GetVChannel()] = append(sealedSegments[segment.GetCollectionID()][segment.GetVChannel()], segment) + } + return undone, sealedSegments +} + +// sendFlushMessageIntoWAL sends a flush message into wal. +func (m *sealQueue) sendFlushMessageIntoWAL(ctx context.Context, collectionID int64, vchannel string, segments []*segmentAllocManager) error { + segmentIDs := make([]int64, 0, len(segments)) + for _, segment := range segments { + segmentIDs = append(segmentIDs, segment.GetSegmentID()) + } + msg, err := m.createNewFlushMessage(collectionID, vchannel, segmentIDs) + if err != nil { + return errors.Wrap(err, "at create new flush message") + } + + msgID, err := m.wal.Get().Append(ctx, msg) + if err != nil { + m.logger.Warn("send flush message into wal failed", zap.Int64("collectionID", collectionID), zap.String("vchannel", vchannel), zap.Int64s("segmentIDs", segmentIDs), zap.Error(err)) + return err + } + m.logger.Info("send flush message into wal", zap.Int64("collectionID", collectionID), zap.String("vchannel", vchannel), zap.Int64s("segmentIDs", segmentIDs), zap.Any("msgID", msgID)) + return nil +} + +// createNewFlushMessage creates a new flush message. +func (m *sealQueue) createNewFlushMessage(collectionID int64, vchannel string, segmentIDs []int64) (message.MutableMessage, error) { + // Create a flush message. + msg, err := message.NewFlushMessageBuilderV1(). + WithHeader(&message.FlushMessageHeader{}). + WithBody(&message.FlushMessagePayload{ + CollectionId: collectionID, + SegmentId: segmentIDs, + }).BuildMutable() + if err != nil { + return nil, errors.Wrap(err, "at create new flush message") + } + msg.WithVChannel(vchannel) + return msg, nil +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/segment_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/segment_manager.go new file mode 100644 index 0000000000..1f02381eda --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/segment_manager.go @@ -0,0 +1,252 @@ +package manager + +import ( + "context" + "time" + + "go.uber.org/atomic" + "go.uber.org/zap" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/policy" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +const dirtyThreshold = 30 * 1024 * 1024 // 30MB + +// newSegmentAllocManagerFromProto creates a new segment assignment meta from proto. +func newSegmentAllocManagerFromProto( + pchannel types.PChannelInfo, + inner *streamingpb.SegmentAssignmentMeta, +) *segmentAllocManager { + stat := stats.NewSegmentStatFromProto(inner.Stat) + // Growing segment's stat should be registered to stats manager. + // Async sealed policy will use it. + if inner.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { + resource.Resource().SegmentAssignStatsManager().RegisterNewGrowingSegment(stats.SegmentBelongs{ + CollectionID: inner.GetCollectionId(), + PartitionID: inner.GetPartitionId(), + PChannel: pchannel.Name, + VChannel: inner.GetVchannel(), + }, inner.GetSegmentId(), stat) + stat = nil + } + return &segmentAllocManager{ + pchannel: pchannel, + inner: inner, + immutableStat: stat, + ackSem: atomic.NewInt32(0), + dirtyBytes: 0, + } +} + +// newSegmentAllocManager creates a new segment assignment meta. +func newSegmentAllocManager( + pchannel types.PChannelInfo, + collectionID int64, + partitionID int64, + segmentID int64, + vchannel string, +) *segmentAllocManager { + return &segmentAllocManager{ + pchannel: pchannel, + inner: &streamingpb.SegmentAssignmentMeta{ + CollectionId: collectionID, + PartitionId: partitionID, + SegmentId: segmentID, + Vchannel: vchannel, + State: streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_PENDING, + Stat: nil, + }, + immutableStat: nil, // immutable stat can be seen after sealed. + ackSem: atomic.NewInt32(0), + dirtyBytes: 0, + } +} + +// segmentAllocManager is the meta of segment assignment, +// only used to recover the assignment status on streaming node. +// !!! Not Concurrent Safe +// The state transfer is as follows: +// Pending -> Growing -> Sealed -> Flushed. +// +// The recovery process is as follows: +// +// | State | DataCoord View | Writable | WAL Status | Recovery | +// |-- | -- | -- | -- | -- | +// | Pending | Not exist | No | Not exist | 1. Check datacoord if exist; transfer into growing if exist. | +// | Growing | Exist | Yes | Insert Message Exist; Seal Message Not Exist | nothing | +// | Sealed | Exist | No | Insert Message Exist; Seal Message Maybe Exist | Resend a Seal Message and transfer into Flushed. | +// | Flushed | Exist | No | Insert Message Exist; Seal Message Exist | Already physically deleted, nothing to do | +type segmentAllocManager struct { + pchannel types.PChannelInfo + inner *streamingpb.SegmentAssignmentMeta + immutableStat *stats.SegmentStats // after sealed or flushed, the stat is immutable and cannot be seen by stats manager. + ackSem *atomic.Int32 // the ackSem is increased when segment allocRows, decreased when the segment is acked. + dirtyBytes uint64 // records the dirty bytes that didn't persist. +} + +// GetCollectionID returns the collection id of the segment assignment meta. +func (s *segmentAllocManager) GetCollectionID() int64 { + return s.inner.GetCollectionId() +} + +// GetPartitionID returns the partition id of the segment assignment meta. +func (s *segmentAllocManager) GetPartitionID() int64 { + return s.inner.GetPartitionId() +} + +// GetSegmentID returns the segment id of the segment assignment meta. +func (s *segmentAllocManager) GetSegmentID() int64 { + return s.inner.GetSegmentId() +} + +// GetVChannel returns the vchannel of the segment assignment meta. +func (s *segmentAllocManager) GetVChannel() string { + return s.inner.GetVchannel() +} + +// State returns the state of the segment assignment meta. +func (s *segmentAllocManager) GetState() streamingpb.SegmentAssignmentState { + return s.inner.GetState() +} + +// Stat get the stat of segments. +// Pending segment will return nil. +// Growing segment will return a snapshot. +// Sealed segment will return the final. +func (s *segmentAllocManager) GetStat() *stats.SegmentStats { + if s.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { + return resource.Resource().SegmentAssignStatsManager().GetStatsOfSegment(s.GetSegmentID()) + } + return s.immutableStat +} + +// AckSem returns the ack sem. +func (s *segmentAllocManager) AckSem() int32 { + return s.ackSem.Load() +} + +// AllocRows ask for rows from current segment. +// Only growing and not fenced segment can alloc rows. +func (s *segmentAllocManager) AllocRows(ctx context.Context, m stats.InsertMetrics) (bool, *atomic.Int32) { + // if the segment is not growing or reach limit, return false directly. + if s.inner.GetState() != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { + return false, nil + } + inserted := resource.Resource().SegmentAssignStatsManager().AllocRows(s.GetSegmentID(), m) + if !inserted { + return false, nil + } + s.dirtyBytes += m.BinarySize + s.ackSem.Inc() + + // persist stats if too dirty. + s.persistStatsIfTooDirty(ctx) + return inserted, s.ackSem +} + +// Snapshot returns the snapshot of the segment assignment meta. +func (s *segmentAllocManager) Snapshot() *streamingpb.SegmentAssignmentMeta { + copied := proto.Clone(s.inner).(*streamingpb.SegmentAssignmentMeta) + copied.Stat = stats.NewProtoFromSegmentStat(s.GetStat()) + return copied +} + +// IsDirtyEnough returns if the dirty bytes is enough to persist. +func (s *segmentAllocManager) IsDirtyEnough() bool { + // only growing segment can be dirty. + return s.inner.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING && s.dirtyBytes >= dirtyThreshold +} + +// PersisteStatsIfTooDirty persists the stats if the dirty bytes is too large. +func (s *segmentAllocManager) persistStatsIfTooDirty(ctx context.Context) { + if s.inner.GetState() != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { + return + } + if s.dirtyBytes < dirtyThreshold { + return + } + if err := resource.Resource().StreamingNodeCatalog().SaveSegmentAssignments(ctx, s.pchannel.Name, []*streamingpb.SegmentAssignmentMeta{ + s.Snapshot(), + }); err != nil { + log.Warn("failed to persist stats of segment", zap.Int64("segmentID", s.GetSegmentID()), zap.Error(err)) + } + s.dirtyBytes = 0 +} + +// BeginModification begins the modification of the segment assignment meta. +// Do a copy of the segment assignment meta, update the remote meta storage, than modifies the original. +func (s *segmentAllocManager) BeginModification() *mutableSegmentAssignmentMeta { + copied := s.Snapshot() + return &mutableSegmentAssignmentMeta{ + original: s, + modifiedCopy: copied, + } +} + +// mutableSegmentAssignmentMeta is the mutable version of segment assignment meta. +type mutableSegmentAssignmentMeta struct { + original *segmentAllocManager + modifiedCopy *streamingpb.SegmentAssignmentMeta +} + +// IntoGrowing transfers the segment assignment meta into growing state. +func (m *mutableSegmentAssignmentMeta) IntoGrowing(limitation *policy.SegmentLimitation) { + if m.modifiedCopy.State != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_PENDING { + panic("tranfer state to growing from non-pending state") + } + m.modifiedCopy.State = streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING + now := time.Now().UnixNano() + m.modifiedCopy.Stat = &streamingpb.SegmentAssignmentStat{ + MaxBinarySize: limitation.SegmentSize, + CreateTimestampNanoseconds: now, + LastModifiedTimestampNanoseconds: now, + } +} + +// IntoSealed transfers the segment assignment meta into sealed state. +func (m *mutableSegmentAssignmentMeta) IntoSealed() { + if m.modifiedCopy.State != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { + panic("tranfer state to sealed from non-growing state") + } + m.modifiedCopy.State = streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_SEALED +} + +// IntoFlushed transfers the segment assignment meta into flushed state. +// Will be delted physically when transfer into flushed state. +func (m *mutableSegmentAssignmentMeta) IntoFlushed() { + if m.modifiedCopy.State != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_SEALED { + panic("tranfer state to flushed from non-sealed state") + } + m.modifiedCopy.State = streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_FLUSHED +} + +// Commit commits the modification. +func (m *mutableSegmentAssignmentMeta) Commit(ctx context.Context) error { + if err := resource.Resource().StreamingNodeCatalog().SaveSegmentAssignments(ctx, m.original.pchannel.Name, []*streamingpb.SegmentAssignmentMeta{ + m.modifiedCopy, + }); err != nil { + return err + } + if m.original.GetState() != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING && + m.modifiedCopy.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { + // if the state transferred into growing, register the stats to stats manager. + resource.Resource().SegmentAssignStatsManager().RegisterNewGrowingSegment(stats.SegmentBelongs{ + CollectionID: m.original.GetCollectionID(), + PartitionID: m.original.GetPartitionID(), + PChannel: m.original.pchannel.Name, + VChannel: m.original.GetVChannel(), + }, m.original.GetSegmentID(), stats.NewSegmentStatFromProto(m.modifiedCopy.Stat)) + } else if m.original.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING && + m.modifiedCopy.GetState() != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { + // if the state transferred from growing into others, remove the stats from stats manager. + m.original.immutableStat = resource.Resource().SegmentAssignStatsManager().UnregisterSealedSegment(m.original.GetSegmentID()) + } + m.original.inner = m.modifiedCopy + return nil +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/policy/global_seal_policy.go b/internal/streamingnode/server/wal/interceptors/segment/policy/global_seal_policy.go new file mode 100644 index 0000000000..4ddd6ff5a6 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/policy/global_seal_policy.go @@ -0,0 +1,14 @@ +package policy + +import "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + +func GetGlobalAsyncSealPolicy() []GlobalAsyncSealPolicy { + // TODO: dynamic policy can be applied here in future. + return []GlobalAsyncSealPolicy{} +} + +// GlobalAsyncSealPolicy is the policy to check if a global segment should be sealed or not. +type GlobalAsyncSealPolicy interface { + // ShouldSealed checks if the segment should be sealed, and return the reason string. + ShouldSealed(m stats.StatsManager) +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/policy/segment_limitation_policy.go b/internal/streamingnode/server/wal/interceptors/segment/policy/segment_limitation_policy.go new file mode 100644 index 0000000000..7de8d3a502 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/policy/segment_limitation_policy.go @@ -0,0 +1,59 @@ +package policy + +import ( + "math/rand" + + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +// GetSegmentLimitationPolicy returns the segment limitation policy. +func GetSegmentLimitationPolicy() SegmentLimitationPolicy { + // TODO: dynamic policy can be applied here in future. + return jitterSegmentLimitationPolicy{} +} + +// SegmentLimitation is the limitation of the segment. +type SegmentLimitation struct { + PolicyName string + SegmentSize uint64 + ExtraInfo interface{} +} + +// SegmentLimitationPolicy is the interface to generate the limitation of the segment. +type SegmentLimitationPolicy interface { + // GenerateLimitation generates the limitation of the segment. + GenerateLimitation() SegmentLimitation +} + +// jitterSegmentLimitationPolicyExtraInfo is the extra info of the jitter segment limitation policy. +type jitterSegmentLimitationPolicyExtraInfo struct { + Jitter float64 + JitterRatio float64 + MaxSegmentSize uint64 +} + +// jiiterSegmentLimitationPolicy is the policy to generate the limitation of the segment. +// Add a jitter to the segment size limitation to scatter the segment sealing time. +type jitterSegmentLimitationPolicy struct{} + +// GenerateLimitation generates the limitation of the segment. +func (p jitterSegmentLimitationPolicy) GenerateLimitation() SegmentLimitation { + // TODO: It's weird to set such a parameter into datacoord configuration. + // Refactor it in the future + jitter := paramtable.Get().DataCoordCfg.SegmentSealProportionJitter.GetAsFloat() + jitterRatio := 1 - jitter*rand.Float64() // generate a random number in [1-jitter, 1] + if jitterRatio <= 0 || jitterRatio > 1 { + jitterRatio = 1 + } + maxSegmentSize := uint64(paramtable.Get().DataCoordCfg.SegmentMaxSize.GetAsInt64() * 1024 * 1024) + segmentSize := uint64(jitterRatio * float64(maxSegmentSize)) + return SegmentLimitation{ + PolicyName: "jitter_segment_limitation", + SegmentSize: segmentSize, + ExtraInfo: jitterSegmentLimitationPolicyExtraInfo{ + Jitter: jitter, + JitterRatio: jitterRatio, + MaxSegmentSize: maxSegmentSize, + }, + } +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/policy/segment_seal_policy.go b/internal/streamingnode/server/wal/interceptors/segment/policy/segment_seal_policy.go new file mode 100644 index 0000000000..1a8110770f --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/policy/segment_seal_policy.go @@ -0,0 +1,114 @@ +package policy + +import ( + "time" + + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +// GetSegmentAsyncSealPolicy returns the segment async seal policy. +func GetSegmentAsyncSealPolicy() []SegmentAsyncSealPolicy { + // TODO: dynamic policy can be applied here in future. + return []SegmentAsyncSealPolicy{ + &sealByCapacity{}, + &sealByBinlogFileNumber{}, + &sealByLifetime{}, + &sealByIdleTime{}, + } +} + +// SealPolicyResult is the result of the seal policy. +type SealPolicyResult struct { + PolicyName string + ShouldBeSealed bool + ExtraInfo interface{} +} + +// SegmentAsyncSealPolicy is the policy to check if a segment should be sealed or not. +// Those policies are called asynchronously, so the stat is not real time. +// A policy should be stateless, and only check by segment stats. +// quick enough to be called. +type SegmentAsyncSealPolicy interface { + // ShouldBeSealed checks if the segment should be sealed, and return the reason string. + ShouldBeSealed(stats *stats.SegmentStats) SealPolicyResult +} + +// sealByCapacity is a policy to seal the segment by the capacity. +type sealByCapacity struct{} + +// ShouldBeSealed checks if the segment should be sealed, and return the reason string. +func (p *sealByCapacity) ShouldBeSealed(stats *stats.SegmentStats) SealPolicyResult { + return SealPolicyResult{ + PolicyName: "seal_by_capacity", + ShouldBeSealed: stats.ReachLimit, + ExtraInfo: nil, + } +} + +// sealByBinlogFileNumberExtraInfo is the extra info of the seal by binlog file number policy. +type sealByBinlogFileNumberExtraInfo struct { + BinLogFileNumberLimit int +} + +// sealByBinlogFileNumber is a policy to seal the segment by the binlog file number. +type sealByBinlogFileNumber struct{} + +// ShouldBeSealed checks if the segment should be sealed, and return the reason string. +func (p *sealByBinlogFileNumber) ShouldBeSealed(stats *stats.SegmentStats) SealPolicyResult { + limit := paramtable.Get().DataCoordCfg.SegmentMaxBinlogFileNumber.GetAsInt() + shouldBeSealed := stats.BinLogCounter >= uint64(limit) + return SealPolicyResult{ + PolicyName: "seal_by_binlog_file_number", + ShouldBeSealed: shouldBeSealed, + ExtraInfo: &sealByBinlogFileNumberExtraInfo{ + BinLogFileNumberLimit: limit, + }, + } +} + +// sealByLifetimeExtraInfo is the extra info of the seal by lifetime policy. +type sealByLifetimeExtraInfo struct { + MaxLifeTime time.Duration +} + +// sealByLifetime is a policy to seal the segment by the lifetime. +type sealByLifetime struct{} + +// ShouldBeSealed checks if the segment should be sealed, and return the reason string. +func (p *sealByLifetime) ShouldBeSealed(stats *stats.SegmentStats) SealPolicyResult { + lifetime := paramtable.Get().DataCoordCfg.SegmentMaxLifetime.GetAsDuration(time.Second) + shouldBeSealed := time.Since(stats.CreateTime) > lifetime + return SealPolicyResult{ + PolicyName: "seal_by_lifetime", + ShouldBeSealed: shouldBeSealed, + ExtraInfo: sealByLifetimeExtraInfo{ + MaxLifeTime: lifetime, + }, + } +} + +// sealByIdleTimeExtraInfo is the extra info of the seal by idle time policy. +type sealByIdleTimeExtraInfo struct { + IdleTime time.Duration + MinimalSize uint64 +} + +// sealByIdleTime is a policy to seal the segment by the idle time. +type sealByIdleTime struct{} + +// ShouldBeSealed checks if the segment should be sealed, and return the reason string. +func (p *sealByIdleTime) ShouldBeSealed(stats *stats.SegmentStats) SealPolicyResult { + idleTime := paramtable.Get().DataCoordCfg.SegmentMaxIdleTime.GetAsDuration(time.Second) + minSize := uint64(paramtable.Get().DataCoordCfg.SegmentMinSizeFromIdleToSealed.GetAsInt() * 1024 * 1024) + + shouldBeSealed := stats.Insert.BinarySize > minSize && time.Since(stats.LastModifiedTime) > idleTime + return SealPolicyResult{ + PolicyName: "seal_by_idle_time", + ShouldBeSealed: shouldBeSealed, + ExtraInfo: sealByIdleTimeExtraInfo{ + IdleTime: idleTime, + MinimalSize: minSize, + }, + } +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/segment_assign_interceptor.go b/internal/streamingnode/server/wal/interceptors/segment/segment_assign_interceptor.go new file mode 100644 index 0000000000..f544912db5 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/segment_assign_interceptor.go @@ -0,0 +1,200 @@ +package segment + +import ( + "context" + "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/manager" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var _ interceptors.AppendInterceptor = (*segmentInterceptor)(nil) + +// segmentInterceptor is the implementation of segment assignment interceptor. +type segmentInterceptor struct { + ctx context.Context + cancel context.CancelFunc + + logger *log.MLogger + assignManager *syncutil.Future[*manager.PChannelSegmentAllocManager] +} + +// Ready returns a channel that will be closed when the segment interceptor is ready. +func (impl *segmentInterceptor) Ready() <-chan struct{} { + // Wait for segment assignment manager ready. + return impl.assignManager.Done() +} + +// DoAppend assigns segment for every partition in the message. +func (impl *segmentInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, appendOp interceptors.Append) (msgID message.MessageID, err error) { + switch msg.MessageType() { + case message.MessageTypeCreateCollection: + return impl.handleCreateCollection(ctx, msg, appendOp) + case message.MessageTypeDropCollection: + return impl.handleDropCollection(ctx, msg, appendOp) + case message.MessageTypeCreatePartition: + return impl.handleCreatePartition(ctx, msg, appendOp) + case message.MessageTypeDropPartition: + return impl.handleDropPartition(ctx, msg, appendOp) + case message.MessageTypeInsert: + return impl.handleInsertMessage(ctx, msg, appendOp) + default: + return appendOp(ctx, msg) + } +} + +// handleCreateCollection handles the create collection message. +func (impl *segmentInterceptor) handleCreateCollection(ctx context.Context, msg message.MutableMessage, appendOp interceptors.Append) (message.MessageID, error) { + createCollectionMsg, err := message.AsMutableCreateCollectionMessageV1(msg) + if err != nil { + return nil, err + } + // send the create collection message. + msgID, err := appendOp(ctx, msg) + if err != nil { + return msgID, err + } + + // Set up the partition manager for the collection, new incoming insert message can be assign segment. + h := createCollectionMsg.Header() + impl.assignManager.Get().NewCollection(h.GetCollectionId(), msg.VChannel(), h.GetPartitionIds()) + return msgID, nil +} + +// handleDropCollection handles the drop collection message. +func (impl *segmentInterceptor) handleDropCollection(ctx context.Context, msg message.MutableMessage, appendOp interceptors.Append) (message.MessageID, error) { + dropCollectionMessage, err := message.AsMutableDropCollectionMessageV1(msg) + if err != nil { + return nil, err + } + // Drop collections remove all partition managers from assignment service. + h := dropCollectionMessage.Header() + if err := impl.assignManager.Get().RemoveCollection(ctx, h.GetCollectionId()); err != nil { + return nil, err + } + + // send the drop collection message. + return appendOp(ctx, msg) +} + +// handleCreatePartition handles the create partition message. +func (impl *segmentInterceptor) handleCreatePartition(ctx context.Context, msg message.MutableMessage, appendOp interceptors.Append) (message.MessageID, error) { + createPartitionMessage, err := message.AsMutableCreatePartitionMessageV1(msg) + if err != nil { + return nil, err + } + // send the create collection message. + msgID, err := appendOp(ctx, msg) + if err != nil { + return msgID, err + } + + // Set up the partition manager for the collection, new incoming insert message can be assign segment. + h := createPartitionMessage.Header() + // error can never happens for wal lifetime control. + _ = impl.assignManager.Get().NewPartition(h.GetCollectionId(), h.GetPartitionId()) + return msgID, nil +} + +// handleDropPartition handles the drop partition message. +func (impl *segmentInterceptor) handleDropPartition(ctx context.Context, msg message.MutableMessage, appendOp interceptors.Append) (message.MessageID, error) { + dropPartitionMessage, err := message.AsMutableDropPartitionMessageV1(msg) + if err != nil { + return nil, err + } + + // drop partition, remove the partition manager from assignment service. + h := dropPartitionMessage.Header() + if err := impl.assignManager.Get().RemovePartition(ctx, h.GetCollectionId(), h.GetPartitionId()); err != nil { + return nil, err + } + + // send the create collection message. + return appendOp(ctx, msg) +} + +// handleInsertMessage handles the insert message. +func (impl *segmentInterceptor) handleInsertMessage(ctx context.Context, msg message.MutableMessage, appendOp interceptors.Append) (message.MessageID, error) { + insertMsg, err := message.AsMutableInsertMessageV1(msg) + if err != nil { + return nil, err + } + // Assign segment for insert message. + // Current implementation a insert message only has one parition, but we need to merge the message for partition-key in future. + header := insertMsg.Header() + for _, partition := range header.GetPartitions() { + result, err := impl.assignManager.Get().AssignSegment(ctx, &manager.AssignSegmentRequest{ + CollectionID: header.GetCollectionId(), + PartitionID: partition.GetPartitionId(), + InsertMetrics: stats.InsertMetrics{ + Rows: partition.GetRows(), + BinarySize: partition.GetBinarySize(), + }, + }) + if err != nil { + return nil, status.NewInner("segment assignment failure with error: %s", err.Error()) + } + // once the segment assignment is done, we need to ack the result, + // if other partitions failed to assign segment or wal write failure, + // the segment assignment will not rolled back for simple implementation. + defer result.Ack() + + // Attach segment assignment to message. + partition.SegmentAssignment = &message.SegmentAssignment{ + SegmentId: result.SegmentID, + } + } + // Update the insert message headers. + insertMsg.OverwriteHeader(header) + + return appendOp(ctx, msg) +} + +// Close closes the segment interceptor. +func (impl *segmentInterceptor) Close() { + // unregister the pchannels + resource.Resource().SegmentSealedInspector().UnregisterPChannelManager(impl.assignManager.Get()) + impl.assignManager.Get().Close(context.Background()) +} + +// recoverPChannelManager recovers PChannel Assignment Manager. +func (impl *segmentInterceptor) recoverPChannelManager(param interceptors.InterceptorBuildParam) { + timer := typeutil.NewBackoffTimer(typeutil.BackoffTimerConfig{ + Default: time.Second, + Backoff: typeutil.BackoffConfig{ + InitialInterval: 10 * time.Millisecond, + Multiplier: 2.0, + MaxInterval: time.Second, + }, + }) + timer.EnableBackoff() + for counter := 0; ; counter++ { + pm, err := manager.RecoverPChannelSegmentAllocManager(impl.ctx, param.WALImpls.Channel(), param.WAL) + if err != nil { + ch, d := timer.NextTimer() + impl.logger.Warn("recover PChannel Assignment Manager failed, wait a backoff", zap.Int("retry", counter), zap.Duration("nextRetryInterval", d), zap.Error(err)) + select { + case <-impl.ctx.Done(): + impl.logger.Info("segment interceptor has been closed", zap.Error(impl.ctx.Err())) + return + case <-ch: + continue + } + } + + // register the manager into inspector, to do the seal asynchronously + resource.Resource().SegmentSealedInspector().RegsiterPChannelManager(pm) + impl.assignManager.Set(pm) + impl.logger.Info("recover PChannel Assignment Manager success") + return + } +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/stats/signal_notifier.go b/internal/streamingnode/server/wal/interceptors/segment/stats/signal_notifier.go new file mode 100644 index 0000000000..6b6ae07815 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/stats/signal_notifier.go @@ -0,0 +1,49 @@ +package stats + +import ( + "sync" + + "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// NewSealSignalNotifier creates a new seal signal notifier. +func NewSealSignalNotifier() *SealSignalNotifier { + return &SealSignalNotifier{ + cond: syncutil.NewContextCond(&sync.Mutex{}), + signal: typeutil.NewSet[SegmentBelongs](), + } +} + +// SealSignalNotifier is a notifier for seal signal. +type SealSignalNotifier struct { + cond *syncutil.ContextCond + signal typeutil.Set[SegmentBelongs] +} + +// AddAndNotify adds a signal and notifies the waiter. +func (n *SealSignalNotifier) AddAndNotify(belongs SegmentBelongs) { + n.cond.LockAndBroadcast() + n.signal.Insert(belongs) + n.cond.L.Unlock() +} + +func (n *SealSignalNotifier) WaitChan() <-chan struct{} { + n.cond.L.Lock() + if n.signal.Len() > 0 { + n.cond.L.Unlock() + ch := make(chan struct{}) + close(ch) + return ch + } + return n.cond.WaitChan() +} + +// Get gets the signal. +func (n *SealSignalNotifier) Get() typeutil.Set[SegmentBelongs] { + n.cond.L.Lock() + signal := n.signal + n.signal = typeutil.NewSet[SegmentBelongs]() + n.cond.L.Unlock() + return signal +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/stats/stats.go b/internal/streamingnode/server/wal/interceptors/segment/stats/stats.go new file mode 100644 index 0000000000..7d284ddeb6 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/stats/stats.go @@ -0,0 +1,83 @@ +package stats + +import ( + "time" + + "github.com/milvus-io/milvus/internal/proto/streamingpb" +) + +// SegmentStats is the usage stats of a segment. +type SegmentStats struct { + Insert InsertMetrics + MaxBinarySize uint64 // MaxBinarySize of current segment should be assigned, it's a fixed value when segment is transfer int growing. + CreateTime time.Time // created timestamp of this segment, it's a fixed value when segment is created, not a tso. + LastModifiedTime time.Time // LastWriteTime is the last write time of this segment, it's not a tso, just a local time. + BinLogCounter uint64 // BinLogCounter is the counter of binlog, it's an async stat not real time. + ReachLimit bool // ReachLimit is a flag to indicate the segment reach the limit once. +} + +// NewSegmentStatFromProto creates a new segment assignment stat from proto. +func NewSegmentStatFromProto(statProto *streamingpb.SegmentAssignmentStat) *SegmentStats { + if statProto == nil { + return nil + } + return &SegmentStats{ + Insert: InsertMetrics{ + Rows: statProto.InsertedRows, + BinarySize: statProto.InsertedBinarySize, + }, + MaxBinarySize: statProto.MaxBinarySize, + CreateTime: time.Unix(0, statProto.CreateTimestampNanoseconds), + BinLogCounter: statProto.BinlogCounter, + LastModifiedTime: time.Unix(0, statProto.LastModifiedTimestampNanoseconds), + } +} + +// NewProtoFromSegmentStat creates a new proto from segment assignment stat. +func NewProtoFromSegmentStat(stat *SegmentStats) *streamingpb.SegmentAssignmentStat { + if stat == nil { + return nil + } + return &streamingpb.SegmentAssignmentStat{ + MaxBinarySize: stat.MaxBinarySize, + InsertedRows: stat.Insert.Rows, + InsertedBinarySize: stat.Insert.BinarySize, + CreateTimestampNanoseconds: stat.CreateTime.UnixNano(), + BinlogCounter: stat.BinLogCounter, + LastModifiedTimestampNanoseconds: stat.LastModifiedTime.UnixNano(), + } +} + +// FlushOperationMetrics is the metrics of flush operation. +type FlushOperationMetrics struct { + BinLogCounter uint64 +} + +// AllocRows alloc space of rows on current segment. +// Return true if the segment is assigned. +func (s *SegmentStats) AllocRows(m InsertMetrics) bool { + if m.BinarySize > s.BinaryCanBeAssign() { + s.ReachLimit = true + return false + } + + s.Insert.Collect(m) + s.LastModifiedTime = time.Now() + return true +} + +// BinaryCanBeAssign returns the capacity of binary size can be inserted. +func (s *SegmentStats) BinaryCanBeAssign() uint64 { + return s.MaxBinarySize - s.Insert.BinarySize +} + +// UpdateOnFlush updates the stats of segment on flush. +func (s *SegmentStats) UpdateOnFlush(f FlushOperationMetrics) { + s.BinLogCounter = f.BinLogCounter +} + +// Copy copies the segment stats. +func (s *SegmentStats) Copy() *SegmentStats { + s2 := *s + return &s2 +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager.go b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager.go new file mode 100644 index 0000000000..f571717728 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager.go @@ -0,0 +1,162 @@ +package stats + +import ( + "fmt" + "sync" +) + +// StatsManager is the manager of stats. +// It manages the insert stats of all segments, used to check if a segment has enough space to insert or should be sealed. +// If there will be a lock contention, we can optimize it by apply lock per segment. +type StatsManager struct { + mu sync.Mutex + totalStats InsertMetrics + pchannelStats map[string]*InsertMetrics + vchannelStats map[string]*InsertMetrics + segmentStats map[int64]*SegmentStats // map[SegmentID]SegmentStats + segmentIndex map[int64]SegmentBelongs // map[SegmentID]channels + sealNotifier *SealSignalNotifier +} + +type SegmentBelongs struct { + PChannel string + VChannel string + CollectionID int64 + PartitionID int64 +} + +// NewStatsManager creates a new stats manager. +func NewStatsManager() *StatsManager { + return &StatsManager{ + mu: sync.Mutex{}, + totalStats: InsertMetrics{}, + pchannelStats: make(map[string]*InsertMetrics), + vchannelStats: make(map[string]*InsertMetrics), + segmentStats: make(map[int64]*SegmentStats), + segmentIndex: make(map[int64]SegmentBelongs), + sealNotifier: NewSealSignalNotifier(), + } +} + +// RegisterNewGrowingSegment registers a new growing segment. +// delegate the stats management to stats manager. +func (m *StatsManager) RegisterNewGrowingSegment(belongs SegmentBelongs, segmentID int64, stats *SegmentStats) { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.segmentStats[segmentID]; ok { + panic(fmt.Sprintf("register a segment %d that already exist, critical bug", segmentID)) + } + + m.segmentStats[segmentID] = stats + m.segmentIndex[segmentID] = belongs + m.totalStats.Collect(stats.Insert) + if _, ok := m.pchannelStats[belongs.PChannel]; !ok { + m.pchannelStats[belongs.PChannel] = &InsertMetrics{} + } + m.pchannelStats[belongs.PChannel].Collect(stats.Insert) + + if _, ok := m.vchannelStats[belongs.VChannel]; !ok { + m.vchannelStats[belongs.VChannel] = &InsertMetrics{} + } + m.vchannelStats[belongs.VChannel].Collect(stats.Insert) +} + +// AllocRows alloc number of rows on current segment. +func (m *StatsManager) AllocRows(segmentID int64, insert InsertMetrics) bool { + m.mu.Lock() + defer m.mu.Unlock() + + // Must be exist, otherwise it's a bug. + info, ok := m.segmentIndex[segmentID] + if !ok { + panic(fmt.Sprintf("alloc rows on a segment %d that not exist", segmentID)) + } + inserted := m.segmentStats[segmentID].AllocRows(insert) + + // update the total stats if inserted. + if inserted { + m.totalStats.Collect(insert) + m.pchannelStats[info.PChannel].Collect(insert) + m.vchannelStats[info.VChannel].Collect(insert) + return true + } + + // If not inserted, current segment can not hold the message, notify seal manager to do seal the segment. + m.sealNotifier.AddAndNotify(info) + return false +} + +// SealNotifier returns the seal notifier. +func (m *StatsManager) SealNotifier() *SealSignalNotifier { + // no lock here, because it's read only. + return m.sealNotifier +} + +// GetStatsOfSegment gets the stats of segment. +func (m *StatsManager) GetStatsOfSegment(segmentID int64) *SegmentStats { + m.mu.Lock() + defer m.mu.Unlock() + return m.segmentStats[segmentID].Copy() +} + +// UpdateOnFlush updates the stats of segment on flush. +// It's an async update operation, so it's not necessary to do success. +func (m *StatsManager) UpdateOnFlush(segmentID int64, flush FlushOperationMetrics) { + m.mu.Lock() + defer m.mu.Unlock() + + // Must be exist, otherwise it's a bug. + if _, ok := m.segmentIndex[segmentID]; !ok { + return + } + m.segmentStats[segmentID].UpdateOnFlush(flush) + + // binlog counter is updated, notify seal manager to do seal scanning. + m.sealNotifier.AddAndNotify(m.segmentIndex[segmentID]) +} + +// UnregisterSealedSegment unregisters the sealed segment. +func (m *StatsManager) UnregisterSealedSegment(segmentID int64) *SegmentStats { + m.mu.Lock() + defer m.mu.Unlock() + + // Must be exist, otherwise it's a bug. + info, ok := m.segmentIndex[segmentID] + if !ok { + panic(fmt.Sprintf("unregister a segment %d that not exist, critical bug", segmentID)) + } + + stats := m.segmentStats[segmentID] + m.pchannelStats[info.PChannel].Subtract(stats.Insert) + m.vchannelStats[info.VChannel].Subtract(stats.Insert) + + m.totalStats.Collect(stats.Insert) + delete(m.segmentStats, segmentID) + delete(m.segmentIndex, segmentID) + if m.pchannelStats[info.PChannel].BinarySize == 0 { + delete(m.pchannelStats, info.PChannel) + } + if m.vchannelStats[info.VChannel].BinarySize == 0 { + delete(m.vchannelStats, info.VChannel) + } + return stats +} + +// InsertOpeatationMetrics is the metrics of insert operation. +type InsertMetrics struct { + Rows uint64 + BinarySize uint64 +} + +// Collect collects other metrics. +func (m *InsertMetrics) Collect(other InsertMetrics) { + m.Rows += other.Rows + m.BinarySize += other.BinarySize +} + +// Subtract subtract by other metrics. +func (m *InsertMetrics) Subtract(other InsertMetrics) { + m.Rows -= other.Rows + m.BinarySize -= other.BinarySize +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager_test.go b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager_test.go new file mode 100644 index 0000000000..0a01abbb7c --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_manager_test.go @@ -0,0 +1,114 @@ +package stats + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStatsManager(t *testing.T) { + m := NewStatsManager() + + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 2}, 3, createSegmentStats(100, 100, 300)) + assert.Len(t, m.segmentStats, 1) + assert.Len(t, m.vchannelStats, 1) + assert.Len(t, m.pchannelStats, 1) + assert.Len(t, m.segmentIndex, 1) + + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 3}, 4, createSegmentStats(100, 100, 300)) + assert.Len(t, m.segmentStats, 2) + assert.Len(t, m.segmentIndex, 2) + assert.Len(t, m.vchannelStats, 1) + assert.Len(t, m.pchannelStats, 1) + + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel2", CollectionID: 2, PartitionID: 4}, 5, createSegmentStats(100, 100, 300)) + assert.Len(t, m.segmentStats, 3) + assert.Len(t, m.segmentIndex, 3) + assert.Len(t, m.vchannelStats, 2) + assert.Len(t, m.pchannelStats, 1) + + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel2", VChannel: "vchannel3", CollectionID: 2, PartitionID: 5}, 6, createSegmentStats(100, 100, 300)) + assert.Len(t, m.segmentStats, 4) + assert.Len(t, m.segmentIndex, 4) + assert.Len(t, m.vchannelStats, 3) + assert.Len(t, m.pchannelStats, 2) + + assert.Panics(t, func() { + m.RegisterNewGrowingSegment(SegmentBelongs{PChannel: "pchannel", VChannel: "vchannel", CollectionID: 1, PartitionID: 2}, 3, createSegmentStats(100, 100, 300)) + }) + + shouldBlock(t, m.SealNotifier().WaitChan()) + + m.AllocRows(3, InsertMetrics{Rows: 50, BinarySize: 50}) + stat := m.GetStatsOfSegment(3) + assert.Equal(t, uint64(150), stat.Insert.BinarySize) + + shouldBlock(t, m.SealNotifier().WaitChan()) + m.AllocRows(5, InsertMetrics{Rows: 250, BinarySize: 250}) + <-m.SealNotifier().WaitChan() + infos := m.SealNotifier().Get() + assert.Len(t, infos, 1) + + m.AllocRows(6, InsertMetrics{Rows: 150, BinarySize: 150}) + shouldBlock(t, m.SealNotifier().WaitChan()) + + assert.Equal(t, uint64(250), m.vchannelStats["vchannel3"].BinarySize) + assert.Equal(t, uint64(100), m.vchannelStats["vchannel2"].BinarySize) + assert.Equal(t, uint64(250), m.vchannelStats["vchannel"].BinarySize) + + assert.Equal(t, uint64(350), m.pchannelStats["pchannel"].BinarySize) + assert.Equal(t, uint64(250), m.pchannelStats["pchannel2"].BinarySize) + + m.UpdateOnFlush(3, FlushOperationMetrics{BinLogCounter: 100}) + <-m.SealNotifier().WaitChan() + infos = m.SealNotifier().Get() + assert.Len(t, infos, 1) + m.UpdateOnFlush(1000, FlushOperationMetrics{BinLogCounter: 100}) + shouldBlock(t, m.SealNotifier().WaitChan()) + + m.AllocRows(3, InsertMetrics{Rows: 400, BinarySize: 400}) + m.AllocRows(5, InsertMetrics{Rows: 250, BinarySize: 250}) + m.AllocRows(6, InsertMetrics{Rows: 400, BinarySize: 400}) + <-m.SealNotifier().WaitChan() + infos = m.SealNotifier().Get() + assert.Len(t, infos, 3) + + m.UnregisterSealedSegment(3) + m.UnregisterSealedSegment(4) + m.UnregisterSealedSegment(5) + m.UnregisterSealedSegment(6) + assert.Empty(t, m.segmentStats) + assert.Empty(t, m.vchannelStats) + assert.Empty(t, m.pchannelStats) + assert.Empty(t, m.segmentIndex) + + assert.Panics(t, func() { + m.AllocRows(100, InsertMetrics{Rows: 100, BinarySize: 100}) + }) + assert.Panics(t, func() { + m.UnregisterSealedSegment(1) + }) +} + +func createSegmentStats(row uint64, binarySize uint64, maxBinarSize uint64) *SegmentStats { + return &SegmentStats{ + Insert: InsertMetrics{ + Rows: row, + BinarySize: binarySize, + }, + MaxBinarySize: maxBinarSize, + CreateTime: time.Now(), + LastModifiedTime: time.Now(), + BinLogCounter: 0, + } +} + +func shouldBlock(t *testing.T, ch <-chan struct{}) { + select { + case <-ch: + t.Errorf("should block but not") + case <-time.After(10 * time.Millisecond): + return + } +} diff --git a/internal/streamingnode/server/wal/interceptors/segment/stats/stats_test.go b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_test.go new file mode 100644 index 0000000000..bdef19f136 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/segment/stats/stats_test.go @@ -0,0 +1,75 @@ +package stats + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStatsConvention(t *testing.T) { + assert.Nil(t, NewProtoFromSegmentStat(nil)) + stat := &SegmentStats{ + Insert: InsertMetrics{ + Rows: 1, + BinarySize: 2, + }, + MaxBinarySize: 2, + CreateTime: time.Now(), + LastModifiedTime: time.Now(), + BinLogCounter: 3, + } + pb := NewProtoFromSegmentStat(stat) + assert.Equal(t, stat.MaxBinarySize, pb.MaxBinarySize) + assert.Equal(t, stat.Insert.Rows, pb.InsertedRows) + assert.Equal(t, stat.Insert.BinarySize, pb.InsertedBinarySize) + assert.Equal(t, stat.CreateTime.UnixNano(), pb.CreateTimestampNanoseconds) + assert.Equal(t, stat.LastModifiedTime.UnixNano(), pb.LastModifiedTimestampNanoseconds) + assert.Equal(t, stat.BinLogCounter, pb.BinlogCounter) + + stat2 := NewSegmentStatFromProto(pb) + assert.Equal(t, stat.MaxBinarySize, stat2.MaxBinarySize) + assert.Equal(t, stat.Insert.Rows, stat2.Insert.Rows) + assert.Equal(t, stat.Insert.BinarySize, stat2.Insert.BinarySize) + assert.Equal(t, stat.CreateTime.UnixNano(), stat2.CreateTime.UnixNano()) + assert.Equal(t, stat.LastModifiedTime.UnixNano(), stat2.LastModifiedTime.UnixNano()) + assert.Equal(t, stat.BinLogCounter, stat2.BinLogCounter) +} + +func TestSegmentStats(t *testing.T) { + now := time.Now() + stat := &SegmentStats{ + Insert: InsertMetrics{ + Rows: 100, + BinarySize: 200, + }, + MaxBinarySize: 400, + CreateTime: now, + LastModifiedTime: now, + BinLogCounter: 3, + } + + insert1 := InsertMetrics{ + Rows: 60, + BinarySize: 120, + } + inserted := stat.AllocRows(insert1) + assert.True(t, inserted) + assert.Equal(t, stat.Insert.Rows, uint64(160)) + assert.Equal(t, stat.Insert.BinarySize, uint64(320)) + assert.True(t, time.Now().After(now)) + + insert1 = InsertMetrics{ + Rows: 100, + BinarySize: 100, + } + inserted = stat.AllocRows(insert1) + assert.False(t, inserted) + assert.Equal(t, stat.Insert.Rows, uint64(160)) + assert.Equal(t, stat.Insert.BinarySize, uint64(320)) + + stat.UpdateOnFlush(FlushOperationMetrics{ + BinLogCounter: 4, + }) + assert.Equal(t, uint64(4), stat.BinLogCounter) +} diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go index 55f9be181d..efbebe451a 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go @@ -19,7 +19,7 @@ func TestAck(t *testing.T) { ctx := context.Background() rc := idalloc.NewMockRootCoordClient(t) - resource.InitForTest(resource.OptRootCoordClient(rc)) + resource.InitForTest(t, resource.OptRootCoordClient(rc)) ackManager := NewAckManager() msgID := mock_message.NewMockMessageID(t) diff --git a/internal/streamingnode/server/walmanager/manager_impl_test.go b/internal/streamingnode/server/walmanager/manager_impl_test.go index 8907676e15..74953326c3 100644 --- a/internal/streamingnode/server/walmanager/manager_impl_test.go +++ b/internal/streamingnode/server/walmanager/manager_impl_test.go @@ -30,7 +30,8 @@ func TestManager(t *testing.T) { flusher := mock_flusher.NewMockFlusher(t) flusher.EXPECT().RegisterPChannel(mock.Anything, mock.Anything).Return(nil) - resource.Init( + resource.InitForTest( + t, resource.OptFlusher(flusher), resource.OptRootCoordClient(rootcoord), resource.OptDataCoordClient(datacoord), diff --git a/internal/streamingnode/server/walmanager/wal_lifetime_test.go b/internal/streamingnode/server/walmanager/wal_lifetime_test.go index 7ce53ab393..d34bfe4f88 100644 --- a/internal/streamingnode/server/walmanager/wal_lifetime_test.go +++ b/internal/streamingnode/server/walmanager/wal_lifetime_test.go @@ -25,7 +25,8 @@ func TestWALLifetime(t *testing.T) { flusher.EXPECT().RegisterPChannel(mock.Anything, mock.Anything).Return(nil) flusher.EXPECT().UnregisterPChannel(mock.Anything).Return() - resource.Init( + resource.InitForTest( + t, resource.OptFlusher(flusher), resource.OptRootCoordClient(rootcoord), resource.OptDataCoordClient(datacoord), diff --git a/pkg/streaming/util/message/builder.go b/pkg/streaming/util/message/builder.go index b3ef5feb5e..38283d4990 100644 --- a/pkg/streaming/util/message/builder.go +++ b/pkg/streaming/util/message/builder.go @@ -43,6 +43,7 @@ var ( NewDropCollectionMessageBuilderV1 = createNewMessageBuilderV1[*DropCollectionMessageHeader, *msgpb.DropCollectionRequest]() NewCreatePartitionMessageBuilderV1 = createNewMessageBuilderV1[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest]() NewDropPartitionMessageBuilderV1 = createNewMessageBuilderV1[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest]() + NewFlushMessageBuilderV1 = createNewMessageBuilderV1[*FlushMessageHeader, *FlushMessagePayload]() ) // createNewMessageBuilderV1 creates a new message builder with v1 marker. diff --git a/pkg/streaming/util/message/messagepb/message.proto b/pkg/streaming/util/message/messagepb/message.proto index 9783691740..102338641c 100644 --- a/pkg/streaming/util/message/messagepb/message.proto +++ b/pkg/streaming/util/message/messagepb/message.proto @@ -17,8 +17,8 @@ option go_package = "github.com/milvus-io/milvus/pkg/streaming/util/message/mess /// 7. DropPartitionRequest /// -// FlushMessagePayload is the payload of flush message. -message FlushMessagePayload { +// FlushMessageBody is the body of flush message. +message FlushMessageBody { int64 collection_id = 1; // indicate which the collection that segment belong to. repeated int64 segment_id = 2; // indicate which segment to flush. @@ -63,7 +63,8 @@ message FlushMessageHeader {} // CreateCollectionMessageHeader is the header of create collection message. message CreateCollectionMessageHeader { - int64 collection_id = 1; + int64 collection_id = 1; + repeated int64 partition_ids = 2; } // DropCollectionMessageHeader is the header of drop collection message. diff --git a/pkg/streaming/util/message/specialized_message.go b/pkg/streaming/util/message/specialized_message.go index 54c8496e69..3a079e3834 100644 --- a/pkg/streaming/util/message/specialized_message.go +++ b/pkg/streaming/util/message/specialized_message.go @@ -21,6 +21,8 @@ type ( DropCollectionMessageHeader = messagepb.DropCollectionMessageHeader CreatePartitionMessageHeader = messagepb.CreatePartitionMessageHeader DropPartitionMessageHeader = messagepb.DropPartitionMessageHeader + FlushMessageHeader = messagepb.FlushMessageHeader + FlushMessagePayload = messagepb.FlushMessageBody ) // messageTypeMap maps the proto message type to the message type. @@ -32,6 +34,7 @@ var messageTypeMap = map[reflect.Type]MessageType{ reflect.TypeOf(&DropCollectionMessageHeader{}): MessageTypeDropCollection, reflect.TypeOf(&CreatePartitionMessageHeader{}): MessageTypeCreatePartition, reflect.TypeOf(&DropPartitionMessageHeader{}): MessageTypeDropPartition, + reflect.TypeOf(&FlushMessageHeader{}): MessageTypeFlush, } // List all specialized message types. diff --git a/pkg/streaming/util/message/test_case.go b/pkg/streaming/util/message/test_case.go index 25ce304e58..3b6e655e4a 100644 --- a/pkg/streaming/util/message/test_case.go +++ b/pkg/streaming/util/message/test_case.go @@ -104,6 +104,7 @@ func CreateTestInsertMessage(t *testing.T, segmentID int64, totalRows int, timet func CreateTestCreateCollectionMessage(t *testing.T, collectionID int64, timetick uint64, messageID MessageID) MutableMessage { header := &CreateCollectionMessageHeader{ CollectionId: collectionID, + PartitionIds: []int64{2}, } payload := &msgpb.CreateCollectionRequest{ Base: &commonpb.MsgBase{ @@ -132,7 +133,16 @@ func CreateTestCreateCollectionMessage(t *testing.T, collectionID int64, timetic // CreateTestEmptyInsertMesage creates an empty insert message for testing func CreateTestEmptyInsertMesage(msgID int64, extraProperties map[string]string) MutableMessage { msg, err := NewInsertMessageBuilderV1(). - WithHeader(&InsertMessageHeader{}). + WithHeader(&InsertMessageHeader{ + CollectionId: 1, + Partitions: []*PartitionSegmentAssignment{ + { + PartitionId: 2, + Rows: 1000, + BinarySize: 1024 * 1024, + }, + }, + }). WithBody(&msgpb.InsertRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Insert, @@ -144,5 +154,5 @@ func CreateTestEmptyInsertMesage(msgID int64, extraProperties map[string]string) if err != nil { panic(err) } - return msg + return msg.WithVChannel("v1") } diff --git a/pkg/util/syncutil/context_condition_variable.go b/pkg/util/syncutil/context_condition_variable.go index 211253860a..5ca5a4a405 100644 --- a/pkg/util/syncutil/context_condition_variable.go +++ b/pkg/util/syncutil/context_condition_variable.go @@ -63,6 +63,18 @@ func (cv *ContextCond) Wait(ctx context.Context) error { return nil } +// WaitChan returns a channel that can be used to wait for a broadcast. +// Should be called after Lock. +// The channel is closed when a broadcast is received. +func (cv *ContextCond) WaitChan() <-chan struct{} { + if cv.ch == nil { + cv.ch = make(chan struct{}) + } + ch := cv.ch + cv.L.Unlock() + return ch +} + // noCopy may be added to structs which must not be copied // after the first use. // diff --git a/pkg/util/syncutil/context_condition_variable_test.go b/pkg/util/syncutil/context_condition_variable_test.go index 7988078478..4480283632 100644 --- a/pkg/util/syncutil/context_condition_variable_test.go +++ b/pkg/util/syncutil/context_condition_variable_test.go @@ -13,7 +13,7 @@ func TestContextCond(t *testing.T) { cv := NewContextCond(&sync.Mutex{}) cv.L.Lock() go func() { - time.Sleep(1 * time.Second) + time.Sleep(10 * time.Millisecond) cv.LockAndBroadcast() cv.L.Unlock() }() @@ -23,7 +23,7 @@ func TestContextCond(t *testing.T) { cv.L.Lock() go func() { - time.Sleep(1 * time.Second) + time.Sleep(20 * time.Millisecond) cv.LockAndBroadcast() cv.L.Unlock() }() diff --git a/pkg/util/syncutil/versioned_notifier.go b/pkg/util/syncutil/versioned_notifier.go index c1e48e134a..61bcb11250 100644 --- a/pkg/util/syncutil/versioned_notifier.go +++ b/pkg/util/syncutil/versioned_notifier.go @@ -78,3 +78,30 @@ func (vl *VersionedListener) Wait(ctx context.Context) error { vl.inner.cond.L.Unlock() return nil } + +// WaitChan returns a channel that will be closed when the next notification comes. +// Use Sync to sync the listener to the latest version to avoid redundant notify. +// +// ch := vl.WaitChan() +// <-ch +// vl.Sync() +// ... make use of the notification ... +func (vl *VersionedListener) WaitChan() <-chan struct{} { + vl.inner.cond.L.Lock() + // Return a closed channel if the version is newer than the last notified version. + if vl.lastNotifiedVersion < vl.inner.version { + vl.lastNotifiedVersion = vl.inner.version + vl.inner.cond.L.Unlock() + ch := make(chan struct{}) + close(ch) + return ch + } + return vl.inner.cond.WaitChan() +} + +// Sync syncs the listener to the latest version. +func (vl *VersionedListener) Sync() { + vl.inner.cond.L.Lock() + vl.lastNotifiedVersion = vl.inner.version + vl.inner.cond.L.Unlock() +} diff --git a/pkg/util/syncutil/versioned_notifier_test.go b/pkg/util/syncutil/versioned_notifier_test.go index 98497fd1d5..161df09eb7 100644 --- a/pkg/util/syncutil/versioned_notifier_test.go +++ b/pkg/util/syncutil/versioned_notifier_test.go @@ -13,6 +13,7 @@ func TestLatestVersionedNotifier(t *testing.T) { // Create a listener at the latest version listener := vn.Listen(VersionedListenAtLatest) + useWaitChanListener := vn.Listen(VersionedListenAtLatest) // Start a goroutine to wait for the notification done := make(chan struct{}) @@ -24,8 +25,15 @@ func TestLatestVersionedNotifier(t *testing.T) { close(done) }() + done2 := make(chan struct{}) + go func() { + ch := useWaitChanListener.WaitChan() + <-ch + close(done2) + }() + // Should be blocked. - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() select { case <-done: @@ -38,6 +46,7 @@ func TestLatestVersionedNotifier(t *testing.T) { // Wait for the goroutine to finish <-done + <-done2 } func TestEarliestVersionedNotifier(t *testing.T) { @@ -45,6 +54,7 @@ func TestEarliestVersionedNotifier(t *testing.T) { // Create a listener at the latest version listener := vn.Listen(VersionedListenAtEarliest) + useWaitChanListener := vn.Listen(VersionedListenAtLatest) // Should be non-blocked. err := listener.Wait(context.Background()) @@ -60,21 +70,50 @@ func TestEarliestVersionedNotifier(t *testing.T) { close(done) }() + done2 := make(chan struct{}) + go func() { + ch := useWaitChanListener.WaitChan() + <-ch + close(done2) + }() + // Should be blocked. - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() select { case <-done: t.Errorf("Wait returned before NotifyAll") + case <-done2: + t.Errorf("WaitChan returned before NotifyAll") case <-ctx.Done(): } + + // Notify all listeners + vn.NotifyAll() + + // Wait for the goroutine to finish + <-done + <-done2 + + // should not be blocked + useWaitChanListener = vn.Listen(VersionedListenAtEarliest) + <-useWaitChanListener.WaitChan() + + // should blocked + useWaitChanListener = vn.Listen(VersionedListenAtEarliest) + useWaitChanListener.Sync() + select { + case <-time.After(10 * time.Millisecond): + case <-useWaitChanListener.WaitChan(): + t.Errorf("WaitChan returned before NotifyAll") + } } func TestTimeoutListeningVersionedNotifier(t *testing.T) { vn := NewVersionedNotifier() listener := vn.Listen(VersionedListenAtLatest) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() err := listener.Wait(ctx) assert.Error(t, err) diff --git a/pkg/util/typeutil/backoff_timer.go b/pkg/util/typeutil/backoff_timer.go new file mode 100644 index 0000000000..dd26b136fe --- /dev/null +++ b/pkg/util/typeutil/backoff_timer.go @@ -0,0 +1,96 @@ +package typeutil + +import ( + "time" + + "github.com/cenkalti/backoff/v4" +) + +var _ BackoffTimerConfigFetcher = BackoffTimerConfig{} + +// BackoffTimerConfigFetcher is the interface to fetch backoff timer configuration +type BackoffTimerConfigFetcher interface { + DefaultInterval() time.Duration + BackoffConfig() BackoffConfig +} + +// BackoffTimerConfig is the configuration for backoff timer +// It's also used to be const config fetcher. +// Every DefaultInterval is a fetch loop. +type BackoffTimerConfig struct { + Default time.Duration + Backoff BackoffConfig +} + +// BackoffConfig is the configuration for backoff +type BackoffConfig struct { + InitialInterval time.Duration + Multiplier float64 + MaxInterval time.Duration +} + +func (c BackoffTimerConfig) DefaultInterval() time.Duration { + return c.Default +} + +func (c BackoffTimerConfig) BackoffConfig() BackoffConfig { + return c.Backoff +} + +// NewBackoffTimer creates a new balanceTimer +func NewBackoffTimer(configFetcher BackoffTimerConfigFetcher) *BackoffTimer { + return &BackoffTimer{ + configFetcher: configFetcher, + backoff: nil, + } +} + +// BackoffTimer is a timer for balance operation +type BackoffTimer struct { + configFetcher BackoffTimerConfigFetcher + backoff *backoff.ExponentialBackOff +} + +// EnableBackoff enables the backoff +func (t *BackoffTimer) EnableBackoff() { + if t.backoff == nil { + cfg := t.configFetcher.BackoffConfig() + defaultInterval := t.configFetcher.DefaultInterval() + backoff := backoff.NewExponentialBackOff() + backoff.InitialInterval = cfg.InitialInterval + backoff.Multiplier = cfg.Multiplier + backoff.MaxInterval = cfg.MaxInterval + backoff.MaxElapsedTime = defaultInterval + backoff.Stop = defaultInterval + backoff.Reset() + t.backoff = backoff + } +} + +// DisableBackoff disables the backoff +func (t *BackoffTimer) DisableBackoff() { + t.backoff = nil +} + +// IsBackoffStopped returns the elapsed time of backoff +func (t *BackoffTimer) IsBackoffStopped() bool { + if t.backoff != nil { + return t.backoff.GetElapsedTime() > t.backoff.MaxElapsedTime + } + return true +} + +// NextTimer returns the next timer and the duration of the timer +func (t *BackoffTimer) NextTimer() (<-chan time.Time, time.Duration) { + nextBackoff := t.NextInterval() + return time.After(nextBackoff), nextBackoff +} + +// NextInterval returns the next interval +func (t *BackoffTimer) NextInterval() time.Duration { + // if the backoff is enabled, use backoff + if t.backoff != nil { + return t.backoff.NextBackOff() + } + return t.configFetcher.DefaultInterval() +} diff --git a/pkg/util/typeutil/backoff_timer_test.go b/pkg/util/typeutil/backoff_timer_test.go new file mode 100644 index 0000000000..ddc11c933e --- /dev/null +++ b/pkg/util/typeutil/backoff_timer_test.go @@ -0,0 +1,44 @@ +package typeutil + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBackoffTimer(t *testing.T) { + b := NewBackoffTimer(BackoffTimerConfig{ + Default: time.Second, + Backoff: BackoffConfig{ + InitialInterval: 50 * time.Millisecond, + Multiplier: 2, + MaxInterval: 200 * time.Millisecond, + }, + }) + + for i := 0; i < 2; i++ { + assert.Equal(t, time.Second, b.NextInterval()) + assert.Equal(t, time.Second, b.NextInterval()) + assert.Equal(t, time.Second, b.NextInterval()) + assert.True(t, b.IsBackoffStopped()) + + b.EnableBackoff() + assert.False(t, b.IsBackoffStopped()) + timer, backoff := b.NextTimer() + assert.Less(t, backoff, 200*time.Millisecond) + for { + <-timer + if b.IsBackoffStopped() { + break + } + timer, _ = b.NextTimer() + } + assert.True(t, b.IsBackoffStopped()) + + assert.Equal(t, time.Second, b.NextInterval()) + b.DisableBackoff() + assert.Equal(t, time.Second, b.NextInterval()) + assert.True(t, b.IsBackoffStopped()) + } +}