Support configurable policy for query node, Add user level schedule policy (#23718)

Signed-off-by: chyezh <ye.zhen@zilliz.com>
This commit is contained in:
chyezh 2023-05-16 10:55:22 +08:00 committed by GitHub
parent bc86aa666f
commit 20054dc42b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 767 additions and 83 deletions

View File

@ -242,6 +242,18 @@ queryNode:
# if the lag is larger than this config, scheduler will return error without waiting.
# the valid value is [3600, infinite)
maxTimestampLag: 86400
# read task schedule policy: fifo(by default), user-task-polling.
scheduleReadPolicy:
# fifo: A FIFO queue support the schedule.
# user-task-polling:
# The user's tasks will be polled one by one and scheduled.
# Scheduling is fair on task granularity.
# The policy is based on the username for authentication.
# And an empty username is considered the same user.
# When there are no multi-users, the policy decay into FIFO
name: fifo
# user-task-polling configure:
taskQueueExpire: 60 # 1 min by default, expire time of inner user task queue since queue is empty.
grouping:
enabled: true

View File

@ -149,6 +149,20 @@ var (
queryTypeLabelName,
})
QueryNodeSQPerUserLatencyInQueue = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: milvusNamespace,
Subsystem: typeutil.QueryNodeRole,
Name: "sq_queue_user_latency",
Help: "latency per user of search or query in queue",
Buckets: buckets,
}, []string{
nodeIDLabelName,
queryTypeLabelName,
usernameLabelName,
},
)
QueryNodeSQSegmentLatency = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: milvusNamespace,
@ -380,6 +394,7 @@ func RegisterQueryNode(registry *prometheus.Registry) {
registry.MustRegister(QueryNodeSQCount)
registry.MustRegister(QueryNodeSQReqLatency)
registry.MustRegister(QueryNodeSQLatencyInQueue)
registry.MustRegister(QueryNodeSQPerUserLatencyInQueue)
registry.MustRegister(QueryNodeSQSegmentLatency)
registry.MustRegister(QueryNodeSQSegmentLatencyInCore)
registry.MustRegister(QueryNodeReduceLatency)
@ -422,5 +437,4 @@ func CleanupQueryNodeCollectionMetrics(nodeID int64, collectionID int64) {
collectionIDLabelName: fmt.Sprint(collectionID),
})
}
}

View File

@ -41,7 +41,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
})
assert.NotNil(t, err)
ctx = GetContext(context.Background(), "alice:123456")
ctx = getContextWithAuthorization(context.Background(), "alice:123456")
client := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
mgr := newShardClientMgr()
@ -65,13 +65,13 @@ func TestPrivilegeInterceptor(t *testing.T) {
}, nil
}
_, err = PrivilegeInterceptor(GetContext(context.Background(), "foo:123456"), &milvuspb.LoadCollectionRequest{
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "foo:123456"), &milvuspb.LoadCollectionRequest{
DbName: "db_test",
CollectionName: "col1",
})
assert.NotNil(t, err)
_, err = PrivilegeInterceptor(GetContext(context.Background(), "root:123456"), &milvuspb.LoadCollectionRequest{
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "root:123456"), &milvuspb.LoadCollectionRequest{
DbName: "db_test",
CollectionName: "col1",
})
@ -103,7 +103,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
})
assert.Nil(t, err)
fooCtx := GetContext(context.Background(), "foo:123456")
fooCtx := getContextWithAuthorization(context.Background(), "foo:123456")
_, err = PrivilegeInterceptor(fooCtx, &milvuspb.LoadCollectionRequest{
DbName: "db_test",
CollectionName: "col1",
@ -130,7 +130,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
})
assert.Nil(t, err)
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
DbName: "db_test",
CollectionName: "col1",
})
@ -148,7 +148,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
go func() {
defer g.Done()
assert.NotPanics(t, func() {
PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
DbName: "db_test",
CollectionName: "col1",
})
@ -161,7 +161,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
getPolicyModel("foo")
})
})
}
func TestResourceGroupPrivilege(t *testing.T) {
@ -173,7 +172,7 @@ func TestResourceGroupPrivilege(t *testing.T) {
_, err := PrivilegeInterceptor(ctx, &milvuspb.ListResourceGroupsRequest{})
assert.NotNil(t, err)
ctx = GetContext(context.Background(), "fooo:123456")
ctx = getContextWithAuthorization(context.Background(), "fooo:123456")
client := &MockRootCoordClientInterface{}
queryCoord := &MockQueryCoordClientInterface{}
mgr := newShardClientMgr()
@ -198,29 +197,28 @@ func TestResourceGroupPrivilege(t *testing.T) {
}
InitMetaCache(ctx, client, queryCoord, mgr)
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.CreateResourceGroupRequest{
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.CreateResourceGroupRequest{
ResourceGroup: "rg",
})
assert.Nil(t, err)
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.DropResourceGroupRequest{
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.DropResourceGroupRequest{
ResourceGroup: "rg",
})
assert.Nil(t, err)
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.DescribeResourceGroupRequest{
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.DescribeResourceGroupRequest{
ResourceGroup: "rg",
})
assert.Nil(t, err)
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.ListResourceGroupsRequest{})
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.ListResourceGroupsRequest{})
assert.Nil(t, err)
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.TransferNodeRequest{})
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.TransferNodeRequest{})
assert.Nil(t, err)
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.TransferReplicaRequest{})
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.TransferReplicaRequest{})
assert.Nil(t, err)
})
}

View File

@ -45,7 +45,7 @@ func TestProxyRpcLimit(t *testing.T) {
t.Setenv("ROCKSMQ_PATH", path)
defer os.RemoveAll(path)
ctx := GetContext(context.Background(), "root:123456")
ctx := getContextWithAuthorization(context.Background(), "root:123456")
localMsg := true
factory := dependency.NewDefaultFactory(localMsg)

View File

@ -353,12 +353,12 @@ func (s *proxyTestServer) GetStatisticsChannel(ctx context.Context, request *int
func (s *proxyTestServer) startGrpc(ctx context.Context, wg *sync.WaitGroup, p paramtable.GrpcServerConfig) {
defer wg.Done()
var kaep = keepalive.EnforcementPolicy{
kaep := keepalive.EnforcementPolicy{
MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection
PermitWithoutStream: true, // Allow pings even when there are no active streams
}
var kasp = keepalive.ServerParameters{
kasp := keepalive.ServerParameters{
Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active
Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead
}
@ -425,7 +425,7 @@ func TestProxy(t *testing.T) {
defer os.RemoveAll(path)
ctx, cancel := context.WithCancel(context.Background())
ctx = GetContext(ctx, "root:123456")
ctx = getContextWithAuthorization(ctx, "root:123456")
localMsg := true
Params.InitOnce()
factory := dependency.MockDefaultFactory(localMsg, &Params)
@ -758,7 +758,6 @@ func TestProxy(t *testing.T) {
resp, err = proxy.CreateCollection(ctx, reqInvalidField)
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
})
wg.Add(1)
@ -842,7 +841,6 @@ func TestProxy(t *testing.T) {
DbName: dbName,
CollectionName: collectionName,
})
})
wg.Add(1)
@ -1423,7 +1421,7 @@ func TestProxy(t *testing.T) {
},
}
//resp, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{
// resp, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{
_, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{
Base: nil,
OpLeft: opLeft,
@ -2109,7 +2107,7 @@ func TestProxy(t *testing.T) {
t.Run("credential UPDATE api", func(t *testing.T) {
defer wg.Done()
rootCtx := ctx
fooCtx := GetContext(context.Background(), "foo:123456")
fooCtx := getContextWithAuthorization(context.Background(), "foo:123456")
ctx = fooCtx
originUsers := Params.CommonCfg.SuperUsers
Params.CommonCfg.SuperUsers = []string{"root"}
@ -3863,7 +3861,6 @@ func TestProxy_ListImportTasks(t *testing.T) {
}
func TestProxy_GetStatistics(t *testing.T) {
}
func TestProxy_GetLoadState(t *testing.T) {

View File

@ -21,6 +21,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/contextutil"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/timerecord"
@ -229,7 +230,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
log.Ctx(ctx).Debug("Validate partition names.",
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
//fetch search_growing from search param
// fetch search_growing from search param
var ignoreGrowing bool
for i, kv := range t.request.GetQueryParams() {
if kv.GetKey() == IgnoreGrowingKey {
@ -358,6 +359,11 @@ func (t *queryTask) Execute(ctx context.Context) error {
defer tr.CtxElapse(ctx, "done")
log := log.Ctx(ctx)
// Add user name into context if it's exists.
if username, _ := GetCurUserFromContext(ctx); username != "" {
ctx = contextutil.WithUserInGrpcMetadata(ctx, username)
}
executeQuery := func() error {
shards, err := globalMetaCache.GetShards(ctx, true, t.collectionName)
if err != nil {

View File

@ -23,6 +23,7 @@ import (
"github.com/milvus-io/milvus/internal/querynode"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/contextutil"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/retry"
@ -255,7 +256,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
log.Ctx(ctx).Debug("translate output fields", zap.Int64("msgID", t.ID()),
zap.Strings("output fields", t.request.GetOutputFields()))
//fetch search_growing from search param
// fetch search_growing from search param
var ignoreGrowing bool
for i, kv := range t.request.GetSearchParams() {
if kv.GetKey() == IgnoreGrowingKey {
@ -380,6 +381,11 @@ func (t *searchTask) Execute(ctx context.Context) error {
defer tr.CtxElapse(ctx, "done")
log := log.Ctx(ctx)
// Add user name into context if it's exists.
if username, _ := GetCurUserFromContext(ctx); username != "" {
ctx = contextutil.WithUserInGrpcMetadata(ctx, username)
}
executeSearch := func() error {
shard2Leaders, err := globalMetaCache.GetShards(ctx, true, t.collectionName)
if err != nil {
@ -723,7 +729,7 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
return ret, err
}
//printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
}
var (

View File

@ -711,7 +711,7 @@ func TestIsDefaultRole(t *testing.T) {
assert.Equal(t, false, IsDefaultRole("manager"))
}
func GetContext(ctx context.Context, originValue string) context.Context {
func getContextWithAuthorization(ctx context.Context, originValue string) context.Context {
authKey := strings.ToLower(util.HeaderAuthorize)
authValue := crypto.Base64Encode(originValue)
contextMap := map[string]string{
@ -740,12 +740,12 @@ func TestGetCurUserFromContext(t *testing.T) {
_, err = GetCurUserFromContext(metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{})))
assert.NotNil(t, err)
_, err = GetCurUserFromContext(GetContext(context.Background(), "123456"))
_, err = GetCurUserFromContext(getContextWithAuthorization(context.Background(), "123456"))
assert.NotNil(t, err)
root := "root"
password := "123456"
username, err := GetCurUserFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password)))
username, err := GetCurUserFromContext(getContextWithAuthorization(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password)))
assert.Nil(t, err)
assert.Equal(t, "root", username)
}

View File

@ -26,6 +26,7 @@ import (
"time"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/contextutil"
"github.com/milvus-io/milvus/internal/util/errorutil"
"github.com/golang/protobuf/proto"
@ -875,13 +876,19 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
metrics.SearchLabel).Observe(float64(historicalTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.SearchLabel).Observe(float64(historicalTask.reduceDur.Milliseconds()))
// In queue latency per user.
metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues(
fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.SearchLabel,
contextutil.GetUserFromGrpcMetadata(historicalTask.Ctx()),
).Observe(float64(historicalTask.queueDur.Milliseconds()))
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
return historicalTask.Ret, nil
}
//from Proxy
// from Proxy
tr := timerecord.NewTimeRecorder("SearchShard")
log.Ctx(ctx).Debug("start do search",
zap.String("vChannel", dmlChannel),
@ -900,6 +907,8 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
errCluster error
)
defer cancel()
// Passthrough username across grpc.
searchCtx = contextutil.PassthroughUserInGrpcMetadata(searchCtx)
// optimize search params
if req.Req.GetSerializedExprPlan() != nil && node.queryHook != nil {
@ -1046,6 +1055,12 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
metrics.QueryLabel).Observe(float64(queryTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.QueryLabel).Observe(float64(queryTask.reduceDur.Milliseconds()))
// In queue latency per user.
metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues(
fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.QueryLabel,
contextutil.GetUserFromGrpcMetadata(queryTask.Ctx()),
).Observe(float64(queryTask.queueDur.Milliseconds()))
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
@ -1063,6 +1078,9 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
queryCtx, cancel := context.WithCancel(ctx)
defer cancel()
// Passthrough username across grpc.
queryCtx = contextutil.PassthroughUserInGrpcMetadata(queryCtx)
var results []*internalpb.RetrieveResults
var errCluster error
var withStreamingFunc queryWithStreaming

View File

@ -2,23 +2,313 @@ package querynode
import (
"container/list"
"container/ring"
"time"
"github.com/milvus-io/milvus/internal/util/contextutil"
)
type scheduleReadTaskPolicy func(sqTasks *list.List, targetUsage int32, maxNum int32) ([]readTask, int32)
const (
scheduleReadPolicyNameFIFO = "fifo"
scheduleReadPolicyNameUserTaskPolling = "user-task-polling"
)
func defaultScheduleReadPolicy(sqTasks *list.List, targetUsage int32, maxNum int32) ([]readTask, int32) {
var (
_ scheduleReadPolicy = &fifoScheduleReadPolicy{}
_ scheduleReadPolicy = &userTaskPollingScheduleReadPolicy{}
)
type scheduleReadPolicy interface {
// Read task count inside.
len() int
// Add a new task into ready task.
addTask(rt readTask)
// Merge a new task into exist ready task.
// Return true if merge success.
mergeTask(rt readTask) bool
// Schedule new task to run.
schedule(targetCPUUsage int32, maxNum int32) ([]readTask, int32)
}
// Create a new schedule policy.
func newReadScheduleTaskPolicy(policyName string) scheduleReadPolicy {
switch policyName {
case "":
fallthrough
case scheduleReadPolicyNameFIFO:
return newFIFOScheduleReadPolicy()
case scheduleReadPolicyNameUserTaskPolling:
return newUserTaskPollingScheduleReadPolicy(
Params.QueryNodeCfg.ScheduleReadPolicy.UserTaskPolling.TaskQueueExpire)
default:
panic("invalid read schedule task policy")
}
}
// Create a new user based task queue.
func newUserBasedTaskQueue(username string) *userBasedTaskQueue {
return &userBasedTaskQueue{
cleanupTimestamp: time.Now(),
queue: list.New(),
username: username,
}
}
// User based task queue.
type userBasedTaskQueue struct {
cleanupTimestamp time.Time
queue *list.List
username string
}
// Get length of user task
func (q *userBasedTaskQueue) len() int {
return q.queue.Len()
}
// Add a new task to end of queue.
func (q *userBasedTaskQueue) push(t readTask) {
q.queue.PushBack(t)
}
// Get the first task from queue.
func (q *userBasedTaskQueue) front() readTask {
element := q.queue.Front()
if element == nil {
return nil
}
return element.Value.(readTask)
}
// Remove the first task from queue.
func (q *userBasedTaskQueue) pop() {
front := q.queue.Front()
if front != nil {
q.queue.Remove(front)
}
if q.queue.Len() == 0 {
q.cleanupTimestamp = time.Now()
}
}
// Return true if user based task is empty and empty for d time.
func (q *userBasedTaskQueue) expire(d time.Duration) bool {
if q.queue.Len() != 0 {
return false
}
if time.Since(q.cleanupTimestamp) > d {
return true
}
return false
}
// Merge a new task to task in queue.
func (q *userBasedTaskQueue) mergeTask(rt readTask) bool {
for element := q.queue.Back(); element != nil; element = element.Prev() {
task := element.Value.(readTask)
if task.CanMergeWith(rt) {
task.Merge(rt)
return true
}
}
return false
}
// Implement user based schedule read policy.
type userTaskPollingScheduleReadPolicy struct {
taskCount int // task count in the policy.
route map[string]*ring.Ring // map username to node of task ring.
checkpoint *ring.Ring // last not schedule ring node, ring.Ring[list.List]
taskQueueExpire time.Duration
}
// Create a new user-based schedule read policy.
func newUserTaskPollingScheduleReadPolicy(taskQueueExpire time.Duration) scheduleReadPolicy {
return &userTaskPollingScheduleReadPolicy{
taskCount: 0,
route: make(map[string]*ring.Ring),
checkpoint: nil,
taskQueueExpire: taskQueueExpire,
}
}
// Get length of task queue.
func (p *userTaskPollingScheduleReadPolicy) len() int {
return p.taskCount
}
// Add a new task into ready task queue.
func (p *userTaskPollingScheduleReadPolicy) addTask(rt readTask) {
username := contextutil.GetUserFromGrpcMetadata(rt.Ctx()) // empty user will compete on single list.
if r, ok := p.route[username]; ok {
// Add new task to the back of queue if queue exist.
r.Value.(*userBasedTaskQueue).push(rt)
} else {
// Create a new list, and add it to the route and queues.
newQueue := newUserBasedTaskQueue(username)
newQueue.push(rt)
newRing := ring.New(1)
newRing.Value = newQueue
p.route[username] = newRing
if p.checkpoint == nil {
// Create new ring if not exist.
p.checkpoint = newRing
} else {
// Add the new ring before the checkpoint.
p.checkpoint.Prev().Link(newRing)
}
}
p.taskCount++
}
// Merge a new task into exist ready task.
func (p *userTaskPollingScheduleReadPolicy) mergeTask(rt readTask) bool {
if p.taskCount == 0 {
return false
}
username := contextutil.GetUserFromGrpcMetadata(rt.Ctx()) // empty user will compete on single list.
// Applied to task with same user first.
if r, ok := p.route[username]; ok {
// Try to merge task into queue.
if r.Value.(*userBasedTaskQueue).mergeTask(rt) {
return true
}
}
// Try to merge task into other user queue before checkpoint.
node := p.checkpoint.Prev()
queuesLen := p.checkpoint.Len()
for i := 0; i < queuesLen; i++ {
prev := node.Prev()
queue := node.Value.(*userBasedTaskQueue)
if queue.len() == 0 || queue.username == username {
continue
}
if queue.mergeTask(rt) {
return true
}
node = prev
}
return false
}
func (p *userTaskPollingScheduleReadPolicy) schedule(targetCPUUsage int32, maxNum int32) (result []readTask, usage int32) {
// Return directly if there's no task ready.
if p.taskCount == 0 {
return
}
queuesLen := p.checkpoint.Len()
checkpoint := p.checkpoint
// TODO: infinite loop.
L:
for {
readyCount := len(result)
for i := 0; i < queuesLen; i++ {
if len(result) >= int(maxNum) {
break L
}
next := checkpoint.Next()
// Find task in this queue.
taskQueue := checkpoint.Value.(*userBasedTaskQueue)
// empty task queue for this user.
if taskQueue.len() == 0 {
// expire the queue.
if taskQueue.expire(p.taskQueueExpire) {
delete(p.route, taskQueue.username)
if checkpoint.Len() == 1 {
checkpoint = nil
break L
} else {
checkpoint.Prev().Unlink(1)
}
}
checkpoint = next
continue
}
// Read first read task of queue and check if cpu is enough.
task := taskQueue.front()
tUsage := task.CPUUsage()
if usage+tUsage > targetCPUUsage {
break L
}
// Pop the task and add to schedule list.
usage += tUsage
result = append(result, task)
taskQueue.pop()
p.taskCount--
checkpoint = next
}
// Stop loop if no task is added.
if readyCount == len(result) {
break L
}
}
// Update checkpoint.
p.checkpoint = checkpoint
return result, usage
}
// Implement default FIFO policy.
type fifoScheduleReadPolicy struct {
ready *list.List // Save ready to run task.
}
// Create a new default schedule read policy
func newFIFOScheduleReadPolicy() scheduleReadPolicy {
return &fifoScheduleReadPolicy{
ready: list.New(),
}
}
// Read task count inside.
func (p *fifoScheduleReadPolicy) len() int {
return p.ready.Len()
}
// Add a new task into ready task.
func (p *fifoScheduleReadPolicy) addTask(rt readTask) {
p.ready.PushBack(rt)
}
// Merge a new task into exist ready task.
func (p *fifoScheduleReadPolicy) mergeTask(rt readTask) bool {
// Reverse Iterate the task in the queue
for task := p.ready.Back(); task != nil; task = task.Prev() {
taskExist := task.Value.(readTask)
if taskExist.CanMergeWith(rt) {
taskExist.Merge(rt)
return true
}
}
return false
}
// Schedule a new task.
func (p *fifoScheduleReadPolicy) schedule(targetCPUUsage int32, maxNum int32) (result []readTask, usage int32) {
var ret []readTask
usage := int32(0)
var next *list.Element
for e := sqTasks.Front(); e != nil && maxNum > 0; e = next {
for e := p.ready.Front(); e != nil && maxNum > 0; e = next {
next = e.Next()
t, _ := e.Value.(readTask)
tUsage := t.CPUUsage()
if usage+tUsage > targetUsage {
if usage+tUsage > targetCPUUsage {
break
}
usage += tUsage
sqTasks.Remove(e)
p.ready.Remove(e)
rateCol.rtCounter.sub(t, readyQueueType)
ret = append(ret, t)
maxNum--

View File

@ -1,60 +1,296 @@
package querynode
import (
"container/list"
"context"
"fmt"
"math"
"strings"
"testing"
"time"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/crypto"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
)
func TestScheduler_newReadScheduleTaskPolicy(t *testing.T) {
policy := newReadScheduleTaskPolicy(scheduleReadPolicyNameFIFO)
assert.IsType(t, policy, &fifoScheduleReadPolicy{})
policy = newReadScheduleTaskPolicy("")
assert.IsType(t, policy, &fifoScheduleReadPolicy{})
policy = newReadScheduleTaskPolicy(scheduleReadPolicyNameUserTaskPolling)
assert.IsType(t, policy, &userTaskPollingScheduleReadPolicy{})
assert.Panics(t, func() {
newReadScheduleTaskPolicy("other")
})
}
func TestScheduler_defaultScheduleReadPolicy(t *testing.T) {
readyReadTasks := list.New()
for i := 1; i <= 10; i++ {
policy := newFIFOScheduleReadPolicy()
testBasicScheduleReadPolicy(t, policy)
for i := 1; i <= 100; i++ {
t := mockReadTask{
cpuUsage: int32(i * 10),
}
readyReadTasks.PushBack(&t)
policy.addTask(&t)
}
scheduleFunc := defaultScheduleReadPolicy
targetUsage := int32(100)
maxNum := int32(2)
tasks, cur := scheduleFunc(readyReadTasks, targetUsage, maxNum)
tasks, cur := policy.schedule(targetUsage, maxNum)
assert.Equal(t, int32(30), cur)
assert.Equal(t, int32(2), int32(len(tasks)))
targetUsage = 300
maxNum = 0
tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum)
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, int32(0), cur)
assert.Equal(t, 0, len(tasks))
targetUsage = 0
maxNum = 0
tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum)
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, int32(0), cur)
assert.Equal(t, 0, len(tasks))
targetUsage = 0
maxNum = 300
tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum)
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, int32(0), cur)
assert.Equal(t, 0, len(tasks))
actual := int32(180) // sum(3..6) * 10 3 + 4 + 5 + 6
targetUsage = int32(190) // > actual
maxNum = math.MaxInt32
tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum)
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, actual, cur)
assert.Equal(t, 4, len(tasks))
actual = 340 // sum(7..10) * 10 , 7+ 8 + 9 + 10
targetUsage = 340
maxNum = 4
tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum)
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, actual, cur)
assert.Equal(t, 4, len(tasks))
actual = 4995 * 10 // sum(11..100)
targetUsage = actual
maxNum = 90
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, actual, cur)
assert.Equal(t, 90, len(tasks))
assert.Equal(t, 0, policy.len())
}
func TestScheduler_userTaskPollingScheduleReadPolicy(t *testing.T) {
policy := newUserTaskPollingScheduleReadPolicy(time.Minute)
testBasicScheduleReadPolicy(t, policy)
for i := 1; i <= 100; i++ {
policy.addTask(&mockReadTask{
cpuUsage: int32(i * 10),
mockTask: mockTask{
baseTask: baseTask{
ctx: getContextWithAuthorization(context.Background(), fmt.Sprintf("user%d:123456", i%10)),
},
},
})
}
targetUsage := int32(100)
maxNum := int32(2)
tasks, cur := policy.schedule(targetUsage, maxNum)
assert.Equal(t, int32(30), cur)
assert.Equal(t, int32(2), int32(len(tasks)))
assert.Equal(t, 98, policy.len())
targetUsage = 300
maxNum = 0
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, int32(0), cur)
assert.Equal(t, 0, len(tasks))
assert.Equal(t, 98, policy.len())
targetUsage = 0
maxNum = 0
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, int32(0), cur)
assert.Equal(t, 0, len(tasks))
assert.Equal(t, 98, policy.len())
targetUsage = 0
maxNum = 300
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, int32(0), cur)
assert.Equal(t, 0, len(tasks))
assert.Equal(t, 98, policy.len())
actual := int32(180) // sum(3..6) * 10 3 + 4 + 5 + 6
targetUsage = int32(190) // > actual
maxNum = math.MaxInt32
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, actual, cur)
assert.Equal(t, 4, len(tasks))
assert.Equal(t, 94, policy.len())
actual = 340 // sum(7..10) * 10 , 7+ 8 + 9 + 10
targetUsage = 340
maxNum = 4
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, actual, cur)
assert.Equal(t, 4, len(tasks))
assert.Equal(t, 90, policy.len())
actual = 4995 * 10 // sum(11..100)
targetUsage = actual
maxNum = 90
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, actual, cur)
assert.Equal(t, 90, len(tasks))
assert.Equal(t, 0, policy.len())
time.Sleep(time.Minute + time.Second)
policy.addTask(&mockReadTask{
cpuUsage: int32(1),
mockTask: mockTask{
baseTask: baseTask{
ctx: getContextWithAuthorization(context.Background(), fmt.Sprintf("user%d:123456", 11)),
},
},
})
tasks, cur = policy.schedule(targetUsage, maxNum)
assert.Equal(t, int32(1), cur)
assert.Equal(t, 1, len(tasks))
assert.Equal(t, 0, policy.len())
policyInner := policy.(*userTaskPollingScheduleReadPolicy)
assert.Equal(t, 1, len(policyInner.route))
assert.Equal(t, 1, policyInner.checkpoint.Len())
}
func Test_userBasedTaskQueue(t *testing.T) {
n := 50
q := newUserBasedTaskQueue("test_user")
for i := 1; i <= n; i++ {
q.push(&mockReadTask{
cpuUsage: int32(i * 10),
mockTask: mockTask{
baseTask: baseTask{
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
},
},
})
assert.Equal(t, q.len(), i)
assert.Equal(t, q.expire(time.Second), false)
}
for i := 0; i < n; i++ {
q.pop()
assert.Equal(t, q.len(), n-(i+1))
assert.Equal(t, q.expire(time.Second), false)
}
time.Sleep(time.Second)
assert.Equal(t, q.expire(time.Second), true)
}
func testBasicScheduleReadPolicy(t *testing.T, policy scheduleReadPolicy) {
// test, push and schedule.
for i := 1; i <= 50; i++ {
cpuUsage := int32(i * 10)
id := 1
policy.addTask(&mockReadTask{
cpuUsage: cpuUsage,
mockTask: mockTask{
baseTask: baseTask{
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
id: UniqueID(id),
},
},
})
assert.Equal(t, policy.len(), 1)
task, cost := policy.schedule(cpuUsage, 1)
assert.Equal(t, cost, cpuUsage)
assert.Equal(t, len(task), 1)
assert.Equal(t, task[0].ID(), int64(id))
}
// test, can not merge and schedule.
cpuUsage := int32(100)
notMergeTask := &mockReadTask{
cpuUsage: cpuUsage,
canMerge: false,
mockTask: mockTask{
baseTask: baseTask{
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
},
},
}
assert.False(t, policy.mergeTask(notMergeTask))
policy.addTask(notMergeTask)
assert.Equal(t, policy.len(), 1)
task2 := &mockReadTask{
cpuUsage: cpuUsage,
canMerge: false,
mockTask: mockTask{
baseTask: baseTask{
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
},
},
}
assert.False(t, policy.mergeTask(task2))
assert.Equal(t, policy.len(), 1)
policy.addTask(notMergeTask)
assert.Equal(t, policy.len(), 2)
task, cost := policy.schedule(2*cpuUsage, 1)
assert.Equal(t, cost, cpuUsage)
assert.Equal(t, len(task), 1)
assert.Equal(t, policy.len(), 1)
assert.False(t, task[0].(*mockReadTask).merged)
task, cost = policy.schedule(2*cpuUsage, 1)
assert.Equal(t, cost, cpuUsage)
assert.Equal(t, len(task), 1)
assert.Equal(t, policy.len(), 0)
assert.False(t, task[0].(*mockReadTask).merged)
// test, can merge and schedule.
mergeTask := &mockReadTask{
cpuUsage: cpuUsage,
canMerge: true,
mockTask: mockTask{
baseTask: baseTask{
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
},
},
}
policy.addTask(mergeTask)
task2 = &mockReadTask{
cpuUsage: cpuUsage,
mockTask: mockTask{
baseTask: baseTask{
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
},
},
}
assert.True(t, policy.mergeTask(task2))
assert.Equal(t, policy.len(), 1)
task, cost = policy.schedule(cpuUsage, 1)
assert.Equal(t, cost, cpuUsage)
assert.Equal(t, len(task), 1)
assert.Equal(t, policy.len(), 0)
assert.True(t, task[0].(*mockReadTask).merged)
}
func getContextWithAuthorization(ctx context.Context, originValue string) context.Context {
authKey := strings.ToLower(util.HeaderAuthorize)
authValue := crypto.Base64Encode(originValue)
contextMap := map[string]string{
authKey: authValue,
}
md := metadata.New(contextMap)
return metadata.NewIncomingContext(ctx, md)
}

View File

@ -33,6 +33,7 @@ import (
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/contextutil"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/samber/lo"
@ -127,9 +128,11 @@ type ShardSegmentDetector interface {
type ShardNodeBuilder func(nodeID int64, addr string) shardQueryNode
// withStreaming function type to let search detects corresponding search streaming is done.
type searchWithStreaming func(ctx context.Context) (error, *internalpb.SearchResults)
type queryWithStreaming func(ctx context.Context) (error, *internalpb.RetrieveResults)
type getStatisticsWithStreaming func(ctx context.Context) (error, *internalpb.GetStatisticsResponse)
type (
searchWithStreaming func(ctx context.Context) (error, *internalpb.SearchResults)
queryWithStreaming func(ctx context.Context) (error, *internalpb.RetrieveResults)
getStatisticsWithStreaming func(ctx context.Context) (error, *internalpb.GetStatisticsResponse)
)
// ShardCluster maintains the ShardCluster information and perform shard level operations
type ShardCluster struct {
@ -156,7 +159,8 @@ type ShardCluster struct {
// NewShardCluster create a ShardCluster with provided information.
func NewShardCluster(collectionID int64, replicaID int64, vchannelName string, version int64,
nodeDetector ShardNodeDetector, segmentDetector ShardSegmentDetector, nodeBuilder ShardNodeBuilder) *ShardCluster {
nodeDetector ShardNodeDetector, segmentDetector ShardSegmentDetector, nodeBuilder ShardNodeBuilder,
) *ShardCluster {
sc := &ShardCluster{
state: atomic.NewInt32(int32(unavailable)),
@ -933,6 +937,13 @@ func getSearchWithStreamingFunc(searchCtx context.Context, req *querypb.SearchRe
metrics.SearchLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
metrics.SearchLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
// In queue latency per user.
metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues(
fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.SearchLabel,
contextutil.GetUserFromGrpcMetadata(streamingTask.Ctx()),
).Observe(float64(streamingTask.queueDur.Milliseconds()))
return nil, streamingTask.Ret
}
}
@ -943,7 +954,6 @@ func getQueryWithStreamingFunc(queryCtx context.Context, req *querypb.QueryReque
streamingTask.DataScope = querypb.DataScope_Streaming
streamingTask.QS = qs
err := node.scheduler.AddReadTask(queryCtx, streamingTask)
if err != nil {
return err, nil
}
@ -955,6 +965,13 @@ func getQueryWithStreamingFunc(queryCtx context.Context, req *querypb.QueryReque
metrics.QueryLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
metrics.QueryLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
// In queue latency per user.
metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues(
fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
metrics.QueryLabel,
contextutil.GetUserFromGrpcMetadata(streamingTask.Ctx()),
).Observe(float64(streamingTask.queueDur.Milliseconds()))
return nil, streamingTask.Ret
}
}

View File

@ -40,7 +40,6 @@ type taskScheduler struct {
// for search and query start
unsolvedReadTasks *list.List
readyReadTasks *list.List
receiveReadTaskChan chan readTask
executeReadTaskChan chan readTask
@ -50,7 +49,8 @@ type taskScheduler struct {
// tSafeReplica
tSafeReplica TSafeReplicaInterface
schedule scheduleReadTaskPolicy
// schedule policy for ready read task.
readySchedulePolicy scheduleReadPolicy
// for search and query end
cpuUsage int32 // 1200 means 1200% 12 cores
@ -77,13 +77,12 @@ func newTaskScheduler(ctx context.Context, tSafeReplica TSafeReplicaInterface) *
ctx: ctx1,
cancel: cancel,
unsolvedReadTasks: list.New(),
readyReadTasks: list.New(),
receiveReadTaskChan: make(chan readTask, Params.QueryNodeCfg.MaxReceiveChanSize),
executeReadTaskChan: make(chan readTask, maxExecuteReadChanLen),
notifyChan: make(chan struct{}, 1),
tSafeReplica: tSafeReplica,
maxCPUUsage: int32(getNumCPU() * 100),
schedule: defaultScheduleReadPolicy,
readySchedulePolicy: newReadScheduleTaskPolicy(Params.QueryNodeCfg.ScheduleReadPolicy.Name),
}
s.queue = newQueryNodeTaskQueue(s)
return s
@ -248,7 +247,7 @@ func (s *taskScheduler) AddReadTask(ctx context.Context, t readTask) error {
func (s *taskScheduler) popAndAddToExecute() {
readConcurrency := atomic.LoadInt32(&s.readConcurrency)
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(readConcurrency))
if s.readyReadTasks.Len() == 0 {
if s.readySchedulePolicy.len() == 0 {
return
}
curUsage := atomic.LoadInt32(&s.cpuUsage)
@ -266,7 +265,7 @@ func (s *taskScheduler) popAndAddToExecute() {
return
}
tasks, deltaUsage := s.schedule(s.readyReadTasks, targetUsage, remain)
tasks, deltaUsage := s.readySchedulePolicy.schedule(targetUsage, remain)
atomic.AddInt32(&s.cpuUsage, deltaUsage)
for _, t := range tasks {
s.executeReadTaskChan <- t
@ -366,23 +365,12 @@ func (s *taskScheduler) tryMergeReadTasks() {
}
if ready {
if !Params.QueryNodeCfg.GroupEnabled {
s.readyReadTasks.PushBack(t)
s.readySchedulePolicy.addTask(t)
rateCol.rtCounter.add(t, readyQueueType)
} else {
merged := false
for m := s.readyReadTasks.Back(); m != nil; m = m.Prev() {
mTask, ok := m.Value.(readTask)
if !ok {
continue
}
if mTask.CanMergeWith(t) {
mTask.Merge(t)
merged = true
break
}
}
if !merged {
s.readyReadTasks.PushBack(t)
// Try to merge task first, otherwise add it.
if !s.readySchedulePolicy.mergeTask(t) {
s.readySchedulePolicy.addTask(t)
rateCol.rtCounter.add(t, readyQueueType)
}
}
@ -391,5 +379,5 @@ func (s *taskScheduler) tryMergeReadTasks() {
}
}
metrics.QueryNodeReadTaskUnsolveLen.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(s.unsolvedReadTasks.Len()))
metrics.QueryNodeReadTaskReadyLen.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(s.readyReadTasks.Len()))
metrics.QueryNodeReadTaskReadyLen.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(s.readySchedulePolicy.len()))
}

View File

@ -71,6 +71,7 @@ type mockReadTask struct {
timeoutError error
step typeutil.TaskStep
readyError error
merged bool
}
func (m *mockReadTask) GetCollectionID() UniqueID {
@ -82,7 +83,7 @@ func (m *mockReadTask) Ready() (bool, error) {
}
func (m *mockReadTask) Merge(o readTask) {
m.merged = true
}
func (m *mockReadTask) CPUUsage() int32 {

View File

@ -38,6 +38,8 @@ const (
HeaderAuthorize = "authorization"
HeaderDBName = "dbName"
// HeaderUser is the key of username in metadata of grpc.
HeaderUser = "user"
// HeaderSourceID identify requests from Milvus members and client requests
HeaderSourceID = "sourceId"
// MemberCredID id for Milvus members (data/index/query node/coord component)

View File

@ -16,7 +16,14 @@
package contextutil
import "context"
import (
"context"
"strings"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/crypto"
"google.golang.org/grpc/metadata"
)
type ctxTenantKey struct{}
@ -37,3 +44,47 @@ func TenantID(ctx context.Context) string {
return ""
}
// Passthrough "user" field in grpc from incoming to outgoing.
func PassthroughUserInGrpcMetadata(ctx context.Context) context.Context {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return ctx
}
user := md[strings.ToLower(util.HeaderUser)]
if len(user) == 0 || len(user[0]) == 0 {
return ctx
}
return metadata.AppendToOutgoingContext(ctx, util.HeaderUser, user[0])
}
// Set "user" field in grpc outgoing metadata
func WithUserInGrpcMetadata(ctx context.Context, user string) context.Context {
return metadata.AppendToOutgoingContext(ctx, util.HeaderUser, crypto.Base64Encode(user))
}
// Get "user" field from grpc metadata, empty string will returned if not set.
func GetUserFromGrpcMetadata(ctx context.Context) string {
if user := getUserFromGrpcMetadataAux(ctx, metadata.FromIncomingContext); user != "" {
return user
}
return getUserFromGrpcMetadataAux(ctx, metadata.FromOutgoingContext)
}
// Aux function for `GetUserFromGrpc`
func getUserFromGrpcMetadataAux(ctx context.Context, mdGetter func(ctx context.Context) (metadata.MD, bool)) string {
md, ok := mdGetter(ctx)
if !ok {
return ""
}
user := md[strings.ToLower(util.HeaderUser)]
if len(user) == 0 || len(user[0]) == 0 {
return ""
}
// It may be duplicated in meta, but should be always same.
rawuser, err := crypto.Base64Decode(user[0])
if err != nil {
return ""
}
return rawuser
}

View File

@ -0,0 +1,29 @@
package contextutil
import (
"context"
"testing"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/crypto"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
)
// Test UserInGrpcMetadata related function
func Test_UserInGrpcMetadata(t *testing.T) {
testUser := "test_user_131"
ctx := context.Background()
assert.Equal(t, GetUserFromGrpcMetadata(ctx), "")
ctx = WithUserInGrpcMetadata(ctx, testUser)
assert.Equal(t, GetUserFromGrpcMetadata(ctx), testUser)
md := metadata.Pairs(util.HeaderUser, crypto.Base64Encode(testUser))
ctx = metadata.NewIncomingContext(context.Background(), md)
assert.Equal(t, getUserFromGrpcMetadataAux(ctx, metadata.FromIncomingContext), testUser)
ctx = PassthroughUserInGrpcMetadata(ctx)
assert.Equal(t, getUserFromGrpcMetadataAux(ctx, metadata.FromOutgoingContext), testUser)
assert.Equal(t, GetUserFromGrpcMetadata(ctx), testUser)
}

View File

@ -34,11 +34,11 @@ const (
// DefaultIndexSliceSize defines the default slice size of index file when serializing.
DefaultIndexSliceSize = 16
DefaultGracefulTime = 5000 //ms
DefaultGracefulTime = 5000 // ms
DefaultGracefulStopTimeout = 30 // s
DefaultThreadCoreCoefficient = 10
DefaultSessionTTL = 20 //s
DefaultSessionTTL = 20 // s
DefaultSessionRetryTimes = 30
DefaultMaxDegree = 56
@ -1117,6 +1117,9 @@ type queryNodeConfig struct {
CPURatio float64
MaxTimestampLag time.Duration
// schedule
ScheduleReadPolicy queryNodeConfigScheduleReadPolicy
GCHelperEnabled bool
MinimumGOGCConfig int
MaximumGOGCConfig int
@ -1124,6 +1127,15 @@ type queryNodeConfig struct {
GracefulStopTimeout int64
}
type queryNodeConfigScheduleReadPolicy struct {
Name string
UserTaskPolling queryNodeConfigScheduleUserTaskPolling
}
type queryNodeConfigScheduleUserTaskPolling struct {
TaskQueueExpire time.Duration
}
func (p *queryNodeConfig) init(base *BaseTable) {
p.Base = base
p.NodeID.Store(UniqueID(0))
@ -1153,6 +1165,9 @@ func (p *queryNodeConfig) init(base *BaseTable) {
p.initDiskCapacity()
p.initMaxDiskUsagePercentage()
// Initialize scheduler.
p.initScheduleReadPolicy()
p.initGCTunerEnbaled()
p.initMaximumGOGC()
p.initMinimumGOGC()
@ -1346,6 +1361,11 @@ func (p *queryNodeConfig) initDiskCapacity() {
p.DiskCapacityLimit = diskSize * 1024 * 1024 * 1024
}
func (p *queryNodeConfig) initScheduleReadPolicy() {
p.ScheduleReadPolicy.Name = p.Base.LoadWithDefault("queryNode.scheduler.scheduleReadPolicy.name", "fifo")
p.ScheduleReadPolicy.UserTaskPolling.TaskQueueExpire = time.Duration(p.Base.ParseInt64WithDefault("queryNode.scheduler.scheduleReadPolicy.taskQueueExpire", 60)) * time.Second
}
func (p *queryNodeConfig) initGCTunerEnbaled() {
p.GCHelperEnabled = p.Base.ParseBool("queryNode.gchelper.enabled", true)
}
@ -1514,7 +1534,6 @@ func (p *dataCoordConfig) initSegmentMinSizeFromIdleToSealed() {
func (p *dataCoordConfig) initSegmentExpansionRate() {
p.SegmentExpansionRate = p.Base.ParseFloatWithDefault("dataCoord.segment.expansionRate", 1.25)
log.Info("init segment expansion rate", zap.Float64("value", p.SegmentExpansionRate))
}
func (p *dataCoordConfig) initSegmentMaxBinlogFileNumber() {
@ -1647,7 +1666,7 @@ type dataNodeConfig struct {
Base *BaseTable
// ID of the current node
//NodeID atomic.Value
// NodeID atomic.Value
NodeID atomic.Value
// IP of the current DataNode
IP string