diff --git a/internal/datanode/flow_graph_delete_node.go b/internal/datanode/flow_graph_delete_node.go index 3c732f7d7a..680612f694 100644 --- a/internal/datanode/flow_graph_delete_node.go +++ b/internal/datanode/flow_graph_delete_node.go @@ -68,13 +68,13 @@ func (dn *deleteNode) Operate(in []Msg) []Msg { // filterSegmentByPK returns the bloom filter check result. // If the key may exists in the segment, returns it in map. // If the key not exists in the segment, the segment is filter out. -func (dn *deleteNode) filterSegmentByPK(pks []int64) (map[int64][]int64, error) { +func (dn *deleteNode) filterSegmentByPK(partID UniqueID, pks []int64) (map[int64][]int64, error) { if pks == nil { return nil, errors.New("pks is nil") } results := make(map[int64][]int64) buf := make([]byte, 8) - segments := dn.replica.getSegments(dn.channelName) + segments := dn.replica.filterSegments(dn.channelName, partID) for _, segment := range segments { for _, pk := range pks { binary.BigEndian.PutUint64(buf, uint64(pk)) diff --git a/internal/datanode/flow_graph_delete_node_test.go b/internal/datanode/flow_graph_delete_node_test.go index 0c2c221992..ced7c42518 100644 --- a/internal/datanode/flow_graph_delete_node_test.go +++ b/internal/datanode/flow_graph_delete_node_test.go @@ -27,7 +27,7 @@ type mockReplica struct { flushedSegments map[UniqueID]*Segment } -func (replica *mockReplica) getSegments(channelName string) []*Segment { +func (replica *mockReplica) filterSegments(channelName string, partitionID UniqueID) []*Segment { results := make([]*Segment, 0) for _, value := range replica.newSegments { results = append(results, value) @@ -148,7 +148,7 @@ func Test_GetSegmentsByPKs(t *testing.T) { mockReplica.flushedSegments[segment5.segmentID] = segment5 mockReplica.flushedSegments[segment6.segmentID] = segment6 dn := newDeleteNode(mockReplica, "test", make(chan *flushMsg)) - results, err := dn.filterSegmentByPK([]int64{0, 1, 2, 3, 4}) + results, err := dn.filterSegmentByPK(0, []int64{0, 1, 2, 3, 4}) assert.Nil(t, err) expected := map[int64][]int64{ 0: {1, 2, 3}, @@ -160,5 +160,4 @@ func Test_GetSegmentsByPKs(t *testing.T) { for key, value := range expected { assert.ElementsMatch(t, value, results[key]) } - } diff --git a/internal/datanode/segment_replica.go b/internal/datanode/segment_replica.go index 5644c5b123..94ca07ccd7 100644 --- a/internal/datanode/segment_replica.go +++ b/internal/datanode/segment_replica.go @@ -43,7 +43,7 @@ type Replica interface { addNewSegment(segID, collID, partitionID UniqueID, channelName string, startPos, endPos *internalpb.MsgPosition) error addNormalSegment(segID, collID, partitionID UniqueID, channelName string, numOfRows int64, cp *segmentCheckPoint) error - getSegments(channelName string) []*Segment + filterSegments(channelName string, partitionID UniqueID) []*Segment listNewSegmentsStartPositions() []*datapb.SegmentStartPosition listSegmentsCheckPoints() map[UniqueID]segmentCheckPoint updateSegmentEndPosition(segID UniqueID, endPos *internalpb.MsgPosition) @@ -223,24 +223,28 @@ func (replica *SegmentReplica) addNewSegment(segID, collID, partitionID UniqueID return nil } -// getSegments return segments with same channelName -func (replica *SegmentReplica) getSegments(channelName string) []*Segment { +// filterSegments return segments with same channelName and partition ID +func (replica *SegmentReplica) filterSegments(channelName string, partitionID UniqueID) []*Segment { replica.segMu.Lock() defer replica.segMu.Unlock() results := make([]*Segment, 0) - for _, value := range replica.newSegments { - if value.channelName == channelName { - results = append(results, value) + + isMatched := func(segment *Segment, chanName string, partID UniqueID) bool { + return segment.channelName == chanName && (partID == 0 || segment.partitionID == partID) + } + for _, seg := range replica.newSegments { + if isMatched(seg, channelName, partitionID) { + results = append(results, seg) } } - for _, value := range replica.normalSegments { - if value.channelName == channelName { - results = append(results, value) + for _, seg := range replica.normalSegments { + if isMatched(seg, channelName, partitionID) { + results = append(results, seg) } } - for _, value := range replica.flushedSegments { - if value.channelName == channelName { - results = append(results, value) + for _, seg := range replica.flushedSegments { + if isMatched(seg, channelName, partitionID) { + results = append(results, seg) } } return results diff --git a/internal/datanode/segment_replica_test.go b/internal/datanode/segment_replica_test.go index 5a4733993c..4c68eea031 100644 --- a/internal/datanode/segment_replica_test.go +++ b/internal/datanode/segment_replica_test.go @@ -559,7 +559,7 @@ func TestSegmentReplica_InterfaceMethod(te *testing.T) { err = replica.addFlushedSegment(1, 1, 2, "insert-01", int64(0)) assert.Nil(t, err) - totalSegments := replica.getSegments("insert-01") + totalSegments := replica.filterSegments("insert-01", 0) assert.Equal(t, len(totalSegments), 3) }) }