enhance: Refactor channel dist manager interface (#31119)

issue: #31091
This PR add GetByFilter interface in channel dist manager, instead of
all kind of get func

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2024-04-02 10:23:14 +08:00 committed by GitHub
parent 16d869c57e
commit 0944a1f790
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 72 additions and 134 deletions

View File

@ -149,7 +149,7 @@ func (b *RowCountBasedBalancer) convertToNodeItemsByChannel(nodeIDs []int64) []*
ret := make([]*nodeItem, 0, len(nodeIDs))
for _, nodeInfo := range b.getNodes(nodeIDs) {
node := nodeInfo.ID()
channels := b.dist.ChannelDistManager.GetByNode(node)
channels := b.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(node))
// more channel num, less priority
nodeItem := newNodeItem(len(channels), node)
@ -292,7 +292,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode
segmentsToMove = lo.Filter(segmentsToMove, func(s *meta.Segment, _ int) bool {
// if the segment are redundant, skip it's balance for now
return len(b.dist.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
return len(b.dist.SegmentDistManager.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
})
if len(nodesWithLessRow) == 0 || len(segmentsToMove) == 0 {
@ -311,7 +311,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode
func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan {
channelPlans := make([]ChannelAssignPlan, 0)
for _, nodeID := range offlineNodes {
dmChannels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)
dmChannels := b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(nodeID))
plans := b.AssignChannel(dmChannels, onlineNodes, false)
for i := range plans {
plans[i].From = nodeID
@ -326,7 +326,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode
channelPlans := make([]ChannelAssignPlan, 0)
if len(onlineNodes) > 1 {
// start to balance channels on all available nodes
channelDist := b.dist.ChannelDistManager.GetChannelDistByReplica(replica)
channelDist := b.dist.ChannelDistManager.GetByFilter(meta.WithReplica2Channel(replica))
if len(channelDist) == 0 {
return nil
}
@ -336,7 +336,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode
nodeWithLessChannel := make([]int64, 0)
channelsToMove := make([]*meta.DmChannel, 0)
for _, node := range onlineNodes {
channels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), node)
channels := b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node))
if len(channels) <= average {
nodeWithLessChannel = append(nodeWithLessChannel, node)

View File

@ -315,7 +315,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNodes [
// if the segment are redundant, skip it's balance for now
segmentsToMove = lo.Filter(segmentsToMove, func(s *meta.Segment, _ int) bool {
return len(b.dist.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
return len(b.dist.SegmentDistManager.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
})
if len(segmentsToMove) == 0 {

View File

@ -176,7 +176,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica,
// 3. print stopping nodes channel distribution
distInfo += "[stoppingNodesChannelDist:"
for stoppingNodeID := range stoppingNodesSegments {
stoppingNodeChannels := channelManager.GetByCollectionAndNode(replica.GetCollectionID(), stoppingNodeID)
stoppingNodeChannels := channelManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(stoppingNodeID))
distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", stoppingNodeID, len(stoppingNodeChannels))
distInfo += "channels:["
for _, stoppingChan := range stoppingNodeChannels {
@ -189,7 +189,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica,
// 4. print normal nodes channel distribution
distInfo += "[normalNodesChannelDist:"
for normalNodeID := range nodeSegments {
normalNodeChannels := channelManager.GetByCollectionAndNode(replica.GetCollectionID(), normalNodeID)
normalNodeChannels := channelManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(normalNodeID))
distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", normalNodeID, len(normalNodeChannels))
distInfo += "channels:["
for _, normalNodeChan := range normalNodeChannels {

View File

@ -91,7 +91,7 @@ func (c *ChannelChecker) Check(ctx context.Context) []task.Task {
}
}
channels := c.dist.ChannelDistManager.GetAll()
channels := c.dist.ChannelDistManager.GetByFilter()
released := utils.FilterReleased(channels, collectionIDs)
releaseTasks := c.createChannelReduceTasks(ctx, released, meta.NilReplica)
task.SetReason("collection released", releaseTasks...)
@ -163,7 +163,7 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64,
func (c *ChannelChecker) getChannelDist(replica *meta.Replica) []*meta.DmChannel {
dist := make([]*meta.DmChannel, 0)
for _, nodeID := range replica.GetNodes() {
dist = append(dist, c.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)...)
dist = append(dist, c.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(nodeID))...)
}
return dist
}

View File

@ -95,7 +95,7 @@ func (c *SegmentChecker) Check(ctx context.Context) []task.Task {
}
// find already released segments which are not contained in target
segments := c.dist.SegmentDistManager.GetByFilter(nil)
segments := c.dist.SegmentDistManager.GetByFilter()
released := utils.FilterReleased(segments, collectionIDs)
reduceTasks := c.createSegmentReduceTasks(ctx, released, meta.NilReplica, querypb.DataScope_Historical)
task.SetReason("collection released", reduceTasks...)
@ -150,7 +150,6 @@ func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64,
zap.Int64("replicaID", replica.ID))
leaders := c.dist.ChannelDistManager.GetShardLeadersByReplica(replica)
// distMgr.LeaderViewManager.
for channelName, node := range leaders {
view := c.dist.LeaderViewManager.GetLeaderShardView(node, channelName)
if view == nil {

View File

@ -212,7 +212,7 @@ func (dh *distHandler) getDistribution(ctx context.Context) (*querypb.GetDataDis
defer dh.mu.Unlock()
channels := make(map[string]*msgpb.MsgPosition)
for _, channel := range dh.dist.ChannelDistManager.GetByNode(dh.nodeID) {
for _, channel := range dh.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(dh.nodeID)) {
targetChannel := dh.target.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget)
if targetChannel == nil {
continue

View File

@ -49,7 +49,7 @@ func waitCollectionReleased(dist *meta.DistributionManager, checkerController *c
return partitionSet.Contain(segment.GetPartitionID())
})
} else {
channels = dist.ChannelDistManager.GetByCollection(collection)
channels = dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(collection))
}
if len(channels)+len(segments) == 0 {

View File

@ -25,6 +25,32 @@ import (
. "github.com/milvus-io/milvus/pkg/util/typeutil"
)
type ChannelDistFilter = func(ch *DmChannel) bool
func WithCollectionID2Channel(collectionID int64) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.GetCollectionID() == collectionID
}
}
func WithNodeID2Channel(nodeID int64) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.Node == nodeID
}
}
func WithReplica2Channel(replica *Replica) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.GetCollectionID() == replica.GetCollectionID() && replica.Contains(ch.Node)
}
}
func WithChannelName2Channel(channelName string) ChannelDistFilter {
return func(ch *DmChannel) bool {
return ch.GetChannelName() == channelName
}
}
type DmChannel struct {
*datapb.VchannelInfo
Node int64
@ -58,33 +84,7 @@ func NewChannelDistManager() *ChannelDistManager {
}
}
func (m *ChannelDistManager) GetByNode(nodeID UniqueID) []*DmChannel {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
return m.getByNode(nodeID)
}
func (m *ChannelDistManager) getByNode(nodeID UniqueID) []*DmChannel {
channels, ok := m.channels[nodeID]
if !ok {
return nil
}
return channels
}
func (m *ChannelDistManager) GetAll() []*DmChannel {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
result := make([]*DmChannel, 0)
for _, channels := range m.channels {
result = append(result, channels...)
}
return result
}
// todo by liuwei: should consider the case of duplicate leader exists
// GetShardLeader returns the node whthin the given replicaNodes and subscribing the given shard,
// returns (0, false) if not found.
func (m *ChannelDistManager) GetShardLeader(replica *Replica, shard string) (int64, bool) {
@ -103,6 +103,7 @@ func (m *ChannelDistManager) GetShardLeader(replica *Replica, shard string) (int
return 0, false
}
// todo by liuwei: should consider the case of duplicate leader exists
func (m *ChannelDistManager) GetShardLeadersByReplica(replica *Replica) map[string]int64 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
@ -119,35 +120,25 @@ func (m *ChannelDistManager) GetShardLeadersByReplica(replica *Replica) map[stri
return ret
}
func (m *ChannelDistManager) GetChannelDistByReplica(replica *Replica) map[string][]int64 {
// return all channels in list which match all given filters
func (m *ChannelDistManager) GetByFilter(filters ...ChannelDistFilter) []*DmChannel {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
ret := make(map[string][]int64)
for _, node := range replica.GetNodes() {
channels := m.channels[node]
for _, dmc := range channels {
if dmc.GetCollectionID() == replica.GetCollectionID() {
channelName := dmc.GetChannelName()
_, ok := ret[channelName]
if !ok {
ret[channelName] = make([]int64, 0)
}
ret[channelName] = append(ret[channelName], node)
mergedFilters := func(ch *DmChannel) bool {
for _, fn := range filters {
if fn != nil && !fn(ch) {
return false
}
}
}
return ret
}
func (m *ChannelDistManager) GetByCollection(collectionID UniqueID) []*DmChannel {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
return true
}
ret := make([]*DmChannel, 0)
for _, channels := range m.channels {
for _, channel := range channels {
if channel.CollectionID == collectionID {
if mergedFilters(channel) {
ret = append(ret, channel)
}
}
@ -155,20 +146,6 @@ func (m *ChannelDistManager) GetByCollection(collectionID UniqueID) []*DmChannel
return ret
}
func (m *ChannelDistManager) GetByCollectionAndNode(collectionID, nodeID UniqueID) []*DmChannel {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
channels := make([]*DmChannel, 0)
for _, channel := range m.getByNode(nodeID) {
if channel.CollectionID == collectionID {
channels = append(channels, channel)
}
}
return channels
}
func (m *ChannelDistManager) Update(nodeID UniqueID, channels ...*DmChannel) {
m.rwmutex.Lock()
defer m.rwmutex.Unlock()

View File

@ -66,36 +66,36 @@ func (suite *ChannelDistManagerSuite) TestGetBy() {
dist := suite.dist
// Test GetAll
channels := dist.GetAll()
channels := dist.GetByFilter(nil)
suite.Len(channels, 4)
// Test GetByNode
for _, node := range suite.nodes {
channels := dist.GetByNode(node)
channels := dist.GetByFilter(WithNodeID2Channel(node))
suite.AssertNode(channels, node)
}
// Test GetByCollection
channels = dist.GetByCollection(suite.collection)
channels = dist.GetByFilter(WithCollectionID2Channel(suite.collection))
suite.Len(channels, 4)
suite.AssertCollection(channels, suite.collection)
channels = dist.GetByCollection(-1)
channels = dist.GetByFilter(WithCollectionID2Channel(-1))
suite.Len(channels, 0)
// Test GetByNodeAndCollection
// 1. Valid node and valid collection
for _, node := range suite.nodes {
channels := dist.GetByCollectionAndNode(suite.collection, node)
channels := dist.GetByFilter(WithCollectionID2Channel(suite.collection), WithNodeID2Channel(node))
suite.AssertNode(channels, node)
suite.AssertCollection(channels, suite.collection)
}
// 2. Valid node and invalid collection
channels = dist.GetByCollectionAndNode(-1, suite.nodes[1])
channels = dist.GetByFilter(WithCollectionID2Channel(-1), WithNodeID2Channel(suite.nodes[1]))
suite.Len(channels, 0)
// 3. Invalid node and valid collection
channels = dist.GetByCollectionAndNode(suite.collection, -1)
channels = dist.GetByFilter(WithCollectionID2Channel(suite.collection), WithNodeID2Channel(-1))
suite.Len(channels, 0)
}
@ -148,47 +148,6 @@ func (suite *ChannelDistManagerSuite) TestGetShardLeader() {
suite.Equal(leaders["dmc1"], suite.nodes[1])
}
func (suite *ChannelDistManagerSuite) TestGetChannelDistByReplica() {
replica := NewReplica(
&querypb.Replica{
CollectionID: suite.collection,
},
typeutil.NewUniqueSet(11, 22, 33),
)
ch1 := &DmChannel{
VchannelInfo: &datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: "test-channel1",
},
Node: 11,
Version: 1,
}
ch2 := &DmChannel{
VchannelInfo: &datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: "test-channel1",
},
Node: 22,
Version: 1,
}
ch3 := &DmChannel{
VchannelInfo: &datapb.VchannelInfo{
CollectionID: suite.collection,
ChannelName: "test-channel2",
},
Node: 33,
Version: 1,
}
suite.dist.Update(11, ch1)
suite.dist.Update(22, ch2)
suite.dist.Update(33, ch3)
dist := suite.dist.GetChannelDistByReplica(replica)
suite.Len(dist["test-channel1"], 2)
suite.Len(dist["test-channel2"], 1)
}
func (suite *ChannelDistManagerSuite) AssertNames(channels []*DmChannel, names ...string) bool {
for _, channel := range channels {
hasChannel := false

View File

@ -108,16 +108,19 @@ func (m *SegmentDistManager) GetByFilter(filters ...SegmentDistFilter) []*Segmen
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
mergedFilters := func(s *Segment) bool {
for _, f := range filters {
if f != nil && !f(s) {
return false
}
}
return true
}
ret := make([]*Segment, 0)
for _, segments := range m.segments {
for _, segment := range segments {
allMatch := true
for _, f := range filters {
if f != nil && !f(segment) {
allMatch = false
}
}
if allMatch {
if mergedFilters(segment) {
ret = append(ret, segment)
}
}

View File

@ -98,8 +98,8 @@ func (ob *ReplicaObserver) checkNodesInReplica() {
)
for node := range outboundNodes {
channels := ob.distMgr.ChannelDistManager.GetByCollectionAndNode(collectionID, node)
segments := ob.distMgr.SegmentDistManager.GetByFilter(meta.WithCollectionID(collectionID), meta.WithNodeID(node))
channels := ob.distMgr.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node))
if len(channels) == 0 && len(segments) == 0 {
replica.RemoveNode(node)

View File

@ -147,7 +147,7 @@ func (s *Server) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQ
}
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(req.GetNodeID()))
channels := s.dist.ChannelDistManager.GetByNode(req.NodeID)
channels := s.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(req.GetNodeID()))
return &querypb.GetQueryNodeDistributionResponse{
Status: merr.Success(),
ChannelNames: lo.Map(channels, func(c *meta.DmChannel, _ int) string { return c.GetChannelName() }),
@ -364,7 +364,7 @@ func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChann
dstNodeSet.Remove(srcNode)
// check sealed segment list
channels := s.dist.ChannelDistManager.GetByCollectionAndNode(replica.CollectionID, srcNode)
channels := s.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(srcNode))
toBalance := typeutil.NewSet[*meta.DmChannel]()
if req.GetTransferAll() {
toBalance.Insert(channels...)
@ -421,8 +421,8 @@ func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.Ch
}
// check channel list
channelOnSrc := s.dist.ChannelDistManager.GetByNode(req.GetSourceNodeID())
channelOnDst := s.dist.ChannelDistManager.GetByNode(req.GetTargetNodeID())
channelOnSrc := s.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(req.GetSourceNodeID()))
channelOnDst := s.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(req.GetTargetNodeID()))
channelDstMap := lo.SliceToMap(channelOnDst, func(ch *meta.DmChannel) (string, *meta.DmChannel) {
return ch.GetChannelName(), ch
})