From 73adf2a5cc3c81f41a9f0cd7f7af1f70cac6cfc2 Mon Sep 17 00:00:00 2001 From: chyezh Date: Mon, 8 Apr 2024 17:09:16 +0800 Subject: [PATCH] fix: use stateful lock to avoid load and release on LocalSegment concurrently (#31606) issue: #31605 --------- Signed-off-by: chyezh --- internal/querynodev2/segments/manager.go | 12 +- internal/querynodev2/segments/mock_segment.go | 8 +- internal/querynodev2/segments/segment.go | 164 ++++++------- .../querynodev2/segments/segment_interface.go | 6 +- internal/querynodev2/segments/segment_l0.go | 4 +- .../querynodev2/segments/segment_loader.go | 36 ++- internal/querynodev2/segments/segment_test.go | 7 +- .../segments/state/load_state_lock.go | 205 ++++++++++++++++ .../segments/state/load_state_lock_guard.go | 45 ++++ .../segments/state/load_state_lock_test.go | 224 ++++++++++++++++++ 10 files changed, 602 insertions(+), 109 deletions(-) create mode 100644 internal/querynodev2/segments/state/load_state_lock.go create mode 100644 internal/querynodev2/segments/state/load_state_lock_guard.go create mode 100644 internal/querynodev2/segments/state/load_state_lock_test.go diff --git a/internal/querynodev2/segments/manager.go b/internal/querynodev2/segments/manager.go index a7e626c066..9ebbab1296 100644 --- a/internal/querynodev2/segments/manager.go +++ b/internal/querynodev2/segments/manager.go @@ -370,7 +370,7 @@ func (mgr *segmentManager) GetAndPinBy(filters ...SegmentFilter) ([]Segment, err defer func() { if err != nil { for _, segment := range ret { - segment.RUnlock() + segment.Unpin() } } }() @@ -379,7 +379,7 @@ func (mgr *segmentManager) GetAndPinBy(filters ...SegmentFilter) ([]Segment, err if segment.Level() == datapb.SegmentLevel_L0 { return true } - err = segment.RLock() + err = segment.PinIfNotReleased() if err != nil { return false } @@ -399,7 +399,7 @@ func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter) defer func() { if err != nil { for _, segment := range lockedSegments { - segment.RUnlock() + segment.Unpin() } } }() @@ -417,14 +417,14 @@ func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter) sealedExist = sealedExist && filter(sealed, filters...) if growingExist { - err = growing.RLock() + err = growing.PinIfNotReleased() if err != nil { return nil, err } lockedSegments = append(lockedSegments, growing) } if sealedExist { - err = sealed.RLock() + err = sealed.PinIfNotReleased() if err != nil { return nil, err } @@ -442,7 +442,7 @@ func (mgr *segmentManager) GetAndPin(segments []int64, filters ...SegmentFilter) func (mgr *segmentManager) Unpin(segments []Segment) { for _, segment := range segments { - segment.RUnlock() + segment.Unpin() } } diff --git a/internal/querynodev2/segments/mock_segment.go b/internal/querynodev2/segments/mock_segment.go index a433973321..878ccca5ed 100644 --- a/internal/querynodev2/segments/mock_segment.go +++ b/internal/querynodev2/segments/mock_segment.go @@ -874,8 +874,8 @@ func (_c *MockSegment_Partition_Call) RunAndReturn(run func() int64) *MockSegmen return _c } -// RLock provides a mock function with given fields: -func (_m *MockSegment) RLock() error { +// PinIfNotReleased provides a mock function with given fields: +func (_m *MockSegment) PinIfNotReleased() error { ret := _m.Called() var r0 error @@ -915,8 +915,8 @@ func (_c *MockSegment_RLock_Call) RunAndReturn(run func() error) *MockSegment_RL return _c } -// RUnlock provides a mock function with given fields: -func (_m *MockSegment) RUnlock() { +// Unpin provides a mock function with given fields: +func (_m *MockSegment) Unpin() { _m.Called() } diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index aabbfad421..6d5cd008ea 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -31,7 +31,6 @@ import ( "io" "strconv" "strings" - "sync" "unsafe" "github.com/apache/arrow/go/v12/arrow/array" @@ -50,6 +49,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" + "github.com/milvus-io/milvus/internal/querynodev2/segments/state" "github.com/milvus-io/milvus/internal/storage" typeutil_internal "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" @@ -213,7 +213,7 @@ var _ Segment = (*LocalSegment)(nil) // Segment is a wrapper of the underlying C-structure segment. type LocalSegment struct { baseSegment - ptrLock sync.RWMutex // protects segmentPtr + ptrLock *state.LoadStateLock ptr C.CSegmentInterface // cached results, to avoid too many CGO calls @@ -242,10 +242,13 @@ func NewSegment(ctx context.Context, return NewL0Segment(collection, segmentType, version, loadInfo) } var cSegType C.SegmentType + var locker *state.LoadStateLock switch segmentType { case SegmentTypeSealed: cSegType = C.Sealed + locker = state.NewLoadStateLock(state.LoadStateOnlyMeta) case SegmentTypeGrowing: + locker = state.NewLoadStateLock(state.LoadStateDataLoaded) cSegType = C.Growing default: return nil, fmt.Errorf("illegal segment type %d when create segment %d", segmentType, loadInfo.GetSegmentID()) @@ -275,6 +278,7 @@ func NewSegment(ctx context.Context, segment := &LocalSegment{ baseSegment: newBaseSegment(collection, segmentType, version, loadInfo), + ptrLock: locker, ptr: newPtr, lastDeltaTimestamp: atomic.NewUint64(0), fields: typeutil.NewConcurrentMap[int64, *FieldInfo](), @@ -308,11 +312,14 @@ func NewSegmentV2( } var segmentPtr C.CSegmentInterface var status C.CStatus + var locker *state.LoadStateLock switch segmentType { case SegmentTypeSealed: status = C.NewSegment(collection.collectionPtr, C.Sealed, C.int64_t(loadInfo.GetSegmentID()), &segmentPtr) + locker = state.NewLoadStateLock(state.LoadStateOnlyMeta) case SegmentTypeGrowing: status = C.NewSegment(collection.collectionPtr, C.Growing, C.int64_t(loadInfo.GetSegmentID()), &segmentPtr) + locker = state.NewLoadStateLock(state.LoadStateDataLoaded) default: return nil, fmt.Errorf("illegal segment type %d when create segment %d", segmentType, loadInfo.GetSegmentID()) } @@ -338,6 +345,7 @@ func NewSegmentV2( segment := &LocalSegment{ baseSegment: newBaseSegment(collection, segmentType, version, loadInfo), + ptrLock: locker, ptr: segmentPtr, lastDeltaTimestamp: atomic.NewUint64(0), fields: typeutil.NewConcurrentMap[int64, *FieldInfo](), @@ -355,41 +363,31 @@ func NewSegmentV2( return segment, nil } -func (s *LocalSegment) isValid() bool { - return s.ptr != nil -} - -// RLock acquires the `ptrLock` and returns true if the pointer is valid +// PinIfNotReleased acquires the `ptrLock` and returns true if the pointer is valid // Provide ONLY the read lock operations, // don't make `ptrLock` public to avoid abusing of the mutex. -func (s *LocalSegment) RLock() error { - s.ptrLock.RLock() - if !s.isValid() { - s.ptrLock.RUnlock() +func (s *LocalSegment) PinIfNotReleased() error { + if !s.ptrLock.PinIfNotReleased() { return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } return nil } -func (s *LocalSegment) RUnlock() { - s.ptrLock.RUnlock() +func (s *LocalSegment) Unpin() { + s.ptrLock.Unpin() } func (s *LocalSegment) InsertCount() int64 { - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - return s.insertCount.Load() } func (s *LocalSegment) RowNum() int64 { - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - - if !s.isValid() { + // if segment is not loaded, return 0 (maybe not loaded or release by lru) + if !s.ptrLock.RLockIf(state.IsDataLoaded) { log.Warn("segment is not valid", zap.Int64("segmentID", s.ID())) return 0 } + defer s.ptrLock.RUnlock() rowNum := s.rowNum.Load() if rowNum < 0 { @@ -406,12 +404,10 @@ func (s *LocalSegment) RowNum() int64 { } func (s *LocalSegment) MemSize() int64 { - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - - if !s.isValid() { + if !s.ptrLock.RLockIf(state.IsNotReleased) { return 0 } + defer s.ptrLock.RUnlock() memSize := s.memSize.Load() if memSize < 0 { @@ -449,11 +445,11 @@ func (s *LocalSegment) ExistIndex(fieldID int64) bool { } func (s *LocalSegment) HasRawData(fieldID int64) bool { - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - if !s.isValid() { + if !s.ptrLock.RLockIf(state.IsNotReleased) { return false } + defer s.ptrLock.RUnlock() + ret := C.HasRawData(s.ptr, C.int64_t(fieldID)) return bool(ret) } @@ -482,12 +478,11 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S zap.Int64("segmentID", s.ID()), zap.String("segmentType", s.segmentType.String()), ) - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - - if s.ptr == nil { + if !s.ptrLock.RLockIf(state.IsNotReleased) { + // TODO: check if the segment is readable but not released. too many related logic need to be refactor. return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } + defer s.ptrLock.RUnlock() traceCtx := ParseCTraceContext(ctx) @@ -520,12 +515,11 @@ func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*S } func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - - if s.ptr == nil { + if !s.ptrLock.RLockIf(state.IsNotReleased) { + // TODO: check if the segment is readable but not released. too many related logic need to be refactor. return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } + defer s.ptrLock.RUnlock() log := log.Ctx(ctx).With( zap.Int64("collectionID", s.Collection()), @@ -616,13 +610,10 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps [] if s.Type() != SegmentTypeGrowing { return fmt.Errorf("unexpected segmentType when segmentInsert, segmentType = %s", s.segmentType.String()) } - - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - - if s.ptr == nil { + if !s.ptrLock.RLockIf(state.IsNotReleased) { return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } + defer s.ptrLock.RUnlock() offset, err := s.preInsert(ctx, len(rowIDs)) if err != nil { @@ -676,13 +667,10 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys []storage.Primary if len(primaryKeys) == 0 { return nil } - - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - - if s.ptr == nil { + if !s.ptrLock.RLockIf(state.IsNotReleased) { return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } + defer s.ptrLock.RUnlock() cOffset := C.int64_t(0) // depre cSize := C.int64_t(len(primaryKeys)) @@ -743,12 +731,10 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys []storage.Primary // -------------------------------------------------------------------------------------- interfaces for sealed segment func (s *LocalSegment) LoadMultiFieldData(ctx context.Context, rowCount int64, fields []*datapb.FieldBinlog) error { - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - - if s.ptr == nil { + if !s.ptrLock.RLockIf(state.IsNotReleased) { return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } + defer s.ptrLock.RUnlock() log := log.Ctx(ctx).With( zap.Int64("collectionID", s.Collection()), @@ -847,16 +833,14 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun s.loadStatus.Store(string(options.LoadStatus)) - s.ptrLock.RLock() + if !s.ptrLock.RLockIf(state.IsNotReleased) { + return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") + } defer s.ptrLock.RUnlock() ctx, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, fmt.Sprintf("LoadFieldData-%d-%d", s.ID(), fieldID)) defer sp.End() - if s.ptr == nil { - return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") - } - log := log.Ctx(ctx).With( zap.Int64("collectionID", s.Collection()), zap.Int64("partitionID", s.Partition()), @@ -1011,12 +995,10 @@ func (s *LocalSegment) LoadDeltaData2(ctx context.Context, schema *schemapb.Coll } func (s *LocalSegment) AddFieldDataInfo(ctx context.Context, rowCount int64, fields []*datapb.FieldBinlog) error { - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - - if s.ptr == nil { + if !s.ptrLock.RLockIf(state.IsNotReleased) { return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } + defer s.ptrLock.RUnlock() log := log.Ctx(ctx).With( zap.Int64("collectionID", s.Collection()), @@ -1066,12 +1048,10 @@ func (s *LocalSegment) LoadDeltaData(ctx context.Context, deltaData *storage.Del pks, tss := deltaData.Pks, deltaData.Tss rowNum := deltaData.RowCount - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - - if s.ptr == nil { + if !s.ptrLock.RLockIf(state.IsNotReleased) { return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } + defer s.ptrLock.RUnlock() log := log.Ctx(ctx).With( zap.Int64("collectionID", s.Collection()), @@ -1224,12 +1204,10 @@ func (s *LocalSegment) UpdateIndexInfo(ctx context.Context, indexInfo *querypb.F zap.Int64("segmentID", s.ID()), zap.Int64("fieldID", indexInfo.FieldID), ) - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - - if s.ptr == nil { + if !s.ptrLock.RLockIf(state.IsNotReleased) { return merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } + defer s.ptrLock.RUnlock() var status C.CStatus GetDynamicPool().Submit(func() (any, error) { @@ -1263,12 +1241,10 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64) { zap.Int64("segmentID", s.ID()), zap.Int64("fieldID", fieldID), ) - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - - if s.ptr == nil { + if !s.ptrLock.RLockIf(state.IsNotReleased) { return } + defer s.ptrLock.RUnlock() var status C.CStatus @@ -1287,11 +1263,11 @@ func (s *LocalSegment) WarmupChunkCache(ctx context.Context, fieldID int64) { }).Await() case "async": GetLoadPool().Submit(func() (any, error) { - s.ptrLock.RLock() - defer s.ptrLock.RUnlock() - if s.ptr == nil { + if !s.ptrLock.RLockIf(state.IsNotReleased) { return nil, nil } + defer s.ptrLock.RUnlock() + cFieldID := C.int64_t(fieldID) status = C.WarmupChunkCache(s.ptr, cFieldID) if err := HandleCStatus(ctx, &status, ""); err != nil { @@ -1357,26 +1333,17 @@ func (s *LocalSegment) Release(opts ...releaseOption) { for _, opt := range opts { opt(options) } - - /* - void - deleteSegment(CSegmentInterface segment); - */ - var ptr C.CSegmentInterface - - // wait all read ops finished - s.ptrLock.Lock() - ptr = s.ptr - s.ptr = nil - if options.Scope == ReleaseScopeData { - s.loadStatus.Store(string(LoadStatusMeta)) - } - s.ptrLock.Unlock() - - if ptr == nil { + stateLockGuard := s.startRelease(options.Scope) + if stateLockGuard == nil { // release is already done. return } + // release will never fail + defer stateLockGuard.Done(nil) + + // wait all read ops finished + ptr := s.ptr if options.Scope == ReleaseScopeData { + s.loadStatus.Store(string(LoadStatusMeta)) C.ClearSegmentData(ptr) return } @@ -1406,3 +1373,20 @@ func (s *LocalSegment) Release(opts ...releaseOption) { zap.Int64("insertCount", s.InsertCount()), ) } + +// StartLoadData starts the loading process of the segment. +func (s *LocalSegment) StartLoadData() (state.LoadStateLockGuard, error) { + return s.ptrLock.StartLoadData() +} + +// startRelease starts the releasing process of the segment. +func (s *LocalSegment) startRelease(scope ReleaseScope) state.LoadStateLockGuard { + switch scope { + case ReleaseScopeData: + return s.ptrLock.StartReleaseData() + case ReleaseScopeAll: + return s.ptrLock.StartReleaseAll() + default: + panic(fmt.Sprintf("unexpected release scope %d", scope)) + } +} diff --git a/internal/querynodev2/segments/segment_interface.go b/internal/querynodev2/segments/segment_interface.go index f8a340dce0..e36612a2a1 100644 --- a/internal/querynodev2/segments/segment_interface.go +++ b/internal/querynodev2/segments/segment_interface.go @@ -62,8 +62,10 @@ type Segment interface { Level() datapb.SegmentLevel LoadStatus() LoadStatus LoadInfo() *querypb.SegmentLoadInfo - RLock() error - RUnlock() + // PinIfNotReleased the segment to prevent it from being released + PinIfNotReleased() error + // Unpin the segment to allow it to be released + Unpin() // Stats related // InsertCount returns the number of inserted rows, not effected by deletion diff --git a/internal/querynodev2/segments/segment_l0.go b/internal/querynodev2/segments/segment_l0.go index 6740ce457e..582c98a08d 100644 --- a/internal/querynodev2/segments/segment_l0.go +++ b/internal/querynodev2/segments/segment_l0.go @@ -68,11 +68,11 @@ func NewL0Segment(collection *Collection, return segment, nil } -func (s *L0Segment) RLock() error { +func (s *L0Segment) PinIfNotReleased() error { return nil } -func (s *L0Segment) RUnlock() {} +func (s *L0Segment) Unpin() {} func (s *L0Segment) InsertCount() int64 { return 0 diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index 928ca70d13..f95035605e 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -375,7 +375,23 @@ func (loader *segmentLoaderV2) LoadSegment(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo, loadstatus LoadStatus, -) error { +) (err error) { + // TODO: we should create a transaction-like api to load segment for segment interface, + // but not do many things in segment loader. + stateLockGuard, err := segment.StartLoadData() + // segment can not do load now. + if err != nil { + return err + } + defer func() { + // segment is already loaded. + // TODO: if stateLockGuard is nil, we should not call LoadSegment anymore. + // but current Load is not clear enough to do an actual state transition, keep previous logic to avoid introduced bug. + if stateLockGuard != nil { + stateLockGuard.Done(err) + } + }() + log := log.Ctx(ctx).With( zap.Int64("collectionID", segment.Collection()), zap.Int64("partitionID", segment.Partition()), @@ -1008,7 +1024,23 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, segment *LocalSegment, loadInfo *querypb.SegmentLoadInfo, loadStatus LoadStatus, -) error { +) (err error) { + // TODO: we should create a transaction-like api to load segment for segment interface, + // but not do many things in segment loader. + stateLockGuard, err := segment.StartLoadData() + // segment can not do load now. + if err != nil { + return err + } + defer func() { + // segment is already loaded. + // TODO: if stateLockGuard is nil, we should not call LoadSegment anymore. + // but current Load is not clear enough to do an actual state transition, keep previous logic to avoid introduced bug. + if stateLockGuard != nil { + stateLockGuard.Done(err) + } + }() + log := log.Ctx(ctx).With( zap.Int64("collectionID", segment.Collection()), zap.Int64("partitionID", segment.Partition()), diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go index 31271388af..03a0a4e7b3 100644 --- a/internal/querynodev2/segments/segment_test.go +++ b/internal/querynodev2/segments/segment_test.go @@ -94,10 +94,13 @@ func (suite *SegmentSuite) SetupTest() { suite.chunkManager, ) suite.Require().NoError(err) + g, err := suite.sealed.(*LocalSegment).StartLoadData() + suite.Require().NoError(err) for _, binlog := range binlogs { err = suite.sealed.(*LocalSegment).LoadFieldData(ctx, binlog.FieldID, int64(msgLength), binlog) suite.Require().NoError(err) } + g.Done(nil) suite.growing, err = NewSegment(ctx, suite.collection, @@ -198,9 +201,7 @@ func (suite *SegmentSuite) TestSegmentReleased() { sealed := suite.sealed.(*LocalSegment) - sealed.ptrLock.RLock() - suite.False(sealed.isValid()) - sealed.ptrLock.RUnlock() + suite.False(sealed.ptrLock.PinIfNotReleased()) suite.EqualValues(0, sealed.RowNum()) suite.EqualValues(0, sealed.MemSize()) suite.False(sealed.HasRawData(101)) diff --git a/internal/querynodev2/segments/state/load_state_lock.go b/internal/querynodev2/segments/state/load_state_lock.go new file mode 100644 index 0000000000..3cafcd42a6 --- /dev/null +++ b/internal/querynodev2/segments/state/load_state_lock.go @@ -0,0 +1,205 @@ +package state + +import ( + "fmt" + "sync" + + "github.com/cockroachdb/errors" + "go.uber.org/atomic" +) + +type loadStateEnum int + +// LoadState represent the state transition of segment. +// LoadStateOnlyMeta: segment is created with meta, but not loaded. +// LoadStateDataLoading: segment is loading data. +// LoadStateDataLoaded: segment is full loaded, ready to be searched or queried. +// LoadStateDataReleasing: segment is releasing data. +// LoadStateReleased: segment is released. +// LoadStateOnlyMeta -> LoadStateDataLoading -> LoadStateDataLoaded -> LoadStateDataReleasing -> (LoadStateReleased or LoadStateOnlyMeta) +const ( + LoadStateOnlyMeta loadStateEnum = iota + LoadStateDataLoading // There will be only one goroutine access segment when loading. + LoadStateDataLoaded + LoadStateDataReleasing // There will be only one goroutine access segment when releasing. + LoadStateReleased +) + +// LoadState is the state of segment loading. +func (ls loadStateEnum) String() string { + switch ls { + case LoadStateOnlyMeta: + return "meta" + case LoadStateDataLoading: + return "loading-data" + case LoadStateDataLoaded: + return "loaded" + case LoadStateDataReleasing: + return "releasing-data" + case LoadStateReleased: + return "released" + default: + return "unknown" + } +} + +// NewLoadStateLock creates a LoadState. +func NewLoadStateLock(state loadStateEnum) *LoadStateLock { + if state != LoadStateOnlyMeta && state != LoadStateDataLoaded { + panic(fmt.Sprintf("invalid state for construction of LoadStateLock, %s", state.String())) + } + + mu := &sync.RWMutex{} + return &LoadStateLock{ + mu: mu, + cv: sync.Cond{L: mu}, + state: state, + refCnt: atomic.NewInt32(0), + } +} + +// LoadStateLock is the state of segment loading. +type LoadStateLock struct { + mu *sync.RWMutex + cv sync.Cond + state loadStateEnum + refCnt *atomic.Int32 + // ReleaseAll can be called only when refCnt is 0. + // We need it to be modified when lock is +} + +// RLockIfNotReleased locks the segment if the state is not released. +func (ls *LoadStateLock) RLockIf(pred StatePredicate) bool { + ls.mu.RLock() + if !pred(ls.state) { + ls.mu.RUnlock() + return false + } + return true +} + +// RUnlock unlocks the segment. +func (ls *LoadStateLock) RUnlock() { + ls.mu.RUnlock() +} + +// PinIfNotReleased pin the segment into memory, avoid ReleaseAll to release it. +func (ls *LoadStateLock) PinIfNotReleased() bool { + ls.mu.RLock() + defer ls.mu.RUnlock() + if ls.state == LoadStateReleased { + return false + } + ls.refCnt.Inc() + return true +} + +// Unpin unpin the segment, then segment can be released by ReleaseAll. +func (ls *LoadStateLock) Unpin() { + ls.mu.RLock() + defer ls.mu.RUnlock() + newCnt := ls.refCnt.Dec() + if newCnt < 0 { + panic("unpin more than pin") + } + if newCnt == 0 { + // notify ReleaseAll to release segment if refcnt is zero. + ls.cv.Broadcast() + } +} + +// StartLoadData starts load segment data +// Fast fail if segment is not in LoadStateOnlyMeta. +func (ls *LoadStateLock) StartLoadData() (LoadStateLockGuard, error) { + // only meta can be loaded. + ls.cv.L.Lock() + defer ls.cv.L.Unlock() + + if ls.state == LoadStateDataLoaded { + return nil, nil + } + if ls.state != LoadStateOnlyMeta { + return nil, errors.New("segment is not in LoadStateOnlyMeta, cannot start to loading data") + } + ls.state = LoadStateDataLoading + ls.cv.Broadcast() + + return newLoadStateLockGuard(ls, LoadStateOnlyMeta, LoadStateDataLoaded), nil +} + +// StartReleaseData wait until the segment is releasable and starts releasing segment data. +func (ls *LoadStateLock) StartReleaseData() (g LoadStateLockGuard) { + ls.cv.L.Lock() + defer ls.cv.L.Unlock() + + ls.waitUntilCanReleaseData() + + switch ls.state { + case LoadStateDataLoaded: + ls.state = LoadStateDataReleasing + ls.cv.Broadcast() + return newLoadStateLockGuard(ls, LoadStateDataLoaded, LoadStateOnlyMeta) + case LoadStateOnlyMeta: + // already transit to target state, do nothing. + return nil + case LoadStateReleased: + // do nothing for empty segment. + return nil + default: + panic(fmt.Sprintf("unreachable code: invalid state when releasing data, %s", ls.state.String())) + } +} + +// StartReleaseAll wait until the segment is releasable and starts releasing all segment. +func (ls *LoadStateLock) StartReleaseAll() (g LoadStateLockGuard) { + ls.cv.L.Lock() + defer ls.cv.L.Unlock() + + ls.waitUntilCanReleaseAll() + + switch ls.state { + case LoadStateDataLoaded: + ls.state = LoadStateReleased + ls.cv.Broadcast() + return newNopLoadStateLockGuard() + case LoadStateOnlyMeta: + ls.state = LoadStateReleased + ls.cv.Broadcast() + return newNopLoadStateLockGuard() + case LoadStateReleased: + // already transit to target state, do nothing. + return nil + default: + panic(fmt.Sprintf("unreachable code: invalid state when releasing data, %s", ls.state.String())) + } +} + +// waitUntilCanReleaseData waits until segment is release data able. +func (ls *LoadStateLock) waitUntilCanReleaseData() { + state := ls.state + for state != LoadStateDataLoaded && state != LoadStateOnlyMeta && state != LoadStateReleased { + ls.cv.Wait() + state = ls.state + } +} + +// waitUntilCanReleaseAll waits until segment is releasable. +func (ls *LoadStateLock) waitUntilCanReleaseAll() { + state := ls.state + for (state != LoadStateDataLoaded && state != LoadStateOnlyMeta && state != LoadStateReleased) || ls.refCnt.Load() != 0 { + ls.cv.Wait() + state = ls.state + } +} + +type StatePredicate func(state loadStateEnum) bool + +// IsNotReleased checks if the segment is not released. +func IsNotReleased(state loadStateEnum) bool { + return state != LoadStateReleased +} + +// IsDataLoaded checks if the segment is loaded. +func IsDataLoaded(state loadStateEnum) bool { + return state == LoadStateDataLoaded +} diff --git a/internal/querynodev2/segments/state/load_state_lock_guard.go b/internal/querynodev2/segments/state/load_state_lock_guard.go new file mode 100644 index 0000000000..ffebe9d499 --- /dev/null +++ b/internal/querynodev2/segments/state/load_state_lock_guard.go @@ -0,0 +1,45 @@ +package state + +type LoadStateLockGuard interface { + Done(err error) +} + +// newLoadStateLockGuard creates a LoadStateGuard. +func newLoadStateLockGuard(ls *LoadStateLock, original loadStateEnum, target loadStateEnum) *loadStateLockGuard { + return &loadStateLockGuard{ + ls: ls, + original: original, + target: target, + } +} + +// loadStateLockGuard is a guard to update the state of LoadState. +type loadStateLockGuard struct { + ls *LoadStateLock + original loadStateEnum + target loadStateEnum +} + +// Done updates the state of LoadState to target state. +func (g *loadStateLockGuard) Done(err error) { + g.ls.cv.L.Lock() + g.ls.cv.Broadcast() + defer g.ls.cv.L.Unlock() + + if err != nil { + g.ls.state = g.original + return + } + g.ls.state = g.target +} + +// newNopLoadStateLockGuard creates a LoadStateLockGuard that does nothing. +func newNopLoadStateLockGuard() LoadStateLockGuard { + return nopLockGuard{} +} + +// nopLockGuard is a guard that does nothing. +type nopLockGuard struct{} + +// Done does nothing. +func (nopLockGuard) Done(err error) {} diff --git a/internal/querynodev2/segments/state/load_state_lock_test.go b/internal/querynodev2/segments/state/load_state_lock_test.go new file mode 100644 index 0000000000..f59b618540 --- /dev/null +++ b/internal/querynodev2/segments/state/load_state_lock_test.go @@ -0,0 +1,224 @@ +package state + +import ( + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" +) + +func TestLoadStateLoadData(t *testing.T) { + l := NewLoadStateLock(LoadStateOnlyMeta) + // Test Load Data, roll back + g, err := l.StartLoadData() + assert.NoError(t, err) + assert.NotNil(t, g) + assert.Equal(t, LoadStateDataLoading, l.state) + g.Done(errors.New("test")) + assert.Equal(t, LoadStateOnlyMeta, l.state) + + // Test Load Data, success + g, err = l.StartLoadData() + assert.NoError(t, err) + assert.NotNil(t, g) + assert.Equal(t, LoadStateDataLoading, l.state) + g.Done(nil) + assert.Equal(t, LoadStateDataLoaded, l.state) + + // nothing to do with loaded. + g, err = l.StartLoadData() + assert.NoError(t, err) + assert.Nil(t, g) + + for _, s := range []loadStateEnum{ + LoadStateDataLoading, + LoadStateDataReleasing, + LoadStateReleased, + } { + l.state = s + g, err = l.StartLoadData() + assert.Error(t, err) + assert.Nil(t, g) + } +} + +func TestStartReleaseData(t *testing.T) { + l := NewLoadStateLock(LoadStateOnlyMeta) + // Test Release Data, nothing to do on only meta. + g := l.StartReleaseData() + assert.Nil(t, g) + assert.Equal(t, LoadStateOnlyMeta, l.state) + + // roll back + // never roll back on current using. + l.state = LoadStateDataLoaded + g = l.StartReleaseData() + assert.Equal(t, LoadStateDataReleasing, l.state) + assert.NotNil(t, g) + g.Done(errors.New("test")) + assert.Equal(t, LoadStateDataLoaded, l.state) + + // success + l.state = LoadStateDataLoaded + g = l.StartReleaseData() + assert.Equal(t, LoadStateDataReleasing, l.state) + assert.NotNil(t, g) + g.Done(nil) + assert.Equal(t, LoadStateOnlyMeta, l.state) + + // nothing to do on released + l.state = LoadStateReleased + g = l.StartReleaseData() + assert.Nil(t, g) + + // test blocking. + l.state = LoadStateOnlyMeta + g, err := l.StartLoadData() + assert.NoError(t, err) + + ch := make(chan struct{}) + go func() { + g := l.StartReleaseData() + assert.NotNil(t, g) + g.Done(nil) + close(ch) + }() + + // should be blocked because on loading. + select { + case <-ch: + t.Errorf("should be blocked") + case <-time.After(500 * time.Millisecond): + } + // loaded finished. + g.Done(nil) + + // release can be started. + select { + case <-ch: + case <-time.After(500 * time.Millisecond): + t.Errorf("should not be blocked") + } + assert.Equal(t, LoadStateOnlyMeta, l.state) +} + +func TestStartReleaseAll(t *testing.T) { + l := NewLoadStateLock(LoadStateOnlyMeta) + // Test Release All, nothing to do on only meta. + g := l.StartReleaseAll() + assert.NotNil(t, g) + assert.Equal(t, LoadStateReleased, l.state) + g.Done(nil) + assert.Equal(t, LoadStateReleased, l.state) + + // roll back + // never roll back on current using. + l.state = LoadStateDataLoaded + g = l.StartReleaseData() + assert.Equal(t, LoadStateDataReleasing, l.state) + assert.NotNil(t, g) + g.Done(errors.New("test")) + assert.Equal(t, LoadStateDataLoaded, l.state) + + // success + l.state = LoadStateDataLoaded + g = l.StartReleaseAll() + assert.Equal(t, LoadStateReleased, l.state) + assert.NotNil(t, g) + g.Done(nil) + assert.Equal(t, LoadStateReleased, l.state) + + // nothing to do on released + l.state = LoadStateReleased + g = l.StartReleaseAll() + assert.Nil(t, g) + + // test blocking. + l.state = LoadStateOnlyMeta + g, err := l.StartLoadData() + assert.NoError(t, err) + + ch := make(chan struct{}) + go func() { + g := l.StartReleaseAll() + assert.NotNil(t, g) + g.Done(nil) + close(ch) + }() + + // should be blocked because on loading. + select { + case <-ch: + t.Errorf("should be blocked") + case <-time.After(500 * time.Millisecond): + } + // loaded finished. + g.Done(nil) + + // release can be started. + select { + case <-ch: + case <-time.After(500 * time.Millisecond): + t.Errorf("should not be blocked") + } + assert.Equal(t, LoadStateReleased, l.state) +} + +func TestRLock(t *testing.T) { + l := NewLoadStateLock(LoadStateOnlyMeta) + assert.True(t, l.RLockIf(IsNotReleased)) + l.RUnlock() + assert.False(t, l.RLockIf(IsDataLoaded)) + + l = NewLoadStateLock(LoadStateDataLoaded) + assert.True(t, l.RLockIf(IsNotReleased)) + l.RUnlock() + assert.True(t, l.RLockIf(IsDataLoaded)) + l.RUnlock() + + l = NewLoadStateLock(LoadStateOnlyMeta) + l.StartReleaseAll().Done(nil) + assert.False(t, l.RLockIf(IsNotReleased)) + assert.False(t, l.RLockIf(IsDataLoaded)) +} + +func TestPin(t *testing.T) { + l := NewLoadStateLock(LoadStateOnlyMeta) + assert.True(t, l.PinIfNotReleased()) + l.Unpin() + + l.StartReleaseAll().Done(nil) + assert.False(t, l.PinIfNotReleased()) + + l = NewLoadStateLock(LoadStateDataLoaded) + assert.True(t, l.PinIfNotReleased()) + + ch := make(chan struct{}) + go func() { + l.StartReleaseAll().Done(nil) + close(ch) + }() + + select { + case <-ch: + t.Errorf("should be blocked") + case <-time.After(500 * time.Millisecond): + } + + // should be blocked until refcnt is zero. + assert.True(t, l.PinIfNotReleased()) + l.Unpin() + select { + case <-ch: + t.Errorf("should be blocked") + case <-time.After(500 * time.Millisecond): + } + l.Unpin() + <-ch + + assert.Panics(t, func() { + // too much unpin + l.Unpin() + }) +}