diff --git a/cmd/components/streaming_node.go b/cmd/components/streaming_node.go index 2dad02154f..ab46f0b901 100644 --- a/cmd/components/streaming_node.go +++ b/cmd/components/streaming_node.go @@ -29,8 +29,8 @@ type StreamingNode struct { } // NewStreamingNode creates a new StreamingNode -func NewStreamingNode(_ context.Context, _ dependency.Factory) (*StreamingNode, error) { - svr, err := streamingnode.NewServer() +func NewStreamingNode(_ context.Context, factory dependency.Factory) (*StreamingNode, error) { + svr, err := streamingnode.NewServer(factory) if err != nil { return nil, err } diff --git a/internal/.mockery.yaml b/internal/.mockery.yaml index e9c2cdca12..8e4dfe4117 100644 --- a/internal/.mockery.yaml +++ b/internal/.mockery.yaml @@ -8,12 +8,19 @@ packages: github.com/milvus-io/milvus/internal/distributed/streaming: interfaces: WALAccesser: + Utility: github.com/milvus-io/milvus/internal/streamingcoord/server/balancer: interfaces: Balancer: github.com/milvus-io/milvus/internal/streamingnode/client/manager: interfaces: ManagerClient: + github.com/milvus-io/milvus/internal/streamingcoord/client: + interfaces: + Client: + github.com/milvus-io/milvus/internal/streamingnode/client/handler: + interfaces: + HandlerClient: github.com/milvus-io/milvus/internal/streamingnode/client/handler/assignment: interfaces: Watcher: @@ -37,11 +44,11 @@ packages: Interceptor: InterceptorWithReady: InterceptorBuilder: - github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector: - interfaces: + ? github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/inspector + : interfaces: SealOperator: - github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector: - interfaces: + ? github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector + : interfaces: TimeTickSyncOperator: google.golang.org/grpc: interfaces: diff --git a/internal/distributed/streaming/append.go b/internal/distributed/streaming/append.go index 530b9448d4..d44cb84ddf 100644 --- a/internal/distributed/streaming/append.go +++ b/internal/distributed/streaming/append.go @@ -2,7 +2,6 @@ package streaming import ( "context" - "sync" "github.com/milvus-io/milvus/internal/distributed/streaming/internal/producer" "github.com/milvus-io/milvus/pkg/streaming/util/message" @@ -10,161 +9,12 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" ) -// newAppendResponseN creates a new append response. -func newAppendResponseN(n int) AppendResponses { - return AppendResponses{ - Responses: make([]AppendResponse, n), - } -} - -// AppendResponse is the response of one append operation. -type AppendResponse struct { - AppendResult *types.AppendResult - Error error -} - -// AppendResponses is the response of append operation. -type AppendResponses struct { - Responses []AppendResponse -} - -// UnwrapFirstError returns the first error in the responses. -func (a AppendResponses) UnwrapFirstError() error { - for _, r := range a.Responses { - if r.Error != nil { - return r.Error - } - } - return nil -} - -// MaxTimeTick returns the max time tick in the responses. -func (a AppendResponses) MaxTimeTick() uint64 { - maxTimeTick := uint64(0) - for _, r := range a.Responses { - if r.AppendResult.TimeTick > maxTimeTick { - maxTimeTick = r.AppendResult.TimeTick - } - } - return maxTimeTick -} - -// fillAllError fills all the responses with the same error. -func (a *AppendResponses) fillAllError(err error) { - for i := range a.Responses { - a.Responses[i].Error = err - } -} - -// fillResponseAtIdx fill the response at idx -func (a *AppendResponses) fillResponseAtIdx(resp AppendResponse, idx int) { - a.Responses[idx] = resp -} - -// dispatchByPChannel dispatches the message into different pchannel. -func (w *walAccesserImpl) dispatchByPChannel(ctx context.Context, msgs ...message.MutableMessage) AppendResponses { - if len(msgs) == 0 { - return newAppendResponseN(0) - } - - // dispatch the messages into different pchannel. - dispatchedMessages, indexes := w.dispatchMessages(msgs...) - - // only one pchannel, append it directly, no more goroutine needed. - if len(dispatchedMessages) == 1 { - for pchannel, msgs := range dispatchedMessages { - return w.appendToPChannel(ctx, pchannel, msgs...) - } - } - - // otherwise, start multiple goroutine to append to different pchannel. - resp := newAppendResponseN(len(msgs)) - wg := sync.WaitGroup{} - wg.Add(len(dispatchedMessages)) - - mu := sync.Mutex{} - for pchannel, msgs := range dispatchedMessages { - pchannel := pchannel - msgs := msgs - idxes := indexes[pchannel] - w.appendExecutionPool.Submit(func() (struct{}, error) { - defer wg.Done() - singleResp := w.appendToPChannel(ctx, pchannel, msgs...) - mu.Lock() - for i, idx := range idxes { - resp.fillResponseAtIdx(singleResp.Responses[i], idx) - } - mu.Unlock() - return struct{}{}, nil - }) - } - wg.Wait() - return resp -} - -// dispatchMessages dispatches the messages into different pchannel. -func (w *walAccesserImpl) dispatchMessages(msgs ...message.MutableMessage) (map[string][]message.MutableMessage, map[string][]int) { - dispatchedMessages := make(map[string][]message.MutableMessage, 0) - // record the index of the message in the msgs, used to fill back response. - indexes := make(map[string][]int, 0) - for idx, msg := range msgs { - pchannel := funcutil.ToPhysicalChannel(msg.VChannel()) - if _, ok := dispatchedMessages[pchannel]; !ok { - dispatchedMessages[pchannel] = make([]message.MutableMessage, 0) - indexes[pchannel] = make([]int, 0) - } - dispatchedMessages[pchannel] = append(dispatchedMessages[pchannel], msg) - indexes[pchannel] = append(indexes[pchannel], idx) - } - return dispatchedMessages, indexes -} - -// appendToPChannel appends the messages to the specified pchannel. -func (w *walAccesserImpl) appendToPChannel(ctx context.Context, pchannel string, msgs ...message.MutableMessage) AppendResponses { - if len(msgs) == 0 { - return newAppendResponseN(0) - } - resp := newAppendResponseN(len(msgs)) - +// appendToWAL appends the message to the wal. +func (w *walAccesserImpl) appendToWAL(ctx context.Context, msg message.MutableMessage) (*types.AppendResult, error) { + pchannel := funcutil.ToPhysicalChannel(msg.VChannel()) // get producer of pchannel. p := w.getProducer(pchannel) - - // if only one message here, append it directly, no more goroutine needed. - // at most time, there's only one message here. - // TODO: only the partition-key with high partition will generate many message in one time on the same pchannel, - // we should optimize the message-format, make it into one; but not the goroutine count. - if len(msgs) == 1 { - produceResult, err := p.Produce(ctx, msgs[0]) - resp.fillResponseAtIdx(AppendResponse{ - AppendResult: produceResult, - Error: err, - }, 0) - return resp - } - - // concurrent produce here. - wg := sync.WaitGroup{} - wg.Add(len(msgs)) - - mu := sync.Mutex{} - for i, msg := range msgs { - i := i - msg := msg - w.appendExecutionPool.Submit(func() (struct{}, error) { - defer wg.Done() - msgID, err := p.Produce(ctx, msg) - - mu.Lock() - resp.fillResponseAtIdx(AppendResponse{ - AppendResult: msgID, - Error: err, - }, i) - mu.Unlock() - return struct{}{}, nil - }) - } - wg.Wait() - return resp + return p.Produce(ctx, msg) } // createOrGetProducer creates or get a producer. @@ -183,3 +33,21 @@ func (w *walAccesserImpl) getProducer(pchannel string) *producer.ResumableProduc w.producers[pchannel] = p return p } + +// assertNoSystemMessage asserts the message is not system message. +func assertNoSystemMessage(msgs ...message.MutableMessage) { + for _, msg := range msgs { + if msg.MessageType().IsSystem() { + panic("system message is not allowed to append from client") + } + } +} + +// We only support delete and insert message for txn now. +func assertIsDmlMessage(msgs ...message.MutableMessage) { + for _, msg := range msgs { + if msg.MessageType() != message.MessageTypeInsert && msg.MessageType() != message.MessageTypeDelete { + panic("only insert and delete message is allowed in txn") + } + } +} diff --git a/internal/distributed/streaming/internal/errs/error.go b/internal/distributed/streaming/internal/errs/error.go index 84b47ef802..5001ac442b 100644 --- a/internal/distributed/streaming/internal/errs/error.go +++ b/internal/distributed/streaming/internal/errs/error.go @@ -6,6 +6,7 @@ import ( // All error in streamingservice package should be marked by streamingservice/errs package. var ( - ErrClosed = errors.New("closed") - ErrCanceled = errors.New("canceled") + ErrClosed = errors.New("closed") + ErrCanceled = errors.New("canceled") + ErrTxnUnavailable = errors.New("transaction unavailable") ) diff --git a/internal/distributed/streaming/internal/producer/producer.go b/internal/distributed/streaming/internal/producer/producer.go index d53452ff42..6c080e5fd3 100644 --- a/internal/distributed/streaming/internal/producer/producer.go +++ b/internal/distributed/streaming/internal/producer/producer.go @@ -97,6 +97,13 @@ func (p *ResumableProducer) Produce(ctx context.Context, msg message.MutableMess if status.IsCanceled(err) { return nil, errors.Mark(err, errs.ErrCanceled) } + if sErr := status.AsStreamingError(err); sErr != nil { + // if the error is txn unavailable, it cannot be retried forever. + // we should mark it and return. + if sErr.IsTxnUnavilable() { + return nil, errors.Mark(err, errs.ErrTxnUnavailable) + } + } } } diff --git a/internal/distributed/streaming/streaming.go b/internal/distributed/streaming/streaming.go index 39d4cb3198..3bd722a357 100644 --- a/internal/distributed/streaming/streaming.go +++ b/internal/distributed/streaming/streaming.go @@ -2,10 +2,12 @@ package streaming import ( "context" + "time" kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/options" + "github.com/milvus-io/milvus/pkg/streaming/util/types" ) var singleton WALAccesser = nil @@ -19,7 +21,9 @@ func Init() { // Release releases the resources of the wal accesser. func Release() { - singleton.Close() + if w, ok := singleton.(*walAccesserImpl); ok && w != nil { + w.Close() + } } // WAL is the entrance to interact with the milvus write ahead log. @@ -27,6 +31,23 @@ func WAL() WALAccesser { return singleton } +// AppendOption is the option for append operation. +type AppendOption struct { + BarrierTimeTick uint64 // BarrierTimeTick is the barrier time tick of the message. + // Must be allocated from tso, otherwise undetermined behavior. +} + +type TxnOption struct { + // VChannel is the target vchannel to write. + // TODO: support cross-wal txn in future. + VChannel string + + // Keepalive is the time to keepalive of the transaction. + // If the txn don't append message in the keepalive time, the txn will be expired. + // Only make sense when ttl is greater than 1ms. + Keepalive time.Duration +} + type ReadOption struct { // VChannel is the target vchannel to read. VChannel string @@ -55,13 +76,47 @@ type Scanner interface { // WALAccesser is the interfaces to interact with the milvus write ahead log. type WALAccesser interface { - // Append writes a record to the log. - // !!! Append didn't promise the order of the message and atomic write. - Append(ctx context.Context, msgs ...message.MutableMessage) AppendResponses + // Txn returns a transaction for writing records to the log. + // Once the txn is returned, the Commit or Rollback operation must be called once, otherwise resource leak on wal. + Txn(ctx context.Context, opts TxnOption) (Txn, error) + + // Append writes a records to the log. + Append(ctx context.Context, msgs message.MutableMessage, opts ...AppendOption) (*types.AppendResult, error) // Read returns a scanner for reading records from the wal. Read(ctx context.Context, opts ReadOption) Scanner - // Close closes the wal accesser - Close() + // Utility returns the utility for writing records to the log. + Utility() Utility +} + +// Txn is the interface for writing transaction into the wal. +type Txn interface { + // Append writes a record to the log. + Append(ctx context.Context, msg message.MutableMessage, opts ...AppendOption) error + + // Commit commits the transaction. + // Commit and Rollback can be only call once, and not concurrent safe with append operation. + Commit(ctx context.Context) (*types.AppendResult, error) + + // Rollback rollbacks the transaction. + // Commit and Rollback can be only call once, and not concurrent safe with append operation. + // TODO: Manually rollback is make no sense for current single wal txn. + // It is preserved for future cross-wal txn. + Rollback(ctx context.Context) error +} + +type Utility interface { + // AppendMessages appends messages to the wal. + // It it a helper utility function to append messages to the wal. + // If the messages is belong to one vchannel, it will be sent as a transaction. + // Otherwise, it will be sent as individual messages. + // !!! This function do not promise the atomicity and deliver order of the messages appending. + // TODO: Remove after we support cross-wal txn. + AppendMessages(ctx context.Context, msgs ...message.MutableMessage) AppendResponses + + // AppendMessagesWithOption appends messages to the wal with the given option. + // Same with AppendMessages, but with the given option. + // TODO: Remove after we support cross-wal txn. + AppendMessagesWithOption(ctx context.Context, opts AppendOption, msgs ...message.MutableMessage) AppendResponses } diff --git a/internal/distributed/streaming/streaming_test.go b/internal/distributed/streaming/streaming_test.go index 945d465dc6..e0db8bb3eb 100644 --- a/internal/distributed/streaming/streaming_test.go +++ b/internal/distributed/streaming/streaming_test.go @@ -18,13 +18,13 @@ const vChannel = "by-dev-rootcoord-dml_4" func TestMain(m *testing.M) { paramtable.Init() - streaming.Init() - defer streaming.Release() m.Run() } func TestStreamingProduce(t *testing.T) { t.Skip() + streaming.Init() + defer streaming.Release() msg, _ := message.NewCreateCollectionMessageBuilderV1(). WithHeader(&message.CreateCollectionMessageHeader{ CollectionId: 1, @@ -35,10 +35,10 @@ func TestStreamingProduce(t *testing.T) { }). WithVChannel(vChannel). BuildMutable() - resp := streaming.WAL().Append(context.Background(), msg) - fmt.Printf("%+v\n", resp) + resp, err := streaming.WAL().Append(context.Background(), msg) + fmt.Printf("%+v\t%+v\n", resp, err) - for i := 0; i < 1000; i++ { + for i := 0; i < 500; i++ { time.Sleep(time.Millisecond * 1) msg, _ := message.NewInsertMessageBuilderV1(). WithHeader(&message.InsertMessageHeader{ @@ -49,8 +49,38 @@ func TestStreamingProduce(t *testing.T) { }). WithVChannel(vChannel). BuildMutable() - resp := streaming.WAL().Append(context.Background(), msg) - fmt.Printf("%+v\n", resp) + resp, err := streaming.WAL().Append(context.Background(), msg) + fmt.Printf("%+v\t%+v\n", resp, err) + } + + for i := 0; i < 500; i++ { + time.Sleep(time.Millisecond * 1) + txn, err := streaming.WAL().Txn(context.Background(), streaming.TxnOption{ + VChannel: vChannel, + Keepalive: 100 * time.Millisecond, + }) + if err != nil { + t.Errorf("txn failed: %v", err) + return + } + for j := 0; j < 5; j++ { + msg, _ := message.NewInsertMessageBuilderV1(). + WithHeader(&message.InsertMessageHeader{ + CollectionId: 1, + }). + WithBody(&msgpb.InsertRequest{ + CollectionID: 1, + }). + WithVChannel(vChannel). + BuildMutable() + err := txn.Append(context.Background(), msg) + fmt.Printf("%+v\n", err) + } + result, err := txn.Commit(context.Background()) + if err != nil { + t.Errorf("txn failed: %v", err) + } + fmt.Printf("%+v\n", result) } msg, _ = message.NewDropCollectionMessageBuilderV1(). @@ -62,12 +92,14 @@ func TestStreamingProduce(t *testing.T) { }). WithVChannel(vChannel). BuildMutable() - resp = streaming.WAL().Append(context.Background(), msg) - fmt.Printf("%+v\n", resp) + resp, err = streaming.WAL().Append(context.Background(), msg) + fmt.Printf("%+v\t%+v\n", resp, err) } func TestStreamingConsume(t *testing.T) { t.Skip() + streaming.Init() + defer streaming.Release() ch := make(message.ChanMessageHandler, 10) s := streaming.WAL().Read(context.Background(), streaming.ReadOption{ VChannel: vChannel, @@ -83,8 +115,9 @@ func TestStreamingConsume(t *testing.T) { time.Sleep(10 * time.Millisecond) select { case msg := <-ch: - fmt.Printf("msgID=%+v, tt=%d, lca=%+v, body=%s, idx=%d\n", + fmt.Printf("msgID=%+v, msgType=%+v, tt=%d, lca=%+v, body=%s, idx=%d\n", msg.MessageID(), + msg.MessageType(), msg.TimeTick(), msg.LastConfirmedMessageID(), string(msg.Payload()), diff --git a/internal/distributed/streaming/txn.go b/internal/distributed/streaming/txn.go new file mode 100644 index 0000000000..771f22ec8e --- /dev/null +++ b/internal/distributed/streaming/txn.go @@ -0,0 +1,103 @@ +package streaming + +import ( + "context" + "sync" + + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" +) + +var _ Txn = (*txnImpl)(nil) + +// txnImpl is the implementation of Txn. +type txnImpl struct { + mu sync.Mutex + inFlightCount int + state message.TxnState + opts TxnOption + txnCtx *message.TxnContext + *walAccesserImpl +} + +// Append writes records to the log. +func (t *txnImpl) Append(ctx context.Context, msg message.MutableMessage, opts ...AppendOption) error { + assertNoSystemMessage(msg) + assertIsDmlMessage(msg) + + t.mu.Lock() + if t.state != message.TxnStateInFlight { + t.mu.Unlock() + return status.NewInvalidTransactionState("Append", message.TxnStateInFlight, t.state) + } + t.inFlightCount++ + t.mu.Unlock() + + defer func() { + t.mu.Lock() + t.inFlightCount-- + t.mu.Unlock() + }() + + // assert if vchannel is equal. + if msg.VChannel() != t.opts.VChannel { + panic("vchannel not match when using transaction") + } + + // setup txn context and add to wal. + applyOpt(msg, opts...) + _, err := t.appendToWAL(ctx, msg.WithTxnContext(*t.txnCtx)) + return err +} + +// Commit commits the transaction. +func (t *txnImpl) Commit(ctx context.Context) (*types.AppendResult, error) { + t.mu.Lock() + if t.state != message.TxnStateInFlight { + t.mu.Unlock() + return nil, status.NewInvalidTransactionState("Commit", message.TxnStateInFlight, t.state) + } + t.state = message.TxnStateCommitted + if t.inFlightCount != 0 { + panic("in flight count not zero when commit") + } + t.mu.Unlock() + defer t.walAccesserImpl.lifetime.Done() + + commit, err := message.NewCommitTxnMessageBuilderV2(). + WithVChannel(t.opts.VChannel). + WithHeader(&message.CommitTxnMessageHeader{}). + WithBody(&message.CommitTxnMessageBody{}). + BuildMutable() + if err != nil { + return nil, err + } + return t.appendToWAL(ctx, commit.WithTxnContext(*t.txnCtx)) +} + +// Rollback rollbacks the transaction. +func (t *txnImpl) Rollback(ctx context.Context) error { + t.mu.Lock() + if t.state != message.TxnStateInFlight { + t.mu.Unlock() + return status.NewInvalidTransactionState("Rollback", message.TxnStateInFlight, t.state) + } + t.state = message.TxnStateRollbacked + if t.inFlightCount != 0 { + panic("in flight count not zero when rollback") + } + t.mu.Unlock() + defer t.walAccesserImpl.lifetime.Done() + + rollback, err := message.NewRollbackTxnMessageBuilderV2(). + WithVChannel(t.opts.VChannel). + WithHeader(&message.RollbackTxnMessageHeader{}). + WithBody(&message.RollbackTxnMessageBody{}). + BuildMutable() + if err != nil { + return err + } + _, err = t.appendToWAL(ctx, rollback.WithTxnContext(*t.txnCtx)) + return err +} diff --git a/internal/distributed/streaming/util.go b/internal/distributed/streaming/util.go new file mode 100644 index 0000000000..8701ef9462 --- /dev/null +++ b/internal/distributed/streaming/util.go @@ -0,0 +1,209 @@ +package streaming + +import ( + "context" + "sync" + "time" + + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/util/conc" +) + +type utility struct { + appendExecutionPool *conc.Pool[struct{}] + dispatchExecutionPool *conc.Pool[struct{}] + *walAccesserImpl +} + +// AppendMessagesToWAL appends messages to the wal. +// It it a helper utility function to append messages to the wal. +// If the messages is belong to one vchannel, it will be sent as a transaction. +// Otherwise, it will be sent as individual messages. +// !!! This function do not promise the atomicity and deliver order of the messages appending. +func (u *utility) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) AppendResponses { + assertNoSystemMessage(msgs...) + + // dispatch the messages into different vchannel. + dispatchedMessages, indexes := u.dispatchMessages(msgs...) + + // If only one vchannel, append it directly without other goroutine. + if len(dispatchedMessages) == 1 { + return u.appendToVChannel(ctx, msgs[0].VChannel(), msgs...) + } + + // Otherwise append the messages concurrently. + mu := &sync.Mutex{} + resp := newAppendResponseN(len(msgs)) + + wg := &sync.WaitGroup{} + wg.Add(len(dispatchedMessages)) + for vchannel, msgs := range dispatchedMessages { + vchannel := vchannel + msgs := msgs + idxes := indexes[vchannel] + u.dispatchExecutionPool.Submit(func() (struct{}, error) { + defer wg.Done() + singleResp := u.appendToVChannel(ctx, vchannel, msgs...) + mu.Lock() + for i, idx := range idxes { + resp.fillResponseAtIdx(singleResp.Responses[i], idx) + } + mu.Unlock() + return struct{}{}, nil + }) + } + wg.Wait() + return resp +} + +// AppendMessagesWithOption appends messages to the wal with the given option. +func (u *utility) AppendMessagesWithOption(ctx context.Context, opts AppendOption, msgs ...message.MutableMessage) AppendResponses { + for _, msg := range msgs { + applyOpt(msg, opts) + } + return u.AppendMessages(ctx, msgs...) +} + +// dispatchMessages dispatches the messages into different vchannel. +func (u *utility) dispatchMessages(msgs ...message.MutableMessage) (map[string][]message.MutableMessage, map[string][]int) { + dispatchedMessages := make(map[string][]message.MutableMessage, 0) + indexes := make(map[string][]int, 0) + for idx, msg := range msgs { + vchannel := msg.VChannel() + if _, ok := dispatchedMessages[vchannel]; !ok { + dispatchedMessages[vchannel] = make([]message.MutableMessage, 0) + indexes[vchannel] = make([]int, 0) + } + dispatchedMessages[vchannel] = append(dispatchedMessages[vchannel], msg) + indexes[vchannel] = append(indexes[vchannel], idx) + } + return dispatchedMessages, indexes +} + +// appendToVChannel appends the messages to the specified vchannel. +func (u *utility) appendToVChannel(ctx context.Context, vchannel string, msgs ...message.MutableMessage) AppendResponses { + if len(msgs) == 0 { + return newAppendResponseN(0) + } + resp := newAppendResponseN(len(msgs)) + + // if only one message here, append it directly, no more goroutine needed. + // at most time, there's only one message here. + // TODO: only the partition-key with high partition will generate many message in one time on the same pchannel, + // we should optimize the message-format, make it into one; but not the goroutine count. + if len(msgs) == 1 { + appendResult, err := u.appendToWAL(ctx, msgs[0]) + resp.fillResponseAtIdx(AppendResponse{ + AppendResult: appendResult, + Error: err, + }, 0) + return resp + } + + // Otherwise, we start a transaction to append the messages. + // The transaction will be committed when all messages are appended. + txn, err := u.Txn(ctx, TxnOption{ + VChannel: vchannel, + Keepalive: 5 * time.Second, + }) + if err != nil { + resp.fillAllError(err) + return resp + } + + // concurrent produce here. + wg := sync.WaitGroup{} + wg.Add(len(msgs)) + + mu := sync.Mutex{} + for i, msg := range msgs { + i := i + msg := msg + u.appendExecutionPool.Submit(func() (struct{}, error) { + defer wg.Done() + if err := txn.Append(ctx, msg); err != nil { + mu.Lock() + resp.fillResponseAtIdx(AppendResponse{ + Error: err, + }, i) + mu.Unlock() + } + return struct{}{}, nil + }) + } + wg.Wait() + + // if there's any error, rollback the transaction. + // and fill the error with the first error. + if err := resp.UnwrapFirstError(); err != nil { + _ = txn.Rollback(ctx) // rollback failure can be ignored. + resp.fillAllError(err) + return resp + } + + // commit the transaction and fill the response. + appendResult, err := txn.Commit(ctx) + resp.fillAllResponse(AppendResponse{ + AppendResult: appendResult, + Error: err, + }) + return resp +} + +// newAppendResponseN creates a new append response. +func newAppendResponseN(n int) AppendResponses { + return AppendResponses{ + Responses: make([]AppendResponse, n), + } +} + +// AppendResponse is the response of one append operation. +type AppendResponse struct { + AppendResult *types.AppendResult + Error error +} + +// AppendResponses is the response of append operation. +type AppendResponses struct { + Responses []AppendResponse +} + +// UnwrapFirstError returns the first error in the responses. +func (a AppendResponses) UnwrapFirstError() error { + for _, r := range a.Responses { + if r.Error != nil { + return r.Error + } + } + return nil +} + +// fillAllError fills all the responses with the same error. +func (a *AppendResponses) fillAllError(err error) { + for i := range a.Responses { + a.Responses[i].Error = err + } +} + +// fillResponseAtIdx fill the response at idx +func (a *AppendResponses) fillResponseAtIdx(resp AppendResponse, idx int) { + a.Responses[idx] = resp +} + +func (a *AppendResponses) fillAllResponse(resp AppendResponse) { + for i := range a.Responses { + a.Responses[i] = resp + } +} + +// applyOpt applies the append options to the message. +func applyOpt(msg message.MutableMessage, opts ...AppendOption) message.MutableMessage { + if len(opts) == 0 { + return msg + } + if opts[0].BarrierTimeTick > 0 { + msg = msg.WithBarrierTimeTick(opts[0].BarrierTimeTick) + } + return msg +} diff --git a/internal/distributed/streaming/wal.go b/internal/distributed/streaming/wal.go index 58c7ef0d05..e5450226b0 100644 --- a/internal/distributed/streaming/wal.go +++ b/internal/distributed/streaming/wal.go @@ -3,6 +3,7 @@ package streaming import ( "context" "sync" + "time" clientv3 "go.etcd.io/etcd/client/v3" @@ -13,6 +14,7 @@ import ( "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/options" + "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/lifetime" @@ -30,8 +32,11 @@ func newWALAccesser(c *clientv3.Client) *walAccesserImpl { handlerClient: handlerClient, producerMutex: sync.Mutex{}, producers: make(map[string]*producer.ResumableProducer), - // TODO: make the pool size configurable. - appendExecutionPool: conc.NewPool[struct{}](10), + utility: &utility{ + // TODO: optimize the pool size, use the streaming api but not goroutines. + appendExecutionPool: conc.NewPool[struct{}](10), + dispatchExecutionPool: conc.NewPool[struct{}](10), + }, } } @@ -43,26 +48,21 @@ type walAccesserImpl struct { streamingCoordAssignmentClient client.Client handlerClient handler.HandlerClient - producerMutex sync.Mutex - producers map[string]*producer.ResumableProducer - appendExecutionPool *conc.Pool[struct{}] + producerMutex sync.Mutex + producers map[string]*producer.ResumableProducer + utility *utility } // Append writes a record to the log. -func (w *walAccesserImpl) Append(ctx context.Context, msgs ...message.MutableMessage) AppendResponses { +func (w *walAccesserImpl) Append(ctx context.Context, msg message.MutableMessage, opts ...AppendOption) (*types.AppendResult, error) { + assertNoSystemMessage(msg) if err := w.lifetime.Add(lifetime.IsWorking); err != nil { - err := status.NewOnShutdownError("wal accesser closed, %s", err.Error()) - resp := newAppendResponseN(len(msgs)) - resp.fillAllError(err) - return resp + return nil, status.NewOnShutdownError("wal accesser closed, %s", err.Error()) } defer w.lifetime.Done() - // If there is only one message, append it to the corresponding pchannel is ok. - if len(msgs) <= 1 { - return w.appendToPChannel(ctx, funcutil.ToPhysicalChannel(msgs[0].VChannel()), msgs...) - } - return w.dispatchByPChannel(ctx, msgs...) + msg = applyOpt(msg, opts...) + return w.appendToWAL(ctx, msg) } // Read returns a scanner for reading records from the wal. @@ -84,6 +84,56 @@ func (w *walAccesserImpl) Read(_ context.Context, opts ReadOption) Scanner { return rc } +func (w *walAccesserImpl) Txn(ctx context.Context, opts TxnOption) (Txn, error) { + if err := w.lifetime.Add(lifetime.IsWorking); err != nil { + return nil, status.NewOnShutdownError("wal accesser closed, %s", err.Error()) + } + + if opts.VChannel == "" { + return nil, status.NewInvaildArgument("vchannel is required") + } + if opts.Keepalive < 1*time.Millisecond { + return nil, status.NewInvaildArgument("ttl must be greater than or equal to 1ms") + } + + // Create a new transaction, send the begin txn message. + beginTxn, err := message.NewBeginTxnMessageBuilderV2(). + WithVChannel(opts.VChannel). + WithHeader(&message.BeginTxnMessageHeader{ + KeepaliveMilliseconds: opts.Keepalive.Milliseconds(), + }). + WithBody(&message.BeginTxnMessageBody{}). + BuildMutable() + if err != nil { + w.lifetime.Done() + return nil, err + } + + appendResult, err := w.appendToWAL(ctx, beginTxn) + if err != nil { + w.lifetime.Done() + return nil, err + } + + // Create new transaction success. + return &txnImpl{ + mu: sync.Mutex{}, + state: message.TxnStateInFlight, + opts: opts, + txnCtx: appendResult.TxnCtx, + walAccesserImpl: w, + }, nil +} + +// Utility returns the utility of the wal accesser. +func (w *walAccesserImpl) Utility() Utility { + return &utility{ + appendExecutionPool: w.utility.appendExecutionPool, + dispatchExecutionPool: w.utility.dispatchExecutionPool, + walAccesserImpl: w, + } +} + // Close closes all the wal accesser. func (w *walAccesserImpl) Close() { w.lifetime.SetState(lifetime.Stopped) diff --git a/internal/distributed/streaming/wal_test.go b/internal/distributed/streaming/wal_test.go new file mode 100644 index 0000000000..258f30ec00 --- /dev/null +++ b/internal/distributed/streaming/wal_test.go @@ -0,0 +1,131 @@ +package streaming + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/distributed/streaming/internal/producer" + "github.com/milvus-io/milvus/internal/mocks/streamingcoord/mock_client" + "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/handler/mock_producer" + "github.com/milvus-io/milvus/internal/mocks/streamingnode/client/mock_handler" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/types" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" + "github.com/milvus-io/milvus/pkg/util/conc" + "github.com/milvus-io/milvus/pkg/util/lifetime" +) + +const ( + vChannel1 = "by-dev-rootcoord-dml_1" + vChannel2 = "by-dev-rootcoord-dml_2" + vChannel3 = "by-dev-rootcoord-dml_3" +) + +func TestWAL(t *testing.T) { + coordClient := mock_client.NewMockClient(t) + coordClient.EXPECT().Close().Return() + handler := mock_handler.NewMockHandlerClient(t) + handler.EXPECT().Close().Return() + + w := &walAccesserImpl{ + lifetime: lifetime.NewLifetime(lifetime.Working), + streamingCoordAssignmentClient: coordClient, + handlerClient: handler, + producerMutex: sync.Mutex{}, + producers: make(map[string]*producer.ResumableProducer), + utility: &utility{ + appendExecutionPool: conc.NewPool[struct{}](10), + dispatchExecutionPool: conc.NewPool[struct{}](10), + }, + } + defer w.Close() + + ctx := context.Background() + + available := make(chan struct{}) + p := mock_producer.NewMockProducer(t) + p.EXPECT().IsAvailable().RunAndReturn(func() bool { + select { + case <-available: + return false + default: + return true + } + }) + p.EXPECT().Produce(mock.Anything, mock.Anything).Return(&types.AppendResult{ + MessageID: walimplstest.NewTestMessageID(1), + TimeTick: 10, + TxnCtx: &message.TxnContext{ + TxnID: 1, + Keepalive: 10 * time.Second, + }, + }, nil) + p.EXPECT().Available().Return(available) + p.EXPECT().Close().Return() + + handler.EXPECT().CreateProducer(mock.Anything, mock.Anything).Return(p, nil) + result, err := w.Append(ctx, newInsertMessage(vChannel1)) + assert.NoError(t, err) + assert.NotNil(t, result) + + // Test committed txn. + txn, err := w.Txn(ctx, TxnOption{ + VChannel: vChannel1, + Keepalive: 10 * time.Second, + }) + assert.NoError(t, err) + assert.NotNil(t, txn) + + err = txn.Append(ctx, newInsertMessage(vChannel1)) + assert.NoError(t, err) + err = txn.Append(ctx, newInsertMessage(vChannel1)) + assert.NoError(t, err) + + result, err = txn.Commit(ctx) + assert.NoError(t, err) + assert.NotNil(t, result) + + // Test rollback txn. + txn, err = w.Txn(ctx, TxnOption{ + VChannel: vChannel1, + Keepalive: 10 * time.Second, + }) + assert.NoError(t, err) + assert.NotNil(t, txn) + + err = txn.Append(ctx, newInsertMessage(vChannel1)) + assert.NoError(t, err) + err = txn.Append(ctx, newInsertMessage(vChannel1)) + assert.NoError(t, err) + + err = txn.Rollback(ctx) + assert.NoError(t, err) + + resp := w.Utility().AppendMessages(ctx, + newInsertMessage(vChannel1), + newInsertMessage(vChannel2), + newInsertMessage(vChannel2), + newInsertMessage(vChannel3), + newInsertMessage(vChannel3), + newInsertMessage(vChannel3), + ) + assert.NoError(t, resp.UnwrapFirstError()) +} + +func newInsertMessage(vChannel string) message.MutableMessage { + msg, err := message.NewInsertMessageBuilderV1(). + WithVChannel(vChannel). + WithHeader(&message.InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{}). + BuildMutable() + if err != nil { + panic(err) + } + return msg +} diff --git a/internal/distributed/streamingnode/service.go b/internal/distributed/streamingnode/service.go index 27ee50a77f..8a909b2324 100644 --- a/internal/distributed/streamingnode/service.go +++ b/internal/distributed/streamingnode/service.go @@ -39,9 +39,11 @@ import ( rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" tikvkv "github.com/milvus-io/milvus/internal/kv/tikv" + "github.com/milvus-io/milvus/internal/storage" streamingnodeserver "github.com/milvus-io/milvus/internal/streamingnode/server" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/componentutil" + "github.com/milvus-io/milvus/internal/util/dependency" kvfactory "github.com/milvus-io/milvus/internal/util/dependency/kv" "github.com/milvus-io/milvus/internal/util/sessionutil" streamingserviceinterceptor "github.com/milvus-io/milvus/internal/util/streamingutil/service/interceptor" @@ -75,28 +77,35 @@ type Server struct { lis net.Listener // component client - etcdCli *clientv3.Client - tikvCli *txnkv.Client - rootCoord types.RootCoordClient - dataCoord types.DataCoordClient + etcdCli *clientv3.Client + tikvCli *txnkv.Client + rootCoord types.RootCoordClient + dataCoord types.DataCoordClient + chunkManager storage.ChunkManager + f dependency.Factory } // NewServer create a new StreamingNode server. -func NewServer() (*Server, error) { +func NewServer(f dependency.Factory) (*Server, error) { return &Server{ stopOnce: sync.Once{}, grpcServerChan: make(chan struct{}), + f: f, }, nil } // Run runs the server. func (s *Server) Run() error { - if err := s.init(); err != nil { + // TODO: We should set a timeout for the process startup. + // But currently, we don't implement. + ctx := context.Background() + + if err := s.init(ctx); err != nil { return err } log.Info("streamingnode init done ...") - if err := s.start(); err != nil { + if err := s.start(ctx); err != nil { return err } log.Info("streamingnode start done ...") @@ -156,7 +165,7 @@ func (s *Server) Health(ctx context.Context) commonpb.StateCode { return s.streamingnode.Health(ctx) } -func (s *Server) init() (err error) { +func (s *Server) init(ctx context.Context) (err error) { defer func() { if err != nil { log.Error("StreamingNode init failed", zap.Error(err)) @@ -174,13 +183,16 @@ func (s *Server) init() (err error) { if err := s.allocateAddress(); err != nil { return err } - if err := s.initSession(); err != nil { + if err := s.initSession(ctx); err != nil { return err } - if err := s.initRootCoord(); err != nil { + if err := s.initRootCoord(ctx); err != nil { return err } - if err := s.initDataCoord(); err != nil { + if err := s.initDataCoord(ctx); err != nil { + return err + } + if err := s.initChunkManager(ctx); err != nil { return err } s.initGRPCServer() @@ -193,14 +205,15 @@ func (s *Server) init() (err error) { WithDataCoordClient(s.dataCoord). WithSession(s.session). WithMetaKV(s.metaKV). + WithChunkManager(s.chunkManager). Build() - if err := s.streamingnode.Init(context.Background()); err != nil { + if err := s.streamingnode.Init(ctx); err != nil { return errors.Wrap(err, "StreamingNode service init failed") } return nil } -func (s *Server) start() (err error) { +func (s *Server) start(ctx context.Context) (err error) { defer func() { if err != nil { log.Error("StreamingNode start failed", zap.Error(err)) @@ -213,7 +226,7 @@ func (s *Server) start() (err error) { s.streamingnode.Start() // Start grpc server. - if err := s.startGPRCServer(); err != nil { + if err := s.startGPRCServer(ctx); err != nil { return errors.Wrap(err, "StreamingNode start gRPC server fail") } @@ -222,8 +235,8 @@ func (s *Server) start() (err error) { return nil } -func (s *Server) initSession() error { - s.session = sessionutil.NewSession(context.Background()) +func (s *Server) initSession(ctx context.Context) error { + s.session = sessionutil.NewSession(ctx) if s.session == nil { return errors.New("session is nil, the etcd client connection may have failed") } @@ -260,36 +273,47 @@ func (s *Server) initMeta() error { return nil } -func (s *Server) initRootCoord() (err error) { +func (s *Server) initRootCoord(ctx context.Context) (err error) { log.Info("StreamingNode connect to rootCoord...") - s.rootCoord, err = rcc.NewClient(context.Background()) + s.rootCoord, err = rcc.NewClient(ctx) if err != nil { return errors.Wrap(err, "StreamingNode try to new RootCoord client failed") } log.Info("StreamingNode try to wait for RootCoord ready") - err = componentutil.WaitForComponentHealthy(context.Background(), s.rootCoord, "RootCoord", 1000000, time.Millisecond*200) + err = componentutil.WaitForComponentHealthy(ctx, s.rootCoord, "RootCoord", 1000000, time.Millisecond*200) if err != nil { return errors.Wrap(err, "StreamingNode wait for RootCoord ready failed") } return nil } -func (s *Server) initDataCoord() (err error) { +func (s *Server) initDataCoord(ctx context.Context) (err error) { log.Info("StreamingNode connect to dataCoord...") - s.dataCoord, err = dcc.NewClient(context.Background()) + s.dataCoord, err = dcc.NewClient(ctx) if err != nil { return errors.Wrap(err, "StreamingNode try to new DataCoord client failed") } log.Info("StreamingNode try to wait for DataCoord ready") - err = componentutil.WaitForComponentHealthy(context.Background(), s.dataCoord, "DataCoord", 1000000, time.Millisecond*200) + err = componentutil.WaitForComponentHealthy(ctx, s.dataCoord, "DataCoord", 1000000, time.Millisecond*200) if err != nil { return errors.Wrap(err, "StreamingNode wait for DataCoord ready failed") } return nil } +func (s *Server) initChunkManager(ctx context.Context) (err error) { + log.Info("StreamingNode init chunk manager...") + s.f.Init(paramtable.Get()) + manager, err := s.f.NewPersistentStorageChunkManager(ctx) + if err != nil { + return errors.Wrap(err, "StreamingNode try to new chunk manager failed") + } + s.chunkManager = manager + return nil +} + func (s *Server) initGRPCServer() { log.Info("create StreamingNode server...") cfg := ¶mtable.Get().StreamingNodeGrpcServerCfg @@ -357,7 +381,7 @@ func (s *Server) getAddress() (string, error) { } // startGRPCServer starts the grpc server. -func (s *Server) startGPRCServer() error { +func (s *Server) startGPRCServer(ctx context.Context) error { errCh := make(chan error, 1) go func() { defer close(s.grpcServerChan) @@ -372,7 +396,7 @@ func (s *Server) startGPRCServer() error { } } }() - funcutil.CheckGrpcReady(context.Background(), errCh) + funcutil.CheckGrpcReady(ctx, errCh) return <-errCh } diff --git a/internal/metastore/kv/streamingnode/constant.go b/internal/metastore/kv/streamingnode/constant.go index 111b0ef7fd..d1bf796f28 100644 --- a/internal/metastore/kv/streamingnode/constant.go +++ b/internal/metastore/kv/streamingnode/constant.go @@ -1,6 +1,7 @@ package streamingnode const ( - MetaPrefix = "streamingnode-meta" - SegmentAssignMeta = MetaPrefix + "/segment-assign" + MetaPrefix = "streamingnode-meta" + SegmentAssignMeta = MetaPrefix + "/segment-assign" + SegmentAssignSubFolder = "s" ) diff --git a/internal/metastore/kv/streamingnode/kv_catalog.go b/internal/metastore/kv/streamingnode/kv_catalog.go index 7fe5d14499..a7b4ac803f 100644 --- a/internal/metastore/kv/streamingnode/kv_catalog.go +++ b/internal/metastore/kv/streamingnode/kv_catalog.go @@ -83,10 +83,13 @@ func (c *catalog) SaveSegmentAssignments(ctx context.Context, pChannelName strin // buildSegmentAssignmentMetaPath builds the path for segment assignment // streamingnode-meta/segment-assign/${pChannelName} func buildSegmentAssignmentMetaPath(pChannelName string) string { - return path.Join(SegmentAssignMeta, pChannelName) + // !!! bad implementation here, but we can't make compatibility for underlying meta kv. + // underlying meta kv will remove the last '/' of the path, cause the pchannel lost. + // So we add a special sub path to avoid this. + return path.Join(SegmentAssignMeta, pChannelName, SegmentAssignSubFolder) + "/" } // buildSegmentAssignmentMetaPathOfSegment builds the path for segment assignment func buildSegmentAssignmentMetaPathOfSegment(pChannelName string, segmentID int64) string { - return path.Join(SegmentAssignMeta, pChannelName, strconv.FormatInt(segmentID, 10)) + return path.Join(SegmentAssignMeta, pChannelName, SegmentAssignSubFolder, strconv.FormatInt(segmentID, 10)) } diff --git a/internal/mocks/distributed/mock_streaming/mock_Utility.go b/internal/mocks/distributed/mock_streaming/mock_Utility.go new file mode 100644 index 0000000000..e2fa616a5a --- /dev/null +++ b/internal/mocks/distributed/mock_streaming/mock_Utility.go @@ -0,0 +1,154 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_streaming + +import ( + context "context" + + message "github.com/milvus-io/milvus/pkg/streaming/util/message" + mock "github.com/stretchr/testify/mock" + + streaming "github.com/milvus-io/milvus/internal/distributed/streaming" +) + +// MockUtility is an autogenerated mock type for the Utility type +type MockUtility struct { + mock.Mock +} + +type MockUtility_Expecter struct { + mock *mock.Mock +} + +func (_m *MockUtility) EXPECT() *MockUtility_Expecter { + return &MockUtility_Expecter{mock: &_m.Mock} +} + +// AppendMessages provides a mock function with given fields: ctx, msgs +func (_m *MockUtility) AppendMessages(ctx context.Context, msgs ...message.MutableMessage) streaming.AppendResponses { + _va := make([]interface{}, len(msgs)) + for _i := range msgs { + _va[_i] = msgs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 streaming.AppendResponses + if rf, ok := ret.Get(0).(func(context.Context, ...message.MutableMessage) streaming.AppendResponses); ok { + r0 = rf(ctx, msgs...) + } else { + r0 = ret.Get(0).(streaming.AppendResponses) + } + + return r0 +} + +// MockUtility_AppendMessages_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AppendMessages' +type MockUtility_AppendMessages_Call struct { + *mock.Call +} + +// AppendMessages is a helper method to define mock.On call +// - ctx context.Context +// - msgs ...message.MutableMessage +func (_e *MockUtility_Expecter) AppendMessages(ctx interface{}, msgs ...interface{}) *MockUtility_AppendMessages_Call { + return &MockUtility_AppendMessages_Call{Call: _e.mock.On("AppendMessages", + append([]interface{}{ctx}, msgs...)...)} +} + +func (_c *MockUtility_AppendMessages_Call) Run(run func(ctx context.Context, msgs ...message.MutableMessage)) *MockUtility_AppendMessages_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]message.MutableMessage, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(message.MutableMessage) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockUtility_AppendMessages_Call) Return(_a0 streaming.AppendResponses) *MockUtility_AppendMessages_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockUtility_AppendMessages_Call) RunAndReturn(run func(context.Context, ...message.MutableMessage) streaming.AppendResponses) *MockUtility_AppendMessages_Call { + _c.Call.Return(run) + return _c +} + +// AppendMessagesWithOption provides a mock function with given fields: ctx, opts, msgs +func (_m *MockUtility) AppendMessagesWithOption(ctx context.Context, opts streaming.AppendOption, msgs ...message.MutableMessage) streaming.AppendResponses { + _va := make([]interface{}, len(msgs)) + for _i := range msgs { + _va[_i] = msgs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, opts) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 streaming.AppendResponses + if rf, ok := ret.Get(0).(func(context.Context, streaming.AppendOption, ...message.MutableMessage) streaming.AppendResponses); ok { + r0 = rf(ctx, opts, msgs...) + } else { + r0 = ret.Get(0).(streaming.AppendResponses) + } + + return r0 +} + +// MockUtility_AppendMessagesWithOption_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AppendMessagesWithOption' +type MockUtility_AppendMessagesWithOption_Call struct { + *mock.Call +} + +// AppendMessagesWithOption is a helper method to define mock.On call +// - ctx context.Context +// - opts streaming.AppendOption +// - msgs ...message.MutableMessage +func (_e *MockUtility_Expecter) AppendMessagesWithOption(ctx interface{}, opts interface{}, msgs ...interface{}) *MockUtility_AppendMessagesWithOption_Call { + return &MockUtility_AppendMessagesWithOption_Call{Call: _e.mock.On("AppendMessagesWithOption", + append([]interface{}{ctx, opts}, msgs...)...)} +} + +func (_c *MockUtility_AppendMessagesWithOption_Call) Run(run func(ctx context.Context, opts streaming.AppendOption, msgs ...message.MutableMessage)) *MockUtility_AppendMessagesWithOption_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]message.MutableMessage, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(message.MutableMessage) + } + } + run(args[0].(context.Context), args[1].(streaming.AppendOption), variadicArgs...) + }) + return _c +} + +func (_c *MockUtility_AppendMessagesWithOption_Call) Return(_a0 streaming.AppendResponses) *MockUtility_AppendMessagesWithOption_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockUtility_AppendMessagesWithOption_Call) RunAndReturn(run func(context.Context, streaming.AppendOption, ...message.MutableMessage) streaming.AppendResponses) *MockUtility_AppendMessagesWithOption_Call { + _c.Call.Return(run) + return _c +} + +// NewMockUtility creates a new instance of MockUtility. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockUtility(t interface { + mock.TestingT + Cleanup(func()) +}) *MockUtility { + mock := &MockUtility{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go index 733381f1bd..5090e413bc 100644 --- a/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go +++ b/internal/mocks/distributed/mock_streaming/mock_WALAccesser.go @@ -9,6 +9,8 @@ import ( mock "github.com/stretchr/testify/mock" streaming "github.com/milvus-io/milvus/internal/distributed/streaming" + + types "github.com/milvus-io/milvus/pkg/streaming/util/types" ) // MockWALAccesser is an autogenerated mock type for the WALAccesser type @@ -24,25 +26,37 @@ func (_m *MockWALAccesser) EXPECT() *MockWALAccesser_Expecter { return &MockWALAccesser_Expecter{mock: &_m.Mock} } -// Append provides a mock function with given fields: ctx, msgs -func (_m *MockWALAccesser) Append(ctx context.Context, msgs ...message.MutableMessage) streaming.AppendResponses { - _va := make([]interface{}, len(msgs)) - for _i := range msgs { - _va[_i] = msgs[_i] +// Append provides a mock function with given fields: ctx, msgs, opts +func (_m *MockWALAccesser) Append(ctx context.Context, msgs message.MutableMessage, opts ...streaming.AppendOption) (*types.AppendResult, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] } var _ca []interface{} - _ca = append(_ca, ctx) + _ca = append(_ca, ctx, msgs) _ca = append(_ca, _va...) ret := _m.Called(_ca...) - var r0 streaming.AppendResponses - if rf, ok := ret.Get(0).(func(context.Context, ...message.MutableMessage) streaming.AppendResponses); ok { - r0 = rf(ctx, msgs...) + var r0 *types.AppendResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, ...streaming.AppendOption) (*types.AppendResult, error)); ok { + return rf(ctx, msgs, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, message.MutableMessage, ...streaming.AppendOption) *types.AppendResult); ok { + r0 = rf(ctx, msgs, opts...) } else { - r0 = ret.Get(0).(streaming.AppendResponses) + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.AppendResult) + } } - return r0 + if rf, ok := ret.Get(1).(func(context.Context, message.MutableMessage, ...streaming.AppendOption) error); ok { + r1 = rf(ctx, msgs, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // MockWALAccesser_Append_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Append' @@ -51,64 +65,33 @@ type MockWALAccesser_Append_Call struct { } // Append is a helper method to define mock.On call -// - ctx context.Context -// - msgs ...message.MutableMessage -func (_e *MockWALAccesser_Expecter) Append(ctx interface{}, msgs ...interface{}) *MockWALAccesser_Append_Call { +// - ctx context.Context +// - msgs message.MutableMessage +// - opts ...streaming.AppendOption +func (_e *MockWALAccesser_Expecter) Append(ctx interface{}, msgs interface{}, opts ...interface{}) *MockWALAccesser_Append_Call { return &MockWALAccesser_Append_Call{Call: _e.mock.On("Append", - append([]interface{}{ctx}, msgs...)...)} + append([]interface{}{ctx, msgs}, opts...)...)} } -func (_c *MockWALAccesser_Append_Call) Run(run func(ctx context.Context, msgs ...message.MutableMessage)) *MockWALAccesser_Append_Call { +func (_c *MockWALAccesser_Append_Call) Run(run func(ctx context.Context, msgs message.MutableMessage, opts ...streaming.AppendOption)) *MockWALAccesser_Append_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]message.MutableMessage, len(args)-1) - for i, a := range args[1:] { + variadicArgs := make([]streaming.AppendOption, len(args)-2) + for i, a := range args[2:] { if a != nil { - variadicArgs[i] = a.(message.MutableMessage) + variadicArgs[i] = a.(streaming.AppendOption) } } - run(args[0].(context.Context), variadicArgs...) + run(args[0].(context.Context), args[1].(message.MutableMessage), variadicArgs...) }) return _c } -func (_c *MockWALAccesser_Append_Call) Return(_a0 streaming.AppendResponses) *MockWALAccesser_Append_Call { - _c.Call.Return(_a0) +func (_c *MockWALAccesser_Append_Call) Return(_a0 *types.AppendResult, _a1 error) *MockWALAccesser_Append_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *MockWALAccesser_Append_Call) RunAndReturn(run func(context.Context, ...message.MutableMessage) streaming.AppendResponses) *MockWALAccesser_Append_Call { - _c.Call.Return(run) - return _c -} - -// Close provides a mock function with given fields: -func (_m *MockWALAccesser) Close() { - _m.Called() -} - -// MockWALAccesser_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' -type MockWALAccesser_Close_Call struct { - *mock.Call -} - -// Close is a helper method to define mock.On call -func (_e *MockWALAccesser_Expecter) Close() *MockWALAccesser_Close_Call { - return &MockWALAccesser_Close_Call{Call: _e.mock.On("Close")} -} - -func (_c *MockWALAccesser_Close_Call) Run(run func()) *MockWALAccesser_Close_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *MockWALAccesser_Close_Call) Return() *MockWALAccesser_Close_Call { - _c.Call.Return() - return _c -} - -func (_c *MockWALAccesser_Close_Call) RunAndReturn(run func()) *MockWALAccesser_Close_Call { +func (_c *MockWALAccesser_Append_Call) RunAndReturn(run func(context.Context, message.MutableMessage, ...streaming.AppendOption) (*types.AppendResult, error)) *MockWALAccesser_Append_Call { _c.Call.Return(run) return _c } @@ -135,8 +118,8 @@ type MockWALAccesser_Read_Call struct { } // Read is a helper method to define mock.On call -// - ctx context.Context -// - opts streaming.ReadOption +// - ctx context.Context +// - opts streaming.ReadOption func (_e *MockWALAccesser_Expecter) Read(ctx interface{}, opts interface{}) *MockWALAccesser_Read_Call { return &MockWALAccesser_Read_Call{Call: _e.mock.On("Read", ctx, opts)} } @@ -158,6 +141,104 @@ func (_c *MockWALAccesser_Read_Call) RunAndReturn(run func(context.Context, stre return _c } +// Txn provides a mock function with given fields: ctx, opts +func (_m *MockWALAccesser) Txn(ctx context.Context, opts streaming.TxnOption) (streaming.Txn, error) { + ret := _m.Called(ctx, opts) + + var r0 streaming.Txn + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, streaming.TxnOption) (streaming.Txn, error)); ok { + return rf(ctx, opts) + } + if rf, ok := ret.Get(0).(func(context.Context, streaming.TxnOption) streaming.Txn); ok { + r0 = rf(ctx, opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(streaming.Txn) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, streaming.TxnOption) error); ok { + r1 = rf(ctx, opts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockWALAccesser_Txn_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Txn' +type MockWALAccesser_Txn_Call struct { + *mock.Call +} + +// Txn is a helper method to define mock.On call +// - ctx context.Context +// - opts streaming.TxnOption +func (_e *MockWALAccesser_Expecter) Txn(ctx interface{}, opts interface{}) *MockWALAccesser_Txn_Call { + return &MockWALAccesser_Txn_Call{Call: _e.mock.On("Txn", ctx, opts)} +} + +func (_c *MockWALAccesser_Txn_Call) Run(run func(ctx context.Context, opts streaming.TxnOption)) *MockWALAccesser_Txn_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(streaming.TxnOption)) + }) + return _c +} + +func (_c *MockWALAccesser_Txn_Call) Return(_a0 streaming.Txn, _a1 error) *MockWALAccesser_Txn_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockWALAccesser_Txn_Call) RunAndReturn(run func(context.Context, streaming.TxnOption) (streaming.Txn, error)) *MockWALAccesser_Txn_Call { + _c.Call.Return(run) + return _c +} + +// Utility provides a mock function with given fields: +func (_m *MockWALAccesser) Utility() streaming.Utility { + ret := _m.Called() + + var r0 streaming.Utility + if rf, ok := ret.Get(0).(func() streaming.Utility); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(streaming.Utility) + } + } + + return r0 +} + +// MockWALAccesser_Utility_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Utility' +type MockWALAccesser_Utility_Call struct { + *mock.Call +} + +// Utility is a helper method to define mock.On call +func (_e *MockWALAccesser_Expecter) Utility() *MockWALAccesser_Utility_Call { + return &MockWALAccesser_Utility_Call{Call: _e.mock.On("Utility")} +} + +func (_c *MockWALAccesser_Utility_Call) Run(run func()) *MockWALAccesser_Utility_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockWALAccesser_Utility_Call) Return(_a0 streaming.Utility) *MockWALAccesser_Utility_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockWALAccesser_Utility_Call) RunAndReturn(run func() streaming.Utility) *MockWALAccesser_Utility_Call { + _c.Call.Return(run) + return _c +} + // NewMockWALAccesser creates a new instance of MockWALAccesser. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockWALAccesser(t interface { diff --git a/internal/mocks/streamingcoord/mock_client/mock_Client.go b/internal/mocks/streamingcoord/mock_client/mock_Client.go new file mode 100644 index 0000000000..719b05c716 --- /dev/null +++ b/internal/mocks/streamingcoord/mock_client/mock_Client.go @@ -0,0 +1,110 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_client + +import ( + client "github.com/milvus-io/milvus/internal/streamingcoord/client" + mock "github.com/stretchr/testify/mock" +) + +// MockClient is an autogenerated mock type for the Client type +type MockClient struct { + mock.Mock +} + +type MockClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockClient) EXPECT() *MockClient_Expecter { + return &MockClient_Expecter{mock: &_m.Mock} +} + +// Assignment provides a mock function with given fields: +func (_m *MockClient) Assignment() client.AssignmentService { + ret := _m.Called() + + var r0 client.AssignmentService + if rf, ok := ret.Get(0).(func() client.AssignmentService); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.AssignmentService) + } + } + + return r0 +} + +// MockClient_Assignment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Assignment' +type MockClient_Assignment_Call struct { + *mock.Call +} + +// Assignment is a helper method to define mock.On call +func (_e *MockClient_Expecter) Assignment() *MockClient_Assignment_Call { + return &MockClient_Assignment_Call{Call: _e.mock.On("Assignment")} +} + +func (_c *MockClient_Assignment_Call) Run(run func()) *MockClient_Assignment_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_Assignment_Call) Return(_a0 client.AssignmentService) *MockClient_Assignment_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockClient_Assignment_Call) RunAndReturn(run func() client.AssignmentService) *MockClient_Assignment_Call { + _c.Call.Return(run) + return _c +} + +// Close provides a mock function with given fields: +func (_m *MockClient) Close() { + _m.Called() +} + +// MockClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockClient_Expecter) Close() *MockClient_Close_Call { + return &MockClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockClient_Close_Call) Run(run func()) *MockClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockClient_Close_Call) Return() *MockClient_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockClient_Close_Call) RunAndReturn(run func()) *MockClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// NewMockClient creates a new instance of MockClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockClient { + mock := &MockClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/mocks/streamingnode/client/mock_handler/mock_HandlerClient.go b/internal/mocks/streamingnode/client/mock_handler/mock_HandlerClient.go new file mode 100644 index 0000000000..3cf3fb40b9 --- /dev/null +++ b/internal/mocks/streamingnode/client/mock_handler/mock_HandlerClient.go @@ -0,0 +1,184 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_handler + +import ( + context "context" + + consumer "github.com/milvus-io/milvus/internal/streamingnode/client/handler/consumer" + + handler "github.com/milvus-io/milvus/internal/streamingnode/client/handler" + + mock "github.com/stretchr/testify/mock" + + producer "github.com/milvus-io/milvus/internal/streamingnode/client/handler/producer" +) + +// MockHandlerClient is an autogenerated mock type for the HandlerClient type +type MockHandlerClient struct { + mock.Mock +} + +type MockHandlerClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockHandlerClient) EXPECT() *MockHandlerClient_Expecter { + return &MockHandlerClient_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *MockHandlerClient) Close() { + _m.Called() +} + +// MockHandlerClient_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockHandlerClient_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *MockHandlerClient_Expecter) Close() *MockHandlerClient_Close_Call { + return &MockHandlerClient_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *MockHandlerClient_Close_Call) Run(run func()) *MockHandlerClient_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockHandlerClient_Close_Call) Return() *MockHandlerClient_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockHandlerClient_Close_Call) RunAndReturn(run func()) *MockHandlerClient_Close_Call { + _c.Call.Return(run) + return _c +} + +// CreateConsumer provides a mock function with given fields: ctx, opts +func (_m *MockHandlerClient) CreateConsumer(ctx context.Context, opts *handler.ConsumerOptions) (consumer.Consumer, error) { + ret := _m.Called(ctx, opts) + + var r0 consumer.Consumer + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *handler.ConsumerOptions) (consumer.Consumer, error)); ok { + return rf(ctx, opts) + } + if rf, ok := ret.Get(0).(func(context.Context, *handler.ConsumerOptions) consumer.Consumer); ok { + r0 = rf(ctx, opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(consumer.Consumer) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *handler.ConsumerOptions) error); ok { + r1 = rf(ctx, opts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockHandlerClient_CreateConsumer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateConsumer' +type MockHandlerClient_CreateConsumer_Call struct { + *mock.Call +} + +// CreateConsumer is a helper method to define mock.On call +// - ctx context.Context +// - opts *handler.ConsumerOptions +func (_e *MockHandlerClient_Expecter) CreateConsumer(ctx interface{}, opts interface{}) *MockHandlerClient_CreateConsumer_Call { + return &MockHandlerClient_CreateConsumer_Call{Call: _e.mock.On("CreateConsumer", ctx, opts)} +} + +func (_c *MockHandlerClient_CreateConsumer_Call) Run(run func(ctx context.Context, opts *handler.ConsumerOptions)) *MockHandlerClient_CreateConsumer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*handler.ConsumerOptions)) + }) + return _c +} + +func (_c *MockHandlerClient_CreateConsumer_Call) Return(_a0 consumer.Consumer, _a1 error) *MockHandlerClient_CreateConsumer_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockHandlerClient_CreateConsumer_Call) RunAndReturn(run func(context.Context, *handler.ConsumerOptions) (consumer.Consumer, error)) *MockHandlerClient_CreateConsumer_Call { + _c.Call.Return(run) + return _c +} + +// CreateProducer provides a mock function with given fields: ctx, opts +func (_m *MockHandlerClient) CreateProducer(ctx context.Context, opts *handler.ProducerOptions) (producer.Producer, error) { + ret := _m.Called(ctx, opts) + + var r0 producer.Producer + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *handler.ProducerOptions) (producer.Producer, error)); ok { + return rf(ctx, opts) + } + if rf, ok := ret.Get(0).(func(context.Context, *handler.ProducerOptions) producer.Producer); ok { + r0 = rf(ctx, opts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(producer.Producer) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *handler.ProducerOptions) error); ok { + r1 = rf(ctx, opts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockHandlerClient_CreateProducer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateProducer' +type MockHandlerClient_CreateProducer_Call struct { + *mock.Call +} + +// CreateProducer is a helper method to define mock.On call +// - ctx context.Context +// - opts *handler.ProducerOptions +func (_e *MockHandlerClient_Expecter) CreateProducer(ctx interface{}, opts interface{}) *MockHandlerClient_CreateProducer_Call { + return &MockHandlerClient_CreateProducer_Call{Call: _e.mock.On("CreateProducer", ctx, opts)} +} + +func (_c *MockHandlerClient_CreateProducer_Call) Run(run func(ctx context.Context, opts *handler.ProducerOptions)) *MockHandlerClient_CreateProducer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*handler.ProducerOptions)) + }) + return _c +} + +func (_c *MockHandlerClient_CreateProducer_Call) Return(_a0 producer.Producer, _a1 error) *MockHandlerClient_CreateProducer_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockHandlerClient_CreateProducer_Call) RunAndReturn(run func(context.Context, *handler.ProducerOptions) (producer.Producer, error)) *MockHandlerClient_CreateProducer_Call { + _c.Call.Return(run) + return _c +} + +// NewMockHandlerClient creates a new instance of MockHandlerClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockHandlerClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockHandlerClient { + mock := &MockHandlerClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/rootcoord/create_partition_task.go b/internal/rootcoord/create_partition_task.go index 6d7ceaa73d..76a5c7a718 100644 --- a/internal/rootcoord/create_partition_task.go +++ b/internal/rootcoord/create_partition_task.go @@ -102,6 +102,7 @@ func (t *createPartitionTask) Execute(ctx context.Context) error { baseStep: baseStep{core: t.core}, vchannels: t.collMeta.VirtualChannelNames, partition: partition, + ts: t.GetTs(), }, &nullStep{}) } diff --git a/internal/rootcoord/garbage_collector.go b/internal/rootcoord/garbage_collector.go index e637d4b29a..0c36d59ee0 100644 --- a/internal/rootcoord/garbage_collector.go +++ b/internal/rootcoord/garbage_collector.go @@ -232,10 +232,14 @@ func (c *bgGarbageCollector) notifyPartitionGc(ctx context.Context, pChannels [] } func (c *bgGarbageCollector) notifyPartitionGcByStreamingService(ctx context.Context, vchannels []string, partition *model.Partition) (uint64, error) { + ts, err := c.s.tsoAllocator.GenerateTSO(1) + if err != nil { + return 0, err + } req := &msgpb.DropPartitionRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_DropPartition), - commonpbutil.WithTimeStamp(0), // ts is given by streamingnode. + commonpbutil.WithTimeStamp(ts), commonpbutil.WithSourceID(c.s.session.ServerID), ), PartitionName: partition.PartitionName, @@ -258,12 +262,13 @@ func (c *bgGarbageCollector) notifyPartitionGcByStreamingService(ctx context.Con } msgs = append(msgs, msg) } - resp := streaming.WAL().Append(ctx, msgs...) - if err := resp.UnwrapFirstError(); err != nil { + // Ts is used as barrier time tick to ensure the message's time tick are given after the barrier time tick. + if err := streaming.WAL().Utility().AppendMessagesWithOption(ctx, streaming.AppendOption{ + BarrierTimeTick: ts, + }, msgs...).UnwrapFirstError(); err != nil { return 0, err } - // TODO: sheep, return resp.MaxTimeTick(), nil - return c.s.tsoAllocator.GenerateTSO(1) + return ts, nil } func (c *bgGarbageCollector) GcCollectionData(ctx context.Context, coll *model.Collection) (ddlTs Timestamp, err error) { diff --git a/internal/rootcoord/garbage_collector_test.go b/internal/rootcoord/garbage_collector_test.go index d63a066fb7..43b9d34741 100644 --- a/internal/rootcoord/garbage_collector_test.go +++ b/internal/rootcoord/garbage_collector_test.go @@ -547,7 +547,9 @@ func TestGcPartitionData(t *testing.T) { defer streamingutil.UnsetStreamingServiceEnabled() wal := mock_streaming.NewMockWALAccesser(t) - wal.EXPECT().Append(mock.Anything, mock.Anything, mock.Anything).Return(streaming.AppendResponses{}) + u := mock_streaming.NewMockUtility(t) + u.EXPECT().AppendMessagesWithOption(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(streaming.AppendResponses{}) + wal.EXPECT().Utility().Return(u) streaming.SetWALForTest(wal) tsoAllocator := mocktso.NewAllocator(t) diff --git a/internal/rootcoord/step.go b/internal/rootcoord/step.go index 657e8d66e6..42c47dbcaa 100644 --- a/internal/rootcoord/step.go +++ b/internal/rootcoord/step.go @@ -383,6 +383,7 @@ type broadcastCreatePartitionMsgStep struct { baseStep vchannels []string partition *model.Partition + ts Timestamp } func (s *broadcastCreatePartitionMsgStep) Execute(ctx context.Context) ([]nestedStep, error) { @@ -411,8 +412,9 @@ func (s *broadcastCreatePartitionMsgStep) Execute(ctx context.Context) ([]nested } msgs = append(msgs, msg) } - resp := streaming.WAL().Append(ctx, msgs...) - if err := resp.UnwrapFirstError(); err != nil { + if err := streaming.WAL().Utility().AppendMessagesWithOption(ctx, streaming.AppendOption{ + BarrierTimeTick: s.ts, + }, msgs...).UnwrapFirstError(); err != nil { return nil, err } return nil, nil diff --git a/internal/rootcoord/step_test.go b/internal/rootcoord/step_test.go index c59c50b1d3..958a946b1f 100644 --- a/internal/rootcoord/step_test.go +++ b/internal/rootcoord/step_test.go @@ -123,7 +123,9 @@ func TestSkip(t *testing.T) { func TestBroadcastCreatePartitionMsgStep(t *testing.T) { wal := mock_streaming.NewMockWALAccesser(t) - wal.EXPECT().Append(mock.Anything, mock.Anything, mock.Anything).Return(streaming.AppendResponses{}) + u := mock_streaming.NewMockUtility(t) + u.EXPECT().AppendMessagesWithOption(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(streaming.AppendResponses{}) + wal.EXPECT().Utility().Return(u) streaming.SetWALForTest(wal) step := &broadcastCreatePartitionMsgStep{ diff --git a/internal/streamingnode/client/handler/consumer/consumer_impl.go b/internal/streamingnode/client/handler/consumer/consumer_impl.go index 17abf846ec..a879c39a29 100644 --- a/internal/streamingnode/client/handler/consumer/consumer_impl.go +++ b/internal/streamingnode/client/handler/consumer/consumer_impl.go @@ -97,6 +97,7 @@ type consumerImpl struct { logger *log.MLogger msgHandler message.Handler finishErr *syncutil.Future[error] + txnBuilder *message.ImmutableTxnMessageBuilder } // Close close the consumer client. @@ -132,6 +133,16 @@ func (c *consumerImpl) execute() { // recvLoop is the recv arm of the grpc stream. // Throughput of the grpc framework should be ok to use single stream to receive message. // Once throughput is not enough, look at https://grpc.io/docs/guides/performance/ to find the solution. +// recvLoop will always receive message from server by following sequence: +// - message at timetick 4. +// - message at timetick 5. +// - txn begin message at timetick 1. +// - txn body message at timetick 2. +// - txn body message at timetick 3. +// - txn commit message at timetick 6. +// - message at timetick 7. +// - Close. +// - EOF. func (c *consumerImpl) recvLoop() (err error) { defer func() { if err != nil { @@ -157,11 +168,19 @@ func (c *consumerImpl) recvLoop() (err error) { if err != nil { return err } - c.msgHandler.Handle(message.NewImmutableMesasge( + newImmutableMsg := message.NewImmutableMesasge( msgID, resp.Consume.GetMessage().GetPayload(), resp.Consume.GetMessage().GetProperties(), - )) + ) + if newImmutableMsg.TxnContext() != nil { + c.handleTxnMessage(newImmutableMsg) + } else { + if c.txnBuilder != nil { + panic("unreachable code: txn builder should be nil if we receive a non-txn message") + } + c.msgHandler.Handle(newImmutableMsg) + } case *streamingpb.ConsumeResponse_Close: // Should receive io.EOF after that. // Do nothing at current implementation. @@ -170,3 +189,40 @@ func (c *consumerImpl) recvLoop() (err error) { } } } + +func (c *consumerImpl) handleTxnMessage(msg message.ImmutableMessage) { + switch msg.MessageType() { + case message.MessageTypeBeginTxn: + if c.txnBuilder != nil { + panic("unreachable code: txn builder should be nil if we receive a begin txn message") + } + beginMsg, err := message.AsImmutableBeginTxnMessageV2(msg) + if err != nil { + c.logger.Warn("failed to convert message to begin txn message", zap.Any("messageID", beginMsg.MessageID()), zap.Error(err)) + return + } + c.txnBuilder = message.NewImmutableTxnMessageBuilder(beginMsg) + case message.MessageTypeCommitTxn: + if c.txnBuilder == nil { + panic("unreachable code: txn builder should not be nil if we receive a commit txn message") + } + commitMsg, err := message.AsImmutableCommitTxnMessageV2(msg) + if err != nil { + c.logger.Warn("failed to convert message to commit txn message", zap.Any("messageID", commitMsg.MessageID()), zap.Error(err)) + c.txnBuilder = nil + return + } + msg, err := c.txnBuilder.Build(commitMsg) + c.txnBuilder = nil + if err != nil { + c.logger.Warn("failed to build txn message", zap.Any("messageID", commitMsg.MessageID()), zap.Error(err)) + return + } + c.msgHandler.Handle(msg) + default: + if c.txnBuilder == nil { + panic("unreachable code: txn builder should not be nil if we receive a non-begin txn message") + } + c.txnBuilder.Add(msg) + } +} diff --git a/internal/streamingnode/client/handler/consumer/consumer_test.go b/internal/streamingnode/client/handler/consumer/consumer_test.go index 0011aa853f..656e92e234 100644 --- a/internal/streamingnode/client/handler/consumer/consumer_test.go +++ b/internal/streamingnode/client/handler/consumer/consumer_test.go @@ -4,10 +4,12 @@ import ( "context" "io" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mocks/streaming/proto/mock_streamingpb" "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" @@ -15,6 +17,7 @@ import ( "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) func TestConsumer(t *testing.T) { @@ -58,29 +61,73 @@ func TestConsumer(t *testing.T) { recvCh <- &streamingpb.ConsumeResponse{ Response: &streamingpb.ConsumeResponse_Create{ Create: &streamingpb.CreateConsumerResponse{ - WalName: "test", + WalName: walimplstest.WALName, }, }, } - recvCh <- &streamingpb.ConsumeResponse{ + + mmsg, _ := message.NewInsertMessageBuilderV1(). + WithHeader(&message.InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{}). + WithVChannel("test-1"). + BuildMutable() + recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(1), mmsg) + + consumer, err := CreateConsumer(ctx, opts, c) + assert.NoError(t, err) + assert.NotNil(t, consumer) + msg := <-resultCh + assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(1))) + + txnCtx := message.TxnContext{ + TxnID: 1, + Keepalive: time.Second, + } + mmsg, _ = message.NewBeginTxnMessageBuilderV2(). + WithVChannel("test-1"). + WithHeader(&message.BeginTxnMessageHeader{}). + WithBody(&message.BeginTxnMessageBody{}). + BuildMutable() + recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(2), mmsg.WithTxnContext(txnCtx)) + + mmsg, _ = message.NewInsertMessageBuilderV1(). + WithVChannel("test-1"). + WithHeader(&message.InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{}). + BuildMutable() + recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(3), mmsg.WithTxnContext(txnCtx)) + + mmsg, _ = message.NewCommitTxnMessageBuilderV2(). + WithVChannel("test-1"). + WithHeader(&message.CommitTxnMessageHeader{}). + WithBody(&message.CommitTxnMessageBody{}). + BuildMutable() + recvCh <- newConsumeResponse(walimplstest.NewTestMessageID(4), mmsg.WithTxnContext(txnCtx)) + + msg = <-resultCh + assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(4))) + assert.Equal(t, msg.TxnContext().TxnID, txnCtx.TxnID) + assert.Equal(t, message.MessageTypeTxn, msg.MessageType()) + + consumer.Close() + <-consumer.Done() + assert.NoError(t, consumer.Error()) +} + +func newConsumeResponse(id message.MessageID, msg message.MutableMessage) *streamingpb.ConsumeResponse { + msg.WithTimeTick(tsoutil.GetCurrentTime()) + msg.WithLastConfirmed(walimplstest.NewTestMessageID(0)) + return &streamingpb.ConsumeResponse{ Response: &streamingpb.ConsumeResponse_Consume{ Consume: &streamingpb.ConsumeMessageReponse{ Message: &messagespb.ImmutableMessage{ Id: &messagespb.MessageID{ - Id: walimplstest.NewTestMessageID(1).Marshal(), + Id: id.Marshal(), }, - Payload: []byte{}, - Properties: make(map[string]string), + Payload: msg.Payload(), + Properties: msg.Properties().ToRawMap(), }, }, }, } - consumer, err := CreateConsumer(ctx, opts, c) - assert.NoError(t, err) - assert.NotNil(t, consumer) - consumer.Close() - msg := <-resultCh - assert.True(t, msg.MessageID().EQ(walimplstest.NewTestMessageID(1))) - <-consumer.Done() - assert.NoError(t, consumer.Error()) } diff --git a/internal/streamingnode/client/handler/producer/producer.go b/internal/streamingnode/client/handler/producer/producer.go index 8190f4110c..41dec673d9 100644 --- a/internal/streamingnode/client/handler/producer/producer.go +++ b/internal/streamingnode/client/handler/producer/producer.go @@ -19,6 +19,7 @@ type Producer interface { Assignment() types.PChannelInfoAssigned // Produce sends the produce message to server. + // TODO: Support Batch produce here. Produce(ctx context.Context, msg message.MutableMessage) (*ProduceResult, error) // Check if a producer is available. diff --git a/internal/streamingnode/client/handler/producer/producer_impl.go b/internal/streamingnode/client/handler/producer/producer_impl.go index 9e690721b9..1317eaccc6 100644 --- a/internal/streamingnode/client/handler/producer/producer_impl.go +++ b/internal/streamingnode/client/handler/producer/producer_impl.go @@ -297,6 +297,8 @@ func (p *producerImpl) recvLoop() (err error) { result: &ProduceResult{ MessageID: msgID, TimeTick: produceResp.Result.GetTimetick(), + TxnCtx: message.NewTxnContextFromProto(produceResp.Result.GetTxnContext()), + Extra: produceResp.Result.GetExtra(), }, } case *streamingpb.ProduceMessageResponse_Error: diff --git a/internal/streamingnode/client/handler/producer/producer_test.go b/internal/streamingnode/client/handler/producer/producer_test.go index 7c9b0f647f..bea7eda13d 100644 --- a/internal/streamingnode/client/handler/producer/producer_test.go +++ b/internal/streamingnode/client/handler/producer/producer_test.go @@ -51,7 +51,7 @@ func TestProducer(t *testing.T) { recvCh <- &streamingpb.ProduceResponse{ Response: &streamingpb.ProduceResponse_Create{ Create: &streamingpb.CreateProducerResponse{ - WalName: "test", + WalName: walimplstest.WALName, }, }, } diff --git a/internal/streamingnode/server/builder.go b/internal/streamingnode/server/builder.go index f235fe4703..ee4b262f94 100644 --- a/internal/streamingnode/server/builder.go +++ b/internal/streamingnode/server/builder.go @@ -5,6 +5,8 @@ import ( "google.golang.org/grpc" "github.com/milvus-io/milvus/internal/metastore/kv/streamingnode" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/streamingnode/server/flusher/flusherimpl" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/componentutil" @@ -16,12 +18,13 @@ import ( // ServerBuilder is used to build a server. // All component should be initialized before server initialization should be added here. type ServerBuilder struct { - etcdClient *clientv3.Client - grpcServer *grpc.Server - rc types.RootCoordClient - dc types.DataCoordClient - session *sessionutil.Session - kv kv.MetaKv + etcdClient *clientv3.Client + grpcServer *grpc.Server + rc types.RootCoordClient + dc types.DataCoordClient + session *sessionutil.Session + kv kv.MetaKv + chunkManager storage.ChunkManager } // NewServerBuilder creates a new server builder. @@ -65,14 +68,24 @@ func (b *ServerBuilder) WithMetaKV(kv kv.MetaKv) *ServerBuilder { return b } +// WithChunkManager sets chunk manager to the server builder. +func (b *ServerBuilder) WithChunkManager(chunkManager storage.ChunkManager) *ServerBuilder { + b.chunkManager = chunkManager + return b +} + // Build builds a streaming node server. func (s *ServerBuilder) Build() *Server { - resource.Init( + resource.Apply( resource.OptETCD(s.etcdClient), resource.OptRootCoordClient(s.rc), resource.OptDataCoordClient(s.dc), resource.OptStreamingNodeCatalog(streamingnode.NewCataLog(s.kv)), ) + resource.Apply( + resource.OptFlusher(flusherimpl.NewFlusher(s.chunkManager)), + ) + resource.Done() return &Server{ session: s.session, grpcServer: s.grpcServer, diff --git a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go index e7b7b99fbe..432849273d 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl.go @@ -29,12 +29,12 @@ import ( "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" adaptor2 "github.com/milvus-io/milvus/internal/streamingnode/server/wal/adaptor" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/message/adaptor" "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -55,22 +55,28 @@ type flusherImpl struct { tasks *typeutil.ConcurrentMap[string, wal.WAL] // unwatched vchannels scanners *typeutil.ConcurrentMap[string, wal.Scanner] // watched scanners - stopOnce sync.Once - stopChan chan struct{} + stopOnce sync.Once + stopChan chan struct{} + pipelineParams *util.PipelineParams } -func NewFlusher() flusher.Flusher { - params := GetPipelineParams() +func NewFlusher(chunkManager storage.ChunkManager) flusher.Flusher { + params := getPipelineParams(chunkManager) + return newFlusherWithParam(params) +} + +func newFlusherWithParam(params *util.PipelineParams) flusher.Flusher { fgMgr := pipeline.NewFlowgraphManager() return &flusherImpl{ - fgMgr: fgMgr, - syncMgr: params.SyncMgr, - wbMgr: params.WriteBufferManager, - cpUpdater: params.CheckpointUpdater, - tasks: typeutil.NewConcurrentMap[string, wal.WAL](), - scanners: typeutil.NewConcurrentMap[string, wal.Scanner](), - stopOnce: sync.Once{}, - stopChan: make(chan struct{}), + fgMgr: fgMgr, + syncMgr: params.SyncMgr, + wbMgr: params.WriteBufferManager, + cpUpdater: params.CheckpointUpdater, + tasks: typeutil.NewConcurrentMap[string, wal.WAL](), + scanners: typeutil.NewConcurrentMap[string, wal.Scanner](), + stopOnce: sync.Once{}, + stopChan: make(chan struct{}), + pipelineParams: params, } } @@ -181,11 +187,12 @@ func (f *flusherImpl) buildPipeline(vchannel string, w wal.WAL) error { // Create scanner. policy := options.DeliverPolicyStartFrom(messageID) - filter := func(msg message.ImmutableMessage) bool { return msg.VChannel() == vchannel } handler := adaptor2.NewMsgPackAdaptorHandler() ro := wal.ReadOption{ - DeliverPolicy: policy, - MessageFilter: filter, + DeliverPolicy: policy, + MessageFilter: []options.DeliverFilter{ + options.DeliverFilterVChannel(vchannel), + }, MesasgeHandler: handler, } scanner, err := w.Read(ctx, ro) @@ -194,7 +201,7 @@ func (f *flusherImpl) buildPipeline(vchannel string, w wal.WAL) error { } // Build and add pipeline. - ds, err := pipeline.NewStreamingNodeDataSyncService(ctx, GetPipelineParams(), + ds, err := pipeline.NewStreamingNodeDataSyncService(ctx, f.pipelineParams, &datapb.ChannelWatchInfo{Vchan: resp.GetInfo(), Schema: resp.GetSchema()}, handler.Chan()) if err != nil { return err diff --git a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go index aa0f11aa4b..c4d6b18715 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go +++ b/internal/streamingnode/server/flusher/flusherimpl/flusher_impl_test.go @@ -48,6 +48,7 @@ type FlusherSuite struct { pchannel string vchannels []string + syncMgr *syncmgr.MockSyncManager wbMgr *writebuffer.MockBufferManager rootcoord *mocks.MockRootCoordClient @@ -89,22 +90,18 @@ func (s *FlusherSuite) SetupSuite() { }, nil }) - syncMgr := syncmgr.NewMockSyncManager(s.T()) - wbMgr := writebuffer.NewMockBufferManager(s.T()) - wbMgr.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - wbMgr.EXPECT().RemoveChannel(mock.Anything).Return() - wbMgr.EXPECT().Start().Return() - wbMgr.EXPECT().Stop().Return() - resource.InitForTest( s.T(), - resource.OptSyncManager(syncMgr), - resource.OptBufferManager(wbMgr), resource.OptRootCoordClient(rootcoord), resource.OptDataCoordClient(datacoord), ) - s.wbMgr = wbMgr + s.syncMgr = syncmgr.NewMockSyncManager(s.T()) + s.wbMgr = writebuffer.NewMockBufferManager(s.T()) + s.wbMgr.EXPECT().Register(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + s.wbMgr.EXPECT().RemoveChannel(mock.Anything).Return() + s.wbMgr.EXPECT().Start().Return() + s.wbMgr.EXPECT().Stop().Return() s.rootcoord = rootcoord } @@ -131,7 +128,12 @@ func (s *FlusherSuite) SetupTest() { }) s.wal = w - s.flusher = NewFlusher() + m := mocks.NewChunkManager(s.T()) + params := getPipelineParams(m) + params.SyncMgr = s.syncMgr + params.WriteBufferManager = s.wbMgr + + s.flusher = newFlusherWithParam(params) s.flusher.Start() } diff --git a/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go b/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go index 2427d35659..086f924efc 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go +++ b/internal/streamingnode/server/flusher/flusherimpl/pipeline_params.go @@ -18,43 +18,34 @@ package flusherimpl import ( "context" - "sync" "github.com/milvus-io/milvus/internal/flushcommon/broker" + "github.com/milvus-io/milvus/internal/flushcommon/syncmgr" "github.com/milvus-io/milvus/internal/flushcommon/util" + "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" + "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/pkg/util/paramtable" ) -var ( - pipelineParams *util.PipelineParams - initOnce sync.Once -) - -func initPipelineParams() { - initOnce.Do(func() { - var ( - rsc = resource.Resource() - syncMgr = rsc.SyncManager() - wbMgr = rsc.BufferManager() - coordBroker = broker.NewCoordBroker(rsc.DataCoordClient(), paramtable.GetNodeID()) - cpUpdater = util.NewChannelCheckpointUpdater(coordBroker) - ) - pipelineParams = &util.PipelineParams{ - Ctx: context.Background(), - Broker: coordBroker, - SyncMgr: syncMgr, - ChunkManager: rsc.ChunkManager(), - WriteBufferManager: wbMgr, - CheckpointUpdater: cpUpdater, - Allocator: idalloc.NewMAllocator(rsc.IDAllocator()), - FlushMsgHandler: flushMsgHandlerImpl(wbMgr), - } - }) -} - -func GetPipelineParams() *util.PipelineParams { - initPipelineParams() - return pipelineParams +// getPipelineParams initializes the pipeline parameters. +func getPipelineParams(chunkManager storage.ChunkManager) *util.PipelineParams { + var ( + rsc = resource.Resource() + syncMgr = syncmgr.NewSyncManager(chunkManager) + wbMgr = writebuffer.NewManager(syncMgr) + coordBroker = broker.NewCoordBroker(rsc.DataCoordClient(), paramtable.GetNodeID()) + cpUpdater = util.NewChannelCheckpointUpdater(coordBroker) + ) + return &util.PipelineParams{ + Ctx: context.Background(), + Broker: coordBroker, + SyncMgr: syncMgr, + ChunkManager: chunkManager, + WriteBufferManager: wbMgr, + CheckpointUpdater: cpUpdater, + Allocator: idalloc.NewMAllocator(rsc.IDAllocator()), + FlushMsgHandler: flushMsgHandlerImpl(wbMgr), + } } diff --git a/internal/streamingnode/server/resource/idalloc/test_mock_root_coord_client.go b/internal/streamingnode/server/resource/idalloc/test_mock_root_coord_client.go index 72d75d9db7..2d3ed8bec6 100644 --- a/internal/streamingnode/server/resource/idalloc/test_mock_root_coord_client.go +++ b/internal/streamingnode/server/resource/idalloc/test_mock_root_coord_client.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/stretchr/testify/mock" "go.uber.org/atomic" @@ -15,20 +16,32 @@ import ( "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) func NewMockRootCoordClient(t *testing.T) *mocks.MockRootCoordClient { counter := atomic.NewUint64(1) client := mocks.NewMockRootCoordClient(t) + lastAllocate := atomic.NewInt64(0) client.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn( func(ctx context.Context, atr *rootcoordpb.AllocTimestampRequest, co ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { if atr.Count > 1000 { panic(fmt.Sprintf("count %d is too large", atr.Count)) } - c := counter.Add(uint64(atr.Count)) + now := time.Now() + for { + lastAllocateMilli := lastAllocate.Load() + if now.UnixMilli() <= lastAllocateMilli { + now = time.Now() + continue + } + if lastAllocate.CompareAndSwap(lastAllocateMilli, now.UnixMilli()) { + break + } + } return &rootcoordpb.AllocTimestampResponse{ Status: merr.Success(), - Timestamp: c - uint64(atr.Count), + Timestamp: tsoutil.ComposeTSByTime(now, 0), Count: atr.Count, }, nil }, diff --git a/internal/streamingnode/server/resource/resource.go b/internal/streamingnode/server/resource/resource.go index f53a9a65cf..e86d490a22 100644 --- a/internal/streamingnode/server/resource/resource.go +++ b/internal/streamingnode/server/resource/resource.go @@ -5,8 +5,6 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" - "github.com/milvus-io/milvus/internal/flushcommon/syncmgr" - "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/streamingnode/server/flusher" @@ -17,7 +15,7 @@ import ( "github.com/milvus-io/milvus/internal/types" ) -var r *resourceImpl // singleton resource instance +var r = &resourceImpl{} // singleton resource instance // optResourceInit is the option to initialize the resource. type optResourceInit func(r *resourceImpl) @@ -29,20 +27,6 @@ func OptFlusher(flusher flusher.Flusher) optResourceInit { } } -// OptSyncManager provides the sync manager to the resource. -func OptSyncManager(syncMgr syncmgr.SyncManager) optResourceInit { - return func(r *resourceImpl) { - r.syncMgr = syncMgr - } -} - -// OptBufferManager provides the write buffer manager to the resource. -func OptBufferManager(wbMgr writebuffer.BufferManager) optResourceInit { - return func(r *resourceImpl) { - r.wbMgr = wbMgr - } -} - // OptETCD provides the etcd client to the resource. func OptETCD(etcd *clientv3.Client) optResourceInit { return func(r *resourceImpl) { @@ -61,6 +45,8 @@ func OptChunkManager(chunkManager storage.ChunkManager) optResourceInit { func OptRootCoordClient(rootCoordClient types.RootCoordClient) optResourceInit { return func(r *resourceImpl) { r.rootCoordClient = rootCoordClient + r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) + r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) } } @@ -78,19 +64,19 @@ func OptStreamingNodeCatalog(catalog metastore.StreamingNodeCataLog) optResource } } -// Init initializes the singleton of resources. +// Apply initializes the singleton of resources. // Should be call when streaming node startup. -func Init(opts ...optResourceInit) { - r = &resourceImpl{} +func Apply(opts ...optResourceInit) { for _, opt := range opts { opt(r) } - r.timestampAllocator = idalloc.NewTSOAllocator(r.rootCoordClient) - r.idAllocator = idalloc.NewIDAllocator(r.rootCoordClient) +} + +// Done finish all initialization of resources. +func Done() { r.segmentAssignStatsManager = stats.NewStatsManager() r.segmentSealedInspector = sinspector.NewSealedInspector(r.segmentAssignStatsManager.SealNotifier()) r.timeTickInspector = tinspector.NewTimeTickSyncInspector() - assertNotNil(r.TSOAllocator()) assertNotNil(r.RootCoordClient()) assertNotNil(r.DataCoordClient()) @@ -108,10 +94,7 @@ func Resource() *resourceImpl { // resourceImpl is a basic resource dependency for streamingnode server. // All utility on it is concurrent-safe and singleton. type resourceImpl struct { - flusher flusher.Flusher - syncMgr syncmgr.SyncManager - wbMgr writebuffer.BufferManager - + flusher flusher.Flusher timestampAllocator idalloc.Allocator idAllocator idalloc.Allocator etcdClient *clientv3.Client @@ -129,16 +112,6 @@ func (r *resourceImpl) Flusher() flusher.Flusher { return r.flusher } -// SyncManager returns the sync manager. -func (r *resourceImpl) SyncManager() syncmgr.SyncManager { - return r.syncMgr -} - -// BufferManager returns the write buffer manager. -func (r *resourceImpl) BufferManager() writebuffer.BufferManager { - return r.wbMgr -} - // TSOAllocator returns the timestamp allocator to allocate timestamp. func (r *resourceImpl) TSOAllocator() idalloc.Allocator { return r.timestampAllocator diff --git a/internal/streamingnode/server/resource/resource_test.go b/internal/streamingnode/server/resource/resource_test.go index ad114ebc82..1d8d4f976f 100644 --- a/internal/streamingnode/server/resource/resource_test.go +++ b/internal/streamingnode/server/resource/resource_test.go @@ -11,24 +11,24 @@ import ( "github.com/milvus-io/milvus/pkg/util/paramtable" ) -func TestInit(t *testing.T) { +func TestApply(t *testing.T) { paramtable.Init() + Apply() + Apply(OptETCD(&clientv3.Client{})) + Apply(OptRootCoordClient(mocks.NewMockRootCoordClient(t))) + assert.Panics(t, func() { - Init() + Done() }) - assert.Panics(t, func() { - Init(OptETCD(&clientv3.Client{})) - }) - assert.Panics(t, func() { - Init(OptRootCoordClient(mocks.NewMockRootCoordClient(t))) - }) - Init( + + Apply( OptETCD(&clientv3.Client{}), OptRootCoordClient(mocks.NewMockRootCoordClient(t)), OptDataCoordClient(mocks.NewMockDataCoordClient(t)), OptStreamingNodeCatalog(mock_metastore.NewMockStreamingNodeCataLog(t)), ) + Done() assert.NotNil(t, Resource().TSOAllocator()) assert.NotNil(t, Resource().ETCD()) diff --git a/internal/streamingnode/server/service/handler/consumer/consume_server.go b/internal/streamingnode/server/service/handler/consumer/consume_server.go index 19c4d577e5..29eac7ab61 100644 --- a/internal/streamingnode/server/service/handler/consumer/consume_server.go +++ b/internal/streamingnode/server/service/handler/consumer/consume_server.go @@ -15,7 +15,7 @@ import ( "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" - "github.com/milvus-io/milvus/pkg/streaming/util/options" + "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -36,13 +36,9 @@ func CreateConsumeServer(walManager walmanager.Manager, streamServer streamingpb if err != nil { return nil, err } - filter, err := options.GetFilterFunc(createReq.DeliverFilters) - if err != nil { - return nil, err - } scanner, err := l.Read(streamServer.Context(), wal.ReadOption{ DeliverPolicy: createReq.GetDeliverPolicy(), - MessageFilter: filter, + MessageFilter: createReq.DeliverFilters, }) if err != nil { return nil, err @@ -110,24 +106,28 @@ func (c *ConsumeServer) sendLoop() (err error) { if !ok { return status.NewInner("scanner error: %s", c.scanner.Error()) } - // Send Consumed message to client and do metrics. - messageSize := msg.EstimateSize() - if err := c.consumeServer.SendConsumeMessage(&streamingpb.ConsumeMessageReponse{ - Message: &messagespb.ImmutableMessage{ - Id: &messagespb.MessageID{ - Id: msg.MessageID().Marshal(), - }, - Payload: msg.Payload(), - Properties: msg.Properties().ToRawMap(), - }, - }); err != nil { - return status.NewInner("send consume message failed: %s", err.Error()) + // If the message is a transaction message, we should send the sub messages one by one, + // Otherwise we can send the full message directly. + if txnMsg, ok := msg.(message.ImmutableTxnMessage); ok { + if err := c.sendImmutableMessage(txnMsg.Begin()); err != nil { + return err + } + if err := txnMsg.RangeOver(func(im message.ImmutableMessage) error { + if err := c.sendImmutableMessage(im); err != nil { + return err + } + return nil + }); err != nil { + return err + } + if err := c.sendImmutableMessage(txnMsg.Commit()); err != nil { + return err + } + } else { + if err := c.sendImmutableMessage(msg); err != nil { + return err + } } - metrics.StreamingNodeConsumeBytes.WithLabelValues( - paramtable.GetStringNodeID(), - c.scanner.Channel().Name, - strconv.FormatInt(c.scanner.Channel().Term, 10), - ).Observe(float64(messageSize)) case <-c.closeCh: c.logger.Info("close channel notified") if err := c.consumeServer.SendClosed(); err != nil { @@ -141,6 +141,28 @@ func (c *ConsumeServer) sendLoop() (err error) { } } +func (c *ConsumeServer) sendImmutableMessage(msg message.ImmutableMessage) error { + // Send Consumed message to client and do metrics. + messageSize := msg.EstimateSize() + if err := c.consumeServer.SendConsumeMessage(&streamingpb.ConsumeMessageReponse{ + Message: &messagespb.ImmutableMessage{ + Id: &messagespb.MessageID{ + Id: msg.MessageID().Marshal(), + }, + Payload: msg.Payload(), + Properties: msg.Properties().ToRawMap(), + }, + }); err != nil { + return status.NewInner("send consume message failed: %s", err.Error()) + } + metrics.StreamingNodeConsumeBytes.WithLabelValues( + paramtable.GetStringNodeID(), + c.scanner.Channel().Name, + strconv.FormatInt(c.scanner.Channel().Term, 10), + ).Observe(float64(messageSize)) + return nil +} + // recvLoop receives messages from client. func (c *ConsumeServer) recvLoop() (err error) { defer func() { diff --git a/internal/streamingnode/server/service/handler/consumer/consume_server_test.go b/internal/streamingnode/server/service/handler/consumer/consume_server_test.go index 2698895e46..314e734379 100644 --- a/internal/streamingnode/server/service/handler/consumer/consume_server_test.go +++ b/internal/streamingnode/server/service/handler/consumer/consume_server_test.go @@ -21,7 +21,6 @@ import ( "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/message" - "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -32,66 +31,6 @@ func TestMain(m *testing.M) { m.Run() } -func TestNewMessageFilter(t *testing.T) { - filters := []*streamingpb.DeliverFilter{ - { - Filter: &streamingpb.DeliverFilter_TimeTickGt{ - TimeTickGt: &streamingpb.DeliverFilterTimeTickGT{ - TimeTick: 1, - }, - }, - }, - { - Filter: &streamingpb.DeliverFilter_Vchannel{ - Vchannel: &streamingpb.DeliverFilterVChannel{ - Vchannel: "test", - }, - }, - }, - } - filterFunc, err := options.GetFilterFunc(filters) - assert.NoError(t, err) - - msg := mock_message.NewMockImmutableMessage(t) - msg.EXPECT().TimeTick().Return(2).Maybe() - msg.EXPECT().VChannel().Return("test2").Maybe() - assert.False(t, filterFunc(msg)) - - msg = mock_message.NewMockImmutableMessage(t) - msg.EXPECT().TimeTick().Return(1).Maybe() - msg.EXPECT().VChannel().Return("test").Maybe() - assert.False(t, filterFunc(msg)) - - msg = mock_message.NewMockImmutableMessage(t) - msg.EXPECT().TimeTick().Return(2).Maybe() - msg.EXPECT().VChannel().Return("test").Maybe() - assert.True(t, filterFunc(msg)) - - filters = []*streamingpb.DeliverFilter{ - { - Filter: &streamingpb.DeliverFilter_TimeTickGte{ - TimeTickGte: &streamingpb.DeliverFilterTimeTickGTE{ - TimeTick: 1, - }, - }, - }, - { - Filter: &streamingpb.DeliverFilter_Vchannel{ - Vchannel: &streamingpb.DeliverFilterVChannel{ - Vchannel: "test", - }, - }, - }, - } - filterFunc, err = options.GetFilterFunc(filters) - assert.NoError(t, err) - - msg = mock_message.NewMockImmutableMessage(t) - msg.EXPECT().TimeTick().Return(1).Maybe() - msg.EXPECT().VChannel().Return("test").Maybe() - assert.True(t, filterFunc(msg)) -} - func TestCreateConsumeServer(t *testing.T) { manager := mock_walmanager.NewMockManager(t) grpcConsumeServer := mock_streamingpb.NewMockStreamingNodeHandlerService_ConsumeServer(t) @@ -201,9 +140,9 @@ func TestConsumerServeSendArm(t *testing.T) { } ctx, cancel := context.WithCancel(context.Background()) grpcConsumerServer.EXPECT().Context().Return(ctx) - grpcConsumerServer.EXPECT().Send(mock.Anything).RunAndReturn(func(cr *streamingpb.ConsumeResponse) error { return nil }).Times(2) + grpcConsumerServer.EXPECT().Send(mock.Anything).RunAndReturn(func(cr *streamingpb.ConsumeResponse) error { return nil }).Times(7) - scanCh := make(chan message.ImmutableMessage, 1) + scanCh := make(chan message.ImmutableMessage, 5) scanner.EXPECT().Channel().Return(types.PChannelInfo{}) scanner.EXPECT().Chan().Return(scanCh) scanner.EXPECT().Close().Return(nil).Times(3) @@ -227,6 +166,20 @@ func TestConsumerServeSendArm(t *testing.T) { msg.EXPECT().Properties().Return(properties) scanCh <- msg + // test send txn message. + txnMsg := mock_message.NewMockImmutableTxnMessage(t) + txnMsg.EXPECT().Begin().Return(msg) + txnMsg.EXPECT().RangeOver(mock.Anything).RunAndReturn(func(f func(message.ImmutableMessage) error) error { + for i := 0; i < 3; i++ { + if err := f(msg); err != nil { + return err + } + } + return nil + }) + txnMsg.EXPECT().Commit().Return(msg) + scanCh <- txnMsg + // test scanner broken. scanner.EXPECT().Error().Return(io.EOF) close(scanCh) diff --git a/internal/streamingnode/server/service/handler/producer/produce_server.go b/internal/streamingnode/server/service/handler/producer/produce_server.go index 5bbf53c39d..06c2b6f629 100644 --- a/internal/streamingnode/server/service/handler/producer/produce_server.go +++ b/internal/streamingnode/server/service/handler/producer/produce_server.go @@ -205,9 +205,6 @@ func (p *ProduceServer) validateMessage(msg message.MutableMessage) error { if !msg.MessageType().Valid() { return status.NewInvaildArgument("unsupported message type") } - if msg.Payload() == nil { - return status.NewInvaildArgument("empty payload for message") - } return nil } @@ -227,7 +224,9 @@ func (p *ProduceServer) sendProduceResult(reqID int64, appendResult *wal.AppendR Id: &messagespb.MessageID{ Id: appendResult.MessageID.Marshal(), }, - Timetick: appendResult.TimeTick, + Timetick: appendResult.TimeTick, + TxnContext: appendResult.TxnCtx.IntoProto(), + Extra: appendResult.Extra, }, } } diff --git a/internal/streamingnode/server/wal/adaptor/opener_test.go b/internal/streamingnode/server/wal/adaptor/opener_test.go index f2b28cf104..44b3d30138 100644 --- a/internal/streamingnode/server/wal/adaptor/opener_test.go +++ b/internal/streamingnode/server/wal/adaptor/opener_test.go @@ -13,6 +13,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/pkg/mocks/streaming/mock_walimpls" + "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls" @@ -78,7 +79,10 @@ func TestOpenerAdaptor(t *testing.T) { assert.NotNil(t, wal) for { - msgID, err := wal.Append(context.Background(), nil) + msg := mock_message.NewMockMutableMessage(t) + msg.EXPECT().WithWALTerm(mock.Anything).Return(msg).Maybe() + + msgID, err := wal.Append(context.Background(), msg) time.Sleep(time.Millisecond * 10) if err != nil { assert.Nil(t, msgID) diff --git a/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go b/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go index 1276b0fd5c..436d00ff5e 100644 --- a/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/scanner_adaptor.go @@ -10,6 +10,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/util/options" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/streaming/walimpls" "github.com/milvus-io/milvus/pkg/streaming/walimpls/helper" @@ -29,12 +30,16 @@ func newScannerAdaptor( if readOption.MesasgeHandler == nil { readOption.MesasgeHandler = defaultMessageHandler(make(chan message.ImmutableMessage)) } + options.GetFilterFunc(readOption.MessageFilter) + logger := log.With(zap.String("name", name), zap.String("channel", l.Channel().Name)) s := &scannerAdaptorImpl{ - logger: log.With(zap.String("name", name), zap.String("channel", l.Channel().Name)), + logger: logger, innerWAL: l, readOption: readOption, + filterFunc: options.GetFilterFunc(readOption.MessageFilter), reorderBuffer: utility.NewReOrderBuffer(), pendingQueue: typeutil.NewMultipartQueue[message.ImmutableMessage](), + txnBuffer: utility.NewTxnBuffer(logger), cleanup: cleanup, ScannerHelper: helper.NewScannerHelper(name), lastTimeTickInfo: inspector.TimeTickInfo{}, @@ -49,8 +54,10 @@ type scannerAdaptorImpl struct { logger *log.MLogger innerWAL walimpls.WALImpls readOption wal.ReadOption + filterFunc func(message.ImmutableMessage) bool reorderBuffer *utility.ReOrderByTimeTickBuffer // only support time tick reorder now. pendingQueue *typeutil.MultipartQueue[message.ImmutableMessage] // + txnBuffer *utility.TxnBuffer // txn buffer for txn message. cleanup func() lastTimeTickInfo inspector.TimeTickInfo } @@ -136,8 +143,15 @@ func (s *scannerAdaptorImpl) getUpstream(scanner walimpls.ScannerImpls) <-chan m func (s *scannerAdaptorImpl) handleUpstream(msg message.ImmutableMessage) { if msg.MessageType() == message.MessageTypeTimeTick { // If the time tick message incoming, - // the reorder buffer can be consumed into a pending queue with latest timetick. - s.pendingQueue.Add(s.reorderBuffer.PopUtilTimeTick(msg.TimeTick())) + // the reorder buffer can be consumed until latest confirmed timetick. + messages := s.reorderBuffer.PopUtilTimeTick(msg.TimeTick()) + + // There's some txn message need to hold until confirmed, so we need to handle them in txn buffer. + msgs := s.txnBuffer.HandleImmutableMessages(messages, msg.TimeTick()) + + // Push the confirmed messages into pending queue for consuming. + // and push forward timetick info. + s.pendingQueue.Add(msgs) s.lastTimeTickInfo = inspector.TimeTickInfo{ MessageID: msg.MessageID(), TimeTick: msg.TimeTick(), @@ -145,8 +159,10 @@ func (s *scannerAdaptorImpl) handleUpstream(msg message.ImmutableMessage) { } return } + // Filtering the message if needed. - if s.readOption.MessageFilter != nil && !s.readOption.MessageFilter(msg) { + // System message should never be filtered. + if s.filterFunc != nil && !s.filterFunc(msg) { return } // otherwise add message into reorder buffer directly. diff --git a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go index 76892aaab8..992bcdd595 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_adaptor.go +++ b/internal/streamingnode/server/wal/adaptor/wal_adaptor.go @@ -7,6 +7,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" @@ -20,7 +21,7 @@ import ( var _ wal.WAL = (*walAdaptorImpl)(nil) -type unwrapMessageIDFunc func(*wal.AppendResult) +type gracefulCloseFunc func() // adaptImplsToWAL creates a new wal from wal impls. func adaptImplsToWAL( @@ -84,16 +85,32 @@ func (w *walAdaptorImpl) Append(ctx context.Context, msg message.MutableMessage) return nil, ctx.Err() case <-w.interceptorBuildResult.Interceptor.Ready(): } + // Setup the term of wal. + msg = msg.WithWALTerm(w.Channel().Term) // Execute the interceptor and wal append. - messageID, err := w.interceptorBuildResult.Interceptor.DoAppend(ctx, msg, w.inner.Append) + var extraAppendResult utility.ExtraAppendResult + ctx = utility.WithExtraAppendResult(ctx, &extraAppendResult) + messageID, err := w.interceptorBuildResult.Interceptor.DoAppend(ctx, msg, + func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error) { + if notPersistHint := utility.GetNotPersisted(ctx); notPersistHint != nil { + // do not persist the message if the hint is set. + // only used by time tick sync operator. + return notPersistHint.MessageID, nil + } + return w.inner.Append(ctx, msg) + }) if err != nil { return nil, err } // unwrap the messageID if needed. - r := &wal.AppendResult{MessageID: messageID} - w.interceptorBuildResult.UnwrapMessageIDFunc(r) + r := &wal.AppendResult{ + MessageID: messageID, + TimeTick: extraAppendResult.TimeTick, + TxnCtx: extraAppendResult.TxnCtx, + Extra: extraAppendResult.Extra, + } return r, nil } @@ -150,6 +167,10 @@ func (w *walAdaptorImpl) Available() <-chan struct{} { // Close overrides Scanner Close function. func (w *walAdaptorImpl) Close() { + // graceful close the interceptors before wal closing. + w.interceptorBuildResult.GracefulCloseFunc() + + // begin to close the wal. w.lifetime.SetState(lifetime.Stopped) w.lifetime.Wait() w.lifetime.Close() @@ -167,8 +188,8 @@ func (w *walAdaptorImpl) Close() { } type interceptorBuildResult struct { - Interceptor interceptors.InterceptorWithReady - UnwrapMessageIDFunc unwrapMessageIDFunc + Interceptor interceptors.InterceptorWithReady + GracefulCloseFunc gracefulCloseFunc } func (r interceptorBuildResult) Close() { @@ -182,19 +203,13 @@ func buildInterceptor(builders []interceptors.InterceptorBuilder, param intercep for _, b := range builders { builtIterceptors = append(builtIterceptors, b.Build(param)) } - - unwrapMessageIDFuncs := make([]func(*wal.AppendResult), 0) - for _, i := range builtIterceptors { - if r, ok := i.(interceptors.InterceptorWithUnwrapMessageID); ok { - unwrapMessageIDFuncs = append(unwrapMessageIDFuncs, r.UnwrapMessageID) - } - } - return interceptorBuildResult{ Interceptor: interceptors.NewChainedInterceptor(builtIterceptors...), - UnwrapMessageIDFunc: func(result *wal.AppendResult) { - for _, f := range unwrapMessageIDFuncs { - f(result) + GracefulCloseFunc: func() { + for _, i := range builtIterceptors { + if c, ok := i.(interceptors.InterceptorWithGracefulClose); ok { + c.GracefulClose() + } } }, } diff --git a/internal/streamingnode/server/wal/adaptor/wal_adaptor_test.go b/internal/streamingnode/server/wal/adaptor/wal_adaptor_test.go index aa327fb263..767910103a 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_adaptor_test.go +++ b/internal/streamingnode/server/wal/adaptor/wal_adaptor_test.go @@ -18,6 +18,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/mocks/streaming/mock_walimpls" + "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/types" @@ -70,9 +71,12 @@ func TestWALAdaptor(t *testing.T) { lAdapted := adaptImplsToWAL(l, nil, func() {}) assert.NotNil(t, lAdapted.Channel()) - _, err := lAdapted.Append(context.Background(), nil) + + msg := mock_message.NewMockMutableMessage(t) + msg.EXPECT().WithWALTerm(mock.Anything).Return(msg).Maybe() + _, err := lAdapted.Append(context.Background(), msg) assert.NoError(t, err) - lAdapted.AppendAsync(context.Background(), nil, func(mi *wal.AppendResult, err error) { + lAdapted.AppendAsync(context.Background(), msg, func(mi *wal.AppendResult, err error) { assert.Nil(t, err) }) @@ -108,9 +112,9 @@ func TestWALAdaptor(t *testing.T) { case <-ch: } - _, err = lAdapted.Append(context.Background(), nil) + _, err = lAdapted.Append(context.Background(), msg) assertShutdownError(t, err) - lAdapted.AppendAsync(context.Background(), nil, func(mi *wal.AppendResult, err error) { + lAdapted.AppendAsync(context.Background(), msg, func(mi *wal.AppendResult, err error) { assertShutdownError(t, err) }) _, err = lAdapted.Read(context.Background(), wal.ReadOption{}) @@ -132,7 +136,9 @@ func TestNoInterceptor(t *testing.T) { lWithInterceptors := adaptImplsToWAL(l, nil, func() {}) - _, err := lWithInterceptors.Append(context.Background(), nil) + msg := mock_message.NewMockMutableMessage(t) + msg.EXPECT().WithWALTerm(mock.Anything).Return(msg).Maybe() + _, err := lWithInterceptors.Append(context.Background(), msg) assert.NoError(t, err) lWithInterceptors.Close() } @@ -162,12 +168,14 @@ func TestWALWithInterceptor(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() // Interceptor is not ready, so the append/read will be blocked until timeout. - _, err := lWithInterceptors.Append(ctx, nil) + msg := mock_message.NewMockMutableMessage(t) + msg.EXPECT().WithWALTerm(mock.Anything).Return(msg).Maybe() + _, err := lWithInterceptors.Append(ctx, msg) assert.ErrorIs(t, err, context.DeadlineExceeded) // Interceptor is ready, so the append/read will return soon. close(readyCh) - _, err = lWithInterceptors.Append(context.Background(), nil) + _, err = lWithInterceptors.Append(context.Background(), msg) assert.NoError(t, err) lWithInterceptors.Close() diff --git a/internal/streamingnode/server/wal/adaptor/wal_test.go b/internal/streamingnode/server/wal/adaptor/wal_test.go index d030bab00d..3f3f2cdc7b 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_test.go +++ b/internal/streamingnode/server/wal/adaptor/wal_test.go @@ -15,8 +15,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus/internal/flushcommon/syncmgr" - "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks/mock_metastore" "github.com/milvus-io/milvus/internal/mocks/streamingnode/server/mock_flusher" @@ -59,9 +57,6 @@ func initResourceForTest(t *testing.T) { catalog.EXPECT().ListSegmentAssignment(mock.Anything, mock.Anything).Return(nil, nil) catalog.EXPECT().SaveSegmentAssignments(mock.Anything, mock.Anything, mock.Anything).Return(nil) - syncMgr := syncmgr.NewMockSyncManager(t) - wbMgr := writebuffer.NewMockBufferManager(t) - flusher := mock_flusher.NewMockFlusher(t) flusher.EXPECT().RegisterPChannel(mock.Anything, mock.Anything).Return(nil).Maybe() flusher.EXPECT().UnregisterPChannel(mock.Anything).Return().Maybe() @@ -70,8 +65,6 @@ func initResourceForTest(t *testing.T) { resource.InitForTest( t, - resource.OptSyncManager(syncMgr), - resource.OptBufferManager(wbMgr), resource.OptRootCoordClient(rc), resource.OptDataCoordClient(dc), resource.OptFlusher(flusher), @@ -224,16 +217,91 @@ func (f *testOneWALFramework) testAppend(ctx context.Context, w wal.WAL) ([]mess go func(i int) { defer swg.Done() time.Sleep(time.Duration(5+rand.Int31n(10)) * time.Millisecond) - // ...rocksmq has a dirty implement of properties, - // without commonpb.MsgHeader, it can not work. - msg := message.CreateTestEmptyInsertMesage(int64(i), map[string]string{ - "id": fmt.Sprintf("%d", i), - "const": "t", - }) - appendResult, err := w.Append(ctx, msg) - assert.NoError(f.t, err) - assert.NotNil(f.t, appendResult) - messages[i] = msg.IntoImmutableMessage(appendResult.MessageID) + + createPartOfTxn := func() (*message.ImmutableTxnMessageBuilder, *message.TxnContext) { + msg, err := message.NewBeginTxnMessageBuilderV2(). + WithVChannel("v1"). + WithHeader(&message.BeginTxnMessageHeader{ + KeepaliveMilliseconds: 1000, + }). + WithBody(&message.BeginTxnMessageBody{}). + BuildMutable() + assert.NoError(f.t, err) + assert.NotNil(f.t, msg) + appendResult, err := w.Append(ctx, msg) + assert.NoError(f.t, err) + assert.NotNil(f.t, appendResult) + + immutableMsg := msg.IntoImmutableMessage(appendResult.MessageID) + begin, err := message.AsImmutableBeginTxnMessageV2(immutableMsg) + assert.NoError(f.t, err) + b := message.NewImmutableTxnMessageBuilder(begin) + txnCtx := appendResult.TxnCtx + for i := 0; i < int(rand.Int31n(5)); i++ { + msg = message.CreateTestEmptyInsertMesage(int64(i), map[string]string{}) + msg.WithTxnContext(*txnCtx) + appendResult, err = w.Append(ctx, msg) + assert.NoError(f.t, err) + assert.NotNil(f.t, msg) + b.Add(msg.IntoImmutableMessage(appendResult.MessageID)) + } + + return b, txnCtx + } + + if rand.Int31n(2) == 0 { + // ...rocksmq has a dirty implement of properties, + // without commonpb.MsgHeader, it can not work. + msg := message.CreateTestEmptyInsertMesage(int64(i), map[string]string{ + "id": fmt.Sprintf("%d", i), + "const": "t", + }) + appendResult, err := w.Append(ctx, msg) + assert.NoError(f.t, err) + assert.NotNil(f.t, appendResult) + messages[i] = msg.IntoImmutableMessage(appendResult.MessageID) + } else { + b, txnCtx := createPartOfTxn() + + msg, err := message.NewCommitTxnMessageBuilderV2(). + WithVChannel("v1"). + WithHeader(&message.CommitTxnMessageHeader{}). + WithBody(&message.CommitTxnMessageBody{}). + WithProperties(map[string]string{ + "id": fmt.Sprintf("%d", i), + "const": "t", + }). + BuildMutable() + assert.NoError(f.t, err) + assert.NotNil(f.t, msg) + appendResult, err := w.Append(ctx, msg.WithTxnContext(*txnCtx)) + assert.NoError(f.t, err) + assert.NotNil(f.t, appendResult) + + immutableMsg := msg.IntoImmutableMessage(appendResult.MessageID) + commit, err := message.AsImmutableCommitTxnMessageV2(immutableMsg) + assert.NoError(f.t, err) + txn, err := b.Build(commit) + assert.NoError(f.t, err) + messages[i] = txn + } + + if rand.Int31n(3) == 0 { + // produce a rollback or expired message. + _, txnCtx := createPartOfTxn() + if rand.Int31n(2) == 0 { + msg, err := message.NewRollbackTxnMessageBuilderV2(). + WithVChannel("v1"). + WithHeader(&message.RollbackTxnMessageHeader{}). + WithBody(&message.RollbackTxnMessageBody{}). + BuildMutable() + assert.NoError(f.t, err) + assert.NotNil(f.t, msg) + appendResult, err := w.Append(ctx, msg.WithTxnContext(*txnCtx)) + assert.NoError(f.t, err) + assert.NotNil(f.t, appendResult) + } + } }(i) } swg.Wait() @@ -252,8 +320,8 @@ func (f *testOneWALFramework) testAppend(ctx context.Context, w wal.WAL) ([]mess func (f *testOneWALFramework) testRead(ctx context.Context, w wal.WAL) ([]message.ImmutableMessage, error) { s, err := w.Read(ctx, wal.ReadOption{ DeliverPolicy: options.DeliverPolicyAll(), - MessageFilter: func(im message.ImmutableMessage) bool { - return im.MessageType() == message.MessageTypeInsert + MessageFilter: []options.DeliverFilter{ + options.DeliverFilterMessageType(message.MessageTypeInsert), }, }) assert.NoError(f.t, err) @@ -263,7 +331,7 @@ func (f *testOneWALFramework) testRead(ctx context.Context, w wal.WAL) ([]messag msgs := make([]message.ImmutableMessage, 0, expectedCnt) for { msg, ok := <-s.Chan() - if msg.MessageType() != message.MessageTypeInsert { + if msg.MessageType() != message.MessageTypeInsert && msg.MessageType() != message.MessageTypeTxn { continue } assert.NotNil(f.t, msg) @@ -297,8 +365,9 @@ func (f *testOneWALFramework) testReadWithOption(ctx context.Context, w wal.WAL) readFromMsg := f.written[idx] s, err := w.Read(ctx, wal.ReadOption{ DeliverPolicy: options.DeliverPolicyStartFrom(readFromMsg.LastConfirmedMessageID()), - MessageFilter: func(im message.ImmutableMessage) bool { - return im.TimeTick() >= readFromMsg.TimeTick() && im.MessageType() == message.MessageTypeInsert + MessageFilter: []options.DeliverFilter{ + options.DeliverFilterTimeTickGTE(readFromMsg.TimeTick()), + options.DeliverFilterMessageType(message.MessageTypeInsert), }, }) assert.NoError(f.t, err) @@ -307,7 +376,7 @@ func (f *testOneWALFramework) testReadWithOption(ctx context.Context, w wal.WAL) lastTimeTick := readFromMsg.TimeTick() - 1 for { msg, ok := <-s.Chan() - if msg.MessageType() != message.MessageTypeInsert { + if msg.MessageType() != message.MessageTypeInsert && msg.MessageType() != message.MessageTypeTxn { continue } msgCount++ @@ -337,18 +406,36 @@ func (f *testOneWALFramework) assertSortByTimeTickMessageList(msgs []message.Imm func (f *testOneWALFramework) assertEqualMessageList(msgs1 []message.ImmutableMessage, msgs2 []message.ImmutableMessage) { assert.Equal(f.t, len(msgs2), len(msgs1)) for i := 0; i < len(msgs1); i++ { - assert.True(f.t, msgs1[i].MessageID().EQ(msgs2[i].MessageID())) - // assert.True(f.t, bytes.Equal(msgs1[i].Payload(), msgs2[i].Payload())) - id1, ok1 := msgs1[i].Properties().Get("id") - id2, ok2 := msgs2[i].Properties().Get("id") - assert.True(f.t, ok1) - assert.True(f.t, ok2) - assert.Equal(f.t, id1, id2) - id1, ok1 = msgs1[i].Properties().Get("const") - id2, ok2 = msgs2[i].Properties().Get("const") - assert.True(f.t, ok1) - assert.True(f.t, ok2) - assert.Equal(f.t, id1, id2) + assert.Equal(f.t, msgs1[i].MessageType(), msgs2[i].MessageType()) + if msgs1[i].MessageType() == message.MessageTypeInsert { + assert.True(f.t, msgs1[i].MessageID().EQ(msgs2[i].MessageID())) + // assert.True(f.t, bytes.Equal(msgs1[i].Payload(), msgs2[i].Payload())) + id1, ok1 := msgs1[i].Properties().Get("id") + id2, ok2 := msgs2[i].Properties().Get("id") + assert.True(f.t, ok1) + assert.True(f.t, ok2) + assert.Equal(f.t, id1, id2) + id1, ok1 = msgs1[i].Properties().Get("const") + id2, ok2 = msgs2[i].Properties().Get("const") + assert.True(f.t, ok1) + assert.True(f.t, ok2) + assert.Equal(f.t, id1, id2) + } + if msgs1[i].MessageType() == message.MessageTypeTxn { + txn1 := message.AsImmutableTxnMessage(msgs1[i]) + txn2 := message.AsImmutableTxnMessage(msgs2[i]) + assert.Equal(f.t, txn1.Size(), txn2.Size()) + id1, ok1 := txn1.Commit().Properties().Get("id") + id2, ok2 := txn2.Commit().Properties().Get("id") + assert.True(f.t, ok1) + assert.True(f.t, ok2) + assert.Equal(f.t, id1, id2) + id1, ok1 = txn1.Commit().Properties().Get("const") + id2, ok2 = txn2.Commit().Properties().Get("const") + assert.True(f.t, ok1) + assert.True(f.t, ok2) + assert.Equal(f.t, id1, id2) + } } } diff --git a/internal/streamingnode/server/wal/interceptors/interceptor.go b/internal/streamingnode/server/wal/interceptors/interceptor.go index 089189026f..9ff57c64a7 100644 --- a/internal/streamingnode/server/wal/interceptors/interceptor.go +++ b/internal/streamingnode/server/wal/interceptors/interceptor.go @@ -39,7 +39,7 @@ type Interceptor interface { // Execute the append operation with interceptor. DoAppend(ctx context.Context, msg message.MutableMessage, append Append) (message.MessageID, error) - // Close the interceptor release the resources. + // Close the interceptor release all the resources. Close() } @@ -57,9 +57,11 @@ type InterceptorWithReady interface { Ready() <-chan struct{} } -type InterceptorWithUnwrapMessageID interface { +// Some interceptor may need to perform a graceful close operation. +type InterceptorWithGracefulClose interface { Interceptor - // UnwrapMessageID the message id from the append result. - UnwrapMessageID(*wal.AppendResult) + // GracefulClose will be called when the wal begin to close. + // The interceptor can do some operations before the wal rejects all incoming append operations. + GracefulClose() } diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/params.go b/internal/streamingnode/server/wal/interceptors/segment/manager/params.go index 086e570ac8..26b93a1650 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/params.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/params.go @@ -4,6 +4,7 @@ import ( "go.uber.org/atomic" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" ) // AssignSegmentRequest is a request to allocate segment. @@ -11,6 +12,8 @@ type AssignSegmentRequest struct { CollectionID int64 PartitionID int64 InsertMetrics stats.InsertMetrics + TimeTick uint64 + TxnSession *txn.TxnSession } // AssignSegmentResult is a result of segment allocation. diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go index 3f0224dc10..e102d114d3 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_manager.go @@ -10,13 +10,14 @@ import ( "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/policy" - "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/merr" ) +var ErrFencedAssign = errors.New("fenced assign") + // newPartitionSegmentManager creates a new partition segment assign manager. func newPartitionSegmentManager( pchannel types.PChannelInfo, @@ -42,13 +43,14 @@ func newPartitionSegmentManager( // partitionSegmentManager is a assign manager of determined partition on determined vchannel. type partitionSegmentManager struct { - mu sync.Mutex - logger *log.MLogger - pchannel types.PChannelInfo - vchannel string - collectionID int64 - paritionID int64 - segments []*segmentAllocManager // there will be very few segments in this list. + mu sync.Mutex + logger *log.MLogger + pchannel types.PChannelInfo + vchannel string + collectionID int64 + paritionID int64 + segments []*segmentAllocManager // there will be very few segments in this list. + fencedAssignTimeTick uint64 // the time tick that the assign operation is fenced. } func (m *partitionSegmentManager) CollectionID() int64 { @@ -56,11 +58,35 @@ func (m *partitionSegmentManager) CollectionID() int64 { } // AssignSegment assigns a segment for a assign segment request. -func (m *partitionSegmentManager) AssignSegment(ctx context.Context, insert stats.InsertMetrics) (*AssignSegmentResult, error) { +func (m *partitionSegmentManager) AssignSegment(ctx context.Context, req *AssignSegmentRequest) (*AssignSegmentResult, error) { m.mu.Lock() defer m.mu.Unlock() - return m.assignSegment(ctx, insert) + // !!! We have promised that the fencedAssignTimeTick is always less than new incoming insert request by Barrier TimeTick of ManualFlush. + // So it's just a promise check here. + // If the request time tick is less than the fenced time tick, the assign operation is fenced. + // A special error will be returned to indicate the assign operation is fenced. + // The wal will retry it with new timetick. + if req.TimeTick <= m.fencedAssignTimeTick { + return nil, ErrFencedAssign + } + return m.assignSegment(ctx, req) +} + +// SealAllSegmentsAndFenceUntil seals all segments and fence assign until the maximum of timetick or max time tick. +func (m *partitionSegmentManager) SealAllSegmentsAndFenceUntil(timeTick uint64) (sealedSegments []*segmentAllocManager) { + m.mu.Lock() + defer m.mu.Unlock() + + segmentManagers := m.collectShouldBeSealedWithPolicy(func(segmentMeta *segmentAllocManager) bool { return true }) + // fence the assign operation until the incoming time tick or latest assigned timetick. + // The new incoming assignment request will be fenced. + // So all the insert operation before the fenced time tick cannot added to the growing segment (no more insert can be applied on it). + // In other words, all insert operation before the fenced time tick will be sealed + if timeTick > m.fencedAssignTimeTick { + m.fencedAssignTimeTick = timeTick + } + return segmentManagers } // CollectShouldBeSealed try to collect all segments that should be sealed. @@ -68,6 +94,11 @@ func (m *partitionSegmentManager) CollectShouldBeSealed() []*segmentAllocManager m.mu.Lock() defer m.mu.Unlock() + return m.collectShouldBeSealedWithPolicy(m.hitSealPolicy) +} + +// collectShouldBeSealedWithPolicy collects all segments that should be sealed by policy. +func (m *partitionSegmentManager) collectShouldBeSealedWithPolicy(predicates func(segmentMeta *segmentAllocManager) bool) []*segmentAllocManager { shouldBeSealedSegments := make([]*segmentAllocManager, 0, len(m.segments)) segments := make([]*segmentAllocManager, 0, len(m.segments)) for _, segment := range m.segments { @@ -81,8 +112,10 @@ func (m *partitionSegmentManager) CollectShouldBeSealed() []*segmentAllocManager ) continue } - // policy hitted segment should be removed from assignment manager. - if m.hitSealPolicy(segment) { + + // policy hitted growing segment should be removed from assignment manager. + if segment.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING && + predicates(segment) { shouldBeSealedSegments = append(shouldBeSealedSegments, segment) continue } @@ -96,6 +129,7 @@ func (m *partitionSegmentManager) CollectShouldBeSealed() []*segmentAllocManager func (m *partitionSegmentManager) CollectDirtySegmentsAndClear() []*segmentAllocManager { m.mu.Lock() defer m.mu.Unlock() + dirtySegments := make([]*segmentAllocManager, 0, len(m.segments)) for _, segment := range m.segments { if segment.IsDirtyEnough() { @@ -110,6 +144,7 @@ func (m *partitionSegmentManager) CollectDirtySegmentsAndClear() []*segmentAlloc func (m *partitionSegmentManager) CollectAllCanBeSealedAndClear() []*segmentAllocManager { m.mu.Lock() defer m.mu.Unlock() + canBeSealed := make([]*segmentAllocManager, 0, len(m.segments)) for _, segment := range m.segments { if segment.GetState() == streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING || @@ -215,10 +250,10 @@ func (m *partitionSegmentManager) createNewPendingSegment(ctx context.Context) ( } // assignSegment assigns a segment for a assign segment request and return should trigger a seal operation. -func (m *partitionSegmentManager) assignSegment(ctx context.Context, insert stats.InsertMetrics) (*AssignSegmentResult, error) { +func (m *partitionSegmentManager) assignSegment(ctx context.Context, req *AssignSegmentRequest) (*AssignSegmentResult, error) { // Alloc segment for insert at previous segments. for _, segment := range m.segments { - inserted, ack := segment.AllocRows(ctx, insert) + inserted, ack := segment.AllocRows(ctx, req) if inserted { return &AssignSegmentResult{SegmentID: segment.GetSegmentID(), Acknowledge: ack}, nil } @@ -229,8 +264,8 @@ func (m *partitionSegmentManager) assignSegment(ctx context.Context, insert stat if err != nil { return nil, err } - if inserted, ack := newGrowingSegment.AllocRows(ctx, insert); inserted { + if inserted, ack := newGrowingSegment.AllocRows(ctx, req); inserted { return &AssignSegmentResult{SegmentID: newGrowingSegment.GetSegmentID(), Acknowledge: ack}, nil } - return nil, errors.Errorf("too large insert message, cannot hold in empty growing segment, stats: %+v", insert) + return nil, errors.Errorf("too large insert message, cannot hold in empty growing segment, stats: %+v", req.InsertMetrics) } diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_managers.go b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_managers.go index 08680210e0..2cec243849 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/partition_managers.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/partition_managers.go @@ -208,6 +208,34 @@ func (m *partitionSegmentManagers) RemovePartition(collectionID int64, partition return pm.CollectAllCanBeSealedAndClear() } +// SealAllSegmentsAndFenceUntil seals all segments and fence assign until timetick. +func (m *partitionSegmentManagers) SealAllSegmentsAndFenceUntil(collectionID int64, timetick uint64) ([]*segmentAllocManager, error) { + m.mu.Lock() + defer m.mu.Unlock() + + collectionInfo, ok := m.collectionInfos[collectionID] + if !ok { + m.logger.Warn("collection not exists when Flush in segment assignment service", zap.Int64("collectionID", collectionID)) + return nil, errors.New("collection not found") + } + + sealedSegments := make([]*segmentAllocManager, 0) + // collect all partitions + for _, partition := range collectionInfo.Partitions { + // Seal all segments and fence assign to the partition manager. + pm, ok := m.managers.Get(partition.PartitionId) + if !ok { + m.logger.Warn("partition not found when Flush in segment assignment service, it's may be a bug in system", + zap.Int64("collectionID", collectionID), + zap.Int64("partitionID", partition.PartitionId)) + return nil, errors.New("partition not found") + } + newSealedSegments := pm.SealAllSegmentsAndFenceUntil(timetick) + sealedSegments = append(sealedSegments, newSealedSegments...) + } + return sealedSegments, nil +} + // Range ranges the partition managers. func (m *partitionSegmentManagers) Range(f func(pm *partitionSegmentManager)) { m.managers.Range(func(_ int64, pm *partitionSegmentManager) bool { diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go index 096f159003..adcf87370a 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager.go @@ -99,7 +99,7 @@ func (m *PChannelSegmentAllocManager) AssignSegment(ctx context.Context, req *As if err != nil { return nil, err } - return manager.AssignSegment(ctx, req.InsertMetrics) + return manager.AssignSegment(ctx, req) } // RemoveCollection removes the specified collection. @@ -138,6 +138,36 @@ func (m *PChannelSegmentAllocManager) RemovePartition(ctx context.Context, colle return m.helper.WaitUntilNoWaitSeal(ctx) } +// SealAllSegmentsAndFenceUntil seals all segments and fence assign until timetick and return the segmentIDs. +func (m *PChannelSegmentAllocManager) SealAllSegmentsAndFenceUntil(ctx context.Context, collectionID int64, timetick uint64) ([]int64, error) { + if err := m.checkLifetime(); err != nil { + return nil, err + } + defer m.lifetime.Done() + + // All message's timetick less than incoming timetick is all belong to the output sealed segment. + // So the output sealed segment transfer into flush == all message's timetick less than incoming timetick are flushed. + sealedSegments, err := m.managers.SealAllSegmentsAndFenceUntil(collectionID, timetick) + if err != nil { + return nil, err + } + + segmentIDs := make([]int64, 0, len(sealedSegments)) + for _, segment := range sealedSegments { + segmentIDs = append(segmentIDs, segment.GetSegmentID()) + } + + // trigger a seal operation in background rightnow. + m.helper.AsyncSeal(sealedSegments...) + + // wait for all segment has been flushed. + if err := m.helper.WaitUntilNoWaitSeal(ctx); err != nil { + return nil, err + } + + return segmentIDs, nil +} + // TryToSealSegments tries to seal the specified segments. func (m *PChannelSegmentAllocManager) TryToSealSegments(ctx context.Context, infos ...stats.SegmentBelongs) { if err := m.lifetime.Add(lifetime.IsWorking); err != nil { diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go index 02fe470d3f..ae7dbe36f6 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/pchannel_manager_test.go @@ -18,11 +18,13 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/syncutil" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) func TestSegmentAllocManager(t *testing.T) { @@ -47,6 +49,7 @@ func TestSegmentAllocManager(t *testing.T) { Rows: 100, BinarySize: 100, }, + TimeTick: tsoutil.GetCurrentTime(), }) assert.NoError(t, err) assert.NotNil(t, result) @@ -59,6 +62,7 @@ func TestSegmentAllocManager(t *testing.T) { Rows: 1024 * 1024, BinarySize: 1024 * 1024, // 1MB setting at paramtable. }, + TimeTick: tsoutil.GetCurrentTime(), }) assert.NoError(t, err) assert.NotNil(t, result2) @@ -76,6 +80,7 @@ func TestSegmentAllocManager(t *testing.T) { Rows: 1, BinarySize: 1, }, + TimeTick: tsoutil.GetCurrentTime(), }) assert.NoError(t, err) assert.NotNil(t, result3) @@ -88,6 +93,36 @@ func TestSegmentAllocManager(t *testing.T) { m.TryToSealWaitedSegment(ctx) assert.True(t, m.IsNoWaitSeal()) // result2 is acked, so new seal segment will be sealed right away. + // interactive with txn + txnManager := txn.NewTxnManager() + txn, err := txnManager.BeginNewTxn(context.Background(), tsoutil.GetCurrentTime(), time.Second) + assert.NoError(t, err) + txn.BeginDone() + + for i := 0; i < 3; i++ { + result, err = m.AssignSegment(ctx, &AssignSegmentRequest{ + CollectionID: 1, + PartitionID: 1, + InsertMetrics: stats.InsertMetrics{ + Rows: 1024 * 1024, + BinarySize: 1024 * 1024, // 1MB setting at paramtable. + }, + TxnSession: txn, + TimeTick: tsoutil.GetCurrentTime(), + }) + assert.NoError(t, err) + result.Ack() + } + // because of there's a txn session uncommitted, so the segment will not be sealed. + m.TryToSealSegments(ctx) + assert.False(t, m.IsNoWaitSeal()) + + err = txn.RequestCommitAndWait(context.Background(), 0) + assert.NoError(t, err) + txn.CommitDone() + m.TryToSealSegments(ctx) + assert.True(t, m.IsNoWaitSeal()) + // Try to seal a partition. m.TryToSealSegments(ctx, stats.SegmentBelongs{ CollectionID: 1, @@ -109,6 +144,7 @@ func TestSegmentAllocManager(t *testing.T) { Rows: 100, BinarySize: 100, }, + TimeTick: tsoutil.GetCurrentTime(), }) assert.NoError(t, err) assert.NotNil(t, result) @@ -121,6 +157,30 @@ func TestSegmentAllocManager(t *testing.T) { m.TryToSealSegments(ctx) assert.True(t, m.IsNoWaitSeal()) + // Test fence + ts := tsoutil.GetCurrentTime() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + ids, err := m.SealAllSegmentsAndFenceUntil(ctx, 1, ts) + assert.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Empty(t, ids) + assert.False(t, m.IsNoWaitSeal()) + m.TryToSealSegments(ctx) + assert.True(t, m.IsNoWaitSeal()) + + result, err = m.AssignSegment(ctx, &AssignSegmentRequest{ + CollectionID: 1, + PartitionID: 3, + InsertMetrics: stats.InsertMetrics{ + Rows: 100, + BinarySize: 100, + }, + TimeTick: ts, + }) + assert.ErrorIs(t, err, ErrFencedAssign) + assert.Nil(t, result) + m.Close(ctx) } @@ -146,6 +206,7 @@ func TestCreateAndDropCollection(t *testing.T) { Rows: 100, BinarySize: 200, }, + TimeTick: tsoutil.GetCurrentTime(), } resp, err := m.AssignSegment(ctx, testRequest) diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/seal_queue.go b/internal/streamingnode/server/wal/interceptors/segment/manager/seal_queue.go index 4b220ee9fb..7ef2865f66 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/seal_queue.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/seal_queue.go @@ -95,7 +95,7 @@ func (q *sealQueue) tryToSealSegments(ctx context.Context, segments ...*segmentA // send flush message into wal. for collectionID, vchannelSegments := range sealedSegments { for vchannel, segments := range vchannelSegments { - if err := q.sendFlushMessageIntoWAL(ctx, collectionID, vchannel, segments); err != nil { + if err := q.sendFlushSegmentsMessageIntoWAL(ctx, collectionID, vchannel, segments); err != nil { q.logger.Warn("fail to send flush message into wal", zap.String("vchannel", vchannel), zap.Int64("collectionID", collectionID), zap.Error(err)) undone = append(undone, segments...) continue @@ -146,6 +146,13 @@ func (q *sealQueue) transferSegmentStateIntoSealed(ctx context.Context, segments continue } + txnSem := segment.TxnSem() + if txnSem > 0 { + undone = append(undone, segment) + q.logger.Info("segment has been sealed, but there are flying txns, delay it", zap.Int64("segmentID", segment.GetSegmentID()), zap.Int32("txnSem", txnSem)) + continue + } + // collect all sealed segments and no flying ack segment. if _, ok := sealedSegments[segment.GetCollectionID()]; !ok { sealedSegments[segment.GetCollectionID()] = make(map[string][]*segmentAllocManager) @@ -158,15 +165,21 @@ func (q *sealQueue) transferSegmentStateIntoSealed(ctx context.Context, segments return undone, sealedSegments } -// sendFlushMessageIntoWAL sends a flush message into wal. -func (m *sealQueue) sendFlushMessageIntoWAL(ctx context.Context, collectionID int64, vchannel string, segments []*segmentAllocManager) error { +// sendFlushSegmentsMessageIntoWAL sends a flush message into wal. +func (m *sealQueue) sendFlushSegmentsMessageIntoWAL(ctx context.Context, collectionID int64, vchannel string, segments []*segmentAllocManager) error { segmentIDs := make([]int64, 0, len(segments)) for _, segment := range segments { segmentIDs = append(segmentIDs, segment.GetSegmentID()) } - msg, err := m.createNewFlushMessage(collectionID, vchannel, segmentIDs) + msg, err := message.NewFlushMessageBuilderV2(). + WithVChannel(vchannel). + WithHeader(&message.FlushMessageHeader{}). + WithBody(&message.FlushMessageBody{ + CollectionId: collectionID, + SegmentId: segmentIDs, + }).BuildMutable() if err != nil { - return errors.Wrap(err, "at create new flush message") + return errors.Wrap(err, "at create new flush segments message") } msgID, err := m.wal.Get().Append(ctx, msg) @@ -179,7 +192,11 @@ func (m *sealQueue) sendFlushMessageIntoWAL(ctx context.Context, collectionID in } // createNewFlushMessage creates a new flush message. -func (m *sealQueue) createNewFlushMessage(collectionID int64, vchannel string, segmentIDs []int64) (message.MutableMessage, error) { +func (m *sealQueue) createNewFlushMessage( + collectionID int64, + vchannel string, + segmentIDs []int64, +) (message.MutableMessage, error) { // Create a flush message. msg, err := message.NewFlushMessageBuilderV2(). WithVChannel(vchannel). diff --git a/internal/streamingnode/server/wal/interceptors/segment/manager/segment_manager.go b/internal/streamingnode/server/wal/interceptors/segment/manager/segment_manager.go index d2684043f6..594648a661 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/manager/segment_manager.go +++ b/internal/streamingnode/server/wal/interceptors/segment/manager/segment_manager.go @@ -40,6 +40,7 @@ func newSegmentAllocManagerFromProto( inner: inner, immutableStat: stat, ackSem: atomic.NewInt32(0), + txnSem: atomic.NewInt32(0), dirtyBytes: 0, } } @@ -65,6 +66,7 @@ func newSegmentAllocManager( immutableStat: nil, // immutable stat can be seen after sealed. ackSem: atomic.NewInt32(0), dirtyBytes: 0, + txnSem: atomic.NewInt32(0), } } @@ -88,6 +90,7 @@ type segmentAllocManager struct { immutableStat *stats.SegmentStats // after sealed or flushed, the stat is immutable and cannot be seen by stats manager. ackSem *atomic.Int32 // the ackSem is increased when segment allocRows, decreased when the segment is acked. dirtyBytes uint64 // records the dirty bytes that didn't persist. + txnSem *atomic.Int32 // the runnint txn count of the segment. } // GetCollectionID returns the collection id of the segment assignment meta. @@ -131,20 +134,31 @@ func (s *segmentAllocManager) AckSem() int32 { return s.ackSem.Load() } +// TxnSem returns the txn sem. +func (s *segmentAllocManager) TxnSem() int32 { + return s.txnSem.Load() +} + // AllocRows ask for rows from current segment. // Only growing and not fenced segment can alloc rows. -func (s *segmentAllocManager) AllocRows(ctx context.Context, m stats.InsertMetrics) (bool, *atomic.Int32) { +func (s *segmentAllocManager) AllocRows(ctx context.Context, req *AssignSegmentRequest) (bool, *atomic.Int32) { // if the segment is not growing or reach limit, return false directly. if s.inner.GetState() != streamingpb.SegmentAssignmentState_SEGMENT_ASSIGNMENT_STATE_GROWING { return false, nil } - inserted := resource.Resource().SegmentAssignStatsManager().AllocRows(s.GetSegmentID(), m) + inserted := resource.Resource().SegmentAssignStatsManager().AllocRows(s.GetSegmentID(), req.InsertMetrics) if !inserted { return false, nil } - s.dirtyBytes += m.BinarySize + s.dirtyBytes += req.InsertMetrics.BinarySize s.ackSem.Inc() + // register the txn session cleanup to the segment. + if req.TxnSession != nil { + s.txnSem.Inc() + req.TxnSession.RegisterCleanup(func() { s.txnSem.Dec() }, req.TimeTick) + } + // persist stats if too dirty. s.persistStatsIfTooDirty(ctx) return inserted, s.ackSem diff --git a/internal/streamingnode/server/wal/interceptors/segment/segment_assign_interceptor.go b/internal/streamingnode/server/wal/interceptors/segment/segment_assign_interceptor.go index d0220e2035..1d85a487bd 100644 --- a/internal/streamingnode/server/wal/interceptors/segment/segment_assign_interceptor.go +++ b/internal/streamingnode/server/wal/interceptors/segment/segment_assign_interceptor.go @@ -10,6 +10,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/manager" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/segment/stats" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" @@ -137,8 +138,10 @@ func (impl *segmentInterceptor) handleInsertMessage(ctx context.Context, msg mes PartitionID: partition.GetPartitionId(), InsertMetrics: stats.InsertMetrics{ Rows: partition.GetRows(), - BinarySize: partition.GetBinarySize(), + BinarySize: uint64(msg.EstimateSize()), // TODO: Use parition.BinarySize in future when merge partitions together in one message. }, + TimeTick: msg.TimeTick(), + TxnSession: txn.GetTxnSessionFromContext(ctx), }) if err != nil { return nil, status.NewInner("segment assignment failure with error: %s", err.Error()) diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack.go index 2fe4c4d11b..8c9b00de02 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack.go @@ -7,20 +7,16 @@ import ( "github.com/milvus-io/milvus/pkg/util/typeutil" ) -var _ typeutil.HeapInterface = (*timestampWithAckArray)(nil) - -// newAcker creates a new acker. -func newAcker(ts uint64, lastConfirmedMessageID message.MessageID) *Acker { - return &Acker{ - acknowledged: atomic.NewBool(false), - detail: newAckDetail(ts, lastConfirmedMessageID), - } -} +var ( + _ typeutil.HeapInterface = (*ackersOrderByTimestamp)(nil) + _ typeutil.HeapInterface = (*ackersOrderByEndTimestamp)(nil) +) // Acker records the timestamp and last confirmed message id that has not been acknowledged. type Acker struct { acknowledged *atomic.Bool // is acknowledged. detail *AckDetail // info is available after acknowledged. + manager *AckManager // the manager of the acker. } // LastConfirmedMessageID returns the last confirmed message id. @@ -30,7 +26,7 @@ func (ta *Acker) LastConfirmedMessageID() message.MessageID { // Timestamp returns the timestamp. func (ta *Acker) Timestamp() uint64 { - return ta.detail.Timestamp + return ta.detail.BeginTimestamp } // Ack marks the timestamp as acknowledged. @@ -39,6 +35,7 @@ func (ta *Acker) Ack(opts ...AckOption) { opt(ta.detail) } ta.acknowledged.Store(true) + ta.manager.ack(ta) } // ackDetail returns the ack info, only can be called after acknowledged. @@ -49,31 +46,46 @@ func (ta *Acker) ackDetail() *AckDetail { return ta.detail } -// timestampWithAckArray is a heap underlying represent of timestampAck. -type timestampWithAckArray []*Acker - -// Len returns the length of the heap. -func (h timestampWithAckArray) Len() int { - return len(h) +// ackersOrderByTimestamp is a heap underlying represent of timestampAck. +type ackersOrderByTimestamp struct { + ackers } // Less returns true if the element at index i is less than the element at index j. -func (h timestampWithAckArray) Less(i, j int) bool { - return h[i].detail.Timestamp < h[j].detail.Timestamp +func (h ackersOrderByTimestamp) Less(i, j int) bool { + return h.ackers[i].detail.BeginTimestamp < h.ackers[j].detail.BeginTimestamp +} + +// ackersOrderByEndTimestamp is a heap underlying represent of timestampAck. +type ackersOrderByEndTimestamp struct { + ackers +} + +// Less returns true if the element at index i is less than the element at index j. +func (h ackersOrderByEndTimestamp) Less(i, j int) bool { + return h.ackers[i].detail.EndTimestamp < h.ackers[j].detail.EndTimestamp +} + +// ackers is a heap underlying represent of timestampAck. +type ackers []*Acker + +// Len returns the length of the heap. +func (h ackers) Len() int { + return len(h) } // Swap swaps the elements at indexes i and j. -func (h timestampWithAckArray) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h ackers) Swap(i, j int) { h[i], h[j] = h[j], h[i] } // Push pushes the last one at len. -func (h *timestampWithAckArray) Push(x interface{}) { +func (h *ackers) Push(x interface{}) { // Push and Pop use pointer receivers because they modify the slice's length, // not just its contents. *h = append(*h, x.(*Acker)) } // Pop pop the last one at len. -func (h *timestampWithAckArray) Pop() interface{} { +func (h *ackers) Pop() interface{} { old := *h n := len(old) x := old[n-1] @@ -82,6 +94,6 @@ func (h *timestampWithAckArray) Pop() interface{} { } // Peek returns the element at the top of the heap. -func (h *timestampWithAckArray) Peek() interface{} { +func (h *ackers) Peek() interface{} { return (*h)[0] } diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_details.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_details.go index 7151f443c0..1b73cbbec4 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_details.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_details.go @@ -31,7 +31,7 @@ func (ad *AckDetails) AddDetails(details sortedDetails) { ad.detail = details return } - if ad.detail[len(ad.detail)-1].Timestamp >= details[0].Timestamp { + if ad.detail[len(ad.detail)-1].BeginTimestamp >= details[0].BeginTimestamp { panic("unreachable: the details must be sorted by timestamp") } ad.detail = append(ad.detail, details...) @@ -62,7 +62,10 @@ func (ad *AckDetails) IsNoPersistedMessage() bool { // LastAllAcknowledgedTimestamp returns the last timestamp which all timestamps before it have been acknowledged. // panic if no timestamp has been acknowledged. func (ad *AckDetails) LastAllAcknowledgedTimestamp() uint64 { - return ad.detail[len(ad.detail)-1].Timestamp + if len(ad.detail) > 0 { + return ad.detail[len(ad.detail)-1].BeginTimestamp + } + return 0 } // EarliestLastConfirmedMessageID returns the last confirmed message id. diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_details_test.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_details_test.go index eb6f2d2fd2..c4da34a0b1 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_details_test.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_details_test.go @@ -13,13 +13,13 @@ func TestAckDetails(t *testing.T) { assert.True(t, details.Empty()) assert.Equal(t, 0, details.Len()) details.AddDetails(sortedDetails{ - &AckDetail{Timestamp: 1, IsSync: true}, + &AckDetail{BeginTimestamp: 1, IsSync: true}, }) assert.True(t, details.IsNoPersistedMessage()) assert.Equal(t, uint64(1), details.LastAllAcknowledgedTimestamp()) details.AddDetails(sortedDetails{ - &AckDetail{Timestamp: 2, LastConfirmedMessageID: walimplstest.NewTestMessageID(2)}, - &AckDetail{Timestamp: 3, LastConfirmedMessageID: walimplstest.NewTestMessageID(1)}, + &AckDetail{BeginTimestamp: 2, LastConfirmedMessageID: walimplstest.NewTestMessageID(2)}, + &AckDetail{BeginTimestamp: 3, LastConfirmedMessageID: walimplstest.NewTestMessageID(1)}, }) assert.False(t, details.IsNoPersistedMessage()) assert.Equal(t, uint64(3), details.LastAllAcknowledgedTimestamp()) @@ -27,7 +27,7 @@ func TestAckDetails(t *testing.T) { assert.Panics(t, func() { details.AddDetails(sortedDetails{ - &AckDetail{Timestamp: 1, IsSync: true}, + &AckDetail{BeginTimestamp: 1, IsSync: true}, }) }) diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go index efbebe451a..95c8e22c15 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack_test.go @@ -2,13 +2,18 @@ package ack import ( "context" + "fmt" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.uber.org/atomic" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" - "github.com/milvus-io/milvus/internal/streamingnode/server/resource/idalloc" - "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" + "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -18,19 +23,29 @@ func TestAck(t *testing.T) { ctx := context.Background() - rc := idalloc.NewMockRootCoordClient(t) + counter := atomic.NewUint64(1) + rc := mocks.NewMockRootCoordClient(t) + rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, atr *rootcoordpb.AllocTimestampRequest, co ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { + if atr.Count > 1000 { + panic(fmt.Sprintf("count %d is too large", atr.Count)) + } + c := counter.Add(uint64(atr.Count)) + return &rootcoordpb.AllocTimestampResponse{ + Status: merr.Success(), + Timestamp: c - uint64(atr.Count), + Count: atr.Count, + }, nil + }, + ) resource.InitForTest(t, resource.OptRootCoordClient(rc)) - ackManager := NewAckManager() - msgID := mock_message.NewMockMessageID(t) - msgID.EXPECT().EQ(msgID).Return(true) - ackManager.AdvanceLastConfirmedMessageID(msgID) + ackManager := NewAckManager(0, nil) ackers := map[uint64]*Acker{} for i := 0; i < 10; i++ { acker, err := ackManager.Allocate(ctx) assert.NoError(t, err) - assert.True(t, acker.LastConfirmedMessageID().EQ(msgID)) ackers[acker.Timestamp()] = acker } @@ -42,28 +57,28 @@ func TestAck(t *testing.T) { // notAck: [1, 3, ..., 10] // ack: [2] - ackers[2].Ack() + ackers[2].Ack(OptSync()) details, err = ackManager.SyncAndGetAcknowledged(ctx) assert.NoError(t, err) assert.Empty(t, details) // notAck: [1, 3, 5, ..., 10] // ack: [2, 4] - ackers[4].Ack() + ackers[4].Ack(OptSync()) details, err = ackManager.SyncAndGetAcknowledged(ctx) assert.NoError(t, err) assert.Empty(t, details) // notAck: [3, 5, ..., 10] // ack: [1, 2, 4] - ackers[1].Ack() + ackers[1].Ack(OptSync()) // notAck: [3, 5, ..., 10] // ack: [4] details, err = ackManager.SyncAndGetAcknowledged(ctx) assert.NoError(t, err) assert.Equal(t, 2, len(details)) - assert.Equal(t, uint64(1), details[0].Timestamp) - assert.Equal(t, uint64(2), details[1].Timestamp) + assert.Equal(t, uint64(1), details[0].BeginTimestamp) + assert.Equal(t, uint64(2), details[1].BeginTimestamp) // notAck: [3, 5, ..., 10] // ack: [4] @@ -74,7 +89,7 @@ func TestAck(t *testing.T) { // notAck: [3] // ack: [4, ..., 10] for i := 5; i <= 10; i++ { - ackers[uint64(i)].Ack() + ackers[uint64(i)].Ack(OptSync()) } details, err = ackManager.SyncAndGetAcknowledged(ctx) assert.NoError(t, err) @@ -92,7 +107,7 @@ func TestAck(t *testing.T) { // notAck: [...,x, y] // ack: [3, ..., 10] - ackers[3].Ack() + ackers[3].Ack(OptSync()) // notAck: [...,x, y] // ack: [] @@ -106,8 +121,8 @@ func TestAck(t *testing.T) { assert.NoError(t, err) assert.Empty(t, details) - tsX.Ack() - tsY.Ack() + tsX.Ack(OptSync()) + tsY.Ack(OptSync()) // notAck: [] // ack: [] diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/detail.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/detail.go index b19e9be5b9..96fd24c97b 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/detail.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/detail.go @@ -3,6 +3,7 @@ package ack import ( "fmt" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" "github.com/milvus-io/milvus/pkg/streaming/util/message" ) @@ -12,7 +13,7 @@ func newAckDetail(ts uint64, lastConfirmedMessageID message.MessageID) *AckDetai panic(fmt.Sprintf("ts should never less than 0 %d", ts)) } return &AckDetail{ - Timestamp: ts, + BeginTimestamp: ts, LastConfirmedMessageID: lastConfirmedMessageID, IsSync: false, Err: nil, @@ -21,8 +22,12 @@ func newAckDetail(ts uint64, lastConfirmedMessageID message.MessageID) *AckDetai // AckDetail records the information of acker. type AckDetail struct { - Timestamp uint64 + BeginTimestamp uint64 // the timestamp when acker is allocated. + EndTimestamp uint64 // the timestamp when acker is acknowledged. + // for avoiding allocation of timestamp failure, the timestamp will use the ack manager last allocated timestamp. LastConfirmedMessageID message.MessageID + MessageID message.MessageID + TxnSession *txn.TxnSession IsSync bool Err error } @@ -43,3 +48,17 @@ func OptError(err error) AckOption { detail.Err = err } } + +// OptMessageID marks the message id for acker. +func OptMessageID(messageID message.MessageID) AckOption { + return func(detail *AckDetail) { + detail.MessageID = messageID + } +} + +// OptTxnSession marks the session for acker. +func OptTxnSession(session *txn.TxnSession) AckOption { + return func(detail *AckDetail) { + detail.TxnSession = session + } +} diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/detail_test.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/detail_test.go index 36dac55eef..1a9dc27cfe 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/detail_test.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/detail_test.go @@ -6,7 +6,9 @@ import ( "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" ) func TestDetail(t *testing.T) { @@ -17,7 +19,7 @@ func TestDetail(t *testing.T) { msgID.EXPECT().EQ(msgID).Return(true) ackDetail := newAckDetail(1, msgID) - assert.Equal(t, uint64(1), ackDetail.Timestamp) + assert.Equal(t, uint64(1), ackDetail.BeginTimestamp) assert.True(t, ackDetail.LastConfirmedMessageID.EQ(msgID)) assert.False(t, ackDetail.IsSync) assert.NoError(t, ackDetail.Err) @@ -26,4 +28,10 @@ func TestDetail(t *testing.T) { assert.True(t, ackDetail.IsSync) OptError(errors.New("test"))(ackDetail) assert.Error(t, ackDetail.Err) + + OptMessageID(walimplstest.NewTestMessageID(1))(ackDetail) + assert.NotNil(t, ackDetail.MessageID) + + OptTxnSession(&txn.TxnSession{})(ackDetail) + assert.NotNil(t, ackDetail.TxnSession) } diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/last_confirmed.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/last_confirmed.go new file mode 100644 index 0000000000..c43d894a87 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/last_confirmed.go @@ -0,0 +1,89 @@ +package ack + +import ( + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type uncommittedTxnInfo struct { + session *txn.TxnSession // if nil, it's a non-txn(autocommit) message. + messageID message.MessageID // the message id of the txn begins. +} + +// newLastConfirmedManager creates a new last confirmed manager. +func newLastConfirmedManager(lastConfirmedMessageID message.MessageID) *lastConfirmedManager { + return &lastConfirmedManager{ + lastConfirmedMessageID: lastConfirmedMessageID, + notDoneTxnMessage: typeutil.NewHeap[*uncommittedTxnInfo](&uncommittedTxnInfoOrderByMessageID{}), + } +} + +// lastConfirmedManager manages the last confirmed message id. +type lastConfirmedManager struct { + lastConfirmedMessageID message.MessageID + notDoneTxnMessage typeutil.Heap[*uncommittedTxnInfo] +} + +// AddConfirmedDetails adds the confirmed details. +func (m *lastConfirmedManager) AddConfirmedDetails(details sortedDetails, ts uint64) { + for _, detail := range details { + if detail.IsSync || detail.Err != nil { + continue + } + m.notDoneTxnMessage.Push(&uncommittedTxnInfo{ + session: detail.TxnSession, + messageID: detail.MessageID, + }) + } + m.updateLastConfirmedMessageID(ts) +} + +// GetLastConfirmedMessageID returns the last confirmed message id. +func (m *lastConfirmedManager) GetLastConfirmedMessageID() message.MessageID { + return m.lastConfirmedMessageID +} + +// updateLastConfirmedMessageID updates the last confirmed message id. +func (m *lastConfirmedManager) updateLastConfirmedMessageID(ts uint64) { + for m.notDoneTxnMessage.Len() > 0 && + (m.notDoneTxnMessage.Peek().session == nil || m.notDoneTxnMessage.Peek().session.IsExpiredOrDone(ts)) { + info := m.notDoneTxnMessage.Pop() + if m.lastConfirmedMessageID.LT(info.messageID) { + m.lastConfirmedMessageID = info.messageID + } + } +} + +// uncommittedTxnInfoOrderByMessageID is the heap array of the txnSession. +type uncommittedTxnInfoOrderByMessageID []*uncommittedTxnInfo + +func (h uncommittedTxnInfoOrderByMessageID) Len() int { + return len(h) +} + +func (h uncommittedTxnInfoOrderByMessageID) Less(i, j int) bool { + return h[i].messageID.LT(h[j].messageID) +} + +func (h uncommittedTxnInfoOrderByMessageID) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *uncommittedTxnInfoOrderByMessageID) Push(x interface{}) { + *h = append(*h, x.(*uncommittedTxnInfo)) +} + +// Pop pop the last one at len. +func (h *uncommittedTxnInfoOrderByMessageID) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +// Peek returns the element at the top of the heap. +func (h *uncommittedTxnInfoOrderByMessageID) Peek() interface{} { + return (*h)[0] +} diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go index 55583d286c..a34f897b07 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go @@ -4,43 +4,77 @@ import ( "context" "sync" + "go.uber.org/atomic" + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/syncutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) // AckManager manages the timestampAck. type AckManager struct { - mu sync.Mutex - notAckHeap typeutil.Heap[*Acker] // a minimum heap of timestampAck to search minimum timestamp in list. - lastConfirmedMessageID message.MessageID + cond *syncutil.ContextCond + lastAllocatedTimeTick uint64 // The last allocated time tick, the latest timestamp allocated by the allocator. + lastConfirmedTimeTick uint64 // The last confirmed time tick, the message which time tick less than lastConfirmedTimeTick has been committed into wal. + notAckHeap typeutil.Heap[*Acker] // A minimum heap of timestampAck to search minimum allocated but not ack timestamp in list. + ackHeap typeutil.Heap[*Acker] // A minimum heap of timestampAck to search minimum ack timestamp in list. + // It is used to detect the concurrent operation to find the last confirmed message id. + acknowledgedDetails sortedDetails // All ack details which time tick less than lastConfirmedTimeTick will be temporarily kept here until sync operation happens. + lastConfirmedManager *lastConfirmedManager // The last confirmed message id manager. } // NewAckManager creates a new timestampAckHelper. -func NewAckManager() *AckManager { +func NewAckManager( + lastConfirmedTimeTick uint64, + lastConfirmedMessageID message.MessageID, +) *AckManager { return &AckManager{ - mu: sync.Mutex{}, - notAckHeap: typeutil.NewHeap[*Acker](×tampWithAckArray{}), + cond: syncutil.NewContextCond(&sync.Mutex{}), + lastAllocatedTimeTick: 0, + notAckHeap: typeutil.NewHeap[*Acker](&ackersOrderByTimestamp{}), + ackHeap: typeutil.NewHeap[*Acker](&ackersOrderByEndTimestamp{}), + lastConfirmedTimeTick: lastConfirmedTimeTick, + lastConfirmedManager: newLastConfirmedManager(lastConfirmedMessageID), } } +// AllocateWithBarrier allocates a timestamp with a barrier. +func (ta *AckManager) AllocateWithBarrier(ctx context.Context, barrierTimeTick uint64) (*Acker, error) { + // wait until the lastConfirmedTimeTick is greater than barrierTimeTick. + ta.cond.L.Lock() + if ta.lastConfirmedTimeTick <= barrierTimeTick { + if err := ta.cond.Wait(ctx); err != nil { + return nil, err + } + } + ta.cond.L.Unlock() + + return ta.Allocate(ctx) +} + // Allocate allocates a timestamp. // Concurrent safe to call with Sync and Allocate. func (ta *AckManager) Allocate(ctx context.Context) (*Acker, error) { - ta.mu.Lock() - defer ta.mu.Unlock() + ta.cond.L.Lock() + defer ta.cond.L.Unlock() // allocate one from underlying allocator first. ts, err := resource.Resource().TSOAllocator().Allocate(ctx) if err != nil { return nil, err } + ta.lastAllocatedTimeTick = ts // create new timestampAck for ack process. // add ts to heap wait for ack. - tsWithAck := newAcker(ts, ta.lastConfirmedMessageID) - ta.notAckHeap.Push(tsWithAck) - return tsWithAck, nil + acker := &Acker{ + acknowledged: atomic.NewBool(false), + detail: newAckDetail(ts, ta.lastConfirmedManager.GetLastConfirmedMessageID()), + manager: ta, + } + ta.notAckHeap.Push(acker) + return acker, nil } // SyncAndGetAcknowledged syncs the ack records with allocator, and get the last all acknowledged info. @@ -57,33 +91,52 @@ func (ta *AckManager) SyncAndGetAcknowledged(ctx context.Context) ([]*AckDetail, } tsWithAck.Ack(OptSync()) - // update a new snapshot of acknowledged timestamps after sync up. - return ta.popUntilLastAllAcknowledged(), nil + ta.cond.L.Lock() + defer ta.cond.L.Unlock() + + details := ta.acknowledgedDetails + ta.acknowledgedDetails = make(sortedDetails, 0, 5) + return details, nil +} + +// ack marks the timestamp as acknowledged. +func (ta *AckManager) ack(acker *Acker) { + ta.cond.L.Lock() + defer ta.cond.L.Unlock() + + acker.detail.EndTimestamp = ta.lastAllocatedTimeTick + ta.ackHeap.Push(acker) + ta.popUntilLastAllAcknowledged() } // popUntilLastAllAcknowledged pops the timestamps until the one that all timestamps before it have been acknowledged. -func (ta *AckManager) popUntilLastAllAcknowledged() sortedDetails { - ta.mu.Lock() - defer ta.mu.Unlock() - +func (ta *AckManager) popUntilLastAllAcknowledged() { // pop all acknowledged timestamps. - details := make(sortedDetails, 0, 5) + acknowledgedDetails := make(sortedDetails, 0, 5) for ta.notAckHeap.Len() > 0 && ta.notAckHeap.Peek().acknowledged.Load() { ack := ta.notAckHeap.Pop() - details = append(details, ack.ackDetail()) + acknowledgedDetails = append(acknowledgedDetails, ack.ackDetail()) } - return details -} - -// AdvanceLastConfirmedMessageID update the last confirmed message id. -func (ta *AckManager) AdvanceLastConfirmedMessageID(msgID message.MessageID) { - if msgID == nil { + if len(acknowledgedDetails) == 0 { return } - ta.mu.Lock() - if ta.lastConfirmedMessageID == nil || ta.lastConfirmedMessageID.LT(msgID) { - ta.lastConfirmedMessageID = msgID + // broadcast to notify the last confirmed timetick updated. + ta.cond.UnsafeBroadcast() + + // update last confirmed time tick. + ta.lastConfirmedTimeTick = acknowledgedDetails[len(acknowledgedDetails)-1].BeginTimestamp + + // pop all EndTimestamp is less than lastConfirmedTimeTick. + // The message which EndTimetick less than lastConfirmedTimeTick has all been committed into wal. + // So the MessageID of the messages is dense and continuous. + confirmedDetails := make(sortedDetails, 0, 5) + for ta.ackHeap.Len() > 0 && ta.ackHeap.Peek().detail.EndTimestamp < ta.lastConfirmedTimeTick { + ack := ta.ackHeap.Pop() + confirmedDetails = append(confirmedDetails, ack.ackDetail()) } - ta.mu.Unlock() + ta.lastConfirmedManager.AddConfirmedDetails(confirmedDetails, ta.lastConfirmedTimeTick) + // TODO: cache update operation is also performed here. + + ta.acknowledgedDetails = append(ta.acknowledgedDetails, acknowledgedDetails...) } diff --git a/internal/streamingnode/server/wal/interceptors/timetick/builder.go b/internal/streamingnode/server/wal/interceptors/timetick/builder.go index 0e5d060e61..2fe398a677 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/builder.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/builder.go @@ -3,6 +3,7 @@ package timetick import ( "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" ) var _ interceptors.InterceptorBuilder = (*interceptorBuilder)(nil) @@ -24,6 +25,7 @@ func (b *interceptorBuilder) Build(param interceptors.InterceptorBuildParam) int go operator.initialize() resource.Resource().TimeTickInspector().RegisterSyncOperator(operator) return &timeTickAppendInterceptor{ - operator: operator, + operator: operator, + txnManager: txn.NewTxnManager(), } } diff --git a/internal/streamingnode/server/wal/interceptors/timetick/inspector/impls.go b/internal/streamingnode/server/wal/interceptors/timetick/inspector/impls.go index ee7463934b..4182f6804c 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/inspector/impls.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/inspector/impls.go @@ -3,6 +3,9 @@ package inspector import ( "time" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/types" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/syncutil" @@ -41,6 +44,7 @@ func (s *timeTickSyncInspectorImpl) MustGetOperator(pChannelInfo types.PChannelI // RegisterSyncOperator registers a sync operator. func (s *timeTickSyncInspectorImpl) RegisterSyncOperator(operator TimeTickSyncOperator) { + log.Info("RegisterSyncOperator", zap.String("channel", operator.Channel().Name)) _, loaded := s.operators.GetOrInsert(operator.Channel().Name, operator) if loaded { panic("sync operator already exists, critical bug in code") @@ -49,6 +53,7 @@ func (s *timeTickSyncInspectorImpl) RegisterSyncOperator(operator TimeTickSyncOp // UnregisterSyncOperator unregisters a sync operator. func (s *timeTickSyncInspectorImpl) UnregisterSyncOperator(operator TimeTickSyncOperator) { + log.Info("UnregisterSyncOperator", zap.String("channel", operator.Channel().Name)) _, loaded := s.operators.GetAndRemove(operator.Channel().Name) if !loaded { panic("sync operator not found, critical bug in code") diff --git a/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go b/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go index a081341e38..4786b745f0 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/timetick_interceptor.go @@ -2,24 +2,25 @@ package timetick import ( "context" + "time" "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" - "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/ack" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/txn" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" ) -var ( - _ interceptors.InterceptorWithReady = (*timeTickAppendInterceptor)(nil) - _ interceptors.InterceptorWithUnwrapMessageID = (*timeTickAppendInterceptor)(nil) -) +var _ interceptors.InterceptorWithReady = (*timeTickAppendInterceptor)(nil) // timeTickAppendInterceptor is a append interceptor. type timeTickAppendInterceptor struct { - operator *timeTickSyncOperator + operator *timeTickSyncOperator + txnManager *txn.TxnManager } // Ready implements AppendInterceptor. @@ -28,41 +29,85 @@ func (impl *timeTickAppendInterceptor) Ready() <-chan struct{} { } // Do implements AppendInterceptor. -func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, append interceptors.Append) (message.MessageID, error) { - var timetick uint64 - var msgID message.MessageID - var err error +func (impl *timeTickAppendInterceptor) DoAppend(ctx context.Context, msg message.MutableMessage, append interceptors.Append) (msgID message.MessageID, err error) { + var txnSession *txn.TxnSession if msg.MessageType() != message.MessageTypeTimeTick { - // Allocate new acker for message. + // Allocate new timestamp acker for message. var acker *ack.Acker - var err error - if acker, err = impl.operator.AckManager().Allocate(ctx); err != nil { - return nil, errors.Wrap(err, "allocate timestamp failed") + if msg.BarrierTimeTick() == 0 { + if acker, err = impl.operator.AckManager().Allocate(ctx); err != nil { + return nil, errors.Wrap(err, "allocate timestamp failed") + } + } else { + if acker, err = impl.operator.AckManager().AllocateWithBarrier(ctx, msg.BarrierTimeTick()); err != nil { + return nil, errors.Wrap(err, "allocate timestamp with barrier failed") + } } - defer func() { - acker.Ack(ack.OptError(err)) - impl.operator.AckManager().AdvanceLastConfirmedMessageID(msgID) - }() - // Assign timestamp to message and call append method. + // Assign timestamp to message and call the append method. msg = msg. WithTimeTick(acker.Timestamp()). // message assigned with these timetick. WithLastConfirmed(acker.LastConfirmedMessageID()) // start consuming from these message id, the message which timetick greater than current timetick will never be lost. - timetick = acker.Timestamp() - } else { - timetick = msg.TimeTick() + + defer func() { + if err != nil { + acker.Ack(ack.OptError(err)) + return + } + acker.Ack( + ack.OptMessageID(msgID), + ack.OptTxnSession(txnSession), + ) + }() } - // append the message into wal. - if msgID, err = append(ctx, msg); err != nil { - return nil, err + switch msg.MessageType() { + case message.MessageTypeBeginTxn: + if txnSession, msg, err = impl.handleBegin(ctx, msg); err != nil { + return nil, err + } + case message.MessageTypeCommitTxn: + if txnSession, err = impl.handleCommit(ctx, msg); err != nil { + return nil, err + } + defer txnSession.CommitDone() + case message.MessageTypeRollbackTxn: + if txnSession, err = impl.handleRollback(ctx, msg); err != nil { + return nil, err + } + defer txnSession.RollbackDone() + case message.MessageTypeTimeTick: + // cleanup the expired transaction sessions and the already done transaction. + impl.txnManager.CleanupTxnUntil(msg.TimeTick()) + default: + // handle the transaction body message. + if msg.TxnContext() != nil { + if txnSession, err = impl.handleTxnMessage(ctx, msg); err != nil { + return nil, err + } + defer func() { + if err != nil { + txnSession.AddNewMessageFail() + } + // perform keepalive for the transaction session if append success. + txnSession.AddNewMessageAndKeepalive(msg.TimeTick()) + }() + } } - // wrap message id with timetick. - return wrapMessageIDWithTimeTick{ - MessageID: msgID, - timetick: timetick, - }, nil + // Attach the txn session to the context. + // So the all interceptors of append operation can see it. + if txnSession != nil { + ctx = txn.WithTxnSession(ctx, txnSession) + } + msgID, err = impl.appendMsg(ctx, msg, append) + return +} + +// GracefulClose implements InterceptorWithGracefulClose. +func (impl *timeTickAppendInterceptor) GracefulClose() { + log.Warn("timeTickAppendInterceptor is closing") + impl.txnManager.GracefulClose() } // Close implements AppendInterceptor. @@ -71,13 +116,84 @@ func (impl *timeTickAppendInterceptor) Close() { impl.operator.Close() } -func (impl *timeTickAppendInterceptor) UnwrapMessageID(r *wal.AppendResult) { - m := r.MessageID.(wrapMessageIDWithTimeTick) - r.MessageID = m.MessageID - r.TimeTick = m.timetick +// handleBegin handle the begin transaction message. +func (impl *timeTickAppendInterceptor) handleBegin(ctx context.Context, msg message.MutableMessage) (*txn.TxnSession, message.MutableMessage, error) { + beginTxnMsg, err := message.AsMutableBeginTxnMessageV2(msg) + if err != nil { + return nil, nil, err + } + // Begin transaction will generate a txn context. + session, err := impl.txnManager.BeginNewTxn(ctx, msg.TimeTick(), time.Duration(beginTxnMsg.Header().KeepaliveMilliseconds)*time.Millisecond) + if err != nil { + session.BeginRollback() + return nil, nil, err + } + session.BeginDone() + return session, msg.WithTxnContext(session.TxnContext()), nil } -type wrapMessageIDWithTimeTick struct { - message.MessageID - timetick uint64 +// handleCommit handle the commit transaction message. +func (impl *timeTickAppendInterceptor) handleCommit(ctx context.Context, msg message.MutableMessage) (*txn.TxnSession, error) { + commitTxnMsg, err := message.AsMutableCommitTxnMessageV2(msg) + if err != nil { + return nil, err + } + session, err := impl.txnManager.GetSessionOfTxn(commitTxnMsg.TxnContext().TxnID) + if err != nil { + return nil, err + } + + // Start commit the message. + if err = session.RequestCommitAndWait(ctx, msg.TimeTick()); err != nil { + return nil, err + } + return session, nil +} + +// handleRollback handle the rollback transaction message. +func (impl *timeTickAppendInterceptor) handleRollback(ctx context.Context, msg message.MutableMessage) (session *txn.TxnSession, err error) { + rollbackTxnMsg, err := message.AsMutableRollbackTxnMessageV2(msg) + if err != nil { + return nil, err + } + session, err = impl.txnManager.GetSessionOfTxn(rollbackTxnMsg.TxnContext().TxnID) + if err != nil { + return nil, err + } + + // Start commit the message. + if err = session.RequestRollback(ctx, msg.TimeTick()); err != nil { + return nil, err + } + return session, nil +} + +// handleTxnMessage handle the transaction body message. +func (impl *timeTickAppendInterceptor) handleTxnMessage(ctx context.Context, msg message.MutableMessage) (session *txn.TxnSession, err error) { + txnContext := msg.TxnContext() + session, err = impl.txnManager.GetSessionOfTxn(txnContext.TxnID) + if err != nil { + return nil, err + } + // Add the message to the transaction. + if err = session.AddNewMessage(ctx, msg.TimeTick()); err != nil { + return nil, err + } + return session, nil +} + +// appendMsg is a helper function to append message. +func (impl *timeTickAppendInterceptor) appendMsg( + ctx context.Context, + msg message.MutableMessage, + append func(context.Context, message.MutableMessage) (message.MessageID, error), +) (message.MessageID, error) { + msgID, err := append(ctx, msg) + if err != nil { + return nil, err + } + + utility.AttachAppendResultTimeTick(ctx, msg.TimeTick()) + utility.AttachAppendResultTxnContext(ctx, msg.TxnContext()) + return msgID, nil } diff --git a/internal/streamingnode/server/wal/interceptors/timetick/timetick_sync_operator.go b/internal/streamingnode/server/wal/interceptors/timetick/timetick_sync_operator.go index d3e3032c01..ff091a7059 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/timetick_sync_operator.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/timetick_sync_operator.go @@ -11,6 +11,7 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/ack" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/inspector" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/types" @@ -31,7 +32,7 @@ func newTimeTickSyncOperator(param interceptors.InterceptorBuildParam) *timeTick pchannel: param.WALImpls.Channel(), ready: make(chan struct{}), interceptorBuildParam: param, - ackManager: ack.NewAckManager(), + ackManager: nil, ackDetails: ack.NewAckDetails(), sourceID: paramtable.GetNodeID(), timeTickNotifier: inspector.NewTimeTickNotifier(), @@ -136,10 +137,13 @@ func (impl *timeTickSyncOperator) blockUntilSyncTimeTickReady() error { lastErr = errors.Wrap(err, "allocate timestamp failed") continue } - if err := impl.sendPersistentTsMsg(impl.ctx, ts, nil, underlyingWALImpls.Append); err != nil { + msgID, err := impl.sendPersistentTsMsg(impl.ctx, ts, nil, underlyingWALImpls.Append) + if err != nil { lastErr = errors.Wrap(err, "send first timestamp message failed") continue } + // initialize ack manager. + impl.ackManager = ack.NewAckManager(ts, msgID) break } // interceptor is ready now. @@ -190,11 +194,11 @@ func (impl *timeTickSyncOperator) sendTsMsg(ctx context.Context, appender func(c if impl.ackDetails.IsNoPersistedMessage() { // there's no persisted message, so no need to send persistent time tick message. - // only update it to notify the scanner. - return impl.notifyNoPersistentTsMsg(ts) + return impl.sendNoPersistentTsMsg(ctx, ts, appender) } // otherwise, send persistent time tick message. - return impl.sendPersistentTsMsg(ctx, ts, lastConfirmedMessageID, appender) + _, err := impl.sendPersistentTsMsg(ctx, ts, lastConfirmedMessageID, appender) + return err } // sendPersistentTsMsg sends persistent time tick message to wal. @@ -202,16 +206,16 @@ func (impl *timeTickSyncOperator) sendPersistentTsMsg(ctx context.Context, ts uint64, lastConfirmedMessageID message.MessageID, appender func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error), -) error { +) (message.MessageID, error) { msg, err := NewTimeTickMsg(ts, lastConfirmedMessageID, impl.sourceID) if err != nil { - return errors.Wrap(err, "at build time tick msg") + return nil, errors.Wrap(err, "at build time tick msg") } // Append it to wal. msgID, err := appender(ctx, msg) if err != nil { - return errors.Wrapf(err, + return nil, errors.Wrapf(err, "append time tick msg to wal failed, timestamp: %d, previous message counter: %d", impl.ackDetails.LastAllAcknowledgedTimestamp(), impl.ackDetails.Len(), @@ -220,19 +224,40 @@ func (impl *timeTickSyncOperator) sendPersistentTsMsg(ctx context.Context, // Ack details has been committed to wal, clear it. impl.ackDetails.Clear() - // Update last confirmed message id, so that the ack manager can use it for next time tick ack allocation. - impl.ackManager.AdvanceLastConfirmedMessageID(msgID) // Update last time tick message id and time tick. impl.timeTickNotifier.Update(inspector.TimeTickInfo{ MessageID: msgID, TimeTick: ts, }) - return nil + return msgID, nil } -// notifyNoPersistentTsMsg sends no persistent time tick message. -func (impl *timeTickSyncOperator) notifyNoPersistentTsMsg(ts uint64) error { +// sendNoPersistentTsMsg sends no persistent time tick message to wal. +func (impl *timeTickSyncOperator) sendNoPersistentTsMsg(ctx context.Context, ts uint64, appender func(ctx context.Context, msg message.MutableMessage) (message.MessageID, error)) error { + msg, err := NewTimeTickMsg(ts, nil, impl.sourceID) + if err != nil { + return errors.Wrap(err, "at build time tick msg when send no persist msg") + } + + // with the hint of not persisted message, the underlying wal will not persist it. + // but the interceptors will still be triggered. + ctx = utility.WithNotPersisted(ctx, &utility.NotPersistedHint{ + MessageID: impl.timeTickNotifier.Get().MessageID, + }) + + // Append it to wal. + _, err = appender(ctx, msg) + if err != nil { + return errors.Wrapf(err, + "append no persist time tick msg to wal failed, timestamp: %d, previous message counter: %d", + impl.ackDetails.LastAllAcknowledgedTimestamp(), + impl.ackDetails.Len(), + ) + } + + // Ack details has been committed to wal, clear it. impl.ackDetails.Clear() + // Only update time tick. impl.timeTickNotifier.OnlyUpdateTs(ts) return nil } diff --git a/internal/streamingnode/server/wal/interceptors/timetick/timetick_sync_operator_test.go b/internal/streamingnode/server/wal/interceptors/timetick/timetick_sync_operator_test.go index a428717811..da4ee8ee66 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/timetick_sync_operator_test.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/timetick_sync_operator_test.go @@ -12,6 +12,8 @@ import ( "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/internal/streamingnode/server/wal" "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/interceptors/timetick/ack" + "github.com/milvus-io/milvus/internal/streamingnode/server/wal/utility" "github.com/milvus-io/milvus/pkg/mocks/streaming/mock_walimpls" "github.com/milvus-io/milvus/pkg/streaming/util/message" "github.com/milvus-io/milvus/pkg/streaming/util/types" @@ -50,6 +52,14 @@ func TestTimeTickSyncOperator(t *testing.T) { operator.initialize() <-operator.Ready() l := mock_wal.NewMockWAL(t) + l.EXPECT().Append(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, mm message.MutableMessage) (*types.AppendResult, error) { + hint := utility.GetNotPersisted(ctx) + assert.NotNil(t, hint) + return &types.AppendResult{ + MessageID: hint.MessageID, + TimeTick: mm.TimeTick(), + }, nil + }) walFuture.Set(l) // Test the sync operation, but there is no message to sync. @@ -77,7 +87,8 @@ func TestTimeTickSyncOperator(t *testing.T) { operator.Sync(ctx) // After ack, a wal operation will be trigger. - acker.Ack() + acker.Ack(ack.OptMessageID(msgID), ack.OptTxnSession(nil)) + l.EXPECT().Append(mock.Anything, mock.Anything).Unset() l.EXPECT().Append(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, mm message.MutableMessage) (*types.AppendResult, error) { ts, _ := resource.Resource().TSOAllocator().Allocate(ctx) return &types.AppendResult{ diff --git a/internal/streamingnode/server/wal/interceptors/txn/session.go b/internal/streamingnode/server/wal/interceptors/txn/session.go new file mode 100644 index 0000000000..3f5d385742 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/txn/session.go @@ -0,0 +1,262 @@ +package txn + +import ( + "context" + "sync" + + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/tsoutil" +) + +type txnSessionKeyType int + +var txnSessionKeyValue txnSessionKeyType = 1 + +// TxnSession is a session for a transaction. +type TxnSession struct { + mu sync.Mutex + + lastTimetick uint64 // session last timetick. + expired bool // The flag indicates the transaction has trigger expired once. + txnContext message.TxnContext // transaction id of the session + inFlightCount int // The message is in flight count of the session. + state message.TxnState // The state of the session. + doneWait chan struct{} // The channel for waiting the transaction committed. + rollback bool // The flag indicates the transaction is rollbacked. + cleanupCallbacks []func() // The cleanup callbacks function for the session. +} + +// TxnContext returns the txn context of the session. +func (s *TxnSession) TxnContext() message.TxnContext { + return s.txnContext +} + +// BeginDone marks the transaction as in flight. +func (s *TxnSession) BeginDone() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.state != message.TxnStateBegin { + // unreachable code here. + panic("invalid state for in flight") + } + s.state = message.TxnStateInFlight +} + +// BeginRollback marks the transaction as rollbacked at begin state. +func (s *TxnSession) BeginRollback() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.state != message.TxnStateBegin { + // unreachable code here. + panic("invalid state for rollback") + } + s.state = message.TxnStateRollbacked +} + +// AddNewMessage adds a new message to the session. +func (s *TxnSession) AddNewMessage(ctx context.Context, timetick uint64) error { + // if the txn is expired, return error. + if err := s.checkIfExpired(timetick); err != nil { + return err + } + s.mu.Lock() + defer s.mu.Unlock() + + if s.state != message.TxnStateInFlight { + return status.NewInvalidTransactionState("AddNewMessage", message.TxnStateInFlight, s.state) + } + s.inFlightCount++ + return nil +} + +// AddNewMessageAndKeepalive decreases the in flight count of the session and keepalive the session. +// notify the committedWait channel if the in flight count is 0 and committed waited. +func (s *TxnSession) AddNewMessageAndKeepalive(timetick uint64) { + s.mu.Lock() + defer s.mu.Unlock() + + // make a refresh lease here. + if s.lastTimetick < timetick { + s.lastTimetick = timetick + } + s.inFlightCount-- + if s.doneWait != nil && s.inFlightCount == 0 { + close(s.doneWait) + } +} + +// AddNewMessageFail decreases the in flight count of the session but not refresh the lease. +func (s *TxnSession) AddNewMessageFail() { + s.mu.Lock() + defer s.mu.Unlock() + + s.inFlightCount-- + if s.doneWait != nil && s.inFlightCount == 0 { + close(s.doneWait) + } +} + +// isExpiredOrDone checks if the session is expired or done. +func (s *TxnSession) IsExpiredOrDone(ts uint64) bool { + s.mu.Lock() + defer s.mu.Unlock() + + return s.isExpiredOrDone(ts) +} + +// isExpiredOrDone checks if the session is expired or done. +func (s *TxnSession) isExpiredOrDone(ts uint64) bool { + // A timeout txn or rollbacked/committed txn should be cleared. + // OnCommit and OnRollback session should not be cleared before timeout to + // avoid session clear callback to be called too early. + return s.expiredTimeTick() <= ts || s.state == message.TxnStateRollbacked || s.state == message.TxnStateCommitted +} + +// expiredTimeTick returns the expired time tick of the session. +func (s *TxnSession) expiredTimeTick() uint64 { + return tsoutil.AddPhysicalDurationOnTs(s.lastTimetick, s.txnContext.Keepalive) +} + +// RequestCommitAndWait request commits the transaction and waits for the all messages sent. +func (s *TxnSession) RequestCommitAndWait(ctx context.Context, timetick uint64) error { + waitCh, err := s.getDoneChan(timetick, message.TxnStateOnCommit) + if err != nil { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-waitCh: + return nil + } +} + +// CommitDone marks the transaction as committed. +func (s *TxnSession) CommitDone() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.state != message.TxnStateOnCommit { + // unreachable code here. + panic("invalid state for commit done") + } + s.state = message.TxnStateCommitted + s.cleanup() +} + +// RequestRollback rolls back the transaction. +func (s *TxnSession) RequestRollback(ctx context.Context, timetick uint64) error { + waitCh, err := s.getDoneChan(timetick, message.TxnStateOnRollback) + if err != nil { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-waitCh: + return nil + } +} + +// RollbackDone marks the transaction as rollbacked. +func (s *TxnSession) RollbackDone() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.state != message.TxnStateOnRollback { + // unreachable code here. + panic("invalid state for rollback done") + } + s.state = message.TxnStateRollbacked + s.cleanup() +} + +// RegisterCleanup registers the cleanup function for the session. +// It will be called when the session is expired or done. +// !!! A committed/rollbacked or expired session will never be seen by other components. +// so the cleanup function will always be called. +func (s *TxnSession) RegisterCleanup(f func(), ts uint64) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.isExpiredOrDone(ts) { + panic("unreachable code: register cleanup for expired or done session") + } + s.cleanupCallbacks = append(s.cleanupCallbacks, f) +} + +// Cleanup cleans up the session. +func (s *TxnSession) Cleanup() { + s.mu.Lock() + defer s.mu.Unlock() + + s.cleanup() +} + +// cleanup calls the cleanup functions. +func (s *TxnSession) cleanup() { + for _, f := range s.cleanupCallbacks { + f() + } + s.cleanupCallbacks = nil +} + +// getDoneChan returns the channel for waiting the transaction committed. +func (s *TxnSession) getDoneChan(timetick uint64, state message.TxnState) (<-chan struct{}, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if err := s.checkIfExpired(timetick); err != nil { + return nil, err + } + + if s.state != message.TxnStateInFlight { + return nil, status.NewInvalidTransactionState("GetWaitChan", message.TxnStateInFlight, s.state) + } + s.state = state + + if s.doneWait == nil { + s.doneWait = make(chan struct{}) + if s.inFlightCount == 0 { + close(s.doneWait) + } + } + return s.doneWait, nil +} + +// checkIfExpired checks if the session is expired. +func (s *TxnSession) checkIfExpired(tt uint64) error { + if s.expired { + return status.NewTransactionExpired("some message has been expired, expired at %d, current %d", s.expiredTimeTick(), tt) + } + expiredTimeTick := s.expiredTimeTick() + if tt >= expiredTimeTick { + // once the session is expired, it will never be active again. + s.expired = true + return status.NewTransactionExpired("transaction expired at %d, current %d", expiredTimeTick, tt) + } + return nil +} + +// WithTxnSession returns a new context with the TxnSession. +func WithTxnSession(ctx context.Context, session *TxnSession) context.Context { + return context.WithValue(ctx, txnSessionKeyValue, session) +} + +// GetTxnSessionFromContext returns the TxnSession from the context. +func GetTxnSessionFromContext(ctx context.Context) *TxnSession { + if ctx == nil { + return nil + } + if v := ctx.Value(txnSessionKeyValue); v != nil { + if session, ok := v.(*TxnSession); ok { + return session + } + } + return nil +} diff --git a/internal/streamingnode/server/wal/interceptors/txn/session_test.go b/internal/streamingnode/server/wal/interceptors/txn/session_test.go new file mode 100644 index 0000000000..30e067c661 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/txn/session_test.go @@ -0,0 +1,184 @@ +package txn + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" + + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" +) + +func TestMain(m *testing.M) { + paramtable.Init() + m.Run() +} + +func TestSession(t *testing.T) { + resource.InitForTest(t) + ctx := context.Background() + + m := NewTxnManager() + session, err := m.BeginNewTxn(ctx, 0, 10*time.Millisecond) + assert.NotNil(t, session) + assert.NoError(t, err) + + // Test Begin + assert.Equal(t, message.TxnStateBegin, session.state) + assert.False(t, session.IsExpiredOrDone(0)) + expiredTs := tsoutil.AddPhysicalDurationOnTs(0, 10*time.Millisecond) + assert.True(t, session.IsExpiredOrDone(expiredTs)) + session.BeginRollback() + assert.Equal(t, message.TxnStateRollbacked, session.state) + assert.True(t, session.IsExpiredOrDone(0)) + + session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond) + assert.NoError(t, err) + session.BeginDone() + assert.Equal(t, message.TxnStateInFlight, session.state) + assert.False(t, session.IsExpiredOrDone(0)) + + // Test add new message + err = session.AddNewMessage(ctx, expiredTs) + assert.Error(t, err) + serr := status.AsStreamingError(err) + assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code) + + // Test add new message after expire, should expired forever. + err = session.AddNewMessage(ctx, 0) + assert.Error(t, err) + serr = status.AsStreamingError(err) + assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code) + + session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond) + assert.NoError(t, err) + session.BeginDone() + assert.NoError(t, err) + err = session.AddNewMessage(ctx, 0) + assert.NoError(t, err) + session.AddNewMessageAndKeepalive(0) + + // Test Commit. + err = session.RequestCommitAndWait(ctx, 0) + assert.NoError(t, err) + assert.Equal(t, message.TxnStateOnCommit, session.state) + session.CommitDone() + assert.Equal(t, message.TxnStateCommitted, session.state) + + // Test Commit timeout. + session, err = m.BeginNewTxn(ctx, 0, 10*time.Millisecond) + assert.NoError(t, err) + session.BeginDone() + err = session.AddNewMessage(ctx, 0) + assert.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + err = session.RequestCommitAndWait(ctx, 0) + assert.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + + // Test Commit Expired + err = session.RequestCommitAndWait(ctx, expiredTs) + assert.Error(t, err) + serr = status.AsStreamingError(err) + assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code) + + // Test Rollback + session, _ = m.BeginNewTxn(context.Background(), 0, 10*time.Millisecond) + session.BeginDone() + // Rollback expired. + err = session.RequestRollback(context.Background(), expiredTs) + assert.Error(t, err) + serr = status.AsStreamingError(err) + assert.Equal(t, streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, serr.Code) + + // Rollback success + session, _ = m.BeginNewTxn(context.Background(), 0, 10*time.Millisecond) + session.BeginDone() + err = session.RequestRollback(context.Background(), 0) + assert.NoError(t, err) + assert.Equal(t, message.TxnStateOnRollback, session.state) +} + +func TestManager(t *testing.T) { + resource.InitForTest(t) + m := NewTxnManager() + + wg := &sync.WaitGroup{} + + wg.Add(20) + count := atomic.NewInt32(20) + for i := 0; i < 20; i++ { + go func(i int) { + defer wg.Done() + session, err := m.BeginNewTxn(context.Background(), 0, time.Duration(i+1)*time.Millisecond) + assert.NoError(t, err) + assert.NotNil(t, session) + session.BeginDone() + + session, err = m.GetSessionOfTxn(session.TxnContext().TxnID) + assert.NoError(t, err) + assert.NotNil(t, session) + + session.RegisterCleanup(func() { + count.Dec() + }, 0) + if i%3 == 0 { + err := session.RequestCommitAndWait(context.Background(), 0) + session.CommitDone() + assert.NoError(t, err) + } else if i%3 == 1 { + err := session.RequestRollback(context.Background(), 0) + assert.NoError(t, err) + session.RollbackDone() + } + }(i) + } + wg.Wait() + + closed := make(chan struct{}) + go func() { + m.GracefulClose() + close(closed) + }() + + select { + case <-closed: + t.Errorf("manager should not be closed") + case <-time.After(10 * time.Millisecond): + } + + expiredTs := tsoutil.AddPhysicalDurationOnTs(0, 10*time.Millisecond) + m.CleanupTxnUntil(expiredTs) + select { + case <-closed: + t.Errorf("manager should not be closed") + case <-time.After(10 * time.Millisecond): + } + + m.CleanupTxnUntil(tsoutil.AddPhysicalDurationOnTs(0, 20*time.Millisecond)) + select { + case <-closed: + case <-time.After(10 * time.Millisecond): + t.Errorf("manager should be closed") + } + + assert.Equal(t, int32(0), count.Load()) +} + +func TestWithCo(t *testing.T) { + session := &TxnSession{} + ctx := WithTxnSession(context.Background(), session) + + session = GetTxnSessionFromContext(ctx) + assert.NotNil(t, session) +} diff --git a/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go b/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go new file mode 100644 index 0000000000..507369cec4 --- /dev/null +++ b/internal/streamingnode/server/wal/interceptors/txn/txn_manager.go @@ -0,0 +1,106 @@ +package txn + +import ( + "context" + "sync" + "time" + + "github.com/milvus-io/milvus/internal/streamingnode/server/resource" + "github.com/milvus-io/milvus/internal/util/streamingutil/status" + "github.com/milvus-io/milvus/pkg/streaming/util/message" +) + +// NewTxnManager creates a new transaction manager. +func NewTxnManager() *TxnManager { + return &TxnManager{ + mu: sync.Mutex{}, + sessions: make(map[message.TxnID]*TxnSession), + closed: nil, + } +} + +// TxnManager is the manager of transactions. +// We don't support cross wal transaction by now and +// We don't support the transaction lives after the wal transferred to another streaming node. +type TxnManager struct { + mu sync.Mutex + sessions map[message.TxnID]*TxnSession + closed chan struct{} +} + +// BeginNewTxn starts a new transaction with a session. +// We only support a transaction work on a streaming node, once the wal is transferred to another node, +// the transaction is treated as expired (rollback), and user will got a expired error, then perform a retry. +func (m *TxnManager) BeginNewTxn(ctx context.Context, timetick uint64, keepalive time.Duration) (*TxnSession, error) { + id, err := resource.Resource().IDAllocator().Allocate(ctx) + if err != nil { + return nil, err + } + m.mu.Lock() + defer m.mu.Unlock() + + // The manager is on graceful shutdown. + // Avoid creating new transactions. + if m.closed != nil { + return nil, status.NewTransactionExpired("manager closed") + } + session := &TxnSession{ + mu: sync.Mutex{}, + lastTimetick: timetick, + txnContext: message.TxnContext{ + TxnID: message.TxnID(id), + Keepalive: keepalive, + }, + inFlightCount: 0, + state: message.TxnStateBegin, + doneWait: nil, + rollback: false, + } + + m.sessions[session.TxnContext().TxnID] = session + return session, nil +} + +// CleanupTxnUntil cleans up the transactions until the specified timestamp. +func (m *TxnManager) CleanupTxnUntil(ts uint64) { + m.mu.Lock() + defer m.mu.Unlock() + + for id, session := range m.sessions { + if session.IsExpiredOrDone(ts) { + session.Cleanup() + delete(m.sessions, id) + } + } + + // If the manager is on graceful shutdown and all transactions are cleaned up. + if len(m.sessions) == 0 && m.closed != nil { + close(m.closed) + } +} + +// GetSessionOfTxn returns the session of the transaction. +func (m *TxnManager) GetSessionOfTxn(id message.TxnID) (*TxnSession, error) { + m.mu.Lock() + defer m.mu.Unlock() + + session, ok := m.sessions[id] + if !ok { + return nil, status.NewTransactionExpired("not found in manager") + } + return session, nil +} + +// GracefulClose waits for all transactions to be cleaned up. +func (m *TxnManager) GracefulClose() { + m.mu.Lock() + if m.closed == nil { + m.closed = make(chan struct{}) + if len(m.sessions) == 0 { + close(m.closed) + } + } + m.mu.Unlock() + + <-m.closed +} diff --git a/internal/streamingnode/server/wal/scanner.go b/internal/streamingnode/server/wal/scanner.go index db80346355..1c5b5aade9 100644 --- a/internal/streamingnode/server/wal/scanner.go +++ b/internal/streamingnode/server/wal/scanner.go @@ -17,7 +17,7 @@ var ErrUpstreamClosed = errors.New("upstream closed") // ReadOption is the option for reading records from the wal. type ReadOption struct { DeliverPolicy options.DeliverPolicy - MessageFilter MessageFilter + MessageFilter []options.DeliverFilter MesasgeHandler MessageHandler // message handler for message processing. // If the message handler is nil (no redundant operation need to apply), // the default message handler will be used, and the receiver will be returned from Chan. diff --git a/internal/streamingnode/server/wal/utility/context.go b/internal/streamingnode/server/wal/utility/context.go new file mode 100644 index 0000000000..8b9453e366 --- /dev/null +++ b/internal/streamingnode/server/wal/utility/context.go @@ -0,0 +1,66 @@ +package utility + +import ( + "context" + + "google.golang.org/protobuf/types/known/anypb" + + "github.com/milvus-io/milvus/pkg/streaming/util/message" +) + +// walCtxKey is the key type of extra append result. +type walCtxKey int + +var ( + extraAppendResultValue walCtxKey = 1 + notPersistedValue walCtxKey = 2 +) + +// ExtraAppendResult is the extra append result. +type ExtraAppendResult struct { + TimeTick uint64 + TxnCtx *message.TxnContext + Extra *anypb.Any +} + +// NotPersistedHint is the hint of not persisted message. +type NotPersistedHint struct { + MessageID message.MessageID // The reused MessageID. +} + +// WithNotPersisted set not persisted message to context +func WithNotPersisted(ctx context.Context, hint *NotPersistedHint) context.Context { + return context.WithValue(ctx, notPersistedValue, hint) +} + +// GetNotPersisted get not persisted message from context +func GetNotPersisted(ctx context.Context) *NotPersistedHint { + val := ctx.Value(notPersistedValue) + if val == nil { + return nil + } + return val.(*NotPersistedHint) +} + +// WithExtraAppendResult set extra to context +func WithExtraAppendResult(ctx context.Context, r *ExtraAppendResult) context.Context { + return context.WithValue(ctx, extraAppendResultValue, r) +} + +// AttachAppendResultExtra set extra to context +func AttachAppendResultExtra(ctx context.Context, extra *anypb.Any) { + result := ctx.Value(extraAppendResultValue) + result.(*ExtraAppendResult).Extra = extra +} + +// AttachAppendResultTimeTick set time tick to context +func AttachAppendResultTimeTick(ctx context.Context, timeTick uint64) { + result := ctx.Value(extraAppendResultValue) + result.(*ExtraAppendResult).TimeTick = timeTick +} + +// AttachAppendResultTxnContext set txn context to context +func AttachAppendResultTxnContext(ctx context.Context, txnCtx *message.TxnContext) { + result := ctx.Value(extraAppendResultValue) + result.(*ExtraAppendResult).TxnCtx = txnCtx +} diff --git a/internal/streamingnode/server/wal/utility/txn_buffer.go b/internal/streamingnode/server/wal/utility/txn_buffer.go new file mode 100644 index 0000000000..647d14ccdc --- /dev/null +++ b/internal/streamingnode/server/wal/utility/txn_buffer.go @@ -0,0 +1,163 @@ +package utility + +import ( + "go.uber.org/zap" + + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/message" +) + +// NewTxnBuffer creates a new txn buffer. +func NewTxnBuffer(logger *log.MLogger) *TxnBuffer { + return &TxnBuffer{ + logger: logger, + builders: make(map[message.TxnID]*message.ImmutableTxnMessageBuilder), + } +} + +// TxnBuffer is a buffer for txn messages. +type TxnBuffer struct { + logger *log.MLogger + builders map[message.TxnID]*message.ImmutableTxnMessageBuilder +} + +// HandleImmutableMessages handles immutable messages. +// The timetick of msgs should be in ascending order, and the timetick of all messages is less than or equal to ts. +// Hold the uncommitted txn messages until the commit or rollback message comes and pop the committed txn messages. +func (b *TxnBuffer) HandleImmutableMessages(msgs []message.ImmutableMessage, ts uint64) []message.ImmutableMessage { + result := make([]message.ImmutableMessage, 0, len(msgs)) + for _, msg := range msgs { + // Not a txn message, can be consumed right now. + if msg.TxnContext() == nil { + result = append(result, msg) + continue + } + switch msg.MessageType() { + case message.MessageTypeBeginTxn: + b.handleBeginTxn(msg) + case message.MessageTypeCommitTxn: + if newTxnMsg := b.handleCommitTxn(msg); newTxnMsg != nil { + result = append(result, newTxnMsg) + } + case message.MessageTypeRollbackTxn: + b.handleRollbackTxn(msg) + default: + b.handleTxnBodyMessage(msg) + } + } + b.clearExpiredTxn(ts) + return result +} + +// handleBeginTxn handles begin txn message. +func (b *TxnBuffer) handleBeginTxn(msg message.ImmutableMessage) { + beginMsg, err := message.AsImmutableBeginTxnMessageV2(msg) + if err != nil { + b.logger.DPanic( + "failed to convert message to begin txn message, it's a critical error", + zap.Int64("txnID", int64(beginMsg.TxnContext().TxnID)), + zap.Any("messageID", beginMsg.MessageID()), + zap.Error(err)) + return + } + if _, ok := b.builders[beginMsg.TxnContext().TxnID]; ok { + b.logger.Warn( + "txn id already exist, so ignore the repeated begin txn message", + zap.Int64("txnID", int64(beginMsg.TxnContext().TxnID)), + zap.Any("messageID", beginMsg.MessageID()), + ) + return + } + b.builders[beginMsg.TxnContext().TxnID] = message.NewImmutableTxnMessageBuilder(beginMsg) +} + +// handleCommitTxn handles commit txn message. +func (b *TxnBuffer) handleCommitTxn(msg message.ImmutableMessage) message.ImmutableMessage { + commitMsg, err := message.AsImmutableCommitTxnMessageV2(msg) + if err != nil { + b.logger.DPanic( + "failed to convert message to commit txn message, it's a critical error", + zap.Int64("txnID", int64(commitMsg.TxnContext().TxnID)), + zap.Any("messageID", commitMsg.MessageID()), + zap.Error(err)) + return nil + } + builder, ok := b.builders[commitMsg.TxnContext().TxnID] + if !ok { + b.logger.Warn( + "txn id not exist, it may be a repeated committed message, so ignore it", + zap.Int64("txnID", int64(commitMsg.TxnContext().TxnID)), + zap.Any("messageID", commitMsg.MessageID()), + ) + return nil + } + + // build the txn message and remove it from buffer. + txnMsg, err := builder.Build(commitMsg) + delete(b.builders, commitMsg.TxnContext().TxnID) + if err != nil { + b.logger.Warn( + "failed to build txn message, it's a critical error, some data is lost", + zap.Int64("txnID", int64(commitMsg.TxnContext().TxnID)), + zap.Any("messageID", commitMsg.MessageID()), + zap.Error(err)) + return nil + } + b.logger.Debug( + "the txn is committed", + zap.Int64("txnID", int64(commitMsg.TxnContext().TxnID)), + zap.Any("messageID", commitMsg.MessageID()), + ) + return txnMsg +} + +// handleRollbackTxn handles rollback txn message. +func (b *TxnBuffer) handleRollbackTxn(msg message.ImmutableMessage) { + rollbackMsg, err := message.AsImmutableRollbackTxnMessageV2(msg) + if err != nil { + b.logger.DPanic( + "failed to convert message to rollback txn message, it's a critical error", + zap.Int64("txnID", int64(rollbackMsg.TxnContext().TxnID)), + zap.Any("messageID", rollbackMsg.MessageID()), + zap.Error(err)) + return + } + b.logger.Debug( + "the txn is rollback, so drop the txn from buffer", + zap.Int64("txnID", int64(rollbackMsg.TxnContext().TxnID)), + zap.Any("messageID", rollbackMsg.MessageID()), + ) + // just drop the txn from buffer. + delete(b.builders, rollbackMsg.TxnContext().TxnID) +} + +// handleTxnBodyMessage handles txn body message. +func (b *TxnBuffer) handleTxnBodyMessage(msg message.ImmutableMessage) { + builder, ok := b.builders[msg.TxnContext().TxnID] + if !ok { + b.logger.Warn( + "txn id not exist, so ignore the body message", + zap.Int64("txnID", int64(msg.TxnContext().TxnID)), + zap.Any("messageID", msg.MessageID()), + ) + return + } + builder.Add(msg) +} + +// clearExpiredTxn clears the expired txn. +func (b *TxnBuffer) clearExpiredTxn(ts uint64) { + for txnID, builder := range b.builders { + if builder.ExpiredTimeTick() <= ts { + delete(b.builders, txnID) + if b.logger.Level().Enabled(zap.DebugLevel) { + b.logger.Debug( + "the txn is expired, so drop the txn from buffer", + zap.Int64("txnID", int64(txnID)), + zap.Uint64("expiredTimeTick", builder.ExpiredTimeTick()), + zap.Uint64("currentTimeTick", ts), + ) + } + } + } +} diff --git a/internal/streamingnode/server/wal/utility/txn_buffer_test.go b/internal/streamingnode/server/wal/utility/txn_buffer_test.go new file mode 100644 index 0000000000..c6280bb56a --- /dev/null +++ b/internal/streamingnode/server/wal/utility/txn_buffer_test.go @@ -0,0 +1,154 @@ +package utility + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" + "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var idAllocator = typeutil.NewIDAllocator() + +func TestTxnBuffer(t *testing.T) { + b := NewTxnBuffer(log.With()) + + baseTso := tsoutil.GetCurrentTime() + + msgs := b.HandleImmutableMessages([]message.ImmutableMessage{ + newInsertMessage(t, nil, baseTso), + newInsertMessage(t, nil, baseTso), + newInsertMessage(t, nil, baseTso), + }, tsoutil.AddPhysicalDurationOnTs(baseTso, time.Millisecond)) + assert.Len(t, msgs, 3) + + msgs = b.HandleImmutableMessages([]message.ImmutableMessage{ + newInsertMessage(t, nil, baseTso), + newInsertMessage(t, &message.TxnContext{ + TxnID: 1, + Keepalive: time.Second, + }, baseTso), + newInsertMessage(t, nil, baseTso), + newRollbackMessage(t, &message.TxnContext{ + TxnID: 1, + Keepalive: time.Second, + }, baseTso), + newCommitMessage(t, &message.TxnContext{ + TxnID: 2, + Keepalive: time.Second, + }, baseTso), + }, tsoutil.AddPhysicalDurationOnTs(baseTso, time.Millisecond)) + assert.Len(t, msgs, 2) + + // Test successful commit + txnCtx := &message.TxnContext{ + TxnID: 1, + Keepalive: 201 * time.Millisecond, + } + createUnCommitted := func() { + msgs = b.HandleImmutableMessages([]message.ImmutableMessage{ + newBeginMessage(t, txnCtx, baseTso), + }, tsoutil.AddPhysicalDurationOnTs(baseTso, time.Millisecond)) + assert.Len(t, msgs, 0) + + msgs = b.HandleImmutableMessages([]message.ImmutableMessage{ + newInsertMessage(t, txnCtx, tsoutil.AddPhysicalDurationOnTs(baseTso, 100*time.Millisecond)), + }, tsoutil.AddPhysicalDurationOnTs(baseTso, 200*time.Millisecond)) + assert.Len(t, msgs, 0) + + msgs = b.HandleImmutableMessages([]message.ImmutableMessage{ + newInsertMessage(t, nil, tsoutil.AddPhysicalDurationOnTs(baseTso, 250*time.Millisecond)), + newInsertMessage(t, txnCtx, tsoutil.AddPhysicalDurationOnTs(baseTso, 300*time.Millisecond)), + }, tsoutil.AddPhysicalDurationOnTs(baseTso, 400*time.Millisecond)) + // non txn message should be passed. + assert.Len(t, msgs, 1) + } + createUnCommitted() + msgs = b.HandleImmutableMessages([]message.ImmutableMessage{ + newCommitMessage(t, txnCtx, tsoutil.AddPhysicalDurationOnTs(baseTso, 500*time.Millisecond)), + }, tsoutil.AddPhysicalDurationOnTs(baseTso, 600*time.Millisecond)) + assert.Len(t, msgs, 1) + assert.Len(t, b.builders, 0) + + // Test rollback + txnCtx.TxnID = 2 + createUnCommitted() + msgs = b.HandleImmutableMessages([]message.ImmutableMessage{ + newRollbackMessage(t, txnCtx, tsoutil.AddPhysicalDurationOnTs(baseTso, 500*time.Millisecond)), + }, tsoutil.AddPhysicalDurationOnTs(baseTso, 600*time.Millisecond)) + assert.Len(t, msgs, 0) + assert.Len(t, b.builders, 0) + + // Test expired txn + createUnCommitted() + msgs = b.HandleImmutableMessages([]message.ImmutableMessage{}, tsoutil.AddPhysicalDurationOnTs(baseTso, 500*time.Millisecond)) + assert.Len(t, msgs, 0) + assert.Len(t, b.builders, 1) + msgs = b.HandleImmutableMessages([]message.ImmutableMessage{}, tsoutil.AddPhysicalDurationOnTs(baseTso, 501*time.Millisecond)) + assert.Len(t, msgs, 0) + assert.Len(t, b.builders, 0) +} + +func newInsertMessage(t *testing.T, txnCtx *message.TxnContext, ts uint64) message.ImmutableMessage { + msg, err := message.NewInsertMessageBuilderV1(). + WithVChannel("v1"). + WithHeader(&message.InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{}). + BuildMutable() + assert.NoError(t, err) + assert.NotNil(t, msg) + if txnCtx != nil { + msg = msg.WithTxnContext(*txnCtx) + } + return msg.WithTimeTick(ts). + WithLastConfirmedUseMessageID(). + IntoImmutableMessage(walimplstest.NewTestMessageID(idAllocator.Allocate())) +} + +func newBeginMessage(t *testing.T, txnCtx *message.TxnContext, ts uint64) message.ImmutableMessage { + msg, err := message.NewBeginTxnMessageBuilderV2(). + WithVChannel("v1"). + WithHeader(&message.BeginTxnMessageHeader{}). + WithBody(&message.BeginTxnMessageBody{}). + BuildMutable() + assert.NoError(t, err) + assert.NotNil(t, msg) + return msg.WithTimeTick(ts). + WithLastConfirmedUseMessageID(). + WithTxnContext(*txnCtx). + IntoImmutableMessage(walimplstest.NewTestMessageID(idAllocator.Allocate())) +} + +func newCommitMessage(t *testing.T, txnCtx *message.TxnContext, ts uint64) message.ImmutableMessage { + msg, err := message.NewCommitTxnMessageBuilderV2(). + WithVChannel("v1"). + WithHeader(&message.CommitTxnMessageHeader{}). + WithBody(&message.CommitTxnMessageBody{}). + BuildMutable() + assert.NoError(t, err) + assert.NotNil(t, msg) + return msg.WithTimeTick(ts). + WithLastConfirmedUseMessageID(). + WithTxnContext(*txnCtx). + IntoImmutableMessage(walimplstest.NewTestMessageID(idAllocator.Allocate())) +} + +func newRollbackMessage(t *testing.T, txnCtx *message.TxnContext, ts uint64) message.ImmutableMessage { + msg, err := message.NewRollbackTxnMessageBuilderV2(). + WithVChannel("v1"). + WithHeader(&message.RollbackTxnMessageHeader{}). + WithBody(&message.RollbackTxnMessageBody{}). + BuildMutable() + assert.NoError(t, err) + assert.NotNil(t, msg) + return msg.WithTimeTick(ts). + WithLastConfirmedUseMessageID(). + WithTxnContext(*txnCtx). + IntoImmutableMessage(walimplstest.NewTestMessageID(idAllocator.Allocate())) +} diff --git a/internal/util/streamingutil/status/rpc_error.go b/internal/util/streamingutil/status/rpc_error.go index ed74bcfa0f..4a4da36131 100644 --- a/internal/util/streamingutil/status/rpc_error.go +++ b/internal/util/streamingutil/status/rpc_error.go @@ -13,16 +13,18 @@ import ( ) var streamingErrorToGRPCStatus = map[streamingpb.StreamingCode]codes.Code{ - streamingpb.StreamingCode_STREAMING_CODE_OK: codes.OK, - streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST: codes.FailedPrecondition, - streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_FENCED: codes.FailedPrecondition, - streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN: codes.FailedPrecondition, - streamingpb.StreamingCode_STREAMING_CODE_INVALID_REQUEST_SEQ: codes.FailedPrecondition, - streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM: codes.FailedPrecondition, - streamingpb.StreamingCode_STREAMING_CODE_IGNORED_OPERATION: codes.FailedPrecondition, - streamingpb.StreamingCode_STREAMING_CODE_INNER: codes.Internal, - streamingpb.StreamingCode_STREAMING_CODE_INVAILD_ARGUMENT: codes.InvalidArgument, - streamingpb.StreamingCode_STREAMING_CODE_UNKNOWN: codes.Unknown, + streamingpb.StreamingCode_STREAMING_CODE_OK: codes.OK, + streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST: codes.FailedPrecondition, + streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_FENCED: codes.FailedPrecondition, + streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN: codes.FailedPrecondition, + streamingpb.StreamingCode_STREAMING_CODE_INVALID_REQUEST_SEQ: codes.FailedPrecondition, + streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM: codes.FailedPrecondition, + streamingpb.StreamingCode_STREAMING_CODE_IGNORED_OPERATION: codes.FailedPrecondition, + streamingpb.StreamingCode_STREAMING_CODE_INNER: codes.Internal, + streamingpb.StreamingCode_STREAMING_CODE_INVAILD_ARGUMENT: codes.InvalidArgument, + streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED: codes.FailedPrecondition, + streamingpb.StreamingCode_STREAMING_CODE_INVALID_TRANSACTION_STATE: codes.FailedPrecondition, + streamingpb.StreamingCode_STREAMING_CODE_UNKNOWN: codes.Unknown, } // NewGRPCStatusFromStreamingError converts StreamingError to grpc status. diff --git a/internal/util/streamingutil/status/streaming_error.go b/internal/util/streamingutil/status/streaming_error.go index 030ef7f00f..1fa176fb49 100644 --- a/internal/util/streamingutil/status/streaming_error.go +++ b/internal/util/streamingutil/status/streaming_error.go @@ -7,6 +7,7 @@ import ( "github.com/cockroachdb/redact" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" ) var _ error = (*StreamingError)(nil) @@ -42,6 +43,12 @@ func (e *StreamingError) IsSkippedOperation() bool { e.Code == streamingpb.StreamingCode_STREAMING_CODE_UNMATCHED_CHANNEL_TERM } +// IsTxnUnavilable returns true if the transaction is unavailable. +func (e *StreamingError) IsTxnUnavilable() bool { + return e.Code == streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED || + e.Code == streamingpb.StreamingCode_STREAMING_CODE_INVALID_TRANSACTION_STATE +} + // NewOnShutdownError creates a new StreamingError with code STREAMING_CODE_ON_SHUTDOWN. func NewOnShutdownError(format string, args ...interface{}) *StreamingError { return New(streamingpb.StreamingCode_STREAMING_CODE_ON_SHUTDOWN, format, args...) @@ -57,6 +64,12 @@ func NewInvalidRequestSeq(format string, args ...interface{}) *StreamingError { return New(streamingpb.StreamingCode_STREAMING_CODE_INVALID_REQUEST_SEQ, format, args...) } +// NewChannelFenced creates a new StreamingError with code STREAMING_CODE_CHANNEL_FENCED. +// TODO: Unused by now, add it after enable wal fence. +func NewChannelFenced(channel string) *StreamingError { + return New(streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_FENCED, "%s fenced", channel) +} + // NewChannelNotExist creates a new StreamingError with code STREAMING_CODE_CHANNEL_NOT_EXIST. func NewChannelNotExist(channel string) *StreamingError { return New(streamingpb.StreamingCode_STREAMING_CODE_CHANNEL_NOT_EXIST, "%s not exist", channel) @@ -82,6 +95,16 @@ func NewInvaildArgument(format string, args ...interface{}) *StreamingError { return New(streamingpb.StreamingCode_STREAMING_CODE_INVAILD_ARGUMENT, format, args...) } +// NewTransactionExpired creates a new StreamingError with code STREAMING_CODE_TRANSACTION_EXPIRED. +func NewTransactionExpired(format string, args ...interface{}) *StreamingError { + return New(streamingpb.StreamingCode_STREAMING_CODE_TRANSACTION_EXPIRED, format, args...) +} + +// NewInvalidTransactionState creates a new StreamingError with code STREAMING_CODE_INVALID_TRANSACTION_STATE. +func NewInvalidTransactionState(operation string, expectState message.TxnState, currentState message.TxnState) *StreamingError { + return New(streamingpb.StreamingCode_STREAMING_CODE_INVALID_TRANSACTION_STATE, "invalid transaction state for operation %s, expect %s, current %s", operation, expectState, currentState) +} + // New creates a new StreamingError with the given code and cause. func New(code streamingpb.StreamingCode, format string, args ...interface{}) *StreamingError { if len(args) == 0 { diff --git a/pkg/.mockery_pkg.yaml b/pkg/.mockery_pkg.yaml index b0a1345318..b151a3b520 100644 --- a/pkg/.mockery_pkg.yaml +++ b/pkg/.mockery_pkg.yaml @@ -1,7 +1,7 @@ quiet: False with-expecter: True filename: "mock_{{.InterfaceName}}.go" -dir: "mocks/{{trimPrefix .PackagePath \"github.com/milvus-io/milvus/pkg\" | dir }}/mock_{{.PackageName}}" +dir: 'mocks/{{trimPrefix .PackagePath "github.com/milvus-io/milvus/pkg" | dir }}/mock_{{.PackageName}}' mockname: "Mock{{.InterfaceName}}" outpkg: "mock_{{.PackageName}}" packages: @@ -12,6 +12,7 @@ packages: interfaces: MessageID: ImmutableMessage: + ImmutableTxnMessage: MutableMessage: RProperties: github.com/milvus-io/milvus/pkg/streaming/walimpls: @@ -38,4 +39,3 @@ packages: StreamingNodeHandlerServiceClient: StreamingNodeHandlerService_ConsumeClient: StreamingNodeHandlerService_ProduceClient: - \ No newline at end of file diff --git a/pkg/mocks/streaming/util/mock_message/mock_ImmutableMessage.go b/pkg/mocks/streaming/util/mock_message/mock_ImmutableMessage.go index 426f86320b..3b3551268d 100644 --- a/pkg/mocks/streaming/util/mock_message/mock_ImmutableMessage.go +++ b/pkg/mocks/streaming/util/mock_message/mock_ImmutableMessage.go @@ -20,6 +20,47 @@ func (_m *MockImmutableMessage) EXPECT() *MockImmutableMessage_Expecter { return &MockImmutableMessage_Expecter{mock: &_m.Mock} } +// BarrierTimeTick provides a mock function with given fields: +func (_m *MockImmutableMessage) BarrierTimeTick() uint64 { + ret := _m.Called() + + var r0 uint64 + if rf, ok := ret.Get(0).(func() uint64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +// MockImmutableMessage_BarrierTimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BarrierTimeTick' +type MockImmutableMessage_BarrierTimeTick_Call struct { + *mock.Call +} + +// BarrierTimeTick is a helper method to define mock.On call +func (_e *MockImmutableMessage_Expecter) BarrierTimeTick() *MockImmutableMessage_BarrierTimeTick_Call { + return &MockImmutableMessage_BarrierTimeTick_Call{Call: _e.mock.On("BarrierTimeTick")} +} + +func (_c *MockImmutableMessage_BarrierTimeTick_Call) Run(run func()) *MockImmutableMessage_BarrierTimeTick_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableMessage_BarrierTimeTick_Call) Return(_a0 uint64) *MockImmutableMessage_BarrierTimeTick_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableMessage_BarrierTimeTick_Call) RunAndReturn(run func() uint64) *MockImmutableMessage_BarrierTimeTick_Call { + _c.Call.Return(run) + return _c +} + // EstimateSize provides a mock function with given fields: func (_m *MockImmutableMessage) EstimateSize() int { ret := _m.Called() @@ -315,6 +356,49 @@ func (_c *MockImmutableMessage_TimeTick_Call) RunAndReturn(run func() uint64) *M return _c } +// TxnContext provides a mock function with given fields: +func (_m *MockImmutableMessage) TxnContext() *message.TxnContext { + ret := _m.Called() + + var r0 *message.TxnContext + if rf, ok := ret.Get(0).(func() *message.TxnContext); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*message.TxnContext) + } + } + + return r0 +} + +// MockImmutableMessage_TxnContext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TxnContext' +type MockImmutableMessage_TxnContext_Call struct { + *mock.Call +} + +// TxnContext is a helper method to define mock.On call +func (_e *MockImmutableMessage_Expecter) TxnContext() *MockImmutableMessage_TxnContext_Call { + return &MockImmutableMessage_TxnContext_Call{Call: _e.mock.On("TxnContext")} +} + +func (_c *MockImmutableMessage_TxnContext_Call) Run(run func()) *MockImmutableMessage_TxnContext_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableMessage_TxnContext_Call) Return(_a0 *message.TxnContext) *MockImmutableMessage_TxnContext_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableMessage_TxnContext_Call) RunAndReturn(run func() *message.TxnContext) *MockImmutableMessage_TxnContext_Call { + _c.Call.Return(run) + return _c +} + // VChannel provides a mock function with given fields: func (_m *MockImmutableMessage) VChannel() string { ret := _m.Called() diff --git a/pkg/mocks/streaming/util/mock_message/mock_ImmutableTxnMessage.go b/pkg/mocks/streaming/util/mock_message/mock_ImmutableTxnMessage.go new file mode 100644 index 0000000000..2e69295279 --- /dev/null +++ b/pkg/mocks/streaming/util/mock_message/mock_ImmutableTxnMessage.go @@ -0,0 +1,706 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mock_message + +import ( + message "github.com/milvus-io/milvus/pkg/streaming/util/message" + mock "github.com/stretchr/testify/mock" +) + +// MockImmutableTxnMessage is an autogenerated mock type for the ImmutableTxnMessage type +type MockImmutableTxnMessage struct { + mock.Mock +} + +type MockImmutableTxnMessage_Expecter struct { + mock *mock.Mock +} + +func (_m *MockImmutableTxnMessage) EXPECT() *MockImmutableTxnMessage_Expecter { + return &MockImmutableTxnMessage_Expecter{mock: &_m.Mock} +} + +// BarrierTimeTick provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) BarrierTimeTick() uint64 { + ret := _m.Called() + + var r0 uint64 + if rf, ok := ret.Get(0).(func() uint64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +// MockImmutableTxnMessage_BarrierTimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BarrierTimeTick' +type MockImmutableTxnMessage_BarrierTimeTick_Call struct { + *mock.Call +} + +// BarrierTimeTick is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) BarrierTimeTick() *MockImmutableTxnMessage_BarrierTimeTick_Call { + return &MockImmutableTxnMessage_BarrierTimeTick_Call{Call: _e.mock.On("BarrierTimeTick")} +} + +func (_c *MockImmutableTxnMessage_BarrierTimeTick_Call) Run(run func()) *MockImmutableTxnMessage_BarrierTimeTick_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_BarrierTimeTick_Call) Return(_a0 uint64) *MockImmutableTxnMessage_BarrierTimeTick_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_BarrierTimeTick_Call) RunAndReturn(run func() uint64) *MockImmutableTxnMessage_BarrierTimeTick_Call { + _c.Call.Return(run) + return _c +} + +// Begin provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) Begin() message.ImmutableMessage { + ret := _m.Called() + + var r0 message.ImmutableMessage + if rf, ok := ret.Get(0).(func() message.ImmutableMessage); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.ImmutableMessage) + } + } + + return r0 +} + +// MockImmutableTxnMessage_Begin_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Begin' +type MockImmutableTxnMessage_Begin_Call struct { + *mock.Call +} + +// Begin is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) Begin() *MockImmutableTxnMessage_Begin_Call { + return &MockImmutableTxnMessage_Begin_Call{Call: _e.mock.On("Begin")} +} + +func (_c *MockImmutableTxnMessage_Begin_Call) Run(run func()) *MockImmutableTxnMessage_Begin_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_Begin_Call) Return(_a0 message.ImmutableMessage) *MockImmutableTxnMessage_Begin_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_Begin_Call) RunAndReturn(run func() message.ImmutableMessage) *MockImmutableTxnMessage_Begin_Call { + _c.Call.Return(run) + return _c +} + +// Commit provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) Commit() message.ImmutableMessage { + ret := _m.Called() + + var r0 message.ImmutableMessage + if rf, ok := ret.Get(0).(func() message.ImmutableMessage); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.ImmutableMessage) + } + } + + return r0 +} + +// MockImmutableTxnMessage_Commit_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Commit' +type MockImmutableTxnMessage_Commit_Call struct { + *mock.Call +} + +// Commit is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) Commit() *MockImmutableTxnMessage_Commit_Call { + return &MockImmutableTxnMessage_Commit_Call{Call: _e.mock.On("Commit")} +} + +func (_c *MockImmutableTxnMessage_Commit_Call) Run(run func()) *MockImmutableTxnMessage_Commit_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_Commit_Call) Return(_a0 message.ImmutableMessage) *MockImmutableTxnMessage_Commit_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_Commit_Call) RunAndReturn(run func() message.ImmutableMessage) *MockImmutableTxnMessage_Commit_Call { + _c.Call.Return(run) + return _c +} + +// EstimateSize provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) EstimateSize() int { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +// MockImmutableTxnMessage_EstimateSize_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EstimateSize' +type MockImmutableTxnMessage_EstimateSize_Call struct { + *mock.Call +} + +// EstimateSize is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) EstimateSize() *MockImmutableTxnMessage_EstimateSize_Call { + return &MockImmutableTxnMessage_EstimateSize_Call{Call: _e.mock.On("EstimateSize")} +} + +func (_c *MockImmutableTxnMessage_EstimateSize_Call) Run(run func()) *MockImmutableTxnMessage_EstimateSize_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_EstimateSize_Call) Return(_a0 int) *MockImmutableTxnMessage_EstimateSize_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_EstimateSize_Call) RunAndReturn(run func() int) *MockImmutableTxnMessage_EstimateSize_Call { + _c.Call.Return(run) + return _c +} + +// LastConfirmedMessageID provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) LastConfirmedMessageID() message.MessageID { + ret := _m.Called() + + var r0 message.MessageID + if rf, ok := ret.Get(0).(func() message.MessageID); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.MessageID) + } + } + + return r0 +} + +// MockImmutableTxnMessage_LastConfirmedMessageID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LastConfirmedMessageID' +type MockImmutableTxnMessage_LastConfirmedMessageID_Call struct { + *mock.Call +} + +// LastConfirmedMessageID is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) LastConfirmedMessageID() *MockImmutableTxnMessage_LastConfirmedMessageID_Call { + return &MockImmutableTxnMessage_LastConfirmedMessageID_Call{Call: _e.mock.On("LastConfirmedMessageID")} +} + +func (_c *MockImmutableTxnMessage_LastConfirmedMessageID_Call) Run(run func()) *MockImmutableTxnMessage_LastConfirmedMessageID_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_LastConfirmedMessageID_Call) Return(_a0 message.MessageID) *MockImmutableTxnMessage_LastConfirmedMessageID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_LastConfirmedMessageID_Call) RunAndReturn(run func() message.MessageID) *MockImmutableTxnMessage_LastConfirmedMessageID_Call { + _c.Call.Return(run) + return _c +} + +// MessageID provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) MessageID() message.MessageID { + ret := _m.Called() + + var r0 message.MessageID + if rf, ok := ret.Get(0).(func() message.MessageID); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.MessageID) + } + } + + return r0 +} + +// MockImmutableTxnMessage_MessageID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MessageID' +type MockImmutableTxnMessage_MessageID_Call struct { + *mock.Call +} + +// MessageID is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) MessageID() *MockImmutableTxnMessage_MessageID_Call { + return &MockImmutableTxnMessage_MessageID_Call{Call: _e.mock.On("MessageID")} +} + +func (_c *MockImmutableTxnMessage_MessageID_Call) Run(run func()) *MockImmutableTxnMessage_MessageID_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_MessageID_Call) Return(_a0 message.MessageID) *MockImmutableTxnMessage_MessageID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_MessageID_Call) RunAndReturn(run func() message.MessageID) *MockImmutableTxnMessage_MessageID_Call { + _c.Call.Return(run) + return _c +} + +// MessageType provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) MessageType() message.MessageType { + ret := _m.Called() + + var r0 message.MessageType + if rf, ok := ret.Get(0).(func() message.MessageType); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(message.MessageType) + } + + return r0 +} + +// MockImmutableTxnMessage_MessageType_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MessageType' +type MockImmutableTxnMessage_MessageType_Call struct { + *mock.Call +} + +// MessageType is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) MessageType() *MockImmutableTxnMessage_MessageType_Call { + return &MockImmutableTxnMessage_MessageType_Call{Call: _e.mock.On("MessageType")} +} + +func (_c *MockImmutableTxnMessage_MessageType_Call) Run(run func()) *MockImmutableTxnMessage_MessageType_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_MessageType_Call) Return(_a0 message.MessageType) *MockImmutableTxnMessage_MessageType_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_MessageType_Call) RunAndReturn(run func() message.MessageType) *MockImmutableTxnMessage_MessageType_Call { + _c.Call.Return(run) + return _c +} + +// Payload provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) Payload() []byte { + ret := _m.Called() + + var r0 []byte + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + return r0 +} + +// MockImmutableTxnMessage_Payload_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Payload' +type MockImmutableTxnMessage_Payload_Call struct { + *mock.Call +} + +// Payload is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) Payload() *MockImmutableTxnMessage_Payload_Call { + return &MockImmutableTxnMessage_Payload_Call{Call: _e.mock.On("Payload")} +} + +func (_c *MockImmutableTxnMessage_Payload_Call) Run(run func()) *MockImmutableTxnMessage_Payload_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_Payload_Call) Return(_a0 []byte) *MockImmutableTxnMessage_Payload_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_Payload_Call) RunAndReturn(run func() []byte) *MockImmutableTxnMessage_Payload_Call { + _c.Call.Return(run) + return _c +} + +// Properties provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) Properties() message.RProperties { + ret := _m.Called() + + var r0 message.RProperties + if rf, ok := ret.Get(0).(func() message.RProperties); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.RProperties) + } + } + + return r0 +} + +// MockImmutableTxnMessage_Properties_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Properties' +type MockImmutableTxnMessage_Properties_Call struct { + *mock.Call +} + +// Properties is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) Properties() *MockImmutableTxnMessage_Properties_Call { + return &MockImmutableTxnMessage_Properties_Call{Call: _e.mock.On("Properties")} +} + +func (_c *MockImmutableTxnMessage_Properties_Call) Run(run func()) *MockImmutableTxnMessage_Properties_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_Properties_Call) Return(_a0 message.RProperties) *MockImmutableTxnMessage_Properties_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_Properties_Call) RunAndReturn(run func() message.RProperties) *MockImmutableTxnMessage_Properties_Call { + _c.Call.Return(run) + return _c +} + +// RangeOver provides a mock function with given fields: visitor +func (_m *MockImmutableTxnMessage) RangeOver(visitor func(message.ImmutableMessage) error) error { + ret := _m.Called(visitor) + + var r0 error + if rf, ok := ret.Get(0).(func(func(message.ImmutableMessage) error) error); ok { + r0 = rf(visitor) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockImmutableTxnMessage_RangeOver_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RangeOver' +type MockImmutableTxnMessage_RangeOver_Call struct { + *mock.Call +} + +// RangeOver is a helper method to define mock.On call +// - visitor func(message.ImmutableMessage) error +func (_e *MockImmutableTxnMessage_Expecter) RangeOver(visitor interface{}) *MockImmutableTxnMessage_RangeOver_Call { + return &MockImmutableTxnMessage_RangeOver_Call{Call: _e.mock.On("RangeOver", visitor)} +} + +func (_c *MockImmutableTxnMessage_RangeOver_Call) Run(run func(visitor func(message.ImmutableMessage) error)) *MockImmutableTxnMessage_RangeOver_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(func(message.ImmutableMessage) error)) + }) + return _c +} + +func (_c *MockImmutableTxnMessage_RangeOver_Call) Return(_a0 error) *MockImmutableTxnMessage_RangeOver_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_RangeOver_Call) RunAndReturn(run func(func(message.ImmutableMessage) error) error) *MockImmutableTxnMessage_RangeOver_Call { + _c.Call.Return(run) + return _c +} + +// Size provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) Size() int { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +// MockImmutableTxnMessage_Size_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Size' +type MockImmutableTxnMessage_Size_Call struct { + *mock.Call +} + +// Size is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) Size() *MockImmutableTxnMessage_Size_Call { + return &MockImmutableTxnMessage_Size_Call{Call: _e.mock.On("Size")} +} + +func (_c *MockImmutableTxnMessage_Size_Call) Run(run func()) *MockImmutableTxnMessage_Size_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_Size_Call) Return(_a0 int) *MockImmutableTxnMessage_Size_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_Size_Call) RunAndReturn(run func() int) *MockImmutableTxnMessage_Size_Call { + _c.Call.Return(run) + return _c +} + +// TimeTick provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) TimeTick() uint64 { + ret := _m.Called() + + var r0 uint64 + if rf, ok := ret.Get(0).(func() uint64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +// MockImmutableTxnMessage_TimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TimeTick' +type MockImmutableTxnMessage_TimeTick_Call struct { + *mock.Call +} + +// TimeTick is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) TimeTick() *MockImmutableTxnMessage_TimeTick_Call { + return &MockImmutableTxnMessage_TimeTick_Call{Call: _e.mock.On("TimeTick")} +} + +func (_c *MockImmutableTxnMessage_TimeTick_Call) Run(run func()) *MockImmutableTxnMessage_TimeTick_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_TimeTick_Call) Return(_a0 uint64) *MockImmutableTxnMessage_TimeTick_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_TimeTick_Call) RunAndReturn(run func() uint64) *MockImmutableTxnMessage_TimeTick_Call { + _c.Call.Return(run) + return _c +} + +// TxnContext provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) TxnContext() *message.TxnContext { + ret := _m.Called() + + var r0 *message.TxnContext + if rf, ok := ret.Get(0).(func() *message.TxnContext); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*message.TxnContext) + } + } + + return r0 +} + +// MockImmutableTxnMessage_TxnContext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TxnContext' +type MockImmutableTxnMessage_TxnContext_Call struct { + *mock.Call +} + +// TxnContext is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) TxnContext() *MockImmutableTxnMessage_TxnContext_Call { + return &MockImmutableTxnMessage_TxnContext_Call{Call: _e.mock.On("TxnContext")} +} + +func (_c *MockImmutableTxnMessage_TxnContext_Call) Run(run func()) *MockImmutableTxnMessage_TxnContext_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_TxnContext_Call) Return(_a0 *message.TxnContext) *MockImmutableTxnMessage_TxnContext_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_TxnContext_Call) RunAndReturn(run func() *message.TxnContext) *MockImmutableTxnMessage_TxnContext_Call { + _c.Call.Return(run) + return _c +} + +// VChannel provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) VChannel() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockImmutableTxnMessage_VChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VChannel' +type MockImmutableTxnMessage_VChannel_Call struct { + *mock.Call +} + +// VChannel is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) VChannel() *MockImmutableTxnMessage_VChannel_Call { + return &MockImmutableTxnMessage_VChannel_Call{Call: _e.mock.On("VChannel")} +} + +func (_c *MockImmutableTxnMessage_VChannel_Call) Run(run func()) *MockImmutableTxnMessage_VChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_VChannel_Call) Return(_a0 string) *MockImmutableTxnMessage_VChannel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_VChannel_Call) RunAndReturn(run func() string) *MockImmutableTxnMessage_VChannel_Call { + _c.Call.Return(run) + return _c +} + +// Version provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) Version() message.Version { + ret := _m.Called() + + var r0 message.Version + if rf, ok := ret.Get(0).(func() message.Version); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(message.Version) + } + + return r0 +} + +// MockImmutableTxnMessage_Version_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Version' +type MockImmutableTxnMessage_Version_Call struct { + *mock.Call +} + +// Version is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) Version() *MockImmutableTxnMessage_Version_Call { + return &MockImmutableTxnMessage_Version_Call{Call: _e.mock.On("Version")} +} + +func (_c *MockImmutableTxnMessage_Version_Call) Run(run func()) *MockImmutableTxnMessage_Version_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_Version_Call) Return(_a0 message.Version) *MockImmutableTxnMessage_Version_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_Version_Call) RunAndReturn(run func() message.Version) *MockImmutableTxnMessage_Version_Call { + _c.Call.Return(run) + return _c +} + +// WALName provides a mock function with given fields: +func (_m *MockImmutableTxnMessage) WALName() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockImmutableTxnMessage_WALName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WALName' +type MockImmutableTxnMessage_WALName_Call struct { + *mock.Call +} + +// WALName is a helper method to define mock.On call +func (_e *MockImmutableTxnMessage_Expecter) WALName() *MockImmutableTxnMessage_WALName_Call { + return &MockImmutableTxnMessage_WALName_Call{Call: _e.mock.On("WALName")} +} + +func (_c *MockImmutableTxnMessage_WALName_Call) Run(run func()) *MockImmutableTxnMessage_WALName_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockImmutableTxnMessage_WALName_Call) Return(_a0 string) *MockImmutableTxnMessage_WALName_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImmutableTxnMessage_WALName_Call) RunAndReturn(run func() string) *MockImmutableTxnMessage_WALName_Call { + _c.Call.Return(run) + return _c +} + +// NewMockImmutableTxnMessage creates a new instance of MockImmutableTxnMessage. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockImmutableTxnMessage(t interface { + mock.TestingT + Cleanup(func()) +}) *MockImmutableTxnMessage { + mock := &MockImmutableTxnMessage{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/mocks/streaming/util/mock_message/mock_MutableMessage.go b/pkg/mocks/streaming/util/mock_message/mock_MutableMessage.go index 1006d9f33e..960c012e46 100644 --- a/pkg/mocks/streaming/util/mock_message/mock_MutableMessage.go +++ b/pkg/mocks/streaming/util/mock_message/mock_MutableMessage.go @@ -20,6 +20,47 @@ func (_m *MockMutableMessage) EXPECT() *MockMutableMessage_Expecter { return &MockMutableMessage_Expecter{mock: &_m.Mock} } +// BarrierTimeTick provides a mock function with given fields: +func (_m *MockMutableMessage) BarrierTimeTick() uint64 { + ret := _m.Called() + + var r0 uint64 + if rf, ok := ret.Get(0).(func() uint64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +// MockMutableMessage_BarrierTimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BarrierTimeTick' +type MockMutableMessage_BarrierTimeTick_Call struct { + *mock.Call +} + +// BarrierTimeTick is a helper method to define mock.On call +func (_e *MockMutableMessage_Expecter) BarrierTimeTick() *MockMutableMessage_BarrierTimeTick_Call { + return &MockMutableMessage_BarrierTimeTick_Call{Call: _e.mock.On("BarrierTimeTick")} +} + +func (_c *MockMutableMessage_BarrierTimeTick_Call) Run(run func()) *MockMutableMessage_BarrierTimeTick_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMutableMessage_BarrierTimeTick_Call) Return(_a0 uint64) *MockMutableMessage_BarrierTimeTick_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMutableMessage_BarrierTimeTick_Call) RunAndReturn(run func() uint64) *MockMutableMessage_BarrierTimeTick_Call { + _c.Call.Return(run) + return _c +} + // EstimateSize provides a mock function with given fields: func (_m *MockMutableMessage) EstimateSize() int { ret := _m.Called() @@ -273,6 +314,49 @@ func (_c *MockMutableMessage_TimeTick_Call) RunAndReturn(run func() uint64) *Moc return _c } +// TxnContext provides a mock function with given fields: +func (_m *MockMutableMessage) TxnContext() *message.TxnContext { + ret := _m.Called() + + var r0 *message.TxnContext + if rf, ok := ret.Get(0).(func() *message.TxnContext); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*message.TxnContext) + } + } + + return r0 +} + +// MockMutableMessage_TxnContext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TxnContext' +type MockMutableMessage_TxnContext_Call struct { + *mock.Call +} + +// TxnContext is a helper method to define mock.On call +func (_e *MockMutableMessage_Expecter) TxnContext() *MockMutableMessage_TxnContext_Call { + return &MockMutableMessage_TxnContext_Call{Call: _e.mock.On("TxnContext")} +} + +func (_c *MockMutableMessage_TxnContext_Call) Run(run func()) *MockMutableMessage_TxnContext_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockMutableMessage_TxnContext_Call) Return(_a0 *message.TxnContext) *MockMutableMessage_TxnContext_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMutableMessage_TxnContext_Call) RunAndReturn(run func() *message.TxnContext) *MockMutableMessage_TxnContext_Call { + _c.Call.Return(run) + return _c +} + // VChannel provides a mock function with given fields: func (_m *MockMutableMessage) VChannel() string { ret := _m.Called() @@ -355,6 +439,50 @@ func (_c *MockMutableMessage_Version_Call) RunAndReturn(run func() message.Versi return _c } +// WithBarrierTimeTick provides a mock function with given fields: tt +func (_m *MockMutableMessage) WithBarrierTimeTick(tt uint64) message.MutableMessage { + ret := _m.Called(tt) + + var r0 message.MutableMessage + if rf, ok := ret.Get(0).(func(uint64) message.MutableMessage); ok { + r0 = rf(tt) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.MutableMessage) + } + } + + return r0 +} + +// MockMutableMessage_WithBarrierTimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithBarrierTimeTick' +type MockMutableMessage_WithBarrierTimeTick_Call struct { + *mock.Call +} + +// WithBarrierTimeTick is a helper method to define mock.On call +// - tt uint64 +func (_e *MockMutableMessage_Expecter) WithBarrierTimeTick(tt interface{}) *MockMutableMessage_WithBarrierTimeTick_Call { + return &MockMutableMessage_WithBarrierTimeTick_Call{Call: _e.mock.On("WithBarrierTimeTick", tt)} +} + +func (_c *MockMutableMessage_WithBarrierTimeTick_Call) Run(run func(tt uint64)) *MockMutableMessage_WithBarrierTimeTick_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(uint64)) + }) + return _c +} + +func (_c *MockMutableMessage_WithBarrierTimeTick_Call) Return(_a0 message.MutableMessage) *MockMutableMessage_WithBarrierTimeTick_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMutableMessage_WithBarrierTimeTick_Call) RunAndReturn(run func(uint64) message.MutableMessage) *MockMutableMessage_WithBarrierTimeTick_Call { + _c.Call.Return(run) + return _c +} + // WithLastConfirmed provides a mock function with given fields: id func (_m *MockMutableMessage) WithLastConfirmed(id message.MessageID) message.MutableMessage { ret := _m.Called(id) @@ -486,6 +614,94 @@ func (_c *MockMutableMessage_WithTimeTick_Call) RunAndReturn(run func(uint64) me return _c } +// WithTxnContext provides a mock function with given fields: txnCtx +func (_m *MockMutableMessage) WithTxnContext(txnCtx message.TxnContext) message.MutableMessage { + ret := _m.Called(txnCtx) + + var r0 message.MutableMessage + if rf, ok := ret.Get(0).(func(message.TxnContext) message.MutableMessage); ok { + r0 = rf(txnCtx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.MutableMessage) + } + } + + return r0 +} + +// MockMutableMessage_WithTxnContext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithTxnContext' +type MockMutableMessage_WithTxnContext_Call struct { + *mock.Call +} + +// WithTxnContext is a helper method to define mock.On call +// - txnCtx message.TxnContext +func (_e *MockMutableMessage_Expecter) WithTxnContext(txnCtx interface{}) *MockMutableMessage_WithTxnContext_Call { + return &MockMutableMessage_WithTxnContext_Call{Call: _e.mock.On("WithTxnContext", txnCtx)} +} + +func (_c *MockMutableMessage_WithTxnContext_Call) Run(run func(txnCtx message.TxnContext)) *MockMutableMessage_WithTxnContext_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(message.TxnContext)) + }) + return _c +} + +func (_c *MockMutableMessage_WithTxnContext_Call) Return(_a0 message.MutableMessage) *MockMutableMessage_WithTxnContext_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMutableMessage_WithTxnContext_Call) RunAndReturn(run func(message.TxnContext) message.MutableMessage) *MockMutableMessage_WithTxnContext_Call { + _c.Call.Return(run) + return _c +} + +// WithWALTerm provides a mock function with given fields: term +func (_m *MockMutableMessage) WithWALTerm(term int64) message.MutableMessage { + ret := _m.Called(term) + + var r0 message.MutableMessage + if rf, ok := ret.Get(0).(func(int64) message.MutableMessage); ok { + r0 = rf(term) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.MutableMessage) + } + } + + return r0 +} + +// MockMutableMessage_WithWALTerm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithWALTerm' +type MockMutableMessage_WithWALTerm_Call struct { + *mock.Call +} + +// WithWALTerm is a helper method to define mock.On call +// - term int64 +func (_e *MockMutableMessage_Expecter) WithWALTerm(term interface{}) *MockMutableMessage_WithWALTerm_Call { + return &MockMutableMessage_WithWALTerm_Call{Call: _e.mock.On("WithWALTerm", term)} +} + +func (_c *MockMutableMessage_WithWALTerm_Call) Run(run func(term int64)) *MockMutableMessage_WithWALTerm_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockMutableMessage_WithWALTerm_Call) Return(_a0 message.MutableMessage) *MockMutableMessage_WithWALTerm_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockMutableMessage_WithWALTerm_Call) RunAndReturn(run func(int64) message.MutableMessage) *MockMutableMessage_WithWALTerm_Call { + _c.Call.Return(run) + return _c +} + // NewMockMutableMessage creates a new instance of MockMutableMessage. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockMutableMessage(t interface { diff --git a/pkg/streaming/proto/messages.proto b/pkg/streaming/proto/messages.proto index da07b36be9..7bd2201152 100644 --- a/pkg/streaming/proto/messages.proto +++ b/pkg/streaming/proto/messages.proto @@ -33,6 +33,27 @@ enum MessageType { DropCollection = 6; CreatePartition = 7; DropPartition = 8; + // begin transaction message is only used for transaction, once a begin + // transaction message is received, all messages combined with the + // transaction message cannot be consumed until a CommitTxn message + // is received. + BeginTxn = 900; + // commit transaction message is only used for transaction, once a commit + // transaction message is received, all messages combined with the + // transaction message can be consumed, the message combined with the + // transaction which is received after the commit transaction message will + // be drop. + CommitTxn = 901; + // rollback transaction message is only used for transaction, once a + // rollback transaction message is received, all messages combined with the + // transaction message can be discarded, the message combined with the + // transaction which is received after the rollback transaction message will + // be drop. + RollbackTxn = 902; + // txn message is a set of messages combined by multiple messages in a + // transaction. the txn properties is consist of the begin txn message and + // commit txn message. + Txn = 999; } /// @@ -55,6 +76,27 @@ message FlushMessageBody { repeated int64 segment_id = 2; // indicate which segment to flush. } +// BeginTxnMessageBody is the body of begin transaction message. +// Just do nothing now. +message BeginTxnMessageBody {} + +// CommitTxnMessageBody is the body of commit transaction message. +// Just do nothing now. +message CommitTxnMessageBody {} + +// RollbackTxnMessageBody is the body of rollback transaction message. +// Just do nothing now. +message RollbackTxnMessageBody {} + +// TxnMessageBody is the body of transaction message. +// A transaction message is combined by multiple messages. +// It's only can be seen at consume side. +// All message in a transaction message only has same timetick which is equal to +// the CommitTransationMessage. +message TxnMessageBody { + repeated Message messages = 1; +} + /// /// Message Header Definitions /// Used to fast handling at streaming node write ahead. @@ -114,3 +156,53 @@ message DropPartitionMessageHeader { int64 collection_id = 1; int64 partition_id = 2; } + +// BeginTxnMessageHeader is the header of begin transaction message. +// Just do nothing now. +// Add Channel info here to implement cross pchannel transaction. +message BeginTxnMessageHeader { + // the max milliseconds to keep alive of the transaction. + // the keepalive_milliseconds is never changed in a transaction by now, + int64 keepalive_milliseconds = 1; +} + +// CommitTxnMessageHeader is the header of commit transaction message. +// Just do nothing now. +message CommitTxnMessageHeader {} + +// RollbackTxnMessageHeader is the header of rollback transaction +// message. +// Just do nothing now. +message RollbackTxnMessageHeader {} + +// TxnMessageHeader is the header of transaction message. +// Just do nothing now. +message TxnMessageHeader {} + +// TxnContext is the context of transaction. +// It will be carried by every message in a transaction. +message TxnContext { + // the unique id of the transaction. + // the txn_id is never changed in a transaction. + int64 txn_id = 1; + // the next keep alive timeout of the transaction. + // after the keep alive timeout, the transaction will be expired. + int64 keepalive_milliseconds = 2; +} + +enum TxnState { + // should never be used. + TxnUnknown = 0; + // the transaction begin. + TxnBegin = 1; + // the transaction is in flight. + TxnInFlight = 2; + // the transaction is on commit. + TxnOnCommit = 3; + // the transaction is committed. + TxnCommitted = 4; + // the transaction is on rollback. + TxnOnRollback = 5; + // the transaction is rollbacked. + TxnRollbacked = 6; +} diff --git a/pkg/streaming/proto/streaming.proto b/pkg/streaming/proto/streaming.proto index ad9f9b25ae..7930d5a3c3 100644 --- a/pkg/streaming/proto/streaming.proto +++ b/pkg/streaming/proto/streaming.proto @@ -7,6 +7,7 @@ option go_package = "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb import "messages.proto"; import "milvus.proto"; import "google/protobuf/empty.proto"; +import "google/protobuf/any.proto"; // // Common @@ -187,15 +188,17 @@ message DeliverFilterMessageType { // StreamingCode is the error code for log internal component. enum StreamingCode { STREAMING_CODE_OK = 0; - STREAMING_CODE_CHANNEL_NOT_EXIST = 1; // channel not exist - STREAMING_CODE_CHANNEL_FENCED = 2; // channel is fenced - STREAMING_CODE_ON_SHUTDOWN = 3; // component is on shutdown - STREAMING_CODE_INVALID_REQUEST_SEQ = 4; // invalid request sequence - STREAMING_CODE_UNMATCHED_CHANNEL_TERM = 5; // unmatched channel term - STREAMING_CODE_IGNORED_OPERATION = 6; // ignored operation - STREAMING_CODE_INNER = 7; // underlying service failure. - STREAMING_CODE_INVAILD_ARGUMENT = 8; // invalid argument - STREAMING_CODE_UNKNOWN = 999; // unknown error + STREAMING_CODE_CHANNEL_NOT_EXIST = 1; // channel not exist + STREAMING_CODE_CHANNEL_FENCED = 2; // channel is fenced + STREAMING_CODE_ON_SHUTDOWN = 3; // component is on shutdown + STREAMING_CODE_INVALID_REQUEST_SEQ = 4; // invalid request sequence + STREAMING_CODE_UNMATCHED_CHANNEL_TERM = 5; // unmatched channel term + STREAMING_CODE_IGNORED_OPERATION = 6; // ignored operation + STREAMING_CODE_INNER = 7; // underlying service failure. + STREAMING_CODE_INVAILD_ARGUMENT = 8; // invalid argument + STREAMING_CODE_TRANSACTION_EXPIRED = 9; // transaction expired + STREAMING_CODE_INVALID_TRANSACTION_STATE = 10; // invalid transaction state + STREAMING_CODE_UNKNOWN = 999; // unknown error } // StreamingError is the error type for log internal component. @@ -289,6 +292,8 @@ message ProduceMessageResponse { message ProduceMessageResponseResult { messages.MessageID id = 1; // the offset of the message in the channel. uint64 timetick = 2; // the timetick of that message sent. + messages.TxnContext txnContext = 3; // the txn context of the message. + google.protobuf.Any extra = 4; // the extra message. } // CloseProducerResponse is the result of the CloseProducer RPC. diff --git a/pkg/streaming/util/message/builder.go b/pkg/streaming/util/message/builder.go index c53a2e8570..0eacc04555 100644 --- a/pkg/streaming/util/message/builder.go +++ b/pkg/streaming/util/message/builder.go @@ -7,10 +7,11 @@ import ( "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/pkg/util/tsoutil" ) // NewMutableMessage creates a new mutable message. -// Only used at server side for streamingnode internal service, don't use it at client side. +// !!! Only used at server side for streamingnode internal service, don't use it at client side. func NewMutableMessage(payload []byte, properties map[string]string) MutableMessage { return &messageImpl{ payload: payload, @@ -19,6 +20,7 @@ func NewMutableMessage(payload []byte, properties map[string]string) MutableMess } // NewImmutableMessage creates a new immutable message. +// !!! Only used at server side for streaming internal service, don't use it at client side. func NewImmutableMesasge( id MessageID, payload []byte, @@ -43,6 +45,10 @@ var ( NewCreatePartitionMessageBuilderV1 = createNewMessageBuilderV1[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest]() NewDropPartitionMessageBuilderV1 = createNewMessageBuilderV1[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest]() NewFlushMessageBuilderV2 = createNewMessageBuilderV2[*FlushMessageHeader, *FlushMessageBody]() + NewBeginTxnMessageBuilderV2 = createNewMessageBuilderV2[*BeginTxnMessageHeader, *BeginTxnMessageBody]() + NewCommitTxnMessageBuilderV2 = createNewMessageBuilderV2[*CommitTxnMessageHeader, *CommitTxnMessageBody]() + NewRollbackTxnMessageBuilderV2 = createNewMessageBuilderV2[*RollbackTxnMessageHeader, *RollbackTxnMessageBody]() + newTxnMessageBuilderV2 = createNewMessageBuilderV2[*TxnMessageHeader, *TxnMessageBody]() ) // createNewMessageBuilderV1 creates a new message builder with v1 marker. @@ -143,7 +149,7 @@ func (b *mutableMesasgeBuilder[H, B]) BuildMutable() (MutableMessage, error) { if err != nil { return nil, errors.Wrap(err, "failed to encode header") } - b.properties.Set(messageSpecialiedHeader, sp) + b.properties.Set(messageHeader, sp) payload, err := proto.Marshal(b.body) if err != nil { @@ -154,3 +160,73 @@ func (b *mutableMesasgeBuilder[H, B]) BuildMutable() (MutableMessage, error) { properties: b.properties, }, nil } + +// NewImmutableTxnMessageBuilder creates a new txn builder. +func NewImmutableTxnMessageBuilder(begin ImmutableBeginTxnMessageV2) *ImmutableTxnMessageBuilder { + return &ImmutableTxnMessageBuilder{ + txnCtx: *begin.TxnContext(), + begin: begin, + messages: make([]ImmutableMessage, 0), + } +} + +// ImmutableTxnMessageBuilder is a builder for txn message. +type ImmutableTxnMessageBuilder struct { + txnCtx TxnContext + begin ImmutableBeginTxnMessageV2 + messages []ImmutableMessage +} + +// ExpiredTimeTick returns the expired time tick of the txn. +func (b *ImmutableTxnMessageBuilder) ExpiredTimeTick() uint64 { + if len(b.messages) > 0 { + return tsoutil.AddPhysicalDurationOnTs(b.messages[len(b.messages)-1].TimeTick(), b.txnCtx.Keepalive) + } + return tsoutil.AddPhysicalDurationOnTs(b.begin.TimeTick(), b.txnCtx.Keepalive) +} + +// Push pushes a message into the txn builder. +func (b *ImmutableTxnMessageBuilder) Add(msg ImmutableMessage) *ImmutableTxnMessageBuilder { + b.messages = append(b.messages, msg) + return b +} + +// Build builds a txn message. +func (b *ImmutableTxnMessageBuilder) Build(commit ImmutableCommitTxnMessageV2) (ImmutableTxnMessage, error) { + msg, err := newImmutableTxnMesasgeFromWAL(b.begin, b.messages, commit) + b.begin = nil + b.messages = nil + return msg, err +} + +// newImmutableTxnMesasgeFromWAL creates a new immutable transaction message. +func newImmutableTxnMesasgeFromWAL( + begin ImmutableBeginTxnMessageV2, + body []ImmutableMessage, + commit ImmutableCommitTxnMessageV2, +) (ImmutableTxnMessage, error) { + // combine begin and commit messages into one. + msg, err := newTxnMessageBuilderV2(). + WithHeader(&TxnMessageHeader{}). + WithBody(&TxnMessageBody{}). + WithVChannel(begin.VChannel()). + BuildMutable() + if err != nil { + return nil, err + } + // we don't need to modify the begin message's timetick, but set all the timetick of body messages. + for _, m := range body { + m.(*immutableMessageImpl).overwriteTimeTick(commit.TimeTick()) + m.(*immutableMessageImpl).overwriteLastConfirmedMessageID(commit.LastConfirmedMessageID()) + } + immutableMsg := msg.WithTimeTick(commit.TimeTick()). + WithLastConfirmed(commit.LastConfirmedMessageID()). + WithTxnContext(*commit.TxnContext()). + IntoImmutableMessage(commit.MessageID()) + return &immutableTxnMessageImpl{ + immutableMessageImpl: *immutableMsg.(*immutableMessageImpl), + begin: begin, + messages: body, + commit: commit, + }, nil +} diff --git a/pkg/streaming/util/message/message.go b/pkg/streaming/util/message/message.go index 905ffcd4be..733ed568d8 100644 --- a/pkg/streaming/util/message/message.go +++ b/pkg/streaming/util/message/message.go @@ -3,9 +3,10 @@ package message import "google.golang.org/protobuf/proto" var ( - _ BasicMessage = (*messageImpl)(nil) - _ MutableMessage = (*messageImpl)(nil) - _ ImmutableMessage = (*immutableMessageImpl)(nil) + _ BasicMessage = (*messageImpl)(nil) + _ MutableMessage = (*messageImpl)(nil) + _ ImmutableMessage = (*immutableMessageImpl)(nil) + _ ImmutableTxnMessage = (*immutableTxnMessageImpl)(nil) ) // BasicMessage is the basic interface of message. @@ -37,6 +38,13 @@ type BasicMessage interface { // Available only when the message's version greater than 0. // Otherwise, it will panic. TimeTick() uint64 + + // BarrierTimeTick returns the barrier time tick of current message. + // 0 by default, no fence. + BarrierTimeTick() uint64 + + // TxnContext returns the transaction context of current message. + TxnContext() *TxnContext } // MutableMessage is the mutable message interface. @@ -44,17 +52,32 @@ type BasicMessage interface { type MutableMessage interface { BasicMessage + // WithBarrierTimeTick sets the barrier time tick of current message. + // these time tick is used to promised the message will be sent after that time tick. + // and the message which timetick is less than it will never concurrent append with it. + // !!! preserved for streaming system internal usage, don't call it outside of streaming system. + WithBarrierTimeTick(tt uint64) MutableMessage + + // WithWALTerm sets the wal term of current message. + // !!! preserved for streaming system internal usage, don't call it outside of streaming system. + WithWALTerm(term int64) MutableMessage + // WithLastConfirmed sets the last confirmed message id of current message. // !!! preserved for streaming system internal usage, don't call it outside of streaming system. WithLastConfirmed(id MessageID) MutableMessage // WithLastConfirmedUseMessageID sets the last confirmed message id of current message to be the same as message id. + // !!! preserved for streaming system internal usage, don't call it outside of streaming system. WithLastConfirmedUseMessageID() MutableMessage // WithTimeTick sets the time tick of current message. // !!! preserved for streaming system internal usage, don't call it outside of streaming system. WithTimeTick(tt uint64) MutableMessage + // WithTxnContext sets the transaction context of current message. + // !!! preserved for streaming system internal usage, don't call it outside of streaming system. + WithTxnContext(txnCtx TxnContext) MutableMessage + // IntoImmutableMessage converts the mutable message to immutable message. IntoImmutableMessage(msgID MessageID) ImmutableMessage } @@ -78,6 +101,26 @@ type ImmutableMessage interface { LastConfirmedMessageID() MessageID } +// ImmutableTxnMessage is the read-only transaction message interface. +// Once a transaction is committed, the wal will generate a transaction message. +// The MessageType() is always return MessageTypeTransaction if it's a transaction message. +type ImmutableTxnMessage interface { + ImmutableMessage + + // Begin returns the begin message of the transaction. + Begin() ImmutableMessage + + // Commit returns the commit message of the transaction. + Commit() ImmutableMessage + + // RangeOver iterates over the underlying messages in the transaction. + // If visitor return not nil, the iteration will be stopped. + RangeOver(visitor func(ImmutableMessage) error) error + + // Size returns the number of messages in the transaction. + Size() int +} + // specializedMutableMessage is the specialized mutable message interface. type specializedMutableMessage[H proto.Message, B proto.Message] interface { BasicMessage diff --git a/pkg/streaming/util/message/message_builder_test.go b/pkg/streaming/util/message/message_builder_test.go index 030438a07f..5c2a503392 100644 --- a/pkg/streaming/util/message/message_builder_test.go +++ b/pkg/streaming/util/message/message_builder_test.go @@ -11,6 +11,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" ) func TestMessage(t *testing.T) { @@ -32,14 +33,17 @@ func TestMessage(t *testing.T) { assert.Equal(t, "value", v) assert.True(t, ok) assert.Equal(t, message.MessageTypeTimeTick, mutableMessage.MessageType()) - assert.Equal(t, 32, mutableMessage.EstimateSize()) + assert.Equal(t, 31, mutableMessage.EstimateSize()) mutableMessage.WithTimeTick(123) + mutableMessage.WithBarrierTimeTick(456) + mutableMessage.WithWALTerm(1) v, ok = mutableMessage.Properties().Get("_tt") assert.True(t, ok) tt, err := message.DecodeUint64(v) assert.Equal(t, uint64(123), tt) assert.NoError(t, err) assert.Equal(t, uint64(123), mutableMessage.TimeTick()) + assert.Equal(t, uint64(456), mutableMessage.BarrierTimeTick()) lcMsgID := mock_message.NewMockMessageID(t) lcMsgID.EXPECT().Marshal().Return("lcMsgID") @@ -113,3 +117,16 @@ func TestMessage(t *testing.T) { message.NewTimeTickMessageBuilderV1().BuildMutable() }) } + +func TestLastConfirmed(t *testing.T) { + flush, _ := message.NewFlushMessageBuilderV2(). + WithVChannel("vchan"). + WithHeader(&message.FlushMessageHeader{}). + WithBody(&message.FlushMessageBody{}). + BuildMutable() + + imFlush := flush.WithTimeTick(1). + WithLastConfirmedUseMessageID(). + IntoImmutableMessage(walimplstest.NewTestMessageID(1)) + assert.True(t, imFlush.LastConfirmedMessageID().EQ(walimplstest.NewTestMessageID(1))) +} diff --git a/pkg/streaming/util/message/message_impl.go b/pkg/streaming/util/message/message_impl.go index be237e680a..7aeba5e61b 100644 --- a/pkg/streaming/util/message/message_impl.go +++ b/pkg/streaming/util/message/message_impl.go @@ -2,6 +2,8 @@ package message import ( "fmt" + + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" ) type messageImpl struct { @@ -43,12 +45,21 @@ func (m *messageImpl) EstimateSize() int { return len(m.payload) + m.properties.EstimateSize() } -// WithVChannel sets the virtual channel of current message. -func (m *messageImpl) WithVChannel(vChannel string) MutableMessage { - if m.properties.Exist(messageVChannel) { - panic("vchannel already set in properties of message") +// WithBarrierTimeTick sets the barrier time tick of current message. +func (m *messageImpl) WithBarrierTimeTick(tt uint64) MutableMessage { + if m.properties.Exist(messageBarrierTimeTick) { + panic("barrier time tick already set in properties of message") } - m.properties.Set(messageVChannel, vChannel) + m.properties.Set(messageBarrierTimeTick, EncodeUint64(tt)) + return m +} + +// WithWALTerm sets the wal term of current message. +func (m *messageImpl) WithWALTerm(term int64) MutableMessage { + if m.properties.Exist(messageWALTerm) { + panic("wal term already set in properties of message") + } + m.properties.Set(messageWALTerm, EncodeInt64(term)) return m } @@ -63,6 +74,9 @@ func (m *messageImpl) WithTimeTick(tt uint64) MutableMessage { // WithLastConfirmed sets the last confirmed message id of current message. func (m *messageImpl) WithLastConfirmed(id MessageID) MutableMessage { + if m.properties.Exist(messageLastConfirmedIDSameWithMessageID) { + panic("last confirmed message already set in properties of message") + } if m.properties.Exist(messageLastConfirmed) { panic("last confirmed message already set in properties of message") } @@ -72,7 +86,23 @@ func (m *messageImpl) WithLastConfirmed(id MessageID) MutableMessage { // WithLastConfirmedUseMessageID sets the last confirmed message id of current message to be the same as message id. func (m *messageImpl) WithLastConfirmedUseMessageID() MutableMessage { - m.properties.Set(messageLastConfirmed, messageLastConfirmedValueUseMessageID) + if m.properties.Exist(messageLastConfirmedIDSameWithMessageID) { + panic("last confirmed message already set in properties of message") + } + if m.properties.Exist(messageLastConfirmed) { + panic("last confirmed message already set in properties of message") + } + m.properties.Set(messageLastConfirmedIDSameWithMessageID, "") + return m +} + +// WithTxnContext sets the transaction context of current message. +func (m *messageImpl) WithTxnContext(txnCtx TxnContext) MutableMessage { + pb, err := EncodeProto(txnCtx.IntoProto()) + if err != nil { + panic("should not happen on txn proto") + } + m.properties.Set(messageTxnContext, pb) return m } @@ -84,6 +114,19 @@ func (m *messageImpl) IntoImmutableMessage(id MessageID) ImmutableMessage { } } +// TxnContext returns the transaction context of current message. +func (m *messageImpl) TxnContext() *TxnContext { + value, ok := m.properties.Get(messageTxnContext) + if !ok { + return nil + } + txnCtx := &messagespb.TxnContext{} + if err := DecodeProto(value, txnCtx); err != nil { + panic(fmt.Sprintf("there's a bug in the message codes, dirty txn context %s in properties of message", value)) + } + return NewTxnContextFromProto(txnCtx) +} + // TimeTick returns the time tick of current message. func (m *messageImpl) TimeTick() uint64 { value, ok := m.properties.Get(messageTimeTick) @@ -97,6 +140,19 @@ func (m *messageImpl) TimeTick() uint64 { return tt } +// BarrierTimeTick returns the barrier time tick of current message. +func (m *messageImpl) BarrierTimeTick() uint64 { + value, ok := m.properties.Get(messageBarrierTimeTick) + if !ok { + return 0 + } + tt, err := DecodeUint64(value) + if err != nil { + panic(fmt.Sprintf("there's a bug in the message codes, dirty barrier timetick %s in properties of message", value)) + } + return tt +} + // VChannel returns the vchannel of current message. // If the message is broadcasted, the vchannel will be empty. func (m *messageImpl) VChannel() string { @@ -123,16 +179,63 @@ func (m *immutableMessageImpl) MessageID() MessageID { } func (m *immutableMessageImpl) LastConfirmedMessageID() MessageID { + // same with message id + if _, ok := m.properties.Get(messageLastConfirmedIDSameWithMessageID); ok { + return m.MessageID() + } value, ok := m.properties.Get(messageLastConfirmed) if !ok { panic(fmt.Sprintf("there's a bug in the message codes, last confirmed message lost in properties of message, id: %+v", m.id)) } - if value == messageLastConfirmedValueUseMessageID { - return m.MessageID() - } id, err := UnmarshalMessageID(m.id.WALName(), value) if err != nil { panic(fmt.Sprintf("there's a bug in the message codes, dirty last confirmed message in properties of message, id: %+v", m.id)) } return id } + +// overwriteTimeTick overwrites the time tick of current message. +func (m *immutableMessageImpl) overwriteTimeTick(timetick uint64) { + m.properties.Delete(messageTimeTick) + m.WithTimeTick(timetick) +} + +// overwriteLastConfirmedMessageID overwrites the last confirmed message id of current message. +func (m *immutableMessageImpl) overwriteLastConfirmedMessageID(id MessageID) { + m.properties.Delete(messageLastConfirmed) + m.properties.Delete(messageLastConfirmedIDSameWithMessageID) + m.WithLastConfirmed(id) +} + +// immutableTxnMessageImpl is a immutable transaction message. +type immutableTxnMessageImpl struct { + immutableMessageImpl + begin ImmutableMessage + messages []ImmutableMessage // the messages that wrapped by the transaction message. + commit ImmutableMessage +} + +// Begin returns the begin message of the transaction message. +func (m *immutableTxnMessageImpl) Begin() ImmutableMessage { + return m.begin +} + +// RangeOver iterates over the underlying messages in the transaction message. +func (m *immutableTxnMessageImpl) RangeOver(fn func(ImmutableMessage) error) error { + for _, msg := range m.messages { + if err := fn(msg); err != nil { + return err + } + } + return nil +} + +// Commit returns the commit message of the transaction message. +func (m *immutableTxnMessageImpl) Commit() ImmutableMessage { + return m.commit +} + +// Size returns the number of messages in the transaction message. +func (m *immutableTxnMessageImpl) Size() int { + return len(m.messages) +} diff --git a/pkg/streaming/util/message/message_type.go b/pkg/streaming/util/message/message_type.go index 0d2d8c90f4..3f102b9447 100644 --- a/pkg/streaming/util/message/message_type.go +++ b/pkg/streaming/util/message/message_type.go @@ -18,6 +18,10 @@ const ( MessageTypeDropCollection MessageType = MessageType(messagespb.MessageType_DropCollection) MessageTypeCreatePartition MessageType = MessageType(messagespb.MessageType_CreatePartition) MessageTypeDropPartition MessageType = MessageType(messagespb.MessageType_DropPartition) + MessageTypeTxn MessageType = MessageType(messagespb.MessageType_Txn) + MessageTypeBeginTxn MessageType = MessageType(messagespb.MessageType_BeginTxn) + MessageTypeCommitTxn MessageType = MessageType(messagespb.MessageType_CommitTxn) + MessageTypeRollbackTxn MessageType = MessageType(messagespb.MessageType_RollbackTxn) ) var messageTypeName = map[MessageType]string{ @@ -30,6 +34,10 @@ var messageTypeName = map[MessageType]string{ MessageTypeDropCollection: "DROP_COLLECTION", MessageTypeCreatePartition: "CREATE_PARTITION", MessageTypeDropPartition: "DROP_PARTITION", + MessageTypeTxn: "TXN", + MessageTypeBeginTxn: "BEGIN_TXN", + MessageTypeCommitTxn: "COMMIT_TXN", + MessageTypeRollbackTxn: "ROLLBACK_TXN", } // String implements fmt.Stringer interface. @@ -48,6 +56,12 @@ func (t MessageType) Valid() bool { return t != MessageTypeUnknown && ok } +// IsSysmtem checks if the MessageType is a system type. +func (t MessageType) IsSystem() bool { + _, ok := systemMessageType[t] + return ok +} + // unmarshalMessageType unmarshal MessageType from string. func unmarshalMessageType(s string) MessageType { i, err := strconv.ParseInt(s, 10, 32) diff --git a/pkg/streaming/util/message/properties.go b/pkg/streaming/util/message/properties.go index 10f02b8165..575c7d2146 100644 --- a/pkg/streaming/util/message/properties.go +++ b/pkg/streaming/util/message/properties.go @@ -2,17 +2,16 @@ package message const ( // preserved properties - messageVersion = "_v" // message version for compatibility. - messageTypeKey = "_t" // message type key. - messageTimeTick = "_tt" // message time tick. - messageLastConfirmed = "_lc" // message last confirmed message id. - messageVChannel = "_vc" // message virtual channel. - messageSpecialiedHeader = "_sh" // specialized message header. -) - -const ( - messageLastConfirmedValueUseMessageID = "use_message_id" // message last confirmed message id is same with message id. - // some message type can not set last confirmed message id, but can use the message id as last confirmed id. + messageVersion = "_v" // message version for compatibility, see `Version` for more information. + messageWALTerm = "_wt" // wal term of a message, always increase by MessageID order, should never rollback. + messageTypeKey = "_t" // message type key. + messageTimeTick = "_tt" // message time tick. + messageBarrierTimeTick = "_btt" // message barrier time tick. + messageLastConfirmed = "_lc" // message last confirmed message id. + messageLastConfirmedIDSameWithMessageID = "_lcs" // message last confirmed message id is the same with message id. + messageVChannel = "_vc" // message virtual channel. + messageHeader = "_h" // specialized message header. + messageTxnContext = "_tx" // transaction context. ) var ( @@ -57,6 +56,10 @@ func (prop propertiesImpl) Set(key, value string) { prop[key] = value } +func (prop propertiesImpl) Delete(key string) { + delete(prop, key) +} + func (prop propertiesImpl) ToRawMap() map[string]string { return map[string]string(prop) } diff --git a/pkg/streaming/util/message/specialized_message.go b/pkg/streaming/util/message/specialized_message.go index 2958468357..050ec53c38 100644 --- a/pkg/streaming/util/message/specialized_message.go +++ b/pkg/streaming/util/message/specialized_message.go @@ -22,7 +22,18 @@ type ( CreatePartitionMessageHeader = messagespb.CreatePartitionMessageHeader DropPartitionMessageHeader = messagespb.DropPartitionMessageHeader FlushMessageHeader = messagespb.FlushMessageHeader - FlushMessageBody = messagespb.FlushMessageBody + BeginTxnMessageHeader = messagespb.BeginTxnMessageHeader + CommitTxnMessageHeader = messagespb.CommitTxnMessageHeader + RollbackTxnMessageHeader = messagespb.RollbackTxnMessageHeader + TxnMessageHeader = messagespb.TxnMessageHeader +) + +type ( + FlushMessageBody = messagespb.FlushMessageBody + BeginTxnMessageBody = messagespb.BeginTxnMessageBody + CommitTxnMessageBody = messagespb.CommitTxnMessageBody + RollbackTxnMessageBody = messagespb.RollbackTxnMessageBody + TxnMessageBody = messagespb.TxnMessageBody ) // messageTypeMap maps the proto message type to the message type. @@ -35,6 +46,19 @@ var messageTypeMap = map[reflect.Type]MessageType{ reflect.TypeOf(&CreatePartitionMessageHeader{}): MessageTypeCreatePartition, reflect.TypeOf(&DropPartitionMessageHeader{}): MessageTypeDropPartition, reflect.TypeOf(&FlushMessageHeader{}): MessageTypeFlush, + reflect.TypeOf(&BeginTxnMessageHeader{}): MessageTypeBeginTxn, + reflect.TypeOf(&CommitTxnMessageHeader{}): MessageTypeCommitTxn, + reflect.TypeOf(&RollbackTxnMessageHeader{}): MessageTypeRollbackTxn, + reflect.TypeOf(&TxnMessageHeader{}): MessageTypeTxn, +} + +// A system preserved message, should not allowed to provide outside of the streaming system. +var systemMessageType = map[MessageType]struct{}{ + MessageTypeTimeTick: {}, + MessageTypeBeginTxn: {}, + MessageTypeCommitTxn: {}, + MessageTypeRollbackTxn: {}, + MessageTypeTxn: {}, } // List all specialized message types. @@ -47,6 +71,9 @@ type ( MutableCreatePartitionMessageV1 = specializedMutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest] MutableDropPartitionMessageV1 = specializedMutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest] MutableFlushMessageV2 = specializedMutableMessage[*FlushMessageHeader, *FlushMessageBody] + MutableBeginTxnMessageV2 = specializedMutableMessage[*BeginTxnMessageHeader, *BeginTxnMessageBody] + MutableCommitTxnMessageV2 = specializedMutableMessage[*CommitTxnMessageHeader, *CommitTxnMessageBody] + MutableRollbackTxnMessageV2 = specializedMutableMessage[*RollbackTxnMessageHeader, *RollbackTxnMessageBody] ImmutableTimeTickMessageV1 = specializedImmutableMessage[*TimeTickMessageHeader, *msgpb.TimeTickMsg] ImmutableInsertMessageV1 = specializedImmutableMessage[*InsertMessageHeader, *msgpb.InsertRequest] @@ -56,6 +83,9 @@ type ( ImmutableCreatePartitionMessageV1 = specializedImmutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest] ImmutableDropPartitionMessageV1 = specializedImmutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest] ImmutableFlushMessageV2 = specializedImmutableMessage[*FlushMessageHeader, *FlushMessageBody] + ImmutableBeginTxnMessageV2 = specializedImmutableMessage[*BeginTxnMessageHeader, *BeginTxnMessageBody] + ImmutableCommitTxnMessageV2 = specializedImmutableMessage[*CommitTxnMessageHeader, *CommitTxnMessageBody] + ImmutableRollbackTxnMessageV2 = specializedImmutableMessage[*RollbackTxnMessageHeader, *RollbackTxnMessageBody] ) // List all as functions for specialized messages. @@ -68,6 +98,9 @@ var ( AsMutableCreatePartitionMessageV1 = asSpecializedMutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest] AsMutableDropPartitionMessageV1 = asSpecializedMutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest] AsMutableFlushMessageV2 = asSpecializedMutableMessage[*FlushMessageHeader, *FlushMessageBody] + AsMutableBeginTxnMessageV2 = asSpecializedMutableMessage[*BeginTxnMessageHeader, *BeginTxnMessageBody] + AsMutableCommitTxnMessageV2 = asSpecializedMutableMessage[*CommitTxnMessageHeader, *CommitTxnMessageBody] + AsMutableRollbackTxnMessageV2 = asSpecializedMutableMessage[*RollbackTxnMessageHeader, *RollbackTxnMessageBody] AsImmutableTimeTickMessageV1 = asSpecializedImmutableMessage[*TimeTickMessageHeader, *msgpb.TimeTickMsg] AsImmutableInsertMessageV1 = asSpecializedImmutableMessage[*InsertMessageHeader, *msgpb.InsertRequest] @@ -77,6 +110,16 @@ var ( AsImmutableCreatePartitionMessageV1 = asSpecializedImmutableMessage[*CreatePartitionMessageHeader, *msgpb.CreatePartitionRequest] AsImmutableDropPartitionMessageV1 = asSpecializedImmutableMessage[*DropPartitionMessageHeader, *msgpb.DropPartitionRequest] AsImmutableFlushMessageV2 = asSpecializedImmutableMessage[*FlushMessageHeader, *FlushMessageBody] + AsImmutableBeginTxnMessageV2 = asSpecializedImmutableMessage[*BeginTxnMessageHeader, *BeginTxnMessageBody] + AsImmutableCommitTxnMessageV2 = asSpecializedImmutableMessage[*CommitTxnMessageHeader, *CommitTxnMessageBody] + AsImmutableRollbackTxnMessageV2 = asSpecializedImmutableMessage[*RollbackTxnMessageHeader, *RollbackTxnMessageBody] + AsImmutableTxnMessage = func(msg ImmutableMessage) ImmutableTxnMessage { + underlying, ok := msg.(*immutableTxnMessageImpl) + if !ok { + return nil + } + return underlying + } ) // asSpecializedMutableMessage converts a MutableMessage to a specialized MutableMessage. @@ -94,7 +137,7 @@ func asSpecializedMutableMessage[H proto.Message, B proto.Message](msg MutableMe } // Get the specialized header from the message. - val, ok := underlying.properties.Get(messageSpecialiedHeader) + val, ok := underlying.properties.Get(messageHeader) if !ok { return nil, errors.Errorf("lost specialized header, %s", msgType.String()) } @@ -120,7 +163,11 @@ func asSpecializedMutableMessage[H proto.Message, B proto.Message](msg MutableMe // Return nil, error if the message is the target specialized message but failed to decode the specialized header. // Return asSpecializedImmutableMessage, nil if the message is the target specialized message and successfully decoded the specialized header. func asSpecializedImmutableMessage[H proto.Message, B proto.Message](msg ImmutableMessage) (specializedImmutableMessage[H, B], error) { - underlying := msg.(*immutableMessageImpl) + underlying, ok := msg.(*immutableMessageImpl) + if !ok { + // maybe a txn message. + return nil, nil + } var header H msgType := mustGetMessageTypeFromHeader(header) @@ -130,7 +177,7 @@ func asSpecializedImmutableMessage[H proto.Message, B proto.Message](msg Immutab } // Get the specialized header from the message. - val, ok := underlying.properties.Get(messageSpecialiedHeader) + val, ok := underlying.properties.Get(messageHeader) if !ok { return nil, errors.Errorf("lost specialized header, %s", msgType.String()) } @@ -184,7 +231,7 @@ func (m *specializedMutableMessageImpl[H, B]) OverwriteHeader(header H) { if err != nil { panic(fmt.Sprintf("failed to encode insert header, there's a bug, %+v, %s", m.header, err.Error())) } - m.messageImpl.properties.Set(messageSpecialiedHeader, newHeader) + m.messageImpl.properties.Set(messageHeader, newHeader) } // specializedImmutableMessageImpl is the specialized immmutable message implementation. diff --git a/pkg/streaming/util/message/txn.go b/pkg/streaming/util/message/txn.go new file mode 100644 index 0000000000..150f92634a --- /dev/null +++ b/pkg/streaming/util/message/txn.go @@ -0,0 +1,51 @@ +package message + +import ( + "time" + + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" +) + +type ( + TxnState = messagespb.TxnState + TxnID int64 +) + +const ( + TxnStateBegin TxnState = messagespb.TxnState_TxnBegin + TxnStateInFlight TxnState = messagespb.TxnState_TxnInFlight + TxnStateOnCommit TxnState = messagespb.TxnState_TxnOnCommit + TxnStateCommitted TxnState = messagespb.TxnState_TxnCommitted + TxnStateOnRollback TxnState = messagespb.TxnState_TxnOnRollback + TxnStateRollbacked TxnState = messagespb.TxnState_TxnRollbacked + + NonTxnID = TxnID(-1) +) + +// NewTxnContextFromProto generates TxnContext from proto message. +func NewTxnContextFromProto(proto *messagespb.TxnContext) *TxnContext { + if proto == nil { + return nil + } + return &TxnContext{ + TxnID: TxnID(proto.TxnId), + Keepalive: time.Duration(proto.KeepaliveMilliseconds) * time.Millisecond, + } +} + +// TxnContext is the transaction context for message. +type TxnContext struct { + TxnID TxnID + Keepalive time.Duration +} + +// IntoProto converts TxnContext to proto message. +func (t *TxnContext) IntoProto() *messagespb.TxnContext { + if t == nil { + return nil + } + return &messagespb.TxnContext{ + TxnId: int64(t.TxnID), + KeepaliveMilliseconds: t.Keepalive.Milliseconds(), + } +} diff --git a/pkg/streaming/util/message/txn_test.go b/pkg/streaming/util/message/txn_test.go new file mode 100644 index 0000000000..ce22af12b4 --- /dev/null +++ b/pkg/streaming/util/message/txn_test.go @@ -0,0 +1,90 @@ +package message_test + +import ( + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" + "github.com/milvus-io/milvus/pkg/streaming/walimpls/impls/walimplstest" +) + +func TestTxn(t *testing.T) { + txn := message.NewTxnContextFromProto(&messagespb.TxnContext{ + TxnId: 1, + KeepaliveMilliseconds: 1000, + }) + assert.Equal(t, message.TxnID(1), txn.TxnID) + assert.Equal(t, time.Second, txn.Keepalive) + + assert.Equal(t, int64(1), txn.IntoProto().TxnId) + assert.Equal(t, int64(1000), txn.IntoProto().KeepaliveMilliseconds) + + txn = message.NewTxnContextFromProto(nil) + assert.Nil(t, txn) +} + +func TestAsImmutableTxnMessage(t *testing.T) { + txnCtx := message.TxnContext{ + TxnID: 1, + Keepalive: time.Second, + } + begin, _ := message.NewBeginTxnMessageBuilderV2(). + WithVChannel("vchan"). + WithHeader(&message.BeginTxnMessageHeader{}). + WithBody(&message.BeginTxnMessageBody{}). + BuildMutable() + imBegin := begin.WithTxnContext(txnCtx). + WithTimeTick(1). + WithLastConfirmed(walimplstest.NewTestMessageID(1)). + IntoImmutableMessage(walimplstest.NewTestMessageID(1)) + + beginMsg, _ := message.AsImmutableBeginTxnMessageV2(imBegin) + + insert, _ := message.NewInsertMessageBuilderV1(). + WithVChannel("vchan"). + WithHeader(&message.InsertMessageHeader{}). + WithBody(&msgpb.InsertRequest{}). + BuildMutable() + + commit, _ := message.NewCommitTxnMessageBuilderV2(). + WithVChannel("vchan"). + WithHeader(&message.CommitTxnMessageHeader{}). + WithBody(&message.CommitTxnMessageBody{}). + BuildMutable() + + imCommit := commit.WithTxnContext(txnCtx). + WithTimeTick(3). + WithLastConfirmed(walimplstest.NewTestMessageID(3)). + IntoImmutableMessage(walimplstest.NewTestMessageID(4)) + + commitMsg, _ := message.AsImmutableCommitTxnMessageV2(imCommit) + + txnMsg, err := message.NewImmutableTxnMessageBuilder(beginMsg). + Add(insert.WithTimeTick(2).WithTxnContext(txnCtx).IntoImmutableMessage(walimplstest.NewTestMessageID(2))). + Build(commitMsg) + + assert.NoError(t, err) + assert.NotNil(t, txnMsg) + assert.Equal(t, uint64(3), txnMsg.TimeTick()) + assert.Equal(t, walimplstest.NewTestMessageID(4), txnMsg.MessageID()) + assert.Equal(t, walimplstest.NewTestMessageID(3), txnMsg.LastConfirmedMessageID()) + err = txnMsg.RangeOver(func(msg message.ImmutableMessage) error { + assert.Equal(t, uint64(3), msg.TimeTick()) + return nil + }) + assert.NoError(t, err) + + err = txnMsg.RangeOver(func(msg message.ImmutableMessage) error { + return errors.New("error") + }) + assert.Error(t, err) + + assert.NotNil(t, txnMsg.Commit()) + assert.Equal(t, 1, txnMsg.Size()) + assert.NotNil(t, txnMsg.Begin()) +} diff --git a/pkg/streaming/util/message/version.go b/pkg/streaming/util/message/version.go index 4cbfa25736..502f7042f6 100644 --- a/pkg/streaming/util/message/version.go +++ b/pkg/streaming/util/message/version.go @@ -5,7 +5,7 @@ import "strconv" var ( VersionOld Version = 0 // old version before streamingnode. VersionV1 Version = 1 // The message marshal unmarshal still use msgstream. - VersionV2 Version = 2 // The message marshal unmsarhsl is not rely on msgstream. + VersionV2 Version = 2 // The message marshal unmarshal never rely on msgstream. ) type Version int // message version for compatibility. diff --git a/pkg/streaming/util/options/deliver.go b/pkg/streaming/util/options/deliver.go index 4e02aff467..6e5bdd9023 100644 --- a/pkg/streaming/util/options/deliver.go +++ b/pkg/streaming/util/options/deliver.go @@ -1,6 +1,8 @@ package options import ( + "fmt" + "github.com/milvus-io/milvus/pkg/streaming/proto/messagespb" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/message" @@ -104,6 +106,9 @@ func DeliverFilterVChannel(vchannel string) DeliverFilter { func DeliverFilterMessageType(messageType ...message.MessageType) DeliverFilter { messageTypes := make([]messagespb.MessageType, 0, len(messageType)) for _, mt := range messageType { + if mt.IsSystem() { + panic(fmt.Sprintf("system message type cannot be filter, %s", mt.String())) + } messageTypes = append(messageTypes, messagespb.MessageType(mt)) } return &streamingpb.DeliverFilter{ @@ -126,25 +131,40 @@ func IsDeliverFilterTimeTick(filter DeliverFilter) bool { } // GetFilterFunc returns the filter function. -func GetFilterFunc(filters []DeliverFilter) (func(message.ImmutableMessage) bool, error) { +func GetFilterFunc(filters []DeliverFilter) func(message.ImmutableMessage) bool { filterFuncs := make([]func(message.ImmutableMessage) bool, 0, len(filters)) for _, filter := range filters { filter := filter switch filter.GetFilter().(type) { case *streamingpb.DeliverFilter_TimeTickGt: filterFuncs = append(filterFuncs, func(im message.ImmutableMessage) bool { - return im.TimeTick() > filter.GetTimeTickGt().TimeTick + // txn message's timetick is determined by the commit message. + // so we only need to filter the commit message. + if im.TxnContext() == nil || im.MessageType() == message.MessageTypeCommitTxn { + return im.TimeTick() > filter.GetTimeTickGt().TimeTick + } + return true }) case *streamingpb.DeliverFilter_TimeTickGte: filterFuncs = append(filterFuncs, func(im message.ImmutableMessage) bool { - return im.TimeTick() >= filter.GetTimeTickGte().TimeTick + // txn message's timetick is determined by the commit message. + // so we only need to filter the commit message. + if im.TxnContext() == nil || im.MessageType() == message.MessageTypeCommitTxn { + return im.TimeTick() >= filter.GetTimeTickGte().TimeTick + } + return true }) case *streamingpb.DeliverFilter_Vchannel: filterFuncs = append(filterFuncs, func(im message.ImmutableMessage) bool { - return im.VChannel() == filter.GetVchannel().Vchannel + // vchannel == "" is a broadcast operation. + return im.VChannel() == "" || im.VChannel() == filter.GetVchannel().Vchannel }) case *streamingpb.DeliverFilter_MessageType: filterFuncs = append(filterFuncs, func(im message.ImmutableMessage) bool { + // system message cannot be filterred. + if im.MessageType().IsSystem() { + return true + } for _, mt := range filter.GetMessageType().MessageTypes { if im.MessageType() == message.MessageType(mt) { return true @@ -163,5 +183,5 @@ func GetFilterFunc(filters []DeliverFilter) (func(message.ImmutableMessage) bool } } return true - }, nil + } } diff --git a/pkg/streaming/util/options/deliver_test.go b/pkg/streaming/util/options/deliver_test.go index 7d45880bda..ad8014ae62 100644 --- a/pkg/streaming/util/options/deliver_test.go +++ b/pkg/streaming/util/options/deliver_test.go @@ -3,8 +3,11 @@ package options import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/pkg/mocks/streaming/util/mock_message" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" + "github.com/milvus-io/milvus/pkg/streaming/util/message" ) func TestDeliverPolicy(t *testing.T) { @@ -32,4 +35,113 @@ func TestDeliverFilter(t *testing.T) { filter = DeliverFilterVChannel("vchannel") _ = filter.GetFilter().(*streamingpb.DeliverFilter_Vchannel) + + filter = DeliverFilterMessageType(message.MessageTypeDelete) + _ = filter.GetFilter().(*streamingpb.DeliverFilter_MessageType) +} + +func TestNewMessageFilter(t *testing.T) { + filters := []DeliverFilter{ + DeliverFilterTimeTickGT(1), + DeliverFilterVChannel("test"), + } + filterFunc := GetFilterFunc(filters) + + msg := mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(2).Maybe() + msg.EXPECT().VChannel().Return("test2").Maybe() + msg.EXPECT().TxnContext().Return(nil).Maybe() + assert.False(t, filterFunc(msg)) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(1).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + msg.EXPECT().TxnContext().Return(nil).Maybe() + assert.False(t, filterFunc(msg)) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(1).Maybe() + msg.EXPECT().VChannel().Return("").Maybe() // vchannel == "" should not be filtered. + msg.EXPECT().TxnContext().Return(nil).Maybe() + assert.False(t, filterFunc(msg)) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(2).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + msg.EXPECT().TxnContext().Return(nil).Maybe() + assert.True(t, filterFunc(msg)) + + // if message is a txn message, it should be only filterred by time tick when the message type is commit. + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(1).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + msg.EXPECT().TxnContext().Return(&message.TxnContext{}).Maybe() + msg.EXPECT().MessageType().Return(message.MessageTypeCommitTxn).Maybe() + assert.False(t, filterFunc(msg)) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(1).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + msg.EXPECT().TxnContext().Return(&message.TxnContext{}).Maybe() + msg.EXPECT().MessageType().Return(message.MessageTypeInsert).Maybe() + assert.True(t, filterFunc(msg)) + + // if message is a txn message, it should be only filterred by time tick when the message type is commit. + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(1).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + msg.EXPECT().TxnContext().Return(&message.TxnContext{}).Maybe() + msg.EXPECT().MessageType().Return(message.MessageTypeCommitTxn).Maybe() + assert.False(t, filterFunc(msg)) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(1).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + msg.EXPECT().TxnContext().Return(&message.TxnContext{}).Maybe() + msg.EXPECT().MessageType().Return(message.MessageTypeInsert).Maybe() + assert.True(t, filterFunc(msg)) + + filters = []*streamingpb.DeliverFilter{ + DeliverFilterTimeTickGTE(1), + DeliverFilterVChannel("test"), + } + filterFunc = GetFilterFunc(filters) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(1).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + msg.EXPECT().TxnContext().Return(nil).Maybe() + assert.True(t, filterFunc(msg)) + + // if message is a txn message, it should be only filterred by time tick when the message type is commit. + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(1).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + msg.EXPECT().TxnContext().Return(&message.TxnContext{}).Maybe() + msg.EXPECT().MessageType().Return(message.MessageTypeCommitTxn).Maybe() + assert.True(t, filterFunc(msg)) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().TimeTick().Return(1).Maybe() + msg.EXPECT().VChannel().Return("test").Maybe() + msg.EXPECT().TxnContext().Return(&message.TxnContext{}).Maybe() + msg.EXPECT().MessageType().Return(message.MessageTypeInsert).Maybe() + assert.True(t, filterFunc(msg)) + + filters = []*streamingpb.DeliverFilter{ + DeliverFilterMessageType(message.MessageTypeInsert, message.MessageTypeDelete), + } + filterFunc = GetFilterFunc(filters) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().MessageType().Return(message.MessageTypeInsert).Maybe() + assert.True(t, filterFunc(msg)) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().MessageType().Return(message.MessageTypeDelete).Maybe() + assert.True(t, filterFunc(msg)) + + msg = mock_message.NewMockImmutableMessage(t) + msg.EXPECT().MessageType().Return(message.MessageTypeFlush).Maybe() + assert.False(t, filterFunc(msg)) } diff --git a/pkg/streaming/util/types/streaming_node.go b/pkg/streaming/util/types/streaming_node.go index 125891aee3..4c6a13e699 100644 --- a/pkg/streaming/util/types/streaming_node.go +++ b/pkg/streaming/util/types/streaming_node.go @@ -4,6 +4,8 @@ import ( "context" "github.com/cockroachdb/errors" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" "github.com/milvus-io/milvus/pkg/streaming/proto/streamingpb" "github.com/milvus-io/milvus/pkg/streaming/util/message" @@ -88,9 +90,25 @@ func (n *StreamingNodeStatus) ErrorOfNode() error { // AppendResult is the result of append operation. type AppendResult struct { - // Message is generated by underlying walimpls. + // MessageID is generated by underlying walimpls. MessageID message.MessageID + // TimeTick is the time tick of the message. // Set by timetick interceptor. TimeTick uint64 + + // TxnCtx is the transaction context of the message. + // If the message is not belong to a transaction, the TxnCtx will be nil. + TxnCtx *message.TxnContext + + // Extra is the extra information of the append result. + Extra *anypb.Any +} + +// GetExtra unmarshal the extra information to the given message. +func (r *AppendResult) GetExtra(m proto.Message) error { + return anypb.UnmarshalTo(r.Extra, m, proto.UnmarshalOptions{ + DiscardUnknown: true, + AllowPartial: true, + }) } diff --git a/pkg/streaming/walimpls/impls/walimplstest/builder.go b/pkg/streaming/walimpls/impls/walimplstest/builder.go index d66feb98d4..e00ac91354 100644 --- a/pkg/streaming/walimpls/impls/walimplstest/builder.go +++ b/pkg/streaming/walimpls/impls/walimplstest/builder.go @@ -10,7 +10,7 @@ import ( ) const ( - WALName = "test" + WALName = "walimplstest" ) func init() { diff --git a/pkg/streaming/walimpls/impls/walimplstest/message_id.go b/pkg/streaming/walimpls/impls/walimplstest/message_id.go index b36d775381..afc8eb7ca0 100644 --- a/pkg/streaming/walimpls/impls/walimplstest/message_id.go +++ b/pkg/streaming/walimpls/impls/walimplstest/message_id.go @@ -4,6 +4,8 @@ package walimplstest import ( + "strconv" + "github.com/milvus-io/milvus/pkg/streaming/util/message" ) @@ -25,7 +27,7 @@ func UnmarshalTestMessageID(data string) (message.MessageID, error) { // unmashalTestMessageID unmarshal the message id. func unmarshalTestMessageID(data string) (testMessageID, error) { - id, err := message.DecodeInt64(data) + id, err := strconv.ParseUint(data, 10, 64) if err != nil { return 0, err } @@ -57,5 +59,5 @@ func (id testMessageID) EQ(other message.MessageID) bool { // Marshal marshal the message id. func (id testMessageID) Marshal() string { - return message.EncodeInt64(int64(id)) + return strconv.FormatInt(int64(id), 10) } diff --git a/pkg/streaming/walimpls/impls/walimplstest/message_log.go b/pkg/streaming/walimpls/impls/walimplstest/message_log.go index 82c713e8c8..818c35c535 100644 --- a/pkg/streaming/walimpls/impls/walimplstest/message_log.go +++ b/pkg/streaming/walimpls/impls/walimplstest/message_log.go @@ -5,6 +5,7 @@ package walimplstest import ( "context" + "encoding/json" "sync" "github.com/milvus-io/milvus/pkg/streaming/util/message" @@ -24,37 +25,61 @@ func newMessageLog() *messageLog { return &messageLog{ cond: syncutil.NewContextCond(&sync.Mutex{}), id: 0, - logs: make([]message.ImmutableMessage, 0), + logs: make([][]byte, 0), } } type messageLog struct { cond *syncutil.ContextCond id int64 - logs []message.ImmutableMessage + logs [][]byte +} + +type entry struct { + ID int64 + Payload []byte + Properties map[string]string } func (l *messageLog) Append(_ context.Context, msg message.MutableMessage) (message.MessageID, error) { l.cond.LockAndBroadcast() defer l.cond.L.Unlock() - newMessageID := NewTestMessageID(l.id) + id := l.id + newEntry := entry{ + ID: id, + Payload: msg.Payload(), + Properties: msg.Properties().ToRawMap(), + } + data, err := json.Marshal(newEntry) + if err != nil { + return nil, err + } + l.id++ - l.logs = append(l.logs, msg.IntoImmutableMessage(newMessageID)) - return newMessageID, nil + l.logs = append(l.logs, data) + return NewTestMessageID(id), nil } func (l *messageLog) ReadAt(ctx context.Context, idx int) (message.ImmutableMessage, error) { - var msg message.ImmutableMessage l.cond.L.Lock() + for idx >= len(l.logs) { if err := l.cond.Wait(ctx); err != nil { return nil, err } } - msg = l.logs[idx] - l.cond.L.Unlock() + defer l.cond.L.Unlock() - return msg, nil + data := l.logs[idx] + var newEntry entry + if err := json.Unmarshal(data, &newEntry); err != nil { + return nil, err + } + return message.NewImmutableMesasge( + NewTestMessageID(newEntry.ID), + newEntry.Payload, + newEntry.Properties, + ), nil } func (l *messageLog) Len() int64 { diff --git a/pkg/util/syncutil/context_condition_variable.go b/pkg/util/syncutil/context_condition_variable.go index 5ca5a4a405..7a8d52eaa4 100644 --- a/pkg/util/syncutil/context_condition_variable.go +++ b/pkg/util/syncutil/context_condition_variable.go @@ -33,6 +33,15 @@ func (cv *ContextCond) LockAndBroadcast() { } } +// UnsafeBroadcast performs a broadcast without locking. +// !!! Must be called with the lock held !!! +func (cv *ContextCond) UnsafeBroadcast() { + if cv.ch != nil { + close(cv.ch) + cv.ch = nil + } +} + // Wait waits for a broadcast or context timeout. // It blocks until either a broadcast is received or the context is canceled or times out. // Returns an error if the context is canceled or times out. diff --git a/pkg/util/tsoutil/tso.go b/pkg/util/tsoutil/tso.go index 20913fbc72..0b3b650f29 100644 --- a/pkg/util/tsoutil/tso.go +++ b/pkg/util/tsoutil/tso.go @@ -62,12 +62,6 @@ func ParseHybridTs(ts uint64) (int64, int64) { return int64(physical), int64(logical) } -// ParseAndFormatHybridTs parses the ts and returns its human-readable format. -func ParseAndFormatHybridTs(ts uint64) string { - physicalTs, _ := ParseHybridTs(ts) - return time.Unix(physicalTs, 0).Format(time.RFC3339) // Convert to RFC3339 format -} - // CalculateDuration returns the number of milliseconds obtained by subtracting ts2 from ts1. func CalculateDuration(ts1, ts2 typeutil.Timestamp) int64 { p1, _ := ParseHybridTs(ts1)