diff --git a/internal/datacoord/channel_manager.go b/internal/datacoord/channel_manager.go index dda51dd2d0..3dca44910e 100644 --- a/internal/datacoord/channel_manager.go +++ b/internal/datacoord/channel_manager.go @@ -63,10 +63,9 @@ func defaultFactory(hash *consistent.Consistent) ChannelPolicyFactory { // NewChannelManager return a new ChannelManager func NewChannelManager(kv kv.TxnKV, posProvider positionProvider, options ...ChannelManagerOpt) (*ChannelManager, error) { - hashring := consistent.New() c := &ChannelManager{ posProvider: posProvider, - factory: defaultFactory(hashring), + factory: NewChannelPolicyFactoryV1(kv), store: NewChannelStore(kv), } diff --git a/internal/datacoord/channel_manager_factory.go b/internal/datacoord/channel_manager_factory.go index 5e2dfa18e0..c227bf4d45 100644 --- a/internal/datacoord/channel_manager_factory.go +++ b/internal/datacoord/channel_manager_factory.go @@ -47,7 +47,7 @@ func NewChannelPolicyFactoryV1(kv kv.TxnKV) *ChannelPolicyFactoryV1 { // NewRegisterPolicy implementing ChannelPolicyFactory returns BufferChannelAssignPolicy func (f *ChannelPolicyFactoryV1) NewRegisterPolicy() RegisterPolicy { - return BufferChannelAssignPolicy + return AvgAssignRegisterPolicy } // NewDeregisterPolicy implementing ChannelPolicyFactory returns AvgAssignUnregisteredChannels diff --git a/internal/datacoord/policy.go b/internal/datacoord/policy.go index bc834f245e..139558e87c 100644 --- a/internal/datacoord/policy.go +++ b/internal/datacoord/policy.go @@ -50,6 +50,63 @@ func BufferChannelAssignPolicy(store ROChannelStore, nodeID int64) ChannelOpSet return opSet } +func AvgAssignRegisterPolicy(store ROChannelStore, nodeID int64) ChannelOpSet { + opSet := BufferChannelAssignPolicy(store, nodeID) + if len(opSet) != 0 { + return opSet + } + + infos := store.GetNodesChannels() + infos = filterNode(infos, nodeID) + + channelNum := 0 + for _, info := range infos { + channelNum += len(info.Channels) + } + avg := channelNum / (len(store.GetNodes()) + 1) + if avg == 0 { + return nil + } + + // sort in descending order and reallocate + sort.Slice(infos, func(i, j int) bool { + return len(infos[i].Channels) > len(infos[j].Channels) + }) + + deletes := make(map[int64][]*channel) + adds := make(map[int64][]*channel) + for i := 0; i < avg; { + t := infos[i%len(infos)] + idx := i / len(infos) + if idx >= len(t.Channels) { + continue + } + deletes[t.NodeID] = append(deletes[t.NodeID], t.Channels[idx]) + adds[nodeID] = append(adds[nodeID], t.Channels[idx]) + i++ + } + + opSet = ChannelOpSet{} + for k, v := range deletes { + opSet.Delete(k, v) + } + for k, v := range adds { + opSet.Add(k, v) + } + return opSet +} + +func filterNode(infos []*NodeChannelInfo, nodeID int64) []*NodeChannelInfo { + filtered := make([]*NodeChannelInfo, 0) + for _, info := range infos { + if info.NodeID == nodeID { + continue + } + filtered = append(filtered, info) + } + return filtered +} + // ConsistentHashRegisterPolicy use a consistent hash to matain the mapping func ConsistentHashRegisterPolicy(hashring *consistent.Consistent) RegisterPolicy { return func(store ROChannelStore, nodeID int64) ChannelOpSet { diff --git a/internal/datacoord/policy_test.go b/internal/datacoord/policy_test.go index 93f5033439..a192d3f0fd 100644 --- a/internal/datacoord/policy_test.go +++ b/internal/datacoord/policy_test.go @@ -428,3 +428,121 @@ func TestBgCheckWithMaxWatchDuration(t *testing.T) { }) } } + +func TestAvgAssignRegisterPolicy(t *testing.T) { + type args struct { + store ROChannelStore + nodeID int64 + } + tests := []struct { + name string + args args + want ChannelOpSet + }{ + { + "test empty", + args{ + &ChannelStore{ + memkv.NewMemoryKV(), + map[int64]*NodeChannelInfo{}, + }, + 1, + }, + nil, + }, + { + "test with buffer channel", + args{ + &ChannelStore{ + memkv.NewMemoryKV(), + map[int64]*NodeChannelInfo{ + bufferID: {bufferID, []*channel{{"ch1", 1}}}, + }, + }, + 1, + }, + []*ChannelOp{ + { + Type: Delete, + NodeID: bufferID, + Channels: []*channel{{"ch1", 1}}, + }, + { + Type: Add, + NodeID: 1, + Channels: []*channel{{"ch1", 1}}, + }, + }, + }, + { + "test with avg assign", + args{ + &ChannelStore{ + memkv.NewMemoryKV(), + map[int64]*NodeChannelInfo{ + 1: {1, []*channel{{"ch1", 1}, {"ch2", 1}}}, + 2: {2, []*channel{{"ch3", 1}, {"ch4", 1}}}, + }, + }, + 3, + }, + []*ChannelOp{ + { + Type: Delete, + NodeID: 1, + Channels: []*channel{{"ch1", 1}}, + }, + { + Type: Add, + NodeID: 3, + Channels: []*channel{{"ch1", 1}}, + }, + }, + }, + { + "test with avg equals to zero", + args{ + &ChannelStore{ + memkv.NewMemoryKV(), + map[int64]*NodeChannelInfo{ + 1: {1, []*channel{{"ch1", 1}}}, + 2: {2, []*channel{{"ch3", 1}}}, + }, + }, + 3, + }, + nil, + }, + { + "test node with empty channel", + args{ + &ChannelStore{ + memkv.NewMemoryKV(), + map[int64]*NodeChannelInfo{ + 1: {1, []*channel{{"ch1", 1}, {"ch2", 1}, {"ch3", 1}}}, + 2: {2, []*channel{}}, + }, + }, + 3, + }, + []*ChannelOp{ + { + Type: Delete, + NodeID: 1, + Channels: []*channel{{"ch1", 1}}, + }, + { + Type: Add, + NodeID: 3, + Channels: []*channel{{"ch1", 1}}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := AvgAssignRegisterPolicy(tt.args.store, tt.args.nodeID) + assert.EqualValues(t, tt.want, got) + }) + } +}