mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
Make proxy use roundrobin to choose replica (#17063)
Fixes: #17055 Signed-off-by: yangxuan <xuan.yang@zilliz.com>
This commit is contained in:
parent
b37b87eb97
commit
127dd34b37
@ -712,7 +712,7 @@ func (c *ChannelManager) Release(nodeID UniqueID, channelName string) error {
|
|||||||
|
|
||||||
toReleaseChannel := c.getChannelByNodeAndName(nodeID, channelName)
|
toReleaseChannel := c.getChannelByNodeAndName(nodeID, channelName)
|
||||||
if toReleaseChannel == nil {
|
if toReleaseChannel == nil {
|
||||||
return fmt.Errorf("fail to find matching nodID: %d with channelName: %s", nodeID, channelName)
|
return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
toReleaseUpdates := getReleaseOp(nodeID, toReleaseChannel)
|
toReleaseUpdates := getReleaseOp(nodeID, toReleaseChannel)
|
||||||
@ -731,7 +731,7 @@ func (c *ChannelManager) toDelete(nodeID UniqueID, channelName string) error {
|
|||||||
|
|
||||||
ch := c.getChannelByNodeAndName(nodeID, channelName)
|
ch := c.getChannelByNodeAndName(nodeID, channelName)
|
||||||
if ch == nil {
|
if ch == nil {
|
||||||
return fmt.Errorf("fail to find matching nodID: %d with channelName: %s", nodeID, channelName)
|
return fmt.Errorf("fail to find matching nodeID: %d with channelName: %s", nodeID, channelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.isMarkedDrop(channelName) {
|
if !c.isMarkedDrop(channelName) {
|
||||||
|
|||||||
@ -53,7 +53,7 @@ type Cache interface {
|
|||||||
GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error)
|
GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error)
|
||||||
// GetCollectionSchema get collection's schema.
|
// GetCollectionSchema get collection's schema.
|
||||||
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
|
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
|
||||||
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error)
|
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error)
|
||||||
ClearShards(collectionName string)
|
ClearShards(collectionName string)
|
||||||
RemoveCollection(ctx context.Context, collectionName string)
|
RemoveCollection(ctx context.Context, collectionName string)
|
||||||
RemovePartition(ctx context.Context, collectionName string, partitionName string)
|
RemovePartition(ctx context.Context, collectionName string, partitionName string)
|
||||||
@ -70,7 +70,7 @@ type collectionInfo struct {
|
|||||||
collID typeutil.UniqueID
|
collID typeutil.UniqueID
|
||||||
schema *schemapb.CollectionSchema
|
schema *schemapb.CollectionSchema
|
||||||
partInfo map[string]*partitionInfo
|
partInfo map[string]*partitionInfo
|
||||||
shardLeaders []*querypb.ShardLeadersList
|
shardLeaders map[string][]queryNode
|
||||||
createdTimestamp uint64
|
createdTimestamp uint64
|
||||||
createdUtcTimestamp uint64
|
createdUtcTimestamp uint64
|
||||||
}
|
}
|
||||||
@ -528,7 +528,7 @@ func (m *MetaCache) GetCredUsernames(ctx context.Context) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetShards update cache if withCache == false
|
// GetShards update cache if withCache == false
|
||||||
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error) {
|
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) (map[string][]queryNode, error) {
|
||||||
info, err := m.GetCollectionInfo(ctx, collectionName)
|
info, err := m.GetCollectionInfo(ctx, collectionName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -536,7 +536,12 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
|
|||||||
|
|
||||||
if withCache {
|
if withCache {
|
||||||
if len(info.shardLeaders) > 0 {
|
if len(info.shardLeaders) > 0 {
|
||||||
return info.shardLeaders, nil
|
shards := updateShardsWithRoundRobin(info.shardLeaders)
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.collInfo[collectionName].shardLeaders = shards
|
||||||
|
m.mu.Unlock()
|
||||||
|
return shards, nil
|
||||||
}
|
}
|
||||||
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
|
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
|
||||||
zap.String("collectionName", collectionName))
|
zap.String("collectionName", collectionName))
|
||||||
@ -557,7 +562,9 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
|
|||||||
return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
|
return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
|
||||||
}
|
}
|
||||||
|
|
||||||
shards := resp.GetShards()
|
shards := parseShardLeaderList2QueryNode(resp.GetShards())
|
||||||
|
|
||||||
|
shards = updateShardsWithRoundRobin(shards)
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
m.collInfo[collectionName].shardLeaders = shards
|
m.collInfo[collectionName].shardLeaders = shards
|
||||||
@ -566,6 +573,22 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam
|
|||||||
return shards, nil
|
return shards, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]queryNode {
|
||||||
|
shard2QueryNodes := make(map[string][]queryNode)
|
||||||
|
|
||||||
|
for _, leaders := range shardsLeaders {
|
||||||
|
qns := make([]queryNode, len(leaders.GetNodeIds()))
|
||||||
|
|
||||||
|
for j := range qns {
|
||||||
|
qns[j] = queryNode{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j]}
|
||||||
|
}
|
||||||
|
|
||||||
|
shard2QueryNodes[leaders.GetChannelName()] = qns
|
||||||
|
}
|
||||||
|
|
||||||
|
return shard2QueryNodes
|
||||||
|
}
|
||||||
|
|
||||||
// ClearShards clear the shard leader cache of a collection
|
// ClearShards clear the shard leader cache of a collection
|
||||||
func (m *MetaCache) ClearShards(collectionName string) {
|
func (m *MetaCache) ClearShards(collectionName string) {
|
||||||
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))
|
log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName))
|
||||||
|
|||||||
@ -344,8 +344,8 @@ func TestMetaCache_GetShards(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotEmpty(t, shards)
|
assert.NotEmpty(t, shards)
|
||||||
assert.Equal(t, 1, len(shards))
|
assert.Equal(t, 1, len(shards))
|
||||||
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
|
|
||||||
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
|
assert.Equal(t, 3, len(shards["channel-1"]))
|
||||||
|
|
||||||
// get from cache
|
// get from cache
|
||||||
qc.validShardLeaders = false
|
qc.validShardLeaders = false
|
||||||
@ -353,8 +353,7 @@ func TestMetaCache_GetShards(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotEmpty(t, shards)
|
assert.NotEmpty(t, shards)
|
||||||
assert.Equal(t, 1, len(shards))
|
assert.Equal(t, 1, len(shards))
|
||||||
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
|
assert.Equal(t, 3, len(shards["channel-1"]))
|
||||||
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -387,8 +386,7 @@ func TestMetaCache_ClearShards(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, shards)
|
require.NotEmpty(t, shards)
|
||||||
require.Equal(t, 1, len(shards))
|
require.Equal(t, 1, len(shards))
|
||||||
require.Equal(t, 3, len(shards[0].GetNodeAddrs()))
|
require.Equal(t, 3, len(shards["channel-1"]))
|
||||||
require.Equal(t, 3, len(shards[0].GetNodeIds()))
|
|
||||||
|
|
||||||
globalMetaCache.ClearShards(collectionName)
|
globalMetaCache.ClearShards(collectionName)
|
||||||
|
|
||||||
|
|||||||
@ -3,11 +3,11 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
|
qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/log"
|
"github.com/milvus-io/milvus/internal/log"
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@ -15,7 +15,7 @@ import (
|
|||||||
|
|
||||||
type getQueryNodePolicy func(context.Context, string) (types.QueryNode, error)
|
type getQueryNodePolicy func(context.Context, string) (types.QueryNode, error)
|
||||||
|
|
||||||
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error
|
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error
|
||||||
|
|
||||||
// TODO add another policy to enbale the use of cache
|
// TODO add another policy to enbale the use of cache
|
||||||
// defaultGetQueryNodePolicy creates QueryNode client for every address everytime
|
// defaultGetQueryNodePolicy creates QueryNode client for every address everytime
|
||||||
@ -40,23 +40,45 @@ var (
|
|||||||
errInvalidShardLeaders = errors.New("Invalid shard leader")
|
errInvalidShardLeaders = errors.New("Invalid shard leader")
|
||||||
)
|
)
|
||||||
|
|
||||||
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error {
|
type queryNode struct {
|
||||||
|
nodeID UniqueID
|
||||||
|
address string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q queryNode) String() string {
|
||||||
|
return fmt.Sprintf("<NodeID: %d>", q.nodeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateShardsWithRoundRobin(shardsLeaders map[string][]queryNode) map[string][]queryNode {
|
||||||
|
|
||||||
|
for channelID, leaders := range shardsLeaders {
|
||||||
|
if len(leaders) <= 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
shardsLeaders[channelID] = append(leaders[1:], leaders[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return shardsLeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders []queryNode) error {
|
||||||
var (
|
var (
|
||||||
err = errBegin
|
err = errBegin
|
||||||
current = 0
|
current = 0
|
||||||
qn types.QueryNode
|
qn types.QueryNode
|
||||||
)
|
)
|
||||||
replicaNum := len(leaders.GetNodeIds())
|
replicaNum := len(leaders)
|
||||||
|
|
||||||
for err != nil && current < replicaNum {
|
for err != nil && current < replicaNum {
|
||||||
currentID := leaders.GetNodeIds()[current]
|
currentID := leaders[current].nodeID
|
||||||
if err != errBegin {
|
if err != errBegin {
|
||||||
log.Warn("retry with another QueryNode",
|
log.Warn("retry with another QueryNode",
|
||||||
zap.Int("retries numbers", current),
|
zap.Int("retries numbers", current),
|
||||||
zap.String("leader", leaders.GetChannelName()), zap.Int64("nodeID", currentID))
|
zap.Int64("nodeID", currentID))
|
||||||
}
|
}
|
||||||
|
|
||||||
qn, err = getQueryNodePolicy(ctx, leaders.GetNodeAddrs()[current])
|
qn, err = getQueryNodePolicy(ctx, leaders[current].address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("fail to get valid QueryNode", zap.Int64("nodeID", currentID),
|
log.Warn("fail to get valid QueryNode", zap.Int64("nodeID", currentID),
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
@ -68,7 +90,6 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
|
|||||||
err = query(currentID, qn)
|
err = query(currentID, qn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("fail to Query with shard leader",
|
log.Warn("fail to Query with shard leader",
|
||||||
zap.String("leader", leaders.GetChannelName()),
|
|
||||||
zap.Int64("nodeID", currentID),
|
zap.Int64("nodeID", currentID),
|
||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
}
|
}
|
||||||
@ -76,9 +97,8 @@ func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy
|
|||||||
}
|
}
|
||||||
|
|
||||||
if current == replicaNum && err != nil {
|
if current == replicaNum && err != nil {
|
||||||
log.Warn("no shard leaders available for channel",
|
log.Warn("no shard leaders available",
|
||||||
zap.String("channel name", leaders.GetChannelName()),
|
zap.String("leaders", fmt.Sprintf("%v", leaders)), zap.Error(err))
|
||||||
zap.Int64s("leaders", leaders.GetNodeIds()), zap.Error(err))
|
|
||||||
// needs to return the error from query
|
// needs to return the error from query
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,11 +5,53 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
"github.com/milvus-io/milvus/internal/log"
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestUpdateShardsWithRoundRobin(t *testing.T) {
|
||||||
|
in := map[string][]queryNode{
|
||||||
|
"channel-1": {
|
||||||
|
{1, "addr1"},
|
||||||
|
{2, "addr2"},
|
||||||
|
},
|
||||||
|
"channel-2": {
|
||||||
|
{20, "addr20"},
|
||||||
|
{21, "addr21"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
out := updateShardsWithRoundRobin(in)
|
||||||
|
|
||||||
|
assert.Equal(t, int64(2), out["channel-1"][0].nodeID)
|
||||||
|
assert.Equal(t, "addr2", out["channel-1"][0].address)
|
||||||
|
assert.Equal(t, int64(21), out["channel-2"][0].nodeID)
|
||||||
|
assert.Equal(t, "addr21", out["channel-2"][0].address)
|
||||||
|
|
||||||
|
t.Run("check print", func(t *testing.T) {
|
||||||
|
qns := []queryNode{
|
||||||
|
{1, "addr1"},
|
||||||
|
{2, "addr2"},
|
||||||
|
{20, "addr20"},
|
||||||
|
{21, "addr21"},
|
||||||
|
}
|
||||||
|
|
||||||
|
res := fmt.Sprintf("list: %v", qns)
|
||||||
|
|
||||||
|
log.Debug("Check String func",
|
||||||
|
zap.Any("Any", qns),
|
||||||
|
zap.Any("ok", qns[0]),
|
||||||
|
zap.String("ok2", res),
|
||||||
|
)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestRoundRobinPolicy(t *testing.T) {
|
func TestRoundRobinPolicy(t *testing.T) {
|
||||||
var (
|
var (
|
||||||
getQueryNodePolicy = mockGetQueryNodePolicy
|
getQueryNodePolicy = mockGetQueryNodePolicy
|
||||||
@ -31,11 +73,12 @@ func TestRoundRobinPolicy(t *testing.T) {
|
|||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
query := (&mockQuery{isvalid: false}).query
|
query := (&mockQuery{isvalid: false}).query
|
||||||
|
|
||||||
leaders := &querypb.ShardLeadersList{
|
leaders := make([]queryNode, 0, len(test.leaderIDs))
|
||||||
ChannelName: t.Name(),
|
for _, ID := range test.leaderIDs {
|
||||||
NodeIds: test.leaderIDs,
|
leaders = append(leaders, queryNode{ID, "random-addr"})
|
||||||
NodeAddrs: make([]string, len(test.leaderIDs)),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
})
|
})
|
||||||
@ -55,10 +98,10 @@ func TestRoundRobinPolicy(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range allPassTests {
|
for _, test := range allPassTests {
|
||||||
query := (&mockQuery{isvalid: true}).query
|
query := (&mockQuery{isvalid: true}).query
|
||||||
leaders := &querypb.ShardLeadersList{
|
leaders := make([]queryNode, 0, len(test.leaderIDs))
|
||||||
ChannelName: t.Name(),
|
for _, ID := range test.leaderIDs {
|
||||||
NodeIds: test.leaderIDs,
|
leaders = append(leaders, queryNode{ID, "random-addr"})
|
||||||
NodeAddrs: make([]string, len(test.leaderIDs)),
|
|
||||||
}
|
}
|
||||||
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -77,10 +120,10 @@ func TestRoundRobinPolicy(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range passAtLast {
|
for _, test := range passAtLast {
|
||||||
query := (&mockQuery{isvalid: true}).query
|
query := (&mockQuery{isvalid: true}).query
|
||||||
leaders := &querypb.ShardLeadersList{
|
leaders := make([]queryNode, 0, len(test.leaderIDs))
|
||||||
ChannelName: t.Name(),
|
for _, ID := range test.leaderIDs {
|
||||||
NodeIds: test.leaderIDs,
|
leaders = append(leaders, queryNode{ID, "random-addr"})
|
||||||
NodeAddrs: make([]string, len(test.leaderIDs)),
|
|
||||||
}
|
}
|
||||||
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
err := roundRobinPolicy(ctx, getQueryNodePolicy, query, leaders)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@ -244,16 +244,17 @@ func (t *queryTask) Execute(ctx context.Context) error {
|
|||||||
t.resultBuf = make(chan *internalpb.RetrieveResults, len(shards))
|
t.resultBuf = make(chan *internalpb.RetrieveResults, len(shards))
|
||||||
t.toReduceResults = make([]*internalpb.RetrieveResults, 0, len(shards))
|
t.toReduceResults = make([]*internalpb.RetrieveResults, 0, len(shards))
|
||||||
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
|
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
|
||||||
for _, shard := range shards {
|
for channelID, leaders := range shards {
|
||||||
s := shard
|
channelID := channelID
|
||||||
|
leaders := leaders
|
||||||
t.runningGroup.Go(func() error {
|
t.runningGroup.Go(func() error {
|
||||||
log.Debug("proxy starting to query one shard",
|
log.Debug("proxy starting to query one shard",
|
||||||
zap.Int64("collectionID", t.CollectionID),
|
zap.Int64("collectionID", t.CollectionID),
|
||||||
zap.String("collection name", t.collectionName),
|
zap.String("collection name", t.collectionName),
|
||||||
zap.String("shard channel", s.GetChannelName()),
|
zap.String("shard channel", channelID),
|
||||||
zap.Uint64("timeoutTs", t.TimeoutTimestamp))
|
zap.Uint64("timeoutTs", t.TimeoutTimestamp))
|
||||||
|
|
||||||
err := t.queryShard(t.runningGroupCtx, s)
|
err := t.queryShard(t.runningGroupCtx, leaders, channelID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -344,12 +345,12 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *queryTask) queryShard(ctx context.Context, leaders *querypb.ShardLeadersList) error {
|
func (t *queryTask) queryShard(ctx context.Context, leaders []queryNode, channelID string) error {
|
||||||
query := func(nodeID UniqueID, qn types.QueryNode) error {
|
query := func(nodeID UniqueID, qn types.QueryNode) error {
|
||||||
req := &querypb.QueryRequest{
|
req := &querypb.QueryRequest{
|
||||||
Req: t.RetrieveRequest,
|
Req: t.RetrieveRequest,
|
||||||
IsShardLeader: true,
|
IsShardLeader: true,
|
||||||
DmlChannel: leaders.GetChannelName(),
|
DmlChannel: channelID,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := qn.Query(ctx, req)
|
result, err := qn.Query(ctx, req)
|
||||||
@ -364,14 +365,14 @@ func (t *queryTask) queryShard(ctx context.Context, leaders *querypb.ShardLeader
|
|||||||
return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason())
|
return fmt.Errorf("fail to Query, QueryNode ID = %d, reason=%s", nodeID, result.GetStatus().GetReason())
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", leaders.GetChannelName()))
|
log.Debug("get query result", zap.Int64("nodeID", nodeID), zap.String("channelID", channelID))
|
||||||
t.resultBuf <- result
|
t.resultBuf <- result
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := t.queryShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, query, leaders)
|
err := t.queryShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, query, leaders)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders.GetNodeIds()))
|
log.Warn("fail to Query to all shard leaders", zap.Int64("taskID", t.ID()), zap.Any("shard leaders", leaders))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -258,27 +258,28 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
|||||||
defer tr.Elapse("done")
|
defer tr.Elapse("done")
|
||||||
|
|
||||||
executeSearch := func(withCache bool) error {
|
executeSearch := func(withCache bool) error {
|
||||||
shards, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
|
shard2Leaders, err := globalMetaCache.GetShards(ctx, withCache, t.collectionName, t.qc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
t.resultBuf = make(chan *internalpb.SearchResults, len(shards))
|
t.resultBuf = make(chan *internalpb.SearchResults, len(shard2Leaders))
|
||||||
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shards))
|
t.toReduceResults = make([]*internalpb.SearchResults, 0, len(shard2Leaders))
|
||||||
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
|
t.runningGroup, t.runningGroupCtx = errgroup.WithContext(ctx)
|
||||||
|
|
||||||
// TODO: try to merge rpc send to different shard leaders.
|
// TODO: try to merge rpc send to different shard leaders.
|
||||||
// If two shard leader is on the same querynode maybe we should merge request to save rpc
|
// If two shard leader is on the same querynode maybe we should merge request to save rpc
|
||||||
for _, shard := range shards {
|
for channelID, leaders := range shard2Leaders {
|
||||||
s := shard
|
channelID := channelID
|
||||||
|
leaders := leaders
|
||||||
t.runningGroup.Go(func() error {
|
t.runningGroup.Go(func() error {
|
||||||
log.Debug("proxy starting to query one shard",
|
log.Debug("proxy starting to query one shard",
|
||||||
zap.Int64("collectionID", t.CollectionID),
|
zap.Int64("collectionID", t.CollectionID),
|
||||||
zap.String("collection name", t.collectionName),
|
zap.String("collection name", t.collectionName),
|
||||||
zap.String("shard channel", s.GetChannelName()),
|
zap.String("shard channel", channelID),
|
||||||
zap.Uint64("timeoutTs", t.TimeoutTimestamp))
|
zap.Uint64("timeoutTs", t.TimeoutTimestamp))
|
||||||
|
|
||||||
err := t.searchShard(t.runningGroupCtx, s)
|
err := t.searchShard(t.runningGroupCtx, leaders, channelID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -393,13 +394,13 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *searchTask) searchShard(ctx context.Context, leaders *querypb.ShardLeadersList) error {
|
func (t *searchTask) searchShard(ctx context.Context, leaders []queryNode, channelID string) error {
|
||||||
|
|
||||||
search := func(nodeID UniqueID, qn types.QueryNode) error {
|
search := func(nodeID UniqueID, qn types.QueryNode) error {
|
||||||
req := &querypb.SearchRequest{
|
req := &querypb.SearchRequest{
|
||||||
Req: t.SearchRequest,
|
Req: t.SearchRequest,
|
||||||
IsShardLeader: true,
|
IsShardLeader: true,
|
||||||
DmlChannel: leaders.GetChannelName(),
|
DmlChannel: channelID,
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := qn.Search(ctx, req)
|
result, err := qn.Search(ctx, req)
|
||||||
@ -420,7 +421,7 @@ func (t *searchTask) searchShard(ctx context.Context, leaders *querypb.ShardLead
|
|||||||
|
|
||||||
err := t.searchShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, search, leaders)
|
err := t.searchShardPolicy(t.TraceCtx(), t.getQueryNodePolicy, search, leaders)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("fail to search to all shard leaders", zap.Any("shard leaders", leaders.GetNodeIds()))
|
log.Warn("fail to search to all shard leaders", zap.Any("shard leaders", leaders))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user