diff --git a/internal/datanode/flow_graph_delete_node.go b/internal/datanode/flow_graph_delete_node.go index bf66355577..366ec0314c 100644 --- a/internal/datanode/flow_graph_delete_node.go +++ b/internal/datanode/flow_graph_delete_node.go @@ -13,6 +13,8 @@ package datanode import ( "context" + "encoding/binary" + "errors" "go.uber.org/zap" @@ -55,6 +57,27 @@ func (ddn *deleteNode) Operate(in []Msg) []Msg { return []Msg{} } +func getSegmentsByPKs(pks []int64, segments []*Segment) (map[int64][]int64, error) { + if pks == nil { + return nil, errors.New("pks is nil when getSegmentsByPKs") + } + if segments == nil { + return nil, errors.New("segments is nil when getSegmentsByPKs") + } + results := make(map[int64][]int64) + buf := make([]byte, 8) + for _, segment := range segments { + for _, pk := range pks { + binary.BigEndian.PutUint64(buf, uint64(pk)) + exist := segment.pkFilter.Test(buf) + if exist { + results[segment.segmentID] = append(results[segment.segmentID], pk) + } + } + } + return results, nil +} + func newDeleteDNode(ctx context.Context, replica Replica) *deleteNode { baseNode := BaseNode{} baseNode.SetMaxParallelism(Params.FlowGraphMaxQueueLength) diff --git a/internal/datanode/flow_graph_delete_node_test.go b/internal/datanode/flow_graph_delete_node_test.go index 1f3a7e02c1..8e42e9e542 100644 --- a/internal/datanode/flow_graph_delete_node_test.go +++ b/internal/datanode/flow_graph_delete_node_test.go @@ -13,8 +13,10 @@ package datanode import ( "context" + "encoding/binary" "testing" + "github.com/bits-and-blooms/bloom/v3" "github.com/stretchr/testify/assert" ) @@ -35,3 +37,53 @@ func TestFlowGraphDeleteNode_Operate_Invalid_Size(t *testing.T) { result := deleteNode.Operate([]Msg{Msg1, Msg2}) assert.Equal(t, len(result), 0) } + +func TestGetSegmentsByPKs(t *testing.T) { + buf := make([]byte, 8) + filter1 := bloom.NewWithEstimates(1000000, 0.01) + for i := 0; i < 3; i++ { + binary.BigEndian.PutUint64(buf, uint64(i)) + filter1.Add(buf) + } + filter2 := bloom.NewWithEstimates(1000000, 0.01) + for i := 3; i < 5; i++ { + binary.BigEndian.PutUint64(buf, uint64(i)) + filter2.Add(buf) + } + segment1 := &Segment{ + segmentID: 1, + pkFilter: filter1, + } + segment2 := &Segment{ + segmentID: 2, + pkFilter: filter1, + } + segment3 := &Segment{ + segmentID: 3, + pkFilter: filter1, + } + segment4 := &Segment{ + segmentID: 4, + pkFilter: filter2, + } + segment5 := &Segment{ + segmentID: 5, + pkFilter: filter2, + } + segments := []*Segment{segment1, segment2, segment3, segment4, segment5} + results, err := getSegmentsByPKs([]int64{0, 1, 2, 3, 4}, segments) + assert.Nil(t, err) + expected := map[int64][]int64{ + 1: {0, 1, 2}, + 2: {0, 1, 2}, + 3: {0, 1, 2}, + 4: {3, 4}, + 5: {3, 4}, + } + assert.Equal(t, expected, results) + + _, err = getSegmentsByPKs(nil, segments) + assert.NotNil(t, err) + _, err = getSegmentsByPKs([]int64{0, 1, 2, 3, 4}, nil) + assert.NotNil(t, err) +} diff --git a/internal/querynode/query_collection.go b/internal/querynode/query_collection.go index ce95213fa9..7bf395d6cd 100644 --- a/internal/querynode/query_collection.go +++ b/internal/querynode/query_collection.go @@ -1136,6 +1136,27 @@ func (q *queryCollection) retrieve(msg queryMsg) error { return nil } +func getSegmentsByPKs(pks []int64, segments []*Segment) (map[int64][]int64, error) { + if pks == nil { + return nil, fmt.Errorf("pks is nil when getSegmentsByPKs") + } + if segments == nil { + return nil, fmt.Errorf("segments is nil when getSegmentsByPKs") + } + results := make(map[int64][]int64) + buf := make([]byte, 8) + for _, segment := range segments { + for _, pk := range pks { + binary.BigEndian.PutUint64(buf, uint64(pk)) + exist := segment.pkFilter.Test(buf) + if exist { + results[segment.segmentID] = append(results[segment.segmentID], pk) + } + } + } + return results, nil +} + func mergeRetrieveResults(dataArr []*segcorepb.RetrieveResults) (*segcorepb.RetrieveResults, error) { var final *segcorepb.RetrieveResults for _, data := range dataArr { diff --git a/internal/querynode/query_collection_test.go b/internal/querynode/query_collection_test.go index 699aaac6a7..807ec08a2d 100644 --- a/internal/querynode/query_collection_test.go +++ b/internal/querynode/query_collection_test.go @@ -7,6 +7,7 @@ import ( "math/rand" "testing" + "github.com/bits-and-blooms/bloom/v3" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" @@ -128,3 +129,53 @@ func TestQueryCollection_withoutVChannel(t *testing.T) { historical.close() streaming.close() } + +func TestGetSegmentsByPKs(t *testing.T) { + buf := make([]byte, 8) + filter1 := bloom.NewWithEstimates(1000000, 0.01) + for i := 0; i < 3; i++ { + binary.BigEndian.PutUint64(buf, uint64(i)) + filter1.Add(buf) + } + filter2 := bloom.NewWithEstimates(1000000, 0.01) + for i := 3; i < 5; i++ { + binary.BigEndian.PutUint64(buf, uint64(i)) + filter2.Add(buf) + } + segment1 := &Segment{ + segmentID: 1, + pkFilter: filter1, + } + segment2 := &Segment{ + segmentID: 2, + pkFilter: filter1, + } + segment3 := &Segment{ + segmentID: 3, + pkFilter: filter1, + } + segment4 := &Segment{ + segmentID: 4, + pkFilter: filter2, + } + segment5 := &Segment{ + segmentID: 5, + pkFilter: filter2, + } + segments := []*Segment{segment1, segment2, segment3, segment4, segment5} + results, err := getSegmentsByPKs([]int64{0, 1, 2, 3, 4}, segments) + assert.Nil(t, err) + expected := map[int64][]int64{ + 1: {0, 1, 2}, + 2: {0, 1, 2}, + 3: {0, 1, 2}, + 4: {3, 4}, + 5: {3, 4}, + } + assert.Equal(t, expected, results) + + _, err = getSegmentsByPKs(nil, segments) + assert.NotNil(t, err) + _, err = getSegmentsByPKs([]int64{0, 1, 2, 3, 4}, nil) + assert.NotNil(t, err) +} diff --git a/internal/querynode/segment.go b/internal/querynode/segment.go index 2b77434aaa..7ae7050423 100644 --- a/internal/querynode/segment.go +++ b/internal/querynode/segment.go @@ -29,6 +29,7 @@ import ( "sync" "unsafe" + "github.com/bits-and-blooms/bloom/v3" "github.com/stretchr/testify/assert" "go.uber.org/zap" @@ -90,6 +91,8 @@ type Segment struct { vectorFieldMutex sync.RWMutex // guards vectorFieldInfos vectorFieldInfos map[UniqueID]*VectorFieldInfo + + pkFilter *bloom.BloomFilter // bloom filter of pk inside a segment } //-------------------------------------------------------------------------------------- common interfaces diff --git a/internal/util/segmentfilter/segment_filter.go b/internal/util/segmentfilter/segment_filter.go deleted file mode 100644 index 3a25c0f46c..0000000000 --- a/internal/util/segmentfilter/segment_filter.go +++ /dev/null @@ -1,29 +0,0 @@ -package segmentfilter - -import ( - "github.com/bits-and-blooms/bloom/v3" - "github.com/milvus-io/milvus/internal/proto/datapb" -) - -// SegmentFilter is used to know which segments may have data corresponding -// to the primary key -type SegmentFilter struct { - segmentInfos []*datapb.SegmentInfo - bloomFilters []*bloom.BloomFilter -} - -func NewSegmentFilter(segmentInfos []*datapb.SegmentInfo) *SegmentFilter { - return &SegmentFilter{ - segmentInfos: segmentInfos, - } -} - -func (sf *SegmentFilter) init() { - panic("This method has not been implemented") -} - -// GetSegmentByPK pass a list of primary key and retrun an map of -// -func (sf *SegmentFilter) GetSegmentByPK(pk []string) map[int64][]string { - panic("This method has not been implemented") -}