diff --git a/internal/querynodev2/segments/manager.go b/internal/querynodev2/segments/manager.go index 778c96dca9..205db118f3 100644 --- a/internal/querynodev2/segments/manager.go +++ b/internal/querynodev2/segments/manager.go @@ -173,9 +173,7 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) { zap.Int64("newVersion", segment.Version()), ) // delete redundant segment - if s, ok := segment.(*LocalSegment); ok { - DeleteSegment(s) - } + segment.Release() continue } replacedSegment = append(replacedSegment, oldSegment) @@ -206,7 +204,7 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) { if len(replacedSegment) > 0 { go func() { for _, segment := range replacedSegment { - remove(segment.(*LocalSegment)) + remove(segment) } }() } @@ -411,7 +409,7 @@ func (mgr *segmentManager) Remove(segmentID UniqueID, scope querypb.DataScope) ( mgr.mu.Lock() var removeGrowing, removeSealed int - var growing, sealed *LocalSegment + var growing, sealed Segment switch scope { case querypb.DataScope_Streaming: growing = mgr.removeSegmentWithType(SegmentTypeGrowing, segmentID) @@ -450,20 +448,20 @@ func (mgr *segmentManager) Remove(segmentID UniqueID, scope querypb.DataScope) ( return removeGrowing, removeSealed } -func (mgr *segmentManager) removeSegmentWithType(typ SegmentType, segmentID UniqueID) *LocalSegment { +func (mgr *segmentManager) removeSegmentWithType(typ SegmentType, segmentID UniqueID) Segment { switch typ { case SegmentTypeGrowing: s, ok := mgr.growingSegments[segmentID] if ok { delete(mgr.growingSegments, segmentID) - return s.(*LocalSegment) + return s } case SegmentTypeSealed: s, ok := mgr.sealedSegments[segmentID] if ok { delete(mgr.sealedSegments, segmentID) - return s.(*LocalSegment) + return s } default: return nil @@ -475,7 +473,7 @@ func (mgr *segmentManager) removeSegmentWithType(typ SegmentType, segmentID Uniq func (mgr *segmentManager) RemoveBy(filters ...SegmentFilter) (int, int) { mgr.mu.Lock() - var removeGrowing, removeSealed []*LocalSegment + var removeGrowing, removeSealed []Segment for id, segment := range mgr.growingSegments { if filter(segment, filters...) { s := mgr.removeSegmentWithType(SegmentTypeGrowing, id) @@ -513,12 +511,12 @@ func (mgr *segmentManager) Clear() { for id, segment := range mgr.growingSegments { delete(mgr.growingSegments, id) - remove(segment.(*LocalSegment)) + remove(segment) } for id, segment := range mgr.sealedSegments { delete(mgr.sealedSegments, id) - remove(segment.(*LocalSegment)) + remove(segment) } mgr.updateMetric() } @@ -538,9 +536,9 @@ func (mgr *segmentManager) updateMetric() { metrics.QueryNodeNumPartitions.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(partiations.Len())) } -func remove(segment *LocalSegment) bool { +func remove(segment Segment) bool { rowNum := segment.RowNum() - DeleteSegment(segment) + segment.Release() metrics.QueryNodeNumSegments.WithLabelValues( fmt.Sprint(paramtable.GetNodeID()), diff --git a/internal/querynodev2/segments/mock_segment.go b/internal/querynodev2/segments/mock_segment.go index a06d54c9f5..b2bb133c56 100644 --- a/internal/querynodev2/segments/mock_segment.go +++ b/internal/querynodev2/segments/mock_segment.go @@ -3,7 +3,10 @@ package segments import ( + context "context" + commonpb "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + mock "github.com/stretchr/testify/mock" msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" @@ -680,6 +683,93 @@ func (_c *MockSegment_RUnlock_Call) RunAndReturn(run func()) *MockSegment_RUnloc return _c } +// Release provides a mock function with given fields: +func (_m *MockSegment) Release() { + _m.Called() +} + +// MockSegment_Release_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Release' +type MockSegment_Release_Call struct { + *mock.Call +} + +// Release is a helper method to define mock.On call +func (_e *MockSegment_Expecter) Release() *MockSegment_Release_Call { + return &MockSegment_Release_Call{Call: _e.mock.On("Release")} +} + +func (_c *MockSegment_Release_Call) Run(run func()) *MockSegment_Release_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSegment_Release_Call) Return() *MockSegment_Release_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSegment_Release_Call) RunAndReturn(run func()) *MockSegment_Release_Call { + _c.Call.Return(run) + return _c +} + +// Retrieve provides a mock function with given fields: ctx, plan +func (_m *MockSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { + ret := _m.Called(ctx, plan) + + var r0 *segcorepb.RetrieveResults + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan) (*segcorepb.RetrieveResults, error)); ok { + return rf(ctx, plan) + } + if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan) *segcorepb.RetrieveResults); ok { + r0 = rf(ctx, plan) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*segcorepb.RetrieveResults) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *RetrievePlan) error); ok { + r1 = rf(ctx, plan) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSegment_Retrieve_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Retrieve' +type MockSegment_Retrieve_Call struct { + *mock.Call +} + +// Retrieve is a helper method to define mock.On call +// - ctx context.Context +// - plan *RetrievePlan +func (_e *MockSegment_Expecter) Retrieve(ctx interface{}, plan interface{}) *MockSegment_Retrieve_Call { + return &MockSegment_Retrieve_Call{Call: _e.mock.On("Retrieve", ctx, plan)} +} + +func (_c *MockSegment_Retrieve_Call) Run(run func(ctx context.Context, plan *RetrievePlan)) *MockSegment_Retrieve_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*RetrievePlan)) + }) + return _c +} + +func (_c *MockSegment_Retrieve_Call) Return(_a0 *segcorepb.RetrieveResults, _a1 error) *MockSegment_Retrieve_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSegment_Retrieve_Call) RunAndReturn(run func(context.Context, *RetrievePlan) (*segcorepb.RetrieveResults, error)) *MockSegment_Retrieve_Call { + _c.Call.Return(run) + return _c +} + // RowNum provides a mock function with given fields: func (_m *MockSegment) RowNum() int64 { ret := _m.Called() @@ -721,6 +811,61 @@ func (_c *MockSegment_RowNum_Call) RunAndReturn(run func() int64) *MockSegment_R return _c } +// Search provides a mock function with given fields: ctx, searchReq +func (_m *MockSegment) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) { + ret := _m.Called(ctx, searchReq) + + var r0 *SearchResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *SearchRequest) (*SearchResult, error)); ok { + return rf(ctx, searchReq) + } + if rf, ok := ret.Get(0).(func(context.Context, *SearchRequest) *SearchResult); ok { + r0 = rf(ctx, searchReq) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*SearchResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *SearchRequest) error); ok { + r1 = rf(ctx, searchReq) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockSegment_Search_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Search' +type MockSegment_Search_Call struct { + *mock.Call +} + +// Search is a helper method to define mock.On call +// - ctx context.Context +// - searchReq *SearchRequest +func (_e *MockSegment_Expecter) Search(ctx interface{}, searchReq interface{}) *MockSegment_Search_Call { + return &MockSegment_Search_Call{Call: _e.mock.On("Search", ctx, searchReq)} +} + +func (_c *MockSegment_Search_Call) Run(run func(ctx context.Context, searchReq *SearchRequest)) *MockSegment_Search_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*SearchRequest)) + }) + return _c +} + +func (_c *MockSegment_Search_Call) Return(_a0 *SearchResult, _a1 error) *MockSegment_Search_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSegment_Search_Call) RunAndReturn(run func(context.Context, *SearchRequest) (*SearchResult, error)) *MockSegment_Search_Call { + _c.Call.Return(run) + return _c +} + // Shard provides a mock function with given fields: func (_m *MockSegment) Shard() string { ret := _m.Called() @@ -879,6 +1024,49 @@ func (_c *MockSegment_UpdateBloomFilter_Call) RunAndReturn(run func([]storage.Pr return _c } +// ValidateIndexedFieldsData provides a mock function with given fields: ctx, result +func (_m *MockSegment) ValidateIndexedFieldsData(ctx context.Context, result *segcorepb.RetrieveResults) error { + ret := _m.Called(ctx, result) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *segcorepb.RetrieveResults) error); ok { + r0 = rf(ctx, result) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockSegment_ValidateIndexedFieldsData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ValidateIndexedFieldsData' +type MockSegment_ValidateIndexedFieldsData_Call struct { + *mock.Call +} + +// ValidateIndexedFieldsData is a helper method to define mock.On call +// - ctx context.Context +// - result *segcorepb.RetrieveResults +func (_e *MockSegment_Expecter) ValidateIndexedFieldsData(ctx interface{}, result interface{}) *MockSegment_ValidateIndexedFieldsData_Call { + return &MockSegment_ValidateIndexedFieldsData_Call{Call: _e.mock.On("ValidateIndexedFieldsData", ctx, result)} +} + +func (_c *MockSegment_ValidateIndexedFieldsData_Call) Run(run func(ctx context.Context, result *segcorepb.RetrieveResults)) *MockSegment_ValidateIndexedFieldsData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*segcorepb.RetrieveResults)) + }) + return _c +} + +func (_c *MockSegment_ValidateIndexedFieldsData_Call) Return(_a0 error) *MockSegment_ValidateIndexedFieldsData_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSegment_ValidateIndexedFieldsData_Call) RunAndReturn(run func(context.Context, *segcorepb.RetrieveResults) error) *MockSegment_ValidateIndexedFieldsData_Call { + _c.Call.Return(run) + return _c +} + // Version provides a mock function with given fields: func (_m *MockSegment) Version() int64 { ret := _m.Called() diff --git a/internal/querynodev2/segments/reduce_test.go b/internal/querynodev2/segments/reduce_test.go index d299561e7c..573a3dfe6f 100644 --- a/internal/querynodev2/segments/reduce_test.go +++ b/internal/querynodev2/segments/reduce_test.go @@ -99,7 +99,7 @@ func (suite *ReduceSuite) SetupTest() { } func (suite *ReduceSuite) TearDownTest() { - DeleteSegment(suite.segment) + suite.segment.Release() DeleteCollection(suite.collection) ctx := context.Background() suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath) diff --git a/internal/querynodev2/segments/retrieve.go b/internal/querynodev2/segments/retrieve.go index 019e5e9645..4e6b1195ff 100644 --- a/internal/querynodev2/segments/retrieve.go +++ b/internal/querynodev2/segments/retrieve.go @@ -45,9 +45,8 @@ func retrieveOnSegments(ctx context.Context, segments []Segment, segType Segment for i, segment := range segments { wg.Add(1) - go func(segment Segment, i int) { + go func(seg Segment, i int) { defer wg.Done() - seg := segment.(*LocalSegment) tr := timerecord.NewTimeRecorder("retrieveOnSegments") result, err := seg.Retrieve(ctx, plan) if err != nil { diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index cffc16efab..67397c1c36 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -129,8 +129,8 @@ func (suite *RetrieveSuite) SetupTest() { } func (suite *RetrieveSuite) TearDownTest() { - DeleteSegment(suite.sealed) - DeleteSegment(suite.growing) + suite.sealed.Release() + suite.growing.Release() DeleteCollection(suite.collection) ctx := context.Background() suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath) @@ -179,7 +179,7 @@ func (suite *RetrieveSuite) TestRetrieveNilSegment() { plan, err := genSimpleRetrievePlan(suite.collection) suite.NoError(err) - DeleteSegment(suite.sealed) + suite.sealed.Release() res, segments, err := RetrieveHistorical(context.TODO(), suite.manager, plan, suite.collectionID, []int64{suite.partitionID}, diff --git a/internal/querynodev2/segments/search.go b/internal/querynodev2/segments/search.go index af4b46097a..8bc71ee61d 100644 --- a/internal/querynodev2/segments/search.go +++ b/internal/querynodev2/segments/search.go @@ -52,9 +52,8 @@ func searchSegments(ctx context.Context, segments []Segment, segType SegmentType // calling segment search in goroutines for i, segment := range segments { wg.Add(1) - go func(segment Segment, i int) { + go func(seg Segment, i int) { defer wg.Done() - seg := segment.(*LocalSegment) if !seg.ExistIndex(searchReq.searchFieldID) { mu.Lock() segmentsWithoutIndex = append(segmentsWithoutIndex, seg.ID()) diff --git a/internal/querynodev2/segments/search_test.go b/internal/querynodev2/segments/search_test.go index ee80de6ccd..9c7d257f55 100644 --- a/internal/querynodev2/segments/search_test.go +++ b/internal/querynodev2/segments/search_test.go @@ -122,7 +122,7 @@ func (suite *SearchSuite) SetupTest() { } func (suite *SearchSuite) TearDownTest() { - DeleteSegment(suite.sealed) + suite.sealed.Release() DeleteCollection(suite.collection) ctx := context.Background() suite.chunkManager.RemoveWithPrefix(ctx, paramtable.Get().MinioCfg.RootPath.GetValue()) diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 7fd5a797ac..8cfd9eb197 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -307,31 +307,6 @@ func (s *LocalSegment) Type() SegmentType { return s.typ } -func DeleteSegment(segment *LocalSegment) { - /* - void - deleteSegment(CSegmentInterface segment); - */ - // wait all read ops finished - var ptr C.CSegmentInterface - - segment.ptrLock.Lock() - ptr = segment.ptr - segment.ptr = nil - segment.ptrLock.Unlock() - - if ptr == nil { - return - } - - C.DeleteSegment(ptr) - log.Info("delete segment from memory", - zap.Int64("collectionID", segment.collectionID), - zap.Int64("partitionID", segment.partitionID), - zap.Int64("segmentID", segment.ID()), - zap.String("segmentType", segment.typ.String())) -} - func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) { /* CStatus @@ -892,3 +867,29 @@ func (s *LocalSegment) UpdateFieldRawDataSize(numRows int64, fieldBinlog *datapb return nil } + +func (s *LocalSegment) Release() { + /* + void + deleteSegment(CSegmentInterface segment); + */ + // wait all read ops finished + var ptr C.CSegmentInterface + + s.ptrLock.Lock() + ptr = s.ptr + s.ptr = nil + s.ptrLock.Unlock() + + if ptr == nil { + return + } + + C.DeleteSegment(ptr) + log.Info("delete segment from memory", + zap.Int64("collectionID", s.collectionID), + zap.Int64("partitionID", s.partitionID), + zap.Int64("segmentID", s.ID()), + zap.String("segmentType", s.typ.String()), + ) +} diff --git a/internal/querynodev2/segments/segment_interface.go b/internal/querynodev2/segments/segment_interface.go index 1be88aadde..601223d0a2 100644 --- a/internal/querynodev2/segments/segment_interface.go +++ b/internal/querynodev2/segments/segment_interface.go @@ -17,6 +17,8 @@ package segments import ( + "context" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/proto/segcorepb" storage "github.com/milvus-io/milvus/internal/storage" @@ -57,4 +59,11 @@ type Segment interface { // Bloom filter related UpdateBloomFilter(pks []storage.PrimaryKey) MayPkExist(pk storage.PrimaryKey) bool + + // Read operations + Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) + Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) + ValidateIndexedFieldsData(ctx context.Context, result *segcorepb.RetrieveResults) error + + Release() } diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index d226807768..54c0a100b6 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -190,7 +190,7 @@ func (loader *segmentLoader) Load(ctx context.Context, newSegments := make(map[int64]*LocalSegment, len(infos)) clearAll := func() { for _, s := range newSegments { - DeleteSegment(s) + s.Release() } debug.FreeOSMemory() } diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go index 9788b00016..49d643bc3e 100644 --- a/internal/querynodev2/segments/segment_test.go +++ b/internal/querynodev2/segments/segment_test.go @@ -114,8 +114,8 @@ func (suite *SegmentSuite) SetupTest() { func (suite *SegmentSuite) TearDownTest() { ctx := context.Background() - DeleteSegment(suite.sealed) - DeleteSegment(suite.growing) + suite.sealed.Release() + suite.growing.Release() DeleteCollection(suite.collection) suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath) } @@ -183,7 +183,7 @@ func (suite *SegmentSuite) TestValidateIndexedFieldsData() { suite.NoError(err) // index doesn't have index type - DeleteSegment(suite.sealed) + suite.sealed.Release() suite.True(suite.sealed.ExistIndex(101)) err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result) suite.Error(err) @@ -192,7 +192,7 @@ func (suite *SegmentSuite) TestValidateIndexedFieldsData() { index := suite.sealed.GetIndex(101) _, indexParams := genIndexParams(IndexHNSW, metric.L2) index.IndexInfo.IndexParams = funcutil.Map2KeyValuePair(indexParams) - DeleteSegment(suite.sealed) + suite.sealed.Release() suite.True(suite.sealed.ExistIndex(101)) err = suite.sealed.ValidateIndexedFieldsData(context.Background(), result) suite.Error(err) @@ -210,7 +210,7 @@ func (suite *SegmentSuite) TestCASVersion() { } func (suite *SegmentSuite) TestSegmentReleased() { - DeleteSegment(suite.sealed) + suite.sealed.Release() suite.sealed.ptrLock.RLock() suite.False(suite.sealed.isValid())