enhance: Enable dynamic update replica selection policy (#35860)

issue: #35859

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2024-09-13 17:05:15 +08:00 committed by GitHub
parent c03eb6f664
commit bd658a6510
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 100 additions and 135 deletions

View File

@ -61,77 +61,75 @@ type LBPolicy interface {
Close()
}
const (
RoundRobin = "round_robin"
LookAside = "look_aside"
)
type LBPolicyImpl struct {
balancer LBBalancer
clientMgr shardClientMgr
getBalancer func() LBBalancer
clientMgr shardClientMgr
balancerMap map[string]LBBalancer
}
func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl {
balancePolicy := params.Params.ProxyCfg.ReplicaSelectionPolicy.GetValue()
balancerMap := make(map[string]LBBalancer)
balancerMap[LookAside] = NewLookAsideBalancer(clientMgr)
balancerMap[RoundRobin] = NewRoundRobinBalancer()
var balancer LBBalancer
switch balancePolicy {
case "round_robin":
log.Info("use round_robin policy on replica selection")
balancer = NewRoundRobinBalancer()
default:
log.Info("use look_aside policy on replica selection")
balancer = NewLookAsideBalancer(clientMgr)
getBalancer := func() LBBalancer {
balancePolicy := params.Params.ProxyCfg.ReplicaSelectionPolicy.GetValue()
if _, ok := balancerMap[balancePolicy]; !ok {
return balancerMap[LookAside]
}
return balancerMap[balancePolicy]
}
return &LBPolicyImpl{
balancer: balancer,
clientMgr: clientMgr,
getBalancer: getBalancer,
clientMgr: clientMgr,
balancerMap: balancerMap,
}
}
func (lb *LBPolicyImpl) Start(ctx context.Context) {
lb.balancer.Start(ctx)
for _, lb := range lb.balancerMap {
lb.Start(ctx)
}
}
// try to select the best node from the available nodes
func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", workload.collectionID),
zap.String("collectionName", workload.collectionName),
zap.String("channelName", workload.channel),
)
filterAvailableNodes := func(node int64, _ int) bool {
return !excludeNodes.Contain(node)
}
getShardLeaders := func() ([]int64, error) {
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) {
availableNodes := lo.FilterMap(workload.shardLeaders, func(node int64, _ int) (int64, bool) { return node, !excludeNodes.Contain(node) })
targetNode, err := balancer.SelectNode(ctx, availableNodes, workload.nq)
if err != nil {
log := log.Ctx(ctx)
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collectionName, workload.collectionID)
if err != nil {
return nil, err
}
return lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID }), nil
}
availableNodes := lo.Filter(workload.shardLeaders, filterAvailableNodes)
targetNode, err := lb.balancer.SelectNode(ctx, availableNodes, workload.nq)
if err != nil {
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
nodes, err := getShardLeaders()
if err != nil || len(nodes) == 0 {
log.Warn("failed to get shard delegator",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Error(err))
return -1, err
}
availableNodes := lo.Filter(nodes, filterAvailableNodes)
availableNodes := lo.FilterMap(shardLeaders[workload.channel], func(node nodeInfo, _ int) (int64, bool) { return node.nodeID, !excludeNodes.Contain(node.nodeID) })
if len(availableNodes) == 0 {
nodes := lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID })
log.Warn("no available shard delegator found",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("nodes", nodes),
zap.Int64s("excluded", excludeNodes.Collect()))
return -1, merr.WrapErrChannelNotAvailable("no available shard delegator found")
}
targetNode, err = lb.balancer.SelectNode(ctx, availableNodes, workload.nq)
targetNode, err = balancer.SelectNode(ctx, availableNodes, workload.nq)
if err != nil {
log.Warn("failed to select shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("availableNodes", availableNodes),
zap.Error(err))
return -1, err
@ -144,17 +142,15 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
// ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes.
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
excludeNodes := typeutil.NewUniqueSet()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", workload.collectionID),
zap.String("collectionName", workload.collectionName),
zap.String("channelName", workload.channel),
)
var lastErr error
err := retry.Do(ctx, func() error {
targetNode, err := lb.selectNode(ctx, workload, excludeNodes)
balancer := lb.getBalancer()
targetNode, err := lb.selectNode(ctx, balancer, workload, excludeNodes)
if err != nil {
log.Warn("failed to select node for shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode),
zap.Error(err),
)
@ -163,16 +159,18 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
}
return err
}
// cancel work load which assign to the target node
defer balancer.CancelWorkload(targetNode, workload.nq)
client, err := lb.clientMgr.GetClient(ctx, targetNode)
if err != nil {
log.Warn("search/query channel failed, node not available",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode),
zap.Error(err))
excludeNodes.Insert(targetNode)
// cancel work load which assign to the target node
lb.balancer.CancelWorkload(targetNode, workload.nq)
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode, workload.channel)
return lastErr
}
@ -180,16 +178,15 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
err = workload.exec(ctx, targetNode, client, workload.channel)
if err != nil {
log.Warn("search/query channel failed",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode),
zap.Error(err))
excludeNodes.Insert(targetNode)
lb.balancer.CancelWorkload(targetNode, workload.nq)
lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode, workload.channel)
return lastErr
}
lb.balancer.CancelWorkload(targetNode, workload.nq)
return nil
}, retry.Attempts(workload.retryTimes))
@ -232,9 +229,11 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
}
func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
lb.balancer.UpdateCostMetrics(node, cost)
lb.getBalancer().UpdateCostMetrics(node, cost)
}
func (lb *LBPolicyImpl) Close() {
lb.balancer.Close()
for _, lb := range lb.balancerMap {
lb.Close()
}
}

View File

@ -101,7 +101,9 @@ func (s *LBPolicySuite) SetupTest() {
s.lbBalancer.EXPECT().Start(context.Background()).Maybe()
s.lbPolicy = NewLBPolicyImpl(s.mgr)
s.lbPolicy.Start(context.Background())
s.lbPolicy.balancer = s.lbBalancer
s.lbPolicy.getBalancer = func() LBBalancer {
return s.lbBalancer
}
err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr)
s.NoError(err)
@ -163,7 +165,7 @@ func (s *LBPolicySuite) loadCollection() {
func (s *LBPolicySuite) TestSelectNode() {
ctx := context.Background()
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
targetNode, err := s.lbPolicy.selectNode(ctx, ChannelWorkload{
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
@ -178,7 +180,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
@ -192,7 +194,7 @@ func (s *LBPolicySuite) TestSelectNode() {
// test select node always fails, expected failure
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
@ -206,7 +208,7 @@ func (s *LBPolicySuite) TestSelectNode() {
// test all nodes has been excluded, expected failure
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
@ -222,7 +224,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
s.qc.ExpectedCalls = nil
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnavailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
@ -419,17 +421,17 @@ func (s *LBPolicySuite) TestUpdateCostMetrics() {
func (s *LBPolicySuite) TestNewLBPolicy() {
policy := NewLBPolicyImpl(s.mgr)
s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.LookAsideBalancer")
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer")
policy.Close()
Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "round_robin")
policy = NewLBPolicyImpl(s.mgr)
s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.RoundRobinBalancer")
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.RoundRobinBalancer")
policy.Close()
Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "look_aside")
policy = NewLBPolicyImpl(s.mgr)
s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.LookAsideBalancer")
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer")
policy.Close()
}

View File

@ -952,11 +952,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
zap.String("collectionName", collectionName),
zap.Int64("collectionID", collectionID))
info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
cacheShardLeaders, ok := m.getCollectionShardLeader(database, collectionName)
if withCache {
if ok {
@ -968,6 +963,12 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord")
}
info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
req := &querypb.GetShardLeadersRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders),

View File

@ -22,18 +22,14 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type RoundRobinBalancer struct {
// request num send to each node
nodeWorkload *typeutil.ConcurrentMap[int64, *atomic.Int64]
idx atomic.Int64
}
func NewRoundRobinBalancer() *RoundRobinBalancer {
return &RoundRobinBalancer{
nodeWorkload: typeutil.NewConcurrentMap[int64, *atomic.Int64](),
}
return &RoundRobinBalancer{}
}
func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) {
@ -41,32 +37,11 @@ func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []in
return -1, merr.ErrNodeNotAvailable
}
targetNode := int64(-1)
var targetNodeWorkload *atomic.Int64
for _, node := range availableNodes {
workload, ok := b.nodeWorkload.Get(node)
if !ok {
workload = atomic.NewInt64(0)
b.nodeWorkload.Insert(node, workload)
}
if targetNodeWorkload == nil || workload.Load() < targetNodeWorkload.Load() {
targetNode = node
targetNodeWorkload = workload
}
}
targetNodeWorkload.Add(cost)
return targetNode, nil
idx := b.idx.Inc()
return availableNodes[int(idx)%len(availableNodes)], nil
}
func (b *RoundRobinBalancer) CancelWorkload(node int64, nq int64) {
load, ok := b.nodeWorkload.Get(node)
if ok {
load.Sub(nq)
}
}
func (b *RoundRobinBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {}

View File

@ -20,6 +20,8 @@ import (
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type RoundRobinBalancerSuite struct {
@ -33,48 +35,34 @@ func (s *RoundRobinBalancerSuite) SetupTest() {
s.balancer.Start(context.Background())
}
func (s *RoundRobinBalancerSuite) TestRoundRobin() {
availableNodes := []int64{1, 2}
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
func TestSelectNode(t *testing.T) {
balancer := NewRoundRobinBalancer()
workload, ok := s.balancer.nodeWorkload.Get(1)
s.True(ok)
s.Equal(int64(2), workload.Load())
workload, ok = s.balancer.nodeWorkload.Get(1)
s.True(ok)
s.Equal(int64(2), workload.Load())
// Test case 1: Empty availableNodes
_, err1 := balancer.SelectNode(context.Background(), []int64{}, 0)
if err1 != merr.ErrNodeNotAvailable {
t.Errorf("Expected ErrNodeNotAvailable, got %v", err1)
}
s.balancer.SelectNode(context.TODO(), availableNodes, 3)
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
s.balancer.SelectNode(context.TODO(), availableNodes, 1)
// Test case 2: Non-empty availableNodes
availableNodes := []int64{1, 2, 3}
selectedNode2, err2 := balancer.SelectNode(context.Background(), availableNodes, 0)
if err2 != nil {
t.Errorf("Expected no error, got %v", err2)
}
if selectedNode2 < 1 || selectedNode2 > 3 {
t.Errorf("Expected a node in the range [1, 3], got %d", selectedNode2)
}
workload, ok = s.balancer.nodeWorkload.Get(1)
s.True(ok)
s.Equal(int64(5), workload.Load())
workload, ok = s.balancer.nodeWorkload.Get(1)
s.True(ok)
s.Equal(int64(5), workload.Load())
}
func (s *RoundRobinBalancerSuite) TestNoAvailableNode() {
availableNodes := []int64{}
_, err := s.balancer.SelectNode(context.TODO(), availableNodes, 1)
s.Error(err)
}
func (s *RoundRobinBalancerSuite) TestCancelWorkload() {
availableNodes := []int64{101}
_, err := s.balancer.SelectNode(context.TODO(), availableNodes, 5)
s.NoError(err)
workload, ok := s.balancer.nodeWorkload.Get(101)
s.True(ok)
s.Equal(int64(5), workload.Load())
s.balancer.CancelWorkload(101, 5)
s.Equal(int64(0), workload.Load())
// Test case 3: Boundary case
availableNodes = []int64{1}
selectedNode3, err3 := balancer.SelectNode(context.Background(), availableNodes, 0)
if err3 != nil {
t.Errorf("Expected no error, got %v", err3)
}
if selectedNode3 != 1 {
t.Errorf("Expected 1, got %d", selectedNode3)
}
}
func TestRoundRobinBalancerSuite(t *testing.T) {