diff --git a/internal/dataservice/cluster.go b/internal/dataservice/cluster.go index 8169f991de..5cf65856e0 100644 --- a/internal/dataservice/cluster.go +++ b/internal/dataservice/cluster.go @@ -11,153 +11,226 @@ package dataservice import ( - "context" - "errors" - "fmt" "sync" - "go.uber.org/zap" - "github.com/milvus-io/milvus/internal/log" - "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/internalpb" + "go.uber.org/zap" + "golang.org/x/net/context" ) -type dataNode struct { - id int64 - address struct { - ip string - port int64 - } - client types.DataNode - channelNum int -} -type dataNodeCluster struct { - sync.RWMutex - nodes []*dataNode +type cluster struct { + mu sync.RWMutex + ctx context.Context + dataManager *clusterNodeManager + sessionManager sessionManager + + startupPolicy clusterStartupPolicy + registerPolicy dataNodeRegisterPolicy + unregisterPolicy dataNodeUnregisterPolicy + assginPolicy channelAssignPolicy } -func (node *dataNode) String() string { - return fmt.Sprintf("id: %d, address: %s:%d", node.id, node.address.ip, node.address.port) +type clusterOption struct { + apply func(c *cluster) } -func newDataNodeCluster() *dataNodeCluster { - return &dataNodeCluster{ - nodes: make([]*dataNode, 0), +func withStartupPolicy(p clusterStartupPolicy) clusterOption { + return clusterOption{ + apply: func(c *cluster) { c.startupPolicy = p }, } } -func (c *dataNodeCluster) Register(dataNode *dataNode) error { - c.Lock() - defer c.Unlock() - if c.checkDataNodeNotExist(dataNode.address.ip, dataNode.address.port) { - c.nodes = append(c.nodes, dataNode) - return nil +func withRegisterPolicy(p dataNodeRegisterPolicy) clusterOption { + return clusterOption{ + apply: func(c *cluster) { c.registerPolicy = p }, } - return errors.New("datanode already exist") } -func (c *dataNodeCluster) checkDataNodeNotExist(ip string, port int64) bool { - for _, node := range c.nodes { - if node.address.ip == ip && node.address.port == port { - return false +func withUnregistorPolicy(p dataNodeUnregisterPolicy) clusterOption { + return clusterOption{ + apply: func(c *cluster) { c.unregisterPolicy = p }, + } +} + +func withAssignPolicy(p channelAssignPolicy) clusterOption { + return clusterOption{ + apply: func(c *cluster) { c.assginPolicy = p }, + } +} + +func defaultStartupPolicy() clusterStartupPolicy { + return newReWatchOnRestartsStartupPolicy() +} + +func defaultRegisterPolicy() dataNodeRegisterPolicy { + return newDoNothingRegisterPolicy() +} + +func defaultUnregisterPolicy() dataNodeUnregisterPolicy { + return newDoNothingUnregisterPolicy() +} + +func defaultAssignPolicy() channelAssignPolicy { + return newAllAssignPolicy() +} + +func newCluster(ctx context.Context, dataManager *clusterNodeManager, sessionManager sessionManager, opts ...clusterOption) *cluster { + c := &cluster{ + ctx: ctx, + sessionManager: sessionManager, + dataManager: dataManager, + startupPolicy: defaultStartupPolicy(), + registerPolicy: defaultRegisterPolicy(), + unregisterPolicy: defaultUnregisterPolicy(), + assginPolicy: defaultAssignPolicy(), + } + for _, opt := range opts { + opt.apply(c) + } + + return c +} + +func (c *cluster) startup(dataNodes []*datapb.DataNodeInfo) error { + deltaChange := c.dataManager.updateCluster(dataNodes) + nodes := c.dataManager.getDataNodes(false) + rets := c.startupPolicy.apply(nodes, deltaChange) + c.dataManager.updateDataNodes(rets) + rets = c.watch(rets) + c.dataManager.updateDataNodes(rets) + return nil +} + +func (c *cluster) watch(nodes []*datapb.DataNodeInfo) []*datapb.DataNodeInfo { + for _, n := range nodes { + uncompletes := make([]string, 0) + for _, ch := range n.Channels { + if ch.State == datapb.ChannelWatchState_Uncomplete { + uncompletes = append(uncompletes, ch.Name) + } } - } - return true -} - -func (c *dataNodeCluster) GetNumOfNodes() int { - c.RLock() - defer c.RUnlock() - return len(c.nodes) -} - -func (c *dataNodeCluster) GetNodeIDs() []int64 { - c.RLock() - defer c.RUnlock() - ret := make([]int64, 0, len(c.nodes)) - for _, node := range c.nodes { - ret = append(ret, node.id) - } - return ret -} - -func (c *dataNodeCluster) WatchInsertChannels(channels []string) { - ctx := context.TODO() - c.Lock() - defer c.Unlock() - var groups [][]string - if len(channels) < len(c.nodes) { - groups = make([][]string, len(channels)) - } else { - groups = make([][]string, len(c.nodes)) - } - length := len(groups) - for i, channel := range channels { - groups[i%length] = append(groups[i%length], channel) - } - for i, group := range groups { - resp, err := c.nodes[i].client.WatchDmChannels(ctx, &datapb.WatchDmChannelsRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_DescribeCollection, - MsgID: -1, // todo - Timestamp: 0, // todo - SourceID: Params.NodeID, - }, - // ChannelNames: group, // TODO - }) - if err = VerifyResponse(resp, err); err != nil { - log.Error("watch dm channels error", zap.Stringer("dataNode", c.nodes[i]), zap.Error(err)) - continue - } - c.nodes[i].channelNum += len(group) - } -} - -func (c *dataNodeCluster) GetDataNodeStates(ctx context.Context) ([]*internalpb.ComponentInfo, error) { - c.RLock() - defer c.RUnlock() - ret := make([]*internalpb.ComponentInfo, 0) - for _, node := range c.nodes { - states, err := node.client.GetComponentStates(ctx) + cli, err := c.sessionManager.getOrCreateSession(n.Address) if err != nil { - log.Error("get component states error", zap.Stringer("dataNode", node), zap.Error(err)) + log.Warn("get session failed", zap.String("addr", n.Address), zap.Error(err)) continue } - ret = append(ret, states.State) + req := &datapb.WatchDmChannelsRequest{ + Base: &commonpb.MsgBase{ + SourceID: Params.NodeID, + }, + //ChannelNames: uncompletes, + } + resp, err := cli.WatchDmChannels(c.ctx, req) + if err != nil { + log.Warn("watch dm channel failed", zap.String("addr", n.Address), zap.Error(err)) + continue + } + if resp.ErrorCode != commonpb.ErrorCode_Success { + log.Warn("watch channels failed", zap.String("address", n.Address), zap.Error(err)) + continue + } + for _, ch := range n.Channels { + if ch.State == datapb.ChannelWatchState_Uncomplete { + ch.State = datapb.ChannelWatchState_Complete + } + } } - return ret, nil + return nodes } -func (c *dataNodeCluster) FlushSegment(request *datapb.FlushSegmentsRequest) { - ctx := context.TODO() - c.Lock() - defer c.Unlock() - for _, node := range c.nodes { - if _, err := node.client.FlushSegments(ctx, request); err != nil { - log.Error("flush segment err", zap.Stringer("dataNode", node), zap.Error(err)) +func (c *cluster) register(n *datapb.DataNodeInfo) { + c.mu.Lock() + defer c.mu.Unlock() + c.dataManager.register(n) + cNodes := c.dataManager.getDataNodes(true) + rets := c.registerPolicy.apply(cNodes, n) + c.dataManager.updateDataNodes(rets) + rets = c.watch(rets) + c.dataManager.updateDataNodes(rets) +} + +func (c *cluster) unregister(n *datapb.DataNodeInfo) { + c.mu.Lock() + defer c.mu.Unlock() + c.dataManager.unregister(n) + cNodes := c.dataManager.getDataNodes(true) + rets := c.unregisterPolicy.apply(cNodes, n) + c.dataManager.updateDataNodes(rets) + rets = c.watch(rets) + c.dataManager.updateDataNodes(rets) +} + +func (c *cluster) watchIfNeeded(channel string) { + c.mu.Lock() + defer c.mu.Unlock() + cNodes := c.dataManager.getDataNodes(true) + rets := c.assginPolicy.apply(cNodes, channel) + c.dataManager.updateDataNodes(rets) + rets = c.watch(rets) + c.dataManager.updateDataNodes(rets) +} + +func (c *cluster) flush(segments []*datapb.SegmentInfo) { + log.Debug("prepare to flush", zap.Any("segments", segments)) + c.mu.Lock() + defer c.mu.Unlock() + + m := make(map[string]map[UniqueID][]UniqueID) // channel-> map[collectionID]segmentIDs + + for _, seg := range segments { + if _, ok := m[seg.InsertChannel]; !ok { + m[seg.InsertChannel] = make(map[UniqueID][]UniqueID) + } + + m[seg.InsertChannel][seg.CollectionID] = append(m[seg.InsertChannel][seg.CollectionID], seg.ID) + } + + dataNodes := c.dataManager.getDataNodes(true) + + channel2Node := make(map[string]string) + for _, node := range dataNodes { + for _, chstatus := range node.Channels { + channel2Node[chstatus.Name] = node.Address + } + } + + for ch, coll2seg := range m { + node, ok := channel2Node[ch] + if !ok { continue } + cli, err := c.sessionManager.getOrCreateSession(node) + if err != nil { + log.Warn("get session failed", zap.String("addr", node), zap.Error(err)) + continue + } + for coll, segs := range coll2seg { + req := &datapb.FlushSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Flush, + SourceID: Params.NodeID, + }, + CollectionID: coll, + SegmentIDs: segs, + } + resp, err := cli.FlushSegments(c.ctx, req) + if err != nil { + log.Warn("flush segment failed", zap.String("addr", node), zap.Error(err)) + continue + } + if resp.ErrorCode != commonpb.ErrorCode_Success { + log.Warn("flush segment failed", zap.String("dataNode", node), zap.Error(err)) + continue + } + log.Debug("flush segments succeed", zap.Any("segmentIDs", segs)) + } } } -func (c *dataNodeCluster) ShutDownClients() { - c.Lock() - defer c.Unlock() - for _, node := range c.nodes { - if err := node.client.Stop(); err != nil { - log.Error("stop client error", zap.Stringer("dataNode", node), zap.Error(err)) - continue - } - } -} - -// Clear only for test -func (c *dataNodeCluster) Clear() { - c.Lock() - defer c.Unlock() - c.nodes = make([]*dataNode, 0) +func (c *cluster) releaseSessions() { + c.mu.Lock() + defer c.mu.Unlock() + c.sessionManager.release() } diff --git a/internal/dataservice/cluster_session_manager.go b/internal/dataservice/cluster_session_manager.go index fc0a6696e1..ce13c98515 100644 --- a/internal/dataservice/cluster_session_manager.go +++ b/internal/dataservice/cluster_session_manager.go @@ -12,28 +12,34 @@ package dataservice import ( "sync" + "time" - grpcdatanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client" "github.com/milvus-io/milvus/internal/types" ) const retryTimes = 2 type sessionManager interface { - sendRequest(addr string, executor func(node types.DataNode) error) error + getOrCreateSession(addr string) (types.DataNode, error) + releaseSession(addr string) + release() } type clusterSessionManager struct { - mu sync.RWMutex - sessions map[string]types.DataNode + mu sync.RWMutex + sessions map[string]types.DataNode + dataClientCreator func(addr string, timeout time.Duration) (types.DataNode, error) } -func newClusterSessionManager() *clusterSessionManager { - return &clusterSessionManager{sessions: make(map[string]types.DataNode)} +func newClusterSessionManager(dataClientCreator func(addr string, timeout time.Duration) (types.DataNode, error)) *clusterSessionManager { + return &clusterSessionManager{ + sessions: make(map[string]types.DataNode), + dataClientCreator: dataClientCreator, + } } func (m *clusterSessionManager) createSession(addr string) error { - cli, err := grpcdatanodeclient.NewClient(addr, 0, []string{}, 0) + cli, err := m.dataClientCreator(addr, 0) if err != nil { return err } @@ -47,8 +53,13 @@ func (m *clusterSessionManager) createSession(addr string) error { return nil } -func (m *clusterSessionManager) getSession(addr string) types.DataNode { - return m.sessions[addr] +func (m *clusterSessionManager) getOrCreateSession(addr string) (types.DataNode, error) { + if !m.hasSession(addr) { + if err := m.createSession(addr); err != nil { + return nil, err + } + } + return m.sessions[addr], nil } func (m *clusterSessionManager) hasSession(addr string) bool { @@ -56,19 +67,17 @@ func (m *clusterSessionManager) hasSession(addr string) bool { return ok } -func (m *clusterSessionManager) sendRequest(addr string, executor func(node types.DataNode) error) error { - m.mu.Lock() - defer m.mu.Unlock() - success := false - var err error - for i := 0; !success && i < retryTimes; i++ { - if i != 0 || !m.hasSession(addr) { - m.createSession(addr) - } - err = executor(m.getSession(addr)) - if err == nil { - return nil - } +func (m *clusterSessionManager) releaseSession(addr string) { + cli, ok := m.sessions[addr] + if !ok { + return + } + _ = cli.Stop() + delete(m.sessions, addr) +} + +func (m *clusterSessionManager) release() { + for _, cli := range m.sessions { + _ = cli.Stop() } - return err } diff --git a/internal/dataservice/cluster_test.go b/internal/dataservice/cluster_test.go index 89ea077651..b19f569d00 100644 --- a/internal/dataservice/cluster_test.go +++ b/internal/dataservice/cluster_test.go @@ -11,90 +11,133 @@ package dataservice import ( + "context" "testing" - "github.com/milvus-io/milvus/internal/proto/internalpb" + memkv "github.com/milvus-io/milvus/internal/kv/mem" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/stretchr/testify/assert" - "golang.org/x/net/context" ) -func TestDataNodeClusterRegister(t *testing.T) { - Params.Init() - cluster := newDataNodeCluster() - dataNodeNum := 3 - ids := make([]int64, 0, dataNodeNum) - for i := 0; i < dataNodeNum; i++ { - c, err := newMockDataNodeClient(int64(i)) - assert.Nil(t, err) - err = c.Init() - assert.Nil(t, err) - err = c.Start() - assert.Nil(t, err) - cluster.Register(&dataNode{ - id: int64(i), - address: struct { - ip string - port int64 - }{"localhost", int64(9999 + i)}, - client: c, - channelNum: 0, - }) - ids = append(ids, int64(i)) +func TestClusterCreate(t *testing.T) { + cPolicy := newMockStartupPolicy() + cluster := createCluster(t, nil, withStartupPolicy(cPolicy)) + addr := "localhost:8080" + nodes := []*datapb.DataNodeInfo{ + { + Address: addr, + Version: 1, + Channels: []*datapb.ChannelStatus{}, + }, } - assert.EqualValues(t, dataNodeNum, cluster.GetNumOfNodes()) - assert.EqualValues(t, ids, cluster.GetNodeIDs()) - states, err := cluster.GetDataNodeStates(context.TODO()) + err := cluster.startup(nodes) assert.Nil(t, err) - assert.EqualValues(t, dataNodeNum, len(states)) - for _, s := range states { - assert.EqualValues(t, internalpb.StateCode_Healthy, s.StateCode) - } - cluster.ShutDownClients() - states, err = cluster.GetDataNodeStates(context.TODO()) - assert.Nil(t, err) - assert.EqualValues(t, dataNodeNum, len(states)) - for _, s := range states { - assert.EqualValues(t, internalpb.StateCode_Abnormal, s.StateCode) - } + dataNodes := cluster.dataManager.getDataNodes(true) + assert.EqualValues(t, 1, len(dataNodes)) + assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address) } -func TestWatchChannels(t *testing.T) { - Params.Init() - dataNodeNum := 3 - cases := []struct { - collectionID UniqueID - channels []string - channelNums []int - }{ - {1, []string{"c1"}, []int{1, 0, 0}}, - {1, []string{"c1", "c2", "c3"}, []int{1, 1, 1}}, - {1, []string{"c1", "c2", "c3", "c4"}, []int{2, 1, 1}}, - {1, []string{"c1", "c2", "c3", "c4", "c5", "c6", "c7"}, []int{3, 2, 2}}, +func TestRegister(t *testing.T) { + cPolicy := newMockStartupPolicy() + registerPolicy := newDoNothingRegisterPolicy() + cluster := createCluster(t, nil, withStartupPolicy(cPolicy), withRegisterPolicy(registerPolicy)) + addr := "localhost:8080" + + err := cluster.startup(nil) + assert.Nil(t, err) + cluster.register(&datapb.DataNodeInfo{ + Address: addr, + Version: 1, + Channels: []*datapb.ChannelStatus{}, + }) + dataNodes := cluster.dataManager.getDataNodes(true) + assert.EqualValues(t, 1, len(dataNodes)) + assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address) +} + +func TestUnregister(t *testing.T) { + cPolicy := newMockStartupPolicy() + unregisterPolicy := newDoNothingUnregisterPolicy() + cluster := createCluster(t, nil, withStartupPolicy(cPolicy), withUnregistorPolicy(unregisterPolicy)) + addr := "localhost:8080" + nodes := []*datapb.DataNodeInfo{ + { + Address: addr, + Version: 1, + Channels: []*datapb.ChannelStatus{}, + }, + } + err := cluster.startup(nodes) + assert.Nil(t, err) + dataNodes := cluster.dataManager.getDataNodes(true) + assert.EqualValues(t, 1, len(dataNodes)) + assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address) + cluster.unregister(&datapb.DataNodeInfo{ + Address: addr, + Version: 1, + Channels: []*datapb.ChannelStatus{}, + }) + dataNodes = cluster.dataManager.getDataNodes(false) + assert.EqualValues(t, 1, len(dataNodes)) + assert.EqualValues(t, offline, cluster.dataManager.dataNodes[addr].status) + assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address) +} + +func TestWatchIfNeeded(t *testing.T) { + cPolicy := newMockStartupPolicy() + cluster := createCluster(t, nil, withStartupPolicy(cPolicy)) + addr := "localhost:8080" + nodes := []*datapb.DataNodeInfo{ + { + Address: addr, + Version: 1, + Channels: []*datapb.ChannelStatus{}, + }, + } + err := cluster.startup(nodes) + assert.Nil(t, err) + dataNodes := cluster.dataManager.getDataNodes(true) + assert.EqualValues(t, 1, len(dataNodes)) + assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address) + + chName := "ch1" + cluster.watchIfNeeded(chName) + dataNodes = cluster.dataManager.getDataNodes(true) + assert.EqualValues(t, 1, len(dataNodes[addr].Channels)) + assert.EqualValues(t, chName, dataNodes[addr].Channels[0].Name) + cluster.watchIfNeeded(chName) + assert.EqualValues(t, 1, len(dataNodes[addr].Channels)) + assert.EqualValues(t, chName, dataNodes[addr].Channels[0].Name) +} + +func TestFlushSegments(t *testing.T) { + cPolicy := newMockStartupPolicy() + cluster := createCluster(t, nil, withStartupPolicy(cPolicy)) + addr := "localhost:8080" + nodes := []*datapb.DataNodeInfo{ + { + Address: addr, + Version: 1, + Channels: []*datapb.ChannelStatus{}, + }, + } + err := cluster.startup(nodes) + assert.Nil(t, err) + segments := []*datapb.SegmentInfo{ + { + ID: 0, + CollectionID: 0, + InsertChannel: "ch1", + }, } - cluster := newDataNodeCluster() - for _, c := range cases { - for i := 0; i < dataNodeNum; i++ { - c, err := newMockDataNodeClient(int64(i)) - assert.Nil(t, err) - err = c.Init() - assert.Nil(t, err) - err = c.Start() - assert.Nil(t, err) - cluster.Register(&dataNode{ - id: int64(i), - address: struct { - ip string - port int64 - }{"localhost", int64(9999 + i)}, - client: c, - channelNum: 0, - }) - } - cluster.WatchInsertChannels(c.channels) - for i := 0; i < len(cluster.nodes); i++ { - assert.EqualValues(t, c.channelNums[i], cluster.nodes[i].channelNum) - } - cluster.Clear() - } + cluster.flush(segments) +} + +func createCluster(t *testing.T, ch chan interface{}, options ...clusterOption) *cluster { + kv := memkv.NewMemoryKV() + sessionManager := newMockSessionManager(ch) + dataManager, err := newClusterNodeManager(kv) + assert.Nil(t, err) + return newCluster(context.TODO(), dataManager, sessionManager, options...) } diff --git a/internal/dataservice/cluster_v2.go b/internal/dataservice/cluster_v2.go deleted file mode 100644 index 20777d3bb7..0000000000 --- a/internal/dataservice/cluster_v2.go +++ /dev/null @@ -1,229 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. -package dataservice - -import ( - "sync" - - "github.com/milvus-io/milvus/internal/log" - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/types" - "go.uber.org/zap" - "golang.org/x/net/context" -) - -type cluster struct { - mu sync.RWMutex - dataManager *clusterNodeManager - sessionManager sessionManager - - startupPolicy clusterStartupPolicy - registerPolicy dataNodeRegisterPolicy - unregisterPolicy dataNodeUnregisterPolicy - assginPolicy channelAssignPolicy -} - -type clusterOption struct { - apply func(c *cluster) -} - -func withStartupPolicy(p clusterStartupPolicy) clusterOption { - return clusterOption{ - apply: func(c *cluster) { c.startupPolicy = p }, - } -} - -func withRegisterPolicy(p dataNodeRegisterPolicy) clusterOption { - return clusterOption{ - apply: func(c *cluster) { c.registerPolicy = p }, - } -} - -func withUnregistorPolicy(p dataNodeUnregisterPolicy) clusterOption { - return clusterOption{ - apply: func(c *cluster) { c.unregisterPolicy = p }, - } -} - -func withAssignPolicy(p channelAssignPolicy) clusterOption { - return clusterOption{ - apply: func(c *cluster) { c.assginPolicy = p }, - } -} - -func defaultStartupPolicy() clusterStartupPolicy { - return newReWatchOnRestartsStartupPolicy() -} - -func defaultRegisterPolicy() dataNodeRegisterPolicy { - return newDoNothingRegisterPolicy() -} - -func defaultUnregisterPolicy() dataNodeUnregisterPolicy { - return newDoNothingUnregisterPolicy() -} - -func defaultAssignPolicy() channelAssignPolicy { - return newAllAssignPolicy() -} - -func newCluster(dataManager *clusterNodeManager, sessionManager sessionManager, opts ...clusterOption) *cluster { - c := &cluster{ - dataManager: dataManager, - sessionManager: sessionManager, - } - c.startupPolicy = defaultStartupPolicy() - c.registerPolicy = defaultRegisterPolicy() - c.unregisterPolicy = defaultUnregisterPolicy() - c.assginPolicy = defaultAssignPolicy() - - for _, opt := range opts { - opt.apply(c) - } - - return c -} - -func (c *cluster) startup(dataNodes []*datapb.DataNodeInfo) error { - deltaChange := c.dataManager.updateCluster(dataNodes) - nodes := c.dataManager.getDataNodes(false) - rets := c.startupPolicy.apply(nodes, deltaChange) - c.dataManager.updateDataNodes(rets) - rets = c.watch(rets) - c.dataManager.updateDataNodes(rets) - return nil -} - -func (c *cluster) watch(nodes []*datapb.DataNodeInfo) []*datapb.DataNodeInfo { - for _, n := range nodes { - uncompletes := make([]string, 0) - for _, ch := range n.Channels { - if ch.State == datapb.ChannelWatchState_Uncomplete { - uncompletes = append(uncompletes, ch.Name) - } - } - executor := func(cli types.DataNode) error { - req := &datapb.WatchDmChannelsRequest{ - Base: &commonpb.MsgBase{ - SourceID: Params.NodeID, - }, - // ChannelNames: uncompletes, // TODO - } - resp, err := cli.WatchDmChannels(context.Background(), req) - if err != nil { - return err - } - if resp.ErrorCode != commonpb.ErrorCode_Success { - log.Warn("watch channels failed", zap.String("address", n.Address), zap.Error(err)) - return nil - } - for _, ch := range n.Channels { - if ch.State == datapb.ChannelWatchState_Uncomplete { - ch.State = datapb.ChannelWatchState_Complete - } - } - return nil - } - - if err := c.sessionManager.sendRequest(n.Address, executor); err != nil { - log.Warn("watch channels failed", zap.String("address", n.Address), zap.Error(err)) - } - } - return nodes -} - -func (c *cluster) register(n *datapb.DataNodeInfo) { - c.mu.Lock() - defer c.mu.Unlock() - c.dataManager.register(n) - cNodes := c.dataManager.getDataNodes(true) - rets := c.registerPolicy.apply(cNodes, n) - c.dataManager.updateDataNodes(rets) - rets = c.watch(rets) - c.dataManager.updateDataNodes(rets) -} - -func (c *cluster) unregister(n *datapb.DataNodeInfo) { - c.mu.Lock() - defer c.mu.Unlock() - c.dataManager.unregister(n) - cNodes := c.dataManager.getDataNodes(true) - rets := c.unregisterPolicy.apply(cNodes, n) - c.dataManager.updateDataNodes(rets) - rets = c.watch(rets) - c.dataManager.updateDataNodes(rets) -} - -func (c *cluster) watchIfNeeded(channel string) { - c.mu.Lock() - defer c.mu.Unlock() - cNodes := c.dataManager.getDataNodes(true) - rets := c.assginPolicy.apply(cNodes, channel) - c.dataManager.updateDataNodes(rets) - rets = c.watch(rets) - c.dataManager.updateDataNodes(rets) -} - -func (c *cluster) flush(segments []*datapb.SegmentInfo) { - c.mu.Lock() - defer c.mu.Unlock() - - m := make(map[string]map[UniqueID][]UniqueID) // channel-> map[collectionID]segmentIDs - - for _, seg := range segments { - if _, ok := m[seg.InsertChannel]; !ok { - m[seg.InsertChannel] = make(map[UniqueID][]UniqueID) - } - - m[seg.InsertChannel][seg.CollectionID] = append(m[seg.InsertChannel][seg.CollectionID], seg.ID) - } - - dataNodes := c.dataManager.getDataNodes(true) - - channel2Node := make(map[string]string) - for _, node := range dataNodes { - for _, chstatus := range node.Channels { - channel2Node[chstatus.Name] = node.Address - } - } - - for ch, coll2seg := range m { - node, ok := channel2Node[ch] - if !ok { - continue - } - for coll, segs := range coll2seg { - executor := func(cli types.DataNode) error { - req := &datapb.FlushSegmentsRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Flush, - SourceID: Params.NodeID, - }, - CollectionID: coll, - SegmentIDs: segs, - } - resp, err := cli.FlushSegments(context.Background(), req) - if err != nil { - return err - } - if resp.ErrorCode != commonpb.ErrorCode_Success { - log.Warn("flush segment error", zap.String("dataNode", node), zap.Error(err)) - } - - return nil - } - if err := c.sessionManager.sendRequest(node, executor); err != nil { - log.Warn("flush segment error", zap.String("dataNode", node), zap.Error(err)) - } - } - } -} diff --git a/internal/dataservice/cluster_v2_test.go b/internal/dataservice/cluster_v2_test.go deleted file mode 100644 index e1e2156e58..0000000000 --- a/internal/dataservice/cluster_v2_test.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright (C) 2019-2020 Zilliz. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software distributed under the License -// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -// or implied. See the License for the specific language governing permissions and limitations under the License. -package dataservice - -import ( - "testing" - - memkv "github.com/milvus-io/milvus/internal/kv/mem" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/stretchr/testify/assert" -) - -func TestClusterCreate(t *testing.T) { - cPolicy := newMockStartupPolicy() - cluster := createCluster(t, withStartupPolicy(cPolicy)) - addr := "localhost:8080" - nodes := []*datapb.DataNodeInfo{ - { - Address: addr, - Version: 1, - Channels: []*datapb.ChannelStatus{}, - }, - } - err := cluster.startup(nodes) - assert.Nil(t, err) - dataNodes := cluster.dataManager.getDataNodes(true) - assert.EqualValues(t, 1, len(dataNodes)) - assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address) -} - -func TestRegister(t *testing.T) { - cPolicy := newMockStartupPolicy() - registerPolicy := newDoNothingRegisterPolicy() - cluster := createCluster(t, withStartupPolicy(cPolicy), withRegisterPolicy(registerPolicy)) - addr := "localhost:8080" - - err := cluster.startup(nil) - assert.Nil(t, err) - cluster.register(&datapb.DataNodeInfo{ - Address: addr, - Version: 1, - Channels: []*datapb.ChannelStatus{}, - }) - dataNodes := cluster.dataManager.getDataNodes(true) - assert.EqualValues(t, 1, len(dataNodes)) - assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address) -} - -func TestUnregister(t *testing.T) { - cPolicy := newMockStartupPolicy() - unregisterPolicy := newDoNothingUnregisterPolicy() - cluster := createCluster(t, withStartupPolicy(cPolicy), withUnregistorPolicy(unregisterPolicy)) - addr := "localhost:8080" - nodes := []*datapb.DataNodeInfo{ - { - Address: addr, - Version: 1, - Channels: []*datapb.ChannelStatus{}, - }, - } - err := cluster.startup(nodes) - assert.Nil(t, err) - dataNodes := cluster.dataManager.getDataNodes(true) - assert.EqualValues(t, 1, len(dataNodes)) - assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address) - cluster.unregister(&datapb.DataNodeInfo{ - Address: addr, - Version: 1, - Channels: []*datapb.ChannelStatus{}, - }) - dataNodes = cluster.dataManager.getDataNodes(false) - assert.EqualValues(t, 1, len(dataNodes)) - assert.EqualValues(t, offline, cluster.dataManager.dataNodes[addr].status) - assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address) -} - -func TestWatchIfNeeded(t *testing.T) { - cPolicy := newMockStartupPolicy() - cluster := createCluster(t, withStartupPolicy(cPolicy)) - addr := "localhost:8080" - nodes := []*datapb.DataNodeInfo{ - { - Address: addr, - Version: 1, - Channels: []*datapb.ChannelStatus{}, - }, - } - err := cluster.startup(nodes) - assert.Nil(t, err) - dataNodes := cluster.dataManager.getDataNodes(true) - assert.EqualValues(t, 1, len(dataNodes)) - assert.EqualValues(t, "localhost:8080", dataNodes[addr].Address) - - chName := "ch1" - cluster.watchIfNeeded(chName) - dataNodes = cluster.dataManager.getDataNodes(true) - assert.EqualValues(t, 1, len(dataNodes[addr].Channels)) - assert.EqualValues(t, chName, dataNodes[addr].Channels[0].Name) - cluster.watchIfNeeded(chName) - assert.EqualValues(t, 1, len(dataNodes[addr].Channels)) - assert.EqualValues(t, chName, dataNodes[addr].Channels[0].Name) -} - -func TestFlushSegments(t *testing.T) { - cPolicy := newMockStartupPolicy() - cluster := createCluster(t, withStartupPolicy(cPolicy)) - addr := "localhost:8080" - nodes := []*datapb.DataNodeInfo{ - { - Address: addr, - Version: 1, - Channels: []*datapb.ChannelStatus{}, - }, - } - err := cluster.startup(nodes) - assert.Nil(t, err) - segments := []*datapb.SegmentInfo{ - { - ID: 0, - CollectionID: 0, - InsertChannel: "ch1", - }, - } - - cluster.flush(segments) -} - -func createCluster(t *testing.T, options ...clusterOption) *cluster { - kv := memkv.NewMemoryKV() - sessionManager := newMockSessionManager() - dataManager, err := newClusterNodeManager(kv) - assert.Nil(t, err) - return newCluster(dataManager, sessionManager, options...) -} diff --git a/internal/dataservice/grpc_handler.go b/internal/dataservice/grpc_handler.go index 6516dffda1..3097b92ff5 100644 --- a/internal/dataservice/grpc_handler.go +++ b/internal/dataservice/grpc_handler.go @@ -26,12 +26,13 @@ func (s *Server) GetComponentStates(ctx context.Context) (*internalpb.ComponentS ErrorCode: commonpb.ErrorCode_UnexpectedError, }, } - dataNodeStates, err := s.cluster.GetDataNodeStates(ctx) - if err != nil { - resp.Status.Reason = err.Error() - return resp, nil - } - resp.SubcomponentStates = dataNodeStates + // todo GetComponentStates need to be removed + //dataNodeStates, err := s.cluster.GetDataNodeStates(ctx) + //if err != nil { + //resp.Status.Reason = err.Error() + //return resp, nil + //} + //resp.SubcomponentStates = dataNodeStates resp.Status.ErrorCode = commonpb.ErrorCode_Success return resp, nil } @@ -55,58 +56,9 @@ func (s *Server) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp } func (s *Server) RegisterNode(ctx context.Context, req *datapb.RegisterNodeRequest) (*datapb.RegisterNodeResponse, error) { - ret := &datapb.RegisterNodeResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, - } - log.Debug("DataService: RegisterNode:", - zap.String("IP", req.Address.Ip), - zap.Int64("Port", req.Address.Port)) - - node, err := s.newDataNode(req.Address.Ip, req.Address.Port, req.Base.SourceID) - if err != nil { - ret.Status.Reason = err.Error() - return ret, nil - } - - resp, err := node.client.WatchDmChannels(s.ctx, &datapb.WatchDmChannelsRequest{ - Base: &commonpb.MsgBase{ - MsgType: 0, - MsgID: 0, - Timestamp: 0, - SourceID: Params.NodeID, - }, - // ChannelNames: s.insertChannels, // TODO - }) - - if err = VerifyResponse(resp, err); err != nil { - ret.Status.Reason = err.Error() - return ret, nil - } - - if err := s.getDDChannel(); err != nil { - ret.Status.Reason = err.Error() - return ret, nil - } - - if err = s.cluster.Register(node); err != nil { - ret.Status.Reason = err.Error() - return ret, nil - } - - ret.Status.ErrorCode = commonpb.ErrorCode_Success - ret.InitParams = &internalpb.InitParams{ - NodeID: Params.NodeID, - StartParams: []*commonpb.KeyValuePair{ - {Key: "DDChannelName", Value: s.ddChannelMu.name}, - {Key: "SegmentStatisticsChannelName", Value: Params.StatisticsChannelName}, - {Key: "TimeTickChannelName", Value: Params.TimeTickChannelName}, - {Key: "CompleteFlushChannelName", Value: Params.SegmentInfoChannelName}, - }, - } - return ret, nil + return nil, nil } + func (s *Server) Flush(ctx context.Context, req *datapb.FlushRequest) (*commonpb.Status, error) { if !s.checkStateIsHealthy() { return &commonpb.Status{ @@ -192,6 +144,7 @@ func (s *Server) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI SegIDAssignments: assigns, }, nil } + func (s *Server) ShowSegments(ctx context.Context, req *datapb.ShowSegmentsRequest) (*datapb.ShowSegmentsResponse, error) { resp := &datapb.ShowSegmentsResponse{ Status: &commonpb.Status{ @@ -280,7 +233,7 @@ func (s *Server) GetInsertChannels(ctx context.Context, req *datapb.GetInsertCha Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, - Values: s.insertChannels, + Values: []string{}, }, nil } diff --git a/internal/dataservice/mock_test.go b/internal/dataservice/mock_test.go index 2a49ce94e4..d9d1c66053 100644 --- a/internal/dataservice/mock_test.go +++ b/internal/dataservice/mock_test.go @@ -70,10 +70,11 @@ type mockDataNodeClient struct { ch chan interface{} } -func newMockDataNodeClient(id int64) (*mockDataNodeClient, error) { +func newMockDataNodeClient(id int64, ch chan interface{}) (*mockDataNodeClient, error) { return &mockDataNodeClient{ id: id, state: internalpb.StateCode_Initializing, + ch: ch, }, nil } @@ -301,12 +302,21 @@ func (p *mockStartupPolicy) apply(oldCluster map[string]*datapb.DataNodeInfo, de } type mockSessionManager struct { + ch chan interface{} } -func newMockSessionManager() sessionManager { - return &mockSessionManager{} +func newMockSessionManager(ch chan interface{}) sessionManager { + return &mockSessionManager{ + ch: ch, + } } -func (m *mockSessionManager) sendRequest(addr string, executor func(node types.DataNode) error) error { - return nil +func (m *mockSessionManager) getOrCreateSession(addr string) (types.DataNode, error) { + return newMockDataNodeClient(0, m.ch) +} + +func (m *mockSessionManager) releaseSession(addr string) { + +} +func (m *mockSessionManager) release() { } diff --git a/internal/dataservice/server.go b/internal/dataservice/server.go index cf72bdeab7..1d7636e964 100644 --- a/internal/dataservice/server.go +++ b/internal/dataservice/server.go @@ -12,7 +12,6 @@ package dataservice import ( "context" - "errors" "fmt" "math/rand" "strconv" @@ -20,9 +19,9 @@ import ( "sync/atomic" "time" + grpcdatanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client" "github.com/milvus-io/milvus/internal/logutil" - grpcdatanodeclient "github.com/milvus-io/milvus/internal/distributed/datanode/client" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/msgstream" @@ -51,39 +50,43 @@ type Server struct { serverLoopCancel context.CancelFunc serverLoopWg sync.WaitGroup state atomic.Value - kvClient *etcdkv.EtcdKV - meta *meta - segAllocator segmentAllocatorInterface - statsHandler *statsHandler - allocator allocatorInterface - cluster *dataNodeCluster - masterClient types.MasterService - ddChannelMu struct { + initOnce sync.Once + startOnce sync.Once + stopOnce sync.Once + + kvClient *etcdkv.EtcdKV + meta *meta + segmentInfoStream msgstream.MsgStream + segAllocator segmentAllocatorInterface + statsHandler *statsHandler + allocator allocatorInterface + cluster *cluster + masterClient types.MasterService + ddChannelMu struct { sync.Mutex name string } - session *sessionutil.Session - flushMsgStream msgstream.MsgStream - insertChannels []string - msFactory msgstream.Factory - createDataNodeClient func(addr string) (types.DataNode, error) + + flushMsgStream msgstream.MsgStream + msFactory msgstream.Factory + + session *sessionutil.Session + activeCh <-chan bool + watchCh <-chan *sessionutil.SessionEvent + + dataClientCreator func(addr string) (types.DataNode, error) } func CreateServer(ctx context.Context, factory msgstream.Factory) (*Server, error) { rand.Seed(time.Now().UnixNano()) s := &Server{ ctx: ctx, - cluster: newDataNodeCluster(), msFactory: factory, } - s.insertChannels = s.getInsertChannels() - s.createDataNodeClient = func(addr string) (types.DataNode, error) { - node, err := grpcdatanodeclient.NewClient(addr, 10*time.Second) - if err != nil { - return nil, err - } - return node, nil + s.dataClientCreator = func(addr string) (types.DataNode, error) { + return grpcdatanodeclient.NewClient(addr) } + s.UpdateStateCode(internalpb.StateCode_Abnormal) log.Debug("DataService", zap.Any("State", s.state.Load())) return s, nil @@ -104,63 +107,116 @@ func (s *Server) SetMasterClient(masterClient types.MasterService) { // Register register data service at etcd func (s *Server) Register() error { - s.session = sessionutil.NewSession(s.ctx, Params.MetaRootPath, []string{Params.EtcdAddress}) - s.session.Init(typeutil.DataServiceRole, Params.IP, true) + s.activeCh = s.session.Init(typeutil.DataServiceRole, Params.IP, true) Params.NodeID = s.session.ServerID return nil } func (s *Server) Init() error { + s.initOnce.Do(func() { + s.session = sessionutil.NewSession(s.ctx, []string{Params.EtcdAddress}) + }) return nil } +var startOnce sync.Once + func (s *Server) Start() error { var err error - m := map[string]interface{}{ - "PulsarAddress": Params.PulsarAddress, - "ReceiveBufSize": 1024, - "PulsarBufSize": 1024} - err = s.msFactory.SetParams(m) + s.startOnce.Do(func() { + m := map[string]interface{}{ + "PulsarAddress": Params.PulsarAddress, + "ReceiveBufSize": 1024, + "PulsarBufSize": 1024} + err = s.msFactory.SetParams(m) + if err != nil { + return + } + + if err = s.initMeta(); err != nil { + return + } + + if err = s.initCluster(); err != nil { + return + } + + if err = s.initSegmentInfoChannel(); err != nil { + return + } + + s.allocator = newAllocator(s.masterClient) + + s.startSegmentAllocator() + s.statsHandler = newStatsHandler(s.meta) + if err = s.initFlushMsgStream(); err != nil { + return + } + + if err = s.initServiceDiscovery(); err != nil { + return + } + + s.startServerLoop() + + s.UpdateStateCode(internalpb.StateCode_Healthy) + log.Debug("start success") + }) + return err +} + +func (s *Server) initCluster() error { + dManager, err := newClusterNodeManager(s.kvClient) if err != nil { return err } + sManager := newClusterSessionManager(s.dataClientCreator) + s.cluster = newCluster(s.ctx, dManager, sManager) + return nil +} - if err := s.initMeta(); err != nil { +func (s *Server) initServiceDiscovery() error { + sessions, rev, err := s.session.GetSessions(typeutil.DataNodeRole) + if err != nil { log.Debug("DataService initMeta failed", zap.Error(err)) return err } + log.Debug("registered sessions", zap.Any("sessions", sessions)) - s.allocator = newAllocator(s.masterClient) + datanodes := make([]*datapb.DataNodeInfo, 0, len(sessions)) + for _, session := range sessions { + datanodes = append(datanodes, &datapb.DataNodeInfo{ + Address: session.Address, + Version: session.ServerID, + Channels: []*datapb.ChannelStatus{}, + }) + } - s.startSegmentAllocator() - s.statsHandler = newStatsHandler(s.meta) - if err = s.loadMetaFromMaster(); err != nil { + if err := s.cluster.startup(datanodes); err != nil { log.Debug("DataService loadMetaFromMaster failed", zap.Error(err)) return err } - if err = s.initMsgProducer(); err != nil { - log.Debug("DataService initMsgProducer failed", zap.Error(err)) - return err - } - s.startServerLoop() - s.UpdateStateCode(internalpb.StateCode_Healthy) - log.Debug("start success") - log.Debug("DataService", zap.Any("State", s.state.Load())) + + s.watchCh = s.session.WatchServices(typeutil.DataNodeRole, rev) + return nil } func (s *Server) startSegmentAllocator() { - stream := s.initSegmentInfoChannel() - helper := createNewSegmentHelper(stream) + helper := createNewSegmentHelper(s.segmentInfoStream) s.segAllocator = newSegmentAllocator(s.meta, s.allocator, withAllocHelper(helper)) } -func (s *Server) initSegmentInfoChannel() msgstream.MsgStream { - segmentInfoStream, _ := s.msFactory.NewMsgStream(s.ctx) - segmentInfoStream.AsProducer([]string{Params.SegmentInfoChannelName}) +func (s *Server) initSegmentInfoChannel() error { + var err error + s.segmentInfoStream, err = s.msFactory.NewMsgStream(s.ctx) + if err != nil { + return err + } + s.segmentInfoStream.AsProducer([]string{Params.SegmentInfoChannelName}) log.Debug("DataService AsProducer: " + Params.SegmentInfoChannelName) - segmentInfoStream.Start() - return segmentInfoStream + s.segmentInfoStream.Start() + return nil } func (s *Server) UpdateStateCode(code internalpb.StateCode) { @@ -189,7 +245,7 @@ func (s *Server) initMeta() error { return retry.Retry(100000, time.Millisecond*200, connectEtcdFn) } -func (s *Server) initMsgProducer() error { +func (s *Server) initFlushMsgStream() error { var err error // segment flush stream s.flushMsgStream, err = s.msFactory.NewMsgStream(s.ctx) @@ -203,72 +259,6 @@ func (s *Server) initMsgProducer() error { return nil } -func (s *Server) loadMetaFromMaster() error { - ctx := context.Background() - log.Debug("loading collection meta from master") - var err error - if err = s.checkMasterIsHealthy(); err != nil { - return err - } - if err = s.getDDChannel(); err != nil { - return err - } - collections, err := s.masterClient.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ShowCollections, - MsgID: -1, // todo add msg id - Timestamp: 0, // todo - SourceID: Params.NodeID, - }, - DbName: "", - }) - if err = VerifyResponse(collections, err); err != nil { - return err - } - for _, collectionName := range collections.CollectionNames { - collection, err := s.masterClient.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_DescribeCollection, - MsgID: -1, // todo - Timestamp: 0, // todo - SourceID: Params.NodeID, - }, - DbName: "", - CollectionName: collectionName, - }) - if err = VerifyResponse(collection, err); err != nil { - log.Error("describe collection error", zap.String("collectionName", collectionName), zap.Error(err)) - continue - } - partitions, err := s.masterClient.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_ShowPartitions, - MsgID: -1, // todo - Timestamp: 0, // todo - SourceID: Params.NodeID, - }, - DbName: "", - CollectionName: collectionName, - CollectionID: collection.CollectionID, - }) - if err = VerifyResponse(partitions, err); err != nil { - log.Error("show partitions error", zap.String("collectionName", collectionName), zap.Int64("collectionID", collection.CollectionID), zap.Error(err)) - continue - } - err = s.meta.AddCollection(&datapb.CollectionInfo{ - ID: collection.CollectionID, - Schema: collection.Schema, - Partitions: partitions.PartitionIDs, - }) - if err != nil { - log.Error("add collection to meta error", zap.Int64("collectionID", collection.CollectionID), zap.Error(err)) - continue - } - } - log.Debug("load collection meta from master complete") - return nil -} - func (s *Server) getDDChannel() error { s.ddChannelMu.Lock() defer s.ddChannelMu.Unlock() @@ -282,37 +272,13 @@ func (s *Server) getDDChannel() error { return nil } -func (s *Server) checkMasterIsHealthy() error { - ticker := time.NewTicker(300 * time.Millisecond) - ctx, cancel := context.WithTimeout(s.ctx, 30*time.Second) - defer func() { - ticker.Stop() - cancel() - }() - for { - var resp *internalpb.ComponentStates - var err error - select { - case <-ctx.Done(): - return errors.New("master is not healthy") - case <-ticker.C: - resp, err = s.masterClient.GetComponentStates(ctx) - if err = VerifyResponse(resp, err); err != nil { - return err - } - } - if resp.State.StateCode == internalpb.StateCode_Healthy { - break - } - } - return nil -} - func (s *Server) startServerLoop() { s.serverLoopCtx, s.serverLoopCancel = context.WithCancel(s.ctx) - s.serverLoopWg.Add(2) + s.serverLoopWg.Add(4) go s.startStatsChannel(s.serverLoopCtx) go s.startDataNodeTtLoop(s.serverLoopCtx) + go s.startWatchService(s.serverLoopCtx) + go s.startActiveCheck(s.serverLoopCtx) } func (s *Server) startStatsChannel(ctx context.Context) { @@ -404,7 +370,6 @@ func (s *Server) startDataNodeTtLoop(ctx context.Context) { } ttMsg := msg.(*msgstream.DataNodeTtMsg) - coll2Segs := make(map[UniqueID][]UniqueID) ch := ttMsg.ChannelName ts := ttMsg.Timestamp segments, err := s.segAllocator.GetFlushableSegments(ctx, ch, ts) @@ -412,6 +377,9 @@ func (s *Server) startDataNodeTtLoop(ctx context.Context) { log.Warn("get flushable segments failed", zap.Error(err)) continue } + + log.Debug("flushable segments", zap.Any("segments", segments)) + segmentInfos := make([]*datapb.SegmentInfo, 0, len(segments)) for _, id := range segments { sInfo, err := s.meta.GetSegment(id) if err != nil { @@ -419,35 +387,74 @@ func (s *Server) startDataNodeTtLoop(ctx context.Context) { zap.Error(err)) continue } - collID, segID := sInfo.CollectionID, sInfo.ID - coll2Segs[collID] = append(coll2Segs[collID], segID) + segmentInfos = append(segmentInfos, sInfo) } - for collID, segIDs := range coll2Segs { - s.cluster.FlushSegment(&datapb.FlushSegmentsRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Flush, - MsgID: -1, // todo add msg id - Timestamp: 0, // todo - SourceID: Params.NodeID, - }, - CollectionID: collID, - SegmentIDs: segIDs, - }) + s.cluster.flush(segmentInfos) + } + } +} + +func (s *Server) startWatchService(ctx context.Context) { + defer s.serverLoopWg.Done() + for { + select { + case <-ctx.Done(): + log.Debug("watch service shutdown") + return + case event := <-s.watchCh: + datanode := &datapb.DataNodeInfo{ + Address: event.Session.Address, + Version: event.Session.ServerID, + Channels: []*datapb.ChannelStatus{}, + } + switch event.EventType { + case sessionutil.SessionAddEvent: + s.cluster.register(datanode) + case sessionutil.SessionDelEvent: + s.cluster.unregister(datanode) + default: + log.Warn("receive unknown service event type", + zap.Any("type", event.EventType)) } } } } +func (s *Server) startActiveCheck(ctx context.Context) { + defer s.serverLoopWg.Done() + + for { + select { + case _, ok := <-s.activeCh: + if ok { + continue + } + s.Stop() + log.Debug("disconnect with etcd") + return + case <-ctx.Done(): + log.Debug("connection check shutdown") + return + } + } +} + +var stopOnce sync.Once + func (s *Server) Stop() error { - s.cluster.ShutDownClients() - s.flushMsgStream.Close() - s.stopServerLoop() + s.stopOnce.Do(func() { + s.cluster.releaseSessions() + s.segmentInfoStream.Close() + s.flushMsgStream.Close() + s.stopServerLoop() + }) return nil } // CleanMeta only for test func (s *Server) CleanMeta() error { + log.Debug("clean meta", zap.Any("kv", s.kvClient)) return s.kvClient.RemoveWithPrefix("") } @@ -456,29 +463,6 @@ func (s *Server) stopServerLoop() { s.serverLoopWg.Wait() } -func (s *Server) newDataNode(ip string, port int64, id UniqueID) (*dataNode, error) { - client, err := s.createDataNodeClient(fmt.Sprintf("%s:%d", ip, port)) - if err != nil { - return nil, err - } - if err := client.Init(); err != nil { - return nil, err - } - - if err := client.Start(); err != nil { - return nil, err - } - return &dataNode{ - id: id, - address: struct { - ip string - port int64 - }{ip: ip, port: port}, - client: client, - channelNum: 0, - }, nil -} - //func (s *Server) validateAllocRequest(collID UniqueID, partID UniqueID, channelName string) error { // if !s.meta.HasCollection(collID) { // return fmt.Errorf("can not find collection %d", collID) diff --git a/internal/dataservice/server_test.go b/internal/dataservice/server_test.go index f98f7703b7..c1183fe887 100644 --- a/internal/dataservice/server_test.go +++ b/internal/dataservice/server_test.go @@ -32,32 +32,8 @@ import ( "go.uber.org/zap" ) -func TestRegisterNode(t *testing.T) { - svr := newTestServer(t) - defer closeTestServer(t, svr) - t.Run("register node", func(t *testing.T) { - resp, err := svr.RegisterNode(context.TODO(), &datapb.RegisterNodeRequest{ - Base: &commonpb.MsgBase{ - MsgType: 0, - MsgID: 0, - Timestamp: 0, - SourceID: 1000, - }, - Address: &commonpb.Address{ - Ip: "localhost", - Port: 1000, - }, - }) - assert.Nil(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - assert.EqualValues(t, 1, svr.cluster.GetNumOfNodes()) - assert.EqualValues(t, []int64{1000}, svr.cluster.GetNodeIDs()) - }) - -} - func TestGetSegmentInfoChannel(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) t.Run("get segment info channel", func(t *testing.T) { resp, err := svr.GetSegmentInfoChannel(context.TODO()) @@ -67,26 +43,6 @@ func TestGetSegmentInfoChannel(t *testing.T) { }) } -func TestGetInsertChannels(t *testing.T) { - svr := newTestServer(t) - defer closeTestServer(t, svr) - t.Run("get insert channels", func(t *testing.T) { - resp, err := svr.GetInsertChannels(context.TODO(), &datapb.GetInsertChannelsRequest{ - Base: &commonpb.MsgBase{ - MsgType: 0, - MsgID: 0, - Timestamp: 0, - SourceID: 1000, - }, - DbID: 0, - CollectionID: 0, - }) - assert.Nil(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - assert.EqualValues(t, svr.getInsertChannels(), resp.Values) - }) -} - func TestAssignSegmentID(t *testing.T) { const collID = 100 const collIDInvalid = 101 @@ -94,7 +50,7 @@ func TestAssignSegmentID(t *testing.T) { const channel0 = "channel0" const channel1 = "channel1" - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) schema := newTestSchema() svr.meta.AddCollection(&datapb.CollectionInfo{ @@ -151,7 +107,7 @@ func TestAssignSegmentID(t *testing.T) { } func TestShowSegments(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) segments := []struct { id UniqueID @@ -202,7 +158,7 @@ func TestShowSegments(t *testing.T) { } func TestFlush(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) schema := newTestSchema() err := svr.meta.AddCollection(&datapb.CollectionInfo{ @@ -231,40 +187,39 @@ func TestFlush(t *testing.T) { assert.EqualValues(t, segID, ids[0]) } -func TestGetComponentStates(t *testing.T) { - svr := newTestServer(t) - defer closeTestServer(t, svr) - cli, err := newMockDataNodeClient(1) - assert.Nil(t, err) - err = cli.Init() - assert.Nil(t, err) - err = cli.Start() - assert.Nil(t, err) +//func TestGetComponentStates(t *testing.T) { +//svr := newTestServer(t) +//defer closeTestServer(t, svr) +//cli := newMockDataNodeClient(1) +//err := cli.Init() +//assert.Nil(t, err) +//err = cli.Start() +//assert.Nil(t, err) - err = svr.cluster.Register(&dataNode{ - id: 1, - address: struct { - ip string - port int64 - }{ - ip: "", - port: 0, - }, - client: cli, - channelNum: 0, - }) - assert.Nil(t, err) +//err = svr.cluster.Register(&dataNode{ +//id: 1, +//address: struct { +//ip string +//port int64 +//}{ +//ip: "", +//port: 0, +//}, +//client: cli, +//channelNum: 0, +//}) +//assert.Nil(t, err) - resp, err := svr.GetComponentStates(context.TODO()) - assert.Nil(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - assert.EqualValues(t, internalpb.StateCode_Healthy, resp.State.StateCode) - assert.EqualValues(t, 1, len(resp.SubcomponentStates)) - assert.EqualValues(t, internalpb.StateCode_Healthy, resp.SubcomponentStates[0].StateCode) -} +//resp, err := svr.GetComponentStates(context.TODO()) +//assert.Nil(t, err) +//assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) +//assert.EqualValues(t, internalpb.StateCode_Healthy, resp.State.StateCode) +//assert.EqualValues(t, 1, len(resp.SubcomponentStates)) +//assert.EqualValues(t, internalpb.StateCode_Healthy, resp.SubcomponentStates[0].StateCode) +//} func TestGetTimeTickChannel(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) resp, err := svr.GetTimeTickChannel(context.TODO()) assert.Nil(t, err) @@ -273,7 +228,7 @@ func TestGetTimeTickChannel(t *testing.T) { } func TestGetStatisticsChannel(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) resp, err := svr.GetStatisticsChannel(context.TODO()) assert.Nil(t, err) @@ -282,7 +237,7 @@ func TestGetStatisticsChannel(t *testing.T) { } func TestGetSegmentStates(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) err := svr.meta.AddSegment(&datapb.SegmentInfo{ ID: 1000, @@ -339,7 +294,7 @@ func TestGetSegmentStates(t *testing.T) { } func TestGetInsertBinlogPaths(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) req := &datapb.GetInsertBinlogPathsRequest{ @@ -351,7 +306,7 @@ func TestGetInsertBinlogPaths(t *testing.T) { } func TestGetCollectionStatistics(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) req := &datapb.GetCollectionStatisticsRequest{ @@ -363,7 +318,7 @@ func TestGetCollectionStatistics(t *testing.T) { } func TestGetSegmentInfo(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) segInfo := &datapb.SegmentInfo{ @@ -380,7 +335,7 @@ func TestGetSegmentInfo(t *testing.T) { } func TestChannel(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) t.Run("Test StatsChannel", func(t *testing.T) { @@ -491,7 +446,7 @@ func TestChannel(t *testing.T) { } func TestSaveBinlogPaths(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) collections := []struct { @@ -613,7 +568,8 @@ func TestSaveBinlogPaths(t *testing.T) { } func TestDataNodeTtChannel(t *testing.T) { - svr := newTestServer(t) + ch := make(chan interface{}, 1) + svr := newTestServer(t, ch) defer closeTestServer(t, svr) svr.meta.AddCollection(&datapb.CollectionInfo{ @@ -622,14 +578,6 @@ func TestDataNodeTtChannel(t *testing.T) { Partitions: []int64{0}, }) - ch := make(chan interface{}, 1) - svr.createDataNodeClient = func(addr string, serverID int64) (types.DataNode, error) { - cli, err := newMockDataNodeClient(0) - assert.Nil(t, err) - cli.ch = ch - return cli, nil - } - ttMsgStream, err := svr.msFactory.NewMsgStream(context.TODO()) assert.Nil(t, err) ttMsgStream.AsProducer([]string{Params.TimeTickChannelName}) @@ -654,20 +602,16 @@ func TestDataNodeTtChannel(t *testing.T) { } } - resp, err := svr.RegisterNode(context.TODO(), &datapb.RegisterNodeRequest{ - Base: &commonpb.MsgBase{ - MsgType: 0, - MsgID: 0, - Timestamp: 0, - SourceID: 0, - }, - Address: &commonpb.Address{ - Ip: "localhost:7777", - Port: 8080, + svr.cluster.register(&datapb.DataNodeInfo{ + Address: "localhost:7777", + Version: 0, + Channels: []*datapb.ChannelStatus{ + { + Name: "ch-1", + State: datapb.ChannelWatchState_Complete, + }, }, }) - assert.Nil(t, err) - assert.EqualValues(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) t.Run("Test segment flush after tt", func(t *testing.T) { resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{ @@ -688,6 +632,7 @@ func TestDataNodeTtChannel(t *testing.T) { assert.EqualValues(t, 1, len(resp.SegIDAssignments)) assign := resp.SegIDAssignments[0] + log.Debug("xxxxxxxxxxxxx", zap.Any("assign", assign)) resp2, err := svr.Flush(context.TODO(), &datapb.FlushRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_Flush, @@ -720,7 +665,7 @@ func TestResumeChannel(t *testing.T) { segmentIDs := make([]int64, 0, 1000) t.Run("Prepare Resume test set", func(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer svr.Stop() i := int64(-1) @@ -743,7 +688,7 @@ func TestResumeChannel(t *testing.T) { }) t.Run("Test ResumeSegmentStatsChannel", func(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) segRows := rand.Int63n(1000) @@ -792,7 +737,7 @@ func TestResumeChannel(t *testing.T) { svr.Stop() time.Sleep(time.Millisecond * 50) - svr = newTestServer(t) + svr = newTestServer(t, nil) defer svr.Stop() <-ch @@ -812,7 +757,7 @@ func TestResumeChannel(t *testing.T) { }) t.Run("Clean up test segments", func(t *testing.T) { - svr := newTestServer(t) + svr := newTestServer(t, nil) defer closeTestServer(t, svr) var err error for _, segID := range segmentIDs { @@ -822,7 +767,7 @@ func TestResumeChannel(t *testing.T) { }) } -func newTestServer(t *testing.T) *Server { +func newTestServer(t *testing.T, receiveCh chan interface{}) *Server { Params.Init() var err error factory := msgstream.NewPmsFactory() @@ -849,8 +794,8 @@ func newTestServer(t *testing.T) *Server { assert.Nil(t, err) defer ms.Stop() svr.SetMasterClient(ms) - svr.createDataNodeClient = func(addr string) (types.DataNode, error) { - return newMockDataNodeClient(0) + svr.dataClientCreator = func(addr string) (types.DataNode, error) { + return newMockDataNodeClient(0, receiveCh) } assert.Nil(t, err) err = svr.Init()