From 338ed2fed4b935ac38e9641d9ea4e7675457f971 Mon Sep 17 00:00:00 2001 From: jiaqizho Date: Tue, 23 Sep 2025 09:58:09 +0800 Subject: [PATCH] enhance: Introduce sparse filter in query (#44347) issue: #44373 The current commit implements sparse filtering in query tasks using the statistical information (Bloom filter/MinMax) of the Primary Key (PK). The statistical information of the PK is bound to the segment during the segment loading phase. A new filter has been added to the segment filter to enable the sparse filtering functionality. Signed-off-by: jiaqizho --- .../querynodev2/delegator/delegator_data.go | 2 +- .../delegator/delegator_data_test.go | 38 +-- .../querynodev2/pkoracle/bloom_filter_set.go | 10 + internal/querynodev2/segments/mock_loader.go | 33 ++- internal/querynodev2/segments/mock_segment.go | 174 +++++++++++ internal/querynodev2/segments/retrieve.go | 17 +- .../querynodev2/segments/retrieve_test.go | 271 +++++++++++++++--- internal/querynodev2/segments/segment.go | 22 ++ .../querynodev2/segments/segment_filter.go | 177 ++++++++++++ .../querynodev2/segments/segment_interface.go | 7 + .../querynodev2/segments/segment_loader.go | 48 +++- .../segments/segment_loader_test.go | 14 +- internal/querynodev2/segments/validate.go | 8 +- internal/querynodev2/tasks/query_task.go | 10 +- pkg/util/paramtable/component_param.go | 9 + pkg/util/paramtable/component_param_test.go | 1 + .../milvus_client/test_milvus_client_query.py | 56 ++++ 17 files changed, 808 insertions(+), 89 deletions(-) diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index a922b4602c..94b4a5b36c 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -498,7 +498,7 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg } } - candidates, err := sd.loader.LoadBloomFilterSet(ctx, req.GetCollectionID(), req.GetVersion(), infos...) + candidates, err := sd.loader.LoadBloomFilterSet(ctx, req.GetCollectionID(), infos...) if err != nil { log.Warn("failed to load bloom filter set for segment", zap.Error(err)) return err diff --git a/internal/querynodev2/delegator/delegator_data_test.go b/internal/querynodev2/delegator/delegator_data_test.go index e13207d8d2..6bd34e26ad 100644 --- a/internal/querynodev2/delegator/delegator_data_test.go +++ b/internal/querynodev2/delegator/delegator_data_test.go @@ -325,8 +325,8 @@ func (s *DelegatorDataSuite) TestProcessDelete() { return ms }) }, nil) - s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything). - Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { + s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.Anything). + Call.Return(func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet { bfs := pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed) bf := bloomfilter.NewBloomFilterWithType(paramtable.Get().CommonCfg.BloomFilterSize.GetAsUint(), @@ -341,7 +341,7 @@ func (s *DelegatorDataSuite) TestProcessDelete() { bfs.AddHistoricalStats(pks) return bfs }) - }, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error { + }, func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) error { return nil }) @@ -560,12 +560,12 @@ func (s *DelegatorDataSuite) TestLoadSegmentsWithBm25() { statsMap.Insert(1, map[int64]*storage.BM25Stats{101: stats}) s.loader.EXPECT().LoadBM25Stats(mock.Anything, s.collectionID, mock.Anything).Return(statsMap, nil) - s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything). - Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { + s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.Anything). + Call.Return(func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet { return pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed) }) - }, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error { + }, func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) error { return nil }) @@ -659,12 +659,12 @@ func (s *DelegatorDataSuite) TestLoadSegments() { s.loader.ExpectedCalls = nil }() - s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything). - Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { + s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.Anything). + Call.Return(func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet { return pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed) }) - }, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error { + }, func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) error { return nil }) @@ -717,8 +717,8 @@ func (s *DelegatorDataSuite) TestLoadSegments() { s.loader.ExpectedCalls = nil }() - s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything). - Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { + s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.Anything). + Call.Return(func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet { bfs := pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed) bf := bloomfilter.NewBloomFilterWithType( @@ -734,7 +734,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() { bfs.AddHistoricalStats(pks) return bfs }) - }, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error { + }, func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) error { return nil }) @@ -866,7 +866,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() { s.loader.ExpectedCalls = nil }() - s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything). + s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.Anything). Return(nil, errors.New("mocked error")) workers := make(map[int64]*cluster.MockWorker) @@ -905,12 +905,12 @@ func (s *DelegatorDataSuite) TestLoadSegments() { s.loader.ExpectedCalls = nil }() - s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything). - Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { + s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.Anything). + Call.Return(func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet { return pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed) }) - }, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error { + }, func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) error { return nil }) @@ -1203,8 +1203,8 @@ func (s *DelegatorDataSuite) TestReleaseSegment() { return ms }) }, nil) - s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.AnythingOfType("int64"), mock.Anything). - Call.Return(func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { + s.loader.EXPECT().LoadBloomFilterSet(mock.Anything, s.collectionID, mock.Anything). + Call.Return(func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet { return lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *pkoracle.BloomFilterSet { bfs := pkoracle.NewBloomFilterSet(info.GetSegmentID(), info.GetPartitionID(), commonpb.SegmentState_Sealed) bf := bloomfilter.NewBloomFilterWithType( @@ -1220,7 +1220,7 @@ func (s *DelegatorDataSuite) TestReleaseSegment() { bfs.AddHistoricalStats(pks) return bfs }) - }, func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) error { + }, func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) error { return nil }) diff --git a/internal/querynodev2/pkoracle/bloom_filter_set.go b/internal/querynodev2/pkoracle/bloom_filter_set.go index cf546ea9b1..0801e1549c 100644 --- a/internal/querynodev2/pkoracle/bloom_filter_set.go +++ b/internal/querynodev2/pkoracle/bloom_filter_set.go @@ -89,6 +89,16 @@ func (s *BloomFilterSet) Type() commonpb.SegmentState { return s.segType } +// Get stats +func (s *BloomFilterSet) Stats() *storage.PkStatistics { + return s.currentStat +} + +// Have BloomFilter exist +func (s *BloomFilterSet) BloomFilterExist() bool { + return s.currentStat != nil || s.historyStats != nil +} + // UpdateBloomFilter updates currentStats with provided pks. func (s *BloomFilterSet) UpdateBloomFilter(pks []storage.PrimaryKey) { s.statsMutex.Lock() diff --git a/internal/querynodev2/segments/mock_loader.go b/internal/querynodev2/segments/mock_loader.go index a619ca21c5..f783e3e48c 100644 --- a/internal/querynodev2/segments/mock_loader.go +++ b/internal/querynodev2/segments/mock_loader.go @@ -183,14 +183,14 @@ func (_c *MockLoader_LoadBM25Stats_Call) RunAndReturn(run func(context.Context, return _c } -// LoadBloomFilterSet provides a mock function with given fields: ctx, collectionID, version, infos -func (_m *MockLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) { +// LoadBloomFilterSet provides a mock function with given fields: ctx, collectionID, infos +func (_m *MockLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) { _va := make([]interface{}, len(infos)) for _i := range infos { _va[_i] = infos[_i] } var _ca []interface{} - _ca = append(_ca, ctx, collectionID, version) + _ca = append(_ca, ctx, collectionID) _ca = append(_ca, _va...) ret := _m.Called(_ca...) @@ -200,19 +200,19 @@ func (_m *MockLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64 var r0 []*pkoracle.BloomFilterSet var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int64, int64, ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error)); ok { - return rf(ctx, collectionID, version, infos...) + if rf, ok := ret.Get(0).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error)); ok { + return rf(ctx, collectionID, infos...) } - if rf, ok := ret.Get(0).(func(context.Context, int64, int64, ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet); ok { - r0 = rf(ctx, collectionID, version, infos...) + if rf, ok := ret.Get(0).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) []*pkoracle.BloomFilterSet); ok { + r0 = rf(ctx, collectionID, infos...) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*pkoracle.BloomFilterSet) } } - if rf, ok := ret.Get(1).(func(context.Context, int64, int64, ...*querypb.SegmentLoadInfo) error); ok { - r1 = rf(ctx, collectionID, version, infos...) + if rf, ok := ret.Get(1).(func(context.Context, int64, ...*querypb.SegmentLoadInfo) error); ok { + r1 = rf(ctx, collectionID, infos...) } else { r1 = ret.Error(1) } @@ -228,22 +228,21 @@ type MockLoader_LoadBloomFilterSet_Call struct { // LoadBloomFilterSet is a helper method to define mock.On call // - ctx context.Context // - collectionID int64 -// - version int64 // - infos ...*querypb.SegmentLoadInfo -func (_e *MockLoader_Expecter) LoadBloomFilterSet(ctx interface{}, collectionID interface{}, version interface{}, infos ...interface{}) *MockLoader_LoadBloomFilterSet_Call { +func (_e *MockLoader_Expecter) LoadBloomFilterSet(ctx interface{}, collectionID interface{}, infos ...interface{}) *MockLoader_LoadBloomFilterSet_Call { return &MockLoader_LoadBloomFilterSet_Call{Call: _e.mock.On("LoadBloomFilterSet", - append([]interface{}{ctx, collectionID, version}, infos...)...)} + append([]interface{}{ctx, collectionID}, infos...)...)} } -func (_c *MockLoader_LoadBloomFilterSet_Call) Run(run func(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo)) *MockLoader_LoadBloomFilterSet_Call { +func (_c *MockLoader_LoadBloomFilterSet_Call) Run(run func(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo)) *MockLoader_LoadBloomFilterSet_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]*querypb.SegmentLoadInfo, len(args)-3) - for i, a := range args[3:] { + variadicArgs := make([]*querypb.SegmentLoadInfo, len(args)-2) + for i, a := range args[2:] { if a != nil { variadicArgs[i] = a.(*querypb.SegmentLoadInfo) } } - run(args[0].(context.Context), args[1].(int64), args[2].(int64), variadicArgs...) + run(args[0].(context.Context), args[1].(int64), variadicArgs...) }) return _c } @@ -253,7 +252,7 @@ func (_c *MockLoader_LoadBloomFilterSet_Call) Return(_a0 []*pkoracle.BloomFilter return _c } -func (_c *MockLoader_LoadBloomFilterSet_Call) RunAndReturn(run func(context.Context, int64, int64, ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error)) *MockLoader_LoadBloomFilterSet_Call { +func (_c *MockLoader_LoadBloomFilterSet_Call) RunAndReturn(run func(context.Context, int64, ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error)) *MockLoader_LoadBloomFilterSet_Call { _c.Call.Return(run) return _c } diff --git a/internal/querynodev2/segments/mock_segment.go b/internal/querynodev2/segments/mock_segment.go index 069199c4d6..f8a0cf6769 100644 --- a/internal/querynodev2/segments/mock_segment.go +++ b/internal/querynodev2/segments/mock_segment.go @@ -15,6 +15,8 @@ import ( msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" + querypb "github.com/milvus-io/milvus/pkg/v2/proto/querypb" segcore "github.com/milvus-io/milvus/internal/util/segcore" @@ -85,6 +87,51 @@ func (_c *MockSegment_BatchPkExist_Call) RunAndReturn(run func(*storage.BatchLoc return _c } +// BloomFilterExist provides a mock function with no fields +func (_m *MockSegment) BloomFilterExist() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for BloomFilterExist") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockSegment_BloomFilterExist_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BloomFilterExist' +type MockSegment_BloomFilterExist_Call struct { + *mock.Call +} + +// BloomFilterExist is a helper method to define mock.On call +func (_e *MockSegment_Expecter) BloomFilterExist() *MockSegment_BloomFilterExist_Call { + return &MockSegment_BloomFilterExist_Call{Call: _e.mock.On("BloomFilterExist")} +} + +func (_c *MockSegment_BloomFilterExist_Call) Run(run func()) *MockSegment_BloomFilterExist_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSegment_BloomFilterExist_Call) Return(_a0 bool) *MockSegment_BloomFilterExist_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSegment_BloomFilterExist_Call) RunAndReturn(run func() bool) *MockSegment_BloomFilterExist_Call { + _c.Call.Return(run) + return _c +} + // CASVersion provides a mock function with given fields: _a0, _a1 func (_m *MockSegment) CASVersion(_a0 int64, _a1 int64) bool { ret := _m.Called(_a0, _a1) @@ -598,6 +645,100 @@ func (_c *MockSegment_GetIndexByID_Call) RunAndReturn(run func(int64) *IndexedFi return _c } +// GetMaxPk provides a mock function with no fields +func (_m *MockSegment) GetMaxPk() *storage.PrimaryKey { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetMaxPk") + } + + var r0 *storage.PrimaryKey + if rf, ok := ret.Get(0).(func() *storage.PrimaryKey); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*storage.PrimaryKey) + } + } + + return r0 +} + +// MockSegment_GetMaxPk_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetMaxPk' +type MockSegment_GetMaxPk_Call struct { + *mock.Call +} + +// GetMaxPk is a helper method to define mock.On call +func (_e *MockSegment_Expecter) GetMaxPk() *MockSegment_GetMaxPk_Call { + return &MockSegment_GetMaxPk_Call{Call: _e.mock.On("GetMaxPk")} +} + +func (_c *MockSegment_GetMaxPk_Call) Run(run func()) *MockSegment_GetMaxPk_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSegment_GetMaxPk_Call) Return(_a0 *storage.PrimaryKey) *MockSegment_GetMaxPk_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSegment_GetMaxPk_Call) RunAndReturn(run func() *storage.PrimaryKey) *MockSegment_GetMaxPk_Call { + _c.Call.Return(run) + return _c +} + +// GetMinPk provides a mock function with no fields +func (_m *MockSegment) GetMinPk() *storage.PrimaryKey { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetMinPk") + } + + var r0 *storage.PrimaryKey + if rf, ok := ret.Get(0).(func() *storage.PrimaryKey); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*storage.PrimaryKey) + } + } + + return r0 +} + +// MockSegment_GetMinPk_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetMinPk' +type MockSegment_GetMinPk_Call struct { + *mock.Call +} + +// GetMinPk is a helper method to define mock.On call +func (_e *MockSegment_Expecter) GetMinPk() *MockSegment_GetMinPk_Call { + return &MockSegment_GetMinPk_Call{Call: _e.mock.On("GetMinPk")} +} + +func (_c *MockSegment_GetMinPk_Call) Run(run func()) *MockSegment_GetMinPk_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSegment_GetMinPk_Call) Return(_a0 *storage.PrimaryKey) *MockSegment_GetMinPk_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSegment_GetMinPk_Call) RunAndReturn(run func() *storage.PrimaryKey) *MockSegment_GetMinPk_Call { + _c.Call.Return(run) + return _c +} + // HasRawData provides a mock function with given fields: fieldID func (_m *MockSegment) HasRawData(fieldID int64) bool { ret := _m.Called(fieldID) @@ -1768,6 +1909,39 @@ func (_c *MockSegment_Search_Call) RunAndReturn(run func(context.Context, *segco return _c } +// SetBloomFilter provides a mock function with given fields: bf +func (_m *MockSegment) SetBloomFilter(bf *pkoracle.BloomFilterSet) { + _m.Called(bf) +} + +// MockSegment_SetBloomFilter_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetBloomFilter' +type MockSegment_SetBloomFilter_Call struct { + *mock.Call +} + +// SetBloomFilter is a helper method to define mock.On call +// - bf *pkoracle.BloomFilterSet +func (_e *MockSegment_Expecter) SetBloomFilter(bf interface{}) *MockSegment_SetBloomFilter_Call { + return &MockSegment_SetBloomFilter_Call{Call: _e.mock.On("SetBloomFilter", bf)} +} + +func (_c *MockSegment_SetBloomFilter_Call) Run(run func(bf *pkoracle.BloomFilterSet)) *MockSegment_SetBloomFilter_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*pkoracle.BloomFilterSet)) + }) + return _c +} + +func (_c *MockSegment_SetBloomFilter_Call) Return() *MockSegment_SetBloomFilter_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSegment_SetBloomFilter_Call) RunAndReturn(run func(*pkoracle.BloomFilterSet)) *MockSegment_SetBloomFilter_Call { + _c.Run(run) + return _c +} + // Shard provides a mock function with no fields func (_m *MockSegment) Shard() metautil.Channel { ret := _m.Called() diff --git a/internal/querynodev2/segments/retrieve.go b/internal/querynodev2/segments/retrieve.go index c4ae9f5fc5..c5728faefa 100644 --- a/internal/querynodev2/segments/retrieve.go +++ b/internal/querynodev2/segments/retrieve.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/segcorepb" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -164,7 +165,7 @@ func retrieveOnSegmentsWithStream(ctx context.Context, mgr *Manager, segments [] } // retrieve will retrieve all the validate target segments -func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest) ([]RetrieveSegmentResult, []Segment, error) { +func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *querypb.QueryRequest, queryPlan *planpb.PlanNode) ([]RetrieveSegmentResult, []Segment, error) { if ctx.Err() != nil { return nil, nil, ctx.Err() } @@ -172,6 +173,7 @@ func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *qu var err error var SegType commonpb.SegmentState var retrieveSegments []Segment + var segFilters []SegmentFilter = make([]SegmentFilter, 0) segIDs := req.GetSegmentIDs() collID := req.Req.GetCollectionID() @@ -180,10 +182,18 @@ func Retrieve(ctx context.Context, manager *Manager, plan *RetrievePlan, req *qu if req.GetScope() == querypb.DataScope_Historical { SegType = SegmentTypeSealed - retrieveSegments, err = validateOnHistorical(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs) + segFilters = append(segFilters, WithType(SegmentTypeSealed)) + if paramtable.Get().QueryNodeCfg.EnableSparseFilterInQuery.GetAsBool() { + segFilters = append(segFilters, WithSparseFilter(queryPlan)) + } + retrieveSegments, err = validate(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs, segFilters...) } else { SegType = SegmentTypeGrowing - retrieveSegments, err = validateOnStream(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs) + segFilters = append(segFilters, WithType(SegmentTypeGrowing)) + if paramtable.Get().QueryNodeCfg.EnableSparseFilterInQuery.GetAsBool() { + segFilters = append(segFilters, WithSparseFilter(queryPlan)) + } + retrieveSegments, err = validate(ctx, manager, collID, req.GetReq().GetPartitionIDs(), segIDs, segFilters...) } if err != nil { @@ -202,6 +212,7 @@ func RetrieveStream(ctx context.Context, manager *Manager, plan *RetrievePlan, r segIDs := req.GetSegmentIDs() collID := req.Req.GetCollectionID() + log.Ctx(ctx).Debug("retrieve stream on segments", zap.Int64s("segmentIDs", segIDs), zap.Int64("collectionID", collID)) if req.GetScope() == querypb.DataScope_Historical { SegType = SegmentTypeSealed diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index 4f3edd6ca8..b81082cd58 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -26,20 +26,27 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" + "github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/initcore" "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) type RetrieveSuite struct { suite.Suite + // schema + ctx context.Context + schema *schemapb.CollectionSchema + // Dependencies rootPath string chunkManager storage.ChunkManager @@ -60,23 +67,26 @@ func (suite *RetrieveSuite) SetupSuite() { func (suite *RetrieveSuite) SetupTest() { var err error - ctx := context.Background() + suite.ctx = context.Background() msgLength := 100 suite.rootPath = suite.T().Name() chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) - suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx) + suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(suite.ctx) initcore.InitRemoteChunkManager(paramtable.Get()) + initcore.InitLocalChunkManager(suite.rootPath) + initcore.InitMmapManager(paramtable.Get(), 1) + initcore.InitTieredStorage(paramtable.Get()) suite.collectionID = 100 suite.partitionID = 10 - suite.segmentID = 1 + suite.segmentID = 100 suite.manager = NewManager() - schema := mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) - indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, schema) + suite.schema = mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) + indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, suite.schema) suite.manager.Collection.PutOrRef(suite.collectionID, - schema, + suite.schema, indexMeta, &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, @@ -85,57 +95,88 @@ func (suite *RetrieveSuite) SetupTest() { }, ) suite.collection = suite.manager.Collection.Get(suite.collectionID) + loader := NewLoader(suite.ctx, suite.manager, suite.chunkManager) - suite.sealed, err = NewSegment(ctx, - suite.collection, - suite.manager.Segment, - SegmentTypeSealed, - 0, - &querypb.SegmentLoadInfo{ - SegmentID: suite.segmentID, - CollectionID: suite.collectionID, - PartitionID: suite.partitionID, - NumOfRows: int64(msgLength), - InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), - Level: datapb.SegmentLevel_Legacy, - }, - ) - suite.Require().NoError(err) - - binlogs, _, err := mock_segcore.SaveBinLog(ctx, + binlogs, statslogs, err := mock_segcore.SaveBinLog(suite.ctx, suite.collectionID, suite.partitionID, suite.segmentID, msgLength, - schema, + suite.schema, suite.chunkManager, ) suite.Require().NoError(err) + + sealLoadInfo := querypb.SegmentLoadInfo{ + SegmentID: suite.segmentID, + CollectionID: suite.collectionID, + PartitionID: suite.partitionID, + NumOfRows: int64(msgLength), + BinlogPaths: binlogs, + Statslogs: statslogs, + InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), + Level: datapb.SegmentLevel_Legacy, + } + + suite.sealed, err = NewSegment(suite.ctx, + suite.collection, + suite.manager.Segment, + SegmentTypeSealed, + 0, + &sealLoadInfo, + ) + suite.Require().NoError(err) + + bfs, err := loader.loadSingleBloomFilterSet(suite.ctx, suite.collectionID, &sealLoadInfo, SegmentTypeSealed) + suite.Require().NoError(err) + suite.sealed.SetBloomFilter(bfs) + for _, binlog := range binlogs { - err = suite.sealed.(*LocalSegment).LoadFieldData(ctx, binlog.FieldID, int64(msgLength), binlog) + err = suite.sealed.(*LocalSegment).LoadFieldData(suite.ctx, binlog.FieldID, int64(msgLength), binlog) suite.Require().NoError(err) } - suite.growing, err = NewSegment(ctx, + binlogs, statlogs, err := mock_segcore.SaveBinLog(suite.ctx, + suite.collectionID, + suite.partitionID, + suite.segmentID+1, + msgLength, + suite.schema, + suite.chunkManager, + ) + suite.Require().NoError(err) + + growingLoadInfo := querypb.SegmentLoadInfo{ + SegmentID: suite.segmentID + 1, + CollectionID: suite.collectionID, + PartitionID: suite.partitionID, + BinlogPaths: binlogs, + Statslogs: statlogs, + InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), + Level: datapb.SegmentLevel_Legacy, + } + + // allow growing segment use the bloom filter + paramtable.Get().QueryNodeCfg.SkipGrowingSegmentBF.SwapTempValue("false") + + suite.growing, err = NewSegment(suite.ctx, suite.collection, suite.manager.Segment, SegmentTypeGrowing, 0, - &querypb.SegmentLoadInfo{ - SegmentID: suite.segmentID + 1, - CollectionID: suite.collectionID, - PartitionID: suite.partitionID, - InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), - Level: datapb.SegmentLevel_Legacy, - }, + &growingLoadInfo, ) suite.Require().NoError(err) + bfs, err = loader.loadSingleBloomFilterSet(suite.ctx, suite.collectionID, &growingLoadInfo, SegmentTypeGrowing) + suite.Require().NoError(err) + suite.growing.SetBloomFilter(bfs) + insertMsg, err := mock_segcore.GenInsertMsg(suite.collection.GetCCollection(), suite.partitionID, suite.growing.ID(), msgLength) suite.Require().NoError(err) insertRecord, err := storage.TransferInsertMsgToInsertRecord(suite.collection.Schema(), insertMsg) suite.Require().NoError(err) - err = suite.growing.Insert(ctx, insertMsg.RowIDs, insertMsg.Timestamps, insertRecord) + err = suite.growing.Insert(suite.ctx, insertMsg.RowIDs, insertMsg.Timestamps, insertRecord) suite.Require().NoError(err) suite.manager.Segment.Put(context.Background(), SegmentTypeSealed, suite.sealed) @@ -163,7 +204,7 @@ func (suite *RetrieveSuite) TestRetrieveSealed() { Scope: querypb.DataScope_Historical, } - res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) + res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req, nil) suite.NoError(err) suite.Len(res[0].Result.Offset, 3) suite.manager.Segment.Unpin(segments) @@ -176,6 +217,160 @@ func (suite *RetrieveSuite) TestRetrieveSealed() { suite.Len(resultByOffsets.Offset, 0) } +func (suite *RetrieveSuite) TestRetrieveWithFilter() { + plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection()) + suite.NoError(err) + + suite.Run("SealSegmentFilter", func() { + // no exist pk + exprStr := "int64Field == 10000000" + schemaHelper, _ := typeutil.CreateSchemaHelper(suite.schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr, nil) + suite.NoError(err) + + req := &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + CollectionID: suite.collectionID, + PartitionIDs: []int64{suite.partitionID}, + }, + SegmentIDs: []int64{suite.sealed.ID()}, + Scope: querypb.DataScope_Historical, + } + + res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req, planNode) + suite.NoError(err) + suite.Len(res, 0) + suite.manager.Segment.Unpin(segments) + }) + + suite.Run("GrowingSegmentFilter", func() { + exprStr := "int64Field == 10000000" + schemaHelper, _ := typeutil.CreateSchemaHelper(suite.schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr, nil) + suite.NoError(err) + + req := &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + CollectionID: suite.collectionID, + PartitionIDs: []int64{suite.partitionID}, + }, + SegmentIDs: []int64{suite.growing.ID()}, + Scope: querypb.DataScope_Streaming, + } + + res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req, planNode) + suite.NoError(err) + suite.Len(res, 0) + suite.manager.Segment.Unpin(segments) + }) + + suite.Run("SegmentFilterRules", func() { + // create more 10 seal segments to test BF + // The pk in seg range is {segN [0...N-1]} + // ex. + // seg1 [0] + // seg2 [0, 1] + // ... + // seg10 [0, 1, 2, ..., 9] + loader := NewLoader(suite.ctx, suite.manager, suite.chunkManager) + for i := range 10 { + segid := int64(i + 1) + msgLen := i + 1 + bl, sl, err := mock_segcore.SaveBinLog(suite.ctx, + suite.collectionID, + suite.partitionID, + segid, + msgLen, + suite.schema, + suite.chunkManager) + + suite.Require().NoError(err) + + sealLoadInfo := querypb.SegmentLoadInfo{ + SegmentID: segid, + CollectionID: suite.collectionID, + PartitionID: suite.partitionID, + NumOfRows: int64(msgLen), + BinlogPaths: bl, + Statslogs: sl, + InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), + Level: datapb.SegmentLevel_Legacy, + } + + sealseg, err := NewSegment(suite.ctx, + suite.collection, + suite.manager.Segment, + SegmentTypeSealed, + 0, + &sealLoadInfo, + ) + suite.Require().NoError(err) + + bfs, err := loader.loadSingleBloomFilterSet(suite.ctx, suite.collectionID, + &sealLoadInfo, SegmentTypeSealed) + suite.Require().NoError(err) + sealseg.SetBloomFilter(bfs) + + suite.manager.Segment.Put(suite.ctx, SegmentTypeSealed, sealseg) + } + + exprs := map[string]int{ + // empty plan + "": 10, + // filter half of seal segments + "int64Field == 5": 5, + "int64Field == 6": 4, + // AND operator, int8Field have not stats but we still can use the int64Field(pk) + "int64Field == 6 and int8Field == -10000": 4, + // nesting expression + "int64Field == 6 and (int64Field == 7 or int8Field == -10000)": 4, + // OR operator + // can't filter, OR operator need both side be filter + "int64Field == 6 or int8Field == -10000": 10, + // can filter + "int64Field == 6 or (int64Field == 7 and int8Field == -10000)": 4, + // IN operator + "int64Field IN [7, 8, 9]": 3, + // NOT IN operator should not be filter + "int64Field NOT IN [7, 8, 9]": 10, + "NOT (int64Field IN [7, 8, 9])": 10, + // empty range + "int64Field IN []": 10, + } + + req := &querypb.QueryRequest{ + Req: &internalpb.RetrieveRequest{ + CollectionID: suite.collectionID, + PartitionIDs: []int64{suite.partitionID}, + }, + SegmentIDs: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + Scope: querypb.DataScope_Historical, + } + + for exprStr, expect := range exprs { + schemaHelper, _ := typeutil.CreateSchemaHelper(suite.schema) + var planNode *planpb.PlanNode + if exprStr == "" { + planNode = nil + err = nil + } else { + planNode, err = planparserv2.CreateRetrievePlan(schemaHelper, exprStr, nil) + } + suite.NoError(err) + + res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req, planNode) + suite.NoError(err) + suite.Len(res, expect) + suite.manager.Segment.Unpin(segments) + } + + // remove the segs + for i := range 10 { + suite.manager.Segment.Remove(suite.ctx, int64(i+1) /*segmentID*/, querypb.DataScope_Historical) + } + }) +} + func (suite *RetrieveSuite) TestRetrieveGrowing() { plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection()) suite.NoError(err) @@ -189,7 +384,7 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() { Scope: querypb.DataScope_Streaming, } - res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) + res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req, nil) suite.NoError(err) suite.Len(res[0].Result.Offset, 3) suite.manager.Segment.Unpin(segments) @@ -259,7 +454,7 @@ func (suite *RetrieveSuite) TestRetrieveNonExistSegment() { Scope: querypb.DataScope_Streaming, } - res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) + res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req, nil) suite.Error(err) suite.Len(res, 0) suite.manager.Segment.Unpin(segments) @@ -279,7 +474,7 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() { Scope: querypb.DataScope_Historical, } - res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req) + res, segments, err := Retrieve(context.TODO(), suite.manager, plan, req, nil) suite.ErrorIs(err, merr.ErrSegmentNotLoaded) suite.Len(res, 0) suite.manager.Segment.Unpin(segments) diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 71a1c6bc18..11435af5b9 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -188,6 +188,14 @@ func (s *baseSegment) LoadInfo() *querypb.SegmentLoadInfo { return s.loadInfo.Load() } +func (s *baseSegment) SetBloomFilter(bf *pkoracle.BloomFilterSet) { + s.bloomFilterSet = bf +} + +func (s *baseSegment) BloomFilterExist() bool { + return s.bloomFilterSet.BloomFilterExist() +} + func (s *baseSegment) UpdateBloomFilter(pks []storage.PrimaryKey) { if s.skipGrowingBF { return @@ -219,6 +227,20 @@ func (s *baseSegment) MayPkExist(pk *storage.LocationsCache) bool { return s.bloomFilterSet.MayPkExist(pk) } +func (s *baseSegment) GetMinPk() *storage.PrimaryKey { + if s.bloomFilterSet.Stats() == nil { + return nil + } + return &s.bloomFilterSet.Stats().MinPK +} + +func (s *baseSegment) GetMaxPk() *storage.PrimaryKey { + if s.bloomFilterSet.Stats() == nil { + return nil + } + return &s.bloomFilterSet.Stats().MaxPK +} + func (s *baseSegment) BatchPkExist(lc *storage.BatchLocationsCache) []bool { if s.skipGrowingBF { allPositive := make([]bool, lc.Size()) diff --git a/internal/querynodev2/segments/segment_filter.go b/internal/querynodev2/segments/segment_filter.go index e9401a7d15..a2004ccebe 100644 --- a/internal/querynodev2/segments/segment_filter.go +++ b/internal/querynodev2/segments/segment_filter.go @@ -17,7 +17,13 @@ package segments import ( + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + storage "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" "github.com/milvus-io/milvus/pkg/v2/util/metautil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -147,3 +153,174 @@ func WithoutLevel(level datapb.SegmentLevel) SegmentFilter { return segment.Level() != level }) } + +// Bloom filter + +type filterFunc func(pk storage.PrimaryKey, op planpb.OpType) bool + +func binaryExprWalker(expr *planpb.BinaryExpr, filter filterFunc) bool { + switch expr.Op { + case planpb.BinaryExpr_LogicalAnd: + // one of expression return false + return exprWalker(expr.Left, filter) && + exprWalker(expr.Right, filter) + + case planpb.BinaryExpr_LogicalOr: + // both left and right return false + return exprWalker(expr.Left, filter) || + exprWalker(expr.Right, filter) + } + + // unknown operator return no filter + return true +} + +func unaryRangeExprWalker(expr *planpb.UnaryRangeExpr, filter filterFunc) bool { + if expr.GetColumnInfo() == nil || + !expr.GetColumnInfo().GetIsPrimaryKey() || + expr.GetValue() == nil { + // not the primary key + return true + } + + var pk storage.PrimaryKey + dt := expr.GetColumnInfo().GetDataType() + + switch dt { + case schemapb.DataType_Int64: + pk = storage.NewInt64PrimaryKey(expr.GetValue().GetInt64Val()) + case schemapb.DataType_VarChar: + pk = storage.NewVarCharPrimaryKey(expr.GetValue().GetStringVal()) + default: + log.Warn("unknown pk type", + zap.Int("type", int(dt)), + zap.String("expr", expr.String())) + return true + } + + return filter(pk, expr.Op) +} + +func termExprWalker(expr *planpb.TermExpr, filter filterFunc) bool { + noFilter := true + if expr.GetColumnInfo() == nil || + !expr.GetColumnInfo().GetIsPrimaryKey() { + return noFilter + } + + // In empty array, direct return + if expr.GetValues() == nil { + return false + } + + var pk storage.PrimaryKey + dt := expr.GetColumnInfo().GetDataType() + invals := expr.GetValues() + + for _, pkval := range invals { + switch dt { + case schemapb.DataType_Int64: + pk = storage.NewInt64PrimaryKey(pkval.GetInt64Val()) + case schemapb.DataType_VarChar: + pk = storage.NewVarCharPrimaryKey(pkval.GetStringVal()) + default: + log.Warn("unknown pk type", + zap.Int("type", int(dt)), + zap.String("expr", expr.String())) + return noFilter + } + + noFilter = filter(pk, planpb.OpType_Equal) + if noFilter { + break + } + } + + return noFilter +} + +// return true if current segment can be filtered +func exprWalker(expr *planpb.Expr, filter filterFunc) bool { + switch expr := expr.GetExpr().(type) { + case *planpb.Expr_BinaryExpr: + return binaryExprWalker(expr.BinaryExpr, filter) + case *planpb.Expr_UnaryRangeExpr: + return unaryRangeExprWalker(expr.UnaryRangeExpr, filter) + case *planpb.Expr_TermExpr: + return termExprWalker(expr.TermExpr, filter) + } + + return true +} + +func doSparseFilter(seg Segment, plan *planpb.PlanNode) bool { + queryPlan := plan.GetQuery() + if queryPlan == nil { + // do nothing if current plan not the query plan + return true + } + + pexpr := queryPlan.GetPredicates() + if pexpr == nil { + return true + } + + return exprWalker(pexpr, func(pk storage.PrimaryKey, op planpb.OpType) bool { + noFilter := true + existMinMax := seg.GetMinPk() != nil && seg.GetMaxPk() != nil + var minPk, maxPk storage.PrimaryKey + if existMinMax { + minPk = *seg.GetMinPk() + maxPk = *seg.GetMaxPk() + } + + switch op { + case planpb.OpType_Equal: + + // bloom filter + existBF := seg.BloomFilterExist() + if existBF { + lc := storage.NewLocationsCache(pk) + // BloomFilter contains this key, no filter here + noFilter = seg.MayPkExist(lc) + } + + // no need check min/max again + if !noFilter { + break + } + + // min/max filter + noFilter = !(existMinMax && (minPk.GT(pk) || maxPk.LT(pk))) + case planpb.OpType_GreaterThan: + noFilter = !(existMinMax && maxPk.LE(pk)) + case planpb.OpType_GreaterEqual: + noFilter = !(existMinMax && maxPk.LT(pk)) + case planpb.OpType_LessThan: + noFilter = !(existMinMax && minPk.GE(pk)) + case planpb.OpType_LessEqual: + noFilter = !(existMinMax && minPk.GT(pk)) + } + + return noFilter + }) +} + +type SegmentSparseFilter SegmentType + +func WithSparseFilter(plan *planpb.PlanNode) SegmentFilter { + return SegmentFilterFunc(func(segment Segment) bool { + if plan == nil { + log.Debug("SparseFilter with nil plan") + return true + } + + rc := doSparseFilter(segment, plan) + + log.Debug("SparseFilter", + zap.Int64("Segment ID", segment.ID()), + zap.Bool("No Filter", rc), + zap.Bool("Exist BF", segment.BloomFilterExist())) + return rc + }) +} diff --git a/internal/querynodev2/segments/segment_interface.go b/internal/querynodev2/segments/segment_interface.go index 692cea0726..43473d0135 100644 --- a/internal/querynodev2/segments/segment_interface.go +++ b/internal/querynodev2/segments/segment_interface.go @@ -20,6 +20,7 @@ import ( "context" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + pkoracle "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" @@ -88,10 +89,16 @@ type Segment interface { Release(ctx context.Context, opts ...releaseOption) // Bloom filter related + SetBloomFilter(bf *pkoracle.BloomFilterSet) + BloomFilterExist() bool UpdateBloomFilter(pks []storage.PrimaryKey) MayPkExist(lc *storage.LocationsCache) bool BatchPkExist(lc *storage.BatchLocationsCache) []bool + // Get min/max + GetMinPk() *storage.PrimaryKey + GetMaxPk() *storage.PrimaryKey + // BM25 stats UpdateBM25Stats(stats map[int64]*storage.BM25Stats) GetBM25Stats() map[int64]*storage.BM25Stats diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index dfcc5a9ffc..d324df2624 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -81,7 +81,7 @@ type Loader interface { LoadDeltaLogs(ctx context.Context, segment Segment, deltaLogs []*datapb.FieldBinlog) error // LoadBloomFilterSet loads needed statslog for RemoteSegment. - LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) + LoadBloomFilterSet(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) // LoadBM25Stats loads BM25 statslog for RemoteSegment LoadBM25Stats(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) (*typeutil.ConcurrentMap[int64, map[int64]*storage.BM25Stats], error) @@ -367,6 +367,15 @@ func (loader *segmentLoader) Load(ctx context.Context, return errors.Wrap(err, "At LoadDeltaLogs") } + if !segment.BloomFilterExist() { + log.Debug("BloomFilterExist", zap.Int64("segid", segment.ID())) + bfs, err := loader.loadSingleBloomFilterSet(ctx, loadInfo.GetCollectionID(), loadInfo, segment.Type()) + if err != nil { + return errors.Wrap(err, "At LoadBloomFilter") + } + segment.SetBloomFilter(bfs) + } + if err = segment.FinishLoad(); err != nil { return errors.Wrap(err, "At FinishLoad") } @@ -635,7 +644,42 @@ func (loader *segmentLoader) LoadBM25Stats(ctx context.Context, collectionID int return loadedStats, nil } -func (loader *segmentLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, version int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) { +// load single bloom filter +func (loader *segmentLoader) loadSingleBloomFilterSet(ctx context.Context, collectionID int64, loadInfo *querypb.SegmentLoadInfo, segtype SegmentType) (*pkoracle.BloomFilterSet, error) { + log := log.Ctx(ctx).With( + zap.Int64("collectionID", collectionID), + zap.Int64("segmentIDs", loadInfo.GetSegmentID())) + + collection := loader.manager.Collection.Get(collectionID) + if collection == nil { + err := merr.WrapErrCollectionNotFound(collectionID) + log.Warn("failed to get collection while loading segment", zap.Error(err)) + return nil, err + } + pkField := GetPkField(collection.Schema()) + + log.Info("start loading remote...", zap.Int("segmentNum", 1)) + + partitionID := loadInfo.PartitionID + segmentID := loadInfo.SegmentID + bfs := pkoracle.NewBloomFilterSet(segmentID, partitionID, segtype) + + log.Info("loading bloom filter for remote...") + pkStatsBinlogs, logType := loader.filterPKStatsBinlogs(loadInfo.Statslogs, pkField.GetFieldID()) + err := loader.loadBloomFilter(ctx, segmentID, bfs, pkStatsBinlogs, logType) + if err != nil { + log.Warn("load remote segment bloom filter failed", + zap.Int64("partitionID", partitionID), + zap.Int64("segmentID", segmentID), + zap.Error(err), + ) + return nil, err + } + + return bfs, nil +} + +func (loader *segmentLoader) LoadBloomFilterSet(ctx context.Context, collectionID int64, infos ...*querypb.SegmentLoadInfo) ([]*pkoracle.BloomFilterSet, error) { log := log.Ctx(ctx).With( zap.Int64("collectionID", collectionID), zap.Int64s("segmentIDs", lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) int64 { diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index 1dd7e037e7..d87c9945da 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -238,12 +238,14 @@ func (suite *SegmentLoaderSuite) TestLoadMultipleSegments() { segments, err := suite.loader.Load(ctx, suite.collectionID, SegmentTypeSealed, 0, loadInfos...) suite.NoError(err) - // Won't load bloom filter with sealed segments + // Will load bloom filter with sealed segments for _, segment := range segments { for pk := 0; pk < 100; pk++ { lc := storage.NewLocationsCache(storage.NewInt64PrimaryKey(int64(pk))) - exist := segment.MayPkExist(lc) - suite.Require().False(exist) + exist := segment.BloomFilterExist() + suite.Require().True(exist) + exist = segment.MayPkExist(lc) + suite.Require().True(exist) } } @@ -277,7 +279,9 @@ func (suite *SegmentLoaderSuite) TestLoadMultipleSegments() { for _, segment := range segments { for pk := 0; pk < 100; pk++ { lc := storage.NewLocationsCache(storage.NewInt64PrimaryKey(int64(pk))) - exist := segment.MayPkExist(lc) + exist := segment.BloomFilterExist() + suite.True(exist) + exist = segment.MayPkExist(lc) suite.True(exist) } } @@ -363,7 +367,7 @@ func (suite *SegmentLoaderSuite) TestLoadBloomFilter() { }) } - bfs, err := suite.loader.LoadBloomFilterSet(ctx, suite.collectionID, 0, loadInfos...) + bfs, err := suite.loader.LoadBloomFilterSet(ctx, suite.collectionID, loadInfos...) suite.NoError(err) for _, bf := range bfs { diff --git a/internal/querynodev2/segments/validate.go b/internal/querynodev2/segments/validate.go index 53a20c27fd..6bcfd46c7d 100644 --- a/internal/querynodev2/segments/validate.go +++ b/internal/querynodev2/segments/validate.go @@ -25,7 +25,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/util/merr" ) -func validate(ctx context.Context, manager *Manager, collectionID int64, partitionIDs []int64, segmentIDs []int64, segmentFilter SegmentFilter) ([]Segment, error) { +func validate(ctx context.Context, manager *Manager, collectionID int64, partitionIDs []int64, segmentIDs []int64, segmentFilter ...SegmentFilter) ([]Segment, error) { collection := manager.Collection.Get(collectionID) if collection == nil { return nil, merr.WrapErrCollectionNotFound(collectionID) @@ -43,14 +43,16 @@ func validate(ctx context.Context, manager *Manager, collectionID int64, partiti }() if len(segmentIDs) == 0 { // legacy logic - segments, err = manager.Segment.GetAndPinBy(segmentFilter, SegmentFilterFunc(func(s Segment) bool { + segmentFilter = append(segmentFilter, SegmentFilterFunc(func(s Segment) bool { return s.Collection() == collectionID })) + + segments, err = manager.Segment.GetAndPinBy(segmentFilter...) if err != nil { return nil, err } } else { - segments, err = manager.Segment.GetAndPin(segmentIDs, segmentFilter) + segments, err = manager.Segment.GetAndPin(segmentIDs, segmentFilter...) if err != nil { return nil, err } diff --git a/internal/querynodev2/tasks/query_task.go b/internal/querynodev2/tasks/query_task.go index a81d59873b..296276d081 100644 --- a/internal/querynodev2/tasks/query_task.go +++ b/internal/querynodev2/tasks/query_task.go @@ -9,6 +9,7 @@ import ( "github.com/samber/lo" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/querynodev2/segments" @@ -16,6 +17,7 @@ import ( "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/segcorepb" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -36,6 +38,7 @@ func NewQueryTask(ctx context.Context, ctx: ctx, collection: collection, segmentManager: manager, + plan: &planpb.PlanNode{}, req: req, notifier: make(chan error, 1), tr: timerecord.NewTimeRecorderWithTrace(ctx, "queryTask"), @@ -48,6 +51,7 @@ type QueryTask struct { collection *segments.Collection segmentManager *segments.Manager req *querypb.QueryRequest + plan *planpb.PlanNode // use to do the bloom filter result *internalpb.RetrieveResults notifier chan error tr *timerecord.TimeRecorder @@ -87,6 +91,9 @@ func (t *QueryTask) PreExecute() error { username). Observe(inQueueDurationMS) + // Unmarshal the origin plan + proto.Unmarshal(t.req.Req.GetSerializedExprPlan(), t.plan) + return nil } @@ -113,7 +120,8 @@ func (t *QueryTask) Execute() error { return err } defer retrievePlan.Delete() - results, pinnedSegments, err := segments.Retrieve(t.ctx, t.segmentManager, retrievePlan, t.req) + + results, pinnedSegments, err := segments.Retrieve(t.ctx, t.segmentManager, retrievePlan, t.req, t.plan) defer t.segmentManager.Segment.Unpin(pinnedSegments) if err != nil { return err diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 636709a980..be545fb828 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -3159,6 +3159,7 @@ type queryNodeConfig struct { QueryStreamMaxBatchSize ParamItem `refreshable:"false"` // BF + EnableSparseFilterInQuery ParamItem `refreshable:"true"` SkipGrowingSegmentBF ParamItem `refreshable:"true"` BloomFilterApplyParallelFactor ParamItem `refreshable:"true"` @@ -4277,6 +4278,14 @@ user-task-polling: } p.BloomFilterApplyParallelFactor.Init(base.mgr) + p.EnableSparseFilterInQuery = ParamItem{ + Key: "queryNode.enableSparseFilterInQuery", + Version: "2.6.2", + DefaultValue: "true", + Doc: "Enable use sparse filter in query.", + } + p.EnableSparseFilterInQuery.Init(base.mgr) + p.SkipGrowingSegmentBF = ParamItem{ Key: "queryNode.skipGrowingSegmentBF", Version: "2.5", diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index 8377f7263f..f47bdc5e52 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -502,6 +502,7 @@ func TestComponentParam(t *testing.T) { assert.Equal(t, 2, Params.BloomFilterApplyParallelFactor.GetAsInt()) assert.Equal(t, true, Params.SkipGrowingSegmentBF.GetAsBool()) + assert.Equal(t, true, Params.EnableSparseFilterInQuery.GetAsBool()) assert.Equal(t, "/var/lib/milvus/data/mmap", Params.MmapDirPath.GetValue()) diff --git a/tests/python_client/milvus_client/test_milvus_client_query.py b/tests/python_client/milvus_client/test_milvus_client_query.py index ce5a04cee1..e2c5d6f6da 100644 --- a/tests/python_client/milvus_client/test_milvus_client_query.py +++ b/tests/python_client/milvus_client/test_milvus_client_query.py @@ -3824,6 +3824,62 @@ class TestQueryOperation(TestMilvusClientV2Base): assert res_one == res_two, "Query results should be identical when querying the same partition repeatedly" self.drop_collection(client, collection_name) + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_query_with_bloom_filter(self): + """ + target: test query with bloom filter in PK + method: compare time with filter by bloom filter and without bloom filter + expected: query with bloom filter in PK should faster than without bloom filter + """ + client = self._client() + collection_name = cf.gen_collection_name_by_testcase_name() + # 1. create collection + schema = self.create_schema(client, enable_dynamic_field=True)[0] + schema.add_field(default_primary_key_field_name, DataType.INT64, is_primary=True, auto_id=False) + schema.add_field(default_vector_field_name, DataType.FLOAT_VECTOR, dim=5) + self.create_collection(client, collection_name=collection_name, schema=schema) + + index_params = client.prepare_index_params() + index_params.add_index(field_name=default_vector_field_name, index_type="HNSW", metric_type="L2") + self.create_index(client, collection_name, index_params) + + # 2. insert data + schema_info = self.describe_collection(client, collection_name)[0] + insert_offset = 0 + insert_nb = 1000 + for i in range(10): + rows = cf.gen_row_data_by_schema(nb=insert_nb, schema=schema_info, start=insert_offset) + self.insert(client, collection_name, rows) + self.flush(client, collection_name) + insert_offset += insert_nb + + # 3. load + self.load_collection(client, collection_name) + + # 4. query with bloom filter and without bloom filter + start_time = time.perf_counter() + res = self.query(client, collection_name=collection_name, + filter=f"{default_primary_key_field_name} != -1", output_fields=["count(*)"] + )[0] + end_time = time.perf_counter() + run_time1 = end_time - start_time + + # with bloom filter + start_time = time.perf_counter() + res = self.query(client, collection_name=collection_name, + filter=f"{default_primary_key_field_name} == -1", output_fields=["count(*)"] + )[0] + end_time = time.perf_counter() + run_time2 = end_time - start_time + + print(f"rt1: {run_time1}s rt2: {run_time2}s") + log.info(f"rt1: {run_time1}s rt2: {run_time2}s") + + # 5. verify without bloom filter should slower than with bloom filter + assert run_time1 > run_time2 + + # 6. clean up + self.drop_collection(client, collection_name) class TestMilvusClientGetInvalid(TestMilvusClientV2Base): """ Test case of search interface """