mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
feat: Implement partial result support on node down (#42009)
issue: https://github.com/milvus-io/milvus/issues/41690 This commit implements partial search result functionality when query nodes go down, improving system availability during node failures. The changes include: - Enhanced load balancing in proxy (lb_policy.go) to handle node failures with retry support - Added partial search result capability in querynode delegator and distribution logic - Implemented tests for various partial result scenarios when nodes go down - Added metrics to track partial search results in querynode_metrics.go - Updated parameter configuration to support partial result required data ratio - Replaced old partial_search_test.go with more comprehensive partial_result_on_node_down_test.go - Updated proto definitions and improved retry logic These changes improve query resilience by returning partial results to users when some query nodes are unavailable, ensuring that queries don't completely fail when a portion of data remains accessible. --------- Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
parent
57b58ad778
commit
54619eaa2c
@ -329,6 +329,7 @@ proxy:
|
||||
slowQuerySpanInSeconds: 5 # query whose executed time exceeds the `slowQuerySpanInSeconds` can be considered slow, in seconds.
|
||||
queryNodePooling:
|
||||
size: 10 # the size for shardleader(querynode) client pool
|
||||
partialResultRequiredDataRatio: 1 # partial result required data ratio, default to 1 which means disable partial result, otherwise, it will be used as the minimum data ratio for partial result
|
||||
http:
|
||||
enabled: true # Whether to enable the http server
|
||||
debug_mode: false # Whether to enable http server debug mode
|
||||
|
||||
@ -17,6 +17,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/samber/lo"
|
||||
@ -119,22 +120,70 @@ func (lb *LBPolicyImpl) GetShardLeaders(ctx context.Context, dbName string, coll
|
||||
}
|
||||
|
||||
// try to select the best node from the available nodes
|
||||
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (nodeInfo, error) {
|
||||
filterDelegator := func(nodes []nodeInfo) map[int64]nodeInfo {
|
||||
ret := make(map[int64]nodeInfo)
|
||||
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload *ChannelWorkload, excludeNodes typeutil.UniqueSet) (nodeInfo, error) {
|
||||
log := log.Ctx(ctx)
|
||||
// Select node using specified nodes
|
||||
trySelectNode := func(nodes []nodeInfo) (nodeInfo, error) {
|
||||
candidateNodes := make(map[int64]nodeInfo)
|
||||
serviceableNodes := make(map[int64]nodeInfo)
|
||||
// Filter nodes based on excludeNodes
|
||||
for _, node := range nodes {
|
||||
if !excludeNodes.Contain(node.nodeID) {
|
||||
ret[node.nodeID] = node
|
||||
if node.serviceable {
|
||||
serviceableNodes[node.nodeID] = node
|
||||
}
|
||||
candidateNodes[node.nodeID] = node
|
||||
}
|
||||
}
|
||||
return ret
|
||||
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
candidatesInStr := lo.Map(nodes, func(node nodeInfo, _ int) string {
|
||||
return node.String()
|
||||
})
|
||||
serviceableNodesInStr := lo.Map(lo.Values(serviceableNodes), func(node nodeInfo, _ int) string {
|
||||
return node.String()
|
||||
})
|
||||
log.Warn("failed to select shard",
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Int64s("excluded", excludeNodes.Collect()),
|
||||
zap.String("candidates", strings.Join(candidatesInStr, ", ")),
|
||||
zap.String("serviceableNodes", strings.Join(serviceableNodesInStr, ", ")),
|
||||
zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
if len(candidateNodes) == 0 {
|
||||
err = merr.WrapErrChannelNotAvailable(workload.channel)
|
||||
return nodeInfo{}, err
|
||||
}
|
||||
|
||||
balancer.RegisterNodeInfo(lo.Values(candidateNodes))
|
||||
// prefer serviceable nodes
|
||||
var targetNodeID int64
|
||||
if len(serviceableNodes) > 0 {
|
||||
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(serviceableNodes), workload.nq)
|
||||
} else {
|
||||
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(candidateNodes), workload.nq)
|
||||
}
|
||||
if err != nil {
|
||||
return nodeInfo{}, err
|
||||
}
|
||||
|
||||
if _, ok := candidateNodes[targetNodeID]; !ok {
|
||||
err = merr.WrapErrNodeNotAvailable(targetNodeID)
|
||||
return nodeInfo{}, err
|
||||
}
|
||||
|
||||
return candidateNodes[targetNodeID], nil
|
||||
}
|
||||
|
||||
availableNodes := filterDelegator(workload.shardLeaders)
|
||||
balancer.RegisterNodeInfo(lo.Values(availableNodes))
|
||||
targetNode, err := balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq)
|
||||
// First attempt with current shard leaders
|
||||
targetNode, err := trySelectNode(workload.shardLeaders)
|
||||
// If failed, refresh cache and retry
|
||||
if err != nil {
|
||||
log := log.Ctx(ctx)
|
||||
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
|
||||
shardLeaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, false)
|
||||
if err != nil {
|
||||
@ -145,51 +194,41 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
|
||||
return nodeInfo{}, err
|
||||
}
|
||||
|
||||
availableNodes = filterDelegator(shardLeaders[workload.channel])
|
||||
if len(availableNodes) == 0 {
|
||||
log.Warn("no available shard delegator found",
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Int64s("availableNodes", lo.Keys(availableNodes)),
|
||||
zap.Int64s("excluded", excludeNodes.Collect()))
|
||||
return nodeInfo{}, merr.WrapErrChannelNotAvailable("no available shard delegator found")
|
||||
}
|
||||
|
||||
balancer.RegisterNodeInfo(lo.Values(availableNodes))
|
||||
targetNode, err = balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq)
|
||||
workload.shardLeaders = shardLeaders[workload.channel]
|
||||
// Second attempt with fresh shard leaders
|
||||
targetNode, err = trySelectNode(workload.shardLeaders)
|
||||
if err != nil {
|
||||
log.Warn("failed to select shard",
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Int64s("availableNodes", lo.Keys(availableNodes)),
|
||||
zap.Int64s("excluded", excludeNodes.Collect()),
|
||||
zap.Error(err))
|
||||
return nodeInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return availableNodes[targetNode], nil
|
||||
return targetNode, nil
|
||||
}
|
||||
|
||||
// ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes.
|
||||
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
|
||||
excludeNodes := typeutil.NewUniqueSet()
|
||||
|
||||
var lastErr error
|
||||
err := retry.Do(ctx, func() error {
|
||||
excludeNodes := typeutil.NewUniqueSet()
|
||||
tryExecute := func() (bool, error) {
|
||||
// if keeping retry after all nodes are excluded, try to clean excludeNodes
|
||||
if excludeNodes.Len() == len(workload.shardLeaders) {
|
||||
excludeNodes.Clear()
|
||||
}
|
||||
|
||||
balancer := lb.getBalancer()
|
||||
targetNode, err := lb.selectNode(ctx, balancer, workload, excludeNodes)
|
||||
targetNode, err := lb.selectNode(ctx, balancer, &workload, excludeNodes)
|
||||
if err != nil {
|
||||
log.Warn("failed to select node for shard",
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Int64("nodeID", targetNode.nodeID),
|
||||
zap.Int64s("excluded", excludeNodes.Collect()),
|
||||
zap.Error(err),
|
||||
)
|
||||
if lastErr != nil {
|
||||
return lastErr
|
||||
return true, lastErr
|
||||
}
|
||||
return err
|
||||
return true, err
|
||||
}
|
||||
// cancel work load which assign to the target node
|
||||
defer balancer.CancelWorkload(targetNode.nodeID, workload.nq)
|
||||
@ -204,7 +243,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
||||
excludeNodes.Insert(targetNode.nodeID)
|
||||
|
||||
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode.nodeID, workload.channel)
|
||||
return lastErr
|
||||
return true, lastErr
|
||||
}
|
||||
|
||||
err = workload.exec(ctx, targetNode.nodeID, client, workload.channel)
|
||||
@ -216,11 +255,19 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
||||
zap.Error(err))
|
||||
excludeNodes.Insert(targetNode.nodeID)
|
||||
lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode.nodeID, workload.channel)
|
||||
return lastErr
|
||||
return true, lastErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}, retry.Attempts(workload.retryTimes))
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// if failed, try to execute with partial result
|
||||
err := retry.Handle(ctx, tryExecute, retry.Attempts(workload.retryTimes))
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to execute with partial result",
|
||||
zap.String("channel", workload.channel),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
@ -233,8 +280,14 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
|
||||
return err
|
||||
}
|
||||
|
||||
// let every request could retry at least twice, which could retry after update shard leader cache
|
||||
wg, ctx := errgroup.WithContext(ctx)
|
||||
totalChannels := len(dml2leaders)
|
||||
if totalChannels == 0 {
|
||||
log.Ctx(ctx).Info("no shard leaders found", zap.Int64("collectionID", workload.collectionID))
|
||||
return merr.WrapErrCollectionNotLoaded(workload.collectionID)
|
||||
}
|
||||
|
||||
wg, _ := errgroup.WithContext(ctx)
|
||||
// Launch a goroutine for each channel
|
||||
for k, v := range dml2leaders {
|
||||
channel := k
|
||||
nodes := v
|
||||
|
||||
@ -69,8 +69,9 @@ func (s *LBPolicySuite) SetupTest() {
|
||||
for i := 1; i <= 5; i++ {
|
||||
s.nodeIDs = append(s.nodeIDs, int64(i))
|
||||
s.nodes = append(s.nodes, nodeInfo{
|
||||
nodeID: int64(i),
|
||||
address: "localhost",
|
||||
nodeID: int64(i),
|
||||
address: "localhost",
|
||||
serviceable: true,
|
||||
})
|
||||
}
|
||||
s.channels = []string{"channel1", "channel2"}
|
||||
@ -84,11 +85,13 @@ func (s *LBPolicySuite) SetupTest() {
|
||||
ChannelName: s.channels[0],
|
||||
NodeIds: s.nodeIDs,
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
|
||||
Serviceable: []bool{true, true, true, true, true},
|
||||
},
|
||||
{
|
||||
ChannelName: s.channels[1],
|
||||
NodeIds: s.nodeIDs,
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
|
||||
Serviceable: []bool{true, true, true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
@ -175,7 +178,7 @@ func (s *LBPolicySuite) TestSelectNode() {
|
||||
ctx := context.Background()
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
|
||||
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
@ -191,12 +194,12 @@ func (s *LBPolicySuite) TestSelectNode() {
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
shardLeaders: []nodeInfo{},
|
||||
shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
}, typeutil.NewUniqueSet())
|
||||
s.NoError(err)
|
||||
@ -206,7 +209,7 @@ func (s *LBPolicySuite) TestSelectNode() {
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
@ -220,7 +223,7 @@ func (s *LBPolicySuite) TestSelectNode() {
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
@ -237,7 +240,7 @@ func (s *LBPolicySuite) TestSelectNode() {
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
return nil, merr.ErrServiceUnavailable
|
||||
}
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
@ -291,7 +294,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
||||
|
||||
// test get client failed, and retry failed, expected success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(2)
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
|
||||
@ -313,6 +316,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
@ -334,7 +341,9 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
counter := 0
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
@ -384,7 +393,9 @@ func (s *LBPolicySuite) TestExecute() {
|
||||
// test all channel success
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
|
||||
db: dbName,
|
||||
|
||||
@ -970,7 +970,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
|
||||
}
|
||||
|
||||
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
|
||||
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord")
|
||||
}
|
||||
|
||||
info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
|
||||
@ -983,7 +982,8 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
CollectionID: info.collID,
|
||||
CollectionID: info.collID,
|
||||
WithUnserviceableShards: true,
|
||||
}
|
||||
|
||||
tr := timerecord.NewTimeRecorder("UpdateShardCache")
|
||||
@ -1002,6 +1002,19 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
|
||||
idx: atomic.NewInt64(0),
|
||||
}
|
||||
|
||||
// convert shards map to string for logging
|
||||
if log.Logger.Level() == zap.DebugLevel {
|
||||
shardStr := make([]string, 0, len(shards))
|
||||
for channel, nodes := range shards {
|
||||
nodeStrs := make([]string, 0, len(nodes))
|
||||
for _, node := range nodes {
|
||||
nodeStrs = append(nodeStrs, node.String())
|
||||
}
|
||||
shardStr = append(shardStr, fmt.Sprintf("%s:[%s]", channel, strings.Join(nodeStrs, ", ")))
|
||||
}
|
||||
log.Debug("update shard leader cache", zap.String("newShardLeaders", strings.Join(shardStr, ", ")))
|
||||
}
|
||||
|
||||
m.leaderMut.Lock()
|
||||
if _, ok := m.collLeader[database]; !ok {
|
||||
m.collLeader[database] = make(map[string]*shardLeaders)
|
||||
@ -1028,7 +1041,7 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m
|
||||
qns := make([]nodeInfo, len(leaders.GetNodeIds()))
|
||||
|
||||
for j := range qns {
|
||||
qns[j] = nodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]}
|
||||
qns[j] = nodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j], leaders.GetServiceable()[j]}
|
||||
}
|
||||
|
||||
shard2QueryNodes[leaders.GetChannelName()] = qns
|
||||
|
||||
@ -1395,6 +1395,7 @@ func TestMetaCache_GetShards(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
@ -1455,6 +1456,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
@ -1836,6 +1838,7 @@ func TestMetaCache_InvalidateShardLeaderCache(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
@ -1505,6 +1505,7 @@ func (coord *MixCoordMock) GetShardLeaders(ctx context.Context, in *querypb.GetS
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
@ -20,12 +20,13 @@ import (
|
||||
type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error)
|
||||
|
||||
type nodeInfo struct {
|
||||
nodeID UniqueID
|
||||
address string
|
||||
nodeID UniqueID
|
||||
address string
|
||||
serviceable bool
|
||||
}
|
||||
|
||||
func (n nodeInfo) String() string {
|
||||
return fmt.Sprintf("<NodeID: %d>", n.nodeID)
|
||||
return fmt.Sprintf("<NodeID: %d, serviceable: %v, address: %s>", n.nodeID, n.serviceable, n.address)
|
||||
}
|
||||
|
||||
var errClosed = errors.New("client is closed")
|
||||
|
||||
@ -161,6 +161,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
@ -185,6 +186,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
@ -210,6 +212,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
@ -241,6 +244,7 @@ func getMockQueryCoord() *mocks.MockMixCoordClient {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
@ -2994,6 +2994,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
@ -65,6 +65,7 @@ func (s *StatisticTaskSuite) SetupTest() {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
|
||||
@ -978,6 +978,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
@ -1002,6 +1003,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
@ -1026,6 +1028,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
@ -1057,6 +1060,7 @@ func Test_isPartitionIsLoaded(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
@ -1082,6 +1086,7 @@ func Test_isPartitionIsLoaded(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
@ -1107,6 +1112,7 @@ func Test_isPartitionIsLoaded(t *testing.T) {
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
@ -31,6 +31,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/session"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
@ -264,6 +265,12 @@ func (ob *TargetObserver) check(ctx context.Context, collectionID int64) {
|
||||
if ob.shouldUpdateNextTarget(ctx, collectionID) {
|
||||
// update next target in collection level
|
||||
ob.updateNextTarget(ctx, collectionID)
|
||||
|
||||
// sync next target to delegator if current target not exist, to support partial search
|
||||
if !ob.targetMgr.IsCurrentTargetExist(ctx, collectionID, -1) {
|
||||
newVersion := ob.targetMgr.GetCollectionTargetVersion(ctx, collectionID, meta.NextTarget)
|
||||
ob.syncNextTargetToDelegator(ctx, collectionID, ob.distMgr.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(collectionID)), newVersion)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -412,14 +419,18 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect
|
||||
collReadyDelegatorList = append(collReadyDelegatorList, chReadyDelegatorList...)
|
||||
}
|
||||
|
||||
return ob.syncNextTargetToDelegator(ctx, collectionID, collReadyDelegatorList, newVersion)
|
||||
}
|
||||
|
||||
// sync next target info to delegator as readable snapshot
|
||||
// 1. if next target is changed before delegator becomes serviceable, we need to sync the new next target to delegator to support partial search
|
||||
// 2. if next target is ready to read, we need to sync the next target to delegator to support full search
|
||||
func (ob *TargetObserver) syncNextTargetToDelegator(ctx context.Context, collectionID int64, collReadyDelegatorList []*meta.DmChannel, newVersion int64) bool {
|
||||
var partitions []int64
|
||||
var indexInfo []*indexpb.IndexInfo
|
||||
var err error
|
||||
for _, d := range collReadyDelegatorList {
|
||||
updateVersionAction := ob.checkNeedUpdateTargetVersion(ctx, d.View, newVersion)
|
||||
if updateVersionAction == nil {
|
||||
continue
|
||||
}
|
||||
updateVersionAction := ob.genSyncAction(ctx, d.View, newVersion)
|
||||
replica := ob.meta.ReplicaManager.GetByCollectionAndNode(ctx, collectionID, d.Node)
|
||||
if replica == nil {
|
||||
log.Warn("replica not found", zap.Int64("nodeID", d.Node), zap.Int64("collectionID", collectionID))
|
||||
@ -441,19 +452,16 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect
|
||||
}
|
||||
}
|
||||
|
||||
if !ob.sync(ctx, replica, d.View, []*querypb.SyncAction{updateVersionAction}, partitions, indexInfo) {
|
||||
if !ob.syncToDelegator(ctx, replica, d.View, updateVersionAction, partitions, indexInfo) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (ob *TargetObserver) sync(ctx context.Context, replica *meta.Replica, LeaderView *meta.LeaderView, diffs []*querypb.SyncAction,
|
||||
func (ob *TargetObserver) syncToDelegator(ctx context.Context, replica *meta.Replica, LeaderView *meta.LeaderView, action *querypb.SyncAction,
|
||||
partitions []int64, indexInfo []*indexpb.IndexInfo,
|
||||
) bool {
|
||||
if len(diffs) == 0 {
|
||||
return true
|
||||
}
|
||||
replicaID := replica.GetID()
|
||||
|
||||
log := log.With(
|
||||
@ -469,7 +477,7 @@ func (ob *TargetObserver) sync(ctx context.Context, replica *meta.Replica, Leade
|
||||
CollectionID: LeaderView.CollectionID,
|
||||
ReplicaID: replicaID,
|
||||
Channel: LeaderView.Channel,
|
||||
Actions: diffs,
|
||||
Actions: []*querypb.SyncAction{action},
|
||||
LoadMeta: &querypb.LoadMetaInfo{
|
||||
LoadType: ob.meta.GetLoadType(ctx, LeaderView.CollectionID),
|
||||
CollectionID: LeaderView.CollectionID,
|
||||
@ -496,31 +504,34 @@ func (ob *TargetObserver) sync(ctx context.Context, replica *meta.Replica, Leade
|
||||
return true
|
||||
}
|
||||
|
||||
func (ob *TargetObserver) checkNeedUpdateTargetVersion(ctx context.Context, leaderView *meta.LeaderView, targetVersion int64) *querypb.SyncAction {
|
||||
log.Ctx(ctx).WithRateGroup("qcv2.LeaderObserver", 1, 60)
|
||||
if targetVersion <= leaderView.TargetVersion {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.RatedInfo(10, "Update readable segment version",
|
||||
zap.Int64("collectionID", leaderView.CollectionID),
|
||||
zap.String("channelName", leaderView.Channel),
|
||||
zap.Int64("nodeID", leaderView.ID),
|
||||
zap.Int64("oldVersion", leaderView.TargetVersion),
|
||||
zap.Int64("newVersion", targetVersion),
|
||||
)
|
||||
// sync next target info to delegator
|
||||
// 1. if next target is changed before delegator becomes serviceable, we need to sync the new next target to delegator to support partial search
|
||||
// 2. if next target is ready to read, we need to sync the next target to delegator to support full search
|
||||
func (ob *TargetObserver) genSyncAction(ctx context.Context, leaderView *meta.LeaderView, targetVersion int64) *querypb.SyncAction {
|
||||
log.Ctx(ctx).WithRateGroup("qcv2.LeaderObserver", 1, 60).
|
||||
RatedInfo(10, "Update readable segment version",
|
||||
zap.Int64("collectionID", leaderView.CollectionID),
|
||||
zap.String("channelName", leaderView.Channel),
|
||||
zap.Int64("nodeID", leaderView.ID),
|
||||
zap.Int64("oldVersion", leaderView.TargetVersion),
|
||||
zap.Int64("newVersion", targetVersion),
|
||||
)
|
||||
|
||||
sealedSegments := ob.targetMgr.GetSealedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
|
||||
growingSegments := ob.targetMgr.GetGrowingSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
|
||||
droppedSegments := ob.targetMgr.GetDroppedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
|
||||
channel := ob.targetMgr.GetDmChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTargetFirst)
|
||||
sealedSegmentRowCount := lo.MapValues(sealedSegments, func(segment *datapb.SegmentInfo, _ int64) int64 {
|
||||
return segment.GetNumOfRows()
|
||||
})
|
||||
|
||||
action := &querypb.SyncAction{
|
||||
Type: querypb.SyncType_UpdateVersion,
|
||||
GrowingInTarget: growingSegments.Collect(),
|
||||
SealedInTarget: lo.Keys(sealedSegments),
|
||||
DroppedInTarget: droppedSegments,
|
||||
TargetVersion: targetVersion,
|
||||
Type: querypb.SyncType_UpdateVersion,
|
||||
GrowingInTarget: growingSegments.Collect(),
|
||||
SealedInTarget: lo.Keys(sealedSegmentRowCount),
|
||||
DroppedInTarget: droppedSegments,
|
||||
TargetVersion: targetVersion,
|
||||
SealedSegmentRowCount: sealedSegmentRowCount,
|
||||
}
|
||||
|
||||
if channel != nil {
|
||||
|
||||
@ -263,7 +263,7 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
|
||||
}, 7*time.Second, 1*time.Second)
|
||||
|
||||
ch1View := suite.distMgr.ChannelDistManager.GetByFilter(meta.WithChannelName2Channel("channel-1"))[0].View
|
||||
action := suite.observer.checkNeedUpdateTargetVersion(ctx, ch1View, 100)
|
||||
action := suite.observer.genSyncAction(ctx, ch1View, 100)
|
||||
suite.Equal(action.GetDeleteCP().Timestamp, uint64(200))
|
||||
}
|
||||
|
||||
|
||||
@ -902,7 +902,7 @@ func (s *Server) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeade
|
||||
}, nil
|
||||
}
|
||||
|
||||
leaders, err := utils.GetShardLeaders(ctx, s.meta, s.targetMgr, s.dist, s.nodeMgr, req.GetCollectionID())
|
||||
leaders, err := utils.GetShardLeaders(ctx, s.meta, s.targetMgr, s.dist, s.nodeMgr, req.GetCollectionID(), req.GetWithUnserviceableShards())
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: merr.Status(err),
|
||||
Shards: leaders,
|
||||
|
||||
@ -35,6 +35,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/session"
|
||||
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
@ -384,6 +385,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
|
||||
return merr.WrapErrServiceInternal(fmt.Sprintf("failed to get partitions for collection=%d", task.CollectionID()))
|
||||
}
|
||||
|
||||
version := ex.targetMgr.GetCollectionTargetVersion(ctx, task.CollectionID(), meta.NextTargetFirst)
|
||||
req := packSubChannelRequest(
|
||||
task,
|
||||
action,
|
||||
@ -392,6 +394,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
|
||||
dmChannel,
|
||||
indexInfo,
|
||||
partitions,
|
||||
version,
|
||||
)
|
||||
err = fillSubChannelRequest(ctx, req, ex.broker, ex.shouldIncludeFlushedSegmentInfo(action.Node()))
|
||||
if err != nil {
|
||||
@ -400,6 +403,12 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
sealedSegments := ex.targetMgr.GetSealedSegmentsByChannel(ctx, dmChannel.CollectionID, dmChannel.ChannelName, meta.NextTarget)
|
||||
sealedSegmentRowCount := lo.MapValues(sealedSegments, func(segment *datapb.SegmentInfo, _ int64) int64 {
|
||||
return segment.GetNumOfRows()
|
||||
})
|
||||
req.SealedSegmentRowCount = sealedSegmentRowCount
|
||||
|
||||
ts := dmChannel.GetSeekPosition().GetTimestamp()
|
||||
log.Info("subscribe channel...",
|
||||
zap.Uint64("checkpoint", ts),
|
||||
|
||||
@ -208,6 +208,7 @@ func packSubChannelRequest(
|
||||
channel *meta.DmChannel,
|
||||
indexInfo []*indexpb.IndexInfo,
|
||||
partitions []int64,
|
||||
targetVersion int64,
|
||||
) *querypb.WatchDmChannelsRequest {
|
||||
return &querypb.WatchDmChannelsRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
@ -223,6 +224,7 @@ func packSubChannelRequest(
|
||||
ReplicaID: task.ReplicaID(),
|
||||
Version: time.Now().UnixNano(),
|
||||
IndexInfoList: indexInfo,
|
||||
TargetVersion: targetVersion,
|
||||
}
|
||||
}
|
||||
|
||||
@ -253,6 +255,7 @@ func fillSubChannelRequest(
|
||||
req.SegmentInfos = lo.SliceToMap(segmentInfos, func(info *datapb.SegmentInfo) (int64, *datapb.SegmentInfo) {
|
||||
return info.GetID(), info
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -100,8 +100,14 @@ func checkLoadStatus(ctx context.Context, m *meta.Meta, collectionID int64) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager,
|
||||
nodeMgr *session.NodeManager, collectionID int64, channels map[string]*meta.DmChannel,
|
||||
func GetShardLeadersWithChannels(
|
||||
ctx context.Context,
|
||||
m *meta.Meta,
|
||||
dist *meta.DistributionManager,
|
||||
nodeMgr *session.NodeManager,
|
||||
collectionID int64,
|
||||
channels map[string]*meta.DmChannel,
|
||||
withUnserviceableShards bool,
|
||||
) ([]*querypb.ShardLeadersList, error) {
|
||||
ret := make([]*querypb.ShardLeadersList, 0)
|
||||
|
||||
@ -111,9 +117,10 @@ func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr me
|
||||
|
||||
ids := make([]int64, 0, len(replicas))
|
||||
addrs := make([]string, 0, len(replicas))
|
||||
serviceable := make([]bool, 0, len(replicas))
|
||||
for _, replica := range replicas {
|
||||
leader := dist.ChannelDistManager.GetShardLeader(channel.GetChannelName(), replica)
|
||||
if leader == nil || !leader.IsServiceable() {
|
||||
if leader == nil || (!withUnserviceableShards && !leader.IsServiceable()) {
|
||||
log.WithRateGroup("util.GetShardLeaders", 1, 60).
|
||||
Warn("leader is not available in replica", zap.String("channel", channel.GetChannelName()), zap.Int64("replicaID", replica.GetID()))
|
||||
continue
|
||||
@ -122,11 +129,11 @@ func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr me
|
||||
if info != nil {
|
||||
ids = append(ids, info.ID())
|
||||
addrs = append(addrs, info.Addr())
|
||||
serviceable = append(serviceable, leader.IsServiceable())
|
||||
}
|
||||
}
|
||||
|
||||
// to avoid node down during GetShardLeaders
|
||||
if len(ids) == 0 {
|
||||
if len(ids) == 0 && !withUnserviceableShards {
|
||||
err := merr.WrapErrChannelNotAvailable(channel.GetChannelName())
|
||||
msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName())
|
||||
log.Warn(msg, zap.Error(err))
|
||||
@ -137,13 +144,22 @@ func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr me
|
||||
ChannelName: channel.GetChannelName(),
|
||||
NodeIds: ids,
|
||||
NodeAddrs: addrs,
|
||||
Serviceable: serviceable,
|
||||
})
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func GetShardLeaders(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, nodeMgr *session.NodeManager, collectionID int64) ([]*querypb.ShardLeadersList, error) {
|
||||
func GetShardLeaders(ctx context.Context,
|
||||
m *meta.Meta,
|
||||
targetMgr meta.TargetManagerInterface,
|
||||
dist *meta.DistributionManager,
|
||||
nodeMgr *session.NodeManager,
|
||||
collectionID int64,
|
||||
withUnserviceableShards bool,
|
||||
) ([]*querypb.ShardLeadersList, error) {
|
||||
// skip check load status if withUnserviceableShards is true
|
||||
if err := checkLoadStatus(ctx, m, collectionID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -155,7 +171,7 @@ func GetShardLeaders(ctx context.Context, m *meta.Meta, targetMgr meta.TargetMan
|
||||
log.Ctx(ctx).Warn("failed to get channels", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
return GetShardLeadersWithChannels(ctx, m, targetMgr, dist, nodeMgr, collectionID, channels)
|
||||
return GetShardLeadersWithChannels(ctx, m, dist, nodeMgr, collectionID, channels, withUnserviceableShards)
|
||||
}
|
||||
|
||||
// CheckCollectionsQueryable check all channels are watched and all segments are loaded for this collection
|
||||
@ -194,7 +210,7 @@ func checkCollectionQueryable(ctx context.Context, m *meta.Meta, targetMgr meta.
|
||||
return err
|
||||
}
|
||||
|
||||
shardList, err := GetShardLeadersWithChannels(ctx, m, targetMgr, dist, nodeMgr, collectionID, channels)
|
||||
shardList, err := GetShardLeadersWithChannels(ctx, m, dist, nodeMgr, collectionID, channels, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -30,12 +30,12 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
|
||||
@ -88,8 +88,8 @@ type ShardDelegator interface {
|
||||
LoadL0(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error
|
||||
LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error
|
||||
ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error
|
||||
SyncTargetVersion(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, droppedInTarget []int64, checkpoint *msgpb.MsgPosition, deleteSeekPos *msgpb.MsgPosition)
|
||||
GetQueryView() *channelQueryView
|
||||
SyncTargetVersion(action *querypb.SyncAction, partitions []int64)
|
||||
GetChannelQueryView() *channelQueryView
|
||||
GetDeleteBufferSize() (entryNum int64, memorySize int64)
|
||||
|
||||
// manage exclude segments
|
||||
@ -369,21 +369,32 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
|
||||
req.Req.GetIsIterator(),
|
||||
)
|
||||
|
||||
partialResultRequiredDataRatio := paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.GetAsFloat()
|
||||
// wait tsafe
|
||||
waitTr := timerecord.NewTimeRecorder("wait tSafe")
|
||||
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
|
||||
if err != nil {
|
||||
log.Warn("delegator search failed to wait tsafe", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if req.GetReq().GetMvccTimestamp() == 0 {
|
||||
req.Req.MvccTimestamp = tSafe
|
||||
var tSafe uint64
|
||||
var err error
|
||||
if partialResultRequiredDataRatio >= 1.0 {
|
||||
tSafe, err = sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
|
||||
if err != nil {
|
||||
log.Warn("delegator search failed to wait tsafe", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if req.GetReq().GetMvccTimestamp() == 0 {
|
||||
req.Req.MvccTimestamp = tSafe
|
||||
}
|
||||
} else {
|
||||
tSafe = sd.GetTSafe()
|
||||
if req.GetReq().GetMvccTimestamp() == 0 {
|
||||
req.Req.MvccTimestamp = tSafe
|
||||
}
|
||||
}
|
||||
|
||||
metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).
|
||||
Observe(float64(waitTr.ElapseSpan().Milliseconds()))
|
||||
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments(partialResultRequiredDataRatio, req.GetReq().GetPartitionIDs()...)
|
||||
if err != nil {
|
||||
log.Warn("delegator failed to search, current distribution is not serviceable", zap.Error(err))
|
||||
return nil, err
|
||||
@ -500,7 +511,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
|
||||
fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).
|
||||
Observe(float64(waitTr.ElapseSpan().Milliseconds()))
|
||||
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments(float64(1.0), req.GetReq().GetPartitionIDs()...)
|
||||
if err != nil {
|
||||
log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err))
|
||||
return err
|
||||
@ -562,21 +573,31 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
|
||||
req.Req.GetIsIterator(),
|
||||
)
|
||||
|
||||
partialResultRequiredDataRatio := paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.GetAsFloat()
|
||||
// wait tsafe
|
||||
waitTr := timerecord.NewTimeRecorder("wait tSafe")
|
||||
tSafe, err := sd.waitTSafe(ctx, req.Req.GetGuaranteeTimestamp())
|
||||
if err != nil {
|
||||
log.Warn("delegator query failed to wait tsafe", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if req.GetReq().GetMvccTimestamp() == 0 {
|
||||
req.Req.MvccTimestamp = tSafe
|
||||
var tSafe uint64
|
||||
var err error
|
||||
if partialResultRequiredDataRatio >= 1.0 {
|
||||
tSafe, err = sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
|
||||
if err != nil {
|
||||
log.Warn("delegator search failed to wait tsafe", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
if req.GetReq().GetMvccTimestamp() == 0 {
|
||||
req.Req.MvccTimestamp = tSafe
|
||||
}
|
||||
} else {
|
||||
if req.GetReq().GetMvccTimestamp() == 0 {
|
||||
req.Req.MvccTimestamp = sd.GetTSafe()
|
||||
}
|
||||
}
|
||||
|
||||
metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues(
|
||||
fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).
|
||||
Observe(float64(waitTr.ElapseSpan().Milliseconds()))
|
||||
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...)
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments(partialResultRequiredDataRatio, req.GetReq().GetPartitionIDs()...)
|
||||
if err != nil {
|
||||
log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err))
|
||||
return nil, err
|
||||
@ -646,7 +667,7 @@ func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetSta
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.Req.GetPartitionIDs()...)
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments(1.0, req.Req.GetPartitionIDs()...)
|
||||
if err != nil {
|
||||
log.Warn("delegator failed to GetStatistics, current distribution is not servicable")
|
||||
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not serviceable")
|
||||
@ -708,13 +729,14 @@ func organizeSubTask[T any](ctx context.Context,
|
||||
// update request
|
||||
req := modify(req, scope, segmentIDs, workerID)
|
||||
|
||||
// for partial search, tolerate some worker are offline
|
||||
worker, err := sd.workerManager.GetWorker(ctx, workerID)
|
||||
if err != nil {
|
||||
log.Warn("failed to get worker",
|
||||
log.Warn("failed to get worker for sub task",
|
||||
zap.Int64("nodeID", workerID),
|
||||
zap.Int64s("segments", segmentIDs),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("failed to get worker %d, %w", workerID, err)
|
||||
}
|
||||
|
||||
result = append(result, subTask[T]{
|
||||
@ -744,50 +766,110 @@ func executeSubTasks[T any, R interface {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(tasks))
|
||||
var partialResultRequiredDataRatio float64
|
||||
if taskType == "Query" || taskType == "Search" {
|
||||
partialResultRequiredDataRatio = paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.GetAsFloat()
|
||||
} else {
|
||||
partialResultRequiredDataRatio = 1.0
|
||||
}
|
||||
|
||||
resultCh := make(chan R, len(tasks))
|
||||
errCh := make(chan error, 1)
|
||||
wg, ctx := errgroup.WithContext(ctx)
|
||||
type channelResult struct {
|
||||
nodeID int64
|
||||
segments []int64
|
||||
result R
|
||||
err error
|
||||
}
|
||||
// Buffered channel to collect results from all goroutines
|
||||
resultCh := make(chan channelResult, len(tasks))
|
||||
for _, task := range tasks {
|
||||
go func(task subTask[T]) {
|
||||
defer wg.Done()
|
||||
result, err := execute(ctx, task.req, task.worker)
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
err = fmt.Errorf("worker(%d) query failed: %s", task.targetID, result.GetStatus().GetReason())
|
||||
task := task // capture loop variable
|
||||
wg.Go(func() error {
|
||||
var result R
|
||||
var err error
|
||||
if task.targetID == -1 || task.worker == nil {
|
||||
var segments []int64
|
||||
if req, ok := any(task.req).(interface{ GetSegmentIDs() []int64 }); ok {
|
||||
segments = req.GetSegmentIDs()
|
||||
} else {
|
||||
segments = []int64{}
|
||||
}
|
||||
err = fmt.Errorf("segments not loaded in any worker: %v", segments[:min(len(segments), 10)])
|
||||
} else {
|
||||
result, err = execute(ctx, task.req, task.worker)
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
err = fmt.Errorf("worker(%d) query failed: %s", task.targetID, result.GetStatus().GetReason())
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warn("failed to execute sub task",
|
||||
zap.String("taskType", taskType),
|
||||
zap.Int64("nodeID", task.targetID),
|
||||
zap.Error(err),
|
||||
)
|
||||
select {
|
||||
case errCh <- err: // must be the first
|
||||
default: // skip other errors
|
||||
// check if partial result is disabled, if so, let all sub tasks fail fast
|
||||
if partialResultRequiredDataRatio == 1 {
|
||||
return err
|
||||
}
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
resultCh <- result
|
||||
}(task)
|
||||
|
||||
taskResult := channelResult{
|
||||
nodeID: task.targetID,
|
||||
result: result,
|
||||
err: err,
|
||||
}
|
||||
if req, ok := any(task.req).(interface{ GetSegmentIDs() []int64 }); ok {
|
||||
taskResult.segments = req.GetSegmentIDs()
|
||||
}
|
||||
resultCh <- taskResult
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(resultCh)
|
||||
select {
|
||||
case err := <-errCh:
|
||||
log.Warn("Delegator execute subTask failed",
|
||||
// Wait for all tasks to complete
|
||||
if err := wg.Wait(); err != nil {
|
||||
log.Warn("some tasks failed to complete",
|
||||
zap.String("taskType", taskType),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, err
|
||||
default:
|
||||
}
|
||||
close(resultCh)
|
||||
|
||||
successSegmentList := []int64{}
|
||||
failureSegmentList := []int64{}
|
||||
var errors []error
|
||||
|
||||
// Collect results
|
||||
results := make([]R, 0, len(tasks))
|
||||
for item := range resultCh {
|
||||
if item.err == nil {
|
||||
successSegmentList = append(successSegmentList, item.segments...)
|
||||
results = append(results, item.result)
|
||||
} else {
|
||||
failureSegmentList = append(failureSegmentList, item.segments...)
|
||||
errors = append(errors, item.err)
|
||||
}
|
||||
}
|
||||
|
||||
results := make([]R, 0, len(tasks))
|
||||
for result := range resultCh {
|
||||
results = append(results, result)
|
||||
accessDataRatio := 1.0
|
||||
totalSegments := len(successSegmentList) + len(failureSegmentList)
|
||||
if totalSegments > 0 {
|
||||
accessDataRatio = float64(len(successSegmentList)) / float64(totalSegments)
|
||||
if accessDataRatio < 1.0 {
|
||||
log.Info("partial result executed successfully",
|
||||
zap.String("taskType", taskType),
|
||||
zap.Float64("successRatio", accessDataRatio),
|
||||
zap.Float64("partialResultRequiredDataRatio", partialResultRequiredDataRatio),
|
||||
zap.Int("totalSegments", totalSegments),
|
||||
zap.Int64s("failureSegmentList", failureSegmentList),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if accessDataRatio < partialResultRequiredDataRatio {
|
||||
return nil, merr.Combine(errors...)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
@ -891,7 +973,7 @@ func (sd *shardDelegator) UpdateSchema(ctx context.Context, schema *schemapb.Col
|
||||
|
||||
log.Info("delegator received update schema event")
|
||||
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments()
|
||||
sealed, growing, version, err := sd.distribution.PinReadableSegments(1.0)
|
||||
if err != nil {
|
||||
log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err))
|
||||
return err
|
||||
@ -1011,6 +1093,7 @@ func (sd *shardDelegator) loadPartitionStats(ctx context.Context, partStatsVersi
|
||||
func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64,
|
||||
workerManager cluster.Manager, manager *segments.Manager, loader segments.Loader,
|
||||
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager,
|
||||
queryView *channelQueryView,
|
||||
) (ShardDelegator, error) {
|
||||
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID),
|
||||
zap.Int64("replicaID", replicaID),
|
||||
@ -1041,7 +1124,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
|
||||
segmentManager: manager.Segment,
|
||||
workerManager: workerManager,
|
||||
lifetime: lifetime.NewLifetime(lifetime.Initializing),
|
||||
distribution: NewDistribution(channel),
|
||||
distribution: NewDistribution(channel, queryView),
|
||||
deleteBuffer: deletebuffer.NewListDeleteBuffer[*deletebuffer.Item](startTs, sizePerBlock,
|
||||
[]string{fmt.Sprint(paramtable.GetNodeID()), channel}),
|
||||
pkOracle: pkoracle.NewPkOracle(),
|
||||
|
||||
@ -959,70 +959,45 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) SyncTargetVersion(
|
||||
newVersion int64,
|
||||
partitions []int64,
|
||||
growingInTarget []int64,
|
||||
sealedInTarget []int64,
|
||||
droppedInTarget []int64,
|
||||
checkpoint *msgpb.MsgPosition,
|
||||
deleteSeekPos *msgpb.MsgPosition,
|
||||
) {
|
||||
growings := sd.segmentManager.GetBy(
|
||||
segments.WithType(segments.SegmentTypeGrowing),
|
||||
segments.WithChannel(sd.vchannelName),
|
||||
)
|
||||
|
||||
sealedSet := typeutil.NewUniqueSet(sealedInTarget...)
|
||||
growingSet := typeutil.NewUniqueSet(growingInTarget...)
|
||||
droppedSet := typeutil.NewUniqueSet(droppedInTarget...)
|
||||
redundantGrowing := typeutil.NewUniqueSet()
|
||||
for _, s := range growings {
|
||||
if growingSet.Contain(s.ID()) {
|
||||
continue
|
||||
func (sd *shardDelegator) SyncTargetVersion(action *querypb.SyncAction, partitions []int64) {
|
||||
sd.distribution.SyncTargetVersion(action, partitions)
|
||||
// clean delete buffer after distribution becomes serviceable
|
||||
if sd.distribution.queryView.Serviceable() {
|
||||
checkpoint := action.GetCheckpoint()
|
||||
deleteSeekPos := action.GetDeleteCP()
|
||||
if deleteSeekPos == nil {
|
||||
// for compatible with 2.4, we use checkpoint as deleteCP when deleteCP is nil
|
||||
deleteSeekPos = checkpoint
|
||||
log.Info("use checkpoint as deleteCP",
|
||||
zap.String("channelName", sd.vchannelName),
|
||||
zap.Time("deleteSeekPos", tsoutil.PhysicalTime(action.GetCheckpoint().GetTimestamp())))
|
||||
}
|
||||
|
||||
// sealed segment already exists, make growing segment redundant
|
||||
if sealedSet.Contain(s.ID()) {
|
||||
redundantGrowing.Insert(s.ID())
|
||||
}
|
||||
|
||||
// sealed segment already dropped, make growing segment redundant
|
||||
if droppedSet.Contain(s.ID()) {
|
||||
redundantGrowing.Insert(s.ID())
|
||||
start := time.Now()
|
||||
sizeBeforeClean, _ := sd.deleteBuffer.Size()
|
||||
l0NumBeforeClean := len(sd.deleteBuffer.ListL0())
|
||||
sd.deleteBuffer.UnRegister(deleteSeekPos.GetTimestamp())
|
||||
sizeAfterClean, _ := sd.deleteBuffer.Size()
|
||||
l0NumAfterClean := len(sd.deleteBuffer.ListL0())
|
||||
|
||||
if sizeAfterClean < sizeBeforeClean || l0NumAfterClean < l0NumBeforeClean {
|
||||
log.Info("clean delete buffer",
|
||||
zap.String("channel", sd.vchannelName),
|
||||
zap.Time("deleteSeekPos", tsoutil.PhysicalTime(deleteSeekPos.GetTimestamp())),
|
||||
zap.Time("channelCP", tsoutil.PhysicalTime(checkpoint.GetTimestamp())),
|
||||
zap.Int64("sizeBeforeClean", sizeBeforeClean),
|
||||
zap.Int64("sizeAfterClean", sizeAfterClean),
|
||||
zap.Int("l0NumBeforeClean", l0NumBeforeClean),
|
||||
zap.Int("l0NumAfterClean", l0NumAfterClean),
|
||||
zap.Duration("cost", time.Since(start)),
|
||||
)
|
||||
}
|
||||
sd.RefreshLevel0DeletionStats()
|
||||
}
|
||||
redundantGrowingIDs := redundantGrowing.Collect()
|
||||
if len(redundantGrowing) > 0 {
|
||||
log.Warn("found redundant growing segments",
|
||||
zap.Int64s("growingSegments", redundantGrowingIDs))
|
||||
}
|
||||
sd.distribution.SyncTargetVersion(newVersion, partitions, growingInTarget, sealedInTarget, redundantGrowingIDs)
|
||||
start := time.Now()
|
||||
sizeBeforeClean, _ := sd.deleteBuffer.Size()
|
||||
l0NumBeforeClean := len(sd.deleteBuffer.ListL0())
|
||||
sd.deleteBuffer.UnRegister(deleteSeekPos.GetTimestamp())
|
||||
sizeAfterClean, _ := sd.deleteBuffer.Size()
|
||||
l0NumAfterClean := len(sd.deleteBuffer.ListL0())
|
||||
|
||||
if sizeAfterClean < sizeBeforeClean || l0NumAfterClean < l0NumBeforeClean {
|
||||
log.Info("clean delete buffer",
|
||||
zap.String("channel", sd.vchannelName),
|
||||
zap.Time("deleteSeekPos", tsoutil.PhysicalTime(deleteSeekPos.GetTimestamp())),
|
||||
zap.Time("channelCP", tsoutil.PhysicalTime(checkpoint.GetTimestamp())),
|
||||
zap.Int64("sizeBeforeClean", sizeBeforeClean),
|
||||
zap.Int64("sizeAfterClean", sizeAfterClean),
|
||||
zap.Int("l0NumBeforeClean", l0NumBeforeClean),
|
||||
zap.Int("l0NumAfterClean", l0NumAfterClean),
|
||||
zap.Duration("cost", time.Since(start)),
|
||||
)
|
||||
}
|
||||
|
||||
sd.RefreshLevel0DeletionStats()
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) GetQueryView() *channelQueryView {
|
||||
return sd.distribution.GetQueryView()
|
||||
func (sd *shardDelegator) GetChannelQueryView() *channelQueryView {
|
||||
return sd.distribution.queryView
|
||||
}
|
||||
|
||||
func (sd *shardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) {
|
||||
|
||||
@ -194,7 +194,7 @@ func (s *DelegatorDataSuite) genCollectionWithFunction() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager)
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.NoError(err)
|
||||
s.delegator = delegator.(*shardDelegator)
|
||||
}
|
||||
@ -216,7 +216,7 @@ func (s *DelegatorDataSuite) SetupTest() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager)
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.Require().NoError(err)
|
||||
sd, ok := delegator.(*shardDelegator)
|
||||
s.Require().True(ok)
|
||||
@ -419,7 +419,16 @@ func (s *DelegatorDataSuite) TestProcessDelete() {
|
||||
s.Require().NoError(err)
|
||||
|
||||
// sync target version, make delegator serviceable
|
||||
s.delegator.SyncTargetVersion(time.Now().UnixNano(), []int64{500}, []int64{1001}, []int64{1000}, nil, &msgpb.MsgPosition{}, &msgpb.MsgPosition{})
|
||||
s.delegator.SyncTargetVersion(&querypb.SyncAction{
|
||||
TargetVersion: 2001,
|
||||
GrowingInTarget: []int64{1001},
|
||||
SealedSegmentRowCount: map[int64]int64{
|
||||
1000: 100,
|
||||
},
|
||||
DroppedInTarget: []int64{},
|
||||
Checkpoint: &msgpb.MsgPosition{},
|
||||
DeleteCP: &msgpb.MsgPosition{},
|
||||
}, []int64{500, 501})
|
||||
s.delegator.ProcessDelete([]*DeleteData{
|
||||
{
|
||||
PartitionID: 500,
|
||||
@ -799,7 +808,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, nil)
|
||||
}, 10000, nil, nil, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.NoError(err)
|
||||
|
||||
growing0 := segments.NewMockSegment(s.T())
|
||||
@ -1422,8 +1431,15 @@ func (s *DelegatorDataSuite) TestSyncTargetVersion() {
|
||||
s.manager.Segment.Put(context.Background(), segments.SegmentTypeGrowing, ms)
|
||||
}
|
||||
|
||||
s.delegator.SyncTargetVersion(int64(5), []int64{1}, []int64{1}, []int64{2}, []int64{3, 4}, &msgpb.MsgPosition{}, &msgpb.MsgPosition{})
|
||||
s.Equal(int64(5), s.delegator.GetQueryView().GetVersion())
|
||||
s.delegator.SyncTargetVersion(&querypb.SyncAction{
|
||||
TargetVersion: 5,
|
||||
GrowingInTarget: []int64{1},
|
||||
SealedInTarget: []int64{2},
|
||||
DroppedInTarget: []int64{3, 4},
|
||||
Checkpoint: &msgpb.MsgPosition{},
|
||||
DeleteCP: &msgpb.MsgPosition{},
|
||||
}, []int64{500, 501})
|
||||
s.Equal(int64(5), s.delegator.GetChannelQueryView().GetVersion())
|
||||
}
|
||||
|
||||
func (s *DelegatorDataSuite) TestLevel0Deletions() {
|
||||
|
||||
@ -163,7 +163,7 @@ func (s *DelegatorSuite) SetupTest() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager)
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
@ -202,7 +202,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager)
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.Error(err)
|
||||
})
|
||||
|
||||
@ -245,7 +245,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager)
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.NoError(err)
|
||||
})
|
||||
}
|
||||
@ -325,7 +325,19 @@ func (s *DelegatorSuite) initSegments() {
|
||||
Version: 2001,
|
||||
},
|
||||
)
|
||||
s.delegator.SyncTargetVersion(2001, []int64{500, 501}, []int64{1004}, []int64{1000, 1001, 1002, 1003}, []int64{}, &msgpb.MsgPosition{}, &msgpb.MsgPosition{})
|
||||
s.delegator.SyncTargetVersion(&querypb.SyncAction{
|
||||
TargetVersion: 2001,
|
||||
GrowingInTarget: []int64{1004},
|
||||
SealedSegmentRowCount: map[int64]int64{
|
||||
1000: 100,
|
||||
1001: 100,
|
||||
1002: 100,
|
||||
1003: 100,
|
||||
},
|
||||
DroppedInTarget: []int64{},
|
||||
Checkpoint: &msgpb.MsgPosition{},
|
||||
DeleteCP: &msgpb.MsgPosition{},
|
||||
}, []int64{500, 501})
|
||||
}
|
||||
|
||||
func (s *DelegatorSuite) TestSearch() {
|
||||
@ -888,7 +900,8 @@ func (s *DelegatorSuite) TestQueryStream() {
|
||||
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
|
||||
DmlChannels: []string{s.vchannelName},
|
||||
}, server)
|
||||
s.True(errors.Is(err, mockErr))
|
||||
s.Error(err)
|
||||
s.ErrorContains(err, "segments not loaded in any worker")
|
||||
})
|
||||
|
||||
s.Run("worker_return_error", func() {
|
||||
@ -1251,7 +1264,7 @@ func (s *DelegatorSuite) TestUpdateSchema() {
|
||||
s.Run("worker_manager_error", func() {
|
||||
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).RunAndReturn(func(ctx context.Context, i int64) (cluster.Worker, error) {
|
||||
return nil, merr.WrapErrServiceInternal("mocked")
|
||||
}).Once()
|
||||
})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
|
||||
@ -155,7 +155,7 @@ func (s *StreamingForwardSuite) SetupTest() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager)
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.Require().NoError(err)
|
||||
|
||||
sd, ok := delegator.(*shardDelegator)
|
||||
@ -185,7 +185,13 @@ func (s *StreamingForwardSuite) TestBFStreamingForward() {
|
||||
PartitionID: 1,
|
||||
SegmentID: 102,
|
||||
})
|
||||
delegator.distribution.SyncTargetVersion(1, []int64{1}, []int64{100}, []int64{101, 102}, nil)
|
||||
delegator.distribution.SyncTargetVersion(&querypb.SyncAction{
|
||||
TargetVersion: 1,
|
||||
GrowingInTarget: []int64{100},
|
||||
SealedInTarget: []int64{101, 102},
|
||||
DroppedInTarget: nil,
|
||||
Checkpoint: nil,
|
||||
}, []int64{1})
|
||||
|
||||
// Setup pk oracle
|
||||
// empty bfs will not match
|
||||
@ -238,7 +244,13 @@ func (s *StreamingForwardSuite) TestDirectStreamingForward() {
|
||||
PartitionID: 1,
|
||||
SegmentID: 102,
|
||||
})
|
||||
delegator.distribution.SyncTargetVersion(1, []int64{1}, []int64{100}, []int64{101, 102}, nil)
|
||||
delegator.distribution.SyncTargetVersion(&querypb.SyncAction{
|
||||
TargetVersion: 1,
|
||||
GrowingInTarget: []int64{100},
|
||||
SealedInTarget: []int64{101, 102},
|
||||
DroppedInTarget: nil,
|
||||
Checkpoint: nil,
|
||||
}, []int64{1})
|
||||
|
||||
// Setup pk oracle
|
||||
// empty bfs will not match
|
||||
@ -386,7 +398,7 @@ func (s *GrowingMergeL0Suite) SetupTest() {
|
||||
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
|
||||
return s.mq, nil
|
||||
},
|
||||
}, 10000, nil, s.chunkManager)
|
||||
}, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
s.Require().NoError(err)
|
||||
|
||||
sd, ok := delegator.(*shardDelegator)
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
package delegator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/samber/lo"
|
||||
@ -25,6 +26,7 @@ import (
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
@ -56,12 +58,26 @@ func getClosedCh() chan struct{} {
|
||||
}
|
||||
|
||||
// channelQueryView maintains the sealed segment list which should be used for search/query.
|
||||
// for new delegator, will got a new channelQueryView from WatchChannel, and get the queryView update from querycoord before it becomes serviceable
|
||||
// after delegator becomes serviceable, it only update the queryView by SyncTargetVersion
|
||||
type channelQueryView struct {
|
||||
sealedSegments []int64 // sealed segment list which should be used for search/query
|
||||
partitions typeutil.UniqueSet // partitions list which sealed segments belong to
|
||||
version int64 // version of current query view, same as targetVersion in qc
|
||||
growingSegments typeutil.UniqueSet // growing segment list which should be used for search/query
|
||||
sealedSegmentRowCount map[int64]int64 // sealed segment list which should be used for search/query, segmentID -> row count
|
||||
partitions typeutil.UniqueSet // partitions list which sealed segments belong to
|
||||
version int64 // version of current query view, same as targetVersion in qc
|
||||
|
||||
serviceable *atomic.Bool
|
||||
loadedRatio *atomic.Float64 // loaded ratio of current query view, set serviceable to true if loadedRatio == 1.0
|
||||
unloadedSealedSegments []SegmentEntry // workerID -> -1
|
||||
}
|
||||
|
||||
func NewChannelQueryView(growings []int64, sealedSegmentRowCount map[int64]int64, partitions []int64, version int64) *channelQueryView {
|
||||
return &channelQueryView{
|
||||
growingSegments: typeutil.NewUniqueSet(growings...),
|
||||
sealedSegmentRowCount: sealedSegmentRowCount,
|
||||
partitions: typeutil.NewUniqueSet(partitions...),
|
||||
version: version,
|
||||
loadedRatio: atomic.NewFloat64(0),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *channelQueryView) GetVersion() int64 {
|
||||
@ -69,7 +85,11 @@ func (q *channelQueryView) GetVersion() int64 {
|
||||
}
|
||||
|
||||
func (q *channelQueryView) Serviceable() bool {
|
||||
return q.serviceable.Load()
|
||||
return q.loadedRatio.Load() >= 1.0
|
||||
}
|
||||
|
||||
func (q *channelQueryView) GetLoadedRatio() float64 {
|
||||
return q.loadedRatio.Load()
|
||||
}
|
||||
|
||||
// distribution is the struct to store segment distribution.
|
||||
@ -107,22 +127,17 @@ type SegmentEntry struct {
|
||||
Offline bool // if delegator failed to execute forwardDelete/Query/Search on segment, it will be offline
|
||||
}
|
||||
|
||||
// NewDistribution creates a new distribution instance with all field initialized.
|
||||
func NewDistribution(channelName string) *distribution {
|
||||
func NewDistribution(channelName string, queryView *channelQueryView) *distribution {
|
||||
dist := &distribution{
|
||||
channelName: channelName,
|
||||
growingSegments: make(map[UniqueID]SegmentEntry),
|
||||
sealedSegments: make(map[UniqueID]SegmentEntry),
|
||||
snapshots: typeutil.NewConcurrentMap[int64, *snapshot](),
|
||||
current: atomic.NewPointer[snapshot](nil),
|
||||
queryView: &channelQueryView{
|
||||
serviceable: atomic.NewBool(false),
|
||||
partitions: typeutil.NewSet[int64](),
|
||||
version: initialTargetVersion,
|
||||
},
|
||||
queryView: queryView,
|
||||
}
|
||||
|
||||
dist.genSnapshot()
|
||||
dist.updateServiceable("NewDistribution")
|
||||
return dist
|
||||
}
|
||||
|
||||
@ -132,12 +147,14 @@ func (d *distribution) SetIDFOracle(idfOracle IDFOracle) {
|
||||
d.idfOracle = idfOracle
|
||||
}
|
||||
|
||||
func (d *distribution) PinReadableSegments(partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry, version int64, err error) {
|
||||
// return segment distribution in query view
|
||||
func (d *distribution) PinReadableSegments(requiredLoadRatio float64, partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry, version int64, err error) {
|
||||
d.mut.RLock()
|
||||
defer d.mut.RUnlock()
|
||||
|
||||
if !d.Serviceable() {
|
||||
return nil, nil, -1, merr.WrapErrChannelNotAvailable("channel distribution is not serviceable")
|
||||
if d.queryView.GetLoadedRatio() < requiredLoadRatio {
|
||||
return nil, nil, -1, merr.WrapErrChannelNotAvailable(d.channelName,
|
||||
fmt.Sprintf("channel distribution is not serviceable, required load ratio is %f, current load ratio is %f", requiredLoadRatio, d.queryView.GetLoadedRatio()))
|
||||
}
|
||||
|
||||
current := d.current.Load()
|
||||
@ -153,6 +170,15 @@ func (d *distribution) PinReadableSegments(partitions ...int64) (sealed []Snapsh
|
||||
targetVersion := current.GetTargetVersion()
|
||||
filterReadable := d.readableFilter(targetVersion)
|
||||
sealed, growing = d.filterSegments(sealed, growing, filterReadable)
|
||||
|
||||
if len(d.queryView.unloadedSealedSegments) > 0 {
|
||||
// append distribution of unloaded segment
|
||||
sealed = append(sealed, SnapshotItem{
|
||||
NodeID: -1,
|
||||
Segments: d.queryView.unloadedSealedSegments,
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -213,26 +239,50 @@ func (d *distribution) getTargetVersion() int64 {
|
||||
|
||||
// Serviceable returns wether current snapshot is serviceable.
|
||||
func (d *distribution) Serviceable() bool {
|
||||
return d.queryView.serviceable.Load()
|
||||
return d.queryView.Serviceable()
|
||||
}
|
||||
|
||||
// for now, delegator become serviceable only when watchDmChannel is done
|
||||
// so we regard all needed growing is loaded and we compute loadRatio based on sealed segments
|
||||
func (d *distribution) updateServiceable(triggerAction string) {
|
||||
if d.queryView.version != initialTargetVersion {
|
||||
serviceable := true
|
||||
for _, s := range d.queryView.sealedSegments {
|
||||
if entry, ok := d.sealedSegments[s]; !ok || entry.Offline {
|
||||
serviceable = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if serviceable != d.queryView.serviceable.Load() {
|
||||
d.queryView.serviceable.Store(serviceable)
|
||||
log.Info("channel distribution serviceable changed",
|
||||
zap.String("channel", d.channelName),
|
||||
zap.Bool("serviceable", serviceable),
|
||||
zap.String("action", triggerAction))
|
||||
loadedSealedSegments := int64(0)
|
||||
totalSealedRowCount := int64(0)
|
||||
unloadedSealedSegments := make([]SegmentEntry, 0)
|
||||
for id, rowCount := range d.queryView.sealedSegmentRowCount {
|
||||
if entry, ok := d.sealedSegments[id]; ok && !entry.Offline {
|
||||
loadedSealedSegments += rowCount
|
||||
} else {
|
||||
unloadedSealedSegments = append(unloadedSealedSegments, SegmentEntry{SegmentID: id, NodeID: -1})
|
||||
}
|
||||
totalSealedRowCount += rowCount
|
||||
}
|
||||
|
||||
// unloaded segment entry list for partial result
|
||||
d.queryView.unloadedSealedSegments = unloadedSealedSegments
|
||||
|
||||
loadedRatio := 0.0
|
||||
if len(d.queryView.sealedSegmentRowCount) == 0 {
|
||||
loadedRatio = 1.0
|
||||
} else if loadedSealedSegments == 0 {
|
||||
loadedRatio = 0.0
|
||||
} else {
|
||||
loadedRatio = float64(loadedSealedSegments) / float64(totalSealedRowCount)
|
||||
}
|
||||
|
||||
serviceable := loadedRatio >= 1.0
|
||||
if serviceable != d.queryView.Serviceable() {
|
||||
log.Info("channel distribution serviceable changed",
|
||||
zap.String("channel", d.channelName),
|
||||
zap.Bool("serviceable", serviceable),
|
||||
zap.Float64("loadedRatio", loadedRatio),
|
||||
zap.Int64("loadedSealedRowCount", loadedSealedSegments),
|
||||
zap.Int64("totalSealedRowCount", totalSealedRowCount),
|
||||
zap.Int("unloadedSealedSegmentNum", len(unloadedSealedSegments)),
|
||||
zap.Int("totalSealedSegmentNum", len(d.queryView.sealedSegmentRowCount)),
|
||||
zap.String("action", triggerAction))
|
||||
}
|
||||
|
||||
d.queryView.loadedRatio.Store(loadedRatio)
|
||||
}
|
||||
|
||||
// AddDistributions add multiple segment entries.
|
||||
@ -257,8 +307,14 @@ func (d *distribution) AddDistributions(entries ...SegmentEntry) {
|
||||
// remain the target version for already loaded segment to void skipping this segment when executing search
|
||||
entry.TargetVersion = oldEntry.TargetVersion
|
||||
} else {
|
||||
// waiting for sync target version, to become readable
|
||||
entry.TargetVersion = unreadableTargetVersion
|
||||
_, ok := d.queryView.sealedSegmentRowCount[entry.SegmentID]
|
||||
if ok || d.queryView.growingSegments.Contain(entry.SegmentID) {
|
||||
// set segment version to query view version, to support partial result
|
||||
entry.TargetVersion = d.queryView.GetVersion()
|
||||
} else {
|
||||
// set segment version to unreadableTargetVersion, if it's not in query view
|
||||
entry.TargetVersion = unreadableTargetVersion
|
||||
}
|
||||
}
|
||||
d.sealedSegments[entry.SegmentID] = entry
|
||||
}
|
||||
@ -306,65 +362,73 @@ func (d *distribution) MarkOfflineSegments(segmentIDs ...int64) {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTargetVersion update readable segment version
|
||||
func (d *distribution) SyncTargetVersion(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, redundantGrowings []int64) {
|
||||
// update readable channel view
|
||||
// 1. update readable channel view to support partial result before distribution is serviceable
|
||||
// 2. update readable channel view to support full result after new distribution is serviceable
|
||||
// Notice: if we don't need to be compatible with 2.5.x, we can just update new query view to support query,
|
||||
// and new query view will become serviceable automatically, a sync action after distribution is serviceable is unnecessary
|
||||
func (d *distribution) SyncTargetVersion(action *querypb.SyncAction, partitions []int64) {
|
||||
d.mut.Lock()
|
||||
defer d.mut.Unlock()
|
||||
|
||||
for _, segmentID := range growingInTarget {
|
||||
entry, ok := d.growingSegments[segmentID]
|
||||
if !ok {
|
||||
log.Warn("readable growing segment lost, consume from dml seems too slow",
|
||||
zap.Int64("segmentID", segmentID))
|
||||
continue
|
||||
}
|
||||
entry.TargetVersion = newVersion
|
||||
d.growingSegments[segmentID] = entry
|
||||
}
|
||||
|
||||
for _, segmentID := range redundantGrowings {
|
||||
entry, ok := d.growingSegments[segmentID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entry.TargetVersion = redundantTargetVersion
|
||||
d.growingSegments[segmentID] = entry
|
||||
}
|
||||
|
||||
for _, segmentID := range sealedInTarget {
|
||||
entry, ok := d.sealedSegments[segmentID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entry.TargetVersion = newVersion
|
||||
d.sealedSegments[segmentID] = entry
|
||||
}
|
||||
|
||||
oldValue := d.queryView.version
|
||||
d.queryView = &channelQueryView{
|
||||
sealedSegments: sealedInTarget,
|
||||
partitions: typeutil.NewUniqueSet(partitions...),
|
||||
version: newVersion,
|
||||
serviceable: d.queryView.serviceable,
|
||||
growingSegments: typeutil.NewUniqueSet(action.GetGrowingInTarget()...),
|
||||
sealedSegmentRowCount: action.GetSealedSegmentRowCount(),
|
||||
partitions: typeutil.NewUniqueSet(partitions...),
|
||||
version: action.GetTargetVersion(),
|
||||
loadedRatio: atomic.NewFloat64(0),
|
||||
}
|
||||
|
||||
// update working partition list
|
||||
d.genSnapshot()
|
||||
sealedSet := typeutil.NewUniqueSet(action.GetSealedInTarget()...)
|
||||
droppedSet := typeutil.NewUniqueSet(action.GetDroppedInTarget()...)
|
||||
redundantGrowings := make([]int64, 0)
|
||||
for _, s := range d.growingSegments {
|
||||
// sealed segment already exists or dropped, make growing segment redundant
|
||||
if sealedSet.Contain(s.SegmentID) || droppedSet.Contain(s.SegmentID) {
|
||||
s.TargetVersion = redundantTargetVersion
|
||||
d.growingSegments[s.SegmentID] = s
|
||||
redundantGrowings = append(redundantGrowings, s.SegmentID)
|
||||
}
|
||||
}
|
||||
|
||||
d.queryView.growingSegments.Range(func(s UniqueID) bool {
|
||||
entry, ok := d.growingSegments[s]
|
||||
if !ok {
|
||||
log.Warn("readable growing segment lost, consume from dml seems too slow",
|
||||
zap.Int64("segmentID", s))
|
||||
return true
|
||||
}
|
||||
entry.TargetVersion = action.GetTargetVersion()
|
||||
d.growingSegments[s] = entry
|
||||
return true
|
||||
})
|
||||
|
||||
for id := range d.queryView.sealedSegmentRowCount {
|
||||
entry, ok := d.sealedSegments[id]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entry.TargetVersion = action.GetTargetVersion()
|
||||
d.sealedSegments[id] = entry
|
||||
}
|
||||
|
||||
d.genSnapshot()
|
||||
if d.idfOracle != nil {
|
||||
d.idfOracle.SetNext(d.current.Load())
|
||||
d.idfOracle.LazyRemoveGrowings(newVersion, redundantGrowings...)
|
||||
d.idfOracle.LazyRemoveGrowings(action.GetTargetVersion(), redundantGrowings...)
|
||||
}
|
||||
// if sealed segment in leader view is less than sealed segment in target, set delegator to unserviceable
|
||||
d.updateServiceable("SyncTargetVersion")
|
||||
|
||||
log.Info("Update channel query view",
|
||||
zap.String("channel", d.channelName),
|
||||
zap.Int64s("partitions", partitions),
|
||||
zap.Int64("oldVersion", oldValue),
|
||||
zap.Int64("newVersion", newVersion),
|
||||
zap.Int("growingSegmentNum", len(growingInTarget)),
|
||||
zap.Int("sealedSegmentNum", len(sealedInTarget)),
|
||||
zap.Int64("newVersion", action.GetTargetVersion()),
|
||||
zap.Bool("serviceable", d.queryView.Serviceable()),
|
||||
zap.Float64("loadedRatio", d.queryView.GetLoadedRatio()),
|
||||
zap.Int("growingSegmentNum", len(action.GetGrowingInTarget())),
|
||||
zap.Int("sealedSegmentNum", len(action.GetSealedInTarget())),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -21,7 +21,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
)
|
||||
|
||||
type DistributionSuite struct {
|
||||
@ -30,7 +33,7 @@ type DistributionSuite struct {
|
||||
}
|
||||
|
||||
func (s *DistributionSuite) SetupTest() {
|
||||
s.dist = NewDistribution("channel-1")
|
||||
s.dist = NewDistribution("channel-1", NewChannelQueryView(nil, nil, nil, initialTargetVersion))
|
||||
}
|
||||
|
||||
func (s *DistributionSuite) TearDownTest() {
|
||||
@ -44,6 +47,7 @@ func (s *DistributionSuite) TestAddDistribution() {
|
||||
growing []SegmentEntry
|
||||
expected []SnapshotItem
|
||||
expectedSignalClosed bool
|
||||
expectedLoadRatio float64
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
@ -177,8 +181,10 @@ func (s *DistributionSuite) TestAddDistribution() {
|
||||
s.SetupTest()
|
||||
defer s.TearDownTest()
|
||||
s.dist.AddGrowing(tc.growing...)
|
||||
s.dist.SyncTargetVersion(1000, nil, nil, nil, nil)
|
||||
_, _, version, err := s.dist.PinReadableSegments()
|
||||
s.dist.SyncTargetVersion(&querypb.SyncAction{
|
||||
TargetVersion: 1000,
|
||||
}, nil)
|
||||
_, _, version, err := s.dist.PinReadableSegments(1.0)
|
||||
s.Require().NoError(err)
|
||||
s.dist.AddDistributions(tc.input...)
|
||||
sealed, _ := s.dist.PeekSegments(false)
|
||||
@ -225,11 +231,6 @@ func (s *DistributionSuite) TestAddGrowing() {
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
tag: "nil input",
|
||||
input: nil,
|
||||
expected: []SegmentEntry{},
|
||||
},
|
||||
{
|
||||
tag: "normal_case",
|
||||
input: []SegmentEntry{
|
||||
@ -261,8 +262,11 @@ func (s *DistributionSuite) TestAddGrowing() {
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.dist.AddGrowing(tc.input...)
|
||||
s.dist.SyncTargetVersion(1000, tc.workingParts, []int64{1, 2}, nil, nil)
|
||||
_, growing, version, err := s.dist.PinReadableSegments()
|
||||
s.dist.SyncTargetVersion(&querypb.SyncAction{
|
||||
TargetVersion: 1000,
|
||||
GrowingInTarget: []int64{1, 2},
|
||||
}, tc.workingParts)
|
||||
_, growing, version, err := s.dist.PinReadableSegments(1.0)
|
||||
s.Require().NoError(err)
|
||||
defer s.dist.Unpin(version)
|
||||
|
||||
@ -452,15 +456,19 @@ func (s *DistributionSuite) TestRemoveDistribution() {
|
||||
growingIDs := lo.Map(tc.presetGrowing, func(item SegmentEntry, idx int) int64 {
|
||||
return item.SegmentID
|
||||
})
|
||||
sealedIDs := lo.Map(tc.presetSealed, func(item SegmentEntry, idx int) int64 {
|
||||
return item.SegmentID
|
||||
sealedSegmentRowCount := lo.SliceToMap(tc.presetSealed, func(item SegmentEntry) (int64, int64) {
|
||||
return item.SegmentID, 100
|
||||
})
|
||||
s.dist.SyncTargetVersion(time.Now().Unix(), nil, growingIDs, sealedIDs, nil)
|
||||
s.dist.SyncTargetVersion(&querypb.SyncAction{
|
||||
TargetVersion: 1000,
|
||||
GrowingInTarget: growingIDs,
|
||||
SealedSegmentRowCount: sealedSegmentRowCount,
|
||||
}, nil)
|
||||
|
||||
var version int64
|
||||
if tc.withMockRead {
|
||||
var err error
|
||||
_, _, version, err = s.dist.PinReadableSegments()
|
||||
_, _, version, err = s.dist.PinReadableSegments(1.0)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
@ -675,10 +683,14 @@ func (s *DistributionSuite) TestMarkOfflineSegments() {
|
||||
defer s.TearDownTest()
|
||||
|
||||
s.dist.AddDistributions(tc.input...)
|
||||
sealedSegmentID := lo.Map(tc.input, func(t SegmentEntry, _ int) int64 {
|
||||
return t.SegmentID
|
||||
sealedSegmentRowCount := lo.SliceToMap(tc.input, func(t SegmentEntry) (int64, int64) {
|
||||
return t.SegmentID, 100
|
||||
})
|
||||
s.dist.SyncTargetVersion(1000, nil, nil, sealedSegmentID, nil)
|
||||
s.dist.SyncTargetVersion(&querypb.SyncAction{
|
||||
TargetVersion: 1000,
|
||||
SealedSegmentRowCount: sealedSegmentRowCount,
|
||||
DroppedInTarget: nil,
|
||||
}, nil)
|
||||
s.dist.MarkOfflineSegments(tc.offlines...)
|
||||
s.Equal(tc.serviceable, s.dist.Serviceable())
|
||||
|
||||
@ -740,29 +752,187 @@ func (s *DistributionSuite) Test_SyncTargetVersion() {
|
||||
|
||||
s.dist.AddGrowing(growing...)
|
||||
s.dist.AddDistributions(sealed...)
|
||||
s.dist.SyncTargetVersion(2, []int64{1}, []int64{2, 3}, []int64{6}, []int64{})
|
||||
s.dist.SyncTargetVersion(&querypb.SyncAction{
|
||||
TargetVersion: 2,
|
||||
GrowingInTarget: []int64{1},
|
||||
SealedSegmentRowCount: map[int64]int64{4: 100, 5: 200},
|
||||
DroppedInTarget: []int64{6},
|
||||
}, []int64{1})
|
||||
|
||||
s1, s2, _, err := s.dist.PinReadableSegments()
|
||||
s1, s2, _, err := s.dist.PinReadableSegments(1.0)
|
||||
s.Require().NoError(err)
|
||||
s.Len(s1[0].Segments, 1)
|
||||
s.Len(s2, 2)
|
||||
s.Len(s1[0].Segments, 2)
|
||||
s.Len(s2, 1)
|
||||
|
||||
s1, s2, _ = s.dist.PinOnlineSegments()
|
||||
s.Len(s1[0].Segments, 3)
|
||||
s.Len(s2, 3)
|
||||
|
||||
s.dist.queryView.serviceable.Store(true)
|
||||
s.dist.SyncTargetVersion(2, []int64{1}, []int64{222}, []int64{}, []int64{})
|
||||
s.True(s.dist.Serviceable())
|
||||
|
||||
s.dist.SyncTargetVersion(2, []int64{1}, []int64{}, []int64{333}, []int64{})
|
||||
s.dist.SyncTargetVersion(&querypb.SyncAction{
|
||||
TargetVersion: 2,
|
||||
GrowingInTarget: []int64{1},
|
||||
SealedSegmentRowCount: map[int64]int64{333: 100},
|
||||
DroppedInTarget: []int64{},
|
||||
}, []int64{1})
|
||||
s.False(s.dist.Serviceable())
|
||||
|
||||
s.dist.SyncTargetVersion(2, []int64{1}, []int64{}, []int64{333}, []int64{1, 2, 3})
|
||||
_, _, _, err = s.dist.PinReadableSegments()
|
||||
_, _, _, err = s.dist.PinReadableSegments(1.0)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func TestDistributionSuite(t *testing.T) {
|
||||
suite.Run(t, new(DistributionSuite))
|
||||
}
|
||||
|
||||
func TestNewChannelQueryView(t *testing.T) {
|
||||
growings := []int64{1, 2, 3}
|
||||
sealedWithRowCount := map[int64]int64{4: 100, 5: 200, 6: 300}
|
||||
partitions := []int64{7, 8, 9}
|
||||
version := int64(10)
|
||||
|
||||
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
|
||||
assert.NotNil(t, view)
|
||||
assert.ElementsMatch(t, growings, view.growingSegments.Collect())
|
||||
assert.ElementsMatch(t, lo.Keys(sealedWithRowCount), lo.Keys(view.sealedSegmentRowCount))
|
||||
assert.True(t, view.partitions.Contain(7))
|
||||
assert.True(t, view.partitions.Contain(8))
|
||||
assert.True(t, view.partitions.Contain(9))
|
||||
assert.Equal(t, version, view.version)
|
||||
assert.Equal(t, float64(0), view.loadedRatio.Load())
|
||||
assert.False(t, view.Serviceable())
|
||||
}
|
||||
|
||||
func TestDistribution_NewDistribution(t *testing.T) {
|
||||
channelName := "test_channel"
|
||||
growings := []int64{1, 2, 3}
|
||||
sealedWithRowCount := map[int64]int64{4: 100, 5: 200, 6: 300}
|
||||
partitions := []int64{7, 8, 9}
|
||||
version := int64(10)
|
||||
|
||||
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
|
||||
dist := NewDistribution(channelName, view)
|
||||
|
||||
assert.NotNil(t, dist)
|
||||
assert.Equal(t, channelName, dist.channelName)
|
||||
assert.Equal(t, view, dist.queryView)
|
||||
assert.NotNil(t, dist.growingSegments)
|
||||
assert.NotNil(t, dist.sealedSegments)
|
||||
assert.NotNil(t, dist.snapshots)
|
||||
assert.NotNil(t, dist.current)
|
||||
}
|
||||
|
||||
func TestDistribution_UpdateServiceable(t *testing.T) {
|
||||
channelName := "test_channel"
|
||||
growings := []int64{1, 2, 3}
|
||||
sealedWithRowCount := map[int64]int64{4: 100, 5: 100, 6: 100}
|
||||
partitions := []int64{7, 8, 9}
|
||||
version := int64(10)
|
||||
|
||||
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
|
||||
dist := NewDistribution(channelName, view)
|
||||
|
||||
// Test with no segments loaded
|
||||
dist.updateServiceable("test")
|
||||
assert.False(t, dist.Serviceable())
|
||||
assert.Equal(t, float64(0), dist.queryView.GetLoadedRatio())
|
||||
|
||||
// Test with some segments loaded
|
||||
dist.sealedSegments[4] = SegmentEntry{
|
||||
SegmentID: 4,
|
||||
Offline: false,
|
||||
}
|
||||
dist.growingSegments[1] = SegmentEntry{
|
||||
SegmentID: 1,
|
||||
}
|
||||
dist.updateServiceable("test")
|
||||
assert.False(t, dist.Serviceable())
|
||||
assert.Equal(t, float64(2)/float64(6), dist.queryView.GetLoadedRatio())
|
||||
|
||||
// Test with all segments loaded
|
||||
for id := range sealedWithRowCount {
|
||||
dist.sealedSegments[id] = SegmentEntry{
|
||||
SegmentID: id,
|
||||
Offline: false,
|
||||
}
|
||||
}
|
||||
for _, id := range growings {
|
||||
dist.growingSegments[id] = SegmentEntry{
|
||||
SegmentID: id,
|
||||
}
|
||||
}
|
||||
dist.updateServiceable("test")
|
||||
assert.True(t, dist.Serviceable())
|
||||
assert.Equal(t, float64(1), dist.queryView.GetLoadedRatio())
|
||||
}
|
||||
|
||||
func TestDistribution_SyncTargetVersion(t *testing.T) {
|
||||
channelName := "test_channel"
|
||||
growings := []int64{1, 2, 3}
|
||||
sealedWithRowCount := map[int64]int64{4: 100, 5: 100, 6: 100}
|
||||
partitions := []int64{7, 8, 9}
|
||||
version := int64(10)
|
||||
|
||||
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
|
||||
dist := NewDistribution(channelName, view)
|
||||
|
||||
// Add some initial segments
|
||||
dist.growingSegments[1] = SegmentEntry{
|
||||
SegmentID: 1,
|
||||
}
|
||||
dist.sealedSegments[4] = SegmentEntry{
|
||||
SegmentID: 4,
|
||||
}
|
||||
|
||||
// Create a new sync action
|
||||
action := &querypb.SyncAction{
|
||||
GrowingInTarget: []int64{1, 2},
|
||||
SealedSegmentRowCount: map[int64]int64{4: 100, 5: 100},
|
||||
DroppedInTarget: []int64{3},
|
||||
TargetVersion: version + 1,
|
||||
}
|
||||
|
||||
// Sync the view
|
||||
dist.SyncTargetVersion(action, partitions)
|
||||
|
||||
// Verify the changes
|
||||
assert.Equal(t, action.GetTargetVersion(), dist.queryView.version)
|
||||
assert.ElementsMatch(t, action.GetGrowingInTarget(), dist.queryView.growingSegments.Collect())
|
||||
assert.ElementsMatch(t, lo.Keys(action.GetSealedSegmentRowCount()), lo.Keys(dist.queryView.sealedSegmentRowCount))
|
||||
assert.True(t, dist.queryView.partitions.Contain(7))
|
||||
assert.True(t, dist.queryView.partitions.Contain(8))
|
||||
assert.True(t, dist.queryView.partitions.Contain(9))
|
||||
|
||||
// Verify growing segment target version
|
||||
assert.Equal(t, action.GetTargetVersion(), dist.growingSegments[1].TargetVersion)
|
||||
|
||||
// Verify sealed segment target version
|
||||
assert.Equal(t, action.GetTargetVersion(), dist.sealedSegments[4].TargetVersion)
|
||||
}
|
||||
|
||||
func TestDistribution_MarkOfflineSegments(t *testing.T) {
|
||||
channelName := "test_channel"
|
||||
view := NewChannelQueryView([]int64{}, map[int64]int64{1: 100, 2: 200}, []int64{}, 0)
|
||||
dist := NewDistribution(channelName, view)
|
||||
|
||||
// Add some segments
|
||||
dist.sealedSegments[1] = SegmentEntry{
|
||||
SegmentID: 1,
|
||||
NodeID: 100,
|
||||
Version: 1,
|
||||
}
|
||||
dist.sealedSegments[2] = SegmentEntry{
|
||||
SegmentID: 2,
|
||||
NodeID: 100,
|
||||
Version: 1,
|
||||
}
|
||||
|
||||
// Mark segments offline
|
||||
dist.MarkOfflineSegments(1, 2)
|
||||
|
||||
// Verify the changes
|
||||
assert.True(t, dist.sealedSegments[1].Offline)
|
||||
assert.True(t, dist.sealedSegments[2].Offline)
|
||||
assert.Equal(t, int64(-1), dist.sealedSegments[1].NodeID)
|
||||
assert.Equal(t, int64(-1), dist.sealedSegments[2].NodeID)
|
||||
assert.Equal(t, unreadableTargetVersion, dist.sealedSegments[1].Version)
|
||||
assert.Equal(t, unreadableTargetVersion, dist.sealedSegments[2].Version)
|
||||
}
|
||||
|
||||
@ -8,8 +8,6 @@ import (
|
||||
internalpb "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
||||
|
||||
querypb "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
|
||||
schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
@ -243,12 +241,12 @@ func (_c *MockShardDelegator_GetPartitionStatsVersions_Call) RunAndReturn(run fu
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetQueryView provides a mock function with no fields
|
||||
func (_m *MockShardDelegator) GetQueryView() *channelQueryView {
|
||||
// GetChannelQueryView provides a mock function with no fields
|
||||
func (_m *MockShardDelegator) GetChannelQueryView() *channelQueryView {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetQueryView")
|
||||
panic("no return value specified for GetChannelQueryView")
|
||||
}
|
||||
|
||||
var r0 *channelQueryView
|
||||
@ -263,29 +261,29 @@ func (_m *MockShardDelegator) GetQueryView() *channelQueryView {
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockShardDelegator_GetQueryView_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetQueryView'
|
||||
type MockShardDelegator_GetQueryView_Call struct {
|
||||
// MockShardDelegator_GetChannelQueryView_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChannelQueryView'
|
||||
type MockShardDelegator_GetChannelQueryView_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetQueryView is a helper method to define mock.On call
|
||||
func (_e *MockShardDelegator_Expecter) GetQueryView() *MockShardDelegator_GetQueryView_Call {
|
||||
return &MockShardDelegator_GetQueryView_Call{Call: _e.mock.On("GetQueryView")}
|
||||
// GetChannelQueryView is a helper method to define mock.On call
|
||||
func (_e *MockShardDelegator_Expecter) GetChannelQueryView() *MockShardDelegator_GetChannelQueryView_Call {
|
||||
return &MockShardDelegator_GetChannelQueryView_Call{Call: _e.mock.On("GetChannelQueryView")}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_GetQueryView_Call) Run(run func()) *MockShardDelegator_GetQueryView_Call {
|
||||
func (_c *MockShardDelegator_GetChannelQueryView_Call) Run(run func()) *MockShardDelegator_GetChannelQueryView_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_GetQueryView_Call) Return(_a0 *channelQueryView) *MockShardDelegator_GetQueryView_Call {
|
||||
func (_c *MockShardDelegator_GetChannelQueryView_Call) Return(_a0 *channelQueryView) *MockShardDelegator_GetChannelQueryView_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_GetQueryView_Call) RunAndReturn(run func() *channelQueryView) *MockShardDelegator_GetQueryView_Call {
|
||||
func (_c *MockShardDelegator_GetChannelQueryView_Call) RunAndReturn(run func() *channelQueryView) *MockShardDelegator_GetChannelQueryView_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@ -1037,9 +1035,9 @@ func (_c *MockShardDelegator_SyncPartitionStats_Call) RunAndReturn(run func(cont
|
||||
return _c
|
||||
}
|
||||
|
||||
// SyncTargetVersion provides a mock function with given fields: newVersion, partitions, growingInTarget, sealedInTarget, droppedInTarget, checkpoint, deleteSeekPos
|
||||
func (_m *MockShardDelegator) SyncTargetVersion(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, droppedInTarget []int64, checkpoint *msgpb.MsgPosition, deleteSeekPos *msgpb.MsgPosition) {
|
||||
_m.Called(newVersion, partitions, growingInTarget, sealedInTarget, droppedInTarget, checkpoint, deleteSeekPos)
|
||||
// SyncTargetVersion provides a mock function with given fields: action, partitions
|
||||
func (_m *MockShardDelegator) SyncTargetVersion(action *querypb.SyncAction, partitions []int64) {
|
||||
_m.Called(action, partitions)
|
||||
}
|
||||
|
||||
// MockShardDelegator_SyncTargetVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncTargetVersion'
|
||||
@ -1048,20 +1046,15 @@ type MockShardDelegator_SyncTargetVersion_Call struct {
|
||||
}
|
||||
|
||||
// SyncTargetVersion is a helper method to define mock.On call
|
||||
// - newVersion int64
|
||||
// - action *querypb.SyncAction
|
||||
// - partitions []int64
|
||||
// - growingInTarget []int64
|
||||
// - sealedInTarget []int64
|
||||
// - droppedInTarget []int64
|
||||
// - checkpoint *msgpb.MsgPosition
|
||||
// - deleteSeekPos *msgpb.MsgPosition
|
||||
func (_e *MockShardDelegator_Expecter) SyncTargetVersion(newVersion interface{}, partitions interface{}, growingInTarget interface{}, sealedInTarget interface{}, droppedInTarget interface{}, checkpoint interface{}, deleteSeekPos interface{}) *MockShardDelegator_SyncTargetVersion_Call {
|
||||
return &MockShardDelegator_SyncTargetVersion_Call{Call: _e.mock.On("SyncTargetVersion", newVersion, partitions, growingInTarget, sealedInTarget, droppedInTarget, checkpoint, deleteSeekPos)}
|
||||
func (_e *MockShardDelegator_Expecter) SyncTargetVersion(action interface{}, partitions interface{}) *MockShardDelegator_SyncTargetVersion_Call {
|
||||
return &MockShardDelegator_SyncTargetVersion_Call{Call: _e.mock.On("SyncTargetVersion", action, partitions)}
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_SyncTargetVersion_Call) Run(run func(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, droppedInTarget []int64, checkpoint *msgpb.MsgPosition, deleteSeekPos *msgpb.MsgPosition)) *MockShardDelegator_SyncTargetVersion_Call {
|
||||
func (_c *MockShardDelegator_SyncTargetVersion_Call) Run(run func(action *querypb.SyncAction, partitions []int64)) *MockShardDelegator_SyncTargetVersion_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(int64), args[1].([]int64), args[2].([]int64), args[3].([]int64), args[4].([]int64), args[5].(*msgpb.MsgPosition), args[6].(*msgpb.MsgPosition))
|
||||
run(args[0].(*querypb.SyncAction), args[1].([]int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@ -1071,7 +1064,7 @@ func (_c *MockShardDelegator_SyncTargetVersion_Call) Return() *MockShardDelegato
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardDelegator_SyncTargetVersion_Call) RunAndReturn(run func(int64, []int64, []int64, []int64, []int64, *msgpb.MsgPosition, *msgpb.MsgPosition)) *MockShardDelegator_SyncTargetVersion_Call {
|
||||
func (_c *MockShardDelegator_SyncTargetVersion_Call) RunAndReturn(run func(*querypb.SyncAction, []int64)) *MockShardDelegator_SyncTargetVersion_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@ -231,6 +231,7 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque
|
||||
if err != nil {
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader, fmt.Sprint(req.GetReq().GetCollectionID())).Inc()
|
||||
}
|
||||
metrics.QueryNodePartialResultCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, fmt.Sprint(req.GetReq().GetCollectionID())).Inc()
|
||||
}()
|
||||
|
||||
log.Debug("start do query with channel",
|
||||
|
||||
@ -251,6 +251,13 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
|
||||
}
|
||||
}()
|
||||
|
||||
queryView := delegator.NewChannelQueryView(
|
||||
channel.GetUnflushedSegmentIds(),
|
||||
req.GetSealedSegmentRowCount(),
|
||||
req.GetPartitionIDs(),
|
||||
req.GetTargetVersion(),
|
||||
)
|
||||
|
||||
delegator, err := delegator.NewShardDelegator(
|
||||
ctx,
|
||||
req.GetCollectionID(),
|
||||
@ -264,6 +271,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
|
||||
channel.GetSeekPosition().GetTimestamp(),
|
||||
node.queryHook,
|
||||
node.chunkManager,
|
||||
queryView,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn("failed to create shard delegator", zap.Error(err))
|
||||
@ -1252,7 +1260,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
|
||||
numOfGrowingRows += segment.InsertCount()
|
||||
}
|
||||
|
||||
queryView := delegator.GetQueryView()
|
||||
queryView := delegator.GetChannelQueryView()
|
||||
leaderViews = append(leaderViews, &querypb.LeaderView{
|
||||
Collection: delegator.Collection(),
|
||||
Channel: key,
|
||||
@ -1355,16 +1363,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
|
||||
return id, action.GetCheckpoint().Timestamp
|
||||
})
|
||||
shardDelegator.AddExcludedSegments(flushedInfo)
|
||||
deleteCP := action.GetDeleteCP()
|
||||
if deleteCP == nil {
|
||||
// for compatible with 2.4, we use checkpoint as deleteCP when deleteCP is nil
|
||||
deleteCP = action.GetCheckpoint()
|
||||
log.Info("use checkpoint as deleteCP",
|
||||
zap.String("channelName", req.GetChannel()),
|
||||
zap.Time("deleteSeekPos", tsoutil.PhysicalTime(action.GetCheckpoint().GetTimestamp())))
|
||||
}
|
||||
shardDelegator.SyncTargetVersion(action.GetTargetVersion(), req.GetLoadMeta().GetPartitionIDs(), action.GetGrowingInTarget(),
|
||||
action.GetSealedInTarget(), action.GetDroppedInTarget(), action.GetCheckpoint(), deleteCP)
|
||||
shardDelegator.SyncTargetVersion(action, req.GetLoadMeta().GetPartitionIDs())
|
||||
case querypb.SyncType_UpdatePartitionStats:
|
||||
log.Info("sync update partition stats versions")
|
||||
shardDelegator.SyncPartitionStats(ctx, action.PartitionStatsVersions)
|
||||
|
||||
@ -1393,9 +1393,13 @@ func (suite *ServiceSuite) TestSearch_Failed() {
|
||||
}
|
||||
|
||||
syncVersionAction := &querypb.SyncAction{
|
||||
Type: querypb.SyncType_UpdateVersion,
|
||||
SealedInTarget: []int64{1, 2, 3},
|
||||
TargetVersion: time.Now().UnixMilli(),
|
||||
Type: querypb.SyncType_UpdateVersion,
|
||||
SealedSegmentRowCount: map[int64]int64{
|
||||
1: 100,
|
||||
2: 200,
|
||||
3: 300,
|
||||
},
|
||||
TargetVersion: time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
syncReq.Actions = []*querypb.SyncAction{syncVersionAction}
|
||||
|
||||
@ -814,6 +814,18 @@ var (
|
||||
cgoNameLabelName,
|
||||
cgoTypeLabelName,
|
||||
})
|
||||
|
||||
QueryNodePartialResultCount = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: milvusNamespace,
|
||||
Subsystem: typeutil.QueryNodeRole,
|
||||
Name: "partial_result_count",
|
||||
Help: "count of partial result",
|
||||
}, []string{
|
||||
nodeIDLabelName,
|
||||
queryTypeLabelName,
|
||||
collectionIDLabelName,
|
||||
})
|
||||
)
|
||||
|
||||
// RegisterQueryNode registers QueryNode metrics
|
||||
@ -885,6 +897,7 @@ func RegisterQueryNode(registry *prometheus.Registry) {
|
||||
registry.MustRegister(QueryNodeDeleteBufferSize)
|
||||
registry.MustRegister(QueryNodeDeleteBufferRowNum)
|
||||
registry.MustRegister(QueryNodeCGOCallLatency)
|
||||
registry.MustRegister(QueryNodePartialResultCount)
|
||||
// Add cgo metrics
|
||||
RegisterCGOMetrics(registry)
|
||||
|
||||
@ -933,6 +946,13 @@ func CleanupQueryNodeCollectionMetrics(nodeID int64, collectionID int64) {
|
||||
collectionIDLabelName: collectionIDLabel,
|
||||
})
|
||||
|
||||
QueryNodePartialResultCount.
|
||||
DeletePartialMatch(
|
||||
prometheus.Labels{
|
||||
nodeIDLabelName: nodeIDLabel,
|
||||
collectionIDLabelName: collectionIDLabel,
|
||||
})
|
||||
|
||||
QueryNodeSearchHitSegmentNum.
|
||||
DeletePartialMatch(
|
||||
prometheus.Labels{
|
||||
|
||||
@ -281,6 +281,7 @@ message GetSegmentInfoResponse {
|
||||
message GetShardLeadersRequest {
|
||||
common.MsgBase base = 1;
|
||||
int64 collectionID = 2;
|
||||
bool with_unserviceable_shards = 3;
|
||||
}
|
||||
|
||||
message GetShardLeadersResponse {
|
||||
@ -297,6 +298,7 @@ message ShardLeadersList { // All leaders of all replicas of one shard
|
||||
string channel_name = 1;
|
||||
repeated int64 node_ids = 2;
|
||||
repeated string node_addrs = 3;
|
||||
repeated bool serviceable = 4;
|
||||
}
|
||||
|
||||
message SyncNewCreatedPartitionRequest {
|
||||
@ -334,6 +336,8 @@ message WatchDmChannelsRequest {
|
||||
int64 offlineNodeID = 11;
|
||||
int64 version = 12;
|
||||
repeated index.IndexInfo index_info_list = 13;
|
||||
int64 target_version = 14;
|
||||
map<int64, int64> sealed_segment_row_count = 15; // segmentID -> row count, same as unflushedSegmentIds in vchannelInfo
|
||||
}
|
||||
|
||||
message UnsubDmChannelRequest {
|
||||
@ -625,7 +629,7 @@ message LeaderView {
|
||||
}
|
||||
|
||||
message LeaderViewStatus {
|
||||
bool serviceable = 10;
|
||||
bool serviceable = 1;
|
||||
}
|
||||
|
||||
message SegmentDist {
|
||||
@ -725,6 +729,7 @@ message SyncAction {
|
||||
msg.MsgPosition checkpoint = 11;
|
||||
map<int64, int64> partition_stats_versions = 12;
|
||||
msg.MsgPosition deleteCP = 13;
|
||||
map<int64, int64> sealed_segment_row_count = 14; // segmentID -> row count, same as sealedInTarget
|
||||
}
|
||||
|
||||
message SyncDistributionRequest {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -2856,6 +2856,8 @@ type queryNodeConfig struct {
|
||||
IDFEnableDisk ParamItem `refreshable:"true"`
|
||||
IDFLocalPath ParamItem `refreshable:"true"`
|
||||
IDFWriteConcurrenct ParamItem `refreshable:"true"`
|
||||
// partial search
|
||||
PartialResultRequiredDataRatio ParamItem `refreshable:"true"`
|
||||
}
|
||||
|
||||
func (p *queryNodeConfig) init(base *BaseTable) {
|
||||
@ -3786,6 +3788,15 @@ user-task-polling:
|
||||
Export: true,
|
||||
}
|
||||
p.WorkerPoolingSize.Init(base.mgr)
|
||||
|
||||
p.PartialResultRequiredDataRatio = ParamItem{
|
||||
Key: "proxy.partialResultRequiredDataRatio",
|
||||
Version: "2.6.0",
|
||||
DefaultValue: "1",
|
||||
Doc: `partial result required data ratio, default to 1 which means disable partial result, otherwise, it will be used as the minimum data ratio for partial result`,
|
||||
Export: true,
|
||||
}
|
||||
p.PartialResultRequiredDataRatio.Init(base.mgr)
|
||||
}
|
||||
|
||||
// /////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -489,6 +489,10 @@ func TestComponentParam(t *testing.T) {
|
||||
assert.Equal(t, "/var/lib/milvus/data/mmap", Params.MmapDirPath.GetValue())
|
||||
|
||||
assert.Equal(t, 60*time.Second, Params.DiskSizeFetchInterval.GetAsDuration(time.Second))
|
||||
|
||||
assert.Equal(t, 1.0, Params.PartialResultRequiredDataRatio.GetAsFloat())
|
||||
params.Save(Params.PartialResultRequiredDataRatio.Key, "0.8")
|
||||
assert.Equal(t, 0.8, Params.PartialResultRequiredDataRatio.GetAsFloat())
|
||||
})
|
||||
|
||||
t.Run("test dataCoordConfig", func(t *testing.T) {
|
||||
|
||||
@ -13,6 +13,8 @@ package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
@ -23,6 +25,14 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
)
|
||||
|
||||
func getCaller(skip int) string {
|
||||
_, file, line, ok := runtime.Caller(skip)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
return file + ":" + strconv.Itoa(line)
|
||||
}
|
||||
|
||||
// Do will run function with retry mechanism.
|
||||
// fn is the func to run.
|
||||
// Option can control the retry times and timeout.
|
||||
@ -43,7 +53,10 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
|
||||
for i := uint(0); c.attempts == 0 || i < c.attempts; i++ {
|
||||
if err := fn(); err != nil {
|
||||
if i%4 == 0 {
|
||||
log.Warn("retry func failed", zap.Uint("retried", i), zap.Error(err))
|
||||
log.Warn("retry func failed",
|
||||
zap.Uint("retried", i),
|
||||
zap.Error(err),
|
||||
zap.String("caller", getCaller(2)))
|
||||
}
|
||||
|
||||
if !IsRecoverable(err) {
|
||||
@ -52,6 +65,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
|
||||
zap.Uint("retried", i),
|
||||
zap.Uint("attempt", c.attempts),
|
||||
zap.Bool("isContextErr", isContextErr),
|
||||
zap.String("caller", getCaller(2)),
|
||||
)
|
||||
if isContextErr && lastErr != nil {
|
||||
return lastErr
|
||||
@ -62,6 +76,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
|
||||
log.Warn("retry func failed, not be retryable",
|
||||
zap.Uint("retried", i),
|
||||
zap.Uint("attempt", c.attempts),
|
||||
zap.String("caller", getCaller(2)),
|
||||
)
|
||||
return err
|
||||
}
|
||||
@ -73,6 +88,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
|
||||
zap.Uint("retried", i),
|
||||
zap.Uint("attempt", c.attempts),
|
||||
zap.Bool("isContextErr", isContextErr),
|
||||
zap.String("caller", getCaller(2)),
|
||||
)
|
||||
if isContextErr && lastErr != nil {
|
||||
return lastErr
|
||||
@ -88,6 +104,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
|
||||
log.Warn("retry func failed, ctx done",
|
||||
zap.Uint("retried", i),
|
||||
zap.Uint("attempt", c.attempts),
|
||||
zap.String("caller", getCaller(2)),
|
||||
)
|
||||
return lastErr
|
||||
}
|
||||
@ -127,11 +144,22 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
|
||||
for i := uint(0); i < c.attempts; i++ {
|
||||
if shouldRetry, err := fn(); err != nil {
|
||||
if i%4 == 0 {
|
||||
log.Warn("retry func failed", zap.Uint("retried", i), zap.Error(err))
|
||||
log.Warn("retry func failed",
|
||||
zap.Uint("retried", i),
|
||||
zap.String("caller", getCaller(2)),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
if !shouldRetry {
|
||||
if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil {
|
||||
isContextErr := errors.IsAny(err, context.Canceled, context.DeadlineExceeded)
|
||||
log.Warn("retry func failed, not be recoverable",
|
||||
zap.Uint("retried", i),
|
||||
zap.Uint("attempt", c.attempts),
|
||||
zap.Bool("isContextErr", isContextErr),
|
||||
zap.String("caller", getCaller(2)),
|
||||
)
|
||||
if isContextErr && lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
return err
|
||||
@ -139,8 +167,14 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
|
||||
|
||||
deadline, ok := ctx.Deadline()
|
||||
if ok && time.Until(deadline) < c.sleep {
|
||||
// to avoid sleep until ctx done
|
||||
if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil {
|
||||
isContextErr := errors.IsAny(err, context.Canceled, context.DeadlineExceeded)
|
||||
log.Warn("retry func failed, deadline",
|
||||
zap.Uint("retried", i),
|
||||
zap.Uint("attempt", c.attempts),
|
||||
zap.Bool("isContextErr", isContextErr),
|
||||
zap.String("caller", getCaller(2)),
|
||||
)
|
||||
if isContextErr && lastErr != nil {
|
||||
return lastErr
|
||||
}
|
||||
return err
|
||||
@ -151,6 +185,11 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
|
||||
select {
|
||||
case <-time.After(c.sleep):
|
||||
case <-ctx.Done():
|
||||
log.Warn("retry func failed, ctx done",
|
||||
zap.Uint("retried", i),
|
||||
zap.Uint("attempt", c.attempts),
|
||||
zap.String("caller", getCaller(2)),
|
||||
)
|
||||
return lastErr
|
||||
}
|
||||
|
||||
@ -162,6 +201,12 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
log.Warn("retry func failed, reach max retry",
|
||||
zap.Uint("attempt", c.attempts),
|
||||
zap.String("caller", getCaller(2)),
|
||||
)
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
|
||||
@ -349,6 +349,7 @@ func (s *HybridSearchSuite) TestHybridSearchSingleSubReq() {
|
||||
})
|
||||
s.NoError(err)
|
||||
s.NoError(merr.Error(loadStatus))
|
||||
s.WaitForLoad(ctx, collectionName)
|
||||
|
||||
// search
|
||||
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
|
||||
|
||||
@ -0,0 +1,542 @@
|
||||
package partialsearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
||||
const (
|
||||
dim = 128
|
||||
dbName = ""
|
||||
|
||||
timeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type PartialSearchTestSuit struct {
|
||||
integration.MiniClusterSuite
|
||||
}
|
||||
|
||||
func (s *PartialSearchTestSuit) SetupSuite() {
|
||||
paramtable.Init()
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.GracefulStopTimeout.Key, "1")
|
||||
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.AutoBalanceInterval.Key, "10000")
|
||||
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceCheckInterval.Key, "10000")
|
||||
|
||||
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.TaskExecutionCap.Key, "1")
|
||||
|
||||
// make query survive when delegator is down
|
||||
paramtable.Get().Save(paramtable.Get().ProxyCfg.RetryTimesOnReplica.Key, "10")
|
||||
|
||||
s.Require().NoError(s.SetupEmbedEtcd())
|
||||
}
|
||||
|
||||
func (s *PartialSearchTestSuit) initCollection(collectionName string, replica int, channelNum int, segmentNum int, segmentRowNum int) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
s.CreateCollectionWithConfiguration(ctx, &integration.CreateCollectionConfig{
|
||||
DBName: dbName,
|
||||
Dim: dim,
|
||||
CollectionName: collectionName,
|
||||
ChannelNum: channelNum,
|
||||
SegmentNum: segmentNum,
|
||||
RowNumPerSegment: segmentRowNum,
|
||||
})
|
||||
|
||||
for i := 1; i < replica; i++ {
|
||||
s.Cluster.AddQueryNode()
|
||||
}
|
||||
|
||||
// load
|
||||
loadStatus, err := s.Cluster.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
ReplicaNumber: int32(replica),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
|
||||
s.True(merr.Ok(loadStatus))
|
||||
s.WaitForLoad(ctx, collectionName)
|
||||
log.Info("initCollection Done")
|
||||
}
|
||||
|
||||
func (s *PartialSearchTestSuit) executeQuery(collection string) (int, error) {
|
||||
ctx := context.Background()
|
||||
queryResult, err := s.Cluster.Proxy.Query(ctx, &milvuspb.QueryRequest{
|
||||
DbName: "",
|
||||
CollectionName: collection,
|
||||
Expr: "",
|
||||
OutputFields: []string{"count(*)"},
|
||||
})
|
||||
|
||||
if err := merr.CheckRPCCall(queryResult.GetStatus(), err); err != nil {
|
||||
log.Info("query failed", zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int(queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]), nil
|
||||
}
|
||||
|
||||
// expected return partial result, no search failures
|
||||
func (s *PartialSearchTestSuit) TestSingleNodeDownOnSingleReplica() {
|
||||
partialResultRequiredDataRatio := 0.3
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
|
||||
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
|
||||
|
||||
// init cluster with 6 querynode
|
||||
for i := 1; i < 6; i++ {
|
||||
s.Cluster.AddQueryNode()
|
||||
}
|
||||
|
||||
// init collection with 1 replica, 2 channels, 6 segments, 2000 rows per segment
|
||||
// expect each node has 1 channel and 2 segments
|
||||
name := "test_balance_" + funcutil.GenRandomStr()
|
||||
channelNum := 2
|
||||
segmentNumInChannel := 6
|
||||
segmentRowNum := 2000
|
||||
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
|
||||
totalEntities := segmentNumInChannel * segmentRowNum
|
||||
|
||||
stopSearchCh := make(chan struct{})
|
||||
failCounter := atomic.NewInt64(0)
|
||||
partialResultCounter := atomic.NewInt64(0)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopSearchCh:
|
||||
log.Info("stop search")
|
||||
return
|
||||
default:
|
||||
numEntities, err := s.executeQuery(name)
|
||||
if err != nil {
|
||||
log.Info("query failed", zap.Error(err))
|
||||
failCounter.Inc()
|
||||
} else if numEntities < totalEntities {
|
||||
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
|
||||
partialResultCounter.Inc()
|
||||
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
s.Equal(failCounter.Load(), int64(0))
|
||||
s.Equal(partialResultCounter.Load(), int64(0)) // must fix this cornor case
|
||||
|
||||
// stop qn in single replica expected got search failures
|
||||
s.Cluster.QueryNode.Stop()
|
||||
time.Sleep(10 * time.Second)
|
||||
s.Equal(failCounter.Load(), int64(0))
|
||||
s.True(partialResultCounter.Load() >= 0)
|
||||
close(stopSearchCh)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// expected some search failures, but partial result is returned before all data is loaded
|
||||
// for case which all querynode down, partial search can decrease recovery time
|
||||
func (s *PartialSearchTestSuit) TestAllNodeDownOnSingleReplica() {
|
||||
partialResultRequiredDataRatio := 0.5
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
|
||||
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
|
||||
|
||||
// init cluster with 2 querynode
|
||||
s.Cluster.AddQueryNode()
|
||||
|
||||
// init collection with 1 replica, 2 channels, 10 segments, 2000 rows per segment
|
||||
// expect each node has 1 channel and 2 segments
|
||||
name := "test_balance_" + funcutil.GenRandomStr()
|
||||
channelNum := 2
|
||||
segmentNumInChannel := 10
|
||||
segmentRowNum := 2000
|
||||
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
|
||||
totalEntities := segmentNumInChannel * segmentRowNum
|
||||
|
||||
stopSearchCh := make(chan struct{})
|
||||
failCounter := atomic.NewInt64(0)
|
||||
partialResultCounter := atomic.NewInt64(0)
|
||||
|
||||
partialResultRecoverTs := atomic.NewInt64(0)
|
||||
fullResultRecoverTs := atomic.NewInt64(0)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopSearchCh:
|
||||
log.Info("stop search")
|
||||
return
|
||||
default:
|
||||
numEntities, err := s.executeQuery(name)
|
||||
if err != nil {
|
||||
log.Info("query failed", zap.Error(err))
|
||||
failCounter.Inc()
|
||||
} else if failCounter.Load() > 0 {
|
||||
if numEntities < totalEntities {
|
||||
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
|
||||
partialResultCounter.Inc()
|
||||
partialResultRecoverTs.Store(time.Now().UnixNano())
|
||||
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
|
||||
} else {
|
||||
log.Info("query return full result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
|
||||
fullResultRecoverTs.Store(time.Now().UnixNano())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
s.Equal(failCounter.Load(), int64(0))
|
||||
s.Equal(partialResultCounter.Load(), int64(0))
|
||||
|
||||
// stop all qn in single replica expected got search failures
|
||||
for _, qn := range s.Cluster.GetAllQueryNodes() {
|
||||
qn.Stop()
|
||||
}
|
||||
s.Cluster.AddQueryNode()
|
||||
|
||||
time.Sleep(20 * time.Second)
|
||||
s.True(failCounter.Load() >= 0)
|
||||
s.True(partialResultCounter.Load() >= 0)
|
||||
log.Info("partialResultRecoverTs", zap.Int64("partialResultRecoverTs", partialResultRecoverTs.Load()), zap.Int64("fullResultRecoverTs", fullResultRecoverTs.Load()))
|
||||
s.True(partialResultRecoverTs.Load() < fullResultRecoverTs.Load())
|
||||
close(stopSearchCh)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// expected return full result, no search failures
|
||||
// cause we won't pick best replica to response query, there may return partial result even when only one replica is partial loaded
|
||||
func (s *PartialSearchTestSuit) TestSingleNodeDownOnMultiReplica() {
|
||||
partialResultRequiredDataRatio := 0.5
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
|
||||
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
|
||||
// init cluster with 4 querynode
|
||||
qn1 := s.Cluster.AddQueryNode()
|
||||
s.Cluster.AddQueryNode()
|
||||
|
||||
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
|
||||
// expect each node has 1 channel and 2 segments
|
||||
name := "test_balance_" + funcutil.GenRandomStr()
|
||||
channelNum := 2
|
||||
segmentNumInChannel := 2
|
||||
segmentRowNum := 2000
|
||||
s.initCollection(name, 2, channelNum, segmentNumInChannel, segmentRowNum)
|
||||
totalEntities := segmentNumInChannel * segmentRowNum
|
||||
|
||||
stopSearchCh := make(chan struct{})
|
||||
failCounter := atomic.NewInt64(0)
|
||||
partialResultCounter := atomic.NewInt64(0)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopSearchCh:
|
||||
log.Info("stop search")
|
||||
return
|
||||
default:
|
||||
numEntities, err := s.executeQuery(name)
|
||||
if err != nil {
|
||||
log.Info("query failed", zap.Error(err))
|
||||
failCounter.Inc()
|
||||
} else if numEntities < totalEntities {
|
||||
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
|
||||
partialResultCounter.Inc()
|
||||
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
s.Equal(failCounter.Load(), int64(0))
|
||||
s.Equal(partialResultCounter.Load(), int64(0))
|
||||
|
||||
// stop qn in single replica expected got search failures
|
||||
qn1.Stop()
|
||||
time.Sleep(10 * time.Second)
|
||||
s.Equal(failCounter.Load(), int64(0))
|
||||
s.True(partialResultCounter.Load() >= 0)
|
||||
close(stopSearchCh)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// expected return partial result, no search failures
|
||||
func (s *PartialSearchTestSuit) TestEachReplicaHasNodeDownOnMultiReplica() {
|
||||
partialResultRequiredDataRatio := 0.3
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
|
||||
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
|
||||
// init cluster with 12 querynode
|
||||
for i := 2; i < 12; i++ {
|
||||
s.Cluster.AddQueryNode()
|
||||
}
|
||||
|
||||
// init collection with 1 replica, 2 channels, 18 segments, 2000 rows per segment
|
||||
// expect each node has 1 channel and 2 segments
|
||||
name := "test_balance_" + funcutil.GenRandomStr()
|
||||
channelNum := 2
|
||||
segmentNumInChannel := 9
|
||||
segmentRowNum := 2000
|
||||
s.initCollection(name, 2, channelNum, segmentNumInChannel, segmentRowNum)
|
||||
totalEntities := segmentNumInChannel * segmentRowNum
|
||||
|
||||
ctx := context.Background()
|
||||
stopSearchCh := make(chan struct{})
|
||||
failCounter := atomic.NewInt64(0)
|
||||
partialResultCounter := atomic.NewInt64(0)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopSearchCh:
|
||||
log.Info("stop search")
|
||||
return
|
||||
default:
|
||||
numEntities, err := s.executeQuery(name)
|
||||
if err != nil {
|
||||
log.Info("query failed", zap.Error(err))
|
||||
failCounter.Inc()
|
||||
continue
|
||||
} else if numEntities < totalEntities {
|
||||
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
|
||||
partialResultCounter.Inc()
|
||||
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
s.Equal(failCounter.Load(), int64(0))
|
||||
s.Equal(partialResultCounter.Load(), int64(0))
|
||||
|
||||
replicaResp, err := s.Cluster.Proxy.GetReplicas(ctx, &milvuspb.GetReplicasRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: name,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, replicaResp.GetStatus().GetErrorCode())
|
||||
s.Equal(2, len(replicaResp.GetReplicas()))
|
||||
|
||||
// for each replica, choose a querynode to stop
|
||||
for _, replica := range replicaResp.GetReplicas() {
|
||||
for _, qn := range s.Cluster.GetAllQueryNodes() {
|
||||
if funcutil.SliceContain(replica.GetNodeIds(), qn.GetQueryNode().GetNodeID()) {
|
||||
qn.Stop()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
s.True(failCounter.Load() >= 0)
|
||||
s.True(partialResultCounter.Load() >= 0)
|
||||
close(stopSearchCh)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// when set high partial result required data ratio, partial search will not be triggered, expected search failures
|
||||
// for example, set 0.8, but each querynode crash will lost 50% data, so partial search will not be triggered
|
||||
func (s *PartialSearchTestSuit) TestPartialResultRequiredDataRatioTooHigh() {
|
||||
partialResultRequiredDataRatio := 0.8
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
|
||||
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
|
||||
// init cluster with 2 querynode
|
||||
qn1 := s.Cluster.AddQueryNode()
|
||||
|
||||
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
|
||||
// expect each node has 1 channel and 2 segments
|
||||
name := "test_balance_" + funcutil.GenRandomStr()
|
||||
channelNum := 2
|
||||
segmentNumInChannel := 2
|
||||
segmentRowNum := 2000
|
||||
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
|
||||
totalEntities := segmentNumInChannel * segmentRowNum
|
||||
|
||||
stopSearchCh := make(chan struct{})
|
||||
failCounter := atomic.NewInt64(0)
|
||||
partialResultCounter := atomic.NewInt64(0)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopSearchCh:
|
||||
log.Info("stop search")
|
||||
return
|
||||
default:
|
||||
numEntities, err := s.executeQuery(name)
|
||||
if err != nil {
|
||||
log.Info("query failed", zap.Error(err))
|
||||
failCounter.Inc()
|
||||
} else if numEntities < totalEntities {
|
||||
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
|
||||
partialResultCounter.Inc()
|
||||
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
s.Equal(failCounter.Load(), int64(0))
|
||||
s.Equal(partialResultCounter.Load(), int64(0))
|
||||
|
||||
qn1.Stop()
|
||||
time.Sleep(10 * time.Second)
|
||||
s.True(failCounter.Load() >= 0)
|
||||
s.True(partialResultCounter.Load() == 0)
|
||||
close(stopSearchCh)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// set partial result required data ratio to 0, expected no partial result and no search failures even after all querynode down
|
||||
func (s *PartialSearchTestSuit) TestSearchNeverFails() {
|
||||
partialResultRequiredDataRatio := 0.0
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
|
||||
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
|
||||
// init cluster with 2 querynode
|
||||
s.Cluster.AddQueryNode()
|
||||
|
||||
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
|
||||
// expect each node has 1 channel and 2 segments
|
||||
name := "test_balance_" + funcutil.GenRandomStr()
|
||||
channelNum := 2
|
||||
segmentNumInChannel := 2
|
||||
segmentRowNum := 2000
|
||||
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
|
||||
totalEntities := segmentNumInChannel * segmentRowNum
|
||||
|
||||
stopSearchCh := make(chan struct{})
|
||||
failCounter := atomic.NewInt64(0)
|
||||
partialResultCounter := atomic.NewInt64(0)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopSearchCh:
|
||||
log.Info("stop search")
|
||||
return
|
||||
default:
|
||||
numEntities, err := s.executeQuery(name)
|
||||
if err != nil {
|
||||
log.Info("query failed", zap.Error(err))
|
||||
failCounter.Inc()
|
||||
} else if numEntities < totalEntities {
|
||||
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
|
||||
partialResultCounter.Inc()
|
||||
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
s.Equal(failCounter.Load(), int64(0))
|
||||
s.Equal(partialResultCounter.Load(), int64(0))
|
||||
|
||||
for _, qn := range s.Cluster.GetAllQueryNodes() {
|
||||
qn.Stop()
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
s.Equal(failCounter.Load(), int64(0))
|
||||
s.True(partialResultCounter.Load() >= 0)
|
||||
close(stopSearchCh)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *PartialSearchTestSuit) TestSkipWaitTSafe() {
|
||||
partialResultRequiredDataRatio := 0.5
|
||||
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
|
||||
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
|
||||
// mock tsafe Delay
|
||||
paramtable.Get().Save(paramtable.Get().ProxyCfg.TimeTickInterval.Key, "30000")
|
||||
// init cluster with 5 querynode
|
||||
for i := 1; i < 5; i++ {
|
||||
s.Cluster.AddQueryNode()
|
||||
}
|
||||
|
||||
// init collection with 1 replica, 1 channels, 4 segments, 2000 rows per segment
|
||||
// expect each node has 1 channel and 2 segments
|
||||
name := "test_balance_" + funcutil.GenRandomStr()
|
||||
s.initCollection(name, 1, 1, 4, 2000)
|
||||
channelNum := 1
|
||||
segmentNumInChannel := 4
|
||||
segmentRowNum := 2000
|
||||
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
|
||||
totalEntities := segmentNumInChannel * segmentRowNum
|
||||
|
||||
stopSearchCh := make(chan struct{})
|
||||
failCounter := atomic.NewInt64(0)
|
||||
partialResultCounter := atomic.NewInt64(0)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopSearchCh:
|
||||
log.Info("stop search")
|
||||
return
|
||||
default:
|
||||
numEntities, err := s.executeQuery(name)
|
||||
if err != nil {
|
||||
log.Info("query failed", zap.Error(err))
|
||||
failCounter.Inc()
|
||||
} else if numEntities < totalEntities {
|
||||
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
|
||||
partialResultCounter.Inc()
|
||||
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
s.Equal(failCounter.Load(), int64(0))
|
||||
s.Equal(partialResultCounter.Load(), int64(0))
|
||||
|
||||
s.Cluster.QueryNode.Stop()
|
||||
time.Sleep(10 * time.Second)
|
||||
s.Equal(failCounter.Load(), int64(0))
|
||||
s.True(partialResultCounter.Load() >= 0)
|
||||
close(stopSearchCh)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPartialResult(t *testing.T) {
|
||||
suite.Run(t, new(PartialSearchTestSuit))
|
||||
}
|
||||
@ -1,347 +0,0 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package partialsearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
"github.com/milvus-io/milvus/tests/integration"
|
||||
)
|
||||
|
||||
type PartialSearchSuite struct {
|
||||
integration.MiniClusterSuite
|
||||
dim int
|
||||
numCollections int
|
||||
rowsPerCollection int
|
||||
waitTimeInSec time.Duration
|
||||
prefix string
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) setupParam() {
|
||||
s.dim = 128
|
||||
s.numCollections = 1
|
||||
s.rowsPerCollection = 100
|
||||
s.waitTimeInSec = time.Second * 10
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) loadCollection(collectionName string, dim int, wg *sync.WaitGroup) {
|
||||
c := s.Cluster
|
||||
dbName := ""
|
||||
schema := integration.ConstructSchema(collectionName, dim, true)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
s.NoError(err)
|
||||
|
||||
createCollectionStatus, err := c.Proxy.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: common.DefaultShardsNum,
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
err = merr.Error(createCollectionStatus)
|
||||
s.NoError(err)
|
||||
|
||||
showCollectionsResp, err := c.Proxy.ShowCollections(context.TODO(), &milvuspb.ShowCollectionsRequest{})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(showCollectionsResp.GetStatus()))
|
||||
|
||||
batchSize := 500000
|
||||
for start := 0; start < s.rowsPerCollection; start += batchSize {
|
||||
rowNum := batchSize
|
||||
if start+batchSize > s.rowsPerCollection {
|
||||
rowNum = s.rowsPerCollection - start
|
||||
}
|
||||
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
|
||||
hashKeys := integration.GenerateHashKeys(rowNum)
|
||||
insertResult, err := c.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
FieldsData: []*schemapb.FieldData{fVecColumn},
|
||||
HashKeys: hashKeys,
|
||||
NumRows: uint32(rowNum),
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(merr.Ok(insertResult.GetStatus()))
|
||||
}
|
||||
log.Info("=========================Data insertion finished=========================")
|
||||
|
||||
// flush
|
||||
flushResp, err := c.Proxy.Flush(context.TODO(), &milvuspb.FlushRequest{
|
||||
DbName: dbName,
|
||||
CollectionNames: []string{collectionName},
|
||||
})
|
||||
s.NoError(err)
|
||||
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
|
||||
ids := segmentIDs.GetData()
|
||||
s.Require().NotEmpty(segmentIDs)
|
||||
s.Require().True(has)
|
||||
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
|
||||
s.True(has)
|
||||
|
||||
segments, err := c.MetaWatcher.ShowSegments()
|
||||
s.NoError(err)
|
||||
s.NotEmpty(segments)
|
||||
s.WaitForFlush(context.TODO(), ids, flushTs, dbName, collectionName)
|
||||
log.Info("=========================Data flush finished=========================")
|
||||
|
||||
// create index
|
||||
createIndexStatus, err := c.Proxy.CreateIndex(context.TODO(), &milvuspb.CreateIndexRequest{
|
||||
CollectionName: collectionName,
|
||||
FieldName: integration.FloatVecField,
|
||||
IndexName: "_default",
|
||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(createIndexStatus)
|
||||
s.NoError(err)
|
||||
s.WaitForIndexBuilt(context.TODO(), collectionName, integration.FloatVecField)
|
||||
log.Info("=========================Index created=========================")
|
||||
|
||||
// load
|
||||
loadStatus, err := c.Proxy.LoadCollection(context.TODO(), &milvuspb.LoadCollectionRequest{
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
s.NoError(err)
|
||||
err = merr.Error(loadStatus)
|
||||
s.NoError(err)
|
||||
s.WaitForLoad(context.TODO(), collectionName)
|
||||
log.Info("=========================Collection loaded=========================")
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) checkCollectionLoaded(collectionName string) bool {
|
||||
loadProgress, err := s.Cluster.Proxy.GetLoadingProgress(context.TODO(), &milvuspb.GetLoadingProgressRequest{
|
||||
DbName: "",
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
s.NoError(err)
|
||||
if loadProgress.GetProgress() != int64(100) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) checkCollectionsLoaded(startCollectionID, endCollectionID int) bool {
|
||||
notLoaded := 0
|
||||
loaded := 0
|
||||
for idx := startCollectionID; idx < endCollectionID; idx++ {
|
||||
collectionName := s.prefix + "_" + strconv.Itoa(idx)
|
||||
if s.checkCollectionLoaded(collectionName) {
|
||||
notLoaded++
|
||||
} else {
|
||||
loaded++
|
||||
}
|
||||
}
|
||||
log.Info(fmt.Sprintf("loading status: %d/%d", loaded, endCollectionID-startCollectionID+1))
|
||||
return notLoaded == 0
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) checkAllCollectionsLoaded() bool {
|
||||
return s.checkCollectionsLoaded(0, s.numCollections)
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) search(collectionName string, dim int) {
|
||||
c := s.Cluster
|
||||
var err error
|
||||
// Query
|
||||
queryReq := &milvuspb.QueryRequest{
|
||||
Base: nil,
|
||||
CollectionName: collectionName,
|
||||
PartitionNames: nil,
|
||||
Expr: "",
|
||||
OutputFields: []string{"count(*)"},
|
||||
TravelTimestamp: 0,
|
||||
GuaranteeTimestamp: 0,
|
||||
}
|
||||
queryResult, err := c.Proxy.Query(context.TODO(), queryReq)
|
||||
s.NoError(err)
|
||||
s.Equal(queryResult.Status.ErrorCode, commonpb.ErrorCode_Success)
|
||||
s.Equal(len(queryResult.FieldsData), 1)
|
||||
numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]
|
||||
s.Equal(numEntities, int64(s.rowsPerCollection))
|
||||
|
||||
// Search
|
||||
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
|
||||
nq := 10
|
||||
topk := 10
|
||||
roundDecimal := -1
|
||||
radius := 10
|
||||
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
|
||||
params["radius"] = radius
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
|
||||
|
||||
searchResult, _ := c.Proxy.Search(context.TODO(), searchReq)
|
||||
|
||||
err = merr.Error(searchResult.GetStatus())
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) FailOnSearch(collectionName string) {
|
||||
c := s.Cluster
|
||||
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
|
||||
nq := 10
|
||||
topk := 10
|
||||
roundDecimal := -1
|
||||
radius := 10
|
||||
|
||||
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
|
||||
params["radius"] = radius
|
||||
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
|
||||
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, s.dim, topk, roundDecimal)
|
||||
|
||||
searchResult, err := c.Proxy.Search(context.TODO(), searchReq)
|
||||
s.NoError(err)
|
||||
err = merr.Error(searchResult.GetStatus())
|
||||
s.Require().Error(err)
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) setupData() {
|
||||
// Add the second query node
|
||||
log.Info("=========================Start to inject data=========================")
|
||||
s.prefix = "TestPartialSearchUtil" + funcutil.GenRandomStr()
|
||||
searchName := s.prefix + "_0"
|
||||
wg := sync.WaitGroup{}
|
||||
for idx := 0; idx < s.numCollections; idx++ {
|
||||
wg.Add(1)
|
||||
go s.loadCollection(s.prefix+"_"+strconv.Itoa(idx), s.dim, &wg)
|
||||
}
|
||||
wg.Wait()
|
||||
log.Info("=========================Data injection finished=========================")
|
||||
s.checkAllCollectionsLoaded()
|
||||
log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName))
|
||||
s.search(searchName, s.dim)
|
||||
log.Info("=========================Search finished=========================")
|
||||
time.Sleep(s.waitTimeInSec)
|
||||
s.checkAllCollectionsLoaded()
|
||||
log.Info(fmt.Sprintf("=========================start to search2 %s=========================", searchName))
|
||||
s.search(searchName, s.dim)
|
||||
log.Info("=========================Search2 finished=========================")
|
||||
s.checkAllCollectionsReady()
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) checkCollectionsReady(startCollectionID, endCollectionID int) {
|
||||
for i := startCollectionID; i < endCollectionID; i++ {
|
||||
collectionName := s.prefix + "_" + strconv.Itoa(i)
|
||||
s.search(collectionName, s.dim)
|
||||
queryReq := &milvuspb.QueryRequest{
|
||||
CollectionName: collectionName,
|
||||
Expr: "",
|
||||
OutputFields: []string{"count(*)"},
|
||||
}
|
||||
_, err := s.Cluster.Proxy.Query(context.TODO(), queryReq)
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) checkAllCollectionsReady() {
|
||||
s.checkCollectionsReady(0, s.numCollections)
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) releaseSegmentsReq(collectionID, nodeID, segmentID typeutil.UniqueID, shard string) *querypb.ReleaseSegmentsRequest {
|
||||
req := &querypb.ReleaseSegmentsRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_ReleaseSegments),
|
||||
commonpbutil.WithMsgID(1<<30),
|
||||
commonpbutil.WithTargetID(nodeID),
|
||||
),
|
||||
|
||||
NodeID: nodeID,
|
||||
CollectionID: collectionID,
|
||||
SegmentIDs: []int64{segmentID},
|
||||
Scope: querypb.DataScope_Historical,
|
||||
Shard: shard,
|
||||
NeedTransfer: false,
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) describeCollection(name string) (int64, []string) {
|
||||
resp, err := s.Cluster.Proxy.DescribeCollection(context.TODO(), &milvuspb.DescribeCollectionRequest{
|
||||
DbName: "default",
|
||||
CollectionName: name,
|
||||
})
|
||||
s.NoError(err)
|
||||
log.Info(fmt.Sprintf("describe collection: %v", resp))
|
||||
return resp.CollectionID, resp.VirtualChannelNames
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) getSegmentIDs(collectionName string) []int64 {
|
||||
resp, err := s.Cluster.Proxy.GetPersistentSegmentInfo(context.TODO(), &milvuspb.GetPersistentSegmentInfoRequest{
|
||||
DbName: "default",
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
s.NoError(err)
|
||||
var res []int64
|
||||
for _, seg := range resp.Infos {
|
||||
res = append(res, seg.SegmentID)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (s *PartialSearchSuite) TestPartialSearch() {
|
||||
s.setupParam()
|
||||
s.setupData()
|
||||
|
||||
startCollectionID := 0
|
||||
endCollectionID := 0
|
||||
// Search should work in the beginning
|
||||
s.checkCollectionsReady(startCollectionID, endCollectionID)
|
||||
// Test case with one segment released
|
||||
// Partial search does not work yet.
|
||||
c := s.Cluster
|
||||
q1 := c.QueryNode
|
||||
c.MixCoord.StopCheckerForTestOnly()
|
||||
collectionName := s.prefix + "_0"
|
||||
nodeID := q1.GetServerIDForTestOnly()
|
||||
collectionID, channels := s.describeCollection(collectionName)
|
||||
segs := s.getSegmentIDs(collectionName)
|
||||
s.Require().Positive(len(segs))
|
||||
s.Require().Positive(len(channels))
|
||||
segmentID := segs[0]
|
||||
shard := channels[0]
|
||||
req := s.releaseSegmentsReq(collectionID, nodeID, segmentID, shard)
|
||||
q1.ReleaseSegments(context.TODO(), req)
|
||||
s.FailOnSearch(collectionName)
|
||||
c.MixCoord.StartCheckerForTestOnly()
|
||||
}
|
||||
|
||||
func TestPartialSearchUtil(t *testing.T) {
|
||||
suite.Run(t, new(PartialSearchSuite))
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user