From cbb350c5525f1a40410b6205edccf21e0894f02b Mon Sep 17 00:00:00 2001 From: congqixia Date: Wed, 11 Oct 2023 17:03:34 +0800 Subject: [PATCH] Add broker for datanode grpc operations (#27631) Signed-off-by: Congqi Xia --- Makefile | 1 + internal/datanode/broker/broker.go | 52 ++ internal/datanode/broker/datacoord.go | 133 +++++ internal/datanode/broker/datacoord_test.go | 296 +++++++++++ internal/datanode/broker/mock_broker.go | 555 +++++++++++++++++++++ internal/datanode/broker/rootcoord.go | 114 +++++ internal/datanode/broker/rootcoord_test.go | 241 +++++++++ 7 files changed, 1392 insertions(+) create mode 100644 internal/datanode/broker/broker.go create mode 100644 internal/datanode/broker/datacoord.go create mode 100644 internal/datanode/broker/datacoord_test.go create mode 100644 internal/datanode/broker/mock_broker.go create mode 100644 internal/datanode/broker/rootcoord.go create mode 100644 internal/datanode/broker/rootcoord_test.go diff --git a/Makefile b/Makefile index 4584903c3d..9232b8d905 100644 --- a/Makefile +++ b/Makefile @@ -422,6 +422,7 @@ generate-mockery-datacoord: getdeps generate-mockery-datanode: getdeps $(INSTALL_PATH)/mockery --name=Allocator --dir=$(PWD)/internal/datanode/allocator --output=$(PWD)/internal/datanode/allocator --filename=mock_allocator.go --with-expecter --structname=MockAllocator --outpkg=allocator --inpackage + $(INSTALL_PATH)/mockery --name=Broker --dir=$(PWD)/internal/datanode/broker --output=$(PWD)/internal/datanode/broker/ --filename=mock_broker.go --with-expecter --structname=MockBroker --outpkg=broker --inpackage generate-mockery-metastore: getdeps $(INSTALL_PATH)/mockery --name=RootCoordCatalog --dir=$(PWD)/internal/metastore --output=$(PWD)/internal/metastore/mocks --filename=mock_rootcoord_catalog.go --with-expecter --structname=RootCoordCatalog --outpkg=mocks diff --git a/internal/datanode/broker/broker.go b/internal/datanode/broker/broker.go new file mode 100644 index 0000000000..688792866d --- /dev/null +++ b/internal/datanode/broker/broker.go @@ -0,0 +1,52 @@ +package broker + +import ( + "context" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// Broker is the interface for datanode to interact with other components. +type Broker interface { + RootCoord + DataCoord +} + +type coordBroker struct { + *rootCoordBroker + *dataCoordBroker +} + +func NewCoordBroker(rc types.RootCoordClient, dc types.DataCoordClient) Broker { + return &coordBroker{ + rootCoordBroker: &rootCoordBroker{ + client: rc, + }, + dataCoordBroker: &dataCoordBroker{ + client: dc, + }, + } +} + +// RootCoord is the interface wraps `RootCoord` grpc call +type RootCoord interface { + DescribeCollection(ctx context.Context, collectionID typeutil.UniqueID, ts typeutil.Timestamp) (*milvuspb.DescribeCollectionResponse, error) + ShowPartitions(ctx context.Context, dbName, collectionName string) (map[string]int64, error) + ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) error + AllocTimestamp(ctx context.Context, num uint32) (ts uint64, count uint32, err error) +} + +// DataCoord is the interface wraps `DataCoord` grpc call +type DataCoord interface { + AssignSegmentID(ctx context.Context, reqs ...*datapb.SegmentIDRequest) ([]typeutil.UniqueID, error) + ReportTimeTick(ctx context.Context, msgs []*msgpb.DataNodeTtMsg) error + GetSegmentInfo(ctx context.Context, segmentIDs []int64) ([]*datapb.SegmentInfo, error) + UpdateChannelCheckpoint(ctx context.Context, channelName string, cp *msgpb.MsgPosition) error + SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) error + DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) error +} diff --git a/internal/datanode/broker/datacoord.go b/internal/datanode/broker/datacoord.go new file mode 100644 index 0000000000..fa8f0fbe3d --- /dev/null +++ b/internal/datanode/broker/datacoord.go @@ -0,0 +1,133 @@ +package broker + +import ( + "context" + + "github.com/samber/lo" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type dataCoordBroker struct { + client types.DataCoordClient +} + +func (dc *dataCoordBroker) AssignSegmentID(ctx context.Context, reqs ...*datapb.SegmentIDRequest) ([]typeutil.UniqueID, error) { + req := &datapb.AssignSegmentIDRequest{ + NodeID: paramtable.GetNodeID(), + PeerRole: typeutil.ProxyRole, + SegmentIDRequests: reqs, + } + + resp, err := dc.client.AssignSegmentID(ctx, req) + + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to call datacoord AssignSegmentID", zap.Error(err)) + return nil, err + } + + return lo.Map(resp.GetSegIDAssignments(), func(result *datapb.SegmentIDAssignment, _ int) typeutil.UniqueID { + return result.GetSegID() + }), nil +} + +func (dc *dataCoordBroker) ReportTimeTick(ctx context.Context, msgs []*msgpb.DataNodeTtMsg) error { + log := log.Ctx(ctx) + + req := &datapb.ReportDataNodeTtMsgsRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_DataNodeTt), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + Msgs: msgs, + } + + resp, err := dc.client.ReportDataNodeTtMsgs(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to report datanodeTtMsgs", zap.Error(err)) + return err + } + return nil +} + +func (dc *dataCoordBroker) GetSegmentInfo(ctx context.Context, segmentIDs []int64) ([]*datapb.SegmentInfo, error) { + log := log.Ctx(ctx).With( + zap.Int64s("segmentIDs", segmentIDs), + ) + + infoResp, err := dc.client.GetSegmentInfo(ctx, &datapb.GetSegmentInfoRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_SegmentInfo), + commonpbutil.WithMsgID(0), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + SegmentIDs: segmentIDs, + IncludeUnHealthy: true, + }) + if err := merr.CheckRPCCall(infoResp, err); err != nil { + log.Warn("Fail to get SegmentInfo by ids from datacoord", zap.Error(err)) + return nil, err + } + + return infoResp.Infos, nil +} + +func (dc *dataCoordBroker) UpdateChannelCheckpoint(ctx context.Context, channelName string, cp *msgpb.MsgPosition) error { + channelCPTs, _ := tsoutil.ParseTS(cp.GetTimestamp()) + log := log.Ctx(ctx).With( + zap.String("channelName", channelName), + zap.Time("channelCheckpointTime", channelCPTs), + ) + + req := &datapb.UpdateChannelCheckpointRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + VChannel: channelName, + Position: cp, + } + + resp, err := dc.client.UpdateChannelCheckpoint(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to update channel checkpoint", zap.Error(err)) + return err + } + return nil +} + +func (dc *dataCoordBroker) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) error { + log := log.Ctx(ctx) + + resp, err := dc.client.SaveBinlogPaths(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to SaveBinlogPaths", zap.Error(err)) + return err + } + + return nil +} + +func (dc *dataCoordBroker) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) error { + log := log.Ctx(ctx) + + resp, err := dc.client.DropVirtualChannel(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + if resp.GetStatus().GetErrorCode() == commonpb.ErrorCode_MetaFailed { + err = merr.WrapErrChannelNotFound(req.GetChannelName()) + } + log.Warn("failed to SaveBinlogPaths", zap.Error(err)) + return err + } + + return nil +} diff --git a/internal/datanode/broker/datacoord_test.go b/internal/datanode/broker/datacoord_test.go new file mode 100644 index 0000000000..74f1957a0b --- /dev/null +++ b/internal/datanode/broker/datacoord_test.go @@ -0,0 +1,296 @@ +package broker + +import ( + "context" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" +) + +type dataCoordSuite struct { + suite.Suite + + dc *mocks.MockDataCoordClient + broker Broker +} + +func (s *dataCoordSuite) SetupSuite() { + paramtable.Init() +} + +func (s *dataCoordSuite) SetupTest() { + s.dc = mocks.NewMockDataCoordClient(s.T()) + s.broker = NewCoordBroker(nil, s.dc) +} + +func (s *dataCoordSuite) resetMock() { + s.dc.AssertExpectations(s.T()) + s.dc.ExpectedCalls = nil +} + +func (s *dataCoordSuite) TestAssignSegmentID() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + reqs := []*datapb.SegmentIDRequest{ + {CollectionID: 100, Count: 1000}, + {CollectionID: 100, Count: 2000}, + } + + s.Run("normal_case", func() { + s.dc.EXPECT().AssignSegmentID(mock.Anything, mock.Anything). + Return(&datapb.AssignSegmentIDResponse{ + Status: merr.Status(nil), + SegIDAssignments: lo.Map(reqs, func(req *datapb.SegmentIDRequest, _ int) *datapb.SegmentIDAssignment { + return &datapb.SegmentIDAssignment{ + Status: merr.Status(nil), + SegID: 10001, + Count: req.GetCount(), + } + }), + }, nil) + + segmentIDs, err := s.broker.AssignSegmentID(ctx, reqs...) + s.NoError(err) + s.Equal(len(segmentIDs), len(reqs)) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().AssignSegmentID(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + _, err := s.broker.AssignSegmentID(ctx, reqs...) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().AssignSegmentID(mock.Anything, mock.Anything). + Return(&datapb.AssignSegmentIDResponse{ + Status: merr.Status(errors.New("mock")), + }, nil) + + _, err := s.broker.AssignSegmentID(ctx, reqs...) + s.Error(err) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestReportTimeTick() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + msgs := []*msgpb.DataNodeTtMsg{ + {Timestamp: 1000, ChannelName: "dml_0"}, + {Timestamp: 2000, ChannelName: "dml_1"}, + } + + s.Run("normal_case", func() { + s.dc.EXPECT().ReportDataNodeTtMsgs(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *datapb.ReportDataNodeTtMsgsRequest, _ ...grpc.CallOption) { + s.Equal(msgs, req.GetMsgs()) + }). + Return(merr.Status(nil), nil) + + err := s.broker.ReportTimeTick(ctx, msgs) + s.NoError(err) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().ReportDataNodeTtMsgs(mock.Anything, mock.Anything). + Return(merr.Status(errors.New("mock")), nil) + + err := s.broker.ReportTimeTick(ctx, msgs) + s.Error(err) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestGetSegmentInfo() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + segmentIDs := []int64{1, 2, 3} + + s.Run("normal_case", func() { + s.dc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *datapb.GetSegmentInfoRequest, _ ...grpc.CallOption) { + s.ElementsMatch(segmentIDs, req.GetSegmentIDs()) + s.True(req.GetIncludeUnHealthy()) + }). + Return(&datapb.GetSegmentInfoResponse{ + Status: merr.Status(nil), + Infos: lo.Map(segmentIDs, func(id int64, _ int) *datapb.SegmentInfo { + return &datapb.SegmentInfo{ID: id} + }), + }, nil) + infos, err := s.broker.GetSegmentInfo(ctx, segmentIDs) + s.NoError(err) + s.ElementsMatch(segmentIDs, lo.Map(infos, func(info *datapb.SegmentInfo, _ int) int64 { return info.GetID() })) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + _, err := s.broker.GetSegmentInfo(ctx, segmentIDs) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything). + Return(&datapb.GetSegmentInfoResponse{ + Status: merr.Status(errors.New("mock")), + }, nil) + _, err := s.broker.GetSegmentInfo(ctx, segmentIDs) + s.Error(err) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestUpdateChannelCheckpoint() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + channelName := "dml_0" + checkpoint := &msgpb.MsgPosition{ + ChannelName: channelName, + MsgID: []byte{1, 2, 3}, + Timestamp: tsoutil.ComposeTSByTime(time.Now(), 0), + } + + s.Run("normal_case", func() { + s.dc.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *datapb.UpdateChannelCheckpointRequest, _ ...grpc.CallOption) { + s.Equal(channelName, req.GetVChannel()) + cp := req.GetPosition() + s.Equal(checkpoint.MsgID, cp.GetMsgID()) + s.Equal(checkpoint.ChannelName, cp.GetChannelName()) + s.Equal(checkpoint.Timestamp, cp.GetTimestamp()) + }). + Return(merr.Status(nil), nil) + + err := s.broker.UpdateChannelCheckpoint(ctx, channelName, checkpoint) + s.NoError(err) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + err := s.broker.UpdateChannelCheckpoint(ctx, channelName, checkpoint) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().UpdateChannelCheckpoint(mock.Anything, mock.Anything). + Return(merr.Status(errors.New("mock")), nil) + + err := s.broker.UpdateChannelCheckpoint(ctx, channelName, checkpoint) + s.Error(err) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestSaveBinlogPaths() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := &datapb.SaveBinlogPathsRequest{ + Channel: "dml_0", + } + + s.Run("normal_case", func() { + s.dc.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *datapb.SaveBinlogPathsRequest, _ ...grpc.CallOption) { + s.Equal("dml_0", req.GetChannel()) + }). + Return(merr.Status(nil), nil) + err := s.broker.SaveBinlogPaths(ctx, req) + s.NoError(err) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + err := s.broker.SaveBinlogPaths(ctx, req) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().SaveBinlogPaths(mock.Anything, mock.Anything). + Return(merr.Status(errors.New("mock")), nil) + err := s.broker.SaveBinlogPaths(ctx, req) + s.Error(err) + s.resetMock() + }) +} + +func (s *dataCoordSuite) TestDropVirtualChannel() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req := &datapb.DropVirtualChannelRequest{ + ChannelName: "dml_0", + } + + s.Run("normal_case", func() { + s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *datapb.DropVirtualChannelRequest, _ ...grpc.CallOption) { + s.Equal("dml_0", req.GetChannelName()) + }). + Return(&datapb.DropVirtualChannelResponse{Status: merr.Status(nil)}, nil) + err := s.broker.DropVirtualChannel(ctx, req) + s.NoError(err) + s.resetMock() + }) + + s.Run("datacoord_return_error", func() { + s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + err := s.broker.DropVirtualChannel(ctx, req) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_failure_status", func() { + s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything). + Return(&datapb.DropVirtualChannelResponse{Status: merr.Status(errors.New("mock"))}, nil) + err := s.broker.DropVirtualChannel(ctx, req) + s.Error(err) + s.resetMock() + }) + + s.Run("datacoord_return_legacy_MetaFailed", func() { + s.dc.EXPECT().DropVirtualChannel(mock.Anything, mock.Anything). + Return(&datapb.DropVirtualChannelResponse{Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_MetaFailed}}, nil) + err := s.broker.DropVirtualChannel(ctx, req) + s.Error(err) + s.ErrorIs(err, merr.ErrChannelNotFound) + s.resetMock() + }) +} + +func TestDataCoordBroker(t *testing.T) { + suite.Run(t, new(dataCoordSuite)) +} diff --git a/internal/datanode/broker/mock_broker.go b/internal/datanode/broker/mock_broker.go new file mode 100644 index 0000000000..706884fdd2 --- /dev/null +++ b/internal/datanode/broker/mock_broker.go @@ -0,0 +1,555 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package broker + +import ( + context "context" + + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + datapb "github.com/milvus-io/milvus/internal/proto/datapb" + + mock "github.com/stretchr/testify/mock" + + msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + + rootcoordpb "github.com/milvus-io/milvus/internal/proto/rootcoordpb" +) + +// MockBroker is an autogenerated mock type for the Broker type +type MockBroker struct { + mock.Mock +} + +type MockBroker_Expecter struct { + mock *mock.Mock +} + +func (_m *MockBroker) EXPECT() *MockBroker_Expecter { + return &MockBroker_Expecter{mock: &_m.Mock} +} + +// AllocTimestamp provides a mock function with given fields: ctx, num +func (_m *MockBroker) AllocTimestamp(ctx context.Context, num uint32) (uint64, uint32, error) { + ret := _m.Called(ctx, num) + + var r0 uint64 + var r1 uint32 + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, uint32) (uint64, uint32, error)); ok { + return rf(ctx, num) + } + if rf, ok := ret.Get(0).(func(context.Context, uint32) uint64); ok { + r0 = rf(ctx, num) + } else { + r0 = ret.Get(0).(uint64) + } + + if rf, ok := ret.Get(1).(func(context.Context, uint32) uint32); ok { + r1 = rf(ctx, num) + } else { + r1 = ret.Get(1).(uint32) + } + + if rf, ok := ret.Get(2).(func(context.Context, uint32) error); ok { + r2 = rf(ctx, num) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockBroker_AllocTimestamp_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllocTimestamp' +type MockBroker_AllocTimestamp_Call struct { + *mock.Call +} + +// AllocTimestamp is a helper method to define mock.On call +// - ctx context.Context +// - num uint32 +func (_e *MockBroker_Expecter) AllocTimestamp(ctx interface{}, num interface{}) *MockBroker_AllocTimestamp_Call { + return &MockBroker_AllocTimestamp_Call{Call: _e.mock.On("AllocTimestamp", ctx, num)} +} + +func (_c *MockBroker_AllocTimestamp_Call) Run(run func(ctx context.Context, num uint32)) *MockBroker_AllocTimestamp_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uint32)) + }) + return _c +} + +func (_c *MockBroker_AllocTimestamp_Call) Return(ts uint64, count uint32, err error) *MockBroker_AllocTimestamp_Call { + _c.Call.Return(ts, count, err) + return _c +} + +func (_c *MockBroker_AllocTimestamp_Call) RunAndReturn(run func(context.Context, uint32) (uint64, uint32, error)) *MockBroker_AllocTimestamp_Call { + _c.Call.Return(run) + return _c +} + +// AssignSegmentID provides a mock function with given fields: ctx, reqs +func (_m *MockBroker) AssignSegmentID(ctx context.Context, reqs ...*datapb.SegmentIDRequest) ([]int64, error) { + _va := make([]interface{}, len(reqs)) + for _i := range reqs { + _va[_i] = reqs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 []int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ...*datapb.SegmentIDRequest) ([]int64, error)); ok { + return rf(ctx, reqs...) + } + if rf, ok := ret.Get(0).(func(context.Context, ...*datapb.SegmentIDRequest) []int64); ok { + r0 = rf(ctx, reqs...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int64) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, ...*datapb.SegmentIDRequest) error); ok { + r1 = rf(ctx, reqs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_AssignSegmentID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AssignSegmentID' +type MockBroker_AssignSegmentID_Call struct { + *mock.Call +} + +// AssignSegmentID is a helper method to define mock.On call +// - ctx context.Context +// - reqs ...*datapb.SegmentIDRequest +func (_e *MockBroker_Expecter) AssignSegmentID(ctx interface{}, reqs ...interface{}) *MockBroker_AssignSegmentID_Call { + return &MockBroker_AssignSegmentID_Call{Call: _e.mock.On("AssignSegmentID", + append([]interface{}{ctx}, reqs...)...)} +} + +func (_c *MockBroker_AssignSegmentID_Call) Run(run func(ctx context.Context, reqs ...*datapb.SegmentIDRequest)) *MockBroker_AssignSegmentID_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]*datapb.SegmentIDRequest, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(*datapb.SegmentIDRequest) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockBroker_AssignSegmentID_Call) Return(_a0 []int64, _a1 error) *MockBroker_AssignSegmentID_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_AssignSegmentID_Call) RunAndReturn(run func(context.Context, ...*datapb.SegmentIDRequest) ([]int64, error)) *MockBroker_AssignSegmentID_Call { + _c.Call.Return(run) + return _c +} + +// DescribeCollection provides a mock function with given fields: ctx, collectionID, ts +func (_m *MockBroker) DescribeCollection(ctx context.Context, collectionID int64, ts uint64) (*milvuspb.DescribeCollectionResponse, error) { + ret := _m.Called(ctx, collectionID, ts) + + var r0 *milvuspb.DescribeCollectionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) (*milvuspb.DescribeCollectionResponse, error)); ok { + return rf(ctx, collectionID, ts) + } + if rf, ok := ret.Get(0).(func(context.Context, int64, uint64) *milvuspb.DescribeCollectionResponse); ok { + r0 = rf(ctx, collectionID, ts) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.DescribeCollectionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int64, uint64) error); ok { + r1 = rf(ctx, collectionID, ts) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_DescribeCollection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DescribeCollection' +type MockBroker_DescribeCollection_Call struct { + *mock.Call +} + +// DescribeCollection is a helper method to define mock.On call +// - ctx context.Context +// - collectionID int64 +// - ts uint64 +func (_e *MockBroker_Expecter) DescribeCollection(ctx interface{}, collectionID interface{}, ts interface{}) *MockBroker_DescribeCollection_Call { + return &MockBroker_DescribeCollection_Call{Call: _e.mock.On("DescribeCollection", ctx, collectionID, ts)} +} + +func (_c *MockBroker_DescribeCollection_Call) Run(run func(ctx context.Context, collectionID int64, ts uint64)) *MockBroker_DescribeCollection_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64), args[2].(uint64)) + }) + return _c +} + +func (_c *MockBroker_DescribeCollection_Call) Return(_a0 *milvuspb.DescribeCollectionResponse, _a1 error) *MockBroker_DescribeCollection_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_DescribeCollection_Call) RunAndReturn(run func(context.Context, int64, uint64) (*milvuspb.DescribeCollectionResponse, error)) *MockBroker_DescribeCollection_Call { + _c.Call.Return(run) + return _c +} + +// DropVirtualChannel provides a mock function with given fields: ctx, req +func (_m *MockBroker) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) error { + ret := _m.Called(ctx, req) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.DropVirtualChannelRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_DropVirtualChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DropVirtualChannel' +type MockBroker_DropVirtualChannel_Call struct { + *mock.Call +} + +// DropVirtualChannel is a helper method to define mock.On call +// - ctx context.Context +// - req *datapb.DropVirtualChannelRequest +func (_e *MockBroker_Expecter) DropVirtualChannel(ctx interface{}, req interface{}) *MockBroker_DropVirtualChannel_Call { + return &MockBroker_DropVirtualChannel_Call{Call: _e.mock.On("DropVirtualChannel", ctx, req)} +} + +func (_c *MockBroker_DropVirtualChannel_Call) Run(run func(ctx context.Context, req *datapb.DropVirtualChannelRequest)) *MockBroker_DropVirtualChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*datapb.DropVirtualChannelRequest)) + }) + return _c +} + +func (_c *MockBroker_DropVirtualChannel_Call) Return(_a0 error) *MockBroker_DropVirtualChannel_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_DropVirtualChannel_Call) RunAndReturn(run func(context.Context, *datapb.DropVirtualChannelRequest) error) *MockBroker_DropVirtualChannel_Call { + _c.Call.Return(run) + return _c +} + +// GetSegmentInfo provides a mock function with given fields: ctx, segmentIDs +func (_m *MockBroker) GetSegmentInfo(ctx context.Context, segmentIDs []int64) ([]*datapb.SegmentInfo, error) { + ret := _m.Called(ctx, segmentIDs) + + var r0 []*datapb.SegmentInfo + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []int64) ([]*datapb.SegmentInfo, error)); ok { + return rf(ctx, segmentIDs) + } + if rf, ok := ret.Get(0).(func(context.Context, []int64) []*datapb.SegmentInfo); ok { + r0 = rf(ctx, segmentIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*datapb.SegmentInfo) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []int64) error); ok { + r1 = rf(ctx, segmentIDs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_GetSegmentInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSegmentInfo' +type MockBroker_GetSegmentInfo_Call struct { + *mock.Call +} + +// GetSegmentInfo is a helper method to define mock.On call +// - ctx context.Context +// - segmentIDs []int64 +func (_e *MockBroker_Expecter) GetSegmentInfo(ctx interface{}, segmentIDs interface{}) *MockBroker_GetSegmentInfo_Call { + return &MockBroker_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", ctx, segmentIDs)} +} + +func (_c *MockBroker_GetSegmentInfo_Call) Run(run func(ctx context.Context, segmentIDs []int64)) *MockBroker_GetSegmentInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]int64)) + }) + return _c +} + +func (_c *MockBroker_GetSegmentInfo_Call) Return(_a0 []*datapb.SegmentInfo, _a1 error) *MockBroker_GetSegmentInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_GetSegmentInfo_Call) RunAndReturn(run func(context.Context, []int64) ([]*datapb.SegmentInfo, error)) *MockBroker_GetSegmentInfo_Call { + _c.Call.Return(run) + return _c +} + +// ReportImport provides a mock function with given fields: ctx, req +func (_m *MockBroker) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) error { + ret := _m.Called(ctx, req) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *rootcoordpb.ImportResult) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_ReportImport_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportImport' +type MockBroker_ReportImport_Call struct { + *mock.Call +} + +// ReportImport is a helper method to define mock.On call +// - ctx context.Context +// - req *rootcoordpb.ImportResult +func (_e *MockBroker_Expecter) ReportImport(ctx interface{}, req interface{}) *MockBroker_ReportImport_Call { + return &MockBroker_ReportImport_Call{Call: _e.mock.On("ReportImport", ctx, req)} +} + +func (_c *MockBroker_ReportImport_Call) Run(run func(ctx context.Context, req *rootcoordpb.ImportResult)) *MockBroker_ReportImport_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*rootcoordpb.ImportResult)) + }) + return _c +} + +func (_c *MockBroker_ReportImport_Call) Return(_a0 error) *MockBroker_ReportImport_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_ReportImport_Call) RunAndReturn(run func(context.Context, *rootcoordpb.ImportResult) error) *MockBroker_ReportImport_Call { + _c.Call.Return(run) + return _c +} + +// ReportTimeTick provides a mock function with given fields: ctx, msgs +func (_m *MockBroker) ReportTimeTick(ctx context.Context, msgs []*msgpb.DataNodeTtMsg) error { + ret := _m.Called(ctx, msgs) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []*msgpb.DataNodeTtMsg) error); ok { + r0 = rf(ctx, msgs) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_ReportTimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReportTimeTick' +type MockBroker_ReportTimeTick_Call struct { + *mock.Call +} + +// ReportTimeTick is a helper method to define mock.On call +// - ctx context.Context +// - msgs []*msgpb.DataNodeTtMsg +func (_e *MockBroker_Expecter) ReportTimeTick(ctx interface{}, msgs interface{}) *MockBroker_ReportTimeTick_Call { + return &MockBroker_ReportTimeTick_Call{Call: _e.mock.On("ReportTimeTick", ctx, msgs)} +} + +func (_c *MockBroker_ReportTimeTick_Call) Run(run func(ctx context.Context, msgs []*msgpb.DataNodeTtMsg)) *MockBroker_ReportTimeTick_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]*msgpb.DataNodeTtMsg)) + }) + return _c +} + +func (_c *MockBroker_ReportTimeTick_Call) Return(_a0 error) *MockBroker_ReportTimeTick_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_ReportTimeTick_Call) RunAndReturn(run func(context.Context, []*msgpb.DataNodeTtMsg) error) *MockBroker_ReportTimeTick_Call { + _c.Call.Return(run) + return _c +} + +// SaveBinlogPaths provides a mock function with given fields: ctx, req +func (_m *MockBroker) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) error { + ret := _m.Called(ctx, req) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *datapb.SaveBinlogPathsRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_SaveBinlogPaths_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveBinlogPaths' +type MockBroker_SaveBinlogPaths_Call struct { + *mock.Call +} + +// SaveBinlogPaths is a helper method to define mock.On call +// - ctx context.Context +// - req *datapb.SaveBinlogPathsRequest +func (_e *MockBroker_Expecter) SaveBinlogPaths(ctx interface{}, req interface{}) *MockBroker_SaveBinlogPaths_Call { + return &MockBroker_SaveBinlogPaths_Call{Call: _e.mock.On("SaveBinlogPaths", ctx, req)} +} + +func (_c *MockBroker_SaveBinlogPaths_Call) Run(run func(ctx context.Context, req *datapb.SaveBinlogPathsRequest)) *MockBroker_SaveBinlogPaths_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*datapb.SaveBinlogPathsRequest)) + }) + return _c +} + +func (_c *MockBroker_SaveBinlogPaths_Call) Return(_a0 error) *MockBroker_SaveBinlogPaths_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_SaveBinlogPaths_Call) RunAndReturn(run func(context.Context, *datapb.SaveBinlogPathsRequest) error) *MockBroker_SaveBinlogPaths_Call { + _c.Call.Return(run) + return _c +} + +// ShowPartitions provides a mock function with given fields: ctx, dbName, collectionName +func (_m *MockBroker) ShowPartitions(ctx context.Context, dbName string, collectionName string) (map[string]int64, error) { + ret := _m.Called(ctx, dbName, collectionName) + + var r0 map[string]int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (map[string]int64, error)); ok { + return rf(ctx, dbName, collectionName) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) map[string]int64); ok { + r0 = rf(ctx, dbName, collectionName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]int64) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, dbName, collectionName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockBroker_ShowPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ShowPartitions' +type MockBroker_ShowPartitions_Call struct { + *mock.Call +} + +// ShowPartitions is a helper method to define mock.On call +// - ctx context.Context +// - dbName string +// - collectionName string +func (_e *MockBroker_Expecter) ShowPartitions(ctx interface{}, dbName interface{}, collectionName interface{}) *MockBroker_ShowPartitions_Call { + return &MockBroker_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", ctx, dbName, collectionName)} +} + +func (_c *MockBroker_ShowPartitions_Call) Run(run func(ctx context.Context, dbName string, collectionName string)) *MockBroker_ShowPartitions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockBroker_ShowPartitions_Call) Return(_a0 map[string]int64, _a1 error) *MockBroker_ShowPartitions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockBroker_ShowPartitions_Call) RunAndReturn(run func(context.Context, string, string) (map[string]int64, error)) *MockBroker_ShowPartitions_Call { + _c.Call.Return(run) + return _c +} + +// UpdateChannelCheckpoint provides a mock function with given fields: ctx, channelName, cp +func (_m *MockBroker) UpdateChannelCheckpoint(ctx context.Context, channelName string, cp *msgpb.MsgPosition) error { + ret := _m.Called(ctx, channelName, cp) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.MsgPosition) error); ok { + r0 = rf(ctx, channelName, cp) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroker_UpdateChannelCheckpoint_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateChannelCheckpoint' +type MockBroker_UpdateChannelCheckpoint_Call struct { + *mock.Call +} + +// UpdateChannelCheckpoint is a helper method to define mock.On call +// - ctx context.Context +// - channelName string +// - cp *msgpb.MsgPosition +func (_e *MockBroker_Expecter) UpdateChannelCheckpoint(ctx interface{}, channelName interface{}, cp interface{}) *MockBroker_UpdateChannelCheckpoint_Call { + return &MockBroker_UpdateChannelCheckpoint_Call{Call: _e.mock.On("UpdateChannelCheckpoint", ctx, channelName, cp)} +} + +func (_c *MockBroker_UpdateChannelCheckpoint_Call) Run(run func(ctx context.Context, channelName string, cp *msgpb.MsgPosition)) *MockBroker_UpdateChannelCheckpoint_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(*msgpb.MsgPosition)) + }) + return _c +} + +func (_c *MockBroker_UpdateChannelCheckpoint_Call) Return(_a0 error) *MockBroker_UpdateChannelCheckpoint_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroker_UpdateChannelCheckpoint_Call) RunAndReturn(run func(context.Context, string, *msgpb.MsgPosition) error) *MockBroker_UpdateChannelCheckpoint_Call { + _c.Call.Return(run) + return _c +} + +// NewMockBroker creates a new instance of MockBroker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockBroker(t interface { + mock.TestingT + Cleanup(func()) +}) *MockBroker { + mock := &MockBroker{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/datanode/broker/rootcoord.go b/internal/datanode/broker/rootcoord.go new file mode 100644 index 0000000000..47129f8487 --- /dev/null +++ b/internal/datanode/broker/rootcoord.go @@ -0,0 +1,114 @@ +package broker + +import ( + "context" + "fmt" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type rootCoordBroker struct { + client types.RootCoordClient +} + +func (rc *rootCoordBroker) DescribeCollection(ctx context.Context, collectionID typeutil.UniqueID, timestamp typeutil.Timestamp) (*milvuspb.DescribeCollectionResponse, error) { + log := log.Ctx(ctx).With( + zap.Int64("collectionID", collectionID), + zap.Uint64("timestamp", timestamp), + ) + req := &milvuspb.DescribeCollectionRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + // please do not specify the collection name alone after database feature. + CollectionID: collectionID, + TimeStamp: timestamp, + } + + resp, err := rc.client.DescribeCollectionInternal(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to DescribeCollectionInternal", zap.Error(err)) + return nil, err + } + + return resp, nil +} + +func (rc *rootCoordBroker) ShowPartitions(ctx context.Context, dbName, collectionName string) (map[string]int64, error) { + req := &milvuspb.ShowPartitionsRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), + ), + DbName: dbName, + CollectionName: collectionName, + } + + log := log.Ctx(ctx).With( + zap.String("dbName", dbName), + zap.String("collectionName", collectionName), + ) + + resp, err := rc.client.ShowPartitions(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to get partitions of collection", zap.Error(err)) + return nil, err + } + + partitionNames := resp.GetPartitionNames() + partitionIDs := resp.GetPartitionIDs() + if len(partitionNames) != len(partitionIDs) { + log.Warn("partition names and ids are unequal", + zap.Int("partitionNameNumber", len(partitionNames)), + zap.Int("partitionIDNumber", len(partitionIDs))) + return nil, fmt.Errorf("partition names and ids are unequal, number of names: %d, number of ids: %d", + len(partitionNames), len(partitionIDs)) + } + + partitions := make(map[string]int64) + for i := 0; i < len(partitionNames); i++ { + partitions[partitionNames[i]] = partitionIDs[i] + } + + return partitions, nil +} + +func (rc *rootCoordBroker) AllocTimestamp(ctx context.Context, num uint32) (uint64, uint32, error) { + log := log.Ctx(ctx) + + req := &rootcoordpb.AllocTimestampRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_RequestTSO), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + Count: num, + } + + resp, err := rc.client.AllocTimestamp(ctx, req) + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to AllocTimestamp", zap.Error(err)) + return 0, 0, err + } + return resp.GetTimestamp(), resp.GetCount(), nil +} + +func (rc *rootCoordBroker) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) error { + log := log.Ctx(ctx) + resp, err := rc.client.ReportImport(ctx, req) + + if err := merr.CheckRPCCall(resp, err); err != nil { + log.Warn("failed to ReportImport", zap.Error(err)) + return err + } + return nil +} diff --git a/internal/datanode/broker/rootcoord_test.go b/internal/datanode/broker/rootcoord_test.go new file mode 100644 index 0000000000..e08279fe2f --- /dev/null +++ b/internal/datanode/broker/rootcoord_test.go @@ -0,0 +1,241 @@ +package broker + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/tsoutil" +) + +type rootCoordSuite struct { + suite.Suite + + rc *mocks.MockRootCoordClient + broker Broker +} + +func (s *rootCoordSuite) SetupSuite() { + paramtable.Init() +} + +func (s *rootCoordSuite) SetupTest() { + s.rc = mocks.NewMockRootCoordClient(s.T()) + s.broker = NewCoordBroker(s.rc, nil) +} + +func (s *rootCoordSuite) resetMock() { + s.rc.AssertExpectations(s.T()) + s.rc.ExpectedCalls = nil +} + +func (s *rootCoordSuite) TestDescribeCollection() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + collectionID := int64(100) + timestamp := tsoutil.ComposeTSByTime(time.Now(), 0) + + s.Run("normal_case", func() { + collName := "test_collection_name" + + s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) { + s.Equal(collectionID, req.GetCollectionID()) + s.Equal(timestamp, req.GetTimeStamp()) + }). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(nil), + CollectionID: collectionID, + CollectionName: collName, + }, nil) + + resp, err := s.broker.DescribeCollection(ctx, collectionID, timestamp) + s.NoError(err) + s.Equal(collectionID, resp.GetCollectionID()) + s.Equal(collName, resp.GetCollectionName()) + s.resetMock() + }) + + s.Run("rootcoord_return_error", func() { + s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + _, err := s.broker.DescribeCollection(ctx, collectionID, timestamp) + s.Error(err) + s.resetMock() + }) + + s.Run("rootcoord_return_failure_status", func() { + s.rc.EXPECT().DescribeCollectionInternal(mock.Anything, mock.Anything). + Return(&milvuspb.DescribeCollectionResponse{ + Status: merr.Status(errors.New("mocked")), + }, nil) + + _, err := s.broker.DescribeCollection(ctx, collectionID, timestamp) + s.Error(err) + s.resetMock() + }) +} + +func (s *rootCoordSuite) TestShowPartitions() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dbName := "defaultDB" + collName := "testCollection" + + s.Run("normal_case", func() { + partitions := map[string]int64{ + "part1": 1001, + "part2": 1002, + "part3": 1003, + } + + names := lo.Keys(partitions) + ids := lo.Map(names, func(name string, _ int) int64 { + return partitions[name] + }) + + s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *milvuspb.ShowPartitionsRequest, _ ...grpc.CallOption) { + s.Equal(dbName, req.GetDbName()) + s.Equal(collName, req.GetCollectionName()) + }). + Return(&milvuspb.ShowPartitionsResponse{ + Status: merr.Status(nil), + PartitionIDs: ids, + PartitionNames: names, + }, nil) + partNameIDs, err := s.broker.ShowPartitions(ctx, dbName, collName) + s.NoError(err) + s.Equal(len(partitions), len(partNameIDs)) + for name, id := range partitions { + result, ok := partNameIDs[name] + s.True(ok) + s.Equal(id, result) + } + s.resetMock() + }) + + s.Run("rootcoord_return_error", func() { + s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + _, err := s.broker.ShowPartitions(ctx, dbName, collName) + s.Error(err) + s.resetMock() + }) + + s.Run("partition_id_name_not_match", func() { + s.rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything). + Return(&milvuspb.ShowPartitionsResponse{ + Status: merr.Status(nil), + PartitionIDs: []int64{1, 2}, + PartitionNames: []string{"part1"}, + }, nil) + + _, err := s.broker.ShowPartitions(ctx, dbName, collName) + s.Error(err) + s.resetMock() + }) +} + +func (s *rootCoordSuite) TestAllocTimestamp() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("normal_case", func() { + num := rand.Intn(10) + 1 + ts := tsoutil.ComposeTSByTime(time.Now(), 0) + s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *rootcoordpb.AllocTimestampRequest, _ ...grpc.CallOption) { + s.EqualValues(num, req.GetCount()) + }). + Return(&rootcoordpb.AllocTimestampResponse{ + Status: merr.Status(nil), + Timestamp: ts, + Count: uint32(num), + }, nil) + + timestamp, cnt, err := s.broker.AllocTimestamp(ctx, uint32(num)) + s.NoError(err) + s.Equal(ts, timestamp) + s.EqualValues(num, cnt) + s.resetMock() + }) + + s.Run("rootcoord_return_error", func() { + s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + _, _, err := s.broker.AllocTimestamp(ctx, 1) + s.Error(err) + s.resetMock() + }) + + s.Run("rootcoord_return_failure_status", func() { + s.rc.EXPECT().AllocTimestamp(mock.Anything, mock.Anything). + Return(&rootcoordpb.AllocTimestampResponse{Status: merr.Status(errors.New("mock"))}, nil) + _, _, err := s.broker.AllocTimestamp(ctx, 1) + s.Error(err) + s.resetMock() + }) +} + +func (s *rootCoordSuite) TestReportImport() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + taskID := rand.Int63() + + req := &rootcoordpb.ImportResult{ + Status: merr.Status(nil), + TaskId: taskID, + } + + s.Run("normal_case", func() { + s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything). + Run(func(_ context.Context, req *rootcoordpb.ImportResult, _ ...grpc.CallOption) { + s.Equal(taskID, req.GetTaskId()) + }). + Return(merr.Status(nil), nil) + + err := s.broker.ReportImport(ctx, req) + s.NoError(err) + s.resetMock() + }) + + s.Run("rootcoord_return_error", func() { + s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything). + Return(nil, errors.New("mock")) + + err := s.broker.ReportImport(ctx, req) + s.Error(err) + s.resetMock() + }) + + s.Run("rootcoord_return_failure_status", func() { + s.rc.EXPECT().ReportImport(mock.Anything, mock.Anything). + Return(merr.Status(errors.New("mock")), nil) + + err := s.broker.ReportImport(ctx, req) + s.Error(err) + s.resetMock() + }) +} + +func TestRootCoordBroker(t *testing.T) { + suite.Run(t, new(rootCoordSuite)) +}