diff --git a/internal/querynodev2/segments/manager.go b/internal/querynodev2/segments/manager.go index c732ce401f..59919ce64c 100644 --- a/internal/querynodev2/segments/manager.go +++ b/internal/querynodev2/segments/manager.go @@ -25,6 +25,7 @@ package segments import "C" import ( + "context" "fmt" "sync" @@ -64,6 +65,26 @@ func WithID(id int64) SegmentFilter { } } +type SegmentAction func(segment Segment) bool + +func IncreaseVersion(version int64) SegmentAction { + return func(segment Segment) bool { + log := log.Ctx(context.Background()).With( + zap.Int64("segmentID", segment.ID()), + zap.String("type", segment.Type().String()), + zap.Int64("segmentVersion", segment.Version()), + zap.Int64("updateVersion", version), + ) + for oldVersion := segment.Version(); oldVersion < version; { + if segment.CASVersion(oldVersion, version) { + return true + } + } + log.Warn("segment version cannot go backwards, skip update") + return false + } +} + type actionType int32 const ( @@ -95,6 +116,9 @@ type SegmentManager interface { GetAndPinBy(filters ...SegmentFilter) ([]Segment, error) GetAndPin(segments []int64, filters ...SegmentFilter) ([]Segment, error) Unpin(segments []Segment) + + UpdateSegmentBy(action SegmentAction, filters ...SegmentFilter) int + GetSealed(segmentID UniqueID) Segment GetGrowing(segmentID UniqueID) Segment Empty() bool @@ -105,8 +129,6 @@ type SegmentManager interface { Remove(segmentID UniqueID, scope querypb.DataScope) (int, int) RemoveBy(filters ...SegmentFilter) (int, int) Clear() - - UpdateSegmentVersion(segmentType SegmentType, segmentID int64, newVersion int64) } var _ SegmentManager = (*segmentManager)(nil) @@ -176,34 +198,27 @@ func (mgr *segmentManager) Put(segmentType SegmentType, segments ...Segment) { mgr.updateMetric() } -func (mgr *segmentManager) UpdateSegmentVersion(segmentType SegmentType, segmentID int64, newVersion int64) { - mgr.mu.Lock() - defer mgr.mu.Unlock() +func (mgr *segmentManager) UpdateSegmentBy(action SegmentAction, filters ...SegmentFilter) int { + mgr.mu.RLock() + defer mgr.mu.RUnlock() - segment, ok := mgr.sealedSegments[segmentID] - if !ok { - segment, ok = mgr.growingSegments[segmentID] + updated := 0 + for _, segment := range mgr.growingSegments { + if filter(segment, filters...) { + if action(segment) { + updated++ + } + } } - if !ok { - log.Warn("segment not exist, skip segment version change", - zap.Any("type", segmentType), - zap.Int64("segmentID", segmentID), - zap.Int64("newVersion", newVersion), - ) - return + for _, segment := range mgr.sealedSegments { + if filter(segment, filters...) { + if action(segment) { + updated++ + } + } } - - if segment.Version() >= newVersion { - log.Warn("Invalid segment version changed, skip it", - zap.Int64("segmentID", segment.ID()), - zap.Any("type", segmentType), - zap.Int64("oldVersion", segment.Version()), - zap.Int64("newVersion", newVersion)) - return - } - - segment.UpdateVersion(newVersion) + return updated } func (mgr *segmentManager) Get(segmentID UniqueID) Segment { diff --git a/internal/querynodev2/segments/manager_test.go b/internal/querynodev2/segments/manager_test.go index 69d5f5251d..5a8a374355 100644 --- a/internal/querynodev2/segments/manager_test.go +++ b/internal/querynodev2/segments/manager_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -111,6 +112,39 @@ func (s *ManagerSuite) TestRemoveBy() { } } +func (s *ManagerSuite) TestUpdateBy() { + action := IncreaseVersion(1) + + s.Equal(2, s.mgr.UpdateSegmentBy(action, WithType(SegmentTypeSealed))) + s.Equal(1, s.mgr.UpdateSegmentBy(action, WithType(SegmentTypeGrowing))) + + segments := s.mgr.GetBy() + for _, segment := range segments { + s.Equal(int64(1), segment.Version()) + } +} + +func (s *ManagerSuite) TestIncreaseVersion() { + action := IncreaseVersion(1) + + segment := NewMockSegment(s.T()) + segment.EXPECT().ID().Return(100) + segment.EXPECT().Type().Return(commonpb.SegmentState_Sealed) + segment.EXPECT().Version().Return(1) + + s.False(action(segment), "version already gte version") + segment.AssertExpectations(s.T()) + + segment = NewMockSegment(s.T()) + segment.EXPECT().ID().Return(100) + segment.EXPECT().Type().Return(commonpb.SegmentState_Sealed) + segment.EXPECT().Version().Return(0) + segment.EXPECT().CASVersion(int64(0), int64(1)).Return(true) + + s.True(action(segment), "version lt execute CAS") + segment.AssertExpectations(s.T()) +} + func TestManager(t *testing.T) { suite.Run(t, new(ManagerSuite)) } diff --git a/internal/querynodev2/segments/mock_segment.go b/internal/querynodev2/segments/mock_segment.go index 3f6f9212cd..a06d54c9f5 100644 --- a/internal/querynodev2/segments/mock_segment.go +++ b/internal/querynodev2/segments/mock_segment.go @@ -60,6 +60,49 @@ func (_c *MockSegment_AddIndex_Call) RunAndReturn(run func(int64, *IndexedFieldI return _c } +// CASVersion provides a mock function with given fields: _a0, _a1 +func (_m *MockSegment) CASVersion(_a0 int64, _a1 int64) bool { + ret := _m.Called(_a0, _a1) + + var r0 bool + if rf, ok := ret.Get(0).(func(int64, int64) bool); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockSegment_CASVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CASVersion' +type MockSegment_CASVersion_Call struct { + *mock.Call +} + +// CASVersion is a helper method to define mock.On call +// - _a0 int64 +// - _a1 int64 +func (_e *MockSegment_Expecter) CASVersion(_a0 interface{}, _a1 interface{}) *MockSegment_CASVersion_Call { + return &MockSegment_CASVersion_Call{Call: _e.mock.On("CASVersion", _a0, _a1)} +} + +func (_c *MockSegment_CASVersion_Call) Run(run func(_a0 int64, _a1 int64)) *MockSegment_CASVersion_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64), args[1].(int64)) + }) + return _c +} + +func (_c *MockSegment_CASVersion_Call) Return(_a0 bool) *MockSegment_CASVersion_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSegment_CASVersion_Call) RunAndReturn(run func(int64, int64) bool) *MockSegment_CASVersion_Call { + _c.Call.Return(run) + return _c +} + // Collection provides a mock function with given fields: func (_m *MockSegment) Collection() int64 { ret := _m.Called() @@ -836,39 +879,6 @@ func (_c *MockSegment_UpdateBloomFilter_Call) RunAndReturn(run func([]storage.Pr return _c } -// UpdateVersion provides a mock function with given fields: version -func (_m *MockSegment) UpdateVersion(version int64) { - _m.Called(version) -} - -// MockSegment_UpdateVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateVersion' -type MockSegment_UpdateVersion_Call struct { - *mock.Call -} - -// UpdateVersion is a helper method to define mock.On call -// - version int64 -func (_e *MockSegment_Expecter) UpdateVersion(version interface{}) *MockSegment_UpdateVersion_Call { - return &MockSegment_UpdateVersion_Call{Call: _e.mock.On("UpdateVersion", version)} -} - -func (_c *MockSegment_UpdateVersion_Call) Run(run func(version int64)) *MockSegment_UpdateVersion_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64)) - }) - return _c -} - -func (_c *MockSegment_UpdateVersion_Call) Return() *MockSegment_UpdateVersion_Call { - _c.Call.Return() - return _c -} - -func (_c *MockSegment_UpdateVersion_Call) RunAndReturn(run func(int64)) *MockSegment_UpdateVersion_Call { - _c.Call.Return(run) - return _c -} - // Version provides a mock function with given fields: func (_m *MockSegment) Version() int64 { ret := _m.Called() diff --git a/internal/querynodev2/segments/mock_segment_manager.go b/internal/querynodev2/segments/mock_segment_manager.go index bc6a37be3b..7c2413712f 100644 --- a/internal/querynodev2/segments/mock_segment_manager.go +++ b/internal/querynodev2/segments/mock_segment_manager.go @@ -664,37 +664,59 @@ func (_c *MockSegmentManager_Unpin_Call) RunAndReturn(run func([]Segment)) *Mock return _c } -// UpdateSegmentVersion provides a mock function with given fields: segmentType, segmentID, newVersion -func (_m *MockSegmentManager) UpdateSegmentVersion(segmentType commonpb.SegmentState, segmentID int64, newVersion int64) { - _m.Called(segmentType, segmentID, newVersion) +// UpdateSegmentBy provides a mock function with given fields: action, filters +func (_m *MockSegmentManager) UpdateSegmentBy(action SegmentAction, filters ...SegmentFilter) int { + _va := make([]interface{}, len(filters)) + for _i := range filters { + _va[_i] = filters[_i] + } + var _ca []interface{} + _ca = append(_ca, action) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 int + if rf, ok := ret.Get(0).(func(SegmentAction, ...SegmentFilter) int); ok { + r0 = rf(action, filters...) + } else { + r0 = ret.Get(0).(int) + } + + return r0 } -// MockSegmentManager_UpdateSegmentVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateSegmentVersion' -type MockSegmentManager_UpdateSegmentVersion_Call struct { +// MockSegmentManager_UpdateSegmentBy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateSegmentBy' +type MockSegmentManager_UpdateSegmentBy_Call struct { *mock.Call } -// UpdateSegmentVersion is a helper method to define mock.On call -// - segmentType commonpb.SegmentState -// - segmentID int64 -// - newVersion int64 -func (_e *MockSegmentManager_Expecter) UpdateSegmentVersion(segmentType interface{}, segmentID interface{}, newVersion interface{}) *MockSegmentManager_UpdateSegmentVersion_Call { - return &MockSegmentManager_UpdateSegmentVersion_Call{Call: _e.mock.On("UpdateSegmentVersion", segmentType, segmentID, newVersion)} +// UpdateSegmentBy is a helper method to define mock.On call +// - action SegmentAction +// - filters ...SegmentFilter +func (_e *MockSegmentManager_Expecter) UpdateSegmentBy(action interface{}, filters ...interface{}) *MockSegmentManager_UpdateSegmentBy_Call { + return &MockSegmentManager_UpdateSegmentBy_Call{Call: _e.mock.On("UpdateSegmentBy", + append([]interface{}{action}, filters...)...)} } -func (_c *MockSegmentManager_UpdateSegmentVersion_Call) Run(run func(segmentType commonpb.SegmentState, segmentID int64, newVersion int64)) *MockSegmentManager_UpdateSegmentVersion_Call { +func (_c *MockSegmentManager_UpdateSegmentBy_Call) Run(run func(action SegmentAction, filters ...SegmentFilter)) *MockSegmentManager_UpdateSegmentBy_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(commonpb.SegmentState), args[1].(int64), args[2].(int64)) + variadicArgs := make([]SegmentFilter, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(SegmentFilter) + } + } + run(args[0].(SegmentAction), variadicArgs...) }) return _c } -func (_c *MockSegmentManager_UpdateSegmentVersion_Call) Return() *MockSegmentManager_UpdateSegmentVersion_Call { - _c.Call.Return() +func (_c *MockSegmentManager_UpdateSegmentBy_Call) Return(_a0 int) *MockSegmentManager_UpdateSegmentBy_Call { + _c.Call.Return(_a0) return _c } -func (_c *MockSegmentManager_UpdateSegmentVersion_Call) RunAndReturn(run func(commonpb.SegmentState, int64, int64)) *MockSegmentManager_UpdateSegmentVersion_Call { +func (_c *MockSegmentManager_UpdateSegmentBy_Call) RunAndReturn(run func(SegmentAction, ...SegmentFilter) int) *MockSegmentManager_UpdateSegmentBy_Call { _c.Call.Return(run) return _c } diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 927931db71..5c141f1faf 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -126,8 +126,8 @@ func (s *baseSegment) Version() int64 { return s.version.Load() } -func (s *baseSegment) UpdateVersion(version int64) { - s.version.Store(version) +func (s *baseSegment) CASVersion(old, newVersion int64) bool { + return s.version.CompareAndSwap(old, newVersion) } func (s *baseSegment) UpdateBloomFilter(pks []storage.PrimaryKey) { diff --git a/internal/querynodev2/segments/segment_interface.go b/internal/querynodev2/segments/segment_interface.go index d33e1e8d40..1be88aadde 100644 --- a/internal/querynodev2/segments/segment_interface.go +++ b/internal/querynodev2/segments/segment_interface.go @@ -30,6 +30,7 @@ type Segment interface { Partition() int64 Shard() string Version() int64 + CASVersion(int64, int64) bool StartPosition() *msgpb.MsgPosition Type() SegmentType RLock() error @@ -56,5 +57,4 @@ type Segment interface { // Bloom filter related UpdateBloomFilter(pks []storage.PrimaryKey) MayPkExist(pk storage.PrimaryKey) bool - UpdateVersion(version int64) } diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index 1fbd2454da..a674782e5e 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -260,7 +260,8 @@ func (loader *segmentLoader) prepare(segmentType SegmentType, version int64, seg loader.loadingSegments.Insert(segment.GetSegmentID(), make(chan struct{})) } else { // try to update segment version before skip load operation - loader.manager.Segment.UpdateSegmentVersion(segmentType, segment.SegmentID, version) + loader.manager.Segment.UpdateSegmentBy(IncreaseVersion(version), + WithType(segmentType), WithID(segment.SegmentID)) log.Info("skip loaded/loading segment", zap.Int64("segmentID", segment.GetSegmentID()), zap.Bool("isLoaded", len(loader.manager.Segment.GetBy(WithType(segmentType), WithID(segment.GetSegmentID()))) > 0), zap.Bool("isLoading", loader.loadingSegments.Contain(segment.GetSegmentID())), diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go index 6d6099a46b..9788b00016 100644 --- a/internal/querynodev2/segments/segment_test.go +++ b/internal/querynodev2/segments/segment_test.go @@ -198,6 +198,17 @@ func (suite *SegmentSuite) TestValidateIndexedFieldsData() { suite.Error(err) } +func (suite *SegmentSuite) TestCASVersion() { + segment := suite.sealed + + curVersion := segment.Version() + suite.False(segment.CASVersion(curVersion-1, curVersion+1)) + suite.NotEqual(curVersion+1, segment.Version()) + + suite.True(segment.CASVersion(curVersion, curVersion+1)) + suite.Equal(curVersion+1, segment.Version()) +} + func (suite *SegmentSuite) TestSegmentReleased() { DeleteSegment(suite.sealed)