diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 749fe4e3eb..83cedc4398 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -449,22 +449,6 @@ func TestSaveBinlogPaths(t *testing.T) { } func TestDataNodeTtChannel(t *testing.T) { - ch := make(chan interface{}, 1) - svr := newTestServer(t, ch) - defer closeTestServer(t, svr) - - svr.meta.AddCollection(&datapb.CollectionInfo{ - ID: 0, - Schema: newTestSchema(), - Partitions: []int64{0}, - }) - - ttMsgStream, err := svr.msFactory.NewMsgStream(context.TODO()) - assert.Nil(t, err) - ttMsgStream.AsProducer([]string{Params.TimeTickChannelName}) - ttMsgStream.Start() - defer ttMsgStream.Close() - genMsg := func(msgType commonpb.MsgType, ch string, t Timestamp) *msgstream.DataNodeTtMsg { return &msgstream.DataNodeTtMsg{ BaseMsg: msgstream.BaseMsg{ @@ -482,22 +466,37 @@ func TestDataNodeTtChannel(t *testing.T) { }, } } - info := &datapb.DataNodeInfo{ - Address: "localhost:7777", - Version: 0, - Channels: []*datapb.ChannelStatus{ - { - Name: "ch-1", - State: datapb.ChannelWatchState_Complete, - }, - }, - } - node := NewNodeInfo(context.TODO(), info) - node.client, err = newMockDataNodeClient(1, ch) - assert.Nil(t, err) - svr.cluster.Register(node) - t.Run("Test segment flush after tt", func(t *testing.T) { + ch := make(chan interface{}, 1) + svr := newTestServer(t, ch) + defer closeTestServer(t, svr) + + svr.meta.AddCollection(&datapb.CollectionInfo{ + ID: 0, + Schema: newTestSchema(), + Partitions: []int64{0}, + }) + + ttMsgStream, err := svr.msFactory.NewMsgStream(context.TODO()) + assert.Nil(t, err) + ttMsgStream.AsProducer([]string{Params.TimeTickChannelName}) + ttMsgStream.Start() + defer ttMsgStream.Close() + info := &datapb.DataNodeInfo{ + Address: "localhost:7777", + Version: 0, + Channels: []*datapb.ChannelStatus{ + { + Name: "ch-1", + State: datapb.ChannelWatchState_Complete, + }, + }, + } + node := NewNodeInfo(context.TODO(), info) + node.client, err = newMockDataNodeClient(1, ch) + assert.Nil(t, err) + svr.cluster.Register(node) + resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{ NodeID: 0, PeerRole: "", @@ -540,6 +539,89 @@ func TestDataNodeTtChannel(t *testing.T) { assert.EqualValues(t, assign.SegID, flushReq.SegmentIDs[0]) }) + t.Run("flush segment with different channels", func(t *testing.T) { + ch := make(chan interface{}, 1) + svr := newTestServer(t, ch) + defer closeTestServer(t, svr) + svr.meta.AddCollection(&datapb.CollectionInfo{ + ID: 0, + Schema: newTestSchema(), + Partitions: []int64{0}, + }) + ttMsgStream, err := svr.msFactory.NewMsgStream(context.TODO()) + assert.Nil(t, err) + ttMsgStream.AsProducer([]string{Params.TimeTickChannelName}) + ttMsgStream.Start() + defer ttMsgStream.Close() + info := &datapb.DataNodeInfo{ + Address: "localhost:7777", + Version: 0, + Channels: []*datapb.ChannelStatus{ + { + Name: "ch-1", + State: datapb.ChannelWatchState_Complete, + }, + { + Name: "ch-2", + State: datapb.ChannelWatchState_Complete, + }, + }, + } + node := NewNodeInfo(context.TODO(), info) + node.client, err = newMockDataNodeClient(1, ch) + assert.Nil(t, err) + svr.cluster.Register(node) + resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{ + NodeID: 0, + PeerRole: "", + SegmentIDRequests: []*datapb.SegmentIDRequest{ + { + CollectionID: 0, + PartitionID: 0, + ChannelName: "ch-1", + Count: 100, + }, + { + CollectionID: 0, + PartitionID: 0, + ChannelName: "ch-2", + Count: 100, + }, + }, + }) + assert.Nil(t, err) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + assert.EqualValues(t, 2, len(resp.SegIDAssignments)) + var assign *datapb.SegmentIDAssignment + for _, segment := range resp.SegIDAssignments { + if segment.GetChannelName() == "ch-1" { + assign = segment + break + } + } + assert.NotNil(t, assign) + resp2, err := svr.Flush(context.TODO(), &datapb.FlushRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Flush, + MsgID: 0, + Timestamp: 0, + SourceID: 0, + }, + DbID: 0, + CollectionID: 0, + }) + assert.Nil(t, err) + assert.EqualValues(t, commonpb.ErrorCode_Success, resp2.Status.ErrorCode) + + msgPack := msgstream.MsgPack{} + msg := genMsg(commonpb.MsgType_DataNodeTt, "ch-1", assign.ExpireTime) + msgPack.Msgs = append(msgPack.Msgs, msg) + ttMsgStream.Produce(&msgPack) + flushMsg := <-ch + flushReq := flushMsg.(*datapb.FlushSegmentsRequest) + assert.EqualValues(t, 1, len(flushReq.SegmentIDs)) + assert.EqualValues(t, assign.SegID, flushReq.SegmentIDs[0]) + }) } func TestGetVChannelPos(t *testing.T) {