diff --git a/internal/streamingnode/server/wal/interceptors/lock/builder.go b/internal/streamingnode/server/wal/interceptors/lock/builder.go new file mode 100644 index 0000000000..6bc933474a --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/lock/builder.go @@ -0,0 +1,23 @@ +package lock + +import ( + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + "github.com/milvus-io/milvus/pkg/v2/util/lock" +) + +// NewInterceptorBuilder creates a new redo interceptor builder. +// TODO: add it into wal after recovery storage is merged. +func NewInterceptorBuilder() interceptors.InterceptorBuilder { + return &interceptorBuilder{} +} + +// interceptorBuilder is the builder for redo interceptor. +type interceptorBuilder struct{} + +// Build creates a new redo interceptor. +func (b *interceptorBuilder) Build(param *interceptors.InterceptorBuildParam) interceptors.Interceptor { + return &lockAppendInterceptor{ + vchannelLocker: lock.NewKeyLock[string](), + // TODO: txnManager will be intiailized by param txnManager: param.TxnManager, + } +} diff --git a/internal/streamingnode/server/wal/interceptors/lock/lock_interceptor.go b/internal/streamingnode/server/wal/interceptors/lock/lock_interceptor.go new file mode 100644 index 0000000000..d8f0c63b84 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/lock/lock_interceptor.go @@ -0,0 +1,50 @@ +package lock + +import ( + "context" + + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/util/lock" +) + +type lockAppendInterceptor struct { + vchannelLocker *lock.KeyLock[string] + txnManager *txn.TxnManager +} + +func (r *lockAppendInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, append interceptors.Append) (message.MessageID, error) { + g := r.acquireLockGuard(ctx, msg) + defer g() + + return append(ctx, msg) +} + +// acquireLockGuard acquires the lock for the vchannel and return a function as a guard. +func (r *lockAppendInterceptor) acquireLockGuard(_ context.Context, msg message.MutableMessage) func() { + // Acquire the write lock for the vchannel. + vchannel := msg.VChannel() + if msg.MessageType().IsExclusiveRequired() { + r.vchannelLocker.Lock(vchannel) + return func() { + // For exclusive messages, we need to fail all transactions at the vchannel. + // Otherwise, the transaction message may cross the exclusive message. + // e.g. an exclusive message like `ManualFlush` happens, it will flush all the growing segment. + // But the transaction insert message that use those segments may not be committed, + // if we allow it to be committed, a insert message can be seen after the manual flush message, lead to the wrong wal message order. + // So we need to fail all transactions at the vchannel, it will be retried at client side with new txn. + // + // the append operation of exclusive message should be low rate, so it's acceptable to fail all transactions at the vchannel. + r.txnManager.FailTxnAtVChannel(vchannel) + r.vchannelLocker.Unlock(vchannel) + } + } + r.vchannelLocker.RLock(vchannel) + return func() { + r.vchannelLocker.RUnlock(vchannel) + } +} + +// Close the interceptor release all the resources. +func (r *lockAppendInterceptor) Close() {} diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go index 33365c077e..ce32b9588e 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go @@ -21,6 +21,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -112,8 +113,16 @@ func TestSegmentAllocManager(t *testing.T) { assert.True(t, m.IsNoWaitSeal()) // result2 is acked, so new seal segment will be sealed right away. // interactive with txn - txnManager := txn.NewTxnManager(types.PChannelInfo{Name: "test"}) - txn, err := txnManager.BeginNewTxn(context.Background(), tsoutil.GetCurrentTime(), time.Second) + txnManager := txn.NewTxnManager(types.PChannelInfo{Name: "test"}, nil) + msg := message.NewBeginTxnMessageBuilderV2(). + WithVChannel("v1"). + WithHeader(&message.BeginTxnMessageHeader{KeepaliveMilliseconds: 1000}). + WithBody(&message.BeginTxnMessageBody{}). + MustBuildMutable(). + WithTimeTick(tsoutil.GetCurrentTime()) + + beginTxnMsg, _ := message.AsMutableBeginTxnMessageV2(msg) + txn, err := txnManager.BeginNewTxn(ctx, beginTxnMsg) assert.NoError(t, err) txn.BeginDone() diff --git a/internal/streamingnode/server/wal/interceptors/timetick/builder.go b/internal/streamingnode/server/wal/interceptors/timetick/builder.go index aab721febd..e88f7b02aa 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/builder.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/builder.go @@ -23,8 +23,10 @@ func (b *interceptorBuilder) Build(param *interceptors.InterceptorBuildParam) in operator := newTimeTickSyncOperator(param) // initialize operation can be async to avoid block the build operation. resource.Resource().TimeTickInspector().RegisterSyncOperator(operator) + return &timeTickAppendInterceptor{ - operator: operator, - txnManager: txn.NewTxnManager(param.ChannelInfo), + operator: operator, + // TODO: it's just a placeholder, should be replaced after recovery storage is merged. + txnManager: txn.NewTxnManager(param.ChannelInfo, nil), } } diff --git a/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go b/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go index faf916d8d5..82716ba06e 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go @@ -145,7 +145,7 @@ func (impl *timeTickAppendInterceptor) handleBegin(ctx context.Context, msg mess return nil, nil, err } // Begin transaction will generate a txn context. - session, err := impl.txnManager.BeginNewTxn(ctx, msg.TimeTick(), time.Duration(beginTxnMsg.Header().KeepaliveMilliseconds)*time.Millisecond) + session, err := impl.txnManager.BeginNewTxn(ctx, beginTxnMsg) if err != nil { session.BeginRollback() return nil, nil, err diff --git a/internal/streamingnode/server/wal/interceptors/txn/session.go b/internal/streamingnode/server/wal/interceptors/txn/session.go index 80afff6d0e..702c8c8b00 100644 --- a/internal/streamingnode/server/wal/interceptors/txn/session.go +++ b/internal/streamingnode/server/wal/interceptors/txn/session.go @@ -14,10 +14,30 @@ type txnSessionKeyType int var txnSessionKeyValue txnSessionKeyType = 1 +// newTxnSession creates a new transaction session. +func newTxnSession( + vchannel string, + txnContext message.TxnContext, + timetick uint64, + metricsGuard *metricsutil.TxnMetricsGuard, +) *TxnSession { + return &TxnSession{ + mu: sync.Mutex{}, + vchannel: vchannel, + lastTimetick: timetick, + txnContext: txnContext, + inFlightCount: 0, + state: message.TxnStateBegin, + doneWait: nil, + rollback: false, + metricsGuard: metricsGuard, + } +} + // TxnSession is a session for a transaction. type TxnSession struct { - mu sync.Mutex - + mu sync.Mutex + vchannel string // The vchannel of the session. lastTimetick uint64 // session last timetick. expired bool // The flag indicates the transaction has trigger expired once. txnContext message.TxnContext // transaction id of the session @@ -29,6 +49,11 @@ type TxnSession struct { metricsGuard *metricsutil.TxnMetricsGuard // The metrics guard for the session. } +// VChannel returns the vchannel of the session. +func (s *TxnSession) VChannel() string { + return s.vchannel +} + // TxnContext returns the txn context of the session. func (s *TxnSession) TxnContext() message.TxnContext { return s.txnContext diff --git a/internal/streamingnode/server/wal/interceptors/txn/session_test.go b/internal/streamingnode/server/wal/interceptors/txn/session_test.go index 7e07821d4b..334818a462 100644 --- a/internal/streamingnode/server/wal/interceptors/txn/session_test.go +++ b/internal/streamingnode/server/wal/interceptors/txn/session_test.go @@ -9,11 +9,13 @@ import ( "github.com/stretchr/testify/assert" "go.uber.org/atomic" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/v2/proto/streamingpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" + "github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" ) @@ -27,8 +29,11 @@ func TestSession(t *testing.T) { resource.InitForTest(t) ctx := context.Background() - m := NewTxnManager(types.PChannelInfo{Name: "test"}) - session, err := m.BeginNewTxn(ctx, 0, 10*time.Millisecond) + m := NewTxnManager(types.PChannelInfo{Name: "test"}, nil) + <-m.RecoverDone() + session, err := m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond)) + assert.Equal(t, session.VChannel(), "v1") + assert.Equal(t, session.State(), message.TxnStateBegin) assert.NotNil(t, session) assert.NoError(t, err) @@ -41,7 +46,7 @@ func TestSession(t *testing.T) { assert.Equal(t, message.TxnStateRollbacked, session.state) assert.True(t, session.IsExpiredOrDone(0)) - session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond) + session, err = m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond)) assert.NoError(t, err) session.BeginDone() assert.Equal(t, message.TxnStateInFlight, session.state) @@ -59,7 +64,7 @@ func TestSession(t *testing.T) { serr = status.AsStreamingError(err) assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code) - session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond) + session, err = m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond)) assert.NoError(t, err) session.BeginDone() assert.NoError(t, err) @@ -75,7 +80,7 @@ func TestSession(t *testing.T) { assert.Equal(t, message.TxnStateCommitted, session.state) // Test Commit timeout. - session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond) + session, err = m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond)) assert.NoError(t, err) session.BeginDone() err = session.AddNewMessage(ctx, 0) @@ -94,7 +99,7 @@ func TestSession(t *testing.T) { assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code) // Test Rollback - session, _ = m.BeginNewTxn(context.Background(), 0, 10*time.Millisecond) + session, _ = m.BeginNewTxn(context.Background(), newBeginTxnMessage(0, 10*time.Millisecond)) session.BeginDone() // Rollback expired. err = session.RequestRollback(context.Background(), expiredTs) @@ -103,7 +108,7 @@ func TestSession(t *testing.T) { assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code) // Rollback success - session, _ = m.BeginNewTxn(context.Background(), 0, 10*time.Millisecond) + session, _ = m.BeginNewTxn(context.Background(), newBeginTxnMessage(0, 10*time.Millisecond)) session.BeginDone() err = session.RequestRollback(context.Background(), 0) assert.NoError(t, err) @@ -112,7 +117,7 @@ func TestSession(t *testing.T) { func TestManager(t *testing.T) { resource.InitForTest(t) - m := NewTxnManager(types.PChannelInfo{Name: "test"}) + m := NewTxnManager(types.PChannelInfo{Name: "test"}, nil) wg := &sync.WaitGroup{} @@ -121,7 +126,7 @@ func TestManager(t *testing.T) { for i := 0; i < 20; i++ { go func(i int) { defer wg.Done() - session, err := m.BeginNewTxn(context.Background(), 0, time.Duration(i+1)*time.Millisecond) + session, err := m.BeginNewTxn(context.Background(), newBeginTxnMessage(0, time.Duration(i+1)*time.Millisecond)) assert.NoError(t, err) assert.NotNil(t, session) session.BeginDone() @@ -176,6 +181,48 @@ func TestManager(t *testing.T) { assert.Equal(t, int32(0), count.Load()) } +func TestManagerRecoverAndFailAll(t *testing.T) { + resource.InitForTest(t) + now := time.Now() + beginMsg1 := newImmutableBeginTxnMessageWithVChannel("v1", 1, tsoutil.ComposeTSByTime(now, 0), 10*time.Millisecond) + beginMsg2 := newImmutableBeginTxnMessageWithVChannel("v2", 2, tsoutil.ComposeTSByTime(now, 1), 10*time.Millisecond) + beginMsg3 := newImmutableBeginTxnMessageWithVChannel("v1", 3, tsoutil.ComposeTSByTime(now, 2), 10*time.Millisecond) + builders := map[message.TxnID]*message.ImmutableTxnMessageBuilder{ + message.TxnID(1): message.NewImmutableTxnMessageBuilder(beginMsg1), + message.TxnID(2): message.NewImmutableTxnMessageBuilder(beginMsg2), + message.TxnID(3): message.NewImmutableTxnMessageBuilder(beginMsg3), + } + + builders[message.TxnID(1)].Add(message.NewInsertMessageBuilderV1(). + WithVChannel("v1"). + WithHeader(&message.InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{}). + MustBuildMutable(). + WithTimeTick(tsoutil.ComposeTSByTime(now, 3)). + WithLastConfirmedUseMessageID(). + WithTxnContext(message.TxnContext{ + TxnID: message.TxnID(1), + Keepalive: 10 * time.Millisecond, + }). + IntoImmutableMessage(rmq.NewRmqID(1))) + + m := NewTxnManager(types.PChannelInfo{Name: "test"}, builders) + select { + case <-m.RecoverDone(): + t.Errorf("txn manager should not be recovered") + case <-time.After(1 * time.Millisecond): + } + + m.FailTxnAtVChannel("v1") + select { + case <-m.RecoverDone(): + t.Errorf("txn manager should not be recovered") + case <-time.After(1 * time.Millisecond): + } + m.FailTxnAtVChannel("v2") + <-m.RecoverDone() +} + func TestWithContext(t *testing.T) { session := &TxnSession{} ctx := WithTxnSession(context.Background(), session) @@ -183,3 +230,38 @@ func TestWithContext(t *testing.T) { session = GetTxnSessionFromContext(ctx) assert.NotNil(t, session) } + +func newBeginTxnMessage(timetick uint64, keepalive time.Duration) message.MutableBeginTxnMessageV2 { + return newBeginTxnMessageWithVChannel("v1", timetick, keepalive) +} + +func newBeginTxnMessageWithVChannel(vchannel string, timetick uint64, keepalive time.Duration) message.MutableBeginTxnMessageV2 { + msg := message.NewBeginTxnMessageBuilderV2(). + WithVChannel(vchannel). + WithHeader(&message.BeginTxnMessageHeader{KeepaliveMilliseconds: keepalive.Milliseconds()}). + WithBody(&message.BeginTxnMessageBody{}). + MustBuildMutable(). + WithTimeTick(timetick) + + beginTxnMsg, _ := message.AsMutableBeginTxnMessageV2(msg) + return beginTxnMsg +} + +func newImmutableBeginTxnMessageWithVChannel(vchannel string, txnID int64, timetick uint64, keepalive time.Duration) message.ImmutableBeginTxnMessageV2 { + msg := message.NewBeginTxnMessageBuilderV2(). + WithVChannel(vchannel). + WithHeader(&message.BeginTxnMessageHeader{ + KeepaliveMilliseconds: keepalive.Milliseconds(), + }). + WithBody(&message.BeginTxnMessageBody{}). + MustBuildMutable(). + WithTimeTick(timetick). + WithLastConfirmed(rmq.NewRmqID(1)). + WithLastConfirmedUseMessageID(). + WithTxnContext(message.TxnContext{ + TxnID: message.TxnID(txnID), + Keepalive: keepalive, + }). + IntoImmutableMessage(rmq.NewRmqID(1)) + return message.MustAsImmutableBeginTxnMessageV2(msg) +} diff --git a/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go b/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go index e965d9c3f4..f4676a3250 100644 --- a/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go +++ b/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go @@ -18,31 +18,69 @@ import ( ) // NewTxnManager creates a new transaction manager. -func NewTxnManager(pchannel types.PChannelInfo) *TxnManager { - return &TxnManager{ - mu: sync.Mutex{}, - sessions: make(map[message.TxnID]*TxnSession), - closed: nil, - metrics: metricsutil.NewTxnMetrics(pchannel.Name), - logger: resource.Resource().Logger().With(log.FieldComponent("txn-manager")), +// incoming buffer is used to recover the uncommitted messages for txn manager. +func NewTxnManager(pchannel types.PChannelInfo, uncommittedTxnBuilders map[message.TxnID]*message.ImmutableTxnMessageBuilder) *TxnManager { + m := metricsutil.NewTxnMetrics(pchannel.Name) + sessions := make(map[message.TxnID]*TxnSession, len(uncommittedTxnBuilders)) + recoveredSessions := make(map[message.TxnID]struct{}, len(uncommittedTxnBuilders)) + sessionIDs := make([]int64, 0, len(uncommittedTxnBuilders)) + for _, builder := range uncommittedTxnBuilders { + beginMessages, body := builder.Messages() + session := newTxnSession( + beginMessages.VChannel(), + *beginMessages.TxnContext(), // must be the txn message. + beginMessages.TimeTick(), + m.BeginTxn(), + ) + for _, msg := range body { + session.AddNewMessage(context.Background(), msg.TimeTick()) + session.AddNewMessageDoneAndKeepalive(msg.TimeTick()) + } + sessions[session.TxnContext().TxnID] = session + recoveredSessions[session.TxnContext().TxnID] = struct{}{} + sessionIDs = append(sessionIDs, int64(session.TxnContext().TxnID)) } + txnManager := &TxnManager{ + mu: sync.Mutex{}, + recoveredSessions: recoveredSessions, + recoveredSessionsDoneChan: make(chan struct{}), + sessions: sessions, + closed: nil, + metrics: m, + } + txnManager.notifyRecoverDone() + txnManager.SetLogger(resource.Resource().Logger().With(log.FieldComponent("txn-manager"))) + txnManager.Logger().Info("txn manager recovered with txn", zap.Int64s("txnIDs", sessionIDs)) + return txnManager } // TxnManager is the manager of transactions. // We don't support cross wal transaction by now and // We don't support the transaction lives after the wal transferred to another streaming node. type TxnManager struct { - mu sync.Mutex - sessions map[message.TxnID]*TxnSession - closed lifetime.SafeChan - metrics *metricsutil.TxnMetrics - logger *log.MLogger + log.Binder + + mu sync.Mutex + recoveredSessions map[message.TxnID]struct{} + recoveredSessionsDoneChan chan struct{} + sessions map[message.TxnID]*TxnSession + closed lifetime.SafeChan + metrics *metricsutil.TxnMetrics +} + +// RecoverDone returns a channel that is closed when all transactions are cleaned up. +func (m *TxnManager) RecoverDone() <-chan struct{} { + return m.recoveredSessionsDoneChan } // BeginNewTxn starts a new transaction with a session. // We only support a transaction work on a streaming node, once the wal is transferred to another node, // the transaction is treated as expired (rollback), and user will got a expired error, then perform a retry. -func (m *TxnManager) BeginNewTxn(ctx context.Context, timetick uint64, keepalive time.Duration) (*TxnSession, error) { +func (m *TxnManager) BeginNewTxn(ctx context.Context, msg message.MutableBeginTxnMessageV2) (*TxnSession, error) { + timetick := msg.TimeTick() + vchannel := msg.VChannel() + keepalive := time.Duration(msg.Header().KeepaliveMilliseconds) * time.Millisecond + if keepalive == 0 { // If keepalive is 0, the txn set the keepalive with default keepalive. keepalive = paramtable.Get().StreamingCfg.TxnDefaultKeepaliveTimeout.GetAsDurationByParse() @@ -62,24 +100,35 @@ func (m *TxnManager) BeginNewTxn(ctx context.Context, timetick uint64, keepalive if m.closed != nil { return nil, status.NewTransactionExpired("manager closed") } - metricsGuard := m.metrics.BeginTxn() - session := &TxnSession{ - mu: sync.Mutex{}, - lastTimetick: timetick, - txnContext: message.TxnContext{ - TxnID: message.TxnID(id), - Keepalive: keepalive, - }, - inFlightCount: 0, - state: message.TxnStateBegin, - doneWait: nil, - rollback: false, - metricsGuard: metricsGuard, + txnCtx := message.TxnContext{ + TxnID: message.TxnID(id), + Keepalive: keepalive, } + session := newTxnSession(vchannel, txnCtx, timetick, m.metrics.BeginTxn()) m.sessions[session.TxnContext().TxnID] = session return session, nil } +// FailTxnAtVChannel fails all transactions at the specified vchannel. +func (m *TxnManager) FailTxnAtVChannel(vchannel string) { + // avoid the txn to be committed. + m.mu.Lock() + defer m.mu.Unlock() + ids := make([]int64, 0, len(m.sessions)) + for id, session := range m.sessions { + if session.VChannel() == vchannel { + session.Cleanup() + delete(m.sessions, id) + delete(m.recoveredSessions, id) + ids = append(ids, int64(id)) + } + } + if len(ids) > 0 { + m.Logger().Info("transaction interrupted", zap.String("vchannel", vchannel), zap.Int64s("txnIDs", ids)) + } + m.notifyRecoverDone() +} + // CleanupTxnUntil cleans up the transactions until the specified timestamp. func (m *TxnManager) CleanupTxnUntil(ts uint64) { m.mu.Lock() @@ -89,6 +138,7 @@ func (m *TxnManager) CleanupTxnUntil(ts uint64) { if session.IsExpiredOrDone(ts) { session.Cleanup() delete(m.sessions, id) + delete(m.recoveredSessions, id) } } @@ -96,6 +146,16 @@ func (m *TxnManager) CleanupTxnUntil(ts uint64) { if len(m.sessions) == 0 && m.closed != nil { m.closed.Close() } + + m.notifyRecoverDone() +} + +// notifyRecoverDone notifies the recover done channel if all transactions from recover info is done. +func (m *TxnManager) notifyRecoverDone() { + if len(m.recoveredSessions) == 0 && m.recoveredSessions != nil { + close(m.recoveredSessionsDoneChan) + m.recoveredSessions = nil + } } // GetSessionOfTxn returns the session of the transaction. @@ -121,7 +181,7 @@ func (m *TxnManager) GracefulClose(ctx context.Context) error { m.closed.Close() } } - m.logger.Info("there's still txn session in txn manager, waiting for them to be consumed", zap.Int("session count", len(m.sessions))) + m.Logger().Info("graceful close txn manager", zap.Int("activeTxnCount", len(m.sessions))) m.mu.Unlock() select { diff --git a/internal/streamingnode/server/wal/utility/txn_buffer.go b/internal/streamingnode/server/wal/utility/txn_buffer.go index 20f9351bf9..aa5ac2eefb 100644 --- a/internal/streamingnode/server/wal/utility/txn_buffer.go +++ b/internal/streamingnode/server/wal/utility/txn_buffer.go @@ -29,6 +29,11 @@ func (b *TxnBuffer) Bytes() int { return b.bytes } +// GetUncommittedMessageBuilder returns the uncommitted message builders. +func (b *TxnBuffer) GetUncommittedMessageBuilder() map[message.TxnID]*message.ImmutableTxnMessageBuilder { + return b.builders +} + // HandleImmutableMessages handles immutable messages. // The timetick of msgs should be in ascending order, and the timetick of all messages is less than or equal to ts. // Hold the uncommitted txn messages until the commit or rollback message comes and pop the committed txn messages. diff --git a/internal/streamingnode/server/wal/utility/txn_buffer_test.go b/internal/streamingnode/server/wal/utility/txn_buffer_test.go index 364c0aceca..50e4346a2d 100644 --- a/internal/streamingnode/server/wal/utility/txn_buffer_test.go +++ b/internal/streamingnode/server/wal/utility/txn_buffer_test.go @@ -72,6 +72,7 @@ func TestTxnBuffer(t *testing.T) { assert.Len(t, msgs, 1) } createUnCommitted() + assert.Len(t, b.GetUncommittedMessageBuilder(), 1) msgs = b.HandleImmutableMessages([]message.ImmutableMessage{ newCommitMessage(t, txnCtx, tsoutil.AddPhysicalDurationOnTs(baseTso, 500*time.Millisecond)), }, tsoutil.AddPhysicalDurationOnTs(baseTso, 600*time.Millisecond))