mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
Related to #44761 Refactor proxy shard client management by creating a new internal/proxy/shardclient package. This improves code organization and modularity by: - Moving load balancing logic (LookAsideBalancer, RoundRobinBalancer) to shardclient package - Extracting shard client manager and related interfaces into separate package - Relocating shard leader management and client lifecycle code - Adding package documentation (README.md, OWNERS) - Updating proxy code to use the new shardclient package interfaces This change makes the shard client functionality more maintainable and better encapsulated, reducing coupling in the proxy layer. Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
603 lines
16 KiB
Go
603 lines
16 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package shardclient
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"go.uber.org/atomic"
|
|
|
|
"github.com/milvus-io/milvus/internal/mocks"
|
|
"github.com/milvus-io/milvus/internal/types"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
|
)
|
|
|
|
func TestShardClientMgr(t *testing.T) {
|
|
ctx := context.Background()
|
|
nodeInfo := NodeInfo{
|
|
NodeID: 1,
|
|
}
|
|
|
|
qn := mocks.NewMockQueryNodeClient(t)
|
|
qn.EXPECT().Close().Return(nil)
|
|
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
|
return qn, nil
|
|
}
|
|
|
|
mixcoord := mocks.NewMockMixCoordClient(t)
|
|
|
|
mgr := NewShardClientMgr(mixcoord)
|
|
mgr.SetClientCreatorFunc(creator)
|
|
_, err := mgr.GetClient(ctx, nodeInfo)
|
|
assert.Nil(t, err)
|
|
|
|
mgr.Close()
|
|
assert.Equal(t, mgr.clients.Len(), 0)
|
|
}
|
|
|
|
func TestShardClient(t *testing.T) {
|
|
nodeInfo := NodeInfo{
|
|
NodeID: 1,
|
|
}
|
|
|
|
qn := mocks.NewMockQueryNodeClient(t)
|
|
qn.EXPECT().Close().Return(nil).Maybe()
|
|
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
|
return qn, nil
|
|
}
|
|
shardClient := newShardClient(nodeInfo, creator, 3*time.Second)
|
|
assert.Equal(t, len(shardClient.clients), 0)
|
|
assert.Equal(t, false, shardClient.initialized.Load())
|
|
assert.Equal(t, false, shardClient.isClosed)
|
|
|
|
ctx := context.Background()
|
|
_, err := shardClient.getClient(ctx)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, len(shardClient.clients), paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt())
|
|
|
|
// test close
|
|
closed := shardClient.Close(false)
|
|
assert.False(t, closed)
|
|
closed = shardClient.Close(true)
|
|
assert.True(t, closed)
|
|
}
|
|
|
|
func TestPurgeClient(t *testing.T) {
|
|
node := NodeInfo{
|
|
NodeID: 1,
|
|
}
|
|
|
|
returnEmptyResult := atomic.NewBool(false)
|
|
|
|
qn := mocks.NewMockQueryNodeClient(t)
|
|
qn.EXPECT().Close().Return(nil).Maybe()
|
|
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
|
return qn, nil
|
|
}
|
|
|
|
s := &shardClientMgrImpl{
|
|
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
|
|
clientCreator: creator,
|
|
closeCh: make(chan struct{}),
|
|
purgeInterval: 1 * time.Second,
|
|
expiredDuration: 3 * time.Second,
|
|
|
|
collLeader: map[string]map[string]*shardLeaders{
|
|
"default": {
|
|
"test": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 1,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"0": {node},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
go s.PurgeClient()
|
|
defer s.Close()
|
|
_, err := s.GetClient(context.Background(), node)
|
|
assert.Nil(t, err)
|
|
qnClient, ok := s.clients.Get(1)
|
|
assert.True(t, ok)
|
|
assert.True(t, qnClient.lastActiveTs.Load() > 0)
|
|
|
|
time.Sleep(2 * time.Second)
|
|
// expected client should not been purged before expiredDuration
|
|
assert.Equal(t, s.clients.Len(), 1)
|
|
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() >= 2*time.Second.Nanoseconds())
|
|
|
|
_, err = s.GetClient(context.Background(), node)
|
|
assert.Nil(t, err)
|
|
time.Sleep(2 * time.Second)
|
|
// GetClient should refresh lastActiveTs, expected client should not be purged
|
|
assert.Equal(t, s.clients.Len(), 1)
|
|
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() < 3*time.Second.Nanoseconds())
|
|
|
|
time.Sleep(2 * time.Second)
|
|
// client reach the expiredDuration, expected client should not be purged
|
|
assert.Equal(t, s.clients.Len(), 1)
|
|
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() > 3*time.Second.Nanoseconds())
|
|
|
|
s.DeprecateShardCache("default", "test")
|
|
returnEmptyResult.Store(true)
|
|
time.Sleep(2 * time.Second)
|
|
// remove client from shard location, expected client should be purged
|
|
assert.Eventually(t, func() bool {
|
|
return s.clients.Len() == 0
|
|
}, 10*time.Second, 1*time.Second)
|
|
}
|
|
|
|
func TestDeprecateShardCache(t *testing.T) {
|
|
node := NodeInfo{
|
|
NodeID: 1,
|
|
}
|
|
|
|
qn := mocks.NewMockQueryNodeClient(t)
|
|
qn.EXPECT().Close().Return(nil).Maybe()
|
|
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
|
return qn, nil
|
|
}
|
|
|
|
mixcoord := mocks.NewMockMixCoordClient(t)
|
|
mgr := NewShardClientMgr(mixcoord)
|
|
mgr.SetClientCreatorFunc(creator)
|
|
|
|
t.Run("Clear with no collection info", func(t *testing.T) {
|
|
mgr.DeprecateShardCache("default", "collection_not_exist")
|
|
// Should not panic or error
|
|
})
|
|
|
|
t.Run("Clear valid collection empty cache", func(t *testing.T) {
|
|
// Add a collection to cache first
|
|
mgr.leaderMut.Lock()
|
|
mgr.collLeader = map[string]map[string]*shardLeaders{
|
|
"default": {
|
|
"test_collection": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 100,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-1": {node},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
mgr.leaderMut.Unlock()
|
|
|
|
mgr.DeprecateShardCache("default", "test_collection")
|
|
|
|
// Verify cache is cleared
|
|
mgr.leaderMut.RLock()
|
|
defer mgr.leaderMut.RUnlock()
|
|
_, exists := mgr.collLeader["default"]["test_collection"]
|
|
assert.False(t, exists)
|
|
})
|
|
|
|
t.Run("Clear one collection, keep others", func(t *testing.T) {
|
|
mgr.leaderMut.Lock()
|
|
mgr.collLeader = map[string]map[string]*shardLeaders{
|
|
"default": {
|
|
"collection1": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 100,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-1": {node},
|
|
},
|
|
},
|
|
"collection2": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 101,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-2": {node},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
mgr.leaderMut.Unlock()
|
|
|
|
mgr.DeprecateShardCache("default", "collection1")
|
|
|
|
// Verify collection1 is cleared but collection2 remains
|
|
mgr.leaderMut.RLock()
|
|
defer mgr.leaderMut.RUnlock()
|
|
_, exists := mgr.collLeader["default"]["collection1"]
|
|
assert.False(t, exists)
|
|
_, exists = mgr.collLeader["default"]["collection2"]
|
|
assert.True(t, exists)
|
|
})
|
|
|
|
t.Run("Clear last collection in database removes database", func(t *testing.T) {
|
|
mgr.leaderMut.Lock()
|
|
mgr.collLeader = map[string]map[string]*shardLeaders{
|
|
"test_db": {
|
|
"last_collection": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 200,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-1": {node},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
mgr.leaderMut.Unlock()
|
|
|
|
mgr.DeprecateShardCache("test_db", "last_collection")
|
|
|
|
// Verify database is also removed
|
|
mgr.leaderMut.RLock()
|
|
defer mgr.leaderMut.RUnlock()
|
|
_, exists := mgr.collLeader["test_db"]
|
|
assert.False(t, exists)
|
|
})
|
|
|
|
mgr.Close()
|
|
}
|
|
|
|
func TestInvalidateShardLeaderCache(t *testing.T) {
|
|
node := NodeInfo{
|
|
NodeID: 1,
|
|
}
|
|
|
|
qn := mocks.NewMockQueryNodeClient(t)
|
|
qn.EXPECT().Close().Return(nil).Maybe()
|
|
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
|
return qn, nil
|
|
}
|
|
|
|
mixcoord := mocks.NewMockMixCoordClient(t)
|
|
mgr := NewShardClientMgr(mixcoord)
|
|
mgr.SetClientCreatorFunc(creator)
|
|
|
|
t.Run("Invalidate single collection", func(t *testing.T) {
|
|
mgr.leaderMut.Lock()
|
|
mgr.collLeader = map[string]map[string]*shardLeaders{
|
|
"default": {
|
|
"collection1": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 100,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-1": {node},
|
|
},
|
|
},
|
|
"collection2": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 101,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-2": {node},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
mgr.leaderMut.Unlock()
|
|
|
|
mgr.InvalidateShardLeaderCache([]int64{100})
|
|
|
|
// Verify collection with ID 100 is removed, but 101 remains
|
|
mgr.leaderMut.RLock()
|
|
defer mgr.leaderMut.RUnlock()
|
|
_, exists := mgr.collLeader["default"]["collection1"]
|
|
assert.False(t, exists)
|
|
_, exists = mgr.collLeader["default"]["collection2"]
|
|
assert.True(t, exists)
|
|
})
|
|
|
|
t.Run("Invalidate multiple collections", func(t *testing.T) {
|
|
mgr.leaderMut.Lock()
|
|
mgr.collLeader = map[string]map[string]*shardLeaders{
|
|
"default": {
|
|
"collection1": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 100,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-1": {node},
|
|
},
|
|
},
|
|
"collection2": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 101,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-2": {node},
|
|
},
|
|
},
|
|
"collection3": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 102,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-3": {node},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
mgr.leaderMut.Unlock()
|
|
|
|
mgr.InvalidateShardLeaderCache([]int64{100, 102})
|
|
|
|
// Verify collections 100 and 102 are removed, but 101 remains
|
|
mgr.leaderMut.RLock()
|
|
defer mgr.leaderMut.RUnlock()
|
|
_, exists := mgr.collLeader["default"]["collection1"]
|
|
assert.False(t, exists)
|
|
_, exists = mgr.collLeader["default"]["collection2"]
|
|
assert.True(t, exists)
|
|
_, exists = mgr.collLeader["default"]["collection3"]
|
|
assert.False(t, exists)
|
|
})
|
|
|
|
t.Run("Invalidate non-existent collection", func(t *testing.T) {
|
|
mgr.leaderMut.Lock()
|
|
mgr.collLeader = map[string]map[string]*shardLeaders{
|
|
"default": {
|
|
"collection1": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 100,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-1": {node},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
mgr.leaderMut.Unlock()
|
|
|
|
mgr.InvalidateShardLeaderCache([]int64{999})
|
|
|
|
// Verify collection1 still exists
|
|
mgr.leaderMut.RLock()
|
|
defer mgr.leaderMut.RUnlock()
|
|
_, exists := mgr.collLeader["default"]["collection1"]
|
|
assert.True(t, exists)
|
|
})
|
|
|
|
t.Run("Invalidate all collections in database removes database", func(t *testing.T) {
|
|
mgr.leaderMut.Lock()
|
|
mgr.collLeader = map[string]map[string]*shardLeaders{
|
|
"test_db": {
|
|
"collection1": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 200,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-1": {node},
|
|
},
|
|
},
|
|
"collection2": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 201,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-2": {node},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
mgr.leaderMut.Unlock()
|
|
|
|
mgr.InvalidateShardLeaderCache([]int64{200, 201})
|
|
|
|
// Verify database is removed
|
|
mgr.leaderMut.RLock()
|
|
defer mgr.leaderMut.RUnlock()
|
|
_, exists := mgr.collLeader["test_db"]
|
|
assert.False(t, exists)
|
|
})
|
|
|
|
t.Run("Invalidate across multiple databases", func(t *testing.T) {
|
|
mgr.leaderMut.Lock()
|
|
mgr.collLeader = map[string]map[string]*shardLeaders{
|
|
"db1": {
|
|
"collection1": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 100,
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-1": {node},
|
|
},
|
|
},
|
|
},
|
|
"db2": {
|
|
"collection2": {
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 100, // Same collection ID in different database
|
|
shardLeaders: map[string][]NodeInfo{
|
|
"channel-2": {node},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
mgr.leaderMut.Unlock()
|
|
|
|
mgr.InvalidateShardLeaderCache([]int64{100})
|
|
|
|
// Verify collection is removed from both databases
|
|
mgr.leaderMut.RLock()
|
|
defer mgr.leaderMut.RUnlock()
|
|
_, exists := mgr.collLeader["db1"]
|
|
assert.False(t, exists) // db1 should be removed
|
|
_, exists = mgr.collLeader["db2"]
|
|
assert.False(t, exists) // db2 should be removed
|
|
})
|
|
|
|
mgr.Close()
|
|
}
|
|
|
|
func TestShuffleShardLeaders(t *testing.T) {
|
|
t.Run("Shuffle with multiple nodes", func(t *testing.T) {
|
|
shards := map[string][]NodeInfo{
|
|
"channel-1": {
|
|
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
|
|
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
|
|
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
|
|
},
|
|
}
|
|
sl := &shardLeaders{
|
|
idx: atomic.NewInt64(5),
|
|
collectionID: 100,
|
|
shardLeaders: shards,
|
|
}
|
|
|
|
reader := sl.GetReader()
|
|
result := reader.Shuffle()
|
|
|
|
// Verify result has same channel
|
|
assert.Len(t, result, 1)
|
|
assert.Contains(t, result, "channel-1")
|
|
|
|
// Verify all nodes are present
|
|
assert.Len(t, result["channel-1"], 3)
|
|
|
|
// Verify the first node is based on idx rotation (idx=6, 6%3=0, so nodeID 1 should be first)
|
|
assert.Equal(t, int64(1), result["channel-1"][0].NodeID)
|
|
|
|
// Verify all nodes are still present (shuffled)
|
|
nodeIDs := make(map[int64]bool)
|
|
for _, node := range result["channel-1"] {
|
|
nodeIDs[node.NodeID] = true
|
|
}
|
|
assert.True(t, nodeIDs[1])
|
|
assert.True(t, nodeIDs[2])
|
|
assert.True(t, nodeIDs[3])
|
|
})
|
|
|
|
t.Run("Shuffle rotates first replica based on idx", func(t *testing.T) {
|
|
shards := map[string][]NodeInfo{
|
|
"channel-1": {
|
|
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
|
|
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
|
|
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
|
|
},
|
|
}
|
|
sl := &shardLeaders{
|
|
idx: atomic.NewInt64(5),
|
|
collectionID: 100,
|
|
shardLeaders: shards,
|
|
}
|
|
|
|
// First read, idx will be 6 (5+1), 6%3=0, so first replica should be leaders[0] which is nodeID 1
|
|
reader := sl.GetReader()
|
|
result := reader.Shuffle()
|
|
assert.Equal(t, int64(1), result["channel-1"][0].NodeID)
|
|
|
|
// Second read, idx will be 7 (6+1), 7%3=1, so first replica should be leaders[1] which is nodeID 2
|
|
reader = sl.GetReader()
|
|
result = reader.Shuffle()
|
|
assert.Equal(t, int64(2), result["channel-1"][0].NodeID)
|
|
|
|
// Third read, idx will be 8 (7+1), 8%3=2, so first replica should be leaders[2] which is nodeID 3
|
|
reader = sl.GetReader()
|
|
result = reader.Shuffle()
|
|
assert.Equal(t, int64(3), result["channel-1"][0].NodeID)
|
|
})
|
|
|
|
t.Run("Shuffle with single node", func(t *testing.T) {
|
|
shards := map[string][]NodeInfo{
|
|
"channel-1": {
|
|
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
|
|
},
|
|
}
|
|
sl := &shardLeaders{
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 100,
|
|
shardLeaders: shards,
|
|
}
|
|
|
|
reader := sl.GetReader()
|
|
result := reader.Shuffle()
|
|
|
|
assert.Len(t, result["channel-1"], 1)
|
|
assert.Equal(t, int64(1), result["channel-1"][0].NodeID)
|
|
})
|
|
|
|
t.Run("Shuffle with multiple channels", func(t *testing.T) {
|
|
shards := map[string][]NodeInfo{
|
|
"channel-1": {
|
|
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
|
|
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
|
|
},
|
|
"channel-2": {
|
|
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
|
|
{NodeID: 4, Address: "localhost:9003", Serviceable: true},
|
|
},
|
|
}
|
|
sl := &shardLeaders{
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 100,
|
|
shardLeaders: shards,
|
|
}
|
|
|
|
reader := sl.GetReader()
|
|
result := reader.Shuffle()
|
|
|
|
// Verify both channels are present
|
|
assert.Len(t, result, 2)
|
|
assert.Contains(t, result, "channel-1")
|
|
assert.Contains(t, result, "channel-2")
|
|
|
|
// Verify each channel has correct number of nodes
|
|
assert.Len(t, result["channel-1"], 2)
|
|
assert.Len(t, result["channel-2"], 2)
|
|
})
|
|
|
|
t.Run("Shuffle with empty leaders", func(t *testing.T) {
|
|
shards := map[string][]NodeInfo{}
|
|
sl := &shardLeaders{
|
|
idx: atomic.NewInt64(0),
|
|
collectionID: 100,
|
|
shardLeaders: shards,
|
|
}
|
|
|
|
reader := sl.GetReader()
|
|
result := reader.Shuffle()
|
|
|
|
assert.Len(t, result, 0)
|
|
})
|
|
}
|
|
|
|
// func BenchmarkShardClientMgr(b *testing.B) {
|
|
// node := nodeInfo{
|
|
// nodeID: 1,
|
|
// }
|
|
// cache := NewMockCache(b)
|
|
// cache.EXPECT().ListShardLocation().Return(map[int64]nodeInfo{
|
|
// 1: node,
|
|
// }).Maybe()
|
|
// globalMetaCache = cache
|
|
// qn := mocks.NewMockQueryNodeClient(b)
|
|
// qn.EXPECT().Close().Return(nil).Maybe()
|
|
|
|
// creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
|
// return qn, nil
|
|
// }
|
|
// s := &shardClientMgrImpl{
|
|
// clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
|
|
// clientCreator: creator,
|
|
// closeCh: make(chan struct{}),
|
|
// purgeInterval: 1 * time.Second,
|
|
// expiredDuration: 10 * time.Second,
|
|
// }
|
|
// go s.PurgeClient()
|
|
// defer s.Close()
|
|
|
|
// b.ResetTimer()
|
|
// b.RunParallel(func(pb *testing.PB) {
|
|
// for pb.Next() {
|
|
// _, err := s.GetClient(context.Background(), node)
|
|
// assert.Nil(b, err)
|
|
// }
|
|
// })
|
|
// }
|