diff --git a/internal/distributed/querycoord/client/client.go b/internal/distributed/querycoord/client/client.go index d0fa7ebd09..b8a7a345da 100644 --- a/internal/distributed/querycoord/client/client.go +++ b/internal/distributed/querycoord/client/client.go @@ -402,3 +402,102 @@ func (c *Client) DeactivateChecker(ctx context.Context, req *querypb.DeactivateC return client.DeactivateChecker(ctx, req) }) } + +func (c *Client) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest, opts ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error) { + req = typeutil.Clone(req) + commonpbutil.UpdateMsgBase( + req.GetBase(), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), + ) + return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*querypb.ListQueryNodeResponse, error) { + return client.ListQueryNode(ctx, req) + }) +} + +func (c *Client) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error) { + req = typeutil.Clone(req) + commonpbutil.UpdateMsgBase( + req.GetBase(), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), + ) + return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*querypb.GetQueryNodeDistributionResponse, error) { + return client.GetQueryNodeDistribution(ctx, req) + }) +} + +func (c *Client) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + req = typeutil.Clone(req) + commonpbutil.UpdateMsgBase( + req.GetBase(), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), + ) + return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) { + return client.SuspendBalance(ctx, req) + }) +} + +func (c *Client) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + req = typeutil.Clone(req) + commonpbutil.UpdateMsgBase( + req.GetBase(), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), + ) + return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) { + return client.ResumeBalance(ctx, req) + }) +} + +func (c *Client) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + req = typeutil.Clone(req) + commonpbutil.UpdateMsgBase( + req.GetBase(), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), + ) + return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) { + return client.SuspendNode(ctx, req) + }) +} + +func (c *Client) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + req = typeutil.Clone(req) + commonpbutil.UpdateMsgBase( + req.GetBase(), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), + ) + return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) { + return client.ResumeNode(ctx, req) + }) +} + +func (c *Client) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + req = typeutil.Clone(req) + commonpbutil.UpdateMsgBase( + req.GetBase(), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), + ) + return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) { + return client.TransferSegment(ctx, req) + }) +} + +func (c *Client) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + req = typeutil.Clone(req) + commonpbutil.UpdateMsgBase( + req.GetBase(), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), + ) + return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) { + return client.TransferChannel(ctx, req) + }) +} + +func (c *Client) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + req = typeutil.Clone(req) + commonpbutil.UpdateMsgBase( + req.GetBase(), + commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.grpcClient.GetNodeID())), + ) + return wrapGrpcCall(ctx, c, func(client querypb.QueryCoordClient) (*commonpb.Status, error) { + return client.CheckQueryNodeDistribution(ctx, req) + }) +} diff --git a/internal/distributed/querycoord/client/client_test.go b/internal/distributed/querycoord/client/client_test.go index 8b9a83665a..0b14ed48b2 100644 --- a/internal/distributed/querycoord/client/client_test.go +++ b/internal/distributed/querycoord/client/client_test.go @@ -158,6 +158,33 @@ func Test_NewClient(t *testing.T) { r30, err := client.DeactivateChecker(ctx, nil) retCheck(retNotNil, r30, err) + + r31, err := client.ListQueryNode(ctx, nil) + retCheck(retNotNil, r31, err) + + r32, err := client.GetQueryNodeDistribution(ctx, nil) + retCheck(retNotNil, r32, err) + + r33, err := client.SuspendBalance(ctx, nil) + retCheck(retNotNil, r33, err) + + r34, err := client.ResumeBalance(ctx, nil) + retCheck(retNotNil, r34, err) + + r35, err := client.SuspendNode(ctx, nil) + retCheck(retNotNil, r35, err) + + r36, err := client.ResumeNode(ctx, nil) + retCheck(retNotNil, r36, err) + + r37, err := client.TransferSegment(ctx, nil) + retCheck(retNotNil, r37, err) + + r38, err := client.TransferChannel(ctx, nil) + retCheck(retNotNil, r38, err) + + r39, err := client.CheckQueryNodeDistribution(ctx, nil) + retCheck(retNotNil, r39, err) } client.(*Client).grpcClient = &mock.GRPCClientBase[querypb.QueryCoordClient]{ diff --git a/internal/distributed/querycoord/service.go b/internal/distributed/querycoord/service.go index f81546fa45..5fe344b6e3 100644 --- a/internal/distributed/querycoord/service.go +++ b/internal/distributed/querycoord/service.go @@ -446,3 +446,39 @@ func (s *Server) DeactivateChecker(ctx context.Context, req *querypb.DeactivateC func (s *Server) ListCheckers(ctx context.Context, req *querypb.ListCheckersRequest) (*querypb.ListCheckersResponse, error) { return s.queryCoord.ListCheckers(ctx, req) } + +func (s *Server) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error) { + return s.queryCoord.ListQueryNode(ctx, req) +} + +func (s *Server) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error) { + return s.queryCoord.GetQueryNodeDistribution(ctx, req) +} + +func (s *Server) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest) (*commonpb.Status, error) { + return s.queryCoord.SuspendBalance(ctx, req) +} + +func (s *Server) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest) (*commonpb.Status, error) { + return s.queryCoord.ResumeBalance(ctx, req) +} + +func (s *Server) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest) (*commonpb.Status, error) { + return s.queryCoord.SuspendNode(ctx, req) +} + +func (s *Server) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest) (*commonpb.Status, error) { + return s.queryCoord.ResumeNode(ctx, req) +} + +func (s *Server) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest) (*commonpb.Status, error) { + return s.queryCoord.TransferSegment(ctx, req) +} + +func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest) (*commonpb.Status, error) { + return s.queryCoord.TransferChannel(ctx, req) +} + +func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error) { + return s.queryCoord.CheckQueryNodeDistribution(ctx, req) +} diff --git a/internal/distributed/querycoord/service_test.go b/internal/distributed/querycoord/service_test.go index 7c4f26f0b1..08ce7f7d77 100644 --- a/internal/distributed/querycoord/service_test.go +++ b/internal/distributed/querycoord/service_test.go @@ -283,6 +283,78 @@ func Test_NewServer(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) }) + t.Run("ListQueryNode", func(t *testing.T) { + req := &querypb.ListQueryNodeRequest{} + mqc.EXPECT().ListQueryNode(mock.Anything, req).Return(&querypb.ListQueryNodeResponse{Status: merr.Success()}, nil) + resp, err := server.ListQueryNode(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + t.Run("GetQueryNodeDistribution", func(t *testing.T) { + req := &querypb.GetQueryNodeDistributionRequest{} + mqc.EXPECT().GetQueryNodeDistribution(mock.Anything, req).Return(&querypb.GetQueryNodeDistributionResponse{Status: merr.Success()}, nil) + resp, err := server.GetQueryNodeDistribution(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + t.Run("SuspendBalance", func(t *testing.T) { + req := &querypb.SuspendBalanceRequest{} + mqc.EXPECT().SuspendBalance(mock.Anything, req).Return(merr.Success(), nil) + resp, err := server.SuspendBalance(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + }) + + t.Run("ResumeBalance", func(t *testing.T) { + req := &querypb.ResumeBalanceRequest{} + mqc.EXPECT().ResumeBalance(mock.Anything, req).Return(merr.Success(), nil) + resp, err := server.ResumeBalance(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + }) + + t.Run("SuspendNode", func(t *testing.T) { + req := &querypb.SuspendNodeRequest{} + mqc.EXPECT().SuspendNode(mock.Anything, req).Return(merr.Success(), nil) + resp, err := server.SuspendNode(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + }) + + t.Run("ResumeNode", func(t *testing.T) { + req := &querypb.ResumeNodeRequest{} + mqc.EXPECT().ResumeNode(mock.Anything, req).Return(merr.Success(), nil) + resp, err := server.ResumeNode(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + }) + + t.Run("TransferSegment", func(t *testing.T) { + req := &querypb.TransferSegmentRequest{} + mqc.EXPECT().TransferSegment(mock.Anything, req).Return(merr.Success(), nil) + resp, err := server.TransferSegment(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + }) + + t.Run("TransferChannel", func(t *testing.T) { + req := &querypb.TransferChannelRequest{} + mqc.EXPECT().TransferChannel(mock.Anything, req).Return(merr.Success(), nil) + resp, err := server.TransferChannel(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + }) + + t.Run("CheckQueryNodeDistribution", func(t *testing.T) { + req := &querypb.CheckQueryNodeDistributionRequest{} + mqc.EXPECT().CheckQueryNodeDistribution(mock.Anything, req).Return(merr.Success(), nil) + resp, err := server.CheckQueryNodeDistribution(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + }) + err = server.Stop() assert.NoError(t, err) } diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index e53d07ad91..5b7dbba1f1 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -74,6 +74,10 @@ func (s *Server) GetStatistics(ctx context.Context, request *querypb.GetStatisti return s.querynode.GetStatistics(ctx, request) } +func (s *Server) GetQueryNode() types.QueryNodeComponent { + return s.querynode +} + // NewServer create a new QueryNode grpc server. func NewServer(ctx context.Context, factory dependency.Factory) (*Server, error) { ctx1, cancel := context.WithCancel(ctx) diff --git a/internal/mocks/mock_querycoord.go b/internal/mocks/mock_querycoord.go index 2fd90e2095..d9099ac22c 100644 --- a/internal/mocks/mock_querycoord.go +++ b/internal/mocks/mock_querycoord.go @@ -144,6 +144,61 @@ func (_c *MockQueryCoord_CheckHealth_Call) RunAndReturn(run func(context.Context return _c } +// CheckQueryNodeDistribution provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) CheckQueryNodeDistribution(_a0 context.Context, _a1 *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoord_CheckQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckQueryNodeDistribution' +type MockQueryCoord_CheckQueryNodeDistribution_Call struct { + *mock.Call +} + +// CheckQueryNodeDistribution is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *querypb.CheckQueryNodeDistributionRequest +func (_e *MockQueryCoord_Expecter) CheckQueryNodeDistribution(_a0 interface{}, _a1 interface{}) *MockQueryCoord_CheckQueryNodeDistribution_Call { + return &MockQueryCoord_CheckQueryNodeDistribution_Call{Call: _e.mock.On("CheckQueryNodeDistribution", _a0, _a1)} +} + +func (_c *MockQueryCoord_CheckQueryNodeDistribution_Call) Run(run func(_a0 context.Context, _a1 *querypb.CheckQueryNodeDistributionRequest)) *MockQueryCoord_CheckQueryNodeDistribution_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.CheckQueryNodeDistributionRequest)) + }) + return _c +} + +func (_c *MockQueryCoord_CheckQueryNodeDistribution_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_CheckQueryNodeDistribution_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoord_CheckQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error)) *MockQueryCoord_CheckQueryNodeDistribution_Call { + _c.Call.Return(run) + return _c +} + // CreateResourceGroup provides a mock function with given fields: _a0, _a1 func (_m *MockQueryCoord) CreateResourceGroup(_a0 context.Context, _a1 *milvuspb.CreateResourceGroupRequest) (*commonpb.Status, error) { ret := _m.Called(_a0, _a1) @@ -529,6 +584,61 @@ func (_c *MockQueryCoord_GetPartitionStates_Call) RunAndReturn(run func(context. return _c } +// GetQueryNodeDistribution provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) GetQueryNodeDistribution(_a0 context.Context, _a1 *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *querypb.GetQueryNodeDistributionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest) *querypb.GetQueryNodeDistributionResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.GetQueryNodeDistributionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetQueryNodeDistributionRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoord_GetQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetQueryNodeDistribution' +type MockQueryCoord_GetQueryNodeDistribution_Call struct { + *mock.Call +} + +// GetQueryNodeDistribution is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *querypb.GetQueryNodeDistributionRequest +func (_e *MockQueryCoord_Expecter) GetQueryNodeDistribution(_a0 interface{}, _a1 interface{}) *MockQueryCoord_GetQueryNodeDistribution_Call { + return &MockQueryCoord_GetQueryNodeDistribution_Call{Call: _e.mock.On("GetQueryNodeDistribution", _a0, _a1)} +} + +func (_c *MockQueryCoord_GetQueryNodeDistribution_Call) Run(run func(_a0 context.Context, _a1 *querypb.GetQueryNodeDistributionRequest)) *MockQueryCoord_GetQueryNodeDistribution_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.GetQueryNodeDistributionRequest)) + }) + return _c +} + +func (_c *MockQueryCoord_GetQueryNodeDistribution_Call) Return(_a0 *querypb.GetQueryNodeDistributionResponse, _a1 error) *MockQueryCoord_GetQueryNodeDistribution_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoord_GetQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error)) *MockQueryCoord_GetQueryNodeDistribution_Call { + _c.Call.Return(run) + return _c +} + // GetReplicas provides a mock function with given fields: _a0, _a1 func (_m *MockQueryCoord) GetReplicas(_a0 context.Context, _a1 *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) { ret := _m.Called(_a0, _a1) @@ -900,6 +1010,61 @@ func (_c *MockQueryCoord_ListCheckers_Call) RunAndReturn(run func(context.Contex return _c } +// ListQueryNode provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) ListQueryNode(_a0 context.Context, _a1 *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error) { + ret := _m.Called(_a0, _a1) + + var r0 *querypb.ListQueryNodeResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest) *querypb.ListQueryNodeResponse); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.ListQueryNodeResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ListQueryNodeRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoord_ListQueryNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListQueryNode' +type MockQueryCoord_ListQueryNode_Call struct { + *mock.Call +} + +// ListQueryNode is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *querypb.ListQueryNodeRequest +func (_e *MockQueryCoord_Expecter) ListQueryNode(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ListQueryNode_Call { + return &MockQueryCoord_ListQueryNode_Call{Call: _e.mock.On("ListQueryNode", _a0, _a1)} +} + +func (_c *MockQueryCoord_ListQueryNode_Call) Run(run func(_a0 context.Context, _a1 *querypb.ListQueryNodeRequest)) *MockQueryCoord_ListQueryNode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.ListQueryNodeRequest)) + }) + return _c +} + +func (_c *MockQueryCoord_ListQueryNode_Call) Return(_a0 *querypb.ListQueryNodeResponse, _a1 error) *MockQueryCoord_ListQueryNode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoord_ListQueryNode_Call) RunAndReturn(run func(context.Context, *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error)) *MockQueryCoord_ListQueryNode_Call { + _c.Call.Return(run) + return _c +} + // ListResourceGroups provides a mock function with given fields: _a0, _a1 func (_m *MockQueryCoord) ListResourceGroups(_a0 context.Context, _a1 *milvuspb.ListResourceGroupsRequest) (*milvuspb.ListResourceGroupsResponse, error) { ret := _m.Called(_a0, _a1) @@ -1271,6 +1436,116 @@ func (_c *MockQueryCoord_ReleasePartitions_Call) RunAndReturn(run func(context.C return _c } +// ResumeBalance provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) ResumeBalance(_a0 context.Context, _a1 *querypb.ResumeBalanceRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeBalanceRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoord_ResumeBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeBalance' +type MockQueryCoord_ResumeBalance_Call struct { + *mock.Call +} + +// ResumeBalance is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *querypb.ResumeBalanceRequest +func (_e *MockQueryCoord_Expecter) ResumeBalance(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ResumeBalance_Call { + return &MockQueryCoord_ResumeBalance_Call{Call: _e.mock.On("ResumeBalance", _a0, _a1)} +} + +func (_c *MockQueryCoord_ResumeBalance_Call) Run(run func(_a0 context.Context, _a1 *querypb.ResumeBalanceRequest)) *MockQueryCoord_ResumeBalance_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.ResumeBalanceRequest)) + }) + return _c +} + +func (_c *MockQueryCoord_ResumeBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_ResumeBalance_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoord_ResumeBalance_Call) RunAndReturn(run func(context.Context, *querypb.ResumeBalanceRequest) (*commonpb.Status, error)) *MockQueryCoord_ResumeBalance_Call { + _c.Call.Return(run) + return _c +} + +// ResumeNode provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) ResumeNode(_a0 context.Context, _a1 *querypb.ResumeNodeRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeNodeRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoord_ResumeNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeNode' +type MockQueryCoord_ResumeNode_Call struct { + *mock.Call +} + +// ResumeNode is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *querypb.ResumeNodeRequest +func (_e *MockQueryCoord_Expecter) ResumeNode(_a0 interface{}, _a1 interface{}) *MockQueryCoord_ResumeNode_Call { + return &MockQueryCoord_ResumeNode_Call{Call: _e.mock.On("ResumeNode", _a0, _a1)} +} + +func (_c *MockQueryCoord_ResumeNode_Call) Run(run func(_a0 context.Context, _a1 *querypb.ResumeNodeRequest)) *MockQueryCoord_ResumeNode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.ResumeNodeRequest)) + }) + return _c +} + +func (_c *MockQueryCoord_ResumeNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_ResumeNode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoord_ResumeNode_Call) RunAndReturn(run func(context.Context, *querypb.ResumeNodeRequest) (*commonpb.Status, error)) *MockQueryCoord_ResumeNode_Call { + _c.Call.Return(run) + return _c +} + // SetAddress provides a mock function with given fields: address func (_m *MockQueryCoord) SetAddress(address string) { _m.Called(address) @@ -1734,6 +2009,116 @@ func (_c *MockQueryCoord_Stop_Call) RunAndReturn(run func() error) *MockQueryCoo return _c } +// SuspendBalance provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) SuspendBalance(_a0 context.Context, _a1 *querypb.SuspendBalanceRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendBalanceRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoord_SuspendBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendBalance' +type MockQueryCoord_SuspendBalance_Call struct { + *mock.Call +} + +// SuspendBalance is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *querypb.SuspendBalanceRequest +func (_e *MockQueryCoord_Expecter) SuspendBalance(_a0 interface{}, _a1 interface{}) *MockQueryCoord_SuspendBalance_Call { + return &MockQueryCoord_SuspendBalance_Call{Call: _e.mock.On("SuspendBalance", _a0, _a1)} +} + +func (_c *MockQueryCoord_SuspendBalance_Call) Run(run func(_a0 context.Context, _a1 *querypb.SuspendBalanceRequest)) *MockQueryCoord_SuspendBalance_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.SuspendBalanceRequest)) + }) + return _c +} + +func (_c *MockQueryCoord_SuspendBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_SuspendBalance_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoord_SuspendBalance_Call) RunAndReturn(run func(context.Context, *querypb.SuspendBalanceRequest) (*commonpb.Status, error)) *MockQueryCoord_SuspendBalance_Call { + _c.Call.Return(run) + return _c +} + +// SuspendNode provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) SuspendNode(_a0 context.Context, _a1 *querypb.SuspendNodeRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendNodeRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoord_SuspendNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendNode' +type MockQueryCoord_SuspendNode_Call struct { + *mock.Call +} + +// SuspendNode is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *querypb.SuspendNodeRequest +func (_e *MockQueryCoord_Expecter) SuspendNode(_a0 interface{}, _a1 interface{}) *MockQueryCoord_SuspendNode_Call { + return &MockQueryCoord_SuspendNode_Call{Call: _e.mock.On("SuspendNode", _a0, _a1)} +} + +func (_c *MockQueryCoord_SuspendNode_Call) Run(run func(_a0 context.Context, _a1 *querypb.SuspendNodeRequest)) *MockQueryCoord_SuspendNode_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.SuspendNodeRequest)) + }) + return _c +} + +func (_c *MockQueryCoord_SuspendNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_SuspendNode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoord_SuspendNode_Call) RunAndReturn(run func(context.Context, *querypb.SuspendNodeRequest) (*commonpb.Status, error)) *MockQueryCoord_SuspendNode_Call { + _c.Call.Return(run) + return _c +} + // SyncNewCreatedPartition provides a mock function with given fields: _a0, _a1 func (_m *MockQueryCoord) SyncNewCreatedPartition(_a0 context.Context, _a1 *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) { ret := _m.Called(_a0, _a1) @@ -1789,6 +2174,61 @@ func (_c *MockQueryCoord_SyncNewCreatedPartition_Call) RunAndReturn(run func(con return _c } +// TransferChannel provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) TransferChannel(_a0 context.Context, _a1 *querypb.TransferChannelRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferChannelRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoord_TransferChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferChannel' +type MockQueryCoord_TransferChannel_Call struct { + *mock.Call +} + +// TransferChannel is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *querypb.TransferChannelRequest +func (_e *MockQueryCoord_Expecter) TransferChannel(_a0 interface{}, _a1 interface{}) *MockQueryCoord_TransferChannel_Call { + return &MockQueryCoord_TransferChannel_Call{Call: _e.mock.On("TransferChannel", _a0, _a1)} +} + +func (_c *MockQueryCoord_TransferChannel_Call) Run(run func(_a0 context.Context, _a1 *querypb.TransferChannelRequest)) *MockQueryCoord_TransferChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.TransferChannelRequest)) + }) + return _c +} + +func (_c *MockQueryCoord_TransferChannel_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_TransferChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoord_TransferChannel_Call) RunAndReturn(run func(context.Context, *querypb.TransferChannelRequest) (*commonpb.Status, error)) *MockQueryCoord_TransferChannel_Call { + _c.Call.Return(run) + return _c +} + // TransferNode provides a mock function with given fields: _a0, _a1 func (_m *MockQueryCoord) TransferNode(_a0 context.Context, _a1 *milvuspb.TransferNodeRequest) (*commonpb.Status, error) { ret := _m.Called(_a0, _a1) @@ -1899,6 +2339,61 @@ func (_c *MockQueryCoord_TransferReplica_Call) RunAndReturn(run func(context.Con return _c } +// TransferSegment provides a mock function with given fields: _a0, _a1 +func (_m *MockQueryCoord) TransferSegment(_a0 context.Context, _a1 *querypb.TransferSegmentRequest) (*commonpb.Status, error) { + ret := _m.Called(_a0, _a1) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest) (*commonpb.Status, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest) *commonpb.Status); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferSegmentRequest) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoord_TransferSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferSegment' +type MockQueryCoord_TransferSegment_Call struct { + *mock.Call +} + +// TransferSegment is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *querypb.TransferSegmentRequest +func (_e *MockQueryCoord_Expecter) TransferSegment(_a0 interface{}, _a1 interface{}) *MockQueryCoord_TransferSegment_Call { + return &MockQueryCoord_TransferSegment_Call{Call: _e.mock.On("TransferSegment", _a0, _a1)} +} + +func (_c *MockQueryCoord_TransferSegment_Call) Run(run func(_a0 context.Context, _a1 *querypb.TransferSegmentRequest)) *MockQueryCoord_TransferSegment_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*querypb.TransferSegmentRequest)) + }) + return _c +} + +func (_c *MockQueryCoord_TransferSegment_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_TransferSegment_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoord_TransferSegment_Call) RunAndReturn(run func(context.Context, *querypb.TransferSegmentRequest) (*commonpb.Status, error)) *MockQueryCoord_TransferSegment_Call { + _c.Call.Return(run) + return _c +} + // UpdateStateCode provides a mock function with given fields: stateCode func (_m *MockQueryCoord) UpdateStateCode(stateCode commonpb.StateCode) { _m.Called(stateCode) diff --git a/internal/mocks/mock_querycoord_client.go b/internal/mocks/mock_querycoord_client.go index 947bff1387..e8f3972bef 100644 --- a/internal/mocks/mock_querycoord_client.go +++ b/internal/mocks/mock_querycoord_client.go @@ -171,6 +171,76 @@ func (_c *MockQueryCoordClient_CheckHealth_Call) RunAndReturn(run func(context.C return _c } +// CheckQueryNodeDistribution provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) CheckQueryNodeDistribution(ctx context.Context, in *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_CheckQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckQueryNodeDistribution' +type MockQueryCoordClient_CheckQueryNodeDistribution_Call struct { + *mock.Call +} + +// CheckQueryNodeDistribution is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.CheckQueryNodeDistributionRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) CheckQueryNodeDistribution(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_CheckQueryNodeDistribution_Call { + return &MockQueryCoordClient_CheckQueryNodeDistribution_Call{Call: _e.mock.On("CheckQueryNodeDistribution", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_CheckQueryNodeDistribution_Call) Run(run func(ctx context.Context, in *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_CheckQueryNodeDistribution_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.CheckQueryNodeDistributionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_CheckQueryNodeDistribution_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_CheckQueryNodeDistribution_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_CheckQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.CheckQueryNodeDistributionRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_CheckQueryNodeDistribution_Call { + _c.Call.Return(run) + return _c +} + // Close provides a mock function with given fields: func (_m *MockQueryCoordClient) Close() error { ret := _m.Called() @@ -702,6 +772,76 @@ func (_c *MockQueryCoordClient_GetPartitionStates_Call) RunAndReturn(run func(co return _c } +// GetQueryNodeDistribution provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) GetQueryNodeDistribution(ctx context.Context, in *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *querypb.GetQueryNodeDistributionResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) *querypb.GetQueryNodeDistributionResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.GetQueryNodeDistributionResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_GetQueryNodeDistribution_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetQueryNodeDistribution' +type MockQueryCoordClient_GetQueryNodeDistribution_Call struct { + *mock.Call +} + +// GetQueryNodeDistribution is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.GetQueryNodeDistributionRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) GetQueryNodeDistribution(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_GetQueryNodeDistribution_Call { + return &MockQueryCoordClient_GetQueryNodeDistribution_Call{Call: _e.mock.On("GetQueryNodeDistribution", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_GetQueryNodeDistribution_Call) Run(run func(ctx context.Context, in *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_GetQueryNodeDistribution_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.GetQueryNodeDistributionRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_GetQueryNodeDistribution_Call) Return(_a0 *querypb.GetQueryNodeDistributionResponse, _a1 error) *MockQueryCoordClient_GetQueryNodeDistribution_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_GetQueryNodeDistribution_Call) RunAndReturn(run func(context.Context, *querypb.GetQueryNodeDistributionRequest, ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error)) *MockQueryCoordClient_GetQueryNodeDistribution_Call { + _c.Call.Return(run) + return _c +} + // GetReplicas provides a mock function with given fields: ctx, in, opts func (_m *MockQueryCoordClient) GetReplicas(ctx context.Context, in *milvuspb.GetReplicasRequest, opts ...grpc.CallOption) (*milvuspb.GetReplicasResponse, error) { _va := make([]interface{}, len(opts)) @@ -1122,6 +1262,76 @@ func (_c *MockQueryCoordClient_ListCheckers_Call) RunAndReturn(run func(context. return _c } +// ListQueryNode provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) ListQueryNode(ctx context.Context, in *querypb.ListQueryNodeRequest, opts ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *querypb.ListQueryNodeResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) *querypb.ListQueryNodeResponse); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*querypb.ListQueryNodeResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_ListQueryNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListQueryNode' +type MockQueryCoordClient_ListQueryNode_Call struct { + *mock.Call +} + +// ListQueryNode is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.ListQueryNodeRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) ListQueryNode(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ListQueryNode_Call { + return &MockQueryCoordClient_ListQueryNode_Call{Call: _e.mock.On("ListQueryNode", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_ListQueryNode_Call) Run(run func(ctx context.Context, in *querypb.ListQueryNodeRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ListQueryNode_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.ListQueryNodeRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_ListQueryNode_Call) Return(_a0 *querypb.ListQueryNodeResponse, _a1 error) *MockQueryCoordClient_ListQueryNode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_ListQueryNode_Call) RunAndReturn(run func(context.Context, *querypb.ListQueryNodeRequest, ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error)) *MockQueryCoordClient_ListQueryNode_Call { + _c.Call.Return(run) + return _c +} + // ListResourceGroups provides a mock function with given fields: ctx, in, opts func (_m *MockQueryCoordClient) ListResourceGroups(ctx context.Context, in *milvuspb.ListResourceGroupsRequest, opts ...grpc.CallOption) (*milvuspb.ListResourceGroupsResponse, error) { _va := make([]interface{}, len(opts)) @@ -1542,6 +1752,146 @@ func (_c *MockQueryCoordClient_ReleasePartitions_Call) RunAndReturn(run func(con return _c } +// ResumeBalance provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) ResumeBalance(ctx context.Context, in *querypb.ResumeBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_ResumeBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeBalance' +type MockQueryCoordClient_ResumeBalance_Call struct { + *mock.Call +} + +// ResumeBalance is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.ResumeBalanceRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) ResumeBalance(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ResumeBalance_Call { + return &MockQueryCoordClient_ResumeBalance_Call{Call: _e.mock.On("ResumeBalance", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_ResumeBalance_Call) Run(run func(ctx context.Context, in *querypb.ResumeBalanceRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ResumeBalance_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.ResumeBalanceRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_ResumeBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_ResumeBalance_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_ResumeBalance_Call) RunAndReturn(run func(context.Context, *querypb.ResumeBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_ResumeBalance_Call { + _c.Call.Return(run) + return _c +} + +// ResumeNode provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) ResumeNode(ctx context.Context, in *querypb.ResumeNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_ResumeNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ResumeNode' +type MockQueryCoordClient_ResumeNode_Call struct { + *mock.Call +} + +// ResumeNode is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.ResumeNodeRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) ResumeNode(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_ResumeNode_Call { + return &MockQueryCoordClient_ResumeNode_Call{Call: _e.mock.On("ResumeNode", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_ResumeNode_Call) Run(run func(ctx context.Context, in *querypb.ResumeNodeRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_ResumeNode_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.ResumeNodeRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_ResumeNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_ResumeNode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_ResumeNode_Call) RunAndReturn(run func(context.Context, *querypb.ResumeNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_ResumeNode_Call { + _c.Call.Return(run) + return _c +} + // ShowCollections provides a mock function with given fields: ctx, in, opts func (_m *MockQueryCoordClient) ShowCollections(ctx context.Context, in *querypb.ShowCollectionsRequest, opts ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error) { _va := make([]interface{}, len(opts)) @@ -1752,6 +2102,146 @@ func (_c *MockQueryCoordClient_ShowPartitions_Call) RunAndReturn(run func(contex return _c } +// SuspendBalance provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) SuspendBalance(ctx context.Context, in *querypb.SuspendBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_SuspendBalance_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendBalance' +type MockQueryCoordClient_SuspendBalance_Call struct { + *mock.Call +} + +// SuspendBalance is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.SuspendBalanceRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) SuspendBalance(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_SuspendBalance_Call { + return &MockQueryCoordClient_SuspendBalance_Call{Call: _e.mock.On("SuspendBalance", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_SuspendBalance_Call) Run(run func(ctx context.Context, in *querypb.SuspendBalanceRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_SuspendBalance_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.SuspendBalanceRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_SuspendBalance_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_SuspendBalance_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_SuspendBalance_Call) RunAndReturn(run func(context.Context, *querypb.SuspendBalanceRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_SuspendBalance_Call { + _c.Call.Return(run) + return _c +} + +// SuspendNode provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) SuspendNode(ctx context.Context, in *querypb.SuspendNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_SuspendNode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SuspendNode' +type MockQueryCoordClient_SuspendNode_Call struct { + *mock.Call +} + +// SuspendNode is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.SuspendNodeRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) SuspendNode(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_SuspendNode_Call { + return &MockQueryCoordClient_SuspendNode_Call{Call: _e.mock.On("SuspendNode", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_SuspendNode_Call) Run(run func(ctx context.Context, in *querypb.SuspendNodeRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_SuspendNode_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.SuspendNodeRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_SuspendNode_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_SuspendNode_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_SuspendNode_Call) RunAndReturn(run func(context.Context, *querypb.SuspendNodeRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_SuspendNode_Call { + _c.Call.Return(run) + return _c +} + // SyncNewCreatedPartition provides a mock function with given fields: ctx, in, opts func (_m *MockQueryCoordClient) SyncNewCreatedPartition(ctx context.Context, in *querypb.SyncNewCreatedPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { _va := make([]interface{}, len(opts)) @@ -1822,6 +2312,76 @@ func (_c *MockQueryCoordClient_SyncNewCreatedPartition_Call) RunAndReturn(run fu return _c } +// TransferChannel provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) TransferChannel(ctx context.Context, in *querypb.TransferChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_TransferChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferChannel' +type MockQueryCoordClient_TransferChannel_Call struct { + *mock.Call +} + +// TransferChannel is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.TransferChannelRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) TransferChannel(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_TransferChannel_Call { + return &MockQueryCoordClient_TransferChannel_Call{Call: _e.mock.On("TransferChannel", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_TransferChannel_Call) Run(run func(ctx context.Context, in *querypb.TransferChannelRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_TransferChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.TransferChannelRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_TransferChannel_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_TransferChannel_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_TransferChannel_Call) RunAndReturn(run func(context.Context, *querypb.TransferChannelRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_TransferChannel_Call { + _c.Call.Return(run) + return _c +} + // TransferNode provides a mock function with given fields: ctx, in, opts func (_m *MockQueryCoordClient) TransferNode(ctx context.Context, in *milvuspb.TransferNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { _va := make([]interface{}, len(opts)) @@ -1962,6 +2522,76 @@ func (_c *MockQueryCoordClient_TransferReplica_Call) RunAndReturn(run func(conte return _c } +// TransferSegment provides a mock function with given fields: ctx, in, opts +func (_m *MockQueryCoordClient) TransferSegment(ctx context.Context, in *querypb.TransferSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, in) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 *commonpb.Status + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) (*commonpb.Status, error)); ok { + return rf(ctx, in, opts...) + } + if rf, ok := ret.Get(0).(func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) *commonpb.Status); ok { + r0 = rf(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*commonpb.Status) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) error); ok { + r1 = rf(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockQueryCoordClient_TransferSegment_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TransferSegment' +type MockQueryCoordClient_TransferSegment_Call struct { + *mock.Call +} + +// TransferSegment is a helper method to define mock.On call +// - ctx context.Context +// - in *querypb.TransferSegmentRequest +// - opts ...grpc.CallOption +func (_e *MockQueryCoordClient_Expecter) TransferSegment(ctx interface{}, in interface{}, opts ...interface{}) *MockQueryCoordClient_TransferSegment_Call { + return &MockQueryCoordClient_TransferSegment_Call{Call: _e.mock.On("TransferSegment", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *MockQueryCoordClient_TransferSegment_Call) Run(run func(ctx context.Context, in *querypb.TransferSegmentRequest, opts ...grpc.CallOption)) *MockQueryCoordClient_TransferSegment_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]grpc.CallOption, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(grpc.CallOption) + } + } + run(args[0].(context.Context), args[1].(*querypb.TransferSegmentRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockQueryCoordClient_TransferSegment_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoordClient_TransferSegment_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockQueryCoordClient_TransferSegment_Call) RunAndReturn(run func(context.Context, *querypb.TransferSegmentRequest, ...grpc.CallOption) (*commonpb.Status, error)) *MockQueryCoordClient_TransferSegment_Call { + _c.Call.Return(run) + return _c +} + // NewMockQueryCoordClient creates a new instance of MockQueryCoordClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockQueryCoordClient(t interface { diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index 22d8e4c92a..f000212fcf 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -86,13 +86,21 @@ service QueryCoord { returns (DescribeResourceGroupResponse) { } - // ops interfaces - rpc ListCheckers(ListCheckersRequest) returns (ListCheckersResponse) { - } - rpc ActivateChecker(ActivateCheckerRequest) returns (common.Status) { - } - rpc DeactivateChecker(DeactivateCheckerRequest) returns (common.Status) { - } + + // ops interfaces + rpc ListCheckers(ListCheckersRequest) returns (ListCheckersResponse) {} + rpc ActivateChecker(ActivateCheckerRequest) returns (common.Status) {} + rpc DeactivateChecker(DeactivateCheckerRequest) returns (common.Status) {} + + rpc ListQueryNode(ListQueryNodeRequest) returns (ListQueryNodeResponse) {} + rpc GetQueryNodeDistribution(GetQueryNodeDistributionRequest) returns (GetQueryNodeDistributionResponse) {} + rpc SuspendBalance(SuspendBalanceRequest) returns (common.Status) {} + rpc ResumeBalance(ResumeBalanceRequest) returns (common.Status) {} + rpc SuspendNode(SuspendNodeRequest) returns (common.Status) {} + rpc ResumeNode(ResumeNodeRequest) returns (common.Status) {} + rpc TransferSegment(TransferSegmentRequest) returns (common.Status) {} + rpc TransferChannel(TransferChannelRequest) returns (common.Status) {} + rpc CheckQueryNodeDistribution(CheckQueryNodeDistributionRequest) returns (common.Status) {} } service QueryNode { @@ -793,3 +801,75 @@ message CollectionTarget { repeated ChannelTarget Channel_targets = 2; int64 version = 3; } +message NodeInfo { + int64 ID = 2; + string address = 3; + string state = 4; +} + +message ListQueryNodeRequest { + common.MsgBase base = 1; +} + +message ListQueryNodeResponse { + common.Status status = 1; + repeated NodeInfo nodeInfos = 2; +} + +message GetQueryNodeDistributionRequest { + common.MsgBase base = 1; + int64 nodeID = 2; +} + +message GetQueryNodeDistributionResponse { + common.Status status = 1; + int64 ID = 2; + repeated string channel_names = 3; + repeated int64 sealed_segmentIDs = 4; +} + +message SuspendBalanceRequest { + common.MsgBase base = 1; +} + +message ResumeBalanceRequest { + common.MsgBase base = 1; +} + +message SuspendNodeRequest { + common.MsgBase base = 1; + int64 nodeID = 2; +} + +message ResumeNodeRequest { + common.MsgBase base = 1; + int64 nodeID = 2; +} + +message TransferSegmentRequest { + common.MsgBase base = 1; + int64 segmentID = 2; + int64 source_nodeID = 3; + int64 target_nodeID = 4; + bool transfer_all = 5; + bool to_all_nodes = 6; + bool copy_mode = 7; +} + +message TransferChannelRequest { + common.MsgBase base = 1; + string channel_name = 2; + int64 source_nodeID = 3; + int64 target_nodeID = 4; + bool transfer_all = 5; + bool to_all_nodes = 6; + bool copy_mode = 7; +} + +message CheckQueryNodeDistributionRequest { + common.MsgBase base = 1; + int64 source_nodeID = 3; + int64 target_nodeID = 4; +} + + diff --git a/internal/proxy/management.go b/internal/proxy/management.go index 92aef6c600..1abdc89203 100644 --- a/internal/proxy/management.go +++ b/internal/proxy/management.go @@ -17,14 +17,18 @@ package proxy import ( + "encoding/json" "fmt" "net/http" + "strconv" "sync" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" management "github.com/milvus-io/milvus/internal/http" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/merr" ) // this file contains proxy management restful API handler @@ -32,6 +36,17 @@ import ( const ( mgrRouteGcPause = `/management/datacoord/garbage_collection/pause` mgrRouteGcResume = `/management/datacoord/garbage_collection/resume` + + mgrSuspendQueryCoordBalance = `/management/querycoord/balance/suspend` + mgrResumeQueryCoordBalance = `/management/querycoord/balance/resume` + mgrTransferSegment = `/management/querycoord/transfer/segment` + mgrTransferChannel = `/management/querycoord/transfer/channel` + + mgrSuspendQueryNode = `/management/querycoord/node/suspend` + mgrResumeQueryNode = `/management/querycoord/node/resume` + mgrListQueryNode = `/management/querycoord/node/list` + mgrGetQueryNodeDistribution = `/management/querycoord/distribution/get` + mgrCheckQueryNodeDistribution = `/management/querycoord/distribution/check` ) var mgrRouteRegisterOnce sync.Once @@ -46,6 +61,42 @@ func RegisterMgrRoute(proxy *Proxy) { Path: mgrRouteGcResume, HandlerFunc: proxy.ResumeDatacoordGC, }) + management.Register(&management.Handler{ + Path: mgrListQueryNode, + HandlerFunc: proxy.ListQueryNode, + }) + management.Register(&management.Handler{ + Path: mgrGetQueryNodeDistribution, + HandlerFunc: proxy.GetQueryNodeDistribution, + }) + management.Register(&management.Handler{ + Path: mgrSuspendQueryCoordBalance, + HandlerFunc: proxy.SuspendQueryCoordBalance, + }) + management.Register(&management.Handler{ + Path: mgrResumeQueryCoordBalance, + HandlerFunc: proxy.ResumeQueryCoordBalance, + }) + management.Register(&management.Handler{ + Path: mgrSuspendQueryNode, + HandlerFunc: proxy.SuspendQueryNode, + }) + management.Register(&management.Handler{ + Path: mgrResumeQueryNode, + HandlerFunc: proxy.ResumeQueryNode, + }) + management.Register(&management.Handler{ + Path: mgrTransferSegment, + HandlerFunc: proxy.TransferSegment, + }) + management.Register(&management.Handler{ + Path: mgrTransferChannel, + HandlerFunc: proxy.TransferChannel, + }) + management.Register(&management.Handler{ + Path: mgrCheckQueryNodeDistribution, + HandlerFunc: proxy.CheckQueryNodeDistribution, + }) }) } @@ -91,3 +142,362 @@ func (node *Proxy) ResumeDatacoordGC(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"msg": "OK"}`)) } + +func (node *Proxy) ListQueryNode(w http.ResponseWriter, req *http.Request) { + resp, err := node.queryCoord.ListQueryNode(req.Context(), &querypb.ListQueryNodeRequest{ + Base: commonpbutil.NewMsgBase(), + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to list query node, %s"}`, err.Error()))) + return + } + + if !merr.Ok(resp.GetStatus()) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to list query node, %s"}`, resp.GetStatus().GetReason()))) + return + } + + w.WriteHeader(http.StatusOK) + // skip marshal status to output + resp.Status = nil + bytes, err := json.Marshal(resp) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to list query node, %s"}`, err.Error()))) + return + } + w.Write(bytes) +} + +func (node *Proxy) GetQueryNodeDistribution(w http.ResponseWriter, req *http.Request) { + err := req.ParseForm() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error()))) + return + } + + nodeID, err := strconv.ParseInt(req.FormValue("node_id"), 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, err.Error()))) + return + } + + resp, err := node.queryCoord.GetQueryNodeDistribution(req.Context(), &querypb.GetQueryNodeDistributionRequest{ + Base: commonpbutil.NewMsgBase(), + NodeID: nodeID, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, err.Error()))) + return + } + + if !merr.Ok(resp.GetStatus()) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, resp.GetStatus().GetReason()))) + return + } + w.WriteHeader(http.StatusOK) + // skip marshal status to output + resp.Status = nil + bytes, err := json.Marshal(resp) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to get query node distribution, %s"}`, err.Error()))) + return + } + w.Write(bytes) +} + +func (node *Proxy) SuspendQueryCoordBalance(w http.ResponseWriter, req *http.Request) { + resp, err := node.queryCoord.SuspendBalance(req.Context(), &querypb.SuspendBalanceRequest{ + Base: commonpbutil.NewMsgBase(), + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend balance, %s"}`, err.Error()))) + return + } + + if !merr.Ok(resp) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend balance, %s"}`, resp.GetReason()))) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"msg": "OK"}`)) +} + +func (node *Proxy) ResumeQueryCoordBalance(w http.ResponseWriter, req *http.Request) { + resp, err := node.queryCoord.ResumeBalance(req.Context(), &querypb.ResumeBalanceRequest{ + Base: commonpbutil.NewMsgBase(), + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume balance, %s"}`, err.Error()))) + return + } + + if !merr.Ok(resp) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume balance, %s"}`, resp.GetReason()))) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"msg": "OK"}`)) +} + +func (node *Proxy) SuspendQueryNode(w http.ResponseWriter, req *http.Request) { + err := req.ParseForm() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error()))) + return + } + + nodeID, err := strconv.ParseInt(req.FormValue("node_id"), 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend node, %s"}`, err.Error()))) + return + } + resp, err := node.queryCoord.SuspendNode(req.Context(), &querypb.SuspendNodeRequest{ + Base: commonpbutil.NewMsgBase(), + NodeID: nodeID, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend node, %s"}`, err.Error()))) + return + } + + if !merr.Ok(resp) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to suspend node, %s"}`, resp.GetReason()))) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"msg": "OK"}`)) +} + +func (node *Proxy) ResumeQueryNode(w http.ResponseWriter, req *http.Request) { + err := req.ParseForm() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error()))) + return + } + + nodeID, err := strconv.ParseInt(req.FormValue("node_id"), 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume node, %s"}`, err.Error()))) + return + } + resp, err := node.queryCoord.ResumeNode(req.Context(), &querypb.ResumeNodeRequest{ + Base: commonpbutil.NewMsgBase(), + NodeID: nodeID, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume node, %s"}`, err.Error()))) + return + } + + if !merr.Ok(resp) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to resume node, %s"}`, resp.GetReason()))) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"msg": "OK"}`)) +} + +func (node *Proxy) TransferSegment(w http.ResponseWriter, req *http.Request) { + err := req.ParseForm() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error()))) + return + } + + request := &querypb.TransferSegmentRequest{ + Base: commonpbutil.NewMsgBase(), + } + + source, err := strconv.ParseInt(req.FormValue("source_node_id"), 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": failed to transfer segment", %s"}`, err.Error()))) + return + } + request.SourceNodeID = source + + target := req.FormValue("target_node_id") + if len(target) == 0 { + request.ToAllNodes = true + } else { + value, err := strconv.ParseInt(target, 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error()))) + return + } + request.TargetNodeID = value + } + + segmentID := req.FormValue("segment_id") + if len(segmentID) == 0 { + request.TransferAll = true + } else { + value, err := strconv.ParseInt(segmentID, 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error()))) + return + } + request.TargetNodeID = value + } + + copyMode := req.FormValue("copy_mode") + if len(copyMode) == 0 { + request.CopyMode = true + } else { + value, err := strconv.ParseBool(copyMode) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error()))) + return + } + request.CopyMode = value + } + + resp, err := node.queryCoord.TransferSegment(req.Context(), request) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, err.Error()))) + return + } + + if !merr.Ok(resp) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer segment, %s"}`, resp.GetReason()))) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"msg": "OK"}`)) +} + +func (node *Proxy) TransferChannel(w http.ResponseWriter, req *http.Request) { + err := req.ParseForm() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error()))) + return + } + + request := &querypb.TransferChannelRequest{ + Base: commonpbutil.NewMsgBase(), + } + + source, err := strconv.ParseInt(req.FormValue("source_node_id"), 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": failed to transfer channel", %s"}`, err.Error()))) + return + } + request.SourceNodeID = source + + target := req.FormValue("target_node_id") + if len(target) == 0 { + request.ToAllNodes = true + } else { + value, err := strconv.ParseInt(target, 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error()))) + return + } + request.TargetNodeID = value + } + + channel := req.FormValue("channel_name") + if len(channel) == 0 { + request.TransferAll = true + } else { + request.ChannelName = channel + } + + copyMode := req.FormValue("copy_mode") + if len(copyMode) == 0 { + request.CopyMode = false + } else { + value, err := strconv.ParseBool(copyMode) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error()))) + return + } + request.CopyMode = value + } + + resp, err := node.queryCoord.TransferChannel(req.Context(), request) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, err.Error()))) + return + } + + if !merr.Ok(resp) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to transfer channel, %s"}`, resp.GetReason()))) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"msg": "OK"}`)) +} + +func (node *Proxy) CheckQueryNodeDistribution(w http.ResponseWriter, req *http.Request) { + err := req.ParseForm() + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, err.Error()))) + return + } + + source, err := strconv.ParseInt(req.FormValue("source_node_id"), 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": failed to check whether query node has same distribution", %s"}`, err.Error()))) + return + } + + target, err := strconv.ParseInt(req.FormValue("target_node_id"), 10, 64) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, err.Error()))) + return + } + resp, err := node.queryCoord.CheckQueryNodeDistribution(req.Context(), &querypb.CheckQueryNodeDistributionRequest{ + Base: commonpbutil.NewMsgBase(), + SourceNodeID: source, + TargetNodeID: target, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, err.Error()))) + return + } + + if !merr.Ok(resp) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(fmt.Sprintf(`{"msg": "failed to check whether query node has same distribution, %s"}`, resp.GetReason()))) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"msg": "OK"}`)) +} diff --git a/internal/proxy/management_test.go b/internal/proxy/management_test.go index 56654fedaf..930fbd0104 100644 --- a/internal/proxy/management_test.go +++ b/internal/proxy/management_test.go @@ -20,6 +20,7 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" "github.com/cockroachdb/errors" @@ -30,19 +31,25 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/pkg/util/merr" ) type ProxyManagementSuite struct { suite.Suite - datacoord *mocks.MockDataCoordClient - proxy *Proxy + querycoord *mocks.MockQueryCoordClient + datacoord *mocks.MockDataCoordClient + proxy *Proxy } func (s *ProxyManagementSuite) SetupTest() { s.datacoord = mocks.NewMockDataCoordClient(s.T()) + s.querycoord = mocks.NewMockQueryCoordClient(s.T()) + s.proxy = &Proxy{ - dataCoord: s.datacoord, + dataCoord: s.datacoord, + queryCoord: s.querycoord, } } @@ -158,6 +165,527 @@ func (s *ProxyManagementSuite) TestResumeDatacoordGC() { }) } +func (s *ProxyManagementSuite) TestListQueryNode() { + s.Run("normal", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().ListQueryNode(mock.Anything, mock.Anything).Return(&querypb.ListQueryNodeResponse{ + Status: merr.Success(), + NodeInfos: []*querypb.NodeInfo{ + { + ID: 1, + Address: "localhost", + State: "Healthy", + }, + }, + }, nil) + + req, err := http.NewRequest(http.MethodPost, mgrListQueryNode, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.ListQueryNode(recorder, req) + s.Equal(http.StatusOK, recorder.Code) + s.Equal(`{"nodeInfos":[{"ID":1,"address":"localhost","state":"Healthy"}]}`, recorder.Body.String()) + }) + + s.Run("return_error", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().ListQueryNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error")) + req, err := http.NewRequest(http.MethodPost, mgrListQueryNode, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.ListQueryNode(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) + + s.Run("return_failure", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().ListQueryNode(mock.Anything, mock.Anything).Return(&querypb.ListQueryNodeResponse{ + Status: merr.Status(merr.ErrServiceNotReady), + }, nil) + + req, err := http.NewRequest(http.MethodPost, mgrListQueryNode, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.ListQueryNode(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) +} + +func (s *ProxyManagementSuite) TestGetQueryNodeDistribution() { + s.Run("normal", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().GetQueryNodeDistribution(mock.Anything, mock.Anything).Return(&querypb.GetQueryNodeDistributionResponse{ + Status: merr.Success(), + ID: 1, + ChannelNames: []string{"channel-1"}, + SealedSegmentIDs: []int64{1, 2, 3}, + }, nil) + + req, err := http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader("node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + recorder := httptest.NewRecorder() + s.proxy.GetQueryNodeDistribution(recorder, req) + s.Equal(http.StatusOK, recorder.Code) + s.Equal(`{"ID":1,"channel_names":["channel-1"],"sealed_segmentIDs":[1,2,3]}`, recorder.Body.String()) + }) + + s.Run("return_error", func() { + s.SetupTest() + defer s.TearDownTest() + + // test invalid request body + req, err := http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, nil) + s.Require().NoError(err) + recorder := httptest.NewRecorder() + s.proxy.GetQueryNodeDistribution(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test miss requested param + req, err = http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader("")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.GetQueryNodeDistribution(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test rpc return error + s.querycoord.EXPECT().GetQueryNodeDistribution(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error")) + req, err = http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader("node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.GetQueryNodeDistribution(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) + + s.Run("return_failure", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().GetQueryNodeDistribution(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error")) + req, err := http.NewRequest(http.MethodPost, mgrGetQueryNodeDistribution, strings.NewReader("node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder := httptest.NewRecorder() + s.proxy.GetQueryNodeDistribution(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) +} + +func (s *ProxyManagementSuite) TestSuspendQueryCoordBalance() { + s.Run("normal", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().SuspendBalance(mock.Anything, mock.Anything).Return(merr.Success(), nil) + + req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryCoordBalance, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.SuspendQueryCoordBalance(recorder, req) + s.Equal(http.StatusOK, recorder.Code) + s.Equal(`{"msg": "OK"}`, recorder.Body.String()) + }) + + s.Run("return_error", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().SuspendBalance(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error")) + req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryCoordBalance, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.SuspendQueryCoordBalance(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) + + s.Run("return_failure", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().SuspendBalance(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil) + req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryCoordBalance, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.SuspendQueryCoordBalance(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) +} + +func (s *ProxyManagementSuite) TestResumeQueryCoordBalance() { + s.Run("normal", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().ResumeBalance(mock.Anything, mock.Anything).Return(merr.Success(), nil) + + req, err := http.NewRequest(http.MethodPost, mgrResumeQueryCoordBalance, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.ResumeQueryCoordBalance(recorder, req) + s.Equal(http.StatusOK, recorder.Code) + s.Equal(`{"msg": "OK"}`, recorder.Body.String()) + }) + + s.Run("return_error", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().ResumeBalance(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error")) + req, err := http.NewRequest(http.MethodPost, mgrResumeQueryCoordBalance, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.ResumeQueryCoordBalance(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) + + s.Run("return_failure", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().ResumeBalance(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil) + req, err := http.NewRequest(http.MethodPost, mgrResumeQueryCoordBalance, nil) + s.Require().NoError(err) + + recorder := httptest.NewRecorder() + s.proxy.ResumeQueryCoordBalance(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) +} + +func (s *ProxyManagementSuite) TestSuspendQueryNode() { + s.Run("normal", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().SuspendNode(mock.Anything, mock.Anything).Return(merr.Success(), nil) + + req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader("node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + recorder := httptest.NewRecorder() + s.proxy.SuspendQueryNode(recorder, req) + s.Equal(http.StatusOK, recorder.Code) + s.Equal(`{"msg": "OK"}`, recorder.Body.String()) + }) + + s.Run("return_error", func() { + s.SetupTest() + defer s.TearDownTest() + + // test invalid request body + req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryNode, nil) + s.Require().NoError(err) + recorder := httptest.NewRecorder() + s.proxy.SuspendQueryNode(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test miss requested param + req, err = http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader("")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.SuspendQueryNode(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test rpc return error + s.querycoord.EXPECT().SuspendNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error")) + req, err = http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader("node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.SuspendQueryNode(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) + + s.Run("return_failure", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().SuspendNode(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil) + req, err := http.NewRequest(http.MethodPost, mgrSuspendQueryNode, strings.NewReader("node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder := httptest.NewRecorder() + s.proxy.SuspendQueryNode(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) +} + +func (s *ProxyManagementSuite) TestResumeQueryNode() { + s.Run("normal", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().ResumeNode(mock.Anything, mock.Anything).Return(merr.Success(), nil) + + req, err := http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader("node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + recorder := httptest.NewRecorder() + s.proxy.ResumeQueryNode(recorder, req) + s.Equal(http.StatusOK, recorder.Code) + s.Equal(`{"msg": "OK"}`, recorder.Body.String()) + }) + + s.Run("return_error", func() { + s.SetupTest() + defer s.TearDownTest() + + // test invalid request body + req, err := http.NewRequest(http.MethodPost, mgrResumeQueryNode, nil) + s.Require().NoError(err) + recorder := httptest.NewRecorder() + s.proxy.ResumeQueryNode(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test miss requested param + req, err = http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader("")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.ResumeQueryNode(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test rpc return error + s.querycoord.EXPECT().ResumeNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error")) + req, err = http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader("node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.ResumeQueryNode(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) + + s.Run("return_failure", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().ResumeNode(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error")) + req, err := http.NewRequest(http.MethodPost, mgrResumeQueryNode, strings.NewReader("node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder := httptest.NewRecorder() + s.proxy.ResumeQueryNode(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) +} + +func (s *ProxyManagementSuite) TestTransferSegment() { + s.Run("normal", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().TransferSegment(mock.Anything, mock.Anything).Return(merr.Success(), nil) + + req, err := http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1©_mode=false")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder := httptest.NewRecorder() + s.proxy.TransferSegment(recorder, req) + s.Equal(http.StatusOK, recorder.Code) + s.Equal(`{"msg": "OK"}`, recorder.Body.String()) + + // test use default param + req, err = http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.TransferSegment(recorder, req) + s.Equal(http.StatusOK, recorder.Code) + s.Equal(`{"msg": "OK"}`, recorder.Body.String()) + }) + + s.Run("return_error", func() { + s.SetupTest() + defer s.TearDownTest() + + // test invalid request body + req, err := http.NewRequest(http.MethodPost, mgrTransferSegment, nil) + s.Require().NoError(err) + recorder := httptest.NewRecorder() + s.proxy.TransferSegment(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test miss requested param + req, err = http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.TransferSegment(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test rpc return error + s.querycoord.EXPECT().TransferSegment(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error")) + req, err = http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.TransferSegment(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) + + s.Run("return_failure", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().TransferSegment(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil) + req, err := http.NewRequest(http.MethodPost, mgrTransferSegment, strings.NewReader("source_node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder := httptest.NewRecorder() + s.proxy.TransferSegment(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) +} + +func (s *ProxyManagementSuite) TestTransferChannel() { + s.Run("normal", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().TransferChannel(mock.Anything, mock.Anything).Return(merr.Success(), nil) + + req, err := http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1©_mode=false")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder := httptest.NewRecorder() + s.proxy.TransferChannel(recorder, req) + s.Equal(http.StatusOK, recorder.Code) + s.Equal(`{"msg": "OK"}`, recorder.Body.String()) + + // test use default param + req, err = http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.TransferChannel(recorder, req) + s.Equal(http.StatusOK, recorder.Code) + s.Equal(`{"msg": "OK"}`, recorder.Body.String()) + }) + + s.Run("return_error", func() { + s.SetupTest() + defer s.TearDownTest() + + // test invalid request body + req, err := http.NewRequest(http.MethodPost, mgrTransferChannel, nil) + s.Require().NoError(err) + recorder := httptest.NewRecorder() + s.proxy.TransferChannel(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test miss requested param + req, err = http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.TransferChannel(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test rpc return error + s.querycoord.EXPECT().TransferChannel(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error")) + req, err = http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1&target_node_id=1&segment_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.TransferChannel(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) + + s.Run("return_failure", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().TransferChannel(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil) + req, err := http.NewRequest(http.MethodPost, mgrTransferChannel, strings.NewReader("source_node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder := httptest.NewRecorder() + s.proxy.TransferChannel(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) +} + +func (s *ProxyManagementSuite) TestCheckQueryNodeDistribution() { + s.Run("normal", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().CheckQueryNodeDistribution(mock.Anything, mock.Anything).Return(merr.Success(), nil) + + req, err := http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader("source_node_id=1&target_node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder := httptest.NewRecorder() + s.proxy.CheckQueryNodeDistribution(recorder, req) + s.Equal(http.StatusOK, recorder.Code) + s.Equal(`{"msg": "OK"}`, recorder.Body.String()) + }) + + s.Run("return_error", func() { + s.SetupTest() + defer s.TearDownTest() + + // test invalid request body + req, err := http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, nil) + s.Require().NoError(err) + recorder := httptest.NewRecorder() + s.proxy.CheckQueryNodeDistribution(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test miss requested param + req, err = http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader("")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.CheckQueryNodeDistribution(recorder, req) + s.Equal(http.StatusBadRequest, recorder.Code) + + // test rpc return error + s.querycoord.EXPECT().CheckQueryNodeDistribution(mock.Anything, mock.Anything).Return(nil, errors.New("mocked error")) + req, err = http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader("source_node_id=1&target_node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder = httptest.NewRecorder() + s.proxy.CheckQueryNodeDistribution(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) + + s.Run("return_failure", func() { + s.SetupTest() + defer s.TearDownTest() + + s.querycoord.EXPECT().CheckQueryNodeDistribution(mock.Anything, mock.Anything).Return(merr.Status(merr.ErrServiceNotReady), nil) + req, err := http.NewRequest(http.MethodPost, mgrCheckQueryNodeDistribution, strings.NewReader("source_node_id=1&target_node_id=1")) + s.Require().NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + recorder := httptest.NewRecorder() + s.proxy.CheckQueryNodeDistribution(recorder, req) + s.Equal(http.StatusInternalServerError, recorder.Code) + }) +} + func TestProxyManagement(t *testing.T) { suite.Run(t, new(ProxyManagementSuite)) } diff --git a/internal/querycoordv2/balance/balance.go b/internal/querycoordv2/balance/balance.go index ed9568a382..ccbc42124c 100644 --- a/internal/querycoordv2/balance/balance.go +++ b/internal/querycoordv2/balance/balance.go @@ -20,6 +20,8 @@ import ( "fmt" "sort" + "github.com/samber/lo" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" @@ -57,8 +59,8 @@ var ( ) type Balance interface { - AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan - AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan + AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan + AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan BalanceReplica(replica *meta.Replica) ([]SegmentAssignPlan, []ChannelAssignPlan) } @@ -67,7 +69,15 @@ type RoundRobinBalancer struct { nodeManager *session.NodeManager } -func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { +func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { + // skip out suspend node and stopping node during assignment, but skip this check for manual balance + if !manualBalance { + nodes = lo.Filter(nodes, func(node int64, _ int) bool { + info := b.nodeManager.Get(node) + return info != nil && info.GetState() == session.NodeStateNormal + }) + } + nodesInfo := b.getNodes(nodes) if len(nodesInfo) == 0 { return nil @@ -90,7 +100,14 @@ func (b *RoundRobinBalancer) AssignSegment(collectionID int64, segments []*meta. return ret } -func (b *RoundRobinBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan { +func (b *RoundRobinBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan { + // skip out suspend node and stopping node during assignment, but skip this check for manual balance + if !manualBalance { + nodes = lo.Filter(nodes, func(node int64, _ int) bool { + info := b.nodeManager.Get(node) + return info != nil && info.GetState() == session.NodeStateNormal + }) + } nodesInfo := b.getNodes(nodes) if len(nodesInfo) == 0 { return nil diff --git a/internal/querycoordv2/balance/balance_test.go b/internal/querycoordv2/balance/balance_test.go index d49eb87ac4..35f9769e1a 100644 --- a/internal/querycoordv2/balance/balance_test.go +++ b/internal/querycoordv2/balance/balance_test.go @@ -97,7 +97,7 @@ func (suite *BalanceTestSuite) TestAssignBalance() { suite.mockScheduler.EXPECT().GetNodeSegmentDelta(c.nodeIDs[i]).Return(c.deltaCnts[i]) } } - plans := suite.roundRobinBalancer.AssignSegment(0, c.assignments, c.nodeIDs) + plans := suite.roundRobinBalancer.AssignSegment(0, c.assignments, c.nodeIDs, false) suite.ElementsMatch(c.expectPlans, plans) }) } @@ -161,7 +161,7 @@ func (suite *BalanceTestSuite) TestAssignChannel() { suite.mockScheduler.EXPECT().GetNodeChannelDelta(c.nodeIDs[i]).Return(c.deltaCnts[i]) } } - plans := suite.roundRobinBalancer.AssignChannel(c.assignments, c.nodeIDs) + plans := suite.roundRobinBalancer.AssignChannel(c.assignments, c.nodeIDs, false) suite.ElementsMatch(c.expectPlans, plans) }) } diff --git a/internal/querycoordv2/balance/mock_balancer.go b/internal/querycoordv2/balance/mock_balancer.go index f97367b4c3..f1f2250e30 100644 --- a/internal/querycoordv2/balance/mock_balancer.go +++ b/internal/querycoordv2/balance/mock_balancer.go @@ -20,13 +20,13 @@ func (_m *MockBalancer) EXPECT() *MockBalancer_Expecter { return &MockBalancer_Expecter{mock: &_m.Mock} } -// AssignChannel provides a mock function with given fields: channels, nodes -func (_m *MockBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan { - ret := _m.Called(channels, nodes) +// AssignChannel provides a mock function with given fields: channels, nodes, manualBalance +func (_m *MockBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan { + ret := _m.Called(channels, nodes, manualBalance) var r0 []ChannelAssignPlan - if rf, ok := ret.Get(0).(func([]*meta.DmChannel, []int64) []ChannelAssignPlan); ok { - r0 = rf(channels, nodes) + if rf, ok := ret.Get(0).(func([]*meta.DmChannel, []int64, bool) []ChannelAssignPlan); ok { + r0 = rf(channels, nodes, manualBalance) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]ChannelAssignPlan) @@ -44,13 +44,14 @@ type MockBalancer_AssignChannel_Call struct { // AssignChannel is a helper method to define mock.On call // - channels []*meta.DmChannel // - nodes []int64 -func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}) *MockBalancer_AssignChannel_Call { - return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes)} +// - manualBalance bool +func (_e *MockBalancer_Expecter) AssignChannel(channels interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignChannel_Call { + return &MockBalancer_AssignChannel_Call{Call: _e.mock.On("AssignChannel", channels, nodes, manualBalance)} } -func (_c *MockBalancer_AssignChannel_Call) Run(run func(channels []*meta.DmChannel, nodes []int64)) *MockBalancer_AssignChannel_Call { +func (_c *MockBalancer_AssignChannel_Call) Run(run func(channels []*meta.DmChannel, nodes []int64, manualBalance bool)) *MockBalancer_AssignChannel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]*meta.DmChannel), args[1].([]int64)) + run(args[0].([]*meta.DmChannel), args[1].([]int64), args[2].(bool)) }) return _c } @@ -60,18 +61,18 @@ func (_c *MockBalancer_AssignChannel_Call) Return(_a0 []ChannelAssignPlan) *Mock return _c } -func (_c *MockBalancer_AssignChannel_Call) RunAndReturn(run func([]*meta.DmChannel, []int64) []ChannelAssignPlan) *MockBalancer_AssignChannel_Call { +func (_c *MockBalancer_AssignChannel_Call) RunAndReturn(run func([]*meta.DmChannel, []int64, bool) []ChannelAssignPlan) *MockBalancer_AssignChannel_Call { _c.Call.Return(run) return _c } -// AssignSegment provides a mock function with given fields: collectionID, segments, nodes -func (_m *MockBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { - ret := _m.Called(collectionID, segments, nodes) +// AssignSegment provides a mock function with given fields: collectionID, segments, nodes, manualBalance +func (_m *MockBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { + ret := _m.Called(collectionID, segments, nodes, manualBalance) var r0 []SegmentAssignPlan - if rf, ok := ret.Get(0).(func(int64, []*meta.Segment, []int64) []SegmentAssignPlan); ok { - r0 = rf(collectionID, segments, nodes) + if rf, ok := ret.Get(0).(func(int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan); ok { + r0 = rf(collectionID, segments, nodes, manualBalance) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]SegmentAssignPlan) @@ -90,13 +91,14 @@ type MockBalancer_AssignSegment_Call struct { // - collectionID int64 // - segments []*meta.Segment // - nodes []int64 -func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}) *MockBalancer_AssignSegment_Call { - return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes)} +// - manualBalance bool +func (_e *MockBalancer_Expecter) AssignSegment(collectionID interface{}, segments interface{}, nodes interface{}, manualBalance interface{}) *MockBalancer_AssignSegment_Call { + return &MockBalancer_AssignSegment_Call{Call: _e.mock.On("AssignSegment", collectionID, segments, nodes, manualBalance)} } -func (_c *MockBalancer_AssignSegment_Call) Run(run func(collectionID int64, segments []*meta.Segment, nodes []int64)) *MockBalancer_AssignSegment_Call { +func (_c *MockBalancer_AssignSegment_Call) Run(run func(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool)) *MockBalancer_AssignSegment_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(int64), args[1].([]*meta.Segment), args[2].([]int64)) + run(args[0].(int64), args[1].([]*meta.Segment), args[2].([]int64), args[3].(bool)) }) return _c } @@ -106,7 +108,7 @@ func (_c *MockBalancer_AssignSegment_Call) Return(_a0 []SegmentAssignPlan) *Mock return _c } -func (_c *MockBalancer_AssignSegment_Call) RunAndReturn(run func(int64, []*meta.Segment, []int64) []SegmentAssignPlan) *MockBalancer_AssignSegment_Call { +func (_c *MockBalancer_AssignSegment_Call) RunAndReturn(run func(int64, []*meta.Segment, []int64, bool) []SegmentAssignPlan) *MockBalancer_AssignSegment_Call { _c.Call.Return(run) return _c } diff --git a/internal/querycoordv2/balance/rowcount_based_balancer.go b/internal/querycoordv2/balance/rowcount_based_balancer.go index 3225418c40..a3b615895b 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer.go @@ -41,7 +41,15 @@ type RowCountBasedBalancer struct { // AssignSegment, when row count based balancer assign segments, it will assign segment to node with least global row count. // try to make every query node has same row count. -func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { +func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { + // skip out suspend node and stopping node during assignment, but skip this check for manual balance + if !manualBalance { + nodes = lo.Filter(nodes, func(node int64, _ int) bool { + info := b.nodeManager.Get(node) + return info != nil && info.GetState() == session.NodeStateNormal + }) + } + nodeItems := b.convertToNodeItemsBySegment(nodes) if len(nodeItems) == 0 { return nil @@ -75,7 +83,15 @@ func (b *RowCountBasedBalancer) AssignSegment(collectionID int64, segments []*me // AssignSegment, when row count based balancer assign segments, it will assign channel to node with least global channel count. // try to make every query node has channel count -func (b *RowCountBasedBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64) []ChannelAssignPlan { +func (b *RowCountBasedBalancer) AssignChannel(channels []*meta.DmChannel, nodes []int64, manualBalance bool) []ChannelAssignPlan { + // skip out suspend node and stopping node during assignment, but skip this check for manual balance + if !manualBalance { + nodes = lo.Filter(nodes, func(node int64, _ int) bool { + info := b.nodeManager.Get(node) + return info != nil && info.GetState() == session.NodeStateNormal + }) + } + nodeItems := b.convertToNodeItemsByChannel(nodes) nodeItems = lo.Shuffle(nodeItems) if len(nodeItems) == 0 { @@ -215,7 +231,7 @@ func (b *RowCountBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, on b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil && segment.GetLevel() != datapb.SegmentLevel_L0 }) - plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes) + plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes, false) for i := range plans { plans[i].From = nodeID plans[i].Replica = replica @@ -283,7 +299,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode return nil } - segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, nodesWithLessRow) + segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, nodesWithLessRow, false) for i := range segmentPlans { segmentPlans[i].From = segmentPlans[i].Segment.Node segmentPlans[i].Replica = replica @@ -296,7 +312,7 @@ func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, on channelPlans := make([]ChannelAssignPlan, 0) for _, nodeID := range offlineNodes { dmChannels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID) - plans := b.AssignChannel(dmChannels, onlineNodes) + plans := b.AssignChannel(dmChannels, onlineNodes, false) for i := range plans { plans[i].From = nodeID plans[i].Replica = replica @@ -334,7 +350,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode return nil } - channelPlans := b.AssignChannel(channelsToMove, nodeWithLessChannel) + channelPlans := b.AssignChannel(channelsToMove, nodeWithLessChannel, false) for i := range channelPlans { channelPlans[i].From = channelPlans[i].Channel.Node channelPlans[i].Replica = replica diff --git a/internal/querycoordv2/balance/rowcount_based_balancer_test.go b/internal/querycoordv2/balance/rowcount_based_balancer_test.go index b2c14246b3..ecc283f11a 100644 --- a/internal/querycoordv2/balance/rowcount_based_balancer_test.go +++ b/internal/querycoordv2/balance/rowcount_based_balancer_test.go @@ -136,12 +136,67 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegment() { nodeInfo.SetState(c.states[i]) suite.balancer.nodeManager.Add(nodeInfo) } - plans := balancer.AssignSegment(0, c.assignments, c.nodes) + plans := balancer.AssignSegment(0, c.assignments, c.nodes, false) assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans, plans) }) } } +func (suite *RowCountBasedBalancerTestSuite) TestSuspendNode() { + cases := []struct { + name string + distributions map[int64][]*meta.Segment + assignments []*meta.Segment + nodes []int64 + segmentCnts []int + states []session.State + expectPlans []SegmentAssignPlan + }{ + { + name: "test suspend node", + distributions: map[int64][]*meta.Segment{ + 2: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 20}, Node: 2}}, + 3: {{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 30}, Node: 3}}, + }, + assignments: []*meta.Segment{ + {SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 5}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 15}}, + }, + nodes: []int64{1, 2, 3, 4}, + states: []session.State{session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend}, + segmentCnts: []int{0, 1, 1, 0}, + expectPlans: []SegmentAssignPlan{}, + }, + } + + for _, c := range cases { + suite.Run(c.name, func() { + // I do not find a better way to do the setup and teardown work for subtests yet. + // If you do, please replace with it. + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + for node, s := range c.distributions { + balancer.dist.SegmentDistManager.Update(node, s...) + } + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: c.nodes[i], + Address: "localhost", + Hostname: "localhost", + }) + nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i])) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + } + plans := balancer.AssignSegment(0, c.assignments, c.nodes, false) + // all node has been suspend, so no node to assign segment + suite.ElementsMatch(c.expectPlans, plans) + }) + } +} + func (suite *RowCountBasedBalancerTestSuite) TestBalance() { cases := []struct { name string @@ -888,7 +943,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestAssignSegmentWithGrowing() { NumOfGrowingRows: 50, } suite.balancer.dist.LeaderViewManager.Update(1, leaderView) - plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions)) + plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false) for _, p := range plans { suite.Equal(int64(2), p.To) } diff --git a/internal/querycoordv2/balance/score_based_balancer.go b/internal/querycoordv2/balance/score_based_balancer.go index 0cc7adffb5..737cb1830f 100644 --- a/internal/querycoordv2/balance/score_based_balancer.go +++ b/internal/querycoordv2/balance/score_based_balancer.go @@ -50,7 +50,15 @@ func NewScoreBasedBalancer(scheduler task.Scheduler, } // AssignSegment got a segment list, and try to assign each segment to node's with lowest score -func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64) []SegmentAssignPlan { +func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta.Segment, nodes []int64, manualBalance bool) []SegmentAssignPlan { + // skip out suspend node and stopping node during assignment, but skip this check for manual balance + if !manualBalance { + nodes = lo.Filter(nodes, func(node int64, _ int) bool { + info := b.nodeManager.Get(node) + return info != nil && info.GetState() == session.NodeStateNormal + }) + } + // calculate each node's score nodeItems := b.convertToNodeItems(collectionID, nodes) if len(nodeItems) == 0 { @@ -87,7 +95,8 @@ func (b *ScoreBasedBalancer) AssignSegment(collectionID int64, segments []*meta. sourceNode := nodeItemsMap[s.Node] // if segment's node exist, which means this segment comes from balancer. we should consider the benefit // if the segment reassignment doesn't got enough benefit, we should skip this reassignment - if sourceNode != nil && !b.hasEnoughBenefit(sourceNode, targetNode, priorityChange) { + // notice: we should skip benefit check for manual balance + if !manualBalance && sourceNode != nil && !b.hasEnoughBenefit(sourceNode, targetNode, priorityChange) { return } @@ -249,7 +258,7 @@ func (b *ScoreBasedBalancer) genStoppingSegmentPlan(replica *meta.Replica, onlin b.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.NextTarget) != nil && segment.GetLevel() != datapb.SegmentLevel_L0 }) - plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes) + plans := b.AssignSegment(replica.CollectionID, segments, onlineNodes, false) for i := range plans { plans[i].From = nodeID plans[i].Replica = replica @@ -313,7 +322,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNodes [ return nil } - segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, onlineNodes) + segmentPlans := b.AssignSegment(replica.CollectionID, segmentsToMove, onlineNodes, false) for i := range segmentPlans { segmentPlans[i].From = segmentPlans[i].Segment.Node segmentPlans[i].Replica = replica diff --git a/internal/querycoordv2/balance/score_based_balancer_test.go b/internal/querycoordv2/balance/score_based_balancer_test.go index 8b62aa3ab8..a16022de6d 100644 --- a/internal/querycoordv2/balance/score_based_balancer_test.go +++ b/internal/querycoordv2/balance/score_based_balancer_test.go @@ -232,13 +232,68 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegment() { suite.balancer.nodeManager.Add(nodeInfo) } for i := range c.collectionIDs { - plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes) + plans := balancer.AssignSegment(c.collectionIDs[i], c.assignments[i], c.nodes, false) assertSegmentAssignPlanElementMatch(&suite.Suite, c.expectPlans[i], plans) } }) } } +func (suite *ScoreBasedBalancerTestSuite) TestSuspendNode() { + cases := []struct { + name string + distributions map[int64][]*meta.Segment + assignments []*meta.Segment + nodes []int64 + segmentCnts []int + states []session.State + expectPlans []SegmentAssignPlan + }{ + { + name: "test suspend node", + distributions: map[int64][]*meta.Segment{ + 2: {{SegmentInfo: &datapb.SegmentInfo{ID: 1, NumOfRows: 20}, Node: 2}}, + 3: {{SegmentInfo: &datapb.SegmentInfo{ID: 2, NumOfRows: 30}, Node: 3}}, + }, + assignments: []*meta.Segment{ + {SegmentInfo: &datapb.SegmentInfo{ID: 3, NumOfRows: 5}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 4, NumOfRows: 10}}, + {SegmentInfo: &datapb.SegmentInfo{ID: 5, NumOfRows: 15}}, + }, + nodes: []int64{1, 2, 3, 4}, + states: []session.State{session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend, session.NodeStateSuspend}, + segmentCnts: []int{0, 1, 1, 0}, + expectPlans: []SegmentAssignPlan{}, + }, + } + + for _, c := range cases { + suite.Run(c.name, func() { + // I do not find a better way to do the setup and teardown work for subtests yet. + // If you do, please replace with it. + suite.SetupSuite() + defer suite.TearDownTest() + balancer := suite.balancer + for node, s := range c.distributions { + balancer.dist.SegmentDistManager.Update(node, s...) + } + for i := range c.nodes { + nodeInfo := session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: c.nodes[i], + Address: "localhost", + Hostname: "localhost", + }) + nodeInfo.UpdateStats(session.WithSegmentCnt(c.segmentCnts[i])) + nodeInfo.SetState(c.states[i]) + suite.balancer.nodeManager.Add(nodeInfo) + } + plans := balancer.AssignSegment(0, c.assignments, c.nodes, false) + // all node has been suspend, so no node to assign segment + suite.ElementsMatch(c.expectPlans, plans) + }) + } +} + func (suite *ScoreBasedBalancerTestSuite) TestAssignSegmentWithGrowing() { suite.SetupSuite() defer suite.TearDownTest() @@ -279,7 +334,7 @@ func (suite *ScoreBasedBalancerTestSuite) TestAssignSegmentWithGrowing() { NumOfGrowingRows: 50, } suite.balancer.dist.LeaderViewManager.Update(1, leaderView) - plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions)) + plans := balancer.AssignSegment(1, toAssign, lo.Keys(distributions), false) for _, p := range plans { suite.Equal(int64(2), p.To) } diff --git a/internal/querycoordv2/checkers/channel_checker.go b/internal/querycoordv2/checkers/channel_checker.go index 0217add440..3c721725b0 100644 --- a/internal/querycoordv2/checkers/channel_checker.go +++ b/internal/querycoordv2/checkers/channel_checker.go @@ -222,7 +222,7 @@ func (c *ChannelChecker) createChannelLoadTask(ctx context.Context, channels []* availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { return !outboundNodes.Contain(node) }) - plans := c.balancer.AssignChannel(channels, availableNodes) + plans := c.balancer.AssignChannel(channels, availableNodes, false) for i := range plans { plans[i].Replica = replica } diff --git a/internal/querycoordv2/checkers/channel_checker_test.go b/internal/querycoordv2/checkers/channel_checker_test.go index 64c28f3830..96551ef819 100644 --- a/internal/querycoordv2/checkers/channel_checker_test.go +++ b/internal/querycoordv2/checkers/channel_checker_test.go @@ -100,7 +100,7 @@ func (suite *ChannelCheckerTestSuite) setNodeAvailable(nodes ...int64) { func (suite *ChannelCheckerTestSuite) createMockBalancer() balance.Balance { balancer := balance.NewMockBalancer(suite.T()) - balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything).Maybe().Return(func(channels []*meta.DmChannel, nodes []int64) []balance.ChannelAssignPlan { + balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(channels []*meta.DmChannel, nodes []int64, _ bool) []balance.ChannelAssignPlan { plans := make([]balance.ChannelAssignPlan, 0, len(channels)) for i, c := range channels { plan := balance.ChannelAssignPlan{ diff --git a/internal/querycoordv2/checkers/controller_test.go b/internal/querycoordv2/checkers/controller_test.go index c3d41c1bd5..339d184b0c 100644 --- a/internal/querycoordv2/checkers/controller_test.go +++ b/internal/querycoordv2/checkers/controller_test.go @@ -134,11 +134,11 @@ func (suite *CheckerControllerSuite) TestBasic() { assignSegCounter := atomic.NewInt32(0) assingChanCounter := atomic.NewInt32(0) - suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(i1 int64, s []*meta.Segment, i2 []int64) []balance.SegmentAssignPlan { + suite.balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(i1 int64, s []*meta.Segment, i2 []int64, i4 bool) []balance.SegmentAssignPlan { assignSegCounter.Inc() return nil }) - suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything).RunAndReturn(func(dc []*meta.DmChannel, i []int64) []balance.ChannelAssignPlan { + suite.balancer.EXPECT().AssignChannel(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(dc []*meta.DmChannel, i []int64, _ bool) []balance.ChannelAssignPlan { assingChanCounter.Inc() return nil }) diff --git a/internal/querycoordv2/checkers/segment_checker.go b/internal/querycoordv2/checkers/segment_checker.go index 90911a5aa5..34d9d27be4 100644 --- a/internal/querycoordv2/checkers/segment_checker.go +++ b/internal/querycoordv2/checkers/segment_checker.go @@ -400,7 +400,7 @@ func (c *SegmentChecker) createSegmentLoadTasks(ctx context.Context, segments [] SegmentInfo: s, } }) - shardPlans := c.balancer.AssignSegment(replica.CollectionID, segmentInfos, availableNodes) + shardPlans := c.balancer.AssignSegment(replica.CollectionID, segmentInfos, availableNodes, false) for i := range shardPlans { shardPlans[i].Replica = replica } diff --git a/internal/querycoordv2/checkers/segment_checker_test.go b/internal/querycoordv2/checkers/segment_checker_test.go index 6911f0752d..c982dfff1b 100644 --- a/internal/querycoordv2/checkers/segment_checker_test.go +++ b/internal/querycoordv2/checkers/segment_checker_test.go @@ -87,7 +87,7 @@ func (suite *SegmentCheckerTestSuite) TearDownTest() { func (suite *SegmentCheckerTestSuite) createMockBalancer() balance.Balance { balancer := balance.NewMockBalancer(suite.T()) - balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(collectionID int64, segments []*meta.Segment, nodes []int64) []balance.SegmentAssignPlan { + balancer.EXPECT().AssignSegment(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Maybe().Return(func(collectionID int64, segments []*meta.Segment, nodes []int64, _ bool) []balance.SegmentAssignPlan { plans := make([]balance.SegmentAssignPlan, 0, len(segments)) for i, s := range segments { plan := balance.SegmentAssignPlan{ diff --git a/internal/querycoordv2/handlers.go b/internal/querycoordv2/handlers.go index 8a35b89219..0d1a06525d 100644 --- a/internal/querycoordv2/handlers.go +++ b/internal/querycoordv2/handlers.go @@ -22,6 +22,7 @@ import ( "sync" "time" + "github.com/cockroachdb/errors" "github.com/samber/lo" "go.uber.org/zap" @@ -87,78 +88,61 @@ func (s *Server) getCollectionSegmentInfo(collection int64) []*querypb.SegmentIn return lo.Values(infos) } -// parseBalanceRequest parses the load balance request, -// returns the collection, replica, and segments -func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRequest, replica *meta.Replica) error { - srcNode := req.GetSourceNodeIDs()[0] - dstNodeSet := typeutil.NewUniqueSet(req.GetDstNodeIDs()...) - if dstNodeSet.Len() == 0 { - outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica) - availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { - stop, err := s.nodeMgr.IsStoppingNode(node) - if err != nil { - return false - } - return !outboundNodes.Contain(node) && !stop - }) - dstNodeSet.Insert(availableNodes...) +// generate balance segment task and submit to scheduler +// if sync is true, this func call will wait task to finish, until reach the segment task timeout +// if copyMode is true, this func call will generate a load segment task, instead a balance segment task +func (s *Server) balanceSegments(ctx context.Context, + collectionID int64, + replica *meta.Replica, + srcNode int64, + dstNodes []int64, + segments []*meta.Segment, + sync bool, + copyMode bool, +) error { + log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID), zap.Int64("srcNode", srcNode)) + plans := s.balancer.AssignSegment(collectionID, segments, dstNodes, true) + for i := range plans { + plans[i].From = srcNode + plans[i].Replica = replica } - dstNodeSet.Remove(srcNode) - - toBalance := typeutil.NewSet[*meta.Segment]() - // Only balance segments in targets - segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(srcNode)) - segments = lo.Filter(segments, func(segment *meta.Segment, _ int) bool { - return s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil - }) - allSegments := make(map[int64]*meta.Segment) - for _, segment := range segments { - allSegments[segment.GetID()] = segment - } - - if len(req.GetSealedSegmentIDs()) == 0 { - toBalance.Insert(segments...) - } else { - for _, segmentID := range req.GetSealedSegmentIDs() { - segment, ok := allSegments[segmentID] - if !ok { - return fmt.Errorf("segment %d not found in source node %d", segmentID, srcNode) - } - toBalance.Insert(segment) - } - } - - log := log.With( - zap.Int64("collectionID", req.GetCollectionID()), - zap.Int64("srcNodeID", srcNode), - zap.Int64s("destNodeIDs", dstNodeSet.Collect()), - ) - plans := s.balancer.AssignSegment(req.GetCollectionID(), toBalance.Collect(), dstNodeSet.Collect()) tasks := make([]task.Task, 0, len(plans)) for _, plan := range plans { log.Info("manually balance segment...", - zap.Int64("destNodeID", plan.To), + zap.Int64("replica", plan.Replica.ID), + zap.String("channel", plan.Segment.InsertChannel), + zap.Int64("from", plan.From), + zap.Int64("to", plan.To), zap.Int64("segmentID", plan.Segment.GetID()), ) - task, err := task.NewSegmentTask(ctx, + actions := make([]task.Action, 0) + loadAction := task.NewSegmentActionWithScope(plan.To, task.ActionTypeGrow, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical) + actions = append(actions, loadAction) + if !copyMode { + // if in copy mode, the release action will be skip + releaseAction := task.NewSegmentActionWithScope(plan.From, task.ActionTypeReduce, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical) + actions = append(actions, releaseAction) + } + + task, err := task.NewSegmentTask(s.ctx, Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), - task.WrapIDSource(req.GetBase().GetMsgID()), - req.GetCollectionID(), - replica, - task.NewSegmentActionWithScope(plan.To, task.ActionTypeGrow, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical), - task.NewSegmentActionWithScope(srcNode, task.ActionTypeReduce, plan.Segment.GetInsertChannel(), plan.Segment.GetID(), querypb.DataScope_Historical), + utils.ManualBalance, + collectionID, + plan.Replica, + actions..., ) if err != nil { log.Warn("create segment task for balance failed", - zap.Int64("collection", req.GetCollectionID()), - zap.Int64("replica", replica.GetID()), + zap.Int64("replica", plan.Replica.ID), zap.String("channel", plan.Segment.InsertChannel), - zap.Int64("from", srcNode), + zap.Int64("from", plan.From), zap.Int64("to", plan.To), + zap.Int64("segmentID", plan.Segment.GetID()), zap.Error(err), ) continue } + task.SetReason("manual balance") err = s.taskScheduler.Add(task) if err != nil { task.Cancel(err) @@ -166,7 +150,92 @@ func (s *Server) balanceSegments(ctx context.Context, req *querypb.LoadBalanceRe } tasks = append(tasks, task) } - return task.Wait(ctx, Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), tasks...) + + if sync { + err := task.Wait(ctx, Params.QueryCoordCfg.SegmentTaskTimeout.GetAsDuration(time.Millisecond), tasks...) + if err != nil { + msg := "failed to wait all balance task finished" + log.Warn(msg, zap.Error(err)) + return errors.Wrap(err, msg) + } + } + + return nil +} + +// generate balance channel task and submit to scheduler +// if sync is true, this func call will wait task to finish, until reach the channel task timeout +// if copyMode is true, this func call will generate a load channel task, instead a balance channel task +func (s *Server) balanceChannels(ctx context.Context, + collectionID int64, + replica *meta.Replica, + srcNode int64, + dstNodes []int64, + channels []*meta.DmChannel, + sync bool, + copyMode bool, +) error { + log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID)) + + plans := s.balancer.AssignChannel(channels, dstNodes, true) + for i := range plans { + plans[i].From = srcNode + plans[i].Replica = replica + } + + tasks := make([]task.Task, 0, len(plans)) + for _, plan := range plans { + log.Info("manually balance channel...", + zap.Int64("replica", plan.Replica.ID), + zap.String("channel", plan.Channel.GetChannelName()), + zap.Int64("from", plan.From), + zap.Int64("to", plan.To), + ) + + actions := make([]task.Action, 0) + loadAction := task.NewChannelAction(plan.To, task.ActionTypeGrow, plan.Channel.GetChannelName()) + actions = append(actions, loadAction) + if !copyMode { + // if in copy mode, the release action will be skip + releaseAction := task.NewChannelAction(plan.From, task.ActionTypeReduce, plan.Channel.GetChannelName()) + actions = append(actions, releaseAction) + } + task, err := task.NewChannelTask(s.ctx, + Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), + utils.ManualBalance, + collectionID, + plan.Replica, + actions..., + ) + if err != nil { + log.Warn("create channel task for balance failed", + zap.Int64("replica", plan.Replica.ID), + zap.String("channel", plan.Channel.GetChannelName()), + zap.Int64("from", plan.From), + zap.Int64("to", plan.To), + zap.Error(err), + ) + continue + } + task.SetReason("manual balance") + err = s.taskScheduler.Add(task) + if err != nil { + task.Cancel(err) + return err + } + tasks = append(tasks, task) + } + + if sync { + err := task.Wait(ctx, Params.QueryCoordCfg.ChannelTaskTimeout.GetAsDuration(time.Millisecond), tasks...) + if err != nil { + msg := "failed to wait all balance task finished" + log.Warn(msg, zap.Error(err)) + return errors.Wrap(err, msg) + } + } + + return nil } // TODO(dragondriver): add more detail metrics diff --git a/internal/querycoordv2/meta/replica_manager.go b/internal/querycoordv2/meta/replica_manager.go index 34bcbd1c76..f3c77cdc63 100644 --- a/internal/querycoordv2/meta/replica_manager.go +++ b/internal/querycoordv2/meta/replica_manager.go @@ -232,7 +232,7 @@ func (m *ReplicaManager) GetByCollection(collectionID typeutil.UniqueID) []*Repl m.rwmutex.RLock() defer m.rwmutex.RUnlock() - replicas := make([]*Replica, 0, 3) + replicas := make([]*Replica, 0) for _, replica := range m.replicas { if replica.CollectionID == collectionID { replicas = append(replicas, replica) @@ -255,6 +255,20 @@ func (m *ReplicaManager) GetByCollectionAndNode(collectionID, nodeID typeutil.Un return nil } +func (m *ReplicaManager) GetByNode(nodeID typeutil.UniqueID) []*Replica { + m.rwmutex.RLock() + defer m.rwmutex.RUnlock() + + replicas := make([]*Replica, 0) + for _, replica := range m.replicas { + if replica.nodes.Contain(nodeID) { + replicas = append(replicas, replica) + } + } + + return replicas +} + func (m *ReplicaManager) GetByCollectionAndRG(collectionID int64, rgName string) []*Replica { m.rwmutex.RLock() defer m.rwmutex.RUnlock() diff --git a/internal/querycoordv2/meta/replica_manager_test.go b/internal/querycoordv2/meta/replica_manager_test.go index 5255f59437..579eaaac8d 100644 --- a/internal/querycoordv2/meta/replica_manager_test.go +++ b/internal/querycoordv2/meta/replica_manager_test.go @@ -102,6 +102,7 @@ func (suite *ReplicaManagerSuite) TestGet() { for _, replica := range replicas { suite.Equal(collection, replica.GetCollectionID()) suite.Equal(replica, mgr.Get(replica.GetID())) + suite.Equal(len(replica.Replica.GetNodes()), replica.Len()) suite.Equal(replica.Replica.GetNodes(), replica.GetNodes()) replicaNodes[replica.GetID()] = replica.Replica.GetNodes() nodes = append(nodes, replica.Replica.Nodes...) @@ -117,6 +118,24 @@ func (suite *ReplicaManagerSuite) TestGet() { } } +func (suite *ReplicaManagerSuite) TestGetByNode() { + mgr := suite.mgr + + randomNodeID := int64(11111) + testReplica1, err := mgr.spawn(3002, DefaultResourceGroupName) + suite.NoError(err) + testReplica1.AddNode(randomNodeID) + + testReplica2, err := mgr.spawn(3002, DefaultResourceGroupName) + suite.NoError(err) + testReplica2.AddNode(randomNodeID) + + mgr.Put(testReplica1, testReplica2) + + replicas := mgr.GetByNode(randomNodeID) + suite.Len(replicas, 2) +} + func (suite *ReplicaManagerSuite) TestRecover() { mgr := suite.mgr diff --git a/internal/querycoordv2/ops_service_test.go b/internal/querycoordv2/ops_service_test.go new file mode 100644 index 0000000000..4ec45df183 --- /dev/null +++ b/internal/querycoordv2/ops_service_test.go @@ -0,0 +1,888 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package querycoordv2 + +import ( + "context" + "testing" + + "github.com/samber/lo" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "go.uber.org/atomic" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/kv" + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/metastore" + "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querycoordv2/balance" + "github.com/milvus-io/milvus/internal/querycoordv2/checkers" + "github.com/milvus-io/milvus/internal/querycoordv2/dist" + "github.com/milvus-io/milvus/internal/querycoordv2/job" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" + "github.com/milvus-io/milvus/internal/querycoordv2/observers" + "github.com/milvus-io/milvus/internal/querycoordv2/params" + "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/internal/querycoordv2/task" + "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/util/etcd" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metricsinfo" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type OpsServiceSuite struct { + suite.Suite + + // Dependencies + kv kv.MetaKv + store metastore.QueryCoordCatalog + dist *meta.DistributionManager + meta *meta.Meta + targetMgr *meta.TargetManager + broker *meta.MockBroker + targetObserver *observers.TargetObserver + cluster *session.MockCluster + nodeMgr *session.NodeManager + jobScheduler *job.Scheduler + taskScheduler *task.MockScheduler + balancer balance.Balance + + distMgr *meta.DistributionManager + distController *dist.MockController + checkerController *checkers.CheckerController + + // Test object + server *Server +} + +func (suite *OpsServiceSuite) SetupSuite() { + paramtable.Init() +} + +func (suite *OpsServiceSuite) SetupTest() { + config := params.GenerateEtcdConfig() + cli, err := etcd.GetEtcdClient( + config.UseEmbedEtcd.GetAsBool(), + config.EtcdUseSSL.GetAsBool(), + config.Endpoints.GetAsStrings(), + config.EtcdTLSCert.GetValue(), + config.EtcdTLSKey.GetValue(), + config.EtcdTLSCACert.GetValue(), + config.EtcdTLSMinVersion.GetValue()) + suite.Require().NoError(err) + suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) + + suite.store = querycoord.NewCatalog(suite.kv) + suite.dist = meta.NewDistributionManager() + suite.nodeMgr = session.NewNodeManager() + suite.meta = meta.NewMeta(params.RandomIncrementIDAllocator(), suite.store, suite.nodeMgr) + suite.broker = meta.NewMockBroker(suite.T()) + suite.targetMgr = meta.NewTargetManager(suite.broker, suite.meta) + suite.targetObserver = observers.NewTargetObserver( + suite.meta, + suite.targetMgr, + suite.dist, + suite.broker, + suite.cluster, + ) + suite.cluster = session.NewMockCluster(suite.T()) + suite.jobScheduler = job.NewScheduler() + suite.taskScheduler = task.NewMockScheduler(suite.T()) + suite.jobScheduler.Start() + suite.balancer = balance.NewScoreBasedBalancer( + suite.taskScheduler, + suite.nodeMgr, + suite.dist, + suite.meta, + suite.targetMgr, + ) + meta.GlobalFailedLoadCache = meta.NewFailedLoadCache() + suite.distMgr = meta.NewDistributionManager() + suite.distController = dist.NewMockController(suite.T()) + + suite.checkerController = checkers.NewCheckerController(suite.meta, suite.distMgr, + suite.targetMgr, suite.balancer, suite.nodeMgr, suite.taskScheduler, suite.broker) + + suite.server = &Server{ + kv: suite.kv, + store: suite.store, + session: sessionutil.NewSessionWithEtcd(context.Background(), Params.EtcdCfg.MetaRootPath.GetValue(), cli), + metricsCacheManager: metricsinfo.NewMetricsCacheManager(), + dist: suite.dist, + meta: suite.meta, + targetMgr: suite.targetMgr, + broker: suite.broker, + targetObserver: suite.targetObserver, + nodeMgr: suite.nodeMgr, + cluster: suite.cluster, + jobScheduler: suite.jobScheduler, + taskScheduler: suite.taskScheduler, + balancer: suite.balancer, + distController: suite.distController, + ctx: context.Background(), + checkerController: suite.checkerController, + } + suite.server.collectionObserver = observers.NewCollectionObserver( + suite.server.dist, + suite.server.meta, + suite.server.targetMgr, + suite.targetObserver, + &checkers.CheckerController{}, + ) + + suite.server.UpdateStateCode(commonpb.StateCode_Healthy) +} + +func (suite *OpsServiceSuite) TestActiveCheckers() { + // test server unhealthy + suite.server.UpdateStateCode(commonpb.StateCode_Abnormal) + ctx := context.Background() + resp, err := suite.server.ListCheckers(ctx, &querypb.ListCheckersRequest{}) + suite.NoError(err) + suite.False(merr.Ok(resp.Status)) + + resp1, err := suite.server.DeactivateChecker(ctx, &querypb.DeactivateCheckerRequest{}) + suite.NoError(err) + suite.False(merr.Ok(resp1)) + + resp2, err := suite.server.ActivateChecker(ctx, &querypb.ActivateCheckerRequest{}) + suite.NoError(err) + suite.False(merr.Ok(resp2)) + + // test active success + suite.server.UpdateStateCode(commonpb.StateCode_Healthy) + resp, err = suite.server.ListCheckers(ctx, &querypb.ListCheckersRequest{}) + suite.NoError(err) + suite.True(merr.Ok(resp.Status)) + suite.Len(resp.GetCheckerInfos(), 5) + + resp4, err := suite.server.DeactivateChecker(ctx, &querypb.DeactivateCheckerRequest{ + CheckerID: int32(utils.ChannelChecker), + }) + suite.NoError(err) + suite.True(merr.Ok(resp4)) + suite.False(suite.checkerController.IsActive(utils.ChannelChecker)) + + resp5, err := suite.server.ActivateChecker(ctx, &querypb.ActivateCheckerRequest{ + CheckerID: int32(utils.ChannelChecker), + }) + suite.NoError(err) + suite.True(merr.Ok(resp5)) + suite.True(suite.checkerController.IsActive(utils.ChannelChecker)) +} + +func (suite *OpsServiceSuite) TestListQueryNode() { + // test server unhealthy + suite.server.UpdateStateCode(commonpb.StateCode_Abnormal) + ctx := context.Background() + resp, err := suite.server.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{}) + suite.NoError(err) + suite.Equal(0, len(resp.GetNodeInfos())) + suite.False(merr.Ok(resp.Status)) + // test server healthy + suite.server.UpdateStateCode(commonpb.StateCode_Healthy) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 111, + Address: "localhost", + Hostname: "localhost", + })) + resp, err = suite.server.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{}) + suite.NoError(err) + suite.Equal(1, len(resp.GetNodeInfos())) +} + +func (suite *OpsServiceSuite) TestGetQueryNodeDistribution() { + // test server unhealthy + suite.server.UpdateStateCode(commonpb.StateCode_Abnormal) + ctx := context.Background() + resp, err := suite.server.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{}) + suite.NoError(err) + suite.False(merr.Ok(resp.Status)) + + // test node not found + suite.server.UpdateStateCode(commonpb.StateCode_Healthy) + resp, err = suite.server.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{ + NodeID: 1, + }) + suite.NoError(err) + suite.False(merr.Ok(resp.Status)) + + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + })) + // test success + channels := []*meta.DmChannel{ + { + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: 1, + ChannelName: "channel1", + }, + Node: 1, + }, + { + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: 1, + ChannelName: "channel2", + }, + Node: 1, + }, + } + + segments := []*meta.Segment{ + { + SegmentInfo: &datapb.SegmentInfo{ + ID: 1, + CollectionID: 1, + PartitionID: 1, + InsertChannel: "channel1", + }, + Node: 1, + }, + { + SegmentInfo: &datapb.SegmentInfo{ + ID: 2, + CollectionID: 1, + PartitionID: 1, + InsertChannel: "channel2", + }, + Node: 1, + }, + } + suite.dist.ChannelDistManager.Update(1, channels...) + suite.dist.SegmentDistManager.Update(1, segments...) + + resp, err = suite.server.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{ + NodeID: 1, + }) + + suite.NoError(err) + suite.True(merr.Ok(resp.Status)) + suite.Equal(2, len(resp.GetChannelNames())) + suite.Equal(2, len(resp.GetSealedSegmentIDs())) +} + +func (suite *OpsServiceSuite) TestCheckQueryNodeDistribution() { + // test server unhealthy + suite.server.UpdateStateCode(commonpb.StateCode_Abnormal) + ctx := context.Background() + resp, err := suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{}) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + // test node not found + suite.server.UpdateStateCode(commonpb.StateCode_Healthy) + resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{ + TargetNodeID: 2, + }) + suite.NoError(err) + suite.False(merr.Ok(resp)) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + })) + + resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{ + SourceNodeID: 1, + }) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + })) + // test success + channels := []*meta.DmChannel{ + { + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: 1, + ChannelName: "channel1", + }, + Node: 1, + }, + { + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: 1, + ChannelName: "channel2", + }, + Node: 1, + }, + } + + segments := []*meta.Segment{ + { + SegmentInfo: &datapb.SegmentInfo{ + ID: 1, + CollectionID: 1, + PartitionID: 1, + InsertChannel: "channel1", + }, + Node: 1, + }, + { + SegmentInfo: &datapb.SegmentInfo{ + ID: 2, + CollectionID: 1, + PartitionID: 1, + InsertChannel: "channel2", + }, + Node: 1, + }, + } + suite.dist.ChannelDistManager.Update(1, channels...) + suite.dist.SegmentDistManager.Update(1, segments...) + + resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{ + SourceNodeID: 1, + TargetNodeID: 2, + }) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + suite.dist.ChannelDistManager.Update(2, channels...) + suite.dist.SegmentDistManager.Update(2, segments...) + resp, err = suite.server.CheckQueryNodeDistribution(ctx, &querypb.CheckQueryNodeDistributionRequest{ + SourceNodeID: 1, + TargetNodeID: 1, + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) +} + +func (suite *OpsServiceSuite) TestSuspendAndResumeBalance() { + // test server unhealthy + suite.server.UpdateStateCode(commonpb.StateCode_Abnormal) + ctx := context.Background() + resp, err := suite.server.SuspendBalance(ctx, &querypb.SuspendBalanceRequest{}) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + resp, err = suite.server.ResumeBalance(ctx, &querypb.ResumeBalanceRequest{}) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + // test suspend success + suite.server.UpdateStateCode(commonpb.StateCode_Healthy) + resp, err = suite.server.SuspendBalance(ctx, &querypb.SuspendBalanceRequest{}) + suite.NoError(err) + suite.True(merr.Ok(resp)) + suite.False(suite.checkerController.IsActive(utils.BalanceChecker)) + + resp, err = suite.server.ResumeBalance(ctx, &querypb.ResumeBalanceRequest{}) + suite.NoError(err) + suite.True(merr.Ok(resp)) + suite.True(suite.checkerController.IsActive(utils.BalanceChecker)) +} + +func (suite *OpsServiceSuite) TestSuspendAndResumeNode() { + // test server unhealthy + suite.server.UpdateStateCode(commonpb.StateCode_Abnormal) + ctx := context.Background() + resp, err := suite.server.SuspendNode(ctx, &querypb.SuspendNodeRequest{}) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + suite.server.UpdateStateCode(commonpb.StateCode_Abnormal) + resp, err = suite.server.ResumeNode(ctx, &querypb.ResumeNodeRequest{}) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + // test node not found + suite.server.UpdateStateCode(commonpb.StateCode_Healthy) + resp, err = suite.server.SuspendNode(ctx, &querypb.SuspendNodeRequest{ + NodeID: 1, + }) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + resp, err = suite.server.ResumeNode(ctx, &querypb.ResumeNodeRequest{ + NodeID: 1, + }) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + })) + // test success + suite.server.UpdateStateCode(commonpb.StateCode_Healthy) + resp, err = suite.server.SuspendNode(ctx, &querypb.SuspendNodeRequest{ + NodeID: 1, + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + node := suite.nodeMgr.Get(1) + suite.Equal(session.NodeStateSuspend, node.GetState()) + + resp, err = suite.server.ResumeNode(ctx, &querypb.ResumeNodeRequest{ + NodeID: 1, + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + node = suite.nodeMgr.Get(1) + suite.Equal(session.NodeStateNormal, node.GetState()) +} + +func (suite *OpsServiceSuite) TestTransferSegment() { + ctx := context.Background() + + // test server unhealthy + suite.server.UpdateStateCode(commonpb.StateCode_Abnormal) + resp, err := suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{}) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + suite.server.UpdateStateCode(commonpb.StateCode_Healthy) + // test source node not healthy + resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{ + SourceNodeID: 1, + }) + suite.NoError(err) + suite.False(merr.Ok(resp)) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + })) + + collectionID := int64(1) + partitionID := int64(1) + replicaID := int64(1) + nodes := []int64{1, 2, 3, 4} + replica := utils.CreateTestReplica(replicaID, collectionID, nodes) + suite.meta.ReplicaManager.Put(replica) + collection := utils.CreateTestCollection(collectionID, 1) + partition := utils.CreateTestPartition(partitionID, collectionID) + suite.meta.PutCollection(collection, partition) + segmentIDs := []int64{1, 2, 3, 4} + channelNames := []string{"channel-1", "channel-2", "channel-3", "channel-4"} + + // test target node not healthy + resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + }) + suite.NoError(err) + suite.False(merr.Ok(resp)) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 2, + Address: "localhost", + Hostname: "localhost", + })) + + // test segment not exist in node + resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + SegmentID: segmentIDs[0], + }) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + segments := []*datapb.SegmentInfo{ + { + ID: segmentIDs[0], + CollectionID: collectionID, + PartitionID: partitionID, + InsertChannel: channelNames[0], + NumOfRows: 1, + }, + { + ID: segmentIDs[1], + CollectionID: collectionID, + PartitionID: partitionID, + InsertChannel: channelNames[1], + NumOfRows: 1, + }, + { + ID: segmentIDs[2], + CollectionID: collectionID, + PartitionID: partitionID, + InsertChannel: channelNames[2], + NumOfRows: 1, + }, + { + ID: segmentIDs[3], + CollectionID: collectionID, + PartitionID: partitionID, + InsertChannel: channelNames[3], + NumOfRows: 1, + }, + } + + channels := []*datapb.VchannelInfo{ + { + CollectionID: collectionID, + ChannelName: channelNames[0], + }, + { + CollectionID: collectionID, + ChannelName: channelNames[1], + }, + { + CollectionID: collectionID, + ChannelName: channelNames[2], + }, + { + CollectionID: collectionID, + ChannelName: channelNames[3], + }, + } + segmentInfos := lo.Map(segments, func(segment *datapb.SegmentInfo, _ int) *meta.Segment { + return &meta.Segment{ + SegmentInfo: segment, + Node: nodes[0], + } + }) + chanenlInfos := lo.Map(channels, func(channel *datapb.VchannelInfo, _ int) *meta.DmChannel { + return &meta.DmChannel{ + VchannelInfo: channel, + Node: nodes[0], + } + }) + suite.dist.SegmentDistManager.Update(1, segmentInfos[0]) + + // test segment not exist in current target, expect no task assign and success + resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + SegmentID: segmentIDs[0], + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(channels, segments, nil) + suite.targetMgr.UpdateCollectionNextTarget(1) + suite.targetMgr.UpdateCollectionCurrentTarget(1) + suite.dist.SegmentDistManager.Update(1, segmentInfos...) + suite.dist.ChannelDistManager.Update(1, chanenlInfos...) + + for _, node := range nodes { + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: node, + Address: "localhost", + Hostname: "localhost", + })) + suite.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, node) + } + + // test transfer segment success, expect generate 1 balance segment task + suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error { + actions := t.Actions() + suite.Equal(len(actions), 2) + suite.Equal(actions[0].Node(), int64(2)) + return nil + }) + resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + SegmentID: segmentIDs[0], + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + + // test copy mode, expect generate 1 load segment task + suite.taskScheduler.ExpectedCalls = nil + suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error { + actions := t.Actions() + suite.Equal(len(actions), 1) + suite.Equal(actions[0].Node(), int64(2)) + return nil + }) + resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + SegmentID: segmentIDs[0], + CopyMode: true, + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + + // test transfer all segments, expect generate 4 load segment task + suite.taskScheduler.ExpectedCalls = nil + counter := atomic.NewInt64(0) + suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error { + actions := t.Actions() + suite.Equal(len(actions), 2) + suite.Equal(actions[0].Node(), int64(2)) + counter.Inc() + return nil + }) + resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + TransferAll: true, + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + suite.Equal(counter.Load(), int64(4)) + + // test transfer all segment to all nodes, expect generate 4 load segment task + suite.taskScheduler.ExpectedCalls = nil + counter = atomic.NewInt64(0) + nodeSet := typeutil.NewUniqueSet() + suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error { + actions := t.Actions() + suite.Equal(len(actions), 2) + nodeSet.Insert(actions[0].Node()) + counter.Inc() + return nil + }) + resp, err = suite.server.TransferSegment(ctx, &querypb.TransferSegmentRequest{ + SourceNodeID: nodes[0], + TransferAll: true, + ToAllNodes: true, + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + suite.Equal(counter.Load(), int64(4)) + suite.Len(nodeSet.Collect(), 3) +} + +func (suite *OpsServiceSuite) TestTransferChannel() { + ctx := context.Background() + + // test server unhealthy + suite.server.UpdateStateCode(commonpb.StateCode_Abnormal) + resp, err := suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{}) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + suite.server.UpdateStateCode(commonpb.StateCode_Healthy) + // test source node not healthy + resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{ + SourceNodeID: 1, + }) + suite.NoError(err) + suite.False(merr.Ok(resp)) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + })) + + collectionID := int64(1) + partitionID := int64(1) + replicaID := int64(1) + nodes := []int64{1, 2, 3, 4} + replica := utils.CreateTestReplica(replicaID, collectionID, nodes) + suite.meta.ReplicaManager.Put(replica) + collection := utils.CreateTestCollection(collectionID, 1) + partition := utils.CreateTestPartition(partitionID, collectionID) + suite.meta.PutCollection(collection, partition) + segmentIDs := []int64{1, 2, 3, 4} + channelNames := []string{"channel-1", "channel-2", "channel-3", "channel-4"} + + // test target node not healthy + resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + }) + suite.NoError(err) + suite.False(merr.Ok(resp)) + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: 2, + Address: "localhost", + Hostname: "localhost", + })) + + segments := []*datapb.SegmentInfo{ + { + ID: segmentIDs[0], + CollectionID: collectionID, + PartitionID: partitionID, + InsertChannel: channelNames[0], + NumOfRows: 1, + }, + { + ID: segmentIDs[1], + CollectionID: collectionID, + PartitionID: partitionID, + InsertChannel: channelNames[1], + NumOfRows: 1, + }, + { + ID: segmentIDs[2], + CollectionID: collectionID, + PartitionID: partitionID, + InsertChannel: channelNames[2], + NumOfRows: 1, + }, + { + ID: segmentIDs[3], + CollectionID: collectionID, + PartitionID: partitionID, + InsertChannel: channelNames[3], + NumOfRows: 1, + }, + } + + channels := []*datapb.VchannelInfo{ + { + CollectionID: collectionID, + ChannelName: channelNames[0], + }, + { + CollectionID: collectionID, + ChannelName: channelNames[1], + }, + { + CollectionID: collectionID, + ChannelName: channelNames[2], + }, + { + CollectionID: collectionID, + ChannelName: channelNames[3], + }, + } + segmentInfos := lo.Map(segments, func(segment *datapb.SegmentInfo, _ int) *meta.Segment { + return &meta.Segment{ + SegmentInfo: segment, + Node: nodes[0], + } + }) + suite.dist.SegmentDistManager.Update(1, segmentInfos...) + chanenlInfos := lo.Map(channels, func(channel *datapb.VchannelInfo, _ int) *meta.DmChannel { + return &meta.DmChannel{ + VchannelInfo: channel, + Node: nodes[0], + } + }) + + // test channel not exist in node + resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + ChannelName: channelNames[0], + }) + suite.NoError(err) + suite.False(merr.Ok(resp)) + + suite.dist.ChannelDistManager.Update(1, chanenlInfos[0]) + + // test channel not exist in current target, expect no task assign and success + resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + ChannelName: channelNames[0], + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(channels, segments, nil) + suite.targetMgr.UpdateCollectionNextTarget(1) + suite.targetMgr.UpdateCollectionCurrentTarget(1) + suite.dist.SegmentDistManager.Update(1, segmentInfos...) + suite.dist.ChannelDistManager.Update(1, chanenlInfos...) + + for _, node := range nodes { + suite.nodeMgr.Add(session.NewNodeInfo(session.ImmutableNodeInfo{ + NodeID: node, + Address: "localhost", + Hostname: "localhost", + })) + suite.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, node) + } + + // test transfer channel success, expect generate 1 balance channel task + suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error { + actions := t.Actions() + suite.Equal(len(actions), 2) + suite.Equal(actions[0].Node(), int64(2)) + return nil + }) + resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + ChannelName: channelNames[0], + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + + // test copy mode, expect generate 1 load segment task + suite.taskScheduler.ExpectedCalls = nil + suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error { + actions := t.Actions() + suite.Equal(len(actions), 1) + suite.Equal(actions[0].Node(), int64(2)) + return nil + }) + resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + ChannelName: channelNames[0], + CopyMode: true, + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + + // test transfer all channels, expect generate 4 load segment task + suite.taskScheduler.ExpectedCalls = nil + counter := atomic.NewInt64(0) + suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error { + actions := t.Actions() + suite.Equal(len(actions), 2) + suite.Equal(actions[0].Node(), int64(2)) + counter.Inc() + return nil + }) + resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{ + SourceNodeID: nodes[0], + TargetNodeID: nodes[1], + TransferAll: true, + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + suite.Equal(counter.Load(), int64(4)) + + // test transfer all channels to all nodes, expect generate 4 load segment task + suite.taskScheduler.ExpectedCalls = nil + counter = atomic.NewInt64(0) + nodeSet := typeutil.NewUniqueSet() + suite.taskScheduler.EXPECT().Add(mock.Anything).RunAndReturn(func(t task.Task) error { + actions := t.Actions() + suite.Equal(len(actions), 2) + nodeSet.Insert(actions[0].Node()) + counter.Inc() + return nil + }) + resp, err = suite.server.TransferChannel(ctx, &querypb.TransferChannelRequest{ + SourceNodeID: nodes[0], + TransferAll: true, + ToAllNodes: true, + }) + suite.NoError(err) + suite.True(merr.Ok(resp)) + suite.Equal(counter.Load(), int64(4)) + suite.Len(nodeSet.Collect(), 3) +} + +func TestOpsService(t *testing.T) { + suite.Run(t, new(OpsServiceSuite)) +} diff --git a/internal/querycoordv2/ops_services.go b/internal/querycoordv2/ops_services.go index d3fe4ddd73..08cc1fad56 100644 --- a/internal/querycoordv2/ops_services.go +++ b/internal/querycoordv2/ops_services.go @@ -19,10 +19,14 @@ package querycoordv2 import ( "context" + "github.com/cockroachdb/errors" + "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querycoordv2/meta" + "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" @@ -93,3 +97,368 @@ func (s *Server) DeactivateChecker(ctx context.Context, req *querypb.DeactivateC } return merr.Success(), nil } + +// return all available node list, for each node, return it's (nodeID, ip_address) +func (s *Server) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest) (*querypb.ListQueryNodeResponse, error) { + log := log.Ctx(ctx) + log.Info("ListQueryNode request received") + + errMsg := "failed to list querynode state" + if err := merr.CheckHealthy(s.State()); err != nil { + log.Warn(errMsg, zap.Error(err)) + return &querypb.ListQueryNodeResponse{ + Status: merr.Status(errors.Wrap(err, errMsg)), + }, nil + } + + nodes := lo.Map(s.nodeMgr.GetAll(), func(nodeInfo *session.NodeInfo, _ int) *querypb.NodeInfo { + return &querypb.NodeInfo{ + ID: nodeInfo.ID(), + Address: nodeInfo.Addr(), + State: nodeInfo.GetState().String(), + } + }) + + return &querypb.ListQueryNodeResponse{ + Status: merr.Success(), + NodeInfos: nodes, + }, nil +} + +// return query node's data distribution, for given nodeID, return it's (channel_name_list, sealed_segment_list) +func (s *Server) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest) (*querypb.GetQueryNodeDistributionResponse, error) { + log := log.Ctx(ctx).With(zap.Int64("nodeID", req.GetNodeID())) + log.Info("GetQueryNodeDistribution request received") + + errMsg := "failed to get query node distribution" + if err := merr.CheckHealthy(s.State()); err != nil { + log.Warn(errMsg, zap.Error(err)) + return &querypb.GetQueryNodeDistributionResponse{ + Status: merr.Status(errors.Wrap(err, errMsg)), + }, nil + } + + if s.nodeMgr.Get(req.GetNodeID()) == nil { + err := merr.WrapErrNodeNotFound(req.GetNodeID(), errMsg) + log.Warn(errMsg, zap.Error(err)) + return &querypb.GetQueryNodeDistributionResponse{ + Status: merr.Status(err), + }, nil + } + + segments := s.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(req.GetNodeID())) + channels := s.dist.ChannelDistManager.GetByNode(req.NodeID) + return &querypb.GetQueryNodeDistributionResponse{ + Status: merr.Success(), + ChannelNames: lo.Map(channels, func(c *meta.DmChannel, _ int) string { return c.GetChannelName() }), + SealedSegmentIDs: lo.Map(segments, func(s *meta.Segment, _ int) int64 { return s.GetID() }), + }, nil +} + +// suspend background balance for all query node, include stopping balance and auto balance +func (s *Server) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest) (*commonpb.Status, error) { + log := log.Ctx(ctx) + log.Info("SuspendBalance request received") + + errMsg := "failed to suspend balance for all querynode" + if err := merr.CheckHealthy(s.State()); err != nil { + return merr.Status(err), nil + } + + err := s.checkerController.Deactivate(utils.BalanceChecker) + if err != nil { + log.Warn(errMsg, zap.Error(err)) + return merr.Status(err), nil + } + + return merr.Success(), nil +} + +// resume background balance for all query node, include stopping balance and auto balance +func (s *Server) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest) (*commonpb.Status, error) { + log := log.Ctx(ctx) + + log.Info("ResumeBalance request received") + + errMsg := "failed to resume balance for all querynode" + if err := merr.CheckHealthy(s.State()); err != nil { + return merr.Status(err), nil + } + + err := s.checkerController.Activate(utils.BalanceChecker) + if err != nil { + log.Warn(errMsg, zap.Error(err)) + return merr.Status(err), nil + } + + return merr.Success(), nil +} + +// suspend node from resource operation, for given node, suspend load_segment/sub_channel operations +func (s *Server) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest) (*commonpb.Status, error) { + log := log.Ctx(ctx) + + log.Info("SuspendNode request received", zap.Int64("nodeID", req.GetNodeID())) + + errMsg := "failed to suspend query node" + if err := merr.CheckHealthy(s.State()); err != nil { + log.Warn(errMsg, zap.Error(err)) + return merr.Status(err), nil + } + + if s.nodeMgr.Get(req.GetNodeID()) == nil { + err := merr.WrapErrNodeNotFound(req.GetNodeID(), errMsg) + log.Warn(errMsg, zap.Error(err)) + return merr.Status(err), nil + } + + err := s.nodeMgr.Suspend(req.GetNodeID()) + if err != nil { + log.Warn(errMsg, zap.Error(err)) + return merr.Status(err), nil + } + + return merr.Success(), nil +} + +// resume node from resource operation, for given node, resume load_segment/sub_channel operations +func (s *Server) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest) (*commonpb.Status, error) { + log := log.Ctx(ctx) + log.Info("ResumeNode request received", zap.Int64("nodeID", req.GetNodeID())) + + errMsg := "failed to resume query node" + if err := merr.CheckHealthy(s.State()); err != nil { + log.Warn(errMsg, zap.Error(err)) + return merr.Status(errors.Wrap(err, errMsg)), nil + } + + if s.nodeMgr.Get(req.GetNodeID()) == nil { + err := merr.WrapErrNodeNotFound(req.GetNodeID(), errMsg) + log.Warn(errMsg, zap.Error(err)) + return merr.Status(err), nil + } + + err := s.nodeMgr.Resume(req.GetNodeID()) + if err != nil { + log.Warn(errMsg, zap.Error(err)) + return merr.Status(errors.Wrap(err, errMsg)), nil + } + + return merr.Success(), nil +} + +// transfer segment from source to target, +// if no segment_id specified, default to transfer all segment on the source node. +// if no target_nodeId specified, default to move segment to all other nodes +func (s *Server) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest) (*commonpb.Status, error) { + log := log.Ctx(ctx) + + log.Info("TransferSegment request received", + zap.Int64("source", req.GetSourceNodeID()), + zap.Int64("dest", req.GetTargetNodeID()), + zap.Int64("segment", req.GetSegmentID())) + + if err := merr.CheckHealthy(s.State()); err != nil { + msg := "failed to load balance" + log.Warn(msg, zap.Error(err)) + return merr.Status(errors.Wrap(err, msg)), nil + } + + // check whether srcNode is healthy + srcNode := req.GetSourceNodeID() + if err := s.isStoppingNode(srcNode); err != nil { + err := merr.WrapErrNodeNotAvailable(srcNode, "the source node is invalid") + return merr.Status(err), nil + } + + replicas := s.meta.ReplicaManager.GetByNode(req.GetSourceNodeID()) + for _, replica := range replicas { + // when no dst node specified, default to use all other nodes in same + dstNodeSet := typeutil.NewUniqueSet() + if req.GetToAllNodes() { + outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica) + availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { return !outboundNodes.Contain(node) }) + dstNodeSet.Insert(availableNodes...) + } else { + // check whether dstNode is healthy + if err := s.isStoppingNode(req.GetTargetNodeID()); err != nil { + err := merr.WrapErrNodeNotAvailable(srcNode, "the target node is invalid") + return merr.Status(err), nil + } + dstNodeSet.Insert(req.GetTargetNodeID()) + } + dstNodeSet.Remove(srcNode) + + // check sealed segment list + segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(replica.GetCollectionID()), meta.WithNodeID(srcNode)) + + toBalance := typeutil.NewSet[*meta.Segment]() + if req.GetTransferAll() { + toBalance.Insert(segments...) + } else { + // check whether sealed segment exist + segment, ok := lo.Find(segments, func(s *meta.Segment) bool { return s.GetID() == req.GetSegmentID() }) + if !ok { + err := merr.WrapErrSegmentNotFound(req.GetSegmentID(), "segment not found in source node") + return merr.Status(err), nil + } + + existInTarget := s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil + if !existInTarget { + log.Info("segment doesn't exist in current target, skip it", zap.Int64("segmentID", req.GetSegmentID())) + } else { + toBalance.Insert(segment) + } + } + + err := s.balanceSegments(ctx, replica.GetCollectionID(), replica, srcNode, dstNodeSet.Collect(), toBalance.Collect(), false, req.GetCopyMode()) + if err != nil { + msg := "failed to balance segments" + log.Warn(msg, zap.Error(err)) + return merr.Status(errors.Wrap(err, msg)), nil + } + } + return merr.Success(), nil +} + +// transfer channel from source to target, +// if no channel_name specified, default to transfer all channel on the source node. +// if no target_nodeId specified, default to move channel to all other nodes +func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest) (*commonpb.Status, error) { + log := log.Ctx(ctx) + + log.Info("TransferChannel request received", + zap.Int64("source", req.GetSourceNodeID()), + zap.Int64("dest", req.GetTargetNodeID()), + zap.String("channel", req.GetChannelName())) + + if err := merr.CheckHealthy(s.State()); err != nil { + msg := "failed to load balance" + log.Warn(msg, zap.Error(err)) + return merr.Status(errors.Wrap(err, msg)), nil + } + + // check whether srcNode is healthy + srcNode := req.GetSourceNodeID() + if err := s.isStoppingNode(srcNode); err != nil { + err := merr.WrapErrNodeNotAvailable(srcNode, "the source node is invalid") + return merr.Status(err), nil + } + + replicas := s.meta.ReplicaManager.GetByNode(req.GetSourceNodeID()) + for _, replica := range replicas { + // when no dst node specified, default to use all other nodes in same + dstNodeSet := typeutil.NewUniqueSet() + if req.GetToAllNodes() { + outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica) + availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { return !outboundNodes.Contain(node) }) + dstNodeSet.Insert(availableNodes...) + } else { + // check whether dstNode is healthy + if err := s.isStoppingNode(req.GetTargetNodeID()); err != nil { + err := merr.WrapErrNodeNotAvailable(srcNode, "the target node is invalid") + return merr.Status(err), nil + } + dstNodeSet.Insert(req.GetTargetNodeID()) + } + dstNodeSet.Remove(srcNode) + + // check sealed segment list + channels := s.dist.ChannelDistManager.GetByCollectionAndNode(replica.CollectionID, srcNode) + toBalance := typeutil.NewSet[*meta.DmChannel]() + if req.GetTransferAll() { + toBalance.Insert(channels...) + } else { + // check whether sealed segment exist + channel, ok := lo.Find(channels, func(ch *meta.DmChannel) bool { return ch.GetChannelName() == req.GetChannelName() }) + if !ok { + err := merr.WrapErrChannelNotFound(req.GetChannelName(), "channel not found in source node") + return merr.Status(err), nil + } + existInTarget := s.targetMgr.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget) != nil + if !existInTarget { + log.Info("channel doesn't exist in current target, skip it", zap.String("channelName", channel.GetChannelName())) + } else { + toBalance.Insert(channel) + } + } + + err := s.balanceChannels(ctx, replica.GetCollectionID(), replica, srcNode, dstNodeSet.Collect(), toBalance.Collect(), false, req.GetCopyMode()) + if err != nil { + msg := "failed to balance channels" + log.Warn(msg, zap.Error(err)) + return merr.Status(errors.Wrap(err, msg)), nil + } + } + return merr.Success(), nil +} + +func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest) (*commonpb.Status, error) { + log := log.Ctx(ctx) + + log.Info("CheckQueryNodeDistribution request received", + zap.Int64("source", req.GetSourceNodeID()), + zap.Int64("dest", req.GetTargetNodeID())) + + errMsg := "failed to check query node distribution" + if err := merr.CheckHealthy(s.State()); err != nil { + log.Warn(errMsg, zap.Error(err)) + return merr.Status(err), nil + } + + sourceNode := s.nodeMgr.Get(req.GetSourceNodeID()) + if sourceNode == nil { + err := merr.WrapErrNodeNotFound(req.GetSourceNodeID(), "source node not found") + log.Warn(errMsg, zap.Error(err)) + return merr.Status(err), nil + } + + targetNode := s.nodeMgr.Get(req.GetTargetNodeID()) + if targetNode == nil { + err := merr.WrapErrNodeNotFound(req.GetTargetNodeID(), "target node not found") + log.Warn(errMsg, zap.Error(err)) + return merr.Status(err), nil + } + + // check channel list + channelOnSrc := s.dist.ChannelDistManager.GetByNode(req.GetSourceNodeID()) + channelOnDst := s.dist.ChannelDistManager.GetByNode(req.GetTargetNodeID()) + channelDstMap := lo.SliceToMap(channelOnDst, func(ch *meta.DmChannel) (string, *meta.DmChannel) { + return ch.GetChannelName(), ch + }) + for _, ch := range channelOnSrc { + if _, ok := channelDstMap[ch.GetChannelName()]; !ok { + return merr.Status(merr.WrapErrChannelLack(ch.GetChannelName())), nil + } + } + channelSrcMap := lo.SliceToMap(channelOnSrc, func(ch *meta.DmChannel) (string, *meta.DmChannel) { + return ch.GetChannelName(), ch + }) + for _, ch := range channelOnDst { + if _, ok := channelSrcMap[ch.GetChannelName()]; !ok { + return merr.Status(merr.WrapErrChannelLack(ch.GetChannelName())), nil + } + } + + // check segment list + segmentOnSrc := s.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(req.GetSourceNodeID())) + segmentOnDst := s.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(req.GetTargetNodeID())) + segmentDstMap := lo.SliceToMap(segmentOnDst, func(s *meta.Segment) (int64, *meta.Segment) { + return s.GetID(), s + }) + for _, s := range segmentOnSrc { + if _, ok := segmentDstMap[s.GetID()]; !ok { + return merr.Status(merr.WrapErrSegmentLack(s.GetID())), nil + } + } + segmentSrcMap := lo.SliceToMap(segmentOnSrc, func(s *meta.Segment) (int64, *meta.Segment) { + return s.GetID(), s + }) + for _, s := range segmentOnDst { + if _, ok := segmentSrcMap[s.GetID()]; !ok { + return merr.Status(merr.WrapErrSegmentLack(s.GetID())), nil + } + } + + return merr.Success(), nil +} diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index 8e5e941ef7..db882801ac 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -682,24 +682,67 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques return merr.Status(errors.Wrap(err, fmt.Sprintf("can't balance, because the source node[%d] is invalid", srcNode))), nil } - for _, dstNode := range req.GetDstNodeIDs() { - if !replica.Contains(dstNode) { - err := merr.WrapErrNodeNotFound(dstNode, "destination node not found in the same replica") - log.Warn("failed to balance to the destination node", zap.Error(err)) - return merr.Status(err), nil + + // when no dst node specified, default to use all other nodes in same + dstNodeSet := typeutil.NewUniqueSet() + if len(req.GetDstNodeIDs()) == 0 { + outboundNodes := s.meta.ResourceManager.CheckOutboundNodes(replica) + availableNodes := lo.Filter(replica.Replica.GetNodes(), func(node int64, _ int) bool { return !outboundNodes.Contain(node) }) + dstNodeSet.Insert(availableNodes...) + } else { + for _, dstNode := range req.GetDstNodeIDs() { + if !replica.Contains(dstNode) { + err := merr.WrapErrNodeNotFound(dstNode, "destination node not found in the same replica") + log.Warn("failed to balance to the destination node", zap.Error(err)) + return merr.Status(err), nil + } + dstNodeSet.Insert(dstNode) } + } + + // check whether dstNode is healthy + for dstNode := range dstNodeSet { if err := s.isStoppingNode(dstNode); err != nil { return merr.Status(errors.Wrap(err, fmt.Sprintf("can't balance, because the destination node[%d] is invalid", dstNode))), nil } } - err := s.balanceSegments(ctx, req, replica) + // check sealed segment list + segments := s.dist.SegmentDistManager.GetByFilter(meta.WithCollectionID(req.GetCollectionID()), meta.WithNodeID(srcNode)) + segmentsMap := lo.SliceToMap(segments, func(s *meta.Segment) (int64, *meta.Segment) { + return s.GetID(), s + }) + + toBalance := typeutil.NewSet[*meta.Segment]() + if len(req.GetSealedSegmentIDs()) == 0 { + toBalance.Insert(segments...) + } else { + // check whether sealed segment exist + for _, segmentID := range req.GetSealedSegmentIDs() { + segment, ok := segmentsMap[segmentID] + if !ok { + err := merr.WrapErrSegmentNotFound(segmentID, "segment not found in source node") + return merr.Status(err), nil + } + + // Only balance segments in targets + existInTarget := s.targetMgr.GetSealedSegment(segment.GetCollectionID(), segment.GetID(), meta.CurrentTarget) != nil + if !existInTarget { + log.Info("segment doesn't exist in current target, skip it", zap.Int64("segmentID", segmentID)) + continue + } + toBalance.Insert(segment) + } + } + + err := s.balanceSegments(ctx, replica.GetCollectionID(), replica, srcNode, dstNodeSet.Collect(), toBalance.Collect(), true, false) if err != nil { msg := "failed to balance segments" log.Warn(msg, zap.Error(err)) return merr.Status(errors.Wrap(err, msg)), nil } + return merr.Success(), nil } diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index 35283f4f87..39f2de6383 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -1174,6 +1174,51 @@ func (suite *ServiceSuite) TestLoadBalance() { suite.Equal(resp.GetCode(), merr.Code(merr.ErrServiceNotReady)) } +func (suite *ServiceSuite) TestLoadBalanceWithNoDstNode() { + suite.loadAll() + ctx := context.Background() + server := suite.server + + // Test get balance first segment + for _, collection := range suite.collections { + replicas := suite.meta.ReplicaManager.GetByCollection(collection) + nodes := replicas[0].GetNodes() + srcNode := nodes[0] + suite.updateCollectionStatus(collection, querypb.LoadStatus_Loaded) + suite.updateSegmentDist(collection, srcNode) + segments := suite.getAllSegments(collection) + req := &querypb.LoadBalanceRequest{ + CollectionID: collection, + SourceNodeIDs: []int64{srcNode}, + SealedSegmentIDs: segments, + } + suite.taskScheduler.ExpectedCalls = make([]*mock.Call, 0) + suite.taskScheduler.EXPECT().Add(mock.Anything).Run(func(task task.Task) { + actions := task.Actions() + suite.Len(actions, 2) + growAction, reduceAction := actions[0], actions[1] + suite.Contains(nodes, growAction.Node()) + suite.Equal(srcNode, reduceAction.Node()) + task.Cancel(nil) + }).Return(nil) + resp, err := server.LoadBalance(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) + suite.taskScheduler.AssertExpectations(suite.T()) + } + + // Test when server is not healthy + server.UpdateStateCode(commonpb.StateCode_Initializing) + req := &querypb.LoadBalanceRequest{ + CollectionID: suite.collections[0], + SourceNodeIDs: []int64{1}, + DstNodeIDs: []int64{100 + 1}, + } + resp, err := server.LoadBalance(ctx, req) + suite.NoError(err) + suite.Equal(resp.GetCode(), merr.Code(merr.ErrServiceNotReady)) +} + func (suite *ServiceSuite) TestLoadBalanceWithEmptySegmentList() { suite.loadAll() ctx := context.Background() diff --git a/internal/querycoordv2/session/node_manager.go b/internal/querycoordv2/session/node_manager.go index 744ff9ed08..43799ae467 100644 --- a/internal/querycoordv2/session/node_manager.go +++ b/internal/querycoordv2/session/node_manager.go @@ -23,8 +23,11 @@ import ( "github.com/blang/semver/v4" "go.uber.org/atomic" + "go.uber.org/zap" + "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/merr" ) type Manager interface { @@ -33,6 +36,9 @@ type Manager interface { Remove(nodeID int64) Get(nodeID int64) *NodeInfo GetAll() []*NodeInfo + + Suspend(nodeID int64) error + Resume(nodeID int64) error } type NodeManager struct { @@ -62,6 +68,42 @@ func (m *NodeManager) Stopping(nodeID int64) { } } +func (m *NodeManager) Suspend(nodeID int64) error { + m.mu.Lock() + defer m.mu.Unlock() + nodeInfo, ok := m.nodes[nodeID] + if !ok { + return merr.WrapErrNodeNotFound(nodeID) + } + switch nodeInfo.GetState() { + case NodeStateNormal: + nodeInfo.SetState(NodeStateSuspend) + return nil + default: + log.Warn("failed to suspend query node", zap.Int64("nodeID", nodeID), zap.String("state", nodeInfo.GetState().String())) + return merr.WrapErrNodeStateUnexpected(nodeID, nodeInfo.GetState().String(), "failed to suspend a query node") + } +} + +func (m *NodeManager) Resume(nodeID int64) error { + m.mu.Lock() + defer m.mu.Unlock() + nodeInfo, ok := m.nodes[nodeID] + if !ok { + return merr.WrapErrNodeNotFound(nodeID) + } + + switch nodeInfo.GetState() { + case NodeStateSuspend: + nodeInfo.SetState(NodeStateNormal) + return nil + + default: + log.Warn("failed to resume query node", zap.Int64("nodeID", nodeID), zap.String("state", nodeInfo.GetState().String())) + return merr.WrapErrNodeStateUnexpected(nodeID, nodeInfo.GetState().String(), "failed to resume query node") + } +} + func (m *NodeManager) IsStoppingNode(nodeID int64) (bool, error) { m.mu.RLock() defer m.mu.RUnlock() @@ -98,8 +140,9 @@ func NewNodeManager() *NodeManager { type State int const ( - NodeStateNormal = iota - NodeStateStopping + NormalStateName = "Normal" + StoppingStateName = "Stopping" + SuspendStateName = "Suspend" ) type ImmutableNodeInfo struct { @@ -109,6 +152,22 @@ type ImmutableNodeInfo struct { Version semver.Version } +const ( + NodeStateNormal State = iota + NodeStateStopping + NodeStateSuspend +) + +var stateNameMap = map[State]string{ + NodeStateNormal: NormalStateName, + NodeStateStopping: StoppingStateName, + NodeStateSuspend: SuspendStateName, +} + +func (s State) String() string { + return stateNameMap[s] +} + type NodeInfo struct { stats mu sync.RWMutex @@ -161,6 +220,12 @@ func (n *NodeInfo) SetState(s State) { n.state = s } +func (n *NodeInfo) GetState() State { + n.mu.RLock() + defer n.mu.RUnlock() + return n.state +} + func (n *NodeInfo) UpdateStats(opts ...StatsOption) { n.mu.Lock() for _, opt := range opts { diff --git a/internal/querycoordv2/session/node_manager_test.go b/internal/querycoordv2/session/node_manager_test.go new file mode 100644 index 0000000000..fd49fa051f --- /dev/null +++ b/internal/querycoordv2/session/node_manager_test.go @@ -0,0 +1,110 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "testing" + "time" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type NodeManagerSuite struct { + suite.Suite + + nodeManager *NodeManager +} + +func (s *NodeManagerSuite) SetupTest() { + s.nodeManager = NewNodeManager() +} + +func (s *NodeManagerSuite) TearDownTest() { +} + +func (s *NodeManagerSuite) TestNodeOperation() { + s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + })) + s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{ + NodeID: 2, + Address: "localhost", + Hostname: "localhost", + })) + s.nodeManager.Add(NewNodeInfo(ImmutableNodeInfo{ + NodeID: 3, + Address: "localhost", + Hostname: "localhost", + })) + + s.NotNil(s.nodeManager.Get(1)) + s.Len(s.nodeManager.GetAll(), 3) + s.nodeManager.Remove(1) + s.Nil(s.nodeManager.Get(1)) + s.Len(s.nodeManager.GetAll(), 2) + + s.nodeManager.Stopping(2) + s.True(s.nodeManager.IsStoppingNode(2)) + err := s.nodeManager.Resume(2) + s.ErrorIs(err, merr.ErrNodeStateUnexpected) + s.True(s.nodeManager.IsStoppingNode(2)) + node := s.nodeManager.Get(2) + node.SetState(NodeStateNormal) + s.False(s.nodeManager.IsStoppingNode(2)) + + err = s.nodeManager.Resume(3) + s.ErrorIs(err, merr.ErrNodeStateUnexpected) + + s.nodeManager.Suspend(3) + node = s.nodeManager.Get(3) + s.NotNil(node) + s.Equal(NodeStateSuspend, node.GetState()) + s.nodeManager.Resume(3) + node = s.nodeManager.Get(3) + s.NotNil(node) + s.Equal(NodeStateNormal, node.GetState()) +} + +func (s *NodeManagerSuite) TestNodeInfo() { + node := NewNodeInfo(ImmutableNodeInfo{ + NodeID: 1, + Address: "localhost", + Hostname: "localhost", + }) + s.Equal(int64(1), node.ID()) + s.Equal("localhost", node.Addr()) + node.setChannelCnt(1) + node.setSegmentCnt(1) + s.Equal(1, node.ChannelCnt()) + s.Equal(1, node.SegmentCnt()) + + node.UpdateStats(WithSegmentCnt(5)) + node.UpdateStats(WithChannelCnt(5)) + s.Equal(5, node.ChannelCnt()) + s.Equal(5, node.SegmentCnt()) + + node.SetLastHeartbeat(time.Now()) + s.NotNil(node.LastHeartbeat()) +} + +func TestNodeManagerSuite(t *testing.T) { + suite.Run(t, new(NodeManagerSuite)) +} diff --git a/internal/querycoordv2/utils/checker.go b/internal/querycoordv2/utils/checker.go index 3c3bdeb31f..0234ff2e98 100644 --- a/internal/querycoordv2/utils/checker.go +++ b/internal/querycoordv2/utils/checker.go @@ -27,6 +27,7 @@ const ( BalanceCheckerName = "balance_checker" IndexCheckerName = "index_checker" LeaderCheckerName = "leader_checker" + ManualBalanceName = "manual_balance" ) type CheckerType int32 @@ -37,6 +38,7 @@ const ( BalanceChecker IndexChecker LeaderChecker + ManualBalance ) var checkerNames = map[CheckerType]string{ @@ -45,6 +47,7 @@ var checkerNames = map[CheckerType]string{ BalanceChecker: BalanceCheckerName, IndexChecker: IndexCheckerName, LeaderChecker: LeaderCheckerName, + ManualBalance: ManualBalanceName, } func (s CheckerType) String() string { diff --git a/internal/util/mock/grpc_querycoord_client.go b/internal/util/mock/grpc_querycoord_client.go index bde03927b8..89632c1a84 100644 --- a/internal/util/mock/grpc_querycoord_client.go +++ b/internal/util/mock/grpc_querycoord_client.go @@ -141,3 +141,39 @@ func (m *GrpcQueryCoordClient) ActivateChecker(ctx context.Context, in *querypb. func (m *GrpcQueryCoordClient) DeactivateChecker(ctx context.Context, in *querypb.DeactivateCheckerRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { return &commonpb.Status{}, m.Err } + +func (m *GrpcQueryCoordClient) ListQueryNode(ctx context.Context, req *querypb.ListQueryNodeRequest, opts ...grpc.CallOption) (*querypb.ListQueryNodeResponse, error) { + return &querypb.ListQueryNodeResponse{}, m.Err +} + +func (m *GrpcQueryCoordClient) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQueryNodeDistributionRequest, opts ...grpc.CallOption) (*querypb.GetQueryNodeDistributionResponse, error) { + return &querypb.GetQueryNodeDistributionResponse{}, m.Err +} + +func (m *GrpcQueryCoordClient) SuspendBalance(ctx context.Context, req *querypb.SuspendBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *GrpcQueryCoordClient) ResumeBalance(ctx context.Context, req *querypb.ResumeBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *GrpcQueryCoordClient) SuspendNode(ctx context.Context, req *querypb.SuspendNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *GrpcQueryCoordClient) ResumeNode(ctx context.Context, req *querypb.ResumeNodeRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *GrpcQueryCoordClient) TransferSegment(ctx context.Context, req *querypb.TransferSegmentRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *GrpcQueryCoordClient) TransferChannel(ctx context.Context, req *querypb.TransferChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *GrpcQueryCoordClient) CheckQueryNodeDistribution(ctx context.Context, req *querypb.CheckQueryNodeDistributionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} diff --git a/pkg/util/merr/errors.go b/pkg/util/merr/errors.go index 9e91ee65d2..68af2f4230 100644 --- a/pkg/util/merr/errors.go +++ b/pkg/util/merr/errors.go @@ -90,11 +90,12 @@ var ( ErrDatabaseInvalidName = newMilvusError("invalid database name", 802, false) // Node related - ErrNodeNotFound = newMilvusError("node not found", 901, false) - ErrNodeOffline = newMilvusError("node offline", 902, false) - ErrNodeLack = newMilvusError("node lacks", 903, false) - ErrNodeNotMatch = newMilvusError("node not match", 904, false) - ErrNodeNotAvailable = newMilvusError("node not available", 905, false) + ErrNodeNotFound = newMilvusError("node not found", 901, false) + ErrNodeOffline = newMilvusError("node offline", 902, false) + ErrNodeLack = newMilvusError("node lacks", 903, false) + ErrNodeNotMatch = newMilvusError("node not match", 904, false) + ErrNodeNotAvailable = newMilvusError("node not available", 905, false) + ErrNodeStateUnexpected = newMilvusError("node state unexpected", 906, false) // IO related ErrIoKeyNotFound = newMilvusError("key not found", 1000, false) diff --git a/pkg/util/merr/errors_test.go b/pkg/util/merr/errors_test.go index f44e6dee17..93dd6392ee 100644 --- a/pkg/util/merr/errors_test.go +++ b/pkg/util/merr/errors_test.go @@ -120,6 +120,7 @@ func (s *ErrSuite) TestWrap() { s.ErrorIs(WrapErrNodeNotFound(1, "failed to get node"), ErrNodeNotFound) s.ErrorIs(WrapErrNodeOffline(1, "failed to access node"), ErrNodeOffline) s.ErrorIs(WrapErrNodeLack(3, 1, "need more nodes"), ErrNodeLack) + s.ErrorIs(WrapErrNodeStateUnexpected(1, "Stopping", "failed to suspend node"), ErrNodeStateUnexpected) // IO related s.ErrorIs(WrapErrIoKeyNotFound("test_key", "failed to read"), ErrIoKeyNotFound) diff --git a/pkg/util/merr/utils.go b/pkg/util/merr/utils.go index 60a9bd427e..0b078e73de 100644 --- a/pkg/util/merr/utils.go +++ b/pkg/util/merr/utils.go @@ -731,6 +731,14 @@ func WrapErrNodeNotAvailable(id int64, msg ...string) error { return err } +func WrapErrNodeStateUnexpected(id int64, state string, msg ...string) error { + err := wrapFields(ErrNodeStateUnexpected, value("node", id), value("state", state)) + if len(msg) > 0 { + err = errors.Wrap(err, strings.Join(msg, "->")) + } + return err +} + func WrapErrNodeNotMatch(expectedNodeID, actualNodeID int64, msg ...string) error { err := wrapFields(ErrNodeNotMatch, value("expectedNodeID", expectedNodeID), diff --git a/tests/integration/rollingupgrade/manual_rolling_upgrade_test.go b/tests/integration/rollingupgrade/manual_rolling_upgrade_test.go new file mode 100644 index 0000000000..d071351b07 --- /dev/null +++ b/tests/integration/rollingupgrade/manual_rolling_upgrade_test.go @@ -0,0 +1,364 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rollingupgrade + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/samber/lo" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/tests/integration" +) + +type ManualRollingUpgradeSuite struct { + integration.MiniClusterSuite +} + +func (s *ManualRollingUpgradeSuite) SetupSuite() { + paramtable.Init() + params := paramtable.Get() + params.Save(params.QueryCoordCfg.BalanceCheckInterval.Key, "2000") + + rand.Seed(time.Now().UnixNano()) + s.Require().NoError(s.SetupEmbedEtcd()) +} + +func (s *ManualRollingUpgradeSuite) TearDownSuite() { + params := paramtable.Get() + params.Reset(params.QueryCoordCfg.BalanceCheckInterval.Key) + + s.TearDownEmbedEtcd() +} + +func (s *ManualRollingUpgradeSuite) TestTransfer() { + c := s.Cluster + ctx, cancel := context.WithCancel(c.GetContext()) + defer cancel() + + prefix := "TestTransfer" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + dim := 128 + rowNum := 3000 + insertRound := 5 + + schema := integration.ConstructSchema(collectionName, dim, true) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: 2, + }) + s.NoError(err) + + err = merr.Error(createCollectionStatus) + if err != nil { + log.Warn("createCollectionStatus fail reason", zap.Error(err)) + } + + log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus)) + showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + s.True(merr.Ok(showCollectionsResp.GetStatus())) + log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) + + // insert data, and flush generate segment + pkFieldData := integration.NewInt64FieldData(integration.Int64Field, rowNum) + hashKeys := integration.GenerateHashKeys(rowNum) + for i := range lo.Range(insertRound) { + insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{pkFieldData, pkFieldData}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.False(merr.Ok(insertResult.GetStatus())) + log.Info("Insert succeed", zap.Int("round", i+1)) + resp, err := s.Cluster.Proxy.Flush(ctx, &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + } + + // create index + createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP), + }) + s.NoError(err) + err = merr.Error(createIndexStatus) + if err != nil { + log.Warn("createIndexStatus fail reason", zap.Error(err)) + } + + s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField) + log.Info("Create index done") + + // load + loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + err = merr.Error(loadStatus) + if err != nil { + log.Warn("LoadCollection fail reason", zap.Error(err)) + } + s.WaitForLoad(ctx, collectionName) + log.Info("Load collection done") + + // suspend balance + resp2, err := s.Cluster.QueryCoord.SuspendBalance(ctx, &querypb.SuspendBalanceRequest{}) + s.NoError(err) + s.True(merr.Ok(resp2)) + + // get origin qn + qnServer1 := s.Cluster.QueryNode + qn1 := qnServer1.GetQueryNode() + + // add new querynode + qnSever2 := s.Cluster.AddQueryNode() + time.Sleep(5 * time.Second) + qn2 := qnSever2.GetQueryNode() + + // expected 2 querynode found + resp3, err := s.Cluster.QueryCoordClient.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{}) + s.NoError(err) + s.Len(resp3.GetNodeInfos(), 2) + + // due to balance has been suspended, qn2 won't have any segment/channel distribution + resp4, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{ + NodeID: qn2.GetNodeID(), + }) + s.NoError(err) + s.Len(resp4.GetChannelNames(), 0) + s.Len(resp4.GetSealedSegmentIDs(), 0) + + resp5, err := s.Cluster.QueryCoordClient.TransferChannel(ctx, &querypb.TransferChannelRequest{ + SourceNodeID: qn1.GetNodeID(), + TargetNodeID: qn2.GetNodeID(), + TransferAll: true, + }) + s.NoError(err) + s.True(merr.Ok(resp5)) + + // wait for transfer channel done + s.Eventually(func() bool { + resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{ + NodeID: qn1.GetNodeID(), + }) + s.NoError(err) + return len(resp.GetChannelNames()) == 0 + }, 10*time.Second, 1*time.Second) + + // test transfer segment + resp6, err := s.Cluster.QueryCoordClient.TransferSegment(ctx, &querypb.TransferSegmentRequest{ + SourceNodeID: qn1.GetNodeID(), + TargetNodeID: qn2.GetNodeID(), + TransferAll: true, + }) + s.NoError(err) + s.True(merr.Ok(resp6)) + + // wait for transfer segment done + s.Eventually(func() bool { + resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{ + NodeID: qn1.GetNodeID(), + }) + s.NoError(err) + return len(resp.GetSealedSegmentIDs()) == 0 + }, 10*time.Second, 1*time.Second) + + // resume balance, segment/channel will be balance to qn1 + resp7, err := s.Cluster.QueryCoord.ResumeBalance(ctx, &querypb.ResumeBalanceRequest{}) + s.NoError(err) + s.True(merr.Ok(resp7)) + + s.Eventually(func() bool { + resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{ + NodeID: qn1.GetNodeID(), + }) + s.NoError(err) + return len(resp.GetSealedSegmentIDs()) > 0 || len(resp.GetChannelNames()) > 0 + }, 10*time.Second, 1*time.Second) + + log.Info("==================") + log.Info("==================") + log.Info("TestManualRollingUpgrade succeed") + log.Info("==================") + log.Info("==================") +} + +func (s *ManualRollingUpgradeSuite) TestSuspendNode() { + c := s.Cluster + ctx, cancel := context.WithCancel(c.GetContext()) + defer cancel() + + prefix := "TestSuspendNode" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + dim := 128 + rowNum := 3000 + insertRound := 5 + + schema := integration.ConstructSchema(collectionName, dim, true) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: 2, + }) + s.NoError(err) + + err = merr.Error(createCollectionStatus) + if err != nil { + log.Warn("createCollectionStatus fail reason", zap.Error(err)) + } + + log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus)) + showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + s.True(merr.Ok(showCollectionsResp.GetStatus())) + log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) + + // insert data, and flush generate segment + pkFieldData := integration.NewInt64FieldData(integration.Int64Field, rowNum) + hashKeys := integration.GenerateHashKeys(rowNum) + for i := range lo.Range(insertRound) { + insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{pkFieldData, pkFieldData}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.False(merr.Ok(insertResult.GetStatus())) + log.Info("Insert succeed", zap.Int("round", i+1)) + resp, err := s.Cluster.Proxy.Flush(ctx, &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + } + + // create index + createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP), + }) + s.NoError(err) + err = merr.Error(createIndexStatus) + if err != nil { + log.Warn("createIndexStatus fail reason", zap.Error(err)) + } + + s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField) + log.Info("Create index done") + + // add new querynode + qnSever2 := s.Cluster.AddQueryNode() + time.Sleep(5 * time.Second) + qn2 := qnSever2.GetQueryNode() + + // expected 2 querynode found + resp3, err := s.Cluster.QueryCoordClient.ListQueryNode(ctx, &querypb.ListQueryNodeRequest{}) + s.NoError(err) + s.Len(resp3.GetNodeInfos(), 2) + + // suspend Node + resp2, err := s.Cluster.QueryCoord.SuspendNode(ctx, &querypb.SuspendNodeRequest{ + NodeID: qn2.GetNodeID(), + }) + s.NoError(err) + s.True(merr.Ok(resp2)) + + // load + loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + err = merr.Error(loadStatus) + if err != nil { + log.Warn("LoadCollection fail reason", zap.Error(err)) + } + s.WaitForLoad(ctx, collectionName) + log.Info("Load collection done") + + // due to node has been suspended, no segment/channel will be loaded to this qn + resp4, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{ + NodeID: qn2.GetNodeID(), + }) + s.NoError(err) + s.Len(resp4.GetChannelNames(), 0) + s.Len(resp4.GetSealedSegmentIDs(), 0) + + // resume node, segment/channel will be balance to qn2 + resp5, err := s.Cluster.QueryCoord.ResumeNode(ctx, &querypb.ResumeNodeRequest{ + NodeID: qn2.GetNodeID(), + }) + s.NoError(err) + s.True(merr.Ok(resp5)) + + s.Eventually(func() bool { + resp, err := s.Cluster.QueryCoordClient.GetQueryNodeDistribution(ctx, &querypb.GetQueryNodeDistributionRequest{ + NodeID: qn2.GetNodeID(), + }) + s.NoError(err) + return len(resp.GetSealedSegmentIDs()) > 0 || len(resp.GetChannelNames()) > 0 + }, 10*time.Second, 1*time.Second) + + log.Info("==================") + log.Info("==================") + log.Info("TestSuspendNode succeed") + log.Info("==================") + log.Info("==================") +} + +func TestManualRollingUpgrade(t *testing.T) { + suite.Run(t, new(ManualRollingUpgradeSuite)) +}