diff --git a/internal/dataservice/segment_allocator.go b/internal/dataservice/segment_allocator.go index 626e59d11a..5cf782bc25 100644 --- a/internal/dataservice/segment_allocator.go +++ b/internal/dataservice/segment_allocator.go @@ -44,20 +44,16 @@ func (err errRemainInSufficient) Error() string { // segmentAllocator is used to allocate rows for segments and record the allocations. type segmentAllocatorInterface interface { - // OpenSegment add the segment to allocator and set it allocatable - OpenSegment(ctx context.Context, segmentInfo *datapb.SegmentInfo) error // AllocSegment allocate rows and record the allocation. AllocSegment(ctx context.Context, collectionID UniqueID, partitionID UniqueID, channelName string, requestRows int) (UniqueID, int, Timestamp, error) - // GetSealedSegments get all sealed segment. - GetSealedSegments(ctx context.Context) ([]UniqueID, error) - // SealSegment set segment sealed, the segment will not be allocated anymore. - SealSegment(ctx context.Context, segmentID UniqueID) error // DropSegment drop the segment from allocator. DropSegment(ctx context.Context, segmentID UniqueID) + // SealAllSegments get all opened segment ids of collection. return success and failed segment ids + SealAllSegments(ctx context.Context, collectionID UniqueID) error + // GetSealedSegments get all sealed segment. + GetSealedSegments(ctx context.Context) ([]UniqueID, error) // ExpireAllocations check all allocations' expire time and remove the expired allocation. ExpireAllocations(ctx context.Context, timeTick Timestamp) error - // SealAllSegments get all opened segment ids of collection. return success and failed segment ids - SealAllSegments(ctx context.Context, collectionID UniqueID) // IsAllocationsExpired check all allocations of segment expired. IsAllocationsExpired(ctx context.Context, segmentID UniqueID, ts Timestamp) (bool, error) } @@ -113,17 +109,6 @@ func newSegmentAllocator(meta *meta, allocator allocatorInterface, opts ...Optio return alloc } -func (s *segmentAllocator) OpenSegment(ctx context.Context, segmentInfo *datapb.SegmentInfo) error { - sp, _ := trace.StartSpanFromContext(ctx) - defer sp.Finish() - s.mu.Lock() - defer s.mu.Unlock() - if _, ok := s.segments[segmentInfo.ID]; ok { - return fmt.Errorf("segment %d already exist", segmentInfo.ID) - } - return s.open(segmentInfo) -} - func (s *segmentAllocator) open(segmentInfo *datapb.SegmentInfo) error { totalRows, err := s.estimateTotalRows(segmentInfo.CollectionID) if err != nil { @@ -314,23 +299,6 @@ func (s *segmentAllocator) checkSegmentSealed(segStatus *segmentStatus) (bool, e return float64(segMeta.NumRows) >= s.segmentThresholdFactor*float64(segStatus.total), nil } -func (s *segmentAllocator) SealSegment(ctx context.Context, segmentID UniqueID) error { - sp, _ := trace.StartSpanFromContext(ctx) - defer sp.Finish() - s.mu.Lock() - defer s.mu.Unlock() - status, ok := s.segments[segmentID] - if !ok { - return nil - } - - if err := s.sealSegmentInMeta(segmentID); err != nil { - return err - } - status.sealed = true - return nil -} - func (s *segmentAllocator) HasSegment(ctx context.Context, segmentID UniqueID) bool { sp, _ := trace.StartSpanFromContext(ctx) defer sp.Finish() @@ -388,7 +356,7 @@ func (s *segmentAllocator) IsAllocationsExpired(ctx context.Context, segmentID U return status.lastExpireTime <= ts, nil } -func (s *segmentAllocator) SealAllSegments(ctx context.Context, collectionID UniqueID) { +func (s *segmentAllocator) SealAllSegments(ctx context.Context, collectionID UniqueID) error { sp, _ := trace.StartSpanFromContext(ctx) defer sp.Finish() s.mu.Lock() @@ -398,7 +366,29 @@ func (s *segmentAllocator) SealAllSegments(ctx context.Context, collectionID Uni if status.sealed { continue } + if err := s.sealSegmentInMeta(status.id); err != nil { + return err + } status.sealed = true } } + return nil +} + +// only for test +func (s *segmentAllocator) SealSegment(ctx context.Context, segmentID UniqueID) error { + sp, _ := trace.StartSpanFromContext(ctx) + defer sp.Finish() + s.mu.Lock() + defer s.mu.Unlock() + status, ok := s.segments[segmentID] + if !ok { + return nil + } + + if err := s.sealSegmentInMeta(segmentID); err != nil { + return err + } + status.sealed = true + return nil } diff --git a/internal/dataservice/segment_allocator_test.go b/internal/dataservice/segment_allocator_test.go index a4d96dc2e4..219fae5cb9 100644 --- a/internal/dataservice/segment_allocator_test.go +++ b/internal/dataservice/segment_allocator_test.go @@ -14,7 +14,6 @@ import ( "context" "log" "math" - "strconv" "testing" "time" @@ -41,15 +40,6 @@ func TestAllocSegment(t *testing.T) { Schema: schema, }) assert.Nil(t, err) - id, err := mockAllocator.allocID() - assert.Nil(t, err) - segmentInfo, err := BuildSegment(collID, 100, id, "c1") - assert.Nil(t, err) - err = meta.AddSegment(segmentInfo) - assert.Nil(t, err) - err = segAllocator.OpenSegment(ctx, segmentInfo) - assert.Nil(t, err) - cases := []struct { collectionID UniqueID partitionID UniqueID @@ -73,43 +63,6 @@ func TestAllocSegment(t *testing.T) { } } -func TestSealSegment(t *testing.T) { - ctx := context.Background() - Params.Init() - mockAllocator := newMockAllocator() - meta, err := newMemoryMeta(mockAllocator) - assert.Nil(t, err) - segAllocator := newSegmentAllocator(meta, mockAllocator) - - schema := newTestSchema() - collID, err := mockAllocator.allocID() - assert.Nil(t, err) - err = meta.AddCollection(&datapb.CollectionInfo{ - ID: collID, - Schema: schema, - }) - assert.Nil(t, err) - var lastSegID UniqueID - for i := 0; i < 10; i++ { - id, err := mockAllocator.allocID() - assert.Nil(t, err) - segmentInfo, err := BuildSegment(collID, 100, id, "c"+strconv.Itoa(i)) - assert.Nil(t, err) - err = meta.AddSegment(segmentInfo) - assert.Nil(t, err) - err = segAllocator.OpenSegment(ctx, segmentInfo) - assert.Nil(t, err) - lastSegID = segmentInfo.ID - } - - err = segAllocator.SealSegment(ctx, lastSegID) - assert.Nil(t, err) - segAllocator.SealAllSegments(ctx, collID) - sealedSegments, err := segAllocator.GetSealedSegments(ctx) - assert.Nil(t, err) - assert.EqualValues(t, 10, len(sealedSegments)) -} - func TestExpireSegment(t *testing.T) { ctx := context.Background() Params.Init() @@ -126,14 +79,6 @@ func TestExpireSegment(t *testing.T) { Schema: schema, }) assert.Nil(t, err) - id, err := mockAllocator.allocID() - assert.Nil(t, err) - segmentInfo, err := BuildSegment(collID, 100, id, "c1") - assert.Nil(t, err) - err = meta.AddSegment(segmentInfo) - assert.Nil(t, err) - err = segAllocator.OpenSegment(ctx, segmentInfo) - assert.Nil(t, err) id1, _, et, err := segAllocator.AllocSegment(ctx, collID, 100, "c1", 10) ts2, _ := tsoutil.ParseTS(et) diff --git a/internal/dataservice/server.go b/internal/dataservice/server.go index 0518844345..40c9c5b127 100644 --- a/internal/dataservice/server.go +++ b/internal/dataservice/server.go @@ -555,7 +555,12 @@ func (s *Server) Flush(ctx context.Context, req *datapb.FlushRequest) (*commonpb Reason: "server is initializing", }, nil } - s.segAllocator.SealAllSegments(ctx, req.CollectionID) + if err := s.segAllocator.SealAllSegments(ctx, req.CollectionID); err != nil { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("Seal all segments error %s", err), + }, nil + } return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, nil diff --git a/internal/dataservice/server_test.go b/internal/dataservice/server_test.go index b103d09f41..702ce3838d 100644 --- a/internal/dataservice/server_test.go +++ b/internal/dataservice/server_test.go @@ -204,22 +204,8 @@ func TestFlush(t *testing.T) { Partitions: []int64{}, }) assert.Nil(t, err) - segments := []struct { - id UniqueID - collectionID UniqueID - }{ - {1, 0}, - {2, 0}, - } - for _, segment := range segments { - err = svr.segAllocator.OpenSegment(context.TODO(), &datapb.SegmentInfo{ - ID: segment.id, - CollectionID: segment.collectionID, - PartitionID: 0, - State: commonpb.SegmentState_Growing, - }) - assert.Nil(t, err) - } + segID, _, _, err := svr.segAllocator.AllocSegment(context.TODO(), 0, 1, "channel-1", 1) + assert.Nil(t, err) resp, err := svr.Flush(context.TODO(), &datapb.FlushRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Flush, @@ -234,7 +220,8 @@ func TestFlush(t *testing.T) { assert.EqualValues(t, commonpb.ErrorCode_Success, resp.ErrorCode) ids, err := svr.segAllocator.GetSealedSegments(context.TODO()) assert.Nil(t, err) - assert.ElementsMatch(t, ids, []UniqueID{1, 2}) + assert.EqualValues(t, 1, len(ids)) + assert.EqualValues(t, segID, ids[0]) } func TestGetComponentStates(t *testing.T) { diff --git a/internal/dataservice/watcher_test.go b/internal/dataservice/watcher_test.go index dbb3690846..9d53572fb5 100644 --- a/internal/dataservice/watcher_test.go +++ b/internal/dataservice/watcher_test.go @@ -16,10 +16,9 @@ import ( "testing" "time" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/msgstream" "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/stretchr/testify/assert" @@ -73,39 +72,28 @@ func TestWatcher(t *testing.T) { assert.Nil(t, err) cases := []struct { - sealed bool - allocation bool - expired bool - expected bool + sealed bool + expired bool + expected bool }{ - {false, false, true, false}, - {false, true, true, false}, - {false, true, false, false}, - {true, false, true, true}, - {true, true, false, false}, - {true, true, true, true}, + {false, true, false}, + {false, true, false}, + {false, false, false}, + {true, true, true}, + {true, false, false}, + {true, true, true}, } - segIDs := make([]UniqueID, len(cases)) - for i, c := range cases { - segID, err := allocator.allocID() + segIDs := make([]UniqueID, 0) + for i := range cases { + segID, _, _, err := segAllocator.AllocSegment(ctx, collID, partID, "channel"+strconv.Itoa(i), 100) assert.Nil(t, err) - segIDs[i] = segID - segInfo, err := BuildSegment(collID, partID, segID, "channel"+strconv.Itoa(i)) - assert.Nil(t, err) - err = meta.AddSegment(segInfo) - assert.Nil(t, err) - err = segAllocator.OpenSegment(ctx, segInfo) - assert.Nil(t, err) - if c.allocation && c.expired { - _, _, _, err := segAllocator.AllocSegment(ctx, collID, partID, "channel"+strconv.Itoa(i), 100) - assert.Nil(t, err) - } + segIDs = append(segIDs, segID) } time.Sleep(time.Duration(Params.SegIDAssignExpiration+1000) * time.Millisecond) for i, c := range cases { - if c.allocation && !c.expired { + if !c.expired { _, _, _, err := segAllocator.AllocSegment(ctx, collID, partID, "channel"+strconv.Itoa(i), 100) assert.Nil(t, err) }