diff --git a/configs/milvus.yaml b/configs/milvus.yaml index d6608ad5f6..a765003620 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -242,6 +242,18 @@ queryNode: # if the lag is larger than this config, scheduler will return error without waiting. # the valid value is [3600, infinite) maxTimestampLag: 86400 + # read task schedule policy: fifo(by default), user-task-polling. + scheduleReadPolicy: + # fifo: A FIFO queue support the schedule. + # user-task-polling: + # The user's tasks will be polled one by one and scheduled. + # Scheduling is fair on task granularity. + # The policy is based on the username for authentication. + # And an empty username is considered the same user. + # When there are no multi-users, the policy decay into FIFO + name: fifo + # user-task-polling configure: + taskQueueExpire: 60 # 1 min by default, expire time of inner user task queue since queue is empty. grouping: enabled: true diff --git a/internal/metrics/querynode_metrics.go b/internal/metrics/querynode_metrics.go index e842445231..02ab5f5aaf 100644 --- a/internal/metrics/querynode_metrics.go +++ b/internal/metrics/querynode_metrics.go @@ -149,6 +149,20 @@ var ( queryTypeLabelName, }) + QueryNodeSQPerUserLatencyInQueue = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: milvusNamespace, + Subsystem: typeutil.QueryNodeRole, + Name: "sq_queue_user_latency", + Help: "latency per user of search or query in queue", + Buckets: buckets, + }, []string{ + nodeIDLabelName, + queryTypeLabelName, + usernameLabelName, + }, + ) + QueryNodeSQSegmentLatency = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: milvusNamespace, @@ -380,6 +394,7 @@ func RegisterQueryNode(registry *prometheus.Registry) { registry.MustRegister(QueryNodeSQCount) registry.MustRegister(QueryNodeSQReqLatency) registry.MustRegister(QueryNodeSQLatencyInQueue) + registry.MustRegister(QueryNodeSQPerUserLatencyInQueue) registry.MustRegister(QueryNodeSQSegmentLatency) registry.MustRegister(QueryNodeSQSegmentLatencyInCore) registry.MustRegister(QueryNodeReduceLatency) @@ -422,5 +437,4 @@ func CleanupQueryNodeCollectionMetrics(nodeID int64, collectionID int64) { collectionIDLabelName: fmt.Sprint(collectionID), }) } - } diff --git a/internal/proxy/privilege_interceptor_test.go b/internal/proxy/privilege_interceptor_test.go index 24f578f362..98227473bc 100644 --- a/internal/proxy/privilege_interceptor_test.go +++ b/internal/proxy/privilege_interceptor_test.go @@ -41,7 +41,7 @@ func TestPrivilegeInterceptor(t *testing.T) { }) assert.NotNil(t, err) - ctx = GetContext(context.Background(), "alice:123456") + ctx = getContextWithAuthorization(context.Background(), "alice:123456") client := &MockRootCoordClientInterface{} queryCoord := &MockQueryCoordClientInterface{} mgr := newShardClientMgr() @@ -65,13 +65,13 @@ func TestPrivilegeInterceptor(t *testing.T) { }, nil } - _, err = PrivilegeInterceptor(GetContext(context.Background(), "foo:123456"), &milvuspb.LoadCollectionRequest{ + _, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "foo:123456"), &milvuspb.LoadCollectionRequest{ DbName: "db_test", CollectionName: "col1", }) assert.NotNil(t, err) - _, err = PrivilegeInterceptor(GetContext(context.Background(), "root:123456"), &milvuspb.LoadCollectionRequest{ + _, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "root:123456"), &milvuspb.LoadCollectionRequest{ DbName: "db_test", CollectionName: "col1", }) @@ -103,7 +103,7 @@ func TestPrivilegeInterceptor(t *testing.T) { }) assert.Nil(t, err) - fooCtx := GetContext(context.Background(), "foo:123456") + fooCtx := getContextWithAuthorization(context.Background(), "foo:123456") _, err = PrivilegeInterceptor(fooCtx, &milvuspb.LoadCollectionRequest{ DbName: "db_test", CollectionName: "col1", @@ -130,7 +130,7 @@ func TestPrivilegeInterceptor(t *testing.T) { }) assert.Nil(t, err) - _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{ + _, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{ DbName: "db_test", CollectionName: "col1", }) @@ -148,7 +148,7 @@ func TestPrivilegeInterceptor(t *testing.T) { go func() { defer g.Done() assert.NotPanics(t, func() { - PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{ + PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{ DbName: "db_test", CollectionName: "col1", }) @@ -161,7 +161,6 @@ func TestPrivilegeInterceptor(t *testing.T) { getPolicyModel("foo") }) }) - } func TestResourceGroupPrivilege(t *testing.T) { @@ -173,7 +172,7 @@ func TestResourceGroupPrivilege(t *testing.T) { _, err := PrivilegeInterceptor(ctx, &milvuspb.ListResourceGroupsRequest{}) assert.NotNil(t, err) - ctx = GetContext(context.Background(), "fooo:123456") + ctx = getContextWithAuthorization(context.Background(), "fooo:123456") client := &MockRootCoordClientInterface{} queryCoord := &MockQueryCoordClientInterface{} mgr := newShardClientMgr() @@ -198,29 +197,28 @@ func TestResourceGroupPrivilege(t *testing.T) { } InitMetaCache(ctx, client, queryCoord, mgr) - _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.CreateResourceGroupRequest{ + _, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.CreateResourceGroupRequest{ ResourceGroup: "rg", }) assert.Nil(t, err) - _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.DropResourceGroupRequest{ + _, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.DropResourceGroupRequest{ ResourceGroup: "rg", }) assert.Nil(t, err) - _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.DescribeResourceGroupRequest{ + _, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.DescribeResourceGroupRequest{ ResourceGroup: "rg", }) assert.Nil(t, err) - _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.ListResourceGroupsRequest{}) + _, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.ListResourceGroupsRequest{}) assert.Nil(t, err) - _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.TransferNodeRequest{}) + _, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.TransferNodeRequest{}) assert.Nil(t, err) - _, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.TransferReplicaRequest{}) + _, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.TransferReplicaRequest{}) assert.Nil(t, err) }) - } diff --git a/internal/proxy/proxy_rpc_test.go b/internal/proxy/proxy_rpc_test.go index 038a605ec1..d2657d7a7b 100644 --- a/internal/proxy/proxy_rpc_test.go +++ b/internal/proxy/proxy_rpc_test.go @@ -45,7 +45,7 @@ func TestProxyRpcLimit(t *testing.T) { t.Setenv("ROCKSMQ_PATH", path) defer os.RemoveAll(path) - ctx := GetContext(context.Background(), "root:123456") + ctx := getContextWithAuthorization(context.Background(), "root:123456") localMsg := true factory := dependency.NewDefaultFactory(localMsg) diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 6123e0f16a..e232c0e5c6 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -353,12 +353,12 @@ func (s *proxyTestServer) GetStatisticsChannel(ctx context.Context, request *int func (s *proxyTestServer) startGrpc(ctx context.Context, wg *sync.WaitGroup, p paramtable.GrpcServerConfig) { defer wg.Done() - var kaep = keepalive.EnforcementPolicy{ + kaep := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection PermitWithoutStream: true, // Allow pings even when there are no active streams } - var kasp = keepalive.ServerParameters{ + kasp := keepalive.ServerParameters{ Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead } @@ -425,7 +425,7 @@ func TestProxy(t *testing.T) { defer os.RemoveAll(path) ctx, cancel := context.WithCancel(context.Background()) - ctx = GetContext(ctx, "root:123456") + ctx = getContextWithAuthorization(ctx, "root:123456") localMsg := true Params.InitOnce() factory := dependency.MockDefaultFactory(localMsg, &Params) @@ -758,7 +758,6 @@ func TestProxy(t *testing.T) { resp, err = proxy.CreateCollection(ctx, reqInvalidField) assert.NoError(t, err) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode) - }) wg.Add(1) @@ -842,7 +841,6 @@ func TestProxy(t *testing.T) { DbName: dbName, CollectionName: collectionName, }) - }) wg.Add(1) @@ -1423,7 +1421,7 @@ func TestProxy(t *testing.T) { }, } - //resp, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{ + // resp, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{ _, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{ Base: nil, OpLeft: opLeft, @@ -2109,7 +2107,7 @@ func TestProxy(t *testing.T) { t.Run("credential UPDATE api", func(t *testing.T) { defer wg.Done() rootCtx := ctx - fooCtx := GetContext(context.Background(), "foo:123456") + fooCtx := getContextWithAuthorization(context.Background(), "foo:123456") ctx = fooCtx originUsers := Params.CommonCfg.SuperUsers Params.CommonCfg.SuperUsers = []string{"root"} @@ -3863,7 +3861,6 @@ func TestProxy_ListImportTasks(t *testing.T) { } func TestProxy_GetStatistics(t *testing.T) { - } func TestProxy_GetLoadState(t *testing.T) { diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index c8f74dbc3c..9809fda3e5 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -21,6 +21,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/contextutil" "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/timerecord" @@ -229,7 +230,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error { log.Ctx(ctx).Debug("Validate partition names.", zap.Int64("msgID", t.ID()), zap.Any("requestType", "query")) - //fetch search_growing from search param + // fetch search_growing from search param var ignoreGrowing bool for i, kv := range t.request.GetQueryParams() { if kv.GetKey() == IgnoreGrowingKey { @@ -358,6 +359,11 @@ func (t *queryTask) Execute(ctx context.Context) error { defer tr.CtxElapse(ctx, "done") log := log.Ctx(ctx) + // Add user name into context if it's exists. + if username, _ := GetCurUserFromContext(ctx); username != "" { + ctx = contextutil.WithUserInGrpcMetadata(ctx, username) + } + executeQuery := func() error { shards, err := globalMetaCache.GetShards(ctx, true, t.collectionName) if err != nil { diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 5668589861..2836330d61 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -23,6 +23,7 @@ import ( "github.com/milvus-io/milvus/internal/querynode" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/commonpbutil" + "github.com/milvus-io/milvus/internal/util/contextutil" "github.com/milvus-io/milvus/internal/util/distance" "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/retry" @@ -255,7 +256,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { log.Ctx(ctx).Debug("translate output fields", zap.Int64("msgID", t.ID()), zap.Strings("output fields", t.request.GetOutputFields())) - //fetch search_growing from search param + // fetch search_growing from search param var ignoreGrowing bool for i, kv := range t.request.GetSearchParams() { if kv.GetKey() == IgnoreGrowingKey { @@ -380,6 +381,11 @@ func (t *searchTask) Execute(ctx context.Context) error { defer tr.CtxElapse(ctx, "done") log := log.Ctx(ctx) + // Add user name into context if it's exists. + if username, _ := GetCurUserFromContext(ctx); username != "" { + ctx = contextutil.WithUserInGrpcMetadata(ctx, username) + } + executeSearch := func() error { shard2Leaders, err := globalMetaCache.GetShards(ctx, true, t.collectionName) if err != nil { @@ -723,7 +729,7 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb log.Ctx(ctx).Warn("invalid search results", zap.Error(err)) return ret, err } - //printSearchResultData(sData, strconv.FormatInt(int64(i), 10)) + // printSearchResultData(sData, strconv.FormatInt(int64(i), 10)) } var ( diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 5f48b4aca6..9be60bf540 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -711,7 +711,7 @@ func TestIsDefaultRole(t *testing.T) { assert.Equal(t, false, IsDefaultRole("manager")) } -func GetContext(ctx context.Context, originValue string) context.Context { +func getContextWithAuthorization(ctx context.Context, originValue string) context.Context { authKey := strings.ToLower(util.HeaderAuthorize) authValue := crypto.Base64Encode(originValue) contextMap := map[string]string{ @@ -740,12 +740,12 @@ func TestGetCurUserFromContext(t *testing.T) { _, err = GetCurUserFromContext(metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{}))) assert.NotNil(t, err) - _, err = GetCurUserFromContext(GetContext(context.Background(), "123456")) + _, err = GetCurUserFromContext(getContextWithAuthorization(context.Background(), "123456")) assert.NotNil(t, err) root := "root" password := "123456" - username, err := GetCurUserFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password))) + username, err := GetCurUserFromContext(getContextWithAuthorization(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password))) assert.Nil(t, err) assert.Equal(t, "root", username) } diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index 2d1f62ec0c..63f76d36ae 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -26,6 +26,7 @@ import ( "time" "github.com/milvus-io/milvus/internal/util/commonpbutil" + "github.com/milvus-io/milvus/internal/util/contextutil" "github.com/milvus-io/milvus/internal/util/errorutil" "github.com/golang/protobuf/proto" @@ -875,13 +876,19 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se metrics.SearchLabel).Observe(float64(historicalTask.queueDur.Milliseconds())) metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel).Observe(float64(historicalTask.reduceDur.Milliseconds())) + // In queue latency per user. + metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues( + fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), + metrics.SearchLabel, + contextutil.GetUserFromGrpcMetadata(historicalTask.Ctx()), + ).Observe(float64(historicalTask.queueDur.Milliseconds())) latency := tr.ElapseSpan() metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc() return historicalTask.Ret, nil } - //from Proxy + // from Proxy tr := timerecord.NewTimeRecorder("SearchShard") log.Ctx(ctx).Debug("start do search", zap.String("vChannel", dmlChannel), @@ -900,6 +907,8 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se errCluster error ) defer cancel() + // Passthrough username across grpc. + searchCtx = contextutil.PassthroughUserInGrpcMetadata(searchCtx) // optimize search params if req.Req.GetSerializedExprPlan() != nil && node.queryHook != nil { @@ -1046,6 +1055,12 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que metrics.QueryLabel).Observe(float64(queryTask.queueDur.Milliseconds())) metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel).Observe(float64(queryTask.reduceDur.Milliseconds())) + // In queue latency per user. + metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues( + fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), + metrics.QueryLabel, + contextutil.GetUserFromGrpcMetadata(queryTask.Ctx()), + ).Observe(float64(queryTask.queueDur.Milliseconds())) latency := tr.ElapseSpan() metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds())) metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc() @@ -1063,6 +1078,9 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que queryCtx, cancel := context.WithCancel(ctx) defer cancel() + // Passthrough username across grpc. + queryCtx = contextutil.PassthroughUserInGrpcMetadata(queryCtx) + var results []*internalpb.RetrieveResults var errCluster error var withStreamingFunc queryWithStreaming diff --git a/internal/querynode/scheduler_policy.go b/internal/querynode/scheduler_policy.go index a97db3fa5d..71bcc40f3f 100644 --- a/internal/querynode/scheduler_policy.go +++ b/internal/querynode/scheduler_policy.go @@ -2,23 +2,313 @@ package querynode import ( "container/list" + "container/ring" + "time" + + "github.com/milvus-io/milvus/internal/util/contextutil" ) -type scheduleReadTaskPolicy func(sqTasks *list.List, targetUsage int32, maxNum int32) ([]readTask, int32) +const ( + scheduleReadPolicyNameFIFO = "fifo" + scheduleReadPolicyNameUserTaskPolling = "user-task-polling" +) -func defaultScheduleReadPolicy(sqTasks *list.List, targetUsage int32, maxNum int32) ([]readTask, int32) { +var ( + _ scheduleReadPolicy = &fifoScheduleReadPolicy{} + _ scheduleReadPolicy = &userTaskPollingScheduleReadPolicy{} +) + +type scheduleReadPolicy interface { + // Read task count inside. + len() int + + // Add a new task into ready task. + addTask(rt readTask) + + // Merge a new task into exist ready task. + // Return true if merge success. + mergeTask(rt readTask) bool + + // Schedule new task to run. + schedule(targetCPUUsage int32, maxNum int32) ([]readTask, int32) +} + +// Create a new schedule policy. +func newReadScheduleTaskPolicy(policyName string) scheduleReadPolicy { + switch policyName { + case "": + fallthrough + case scheduleReadPolicyNameFIFO: + return newFIFOScheduleReadPolicy() + case scheduleReadPolicyNameUserTaskPolling: + return newUserTaskPollingScheduleReadPolicy( + Params.QueryNodeCfg.ScheduleReadPolicy.UserTaskPolling.TaskQueueExpire) + default: + panic("invalid read schedule task policy") + } +} + +// Create a new user based task queue. +func newUserBasedTaskQueue(username string) *userBasedTaskQueue { + return &userBasedTaskQueue{ + cleanupTimestamp: time.Now(), + queue: list.New(), + username: username, + } +} + +// User based task queue. +type userBasedTaskQueue struct { + cleanupTimestamp time.Time + queue *list.List + username string +} + +// Get length of user task +func (q *userBasedTaskQueue) len() int { + return q.queue.Len() +} + +// Add a new task to end of queue. +func (q *userBasedTaskQueue) push(t readTask) { + q.queue.PushBack(t) +} + +// Get the first task from queue. +func (q *userBasedTaskQueue) front() readTask { + element := q.queue.Front() + if element == nil { + return nil + } + return element.Value.(readTask) +} + +// Remove the first task from queue. +func (q *userBasedTaskQueue) pop() { + front := q.queue.Front() + if front != nil { + q.queue.Remove(front) + } + if q.queue.Len() == 0 { + q.cleanupTimestamp = time.Now() + } +} + +// Return true if user based task is empty and empty for d time. +func (q *userBasedTaskQueue) expire(d time.Duration) bool { + if q.queue.Len() != 0 { + return false + } + if time.Since(q.cleanupTimestamp) > d { + return true + } + return false +} + +// Merge a new task to task in queue. +func (q *userBasedTaskQueue) mergeTask(rt readTask) bool { + for element := q.queue.Back(); element != nil; element = element.Prev() { + task := element.Value.(readTask) + if task.CanMergeWith(rt) { + task.Merge(rt) + return true + } + } + return false +} + +// Implement user based schedule read policy. +type userTaskPollingScheduleReadPolicy struct { + taskCount int // task count in the policy. + route map[string]*ring.Ring // map username to node of task ring. + checkpoint *ring.Ring // last not schedule ring node, ring.Ring[list.List] + taskQueueExpire time.Duration +} + +// Create a new user-based schedule read policy. +func newUserTaskPollingScheduleReadPolicy(taskQueueExpire time.Duration) scheduleReadPolicy { + return &userTaskPollingScheduleReadPolicy{ + taskCount: 0, + route: make(map[string]*ring.Ring), + checkpoint: nil, + taskQueueExpire: taskQueueExpire, + } +} + +// Get length of task queue. +func (p *userTaskPollingScheduleReadPolicy) len() int { + return p.taskCount +} + +// Add a new task into ready task queue. +func (p *userTaskPollingScheduleReadPolicy) addTask(rt readTask) { + username := contextutil.GetUserFromGrpcMetadata(rt.Ctx()) // empty user will compete on single list. + if r, ok := p.route[username]; ok { + // Add new task to the back of queue if queue exist. + r.Value.(*userBasedTaskQueue).push(rt) + } else { + // Create a new list, and add it to the route and queues. + newQueue := newUserBasedTaskQueue(username) + newQueue.push(rt) + newRing := ring.New(1) + newRing.Value = newQueue + p.route[username] = newRing + if p.checkpoint == nil { + // Create new ring if not exist. + p.checkpoint = newRing + } else { + // Add the new ring before the checkpoint. + p.checkpoint.Prev().Link(newRing) + } + } + p.taskCount++ +} + +// Merge a new task into exist ready task. +func (p *userTaskPollingScheduleReadPolicy) mergeTask(rt readTask) bool { + if p.taskCount == 0 { + return false + } + + username := contextutil.GetUserFromGrpcMetadata(rt.Ctx()) // empty user will compete on single list. + // Applied to task with same user first. + if r, ok := p.route[username]; ok { + // Try to merge task into queue. + if r.Value.(*userBasedTaskQueue).mergeTask(rt) { + return true + } + } + + // Try to merge task into other user queue before checkpoint. + node := p.checkpoint.Prev() + queuesLen := p.checkpoint.Len() + for i := 0; i < queuesLen; i++ { + prev := node.Prev() + queue := node.Value.(*userBasedTaskQueue) + if queue.len() == 0 || queue.username == username { + continue + } + if queue.mergeTask(rt) { + return true + } + node = prev + } + + return false +} + +func (p *userTaskPollingScheduleReadPolicy) schedule(targetCPUUsage int32, maxNum int32) (result []readTask, usage int32) { + // Return directly if there's no task ready. + if p.taskCount == 0 { + return + } + + queuesLen := p.checkpoint.Len() + checkpoint := p.checkpoint + // TODO: infinite loop. +L: + for { + readyCount := len(result) + for i := 0; i < queuesLen; i++ { + if len(result) >= int(maxNum) { + break L + } + + next := checkpoint.Next() + // Find task in this queue. + taskQueue := checkpoint.Value.(*userBasedTaskQueue) + + // empty task queue for this user. + if taskQueue.len() == 0 { + // expire the queue. + if taskQueue.expire(p.taskQueueExpire) { + delete(p.route, taskQueue.username) + if checkpoint.Len() == 1 { + checkpoint = nil + break L + } else { + checkpoint.Prev().Unlink(1) + } + } + checkpoint = next + continue + } + + // Read first read task of queue and check if cpu is enough. + task := taskQueue.front() + tUsage := task.CPUUsage() + if usage+tUsage > targetCPUUsage { + break L + } + + // Pop the task and add to schedule list. + usage += tUsage + result = append(result, task) + taskQueue.pop() + p.taskCount-- + + checkpoint = next + } + + // Stop loop if no task is added. + if readyCount == len(result) { + break L + } + } + + // Update checkpoint. + p.checkpoint = checkpoint + return result, usage +} + +// Implement default FIFO policy. +type fifoScheduleReadPolicy struct { + ready *list.List // Save ready to run task. +} + +// Create a new default schedule read policy +func newFIFOScheduleReadPolicy() scheduleReadPolicy { + return &fifoScheduleReadPolicy{ + ready: list.New(), + } +} + +// Read task count inside. +func (p *fifoScheduleReadPolicy) len() int { + return p.ready.Len() +} + +// Add a new task into ready task. +func (p *fifoScheduleReadPolicy) addTask(rt readTask) { + p.ready.PushBack(rt) +} + +// Merge a new task into exist ready task. +func (p *fifoScheduleReadPolicy) mergeTask(rt readTask) bool { + // Reverse Iterate the task in the queue + for task := p.ready.Back(); task != nil; task = task.Prev() { + taskExist := task.Value.(readTask) + if taskExist.CanMergeWith(rt) { + taskExist.Merge(rt) + return true + } + } + return false +} + +// Schedule a new task. +func (p *fifoScheduleReadPolicy) schedule(targetCPUUsage int32, maxNum int32) (result []readTask, usage int32) { var ret []readTask - usage := int32(0) var next *list.Element - for e := sqTasks.Front(); e != nil && maxNum > 0; e = next { + for e := p.ready.Front(); e != nil && maxNum > 0; e = next { next = e.Next() t, _ := e.Value.(readTask) tUsage := t.CPUUsage() - if usage+tUsage > targetUsage { + if usage+tUsage > targetCPUUsage { break } usage += tUsage - sqTasks.Remove(e) + p.ready.Remove(e) rateCol.rtCounter.sub(t, readyQueueType) ret = append(ret, t) maxNum-- diff --git a/internal/querynode/scheduler_policy_test.go b/internal/querynode/scheduler_policy_test.go index 9a14ebbb0e..3cb149b036 100644 --- a/internal/querynode/scheduler_policy_test.go +++ b/internal/querynode/scheduler_policy_test.go @@ -1,60 +1,296 @@ package querynode import ( - "container/list" + "context" + "fmt" "math" + "strings" "testing" + "time" + "github.com/milvus-io/milvus/internal/util" + "github.com/milvus-io/milvus/internal/util/crypto" "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" ) +func TestScheduler_newReadScheduleTaskPolicy(t *testing.T) { + policy := newReadScheduleTaskPolicy(scheduleReadPolicyNameFIFO) + assert.IsType(t, policy, &fifoScheduleReadPolicy{}) + policy = newReadScheduleTaskPolicy("") + assert.IsType(t, policy, &fifoScheduleReadPolicy{}) + policy = newReadScheduleTaskPolicy(scheduleReadPolicyNameUserTaskPolling) + assert.IsType(t, policy, &userTaskPollingScheduleReadPolicy{}) + assert.Panics(t, func() { + newReadScheduleTaskPolicy("other") + }) +} + func TestScheduler_defaultScheduleReadPolicy(t *testing.T) { - readyReadTasks := list.New() - for i := 1; i <= 10; i++ { + policy := newFIFOScheduleReadPolicy() + testBasicScheduleReadPolicy(t, policy) + + for i := 1; i <= 100; i++ { t := mockReadTask{ cpuUsage: int32(i * 10), } - readyReadTasks.PushBack(&t) + policy.addTask(&t) } - scheduleFunc := defaultScheduleReadPolicy - targetUsage := int32(100) maxNum := int32(2) - tasks, cur := scheduleFunc(readyReadTasks, targetUsage, maxNum) + tasks, cur := policy.schedule(targetUsage, maxNum) assert.Equal(t, int32(30), cur) assert.Equal(t, int32(2), int32(len(tasks))) targetUsage = 300 maxNum = 0 - tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum) + tasks, cur = policy.schedule(targetUsage, maxNum) assert.Equal(t, int32(0), cur) assert.Equal(t, 0, len(tasks)) targetUsage = 0 maxNum = 0 - tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum) + tasks, cur = policy.schedule(targetUsage, maxNum) assert.Equal(t, int32(0), cur) assert.Equal(t, 0, len(tasks)) targetUsage = 0 maxNum = 300 - tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum) + tasks, cur = policy.schedule(targetUsage, maxNum) assert.Equal(t, int32(0), cur) assert.Equal(t, 0, len(tasks)) actual := int32(180) // sum(3..6) * 10 3 + 4 + 5 + 6 targetUsage = int32(190) // > actual maxNum = math.MaxInt32 - tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum) + tasks, cur = policy.schedule(targetUsage, maxNum) assert.Equal(t, actual, cur) assert.Equal(t, 4, len(tasks)) actual = 340 // sum(7..10) * 10 , 7+ 8 + 9 + 10 targetUsage = 340 maxNum = 4 - tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum) + tasks, cur = policy.schedule(targetUsage, maxNum) assert.Equal(t, actual, cur) assert.Equal(t, 4, len(tasks)) + + actual = 4995 * 10 // sum(11..100) + targetUsage = actual + maxNum = 90 + tasks, cur = policy.schedule(targetUsage, maxNum) + assert.Equal(t, actual, cur) + assert.Equal(t, 90, len(tasks)) + assert.Equal(t, 0, policy.len()) +} + +func TestScheduler_userTaskPollingScheduleReadPolicy(t *testing.T) { + policy := newUserTaskPollingScheduleReadPolicy(time.Minute) + testBasicScheduleReadPolicy(t, policy) + + for i := 1; i <= 100; i++ { + policy.addTask(&mockReadTask{ + cpuUsage: int32(i * 10), + mockTask: mockTask{ + baseTask: baseTask{ + ctx: getContextWithAuthorization(context.Background(), fmt.Sprintf("user%d:123456", i%10)), + }, + }, + }) + } + targetUsage := int32(100) + maxNum := int32(2) + + tasks, cur := policy.schedule(targetUsage, maxNum) + assert.Equal(t, int32(30), cur) + assert.Equal(t, int32(2), int32(len(tasks))) + assert.Equal(t, 98, policy.len()) + + targetUsage = 300 + maxNum = 0 + tasks, cur = policy.schedule(targetUsage, maxNum) + assert.Equal(t, int32(0), cur) + assert.Equal(t, 0, len(tasks)) + assert.Equal(t, 98, policy.len()) + + targetUsage = 0 + maxNum = 0 + tasks, cur = policy.schedule(targetUsage, maxNum) + assert.Equal(t, int32(0), cur) + assert.Equal(t, 0, len(tasks)) + assert.Equal(t, 98, policy.len()) + + targetUsage = 0 + maxNum = 300 + tasks, cur = policy.schedule(targetUsage, maxNum) + assert.Equal(t, int32(0), cur) + assert.Equal(t, 0, len(tasks)) + assert.Equal(t, 98, policy.len()) + + actual := int32(180) // sum(3..6) * 10 3 + 4 + 5 + 6 + targetUsage = int32(190) // > actual + maxNum = math.MaxInt32 + tasks, cur = policy.schedule(targetUsage, maxNum) + assert.Equal(t, actual, cur) + assert.Equal(t, 4, len(tasks)) + assert.Equal(t, 94, policy.len()) + + actual = 340 // sum(7..10) * 10 , 7+ 8 + 9 + 10 + targetUsage = 340 + maxNum = 4 + tasks, cur = policy.schedule(targetUsage, maxNum) + assert.Equal(t, actual, cur) + assert.Equal(t, 4, len(tasks)) + assert.Equal(t, 90, policy.len()) + + actual = 4995 * 10 // sum(11..100) + targetUsage = actual + maxNum = 90 + tasks, cur = policy.schedule(targetUsage, maxNum) + assert.Equal(t, actual, cur) + assert.Equal(t, 90, len(tasks)) + assert.Equal(t, 0, policy.len()) + + time.Sleep(time.Minute + time.Second) + policy.addTask(&mockReadTask{ + cpuUsage: int32(1), + mockTask: mockTask{ + baseTask: baseTask{ + ctx: getContextWithAuthorization(context.Background(), fmt.Sprintf("user%d:123456", 11)), + }, + }, + }) + + tasks, cur = policy.schedule(targetUsage, maxNum) + assert.Equal(t, int32(1), cur) + assert.Equal(t, 1, len(tasks)) + assert.Equal(t, 0, policy.len()) + policyInner := policy.(*userTaskPollingScheduleReadPolicy) + assert.Equal(t, 1, len(policyInner.route)) + assert.Equal(t, 1, policyInner.checkpoint.Len()) +} + +func Test_userBasedTaskQueue(t *testing.T) { + n := 50 + q := newUserBasedTaskQueue("test_user") + for i := 1; i <= n; i++ { + q.push(&mockReadTask{ + cpuUsage: int32(i * 10), + mockTask: mockTask{ + baseTask: baseTask{ + ctx: getContextWithAuthorization(context.Background(), "default:123456"), + }, + }, + }) + assert.Equal(t, q.len(), i) + assert.Equal(t, q.expire(time.Second), false) + } + + for i := 0; i < n; i++ { + q.pop() + assert.Equal(t, q.len(), n-(i+1)) + assert.Equal(t, q.expire(time.Second), false) + } + + time.Sleep(time.Second) + assert.Equal(t, q.expire(time.Second), true) +} + +func testBasicScheduleReadPolicy(t *testing.T, policy scheduleReadPolicy) { + // test, push and schedule. + for i := 1; i <= 50; i++ { + cpuUsage := int32(i * 10) + id := 1 + + policy.addTask(&mockReadTask{ + cpuUsage: cpuUsage, + mockTask: mockTask{ + baseTask: baseTask{ + ctx: getContextWithAuthorization(context.Background(), "default:123456"), + id: UniqueID(id), + }, + }, + }) + assert.Equal(t, policy.len(), 1) + task, cost := policy.schedule(cpuUsage, 1) + assert.Equal(t, cost, cpuUsage) + assert.Equal(t, len(task), 1) + assert.Equal(t, task[0].ID(), int64(id)) + } + + // test, can not merge and schedule. + cpuUsage := int32(100) + notMergeTask := &mockReadTask{ + cpuUsage: cpuUsage, + canMerge: false, + mockTask: mockTask{ + baseTask: baseTask{ + ctx: getContextWithAuthorization(context.Background(), "default:123456"), + }, + }, + } + assert.False(t, policy.mergeTask(notMergeTask)) + policy.addTask(notMergeTask) + assert.Equal(t, policy.len(), 1) + task2 := &mockReadTask{ + cpuUsage: cpuUsage, + canMerge: false, + mockTask: mockTask{ + baseTask: baseTask{ + ctx: getContextWithAuthorization(context.Background(), "default:123456"), + }, + }, + } + assert.False(t, policy.mergeTask(task2)) + assert.Equal(t, policy.len(), 1) + policy.addTask(notMergeTask) + assert.Equal(t, policy.len(), 2) + task, cost := policy.schedule(2*cpuUsage, 1) + assert.Equal(t, cost, cpuUsage) + assert.Equal(t, len(task), 1) + assert.Equal(t, policy.len(), 1) + assert.False(t, task[0].(*mockReadTask).merged) + task, cost = policy.schedule(2*cpuUsage, 1) + assert.Equal(t, cost, cpuUsage) + assert.Equal(t, len(task), 1) + assert.Equal(t, policy.len(), 0) + assert.False(t, task[0].(*mockReadTask).merged) + + // test, can merge and schedule. + mergeTask := &mockReadTask{ + cpuUsage: cpuUsage, + canMerge: true, + mockTask: mockTask{ + baseTask: baseTask{ + ctx: getContextWithAuthorization(context.Background(), "default:123456"), + }, + }, + } + policy.addTask(mergeTask) + task2 = &mockReadTask{ + cpuUsage: cpuUsage, + mockTask: mockTask{ + baseTask: baseTask{ + ctx: getContextWithAuthorization(context.Background(), "default:123456"), + }, + }, + } + assert.True(t, policy.mergeTask(task2)) + assert.Equal(t, policy.len(), 1) + task, cost = policy.schedule(cpuUsage, 1) + assert.Equal(t, cost, cpuUsage) + assert.Equal(t, len(task), 1) + assert.Equal(t, policy.len(), 0) + assert.True(t, task[0].(*mockReadTask).merged) +} + +func getContextWithAuthorization(ctx context.Context, originValue string) context.Context { + authKey := strings.ToLower(util.HeaderAuthorize) + authValue := crypto.Base64Encode(originValue) + contextMap := map[string]string{ + authKey: authValue, + } + md := metadata.New(contextMap) + return metadata.NewIncomingContext(ctx, md) } diff --git a/internal/querynode/shard_cluster.go b/internal/querynode/shard_cluster.go index 87804fa35b..f320dfde8f 100644 --- a/internal/querynode/shard_cluster.go +++ b/internal/querynode/shard_cluster.go @@ -33,6 +33,7 @@ import ( "github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/contextutil" "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/samber/lo" @@ -127,9 +128,11 @@ type ShardSegmentDetector interface { type ShardNodeBuilder func(nodeID int64, addr string) shardQueryNode // withStreaming function type to let search detects corresponding search streaming is done. -type searchWithStreaming func(ctx context.Context) (error, *internalpb.SearchResults) -type queryWithStreaming func(ctx context.Context) (error, *internalpb.RetrieveResults) -type getStatisticsWithStreaming func(ctx context.Context) (error, *internalpb.GetStatisticsResponse) +type ( + searchWithStreaming func(ctx context.Context) (error, *internalpb.SearchResults) + queryWithStreaming func(ctx context.Context) (error, *internalpb.RetrieveResults) + getStatisticsWithStreaming func(ctx context.Context) (error, *internalpb.GetStatisticsResponse) +) // ShardCluster maintains the ShardCluster information and perform shard level operations type ShardCluster struct { @@ -156,7 +159,8 @@ type ShardCluster struct { // NewShardCluster create a ShardCluster with provided information. func NewShardCluster(collectionID int64, replicaID int64, vchannelName string, version int64, - nodeDetector ShardNodeDetector, segmentDetector ShardSegmentDetector, nodeBuilder ShardNodeBuilder) *ShardCluster { + nodeDetector ShardNodeDetector, segmentDetector ShardSegmentDetector, nodeBuilder ShardNodeBuilder, +) *ShardCluster { sc := &ShardCluster{ state: atomic.NewInt32(int32(unavailable)), @@ -933,6 +937,13 @@ func getSearchWithStreamingFunc(searchCtx context.Context, req *querypb.SearchRe metrics.SearchLabel).Observe(float64(streamingTask.queueDur.Milliseconds())) metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID), metrics.SearchLabel).Observe(float64(streamingTask.reduceDur.Milliseconds())) + // In queue latency per user. + metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues( + fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), + metrics.SearchLabel, + contextutil.GetUserFromGrpcMetadata(streamingTask.Ctx()), + ).Observe(float64(streamingTask.queueDur.Milliseconds())) + return nil, streamingTask.Ret } } @@ -943,7 +954,6 @@ func getQueryWithStreamingFunc(queryCtx context.Context, req *querypb.QueryReque streamingTask.DataScope = querypb.DataScope_Streaming streamingTask.QS = qs err := node.scheduler.AddReadTask(queryCtx, streamingTask) - if err != nil { return err, nil } @@ -955,6 +965,13 @@ func getQueryWithStreamingFunc(queryCtx context.Context, req *querypb.QueryReque metrics.QueryLabel).Observe(float64(streamingTask.queueDur.Milliseconds())) metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID), metrics.QueryLabel).Observe(float64(streamingTask.reduceDur.Milliseconds())) + + // In queue latency per user. + metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues( + fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), + metrics.QueryLabel, + contextutil.GetUserFromGrpcMetadata(streamingTask.Ctx()), + ).Observe(float64(streamingTask.queueDur.Milliseconds())) return nil, streamingTask.Ret } } diff --git a/internal/querynode/task_scheduler.go b/internal/querynode/task_scheduler.go index 5acee0946f..409c67c252 100644 --- a/internal/querynode/task_scheduler.go +++ b/internal/querynode/task_scheduler.go @@ -40,7 +40,6 @@ type taskScheduler struct { // for search and query start unsolvedReadTasks *list.List - readyReadTasks *list.List receiveReadTaskChan chan readTask executeReadTaskChan chan readTask @@ -50,7 +49,8 @@ type taskScheduler struct { // tSafeReplica tSafeReplica TSafeReplicaInterface - schedule scheduleReadTaskPolicy + // schedule policy for ready read task. + readySchedulePolicy scheduleReadPolicy // for search and query end cpuUsage int32 // 1200 means 1200% 12 cores @@ -77,13 +77,12 @@ func newTaskScheduler(ctx context.Context, tSafeReplica TSafeReplicaInterface) * ctx: ctx1, cancel: cancel, unsolvedReadTasks: list.New(), - readyReadTasks: list.New(), receiveReadTaskChan: make(chan readTask, Params.QueryNodeCfg.MaxReceiveChanSize), executeReadTaskChan: make(chan readTask, maxExecuteReadChanLen), notifyChan: make(chan struct{}, 1), tSafeReplica: tSafeReplica, maxCPUUsage: int32(getNumCPU() * 100), - schedule: defaultScheduleReadPolicy, + readySchedulePolicy: newReadScheduleTaskPolicy(Params.QueryNodeCfg.ScheduleReadPolicy.Name), } s.queue = newQueryNodeTaskQueue(s) return s @@ -248,7 +247,7 @@ func (s *taskScheduler) AddReadTask(ctx context.Context, t readTask) error { func (s *taskScheduler) popAndAddToExecute() { readConcurrency := atomic.LoadInt32(&s.readConcurrency) metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(readConcurrency)) - if s.readyReadTasks.Len() == 0 { + if s.readySchedulePolicy.len() == 0 { return } curUsage := atomic.LoadInt32(&s.cpuUsage) @@ -266,7 +265,7 @@ func (s *taskScheduler) popAndAddToExecute() { return } - tasks, deltaUsage := s.schedule(s.readyReadTasks, targetUsage, remain) + tasks, deltaUsage := s.readySchedulePolicy.schedule(targetUsage, remain) atomic.AddInt32(&s.cpuUsage, deltaUsage) for _, t := range tasks { s.executeReadTaskChan <- t @@ -366,23 +365,12 @@ func (s *taskScheduler) tryMergeReadTasks() { } if ready { if !Params.QueryNodeCfg.GroupEnabled { - s.readyReadTasks.PushBack(t) + s.readySchedulePolicy.addTask(t) rateCol.rtCounter.add(t, readyQueueType) } else { - merged := false - for m := s.readyReadTasks.Back(); m != nil; m = m.Prev() { - mTask, ok := m.Value.(readTask) - if !ok { - continue - } - if mTask.CanMergeWith(t) { - mTask.Merge(t) - merged = true - break - } - } - if !merged { - s.readyReadTasks.PushBack(t) + // Try to merge task first, otherwise add it. + if !s.readySchedulePolicy.mergeTask(t) { + s.readySchedulePolicy.addTask(t) rateCol.rtCounter.add(t, readyQueueType) } } @@ -391,5 +379,5 @@ func (s *taskScheduler) tryMergeReadTasks() { } } metrics.QueryNodeReadTaskUnsolveLen.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(s.unsolvedReadTasks.Len())) - metrics.QueryNodeReadTaskReadyLen.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(s.readyReadTasks.Len())) + metrics.QueryNodeReadTaskReadyLen.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(s.readySchedulePolicy.len())) } diff --git a/internal/querynode/task_scheduler_test.go b/internal/querynode/task_scheduler_test.go index e47ad0fa0e..25e871efd3 100644 --- a/internal/querynode/task_scheduler_test.go +++ b/internal/querynode/task_scheduler_test.go @@ -71,6 +71,7 @@ type mockReadTask struct { timeoutError error step typeutil.TaskStep readyError error + merged bool } func (m *mockReadTask) GetCollectionID() UniqueID { @@ -82,7 +83,7 @@ func (m *mockReadTask) Ready() (bool, error) { } func (m *mockReadTask) Merge(o readTask) { - + m.merged = true } func (m *mockReadTask) CPUUsage() int32 { diff --git a/internal/util/constant.go b/internal/util/constant.go index d941b2b873..08d598bff4 100644 --- a/internal/util/constant.go +++ b/internal/util/constant.go @@ -38,6 +38,8 @@ const ( HeaderAuthorize = "authorization" HeaderDBName = "dbName" + // HeaderUser is the key of username in metadata of grpc. + HeaderUser = "user" // HeaderSourceID identify requests from Milvus members and client requests HeaderSourceID = "sourceId" // MemberCredID id for Milvus members (data/index/query node/coord component) diff --git a/internal/util/contextutil/context_util.go b/internal/util/contextutil/context_util.go index 6c9911f4c7..755389385a 100644 --- a/internal/util/contextutil/context_util.go +++ b/internal/util/contextutil/context_util.go @@ -16,7 +16,14 @@ package contextutil -import "context" +import ( + "context" + "strings" + + "github.com/milvus-io/milvus/internal/util" + "github.com/milvus-io/milvus/internal/util/crypto" + "google.golang.org/grpc/metadata" +) type ctxTenantKey struct{} @@ -37,3 +44,47 @@ func TenantID(ctx context.Context) string { return "" } + +// Passthrough "user" field in grpc from incoming to outgoing. +func PassthroughUserInGrpcMetadata(ctx context.Context) context.Context { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx + } + user := md[strings.ToLower(util.HeaderUser)] + if len(user) == 0 || len(user[0]) == 0 { + return ctx + } + return metadata.AppendToOutgoingContext(ctx, util.HeaderUser, user[0]) +} + +// Set "user" field in grpc outgoing metadata +func WithUserInGrpcMetadata(ctx context.Context, user string) context.Context { + return metadata.AppendToOutgoingContext(ctx, util.HeaderUser, crypto.Base64Encode(user)) +} + +// Get "user" field from grpc metadata, empty string will returned if not set. +func GetUserFromGrpcMetadata(ctx context.Context) string { + if user := getUserFromGrpcMetadataAux(ctx, metadata.FromIncomingContext); user != "" { + return user + } + return getUserFromGrpcMetadataAux(ctx, metadata.FromOutgoingContext) +} + +// Aux function for `GetUserFromGrpc` +func getUserFromGrpcMetadataAux(ctx context.Context, mdGetter func(ctx context.Context) (metadata.MD, bool)) string { + md, ok := mdGetter(ctx) + if !ok { + return "" + } + user := md[strings.ToLower(util.HeaderUser)] + if len(user) == 0 || len(user[0]) == 0 { + return "" + } + // It may be duplicated in meta, but should be always same. + rawuser, err := crypto.Base64Decode(user[0]) + if err != nil { + return "" + } + return rawuser +} diff --git a/internal/util/contextutil/context_util_test.go b/internal/util/contextutil/context_util_test.go new file mode 100644 index 0000000000..967698f964 --- /dev/null +++ b/internal/util/contextutil/context_util_test.go @@ -0,0 +1,29 @@ +package contextutil + +import ( + "context" + "testing" + + "github.com/milvus-io/milvus/internal/util" + "github.com/milvus-io/milvus/internal/util/crypto" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" +) + +// Test UserInGrpcMetadata related function +func Test_UserInGrpcMetadata(t *testing.T) { + testUser := "test_user_131" + ctx := context.Background() + assert.Equal(t, GetUserFromGrpcMetadata(ctx), "") + + ctx = WithUserInGrpcMetadata(ctx, testUser) + assert.Equal(t, GetUserFromGrpcMetadata(ctx), testUser) + + md := metadata.Pairs(util.HeaderUser, crypto.Base64Encode(testUser)) + ctx = metadata.NewIncomingContext(context.Background(), md) + assert.Equal(t, getUserFromGrpcMetadataAux(ctx, metadata.FromIncomingContext), testUser) + + ctx = PassthroughUserInGrpcMetadata(ctx) + assert.Equal(t, getUserFromGrpcMetadataAux(ctx, metadata.FromOutgoingContext), testUser) + assert.Equal(t, GetUserFromGrpcMetadata(ctx), testUser) +} diff --git a/internal/util/paramtable/component_param.go b/internal/util/paramtable/component_param.go index dfa3b845bd..f9f28c8e90 100644 --- a/internal/util/paramtable/component_param.go +++ b/internal/util/paramtable/component_param.go @@ -34,11 +34,11 @@ const ( // DefaultIndexSliceSize defines the default slice size of index file when serializing. DefaultIndexSliceSize = 16 - DefaultGracefulTime = 5000 //ms + DefaultGracefulTime = 5000 // ms DefaultGracefulStopTimeout = 30 // s DefaultThreadCoreCoefficient = 10 - DefaultSessionTTL = 20 //s + DefaultSessionTTL = 20 // s DefaultSessionRetryTimes = 30 DefaultMaxDegree = 56 @@ -1117,6 +1117,9 @@ type queryNodeConfig struct { CPURatio float64 MaxTimestampLag time.Duration + // schedule + ScheduleReadPolicy queryNodeConfigScheduleReadPolicy + GCHelperEnabled bool MinimumGOGCConfig int MaximumGOGCConfig int @@ -1124,6 +1127,15 @@ type queryNodeConfig struct { GracefulStopTimeout int64 } +type queryNodeConfigScheduleReadPolicy struct { + Name string + UserTaskPolling queryNodeConfigScheduleUserTaskPolling +} + +type queryNodeConfigScheduleUserTaskPolling struct { + TaskQueueExpire time.Duration +} + func (p *queryNodeConfig) init(base *BaseTable) { p.Base = base p.NodeID.Store(UniqueID(0)) @@ -1153,6 +1165,9 @@ func (p *queryNodeConfig) init(base *BaseTable) { p.initDiskCapacity() p.initMaxDiskUsagePercentage() + // Initialize scheduler. + p.initScheduleReadPolicy() + p.initGCTunerEnbaled() p.initMaximumGOGC() p.initMinimumGOGC() @@ -1346,6 +1361,11 @@ func (p *queryNodeConfig) initDiskCapacity() { p.DiskCapacityLimit = diskSize * 1024 * 1024 * 1024 } +func (p *queryNodeConfig) initScheduleReadPolicy() { + p.ScheduleReadPolicy.Name = p.Base.LoadWithDefault("queryNode.scheduler.scheduleReadPolicy.name", "fifo") + p.ScheduleReadPolicy.UserTaskPolling.TaskQueueExpire = time.Duration(p.Base.ParseInt64WithDefault("queryNode.scheduler.scheduleReadPolicy.taskQueueExpire", 60)) * time.Second +} + func (p *queryNodeConfig) initGCTunerEnbaled() { p.GCHelperEnabled = p.Base.ParseBool("queryNode.gchelper.enabled", true) } @@ -1514,7 +1534,6 @@ func (p *dataCoordConfig) initSegmentMinSizeFromIdleToSealed() { func (p *dataCoordConfig) initSegmentExpansionRate() { p.SegmentExpansionRate = p.Base.ParseFloatWithDefault("dataCoord.segment.expansionRate", 1.25) log.Info("init segment expansion rate", zap.Float64("value", p.SegmentExpansionRate)) - } func (p *dataCoordConfig) initSegmentMaxBinlogFileNumber() { @@ -1647,7 +1666,7 @@ type dataNodeConfig struct { Base *BaseTable // ID of the current node - //NodeID atomic.Value + // NodeID atomic.Value NodeID atomic.Value // IP of the current DataNode IP string