enhance: add lock interceptor and recoverable txn manager (#41640)

issue: #41544

- add a lock interceptor at vchannel granularity.
- make txn manager recoverable and add FailTxnAtVChannel operation.

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-05-09 11:14:53 +08:00 committed by GitHub
parent 44c0799331
commit 3dd9a1147b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 300 additions and 43 deletions

View File

@ -0,0 +1,23 @@
package lock
import (
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
"github.com/milvus-io/milvus/pkg/v2/util/lock"
)
// NewInterceptorBuilder creates a new redo interceptor builder.
// TODO: add it into wal after recovery storage is merged.
func NewInterceptorBuilder() interceptors.InterceptorBuilder {
return &interceptorBuilder{}
}
// interceptorBuilder is the builder for redo interceptor.
type interceptorBuilder struct{}
// Build creates a new redo interceptor.
func (b *interceptorBuilder) Build(param *interceptors.InterceptorBuildParam) interceptors.Interceptor {
return &lockAppendInterceptor{
vchannelLocker: lock.NewKeyLock[string](),
// TODO: txnManager will be intiailized by param txnManager: param.TxnManager,
}
}

View File

@ -0,0 +1,50 @@
package lock
import (
"context"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/lock"
)
type lockAppendInterceptor struct {
vchannelLocker *lock.KeyLock[string]
txnManager *txn.TxnManager
}
func (r *lockAppendInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, append interceptors.Append) (message.MessageID, error) {
g := r.acquireLockGuard(ctx, msg)
defer g()
return append(ctx, msg)
}
// acquireLockGuard acquires the lock for the vchannel and return a function as a guard.
func (r *lockAppendInterceptor) acquireLockGuard(_ context.Context, msg message.MutableMessage) func() {
// Acquire the write lock for the vchannel.
vchannel := msg.VChannel()
if msg.MessageType().IsExclusiveRequired() {
r.vchannelLocker.Lock(vchannel)
return func() {
// For exclusive messages, we need to fail all transactions at the vchannel.
// Otherwise, the transaction message may cross the exclusive message.
// e.g. an exclusive message like `ManualFlush` happens, it will flush all the growing segment.
// But the transaction insert message that use those segments may not be committed,
// if we allow it to be committed, a insert message can be seen after the manual flush message, lead to the wrong wal message order.
// So we need to fail all transactions at the vchannel, it will be retried at client side with new txn.
//
// the append operation of exclusive message should be low rate, so it's acceptable to fail all transactions at the vchannel.
r.txnManager.FailTxnAtVChannel(vchannel)
r.vchannelLocker.Unlock(vchannel)
}
}
r.vchannelLocker.RLock(vchannel)
return func() {
r.vchannelLocker.RUnlock(vchannel)
}
}
// Close the interceptor release all the resources.
func (r *lockAppendInterceptor) Close() {}

View File

@ -21,6 +21,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -112,8 +113,16 @@ func TestSegmentAllocManager(t *testing.T) {
assert.True(t, m.IsNoWaitSeal()) // result2 is acked, so new seal segment will be sealed right away.
// interactive with txn
txnManager := txn.NewTxnManager(types.PChannelInfo{Name: "test"})
txn, err := txnManager.BeginNewTxn(context.Background(), tsoutil.GetCurrentTime(), time.Second)
txnManager := txn.NewTxnManager(types.PChannelInfo{Name: "test"}, nil)
msg := message.NewBeginTxnMessageBuilderV2().
WithVChannel("v1").
WithHeader(&message.BeginTxnMessageHeader{KeepaliveMilliseconds: 1000}).
WithBody(&message.BeginTxnMessageBody{}).
MustBuildMutable().
WithTimeTick(tsoutil.GetCurrentTime())
beginTxnMsg, _ := message.AsMutableBeginTxnMessageV2(msg)
txn, err := txnManager.BeginNewTxn(ctx, beginTxnMsg)
assert.NoError(t, err)
txn.BeginDone()

View File

@ -23,8 +23,10 @@ func (b *interceptorBuilder) Build(param *interceptors.InterceptorBuildParam) in
operator := newTimeTickSyncOperator(param)
// initialize operation can be async to avoid block the build operation.
resource.Resource().TimeTickInspector().RegisterSyncOperator(operator)
return &timeTickAppendInterceptor{
operator: operator,
txnManager: txn.NewTxnManager(param.ChannelInfo),
// TODO: it's just a placeholder, should be replaced after recovery storage is merged.
txnManager: txn.NewTxnManager(param.ChannelInfo, nil),
}
}

View File

@ -145,7 +145,7 @@ func (impl *timeTickAppendInterceptor) handleBegin(ctx context.Context, msg mess
return nil, nil, err
}
// Begin transaction will generate a txn context.
session, err := impl.txnManager.BeginNewTxn(ctx, msg.TimeTick(), time.Duration(beginTxnMsg.Header().KeepaliveMilliseconds)*time.Millisecond)
session, err := impl.txnManager.BeginNewTxn(ctx, beginTxnMsg)
if err != nil {
session.BeginRollback()
return nil, nil, err

View File

@ -14,10 +14,30 @@ type txnSessionKeyType int
var txnSessionKeyValue txnSessionKeyType = 1
// newTxnSession creates a new transaction session.
func newTxnSession(
vchannel string,
txnContext message.TxnContext,
timetick uint64,
metricsGuard *metricsutil.TxnMetricsGuard,
) *TxnSession {
return &TxnSession{
mu: sync.Mutex{},
vchannel: vchannel,
lastTimetick: timetick,
txnContext: txnContext,
inFlightCount: 0,
state: message.TxnStateBegin,
doneWait: nil,
rollback: false,
metricsGuard: metricsGuard,
}
}
// TxnSession is a session for a transaction.
type TxnSession struct {
mu sync.Mutex
vchannel string // The vchannel of the session.
lastTimetick uint64 // session last timetick.
expired bool // The flag indicates the transaction has trigger expired once.
txnContext message.TxnContext // transaction id of the session
@ -29,6 +49,11 @@ type TxnSession struct {
metricsGuard *metricsutil.TxnMetricsGuard // The metrics guard for the session.
}
// VChannel returns the vchannel of the session.
func (s *TxnSession) VChannel() string {
return s.vchannel
}
// TxnContext returns the txn context of the session.
func (s *TxnSession) TxnContext() message.TxnContext {
return s.txnContext

View File

@ -9,11 +9,13 @@ import (
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
)
@ -27,8 +29,11 @@ func TestSession(t *testing.T) {
resource.InitForTest(t)
ctx := context.Background()
m := NewTxnManager(types.PChannelInfo{Name: "test"})
session, err := m.BeginNewTxn(ctx, 0, 10*time.Millisecond)
m := NewTxnManager(types.PChannelInfo{Name: "test"}, nil)
<-m.RecoverDone()
session, err := m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond))
assert.Equal(t, session.VChannel(), "v1")
assert.Equal(t, session.State(), message.TxnStateBegin)
assert.NotNil(t, session)
assert.NoError(t, err)
@ -41,7 +46,7 @@ func TestSession(t *testing.T) {
assert.Equal(t, message.TxnStateRollbacked, session.state)
assert.True(t, session.IsExpiredOrDone(0))
session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond)
session, err = m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond))
assert.NoError(t, err)
session.BeginDone()
assert.Equal(t, message.TxnStateInFlight, session.state)
@ -59,7 +64,7 @@ func TestSession(t *testing.T) {
serr = status.AsStreamingError(err)
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code)
session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond)
session, err = m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond))
assert.NoError(t, err)
session.BeginDone()
assert.NoError(t, err)
@ -75,7 +80,7 @@ func TestSession(t *testing.T) {
assert.Equal(t, message.TxnStateCommitted, session.state)
// Test Commit timeout.
session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond)
session, err = m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond))
assert.NoError(t, err)
session.BeginDone()
err = session.AddNewMessage(ctx, 0)
@ -94,7 +99,7 @@ func TestSession(t *testing.T) {
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code)
// Test Rollback
session, _ = m.BeginNewTxn(context.Background(), 0, 10*time.Millisecond)
session, _ = m.BeginNewTxn(context.Background(), newBeginTxnMessage(0, 10*time.Millisecond))
session.BeginDone()
// Rollback expired.
err = session.RequestRollback(context.Background(), expiredTs)
@ -103,7 +108,7 @@ func TestSession(t *testing.T) {
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code)
// Rollback success
session, _ = m.BeginNewTxn(context.Background(), 0, 10*time.Millisecond)
session, _ = m.BeginNewTxn(context.Background(), newBeginTxnMessage(0, 10*time.Millisecond))
session.BeginDone()
err = session.RequestRollback(context.Background(), 0)
assert.NoError(t, err)
@ -112,7 +117,7 @@ func TestSession(t *testing.T) {
func TestManager(t *testing.T) {
resource.InitForTest(t)
m := NewTxnManager(types.PChannelInfo{Name: "test"})
m := NewTxnManager(types.PChannelInfo{Name: "test"}, nil)
wg := &sync.WaitGroup{}
@ -121,7 +126,7 @@ func TestManager(t *testing.T) {
for i := 0; i < 20; i++ {
go func(i int) {
defer wg.Done()
session, err := m.BeginNewTxn(context.Background(), 0, time.Duration(i+1)*time.Millisecond)
session, err := m.BeginNewTxn(context.Background(), newBeginTxnMessage(0, time.Duration(i+1)*time.Millisecond))
assert.NoError(t, err)
assert.NotNil(t, session)
session.BeginDone()
@ -176,6 +181,48 @@ func TestManager(t *testing.T) {
assert.Equal(t, int32(0), count.Load())
}
func TestManagerRecoverAndFailAll(t *testing.T) {
resource.InitForTest(t)
now := time.Now()
beginMsg1 := newImmutableBeginTxnMessageWithVChannel("v1", 1, tsoutil.ComposeTSByTime(now, 0), 10*time.Millisecond)
beginMsg2 := newImmutableBeginTxnMessageWithVChannel("v2", 2, tsoutil.ComposeTSByTime(now, 1), 10*time.Millisecond)
beginMsg3 := newImmutableBeginTxnMessageWithVChannel("v1", 3, tsoutil.ComposeTSByTime(now, 2), 10*time.Millisecond)
builders := map[message.TxnID]*message.ImmutableTxnMessageBuilder{
message.TxnID(1): message.NewImmutableTxnMessageBuilder(beginMsg1),
message.TxnID(2): message.NewImmutableTxnMessageBuilder(beginMsg2),
message.TxnID(3): message.NewImmutableTxnMessageBuilder(beginMsg3),
}
builders[message.TxnID(1)].Add(message.NewInsertMessageBuilderV1().
WithVChannel("v1").
WithHeader(&message.InsertMessageHeader{}).
WithBody(&msgpb.InsertRequest{}).
MustBuildMutable().
WithTimeTick(tsoutil.ComposeTSByTime(now, 3)).
WithLastConfirmedUseMessageID().
WithTxnContext(message.TxnContext{
TxnID: message.TxnID(1),
Keepalive: 10 * time.Millisecond,
}).
IntoImmutableMessage(rmq.NewRmqID(1)))
m := NewTxnManager(types.PChannelInfo{Name: "test"}, builders)
select {
case <-m.RecoverDone():
t.Errorf("txn manager should not be recovered")
case <-time.After(1 * time.Millisecond):
}
m.FailTxnAtVChannel("v1")
select {
case <-m.RecoverDone():
t.Errorf("txn manager should not be recovered")
case <-time.After(1 * time.Millisecond):
}
m.FailTxnAtVChannel("v2")
<-m.RecoverDone()
}
func TestWithContext(t *testing.T) {
session := &TxnSession{}
ctx := WithTxnSession(context.Background(), session)
@ -183,3 +230,38 @@ func TestWithContext(t *testing.T) {
session = GetTxnSessionFromContext(ctx)
assert.NotNil(t, session)
}
func newBeginTxnMessage(timetick uint64, keepalive time.Duration) message.MutableBeginTxnMessageV2 {
return newBeginTxnMessageWithVChannel("v1", timetick, keepalive)
}
func newBeginTxnMessageWithVChannel(vchannel string, timetick uint64, keepalive time.Duration) message.MutableBeginTxnMessageV2 {
msg := message.NewBeginTxnMessageBuilderV2().
WithVChannel(vchannel).
WithHeader(&message.BeginTxnMessageHeader{KeepaliveMilliseconds: keepalive.Milliseconds()}).
WithBody(&message.BeginTxnMessageBody{}).
MustBuildMutable().
WithTimeTick(timetick)
beginTxnMsg, _ := message.AsMutableBeginTxnMessageV2(msg)
return beginTxnMsg
}
func newImmutableBeginTxnMessageWithVChannel(vchannel string, txnID int64, timetick uint64, keepalive time.Duration) message.ImmutableBeginTxnMessageV2 {
msg := message.NewBeginTxnMessageBuilderV2().
WithVChannel(vchannel).
WithHeader(&message.BeginTxnMessageHeader{
KeepaliveMilliseconds: keepalive.Milliseconds(),
}).
WithBody(&message.BeginTxnMessageBody{}).
MustBuildMutable().
WithTimeTick(timetick).
WithLastConfirmed(rmq.NewRmqID(1)).
WithLastConfirmedUseMessageID().
WithTxnContext(message.TxnContext{
TxnID: message.TxnID(txnID),
Keepalive: keepalive,
}).
IntoImmutableMessage(rmq.NewRmqID(1))
return message.MustAsImmutableBeginTxnMessageV2(msg)
}

View File

@ -18,31 +18,69 @@ import (
)
// NewTxnManager creates a new transaction manager.
func NewTxnManager(pchannel types.PChannelInfo) *TxnManager {
return &TxnManager{
mu: sync.Mutex{},
sessions: make(map[message.TxnID]*TxnSession),
closed: nil,
metrics: metricsutil.NewTxnMetrics(pchannel.Name),
logger: resource.Resource().Logger().With(log.FieldComponent("txn-manager")),
// incoming buffer is used to recover the uncommitted messages for txn manager.
func NewTxnManager(pchannel types.PChannelInfo, uncommittedTxnBuilders map[message.TxnID]*message.ImmutableTxnMessageBuilder) *TxnManager {
m := metricsutil.NewTxnMetrics(pchannel.Name)
sessions := make(map[message.TxnID]*TxnSession, len(uncommittedTxnBuilders))
recoveredSessions := make(map[message.TxnID]struct{}, len(uncommittedTxnBuilders))
sessionIDs := make([]int64, 0, len(uncommittedTxnBuilders))
for _, builder := range uncommittedTxnBuilders {
beginMessages, body := builder.Messages()
session := newTxnSession(
beginMessages.VChannel(),
*beginMessages.TxnContext(), // must be the txn message.
beginMessages.TimeTick(),
m.BeginTxn(),
)
for _, msg := range body {
session.AddNewMessage(context.Background(), msg.TimeTick())
session.AddNewMessageDoneAndKeepalive(msg.TimeTick())
}
sessions[session.TxnContext().TxnID] = session
recoveredSessions[session.TxnContext().TxnID] = struct{}{}
sessionIDs = append(sessionIDs, int64(session.TxnContext().TxnID))
}
txnManager := &TxnManager{
mu: sync.Mutex{},
recoveredSessions: recoveredSessions,
recoveredSessionsDoneChan: make(chan struct{}),
sessions: sessions,
closed: nil,
metrics: m,
}
txnManager.notifyRecoverDone()
txnManager.SetLogger(resource.Resource().Logger().With(log.FieldComponent("txn-manager")))
txnManager.Logger().Info("txn manager recovered with txn", zap.Int64s("txnIDs", sessionIDs))
return txnManager
}
// TxnManager is the manager of transactions.
// We don't support cross wal transaction by now and
// We don't support the transaction lives after the wal transferred to another streaming node.
type TxnManager struct {
log.Binder
mu sync.Mutex
recoveredSessions map[message.TxnID]struct{}
recoveredSessionsDoneChan chan struct{}
sessions map[message.TxnID]*TxnSession
closed lifetime.SafeChan
metrics *metricsutil.TxnMetrics
logger *log.MLogger
}
// RecoverDone returns a channel that is closed when all transactions are cleaned up.
func (m *TxnManager) RecoverDone() <-chan struct{} {
return m.recoveredSessionsDoneChan
}
// BeginNewTxn starts a new transaction with a session.
// We only support a transaction work on a streaming node, once the wal is transferred to another node,
// the transaction is treated as expired (rollback), and user will got a expired error, then perform a retry.
func (m *TxnManager) BeginNewTxn(ctx context.Context, timetick uint64, keepalive time.Duration) (*TxnSession, error) {
func (m *TxnManager) BeginNewTxn(ctx context.Context, msg message.MutableBeginTxnMessageV2) (*TxnSession, error) {
timetick := msg.TimeTick()
vchannel := msg.VChannel()
keepalive := time.Duration(msg.Header().KeepaliveMilliseconds) * time.Millisecond
if keepalive == 0 {
// If keepalive is 0, the txn set the keepalive with default keepalive.
keepalive = paramtable.Get().StreamingCfg.TxnDefaultKeepaliveTimeout.GetAsDurationByParse()
@ -62,24 +100,35 @@ func (m *TxnManager) BeginNewTxn(ctx context.Context, timetick uint64, keepalive
if m.closed != nil {
return nil, status.NewTransactionExpired("manager closed")
}
metricsGuard := m.metrics.BeginTxn()
session := &TxnSession{
mu: sync.Mutex{},
lastTimetick: timetick,
txnContext: message.TxnContext{
txnCtx := message.TxnContext{
TxnID: message.TxnID(id),
Keepalive: keepalive,
},
inFlightCount: 0,
state: message.TxnStateBegin,
doneWait: nil,
rollback: false,
metricsGuard: metricsGuard,
}
session := newTxnSession(vchannel, txnCtx, timetick, m.metrics.BeginTxn())
m.sessions[session.TxnContext().TxnID] = session
return session, nil
}
// FailTxnAtVChannel fails all transactions at the specified vchannel.
func (m *TxnManager) FailTxnAtVChannel(vchannel string) {
// avoid the txn to be committed.
m.mu.Lock()
defer m.mu.Unlock()
ids := make([]int64, 0, len(m.sessions))
for id, session := range m.sessions {
if session.VChannel() == vchannel {
session.Cleanup()
delete(m.sessions, id)
delete(m.recoveredSessions, id)
ids = append(ids, int64(id))
}
}
if len(ids) > 0 {
m.Logger().Info("transaction interrupted", zap.String("vchannel", vchannel), zap.Int64s("txnIDs", ids))
}
m.notifyRecoverDone()
}
// CleanupTxnUntil cleans up the transactions until the specified timestamp.
func (m *TxnManager) CleanupTxnUntil(ts uint64) {
m.mu.Lock()
@ -89,6 +138,7 @@ func (m *TxnManager) CleanupTxnUntil(ts uint64) {
if session.IsExpiredOrDone(ts) {
session.Cleanup()
delete(m.sessions, id)
delete(m.recoveredSessions, id)
}
}
@ -96,6 +146,16 @@ func (m *TxnManager) CleanupTxnUntil(ts uint64) {
if len(m.sessions) == 0 && m.closed != nil {
m.closed.Close()
}
m.notifyRecoverDone()
}
// notifyRecoverDone notifies the recover done channel if all transactions from recover info is done.
func (m *TxnManager) notifyRecoverDone() {
if len(m.recoveredSessions) == 0 && m.recoveredSessions != nil {
close(m.recoveredSessionsDoneChan)
m.recoveredSessions = nil
}
}
// GetSessionOfTxn returns the session of the transaction.
@ -121,7 +181,7 @@ func (m *TxnManager) GracefulClose(ctx context.Context) error {
m.closed.Close()
}
}
m.logger.Info("there's still txn session in txn manager, waiting for them to be consumed", zap.Int("session count", len(m.sessions)))
m.Logger().Info("graceful close txn manager", zap.Int("activeTxnCount", len(m.sessions)))
m.mu.Unlock()
select {

View File

@ -29,6 +29,11 @@ func (b *TxnBuffer) Bytes() int {
return b.bytes
}
// GetUncommittedMessageBuilder returns the uncommitted message builders.
func (b *TxnBuffer) GetUncommittedMessageBuilder() map[message.TxnID]*message.ImmutableTxnMessageBuilder {
return b.builders
}
// HandleImmutableMessages handles immutable messages.
// The timetick of msgs should be in ascending order, and the timetick of all messages is less than or equal to ts.
// Hold the uncommitted txn messages until the commit or rollback message comes and pop the committed txn messages.

View File

@ -72,6 +72,7 @@ func TestTxnBuffer(t *testing.T) {
assert.Len(t, msgs, 1)
}
createUnCommitted()
assert.Len(t, b.GetUncommittedMessageBuilder(), 1)
msgs = b.HandleImmutableMessages([]message.ImmutableMessage{
newCommitMessage(t, txnCtx, tsoutil.AddPhysicalDurationOnTs(baseTso, 500*time.Millisecond)),
}, tsoutil.AddPhysicalDurationOnTs(baseTso, 600*time.Millisecond))