mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
enhance: Refactor channel dist manager interface (#31119)
issue: #31091 This PR add GetByFilter interface in channel dist manager, instead of all kind of get func --------- Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
parent
16d869c57e
commit
0944a1f790
@ -149,7 +149,7 @@ func (b *RowCountBasedBalancer) convertToNodeItemsByChannel(nodeIDs []int64) []*
|
||||
ret := make([]*nodeItem, 0, len(nodeIDs))
|
||||
for _, nodeInfo := range b.getNodes(nodeIDs) {
|
||||
node := nodeInfo.ID()
|
||||
channels := b.dist.ChannelDistManager.GetByNode(node)
|
||||
channels := b.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(node))
|
||||
|
||||
// more channel num, less priority
|
||||
nodeItem := newNodeItem(len(channels), node)
|
||||
@ -292,7 +292,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode
|
||||
|
||||
segmentsToMove = lo.Filter(segmentsToMove, func(s *meta.Segment, _ int) bool {
|
||||
// if the segment are redundant, skip it's balance for now
|
||||
return len(b.dist.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
|
||||
return len(b.dist.SegmentDistManager.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
|
||||
})
|
||||
|
||||
if len(nodesWithLessRow) == 0 || len(segmentsToMove) == 0 {
|
||||
@ -311,7 +311,7 @@ func (b *RowCountBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNode
|
||||
func (b *RowCountBasedBalancer) genStoppingChannelPlan(replica *meta.Replica, onlineNodes []int64, offlineNodes []int64) []ChannelAssignPlan {
|
||||
channelPlans := make([]ChannelAssignPlan, 0)
|
||||
for _, nodeID := range offlineNodes {
|
||||
dmChannels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)
|
||||
dmChannels := b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(nodeID))
|
||||
plans := b.AssignChannel(dmChannels, onlineNodes, false)
|
||||
for i := range plans {
|
||||
plans[i].From = nodeID
|
||||
@ -326,7 +326,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode
|
||||
channelPlans := make([]ChannelAssignPlan, 0)
|
||||
if len(onlineNodes) > 1 {
|
||||
// start to balance channels on all available nodes
|
||||
channelDist := b.dist.ChannelDistManager.GetChannelDistByReplica(replica)
|
||||
channelDist := b.dist.ChannelDistManager.GetByFilter(meta.WithReplica2Channel(replica))
|
||||
if len(channelDist) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -336,7 +336,7 @@ func (b *RowCountBasedBalancer) genChannelPlan(replica *meta.Replica, onlineNode
|
||||
nodeWithLessChannel := make([]int64, 0)
|
||||
channelsToMove := make([]*meta.DmChannel, 0)
|
||||
for _, node := range onlineNodes {
|
||||
channels := b.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), node)
|
||||
channels := b.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node))
|
||||
|
||||
if len(channels) <= average {
|
||||
nodeWithLessChannel = append(nodeWithLessChannel, node)
|
||||
|
||||
@ -315,7 +315,7 @@ func (b *ScoreBasedBalancer) genSegmentPlan(replica *meta.Replica, onlineNodes [
|
||||
|
||||
// if the segment are redundant, skip it's balance for now
|
||||
segmentsToMove = lo.Filter(segmentsToMove, func(s *meta.Segment, _ int) bool {
|
||||
return len(b.dist.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
|
||||
return len(b.dist.SegmentDistManager.GetByFilter(meta.WithReplica(replica), meta.WithSegmentID(s.GetID()))) == 1
|
||||
})
|
||||
|
||||
if len(segmentsToMove) == 0 {
|
||||
|
||||
@ -176,7 +176,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica,
|
||||
// 3. print stopping nodes channel distribution
|
||||
distInfo += "[stoppingNodesChannelDist:"
|
||||
for stoppingNodeID := range stoppingNodesSegments {
|
||||
stoppingNodeChannels := channelManager.GetByCollectionAndNode(replica.GetCollectionID(), stoppingNodeID)
|
||||
stoppingNodeChannels := channelManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(stoppingNodeID))
|
||||
distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", stoppingNodeID, len(stoppingNodeChannels))
|
||||
distInfo += "channels:["
|
||||
for _, stoppingChan := range stoppingNodeChannels {
|
||||
@ -189,7 +189,7 @@ func PrintCurrentReplicaDist(replica *meta.Replica,
|
||||
// 4. print normal nodes channel distribution
|
||||
distInfo += "[normalNodesChannelDist:"
|
||||
for normalNodeID := range nodeSegments {
|
||||
normalNodeChannels := channelManager.GetByCollectionAndNode(replica.GetCollectionID(), normalNodeID)
|
||||
normalNodeChannels := channelManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(normalNodeID))
|
||||
distInfo += fmt.Sprintf("[nodeID:%d, count:%d,", normalNodeID, len(normalNodeChannels))
|
||||
distInfo += "channels:["
|
||||
for _, normalNodeChan := range normalNodeChannels {
|
||||
|
||||
@ -91,7 +91,7 @@ func (c *ChannelChecker) Check(ctx context.Context) []task.Task {
|
||||
}
|
||||
}
|
||||
|
||||
channels := c.dist.ChannelDistManager.GetAll()
|
||||
channels := c.dist.ChannelDistManager.GetByFilter()
|
||||
released := utils.FilterReleased(channels, collectionIDs)
|
||||
releaseTasks := c.createChannelReduceTasks(ctx, released, meta.NilReplica)
|
||||
task.SetReason("collection released", releaseTasks...)
|
||||
@ -163,7 +163,7 @@ func (c *ChannelChecker) getDmChannelDiff(collectionID int64,
|
||||
func (c *ChannelChecker) getChannelDist(replica *meta.Replica) []*meta.DmChannel {
|
||||
dist := make([]*meta.DmChannel, 0)
|
||||
for _, nodeID := range replica.GetNodes() {
|
||||
dist = append(dist, c.dist.ChannelDistManager.GetByCollectionAndNode(replica.GetCollectionID(), nodeID)...)
|
||||
dist = append(dist, c.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(nodeID))...)
|
||||
}
|
||||
return dist
|
||||
}
|
||||
|
||||
@ -95,7 +95,7 @@ func (c *SegmentChecker) Check(ctx context.Context) []task.Task {
|
||||
}
|
||||
|
||||
// find already released segments which are not contained in target
|
||||
segments := c.dist.SegmentDistManager.GetByFilter(nil)
|
||||
segments := c.dist.SegmentDistManager.GetByFilter()
|
||||
released := utils.FilterReleased(segments, collectionIDs)
|
||||
reduceTasks := c.createSegmentReduceTasks(ctx, released, meta.NilReplica, querypb.DataScope_Historical)
|
||||
task.SetReason("collection released", reduceTasks...)
|
||||
@ -150,7 +150,6 @@ func (c *SegmentChecker) getGrowingSegmentDiff(collectionID int64,
|
||||
zap.Int64("replicaID", replica.ID))
|
||||
|
||||
leaders := c.dist.ChannelDistManager.GetShardLeadersByReplica(replica)
|
||||
// distMgr.LeaderViewManager.
|
||||
for channelName, node := range leaders {
|
||||
view := c.dist.LeaderViewManager.GetLeaderShardView(node, channelName)
|
||||
if view == nil {
|
||||
|
||||
2
internal/querycoordv2/dist/dist_handler.go
vendored
2
internal/querycoordv2/dist/dist_handler.go
vendored
@ -212,7 +212,7 @@ func (dh *distHandler) getDistribution(ctx context.Context) (*querypb.GetDataDis
|
||||
defer dh.mu.Unlock()
|
||||
|
||||
channels := make(map[string]*msgpb.MsgPosition)
|
||||
for _, channel := range dh.dist.ChannelDistManager.GetByNode(dh.nodeID) {
|
||||
for _, channel := range dh.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(dh.nodeID)) {
|
||||
targetChannel := dh.target.GetDmChannel(channel.GetCollectionID(), channel.GetChannelName(), meta.CurrentTarget)
|
||||
if targetChannel == nil {
|
||||
continue
|
||||
|
||||
@ -49,7 +49,7 @@ func waitCollectionReleased(dist *meta.DistributionManager, checkerController *c
|
||||
return partitionSet.Contain(segment.GetPartitionID())
|
||||
})
|
||||
} else {
|
||||
channels = dist.ChannelDistManager.GetByCollection(collection)
|
||||
channels = dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(collection))
|
||||
}
|
||||
|
||||
if len(channels)+len(segments) == 0 {
|
||||
|
||||
@ -25,6 +25,32 @@ import (
|
||||
. "github.com/milvus-io/milvus/pkg/util/typeutil"
|
||||
)
|
||||
|
||||
type ChannelDistFilter = func(ch *DmChannel) bool
|
||||
|
||||
func WithCollectionID2Channel(collectionID int64) ChannelDistFilter {
|
||||
return func(ch *DmChannel) bool {
|
||||
return ch.GetCollectionID() == collectionID
|
||||
}
|
||||
}
|
||||
|
||||
func WithNodeID2Channel(nodeID int64) ChannelDistFilter {
|
||||
return func(ch *DmChannel) bool {
|
||||
return ch.Node == nodeID
|
||||
}
|
||||
}
|
||||
|
||||
func WithReplica2Channel(replica *Replica) ChannelDistFilter {
|
||||
return func(ch *DmChannel) bool {
|
||||
return ch.GetCollectionID() == replica.GetCollectionID() && replica.Contains(ch.Node)
|
||||
}
|
||||
}
|
||||
|
||||
func WithChannelName2Channel(channelName string) ChannelDistFilter {
|
||||
return func(ch *DmChannel) bool {
|
||||
return ch.GetChannelName() == channelName
|
||||
}
|
||||
}
|
||||
|
||||
type DmChannel struct {
|
||||
*datapb.VchannelInfo
|
||||
Node int64
|
||||
@ -58,33 +84,7 @@ func NewChannelDistManager() *ChannelDistManager {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ChannelDistManager) GetByNode(nodeID UniqueID) []*DmChannel {
|
||||
m.rwmutex.RLock()
|
||||
defer m.rwmutex.RUnlock()
|
||||
|
||||
return m.getByNode(nodeID)
|
||||
}
|
||||
|
||||
func (m *ChannelDistManager) getByNode(nodeID UniqueID) []*DmChannel {
|
||||
channels, ok := m.channels[nodeID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return channels
|
||||
}
|
||||
|
||||
func (m *ChannelDistManager) GetAll() []*DmChannel {
|
||||
m.rwmutex.RLock()
|
||||
defer m.rwmutex.RUnlock()
|
||||
|
||||
result := make([]*DmChannel, 0)
|
||||
for _, channels := range m.channels {
|
||||
result = append(result, channels...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// todo by liuwei: should consider the case of duplicate leader exists
|
||||
// GetShardLeader returns the node whthin the given replicaNodes and subscribing the given shard,
|
||||
// returns (0, false) if not found.
|
||||
func (m *ChannelDistManager) GetShardLeader(replica *Replica, shard string) (int64, bool) {
|
||||
@ -103,6 +103,7 @@ func (m *ChannelDistManager) GetShardLeader(replica *Replica, shard string) (int
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// todo by liuwei: should consider the case of duplicate leader exists
|
||||
func (m *ChannelDistManager) GetShardLeadersByReplica(replica *Replica) map[string]int64 {
|
||||
m.rwmutex.RLock()
|
||||
defer m.rwmutex.RUnlock()
|
||||
@ -119,35 +120,25 @@ func (m *ChannelDistManager) GetShardLeadersByReplica(replica *Replica) map[stri
|
||||
return ret
|
||||
}
|
||||
|
||||
func (m *ChannelDistManager) GetChannelDistByReplica(replica *Replica) map[string][]int64 {
|
||||
// return all channels in list which match all given filters
|
||||
func (m *ChannelDistManager) GetByFilter(filters ...ChannelDistFilter) []*DmChannel {
|
||||
m.rwmutex.RLock()
|
||||
defer m.rwmutex.RUnlock()
|
||||
|
||||
ret := make(map[string][]int64)
|
||||
for _, node := range replica.GetNodes() {
|
||||
channels := m.channels[node]
|
||||
for _, dmc := range channels {
|
||||
if dmc.GetCollectionID() == replica.GetCollectionID() {
|
||||
channelName := dmc.GetChannelName()
|
||||
_, ok := ret[channelName]
|
||||
if !ok {
|
||||
ret[channelName] = make([]int64, 0)
|
||||
}
|
||||
ret[channelName] = append(ret[channelName], node)
|
||||
mergedFilters := func(ch *DmChannel) bool {
|
||||
for _, fn := range filters {
|
||||
if fn != nil && !fn(ch) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (m *ChannelDistManager) GetByCollection(collectionID UniqueID) []*DmChannel {
|
||||
m.rwmutex.RLock()
|
||||
defer m.rwmutex.RUnlock()
|
||||
return true
|
||||
}
|
||||
|
||||
ret := make([]*DmChannel, 0)
|
||||
for _, channels := range m.channels {
|
||||
for _, channel := range channels {
|
||||
if channel.CollectionID == collectionID {
|
||||
if mergedFilters(channel) {
|
||||
ret = append(ret, channel)
|
||||
}
|
||||
}
|
||||
@ -155,20 +146,6 @@ func (m *ChannelDistManager) GetByCollection(collectionID UniqueID) []*DmChannel
|
||||
return ret
|
||||
}
|
||||
|
||||
func (m *ChannelDistManager) GetByCollectionAndNode(collectionID, nodeID UniqueID) []*DmChannel {
|
||||
m.rwmutex.RLock()
|
||||
defer m.rwmutex.RUnlock()
|
||||
|
||||
channels := make([]*DmChannel, 0)
|
||||
for _, channel := range m.getByNode(nodeID) {
|
||||
if channel.CollectionID == collectionID {
|
||||
channels = append(channels, channel)
|
||||
}
|
||||
}
|
||||
|
||||
return channels
|
||||
}
|
||||
|
||||
func (m *ChannelDistManager) Update(nodeID UniqueID, channels ...*DmChannel) {
|
||||
m.rwmutex.Lock()
|
||||
defer m.rwmutex.Unlock()
|
||||
|
||||
@ -66,36 +66,36 @@ func (suite *ChannelDistManagerSuite) TestGetBy() {
|
||||
dist := suite.dist
|
||||
|
||||
// Test GetAll
|
||||
channels := dist.GetAll()
|
||||
channels := dist.GetByFilter(nil)
|
||||
suite.Len(channels, 4)
|
||||
|
||||
// Test GetByNode
|
||||
for _, node := range suite.nodes {
|
||||
channels := dist.GetByNode(node)
|
||||
channels := dist.GetByFilter(WithNodeID2Channel(node))
|
||||
suite.AssertNode(channels, node)
|
||||
}
|
||||
|
||||
// Test GetByCollection
|
||||
channels = dist.GetByCollection(suite.collection)
|
||||
channels = dist.GetByFilter(WithCollectionID2Channel(suite.collection))
|
||||
suite.Len(channels, 4)
|
||||
suite.AssertCollection(channels, suite.collection)
|
||||
channels = dist.GetByCollection(-1)
|
||||
channels = dist.GetByFilter(WithCollectionID2Channel(-1))
|
||||
suite.Len(channels, 0)
|
||||
|
||||
// Test GetByNodeAndCollection
|
||||
// 1. Valid node and valid collection
|
||||
for _, node := range suite.nodes {
|
||||
channels := dist.GetByCollectionAndNode(suite.collection, node)
|
||||
channels := dist.GetByFilter(WithCollectionID2Channel(suite.collection), WithNodeID2Channel(node))
|
||||
suite.AssertNode(channels, node)
|
||||
suite.AssertCollection(channels, suite.collection)
|
||||
}
|
||||
|
||||
// 2. Valid node and invalid collection
|
||||
channels = dist.GetByCollectionAndNode(-1, suite.nodes[1])
|
||||
channels = dist.GetByFilter(WithCollectionID2Channel(-1), WithNodeID2Channel(suite.nodes[1]))
|
||||
suite.Len(channels, 0)
|
||||
|
||||
// 3. Invalid node and valid collection
|
||||
channels = dist.GetByCollectionAndNode(suite.collection, -1)
|
||||
channels = dist.GetByFilter(WithCollectionID2Channel(suite.collection), WithNodeID2Channel(-1))
|
||||
suite.Len(channels, 0)
|
||||
}
|
||||
|
||||
@ -148,47 +148,6 @@ func (suite *ChannelDistManagerSuite) TestGetShardLeader() {
|
||||
suite.Equal(leaders["dmc1"], suite.nodes[1])
|
||||
}
|
||||
|
||||
func (suite *ChannelDistManagerSuite) TestGetChannelDistByReplica() {
|
||||
replica := NewReplica(
|
||||
&querypb.Replica{
|
||||
CollectionID: suite.collection,
|
||||
},
|
||||
typeutil.NewUniqueSet(11, 22, 33),
|
||||
)
|
||||
|
||||
ch1 := &DmChannel{
|
||||
VchannelInfo: &datapb.VchannelInfo{
|
||||
CollectionID: suite.collection,
|
||||
ChannelName: "test-channel1",
|
||||
},
|
||||
Node: 11,
|
||||
Version: 1,
|
||||
}
|
||||
ch2 := &DmChannel{
|
||||
VchannelInfo: &datapb.VchannelInfo{
|
||||
CollectionID: suite.collection,
|
||||
ChannelName: "test-channel1",
|
||||
},
|
||||
Node: 22,
|
||||
Version: 1,
|
||||
}
|
||||
ch3 := &DmChannel{
|
||||
VchannelInfo: &datapb.VchannelInfo{
|
||||
CollectionID: suite.collection,
|
||||
ChannelName: "test-channel2",
|
||||
},
|
||||
Node: 33,
|
||||
Version: 1,
|
||||
}
|
||||
suite.dist.Update(11, ch1)
|
||||
suite.dist.Update(22, ch2)
|
||||
suite.dist.Update(33, ch3)
|
||||
|
||||
dist := suite.dist.GetChannelDistByReplica(replica)
|
||||
suite.Len(dist["test-channel1"], 2)
|
||||
suite.Len(dist["test-channel2"], 1)
|
||||
}
|
||||
|
||||
func (suite *ChannelDistManagerSuite) AssertNames(channels []*DmChannel, names ...string) bool {
|
||||
for _, channel := range channels {
|
||||
hasChannel := false
|
||||
|
||||
@ -108,16 +108,19 @@ func (m *SegmentDistManager) GetByFilter(filters ...SegmentDistFilter) []*Segmen
|
||||
m.rwmutex.RLock()
|
||||
defer m.rwmutex.RUnlock()
|
||||
|
||||
mergedFilters := func(s *Segment) bool {
|
||||
for _, f := range filters {
|
||||
if f != nil && !f(s) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
ret := make([]*Segment, 0)
|
||||
for _, segments := range m.segments {
|
||||
for _, segment := range segments {
|
||||
allMatch := true
|
||||
for _, f := range filters {
|
||||
if f != nil && !f(segment) {
|
||||
allMatch = false
|
||||
}
|
||||
}
|
||||
if allMatch {
|
||||
if mergedFilters(segment) {
|
||||
ret = append(ret, segment)
|
||||
}
|
||||
}
|
||||
|
||||
@ -98,8 +98,8 @@ func (ob *ReplicaObserver) checkNodesInReplica() {
|
||||
)
|
||||
|
||||
for node := range outboundNodes {
|
||||
channels := ob.distMgr.ChannelDistManager.GetByCollectionAndNode(collectionID, node)
|
||||
segments := ob.distMgr.SegmentDistManager.GetByFilter(meta.WithCollectionID(collectionID), meta.WithNodeID(node))
|
||||
channels := ob.distMgr.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(node))
|
||||
|
||||
if len(channels) == 0 && len(segments) == 0 {
|
||||
replica.RemoveNode(node)
|
||||
|
||||
@ -147,7 +147,7 @@ func (s *Server) GetQueryNodeDistribution(ctx context.Context, req *querypb.GetQ
|
||||
}
|
||||
|
||||
segments := s.dist.SegmentDistManager.GetByFilter(meta.WithNodeID(req.GetNodeID()))
|
||||
channels := s.dist.ChannelDistManager.GetByNode(req.NodeID)
|
||||
channels := s.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(req.GetNodeID()))
|
||||
return &querypb.GetQueryNodeDistributionResponse{
|
||||
Status: merr.Success(),
|
||||
ChannelNames: lo.Map(channels, func(c *meta.DmChannel, _ int) string { return c.GetChannelName() }),
|
||||
@ -364,7 +364,7 @@ func (s *Server) TransferChannel(ctx context.Context, req *querypb.TransferChann
|
||||
dstNodeSet.Remove(srcNode)
|
||||
|
||||
// check sealed segment list
|
||||
channels := s.dist.ChannelDistManager.GetByCollectionAndNode(replica.CollectionID, srcNode)
|
||||
channels := s.dist.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(replica.GetCollectionID()), meta.WithNodeID2Channel(srcNode))
|
||||
toBalance := typeutil.NewSet[*meta.DmChannel]()
|
||||
if req.GetTransferAll() {
|
||||
toBalance.Insert(channels...)
|
||||
@ -421,8 +421,8 @@ func (s *Server) CheckQueryNodeDistribution(ctx context.Context, req *querypb.Ch
|
||||
}
|
||||
|
||||
// check channel list
|
||||
channelOnSrc := s.dist.ChannelDistManager.GetByNode(req.GetSourceNodeID())
|
||||
channelOnDst := s.dist.ChannelDistManager.GetByNode(req.GetTargetNodeID())
|
||||
channelOnSrc := s.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(req.GetSourceNodeID()))
|
||||
channelOnDst := s.dist.ChannelDistManager.GetByFilter(meta.WithNodeID2Channel(req.GetTargetNodeID()))
|
||||
channelDstMap := lo.SliceToMap(channelOnDst, func(ch *meta.DmChannel) (string, *meta.DmChannel) {
|
||||
return ch.GetChannelName(), ch
|
||||
})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user