fix pull target (#23491)

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2023-04-18 18:30:32 +08:00 committed by GitHub
parent eb690ef033
commit cbfe7a45ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 2335 additions and 837 deletions

View File

@ -357,7 +357,7 @@ generate-mockery: getdeps
$(PWD)/bin/mockery --name=Cluster --dir=$(PWD)/internal/querycoordv2/session --output=$(PWD)/internal/querycoordv2/session --filename=mock_cluster.go --with-expecter --structname=MockCluster --outpkg=session --inpackage $(PWD)/bin/mockery --name=Cluster --dir=$(PWD)/internal/querycoordv2/session --output=$(PWD)/internal/querycoordv2/session --filename=mock_cluster.go --with-expecter --structname=MockCluster --outpkg=session --inpackage
$(PWD)/bin/mockery --name=Store --dir=$(PWD)/internal/querycoordv2/meta --output=$(PWD)/internal/querycoordv2/meta --filename=mock_store.go --with-expecter --structname=MockStore --outpkg=meta --inpackage $(PWD)/bin/mockery --name=Store --dir=$(PWD)/internal/querycoordv2/meta --output=$(PWD)/internal/querycoordv2/meta --filename=mock_store.go --with-expecter --structname=MockStore --outpkg=meta --inpackage
$(PWD)/bin/mockery --name=Balance --dir=$(PWD)/internal/querycoordv2/balance --output=$(PWD)/internal/querycoordv2/balance --filename=mock_balancer.go --with-expecter --structname=MockBalancer --outpkg=balance --inpackage $(PWD)/bin/mockery --name=Balance --dir=$(PWD)/internal/querycoordv2/balance --output=$(PWD)/internal/querycoordv2/balance --filename=mock_balancer.go --with-expecter --structname=MockBalancer --outpkg=balance --inpackage
$(PWD)/bin/mockery --name=Controller --dir=$(PWD)/internal/querycoordv2/dist --output=$(PWD)/internal/querycoordv2/dist --filename=mock_controller.go --with-expecter --structname=MockController --outpkg=dist --inpackage $(PWD)/bin/mockery --name=Controller --dir=$(PWD)/internal/querycoordv2/dist --output=$(PWD)/internal/querycoordv2/dist --filename=mock_controller.go --with-expecter --structname=MockController --outpkg=dist --inpackage
# internal/querynode # internal/querynode
$(PWD)/bin/mockery --name=TSafeReplicaInterface --dir=$(PWD)/internal/querynode --output=$(PWD)/internal/querynode --filename=mock_tsafe_replica_test.go --with-expecter --structname=MockTSafeReplicaInterface --outpkg=querynode --inpackage $(PWD)/bin/mockery --name=TSafeReplicaInterface --dir=$(PWD)/internal/querynode --output=$(PWD)/internal/querynode --filename=mock_tsafe_replica_test.go --with-expecter --structname=MockTSafeReplicaInterface --outpkg=querynode --inpackage
# internal/rootcoord # internal/rootcoord

View File

@ -21,6 +21,7 @@ import (
"time" "time"
"github.com/milvus-io/milvus/pkg/util/tsoutil" "github.com/milvus-io/milvus/pkg/util/tsoutil"
"github.com/samber/lo"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
@ -35,7 +36,7 @@ import (
// Handler handles some channel method for ChannelManager // Handler handles some channel method for ChannelManager
type Handler interface { type Handler interface {
// GetQueryVChanPositions gets the information recovery needed of a channel for QueryCoord // GetQueryVChanPositions gets the information recovery needed of a channel for QueryCoord
GetQueryVChanPositions(channel *channel, partitionID UniqueID) *datapb.VchannelInfo GetQueryVChanPositions(channel *channel, partitionIDs ...UniqueID) *datapb.VchannelInfo
// GetDataVChanPositions gets the information recovery needed of a channel for DataNode // GetDataVChanPositions gets the information recovery needed of a channel for DataNode
GetDataVChanPositions(channel *channel, partitionID UniqueID) *datapb.VchannelInfo GetDataVChanPositions(channel *channel, partitionID UniqueID) *datapb.VchannelInfo
CheckShouldDropChannel(channel string, collectionID UniqueID) bool CheckShouldDropChannel(channel string, collectionID UniqueID) bool
@ -101,7 +102,7 @@ func (h *ServerHandler) GetDataVChanPositions(channel *channel, partitionID Uniq
// GetQueryVChanPositions gets vchannel latest postitions with provided dml channel names for QueryCoord, // GetQueryVChanPositions gets vchannel latest postitions with provided dml channel names for QueryCoord,
// we expect QueryCoord gets the indexed segments to load, so the flushed segments below are actually the indexed segments, // we expect QueryCoord gets the indexed segments to load, so the flushed segments below are actually the indexed segments,
// the unflushed segments are actually the segments without index, even they are flushed. // the unflushed segments are actually the segments without index, even they are flushed.
func (h *ServerHandler) GetQueryVChanPositions(channel *channel, partitionID UniqueID) *datapb.VchannelInfo { func (h *ServerHandler) GetQueryVChanPositions(channel *channel, partitionIDs ...UniqueID) *datapb.VchannelInfo {
// cannot use GetSegmentsByChannel since dropped segments are needed here // cannot use GetSegmentsByChannel since dropped segments are needed here
segments := h.s.meta.SelectSegments(func(s *SegmentInfo) bool { segments := h.s.meta.SelectSegments(func(s *SegmentInfo) bool {
return s.InsertChannel == channel.Name && !s.GetIsFake() return s.InsertChannel == channel.Name && !s.GetIsFake()
@ -123,8 +124,11 @@ func (h *ServerHandler) GetQueryVChanPositions(channel *channel, partitionID Uni
unIndexedIDs = make(typeutil.UniqueSet) unIndexedIDs = make(typeutil.UniqueSet)
droppedIDs = make(typeutil.UniqueSet) droppedIDs = make(typeutil.UniqueSet)
) )
validPartitions := lo.Filter(partitionIDs, func(partitionID int64, _ int) bool { return partitionID > allPartitionID })
partitionSet := typeutil.NewUniqueSet(validPartitions...)
for _, s := range segments { for _, s := range segments {
if (partitionID > allPartitionID && s.PartitionID != partitionID) || if (partitionSet.Len() > 0 && !partitionSet.Contain(s.PartitionID)) ||
(s.GetStartPosition() == nil && s.GetDmlPosition() == nil) { (s.GetStartPosition() == nil && s.GetDmlPosition() == nil) {
continue continue
} }
@ -165,7 +169,7 @@ func (h *ServerHandler) GetQueryVChanPositions(channel *channel, partitionID Uni
return &datapb.VchannelInfo{ return &datapb.VchannelInfo{
CollectionID: channel.CollectionID, CollectionID: channel.CollectionID,
ChannelName: channel.Name, ChannelName: channel.Name,
SeekPosition: h.GetChannelSeekPosition(channel, partitionID), SeekPosition: h.GetChannelSeekPosition(channel, partitionIDs...),
FlushedSegmentIds: indexedIDs.Collect(), FlushedSegmentIds: indexedIDs.Collect(),
UnflushedSegmentIds: unIndexedIDs.Collect(), UnflushedSegmentIds: unIndexedIDs.Collect(),
DroppedSegmentIds: droppedIDs.Collect(), DroppedSegmentIds: droppedIDs.Collect(),
@ -175,15 +179,18 @@ func (h *ServerHandler) GetQueryVChanPositions(channel *channel, partitionID Uni
// getEarliestSegmentDMLPos returns the earliest dml position of segments, // getEarliestSegmentDMLPos returns the earliest dml position of segments,
// this is mainly for COMPATIBILITY with old version <=2.1.x // this is mainly for COMPATIBILITY with old version <=2.1.x
func (h *ServerHandler) getEarliestSegmentDMLPos(channel *channel, partitionID UniqueID) *msgpb.MsgPosition { func (h *ServerHandler) getEarliestSegmentDMLPos(channel *channel, partitionIDs ...UniqueID) *msgpb.MsgPosition {
var minPos *msgpb.MsgPosition var minPos *msgpb.MsgPosition
var minPosSegID int64 var minPosSegID int64
var minPosTs uint64 var minPosTs uint64
segments := h.s.meta.SelectSegments(func(s *SegmentInfo) bool { segments := h.s.meta.SelectSegments(func(s *SegmentInfo) bool {
return s.InsertChannel == channel.Name return s.InsertChannel == channel.Name
}) })
validPartitions := lo.Filter(partitionIDs, func(partitionID int64, _ int) bool { return partitionID > allPartitionID })
partitionSet := typeutil.NewUniqueSet(validPartitions...)
for _, s := range segments { for _, s := range segments {
if (partitionID > allPartitionID && s.PartitionID != partitionID) || if (partitionSet.Len() > 0 && !partitionSet.Contain(s.PartitionID)) ||
(s.GetStartPosition() == nil && s.GetDmlPosition() == nil) { (s.GetStartPosition() == nil && s.GetDmlPosition() == nil) {
continue continue
} }
@ -247,7 +254,7 @@ func (h *ServerHandler) getCollectionStartPos(channel *channel) *msgpb.MsgPositi
// 2. Segments earliest dml position; // 2. Segments earliest dml position;
// 3. Collection start position; // 3. Collection start position;
// And would return if any position is valid. // And would return if any position is valid.
func (h *ServerHandler) GetChannelSeekPosition(channel *channel, partitionID UniqueID) *msgpb.MsgPosition { func (h *ServerHandler) GetChannelSeekPosition(channel *channel, partitionIDs ...UniqueID) *msgpb.MsgPosition {
var seekPosition *msgpb.MsgPosition var seekPosition *msgpb.MsgPosition
seekPosition = h.s.meta.GetChannelCheckpoint(channel.Name) seekPosition = h.s.meta.GetChannelCheckpoint(channel.Name)
if seekPosition != nil { if seekPosition != nil {
@ -258,7 +265,7 @@ func (h *ServerHandler) GetChannelSeekPosition(channel *channel, partitionID Uni
return seekPosition return seekPosition
} }
seekPosition = h.getEarliestSegmentDMLPos(channel, partitionID) seekPosition = h.getEarliestSegmentDMLPos(channel, partitionIDs...)
if seekPosition != nil { if seekPosition != nil {
log.Info("channel seek position set from earliest segment dml position", log.Info("channel seek position set from earliest segment dml position",
zap.String("channel", channel.Name), zap.String("channel", channel.Name),

View File

@ -846,7 +846,7 @@ func newMockHandler() *mockHandler {
return &mockHandler{} return &mockHandler{}
} }
func (h *mockHandler) GetQueryVChanPositions(channel *channel, partitionID UniqueID) *datapb.VchannelInfo { func (h *mockHandler) GetQueryVChanPositions(channel *channel, partitionID ...UniqueID) *datapb.VchannelInfo {
return &datapb.VchannelInfo{ return &datapb.VchannelInfo{
CollectionID: channel.CollectionID, CollectionID: channel.CollectionID,
ChannelName: channel.Name, ChannelName: channel.Name,

View File

@ -744,6 +744,102 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
return resp, nil return resp, nil
} }
// GetRecoveryInfoV2 get recovery info for segment
// Called by: QueryCoord.
func (s *Server) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryInfoRequestV2) (*datapb.GetRecoveryInfoResponseV2, error) {
collectionID := req.GetCollectionID()
partitionIDs := req.GetPartitionIDs()
log := log.With(
zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", partitionIDs),
)
log.Info("get recovery info request received")
resp := &datapb.GetRecoveryInfoResponseV2{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
}
if s.isClosed() {
resp.Status.Reason = serverNotServingErrMsg
return resp, nil
}
dresp, err := s.rootCoordClient.DescribeCollectionInternal(s.ctx, &milvuspb.DescribeCollectionRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
CollectionID: collectionID,
})
if err = VerifyResponse(dresp, err); err != nil {
log.Error("get collection info from rootcoord failed",
zap.Error(err))
resp.Status.Reason = err.Error()
return resp, nil
}
channels := dresp.GetVirtualChannelNames()
channelInfos := make([]*datapb.VchannelInfo, 0, len(channels))
flushedIDs := make(typeutil.UniqueSet)
for _, c := range channels {
channelInfo := s.handler.GetQueryVChanPositions(&channel{Name: c, CollectionID: collectionID}, partitionIDs...)
channelInfos = append(channelInfos, channelInfo)
log.Info("datacoord append channelInfo in GetRecoveryInfo",
zap.Any("channelInfo", channelInfo),
)
flushedIDs.Insert(channelInfo.GetFlushedSegmentIds()...)
}
segmentInfos := make([]*datapb.SegmentInfo, 0)
for id := range flushedIDs {
segment := s.meta.GetSegment(id)
if segment == nil {
errMsg := fmt.Sprintf("failed to get segment %d", id)
log.Error(errMsg)
resp.Status.Reason = errMsg
return resp, nil
}
// Skip non-flushing, non-flushed and dropped segments.
if segment.State != commonpb.SegmentState_Flushed && segment.State != commonpb.SegmentState_Flushing && segment.State != commonpb.SegmentState_Dropped {
continue
}
// Also skip bulk insert segments.
if segment.GetIsImporting() {
continue
}
binlogs := segment.GetBinlogs()
if len(binlogs) == 0 {
continue
}
rowCount := segmentutil.CalcRowCountFromBinLog(segment.SegmentInfo)
if rowCount != segment.NumOfRows && rowCount > 0 {
log.Warn("segment row number meta inconsistent with bin log row count and will be corrected",
zap.Int64("segment ID", segment.GetID()),
zap.Int64("segment meta row count (wrong)", segment.GetNumOfRows()),
zap.Int64("segment bin log row count (correct)", rowCount))
} else {
rowCount = segment.NumOfRows
}
segmentInfos = append(segmentInfos, &datapb.SegmentInfo{
ID: segment.ID,
PartitionID: segment.PartitionID,
CollectionID: segment.CollectionID,
InsertChannel: segment.InsertChannel,
NumOfRows: rowCount,
Binlogs: segment.Binlogs,
Statslogs: segment.Statslogs,
Deltalogs: segment.Deltalogs,
})
}
resp.Channels = channelInfos
resp.Segments = segmentInfos
resp.Status.ErrorCode = commonpb.ErrorCode_Success
return resp, nil
}
// GetFlushedSegments returns all segment matches provided criterion and in state Flushed or Dropped (compacted but not GCed yet) // GetFlushedSegments returns all segment matches provided criterion and in state Flushed or Dropped (compacted but not GCed yet)
// If requested partition id < 0, ignores the partition id filter // If requested partition id < 0, ignores the partition id filter
func (s *Server) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) { func (s *Server) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) {

View File

@ -2,14 +2,22 @@ package datacoord
import ( import (
"context" "context"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/msgpb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/util/metautil"
) )
func TestBroadcastAlteredCollection(t *testing.T) { func TestBroadcastAlteredCollection(t *testing.T) {
@ -88,3 +96,492 @@ func TestServer_GcConfirm(t *testing.T) {
assert.False(t, resp.GetGcFinished()) assert.False(t, resp.GetGcFinished())
}) })
} }
func TestGetRecoveryInfoV2(t *testing.T) {
t.Run("test get recovery info with no segments", func(t *testing.T) {
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) {
return newMockRootCoordService(), nil
}
req := &datapb.GetRecoveryInfoRequestV2{
CollectionID: 0,
}
resp, err := svr.GetRecoveryInfoV2(context.TODO(), req)
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.EqualValues(t, 0, len(resp.GetSegments()))
assert.EqualValues(t, 1, len(resp.GetChannels()))
assert.Nil(t, resp.GetChannels()[0].SeekPosition)
})
createSegment := func(id, collectionID, partitionID, numOfRows int64, posTs uint64,
channel string, state commonpb.SegmentState) *datapb.SegmentInfo {
return &datapb.SegmentInfo{
ID: id,
CollectionID: collectionID,
PartitionID: partitionID,
InsertChannel: channel,
NumOfRows: numOfRows,
State: state,
DmlPosition: &msgpb.MsgPosition{
ChannelName: channel,
MsgID: []byte{},
Timestamp: posTs,
},
StartPosition: &msgpb.MsgPosition{
ChannelName: "",
MsgID: []byte{},
MsgGroup: "",
Timestamp: 0,
},
}
}
t.Run("test get earliest position of flushed segments as seek position", func(t *testing.T) {
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) {
return newMockRootCoordService(), nil
}
svr.meta.AddCollection(&collectionInfo{
Schema: newTestSchema(),
})
err := svr.meta.UpdateChannelCheckpoint("vchan1", &msgpb.MsgPosition{
ChannelName: "vchan1",
Timestamp: 10,
})
assert.NoError(t, err)
err = svr.meta.CreateIndex(&model.Index{
TenantID: "",
CollectionID: 0,
FieldID: 2,
IndexID: 0,
IndexName: "",
})
assert.Nil(t, err)
seg1 := createSegment(0, 0, 0, 100, 10, "vchan1", commonpb.SegmentState_Flushed)
seg1.Binlogs = []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
EntriesNum: 20,
LogPath: metautil.BuildInsertLogPath("a", 0, 0, 0, 1, 901),
},
{
EntriesNum: 20,
LogPath: metautil.BuildInsertLogPath("a", 0, 0, 0, 1, 902),
},
{
EntriesNum: 20,
LogPath: metautil.BuildInsertLogPath("a", 0, 0, 0, 1, 903),
},
},
},
}
seg2 := createSegment(1, 0, 0, 100, 20, "vchan1", commonpb.SegmentState_Flushed)
seg2.Binlogs = []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
EntriesNum: 30,
LogPath: metautil.BuildInsertLogPath("a", 0, 0, 1, 1, 801),
},
{
EntriesNum: 70,
LogPath: metautil.BuildInsertLogPath("a", 0, 0, 1, 1, 802),
},
},
},
}
err = svr.meta.AddSegment(NewSegmentInfo(seg1))
assert.Nil(t, err)
err = svr.meta.AddSegment(NewSegmentInfo(seg2))
assert.Nil(t, err)
err = svr.meta.AddSegmentIndex(&model.SegmentIndex{
SegmentID: seg1.ID,
BuildID: seg1.ID,
})
assert.Nil(t, err)
err = svr.meta.FinishTask(&indexpb.IndexTaskInfo{
BuildID: seg1.ID,
State: commonpb.IndexState_Finished,
})
assert.Nil(t, err)
err = svr.meta.AddSegmentIndex(&model.SegmentIndex{
SegmentID: seg2.ID,
BuildID: seg2.ID,
})
assert.Nil(t, err)
err = svr.meta.FinishTask(&indexpb.IndexTaskInfo{
BuildID: seg2.ID,
State: commonpb.IndexState_Finished,
})
assert.Nil(t, err)
req := &datapb.GetRecoveryInfoRequestV2{
CollectionID: 0,
}
resp, err := svr.GetRecoveryInfoV2(context.TODO(), req)
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.EqualValues(t, 1, len(resp.GetChannels()))
assert.EqualValues(t, 0, len(resp.GetChannels()[0].GetUnflushedSegmentIds()))
assert.ElementsMatch(t, []int64{0, 1}, resp.GetChannels()[0].GetFlushedSegmentIds())
assert.EqualValues(t, 10, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
assert.EqualValues(t, 2, len(resp.GetSegments()))
// Row count corrected from 100 + 100 -> 100 + 60.
assert.EqualValues(t, 160, resp.GetSegments()[0].GetNumOfRows()+resp.GetSegments()[1].GetNumOfRows())
})
t.Run("test get recovery of unflushed segments ", func(t *testing.T) {
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) {
return newMockRootCoordService(), nil
}
svr.meta.AddCollection(&collectionInfo{
ID: 0,
Schema: newTestSchema(),
})
err := svr.meta.UpdateChannelCheckpoint("vchan1", &msgpb.MsgPosition{
ChannelName: "vchan1",
Timestamp: 0,
})
assert.NoError(t, err)
seg1 := createSegment(3, 0, 0, 100, 30, "vchan1", commonpb.SegmentState_Growing)
seg1.Binlogs = []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
EntriesNum: 20,
LogPath: metautil.BuildInsertLogPath("a", 0, 0, 3, 1, 901),
},
{
EntriesNum: 20,
LogPath: metautil.BuildInsertLogPath("a", 0, 0, 3, 1, 902),
},
{
EntriesNum: 20,
LogPath: metautil.BuildInsertLogPath("a", 0, 0, 3, 1, 903),
},
},
},
}
seg2 := createSegment(4, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Growing)
seg2.Binlogs = []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
EntriesNum: 30,
LogPath: metautil.BuildInsertLogPath("a", 0, 0, 4, 1, 801),
},
{
EntriesNum: 70,
LogPath: metautil.BuildInsertLogPath("a", 0, 0, 4, 1, 802),
},
},
},
}
err = svr.meta.AddSegment(NewSegmentInfo(seg1))
assert.Nil(t, err)
err = svr.meta.AddSegment(NewSegmentInfo(seg2))
assert.Nil(t, err)
req := &datapb.GetRecoveryInfoRequestV2{
CollectionID: 0,
}
resp, err := svr.GetRecoveryInfoV2(context.TODO(), req)
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.EqualValues(t, 0, len(resp.GetSegments()))
assert.EqualValues(t, 1, len(resp.GetChannels()))
assert.NotNil(t, resp.GetChannels()[0].SeekPosition)
assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
})
t.Run("test get binlogs", func(t *testing.T) {
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
svr.meta.AddCollection(&collectionInfo{
Schema: newTestSchema(),
})
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) {
return newMockRootCoordService(), nil
}
binlogReq := &datapb.SaveBinlogPathsRequest{
SegmentID: 0,
CollectionID: 0,
Field2BinlogPaths: []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
LogPath: "/binlog/file1",
},
{
LogPath: "/binlog/file2",
},
},
},
},
Field2StatslogPaths: []*datapb.FieldBinlog{
{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
LogPath: "/stats_log/file1",
},
{
LogPath: "/stats_log/file2",
},
},
},
},
Deltalogs: []*datapb.FieldBinlog{
{
Binlogs: []*datapb.Binlog{
{
TimestampFrom: 0,
TimestampTo: 1,
LogPath: "/stats_log/file1",
LogSize: 1,
},
},
},
},
}
segment := createSegment(0, 0, 1, 100, 10, "vchan1", commonpb.SegmentState_Flushed)
err := svr.meta.AddSegment(NewSegmentInfo(segment))
assert.Nil(t, err)
err = svr.meta.CreateIndex(&model.Index{
TenantID: "",
CollectionID: 0,
FieldID: 2,
IndexID: 0,
IndexName: "",
})
assert.Nil(t, err)
err = svr.meta.AddSegmentIndex(&model.SegmentIndex{
SegmentID: segment.ID,
BuildID: segment.ID,
})
assert.Nil(t, err)
err = svr.meta.FinishTask(&indexpb.IndexTaskInfo{
BuildID: segment.ID,
State: commonpb.IndexState_Finished,
})
assert.Nil(t, err)
err = svr.channelManager.AddNode(0)
assert.Nil(t, err)
err = svr.channelManager.Watch(&channel{Name: "vchan1", CollectionID: 0})
assert.Nil(t, err)
sResp, err := svr.SaveBinlogPaths(context.TODO(), binlogReq)
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, sResp.ErrorCode)
req := &datapb.GetRecoveryInfoRequestV2{
CollectionID: 0,
PartitionIDs: []int64{1},
}
resp, err := svr.GetRecoveryInfoV2(context.TODO(), req)
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.EqualValues(t, 1, len(resp.GetSegments()))
assert.EqualValues(t, 0, resp.GetSegments()[0].GetID())
assert.EqualValues(t, 1, len(resp.GetSegments()[0].GetBinlogs()))
assert.EqualValues(t, 1, resp.GetSegments()[0].GetBinlogs()[0].GetFieldID())
for i, binlog := range resp.GetSegments()[0].GetBinlogs()[0].GetBinlogs() {
assert.Equal(t, fmt.Sprintf("/binlog/file%d", i+1), binlog.GetLogPath())
}
})
t.Run("with dropped segments", func(t *testing.T) {
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) {
return newMockRootCoordService(), nil
}
svr.meta.AddCollection(&collectionInfo{
ID: 0,
Schema: newTestSchema(),
})
err := svr.meta.UpdateChannelCheckpoint("vchan1", &msgpb.MsgPosition{
ChannelName: "vchan1",
Timestamp: 0,
})
assert.NoError(t, err)
seg1 := createSegment(7, 0, 0, 100, 30, "vchan1", commonpb.SegmentState_Growing)
seg2 := createSegment(8, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Dropped)
err = svr.meta.AddSegment(NewSegmentInfo(seg1))
assert.Nil(t, err)
err = svr.meta.AddSegment(NewSegmentInfo(seg2))
assert.Nil(t, err)
req := &datapb.GetRecoveryInfoRequestV2{
CollectionID: 0,
}
resp, err := svr.GetRecoveryInfoV2(context.TODO(), req)
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.EqualValues(t, 0, len(resp.GetSegments()))
assert.EqualValues(t, 1, len(resp.GetChannels()))
assert.NotNil(t, resp.GetChannels()[0].SeekPosition)
assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
assert.Len(t, resp.GetChannels()[0].GetDroppedSegmentIds(), 1)
assert.Equal(t, UniqueID(8), resp.GetChannels()[0].GetDroppedSegmentIds()[0])
})
t.Run("with fake segments", func(t *testing.T) {
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) {
return newMockRootCoordService(), nil
}
svr.meta.AddCollection(&collectionInfo{
ID: 0,
Schema: newTestSchema(),
})
err := svr.meta.UpdateChannelCheckpoint("vchan1", &msgpb.MsgPosition{
ChannelName: "vchan1",
Timestamp: 0,
})
require.NoError(t, err)
seg1 := createSegment(7, 0, 0, 100, 30, "vchan1", commonpb.SegmentState_Growing)
seg2 := createSegment(8, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Flushed)
seg2.IsFake = true
err = svr.meta.AddSegment(NewSegmentInfo(seg1))
assert.Nil(t, err)
err = svr.meta.AddSegment(NewSegmentInfo(seg2))
assert.Nil(t, err)
req := &datapb.GetRecoveryInfoRequestV2{
CollectionID: 0,
}
resp, err := svr.GetRecoveryInfoV2(context.TODO(), req)
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.EqualValues(t, 0, len(resp.GetSegments()))
assert.EqualValues(t, 1, len(resp.GetChannels()))
assert.NotNil(t, resp.GetChannels()[0].SeekPosition)
assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
})
t.Run("with continuous compaction", func(t *testing.T) {
svr := newTestServer(t, nil)
defer closeTestServer(t, svr)
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) {
return newMockRootCoordService(), nil
}
svr.meta.AddCollection(&collectionInfo{
ID: 0,
Schema: newTestSchema(),
})
err := svr.meta.UpdateChannelCheckpoint("vchan1", &msgpb.MsgPosition{
ChannelName: "vchan1",
Timestamp: 0,
})
assert.NoError(t, err)
seg1 := createSegment(9, 0, 0, 100, 30, "vchan1", commonpb.SegmentState_Dropped)
seg2 := createSegment(10, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Dropped)
seg3 := createSegment(11, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Dropped)
seg3.CompactionFrom = []int64{9, 10}
seg4 := createSegment(12, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Dropped)
seg5 := createSegment(13, 0, 0, 100, 40, "vchan1", commonpb.SegmentState_Flushed)
seg5.CompactionFrom = []int64{11, 12}
err = svr.meta.AddSegment(NewSegmentInfo(seg1))
assert.Nil(t, err)
err = svr.meta.AddSegment(NewSegmentInfo(seg2))
assert.Nil(t, err)
err = svr.meta.AddSegment(NewSegmentInfo(seg3))
assert.Nil(t, err)
err = svr.meta.AddSegment(NewSegmentInfo(seg4))
assert.Nil(t, err)
err = svr.meta.AddSegment(NewSegmentInfo(seg5))
assert.Nil(t, err)
err = svr.meta.CreateIndex(&model.Index{
TenantID: "",
CollectionID: 0,
FieldID: 2,
IndexID: 0,
IndexName: "_default_idx_2",
IsDeleted: false,
CreateTime: 0,
TypeParams: nil,
IndexParams: nil,
IsAutoIndex: false,
UserIndexParams: nil,
})
assert.Nil(t, err)
svr.meta.segments.SetSegmentIndex(seg4.ID, &model.SegmentIndex{
SegmentID: seg4.ID,
CollectionID: 0,
PartitionID: 0,
NumRows: 100,
IndexID: 0,
BuildID: 0,
NodeID: 0,
IndexVersion: 1,
IndexState: commonpb.IndexState_Finished,
FailReason: "",
IsDeleted: false,
CreateTime: 0,
IndexFileKeys: nil,
IndexSize: 0,
})
req := &datapb.GetRecoveryInfoRequestV2{
CollectionID: 0,
}
resp, err := svr.GetRecoveryInfoV2(context.TODO(), req)
assert.Nil(t, err)
assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
assert.NotNil(t, resp.GetChannels()[0].SeekPosition)
assert.NotEqual(t, 0, resp.GetChannels()[0].GetSeekPosition().GetTimestamp())
assert.Len(t, resp.GetChannels()[0].GetDroppedSegmentIds(), 0)
assert.ElementsMatch(t, []UniqueID{9, 10}, resp.GetChannels()[0].GetUnflushedSegmentIds())
assert.ElementsMatch(t, []UniqueID{12}, resp.GetChannels()[0].GetFlushedSegmentIds())
})
t.Run("with closed server", func(t *testing.T) {
svr := newTestServer(t, nil)
closeTestServer(t, svr)
resp, err := svr.GetRecoveryInfoV2(context.TODO(), &datapb.GetRecoveryInfoRequestV2{})
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.GetStatus().GetErrorCode())
assert.Equal(t, serverNotServingErrMsg, resp.GetStatus().GetReason())
})
}

View File

@ -54,7 +54,7 @@ type MockAllocator_Alloc_Call struct {
} }
// Alloc is a helper method to define mock.On call // Alloc is a helper method to define mock.On call
// - count uint32 // - count uint32
func (_e *MockAllocator_Expecter) Alloc(count interface{}) *MockAllocator_Alloc_Call { func (_e *MockAllocator_Expecter) Alloc(count interface{}) *MockAllocator_Alloc_Call {
return &MockAllocator_Alloc_Call{Call: _e.mock.On("Alloc", count)} return &MockAllocator_Alloc_Call{Call: _e.mock.On("Alloc", count)}
} }
@ -170,8 +170,8 @@ type MockAllocator_GetGenerator_Call struct {
} }
// GetGenerator is a helper method to define mock.On call // GetGenerator is a helper method to define mock.On call
// - count int // - count int
// - done <-chan struct{} // - done <-chan struct{}
func (_e *MockAllocator_Expecter) GetGenerator(count interface{}, done interface{}) *MockAllocator_GetGenerator_Call { func (_e *MockAllocator_Expecter) GetGenerator(count interface{}, done interface{}) *MockAllocator_GetGenerator_Call {
return &MockAllocator_GetGenerator_Call{Call: _e.mock.On("GetGenerator", count, done)} return &MockAllocator_GetGenerator_Call{Call: _e.mock.On("GetGenerator", count, done)}
} }

View File

@ -423,6 +423,31 @@ func (c *Client) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
return ret.(*datapb.GetRecoveryInfoResponse), err return ret.(*datapb.GetRecoveryInfoResponse), err
} }
// GetRecoveryInfoV2 request segment recovery info of collection/partitions
//
// ctx is the context to control request deadline and cancellation
// req contains the collection/partitions id to query
//
// response struct `GetRecoveryInfoResponseV2` contains the list of segments info and corresponding vchannel info
// error is returned only when some communication issue occurs
func (c *Client) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryInfoRequestV2) (*datapb.GetRecoveryInfoResponseV2, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.sess.ServerID)),
)
ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) {
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.GetRecoveryInfoV2(ctx, req)
})
if err != nil || ret == nil {
return nil, err
}
return ret.(*datapb.GetRecoveryInfoResponseV2), err
}
// GetFlushedSegments returns flushed segment list of requested collection/parition // GetFlushedSegments returns flushed segment list of requested collection/parition
// //
// ctx is the context to control request deadline and cancellation // ctx is the context to control request deadline and cancellation

View File

@ -208,6 +208,9 @@ func Test_NewClient(t *testing.T) {
ret, err := client.CheckHealth(ctx, nil) ret, err := client.CheckHealth(ctx, nil)
retCheck(retNotNil, ret, err) retCheck(retNotNil, ret, err)
} }
r40, err := client.GetRecoveryInfoV2(ctx, nil)
retCheck(retNotNil, r40, err)
} }
client.grpcClient = &mock.GRPCClientBase[datapb.DataCoordClient]{ client.grpcClient = &mock.GRPCClientBase[datapb.DataCoordClient]{

View File

@ -283,6 +283,11 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
return s.dataCoord.GetRecoveryInfo(ctx, req) return s.dataCoord.GetRecoveryInfo(ctx, req)
} }
// GetRecoveryInfoV2 gets information for recovering channels
func (s *Server) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryInfoRequestV2) (*datapb.GetRecoveryInfoResponseV2, error) {
return s.dataCoord.GetRecoveryInfoV2(ctx, req)
}
// GetFlushedSegments get all flushed segments of a partition // GetFlushedSegments get all flushed segments of a partition
func (s *Server) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) { func (s *Server) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) {
return s.dataCoord.GetFlushedSegments(ctx, req) return s.dataCoord.GetFlushedSegments(ctx, req)

View File

@ -375,6 +375,10 @@ func (m *MockDataCoord) GetRecoveryInfo(ctx context.Context, req *datapb.GetReco
return nil, nil return nil, nil
} }
func (m *MockDataCoord) GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryInfoRequestV2) (*datapb.GetRecoveryInfoResponseV2, error) {
return nil, nil
}
func (m *MockDataCoord) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) { func (m *MockDataCoord) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) {
return nil, nil return nil, nil
} }

File diff suppressed because it is too large Load Diff

View File

@ -41,6 +41,7 @@ service DataCoord {
rpc SaveBinlogPaths(SaveBinlogPathsRequest) returns (common.Status){} rpc SaveBinlogPaths(SaveBinlogPathsRequest) returns (common.Status){}
rpc GetRecoveryInfo(GetRecoveryInfoRequest) returns (GetRecoveryInfoResponse){} rpc GetRecoveryInfo(GetRecoveryInfoRequest) returns (GetRecoveryInfoResponse){}
rpc GetRecoveryInfoV2(GetRecoveryInfoRequestV2) returns (GetRecoveryInfoResponseV2){}
rpc GetFlushedSegments(GetFlushedSegmentsRequest) returns(GetFlushedSegmentsResponse){} rpc GetFlushedSegments(GetFlushedSegmentsRequest) returns(GetFlushedSegmentsResponse){}
rpc GetSegmentsByStates(GetSegmentsByStatesRequest) returns(GetSegmentsByStatesResponse){} rpc GetSegmentsByStates(GetSegmentsByStatesRequest) returns(GetSegmentsByStatesResponse){}
rpc GetFlushAllState(milvus.GetFlushAllStateRequest) returns(milvus.GetFlushAllStateResponse) {} rpc GetFlushAllState(milvus.GetFlushAllStateRequest) returns(milvus.GetFlushAllStateResponse) {}
@ -369,6 +370,18 @@ message GetRecoveryInfoRequest {
int64 partitionID = 3; int64 partitionID = 3;
} }
message GetRecoveryInfoResponseV2 {
common.Status status = 1;
repeated VchannelInfo channels = 2;
repeated SegmentInfo segments = 3;
}
message GetRecoveryInfoRequestV2 {
common.MsgBase base = 1;
int64 collectionID = 2;
repeated int64 partitionIDs = 3;
}
message GetSegmentsByStatesRequest { message GetSegmentsByStatesRequest {
common.MsgBase base = 1; common.MsgBase base = 1;
int64 collectionID = 2; int64 collectionID = 2;

File diff suppressed because it is too large Load Diff

View File

@ -42,8 +42,8 @@ type MockBalancer_AssignChannel_Call struct {
} }
// AssignChannel is a helper method to define mock.On call // AssignChannel is a helper method to define mock.On call
// - channels []*meta.DmChannel // - channels []*meta.DmChannel
// - nodes []int64 // - nodes []int64
func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}) *MockBalancer_AssignChannel_Call { func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}) *MockBalancer_AssignChannel_Call {
return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes)} return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes)}
} }

View File

@ -282,24 +282,29 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
defer suite.TearDownTest() defer suite.TearDownTest()
balancer := suite.balancer balancer := suite.balancer
collection := utils.CreateTestCollection(1, 1) collection := utils.CreateTestCollection(1, 1)
segments := []*datapb.SegmentBinlogs{ segments := []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
}, },
{ {
SegmentID: 2, ID: 2,
PartitionID: 1,
}, },
{ {
SegmentID: 3, ID: 3,
PartitionID: 1,
}, },
{ {
SegmentID: 4, ID: 4,
PartitionID: 1,
}, },
{ {
SegmentID: 5, ID: 5,
PartitionID: 1,
}, },
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(nil, segments, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil)
balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1) balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1)
collection.LoadPercentage = 100 collection.LoadPercentage = 100
@ -308,8 +313,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
balancer.meta.CollectionManager.PutCollection(collection) balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
suite.broker.ExpectedCalls = nil suite.broker.ExpectedCalls = nil
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil).Maybe() suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(nil, segments, nil)
balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
suite.mockScheduler.Mock.On("GetNodeChannelDelta", mock.Anything).Return(0) suite.mockScheduler.Mock.On("GetNodeChannelDelta", mock.Anything).Return(0)
for node, s := range c.distributions { for node, s := range c.distributions {
@ -344,8 +348,8 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() {
shouldMock bool shouldMock bool
distributions map[int64][]*meta.Segment distributions map[int64][]*meta.Segment
distributionChannels map[int64][]*meta.DmChannel distributionChannels map[int64][]*meta.DmChannel
segmentInCurrent []*datapb.SegmentBinlogs segmentInCurrent []*datapb.SegmentInfo
segmentInNext []*datapb.SegmentBinlogs segmentInNext []*datapb.SegmentInfo
expectPlans []SegmentAssignPlan expectPlans []SegmentAssignPlan
expectChannelPlans []ChannelAssignPlan expectChannelPlans []ChannelAssignPlan
}{ }{
@ -366,39 +370,49 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() {
{SegmentInfo: &datapb.SegmentInfo{ID: 5, CollectionID: 1, NumOfRows: 10}, Node: 3}, {SegmentInfo: &datapb.SegmentInfo{ID: 5, CollectionID: 1, NumOfRows: 10}, Node: 3},
}, },
}, },
segmentInCurrent: []*datapb.SegmentBinlogs{ segmentInCurrent: []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
}, },
{ {
SegmentID: 2, ID: 2,
PartitionID: 1,
}, },
{ {
SegmentID: 3, ID: 3,
PartitionID: 1,
}, },
{ {
SegmentID: 4, ID: 4,
PartitionID: 1,
}, },
{ {
SegmentID: 5, ID: 5,
PartitionID: 1,
}, },
}, },
segmentInNext: []*datapb.SegmentBinlogs{ segmentInNext: []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
}, },
{ {
SegmentID: 2, ID: 2,
PartitionID: 1,
}, },
{ {
SegmentID: 3, ID: 3,
PartitionID: 1,
}, },
{ {
SegmentID: 4, ID: 4,
PartitionID: 1,
}, },
{ {
SegmentID: 5, ID: 5,
PartitionID: 1,
}, },
}, },
distributionChannels: map[int64][]*meta.DmChannel{ distributionChannels: map[int64][]*meta.DmChannel{
@ -417,7 +431,6 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() {
{Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v3"}, Node: 3}, From: 3, To: 1, ReplicaID: 1}, {Channel: &meta.DmChannel{VchannelInfo: &datapb.VchannelInfo{CollectionID: 1, ChannelName: "v3"}, Node: 3}, From: 3, To: 1, ReplicaID: 1},
}, },
}, },
{ {
name: "not exist in next target", name: "not exist in next target",
nodes: []int64{1, 2, 3}, nodes: []int64{1, 2, 3},
@ -435,29 +448,36 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() {
{SegmentInfo: &datapb.SegmentInfo{ID: 5, CollectionID: 1, NumOfRows: 10}, Node: 3}, {SegmentInfo: &datapb.SegmentInfo{ID: 5, CollectionID: 1, NumOfRows: 10}, Node: 3},
}, },
}, },
segmentInCurrent: []*datapb.SegmentBinlogs{ segmentInCurrent: []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
}, },
{ {
SegmentID: 2, ID: 2,
PartitionID: 1,
}, },
{ {
SegmentID: 3, ID: 3,
PartitionID: 1,
}, },
{ {
SegmentID: 4, ID: 4,
PartitionID: 1,
}, },
{ {
SegmentID: 5, ID: 5,
PartitionID: 1,
}, },
}, },
segmentInNext: []*datapb.SegmentBinlogs{ segmentInNext: []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
}, },
{ {
SegmentID: 2, ID: 2,
PartitionID: 1,
}, },
}, },
distributionChannels: map[int64][]*meta.DmChannel{ distributionChannels: map[int64][]*meta.DmChannel{
@ -482,7 +502,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() {
balancer := suite.balancer balancer := suite.balancer
collection := utils.CreateTestCollection(1, 1) collection := utils.CreateTestCollection(1, 1)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(nil, c.segmentInCurrent, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, c.segmentInCurrent, nil)
balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1) balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1)
collection.LoadPercentage = 100 collection.LoadPercentage = 100
@ -491,8 +511,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnPartStopping() {
balancer.meta.CollectionManager.PutCollection(collection) balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
suite.broker.ExpectedCalls = nil suite.broker.ExpectedCalls = nil
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil).Maybe() suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, c.segmentInNext, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(nil, c.segmentInNext, nil)
balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
suite.mockScheduler.Mock.On("GetNodeChannelDelta", mock.Anything).Return(0) suite.mockScheduler.Mock.On("GetNodeChannelDelta", mock.Anything).Return(0)
for node, s := range c.distributions { for node, s := range c.distributions {
@ -572,24 +591,29 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() {
defer suite.TearDownTest() defer suite.TearDownTest()
balancer := suite.balancer balancer := suite.balancer
collection := utils.CreateTestCollection(1, 1) collection := utils.CreateTestCollection(1, 1)
segments := []*datapb.SegmentBinlogs{ segments := []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
}, },
{ {
SegmentID: 2, ID: 2,
PartitionID: 1,
}, },
{ {
SegmentID: 3, ID: 3,
PartitionID: 1,
}, },
{ {
SegmentID: 4, ID: 4,
PartitionID: 1,
}, },
{ {
SegmentID: 5, ID: 5,
PartitionID: 1,
}, },
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(nil, segments, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(nil, segments, nil)
balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1) balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1)
collection.LoadPercentage = 100 collection.LoadPercentage = 100

View File

@ -232,7 +232,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() {
notExistedNodes []int64 notExistedNodes []int64
collectionIDs []int64 collectionIDs []int64
replicaIDs []int64 replicaIDs []int64
collectionsSegments [][]*datapb.SegmentBinlogs collectionsSegments [][]*datapb.SegmentInfo
states []session.State states []session.State
shouldMock bool shouldMock bool
distributions map[int64][]*meta.Segment distributions map[int64][]*meta.Segment
@ -245,9 +245,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() {
nodes: []int64{1, 2}, nodes: []int64{1, 2},
collectionIDs: []int64{1}, collectionIDs: []int64{1},
replicaIDs: []int64{1}, replicaIDs: []int64{1},
collectionsSegments: [][]*datapb.SegmentBinlogs{ collectionsSegments: [][]*datapb.SegmentInfo{
{ {
{SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3}, {ID: 1, PartitionID: 1},
{ID: 2, PartitionID: 1},
{ID: 3, PartitionID: 1},
}, },
}, },
states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, states: []session.State{session.NodeStateNormal, session.NodeStateNormal},
@ -268,9 +270,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() {
nodes: []int64{1, 2}, nodes: []int64{1, 2},
collectionIDs: []int64{1}, collectionIDs: []int64{1},
replicaIDs: []int64{1}, replicaIDs: []int64{1},
collectionsSegments: [][]*datapb.SegmentBinlogs{ collectionsSegments: [][]*datapb.SegmentInfo{
{ {
{SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3}, {ID: 1, PartitionID: 1},
{ID: 2, PartitionID: 1},
{ID: 3, PartitionID: 1},
}, },
}, },
states: []session.State{session.NodeStateNormal, session.NodeStateNormal}, states: []session.State{session.NodeStateNormal, session.NodeStateNormal},
@ -288,6 +292,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() {
}, },
} }
suite.mockScheduler.EXPECT().GetSegmentTaskNum().Return(0)
for _, c := range cases { for _, c := range cases {
suite.Run(c.name, func() { suite.Run(c.name, func() {
suite.SetupSuite() suite.SetupSuite()
@ -299,9 +304,8 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceOneRound() {
for i := range c.collectionIDs { for i := range c.collectionIDs {
collection := utils.CreateTestCollection(c.collectionIDs[i], int32(c.replicaIDs[i])) collection := utils.CreateTestCollection(c.collectionIDs[i], int32(c.replicaIDs[i]))
collections = append(collections, collection) collections = append(collections, collection)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, c.collectionIDs[i], c.replicaIDs[i]).Return( suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, c.collectionIDs[i]).Return(
nil, c.collectionsSegments[i], nil) nil, c.collectionsSegments[i], nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionIDs[i]).Return([]int64{c.collectionIDs[i]}, nil).Maybe()
balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(c.collectionIDs[i], c.collectionIDs[i]) balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(c.collectionIDs[i], c.collectionIDs[i])
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionIDs[i], c.collectionIDs[i]) balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionIDs[i], c.collectionIDs[i])
collection.LoadPercentage = 100 collection.LoadPercentage = 100
@ -344,7 +348,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() {
notExistedNodes []int64 notExistedNodes []int64
collectionIDs []int64 collectionIDs []int64
replicaIDs []int64 replicaIDs []int64
collectionsSegments [][]*datapb.SegmentBinlogs collectionsSegments [][]*datapb.SegmentInfo
states []session.State states []session.State
shouldMock bool shouldMock bool
distributions []map[int64][]*meta.Segment distributions []map[int64][]*meta.Segment
@ -354,12 +358,14 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() {
nodes: []int64{1, 2, 3}, nodes: []int64{1, 2, 3},
collectionIDs: []int64{1, 2}, collectionIDs: []int64{1, 2},
replicaIDs: []int64{1, 2}, replicaIDs: []int64{1, 2},
collectionsSegments: [][]*datapb.SegmentBinlogs{ collectionsSegments: [][]*datapb.SegmentInfo{
{ {
{SegmentID: 1}, {SegmentID: 3}, {ID: 1, PartitionID: 1},
{ID: 3, PartitionID: 1},
}, },
{ {
{SegmentID: 2}, {SegmentID: 4}, {ID: 2, PartitionID: 2},
{ID: 4, PartitionID: 2},
}, },
}, },
states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal},
@ -401,14 +407,14 @@ func (suite *ScoreBasedBalancerTestSuite) TestBalanceMultiRound() {
defer suite.TearDownTest() defer suite.TearDownTest()
balancer := suite.balancer balancer := suite.balancer
suite.mockScheduler.EXPECT().GetSegmentTaskNum().Return(0)
//1. set up target for multi collections //1. set up target for multi collections
collections := make([]*meta.Collection, 0, len(balanceCase.collectionIDs)) collections := make([]*meta.Collection, 0, len(balanceCase.collectionIDs))
for i := range balanceCase.collectionIDs { for i := range balanceCase.collectionIDs {
collection := utils.CreateTestCollection(balanceCase.collectionIDs[i], int32(balanceCase.replicaIDs[i])) collection := utils.CreateTestCollection(balanceCase.collectionIDs[i], int32(balanceCase.replicaIDs[i]))
collections = append(collections, collection) collections = append(collections, collection)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, balanceCase.collectionIDs[i], balanceCase.replicaIDs[i]).Return( suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, balanceCase.collectionIDs[i]).Return(
nil, balanceCase.collectionsSegments[i], nil) nil, balanceCase.collectionsSegments[i], nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, balanceCase.collectionIDs[i]).Return([]int64{balanceCase.collectionIDs[i]}, nil).Maybe()
balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i]) balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i])
balancer.targetMgr.UpdateCollectionCurrentTarget(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i]) balancer.targetMgr.UpdateCollectionCurrentTarget(balanceCase.collectionIDs[i], balanceCase.collectionIDs[i])
collection.LoadPercentage = 100 collection.LoadPercentage = 100
@ -458,7 +464,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() {
notExistedNodes []int64 notExistedNodes []int64
collectionIDs []int64 collectionIDs []int64
replicaIDs []int64 replicaIDs []int64
collectionsSegments [][]*datapb.SegmentBinlogs collectionsSegments [][]*datapb.SegmentInfo
states []session.State states []session.State
shouldMock bool shouldMock bool
distributions map[int64][]*meta.Segment distributions map[int64][]*meta.Segment
@ -472,9 +478,11 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() {
outBoundNodes: []int64{}, outBoundNodes: []int64{},
collectionIDs: []int64{1}, collectionIDs: []int64{1},
replicaIDs: []int64{1}, replicaIDs: []int64{1},
collectionsSegments: [][]*datapb.SegmentBinlogs{ collectionsSegments: [][]*datapb.SegmentInfo{
{ {
{SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3}, {ID: 1, PartitionID: 1},
{ID: 2, PartitionID: 1},
{ID: 3, PartitionID: 1},
}, },
}, },
states: []session.State{session.NodeStateStopping, session.NodeStateNormal, session.NodeStateNormal}, states: []session.State{session.NodeStateStopping, session.NodeStateNormal, session.NodeStateNormal},
@ -501,9 +509,9 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() {
outBoundNodes: []int64{}, outBoundNodes: []int64{},
collectionIDs: []int64{1}, collectionIDs: []int64{1},
replicaIDs: []int64{1}, replicaIDs: []int64{1},
collectionsSegments: [][]*datapb.SegmentBinlogs{ collectionsSegments: [][]*datapb.SegmentInfo{
{ {
{SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3}, {ID: 1}, {ID: 2}, {ID: 3},
}, },
}, },
states: []session.State{session.NodeStateStopping, session.NodeStateStopping, session.NodeStateStopping}, states: []session.State{session.NodeStateStopping, session.NodeStateStopping, session.NodeStateStopping},
@ -525,9 +533,9 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() {
outBoundNodes: []int64{1, 2, 3}, outBoundNodes: []int64{1, 2, 3},
collectionIDs: []int64{1}, collectionIDs: []int64{1},
replicaIDs: []int64{1}, replicaIDs: []int64{1},
collectionsSegments: [][]*datapb.SegmentBinlogs{ collectionsSegments: [][]*datapb.SegmentInfo{
{ {
{SegmentID: 1}, {SegmentID: 2}, {SegmentID: 3}, {ID: 1}, {ID: 2}, {ID: 3},
}, },
}, },
states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal}, states: []session.State{session.NodeStateNormal, session.NodeStateNormal, session.NodeStateNormal},
@ -558,9 +566,8 @@ func (suite *ScoreBasedBalancerTestSuite) TestStoppedBalance() {
for i := range c.collectionIDs { for i := range c.collectionIDs {
collection := utils.CreateTestCollection(c.collectionIDs[i], int32(c.replicaIDs[i])) collection := utils.CreateTestCollection(c.collectionIDs[i], int32(c.replicaIDs[i]))
collections = append(collections, collection) collections = append(collections, collection)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, c.collectionIDs[i], c.replicaIDs[i]).Return( suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, c.collectionIDs[i]).Return(
nil, c.collectionsSegments[i], nil) nil, c.collectionsSegments[i], nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, c.collectionIDs[i]).Return([]int64{c.collectionIDs[i]}, nil).Maybe()
balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(c.collectionIDs[i], c.collectionIDs[i]) balancer.targetMgr.UpdateCollectionNextTargetWithPartitions(c.collectionIDs[i], c.collectionIDs[i])
balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionIDs[i], c.collectionIDs[i]) balancer.targetMgr.UpdateCollectionCurrentTarget(c.collectionIDs[i], c.collectionIDs[i])
collection.LoadPercentage = 100 collection.LoadPercentage = 100

View File

@ -114,7 +114,7 @@ func (suite *ChannelCheckerTestSuite) TestLoadChannel() {
}, },
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, nil, nil) channels, nil, nil)
checker.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) checker.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -153,9 +153,9 @@ func (suite *ChannelCheckerTestSuite) TestRepeatedChannels() {
err = checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) err = checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
suite.NoError(err) suite.NoError(err)
segments := []*datapb.SegmentBinlogs{ segments := []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
InsertChannel: "test-insert-channel", InsertChannel: "test-insert-channel",
}, },
} }
@ -166,7 +166,7 @@ func (suite *ChannelCheckerTestSuite) TestRepeatedChannels() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil) channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) checker.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
checker.dist.ChannelDistManager.Update(1, utils.CreateTestChannel(1, 1, 1, "test-insert-channel")) checker.dist.ChannelDistManager.Update(1, utils.CreateTestChannel(1, 1, 1, "test-insert-channel"))

View File

@ -111,13 +111,14 @@ func (suite *SegmentCheckerTestSuite) TestLoadSegments() {
checker.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, 2) checker.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, 2)
// set target // set target
segments := []*datapb.SegmentBinlogs{ segments := []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
InsertChannel: "test-insert-channel", InsertChannel: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
nil, segments, nil) nil, segments, nil)
checker.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) checker.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -166,13 +167,14 @@ func (suite *SegmentCheckerTestSuite) TestReleaseRepeatedSegments() {
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
// set target // set target
segments := []*datapb.SegmentBinlogs{ segments := []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
InsertChannel: "test-insert-channel", InsertChannel: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
nil, segments, nil) nil, segments, nil)
checker.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) checker.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -206,9 +208,10 @@ func (suite *SegmentCheckerTestSuite) TestReleaseGrowingSegments() {
checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) checker.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) checker.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentBinlogs{ segments := []*datapb.SegmentInfo{
{ {
SegmentID: 3, ID: 3,
PartitionID: 1,
InsertChannel: "test-insert-channel", InsertChannel: "test-insert-channel",
}, },
} }
@ -219,7 +222,7 @@ func (suite *SegmentCheckerTestSuite) TestReleaseGrowingSegments() {
SeekPosition: &msgpb.MsgPosition{Timestamp: 10}, SeekPosition: &msgpb.MsgPosition{Timestamp: 10},
}, },
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil) channels, segments, nil)
checker.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) checker.targetMgr.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
checker.targetMgr.UpdateCollectionCurrentTarget(int64(1), int64(1)) checker.targetMgr.UpdateCollectionCurrentTarget(int64(1), int64(1))

View File

@ -106,19 +106,18 @@ func (suite *JobSuite) SetupSuite() {
ChannelName: channel, ChannelName: channel,
}) })
} }
segmentBinlogs := []*datapb.SegmentInfo{}
for partition, segments := range partitions { for partition, segments := range partitions {
segmentBinlogs := []*datapb.SegmentBinlogs{}
for _, segment := range segments { for _, segment := range segments {
segmentBinlogs = append(segmentBinlogs, &datapb.SegmentBinlogs{ segmentBinlogs = append(segmentBinlogs, &datapb.SegmentInfo{
SegmentID: segment, ID: segment,
PartitionID: partition,
InsertChannel: suite.channels[collection][segment%2], InsertChannel: suite.channels[collection][segment%2],
}) })
} }
suite.broker.EXPECT().
GetRecoveryInfo(mock.Anything, collection, partition).
Return(vChannels, segmentBinlogs, nil)
} }
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collection).Return(vChannels, segmentBinlogs, nil)
} }
suite.cluster = session.NewMockCluster(suite.T()) suite.cluster = session.NewMockCluster(suite.T())
@ -578,8 +577,7 @@ func (suite *JobSuite) TestLoadPartition() {
suite.meta.ResourceManager.AddResourceGroup("rg3") suite.meta.ResourceManager.AddResourceGroup("rg3")
// test load 3 replica in 1 rg, should pass rg check // test load 3 replica in 1 rg, should pass rg check
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(999)).Return([]int64{888}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(999)).Return(nil, nil, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(999), int64(888)).Return(nil, nil, nil)
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
CollectionID: 999, CollectionID: 999,
PartitionIDs: []int64{888}, PartitionIDs: []int64{888},
@ -602,8 +600,7 @@ func (suite *JobSuite) TestLoadPartition() {
suite.Contains(err.Error(), meta.ErrNodeNotEnough.Error()) suite.Contains(err.Error(), meta.ErrNodeNotEnough.Error())
// test load 3 replica in 3 rg, should pass rg check // test load 3 replica in 3 rg, should pass rg check
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(999)).Return([]int64{888}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(999)).Return(nil, nil, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(999), int64(888)).Return(nil, nil, nil)
req = &querypb.LoadPartitionsRequest{ req = &querypb.LoadPartitionsRequest{
CollectionID: 999, CollectionID: 999,
PartitionIDs: []int64{888}, PartitionIDs: []int64{888},

View File

@ -47,6 +47,7 @@ type Broker interface {
GetRecoveryInfo(ctx context.Context, collectionID UniqueID, partitionID UniqueID) ([]*datapb.VchannelInfo, []*datapb.SegmentBinlogs, error) GetRecoveryInfo(ctx context.Context, collectionID UniqueID, partitionID UniqueID) ([]*datapb.VchannelInfo, []*datapb.SegmentBinlogs, error)
GetSegmentInfo(ctx context.Context, segmentID ...UniqueID) (*datapb.GetSegmentInfoResponse, error) GetSegmentInfo(ctx context.Context, segmentID ...UniqueID) (*datapb.GetSegmentInfoResponse, error)
GetIndexInfo(ctx context.Context, collectionID UniqueID, segmentID UniqueID) ([]*querypb.FieldIndexInfo, error) GetIndexInfo(ctx context.Context, collectionID UniqueID, segmentID UniqueID) ([]*querypb.FieldIndexInfo, error)
GetRecoveryInfoV2(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) ([]*datapb.VchannelInfo, []*datapb.SegmentInfo, error)
} }
type CoordinatorBroker struct { type CoordinatorBroker struct {
@ -135,6 +136,32 @@ func (broker *CoordinatorBroker) GetRecoveryInfo(ctx context.Context, collection
return recoveryInfo.Channels, recoveryInfo.Binlogs, nil return recoveryInfo.Channels, recoveryInfo.Binlogs, nil
} }
func (broker *CoordinatorBroker) GetRecoveryInfoV2(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) ([]*datapb.VchannelInfo, []*datapb.SegmentInfo, error) {
ctx, cancel := context.WithTimeout(ctx, brokerRPCTimeout)
defer cancel()
getRecoveryInfoRequest := &datapb.GetRecoveryInfoRequestV2{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetRecoveryInfo),
),
CollectionID: collectionID,
PartitionIDs: partitionIDs,
}
recoveryInfo, err := broker.dataCoord.GetRecoveryInfoV2(ctx, getRecoveryInfoRequest)
if err != nil {
log.Error("get recovery info failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs), zap.Error(err))
return nil, nil, err
}
if recoveryInfo.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
err = errors.New(recoveryInfo.GetStatus().GetReason())
log.Error("get recovery info failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs), zap.Error(err))
return nil, nil, err
}
return recoveryInfo.Channels, recoveryInfo.Segments, nil
}
func (broker *CoordinatorBroker) GetSegmentInfo(ctx context.Context, ids ...UniqueID) (*datapb.GetSegmentInfoResponse, error) { func (broker *CoordinatorBroker) GetSegmentInfo(ctx context.Context, ids ...UniqueID) (*datapb.GetSegmentInfoResponse, error) {
ctx, cancel := context.WithTimeout(ctx, brokerRPCTimeout) ctx, cancel := context.WithTimeout(ctx, brokerRPCTimeout)
defer cancel() defer cancel()

View File

@ -28,6 +28,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/datapb"
) )
func TestCoordinatorBroker_GetCollectionSchema(t *testing.T) { func TestCoordinatorBroker_GetCollectionSchema(t *testing.T) {
@ -73,3 +74,43 @@ func TestCoordinatorBroker_GetCollectionSchema(t *testing.T) {
assert.Equal(t, "test_schema", schema.GetName()) assert.Equal(t, "test_schema", schema.GetName())
}) })
} }
func TestCoordinatorBroker_GetRecoveryInfo(t *testing.T) {
t.Run("normal case", func(t *testing.T) {
dc := mocks.NewDataCoord(t)
dc.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(&datapb.GetRecoveryInfoResponseV2{}, nil)
ctx := context.Background()
broker := &CoordinatorBroker{dataCoord: dc}
_, _, err := broker.GetRecoveryInfoV2(ctx, 1)
assert.NoError(t, err)
})
t.Run("get error", func(t *testing.T) {
dc := mocks.NewDataCoord(t)
fakeErr := errors.New("fake error")
dc.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(nil, fakeErr)
ctx := context.Background()
broker := &CoordinatorBroker{dataCoord: dc}
_, _, err := broker.GetRecoveryInfoV2(ctx, 1)
assert.ErrorIs(t, err, fakeErr)
})
t.Run("return non-success code", func(t *testing.T) {
dc := mocks.NewDataCoord(t)
dc.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(&datapb.GetRecoveryInfoResponseV2{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
},
}, nil)
ctx := context.Background()
broker := &CoordinatorBroker{dataCoord: dc}
_, _, err := broker.GetRecoveryInfoV2(ctx, 1)
assert.Error(t, err)
})
}

View File

@ -56,8 +56,8 @@ type MockBroker_GetCollectionSchema_Call struct {
} }
// GetCollectionSchema is a helper method to define mock.On call // GetCollectionSchema is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - collectionID int64 // - collectionID int64
func (_e *MockBroker_Expecter) GetCollectionSchema(ctx interface{}, collectionID interface{}) *MockBroker_GetCollectionSchema_Call { func (_e *MockBroker_Expecter) GetCollectionSchema(ctx interface{}, collectionID interface{}) *MockBroker_GetCollectionSchema_Call {
return &MockBroker_GetCollectionSchema_Call{Call: _e.mock.On("GetCollectionSchema", ctx, collectionID)} return &MockBroker_GetCollectionSchema_Call{Call: _e.mock.On("GetCollectionSchema", ctx, collectionID)}
} }
@ -103,9 +103,9 @@ type MockBroker_GetIndexInfo_Call struct {
} }
// GetIndexInfo is a helper method to define mock.On call // GetIndexInfo is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - collectionID int64 // - collectionID int64
// - segmentID int64 // - segmentID int64
func (_e *MockBroker_Expecter) GetIndexInfo(ctx interface{}, collectionID interface{}, segmentID interface{}) *MockBroker_GetIndexInfo_Call { func (_e *MockBroker_Expecter) GetIndexInfo(ctx interface{}, collectionID interface{}, segmentID interface{}) *MockBroker_GetIndexInfo_Call {
return &MockBroker_GetIndexInfo_Call{Call: _e.mock.On("GetIndexInfo", ctx, collectionID, segmentID)} return &MockBroker_GetIndexInfo_Call{Call: _e.mock.On("GetIndexInfo", ctx, collectionID, segmentID)}
} }
@ -151,8 +151,8 @@ type MockBroker_GetPartitions_Call struct {
} }
// GetPartitions is a helper method to define mock.On call // GetPartitions is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - collectionID int64 // - collectionID int64
func (_e *MockBroker_Expecter) GetPartitions(ctx interface{}, collectionID interface{}) *MockBroker_GetPartitions_Call { func (_e *MockBroker_Expecter) GetPartitions(ctx interface{}, collectionID interface{}) *MockBroker_GetPartitions_Call {
return &MockBroker_GetPartitions_Call{Call: _e.mock.On("GetPartitions", ctx, collectionID)} return &MockBroker_GetPartitions_Call{Call: _e.mock.On("GetPartitions", ctx, collectionID)}
} }
@ -207,9 +207,9 @@ type MockBroker_GetRecoveryInfo_Call struct {
} }
// GetRecoveryInfo is a helper method to define mock.On call // GetRecoveryInfo is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - collectionID int64 // - collectionID int64
// - partitionID int64 // - partitionID int64
func (_e *MockBroker_Expecter) GetRecoveryInfo(ctx interface{}, collectionID interface{}, partitionID interface{}) *MockBroker_GetRecoveryInfo_Call { func (_e *MockBroker_Expecter) GetRecoveryInfo(ctx interface{}, collectionID interface{}, partitionID interface{}) *MockBroker_GetRecoveryInfo_Call {
return &MockBroker_GetRecoveryInfo_Call{Call: _e.mock.On("GetRecoveryInfo", ctx, collectionID, partitionID)} return &MockBroker_GetRecoveryInfo_Call{Call: _e.mock.On("GetRecoveryInfo", ctx, collectionID, partitionID)}
} }
@ -226,6 +226,77 @@ func (_c *MockBroker_GetRecoveryInfo_Call) Return(_a0 []*datapb.VchannelInfo, _a
return _c return _c
} }
// GetRecoveryInfoV2 provides a mock function with given fields: ctx, collectionID, partitionIDs
func (_m *MockBroker) GetRecoveryInfoV2(ctx context.Context, collectionID int64, partitionIDs ...int64) ([]*datapb.VchannelInfo, []*datapb.SegmentInfo, error) {
_va := make([]interface{}, len(partitionIDs))
for _i := range partitionIDs {
_va[_i] = partitionIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, collectionID)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 []*datapb.VchannelInfo
if rf, ok := ret.Get(0).(func(context.Context, int64, ...int64) []*datapb.VchannelInfo); ok {
r0 = rf(ctx, collectionID, partitionIDs...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*datapb.VchannelInfo)
}
}
var r1 []*datapb.SegmentInfo
if rf, ok := ret.Get(1).(func(context.Context, int64, ...int64) []*datapb.SegmentInfo); ok {
r1 = rf(ctx, collectionID, partitionIDs...)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).([]*datapb.SegmentInfo)
}
}
var r2 error
if rf, ok := ret.Get(2).(func(context.Context, int64, ...int64) error); ok {
r2 = rf(ctx, collectionID, partitionIDs...)
} else {
r2 = ret.Error(2)
}
return r0, r1, r2
}
// MockBroker_GetRecoveryInfoV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRecoveryInfoV2'
type MockBroker_GetRecoveryInfoV2_Call struct {
*mock.Call
}
// GetRecoveryInfoV2 is a helper method to define mock.On call
// - ctx context.Context
// - collectionID int64
// - partitionIDs ...int64
func (_e *MockBroker_Expecter) GetRecoveryInfoV2(ctx interface{}, collectionID interface{}, partitionIDs ...interface{}) *MockBroker_GetRecoveryInfoV2_Call {
return &MockBroker_GetRecoveryInfoV2_Call{Call: _e.mock.On("GetRecoveryInfoV2",
append([]interface{}{ctx, collectionID}, partitionIDs...)...)}
}
func (_c *MockBroker_GetRecoveryInfoV2_Call) Run(run func(ctx context.Context, collectionID int64, partitionIDs ...int64)) *MockBroker_GetRecoveryInfoV2_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]int64, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(int64)
}
}
run(args[0].(context.Context), args[1].(int64), variadicArgs...)
})
return _c
}
func (_c *MockBroker_GetRecoveryInfoV2_Call) Return(_a0 []*datapb.VchannelInfo, _a1 []*datapb.SegmentInfo, _a2 error) *MockBroker_GetRecoveryInfoV2_Call {
_c.Call.Return(_a0, _a1, _a2)
return _c
}
// GetSegmentInfo provides a mock function with given fields: ctx, segmentID // GetSegmentInfo provides a mock function with given fields: ctx, segmentID
func (_m *MockBroker) GetSegmentInfo(ctx context.Context, segmentID ...int64) (*datapb.GetSegmentInfoResponse, error) { func (_m *MockBroker) GetSegmentInfo(ctx context.Context, segmentID ...int64) (*datapb.GetSegmentInfoResponse, error) {
_va := make([]interface{}, len(segmentID)) _va := make([]interface{}, len(segmentID))
@ -262,8 +333,8 @@ type MockBroker_GetSegmentInfo_Call struct {
} }
// GetSegmentInfo is a helper method to define mock.On call // GetSegmentInfo is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - segmentID ...int64 // - segmentID ...int64
func (_e *MockBroker_Expecter) GetSegmentInfo(ctx interface{}, segmentID ...interface{}) *MockBroker_GetSegmentInfo_Call { func (_e *MockBroker_Expecter) GetSegmentInfo(ctx interface{}, segmentID ...interface{}) *MockBroker_GetSegmentInfo_Call {
return &MockBroker_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", return &MockBroker_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo",
append([]interface{}{ctx}, segmentID...)...)} append([]interface{}{ctx}, segmentID...)...)}

View File

@ -220,7 +220,7 @@ type MockStore_ReleaseCollection_Call struct {
} }
// ReleaseCollection is a helper method to define mock.On call // ReleaseCollection is a helper method to define mock.On call
// - collection int64 // - collection int64
func (_e *MockStore_Expecter) ReleaseCollection(collection interface{}) *MockStore_ReleaseCollection_Call { func (_e *MockStore_Expecter) ReleaseCollection(collection interface{}) *MockStore_ReleaseCollection_Call {
return &MockStore_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", collection)} return &MockStore_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", collection)}
} }
@ -264,8 +264,8 @@ type MockStore_ReleasePartition_Call struct {
} }
// ReleasePartition is a helper method to define mock.On call // ReleasePartition is a helper method to define mock.On call
// - collection int64 // - collection int64
// - partitions ...int64 // - partitions ...int64
func (_e *MockStore_Expecter) ReleasePartition(collection interface{}, partitions ...interface{}) *MockStore_ReleasePartition_Call { func (_e *MockStore_Expecter) ReleasePartition(collection interface{}, partitions ...interface{}) *MockStore_ReleasePartition_Call {
return &MockStore_ReleasePartition_Call{Call: _e.mock.On("ReleasePartition", return &MockStore_ReleasePartition_Call{Call: _e.mock.On("ReleasePartition",
append([]interface{}{collection}, partitions...)...)} append([]interface{}{collection}, partitions...)...)}
@ -309,8 +309,8 @@ type MockStore_ReleaseReplica_Call struct {
} }
// ReleaseReplica is a helper method to define mock.On call // ReleaseReplica is a helper method to define mock.On call
// - collection int64 // - collection int64
// - replica int64 // - replica int64
func (_e *MockStore_Expecter) ReleaseReplica(collection interface{}, replica interface{}) *MockStore_ReleaseReplica_Call { func (_e *MockStore_Expecter) ReleaseReplica(collection interface{}, replica interface{}) *MockStore_ReleaseReplica_Call {
return &MockStore_ReleaseReplica_Call{Call: _e.mock.On("ReleaseReplica", collection, replica)} return &MockStore_ReleaseReplica_Call{Call: _e.mock.On("ReleaseReplica", collection, replica)}
} }
@ -347,7 +347,7 @@ type MockStore_ReleaseReplicas_Call struct {
} }
// ReleaseReplicas is a helper method to define mock.On call // ReleaseReplicas is a helper method to define mock.On call
// - collectionID int64 // - collectionID int64
func (_e *MockStore_Expecter) ReleaseReplicas(collectionID interface{}) *MockStore_ReleaseReplicas_Call { func (_e *MockStore_Expecter) ReleaseReplicas(collectionID interface{}) *MockStore_ReleaseReplicas_Call {
return &MockStore_ReleaseReplicas_Call{Call: _e.mock.On("ReleaseReplicas", collectionID)} return &MockStore_ReleaseReplicas_Call{Call: _e.mock.On("ReleaseReplicas", collectionID)}
} }
@ -384,7 +384,7 @@ type MockStore_RemoveResourceGroup_Call struct {
} }
// RemoveResourceGroup is a helper method to define mock.On call // RemoveResourceGroup is a helper method to define mock.On call
// - rgName string // - rgName string
func (_e *MockStore_Expecter) RemoveResourceGroup(rgName interface{}) *MockStore_RemoveResourceGroup_Call { func (_e *MockStore_Expecter) RemoveResourceGroup(rgName interface{}) *MockStore_RemoveResourceGroup_Call {
return &MockStore_RemoveResourceGroup_Call{Call: _e.mock.On("RemoveResourceGroup", rgName)} return &MockStore_RemoveResourceGroup_Call{Call: _e.mock.On("RemoveResourceGroup", rgName)}
} }
@ -428,8 +428,8 @@ type MockStore_SaveCollection_Call struct {
} }
// SaveCollection is a helper method to define mock.On call // SaveCollection is a helper method to define mock.On call
// - collection *querypb.CollectionLoadInfo // - collection *querypb.CollectionLoadInfo
// - partitions ...*querypb.PartitionLoadInfo // - partitions ...*querypb.PartitionLoadInfo
func (_e *MockStore_Expecter) SaveCollection(collection interface{}, partitions ...interface{}) *MockStore_SaveCollection_Call { func (_e *MockStore_Expecter) SaveCollection(collection interface{}, partitions ...interface{}) *MockStore_SaveCollection_Call {
return &MockStore_SaveCollection_Call{Call: _e.mock.On("SaveCollection", return &MockStore_SaveCollection_Call{Call: _e.mock.On("SaveCollection",
append([]interface{}{collection}, partitions...)...)} append([]interface{}{collection}, partitions...)...)}
@ -479,7 +479,7 @@ type MockStore_SavePartition_Call struct {
} }
// SavePartition is a helper method to define mock.On call // SavePartition is a helper method to define mock.On call
// - info ...*querypb.PartitionLoadInfo // - info ...*querypb.PartitionLoadInfo
func (_e *MockStore_Expecter) SavePartition(info ...interface{}) *MockStore_SavePartition_Call { func (_e *MockStore_Expecter) SavePartition(info ...interface{}) *MockStore_SavePartition_Call {
return &MockStore_SavePartition_Call{Call: _e.mock.On("SavePartition", return &MockStore_SavePartition_Call{Call: _e.mock.On("SavePartition",
append([]interface{}{}, info...)...)} append([]interface{}{}, info...)...)}
@ -523,7 +523,7 @@ type MockStore_SaveReplica_Call struct {
} }
// SaveReplica is a helper method to define mock.On call // SaveReplica is a helper method to define mock.On call
// - replica *querypb.Replica // - replica *querypb.Replica
func (_e *MockStore_Expecter) SaveReplica(replica interface{}) *MockStore_SaveReplica_Call { func (_e *MockStore_Expecter) SaveReplica(replica interface{}) *MockStore_SaveReplica_Call {
return &MockStore_SaveReplica_Call{Call: _e.mock.On("SaveReplica", replica)} return &MockStore_SaveReplica_Call{Call: _e.mock.On("SaveReplica", replica)}
} }
@ -566,7 +566,7 @@ type MockStore_SaveResourceGroup_Call struct {
} }
// SaveResourceGroup is a helper method to define mock.On call // SaveResourceGroup is a helper method to define mock.On call
// - rgs ...*querypb.ResourceGroup // - rgs ...*querypb.ResourceGroup
func (_e *MockStore_Expecter) SaveResourceGroup(rgs ...interface{}) *MockStore_SaveResourceGroup_Call { func (_e *MockStore_Expecter) SaveResourceGroup(rgs ...interface{}) *MockStore_SaveResourceGroup_Call {
return &MockStore_SaveResourceGroup_Call{Call: _e.mock.On("SaveResourceGroup", return &MockStore_SaveResourceGroup_Call{Call: _e.mock.On("SaveResourceGroup",
append([]interface{}{}, rgs...)...)} append([]interface{}{}, rgs...)...)}

View File

@ -21,12 +21,12 @@ import (
"sync" "sync"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/samber/lo"
"go.uber.org/zap"
) )
type TargetScope = int32 type TargetScope = int32
@ -66,7 +66,7 @@ func (mgr *TargetManager) UpdateCollectionCurrentTarget(collectionID int64, part
log := log.With(zap.Int64("collectionID", collectionID), log := log.With(zap.Int64("collectionID", collectionID),
zap.Int64s("PartitionIDs", partitionIDs)) zap.Int64s("PartitionIDs", partitionIDs))
log.Info("start to update current target for collection") log.Debug("start to update current target for collection")
newTarget := mgr.next.getCollectionTarget(collectionID) newTarget := mgr.next.getCollectionTarget(collectionID)
if newTarget == nil || newTarget.IsEmpty() { if newTarget == nil || newTarget.IsEmpty() {
@ -76,7 +76,7 @@ func (mgr *TargetManager) UpdateCollectionCurrentTarget(collectionID int64, part
mgr.current.updateCollectionTarget(collectionID, newTarget) mgr.current.updateCollectionTarget(collectionID, newTarget)
mgr.next.removeCollectionTarget(collectionID) mgr.next.removeCollectionTarget(collectionID)
log.Info("finish to update current target for collection", log.Debug("finish to update current target for collection",
zap.Int64s("segments", newTarget.GetAllSegmentIDs()), zap.Int64s("segments", newTarget.GetAllSegmentIDs()),
zap.Strings("channels", newTarget.GetAllDmChannelNames())) zap.Strings("channels", newTarget.GetAllDmChannelNames()))
} }
@ -115,9 +115,10 @@ func (mgr *TargetManager) UpdateCollectionNextTarget(collectionID int64) error {
} }
func (mgr *TargetManager) updateCollectionNextTarget(collectionID int64, partitionIDs ...int64) error { func (mgr *TargetManager) updateCollectionNextTarget(collectionID int64, partitionIDs ...int64) error {
log := log.With(zap.Int64("collectionID", collectionID)) log := log.With(zap.Int64("collectionID", collectionID),
zap.Int64s("PartitionIDs", partitionIDs))
log.Info("start to update next targets for collection") log.Debug("start to update next targets for collection")
newTarget, err := mgr.PullNextTarget(mgr.broker, collectionID, partitionIDs...) newTarget, err := mgr.PullNextTarget(mgr.broker, collectionID, partitionIDs...)
if err != nil { if err != nil {
log.Error("failed to get next targets for collection", log.Error("failed to get next targets for collection",
@ -127,14 +128,14 @@ func (mgr *TargetManager) updateCollectionNextTarget(collectionID int64, partiti
mgr.next.updateCollectionTarget(collectionID, newTarget) mgr.next.updateCollectionTarget(collectionID, newTarget)
log.Info("finish to update next targets for collection", log.Debug("finish to update next targets for collection",
zap.Int64s("segments", newTarget.GetAllSegmentIDs()), zap.Int64s("segments", newTarget.GetAllSegmentIDs()),
zap.Strings("channels", newTarget.GetAllDmChannelNames())) zap.Strings("channels", newTarget.GetAllDmChannelNames()))
return nil return nil
} }
func (mgr *TargetManager) PullNextTarget(broker Broker, collectionID int64, chosenPartitionIDs ...int64) (*CollectionTarget, error) { func (mgr *TargetManager) PullNextTargetV1(broker Broker, collectionID int64, chosenPartitionIDs ...int64) (*CollectionTarget, error) {
log.Info("start to pull next targets for partition", log.Info("start to pull next targets for partition",
zap.Int64("collectionID", collectionID), zap.Int64("collectionID", collectionID),
zap.Int64s("chosenPartitionIDs", chosenPartitionIDs)) zap.Int64s("chosenPartitionIDs", chosenPartitionIDs))
@ -189,6 +190,56 @@ func (mgr *TargetManager) PullNextTarget(broker Broker, collectionID int64, chos
return NewCollectionTarget(segments, dmChannels), nil return NewCollectionTarget(segments, dmChannels), nil
} }
func (mgr *TargetManager) PullNextTarget(broker Broker, collectionID int64, chosenPartitionIDs ...int64) (*CollectionTarget, error) {
log.Debug("start to pull next targets for collection",
zap.Int64("collectionID", collectionID),
zap.Int64s("chosenPartitionIDs", chosenPartitionIDs))
channelInfos := make(map[string][]*datapb.VchannelInfo)
segments := make(map[int64]*datapb.SegmentInfo, 0)
dmChannels := make(map[string]*DmChannel)
if len(chosenPartitionIDs) == 0 {
return NewCollectionTarget(segments, dmChannels), nil
}
tryPullNextTargetV1 := func() (*CollectionTarget, error) {
// for rolling upgrade, when call GetRecoveryInfoV2 failed, back to retry GetRecoveryInfo
target, err := mgr.PullNextTargetV1(broker, collectionID, chosenPartitionIDs...)
if err != nil {
return nil, err
}
return target, nil
}
// we should pull `channel targets` from all partitions because QueryNodes need to load
// the complete growing segments. And we should pull `segments targets` only from the chosen partitions.
vChannelInfos, segmentInfos, err := broker.GetRecoveryInfoV2(context.TODO(), collectionID)
if err != nil {
if funcutil.IsGrpcErr(err) {
return tryPullNextTargetV1()
}
return nil, err
}
for _, info := range vChannelInfos {
channelInfos[info.GetChannelName()] = append(channelInfos[info.GetChannelName()], info)
}
partitionSet := typeutil.NewUniqueSet(chosenPartitionIDs...)
for _, segmentInfo := range segmentInfos {
if partitionSet.Contain(segmentInfo.GetPartitionID()) {
segments[segmentInfo.GetID()] = segmentInfo
}
}
for _, infos := range channelInfos {
merged := mgr.mergeDmChannelInfo(infos)
dmChannels[merged.GetChannelName()] = merged
}
return NewCollectionTarget(segments, dmChannels), nil
}
func (mgr *TargetManager) mergeDmChannelInfo(infos []*datapb.VchannelInfo) *DmChannel { func (mgr *TargetManager) mergeDmChannelInfo(infos []*datapb.VchannelInfo) *DmChannel {
var dmChannel *DmChannel var dmChannel *DmChannel

View File

@ -22,6 +22,8 @@ import (
"github.com/samber/lo" "github.com/samber/lo"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
@ -115,18 +117,19 @@ func (suite *TargetManagerSuite) SetupTest() {
}) })
} }
for partition, segments := range suite.segments[collection] { allSegments := make([]*datapb.SegmentInfo, 0)
allSegments := make([]*datapb.SegmentBinlogs, 0) for partitionID, segments := range suite.segments[collection] {
for _, segment := range segments { for _, segment := range segments {
allSegments = append(allSegments, &datapb.SegmentInfo{
allSegments = append(allSegments, &datapb.SegmentBinlogs{ ID: segment,
SegmentID: segment,
InsertChannel: suite.channels[collection][0], InsertChannel: suite.channels[collection][0],
CollectionID: collection,
PartitionID: partitionID,
}) })
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, partition).Return(dmChannels, allSegments, nil)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collection).Return(dmChannels, allSegments, nil)
suite.mgr.UpdateCollectionNextTargetWithPartitions(collection, suite.partitions[collection]...) suite.mgr.UpdateCollectionNextTargetWithPartitions(collection, suite.partitions[collection]...)
} }
} }
@ -181,7 +184,20 @@ func (suite *TargetManagerSuite) TestUpdateNextTarget() {
}, },
} }
nextTargetSegments := []*datapb.SegmentBinlogs{ nextTargetSegments := []*datapb.SegmentInfo{
{
ID: 11,
PartitionID: 1,
InsertChannel: "channel-1",
},
{
ID: 12,
PartitionID: 1,
InsertChannel: "channel-2",
},
}
nextTargetBinlogs := []*datapb.SegmentBinlogs{
{ {
SegmentID: 11, SegmentID: 11,
InsertChannel: "channel-1", InsertChannel: "channel-1",
@ -192,13 +208,26 @@ func (suite *TargetManagerSuite) TestUpdateNextTarget() {
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{1}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collectionID, int64(1)).Return(nextTargetChannels, nextTargetSegments, nil)
suite.mgr.UpdateCollectionNextTargetWithPartitions(collectionID, int64(1)) suite.mgr.UpdateCollectionNextTargetWithPartitions(collectionID, int64(1))
suite.assertSegments([]int64{11, 12}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) suite.assertSegments([]int64{11, 12}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget))
suite.assertChannels([]string{"channel-1", "channel-2"}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) suite.assertChannels([]string{"channel-1", "channel-2"}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget))
suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget)) suite.assertSegments([]int64{}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, CurrentTarget))
suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget))
suite.broker.ExpectedCalls = nil
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nil, nil, status.Errorf(codes.NotFound, "fake not found"))
suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collectionID, int64(1)).Return(nextTargetChannels, nextTargetBinlogs, nil)
err := suite.mgr.UpdateCollectionNextTargetWithPartitions(collectionID, int64(1))
suite.NoError(err)
err = suite.mgr.UpdateCollectionNextTargetWithPartitions(collectionID)
suite.Error(err)
err = suite.mgr.UpdateCollectionNextTarget(collectionID)
suite.NoError(err)
} }
func (suite *TargetManagerSuite) TestRemovePartition() { func (suite *TargetManagerSuite) TestRemovePartition() {

View File

@ -58,8 +58,8 @@ type MockQueryNodeServer_Delete_Call struct {
} }
// Delete is a helper method to define mock.On call // Delete is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.DeleteRequest // - _a1 *querypb.DeleteRequest
func (_e *MockQueryNodeServer_Expecter) Delete(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Delete_Call { func (_e *MockQueryNodeServer_Expecter) Delete(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Delete_Call {
return &MockQueryNodeServer_Delete_Call{Call: _e.mock.On("Delete", _a0, _a1)} return &MockQueryNodeServer_Delete_Call{Call: _e.mock.On("Delete", _a0, _a1)}
} }
@ -105,8 +105,8 @@ type MockQueryNodeServer_GetComponentStates_Call struct {
} }
// GetComponentStates is a helper method to define mock.On call // GetComponentStates is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *milvuspb.GetComponentStatesRequest // - _a1 *milvuspb.GetComponentStatesRequest
func (_e *MockQueryNodeServer_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetComponentStates_Call { func (_e *MockQueryNodeServer_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetComponentStates_Call {
return &MockQueryNodeServer_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)} return &MockQueryNodeServer_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)}
} }
@ -152,8 +152,8 @@ type MockQueryNodeServer_GetDataDistribution_Call struct {
} }
// GetDataDistribution is a helper method to define mock.On call // GetDataDistribution is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.GetDataDistributionRequest // - _a1 *querypb.GetDataDistributionRequest
func (_e *MockQueryNodeServer_Expecter) GetDataDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetDataDistribution_Call { func (_e *MockQueryNodeServer_Expecter) GetDataDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetDataDistribution_Call {
return &MockQueryNodeServer_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", _a0, _a1)} return &MockQueryNodeServer_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", _a0, _a1)}
} }
@ -199,8 +199,8 @@ type MockQueryNodeServer_GetMetrics_Call struct {
} }
// GetMetrics is a helper method to define mock.On call // GetMetrics is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *milvuspb.GetMetricsRequest // - _a1 *milvuspb.GetMetricsRequest
func (_e *MockQueryNodeServer_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetMetrics_Call { func (_e *MockQueryNodeServer_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetMetrics_Call {
return &MockQueryNodeServer_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)} return &MockQueryNodeServer_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)}
} }
@ -246,8 +246,8 @@ type MockQueryNodeServer_GetSegmentInfo_Call struct {
} }
// GetSegmentInfo is a helper method to define mock.On call // GetSegmentInfo is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.GetSegmentInfoRequest // - _a1 *querypb.GetSegmentInfoRequest
func (_e *MockQueryNodeServer_Expecter) GetSegmentInfo(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetSegmentInfo_Call { func (_e *MockQueryNodeServer_Expecter) GetSegmentInfo(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetSegmentInfo_Call {
return &MockQueryNodeServer_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", _a0, _a1)} return &MockQueryNodeServer_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", _a0, _a1)}
} }
@ -293,8 +293,8 @@ type MockQueryNodeServer_GetStatistics_Call struct {
} }
// GetStatistics is a helper method to define mock.On call // GetStatistics is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.GetStatisticsRequest // - _a1 *querypb.GetStatisticsRequest
func (_e *MockQueryNodeServer_Expecter) GetStatistics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatistics_Call { func (_e *MockQueryNodeServer_Expecter) GetStatistics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatistics_Call {
return &MockQueryNodeServer_GetStatistics_Call{Call: _e.mock.On("GetStatistics", _a0, _a1)} return &MockQueryNodeServer_GetStatistics_Call{Call: _e.mock.On("GetStatistics", _a0, _a1)}
} }
@ -340,8 +340,8 @@ type MockQueryNodeServer_GetStatisticsChannel_Call struct {
} }
// GetStatisticsChannel is a helper method to define mock.On call // GetStatisticsChannel is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *internalpb.GetStatisticsChannelRequest // - _a1 *internalpb.GetStatisticsChannelRequest
func (_e *MockQueryNodeServer_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatisticsChannel_Call { func (_e *MockQueryNodeServer_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatisticsChannel_Call {
return &MockQueryNodeServer_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)} return &MockQueryNodeServer_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)}
} }
@ -387,8 +387,8 @@ type MockQueryNodeServer_GetTimeTickChannel_Call struct {
} }
// GetTimeTickChannel is a helper method to define mock.On call // GetTimeTickChannel is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *internalpb.GetTimeTickChannelRequest // - _a1 *internalpb.GetTimeTickChannelRequest
func (_e *MockQueryNodeServer_Expecter) GetTimeTickChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetTimeTickChannel_Call { func (_e *MockQueryNodeServer_Expecter) GetTimeTickChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetTimeTickChannel_Call {
return &MockQueryNodeServer_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", _a0, _a1)} return &MockQueryNodeServer_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", _a0, _a1)}
} }
@ -434,8 +434,8 @@ type MockQueryNodeServer_LoadPartitions_Call struct {
} }
// LoadPartitions is a helper method to define mock.On call // LoadPartitions is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.LoadPartitionsRequest // - _a1 *querypb.LoadPartitionsRequest
func (_e *MockQueryNodeServer_Expecter) LoadPartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadPartitions_Call { func (_e *MockQueryNodeServer_Expecter) LoadPartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadPartitions_Call {
return &MockQueryNodeServer_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", _a0, _a1)} return &MockQueryNodeServer_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", _a0, _a1)}
} }
@ -481,8 +481,8 @@ type MockQueryNodeServer_LoadSegments_Call struct {
} }
// LoadSegments is a helper method to define mock.On call // LoadSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.LoadSegmentsRequest // - _a1 *querypb.LoadSegmentsRequest
func (_e *MockQueryNodeServer_Expecter) LoadSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadSegments_Call { func (_e *MockQueryNodeServer_Expecter) LoadSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadSegments_Call {
return &MockQueryNodeServer_LoadSegments_Call{Call: _e.mock.On("LoadSegments", _a0, _a1)} return &MockQueryNodeServer_LoadSegments_Call{Call: _e.mock.On("LoadSegments", _a0, _a1)}
} }
@ -528,8 +528,8 @@ type MockQueryNodeServer_Query_Call struct {
} }
// Query is a helper method to define mock.On call // Query is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.QueryRequest // - _a1 *querypb.QueryRequest
func (_e *MockQueryNodeServer_Expecter) Query(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Query_Call { func (_e *MockQueryNodeServer_Expecter) Query(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Query_Call {
return &MockQueryNodeServer_Query_Call{Call: _e.mock.On("Query", _a0, _a1)} return &MockQueryNodeServer_Query_Call{Call: _e.mock.On("Query", _a0, _a1)}
} }
@ -575,8 +575,8 @@ type MockQueryNodeServer_ReleaseCollection_Call struct {
} }
// ReleaseCollection is a helper method to define mock.On call // ReleaseCollection is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.ReleaseCollectionRequest // - _a1 *querypb.ReleaseCollectionRequest
func (_e *MockQueryNodeServer_Expecter) ReleaseCollection(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseCollection_Call { func (_e *MockQueryNodeServer_Expecter) ReleaseCollection(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseCollection_Call {
return &MockQueryNodeServer_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", _a0, _a1)} return &MockQueryNodeServer_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", _a0, _a1)}
} }
@ -622,8 +622,8 @@ type MockQueryNodeServer_ReleasePartitions_Call struct {
} }
// ReleasePartitions is a helper method to define mock.On call // ReleasePartitions is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.ReleasePartitionsRequest // - _a1 *querypb.ReleasePartitionsRequest
func (_e *MockQueryNodeServer_Expecter) ReleasePartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleasePartitions_Call { func (_e *MockQueryNodeServer_Expecter) ReleasePartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleasePartitions_Call {
return &MockQueryNodeServer_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", _a0, _a1)} return &MockQueryNodeServer_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", _a0, _a1)}
} }
@ -669,8 +669,8 @@ type MockQueryNodeServer_ReleaseSegments_Call struct {
} }
// ReleaseSegments is a helper method to define mock.On call // ReleaseSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.ReleaseSegmentsRequest // - _a1 *querypb.ReleaseSegmentsRequest
func (_e *MockQueryNodeServer_Expecter) ReleaseSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseSegments_Call { func (_e *MockQueryNodeServer_Expecter) ReleaseSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseSegments_Call {
return &MockQueryNodeServer_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", _a0, _a1)} return &MockQueryNodeServer_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", _a0, _a1)}
} }
@ -716,8 +716,8 @@ type MockQueryNodeServer_Search_Call struct {
} }
// Search is a helper method to define mock.On call // Search is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.SearchRequest // - _a1 *querypb.SearchRequest
func (_e *MockQueryNodeServer_Expecter) Search(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Search_Call { func (_e *MockQueryNodeServer_Expecter) Search(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Search_Call {
return &MockQueryNodeServer_Search_Call{Call: _e.mock.On("Search", _a0, _a1)} return &MockQueryNodeServer_Search_Call{Call: _e.mock.On("Search", _a0, _a1)}
} }
@ -763,8 +763,8 @@ type MockQueryNodeServer_ShowConfigurations_Call struct {
} }
// ShowConfigurations is a helper method to define mock.On call // ShowConfigurations is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *internalpb.ShowConfigurationsRequest // - _a1 *internalpb.ShowConfigurationsRequest
func (_e *MockQueryNodeServer_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ShowConfigurations_Call { func (_e *MockQueryNodeServer_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ShowConfigurations_Call {
return &MockQueryNodeServer_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)} return &MockQueryNodeServer_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)}
} }
@ -810,8 +810,8 @@ type MockQueryNodeServer_SyncDistribution_Call struct {
} }
// SyncDistribution is a helper method to define mock.On call // SyncDistribution is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.SyncDistributionRequest // - _a1 *querypb.SyncDistributionRequest
func (_e *MockQueryNodeServer_Expecter) SyncDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncDistribution_Call { func (_e *MockQueryNodeServer_Expecter) SyncDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncDistribution_Call {
return &MockQueryNodeServer_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", _a0, _a1)} return &MockQueryNodeServer_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", _a0, _a1)}
} }
@ -857,8 +857,8 @@ type MockQueryNodeServer_SyncReplicaSegments_Call struct {
} }
// SyncReplicaSegments is a helper method to define mock.On call // SyncReplicaSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.SyncReplicaSegmentsRequest // - _a1 *querypb.SyncReplicaSegmentsRequest
func (_e *MockQueryNodeServer_Expecter) SyncReplicaSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncReplicaSegments_Call { func (_e *MockQueryNodeServer_Expecter) SyncReplicaSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncReplicaSegments_Call {
return &MockQueryNodeServer_SyncReplicaSegments_Call{Call: _e.mock.On("SyncReplicaSegments", _a0, _a1)} return &MockQueryNodeServer_SyncReplicaSegments_Call{Call: _e.mock.On("SyncReplicaSegments", _a0, _a1)}
} }
@ -904,8 +904,8 @@ type MockQueryNodeServer_UnsubDmChannel_Call struct {
} }
// UnsubDmChannel is a helper method to define mock.On call // UnsubDmChannel is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.UnsubDmChannelRequest // - _a1 *querypb.UnsubDmChannelRequest
func (_e *MockQueryNodeServer_Expecter) UnsubDmChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_UnsubDmChannel_Call { func (_e *MockQueryNodeServer_Expecter) UnsubDmChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_UnsubDmChannel_Call {
return &MockQueryNodeServer_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", _a0, _a1)} return &MockQueryNodeServer_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", _a0, _a1)}
} }
@ -951,8 +951,8 @@ type MockQueryNodeServer_WatchDmChannels_Call struct {
} }
// WatchDmChannels is a helper method to define mock.On call // WatchDmChannels is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.WatchDmChannelsRequest // - _a1 *querypb.WatchDmChannelsRequest
func (_e *MockQueryNodeServer_Expecter) WatchDmChannels(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_WatchDmChannels_Call { func (_e *MockQueryNodeServer_Expecter) WatchDmChannels(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_WatchDmChannels_Call {
return &MockQueryNodeServer_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", _a0, _a1)} return &MockQueryNodeServer_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", _a0, _a1)}
} }

View File

@ -407,7 +407,6 @@ func (suite *CollectionObserverSuite) load(collection int64) {
}) })
} }
allSegments := make(map[int64][]*datapb.SegmentBinlogs, 0) // partitionID -> segments
dmChannels := make([]*datapb.VchannelInfo, 0) dmChannels := make([]*datapb.VchannelInfo, 0)
for _, channel := range suite.channels[collection] { for _, channel := range suite.channels[collection] {
dmChannels = append(dmChannels, &datapb.VchannelInfo{ dmChannels = append(dmChannels, &datapb.VchannelInfo{
@ -416,17 +415,17 @@ func (suite *CollectionObserverSuite) load(collection int64) {
}) })
} }
allSegments := make([]*datapb.SegmentInfo, 0) // partitionID -> segments
for _, segment := range suite.segments[collection] { for _, segment := range suite.segments[collection] {
allSegments[segment.PartitionID] = append(allSegments[segment.PartitionID], &datapb.SegmentBinlogs{ allSegments = append(allSegments, &datapb.SegmentInfo{
SegmentID: segment.GetID(), ID: segment.GetID(),
PartitionID: segment.PartitionID,
InsertChannel: segment.GetInsertChannel(), InsertChannel: segment.GetInsertChannel(),
}) })
} }
partitions := suite.partitions[collection] partitions := suite.partitions[collection]
for _, partition := range partitions { suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collection).Return(dmChannels, allSegments, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, partition).Return(dmChannels, allSegments[partition], nil)
}
suite.targetMgr.UpdateCollectionNextTargetWithPartitions(collection, partitions...) suite.targetMgr.UpdateCollectionNextTargetWithPartitions(collection, partitions...)
} }

View File

@ -86,9 +86,10 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() {
observer := suite.observer observer := suite.observer
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentBinlogs{ segments := []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
InsertChannel: "test-insert-channel", InsertChannel: "test-insert-channel",
}, },
} }
@ -98,7 +99,6 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
info := &datapb.SegmentInfo{ info := &datapb.SegmentInfo{
ID: 1, ID: 1,
CollectionID: 1, CollectionID: 1,
@ -110,7 +110,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil) suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil)
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return( suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return(
&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil) &datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
observer.target.UpdateCollectionCurrentTarget(1) observer.target.UpdateCollectionCurrentTarget(1)
@ -171,9 +171,10 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() {
observer := suite.observer observer := suite.observer
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentBinlogs{ segments := []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
InsertChannel: "test-insert-channel", InsertChannel: "test-insert-channel",
}, },
} }
@ -185,7 +186,6 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() {
} }
schema := utils.CreateTestSchema() schema := utils.CreateTestSchema()
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil) suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
info := &datapb.SegmentInfo{ info := &datapb.SegmentInfo{
ID: 1, ID: 1,
CollectionID: 1, CollectionID: 1,
@ -195,7 +195,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() {
loadInfo := utils.PackSegmentLoadInfo(info, nil) loadInfo := utils.PackSegmentLoadInfo(info, nil)
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return( suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return(
&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil) &datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
observer.target.UpdateCollectionCurrentTarget(1) observer.target.UpdateCollectionCurrentTarget(1)
@ -255,9 +255,10 @@ func (suite *LeaderObserverTestSuite) TestIgnoreBalancedSegment() {
observer := suite.observer observer := suite.observer
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentBinlogs{ segments := []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
InsertChannel: "test-insert-channel", InsertChannel: "test-insert-channel",
}, },
} }
@ -268,8 +269,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreBalancedSegment() {
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
observer.target.UpdateCollectionCurrentTarget(1) observer.target.UpdateCollectionCurrentTarget(1)
@ -295,9 +295,10 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() {
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 2)) observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 2))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(2, 1, []int64{3, 4})) observer.meta.ReplicaManager.Put(utils.CreateTestReplica(2, 1, []int64{3, 4}))
segments := []*datapb.SegmentBinlogs{ segments := []*datapb.SegmentInfo{
{ {
SegmentID: 1, ID: 1,
PartitionID: 1,
InsertChannel: "test-insert-channel", InsertChannel: "test-insert-channel",
}, },
} }
@ -307,7 +308,6 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
info := &datapb.SegmentInfo{ info := &datapb.SegmentInfo{
ID: 1, ID: 1,
CollectionID: 1, CollectionID: 1,
@ -318,7 +318,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() {
schema := utils.CreateTestSchema() schema := utils.CreateTestSchema()
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return( suite.broker.EXPECT().GetSegmentInfo(mock.Anything, int64(1)).Return(
&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil) &datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{info}}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
channels, segments, nil) channels, segments, nil)
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil) suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -436,9 +436,10 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncRemovedSegments() {
observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1)) observer.meta.CollectionManager.PutCollection(utils.CreateTestCollection(1, 1))
observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2})) observer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, []int64{1, 2}))
segments := []*datapb.SegmentBinlogs{ segments := []*datapb.SegmentInfo{
{ {
SegmentID: 2, ID: 2,
PartitionID: 1,
InsertChannel: "test-insert-channel", InsertChannel: "test-insert-channel",
}, },
} }
@ -450,8 +451,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncRemovedSegments() {
} }
schema := utils.CreateTestSchema() schema := utils.CreateTestSchema()
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil) suite.broker.EXPECT().GetCollectionSchema(mock.Anything, int64(1)).Return(schema, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, int64(1)).Return(
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))

View File

@ -49,7 +49,7 @@ type TargetObserverSuite struct {
collectionID int64 collectionID int64
partitionID int64 partitionID int64
nextTargetSegments []*datapb.SegmentBinlogs nextTargetSegments []*datapb.SegmentInfo
nextTargetChannels []*datapb.VchannelInfo nextTargetChannels []*datapb.VchannelInfo
} }
@ -106,19 +106,20 @@ func (suite *TargetObserverSuite) SetupTest() {
}, },
} }
suite.nextTargetSegments = []*datapb.SegmentBinlogs{ suite.nextTargetSegments = []*datapb.SegmentInfo{
{ {
SegmentID: 11, ID: 11,
PartitionID: suite.partitionID,
InsertChannel: "channel-1", InsertChannel: "channel-1",
}, },
{ {
SegmentID: 12, ID: 12,
PartitionID: suite.partitionID,
InsertChannel: "channel-2", InsertChannel: "channel-2",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return([]int64{suite.partitionID}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, mock.Anything).Return(suite.nextTargetChannels, suite.nextTargetSegments, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, mock.Anything, mock.Anything).Return(suite.nextTargetChannels, suite.nextTargetSegments, nil)
} }
func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
@ -153,16 +154,16 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
suite.broker.AssertExpectations(suite.T()) suite.broker.AssertExpectations(suite.T())
suite.broker.ExpectedCalls = suite.broker.ExpectedCalls[:0] suite.broker.ExpectedCalls = suite.broker.ExpectedCalls[:0]
suite.nextTargetSegments = append(suite.nextTargetSegments, &datapb.SegmentBinlogs{ suite.nextTargetSegments = append(suite.nextTargetSegments, &datapb.SegmentInfo{
SegmentID: 13, ID: 13,
PartitionID: suite.partitionID,
InsertChannel: "channel-1", InsertChannel: "channel-1",
}) })
suite.targetMgr.UpdateCollectionCurrentTarget(suite.collectionID) suite.targetMgr.UpdateCollectionCurrentTarget(suite.collectionID)
// Pull next again // Pull next again
suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return([]int64{suite.partitionID}, nil)
suite.broker.EXPECT(). suite.broker.EXPECT().
GetRecoveryInfo(mock.Anything, mock.Anything, mock.Anything). GetRecoveryInfoV2(mock.Anything, mock.Anything).
Return(suite.nextTargetChannels, suite.nextTargetSegments, nil) Return(suite.nextTargetChannels, suite.nextTargetSegments, nil)
suite.Eventually(func() bool { suite.Eventually(func() bool {
return len(suite.targetMgr.GetHistoricalSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 3 && return len(suite.targetMgr.GetHistoricalSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 3 &&

View File

@ -18,7 +18,6 @@ package querycoordv2
import ( import (
"context" "context"
"sync"
"testing" "testing"
"time" "time"
@ -160,8 +159,7 @@ func (suite *ServerSuite) TestRecoverFailed() {
broker := meta.NewMockBroker(suite.T()) broker := meta.NewMockBroker(suite.T())
for _, collection := range suite.collections { for _, collection := range suite.collections {
broker.EXPECT().GetPartitions(mock.Anything, collection).Return([]int64{1}, nil) broker.EXPECT().GetRecoveryInfoV2(context.TODO(), collection).Return(nil, nil, errors.New("CollectionNotExist"))
broker.EXPECT().GetRecoveryInfo(context.TODO(), collection, mock.Anything).Return(nil, nil, errors.New("CollectionNotExist"))
} }
suite.server.targetMgr = meta.NewTargetManager(broker, suite.server.meta) suite.server.targetMgr = meta.NewTargetManager(broker, suite.server.meta)
err = suite.server.Start() err = suite.server.Start()
@ -347,41 +345,25 @@ func (suite *ServerSuite) assertLoaded(collection int64) {
} }
func (suite *ServerSuite) expectGetRecoverInfo(collection int64) { func (suite *ServerSuite) expectGetRecoverInfo(collection int64) {
var ( vChannels := []*datapb.VchannelInfo{}
mu sync.Mutex for _, channel := range suite.channels[collection] {
vChannels []*datapb.VchannelInfo vChannels = append(vChannels, &datapb.VchannelInfo{
segmentBinlogs []*datapb.SegmentBinlogs CollectionID: collection,
) ChannelName: channel,
})
for partition, segments := range suite.segments[collection] {
segments := segments
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, partition).Maybe().Return(func(ctx context.Context, collectionID, partitionID int64) []*datapb.VchannelInfo {
mu.Lock()
vChannels = []*datapb.VchannelInfo{}
for _, channel := range suite.channels[collection] {
vChannels = append(vChannels, &datapb.VchannelInfo{
CollectionID: collection,
ChannelName: channel,
})
}
segmentBinlogs = []*datapb.SegmentBinlogs{}
for _, segment := range segments {
segmentBinlogs = append(segmentBinlogs, &datapb.SegmentBinlogs{
SegmentID: segment,
InsertChannel: suite.channels[collection][segment%2],
})
}
return vChannels
},
func(ctx context.Context, collectionID, partitionID int64) []*datapb.SegmentBinlogs {
return segmentBinlogs
},
func(ctx context.Context, collectionID, partitionID int64) error {
mu.Unlock()
return nil
},
)
} }
segmentInfos := []*datapb.SegmentInfo{}
for _, segments := range suite.segments[collection] {
for _, segment := range segments {
segmentInfos = append(segmentInfos, &datapb.SegmentInfo{
ID: segment,
PartitionID: suite.partitions[collection][0],
InsertChannel: suite.channels[collection][segment%2],
})
}
}
suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collection).Maybe().Return(vChannels, segmentInfos, nil)
} }
func (suite *ServerSuite) expectLoadAndReleasePartitions(querynode *mocks.MockQueryNode) { func (suite *ServerSuite) expectLoadAndReleasePartitions(querynode *mocks.MockQueryNode) {
@ -391,39 +373,41 @@ func (suite *ServerSuite) expectLoadAndReleasePartitions(querynode *mocks.MockQu
func (suite *ServerSuite) expectGetRecoverInfoByMockDataCoord(collection int64, dataCoord *coordMocks.DataCoord) { func (suite *ServerSuite) expectGetRecoverInfoByMockDataCoord(collection int64, dataCoord *coordMocks.DataCoord) {
var ( var (
vChannels []*datapb.VchannelInfo vChannels []*datapb.VchannelInfo
segmentBinlogs []*datapb.SegmentBinlogs segmentInfos []*datapb.SegmentInfo
) )
for partition, segments := range suite.segments[collection] { getRecoveryInfoRequest := &datapb.GetRecoveryInfoRequestV2{
segments := segments Base: commonpbutil.NewMsgBase(
getRecoveryInfoRequest := &datapb.GetRecoveryInfoRequest{ commonpbutil.WithMsgType(commonpb.MsgType_GetRecoveryInfo),
Base: commonpbutil.NewMsgBase( ),
commonpbutil.WithMsgType(commonpb.MsgType_GetRecoveryInfo), CollectionID: collection,
), }
vChannels = []*datapb.VchannelInfo{}
for _, channel := range suite.channels[collection] {
vChannels = append(vChannels, &datapb.VchannelInfo{
CollectionID: collection, CollectionID: collection,
PartitionID: partition, ChannelName: channel,
} })
vChannels = []*datapb.VchannelInfo{} }
for _, channel := range suite.channels[collection] {
vChannels = append(vChannels, &datapb.VchannelInfo{ segmentInfos = []*datapb.SegmentInfo{}
CollectionID: collection, for _, segments := range suite.segments[collection] {
ChannelName: channel,
})
}
segmentBinlogs = []*datapb.SegmentBinlogs{}
for _, segment := range segments { for _, segment := range segments {
segmentBinlogs = append(segmentBinlogs, &datapb.SegmentBinlogs{ segmentInfos = append(segmentInfos, &datapb.SegmentInfo{
SegmentID: segment, ID: segment,
InsertChannel: suite.channels[collection][segment%2], InsertChannel: suite.channels[collection][segment%2],
}) })
} }
dataCoord.EXPECT().GetRecoveryInfo(mock.Anything, getRecoveryInfoRequest).Maybe().Return(&datapb.GetRecoveryInfoResponse{
Status: merr.Status(nil),
Channels: vChannels,
Binlogs: segmentBinlogs,
}, nil)
} }
dataCoord.EXPECT().GetRecoveryInfoV2(mock.Anything, getRecoveryInfoRequest).Return(&datapb.GetRecoveryInfoResponseV2{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Channels: vChannels,
Segments: segmentInfos,
}, nil).Maybe()
} }
func (suite *ServerSuite) updateCollectionStatus(collectionID int64, status querypb.LoadStatus) { func (suite *ServerSuite) updateCollectionStatus(collectionID int64, status querypb.LoadStatus) {

View File

@ -775,10 +775,6 @@ func (suite *ServiceSuite) TestLoadPartition() {
// Test load all partitions // Test load all partitions
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).
Return(append(suite.partitions[collection], 999), nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, int64(999)).
Return(nil, nil, nil)
suite.expectGetRecoverInfo(collection) suite.expectGetRecoverInfo(collection)
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
@ -1651,7 +1647,7 @@ func (suite *ServiceSuite) assertSegments(collection int64, segments []*querypb.
} }
func (suite *ServiceSuite) expectGetRecoverInfo(collection int64) { func (suite *ServiceSuite) expectGetRecoverInfo(collection int64) {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil) suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
vChannels := []*datapb.VchannelInfo{} vChannels := []*datapb.VchannelInfo{}
for _, channel := range suite.channels[collection] { for _, channel := range suite.channels[collection] {
vChannels = append(vChannels, &datapb.VchannelInfo{ vChannels = append(vChannels, &datapb.VchannelInfo{
@ -1660,19 +1656,20 @@ func (suite *ServiceSuite) expectGetRecoverInfo(collection int64) {
}) })
} }
segmentBinlogs := []*datapb.SegmentInfo{}
for partition, segments := range suite.segments[collection] { for partition, segments := range suite.segments[collection] {
segmentBinlogs := []*datapb.SegmentBinlogs{}
for _, segment := range segments { for _, segment := range segments {
segmentBinlogs = append(segmentBinlogs, &datapb.SegmentBinlogs{ segmentBinlogs = append(segmentBinlogs, &datapb.SegmentInfo{
SegmentID: segment, ID: segment,
InsertChannel: suite.channels[collection][segment%2], InsertChannel: suite.channels[collection][segment%2],
PartitionID: partition,
CollectionID: collection,
}) })
} }
suite.broker.EXPECT().
GetRecoveryInfo(mock.Anything, collection, partition).
Return(vChannels, segmentBinlogs, nil)
} }
suite.broker.EXPECT().
GetRecoveryInfoV2(mock.Anything, collection, mock.Anything, mock.Anything).
Return(vChannels, segmentBinlogs, nil)
} }
func (suite *ServiceSuite) getAllSegments(collection int64) []int64 { func (suite *ServiceSuite) getAllSegments(collection int64) []int64 {

View File

@ -56,8 +56,8 @@ type MockCluster_GetComponentStates_Call struct {
} }
// GetComponentStates is a helper method to define mock.On call // GetComponentStates is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
func (_e *MockCluster_Expecter) GetComponentStates(ctx interface{}, nodeID interface{}) *MockCluster_GetComponentStates_Call { func (_e *MockCluster_Expecter) GetComponentStates(ctx interface{}, nodeID interface{}) *MockCluster_GetComponentStates_Call {
return &MockCluster_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx, nodeID)} return &MockCluster_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx, nodeID)}
} }
@ -103,9 +103,9 @@ type MockCluster_GetDataDistribution_Call struct {
} }
// GetDataDistribution is a helper method to define mock.On call // GetDataDistribution is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.GetDataDistributionRequest // - req *querypb.GetDataDistributionRequest
func (_e *MockCluster_Expecter) GetDataDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetDataDistribution_Call { func (_e *MockCluster_Expecter) GetDataDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetDataDistribution_Call {
return &MockCluster_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", ctx, nodeID, req)} return &MockCluster_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", ctx, nodeID, req)}
} }
@ -151,9 +151,9 @@ type MockCluster_GetMetrics_Call struct {
} }
// GetMetrics is a helper method to define mock.On call // GetMetrics is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *milvuspb.GetMetricsRequest // - req *milvuspb.GetMetricsRequest
func (_e *MockCluster_Expecter) GetMetrics(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetMetrics_Call { func (_e *MockCluster_Expecter) GetMetrics(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetMetrics_Call {
return &MockCluster_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, nodeID, req)} return &MockCluster_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, nodeID, req)}
} }
@ -199,9 +199,9 @@ type MockCluster_LoadPartitions_Call struct {
} }
// LoadPartitions is a helper method to define mock.On call // LoadPartitions is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.LoadPartitionsRequest // - req *querypb.LoadPartitionsRequest
func (_e *MockCluster_Expecter) LoadPartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadPartitions_Call { func (_e *MockCluster_Expecter) LoadPartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadPartitions_Call {
return &MockCluster_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, nodeID, req)} return &MockCluster_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, nodeID, req)}
} }
@ -247,9 +247,9 @@ type MockCluster_LoadSegments_Call struct {
} }
// LoadSegments is a helper method to define mock.On call // LoadSegments is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.LoadSegmentsRequest // - req *querypb.LoadSegmentsRequest
func (_e *MockCluster_Expecter) LoadSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadSegments_Call { func (_e *MockCluster_Expecter) LoadSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadSegments_Call {
return &MockCluster_LoadSegments_Call{Call: _e.mock.On("LoadSegments", ctx, nodeID, req)} return &MockCluster_LoadSegments_Call{Call: _e.mock.On("LoadSegments", ctx, nodeID, req)}
} }
@ -295,9 +295,9 @@ type MockCluster_ReleasePartitions_Call struct {
} }
// ReleasePartitions is a helper method to define mock.On call // ReleasePartitions is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.ReleasePartitionsRequest // - req *querypb.ReleasePartitionsRequest
func (_e *MockCluster_Expecter) ReleasePartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleasePartitions_Call { func (_e *MockCluster_Expecter) ReleasePartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleasePartitions_Call {
return &MockCluster_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, nodeID, req)} return &MockCluster_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, nodeID, req)}
} }
@ -343,9 +343,9 @@ type MockCluster_ReleaseSegments_Call struct {
} }
// ReleaseSegments is a helper method to define mock.On call // ReleaseSegments is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.ReleaseSegmentsRequest // - req *querypb.ReleaseSegmentsRequest
func (_e *MockCluster_Expecter) ReleaseSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleaseSegments_Call { func (_e *MockCluster_Expecter) ReleaseSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleaseSegments_Call {
return &MockCluster_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", ctx, nodeID, req)} return &MockCluster_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", ctx, nodeID, req)}
} }
@ -373,7 +373,7 @@ type MockCluster_Start_Call struct {
} }
// Start is a helper method to define mock.On call // Start is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
func (_e *MockCluster_Expecter) Start(ctx interface{}) *MockCluster_Start_Call { func (_e *MockCluster_Expecter) Start(ctx interface{}) *MockCluster_Start_Call {
return &MockCluster_Start_Call{Call: _e.mock.On("Start", ctx)} return &MockCluster_Start_Call{Call: _e.mock.On("Start", ctx)}
} }
@ -446,9 +446,9 @@ type MockCluster_SyncDistribution_Call struct {
} }
// SyncDistribution is a helper method to define mock.On call // SyncDistribution is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.SyncDistributionRequest // - req *querypb.SyncDistributionRequest
func (_e *MockCluster_Expecter) SyncDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_SyncDistribution_Call { func (_e *MockCluster_Expecter) SyncDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_SyncDistribution_Call {
return &MockCluster_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", ctx, nodeID, req)} return &MockCluster_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", ctx, nodeID, req)}
} }
@ -494,9 +494,9 @@ type MockCluster_UnsubDmChannel_Call struct {
} }
// UnsubDmChannel is a helper method to define mock.On call // UnsubDmChannel is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.UnsubDmChannelRequest // - req *querypb.UnsubDmChannelRequest
func (_e *MockCluster_Expecter) UnsubDmChannel(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_UnsubDmChannel_Call { func (_e *MockCluster_Expecter) UnsubDmChannel(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_UnsubDmChannel_Call {
return &MockCluster_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", ctx, nodeID, req)} return &MockCluster_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", ctx, nodeID, req)}
} }
@ -542,9 +542,9 @@ type MockCluster_WatchDmChannels_Call struct {
} }
// WatchDmChannels is a helper method to define mock.On call // WatchDmChannels is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.WatchDmChannelsRequest // - req *querypb.WatchDmChannelsRequest
func (_e *MockCluster_Expecter) WatchDmChannels(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_WatchDmChannels_Call { func (_e *MockCluster_Expecter) WatchDmChannels(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_WatchDmChannels_Call {
return &MockCluster_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", ctx, nodeID, req)} return &MockCluster_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", ctx, nodeID, req)}
} }

View File

@ -13,14 +13,6 @@ type MockScheduler struct {
mock.Mock mock.Mock
} }
func (_m *MockScheduler) GetChannelTaskNum() int {
return 0
}
func (_m *MockScheduler) GetSegmentTaskNum() int {
return 0
}
type MockScheduler_Expecter struct { type MockScheduler_Expecter struct {
mock *mock.Mock mock *mock.Mock
} }
@ -122,6 +114,42 @@ func (_c *MockScheduler_Dispatch_Call) Return() *MockScheduler_Dispatch_Call {
return _c return _c
} }
// GetChannelTaskNum provides a mock function with given fields:
func (_m *MockScheduler) GetChannelTaskNum() int {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// MockScheduler_GetChannelTaskNum_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChannelTaskNum'
type MockScheduler_GetChannelTaskNum_Call struct {
*mock.Call
}
// GetChannelTaskNum is a helper method to define mock.On call
func (_e *MockScheduler_Expecter) GetChannelTaskNum() *MockScheduler_GetChannelTaskNum_Call {
return &MockScheduler_GetChannelTaskNum_Call{Call: _e.mock.On("GetChannelTaskNum")}
}
func (_c *MockScheduler_GetChannelTaskNum_Call) Run(run func()) *MockScheduler_GetChannelTaskNum_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockScheduler_GetChannelTaskNum_Call) Return(_a0 int) *MockScheduler_GetChannelTaskNum_Call {
_c.Call.Return(_a0)
return _c
}
// GetNodeChannelDelta provides a mock function with given fields: nodeID // GetNodeChannelDelta provides a mock function with given fields: nodeID
func (_m *MockScheduler) GetNodeChannelDelta(nodeID int64) int { func (_m *MockScheduler) GetNodeChannelDelta(nodeID int64) int {
ret := _m.Called(nodeID) ret := _m.Called(nodeID)
@ -196,6 +224,42 @@ func (_c *MockScheduler_GetNodeSegmentDelta_Call) Return(_a0 int) *MockScheduler
return _c return _c
} }
// GetSegmentTaskNum provides a mock function with given fields:
func (_m *MockScheduler) GetSegmentTaskNum() int {
ret := _m.Called()
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// MockScheduler_GetSegmentTaskNum_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSegmentTaskNum'
type MockScheduler_GetSegmentTaskNum_Call struct {
*mock.Call
}
// GetSegmentTaskNum is a helper method to define mock.On call
func (_e *MockScheduler_Expecter) GetSegmentTaskNum() *MockScheduler_GetSegmentTaskNum_Call {
return &MockScheduler_GetSegmentTaskNum_Call{Call: _e.mock.On("GetSegmentTaskNum")}
}
func (_c *MockScheduler_GetSegmentTaskNum_Call) Run(run func()) *MockScheduler_GetSegmentTaskNum_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockScheduler_GetSegmentTaskNum_Call) Return(_a0 int) *MockScheduler_GetSegmentTaskNum_Call {
_c.Call.Return(_a0)
return _c
}
// RemoveByNode provides a mock function with given fields: node // RemoveByNode provides a mock function with given fields: node
func (_m *MockScheduler) RemoveByNode(node int64) { func (_m *MockScheduler) RemoveByNode(node int64) {
_m.Called(node) _m.Called(node)

View File

@ -196,10 +196,6 @@ func (suite *TaskSuite) TestSubscribeChannelTask() {
}}, }},
}, nil) }, nil)
} }
// for _, partition := range partitions {
// suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, partition).
// Return(channels, nil, nil)
// }
suite.cluster.EXPECT().WatchDmChannels(mock.Anything, targetNode, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil) suite.cluster.EXPECT().WatchDmChannels(mock.Anything, targetNode, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
// Test subscribe channel task // Test subscribe channel task
@ -224,8 +220,7 @@ func (suite *TaskSuite) TestSubscribeChannelTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(dmChannels, nil, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(dmChannels, nil, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
suite.AssertTaskNum(0, len(suite.subChannels), len(suite.subChannels), 0) suite.AssertTaskNum(0, len(suite.subChannels), len(suite.subChannels), 0)
@ -320,8 +315,7 @@ func (suite *TaskSuite) TestUnsubscribeChannelTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(dmChannels, nil, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(dmChannels, nil, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
// Only first channel exists // Only first channel exists
@ -372,8 +366,6 @@ func (suite *TaskSuite) TestLoadSegmentTask() {
}, nil) }, nil)
suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil) suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil)
} }
// suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, partition).
// Return([]*datapb.VchannelInfo{channel}, nil, nil)
suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil) suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
// Test load segment task // Test load segment task
@ -382,11 +374,12 @@ func (suite *TaskSuite) TestLoadSegmentTask() {
ChannelName: channel.ChannelName, ChannelName: channel.ChannelName,
})) }))
tasks := []Task{} tasks := []Task{}
segments := make([]*datapb.SegmentBinlogs, 0) segments := make([]*datapb.SegmentInfo, 0)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
segments = append(segments, &datapb.SegmentBinlogs{ segments = append(segments, &datapb.SegmentInfo{
SegmentID: segment, ID: segment,
InsertChannel: channel.ChannelName, InsertChannel: channel.ChannelName,
PartitionID: 1,
}) })
task, err := NewSegmentTask( task, err := NewSegmentTask(
ctx, ctx,
@ -401,8 +394,7 @@ func (suite *TaskSuite) TestLoadSegmentTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(nil, segments, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segments, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
@ -463,10 +455,11 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() {
ChannelName: channel.ChannelName, ChannelName: channel.ChannelName,
})) }))
tasks := []Task{} tasks := []Task{}
segmentInfos := make([]*datapb.SegmentBinlogs, 0) segmentInfos := make([]*datapb.SegmentInfo, 0)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
segmentInfos = append(segmentInfos, &datapb.SegmentBinlogs{ segmentInfos = append(segmentInfos, &datapb.SegmentInfo{
SegmentID: segment, ID: segment,
PartitionID: 1,
InsertChannel: channel.ChannelName, InsertChannel: channel.ChannelName,
}) })
task, err := NewSegmentTask( task, err := NewSegmentTask(
@ -482,8 +475,7 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(nil, segmentInfos, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
@ -649,8 +641,6 @@ func (suite *TaskSuite) TestMoveSegmentTask() {
}, nil) }, nil)
suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil) suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil)
} }
// suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, partition).
// Return([]*datapb.VchannelInfo{channel}, nil, nil)
suite.cluster.EXPECT().LoadSegments(mock.Anything, leader, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil) suite.cluster.EXPECT().LoadSegments(mock.Anything, leader, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
suite.cluster.EXPECT().ReleaseSegments(mock.Anything, leader, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil) suite.cluster.EXPECT().ReleaseSegments(mock.Anything, leader, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
@ -668,12 +658,13 @@ func (suite *TaskSuite) TestMoveSegmentTask() {
} }
tasks := []Task{} tasks := []Task{}
segments := make([]*meta.Segment, 0) segments := make([]*meta.Segment, 0)
segmentInfos := make([]*datapb.SegmentBinlogs, 0) segmentInfos := make([]*datapb.SegmentInfo, 0)
for _, segment := range suite.moveSegments { for _, segment := range suite.moveSegments {
segments = append(segments, segments = append(segments,
utils.CreateTestSegment(suite.collection, partition, segment, sourceNode, 1, channel.ChannelName)) utils.CreateTestSegment(suite.collection, partition, segment, sourceNode, 1, channel.ChannelName))
segmentInfos = append(segmentInfos, &datapb.SegmentBinlogs{ segmentInfos = append(segmentInfos, &datapb.SegmentInfo{
SegmentID: segment, ID: segment,
PartitionID: 1,
InsertChannel: channel.ChannelName, InsertChannel: channel.ChannelName,
}) })
view.Segments[segment] = &querypb.SegmentDist{NodeID: sourceNode, Version: 0} view.Segments[segment] = &querypb.SegmentDist{NodeID: sourceNode, Version: 0}
@ -692,8 +683,7 @@ func (suite *TaskSuite) TestMoveSegmentTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return([]*datapb.VchannelInfo{vchannel}, segmentInfos, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return([]*datapb.VchannelInfo{vchannel}, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
suite.target.UpdateCollectionCurrentTarget(suite.collection, int64(1)) suite.target.UpdateCollectionCurrentTarget(suite.collection, int64(1))
suite.dist.SegmentDistManager.Update(sourceNode, segments...) suite.dist.SegmentDistManager.Update(sourceNode, segments...)
@ -748,8 +738,6 @@ func (suite *TaskSuite) TestTaskCanceled() {
}, nil) }, nil)
suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil) suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil)
} }
// suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, partition).
// Return([]*datapb.VchannelInfo{channel}, nil, nil)
suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil) suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
// Test load segment task // Test load segment task
@ -758,10 +746,11 @@ func (suite *TaskSuite) TestTaskCanceled() {
ChannelName: channel.ChannelName, ChannelName: channel.ChannelName,
})) }))
tasks := []Task{} tasks := []Task{}
segmentInfos := []*datapb.SegmentBinlogs{} segmentInfos := []*datapb.SegmentInfo{}
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
segmentInfos = append(segmentInfos, &datapb.SegmentBinlogs{ segmentInfos = append(segmentInfos, &datapb.SegmentInfo{
SegmentID: segment, ID: segment,
PartitionID: partition,
InsertChannel: channel.GetChannelName(), InsertChannel: channel.GetChannelName(),
}) })
task, err := NewSegmentTask( task, err := NewSegmentTask(
@ -779,8 +768,7 @@ func (suite *TaskSuite) TestTaskCanceled() {
} }
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{partition}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(nil, segmentInfos, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, partition).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, partition) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, partition)
// Process tasks // Process tasks
@ -826,8 +814,6 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
}, nil) }, nil)
suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil) suite.broker.EXPECT().GetIndexInfo(mock.Anything, suite.collection, segment).Return(nil, nil)
} }
// suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, partition).
// Return([]*datapb.VchannelInfo{channel}, nil, nil)
suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil) suite.cluster.EXPECT().LoadSegments(mock.Anything, targetNode, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
// Test load segment task // Test load segment task
@ -837,10 +823,11 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
ChannelName: channel.ChannelName, ChannelName: channel.ChannelName,
})) }))
tasks := []Task{} tasks := []Task{}
segmentInfos := make([]*datapb.SegmentBinlogs, 0) segmentInfos := make([]*datapb.SegmentInfo, 0)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
segmentInfos = append(segmentInfos, &datapb.SegmentBinlogs{ segmentInfos = append(segmentInfos, &datapb.SegmentInfo{
SegmentID: segment, ID: segment,
PartitionID: 1,
InsertChannel: channel.GetChannelName(), InsertChannel: channel.GetChannelName(),
}) })
task, err := NewSegmentTask( task, err := NewSegmentTask(
@ -856,8 +843,7 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(nil, segmentInfos, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
@ -877,18 +863,18 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
view.Segments[segment] = &querypb.SegmentDist{NodeID: targetNode, Version: 0} view.Segments[segment] = &querypb.SegmentDist{NodeID: targetNode, Version: 0}
} }
suite.dist.LeaderViewManager.Update(targetNode, view) suite.dist.LeaderViewManager.Update(targetNode, view)
segmentInfos = make([]*datapb.SegmentBinlogs, 0) segmentInfos = make([]*datapb.SegmentInfo, 0)
for _, segment := range suite.loadSegments[1:] { for _, segment := range suite.loadSegments[1:] {
segmentInfos = append(segmentInfos, &datapb.SegmentBinlogs{ segmentInfos = append(segmentInfos, &datapb.SegmentInfo{
SegmentID: segment, ID: segment,
PartitionID: 2,
InsertChannel: channel.GetChannelName(), InsertChannel: channel.GetChannelName(),
}) })
} }
bakExpectations := suite.broker.ExpectedCalls bakExpectations := suite.broker.ExpectedCalls
suite.broker.AssertExpectations(suite.T()) suite.broker.AssertExpectations(suite.T())
suite.broker.ExpectedCalls = suite.broker.ExpectedCalls[:0] suite.broker.ExpectedCalls = suite.broker.ExpectedCalls[:0]
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{2}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(nil, segmentInfos, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(2)).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(2)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(2))
suite.dispatchAndWait(targetNode) suite.dispatchAndWait(targetNode)
suite.AssertTaskNum(0, 0, 0, 0) suite.AssertTaskNum(0, 0, 0, 0)
@ -1073,10 +1059,11 @@ func (suite *TaskSuite) TestNoExecutor() {
CollectionID: suite.collection, CollectionID: suite.collection,
ChannelName: channel.ChannelName, ChannelName: channel.ChannelName,
})) }))
segments := make([]*datapb.SegmentBinlogs, 0) segments := make([]*datapb.SegmentInfo, 0)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
segments = append(segments, &datapb.SegmentBinlogs{ segments = append(segments, &datapb.SegmentInfo{
SegmentID: segment, ID: segment,
PartitionID: 1,
InsertChannel: channel.ChannelName, InsertChannel: channel.ChannelName,
}) })
task, err := NewSegmentTask( task, err := NewSegmentTask(
@ -1091,8 +1078,7 @@ func (suite *TaskSuite) TestNoExecutor() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil) suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, suite.collection).Return(nil, segments, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segments, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)

View File

@ -242,6 +242,15 @@ type DataCoord interface {
// error is returned only when some communication issue occurs // error is returned only when some communication issue occurs
GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error)
// GetRecoveryInfoV2 request segment recovery info of collection or batch partitions
//
// ctx is the context to control request deadline and cancellation
// req contains the collection/partitions id to query
//
// response struct `GetRecoveryInfoResponseV2` contains the list of segments info and corresponding vchannel info
// error is returned only when some communication issue occurs
GetRecoveryInfoV2(ctx context.Context, req *datapb.GetRecoveryInfoRequestV2) (*datapb.GetRecoveryInfoResponseV2, error)
// SaveBinlogPaths updates segments binlogs(including insert binlogs, stats logs and delta logs) // SaveBinlogPaths updates segments binlogs(including insert binlogs, stats logs and delta logs)
// and related message stream positions // and related message stream positions
// //

View File

@ -96,6 +96,10 @@ func (m *GrpcDataCoordClient) GetRecoveryInfo(ctx context.Context, in *datapb.Ge
return &datapb.GetRecoveryInfoResponse{}, m.Err return &datapb.GetRecoveryInfoResponse{}, m.Err
} }
func (m *GrpcDataCoordClient) GetRecoveryInfoV2(ctx context.Context, in *datapb.GetRecoveryInfoRequestV2, opts ...grpc.CallOption) (*datapb.GetRecoveryInfoResponseV2, error) {
return &datapb.GetRecoveryInfoResponseV2{}, m.Err
}
func (m *GrpcDataCoordClient) GetFlushedSegments(ctx context.Context, in *datapb.GetFlushedSegmentsRequest, opts ...grpc.CallOption) (*datapb.GetFlushedSegmentsResponse, error) { func (m *GrpcDataCoordClient) GetFlushedSegments(ctx context.Context, in *datapb.GetFlushedSegmentsRequest, opts ...grpc.CallOption) (*datapb.GetFlushedSegmentsResponse, error) {
return &datapb.GetFlushedSegmentsResponse{}, m.Err return &datapb.GetFlushedSegmentsResponse{}, m.Err
} }