mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 09:38:39 +08:00
Remove merge policy of proxy RoundRobin policy (#23021)
Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
parent
22aeb72eba
commit
bd5fab1e53
@ -2590,10 +2590,9 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
|||||||
),
|
),
|
||||||
ReqID: paramtable.GetNodeID(),
|
ReqID: paramtable.GetNodeID(),
|
||||||
},
|
},
|
||||||
request: request,
|
request: request,
|
||||||
qc: node.queryCoord,
|
qc: node.queryCoord,
|
||||||
queryShardPolicy: mergeRoundRobinPolicy,
|
shardMgr: node.shardMgr,
|
||||||
shardMgr: node.shardMgr,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
method := "Query"
|
method := "Query"
|
||||||
@ -2924,8 +2923,7 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
|
|||||||
qc: node.queryCoord,
|
qc: node.queryCoord,
|
||||||
ids: ids.IdArray,
|
ids: ids.IdArray,
|
||||||
|
|
||||||
queryShardPolicy: mergeRoundRobinPolicy,
|
shardMgr: node.shardMgr,
|
||||||
shardMgr: node.shardMgr,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log := log.Ctx(ctx).With(
|
log := log.Ctx(ctx).With(
|
||||||
|
|||||||
@ -2,161 +2,69 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/log"
|
"github.com/milvus-io/milvus/internal/log"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/merr"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// type pickShardPolicy func(ctx context.Context, mgr *shardClientMgr, query func(UniqueID, types.QueryNode) error, leaders []nodeInfo) error
|
// type pickShardPolicy func(ctx context.Context, mgr *shardClientMgr, query func(UniqueID, types.QueryNode) error, leaders []nodeInfo) error
|
||||||
|
|
||||||
type pickShardPolicy func(context.Context, *shardClientMgr, func(context.Context, UniqueID, types.QueryNode, []string, int) error, map[string][]nodeInfo) error
|
type queryFunc func(context.Context, UniqueID, types.QueryNode, ...string) error
|
||||||
|
type pickShardPolicy func(context.Context, *shardClientMgr, queryFunc, map[string][]nodeInfo) error
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errBegin = errors.New("begin error")
|
|
||||||
errInvalidShardLeaders = errors.New("Invalid shard leader")
|
errInvalidShardLeaders = errors.New("Invalid shard leader")
|
||||||
)
|
)
|
||||||
|
|
||||||
func updateShardsWithRoundRobin(shardsLeaders map[string][]nodeInfo) {
|
// RoundRobinPolicy do the query with multiple dml channels
|
||||||
for channelID, leaders := range shardsLeaders {
|
// if request failed, it finds shard leader for failed dml channels
|
||||||
if len(leaders) <= 1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
shardsLeaders[channelID] = append(leaders[1:], leaders[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// mergeErrSet merges all errors in ErrSet
|
|
||||||
func mergeErrSet(errSet map[string]error) error {
|
|
||||||
var builder strings.Builder
|
|
||||||
for channel, err := range errSet {
|
|
||||||
if err == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
builder.WriteString(fmt.Sprintf("Channel: %s returns err: %s", channel, err.Error()))
|
|
||||||
}
|
|
||||||
return errors.New(builder.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// group dml shard leader with same nodeID
|
|
||||||
func groupShardleadersWithSameQueryNode(
|
|
||||||
ctx context.Context,
|
|
||||||
shard2leaders map[string][]nodeInfo,
|
|
||||||
nexts map[string]int, errSet map[string]error,
|
|
||||||
mgr *shardClientMgr) (map[int64][]string, map[int64]types.QueryNode, error) {
|
|
||||||
// check if all leaders were checked
|
|
||||||
for dml, idx := range nexts {
|
|
||||||
if idx >= len(shard2leaders[dml]) {
|
|
||||||
log.Ctx(ctx).Warn("no shard leaders were available",
|
|
||||||
zap.String("channel", dml),
|
|
||||||
zap.String("leaders", fmt.Sprintf("%v", shard2leaders[dml])))
|
|
||||||
if err, ok := errSet[dml]; ok {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
return nil, nil, fmt.Errorf("no available shard leader")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
qnSet := make(map[int64]types.QueryNode)
|
|
||||||
node2dmls := make(map[int64][]string)
|
|
||||||
updates := make(map[string]int)
|
|
||||||
|
|
||||||
for dml, idx := range nexts {
|
|
||||||
updates[dml] = idx + 1
|
|
||||||
nodeInfo := shard2leaders[dml][idx]
|
|
||||||
if _, ok := qnSet[nodeInfo.nodeID]; !ok {
|
|
||||||
qn, err := mgr.GetClient(ctx, nodeInfo.nodeID)
|
|
||||||
if err != nil {
|
|
||||||
log.Ctx(ctx).Warn("failed to get shard leader", zap.Int64("nodeID", nodeInfo.nodeID), zap.Error(err))
|
|
||||||
// if get client failed, just record error and wait for next round to get client and do query
|
|
||||||
errSet[dml] = err
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
qnSet[nodeInfo.nodeID] = qn
|
|
||||||
}
|
|
||||||
if _, ok := node2dmls[nodeInfo.nodeID]; !ok {
|
|
||||||
node2dmls[nodeInfo.nodeID] = make([]string, 0)
|
|
||||||
}
|
|
||||||
node2dmls[nodeInfo.nodeID] = append(node2dmls[nodeInfo.nodeID], dml)
|
|
||||||
}
|
|
||||||
// update idxes
|
|
||||||
for dml, idx := range updates {
|
|
||||||
nexts[dml] = idx
|
|
||||||
}
|
|
||||||
return node2dmls, qnSet, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// mergeRoundRobinPolicy first group shard leaders with same querynode, then do the query with multiple dml channels
|
|
||||||
// if request failed, it finds shard leader for failed dml channels, and again groups shard leaders and do the query
|
|
||||||
//
|
//
|
||||||
// Suppose qn0 is the shard leader for dml-channel0 and dml-channel1, if search for dml-channel0 succeeded, but
|
func RoundRobinPolicy(
|
||||||
// failed for dml-channel1. In this case, an error returned from qn0, and next shard leaders for dml-channel0 and dml-channel1 will be
|
|
||||||
// retrieved and dml-channel0 therefore will again be searched.
|
|
||||||
//
|
|
||||||
// TODO: In this senario, qn0 should return a partial success results for dml-channel0, and only retrys for dml-channel1
|
|
||||||
func mergeRoundRobinPolicy(
|
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
mgr *shardClientMgr,
|
mgr *shardClientMgr,
|
||||||
query func(context.Context, UniqueID, types.QueryNode, []string, int) error,
|
query queryFunc,
|
||||||
dml2leaders map[string][]nodeInfo) error {
|
dml2leaders map[string][]nodeInfo) error {
|
||||||
nexts := make(map[string]int)
|
|
||||||
errSet := make(map[string]error) // record err for dml channels
|
queryChannel := func(ctx context.Context, channel string) error {
|
||||||
totalChannelNum := len(dml2leaders)
|
var combineErr error
|
||||||
for dml := range dml2leaders {
|
leaders := dml2leaders[channel]
|
||||||
nexts[dml] = 0
|
|
||||||
}
|
for _, target := range leaders {
|
||||||
for len(nexts) > 0 {
|
qn, err := mgr.GetClient(ctx, target.nodeID)
|
||||||
node2dmls, nodeset, err := groupShardleadersWithSameQueryNode(ctx, dml2leaders, nexts, errSet, mgr)
|
if err != nil {
|
||||||
if err != nil {
|
log.Warn("query channel failed, node not available", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err))
|
||||||
log.Ctx(ctx).Warn("failed to search/query with round-robin policy", zap.Error(mergeErrSet(errSet)))
|
combineErr = merr.Combine(combineErr, err)
|
||||||
return err
|
continue
|
||||||
}
|
|
||||||
wg := &sync.WaitGroup{}
|
|
||||||
mu := &sync.Mutex{}
|
|
||||||
wg.Add(len(node2dmls))
|
|
||||||
for nodeID, channels := range node2dmls {
|
|
||||||
nodeID := nodeID
|
|
||||||
channels := channels
|
|
||||||
qn := nodeset[nodeID]
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
if err := query(ctx, nodeID, qn, channels, totalChannelNum); err != nil {
|
|
||||||
log.Ctx(ctx).Warn("failed to do query with node", zap.Int64("nodeID", nodeID),
|
|
||||||
zap.Strings("dmlChannels", channels), zap.Error(err))
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
for _, ch := range channels {
|
|
||||||
errSet[ch] = err
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
for _, channel := range channels {
|
|
||||||
delete(nexts, channel)
|
|
||||||
delete(errSet, channel)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
if len(nexts) > 0 {
|
|
||||||
nextSet := make(map[string]int64)
|
|
||||||
for dml, idx := range nexts {
|
|
||||||
if idx >= len(dml2leaders[dml]) {
|
|
||||||
nextSet[dml] = -1
|
|
||||||
} else {
|
|
||||||
nextSet[dml] = dml2leaders[dml][idx].nodeID
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
log.Ctx(ctx).Warn("retry another query node with round robin", zap.Any("Nexts", nextSet))
|
err = query(ctx, target.nodeID, qn, channel)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("query channel failed", zap.String("channel", channel), zap.Int64("nodeID", target.nodeID), zap.Error(err))
|
||||||
|
combineErr = merr.Combine(combineErr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Ctx(ctx).Error("failed to do query on all shard leader",
|
||||||
|
zap.String("channel", channel), zap.Error(combineErr))
|
||||||
|
return combineErr
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
wg, ctx := errgroup.WithContext(ctx)
|
||||||
|
for channel := range dml2leaders {
|
||||||
|
channel := channel
|
||||||
|
wg.Go(func() error {
|
||||||
|
err := queryChannel(ctx, channel)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
err := wg.Wait()
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,116 +8,12 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/log"
|
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUpdateShardsWithRoundRobin(t *testing.T) {
|
func TestRoundRobinPolicy(t *testing.T) {
|
||||||
list := map[string][]nodeInfo{
|
|
||||||
"channel-1": {
|
|
||||||
{1, "addr1"},
|
|
||||||
{2, "addr2"},
|
|
||||||
},
|
|
||||||
"channel-2": {
|
|
||||||
{20, "addr20"},
|
|
||||||
{21, "addr21"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
updateShardsWithRoundRobin(list)
|
|
||||||
|
|
||||||
assert.Equal(t, int64(2), list["channel-1"][0].nodeID)
|
|
||||||
assert.Equal(t, "addr2", list["channel-1"][0].address)
|
|
||||||
assert.Equal(t, int64(21), list["channel-2"][0].nodeID)
|
|
||||||
assert.Equal(t, "addr21", list["channel-2"][0].address)
|
|
||||||
|
|
||||||
t.Run("check print", func(t *testing.T) {
|
|
||||||
qns := []nodeInfo{
|
|
||||||
{1, "addr1"},
|
|
||||||
{2, "addr2"},
|
|
||||||
{20, "addr20"},
|
|
||||||
{21, "addr21"},
|
|
||||||
}
|
|
||||||
|
|
||||||
res := fmt.Sprintf("list: %v", qns)
|
|
||||||
|
|
||||||
log.Debug("Check String func",
|
|
||||||
zap.Any("Any", qns),
|
|
||||||
zap.Any("ok", qns[0]),
|
|
||||||
zap.String("ok2", res),
|
|
||||||
)
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGroupShardLeadersWithSameQueryNode(t *testing.T) {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
var (
|
|
||||||
ctx = context.TODO()
|
|
||||||
)
|
|
||||||
|
|
||||||
mgr := newShardClientMgr()
|
|
||||||
|
|
||||||
shard2leaders := map[string][]nodeInfo{
|
|
||||||
"c0": {{nodeID: 0, address: "fake"}, {nodeID: 1, address: "fake"}, {nodeID: 2, address: "fake"}},
|
|
||||||
"c1": {{nodeID: 1, address: "fake"}, {nodeID: 2, address: "fake"}, {nodeID: 3, address: "fake"}},
|
|
||||||
"c2": {{nodeID: 0, address: "fake"}, {nodeID: 2, address: "fake"}, {nodeID: 3, address: "fake"}},
|
|
||||||
"c3": {{nodeID: 1, address: "fake"}, {nodeID: 3, address: "fake"}, {nodeID: 4, address: "fake"}},
|
|
||||||
}
|
|
||||||
mgr.UpdateShardLeaders(nil, shard2leaders)
|
|
||||||
nexts := map[string]int{
|
|
||||||
"c0": 0,
|
|
||||||
"c1": 0,
|
|
||||||
"c2": 0,
|
|
||||||
"c3": 0,
|
|
||||||
}
|
|
||||||
errSet := map[string]error{}
|
|
||||||
node2dmls, qnSet, err := groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
for nodeID := range node2dmls {
|
|
||||||
sort.Slice(node2dmls[nodeID], func(i, j int) bool { return node2dmls[nodeID][i] < node2dmls[nodeID][j] })
|
|
||||||
}
|
|
||||||
|
|
||||||
cli0, err := mgr.GetClient(ctx, 0)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
cli1, err := mgr.GetClient(ctx, 1)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
cli2, err := mgr.GetClient(ctx, 2)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
cli3, err := mgr.GetClient(ctx, 3)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, node2dmls, map[int64][]string{0: {"c0", "c2"}, 1: {"c1", "c3"}})
|
|
||||||
assert.Equal(t, qnSet, map[int64]types.QueryNode{0: cli0, 1: cli1})
|
|
||||||
assert.Equal(t, nexts, map[string]int{"c0": 1, "c1": 1, "c2": 1, "c3": 1})
|
|
||||||
// delete client1 in client mgr
|
|
||||||
delete(mgr.clients.data, 1)
|
|
||||||
node2dmls, qnSet, err = groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
for nodeID := range node2dmls {
|
|
||||||
sort.Slice(node2dmls[nodeID], func(i, j int) bool { return node2dmls[nodeID][i] < node2dmls[nodeID][j] })
|
|
||||||
}
|
|
||||||
assert.Equal(t, node2dmls, map[int64][]string{2: {"c1", "c2"}, 3: {"c3"}})
|
|
||||||
assert.Equal(t, qnSet, map[int64]types.QueryNode{2: cli2, 3: cli3})
|
|
||||||
assert.Equal(t, nexts, map[string]int{"c0": 2, "c1": 2, "c2": 2, "c3": 2})
|
|
||||||
assert.NotNil(t, errSet["c0"])
|
|
||||||
|
|
||||||
nexts["c0"] = 3
|
|
||||||
_, _, err = groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr)
|
|
||||||
assert.True(t, strings.Contains(err.Error(), errSet["c0"].Error()))
|
|
||||||
|
|
||||||
nexts["c0"] = 2
|
|
||||||
nexts["c1"] = 3
|
|
||||||
_, _, err = groupShardleadersWithSameQueryNode(ctx, shard2leaders, nexts, errSet, mgr)
|
|
||||||
assert.Equal(t, err, fmt.Errorf("no available shard leader"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMergeRoundRobinPolicy(t *testing.T) {
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -137,7 +33,7 @@ func TestMergeRoundRobinPolicy(t *testing.T) {
|
|||||||
querier := &mockQuery{}
|
querier := &mockQuery{}
|
||||||
querier.init()
|
querier.init()
|
||||||
|
|
||||||
err = mergeRoundRobinPolicy(ctx, mgr, querier.query, shard2leaders)
|
err = RoundRobinPolicy(ctx, mgr, querier.query, shard2leaders)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, querier.records(), map[UniqueID][]string{0: {"c0", "c2"}, 1: {"c1", "c3"}})
|
assert.Equal(t, querier.records(), map[UniqueID][]string{0: {"c0", "c2"}, 1: {"c1", "c3"}})
|
||||||
|
|
||||||
@ -145,7 +41,7 @@ func TestMergeRoundRobinPolicy(t *testing.T) {
|
|||||||
querier.init()
|
querier.init()
|
||||||
querier.failset[0] = mockerr
|
querier.failset[0] = mockerr
|
||||||
|
|
||||||
err = mergeRoundRobinPolicy(ctx, mgr, querier.query, shard2leaders)
|
err = RoundRobinPolicy(ctx, mgr, querier.query, shard2leaders)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, querier.records(), map[int64][]string{1: {"c0", "c1", "c3"}, 2: {"c2"}})
|
assert.Equal(t, querier.records(), map[int64][]string{1: {"c0", "c1", "c3"}, 2: {"c2"}})
|
||||||
|
|
||||||
@ -153,7 +49,7 @@ func TestMergeRoundRobinPolicy(t *testing.T) {
|
|||||||
querier.failset[0] = mockerr
|
querier.failset[0] = mockerr
|
||||||
querier.failset[2] = mockerr
|
querier.failset[2] = mockerr
|
||||||
querier.failset[3] = mockerr
|
querier.failset[3] = mockerr
|
||||||
err = mergeRoundRobinPolicy(ctx, mgr, querier.query, shard2leaders)
|
err = RoundRobinPolicy(ctx, mgr, querier.query, shard2leaders)
|
||||||
assert.True(t, strings.Contains(err.Error(), mockerr.Error()))
|
assert.True(t, strings.Contains(err.Error(), mockerr.Error()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -167,7 +63,7 @@ type mockQuery struct {
|
|||||||
failset map[UniqueID]error
|
failset map[UniqueID]error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockQuery) query(_ context.Context, nodeID UniqueID, qn types.QueryNode, chs []string, _ int) error {
|
func (m *mockQuery) query(_ context.Context, nodeID UniqueID, qn types.QueryNode, chs ...string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
if err, ok := m.failset[nodeID]; ok {
|
if err, ok := m.failset[nodeID]; ok {
|
||||||
|
|||||||
@ -243,7 +243,7 @@ func (t *queryTask) createPlan(ctx context.Context) error {
|
|||||||
|
|
||||||
func (t *queryTask) PreExecute(ctx context.Context) error {
|
func (t *queryTask) PreExecute(ctx context.Context) error {
|
||||||
if t.queryShardPolicy == nil {
|
if t.queryShardPolicy == nil {
|
||||||
t.queryShardPolicy = mergeRoundRobinPolicy
|
t.queryShardPolicy = RoundRobinPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Base.MsgType = commonpb.MsgType_Retrieve
|
t.Base.MsgType = commonpb.MsgType_Retrieve
|
||||||
@ -454,7 +454,7 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string, channelNum int) error {
|
func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs ...string) error {
|
||||||
retrieveReq := typeutil.Clone(t.RetrieveRequest)
|
retrieveReq := typeutil.Clone(t.RetrieveRequest)
|
||||||
retrieveReq.GetBase().TargetID = nodeID
|
retrieveReq.GetBase().TargetID = nodeID
|
||||||
req := &querypb.QueryRequest{
|
req := &querypb.QueryRequest{
|
||||||
|
|||||||
@ -44,7 +44,7 @@ func TestQueryTask_all(t *testing.T) {
|
|||||||
expr = fmt.Sprintf("%s > 0", testInt64Field)
|
expr = fmt.Sprintf("%s > 0", testInt64Field)
|
||||||
hitNum = 10
|
hitNum = 10
|
||||||
|
|
||||||
errPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string, int) error, map[string][]nodeInfo) error {
|
errPolicy = func(context.Context, *shardClientMgr, queryFunc, map[string][]nodeInfo) error {
|
||||||
return fmt.Errorf("fake error")
|
return fmt.Errorf("fake error")
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -181,7 +181,7 @@ func TestQueryTask_all(t *testing.T) {
|
|||||||
task.queryShardPolicy = errPolicy
|
task.queryShardPolicy = errPolicy
|
||||||
assert.Error(t, task.Execute(ctx))
|
assert.Error(t, task.Execute(ctx))
|
||||||
|
|
||||||
task.queryShardPolicy = mergeRoundRobinPolicy
|
task.queryShardPolicy = RoundRobinPolicy
|
||||||
result1 := &internalpb.RetrieveResults{
|
result1 := &internalpb.RetrieveResults{
|
||||||
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},
|
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},
|
||||||
Status: &commonpb.Status{
|
Status: &commonpb.Status{
|
||||||
|
|||||||
@ -52,6 +52,7 @@ type searchTask struct {
|
|||||||
qc types.QueryCoord
|
qc types.QueryCoord
|
||||||
tr *timerecord.TimeRecorder
|
tr *timerecord.TimeRecorder
|
||||||
collectionName string
|
collectionName string
|
||||||
|
channelNum int32
|
||||||
schema *schemapb.CollectionSchema
|
schema *schemapb.CollectionSchema
|
||||||
|
|
||||||
offset int64
|
offset int64
|
||||||
@ -207,7 +208,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||||||
defer sp.End()
|
defer sp.End()
|
||||||
|
|
||||||
if t.searchShardPolicy == nil {
|
if t.searchShardPolicy == nil {
|
||||||
t.searchShardPolicy = mergeRoundRobinPolicy
|
t.searchShardPolicy = RoundRobinPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Base.MsgType = commonpb.MsgType_Search
|
t.Base.MsgType = commonpb.MsgType_Search
|
||||||
@ -358,6 +359,7 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders))
|
t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders))
|
||||||
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders))
|
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders))
|
||||||
|
t.channelNum = int32(len(shard2Leaders))
|
||||||
if err := t.searchShardPolicy(ctx, t.shardMgr, t.searchShard, shard2Leaders); err != nil {
|
if err := t.searchShardPolicy(ctx, t.shardMgr, t.searchShard, shard2Leaders); err != nil {
|
||||||
log.Warn("failed to do search", zap.Error(err), zap.String("Shards", fmt.Sprintf("%v", shard2Leaders)))
|
log.Warn("failed to do search", zap.Error(err), zap.String("Shards", fmt.Sprintf("%v", shard2Leaders)))
|
||||||
return err
|
return err
|
||||||
@ -439,14 +441,14 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string, channelNum int) error {
|
func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs ...string) error {
|
||||||
searchReq := typeutil.Clone(t.SearchRequest)
|
searchReq := typeutil.Clone(t.SearchRequest)
|
||||||
searchReq.GetBase().TargetID = nodeID
|
searchReq.GetBase().TargetID = nodeID
|
||||||
req := &querypb.SearchRequest{
|
req := &querypb.SearchRequest{
|
||||||
Req: searchReq,
|
Req: searchReq,
|
||||||
DmlChannels: channelIDs,
|
DmlChannels: channelIDs,
|
||||||
Scope: querypb.DataScope_All,
|
Scope: querypb.DataScope_All,
|
||||||
TotalChannelNum: int32(channelNum),
|
TotalChannelNum: t.channelNum,
|
||||||
}
|
}
|
||||||
|
|
||||||
queryNode := querynode.GetQueryNode()
|
queryNode := querynode.GetQueryNode()
|
||||||
|
|||||||
@ -1698,7 +1698,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
|||||||
|
|
||||||
shardsNum = int32(2)
|
shardsNum = int32(2)
|
||||||
collectionName = t.Name() + funcutil.GenRandomStr()
|
collectionName = t.Name() + funcutil.GenRandomStr()
|
||||||
errPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string, int) error, map[string][]nodeInfo) error {
|
errPolicy = func(context.Context, *shardClientMgr, queryFunc, map[string][]nodeInfo) error {
|
||||||
return fmt.Errorf("fake error")
|
return fmt.Errorf("fake error")
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -1820,7 +1820,7 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
|||||||
task.searchShardPolicy = errPolicy
|
task.searchShardPolicy = errPolicy
|
||||||
assert.Error(t, task.Execute(ctx))
|
assert.Error(t, task.Execute(ctx))
|
||||||
|
|
||||||
task.searchShardPolicy = mergeRoundRobinPolicy
|
task.searchShardPolicy = RoundRobinPolicy
|
||||||
qn.searchError = fmt.Errorf("mock error")
|
qn.searchError = fmt.Errorf("mock error")
|
||||||
assert.Error(t, task.Execute(ctx))
|
assert.Error(t, task.Execute(ctx))
|
||||||
|
|
||||||
|
|||||||
@ -109,7 +109,7 @@ func (g *getStatisticsTask) PreExecute(ctx context.Context) error {
|
|||||||
defer sp.End()
|
defer sp.End()
|
||||||
|
|
||||||
if g.statisticShardPolicy == nil {
|
if g.statisticShardPolicy == nil {
|
||||||
g.statisticShardPolicy = mergeRoundRobinPolicy
|
g.statisticShardPolicy = RoundRobinPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Maybe we should create a new MsgType: GetStatistics?
|
// TODO: Maybe we should create a new MsgType: GetStatistics?
|
||||||
@ -299,7 +299,7 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs []string, channelNum int) error {
|
func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64, qn types.QueryNode, channelIDs ...string) error {
|
||||||
nodeReq := proto.Clone(g.GetStatisticsRequest).(*internalpb.GetStatisticsRequest)
|
nodeReq := proto.Clone(g.GetStatisticsRequest).(*internalpb.GetStatisticsRequest)
|
||||||
nodeReq.Base.TargetID = nodeID
|
nodeReq.Base.TargetID = nodeID
|
||||||
req := &querypb.GetStatisticsRequest{
|
req := &querypb.GetStatisticsRequest{
|
||||||
|
|||||||
@ -155,27 +155,27 @@ func TestStatisticTask_all(t *testing.T) {
|
|||||||
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
|
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
|
||||||
|
|
||||||
task.ctx = ctx
|
task.ctx = ctx
|
||||||
task.statisticShardPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string, int) error, map[string][]nodeInfo) error {
|
task.statisticShardPolicy = func(context.Context, *shardClientMgr, queryFunc, map[string][]nodeInfo) error {
|
||||||
return fmt.Errorf("fake error")
|
return fmt.Errorf("fake error")
|
||||||
}
|
}
|
||||||
task.fromQueryNode = true
|
task.fromQueryNode = true
|
||||||
assert.Error(t, task.Execute(ctx))
|
assert.Error(t, task.Execute(ctx))
|
||||||
assert.NoError(t, task.PostExecute(ctx))
|
assert.NoError(t, task.PostExecute(ctx))
|
||||||
|
|
||||||
task.statisticShardPolicy = func(context.Context, *shardClientMgr, func(context.Context, int64, types.QueryNode, []string, int) error, map[string][]nodeInfo) error {
|
task.statisticShardPolicy = func(context.Context, *shardClientMgr, queryFunc, map[string][]nodeInfo) error {
|
||||||
return errInvalidShardLeaders
|
return errInvalidShardLeaders
|
||||||
}
|
}
|
||||||
task.fromQueryNode = true
|
task.fromQueryNode = true
|
||||||
assert.Error(t, task.Execute(ctx))
|
assert.Error(t, task.Execute(ctx))
|
||||||
assert.NoError(t, task.PostExecute(ctx))
|
assert.NoError(t, task.PostExecute(ctx))
|
||||||
|
|
||||||
task.statisticShardPolicy = mergeRoundRobinPolicy
|
task.statisticShardPolicy = RoundRobinPolicy
|
||||||
task.fromQueryNode = true
|
task.fromQueryNode = true
|
||||||
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, errors.New("GetStatistics failed")).Times(3)
|
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, errors.New("GetStatistics failed")).Times(3)
|
||||||
assert.Error(t, task.Execute(ctx))
|
assert.Error(t, task.Execute(ctx))
|
||||||
assert.NoError(t, task.PostExecute(ctx))
|
assert.NoError(t, task.PostExecute(ctx))
|
||||||
|
|
||||||
task.statisticShardPolicy = mergeRoundRobinPolicy
|
task.statisticShardPolicy = RoundRobinPolicy
|
||||||
task.fromQueryNode = true
|
task.fromQueryNode = true
|
||||||
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(&internalpb.GetStatisticsResponse{
|
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(&internalpb.GetStatisticsResponse{
|
||||||
Status: &commonpb.Status{
|
Status: &commonpb.Status{
|
||||||
@ -186,7 +186,7 @@ func TestStatisticTask_all(t *testing.T) {
|
|||||||
assert.Error(t, task.Execute(ctx))
|
assert.Error(t, task.Execute(ctx))
|
||||||
assert.NoError(t, task.PostExecute(ctx))
|
assert.NoError(t, task.PostExecute(ctx))
|
||||||
|
|
||||||
task.statisticShardPolicy = mergeRoundRobinPolicy
|
task.statisticShardPolicy = RoundRobinPolicy
|
||||||
task.fromQueryNode = true
|
task.fromQueryNode = true
|
||||||
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(&internalpb.GetStatisticsResponse{
|
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(&internalpb.GetStatisticsResponse{
|
||||||
Status: &commonpb.Status{
|
Status: &commonpb.Status{
|
||||||
@ -197,7 +197,7 @@ func TestStatisticTask_all(t *testing.T) {
|
|||||||
assert.Error(t, task.Execute(ctx))
|
assert.Error(t, task.Execute(ctx))
|
||||||
assert.NoError(t, task.PostExecute(ctx))
|
assert.NoError(t, task.PostExecute(ctx))
|
||||||
|
|
||||||
task.statisticShardPolicy = mergeRoundRobinPolicy
|
task.statisticShardPolicy = RoundRobinPolicy
|
||||||
task.fromQueryNode = true
|
task.fromQueryNode = true
|
||||||
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, nil).Once()
|
qn.EXPECT().GetStatistics(mock.Anything, mock.Anything).Return(nil, nil).Once()
|
||||||
assert.NoError(t, task.Execute(ctx))
|
assert.NoError(t, task.Execute(ctx))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user