diff --git a/internal/streamingnode/server/wal/adaptor/wal_test.go b/internal/streamingnode/server/wal/adaptor/wal_test.go index 02aa63a398..e28cccbe45 100644 --- a/internal/streamingnode/server/wal/adaptor/wal_test.go +++ b/internal/streamingnode/server/wal/adaptor/wal_test.go @@ -12,8 +12,8 @@ import ( "github.com/cockroachdb/errors" "github.com/remeh/sizedwaitgroup" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" @@ -49,8 +49,8 @@ type walTestFramework struct { } func TestFencedError(t *testing.T) { - assert.True(t, errors.IsAny(errors.Mark(errors.New("test"), walimpls.ErrFenced), context.Canceled, walimpls.ErrFenced)) - assert.True(t, errors.IsAny(errors.Wrap(walimpls.ErrFenced, "some message"), context.Canceled, walimpls.ErrFenced)) + require.True(t, errors.IsAny(errors.Mark(errors.New("test"), walimpls.ErrFenced), context.Canceled, walimpls.ErrFenced)) + require.True(t, errors.IsAny(errors.Wrap(walimpls.ErrFenced, "some message"), context.Canceled, walimpls.ErrFenced)) } func TestWAL(t *testing.T) { @@ -100,8 +100,8 @@ func (f *walTestFramework) Run() { loopCnt := 3 wg.Add(loopCnt) o, err := f.b.Build() - assert.NoError(f.t, err) - assert.NotNil(f.t, o) + require.NoError(f.t, err) + require.NotNil(f.t, o) defer o.Close() for i := 0; i < loopCnt; i++ { @@ -145,9 +145,9 @@ func (f *testOneWALFramework) Run() { Channel: pChannel, DisableFlusher: true, }) - assert.NoError(f.t, err) - assert.NotNil(f.t, rwWAL) - assert.Equal(f.t, pChannel.Name, rwWAL.Channel().Name) + require.NoError(f.t, err) + require.NotNil(f.t, rwWAL) + require.Equal(f.t, pChannel.Name, rwWAL.Channel().Name) // TODO: add test here after remove the flusher component. // metrics := rwWAL.Metrics() @@ -156,7 +156,7 @@ func (f *testOneWALFramework) Run() { Channel: pChannel, DisableFlusher: true, }) - assert.NoError(f.t, err) + require.NoError(f.t, err) metrics := roWAL.Metrics() _ = metrics.(types.ROWALMetrics) f.testReadAndWrite(ctx, rwWAL, roWAL) @@ -175,8 +175,8 @@ func (f *testOneWALFramework) Run() { MustBuildMutable() result, err := rwWAL.Append(ctx, createMsg) - assert.Nil(f.t, result) - assert.True(f.t, status.AsStreamingError(err).IsFenced()) + require.Nil(f.t, result) + require.True(f.t, status.AsStreamingError(err).IsFenced()) walimplstest.DisableFenced(pChannel.Name) rwWAL.Close() } @@ -184,8 +184,8 @@ func (f *testOneWALFramework) Run() { func (f *testOneWALFramework) testReadAndWrite(ctx context.Context, rwWAL wal.WAL, roWAL wal.ROWAL) { cp, err := rwWAL.GetReplicateCheckpoint() - assert.True(f.t, status.AsStreamingError(err).IsReplicateViolation()) - assert.Nil(f.t, cp) + require.True(f.t, status.AsStreamingError(err).IsReplicateViolation()) + require.Nil(f.t, cp) f.testSendCreateCollection(ctx, rwWAL) defer f.testSendDropCollection(ctx, rwWAL) @@ -200,15 +200,15 @@ func (f *testOneWALFramework) testReadAndWrite(ctx context.Context, rwWAL wal.WA go func() { defer wg.Done() lastMVCC, err := rwWAL.GetLatestMVCCTimestamp(context.Background(), testVChannel) - assert.NoError(f.t, err) + require.NoError(f.t, err) for { select { case <-appendDone: return case <-time.After(time.Duration(rand.Int31n(100)) * time.Millisecond): newMVCC, err := rwWAL.GetLatestMVCCTimestamp(context.Background(), testVChannel) - assert.NoError(f.t, err) - assert.GreaterOrEqual(f.t, newMVCC, lastMVCC) + require.NoError(f.t, err) + require.GreaterOrEqual(f.t, newMVCC, lastMVCC) lastMVCC = newMVCC } } @@ -220,25 +220,25 @@ func (f *testOneWALFramework) testReadAndWrite(ctx context.Context, rwWAL wal.WA }() var err error newWritten, err = f.testAppend(ctx, rwWAL) - assert.NoError(f.t, err) + require.NoError(f.t, err) }() go func() { defer wg.Done() var err error read1, err = f.testRead(ctx, rwWAL) - assert.NoError(f.t, err) + require.NoError(f.t, err) }() go func() { defer wg.Done() var err error read3, err = f.testRead(ctx, roWAL) - assert.NoError(f.t, err) + require.NoError(f.t, err) }() go func() { defer wg.Done() var err error read2, err = f.testRead(ctx, rwWAL) - assert.NoError(f.t, err) + require.NoError(f.t, err) }() wg.Wait() @@ -280,11 +280,11 @@ func (f *testOneWALFramework) testSendCreateCollection(ctx context.Context, w wa WithBody(&msgpb.CreateCollectionRequest{}). WithVChannel(testVChannel). BuildMutable() - assert.NoError(f.t, err) + require.NoError(f.t, err) msgID, err := w.Append(ctx, createMsg) - assert.NoError(f.t, err) - assert.NotNil(f.t, msgID) + require.NoError(f.t, err) + require.NotNil(f.t, msgID) } func (f *testOneWALFramework) testSendDropCollection(ctx context.Context, w wal.WAL) { @@ -299,12 +299,12 @@ func (f *testOneWALFramework) testSendDropCollection(ctx context.Context, w wal. WithBody(&msgpb.DropCollectionRequest{}). WithVChannel(testVChannel). BuildMutable() - assert.NoError(f.t, err) + require.NoError(f.t, err) done := make(chan struct{}) w.AppendAsync(ctx, dropMsg, func(ar *wal.AppendResult, err error) { - assert.NoError(f.t, err) - assert.NotNil(f.t, ar) + require.NoError(f.t, err) + require.NotNil(f.t, ar) close(done) }) <-done @@ -327,11 +327,11 @@ func (f *testOneWALFramework) testAppend(ctx context.Context, w wal.WAL) ([]mess }). WithBody(&message.BeginTxnMessageBody{}). BuildMutable() - assert.NoError(f.t, err) - assert.NotNil(f.t, msg) + require.NoError(f.t, err) + require.NotNil(f.t, msg) appendResult, err := w.Append(ctx, msg) - assert.NoError(f.t, err) - assert.NotNil(f.t, appendResult) + require.NoError(f.t, err) + require.NotNil(f.t, appendResult) immutableMsg := msg.IntoImmutableMessage(appendResult.MessageID) begin := message.MustAsImmutableBeginTxnMessageV2(immutableMsg) @@ -341,8 +341,8 @@ func (f *testOneWALFramework) testAppend(ctx context.Context, w wal.WAL) ([]mess msg = message.CreateTestEmptyInsertMesage(int64(i), map[string]string{}) msg.WithTxnContext(*txnCtx) appendResult, err = w.Append(ctx, msg) - assert.NoError(f.t, err) - assert.NotNil(f.t, msg) + require.NoError(f.t, err) + require.NotNil(f.t, msg) b.Add(msg.IntoImmutableMessage(appendResult.MessageID)) } @@ -357,8 +357,8 @@ func (f *testOneWALFramework) testAppend(ctx context.Context, w wal.WAL) ([]mess "const": "t", }) appendResult, err := w.Append(ctx, msg) - assert.NoError(f.t, err) - assert.NotNil(f.t, appendResult) + require.NoError(f.t, err) + require.NotNil(f.t, appendResult) messages[i] = msg.IntoImmutableMessage(appendResult.MessageID) } else { b, txnCtx := createPartOfTxn() @@ -372,17 +372,17 @@ func (f *testOneWALFramework) testAppend(ctx context.Context, w wal.WAL) ([]mess "const": "t", }). BuildMutable() - assert.NoError(f.t, err) - assert.NotNil(f.t, msg) + require.NoError(f.t, err) + require.NotNil(f.t, msg) appendResult, err := w.Append(ctx, msg.WithTxnContext(*txnCtx)) - assert.NoError(f.t, err) - assert.NotNil(f.t, appendResult) + require.NoError(f.t, err) + require.NotNil(f.t, appendResult) immutableMsg := msg.IntoImmutableMessage(appendResult.MessageID) commit, err := message.AsImmutableCommitTxnMessageV2(immutableMsg) - assert.NoError(f.t, err) + require.NoError(f.t, err) txn, err := b.Build(commit) - assert.NoError(f.t, err) + require.NoError(f.t, err) messages[i] = txn } @@ -395,11 +395,11 @@ func (f *testOneWALFramework) testAppend(ctx context.Context, w wal.WAL) ([]mess WithHeader(&message.RollbackTxnMessageHeader{}). WithBody(&message.RollbackTxnMessageBody{}). BuildMutable() - assert.NoError(f.t, err) - assert.NotNil(f.t, msg) + require.NoError(f.t, err) + require.NotNil(f.t, msg) appendResult, err := w.Append(ctx, msg.WithTxnContext(*txnCtx)) - assert.NoError(f.t, err) - assert.NotNil(f.t, appendResult) + require.NoError(f.t, err) + require.NotNil(f.t, appendResult) } } }(i) @@ -412,7 +412,7 @@ func (f *testOneWALFramework) testAppend(ctx context.Context, w wal.WAL) ([]mess "term": strconv.FormatInt(int64(f.term), 10), }) appendResult, err := w.Append(ctx, msg) - assert.NoError(f.t, err) + require.NoError(f.t, err) messages[f.messageCount-1] = msg.IntoImmutableMessage(appendResult.MessageID) return messages, nil } @@ -425,7 +425,7 @@ func (f *testOneWALFramework) testRead(ctx context.Context, w wal.ROWAL) ([]mess options.DeliverFilterMessageType(message.MessageTypeInsert), }, }) - assert.NoError(f.t, err) + require.NoError(f.t, err) defer s.Close() expectedCnt := f.messageCount + len(f.written) @@ -442,8 +442,8 @@ func (f *testOneWALFramework) testRead(ctx context.Context, w wal.ROWAL) ([]mess if msg.MessageType() != message.MessageTypeInsert && msg.MessageType() != message.MessageTypeTxn { continue } - assert.NotNil(f.t, msg) - assert.True(f.t, ok) + require.NotNil(f.t, msg) + require.True(f.t, ok) msgs = append(msgs, msg) termString, ok := msg.Properties().Get("term") if !ok { @@ -479,7 +479,7 @@ func (f *testOneWALFramework) testReadWithOption(ctx context.Context, w wal.ROWA options.DeliverFilterMessageType(message.MessageTypeInsert), }, }) - assert.NoError(f.t, err) + require.NoError(f.t, err) maxTimeTick := f.maxTimeTickWritten() msgCount := 0 lastTimeTick := readFromMsg.TimeTick() - 1 @@ -489,9 +489,9 @@ func (f *testOneWALFramework) testReadWithOption(ctx context.Context, w wal.ROWA continue } msgCount++ - assert.NotNil(f.t, msg) - assert.True(f.t, ok) - assert.Greater(f.t, msg.TimeTick(), lastTimeTick) + require.NotNil(f.t, msg) + require.True(f.t, ok) + require.Greater(f.t, msg.TimeTick(), lastTimeTick) lastTimeTick = msg.TimeTick() if msg.TimeTick() >= maxTimeTick { break @@ -499,7 +499,7 @@ func (f *testOneWALFramework) testReadWithOption(ctx context.Context, w wal.ROWA } // shouldn't lost any message. - assert.Equal(f.t, f.countTheTimeTick(readFromMsg.TimeTick()), msgCount) + require.Equal(f.t, f.countTheTimeTick(readFromMsg.TimeTick()), msgCount) s.Close() }() } @@ -508,42 +508,42 @@ func (f *testOneWALFramework) testReadWithOption(ctx context.Context, w wal.ROWA func (f *testOneWALFramework) assertSortByTimeTickMessageList(msgs []message.ImmutableMessage) { for i := 1; i < len(msgs); i++ { - assert.Less(f.t, msgs[i-1].TimeTick(), msgs[i].TimeTick()) + require.Less(f.t, msgs[i-1].TimeTick(), msgs[i].TimeTick()) } } func (f *testOneWALFramework) assertEqualMessageList(msgs1 []message.ImmutableMessage, msgs2 []message.ImmutableMessage) { - assert.Equal(f.t, len(msgs2), len(msgs1)) + require.Equal(f.t, len(msgs2), len(msgs1)) for i := 0; i < len(msgs1); i++ { - assert.Equal(f.t, msgs1[i].MessageType(), msgs2[i].MessageType()) + require.Equal(f.t, msgs1[i].MessageType(), msgs2[i].MessageType()) if msgs1[i].MessageType() == message.MessageTypeInsert { - assert.True(f.t, msgs1[i].MessageID().EQ(msgs2[i].MessageID())) - // assert.True(f.t, bytes.Equal(msgs1[i].Payload(), msgs2[i].Payload())) + require.True(f.t, msgs1[i].MessageID().EQ(msgs2[i].MessageID())) + // require.True(f.t, bytes.Equal(msgs1[i].Payload(), msgs2[i].Payload())) id1, ok1 := msgs1[i].Properties().Get("id") id2, ok2 := msgs2[i].Properties().Get("id") - assert.True(f.t, ok1) - assert.True(f.t, ok2) - assert.Equal(f.t, id1, id2) + require.True(f.t, ok1) + require.True(f.t, ok2) + require.Equal(f.t, id1, id2) id1, ok1 = msgs1[i].Properties().Get("const") id2, ok2 = msgs2[i].Properties().Get("const") - assert.True(f.t, ok1) - assert.True(f.t, ok2) - assert.Equal(f.t, id1, id2) + require.True(f.t, ok1) + require.True(f.t, ok2) + require.Equal(f.t, id1, id2) } if msgs1[i].MessageType() == message.MessageTypeTxn { txn1 := message.AsImmutableTxnMessage(msgs1[i]) txn2 := message.AsImmutableTxnMessage(msgs2[i]) - assert.Equal(f.t, txn1.Size(), txn2.Size()) + require.Equal(f.t, txn1.Size(), txn2.Size()) id1, ok1 := txn1.Commit().Properties().Get("id") id2, ok2 := txn2.Commit().Properties().Get("id") - assert.True(f.t, ok1) - assert.True(f.t, ok2) - assert.Equal(f.t, id1, id2) + require.True(f.t, ok1) + require.True(f.t, ok2) + require.Equal(f.t, id1, id2) id1, ok1 = txn1.Commit().Properties().Get("const") id2, ok2 = txn2.Commit().Properties().Get("const") - assert.True(f.t, ok1) - assert.True(f.t, ok2) - assert.Equal(f.t, id1, id2) + require.True(f.t, ok1) + require.True(f.t, ok2) + require.Equal(f.t, id1, id2) } } } diff --git a/internal/streamingnode/server/wal/interceptors/shard/shards/shard_manager_test.go b/internal/streamingnode/server/wal/interceptors/shard/shards/shard_manager_test.go index 81f9529f37..b6ed81819a 100644 --- a/internal/streamingnode/server/wal/interceptors/shard/shards/shard_manager_test.go +++ b/internal/streamingnode/server/wal/interceptors/shard/shards/shard_manager_test.go @@ -38,7 +38,7 @@ func TestShardManager(t *testing.T) { w.EXPECT().Append(mock.Anything, mock.Anything).Return(&types.AppendResult{ MessageID: rmq.NewRmqID(1), TimeTick: 1000, - }, nil) + }, nil).Maybe() f := syncutil.NewFuture[wal.WAL]() f.Set(w) diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack.go index 765605c20f..b88a660884 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/ack.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/ack.go @@ -5,10 +5,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) -var ( - _ typeutil.HeapInterface = (*ackersOrderByTimestamp)(nil) - _ typeutil.HeapInterface = (*ackersOrderByEndTimestamp)(nil) -) +var _ typeutil.HeapInterface = (*ackersOrderByTimestamp)(nil) // Acker records the timestamp and last confirmed message id that has not been acknowledged. type Acker struct { @@ -53,16 +50,6 @@ func (h ackersOrderByTimestamp) Less(i, j int) bool { return h.ackers[i].detail.BeginTimestamp < h.ackers[j].detail.BeginTimestamp } -// ackersOrderByEndTimestamp is a heap underlying represent of timestampAck. -type ackersOrderByEndTimestamp struct { - ackers -} - -// Less returns true if the element at index i is less than the element at index j. -func (h ackersOrderByEndTimestamp) Less(i, j int) bool { - return h.ackers[i].detail.EndTimestamp < h.ackers[j].detail.EndTimestamp -} - // ackers is a heap underlying represent of timestampAck. type ackers []*Acker diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/last_confirmed.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/last_confirmed.go index 82c818e57b..1cee365918 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/last_confirmed.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/last_confirmed.go @@ -7,8 +7,9 @@ import ( ) type uncommittedTxnInfo struct { - session *txn.TxnSession // if nil, it's a non-txn(autocommit) message. - messageID message.MessageID // the message id of the txn begins. + session *txn.TxnSession // if nil, it's a non-txn(autocommit) message. + messageID message.MessageID // the message id of the txn begins. + EndTimestamp uint64 // the end timestamp of the txn. } // newLastConfirmedManager creates a new last confirmed manager. @@ -32,8 +33,9 @@ func (m *lastConfirmedManager) AddConfirmedDetails(details sortedDetails, ts uin continue } m.notDoneTxnMessage.Push(&uncommittedTxnInfo{ - session: detail.TxnSession, - messageID: detail.Message.MessageID(), + session: detail.TxnSession, + messageID: detail.Message.MessageID(), + EndTimestamp: detail.EndTimestamp, }) } m.updateLastConfirmedMessageID(ts) @@ -46,7 +48,13 @@ func (m *lastConfirmedManager) GetLastConfirmedMessageID() message.MessageID { // updateLastConfirmedMessageID updates the last confirmed message id. func (m *lastConfirmedManager) updateLastConfirmedMessageID(ts uint64) { + // only if the end timestamp is less than the last confirmed time tick, it can be used to update the last confirmed message id. + // once the end timestamp is greater than the last confirmed time tick, there's current write operation is in progress, + // so there may be some messages which message id is less than the peek of the notDoneTxnMessage. + // Otherwise, the message id in the notDoneTxnMessage is dense and continuous, can be used to update the last confirmed message id. + // to make the LastConfirmedMessageID promise, also see the message.LastConfirmedMessageID() method. for m.notDoneTxnMessage.Len() > 0 && + m.notDoneTxnMessage.Peek().EndTimestamp < ts && (m.notDoneTxnMessage.Peek().session == nil || m.notDoneTxnMessage.Peek().session.IsExpiredOrDone(ts)) { info := m.notDoneTxnMessage.Pop() if m.lastConfirmedMessageID.LT(info.messageID) { diff --git a/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go b/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go index d9ff2c6112..79050828e9 100644 --- a/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go +++ b/internal/streamingnode/server/wal/interceptors/timetick/ack/manager.go @@ -19,8 +19,6 @@ type AckManager struct { notAckHeap typeutil.Heap[*Acker] // A minimum heap of timestampAck to search minimum allocated but not ack timestamp in list. // Actually, the notAckHeap can be replaced by a list because of the the allocate operation is protected by mutex, // keep it as a heap to make the code more readable. - ackHeap typeutil.Heap[*Acker] // A minimum heap of timestampAck to search minimum ack timestamp in list. - // It is used to detect the concurrent operation to find the last confirmed message id. acknowledgedDetails sortedDetails // All ack details which time tick less than lastConfirmedTimeTick will be temporarily kept here until sync operation happens. lastConfirmedManager *lastConfirmedManager // The last confirmed message id manager. metrics *metricsutil.TimeTickMetrics @@ -36,7 +34,6 @@ func NewAckManager( mu: sync.Mutex{}, lastAllocatedTimeTick: 0, notAckHeap: typeutil.NewHeap[*Acker](&ackersOrderByTimestamp{}), - ackHeap: typeutil.NewHeap[*Acker](&ackersOrderByEndTimestamp{}), lastConfirmedTimeTick: lastConfirmedTimeTick, lastConfirmedManager: newLastConfirmedManager(lastConfirmedMessageID), metrics: metrics, @@ -108,7 +105,6 @@ func (ta *AckManager) ack(acker *Acker) { acker.acknowledged = true acker.detail.EndTimestamp = ta.lastAllocatedTimeTick - ta.ackHeap.Push(acker) ta.metrics.CountAcknowledgeTimeTick(acker.ackDetail().IsSync) ta.popUntilLastAllAcknowledged() } @@ -129,16 +125,6 @@ func (ta *AckManager) popUntilLastAllAcknowledged() { ta.lastConfirmedTimeTick = acknowledgedDetails[len(acknowledgedDetails)-1].BeginTimestamp ta.metrics.UpdateLastConfirmedTimeTick(ta.lastConfirmedTimeTick) - // pop all EndTimestamp is less than lastConfirmedTimeTick. - // All the messages which EndTimetick less than lastConfirmedTimeTick have been committed into wal. - // So the MessageID of those messages is dense and continuous. - confirmedDetails := make(sortedDetails, 0, 5) - for ta.ackHeap.Len() > 0 && ta.ackHeap.Peek().detail.EndTimestamp < ta.lastConfirmedTimeTick { - ack := ta.ackHeap.Pop() - confirmedDetails = append(confirmedDetails, ack.ackDetail()) - } - ta.lastConfirmedManager.AddConfirmedDetails(confirmedDetails, ta.lastConfirmedTimeTick) - // TODO: cache update operation is also performed here. - + ta.lastConfirmedManager.AddConfirmedDetails(acknowledgedDetails, ta.lastConfirmedTimeTick) ta.acknowledgedDetails = append(ta.acknowledgedDetails, acknowledgedDetails...) } diff --git a/pkg/streaming/util/message/message.go b/pkg/streaming/util/message/message.go index 705c911493..045e7a8988 100644 --- a/pkg/streaming/util/message/message.go +++ b/pkg/streaming/util/message/message.go @@ -161,8 +161,8 @@ type ImmutableMessage interface { MessageID() MessageID // LastConfirmedMessageID returns the last confirmed message id of current message. - // last confirmed message is always a timetick message. - // Read from this message id will guarantee the time tick greater than this message is consumed. + // Read from this message id will guarantee the time tick greater than this message's time tick, + // also promise for the txn message. // Available only when the message's version greater than 0. // Otherwise, it will panic. LastConfirmedMessageID() MessageID