diff --git a/internal/datacoord/cluster_test.go b/internal/datacoord/cluster_test.go index cef5f5f057..98bc9e0dec 100644 --- a/internal/datacoord/cluster_test.go +++ b/internal/datacoord/cluster_test.go @@ -23,9 +23,11 @@ import ( "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" + "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "stathat.com/c/consistent" "github.com/milvus-io/milvus/internal/kv" @@ -33,6 +35,7 @@ import ( "github.com/milvus-io/milvus/internal/kv/mocks" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/metrics" ) func getMetaKv(t *testing.T) kv.MetaKv { @@ -51,21 +54,42 @@ func getWatchKV(t *testing.T) kv.WatchKV { return kv } -func TestClusterCreate(t *testing.T) { - kv := getWatchKV(t) - defer func() { - kv.RemoveWithPrefix("") - kv.Close() - }() +type ClusterSuite struct { + suite.Suite - t.Run("startup normally", func(t *testing.T) { + kv kv.WatchKV +} +func (suite *ClusterSuite) getWatchKV() kv.WatchKV { + rootPath := "/etcd/test/root/" + suite.T().Name() + kv, err := etcdkv.NewWatchKVFactory(rootPath, &Params.EtcdCfg) + suite.Require().NoError(err) + + return kv +} + +func (suite *ClusterSuite) SetupTest() { + kv := getWatchKV(suite.T()) + suite.kv = kv +} + +func (suite *ClusterSuite) TearDownTest() { + if suite.kv != nil { + suite.kv.RemoveWithPrefix("") + suite.kv.Close() + } +} + +func (suite *ClusterSuite) TestCreate() { + kv := suite.kv + + suite.Run("startup_normally", func() { ctx, cancel := context.WithCancel(context.TODO()) defer cancel() sessionManager := NewSessionManager() channelManager, err := NewChannelManager(kv, newMockHandler()) - assert.NoError(t, err) + suite.NoError(err) cluster := NewCluster(sessionManager, channelManager) defer cluster.Close() addr := "localhost:8080" @@ -75,15 +99,15 @@ func TestClusterCreate(t *testing.T) { } nodes := []*NodeInfo{info} err = cluster.Startup(ctx, nodes) - assert.NoError(t, err) + suite.NoError(err) dataNodes := sessionManager.GetSessions() - assert.EqualValues(t, 1, len(dataNodes)) - assert.EqualValues(t, "localhost:8080", dataNodes[0].info.Address) + suite.EqualValues(1, len(dataNodes)) + suite.EqualValues("localhost:8080", dataNodes[0].info.Address) + + suite.Equal(float64(1), testutil.ToFloat64(metrics.DataCoordNumDataNodes)) }) - t.Run("startup with existed channel data", func(t *testing.T) { - defer kv.RemoveWithPrefix("") - + suite.Run("startup_with_existed_channel_data", func() { ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -95,24 +119,24 @@ func TestClusterCreate(t *testing.T) { }, } info1Data, err := proto.Marshal(info1) - assert.NoError(t, err) + suite.NoError(err) err = kv.Save(Params.CommonCfg.DataCoordWatchSubPath.GetValue()+"/1/channel1", string(info1Data)) - assert.NoError(t, err) + suite.NoError(err) sessionManager := NewSessionManager() channelManager, err := NewChannelManager(kv, newMockHandler()) - assert.NoError(t, err) + suite.NoError(err) cluster := NewCluster(sessionManager, channelManager) defer cluster.Close() err = cluster.Startup(ctx, []*NodeInfo{{NodeID: 1, Address: "localhost:9999"}}) - assert.NoError(t, err) + suite.NoError(err) channels := channelManager.GetChannels() - assert.EqualValues(t, []*NodeChannelInfo{{1, []*channel{{Name: "channel1", CollectionID: 1}}}}, channels) + suite.EqualValues([]*NodeChannelInfo{{1, []*channel{{Name: "channel1", CollectionID: 1}}}}, channels) }) - t.Run("remove all nodes and restart with other nodes", func(t *testing.T) { + suite.Run("remove_all_nodes_and_restart_with_other_nodes", func() { defer kv.RemoveWithPrefix("") ctx, cancel := context.WithCancel(context.TODO()) @@ -120,7 +144,7 @@ func TestClusterCreate(t *testing.T) { sessionManager := NewSessionManager() channelManager, err := NewChannelManager(kv, newMockHandler()) - assert.NoError(t, err) + suite.NoError(err) cluster := NewCluster(sessionManager, channelManager) addr := "localhost:8080" @@ -130,18 +154,18 @@ func TestClusterCreate(t *testing.T) { } nodes := []*NodeInfo{info} err = cluster.Startup(ctx, nodes) - assert.NoError(t, err) + suite.NoError(err) err = cluster.UnRegister(info) - assert.NoError(t, err) + suite.NoError(err) sessions := sessionManager.GetSessions() - assert.Empty(t, sessions) + suite.Empty(sessions) cluster.Close() sessionManager2 := NewSessionManager() channelManager2, err := NewChannelManager(kv, newMockHandler()) - assert.NoError(t, err) + suite.NoError(err) clusterReload := NewCluster(sessionManager2, channelManager2) defer clusterReload.Close() @@ -152,34 +176,30 @@ func TestClusterCreate(t *testing.T) { } nodes = []*NodeInfo{info} err = clusterReload.Startup(ctx, nodes) - assert.NoError(t, err) + suite.NoError(err) sessions = sessionManager2.GetSessions() - assert.EqualValues(t, 1, len(sessions)) - assert.EqualValues(t, 2, sessions[0].info.NodeID) - assert.EqualValues(t, addr, sessions[0].info.Address) + suite.EqualValues(1, len(sessions)) + suite.EqualValues(2, sessions[0].info.NodeID) + suite.EqualValues(addr, sessions[0].info.Address) channels := channelManager2.GetChannels() - assert.EqualValues(t, 1, len(channels)) - assert.EqualValues(t, 2, channels[0].NodeID) + suite.EqualValues(1, len(channels)) + suite.EqualValues(2, channels[0].NodeID) }) - t.Run("loadKv Fails", func(t *testing.T) { + suite.Run("loadkv_fails", func() { defer kv.RemoveWithPrefix("") - metakv := mocks.NewWatchKV(t) + metakv := mocks.NewWatchKV(suite.T()) metakv.EXPECT().LoadWithPrefix(mock.Anything).Return(nil, nil, errors.New("failed")) _, err := NewChannelManager(metakv, newMockHandler()) - assert.Error(t, err) + suite.Error(err) }) } -func TestRegister(t *testing.T) { - kv := getWatchKV(t) - defer func() { - kv.RemoveWithPrefix("") - kv.Close() - }() +func (suite *ClusterSuite) TestRegister() { + kv := suite.kv - t.Run("register to empty cluster", func(t *testing.T) { + suite.Run("register_to_empty_cluster", func() { defer kv.RemoveWithPrefix("") ctx, cancel := context.WithCancel(context.TODO()) @@ -187,24 +207,26 @@ func TestRegister(t *testing.T) { sessionManager := NewSessionManager() channelManager, err := NewChannelManager(kv, newMockHandler()) - assert.NoError(t, err) + suite.NoError(err) cluster := NewCluster(sessionManager, channelManager) defer cluster.Close() addr := "localhost:8080" err = cluster.Startup(ctx, nil) - assert.NoError(t, err) + suite.NoError(err) info := &NodeInfo{ NodeID: 1, Address: addr, } err = cluster.Register(info) - assert.NoError(t, err) + suite.NoError(err) sessions := sessionManager.GetSessions() - assert.EqualValues(t, 1, len(sessions)) - assert.EqualValues(t, "localhost:8080", sessions[0].info.Address) + suite.EqualValues(1, len(sessions)) + suite.EqualValues("localhost:8080", sessions[0].info.Address) + + suite.Equal(float64(1), testutil.ToFloat64(metrics.DataCoordNumDataNodes)) }) - t.Run("register to empty cluster with buffer channels", func(t *testing.T) { + suite.Run("register_to_empty_cluster_with_buffer_channels", func() { defer kv.RemoveWithPrefix("") ctx, cancel := context.WithCancel(context.TODO()) @@ -212,32 +234,34 @@ func TestRegister(t *testing.T) { sessionManager := NewSessionManager() channelManager, err := NewChannelManager(kv, newMockHandler()) - assert.NoError(t, err) + suite.NoError(err) err = channelManager.Watch(&channel{ Name: "ch1", CollectionID: 0, }) - assert.NoError(t, err) + suite.NoError(err) cluster := NewCluster(sessionManager, channelManager) defer cluster.Close() addr := "localhost:8080" err = cluster.Startup(ctx, nil) - assert.NoError(t, err) + suite.NoError(err) info := &NodeInfo{ NodeID: 1, Address: addr, } err = cluster.Register(info) - assert.NoError(t, err) + suite.NoError(err) bufferChannels := channelManager.GetBufferChannels() - assert.Empty(t, bufferChannels.Channels) + suite.Empty(bufferChannels.Channels) nodeChannels := channelManager.GetChannels() - assert.EqualValues(t, 1, len(nodeChannels)) - assert.EqualValues(t, 1, nodeChannels[0].NodeID) - assert.EqualValues(t, "ch1", nodeChannels[0].Channels[0].Name) + suite.EqualValues(1, len(nodeChannels)) + suite.EqualValues(1, nodeChannels[0].NodeID) + suite.EqualValues("ch1", nodeChannels[0].Channels[0].Name) + + suite.Equal(float64(1), testutil.ToFloat64(metrics.DataCoordNumDataNodes)) }) - t.Run("register and restart with no channel", func(t *testing.T) { + suite.Run("register_and_restart_with_no_channel", func() { defer kv.RemoveWithPrefix("") ctx, cancel := context.WithCancel(context.TODO()) @@ -245,29 +269,138 @@ func TestRegister(t *testing.T) { sessionManager := NewSessionManager() channelManager, err := NewChannelManager(kv, newMockHandler()) - assert.NoError(t, err) + suite.NoError(err) cluster := NewCluster(sessionManager, channelManager) addr := "localhost:8080" err = cluster.Startup(ctx, nil) - assert.NoError(t, err) + suite.NoError(err) info := &NodeInfo{ NodeID: 1, Address: addr, } err = cluster.Register(info) - assert.NoError(t, err) + suite.NoError(err) cluster.Close() sessionManager2 := NewSessionManager() channelManager2, err := NewChannelManager(kv, newMockHandler()) - assert.NoError(t, err) + suite.NoError(err) restartCluster := NewCluster(sessionManager2, channelManager2) defer restartCluster.Close() channels := channelManager2.GetChannels() - assert.Empty(t, channels) + suite.Empty(channels) + + suite.Equal(float64(1), testutil.ToFloat64(metrics.DataCoordNumDataNodes)) }) } +func (suite *ClusterSuite) TestUnregister() { + kv := suite.kv + + suite.Run("remove_node_after_unregister", func() { + defer kv.RemoveWithPrefix("") + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + sessionManager := NewSessionManager() + channelManager, err := NewChannelManager(kv, newMockHandler()) + suite.NoError(err) + cluster := NewCluster(sessionManager, channelManager) + defer cluster.Close() + addr := "localhost:8080" + info := &NodeInfo{ + Address: addr, + NodeID: 1, + } + nodes := []*NodeInfo{info} + err = cluster.Startup(ctx, nodes) + suite.NoError(err) + err = cluster.UnRegister(nodes[0]) + suite.NoError(err) + sessions := sessionManager.GetSessions() + suite.Empty(sessions) + + suite.Equal(float64(0), testutil.ToFloat64(metrics.DataCoordNumDataNodes)) + }) + + suite.Run("move_channel_to_online_nodes_after_unregister", func() { + defer kv.RemoveWithPrefix("") + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + sessionManager := NewSessionManager() + channelManager, err := NewChannelManager(kv, newMockHandler()) + suite.NoError(err) + cluster := NewCluster(sessionManager, channelManager) + defer cluster.Close() + + nodeInfo1 := &NodeInfo{ + Address: "localhost:8080", + NodeID: 1, + } + nodeInfo2 := &NodeInfo{ + Address: "localhost:8081", + NodeID: 2, + } + nodes := []*NodeInfo{nodeInfo1, nodeInfo2} + err = cluster.Startup(ctx, nodes) + suite.NoError(err) + err = cluster.Watch("ch1", 1) + suite.NoError(err) + err = cluster.UnRegister(nodeInfo1) + suite.NoError(err) + + channels := channelManager.GetChannels() + suite.EqualValues(1, len(channels)) + suite.EqualValues(2, channels[0].NodeID) + suite.EqualValues(1, len(channels[0].Channels)) + suite.EqualValues("ch1", channels[0].Channels[0].Name) + + suite.Equal(float64(1), testutil.ToFloat64(metrics.DataCoordNumDataNodes)) + }) + + suite.Run("remove all channels after unregsiter", func() { + defer kv.RemoveWithPrefix("") + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + var mockSessionCreator = func(ctx context.Context, addr string) (types.DataNode, error) { + return newMockDataNodeClient(1, nil) + } + sessionManager := NewSessionManager(withSessionCreator(mockSessionCreator)) + channelManager, err := NewChannelManager(kv, newMockHandler()) + suite.NoError(err) + cluster := NewCluster(sessionManager, channelManager) + defer cluster.Close() + + nodeInfo := &NodeInfo{ + Address: "localhost:8080", + NodeID: 1, + } + err = cluster.Startup(ctx, []*NodeInfo{nodeInfo}) + suite.NoError(err) + err = cluster.Watch("ch_1", 1) + suite.NoError(err) + err = cluster.UnRegister(nodeInfo) + suite.NoError(err) + channels := channelManager.GetChannels() + suite.Empty(channels) + channel := channelManager.GetBufferChannels() + suite.NotNil(channel) + suite.EqualValues(1, len(channel.Channels)) + suite.EqualValues("ch_1", channel.Channels[0].Name) + + suite.Equal(float64(0), testutil.ToFloat64(metrics.DataCoordNumDataNodes)) + }) +} + +func TestCluster(t *testing.T) { + suite.Run(t, new(ClusterSuite)) +} + func TestUnregister(t *testing.T) { kv := getWatchKV(t) defer func() {