mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
enhance: add lock interceptor and recoverable txn manager (#41640)
issue: #41544 - add a lock interceptor at vchannel granularity. - make txn manager recoverable and add FailTxnAtVChannel operation. Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
parent
44c0799331
commit
3dd9a1147b
@ -0,0 +1,23 @@
|
|||||||
|
package lock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/lock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewInterceptorBuilder creates a new redo interceptor builder.
|
||||||
|
// TODO: add it into wal after recovery storage is merged.
|
||||||
|
func NewInterceptorBuilder() interceptors.InterceptorBuilder {
|
||||||
|
return &interceptorBuilder{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// interceptorBuilder is the builder for redo interceptor.
|
||||||
|
type interceptorBuilder struct{}
|
||||||
|
|
||||||
|
// Build creates a new redo interceptor.
|
||||||
|
func (b *interceptorBuilder) Build(param *interceptors.InterceptorBuildParam) interceptors.Interceptor {
|
||||||
|
return &lockAppendInterceptor{
|
||||||
|
vchannelLocker: lock.NewKeyLock[string](),
|
||||||
|
// TODO: txnManager will be intiailized by param txnManager: param.TxnManager,
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,50 @@
|
|||||||
|
package lock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors"
|
||||||
|
"github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/lock"
|
||||||
|
)
|
||||||
|
|
||||||
|
type lockAppendInterceptor struct {
|
||||||
|
vchannelLocker *lock.KeyLock[string]
|
||||||
|
txnManager *txn.TxnManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *lockAppendInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, append interceptors.Append) (message.MessageID, error) {
|
||||||
|
g := r.acquireLockGuard(ctx, msg)
|
||||||
|
defer g()
|
||||||
|
|
||||||
|
return append(ctx, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// acquireLockGuard acquires the lock for the vchannel and return a function as a guard.
|
||||||
|
func (r *lockAppendInterceptor) acquireLockGuard(_ context.Context, msg message.MutableMessage) func() {
|
||||||
|
// Acquire the write lock for the vchannel.
|
||||||
|
vchannel := msg.VChannel()
|
||||||
|
if msg.MessageType().IsExclusiveRequired() {
|
||||||
|
r.vchannelLocker.Lock(vchannel)
|
||||||
|
return func() {
|
||||||
|
// For exclusive messages, we need to fail all transactions at the vchannel.
|
||||||
|
// Otherwise, the transaction message may cross the exclusive message.
|
||||||
|
// e.g. an exclusive message like `ManualFlush` happens, it will flush all the growing segment.
|
||||||
|
// But the transaction insert message that use those segments may not be committed,
|
||||||
|
// if we allow it to be committed, a insert message can be seen after the manual flush message, lead to the wrong wal message order.
|
||||||
|
// So we need to fail all transactions at the vchannel, it will be retried at client side with new txn.
|
||||||
|
//
|
||||||
|
// the append operation of exclusive message should be low rate, so it's acceptable to fail all transactions at the vchannel.
|
||||||
|
r.txnManager.FailTxnAtVChannel(vchannel)
|
||||||
|
r.vchannelLocker.Unlock(vchannel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.vchannelLocker.RLock(vchannel)
|
||||||
|
return func() {
|
||||||
|
r.vchannelLocker.RUnlock(vchannel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the interceptor release all the resources.
|
||||||
|
func (r *lockAppendInterceptor) Close() {}
|
||||||
@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
|
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||||
@ -112,8 +113,16 @@ func TestSegmentAllocManager(t *testing.T) {
|
|||||||
assert.True(t, m.IsNoWaitSeal()) // result2 is acked, so new seal segment will be sealed right away.
|
assert.True(t, m.IsNoWaitSeal()) // result2 is acked, so new seal segment will be sealed right away.
|
||||||
|
|
||||||
// interactive with txn
|
// interactive with txn
|
||||||
txnManager := txn.NewTxnManager(types.PChannelInfo{Name: "test"})
|
txnManager := txn.NewTxnManager(types.PChannelInfo{Name: "test"}, nil)
|
||||||
txn, err := txnManager.BeginNewTxn(context.Background(), tsoutil.GetCurrentTime(), time.Second)
|
msg := message.NewBeginTxnMessageBuilderV2().
|
||||||
|
WithVChannel("v1").
|
||||||
|
WithHeader(&message.BeginTxnMessageHeader{KeepaliveMilliseconds: 1000}).
|
||||||
|
WithBody(&message.BeginTxnMessageBody{}).
|
||||||
|
MustBuildMutable().
|
||||||
|
WithTimeTick(tsoutil.GetCurrentTime())
|
||||||
|
|
||||||
|
beginTxnMsg, _ := message.AsMutableBeginTxnMessageV2(msg)
|
||||||
|
txn, err := txnManager.BeginNewTxn(ctx, beginTxnMsg)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
txn.BeginDone()
|
txn.BeginDone()
|
||||||
|
|
||||||
|
|||||||
@ -23,8 +23,10 @@ func (b *interceptorBuilder) Build(param *interceptors.InterceptorBuildParam) in
|
|||||||
operator := newTimeTickSyncOperator(param)
|
operator := newTimeTickSyncOperator(param)
|
||||||
// initialize operation can be async to avoid block the build operation.
|
// initialize operation can be async to avoid block the build operation.
|
||||||
resource.Resource().TimeTickInspector().RegisterSyncOperator(operator)
|
resource.Resource().TimeTickInspector().RegisterSyncOperator(operator)
|
||||||
|
|
||||||
return &timeTickAppendInterceptor{
|
return &timeTickAppendInterceptor{
|
||||||
operator: operator,
|
operator: operator,
|
||||||
txnManager: txn.NewTxnManager(param.ChannelInfo),
|
// TODO: it's just a placeholder, should be replaced after recovery storage is merged.
|
||||||
|
txnManager: txn.NewTxnManager(param.ChannelInfo, nil),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -145,7 +145,7 @@ func (impl *timeTickAppendInterceptor) handleBegin(ctx context.Context, msg mess
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
// Begin transaction will generate a txn context.
|
// Begin transaction will generate a txn context.
|
||||||
session, err := impl.txnManager.BeginNewTxn(ctx, msg.TimeTick(), time.Duration(beginTxnMsg.Header().KeepaliveMilliseconds)*time.Millisecond)
|
session, err := impl.txnManager.BeginNewTxn(ctx, beginTxnMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
session.BeginRollback()
|
session.BeginRollback()
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|||||||
@ -14,10 +14,30 @@ type txnSessionKeyType int
|
|||||||
|
|
||||||
var txnSessionKeyValue txnSessionKeyType = 1
|
var txnSessionKeyValue txnSessionKeyType = 1
|
||||||
|
|
||||||
|
// newTxnSession creates a new transaction session.
|
||||||
|
func newTxnSession(
|
||||||
|
vchannel string,
|
||||||
|
txnContext message.TxnContext,
|
||||||
|
timetick uint64,
|
||||||
|
metricsGuard *metricsutil.TxnMetricsGuard,
|
||||||
|
) *TxnSession {
|
||||||
|
return &TxnSession{
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
vchannel: vchannel,
|
||||||
|
lastTimetick: timetick,
|
||||||
|
txnContext: txnContext,
|
||||||
|
inFlightCount: 0,
|
||||||
|
state: message.TxnStateBegin,
|
||||||
|
doneWait: nil,
|
||||||
|
rollback: false,
|
||||||
|
metricsGuard: metricsGuard,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TxnSession is a session for a transaction.
|
// TxnSession is a session for a transaction.
|
||||||
type TxnSession struct {
|
type TxnSession struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
vchannel string // The vchannel of the session.
|
||||||
lastTimetick uint64 // session last timetick.
|
lastTimetick uint64 // session last timetick.
|
||||||
expired bool // The flag indicates the transaction has trigger expired once.
|
expired bool // The flag indicates the transaction has trigger expired once.
|
||||||
txnContext message.TxnContext // transaction id of the session
|
txnContext message.TxnContext // transaction id of the session
|
||||||
@ -29,6 +49,11 @@ type TxnSession struct {
|
|||||||
metricsGuard *metricsutil.TxnMetricsGuard // The metrics guard for the session.
|
metricsGuard *metricsutil.TxnMetricsGuard // The metrics guard for the session.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VChannel returns the vchannel of the session.
|
||||||
|
func (s *TxnSession) VChannel() string {
|
||||||
|
return s.vchannel
|
||||||
|
}
|
||||||
|
|
||||||
// TxnContext returns the txn context of the session.
|
// TxnContext returns the txn context of the session.
|
||||||
func (s *TxnSession) TxnContext() message.TxnContext {
|
func (s *TxnSession) TxnContext() message.TxnContext {
|
||||||
return s.txnContext
|
return s.txnContext
|
||||||
|
|||||||
@ -9,11 +9,13 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||||
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
|
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
|
||||||
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
|
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/streamingpb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
|
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/streaming/walimpls/impls/rmq"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
|
||||||
)
|
)
|
||||||
@ -27,8 +29,11 @@ func TestSession(t *testing.T) {
|
|||||||
resource.InitForTest(t)
|
resource.InitForTest(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
m := NewTxnManager(types.PChannelInfo{Name: "test"})
|
m := NewTxnManager(types.PChannelInfo{Name: "test"}, nil)
|
||||||
session, err := m.BeginNewTxn(ctx, 0, 10*time.Millisecond)
|
<-m.RecoverDone()
|
||||||
|
session, err := m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond))
|
||||||
|
assert.Equal(t, session.VChannel(), "v1")
|
||||||
|
assert.Equal(t, session.State(), message.TxnStateBegin)
|
||||||
assert.NotNil(t, session)
|
assert.NotNil(t, session)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
@ -41,7 +46,7 @@ func TestSession(t *testing.T) {
|
|||||||
assert.Equal(t, message.TxnStateRollbacked, session.state)
|
assert.Equal(t, message.TxnStateRollbacked, session.state)
|
||||||
assert.True(t, session.IsExpiredOrDone(0))
|
assert.True(t, session.IsExpiredOrDone(0))
|
||||||
|
|
||||||
session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond)
|
session, err = m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
session.BeginDone()
|
session.BeginDone()
|
||||||
assert.Equal(t, message.TxnStateInFlight, session.state)
|
assert.Equal(t, message.TxnStateInFlight, session.state)
|
||||||
@ -59,7 +64,7 @@ func TestSession(t *testing.T) {
|
|||||||
serr = status.AsStreamingError(err)
|
serr = status.AsStreamingError(err)
|
||||||
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code)
|
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code)
|
||||||
|
|
||||||
session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond)
|
session, err = m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
session.BeginDone()
|
session.BeginDone()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@ -75,7 +80,7 @@ func TestSession(t *testing.T) {
|
|||||||
assert.Equal(t, message.TxnStateCommitted, session.state)
|
assert.Equal(t, message.TxnStateCommitted, session.state)
|
||||||
|
|
||||||
// Test Commit timeout.
|
// Test Commit timeout.
|
||||||
session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond)
|
session, err = m.BeginNewTxn(ctx, newBeginTxnMessage(0, 10*time.Millisecond))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
session.BeginDone()
|
session.BeginDone()
|
||||||
err = session.AddNewMessage(ctx, 0)
|
err = session.AddNewMessage(ctx, 0)
|
||||||
@ -94,7 +99,7 @@ func TestSession(t *testing.T) {
|
|||||||
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code)
|
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code)
|
||||||
|
|
||||||
// Test Rollback
|
// Test Rollback
|
||||||
session, _ = m.BeginNewTxn(context.Background(), 0, 10*time.Millisecond)
|
session, _ = m.BeginNewTxn(context.Background(), newBeginTxnMessage(0, 10*time.Millisecond))
|
||||||
session.BeginDone()
|
session.BeginDone()
|
||||||
// Rollback expired.
|
// Rollback expired.
|
||||||
err = session.RequestRollback(context.Background(), expiredTs)
|
err = session.RequestRollback(context.Background(), expiredTs)
|
||||||
@ -103,7 +108,7 @@ func TestSession(t *testing.T) {
|
|||||||
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code)
|
assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code)
|
||||||
|
|
||||||
// Rollback success
|
// Rollback success
|
||||||
session, _ = m.BeginNewTxn(context.Background(), 0, 10*time.Millisecond)
|
session, _ = m.BeginNewTxn(context.Background(), newBeginTxnMessage(0, 10*time.Millisecond))
|
||||||
session.BeginDone()
|
session.BeginDone()
|
||||||
err = session.RequestRollback(context.Background(), 0)
|
err = session.RequestRollback(context.Background(), 0)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@ -112,7 +117,7 @@ func TestSession(t *testing.T) {
|
|||||||
|
|
||||||
func TestManager(t *testing.T) {
|
func TestManager(t *testing.T) {
|
||||||
resource.InitForTest(t)
|
resource.InitForTest(t)
|
||||||
m := NewTxnManager(types.PChannelInfo{Name: "test"})
|
m := NewTxnManager(types.PChannelInfo{Name: "test"}, nil)
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
|
|
||||||
@ -121,7 +126,7 @@ func TestManager(t *testing.T) {
|
|||||||
for i := 0; i < 20; i++ {
|
for i := 0; i < 20; i++ {
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
session, err := m.BeginNewTxn(context.Background(), 0, time.Duration(i+1)*time.Millisecond)
|
session, err := m.BeginNewTxn(context.Background(), newBeginTxnMessage(0, time.Duration(i+1)*time.Millisecond))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, session)
|
assert.NotNil(t, session)
|
||||||
session.BeginDone()
|
session.BeginDone()
|
||||||
@ -176,6 +181,48 @@ func TestManager(t *testing.T) {
|
|||||||
assert.Equal(t, int32(0), count.Load())
|
assert.Equal(t, int32(0), count.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestManagerRecoverAndFailAll(t *testing.T) {
|
||||||
|
resource.InitForTest(t)
|
||||||
|
now := time.Now()
|
||||||
|
beginMsg1 := newImmutableBeginTxnMessageWithVChannel("v1", 1, tsoutil.ComposeTSByTime(now, 0), 10*time.Millisecond)
|
||||||
|
beginMsg2 := newImmutableBeginTxnMessageWithVChannel("v2", 2, tsoutil.ComposeTSByTime(now, 1), 10*time.Millisecond)
|
||||||
|
beginMsg3 := newImmutableBeginTxnMessageWithVChannel("v1", 3, tsoutil.ComposeTSByTime(now, 2), 10*time.Millisecond)
|
||||||
|
builders := map[message.TxnID]*message.ImmutableTxnMessageBuilder{
|
||||||
|
message.TxnID(1): message.NewImmutableTxnMessageBuilder(beginMsg1),
|
||||||
|
message.TxnID(2): message.NewImmutableTxnMessageBuilder(beginMsg2),
|
||||||
|
message.TxnID(3): message.NewImmutableTxnMessageBuilder(beginMsg3),
|
||||||
|
}
|
||||||
|
|
||||||
|
builders[message.TxnID(1)].Add(message.NewInsertMessageBuilderV1().
|
||||||
|
WithVChannel("v1").
|
||||||
|
WithHeader(&message.InsertMessageHeader{}).
|
||||||
|
WithBody(&msgpb.InsertRequest{}).
|
||||||
|
MustBuildMutable().
|
||||||
|
WithTimeTick(tsoutil.ComposeTSByTime(now, 3)).
|
||||||
|
WithLastConfirmedUseMessageID().
|
||||||
|
WithTxnContext(message.TxnContext{
|
||||||
|
TxnID: message.TxnID(1),
|
||||||
|
Keepalive: 10 * time.Millisecond,
|
||||||
|
}).
|
||||||
|
IntoImmutableMessage(rmq.NewRmqID(1)))
|
||||||
|
|
||||||
|
m := NewTxnManager(types.PChannelInfo{Name: "test"}, builders)
|
||||||
|
select {
|
||||||
|
case <-m.RecoverDone():
|
||||||
|
t.Errorf("txn manager should not be recovered")
|
||||||
|
case <-time.After(1 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
m.FailTxnAtVChannel("v1")
|
||||||
|
select {
|
||||||
|
case <-m.RecoverDone():
|
||||||
|
t.Errorf("txn manager should not be recovered")
|
||||||
|
case <-time.After(1 * time.Millisecond):
|
||||||
|
}
|
||||||
|
m.FailTxnAtVChannel("v2")
|
||||||
|
<-m.RecoverDone()
|
||||||
|
}
|
||||||
|
|
||||||
func TestWithContext(t *testing.T) {
|
func TestWithContext(t *testing.T) {
|
||||||
session := &TxnSession{}
|
session := &TxnSession{}
|
||||||
ctx := WithTxnSession(context.Background(), session)
|
ctx := WithTxnSession(context.Background(), session)
|
||||||
@ -183,3 +230,38 @@ func TestWithContext(t *testing.T) {
|
|||||||
session = GetTxnSessionFromContext(ctx)
|
session = GetTxnSessionFromContext(ctx)
|
||||||
assert.NotNil(t, session)
|
assert.NotNil(t, session)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newBeginTxnMessage(timetick uint64, keepalive time.Duration) message.MutableBeginTxnMessageV2 {
|
||||||
|
return newBeginTxnMessageWithVChannel("v1", timetick, keepalive)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBeginTxnMessageWithVChannel(vchannel string, timetick uint64, keepalive time.Duration) message.MutableBeginTxnMessageV2 {
|
||||||
|
msg := message.NewBeginTxnMessageBuilderV2().
|
||||||
|
WithVChannel(vchannel).
|
||||||
|
WithHeader(&message.BeginTxnMessageHeader{KeepaliveMilliseconds: keepalive.Milliseconds()}).
|
||||||
|
WithBody(&message.BeginTxnMessageBody{}).
|
||||||
|
MustBuildMutable().
|
||||||
|
WithTimeTick(timetick)
|
||||||
|
|
||||||
|
beginTxnMsg, _ := message.AsMutableBeginTxnMessageV2(msg)
|
||||||
|
return beginTxnMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
func newImmutableBeginTxnMessageWithVChannel(vchannel string, txnID int64, timetick uint64, keepalive time.Duration) message.ImmutableBeginTxnMessageV2 {
|
||||||
|
msg := message.NewBeginTxnMessageBuilderV2().
|
||||||
|
WithVChannel(vchannel).
|
||||||
|
WithHeader(&message.BeginTxnMessageHeader{
|
||||||
|
KeepaliveMilliseconds: keepalive.Milliseconds(),
|
||||||
|
}).
|
||||||
|
WithBody(&message.BeginTxnMessageBody{}).
|
||||||
|
MustBuildMutable().
|
||||||
|
WithTimeTick(timetick).
|
||||||
|
WithLastConfirmed(rmq.NewRmqID(1)).
|
||||||
|
WithLastConfirmedUseMessageID().
|
||||||
|
WithTxnContext(message.TxnContext{
|
||||||
|
TxnID: message.TxnID(txnID),
|
||||||
|
Keepalive: keepalive,
|
||||||
|
}).
|
||||||
|
IntoImmutableMessage(rmq.NewRmqID(1))
|
||||||
|
return message.MustAsImmutableBeginTxnMessageV2(msg)
|
||||||
|
}
|
||||||
|
|||||||
@ -18,31 +18,69 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewTxnManager creates a new transaction manager.
|
// NewTxnManager creates a new transaction manager.
|
||||||
func NewTxnManager(pchannel types.PChannelInfo) *TxnManager {
|
// incoming buffer is used to recover the uncommitted messages for txn manager.
|
||||||
return &TxnManager{
|
func NewTxnManager(pchannel types.PChannelInfo, uncommittedTxnBuilders map[message.TxnID]*message.ImmutableTxnMessageBuilder) *TxnManager {
|
||||||
mu: sync.Mutex{},
|
m := metricsutil.NewTxnMetrics(pchannel.Name)
|
||||||
sessions: make(map[message.TxnID]*TxnSession),
|
sessions := make(map[message.TxnID]*TxnSession, len(uncommittedTxnBuilders))
|
||||||
closed: nil,
|
recoveredSessions := make(map[message.TxnID]struct{}, len(uncommittedTxnBuilders))
|
||||||
metrics: metricsutil.NewTxnMetrics(pchannel.Name),
|
sessionIDs := make([]int64, 0, len(uncommittedTxnBuilders))
|
||||||
logger: resource.Resource().Logger().With(log.FieldComponent("txn-manager")),
|
for _, builder := range uncommittedTxnBuilders {
|
||||||
|
beginMessages, body := builder.Messages()
|
||||||
|
session := newTxnSession(
|
||||||
|
beginMessages.VChannel(),
|
||||||
|
*beginMessages.TxnContext(), // must be the txn message.
|
||||||
|
beginMessages.TimeTick(),
|
||||||
|
m.BeginTxn(),
|
||||||
|
)
|
||||||
|
for _, msg := range body {
|
||||||
|
session.AddNewMessage(context.Background(), msg.TimeTick())
|
||||||
|
session.AddNewMessageDoneAndKeepalive(msg.TimeTick())
|
||||||
}
|
}
|
||||||
|
sessions[session.TxnContext().TxnID] = session
|
||||||
|
recoveredSessions[session.TxnContext().TxnID] = struct{}{}
|
||||||
|
sessionIDs = append(sessionIDs, int64(session.TxnContext().TxnID))
|
||||||
|
}
|
||||||
|
txnManager := &TxnManager{
|
||||||
|
mu: sync.Mutex{},
|
||||||
|
recoveredSessions: recoveredSessions,
|
||||||
|
recoveredSessionsDoneChan: make(chan struct{}),
|
||||||
|
sessions: sessions,
|
||||||
|
closed: nil,
|
||||||
|
metrics: m,
|
||||||
|
}
|
||||||
|
txnManager.notifyRecoverDone()
|
||||||
|
txnManager.SetLogger(resource.Resource().Logger().With(log.FieldComponent("txn-manager")))
|
||||||
|
txnManager.Logger().Info("txn manager recovered with txn", zap.Int64s("txnIDs", sessionIDs))
|
||||||
|
return txnManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// TxnManager is the manager of transactions.
|
// TxnManager is the manager of transactions.
|
||||||
// We don't support cross wal transaction by now and
|
// We don't support cross wal transaction by now and
|
||||||
// We don't support the transaction lives after the wal transferred to another streaming node.
|
// We don't support the transaction lives after the wal transferred to another streaming node.
|
||||||
type TxnManager struct {
|
type TxnManager struct {
|
||||||
|
log.Binder
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
recoveredSessions map[message.TxnID]struct{}
|
||||||
|
recoveredSessionsDoneChan chan struct{}
|
||||||
sessions map[message.TxnID]*TxnSession
|
sessions map[message.TxnID]*TxnSession
|
||||||
closed lifetime.SafeChan
|
closed lifetime.SafeChan
|
||||||
metrics *metricsutil.TxnMetrics
|
metrics *metricsutil.TxnMetrics
|
||||||
logger *log.MLogger
|
}
|
||||||
|
|
||||||
|
// RecoverDone returns a channel that is closed when all transactions are cleaned up.
|
||||||
|
func (m *TxnManager) RecoverDone() <-chan struct{} {
|
||||||
|
return m.recoveredSessionsDoneChan
|
||||||
}
|
}
|
||||||
|
|
||||||
// BeginNewTxn starts a new transaction with a session.
|
// BeginNewTxn starts a new transaction with a session.
|
||||||
// We only support a transaction work on a streaming node, once the wal is transferred to another node,
|
// We only support a transaction work on a streaming node, once the wal is transferred to another node,
|
||||||
// the transaction is treated as expired (rollback), and user will got a expired error, then perform a retry.
|
// the transaction is treated as expired (rollback), and user will got a expired error, then perform a retry.
|
||||||
func (m *TxnManager) BeginNewTxn(ctx context.Context, timetick uint64, keepalive time.Duration) (*TxnSession, error) {
|
func (m *TxnManager) BeginNewTxn(ctx context.Context, msg message.MutableBeginTxnMessageV2) (*TxnSession, error) {
|
||||||
|
timetick := msg.TimeTick()
|
||||||
|
vchannel := msg.VChannel()
|
||||||
|
keepalive := time.Duration(msg.Header().KeepaliveMilliseconds) * time.Millisecond
|
||||||
|
|
||||||
if keepalive == 0 {
|
if keepalive == 0 {
|
||||||
// If keepalive is 0, the txn set the keepalive with default keepalive.
|
// If keepalive is 0, the txn set the keepalive with default keepalive.
|
||||||
keepalive = paramtable.Get().StreamingCfg.TxnDefaultKeepaliveTimeout.GetAsDurationByParse()
|
keepalive = paramtable.Get().StreamingCfg.TxnDefaultKeepaliveTimeout.GetAsDurationByParse()
|
||||||
@ -62,24 +100,35 @@ func (m *TxnManager) BeginNewTxn(ctx context.Context, timetick uint64, keepalive
|
|||||||
if m.closed != nil {
|
if m.closed != nil {
|
||||||
return nil, status.NewTransactionExpired("manager closed")
|
return nil, status.NewTransactionExpired("manager closed")
|
||||||
}
|
}
|
||||||
metricsGuard := m.metrics.BeginTxn()
|
txnCtx := message.TxnContext{
|
||||||
session := &TxnSession{
|
|
||||||
mu: sync.Mutex{},
|
|
||||||
lastTimetick: timetick,
|
|
||||||
txnContext: message.TxnContext{
|
|
||||||
TxnID: message.TxnID(id),
|
TxnID: message.TxnID(id),
|
||||||
Keepalive: keepalive,
|
Keepalive: keepalive,
|
||||||
},
|
|
||||||
inFlightCount: 0,
|
|
||||||
state: message.TxnStateBegin,
|
|
||||||
doneWait: nil,
|
|
||||||
rollback: false,
|
|
||||||
metricsGuard: metricsGuard,
|
|
||||||
}
|
}
|
||||||
|
session := newTxnSession(vchannel, txnCtx, timetick, m.metrics.BeginTxn())
|
||||||
m.sessions[session.TxnContext().TxnID] = session
|
m.sessions[session.TxnContext().TxnID] = session
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FailTxnAtVChannel fails all transactions at the specified vchannel.
|
||||||
|
func (m *TxnManager) FailTxnAtVChannel(vchannel string) {
|
||||||
|
// avoid the txn to be committed.
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
ids := make([]int64, 0, len(m.sessions))
|
||||||
|
for id, session := range m.sessions {
|
||||||
|
if session.VChannel() == vchannel {
|
||||||
|
session.Cleanup()
|
||||||
|
delete(m.sessions, id)
|
||||||
|
delete(m.recoveredSessions, id)
|
||||||
|
ids = append(ids, int64(id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(ids) > 0 {
|
||||||
|
m.Logger().Info("transaction interrupted", zap.String("vchannel", vchannel), zap.Int64s("txnIDs", ids))
|
||||||
|
}
|
||||||
|
m.notifyRecoverDone()
|
||||||
|
}
|
||||||
|
|
||||||
// CleanupTxnUntil cleans up the transactions until the specified timestamp.
|
// CleanupTxnUntil cleans up the transactions until the specified timestamp.
|
||||||
func (m *TxnManager) CleanupTxnUntil(ts uint64) {
|
func (m *TxnManager) CleanupTxnUntil(ts uint64) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@ -89,6 +138,7 @@ func (m *TxnManager) CleanupTxnUntil(ts uint64) {
|
|||||||
if session.IsExpiredOrDone(ts) {
|
if session.IsExpiredOrDone(ts) {
|
||||||
session.Cleanup()
|
session.Cleanup()
|
||||||
delete(m.sessions, id)
|
delete(m.sessions, id)
|
||||||
|
delete(m.recoveredSessions, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,6 +146,16 @@ func (m *TxnManager) CleanupTxnUntil(ts uint64) {
|
|||||||
if len(m.sessions) == 0 && m.closed != nil {
|
if len(m.sessions) == 0 && m.closed != nil {
|
||||||
m.closed.Close()
|
m.closed.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.notifyRecoverDone()
|
||||||
|
}
|
||||||
|
|
||||||
|
// notifyRecoverDone notifies the recover done channel if all transactions from recover info is done.
|
||||||
|
func (m *TxnManager) notifyRecoverDone() {
|
||||||
|
if len(m.recoveredSessions) == 0 && m.recoveredSessions != nil {
|
||||||
|
close(m.recoveredSessionsDoneChan)
|
||||||
|
m.recoveredSessions = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSessionOfTxn returns the session of the transaction.
|
// GetSessionOfTxn returns the session of the transaction.
|
||||||
@ -121,7 +181,7 @@ func (m *TxnManager) GracefulClose(ctx context.Context) error {
|
|||||||
m.closed.Close()
|
m.closed.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
m.logger.Info("there's still txn session in txn manager, waiting for them to be consumed", zap.Int("session count", len(m.sessions)))
|
m.Logger().Info("graceful close txn manager", zap.Int("activeTxnCount", len(m.sessions)))
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|||||||
@ -29,6 +29,11 @@ func (b *TxnBuffer) Bytes() int {
|
|||||||
return b.bytes
|
return b.bytes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUncommittedMessageBuilder returns the uncommitted message builders.
|
||||||
|
func (b *TxnBuffer) GetUncommittedMessageBuilder() map[message.TxnID]*message.ImmutableTxnMessageBuilder {
|
||||||
|
return b.builders
|
||||||
|
}
|
||||||
|
|
||||||
// HandleImmutableMessages handles immutable messages.
|
// HandleImmutableMessages handles immutable messages.
|
||||||
// The timetick of msgs should be in ascending order, and the timetick of all messages is less than or equal to ts.
|
// The timetick of msgs should be in ascending order, and the timetick of all messages is less than or equal to ts.
|
||||||
// Hold the uncommitted txn messages until the commit or rollback message comes and pop the committed txn messages.
|
// Hold the uncommitted txn messages until the commit or rollback message comes and pop the committed txn messages.
|
||||||
|
|||||||
@ -72,6 +72,7 @@ func TestTxnBuffer(t *testing.T) {
|
|||||||
assert.Len(t, msgs, 1)
|
assert.Len(t, msgs, 1)
|
||||||
}
|
}
|
||||||
createUnCommitted()
|
createUnCommitted()
|
||||||
|
assert.Len(t, b.GetUncommittedMessageBuilder(), 1)
|
||||||
msgs = b.HandleImmutableMessages([]message.ImmutableMessage{
|
msgs = b.HandleImmutableMessages([]message.ImmutableMessage{
|
||||||
newCommitMessage(t, txnCtx, tsoutil.AddPhysicalDurationOnTs(baseTso, 500*time.Millisecond)),
|
newCommitMessage(t, txnCtx, tsoutil.AddPhysicalDurationOnTs(baseTso, 500*time.Millisecond)),
|
||||||
}, tsoutil.AddPhysicalDurationOnTs(baseTso, 600*time.Millisecond))
|
}, tsoutil.AddPhysicalDurationOnTs(baseTso, 600*time.Millisecond))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user