feat: Implement partial result support on node down (#42009)

issue: https://github.com/milvus-io/milvus/issues/41690
This commit implements partial search result functionality when query
nodes go down, improving system availability during node failures. The
changes include:

- Enhanced load balancing in proxy (lb_policy.go) to handle node
failures with retry support
- Added partial search result capability in querynode delegator and
distribution logic
- Implemented tests for various partial result scenarios when nodes go
down
- Added metrics to track partial search results in querynode_metrics.go
- Updated parameter configuration to support partial result required
data ratio
- Replaced old partial_search_test.go with more comprehensive
partial_result_on_node_down_test.go
- Updated proto definitions and improved retry logic

These changes improve query resilience by returning partial results to
users when some query nodes are unavailable, ensuring that queries don't
completely fail when a portion of data remains accessible.

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2025-05-28 00:12:28 +08:00 committed by GitHub
parent 57b58ad778
commit 54619eaa2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 3447 additions and 2625 deletions

View File

@ -329,6 +329,7 @@ proxy:
slowQuerySpanInSeconds: 5 # query whose executed time exceeds the `slowQuerySpanInSeconds` can be considered slow, in seconds. slowQuerySpanInSeconds: 5 # query whose executed time exceeds the `slowQuerySpanInSeconds` can be considered slow, in seconds.
queryNodePooling: queryNodePooling:
size: 10 # the size for shardleader(querynode) client pool size: 10 # the size for shardleader(querynode) client pool
partialResultRequiredDataRatio: 1 # partial result required data ratio, default to 1 which means disable partial result, otherwise, it will be used as the minimum data ratio for partial result
http: http:
enabled: true # Whether to enable the http server enabled: true # Whether to enable the http server
debug_mode: false # Whether to enable http server debug mode debug_mode: false # Whether to enable http server debug mode

View File

@ -17,6 +17,7 @@ package proxy
import ( import (
"context" "context"
"strings"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/samber/lo" "github.com/samber/lo"
@ -119,22 +120,70 @@ func (lb *LBPolicyImpl) GetShardLeaders(ctx context.Context, dbName string, coll
} }
// try to select the best node from the available nodes // try to select the best node from the available nodes
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (nodeInfo, error) { func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload *ChannelWorkload, excludeNodes typeutil.UniqueSet) (nodeInfo, error) {
filterDelegator := func(nodes []nodeInfo) map[int64]nodeInfo { log := log.Ctx(ctx)
ret := make(map[int64]nodeInfo) // Select node using specified nodes
trySelectNode := func(nodes []nodeInfo) (nodeInfo, error) {
candidateNodes := make(map[int64]nodeInfo)
serviceableNodes := make(map[int64]nodeInfo)
// Filter nodes based on excludeNodes
for _, node := range nodes { for _, node := range nodes {
if !excludeNodes.Contain(node.nodeID) { if !excludeNodes.Contain(node.nodeID) {
ret[node.nodeID] = node if node.serviceable {
serviceableNodes[node.nodeID] = node
}
candidateNodes[node.nodeID] = node
} }
} }
return ret
var err error
defer func() {
if err != nil {
candidatesInStr := lo.Map(nodes, func(node nodeInfo, _ int) string {
return node.String()
})
serviceableNodesInStr := lo.Map(lo.Values(serviceableNodes), func(node nodeInfo, _ int) string {
return node.String()
})
log.Warn("failed to select shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("excluded", excludeNodes.Collect()),
zap.String("candidates", strings.Join(candidatesInStr, ", ")),
zap.String("serviceableNodes", strings.Join(serviceableNodesInStr, ", ")),
zap.Error(err))
}
}()
if len(candidateNodes) == 0 {
err = merr.WrapErrChannelNotAvailable(workload.channel)
return nodeInfo{}, err
}
balancer.RegisterNodeInfo(lo.Values(candidateNodes))
// prefer serviceable nodes
var targetNodeID int64
if len(serviceableNodes) > 0 {
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(serviceableNodes), workload.nq)
} else {
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(candidateNodes), workload.nq)
}
if err != nil {
return nodeInfo{}, err
}
if _, ok := candidateNodes[targetNodeID]; !ok {
err = merr.WrapErrNodeNotAvailable(targetNodeID)
return nodeInfo{}, err
}
return candidateNodes[targetNodeID], nil
} }
availableNodes := filterDelegator(workload.shardLeaders) // First attempt with current shard leaders
balancer.RegisterNodeInfo(lo.Values(availableNodes)) targetNode, err := trySelectNode(workload.shardLeaders)
targetNode, err := balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq) // If failed, refresh cache and retry
if err != nil { if err != nil {
log := log.Ctx(ctx)
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName) globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
shardLeaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, false) shardLeaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, false)
if err != nil { if err != nil {
@ -145,51 +194,41 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
return nodeInfo{}, err return nodeInfo{}, err
} }
availableNodes = filterDelegator(shardLeaders[workload.channel]) workload.shardLeaders = shardLeaders[workload.channel]
if len(availableNodes) == 0 { // Second attempt with fresh shard leaders
log.Warn("no available shard delegator found", targetNode, err = trySelectNode(workload.shardLeaders)
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("availableNodes", lo.Keys(availableNodes)),
zap.Int64s("excluded", excludeNodes.Collect()))
return nodeInfo{}, merr.WrapErrChannelNotAvailable("no available shard delegator found")
}
balancer.RegisterNodeInfo(lo.Values(availableNodes))
targetNode, err = balancer.SelectNode(ctx, lo.Keys(availableNodes), workload.nq)
if err != nil { if err != nil {
log.Warn("failed to select shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("availableNodes", lo.Keys(availableNodes)),
zap.Int64s("excluded", excludeNodes.Collect()),
zap.Error(err))
return nodeInfo{}, err return nodeInfo{}, err
} }
} }
return availableNodes[targetNode], nil return targetNode, nil
} }
// ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes. // ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes.
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error { func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
excludeNodes := typeutil.NewUniqueSet()
var lastErr error var lastErr error
err := retry.Do(ctx, func() error { excludeNodes := typeutil.NewUniqueSet()
tryExecute := func() (bool, error) {
// if keeping retry after all nodes are excluded, try to clean excludeNodes
if excludeNodes.Len() == len(workload.shardLeaders) {
excludeNodes.Clear()
}
balancer := lb.getBalancer() balancer := lb.getBalancer()
targetNode, err := lb.selectNode(ctx, balancer, workload, excludeNodes) targetNode, err := lb.selectNode(ctx, balancer, &workload, excludeNodes)
if err != nil { if err != nil {
log.Warn("failed to select node for shard", log.Warn("failed to select node for shard",
zap.Int64("collectionID", workload.collectionID), zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel), zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode.nodeID), zap.Int64("nodeID", targetNode.nodeID),
zap.Int64s("excluded", excludeNodes.Collect()),
zap.Error(err), zap.Error(err),
) )
if lastErr != nil { if lastErr != nil {
return lastErr return true, lastErr
} }
return err return true, err
} }
// cancel work load which assign to the target node // cancel work load which assign to the target node
defer balancer.CancelWorkload(targetNode.nodeID, workload.nq) defer balancer.CancelWorkload(targetNode.nodeID, workload.nq)
@ -204,7 +243,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
excludeNodes.Insert(targetNode.nodeID) excludeNodes.Insert(targetNode.nodeID)
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode.nodeID, workload.channel) lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode.nodeID, workload.channel)
return lastErr return true, lastErr
} }
err = workload.exec(ctx, targetNode.nodeID, client, workload.channel) err = workload.exec(ctx, targetNode.nodeID, client, workload.channel)
@ -216,11 +255,19 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
zap.Error(err)) zap.Error(err))
excludeNodes.Insert(targetNode.nodeID) excludeNodes.Insert(targetNode.nodeID)
lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode.nodeID, workload.channel) lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode.nodeID, workload.channel)
return lastErr return true, lastErr
} }
return nil return true, nil
}, retry.Attempts(workload.retryTimes)) }
// if failed, try to execute with partial result
err := retry.Handle(ctx, tryExecute, retry.Attempts(workload.retryTimes))
if err != nil {
log.Ctx(ctx).Warn("failed to execute with partial result",
zap.String("channel", workload.channel),
zap.Error(err))
}
return err return err
} }
@ -233,8 +280,14 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
return err return err
} }
// let every request could retry at least twice, which could retry after update shard leader cache totalChannels := len(dml2leaders)
wg, ctx := errgroup.WithContext(ctx) if totalChannels == 0 {
log.Ctx(ctx).Info("no shard leaders found", zap.Int64("collectionID", workload.collectionID))
return merr.WrapErrCollectionNotLoaded(workload.collectionID)
}
wg, _ := errgroup.WithContext(ctx)
// Launch a goroutine for each channel
for k, v := range dml2leaders { for k, v := range dml2leaders {
channel := k channel := k
nodes := v nodes := v

View File

@ -69,8 +69,9 @@ func (s *LBPolicySuite) SetupTest() {
for i := 1; i <= 5; i++ { for i := 1; i <= 5; i++ {
s.nodeIDs = append(s.nodeIDs, int64(i)) s.nodeIDs = append(s.nodeIDs, int64(i))
s.nodes = append(s.nodes, nodeInfo{ s.nodes = append(s.nodes, nodeInfo{
nodeID: int64(i), nodeID: int64(i),
address: "localhost", address: "localhost",
serviceable: true,
}) })
} }
s.channels = []string{"channel1", "channel2"} s.channels = []string{"channel1", "channel2"}
@ -84,11 +85,13 @@ func (s *LBPolicySuite) SetupTest() {
ChannelName: s.channels[0], ChannelName: s.channels[0],
NodeIds: s.nodeIDs, NodeIds: s.nodeIDs,
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
Serviceable: []bool{true, true, true, true, true},
}, },
{ {
ChannelName: s.channels[1], ChannelName: s.channels[1],
NodeIds: s.nodeIDs, NodeIds: s.nodeIDs,
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
Serviceable: []bool{true, true, true, true, true},
}, },
}, },
}, nil }, nil
@ -175,7 +178,7 @@ func (s *LBPolicySuite) TestSelectNode() {
ctx := context.Background() ctx := context.Background()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
db: dbName, db: dbName,
collectionName: s.collectionName, collectionName: s.collectionName,
collectionID: s.collectionID, collectionID: s.collectionID,
@ -191,12 +194,12 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
db: dbName, db: dbName,
collectionName: s.collectionName, collectionName: s.collectionName,
collectionID: s.collectionID, collectionID: s.collectionID,
channel: s.channels[0], channel: s.channels[0],
shardLeaders: []nodeInfo{}, shardLeaders: s.nodes,
nq: 1, nq: 1,
}, typeutil.NewUniqueSet()) }, typeutil.NewUniqueSet())
s.NoError(err) s.NoError(err)
@ -206,7 +209,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
db: dbName, db: dbName,
collectionName: s.collectionName, collectionName: s.collectionName,
collectionID: s.collectionID, collectionID: s.collectionID,
@ -220,7 +223,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
db: dbName, db: dbName,
collectionName: s.collectionName, collectionName: s.collectionName,
collectionID: s.collectionID, collectionID: s.collectionID,
@ -237,7 +240,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) { s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return nil, merr.ErrServiceUnavailable return nil, merr.ErrServiceUnavailable
} }
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{ targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, &ChannelWorkload{
db: dbName, db: dbName,
collectionName: s.collectionName, collectionName: s.collectionName,
collectionID: s.collectionID, collectionID: s.collectionID,
@ -291,7 +294,7 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
// test get client failed, and retry failed, expected success // test get client failed, and retry failed, expected success
s.mgr.ExpectedCalls = nil s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(2)
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
@ -313,6 +316,10 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.mgr.ExpectedCalls = nil s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
@ -334,7 +341,9 @@ func (s *LBPolicySuite) TestExecuteWithRetry() {
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.ExpectedCalls = nil s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
counter := 0 counter := 0
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{ err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
@ -384,7 +393,9 @@ func (s *LBPolicySuite) TestExecute() {
// test all channel success // test all channel success
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil) s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything) s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil) s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything) s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{ err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName, db: dbName,

View File

@ -970,7 +970,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
} }
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc() metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord")
} }
info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID) info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
@ -983,7 +982,8 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders), commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders),
commonpbutil.WithSourceID(paramtable.GetNodeID()), commonpbutil.WithSourceID(paramtable.GetNodeID()),
), ),
CollectionID: info.collID, CollectionID: info.collID,
WithUnserviceableShards: true,
} }
tr := timerecord.NewTimeRecorder("UpdateShardCache") tr := timerecord.NewTimeRecorder("UpdateShardCache")
@ -1002,6 +1002,19 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
idx: atomic.NewInt64(0), idx: atomic.NewInt64(0),
} }
// convert shards map to string for logging
if log.Logger.Level() == zap.DebugLevel {
shardStr := make([]string, 0, len(shards))
for channel, nodes := range shards {
nodeStrs := make([]string, 0, len(nodes))
for _, node := range nodes {
nodeStrs = append(nodeStrs, node.String())
}
shardStr = append(shardStr, fmt.Sprintf("%s:[%s]", channel, strings.Join(nodeStrs, ", ")))
}
log.Debug("update shard leader cache", zap.String("newShardLeaders", strings.Join(shardStr, ", ")))
}
m.leaderMut.Lock() m.leaderMut.Lock()
if _, ok := m.collLeader[database]; !ok { if _, ok := m.collLeader[database]; !ok {
m.collLeader[database] = make(map[string]*shardLeaders) m.collLeader[database] = make(map[string]*shardLeaders)
@ -1028,7 +1041,7 @@ func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) m
qns := make([]nodeInfo, len(leaders.GetNodeIds())) qns := make([]nodeInfo, len(leaders.GetNodeIds()))
for j := range qns { for j := range qns {
qns[j] = nodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]} qns[j] = nodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j], leaders.GetServiceable()[j]}
} }
shard2QueryNodes[leaders.GetChannelName()] = qns shard2QueryNodes[leaders.GetChannelName()] = qns

View File

@ -1395,6 +1395,7 @@ func TestMetaCache_GetShards(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil }, nil
@ -1455,6 +1456,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil }, nil
@ -1836,6 +1838,7 @@ func TestMetaCache_InvalidateShardLeaderCache(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil }, nil

View File

@ -1505,6 +1505,7 @@ func (coord *MixCoordMock) GetShardLeaders(ctx context.Context, in *querypb.GetS
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil }, nil

View File

@ -20,12 +20,13 @@ import (
type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error)
type nodeInfo struct { type nodeInfo struct {
nodeID UniqueID nodeID UniqueID
address string address string
serviceable bool
} }
func (n nodeInfo) String() string { func (n nodeInfo) String() string {
return fmt.Sprintf("<NodeID: %d>", n.nodeID) return fmt.Sprintf("<NodeID: %d, serviceable: %v, address: %s>", n.nodeID, n.serviceable, n.address)
} }
var errClosed = errors.New("client is closed") var errClosed = errors.New("client is closed")

View File

@ -161,6 +161,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil) }, nil)
@ -185,6 +186,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil) }, nil)
@ -210,6 +212,7 @@ func TestDropIndexTask_PreExecute(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil) }, nil)
@ -241,6 +244,7 @@ func getMockQueryCoord() *mocks.MockMixCoordClient {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil) }, nil)

View File

@ -2994,6 +2994,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil }, nil

View File

@ -65,6 +65,7 @@ func (s *StatisticTaskSuite) SetupTest() {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil }, nil

View File

@ -978,6 +978,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil) }, nil)
@ -1002,6 +1003,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil) }, nil)
@ -1026,6 +1028,7 @@ func Test_isCollectionIsLoaded(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil) }, nil)
@ -1057,6 +1060,7 @@ func Test_isPartitionIsLoaded(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil) }, nil)
@ -1082,6 +1086,7 @@ func Test_isPartitionIsLoaded(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil) }, nil)
@ -1107,6 +1112,7 @@ func Test_isPartitionIsLoaded(t *testing.T) {
ChannelName: "channel-1", ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3}, NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
}, },
}, },
}, nil) }, nil)

View File

@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb" "github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
@ -264,6 +265,12 @@ func (ob *TargetObserver) check(ctx context.Context, collectionID int64) {
if ob.shouldUpdateNextTarget(ctx, collectionID) { if ob.shouldUpdateNextTarget(ctx, collectionID) {
// update next target in collection level // update next target in collection level
ob.updateNextTarget(ctx, collectionID) ob.updateNextTarget(ctx, collectionID)
// sync next target to delegator if current target not exist, to support partial search
if !ob.targetMgr.IsCurrentTargetExist(ctx, collectionID, -1) {
newVersion := ob.targetMgr.GetCollectionTargetVersion(ctx, collectionID, meta.NextTarget)
ob.syncNextTargetToDelegator(ctx, collectionID, ob.distMgr.ChannelDistManager.GetByFilter(meta.WithCollectionID2Channel(collectionID)), newVersion)
}
} }
} }
@ -412,14 +419,18 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect
collReadyDelegatorList = append(collReadyDelegatorList, chReadyDelegatorList...) collReadyDelegatorList = append(collReadyDelegatorList, chReadyDelegatorList...)
} }
return ob.syncNextTargetToDelegator(ctx, collectionID, collReadyDelegatorList, newVersion)
}
// sync next target info to delegator as readable snapshot
// 1. if next target is changed before delegator becomes serviceable, we need to sync the new next target to delegator to support partial search
// 2. if next target is ready to read, we need to sync the next target to delegator to support full search
func (ob *TargetObserver) syncNextTargetToDelegator(ctx context.Context, collectionID int64, collReadyDelegatorList []*meta.DmChannel, newVersion int64) bool {
var partitions []int64 var partitions []int64
var indexInfo []*indexpb.IndexInfo var indexInfo []*indexpb.IndexInfo
var err error var err error
for _, d := range collReadyDelegatorList { for _, d := range collReadyDelegatorList {
updateVersionAction := ob.checkNeedUpdateTargetVersion(ctx, d.View, newVersion) updateVersionAction := ob.genSyncAction(ctx, d.View, newVersion)
if updateVersionAction == nil {
continue
}
replica := ob.meta.ReplicaManager.GetByCollectionAndNode(ctx, collectionID, d.Node) replica := ob.meta.ReplicaManager.GetByCollectionAndNode(ctx, collectionID, d.Node)
if replica == nil { if replica == nil {
log.Warn("replica not found", zap.Int64("nodeID", d.Node), zap.Int64("collectionID", collectionID)) log.Warn("replica not found", zap.Int64("nodeID", d.Node), zap.Int64("collectionID", collectionID))
@ -441,19 +452,16 @@ func (ob *TargetObserver) shouldUpdateCurrentTarget(ctx context.Context, collect
} }
} }
if !ob.sync(ctx, replica, d.View, []*querypb.SyncAction{updateVersionAction}, partitions, indexInfo) { if !ob.syncToDelegator(ctx, replica, d.View, updateVersionAction, partitions, indexInfo) {
return false return false
} }
} }
return true return true
} }
func (ob *TargetObserver) sync(ctx context.Context, replica *meta.Replica, LeaderView *meta.LeaderView, diffs []*querypb.SyncAction, func (ob *TargetObserver) syncToDelegator(ctx context.Context, replica *meta.Replica, LeaderView *meta.LeaderView, action *querypb.SyncAction,
partitions []int64, indexInfo []*indexpb.IndexInfo, partitions []int64, indexInfo []*indexpb.IndexInfo,
) bool { ) bool {
if len(diffs) == 0 {
return true
}
replicaID := replica.GetID() replicaID := replica.GetID()
log := log.With( log := log.With(
@ -469,7 +477,7 @@ func (ob *TargetObserver) sync(ctx context.Context, replica *meta.Replica, Leade
CollectionID: LeaderView.CollectionID, CollectionID: LeaderView.CollectionID,
ReplicaID: replicaID, ReplicaID: replicaID,
Channel: LeaderView.Channel, Channel: LeaderView.Channel,
Actions: diffs, Actions: []*querypb.SyncAction{action},
LoadMeta: &querypb.LoadMetaInfo{ LoadMeta: &querypb.LoadMetaInfo{
LoadType: ob.meta.GetLoadType(ctx, LeaderView.CollectionID), LoadType: ob.meta.GetLoadType(ctx, LeaderView.CollectionID),
CollectionID: LeaderView.CollectionID, CollectionID: LeaderView.CollectionID,
@ -496,31 +504,34 @@ func (ob *TargetObserver) sync(ctx context.Context, replica *meta.Replica, Leade
return true return true
} }
func (ob *TargetObserver) checkNeedUpdateTargetVersion(ctx context.Context, leaderView *meta.LeaderView, targetVersion int64) *querypb.SyncAction { // sync next target info to delegator
log.Ctx(ctx).WithRateGroup("qcv2.LeaderObserver", 1, 60) // 1. if next target is changed before delegator becomes serviceable, we need to sync the new next target to delegator to support partial search
if targetVersion <= leaderView.TargetVersion { // 2. if next target is ready to read, we need to sync the next target to delegator to support full search
return nil func (ob *TargetObserver) genSyncAction(ctx context.Context, leaderView *meta.LeaderView, targetVersion int64) *querypb.SyncAction {
} log.Ctx(ctx).WithRateGroup("qcv2.LeaderObserver", 1, 60).
RatedInfo(10, "Update readable segment version",
log.RatedInfo(10, "Update readable segment version", zap.Int64("collectionID", leaderView.CollectionID),
zap.Int64("collectionID", leaderView.CollectionID), zap.String("channelName", leaderView.Channel),
zap.String("channelName", leaderView.Channel), zap.Int64("nodeID", leaderView.ID),
zap.Int64("nodeID", leaderView.ID), zap.Int64("oldVersion", leaderView.TargetVersion),
zap.Int64("oldVersion", leaderView.TargetVersion), zap.Int64("newVersion", targetVersion),
zap.Int64("newVersion", targetVersion), )
)
sealedSegments := ob.targetMgr.GetSealedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget) sealedSegments := ob.targetMgr.GetSealedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
growingSegments := ob.targetMgr.GetGrowingSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget) growingSegments := ob.targetMgr.GetGrowingSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
droppedSegments := ob.targetMgr.GetDroppedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget) droppedSegments := ob.targetMgr.GetDroppedSegmentsByChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTarget)
channel := ob.targetMgr.GetDmChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTargetFirst) channel := ob.targetMgr.GetDmChannel(ctx, leaderView.CollectionID, leaderView.Channel, meta.NextTargetFirst)
sealedSegmentRowCount := lo.MapValues(sealedSegments, func(segment *datapb.SegmentInfo, _ int64) int64 {
return segment.GetNumOfRows()
})
action := &querypb.SyncAction{ action := &querypb.SyncAction{
Type: querypb.SyncType_UpdateVersion, Type: querypb.SyncType_UpdateVersion,
GrowingInTarget: growingSegments.Collect(), GrowingInTarget: growingSegments.Collect(),
SealedInTarget: lo.Keys(sealedSegments), SealedInTarget: lo.Keys(sealedSegmentRowCount),
DroppedInTarget: droppedSegments, DroppedInTarget: droppedSegments,
TargetVersion: targetVersion, TargetVersion: targetVersion,
SealedSegmentRowCount: sealedSegmentRowCount,
} }
if channel != nil { if channel != nil {

View File

@ -263,7 +263,7 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
}, 7*time.Second, 1*time.Second) }, 7*time.Second, 1*time.Second)
ch1View := suite.distMgr.ChannelDistManager.GetByFilter(meta.WithChannelName2Channel("channel-1"))[0].View ch1View := suite.distMgr.ChannelDistManager.GetByFilter(meta.WithChannelName2Channel("channel-1"))[0].View
action := suite.observer.checkNeedUpdateTargetVersion(ctx, ch1View, 100) action := suite.observer.genSyncAction(ctx, ch1View, 100)
suite.Equal(action.GetDeleteCP().Timestamp, uint64(200)) suite.Equal(action.GetDeleteCP().Timestamp, uint64(200))
} }

View File

@ -902,7 +902,7 @@ func (s *Server) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeade
}, nil }, nil
} }
leaders, err := utils.GetShardLeaders(ctx, s.meta, s.targetMgr, s.dist, s.nodeMgr, req.GetCollectionID()) leaders, err := utils.GetShardLeaders(ctx, s.meta, s.targetMgr, s.dist, s.nodeMgr, req.GetCollectionID(), req.GetWithUnserviceableShards())
return &querypb.GetShardLeadersResponse{ return &querypb.GetShardLeadersResponse{
Status: merr.Status(err), Status: merr.Status(err),
Shards: leaders, Shards: leaders,

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils" "github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb" "github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
@ -384,6 +385,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
return merr.WrapErrServiceInternal(fmt.Sprintf("failed to get partitions for collection=%d", task.CollectionID())) return merr.WrapErrServiceInternal(fmt.Sprintf("failed to get partitions for collection=%d", task.CollectionID()))
} }
version := ex.targetMgr.GetCollectionTargetVersion(ctx, task.CollectionID(), meta.NextTargetFirst)
req := packSubChannelRequest( req := packSubChannelRequest(
task, task,
action, action,
@ -392,6 +394,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
dmChannel, dmChannel,
indexInfo, indexInfo,
partitions, partitions,
version,
) )
err = fillSubChannelRequest(ctx, req, ex.broker, ex.shouldIncludeFlushedSegmentInfo(action.Node())) err = fillSubChannelRequest(ctx, req, ex.broker, ex.shouldIncludeFlushedSegmentInfo(action.Node()))
if err != nil { if err != nil {
@ -400,6 +403,12 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error {
return err return err
} }
sealedSegments := ex.targetMgr.GetSealedSegmentsByChannel(ctx, dmChannel.CollectionID, dmChannel.ChannelName, meta.NextTarget)
sealedSegmentRowCount := lo.MapValues(sealedSegments, func(segment *datapb.SegmentInfo, _ int64) int64 {
return segment.GetNumOfRows()
})
req.SealedSegmentRowCount = sealedSegmentRowCount
ts := dmChannel.GetSeekPosition().GetTimestamp() ts := dmChannel.GetSeekPosition().GetTimestamp()
log.Info("subscribe channel...", log.Info("subscribe channel...",
zap.Uint64("checkpoint", ts), zap.Uint64("checkpoint", ts),

View File

@ -208,6 +208,7 @@ func packSubChannelRequest(
channel *meta.DmChannel, channel *meta.DmChannel,
indexInfo []*indexpb.IndexInfo, indexInfo []*indexpb.IndexInfo,
partitions []int64, partitions []int64,
targetVersion int64,
) *querypb.WatchDmChannelsRequest { ) *querypb.WatchDmChannelsRequest {
return &querypb.WatchDmChannelsRequest{ return &querypb.WatchDmChannelsRequest{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
@ -223,6 +224,7 @@ func packSubChannelRequest(
ReplicaID: task.ReplicaID(), ReplicaID: task.ReplicaID(),
Version: time.Now().UnixNano(), Version: time.Now().UnixNano(),
IndexInfoList: indexInfo, IndexInfoList: indexInfo,
TargetVersion: targetVersion,
} }
} }
@ -253,6 +255,7 @@ func fillSubChannelRequest(
req.SegmentInfos = lo.SliceToMap(segmentInfos, func(info *datapb.SegmentInfo) (int64, *datapb.SegmentInfo) { req.SegmentInfos = lo.SliceToMap(segmentInfos, func(info *datapb.SegmentInfo) (int64, *datapb.SegmentInfo) {
return info.GetID(), info return info.GetID(), info
}) })
return nil return nil
} }

View File

@ -100,8 +100,14 @@ func checkLoadStatus(ctx context.Context, m *meta.Meta, collectionID int64) erro
return nil return nil
} }
func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, func GetShardLeadersWithChannels(
nodeMgr *session.NodeManager, collectionID int64, channels map[string]*meta.DmChannel, ctx context.Context,
m *meta.Meta,
dist *meta.DistributionManager,
nodeMgr *session.NodeManager,
collectionID int64,
channels map[string]*meta.DmChannel,
withUnserviceableShards bool,
) ([]*querypb.ShardLeadersList, error) { ) ([]*querypb.ShardLeadersList, error) {
ret := make([]*querypb.ShardLeadersList, 0) ret := make([]*querypb.ShardLeadersList, 0)
@ -111,9 +117,10 @@ func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr me
ids := make([]int64, 0, len(replicas)) ids := make([]int64, 0, len(replicas))
addrs := make([]string, 0, len(replicas)) addrs := make([]string, 0, len(replicas))
serviceable := make([]bool, 0, len(replicas))
for _, replica := range replicas { for _, replica := range replicas {
leader := dist.ChannelDistManager.GetShardLeader(channel.GetChannelName(), replica) leader := dist.ChannelDistManager.GetShardLeader(channel.GetChannelName(), replica)
if leader == nil || !leader.IsServiceable() { if leader == nil || (!withUnserviceableShards && !leader.IsServiceable()) {
log.WithRateGroup("util.GetShardLeaders", 1, 60). log.WithRateGroup("util.GetShardLeaders", 1, 60).
Warn("leader is not available in replica", zap.String("channel", channel.GetChannelName()), zap.Int64("replicaID", replica.GetID())) Warn("leader is not available in replica", zap.String("channel", channel.GetChannelName()), zap.Int64("replicaID", replica.GetID()))
continue continue
@ -122,11 +129,11 @@ func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr me
if info != nil { if info != nil {
ids = append(ids, info.ID()) ids = append(ids, info.ID())
addrs = append(addrs, info.Addr()) addrs = append(addrs, info.Addr())
serviceable = append(serviceable, leader.IsServiceable())
} }
} }
// to avoid node down during GetShardLeaders if len(ids) == 0 && !withUnserviceableShards {
if len(ids) == 0 {
err := merr.WrapErrChannelNotAvailable(channel.GetChannelName()) err := merr.WrapErrChannelNotAvailable(channel.GetChannelName())
msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName()) msg := fmt.Sprintf("channel %s is not available in any replica", channel.GetChannelName())
log.Warn(msg, zap.Error(err)) log.Warn(msg, zap.Error(err))
@ -137,13 +144,22 @@ func GetShardLeadersWithChannels(ctx context.Context, m *meta.Meta, targetMgr me
ChannelName: channel.GetChannelName(), ChannelName: channel.GetChannelName(),
NodeIds: ids, NodeIds: ids,
NodeAddrs: addrs, NodeAddrs: addrs,
Serviceable: serviceable,
}) })
} }
return ret, nil return ret, nil
} }
func GetShardLeaders(ctx context.Context, m *meta.Meta, targetMgr meta.TargetManagerInterface, dist *meta.DistributionManager, nodeMgr *session.NodeManager, collectionID int64) ([]*querypb.ShardLeadersList, error) { func GetShardLeaders(ctx context.Context,
m *meta.Meta,
targetMgr meta.TargetManagerInterface,
dist *meta.DistributionManager,
nodeMgr *session.NodeManager,
collectionID int64,
withUnserviceableShards bool,
) ([]*querypb.ShardLeadersList, error) {
// skip check load status if withUnserviceableShards is true
if err := checkLoadStatus(ctx, m, collectionID); err != nil { if err := checkLoadStatus(ctx, m, collectionID); err != nil {
return nil, err return nil, err
} }
@ -155,7 +171,7 @@ func GetShardLeaders(ctx context.Context, m *meta.Meta, targetMgr meta.TargetMan
log.Ctx(ctx).Warn("failed to get channels", zap.Error(err)) log.Ctx(ctx).Warn("failed to get channels", zap.Error(err))
return nil, err return nil, err
} }
return GetShardLeadersWithChannels(ctx, m, targetMgr, dist, nodeMgr, collectionID, channels) return GetShardLeadersWithChannels(ctx, m, dist, nodeMgr, collectionID, channels, withUnserviceableShards)
} }
// CheckCollectionsQueryable check all channels are watched and all segments are loaded for this collection // CheckCollectionsQueryable check all channels are watched and all segments are loaded for this collection
@ -194,7 +210,7 @@ func checkCollectionQueryable(ctx context.Context, m *meta.Meta, targetMgr meta.
return err return err
} }
shardList, err := GetShardLeadersWithChannels(ctx, m, targetMgr, dist, nodeMgr, collectionID, channels) shardList, err := GetShardLeadersWithChannels(ctx, m, dist, nodeMgr, collectionID, channels, false)
if err != nil { if err != nil {
return err return err
} }

View File

@ -30,12 +30,12 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"go.uber.org/atomic" "go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/querynodev2/cluster" "github.com/milvus-io/milvus/internal/querynodev2/cluster"
@ -88,8 +88,8 @@ type ShardDelegator interface {
LoadL0(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error LoadL0(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error
LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error
ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error
SyncTargetVersion(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, droppedInTarget []int64, checkpoint *msgpb.MsgPosition, deleteSeekPos *msgpb.MsgPosition) SyncTargetVersion(action *querypb.SyncAction, partitions []int64)
GetQueryView() *channelQueryView GetChannelQueryView() *channelQueryView
GetDeleteBufferSize() (entryNum int64, memorySize int64) GetDeleteBufferSize() (entryNum int64, memorySize int64)
// manage exclude segments // manage exclude segments
@ -369,21 +369,32 @@ func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest
req.Req.GetIsIterator(), req.Req.GetIsIterator(),
) )
partialResultRequiredDataRatio := paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.GetAsFloat()
// wait tsafe // wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe") waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp) var tSafe uint64
if err != nil { var err error
log.Warn("delegator search failed to wait tsafe", zap.Error(err)) if partialResultRequiredDataRatio >= 1.0 {
return nil, err tSafe, err = sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
} if err != nil {
if req.GetReq().GetMvccTimestamp() == 0 { log.Warn("delegator search failed to wait tsafe", zap.Error(err))
req.Req.MvccTimestamp = tSafe return nil, err
}
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = tSafe
}
} else {
tSafe = sd.GetTSafe()
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = tSafe
}
} }
metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues( metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel). fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).
Observe(float64(waitTr.ElapseSpan().Milliseconds())) Observe(float64(waitTr.ElapseSpan().Milliseconds()))
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...) sealed, growing, version, err := sd.distribution.PinReadableSegments(partialResultRequiredDataRatio, req.GetReq().GetPartitionIDs()...)
if err != nil { if err != nil {
log.Warn("delegator failed to search, current distribution is not serviceable", zap.Error(err)) log.Warn("delegator failed to search, current distribution is not serviceable", zap.Error(err))
return nil, err return nil, err
@ -500,7 +511,7 @@ func (sd *shardDelegator) QueryStream(ctx context.Context, req *querypb.QueryReq
fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel). fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).
Observe(float64(waitTr.ElapseSpan().Milliseconds())) Observe(float64(waitTr.ElapseSpan().Milliseconds()))
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...) sealed, growing, version, err := sd.distribution.PinReadableSegments(float64(1.0), req.GetReq().GetPartitionIDs()...)
if err != nil { if err != nil {
log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err)) log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err))
return err return err
@ -562,21 +573,31 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest)
req.Req.GetIsIterator(), req.Req.GetIsIterator(),
) )
partialResultRequiredDataRatio := paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.GetAsFloat()
// wait tsafe // wait tsafe
waitTr := timerecord.NewTimeRecorder("wait tSafe") waitTr := timerecord.NewTimeRecorder("wait tSafe")
tSafe, err := sd.waitTSafe(ctx, req.Req.GetGuaranteeTimestamp()) var tSafe uint64
if err != nil { var err error
log.Warn("delegator query failed to wait tsafe", zap.Error(err)) if partialResultRequiredDataRatio >= 1.0 {
return nil, err tSafe, err = sd.waitTSafe(ctx, req.Req.GuaranteeTimestamp)
} if err != nil {
if req.GetReq().GetMvccTimestamp() == 0 { log.Warn("delegator search failed to wait tsafe", zap.Error(err))
req.Req.MvccTimestamp = tSafe return nil, err
}
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = tSafe
}
} else {
if req.GetReq().GetMvccTimestamp() == 0 {
req.Req.MvccTimestamp = sd.GetTSafe()
}
} }
metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues( metrics.QueryNodeSQLatencyWaitTSafe.WithLabelValues(
fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel). fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).
Observe(float64(waitTr.ElapseSpan().Milliseconds())) Observe(float64(waitTr.ElapseSpan().Milliseconds()))
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.GetReq().GetPartitionIDs()...) sealed, growing, version, err := sd.distribution.PinReadableSegments(partialResultRequiredDataRatio, req.GetReq().GetPartitionIDs()...)
if err != nil { if err != nil {
log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err)) log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err))
return nil, err return nil, err
@ -646,7 +667,7 @@ func (sd *shardDelegator) GetStatistics(ctx context.Context, req *querypb.GetSta
return nil, err return nil, err
} }
sealed, growing, version, err := sd.distribution.PinReadableSegments(req.Req.GetPartitionIDs()...) sealed, growing, version, err := sd.distribution.PinReadableSegments(1.0, req.Req.GetPartitionIDs()...)
if err != nil { if err != nil {
log.Warn("delegator failed to GetStatistics, current distribution is not servicable") log.Warn("delegator failed to GetStatistics, current distribution is not servicable")
return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not serviceable") return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not serviceable")
@ -708,13 +729,14 @@ func organizeSubTask[T any](ctx context.Context,
// update request // update request
req := modify(req, scope, segmentIDs, workerID) req := modify(req, scope, segmentIDs, workerID)
// for partial search, tolerate some worker are offline
worker, err := sd.workerManager.GetWorker(ctx, workerID) worker, err := sd.workerManager.GetWorker(ctx, workerID)
if err != nil { if err != nil {
log.Warn("failed to get worker", log.Warn("failed to get worker for sub task",
zap.Int64("nodeID", workerID), zap.Int64("nodeID", workerID),
zap.Int64s("segments", segmentIDs),
zap.Error(err), zap.Error(err),
) )
return fmt.Errorf("failed to get worker %d, %w", workerID, err)
} }
result = append(result, subTask[T]{ result = append(result, subTask[T]{
@ -744,50 +766,110 @@ func executeSubTasks[T any, R interface {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
var wg sync.WaitGroup var partialResultRequiredDataRatio float64
wg.Add(len(tasks)) if taskType == "Query" || taskType == "Search" {
partialResultRequiredDataRatio = paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.GetAsFloat()
} else {
partialResultRequiredDataRatio = 1.0
}
resultCh := make(chan R, len(tasks)) wg, ctx := errgroup.WithContext(ctx)
errCh := make(chan error, 1) type channelResult struct {
nodeID int64
segments []int64
result R
err error
}
// Buffered channel to collect results from all goroutines
resultCh := make(chan channelResult, len(tasks))
for _, task := range tasks { for _, task := range tasks {
go func(task subTask[T]) { task := task // capture loop variable
defer wg.Done() wg.Go(func() error {
result, err := execute(ctx, task.req, task.worker) var result R
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { var err error
err = fmt.Errorf("worker(%d) query failed: %s", task.targetID, result.GetStatus().GetReason()) if task.targetID == -1 || task.worker == nil {
var segments []int64
if req, ok := any(task.req).(interface{ GetSegmentIDs() []int64 }); ok {
segments = req.GetSegmentIDs()
} else {
segments = []int64{}
}
err = fmt.Errorf("segments not loaded in any worker: %v", segments[:min(len(segments), 10)])
} else {
result, err = execute(ctx, task.req, task.worker)
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
err = fmt.Errorf("worker(%d) query failed: %s", task.targetID, result.GetStatus().GetReason())
}
} }
if err != nil { if err != nil {
log.Warn("failed to execute sub task", log.Warn("failed to execute sub task",
zap.String("taskType", taskType), zap.String("taskType", taskType),
zap.Int64("nodeID", task.targetID), zap.Int64("nodeID", task.targetID),
zap.Error(err), zap.Error(err),
) )
select { // check if partial result is disabled, if so, let all sub tasks fail fast
case errCh <- err: // must be the first if partialResultRequiredDataRatio == 1 {
default: // skip other errors return err
} }
cancel()
return
} }
resultCh <- result
}(task) taskResult := channelResult{
nodeID: task.targetID,
result: result,
err: err,
}
if req, ok := any(task.req).(interface{ GetSegmentIDs() []int64 }); ok {
taskResult.segments = req.GetSegmentIDs()
}
resultCh <- taskResult
return nil
})
} }
wg.Wait() // Wait for all tasks to complete
close(resultCh) if err := wg.Wait(); err != nil {
select { log.Warn("some tasks failed to complete",
case err := <-errCh:
log.Warn("Delegator execute subTask failed",
zap.String("taskType", taskType), zap.String("taskType", taskType),
zap.Error(err), zap.Error(err),
) )
return nil, err return nil, err
default: }
close(resultCh)
successSegmentList := []int64{}
failureSegmentList := []int64{}
var errors []error
// Collect results
results := make([]R, 0, len(tasks))
for item := range resultCh {
if item.err == nil {
successSegmentList = append(successSegmentList, item.segments...)
results = append(results, item.result)
} else {
failureSegmentList = append(failureSegmentList, item.segments...)
errors = append(errors, item.err)
}
} }
results := make([]R, 0, len(tasks)) accessDataRatio := 1.0
for result := range resultCh { totalSegments := len(successSegmentList) + len(failureSegmentList)
results = append(results, result) if totalSegments > 0 {
accessDataRatio = float64(len(successSegmentList)) / float64(totalSegments)
if accessDataRatio < 1.0 {
log.Info("partial result executed successfully",
zap.String("taskType", taskType),
zap.Float64("successRatio", accessDataRatio),
zap.Float64("partialResultRequiredDataRatio", partialResultRequiredDataRatio),
zap.Int("totalSegments", totalSegments),
zap.Int64s("failureSegmentList", failureSegmentList),
)
}
}
if accessDataRatio < partialResultRequiredDataRatio {
return nil, merr.Combine(errors...)
} }
return results, nil return results, nil
} }
@ -891,7 +973,7 @@ func (sd *shardDelegator) UpdateSchema(ctx context.Context, schema *schemapb.Col
log.Info("delegator received update schema event") log.Info("delegator received update schema event")
sealed, growing, version, err := sd.distribution.PinReadableSegments() sealed, growing, version, err := sd.distribution.PinReadableSegments(1.0)
if err != nil { if err != nil {
log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err)) log.Warn("delegator failed to query, current distribution is not serviceable", zap.Error(err))
return err return err
@ -1011,6 +1093,7 @@ func (sd *shardDelegator) loadPartitionStats(ctx context.Context, partStatsVersi
func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64, func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64,
workerManager cluster.Manager, manager *segments.Manager, loader segments.Loader, workerManager cluster.Manager, manager *segments.Manager, loader segments.Loader,
factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager, factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager,
queryView *channelQueryView,
) (ShardDelegator, error) { ) (ShardDelegator, error) {
log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID), log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID),
zap.Int64("replicaID", replicaID), zap.Int64("replicaID", replicaID),
@ -1041,7 +1124,7 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni
segmentManager: manager.Segment, segmentManager: manager.Segment,
workerManager: workerManager, workerManager: workerManager,
lifetime: lifetime.NewLifetime(lifetime.Initializing), lifetime: lifetime.NewLifetime(lifetime.Initializing),
distribution: NewDistribution(channel), distribution: NewDistribution(channel, queryView),
deleteBuffer: deletebuffer.NewListDeleteBuffer[*deletebuffer.Item](startTs, sizePerBlock, deleteBuffer: deletebuffer.NewListDeleteBuffer[*deletebuffer.Item](startTs, sizePerBlock,
[]string{fmt.Sprint(paramtable.GetNodeID()), channel}), []string{fmt.Sprint(paramtable.GetNodeID()), channel}),
pkOracle: pkoracle.NewPkOracle(), pkOracle: pkoracle.NewPkOracle(),

View File

@ -959,70 +959,45 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele
return nil return nil
} }
func (sd *shardDelegator) SyncTargetVersion( func (sd *shardDelegator) SyncTargetVersion(action *querypb.SyncAction, partitions []int64) {
newVersion int64, sd.distribution.SyncTargetVersion(action, partitions)
partitions []int64, // clean delete buffer after distribution becomes serviceable
growingInTarget []int64, if sd.distribution.queryView.Serviceable() {
sealedInTarget []int64, checkpoint := action.GetCheckpoint()
droppedInTarget []int64, deleteSeekPos := action.GetDeleteCP()
checkpoint *msgpb.MsgPosition, if deleteSeekPos == nil {
deleteSeekPos *msgpb.MsgPosition, // for compatible with 2.4, we use checkpoint as deleteCP when deleteCP is nil
) { deleteSeekPos = checkpoint
growings := sd.segmentManager.GetBy( log.Info("use checkpoint as deleteCP",
segments.WithType(segments.SegmentTypeGrowing), zap.String("channelName", sd.vchannelName),
segments.WithChannel(sd.vchannelName), zap.Time("deleteSeekPos", tsoutil.PhysicalTime(action.GetCheckpoint().GetTimestamp())))
)
sealedSet := typeutil.NewUniqueSet(sealedInTarget...)
growingSet := typeutil.NewUniqueSet(growingInTarget...)
droppedSet := typeutil.NewUniqueSet(droppedInTarget...)
redundantGrowing := typeutil.NewUniqueSet()
for _, s := range growings {
if growingSet.Contain(s.ID()) {
continue
} }
// sealed segment already exists, make growing segment redundant start := time.Now()
if sealedSet.Contain(s.ID()) { sizeBeforeClean, _ := sd.deleteBuffer.Size()
redundantGrowing.Insert(s.ID()) l0NumBeforeClean := len(sd.deleteBuffer.ListL0())
} sd.deleteBuffer.UnRegister(deleteSeekPos.GetTimestamp())
sizeAfterClean, _ := sd.deleteBuffer.Size()
// sealed segment already dropped, make growing segment redundant l0NumAfterClean := len(sd.deleteBuffer.ListL0())
if droppedSet.Contain(s.ID()) {
redundantGrowing.Insert(s.ID()) if sizeAfterClean < sizeBeforeClean || l0NumAfterClean < l0NumBeforeClean {
log.Info("clean delete buffer",
zap.String("channel", sd.vchannelName),
zap.Time("deleteSeekPos", tsoutil.PhysicalTime(deleteSeekPos.GetTimestamp())),
zap.Time("channelCP", tsoutil.PhysicalTime(checkpoint.GetTimestamp())),
zap.Int64("sizeBeforeClean", sizeBeforeClean),
zap.Int64("sizeAfterClean", sizeAfterClean),
zap.Int("l0NumBeforeClean", l0NumBeforeClean),
zap.Int("l0NumAfterClean", l0NumAfterClean),
zap.Duration("cost", time.Since(start)),
)
} }
sd.RefreshLevel0DeletionStats()
} }
redundantGrowingIDs := redundantGrowing.Collect()
if len(redundantGrowing) > 0 {
log.Warn("found redundant growing segments",
zap.Int64s("growingSegments", redundantGrowingIDs))
}
sd.distribution.SyncTargetVersion(newVersion, partitions, growingInTarget, sealedInTarget, redundantGrowingIDs)
start := time.Now()
sizeBeforeClean, _ := sd.deleteBuffer.Size()
l0NumBeforeClean := len(sd.deleteBuffer.ListL0())
sd.deleteBuffer.UnRegister(deleteSeekPos.GetTimestamp())
sizeAfterClean, _ := sd.deleteBuffer.Size()
l0NumAfterClean := len(sd.deleteBuffer.ListL0())
if sizeAfterClean < sizeBeforeClean || l0NumAfterClean < l0NumBeforeClean {
log.Info("clean delete buffer",
zap.String("channel", sd.vchannelName),
zap.Time("deleteSeekPos", tsoutil.PhysicalTime(deleteSeekPos.GetTimestamp())),
zap.Time("channelCP", tsoutil.PhysicalTime(checkpoint.GetTimestamp())),
zap.Int64("sizeBeforeClean", sizeBeforeClean),
zap.Int64("sizeAfterClean", sizeAfterClean),
zap.Int("l0NumBeforeClean", l0NumBeforeClean),
zap.Int("l0NumAfterClean", l0NumAfterClean),
zap.Duration("cost", time.Since(start)),
)
}
sd.RefreshLevel0DeletionStats()
} }
func (sd *shardDelegator) GetQueryView() *channelQueryView { func (sd *shardDelegator) GetChannelQueryView() *channelQueryView {
return sd.distribution.GetQueryView() return sd.distribution.queryView
} }
func (sd *shardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) { func (sd *shardDelegator) AddExcludedSegments(excludeInfo map[int64]uint64) {

View File

@ -194,7 +194,7 @@ func (s *DelegatorDataSuite) genCollectionWithFunction() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil return s.mq, nil
}, },
}, 10000, nil, s.chunkManager) }, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err) s.NoError(err)
s.delegator = delegator.(*shardDelegator) s.delegator = delegator.(*shardDelegator)
} }
@ -216,7 +216,7 @@ func (s *DelegatorDataSuite) SetupTest() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil return s.mq, nil
}, },
}, 10000, nil, s.chunkManager) }, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err) s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator) sd, ok := delegator.(*shardDelegator)
s.Require().True(ok) s.Require().True(ok)
@ -419,7 +419,16 @@ func (s *DelegatorDataSuite) TestProcessDelete() {
s.Require().NoError(err) s.Require().NoError(err)
// sync target version, make delegator serviceable // sync target version, make delegator serviceable
s.delegator.SyncTargetVersion(time.Now().UnixNano(), []int64{500}, []int64{1001}, []int64{1000}, nil, &msgpb.MsgPosition{}, &msgpb.MsgPosition{}) s.delegator.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 2001,
GrowingInTarget: []int64{1001},
SealedSegmentRowCount: map[int64]int64{
1000: 100,
},
DroppedInTarget: []int64{},
Checkpoint: &msgpb.MsgPosition{},
DeleteCP: &msgpb.MsgPosition{},
}, []int64{500, 501})
s.delegator.ProcessDelete([]*DeleteData{ s.delegator.ProcessDelete([]*DeleteData{
{ {
PartitionID: 500, PartitionID: 500,
@ -799,7 +808,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil return s.mq, nil
}, },
}, 10000, nil, nil) }, 10000, nil, nil, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err) s.NoError(err)
growing0 := segments.NewMockSegment(s.T()) growing0 := segments.NewMockSegment(s.T())
@ -1422,8 +1431,15 @@ func (s *DelegatorDataSuite) TestSyncTargetVersion() {
s.manager.Segment.Put(context.Background(), segments.SegmentTypeGrowing, ms) s.manager.Segment.Put(context.Background(), segments.SegmentTypeGrowing, ms)
} }
s.delegator.SyncTargetVersion(int64(5), []int64{1}, []int64{1}, []int64{2}, []int64{3, 4}, &msgpb.MsgPosition{}, &msgpb.MsgPosition{}) s.delegator.SyncTargetVersion(&querypb.SyncAction{
s.Equal(int64(5), s.delegator.GetQueryView().GetVersion()) TargetVersion: 5,
GrowingInTarget: []int64{1},
SealedInTarget: []int64{2},
DroppedInTarget: []int64{3, 4},
Checkpoint: &msgpb.MsgPosition{},
DeleteCP: &msgpb.MsgPosition{},
}, []int64{500, 501})
s.Equal(int64(5), s.delegator.GetChannelQueryView().GetVersion())
} }
func (s *DelegatorDataSuite) TestLevel0Deletions() { func (s *DelegatorDataSuite) TestLevel0Deletions() {

View File

@ -163,7 +163,7 @@ func (s *DelegatorSuite) SetupTest() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil return s.mq, nil
}, },
}, 10000, nil, s.chunkManager) }, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err) s.Require().NoError(err)
} }
@ -202,7 +202,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil return s.mq, nil
}, },
}, 10000, nil, s.chunkManager) }, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Error(err) s.Error(err)
}) })
@ -245,7 +245,7 @@ func (s *DelegatorSuite) TestCreateDelegatorWithFunction() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil return s.mq, nil
}, },
}, 10000, nil, s.chunkManager) }, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.NoError(err) s.NoError(err)
}) })
} }
@ -325,7 +325,19 @@ func (s *DelegatorSuite) initSegments() {
Version: 2001, Version: 2001,
}, },
) )
s.delegator.SyncTargetVersion(2001, []int64{500, 501}, []int64{1004}, []int64{1000, 1001, 1002, 1003}, []int64{}, &msgpb.MsgPosition{}, &msgpb.MsgPosition{}) s.delegator.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 2001,
GrowingInTarget: []int64{1004},
SealedSegmentRowCount: map[int64]int64{
1000: 100,
1001: 100,
1002: 100,
1003: 100,
},
DroppedInTarget: []int64{},
Checkpoint: &msgpb.MsgPosition{},
DeleteCP: &msgpb.MsgPosition{},
}, []int64{500, 501})
} }
func (s *DelegatorSuite) TestSearch() { func (s *DelegatorSuite) TestSearch() {
@ -888,7 +900,8 @@ func (s *DelegatorSuite) TestQueryStream() {
Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()}, Req: &internalpb.RetrieveRequest{Base: commonpbutil.NewMsgBase()},
DmlChannels: []string{s.vchannelName}, DmlChannels: []string{s.vchannelName},
}, server) }, server)
s.True(errors.Is(err, mockErr)) s.Error(err)
s.ErrorContains(err, "segments not loaded in any worker")
}) })
s.Run("worker_return_error", func() { s.Run("worker_return_error", func() {
@ -1251,7 +1264,7 @@ func (s *DelegatorSuite) TestUpdateSchema() {
s.Run("worker_manager_error", func() { s.Run("worker_manager_error", func() {
s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).RunAndReturn(func(ctx context.Context, i int64) (cluster.Worker, error) { s.workerManager.EXPECT().GetWorker(mock.Anything, mock.AnythingOfType("int64")).RunAndReturn(func(ctx context.Context, i int64) (cluster.Worker, error) {
return nil, merr.WrapErrServiceInternal("mocked") return nil, merr.WrapErrServiceInternal("mocked")
}).Once() })
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()

View File

@ -155,7 +155,7 @@ func (s *StreamingForwardSuite) SetupTest() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil return s.mq, nil
}, },
}, 10000, nil, s.chunkManager) }, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err) s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator) sd, ok := delegator.(*shardDelegator)
@ -185,7 +185,13 @@ func (s *StreamingForwardSuite) TestBFStreamingForward() {
PartitionID: 1, PartitionID: 1,
SegmentID: 102, SegmentID: 102,
}) })
delegator.distribution.SyncTargetVersion(1, []int64{1}, []int64{100}, []int64{101, 102}, nil) delegator.distribution.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 1,
GrowingInTarget: []int64{100},
SealedInTarget: []int64{101, 102},
DroppedInTarget: nil,
Checkpoint: nil,
}, []int64{1})
// Setup pk oracle // Setup pk oracle
// empty bfs will not match // empty bfs will not match
@ -238,7 +244,13 @@ func (s *StreamingForwardSuite) TestDirectStreamingForward() {
PartitionID: 1, PartitionID: 1,
SegmentID: 102, SegmentID: 102,
}) })
delegator.distribution.SyncTargetVersion(1, []int64{1}, []int64{100}, []int64{101, 102}, nil) delegator.distribution.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 1,
GrowingInTarget: []int64{100},
SealedInTarget: []int64{101, 102},
DroppedInTarget: nil,
Checkpoint: nil,
}, []int64{1})
// Setup pk oracle // Setup pk oracle
// empty bfs will not match // empty bfs will not match
@ -386,7 +398,7 @@ func (s *GrowingMergeL0Suite) SetupTest() {
NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) {
return s.mq, nil return s.mq, nil
}, },
}, 10000, nil, s.chunkManager) }, 10000, nil, s.chunkManager, NewChannelQueryView(nil, nil, nil, initialTargetVersion))
s.Require().NoError(err) s.Require().NoError(err)
sd, ok := delegator.(*shardDelegator) sd, ok := delegator.(*shardDelegator)

View File

@ -17,6 +17,7 @@
package delegator package delegator
import ( import (
"fmt"
"sync" "sync"
"github.com/samber/lo" "github.com/samber/lo"
@ -25,6 +26,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
@ -56,12 +58,26 @@ func getClosedCh() chan struct{} {
} }
// channelQueryView maintains the sealed segment list which should be used for search/query. // channelQueryView maintains the sealed segment list which should be used for search/query.
// for new delegator, will got a new channelQueryView from WatchChannel, and get the queryView update from querycoord before it becomes serviceable
// after delegator becomes serviceable, it only update the queryView by SyncTargetVersion
type channelQueryView struct { type channelQueryView struct {
sealedSegments []int64 // sealed segment list which should be used for search/query growingSegments typeutil.UniqueSet // growing segment list which should be used for search/query
partitions typeutil.UniqueSet // partitions list which sealed segments belong to sealedSegmentRowCount map[int64]int64 // sealed segment list which should be used for search/query, segmentID -> row count
version int64 // version of current query view, same as targetVersion in qc partitions typeutil.UniqueSet // partitions list which sealed segments belong to
version int64 // version of current query view, same as targetVersion in qc
serviceable *atomic.Bool loadedRatio *atomic.Float64 // loaded ratio of current query view, set serviceable to true if loadedRatio == 1.0
unloadedSealedSegments []SegmentEntry // workerID -> -1
}
func NewChannelQueryView(growings []int64, sealedSegmentRowCount map[int64]int64, partitions []int64, version int64) *channelQueryView {
return &channelQueryView{
growingSegments: typeutil.NewUniqueSet(growings...),
sealedSegmentRowCount: sealedSegmentRowCount,
partitions: typeutil.NewUniqueSet(partitions...),
version: version,
loadedRatio: atomic.NewFloat64(0),
}
} }
func (q *channelQueryView) GetVersion() int64 { func (q *channelQueryView) GetVersion() int64 {
@ -69,7 +85,11 @@ func (q *channelQueryView) GetVersion() int64 {
} }
func (q *channelQueryView) Serviceable() bool { func (q *channelQueryView) Serviceable() bool {
return q.serviceable.Load() return q.loadedRatio.Load() >= 1.0
}
func (q *channelQueryView) GetLoadedRatio() float64 {
return q.loadedRatio.Load()
} }
// distribution is the struct to store segment distribution. // distribution is the struct to store segment distribution.
@ -107,22 +127,17 @@ type SegmentEntry struct {
Offline bool // if delegator failed to execute forwardDelete/Query/Search on segment, it will be offline Offline bool // if delegator failed to execute forwardDelete/Query/Search on segment, it will be offline
} }
// NewDistribution creates a new distribution instance with all field initialized. func NewDistribution(channelName string, queryView *channelQueryView) *distribution {
func NewDistribution(channelName string) *distribution {
dist := &distribution{ dist := &distribution{
channelName: channelName, channelName: channelName,
growingSegments: make(map[UniqueID]SegmentEntry), growingSegments: make(map[UniqueID]SegmentEntry),
sealedSegments: make(map[UniqueID]SegmentEntry), sealedSegments: make(map[UniqueID]SegmentEntry),
snapshots: typeutil.NewConcurrentMap[int64, *snapshot](), snapshots: typeutil.NewConcurrentMap[int64, *snapshot](),
current: atomic.NewPointer[snapshot](nil), current: atomic.NewPointer[snapshot](nil),
queryView: &channelQueryView{ queryView: queryView,
serviceable: atomic.NewBool(false),
partitions: typeutil.NewSet[int64](),
version: initialTargetVersion,
},
} }
dist.genSnapshot() dist.genSnapshot()
dist.updateServiceable("NewDistribution")
return dist return dist
} }
@ -132,12 +147,14 @@ func (d *distribution) SetIDFOracle(idfOracle IDFOracle) {
d.idfOracle = idfOracle d.idfOracle = idfOracle
} }
func (d *distribution) PinReadableSegments(partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry, version int64, err error) { // return segment distribution in query view
func (d *distribution) PinReadableSegments(requiredLoadRatio float64, partitions ...int64) (sealed []SnapshotItem, growing []SegmentEntry, version int64, err error) {
d.mut.RLock() d.mut.RLock()
defer d.mut.RUnlock() defer d.mut.RUnlock()
if !d.Serviceable() { if d.queryView.GetLoadedRatio() < requiredLoadRatio {
return nil, nil, -1, merr.WrapErrChannelNotAvailable("channel distribution is not serviceable") return nil, nil, -1, merr.WrapErrChannelNotAvailable(d.channelName,
fmt.Sprintf("channel distribution is not serviceable, required load ratio is %f, current load ratio is %f", requiredLoadRatio, d.queryView.GetLoadedRatio()))
} }
current := d.current.Load() current := d.current.Load()
@ -153,6 +170,15 @@ func (d *distribution) PinReadableSegments(partitions ...int64) (sealed []Snapsh
targetVersion := current.GetTargetVersion() targetVersion := current.GetTargetVersion()
filterReadable := d.readableFilter(targetVersion) filterReadable := d.readableFilter(targetVersion)
sealed, growing = d.filterSegments(sealed, growing, filterReadable) sealed, growing = d.filterSegments(sealed, growing, filterReadable)
if len(d.queryView.unloadedSealedSegments) > 0 {
// append distribution of unloaded segment
sealed = append(sealed, SnapshotItem{
NodeID: -1,
Segments: d.queryView.unloadedSealedSegments,
})
}
return return
} }
@ -213,26 +239,50 @@ func (d *distribution) getTargetVersion() int64 {
// Serviceable returns wether current snapshot is serviceable. // Serviceable returns wether current snapshot is serviceable.
func (d *distribution) Serviceable() bool { func (d *distribution) Serviceable() bool {
return d.queryView.serviceable.Load() return d.queryView.Serviceable()
} }
// for now, delegator become serviceable only when watchDmChannel is done
// so we regard all needed growing is loaded and we compute loadRatio based on sealed segments
func (d *distribution) updateServiceable(triggerAction string) { func (d *distribution) updateServiceable(triggerAction string) {
if d.queryView.version != initialTargetVersion { loadedSealedSegments := int64(0)
serviceable := true totalSealedRowCount := int64(0)
for _, s := range d.queryView.sealedSegments { unloadedSealedSegments := make([]SegmentEntry, 0)
if entry, ok := d.sealedSegments[s]; !ok || entry.Offline { for id, rowCount := range d.queryView.sealedSegmentRowCount {
serviceable = false if entry, ok := d.sealedSegments[id]; ok && !entry.Offline {
break loadedSealedSegments += rowCount
} } else {
} unloadedSealedSegments = append(unloadedSealedSegments, SegmentEntry{SegmentID: id, NodeID: -1})
if serviceable != d.queryView.serviceable.Load() {
d.queryView.serviceable.Store(serviceable)
log.Info("channel distribution serviceable changed",
zap.String("channel", d.channelName),
zap.Bool("serviceable", serviceable),
zap.String("action", triggerAction))
} }
totalSealedRowCount += rowCount
} }
// unloaded segment entry list for partial result
d.queryView.unloadedSealedSegments = unloadedSealedSegments
loadedRatio := 0.0
if len(d.queryView.sealedSegmentRowCount) == 0 {
loadedRatio = 1.0
} else if loadedSealedSegments == 0 {
loadedRatio = 0.0
} else {
loadedRatio = float64(loadedSealedSegments) / float64(totalSealedRowCount)
}
serviceable := loadedRatio >= 1.0
if serviceable != d.queryView.Serviceable() {
log.Info("channel distribution serviceable changed",
zap.String("channel", d.channelName),
zap.Bool("serviceable", serviceable),
zap.Float64("loadedRatio", loadedRatio),
zap.Int64("loadedSealedRowCount", loadedSealedSegments),
zap.Int64("totalSealedRowCount", totalSealedRowCount),
zap.Int("unloadedSealedSegmentNum", len(unloadedSealedSegments)),
zap.Int("totalSealedSegmentNum", len(d.queryView.sealedSegmentRowCount)),
zap.String("action", triggerAction))
}
d.queryView.loadedRatio.Store(loadedRatio)
} }
// AddDistributions add multiple segment entries. // AddDistributions add multiple segment entries.
@ -257,8 +307,14 @@ func (d *distribution) AddDistributions(entries ...SegmentEntry) {
// remain the target version for already loaded segment to void skipping this segment when executing search // remain the target version for already loaded segment to void skipping this segment when executing search
entry.TargetVersion = oldEntry.TargetVersion entry.TargetVersion = oldEntry.TargetVersion
} else { } else {
// waiting for sync target version, to become readable _, ok := d.queryView.sealedSegmentRowCount[entry.SegmentID]
entry.TargetVersion = unreadableTargetVersion if ok || d.queryView.growingSegments.Contain(entry.SegmentID) {
// set segment version to query view version, to support partial result
entry.TargetVersion = d.queryView.GetVersion()
} else {
// set segment version to unreadableTargetVersion, if it's not in query view
entry.TargetVersion = unreadableTargetVersion
}
} }
d.sealedSegments[entry.SegmentID] = entry d.sealedSegments[entry.SegmentID] = entry
} }
@ -306,65 +362,73 @@ func (d *distribution) MarkOfflineSegments(segmentIDs ...int64) {
} }
} }
// UpdateTargetVersion update readable segment version // update readable channel view
func (d *distribution) SyncTargetVersion(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, redundantGrowings []int64) { // 1. update readable channel view to support partial result before distribution is serviceable
// 2. update readable channel view to support full result after new distribution is serviceable
// Notice: if we don't need to be compatible with 2.5.x, we can just update new query view to support query,
// and new query view will become serviceable automatically, a sync action after distribution is serviceable is unnecessary
func (d *distribution) SyncTargetVersion(action *querypb.SyncAction, partitions []int64) {
d.mut.Lock() d.mut.Lock()
defer d.mut.Unlock() defer d.mut.Unlock()
for _, segmentID := range growingInTarget {
entry, ok := d.growingSegments[segmentID]
if !ok {
log.Warn("readable growing segment lost, consume from dml seems too slow",
zap.Int64("segmentID", segmentID))
continue
}
entry.TargetVersion = newVersion
d.growingSegments[segmentID] = entry
}
for _, segmentID := range redundantGrowings {
entry, ok := d.growingSegments[segmentID]
if !ok {
continue
}
entry.TargetVersion = redundantTargetVersion
d.growingSegments[segmentID] = entry
}
for _, segmentID := range sealedInTarget {
entry, ok := d.sealedSegments[segmentID]
if !ok {
continue
}
entry.TargetVersion = newVersion
d.sealedSegments[segmentID] = entry
}
oldValue := d.queryView.version oldValue := d.queryView.version
d.queryView = &channelQueryView{ d.queryView = &channelQueryView{
sealedSegments: sealedInTarget, growingSegments: typeutil.NewUniqueSet(action.GetGrowingInTarget()...),
partitions: typeutil.NewUniqueSet(partitions...), sealedSegmentRowCount: action.GetSealedSegmentRowCount(),
version: newVersion, partitions: typeutil.NewUniqueSet(partitions...),
serviceable: d.queryView.serviceable, version: action.GetTargetVersion(),
loadedRatio: atomic.NewFloat64(0),
} }
// update working partition list sealedSet := typeutil.NewUniqueSet(action.GetSealedInTarget()...)
d.genSnapshot() droppedSet := typeutil.NewUniqueSet(action.GetDroppedInTarget()...)
redundantGrowings := make([]int64, 0)
for _, s := range d.growingSegments {
// sealed segment already exists or dropped, make growing segment redundant
if sealedSet.Contain(s.SegmentID) || droppedSet.Contain(s.SegmentID) {
s.TargetVersion = redundantTargetVersion
d.growingSegments[s.SegmentID] = s
redundantGrowings = append(redundantGrowings, s.SegmentID)
}
}
d.queryView.growingSegments.Range(func(s UniqueID) bool {
entry, ok := d.growingSegments[s]
if !ok {
log.Warn("readable growing segment lost, consume from dml seems too slow",
zap.Int64("segmentID", s))
return true
}
entry.TargetVersion = action.GetTargetVersion()
d.growingSegments[s] = entry
return true
})
for id := range d.queryView.sealedSegmentRowCount {
entry, ok := d.sealedSegments[id]
if !ok {
continue
}
entry.TargetVersion = action.GetTargetVersion()
d.sealedSegments[id] = entry
}
d.genSnapshot()
if d.idfOracle != nil { if d.idfOracle != nil {
d.idfOracle.SetNext(d.current.Load()) d.idfOracle.SetNext(d.current.Load())
d.idfOracle.LazyRemoveGrowings(newVersion, redundantGrowings...) d.idfOracle.LazyRemoveGrowings(action.GetTargetVersion(), redundantGrowings...)
} }
// if sealed segment in leader view is less than sealed segment in target, set delegator to unserviceable
d.updateServiceable("SyncTargetVersion") d.updateServiceable("SyncTargetVersion")
log.Info("Update channel query view", log.Info("Update channel query view",
zap.String("channel", d.channelName), zap.String("channel", d.channelName),
zap.Int64s("partitions", partitions), zap.Int64s("partitions", partitions),
zap.Int64("oldVersion", oldValue), zap.Int64("oldVersion", oldValue),
zap.Int64("newVersion", newVersion), zap.Int64("newVersion", action.GetTargetVersion()),
zap.Int("growingSegmentNum", len(growingInTarget)), zap.Bool("serviceable", d.queryView.Serviceable()),
zap.Int("sealedSegmentNum", len(sealedInTarget)), zap.Float64("loadedRatio", d.queryView.GetLoadedRatio()),
zap.Int("growingSegmentNum", len(action.GetGrowingInTarget())),
zap.Int("sealedSegmentNum", len(action.GetSealedInTarget())),
) )
} }

View File

@ -21,7 +21,10 @@ import (
"time" "time"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
) )
type DistributionSuite struct { type DistributionSuite struct {
@ -30,7 +33,7 @@ type DistributionSuite struct {
} }
func (s *DistributionSuite) SetupTest() { func (s *DistributionSuite) SetupTest() {
s.dist = NewDistribution("channel-1") s.dist = NewDistribution("channel-1", NewChannelQueryView(nil, nil, nil, initialTargetVersion))
} }
func (s *DistributionSuite) TearDownTest() { func (s *DistributionSuite) TearDownTest() {
@ -44,6 +47,7 @@ func (s *DistributionSuite) TestAddDistribution() {
growing []SegmentEntry growing []SegmentEntry
expected []SnapshotItem expected []SnapshotItem
expectedSignalClosed bool expectedSignalClosed bool
expectedLoadRatio float64
} }
cases := []testCase{ cases := []testCase{
@ -177,8 +181,10 @@ func (s *DistributionSuite) TestAddDistribution() {
s.SetupTest() s.SetupTest()
defer s.TearDownTest() defer s.TearDownTest()
s.dist.AddGrowing(tc.growing...) s.dist.AddGrowing(tc.growing...)
s.dist.SyncTargetVersion(1000, nil, nil, nil, nil) s.dist.SyncTargetVersion(&querypb.SyncAction{
_, _, version, err := s.dist.PinReadableSegments() TargetVersion: 1000,
}, nil)
_, _, version, err := s.dist.PinReadableSegments(1.0)
s.Require().NoError(err) s.Require().NoError(err)
s.dist.AddDistributions(tc.input...) s.dist.AddDistributions(tc.input...)
sealed, _ := s.dist.PeekSegments(false) sealed, _ := s.dist.PeekSegments(false)
@ -225,11 +231,6 @@ func (s *DistributionSuite) TestAddGrowing() {
} }
cases := []testCase{ cases := []testCase{
{
tag: "nil input",
input: nil,
expected: []SegmentEntry{},
},
{ {
tag: "normal_case", tag: "normal_case",
input: []SegmentEntry{ input: []SegmentEntry{
@ -261,8 +262,11 @@ func (s *DistributionSuite) TestAddGrowing() {
defer s.TearDownTest() defer s.TearDownTest()
s.dist.AddGrowing(tc.input...) s.dist.AddGrowing(tc.input...)
s.dist.SyncTargetVersion(1000, tc.workingParts, []int64{1, 2}, nil, nil) s.dist.SyncTargetVersion(&querypb.SyncAction{
_, growing, version, err := s.dist.PinReadableSegments() TargetVersion: 1000,
GrowingInTarget: []int64{1, 2},
}, tc.workingParts)
_, growing, version, err := s.dist.PinReadableSegments(1.0)
s.Require().NoError(err) s.Require().NoError(err)
defer s.dist.Unpin(version) defer s.dist.Unpin(version)
@ -452,15 +456,19 @@ func (s *DistributionSuite) TestRemoveDistribution() {
growingIDs := lo.Map(tc.presetGrowing, func(item SegmentEntry, idx int) int64 { growingIDs := lo.Map(tc.presetGrowing, func(item SegmentEntry, idx int) int64 {
return item.SegmentID return item.SegmentID
}) })
sealedIDs := lo.Map(tc.presetSealed, func(item SegmentEntry, idx int) int64 { sealedSegmentRowCount := lo.SliceToMap(tc.presetSealed, func(item SegmentEntry) (int64, int64) {
return item.SegmentID return item.SegmentID, 100
}) })
s.dist.SyncTargetVersion(time.Now().Unix(), nil, growingIDs, sealedIDs, nil) s.dist.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 1000,
GrowingInTarget: growingIDs,
SealedSegmentRowCount: sealedSegmentRowCount,
}, nil)
var version int64 var version int64
if tc.withMockRead { if tc.withMockRead {
var err error var err error
_, _, version, err = s.dist.PinReadableSegments() _, _, version, err = s.dist.PinReadableSegments(1.0)
s.Require().NoError(err) s.Require().NoError(err)
} }
@ -675,10 +683,14 @@ func (s *DistributionSuite) TestMarkOfflineSegments() {
defer s.TearDownTest() defer s.TearDownTest()
s.dist.AddDistributions(tc.input...) s.dist.AddDistributions(tc.input...)
sealedSegmentID := lo.Map(tc.input, func(t SegmentEntry, _ int) int64 { sealedSegmentRowCount := lo.SliceToMap(tc.input, func(t SegmentEntry) (int64, int64) {
return t.SegmentID return t.SegmentID, 100
}) })
s.dist.SyncTargetVersion(1000, nil, nil, sealedSegmentID, nil) s.dist.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 1000,
SealedSegmentRowCount: sealedSegmentRowCount,
DroppedInTarget: nil,
}, nil)
s.dist.MarkOfflineSegments(tc.offlines...) s.dist.MarkOfflineSegments(tc.offlines...)
s.Equal(tc.serviceable, s.dist.Serviceable()) s.Equal(tc.serviceable, s.dist.Serviceable())
@ -740,29 +752,187 @@ func (s *DistributionSuite) Test_SyncTargetVersion() {
s.dist.AddGrowing(growing...) s.dist.AddGrowing(growing...)
s.dist.AddDistributions(sealed...) s.dist.AddDistributions(sealed...)
s.dist.SyncTargetVersion(2, []int64{1}, []int64{2, 3}, []int64{6}, []int64{}) s.dist.SyncTargetVersion(&querypb.SyncAction{
TargetVersion: 2,
GrowingInTarget: []int64{1},
SealedSegmentRowCount: map[int64]int64{4: 100, 5: 200},
DroppedInTarget: []int64{6},
}, []int64{1})
s1, s2, _, err := s.dist.PinReadableSegments() s1, s2, _, err := s.dist.PinReadableSegments(1.0)
s.Require().NoError(err) s.Require().NoError(err)
s.Len(s1[0].Segments, 1) s.Len(s1[0].Segments, 2)
s.Len(s2, 2) s.Len(s2, 1)
s1, s2, _ = s.dist.PinOnlineSegments() s1, s2, _ = s.dist.PinOnlineSegments()
s.Len(s1[0].Segments, 3) s.Len(s1[0].Segments, 3)
s.Len(s2, 3) s.Len(s2, 3)
s.dist.queryView.serviceable.Store(true) s.dist.SyncTargetVersion(&querypb.SyncAction{
s.dist.SyncTargetVersion(2, []int64{1}, []int64{222}, []int64{}, []int64{}) TargetVersion: 2,
s.True(s.dist.Serviceable()) GrowingInTarget: []int64{1},
SealedSegmentRowCount: map[int64]int64{333: 100},
s.dist.SyncTargetVersion(2, []int64{1}, []int64{}, []int64{333}, []int64{}) DroppedInTarget: []int64{},
}, []int64{1})
s.False(s.dist.Serviceable()) s.False(s.dist.Serviceable())
_, _, _, err = s.dist.PinReadableSegments(1.0)
s.dist.SyncTargetVersion(2, []int64{1}, []int64{}, []int64{333}, []int64{1, 2, 3})
_, _, _, err = s.dist.PinReadableSegments()
s.Error(err) s.Error(err)
} }
func TestDistributionSuite(t *testing.T) { func TestDistributionSuite(t *testing.T) {
suite.Run(t, new(DistributionSuite)) suite.Run(t, new(DistributionSuite))
} }
func TestNewChannelQueryView(t *testing.T) {
growings := []int64{1, 2, 3}
sealedWithRowCount := map[int64]int64{4: 100, 5: 200, 6: 300}
partitions := []int64{7, 8, 9}
version := int64(10)
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
assert.NotNil(t, view)
assert.ElementsMatch(t, growings, view.growingSegments.Collect())
assert.ElementsMatch(t, lo.Keys(sealedWithRowCount), lo.Keys(view.sealedSegmentRowCount))
assert.True(t, view.partitions.Contain(7))
assert.True(t, view.partitions.Contain(8))
assert.True(t, view.partitions.Contain(9))
assert.Equal(t, version, view.version)
assert.Equal(t, float64(0), view.loadedRatio.Load())
assert.False(t, view.Serviceable())
}
func TestDistribution_NewDistribution(t *testing.T) {
channelName := "test_channel"
growings := []int64{1, 2, 3}
sealedWithRowCount := map[int64]int64{4: 100, 5: 200, 6: 300}
partitions := []int64{7, 8, 9}
version := int64(10)
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
dist := NewDistribution(channelName, view)
assert.NotNil(t, dist)
assert.Equal(t, channelName, dist.channelName)
assert.Equal(t, view, dist.queryView)
assert.NotNil(t, dist.growingSegments)
assert.NotNil(t, dist.sealedSegments)
assert.NotNil(t, dist.snapshots)
assert.NotNil(t, dist.current)
}
func TestDistribution_UpdateServiceable(t *testing.T) {
channelName := "test_channel"
growings := []int64{1, 2, 3}
sealedWithRowCount := map[int64]int64{4: 100, 5: 100, 6: 100}
partitions := []int64{7, 8, 9}
version := int64(10)
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
dist := NewDistribution(channelName, view)
// Test with no segments loaded
dist.updateServiceable("test")
assert.False(t, dist.Serviceable())
assert.Equal(t, float64(0), dist.queryView.GetLoadedRatio())
// Test with some segments loaded
dist.sealedSegments[4] = SegmentEntry{
SegmentID: 4,
Offline: false,
}
dist.growingSegments[1] = SegmentEntry{
SegmentID: 1,
}
dist.updateServiceable("test")
assert.False(t, dist.Serviceable())
assert.Equal(t, float64(2)/float64(6), dist.queryView.GetLoadedRatio())
// Test with all segments loaded
for id := range sealedWithRowCount {
dist.sealedSegments[id] = SegmentEntry{
SegmentID: id,
Offline: false,
}
}
for _, id := range growings {
dist.growingSegments[id] = SegmentEntry{
SegmentID: id,
}
}
dist.updateServiceable("test")
assert.True(t, dist.Serviceable())
assert.Equal(t, float64(1), dist.queryView.GetLoadedRatio())
}
func TestDistribution_SyncTargetVersion(t *testing.T) {
channelName := "test_channel"
growings := []int64{1, 2, 3}
sealedWithRowCount := map[int64]int64{4: 100, 5: 100, 6: 100}
partitions := []int64{7, 8, 9}
version := int64(10)
view := NewChannelQueryView(growings, sealedWithRowCount, partitions, version)
dist := NewDistribution(channelName, view)
// Add some initial segments
dist.growingSegments[1] = SegmentEntry{
SegmentID: 1,
}
dist.sealedSegments[4] = SegmentEntry{
SegmentID: 4,
}
// Create a new sync action
action := &querypb.SyncAction{
GrowingInTarget: []int64{1, 2},
SealedSegmentRowCount: map[int64]int64{4: 100, 5: 100},
DroppedInTarget: []int64{3},
TargetVersion: version + 1,
}
// Sync the view
dist.SyncTargetVersion(action, partitions)
// Verify the changes
assert.Equal(t, action.GetTargetVersion(), dist.queryView.version)
assert.ElementsMatch(t, action.GetGrowingInTarget(), dist.queryView.growingSegments.Collect())
assert.ElementsMatch(t, lo.Keys(action.GetSealedSegmentRowCount()), lo.Keys(dist.queryView.sealedSegmentRowCount))
assert.True(t, dist.queryView.partitions.Contain(7))
assert.True(t, dist.queryView.partitions.Contain(8))
assert.True(t, dist.queryView.partitions.Contain(9))
// Verify growing segment target version
assert.Equal(t, action.GetTargetVersion(), dist.growingSegments[1].TargetVersion)
// Verify sealed segment target version
assert.Equal(t, action.GetTargetVersion(), dist.sealedSegments[4].TargetVersion)
}
func TestDistribution_MarkOfflineSegments(t *testing.T) {
channelName := "test_channel"
view := NewChannelQueryView([]int64{}, map[int64]int64{1: 100, 2: 200}, []int64{}, 0)
dist := NewDistribution(channelName, view)
// Add some segments
dist.sealedSegments[1] = SegmentEntry{
SegmentID: 1,
NodeID: 100,
Version: 1,
}
dist.sealedSegments[2] = SegmentEntry{
SegmentID: 2,
NodeID: 100,
Version: 1,
}
// Mark segments offline
dist.MarkOfflineSegments(1, 2)
// Verify the changes
assert.True(t, dist.sealedSegments[1].Offline)
assert.True(t, dist.sealedSegments[2].Offline)
assert.Equal(t, int64(-1), dist.sealedSegments[1].NodeID)
assert.Equal(t, int64(-1), dist.sealedSegments[2].NodeID)
assert.Equal(t, unreadableTargetVersion, dist.sealedSegments[1].Version)
assert.Equal(t, unreadableTargetVersion, dist.sealedSegments[2].Version)
}

View File

@ -8,8 +8,6 @@ import (
internalpb "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" internalpb "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
mock "github.com/stretchr/testify/mock" mock "github.com/stretchr/testify/mock"
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
querypb "github.com/milvus-io/milvus/pkg/v2/proto/querypb" querypb "github.com/milvus-io/milvus/pkg/v2/proto/querypb"
schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" schemapb "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@ -243,12 +241,12 @@ func (_c *MockShardDelegator_GetPartitionStatsVersions_Call) RunAndReturn(run fu
return _c return _c
} }
// GetQueryView provides a mock function with no fields // GetChannelQueryView provides a mock function with no fields
func (_m *MockShardDelegator) GetQueryView() *channelQueryView { func (_m *MockShardDelegator) GetChannelQueryView() *channelQueryView {
ret := _m.Called() ret := _m.Called()
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for GetQueryView") panic("no return value specified for GetChannelQueryView")
} }
var r0 *channelQueryView var r0 *channelQueryView
@ -263,29 +261,29 @@ func (_m *MockShardDelegator) GetQueryView() *channelQueryView {
return r0 return r0
} }
// MockShardDelegator_GetQueryView_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetQueryView' // MockShardDelegator_GetChannelQueryView_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChannelQueryView'
type MockShardDelegator_GetQueryView_Call struct { type MockShardDelegator_GetChannelQueryView_Call struct {
*mock.Call *mock.Call
} }
// GetQueryView is a helper method to define mock.On call // GetChannelQueryView is a helper method to define mock.On call
func (_e *MockShardDelegator_Expecter) GetQueryView() *MockShardDelegator_GetQueryView_Call { func (_e *MockShardDelegator_Expecter) GetChannelQueryView() *MockShardDelegator_GetChannelQueryView_Call {
return &MockShardDelegator_GetQueryView_Call{Call: _e.mock.On("GetQueryView")} return &MockShardDelegator_GetChannelQueryView_Call{Call: _e.mock.On("GetChannelQueryView")}
} }
func (_c *MockShardDelegator_GetQueryView_Call) Run(run func()) *MockShardDelegator_GetQueryView_Call { func (_c *MockShardDelegator_GetChannelQueryView_Call) Run(run func()) *MockShardDelegator_GetChannelQueryView_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run() run()
}) })
return _c return _c
} }
func (_c *MockShardDelegator_GetQueryView_Call) Return(_a0 *channelQueryView) *MockShardDelegator_GetQueryView_Call { func (_c *MockShardDelegator_GetChannelQueryView_Call) Return(_a0 *channelQueryView) *MockShardDelegator_GetChannelQueryView_Call {
_c.Call.Return(_a0) _c.Call.Return(_a0)
return _c return _c
} }
func (_c *MockShardDelegator_GetQueryView_Call) RunAndReturn(run func() *channelQueryView) *MockShardDelegator_GetQueryView_Call { func (_c *MockShardDelegator_GetChannelQueryView_Call) RunAndReturn(run func() *channelQueryView) *MockShardDelegator_GetChannelQueryView_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
@ -1037,9 +1035,9 @@ func (_c *MockShardDelegator_SyncPartitionStats_Call) RunAndReturn(run func(cont
return _c return _c
} }
// SyncTargetVersion provides a mock function with given fields: newVersion, partitions, growingInTarget, sealedInTarget, droppedInTarget, checkpoint, deleteSeekPos // SyncTargetVersion provides a mock function with given fields: action, partitions
func (_m *MockShardDelegator) SyncTargetVersion(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, droppedInTarget []int64, checkpoint *msgpb.MsgPosition, deleteSeekPos *msgpb.MsgPosition) { func (_m *MockShardDelegator) SyncTargetVersion(action *querypb.SyncAction, partitions []int64) {
_m.Called(newVersion, partitions, growingInTarget, sealedInTarget, droppedInTarget, checkpoint, deleteSeekPos) _m.Called(action, partitions)
} }
// MockShardDelegator_SyncTargetVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncTargetVersion' // MockShardDelegator_SyncTargetVersion_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncTargetVersion'
@ -1048,20 +1046,15 @@ type MockShardDelegator_SyncTargetVersion_Call struct {
} }
// SyncTargetVersion is a helper method to define mock.On call // SyncTargetVersion is a helper method to define mock.On call
// - newVersion int64 // - action *querypb.SyncAction
// - partitions []int64 // - partitions []int64
// - growingInTarget []int64 func (_e *MockShardDelegator_Expecter) SyncTargetVersion(action interface{}, partitions interface{}) *MockShardDelegator_SyncTargetVersion_Call {
// - sealedInTarget []int64 return &MockShardDelegator_SyncTargetVersion_Call{Call: _e.mock.On("SyncTargetVersion", action, partitions)}
// - droppedInTarget []int64
// - checkpoint *msgpb.MsgPosition
// - deleteSeekPos *msgpb.MsgPosition
func (_e *MockShardDelegator_Expecter) SyncTargetVersion(newVersion interface{}, partitions interface{}, growingInTarget interface{}, sealedInTarget interface{}, droppedInTarget interface{}, checkpoint interface{}, deleteSeekPos interface{}) *MockShardDelegator_SyncTargetVersion_Call {
return &MockShardDelegator_SyncTargetVersion_Call{Call: _e.mock.On("SyncTargetVersion", newVersion, partitions, growingInTarget, sealedInTarget, droppedInTarget, checkpoint, deleteSeekPos)}
} }
func (_c *MockShardDelegator_SyncTargetVersion_Call) Run(run func(newVersion int64, partitions []int64, growingInTarget []int64, sealedInTarget []int64, droppedInTarget []int64, checkpoint *msgpb.MsgPosition, deleteSeekPos *msgpb.MsgPosition)) *MockShardDelegator_SyncTargetVersion_Call { func (_c *MockShardDelegator_SyncTargetVersion_Call) Run(run func(action *querypb.SyncAction, partitions []int64)) *MockShardDelegator_SyncTargetVersion_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64), args[1].([]int64), args[2].([]int64), args[3].([]int64), args[4].([]int64), args[5].(*msgpb.MsgPosition), args[6].(*msgpb.MsgPosition)) run(args[0].(*querypb.SyncAction), args[1].([]int64))
}) })
return _c return _c
} }
@ -1071,7 +1064,7 @@ func (_c *MockShardDelegator_SyncTargetVersion_Call) Return() *MockShardDelegato
return _c return _c
} }
func (_c *MockShardDelegator_SyncTargetVersion_Call) RunAndReturn(run func(int64, []int64, []int64, []int64, []int64, *msgpb.MsgPosition, *msgpb.MsgPosition)) *MockShardDelegator_SyncTargetVersion_Call { func (_c *MockShardDelegator_SyncTargetVersion_Call) RunAndReturn(run func(*querypb.SyncAction, []int64)) *MockShardDelegator_SyncTargetVersion_Call {
_c.Run(run) _c.Run(run)
return _c return _c
} }

View File

@ -231,6 +231,7 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque
if err != nil { if err != nil {
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader, fmt.Sprint(req.GetReq().GetCollectionID())).Inc() metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, metrics.FailLabel, metrics.Leader, fmt.Sprint(req.GetReq().GetCollectionID())).Inc()
} }
metrics.QueryNodePartialResultCount.WithLabelValues(fmt.Sprint(node.GetNodeID()), metrics.QueryLabel, fmt.Sprint(req.GetReq().GetCollectionID())).Inc()
}() }()
log.Debug("start do query with channel", log.Debug("start do query with channel",

View File

@ -251,6 +251,13 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
} }
}() }()
queryView := delegator.NewChannelQueryView(
channel.GetUnflushedSegmentIds(),
req.GetSealedSegmentRowCount(),
req.GetPartitionIDs(),
req.GetTargetVersion(),
)
delegator, err := delegator.NewShardDelegator( delegator, err := delegator.NewShardDelegator(
ctx, ctx,
req.GetCollectionID(), req.GetCollectionID(),
@ -264,6 +271,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm
channel.GetSeekPosition().GetTimestamp(), channel.GetSeekPosition().GetTimestamp(),
node.queryHook, node.queryHook,
node.chunkManager, node.chunkManager,
queryView,
) )
if err != nil { if err != nil {
log.Warn("failed to create shard delegator", zap.Error(err)) log.Warn("failed to create shard delegator", zap.Error(err))
@ -1252,7 +1260,7 @@ func (node *QueryNode) GetDataDistribution(ctx context.Context, req *querypb.Get
numOfGrowingRows += segment.InsertCount() numOfGrowingRows += segment.InsertCount()
} }
queryView := delegator.GetQueryView() queryView := delegator.GetChannelQueryView()
leaderViews = append(leaderViews, &querypb.LeaderView{ leaderViews = append(leaderViews, &querypb.LeaderView{
Collection: delegator.Collection(), Collection: delegator.Collection(),
Channel: key, Channel: key,
@ -1355,16 +1363,7 @@ func (node *QueryNode) SyncDistribution(ctx context.Context, req *querypb.SyncDi
return id, action.GetCheckpoint().Timestamp return id, action.GetCheckpoint().Timestamp
}) })
shardDelegator.AddExcludedSegments(flushedInfo) shardDelegator.AddExcludedSegments(flushedInfo)
deleteCP := action.GetDeleteCP() shardDelegator.SyncTargetVersion(action, req.GetLoadMeta().GetPartitionIDs())
if deleteCP == nil {
// for compatible with 2.4, we use checkpoint as deleteCP when deleteCP is nil
deleteCP = action.GetCheckpoint()
log.Info("use checkpoint as deleteCP",
zap.String("channelName", req.GetChannel()),
zap.Time("deleteSeekPos", tsoutil.PhysicalTime(action.GetCheckpoint().GetTimestamp())))
}
shardDelegator.SyncTargetVersion(action.GetTargetVersion(), req.GetLoadMeta().GetPartitionIDs(), action.GetGrowingInTarget(),
action.GetSealedInTarget(), action.GetDroppedInTarget(), action.GetCheckpoint(), deleteCP)
case querypb.SyncType_UpdatePartitionStats: case querypb.SyncType_UpdatePartitionStats:
log.Info("sync update partition stats versions") log.Info("sync update partition stats versions")
shardDelegator.SyncPartitionStats(ctx, action.PartitionStatsVersions) shardDelegator.SyncPartitionStats(ctx, action.PartitionStatsVersions)

View File

@ -1393,9 +1393,13 @@ func (suite *ServiceSuite) TestSearch_Failed() {
} }
syncVersionAction := &querypb.SyncAction{ syncVersionAction := &querypb.SyncAction{
Type: querypb.SyncType_UpdateVersion, Type: querypb.SyncType_UpdateVersion,
SealedInTarget: []int64{1, 2, 3}, SealedSegmentRowCount: map[int64]int64{
TargetVersion: time.Now().UnixMilli(), 1: 100,
2: 200,
3: 300,
},
TargetVersion: time.Now().UnixMilli(),
} }
syncReq.Actions = []*querypb.SyncAction{syncVersionAction} syncReq.Actions = []*querypb.SyncAction{syncVersionAction}

View File

@ -814,6 +814,18 @@ var (
cgoNameLabelName, cgoNameLabelName,
cgoTypeLabelName, cgoTypeLabelName,
}) })
QueryNodePartialResultCount = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: milvusNamespace,
Subsystem: typeutil.QueryNodeRole,
Name: "partial_result_count",
Help: "count of partial result",
}, []string{
nodeIDLabelName,
queryTypeLabelName,
collectionIDLabelName,
})
) )
// RegisterQueryNode registers QueryNode metrics // RegisterQueryNode registers QueryNode metrics
@ -885,6 +897,7 @@ func RegisterQueryNode(registry *prometheus.Registry) {
registry.MustRegister(QueryNodeDeleteBufferSize) registry.MustRegister(QueryNodeDeleteBufferSize)
registry.MustRegister(QueryNodeDeleteBufferRowNum) registry.MustRegister(QueryNodeDeleteBufferRowNum)
registry.MustRegister(QueryNodeCGOCallLatency) registry.MustRegister(QueryNodeCGOCallLatency)
registry.MustRegister(QueryNodePartialResultCount)
// Add cgo metrics // Add cgo metrics
RegisterCGOMetrics(registry) RegisterCGOMetrics(registry)
@ -933,6 +946,13 @@ func CleanupQueryNodeCollectionMetrics(nodeID int64, collectionID int64) {
collectionIDLabelName: collectionIDLabel, collectionIDLabelName: collectionIDLabel,
}) })
QueryNodePartialResultCount.
DeletePartialMatch(
prometheus.Labels{
nodeIDLabelName: nodeIDLabel,
collectionIDLabelName: collectionIDLabel,
})
QueryNodeSearchHitSegmentNum. QueryNodeSearchHitSegmentNum.
DeletePartialMatch( DeletePartialMatch(
prometheus.Labels{ prometheus.Labels{

View File

@ -281,6 +281,7 @@ message GetSegmentInfoResponse {
message GetShardLeadersRequest { message GetShardLeadersRequest {
common.MsgBase base = 1; common.MsgBase base = 1;
int64 collectionID = 2; int64 collectionID = 2;
bool with_unserviceable_shards = 3;
} }
message GetShardLeadersResponse { message GetShardLeadersResponse {
@ -297,6 +298,7 @@ message ShardLeadersList { // All leaders of all replicas of one shard
string channel_name = 1; string channel_name = 1;
repeated int64 node_ids = 2; repeated int64 node_ids = 2;
repeated string node_addrs = 3; repeated string node_addrs = 3;
repeated bool serviceable = 4;
} }
message SyncNewCreatedPartitionRequest { message SyncNewCreatedPartitionRequest {
@ -334,6 +336,8 @@ message WatchDmChannelsRequest {
int64 offlineNodeID = 11; int64 offlineNodeID = 11;
int64 version = 12; int64 version = 12;
repeated index.IndexInfo index_info_list = 13; repeated index.IndexInfo index_info_list = 13;
int64 target_version = 14;
map<int64, int64> sealed_segment_row_count = 15; // segmentID -> row count, same as unflushedSegmentIds in vchannelInfo
} }
message UnsubDmChannelRequest { message UnsubDmChannelRequest {
@ -625,7 +629,7 @@ message LeaderView {
} }
message LeaderViewStatus { message LeaderViewStatus {
bool serviceable = 10; bool serviceable = 1;
} }
message SegmentDist { message SegmentDist {
@ -725,6 +729,7 @@ message SyncAction {
msg.MsgPosition checkpoint = 11; msg.MsgPosition checkpoint = 11;
map<int64, int64> partition_stats_versions = 12; map<int64, int64> partition_stats_versions = 12;
msg.MsgPosition deleteCP = 13; msg.MsgPosition deleteCP = 13;
map<int64, int64> sealed_segment_row_count = 14; // segmentID -> row count, same as sealedInTarget
} }
message SyncDistributionRequest { message SyncDistributionRequest {

File diff suppressed because it is too large Load Diff

View File

@ -2856,6 +2856,8 @@ type queryNodeConfig struct {
IDFEnableDisk ParamItem `refreshable:"true"` IDFEnableDisk ParamItem `refreshable:"true"`
IDFLocalPath ParamItem `refreshable:"true"` IDFLocalPath ParamItem `refreshable:"true"`
IDFWriteConcurrenct ParamItem `refreshable:"true"` IDFWriteConcurrenct ParamItem `refreshable:"true"`
// partial search
PartialResultRequiredDataRatio ParamItem `refreshable:"true"`
} }
func (p *queryNodeConfig) init(base *BaseTable) { func (p *queryNodeConfig) init(base *BaseTable) {
@ -3786,6 +3788,15 @@ user-task-polling:
Export: true, Export: true,
} }
p.WorkerPoolingSize.Init(base.mgr) p.WorkerPoolingSize.Init(base.mgr)
p.PartialResultRequiredDataRatio = ParamItem{
Key: "proxy.partialResultRequiredDataRatio",
Version: "2.6.0",
DefaultValue: "1",
Doc: `partial result required data ratio, default to 1 which means disable partial result, otherwise, it will be used as the minimum data ratio for partial result`,
Export: true,
}
p.PartialResultRequiredDataRatio.Init(base.mgr)
} }
// ///////////////////////////////////////////////////////////////////////////// // /////////////////////////////////////////////////////////////////////////////

View File

@ -489,6 +489,10 @@ func TestComponentParam(t *testing.T) {
assert.Equal(t, "/var/lib/milvus/data/mmap", Params.MmapDirPath.GetValue()) assert.Equal(t, "/var/lib/milvus/data/mmap", Params.MmapDirPath.GetValue())
assert.Equal(t, 60*time.Second, Params.DiskSizeFetchInterval.GetAsDuration(time.Second)) assert.Equal(t, 60*time.Second, Params.DiskSizeFetchInterval.GetAsDuration(time.Second))
assert.Equal(t, 1.0, Params.PartialResultRequiredDataRatio.GetAsFloat())
params.Save(Params.PartialResultRequiredDataRatio.Key, "0.8")
assert.Equal(t, 0.8, Params.PartialResultRequiredDataRatio.GetAsFloat())
}) })
t.Run("test dataCoordConfig", func(t *testing.T) { t.Run("test dataCoordConfig", func(t *testing.T) {

View File

@ -13,6 +13,8 @@ package retry
import ( import (
"context" "context"
"runtime"
"strconv"
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
@ -23,6 +25,14 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/merr"
) )
func getCaller(skip int) string {
_, file, line, ok := runtime.Caller(skip)
if !ok {
return "unknown"
}
return file + ":" + strconv.Itoa(line)
}
// Do will run function with retry mechanism. // Do will run function with retry mechanism.
// fn is the func to run. // fn is the func to run.
// Option can control the retry times and timeout. // Option can control the retry times and timeout.
@ -43,7 +53,10 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
for i := uint(0); c.attempts == 0 || i < c.attempts; i++ { for i := uint(0); c.attempts == 0 || i < c.attempts; i++ {
if err := fn(); err != nil { if err := fn(); err != nil {
if i%4 == 0 { if i%4 == 0 {
log.Warn("retry func failed", zap.Uint("retried", i), zap.Error(err)) log.Warn("retry func failed",
zap.Uint("retried", i),
zap.Error(err),
zap.String("caller", getCaller(2)))
} }
if !IsRecoverable(err) { if !IsRecoverable(err) {
@ -52,6 +65,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
zap.Uint("retried", i), zap.Uint("retried", i),
zap.Uint("attempt", c.attempts), zap.Uint("attempt", c.attempts),
zap.Bool("isContextErr", isContextErr), zap.Bool("isContextErr", isContextErr),
zap.String("caller", getCaller(2)),
) )
if isContextErr && lastErr != nil { if isContextErr && lastErr != nil {
return lastErr return lastErr
@ -62,6 +76,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
log.Warn("retry func failed, not be retryable", log.Warn("retry func failed, not be retryable",
zap.Uint("retried", i), zap.Uint("retried", i),
zap.Uint("attempt", c.attempts), zap.Uint("attempt", c.attempts),
zap.String("caller", getCaller(2)),
) )
return err return err
} }
@ -73,6 +88,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
zap.Uint("retried", i), zap.Uint("retried", i),
zap.Uint("attempt", c.attempts), zap.Uint("attempt", c.attempts),
zap.Bool("isContextErr", isContextErr), zap.Bool("isContextErr", isContextErr),
zap.String("caller", getCaller(2)),
) )
if isContextErr && lastErr != nil { if isContextErr && lastErr != nil {
return lastErr return lastErr
@ -88,6 +104,7 @@ func Do(ctx context.Context, fn func() error, opts ...Option) error {
log.Warn("retry func failed, ctx done", log.Warn("retry func failed, ctx done",
zap.Uint("retried", i), zap.Uint("retried", i),
zap.Uint("attempt", c.attempts), zap.Uint("attempt", c.attempts),
zap.String("caller", getCaller(2)),
) )
return lastErr return lastErr
} }
@ -127,11 +144,22 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
for i := uint(0); i < c.attempts; i++ { for i := uint(0); i < c.attempts; i++ {
if shouldRetry, err := fn(); err != nil { if shouldRetry, err := fn(); err != nil {
if i%4 == 0 { if i%4 == 0 {
log.Warn("retry func failed", zap.Uint("retried", i), zap.Error(err)) log.Warn("retry func failed",
zap.Uint("retried", i),
zap.String("caller", getCaller(2)),
zap.Error(err),
)
} }
if !shouldRetry { if !shouldRetry {
if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil { isContextErr := errors.IsAny(err, context.Canceled, context.DeadlineExceeded)
log.Warn("retry func failed, not be recoverable",
zap.Uint("retried", i),
zap.Uint("attempt", c.attempts),
zap.Bool("isContextErr", isContextErr),
zap.String("caller", getCaller(2)),
)
if isContextErr && lastErr != nil {
return lastErr return lastErr
} }
return err return err
@ -139,8 +167,14 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
deadline, ok := ctx.Deadline() deadline, ok := ctx.Deadline()
if ok && time.Until(deadline) < c.sleep { if ok && time.Until(deadline) < c.sleep {
// to avoid sleep until ctx done isContextErr := errors.IsAny(err, context.Canceled, context.DeadlineExceeded)
if errors.IsAny(err, context.Canceled, context.DeadlineExceeded) && lastErr != nil { log.Warn("retry func failed, deadline",
zap.Uint("retried", i),
zap.Uint("attempt", c.attempts),
zap.Bool("isContextErr", isContextErr),
zap.String("caller", getCaller(2)),
)
if isContextErr && lastErr != nil {
return lastErr return lastErr
} }
return err return err
@ -151,6 +185,11 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
select { select {
case <-time.After(c.sleep): case <-time.After(c.sleep):
case <-ctx.Done(): case <-ctx.Done():
log.Warn("retry func failed, ctx done",
zap.Uint("retried", i),
zap.Uint("attempt", c.attempts),
zap.String("caller", getCaller(2)),
)
return lastErr return lastErr
} }
@ -162,6 +201,12 @@ func Handle(ctx context.Context, fn func() (bool, error), opts ...Option) error
return nil return nil
} }
} }
if lastErr != nil {
log.Warn("retry func failed, reach max retry",
zap.Uint("attempt", c.attempts),
zap.String("caller", getCaller(2)),
)
}
return lastErr return lastErr
} }

View File

@ -349,6 +349,7 @@ func (s *HybridSearchSuite) TestHybridSearchSingleSubReq() {
}) })
s.NoError(err) s.NoError(err)
s.NoError(merr.Error(loadStatus)) s.NoError(merr.Error(loadStatus))
s.WaitForLoad(ctx, collectionName)
// search // search
expr := fmt.Sprintf("%s > 0", integration.Int64Field) expr := fmt.Sprintf("%s > 0", integration.Int64Field)

View File

@ -0,0 +1,542 @@
package partialsearch
import (
"context"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/tests/integration"
)
const (
dim = 128
dbName = ""
timeout = 10 * time.Second
)
type PartialSearchTestSuit struct {
integration.MiniClusterSuite
}
func (s *PartialSearchTestSuit) SetupSuite() {
paramtable.Init()
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.GracefulStopTimeout.Key, "1")
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.AutoBalanceInterval.Key, "10000")
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceCheckInterval.Key, "10000")
paramtable.Get().Save(paramtable.Get().QueryCoordCfg.TaskExecutionCap.Key, "1")
// make query survive when delegator is down
paramtable.Get().Save(paramtable.Get().ProxyCfg.RetryTimesOnReplica.Key, "10")
s.Require().NoError(s.SetupEmbedEtcd())
}
func (s *PartialSearchTestSuit) initCollection(collectionName string, replica int, channelNum int, segmentNum int, segmentRowNum int) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
s.CreateCollectionWithConfiguration(ctx, &integration.CreateCollectionConfig{
DBName: dbName,
Dim: dim,
CollectionName: collectionName,
ChannelNum: channelNum,
SegmentNum: segmentNum,
RowNumPerSegment: segmentRowNum,
})
for i := 1; i < replica; i++ {
s.Cluster.AddQueryNode()
}
// load
loadStatus, err := s.Cluster.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
ReplicaNumber: int32(replica),
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
s.True(merr.Ok(loadStatus))
s.WaitForLoad(ctx, collectionName)
log.Info("initCollection Done")
}
func (s *PartialSearchTestSuit) executeQuery(collection string) (int, error) {
ctx := context.Background()
queryResult, err := s.Cluster.Proxy.Query(ctx, &milvuspb.QueryRequest{
DbName: "",
CollectionName: collection,
Expr: "",
OutputFields: []string{"count(*)"},
})
if err := merr.CheckRPCCall(queryResult.GetStatus(), err); err != nil {
log.Info("query failed", zap.Error(err))
return 0, err
}
return int(queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]), nil
}
// expected return partial result, no search failures
func (s *PartialSearchTestSuit) TestSingleNodeDownOnSingleReplica() {
partialResultRequiredDataRatio := 0.3
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 6 querynode
for i := 1; i < 6; i++ {
s.Cluster.AddQueryNode()
}
// init collection with 1 replica, 2 channels, 6 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 6
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0)) // must fix this cornor case
// stop qn in single replica expected got search failures
s.Cluster.QueryNode.Stop()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
// expected some search failures, but partial result is returned before all data is loaded
// for case which all querynode down, partial search can decrease recovery time
func (s *PartialSearchTestSuit) TestAllNodeDownOnSingleReplica() {
partialResultRequiredDataRatio := 0.5
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 2 querynode
s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 10 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 10
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
partialResultRecoverTs := atomic.NewInt64(0)
fullResultRecoverTs := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if failCounter.Load() > 0 {
if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
partialResultRecoverTs.Store(time.Now().UnixNano())
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
} else {
log.Info("query return full result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
fullResultRecoverTs.Store(time.Now().UnixNano())
}
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
// stop all qn in single replica expected got search failures
for _, qn := range s.Cluster.GetAllQueryNodes() {
qn.Stop()
}
s.Cluster.AddQueryNode()
time.Sleep(20 * time.Second)
s.True(failCounter.Load() >= 0)
s.True(partialResultCounter.Load() >= 0)
log.Info("partialResultRecoverTs", zap.Int64("partialResultRecoverTs", partialResultRecoverTs.Load()), zap.Int64("fullResultRecoverTs", fullResultRecoverTs.Load()))
s.True(partialResultRecoverTs.Load() < fullResultRecoverTs.Load())
close(stopSearchCh)
wg.Wait()
}
// expected return full result, no search failures
// cause we won't pick best replica to response query, there may return partial result even when only one replica is partial loaded
func (s *PartialSearchTestSuit) TestSingleNodeDownOnMultiReplica() {
partialResultRequiredDataRatio := 0.5
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 4 querynode
qn1 := s.Cluster.AddQueryNode()
s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 2
segmentRowNum := 2000
s.initCollection(name, 2, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
// stop qn in single replica expected got search failures
qn1.Stop()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
// expected return partial result, no search failures
func (s *PartialSearchTestSuit) TestEachReplicaHasNodeDownOnMultiReplica() {
partialResultRequiredDataRatio := 0.3
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 12 querynode
for i := 2; i < 12; i++ {
s.Cluster.AddQueryNode()
}
// init collection with 1 replica, 2 channels, 18 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 9
segmentRowNum := 2000
s.initCollection(name, 2, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
ctx := context.Background()
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
continue
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
replicaResp, err := s.Cluster.Proxy.GetReplicas(ctx, &milvuspb.GetReplicasRequest{
DbName: dbName,
CollectionName: name,
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, replicaResp.GetStatus().GetErrorCode())
s.Equal(2, len(replicaResp.GetReplicas()))
// for each replica, choose a querynode to stop
for _, replica := range replicaResp.GetReplicas() {
for _, qn := range s.Cluster.GetAllQueryNodes() {
if funcutil.SliceContain(replica.GetNodeIds(), qn.GetQueryNode().GetNodeID()) {
qn.Stop()
break
}
}
}
time.Sleep(10 * time.Second)
s.True(failCounter.Load() >= 0)
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
// when set high partial result required data ratio, partial search will not be triggered, expected search failures
// for example, set 0.8, but each querynode crash will lost 50% data, so partial search will not be triggered
func (s *PartialSearchTestSuit) TestPartialResultRequiredDataRatioTooHigh() {
partialResultRequiredDataRatio := 0.8
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 2 querynode
qn1 := s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 2
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
qn1.Stop()
time.Sleep(10 * time.Second)
s.True(failCounter.Load() >= 0)
s.True(partialResultCounter.Load() == 0)
close(stopSearchCh)
wg.Wait()
}
// set partial result required data ratio to 0, expected no partial result and no search failures even after all querynode down
func (s *PartialSearchTestSuit) TestSearchNeverFails() {
partialResultRequiredDataRatio := 0.0
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// init cluster with 2 querynode
s.Cluster.AddQueryNode()
// init collection with 1 replica, 2 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
channelNum := 2
segmentNumInChannel := 2
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
for _, qn := range s.Cluster.GetAllQueryNodes() {
qn.Stop()
}
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
func (s *PartialSearchTestSuit) TestSkipWaitTSafe() {
partialResultRequiredDataRatio := 0.5
paramtable.Get().Save(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key, fmt.Sprintf("%f", partialResultRequiredDataRatio))
defer paramtable.Get().Reset(paramtable.Get().QueryNodeCfg.PartialResultRequiredDataRatio.Key)
// mock tsafe Delay
paramtable.Get().Save(paramtable.Get().ProxyCfg.TimeTickInterval.Key, "30000")
// init cluster with 5 querynode
for i := 1; i < 5; i++ {
s.Cluster.AddQueryNode()
}
// init collection with 1 replica, 1 channels, 4 segments, 2000 rows per segment
// expect each node has 1 channel and 2 segments
name := "test_balance_" + funcutil.GenRandomStr()
s.initCollection(name, 1, 1, 4, 2000)
channelNum := 1
segmentNumInChannel := 4
segmentRowNum := 2000
s.initCollection(name, 1, channelNum, segmentNumInChannel, segmentRowNum)
totalEntities := segmentNumInChannel * segmentRowNum
stopSearchCh := make(chan struct{})
failCounter := atomic.NewInt64(0)
partialResultCounter := atomic.NewInt64(0)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopSearchCh:
log.Info("stop search")
return
default:
numEntities, err := s.executeQuery(name)
if err != nil {
log.Info("query failed", zap.Error(err))
failCounter.Inc()
} else if numEntities < totalEntities {
log.Info("query return partial result", zap.Int("numEntities", numEntities), zap.Int("totalEntities", totalEntities))
partialResultCounter.Inc()
s.True(numEntities >= int((float64(totalEntities) * partialResultRequiredDataRatio)))
}
}
}
}()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.Equal(partialResultCounter.Load(), int64(0))
s.Cluster.QueryNode.Stop()
time.Sleep(10 * time.Second)
s.Equal(failCounter.Load(), int64(0))
s.True(partialResultCounter.Load() >= 0)
close(stopSearchCh)
wg.Wait()
}
func TestPartialResult(t *testing.T) {
suite.Run(t, new(PartialSearchTestSuit))
}

View File

@ -1,347 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package partialsearch
import (
"context"
"fmt"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/suite"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
"github.com/milvus-io/milvus/tests/integration"
)
type PartialSearchSuite struct {
integration.MiniClusterSuite
dim int
numCollections int
rowsPerCollection int
waitTimeInSec time.Duration
prefix string
}
func (s *PartialSearchSuite) setupParam() {
s.dim = 128
s.numCollections = 1
s.rowsPerCollection = 100
s.waitTimeInSec = time.Second * 10
}
func (s *PartialSearchSuite) loadCollection(collectionName string, dim int, wg *sync.WaitGroup) {
c := s.Cluster
dbName := ""
schema := integration.ConstructSchema(collectionName, dim, true)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
createCollectionStatus, err := c.Proxy.CreateCollection(context.TODO(), &milvuspb.CreateCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
})
s.NoError(err)
err = merr.Error(createCollectionStatus)
s.NoError(err)
showCollectionsResp, err := c.Proxy.ShowCollections(context.TODO(), &milvuspb.ShowCollectionsRequest{})
s.NoError(err)
s.True(merr.Ok(showCollectionsResp.GetStatus()))
batchSize := 500000
for start := 0; start < s.rowsPerCollection; start += batchSize {
rowNum := batchSize
if start+batchSize > s.rowsPerCollection {
rowNum = s.rowsPerCollection - start
}
fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim)
hashKeys := integration.GenerateHashKeys(rowNum)
insertResult, err := c.Proxy.Insert(context.TODO(), &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: []*schemapb.FieldData{fVecColumn},
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.True(merr.Ok(insertResult.GetStatus()))
}
log.Info("=========================Data insertion finished=========================")
// flush
flushResp, err := c.Proxy.Flush(context.TODO(), &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has := flushResp.GetCollSegIDs()[collectionName]
ids := segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has := flushResp.GetCollFlushTs()[collectionName]
s.True(has)
segments, err := c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
s.WaitForFlush(context.TODO(), ids, flushTs, dbName, collectionName)
log.Info("=========================Data flush finished=========================")
// create index
createIndexStatus, err := c.Proxy.CreateIndex(context.TODO(), &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: integration.FloatVecField,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.IP),
})
s.NoError(err)
err = merr.Error(createIndexStatus)
s.NoError(err)
s.WaitForIndexBuilt(context.TODO(), collectionName, integration.FloatVecField)
log.Info("=========================Index created=========================")
// load
loadStatus, err := c.Proxy.LoadCollection(context.TODO(), &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
err = merr.Error(loadStatus)
s.NoError(err)
s.WaitForLoad(context.TODO(), collectionName)
log.Info("=========================Collection loaded=========================")
wg.Done()
}
func (s *PartialSearchSuite) checkCollectionLoaded(collectionName string) bool {
loadProgress, err := s.Cluster.Proxy.GetLoadingProgress(context.TODO(), &milvuspb.GetLoadingProgressRequest{
DbName: "",
CollectionName: collectionName,
})
s.NoError(err)
if loadProgress.GetProgress() != int64(100) {
return false
}
return true
}
func (s *PartialSearchSuite) checkCollectionsLoaded(startCollectionID, endCollectionID int) bool {
notLoaded := 0
loaded := 0
for idx := startCollectionID; idx < endCollectionID; idx++ {
collectionName := s.prefix + "_" + strconv.Itoa(idx)
if s.checkCollectionLoaded(collectionName) {
notLoaded++
} else {
loaded++
}
}
log.Info(fmt.Sprintf("loading status: %d/%d", loaded, endCollectionID-startCollectionID+1))
return notLoaded == 0
}
func (s *PartialSearchSuite) checkAllCollectionsLoaded() bool {
return s.checkCollectionsLoaded(0, s.numCollections)
}
func (s *PartialSearchSuite) search(collectionName string, dim int) {
c := s.Cluster
var err error
// Query
queryReq := &milvuspb.QueryRequest{
Base: nil,
CollectionName: collectionName,
PartitionNames: nil,
Expr: "",
OutputFields: []string{"count(*)"},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
queryResult, err := c.Proxy.Query(context.TODO(), queryReq)
s.NoError(err)
s.Equal(queryResult.Status.ErrorCode, commonpb.ErrorCode_Success)
s.Equal(len(queryResult.FieldsData), 1)
numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]
s.Equal(numEntities, int64(s.rowsPerCollection))
// Search
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
radius := 10
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
params["radius"] = radius
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ := c.Proxy.Search(context.TODO(), searchReq)
err = merr.Error(searchResult.GetStatus())
s.NoError(err)
}
func (s *PartialSearchSuite) FailOnSearch(collectionName string) {
c := s.Cluster
expr := fmt.Sprintf("%s > 0", integration.Int64Field)
nq := 10
topk := 10
roundDecimal := -1
radius := 10
params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP)
params["radius"] = radius
searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, s.dim, topk, roundDecimal)
searchResult, err := c.Proxy.Search(context.TODO(), searchReq)
s.NoError(err)
err = merr.Error(searchResult.GetStatus())
s.Require().Error(err)
}
func (s *PartialSearchSuite) setupData() {
// Add the second query node
log.Info("=========================Start to inject data=========================")
s.prefix = "TestPartialSearchUtil" + funcutil.GenRandomStr()
searchName := s.prefix + "_0"
wg := sync.WaitGroup{}
for idx := 0; idx < s.numCollections; idx++ {
wg.Add(1)
go s.loadCollection(s.prefix+"_"+strconv.Itoa(idx), s.dim, &wg)
}
wg.Wait()
log.Info("=========================Data injection finished=========================")
s.checkAllCollectionsLoaded()
log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName))
s.search(searchName, s.dim)
log.Info("=========================Search finished=========================")
time.Sleep(s.waitTimeInSec)
s.checkAllCollectionsLoaded()
log.Info(fmt.Sprintf("=========================start to search2 %s=========================", searchName))
s.search(searchName, s.dim)
log.Info("=========================Search2 finished=========================")
s.checkAllCollectionsReady()
}
func (s *PartialSearchSuite) checkCollectionsReady(startCollectionID, endCollectionID int) {
for i := startCollectionID; i < endCollectionID; i++ {
collectionName := s.prefix + "_" + strconv.Itoa(i)
s.search(collectionName, s.dim)
queryReq := &milvuspb.QueryRequest{
CollectionName: collectionName,
Expr: "",
OutputFields: []string{"count(*)"},
}
_, err := s.Cluster.Proxy.Query(context.TODO(), queryReq)
s.NoError(err)
}
}
func (s *PartialSearchSuite) checkAllCollectionsReady() {
s.checkCollectionsReady(0, s.numCollections)
}
func (s *PartialSearchSuite) releaseSegmentsReq(collectionID, nodeID, segmentID typeutil.UniqueID, shard string) *querypb.ReleaseSegmentsRequest {
req := &querypb.ReleaseSegmentsRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_ReleaseSegments),
commonpbutil.WithMsgID(1<<30),
commonpbutil.WithTargetID(nodeID),
),
NodeID: nodeID,
CollectionID: collectionID,
SegmentIDs: []int64{segmentID},
Scope: querypb.DataScope_Historical,
Shard: shard,
NeedTransfer: false,
}
return req
}
func (s *PartialSearchSuite) describeCollection(name string) (int64, []string) {
resp, err := s.Cluster.Proxy.DescribeCollection(context.TODO(), &milvuspb.DescribeCollectionRequest{
DbName: "default",
CollectionName: name,
})
s.NoError(err)
log.Info(fmt.Sprintf("describe collection: %v", resp))
return resp.CollectionID, resp.VirtualChannelNames
}
func (s *PartialSearchSuite) getSegmentIDs(collectionName string) []int64 {
resp, err := s.Cluster.Proxy.GetPersistentSegmentInfo(context.TODO(), &milvuspb.GetPersistentSegmentInfoRequest{
DbName: "default",
CollectionName: collectionName,
})
s.NoError(err)
var res []int64
for _, seg := range resp.Infos {
res = append(res, seg.SegmentID)
}
return res
}
func (s *PartialSearchSuite) TestPartialSearch() {
s.setupParam()
s.setupData()
startCollectionID := 0
endCollectionID := 0
// Search should work in the beginning
s.checkCollectionsReady(startCollectionID, endCollectionID)
// Test case with one segment released
// Partial search does not work yet.
c := s.Cluster
q1 := c.QueryNode
c.MixCoord.StopCheckerForTestOnly()
collectionName := s.prefix + "_0"
nodeID := q1.GetServerIDForTestOnly()
collectionID, channels := s.describeCollection(collectionName)
segs := s.getSegmentIDs(collectionName)
s.Require().Positive(len(segs))
s.Require().Positive(len(channels))
segmentID := segs[0]
shard := channels[0]
req := s.releaseSegmentsReq(collectionID, nodeID, segmentID, shard)
q1.ReleaseSegments(context.TODO(), req)
s.FailOnSearch(collectionName)
c.MixCoord.StartCheckerForTestOnly()
}
func TestPartialSearchUtil(t *testing.T) {
suite.Run(t, new(PartialSearchSuite))
}