From 735b02cf2bae647e97590f89390cc4e2cc40c247 Mon Sep 17 00:00:00 2001 From: groot Date: Sat, 9 Oct 2021 10:10:59 +0800 Subject: [PATCH] Add unittest for distributed/datanode (#9503) Signed-off-by: yhmo --- internal/datanode/data_node.go | 14 +- internal/distributed/datanode/service.go | 10 +- internal/distributed/datanode/service_test.go | 272 ++++++++++++++++++ internal/types/types.go | 22 ++ 4 files changed, 311 insertions(+), 7 deletions(-) create mode 100644 internal/distributed/datanode/service_test.go diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index 6fdac99049..16d7be94ef 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -138,7 +138,7 @@ func NewDataNode(ctx context.Context, factory msgstream.Factory) *DataNode { } // SetRootCoordInterface sets RootCoord's grpc client, error is returned if repeatedly set. -func (node *DataNode) SetRootCoordInterface(rc types.RootCoord) error { +func (node *DataNode) SetRootCoord(rc types.RootCoord) error { switch { case rc == nil, node.rootCoord != nil: return errors.New("Nil parameter or repeatly set") @@ -149,7 +149,7 @@ func (node *DataNode) SetRootCoordInterface(rc types.RootCoord) error { } // SetDataCoordInterface sets data service's grpc client, error is returned if repeatedly set. -func (node *DataNode) SetDataCoordInterface(ds types.DataCoord) error { +func (node *DataNode) SetDataCoord(ds types.DataCoord) error { switch { case ds == nil, node.dataCoord != nil: return errors.New("Nil parameter or repeatly set") @@ -159,6 +159,11 @@ func (node *DataNode) SetDataCoordInterface(ds types.DataCoord) error { } } +// SetNodeID set node id for DataNode +func (node *DataNode) SetNodeID(id UniqueID) { + node.NodeID = id +} + // Register register datanode to etcd func (node *DataNode) Register() error { node.session = sessionutil.NewSession(node.ctx, Params.MetaRootPath, Params.EtcdEndpoints) @@ -391,6 +396,11 @@ func (node *DataNode) UpdateStateCode(code internalpb.StateCode) { node.State.Store(code) } +// GetStateCode return datanode's state code +func (node *DataNode) GetStateCode() internalpb.StateCode { + return node.State.Load().(internalpb.StateCode) +} + func (node *DataNode) isHealthy() bool { code := node.State.Load().(internalpb.StateCode) return code == internalpb.StateCode_Healthy diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index e9b609f31d..1fb23328d6 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -41,7 +41,7 @@ import ( ) type Server struct { - datanode *dn.DataNode + datanode types.DataNodeComponent wg sync.WaitGroup grpcErrChan chan error grpcServer *grpc.Server @@ -113,11 +113,11 @@ func (s *Server) startGrpcLoop(listener net.Listener) { } func (s *Server) SetRootCoordInterface(ms types.RootCoord) error { - return s.datanode.SetRootCoordInterface(ms) + return s.datanode.SetRootCoord(ms) } func (s *Server) SetDataCoordInterface(ds types.DataCoord) error { - return s.datanode.SetDataCoordInterface(ds) + return s.datanode.SetDataCoord(ds) } func (s *Server) Run() error { @@ -240,7 +240,7 @@ func (s *Server) init() error { } } - s.datanode.NodeID = dn.Params.NodeID + s.datanode.SetNodeID(dn.Params.NodeID) s.datanode.UpdateStateCode(internalpb.StateCode_Initializing) if err := s.datanode.Init(); err != nil { @@ -276,7 +276,7 @@ func (s *Server) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannel } func (s *Server) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsRequest) (*commonpb.Status, error) { - if s.datanode.State.Load().(internalpb.StateCode) != internalpb.StateCode_Healthy { + if s.datanode.GetStateCode() != internalpb.StateCode_Healthy { return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: "DataNode isn't healthy.", diff --git a/internal/distributed/datanode/service_test.go b/internal/distributed/datanode/service_test.go new file mode 100644 index 0000000000..81ac5ac749 --- /dev/null +++ b/internal/distributed/datanode/service_test.go @@ -0,0 +1,272 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package grpcdatanode + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/typeutil" + "github.com/stretchr/testify/assert" +) + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +type MockDataNode struct { + nodeID typeutil.UniqueID + + stateCode internalpb.StateCode + states *internalpb.ComponentStates + status *commonpb.Status + err error + initErr error + startErr error + stopErr error + regErr error + strResp *milvuspb.StringResponse + metricResp *milvuspb.GetMetricsResponse +} + +func (m *MockDataNode) Init() error { + return m.initErr +} + +func (m *MockDataNode) Start() error { + return m.startErr +} + +func (m *MockDataNode) Stop() error { + return m.stopErr +} + +func (m *MockDataNode) Register() error { + return m.regErr +} + +func (m *MockDataNode) SetNodeID(id typeutil.UniqueID) { + m.nodeID = id +} + +func (m *MockDataNode) UpdateStateCode(code internalpb.StateCode) { + m.stateCode = code +} + +func (m *MockDataNode) GetStateCode() internalpb.StateCode { + return m.stateCode +} + +func (m *MockDataNode) SetRootCoord(rc types.RootCoord) error { + return m.err +} + +func (m *MockDataNode) SetDataCoord(dc types.DataCoord) error { + return m.err +} + +func (m *MockDataNode) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { + return m.states, m.err +} + +func (m *MockDataNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { + return m.strResp, m.err +} + +func (m *MockDataNode) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) { + return m.status, m.err +} + +func (m *MockDataNode) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsRequest) (*commonpb.Status, error) { + return m.status, m.err +} + +func (m *MockDataNode) GetMetrics(ctx context.Context, request *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + return m.metricResp, m.err +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +type mockDataCoord struct { + types.DataCoord +} + +func (m *mockDataCoord) Init() error { + return nil +} +func (m *mockDataCoord) Start() error { + return nil +} +func (m *mockDataCoord) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{ + State: &internalpb.ComponentInfo{ + StateCode: internalpb.StateCode_Healthy, + }, + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + SubcomponentStates: []*internalpb.ComponentInfo{ + { + StateCode: internalpb.StateCode_Healthy, + }, + }, + }, nil +} +func (m *mockDataCoord) Stop() error { + return fmt.Errorf("stop error") +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +type mockRootCoord struct { + types.RootCoord +} + +func (m *mockRootCoord) Init() error { + return nil +} +func (m *mockRootCoord) Start() error { + return nil +} +func (m *mockRootCoord) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{ + State: &internalpb.ComponentInfo{ + StateCode: internalpb.StateCode_Healthy, + }, + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + SubcomponentStates: []*internalpb.ComponentInfo{ + { + StateCode: internalpb.StateCode_Healthy, + }, + }, + }, nil +} +func (m *mockRootCoord) Stop() error { + return fmt.Errorf("stop error") +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +func Test_NewServer(t *testing.T) { + ctx := context.Background() + server, err := NewServer(ctx, nil) + assert.Nil(t, err) + assert.NotNil(t, server) + + server.newRootCoordClient = func(string, []string) (types.RootCoord, error) { + return &mockRootCoord{}, nil + } + + server.newDataCoordClient = func(string, []string) (types.DataCoord, error) { + return &mockDataCoord{}, nil + } + + t.Run("Run", func(t *testing.T) { + server.datanode = &MockDataNode{} + err = server.Run() + assert.Nil(t, err) + }) + + t.Run("GetComponentStates", func(t *testing.T) { + server.datanode = &MockDataNode{ + states: &internalpb.ComponentStates{}, + } + states, err := server.GetComponentStates(ctx, nil) + assert.Nil(t, err) + assert.NotNil(t, states) + }) + + t.Run("GetStatisticsChannel", func(t *testing.T) { + server.datanode = &MockDataNode{ + strResp: &milvuspb.StringResponse{}, + } + states, err := server.GetStatisticsChannel(ctx, nil) + assert.Nil(t, err) + assert.NotNil(t, states) + }) + + t.Run("WatchDmChannels", func(t *testing.T) { + server.datanode = &MockDataNode{ + status: &commonpb.Status{}, + } + states, err := server.WatchDmChannels(ctx, nil) + assert.Nil(t, err) + assert.NotNil(t, states) + }) + + t.Run("FlushSegments", func(t *testing.T) { + server.datanode = &MockDataNode{ + status: &commonpb.Status{}, + } + states, err := server.FlushSegments(ctx, nil) + assert.NotNil(t, err) + assert.NotNil(t, states) + }) + + t.Run("GetMetrics", func(t *testing.T) { + server.datanode = &MockDataNode{ + metricResp: &milvuspb.GetMetricsResponse{}, + } + resp, err := server.GetMetrics(ctx, nil) + assert.Nil(t, err) + assert.NotNil(t, resp) + }) + + err = server.Stop() + assert.Nil(t, err) +} + +func Test_Run(t *testing.T) { + ctx := context.Background() + server, err := NewServer(ctx, nil) + assert.Nil(t, err) + assert.NotNil(t, server) + + server.datanode = &MockDataNode{ + regErr: errors.New("error"), + } + + server.newRootCoordClient = func(string, []string) (types.RootCoord, error) { + return &mockRootCoord{}, nil + } + + server.newDataCoordClient = func(string, []string) (types.DataCoord, error) { + return &mockDataCoord{}, nil + } + + err = server.Run() + assert.Error(t, err) + + server.datanode = &MockDataNode{ + startErr: errors.New("error"), + } + + err = server.Run() + assert.Error(t, err) + + server.datanode = &MockDataNode{ + initErr: errors.New("error"), + } + + err = server.Run() + assert.Error(t, err) + + server.datanode = &MockDataNode{ + stopErr: errors.New("error"), + } + + err = server.Stop() + assert.Error(t, err) +} diff --git a/internal/types/types.go b/internal/types/types.go index 3fa18c285c..404a49b701 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -23,6 +23,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" ) // TimeTickProvider is the interface all services implement @@ -59,6 +60,27 @@ type DataNode interface { GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) } +// DataNodeComponent is used by grpc server of DataNode +type DataNodeComponent interface { + DataNode + + // UpdateStateCode updates state code for DataNode + // State includes: Initializing, Healthy and Abnormal + UpdateStateCode(internalpb.StateCode) + + // GetStateCode return state code for DataNode + GetStateCode() internalpb.StateCode + + // SetRootCoord set RootCoord for DataNode + SetRootCoord(RootCoord) error + + // SetDataCoord set DataCoord for DataNode + SetDataCoord(DataCoord) error + + // SetNodeID set node id for DataNode + SetNodeID(typeutil.UniqueID) +} + // DataCoord is the interface `datacoord` package implements type DataCoord interface { Component