diff --git a/internal/datanode/segment_replica.go b/internal/datanode/segment_replica.go index ea06e5296d..2f52d3ab3c 100644 --- a/internal/datanode/segment_replica.go +++ b/internal/datanode/segment_replica.go @@ -59,7 +59,10 @@ type Replica interface { updateSegmentEndPosition(segID UniqueID, endPos *internalpb.MsgPosition) updateSegmentCheckPoint(segID UniqueID) updateSegmentPKRange(segID UniqueID, rowIDs []int64) + refreshFlushedSegmentPKRange(segID UniqueID, rowIDs []int64) + addFlushedSegmentWithPKs(segID, collID, partID UniqueID, channelName string, numOfRow int64, rowIDs []int64) hasSegment(segID UniqueID, countFlushed bool) bool + removeSegment(segID UniqueID) updateStatistics(segID UniqueID, numRows int64) getSegmentStatisticsUpdates(segID UniqueID) (*internalpb.SegmentStatisticsUpdates, error) @@ -515,8 +518,13 @@ func (replica *SegmentReplica) updateSegmentPKRange(segID UniqueID, pks []int64) log.Warn("No match segment to update PK range", zap.Int64("ID", segID)) } -func (replica *SegmentReplica) removeSegment(segID UniqueID) error { - return nil +func (replica *SegmentReplica) removeSegment(segID UniqueID) { + replica.segMu.Lock() + defer replica.segMu.Unlock() + + delete(replica.newSegments, segID) + delete(replica.normalSegments, segID) + delete(replica.flushedSegments, segID) } // hasSegment checks whether this replica has a segment according to segment ID. @@ -628,3 +636,55 @@ func (replica *SegmentReplica) updateSegmentCheckPoint(segID UniqueID) { log.Warn("There's no segment", zap.Int64("ID", segID)) } + +// please call hasSegment first +func (replica *SegmentReplica) refreshFlushedSegmentPKRange(segID UniqueID, rowIDs []int64) { + replica.segMu.Lock() + defer replica.segMu.Unlock() + + seg, ok := replica.flushedSegments[segID] + if ok { + seg.pkFilter.ClearAll() + seg.updatePKRange(rowIDs) + return + } + + log.Warn("No match segment to update PK range", zap.Int64("ID", segID)) +} + +func (replica *SegmentReplica) addFlushedSegmentWithPKs(segID, collID, partID UniqueID, channelName string, numOfRows int64, rowIDs []int64) { + if collID != replica.collectionID { + log.Warn("Mismatch collection", + zap.Int64("input ID", collID), + zap.Int64("expected ID", replica.collectionID)) + return + } + + log.Debug("Add Flushed segment", + zap.Int64("segment ID", segID), + zap.Int64("collection ID", collID), + zap.Int64("partition ID", partID), + zap.String("channel name", channelName), + ) + + seg := &Segment{ + collectionID: collID, + partitionID: partID, + segmentID: segID, + channelName: channelName, + numRows: numOfRows, + + pkFilter: bloom.NewWithEstimates(bloomFilterSize, maxBloomFalsePositive), + minPK: math.MaxInt64, // use max value, represents no value + maxPK: math.MinInt64, // use min value represents no value + } + + seg.updatePKRange(rowIDs) + + seg.isNew.Store(false) + seg.isFlushed.Store(true) + + replica.segMu.Lock() + replica.flushedSegments[segID] = seg + replica.segMu.Unlock() +} diff --git a/internal/datanode/segment_replica_test.go b/internal/datanode/segment_replica_test.go index 86a919bb27..cf0fd41ed2 100644 --- a/internal/datanode/segment_replica_test.go +++ b/internal/datanode/segment_replica_test.go @@ -247,10 +247,53 @@ func TestSegmentReplica(t *testing.T) { }) } -func TestSegmentReplica_InterfaceMethod(te *testing.T) { +func TestSegmentReplica_InterfaceMethod(t *testing.T) { rc := &RootCoordFactory{} - te.Run("Test_addNewSegment", func(to *testing.T) { + t.Run("Test refreshFlushedSegmentPKRange", func(t *testing.T) { + replica, err := newReplica(context.TODO(), rc, 1) + require.NoError(t, err) + + require.False(t, replica.hasSegment(100, true)) + replica.refreshFlushedSegmentPKRange(100, []int64{10}) + + replica.addFlushedSegmentWithPKs(100, 1, 10, "a", 1, []int64{9}) + require.True(t, replica.hasSegment(100, true)) + replica.refreshFlushedSegmentPKRange(100, []int64{10}) + + }) + + t.Run("Test addFlushedSegmentWithPKs", func(t *testing.T) { + tests := []struct { + isvalid bool + + incollID UniqueID + replicaCollID UniqueID + description string + }{ + {true, 1, 1, "valid input collection with replica collection"}, + {false, 1, 2, "invalid input collection with replica collection"}, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + replica, err := newReplica(context.TODO(), rc, test.replicaCollID) + require.NoError(t, err) + if test.isvalid { + replica.addFlushedSegmentWithPKs(100, test.incollID, 10, "a", 1, []int64{9}) + + assert.True(t, replica.hasSegment(100, true)) + assert.False(t, replica.hasSegment(100, false)) + } else { + replica.addFlushedSegmentWithPKs(100, test.incollID, 10, "a", 1, []int64{9}) + assert.False(t, replica.hasSegment(100, true)) + assert.False(t, replica.hasSegment(100, false)) + } + }) + } + }) + + t.Run("Test_addNewSegment", func(t *testing.T) { tests := []struct { isValidCase bool replicaCollID UniqueID @@ -270,7 +313,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { } for _, test := range tests { - to.Run(test.description, func(t *testing.T) { + t.Run(test.description, func(t *testing.T) { sr, err := newReplica(context.Background(), rc, test.replicaCollID) assert.Nil(t, err) require.False(t, sr.hasSegment(test.inSegID, true)) @@ -289,7 +332,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { } }) - te.Run("Test_addNormalSegment", func(to *testing.T) { + t.Run("Test_addNormalSegment", func(t *testing.T) { tests := []struct { isValidCase bool replicaCollID UniqueID @@ -306,7 +349,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { } for _, test := range tests { - to.Run(test.description, func(t *testing.T) { + t.Run(test.description, func(t *testing.T) { sr, err := newReplica(context.Background(), rc, test.replicaCollID) sr.minIOKV = &mockMinioKV{} assert.Nil(t, err) @@ -325,7 +368,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { } }) - te.Run("Test_listSegmentsCheckPoints", func(to *testing.T) { + t.Run("Test_listSegmentsCheckPoints", func(t *testing.T) { tests := []struct { newSegID UniqueID newSegCP *segmentCheckPoint @@ -355,7 +398,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { } for _, test := range tests { - to.Run(test.description, func(t *testing.T) { + t.Run(test.description, func(t *testing.T) { sr := SegmentReplica{ newSegments: make(map[UniqueID]*Segment), normalSegments: make(map[UniqueID]*Segment), @@ -381,7 +424,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { } }) - te.Run("Test_updateSegmentEndPosition", func(to *testing.T) { + t.Run("Test_updateSegmentEndPosition", func(t *testing.T) { tests := []struct { newSegID UniqueID normalSegID UniqueID @@ -405,7 +448,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { } for _, test := range tests { - to.Run(test.description, func(t *testing.T) { + t.Run(test.description, func(t *testing.T) { sr := SegmentReplica{ newSegments: make(map[UniqueID]*Segment), normalSegments: make(map[UniqueID]*Segment), @@ -422,14 +465,13 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { sr.flushedSegments[test.flushedSegID] = &Segment{} } sr.updateSegmentEndPosition(test.inSegID, new(internalpb.MsgPosition)) - err := sr.removeSegment(0) - assert.Nil(t, err) + sr.removeSegment(0) }) } }) - te.Run("Test_updateStatistics", func(to *testing.T) { + t.Run("Test_updateStatistics", func(t *testing.T) { tests := []struct { isvalidCase bool @@ -455,7 +497,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { description: "input seg 301 not in flushedSegments"}, } for _, test := range tests { - to.Run(test.description, func(t *testing.T) { + t.Run(test.description, func(t *testing.T) { sr := SegmentReplica{ newSegments: make(map[UniqueID]*Segment), normalSegments: make(map[UniqueID]*Segment), @@ -492,7 +534,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { } }) - te.Run("Test_getCollectionSchema", func(to *testing.T) { + t.Run("Test_getCollectionSchema", func(t *testing.T) { tests := []struct { isValid bool replicaCollID UniqueID @@ -507,7 +549,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { } for _, test := range tests { - to.Run(test.description, func(t *testing.T) { + t.Run(test.description, func(t *testing.T) { sr, err := newReplica(context.Background(), rc, test.replicaCollID) assert.Nil(t, err) @@ -530,43 +572,43 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { }) - te.Run("Test_addSegmentMinIOLoadError", func(to *testing.T) { + t.Run("Test_addSegmentMinIOLoadError", func(t *testing.T) { sr, err := newReplica(context.Background(), rc, 1) - assert.Nil(to, err) + assert.Nil(t, err) sr.minIOKV = &mockMinioKVError{} cpPos := &internalpb.MsgPosition{ChannelName: "insert-01", Timestamp: Timestamp(10)} cp := &segmentCheckPoint{int64(10), *cpPos} err = sr.addNormalSegment(1, 1, 2, "insert-01", int64(10), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, cp) - assert.NotNil(to, err) + assert.NotNil(t, err) err = sr.addFlushedSegment(1, 1, 2, "insert-01", int64(0), []*datapb.FieldBinlog{getSimpleFieldBinlog()}) - assert.NotNil(to, err) + assert.NotNil(t, err) }) - te.Run("Test_addSegmentStatsError", func(to *testing.T) { + t.Run("Test_addSegmentStatsError", func(t *testing.T) { sr, err := newReplica(context.Background(), rc, 1) - assert.Nil(to, err) + assert.Nil(t, err) sr.minIOKV = &mockMinioKVStatsError{} cpPos := &internalpb.MsgPosition{ChannelName: "insert-01", Timestamp: Timestamp(10)} cp := &segmentCheckPoint{int64(10), *cpPos} err = sr.addNormalSegment(1, 1, 2, "insert-01", int64(10), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, cp) - assert.NotNil(to, err) + assert.NotNil(t, err) err = sr.addFlushedSegment(1, 1, 2, "insert-01", int64(0), []*datapb.FieldBinlog{getSimpleFieldBinlog()}) - assert.NotNil(to, err) + assert.NotNil(t, err) }) - te.Run("Test_addSegmentPkfilterError", func(to *testing.T) { + t.Run("Test_addSegmentPkfilterError", func(t *testing.T) { sr, err := newReplica(context.Background(), rc, 1) - assert.Nil(to, err) + assert.Nil(t, err) sr.minIOKV = &mockPkfilterMergeError{} cpPos := &internalpb.MsgPosition{ChannelName: "insert-01", Timestamp: Timestamp(10)} cp := &segmentCheckPoint{int64(10), *cpPos} err = sr.addNormalSegment(1, 1, 2, "insert-01", int64(10), []*datapb.FieldBinlog{getSimpleFieldBinlog()}, cp) - assert.NotNil(to, err) + assert.NotNil(t, err) err = sr.addFlushedSegment(1, 1, 2, "insert-01", int64(0), []*datapb.FieldBinlog{getSimpleFieldBinlog()}) - assert.NotNil(to, err) + assert.NotNil(t, err) }) }