mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Support configurable policy for query node, Add user level schedule policy (#23718)
Signed-off-by: chyezh <ye.zhen@zilliz.com>
This commit is contained in:
parent
bc86aa666f
commit
20054dc42b
@ -242,6 +242,18 @@ queryNode:
|
||||
# if the lag is larger than this config, scheduler will return error without waiting.
|
||||
# the valid value is [3600, infinite)
|
||||
maxTimestampLag: 86400
|
||||
# read task schedule policy: fifo(by default), user-task-polling.
|
||||
scheduleReadPolicy:
|
||||
# fifo: A FIFO queue support the schedule.
|
||||
# user-task-polling:
|
||||
# The user's tasks will be polled one by one and scheduled.
|
||||
# Scheduling is fair on task granularity.
|
||||
# The policy is based on the username for authentication.
|
||||
# And an empty username is considered the same user.
|
||||
# When there are no multi-users, the policy decay into FIFO
|
||||
name: fifo
|
||||
# user-task-polling configure:
|
||||
taskQueueExpire: 60 # 1 min by default, expire time of inner user task queue since queue is empty.
|
||||
|
||||
grouping:
|
||||
enabled: true
|
||||
|
||||
@ -149,6 +149,20 @@ var (
|
||||
queryTypeLabelName,
|
||||
})
|
||||
|
||||
QueryNodeSQPerUserLatencyInQueue = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Namespace: milvusNamespace,
|
||||
Subsystem: typeutil.QueryNodeRole,
|
||||
Name: "sq_queue_user_latency",
|
||||
Help: "latency per user of search or query in queue",
|
||||
Buckets: buckets,
|
||||
}, []string{
|
||||
nodeIDLabelName,
|
||||
queryTypeLabelName,
|
||||
usernameLabelName,
|
||||
},
|
||||
)
|
||||
|
||||
QueryNodeSQSegmentLatency = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Namespace: milvusNamespace,
|
||||
@ -380,6 +394,7 @@ func RegisterQueryNode(registry *prometheus.Registry) {
|
||||
registry.MustRegister(QueryNodeSQCount)
|
||||
registry.MustRegister(QueryNodeSQReqLatency)
|
||||
registry.MustRegister(QueryNodeSQLatencyInQueue)
|
||||
registry.MustRegister(QueryNodeSQPerUserLatencyInQueue)
|
||||
registry.MustRegister(QueryNodeSQSegmentLatency)
|
||||
registry.MustRegister(QueryNodeSQSegmentLatencyInCore)
|
||||
registry.MustRegister(QueryNodeReduceLatency)
|
||||
@ -422,5 +437,4 @@ func CleanupQueryNodeCollectionMetrics(nodeID int64, collectionID int64) {
|
||||
collectionIDLabelName: fmt.Sprint(collectionID),
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -41,7 +41,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
ctx = GetContext(context.Background(), "alice:123456")
|
||||
ctx = getContextWithAuthorization(context.Background(), "alice:123456")
|
||||
client := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
@ -65,13 +65,13 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "foo:123456"), &milvuspb.LoadCollectionRequest{
|
||||
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "foo:123456"), &milvuspb.LoadCollectionRequest{
|
||||
DbName: "db_test",
|
||||
CollectionName: "col1",
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "root:123456"), &milvuspb.LoadCollectionRequest{
|
||||
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "root:123456"), &milvuspb.LoadCollectionRequest{
|
||||
DbName: "db_test",
|
||||
CollectionName: "col1",
|
||||
})
|
||||
@ -103,7 +103,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
fooCtx := GetContext(context.Background(), "foo:123456")
|
||||
fooCtx := getContextWithAuthorization(context.Background(), "foo:123456")
|
||||
_, err = PrivilegeInterceptor(fooCtx, &milvuspb.LoadCollectionRequest{
|
||||
DbName: "db_test",
|
||||
CollectionName: "col1",
|
||||
@ -130,7 +130,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
|
||||
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
|
||||
DbName: "db_test",
|
||||
CollectionName: "col1",
|
||||
})
|
||||
@ -148,7 +148,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
||||
go func() {
|
||||
defer g.Done()
|
||||
assert.NotPanics(t, func() {
|
||||
PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
|
||||
PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.LoadCollectionRequest{
|
||||
DbName: "db_test",
|
||||
CollectionName: "col1",
|
||||
})
|
||||
@ -161,7 +161,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
||||
getPolicyModel("foo")
|
||||
})
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestResourceGroupPrivilege(t *testing.T) {
|
||||
@ -173,7 +172,7 @@ func TestResourceGroupPrivilege(t *testing.T) {
|
||||
_, err := PrivilegeInterceptor(ctx, &milvuspb.ListResourceGroupsRequest{})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
ctx = GetContext(context.Background(), "fooo:123456")
|
||||
ctx = getContextWithAuthorization(context.Background(), "fooo:123456")
|
||||
client := &MockRootCoordClientInterface{}
|
||||
queryCoord := &MockQueryCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
@ -198,29 +197,28 @@ func TestResourceGroupPrivilege(t *testing.T) {
|
||||
}
|
||||
InitMetaCache(ctx, client, queryCoord, mgr)
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.CreateResourceGroupRequest{
|
||||
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.CreateResourceGroupRequest{
|
||||
ResourceGroup: "rg",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.DropResourceGroupRequest{
|
||||
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.DropResourceGroupRequest{
|
||||
ResourceGroup: "rg",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.DescribeResourceGroupRequest{
|
||||
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.DescribeResourceGroupRequest{
|
||||
ResourceGroup: "rg",
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.ListResourceGroupsRequest{})
|
||||
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.ListResourceGroupsRequest{})
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.TransferNodeRequest{})
|
||||
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.TransferNodeRequest{})
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.TransferReplicaRequest{})
|
||||
_, err = PrivilegeInterceptor(getContextWithAuthorization(context.Background(), "fooo:123456"), &milvuspb.TransferReplicaRequest{})
|
||||
assert.Nil(t, err)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@ -45,7 +45,7 @@ func TestProxyRpcLimit(t *testing.T) {
|
||||
t.Setenv("ROCKSMQ_PATH", path)
|
||||
defer os.RemoveAll(path)
|
||||
|
||||
ctx := GetContext(context.Background(), "root:123456")
|
||||
ctx := getContextWithAuthorization(context.Background(), "root:123456")
|
||||
localMsg := true
|
||||
factory := dependency.NewDefaultFactory(localMsg)
|
||||
|
||||
|
||||
@ -353,12 +353,12 @@ func (s *proxyTestServer) GetStatisticsChannel(ctx context.Context, request *int
|
||||
func (s *proxyTestServer) startGrpc(ctx context.Context, wg *sync.WaitGroup, p paramtable.GrpcServerConfig) {
|
||||
defer wg.Done()
|
||||
|
||||
var kaep = keepalive.EnforcementPolicy{
|
||||
kaep := keepalive.EnforcementPolicy{
|
||||
MinTime: 5 * time.Second, // If a client pings more than once every 5 seconds, terminate the connection
|
||||
PermitWithoutStream: true, // Allow pings even when there are no active streams
|
||||
}
|
||||
|
||||
var kasp = keepalive.ServerParameters{
|
||||
kasp := keepalive.ServerParameters{
|
||||
Time: 60 * time.Second, // Ping the client if it is idle for 60 seconds to ensure the connection is still active
|
||||
Timeout: 10 * time.Second, // Wait 10 second for the ping ack before assuming the connection is dead
|
||||
}
|
||||
@ -425,7 +425,7 @@ func TestProxy(t *testing.T) {
|
||||
defer os.RemoveAll(path)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx = GetContext(ctx, "root:123456")
|
||||
ctx = getContextWithAuthorization(ctx, "root:123456")
|
||||
localMsg := true
|
||||
Params.InitOnce()
|
||||
factory := dependency.MockDefaultFactory(localMsg, &Params)
|
||||
@ -758,7 +758,6 @@ func TestProxy(t *testing.T) {
|
||||
resp, err = proxy.CreateCollection(ctx, reqInvalidField)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
@ -842,7 +841,6 @@ func TestProxy(t *testing.T) {
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
@ -1423,7 +1421,7 @@ func TestProxy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
//resp, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{
|
||||
// resp, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{
|
||||
_, err := proxy.CalcDistance(ctx, &milvuspb.CalcDistanceRequest{
|
||||
Base: nil,
|
||||
OpLeft: opLeft,
|
||||
@ -2109,7 +2107,7 @@ func TestProxy(t *testing.T) {
|
||||
t.Run("credential UPDATE api", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
rootCtx := ctx
|
||||
fooCtx := GetContext(context.Background(), "foo:123456")
|
||||
fooCtx := getContextWithAuthorization(context.Background(), "foo:123456")
|
||||
ctx = fooCtx
|
||||
originUsers := Params.CommonCfg.SuperUsers
|
||||
Params.CommonCfg.SuperUsers = []string{"root"}
|
||||
@ -3863,7 +3861,6 @@ func TestProxy_ListImportTasks(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProxy_GetStatistics(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestProxy_GetLoadState(t *testing.T) {
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/proto/planpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/contextutil"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
"github.com/milvus-io/milvus/internal/util/timerecord"
|
||||
@ -229,7 +230,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
|
||||
log.Ctx(ctx).Debug("Validate partition names.",
|
||||
zap.Int64("msgID", t.ID()), zap.Any("requestType", "query"))
|
||||
|
||||
//fetch search_growing from search param
|
||||
// fetch search_growing from search param
|
||||
var ignoreGrowing bool
|
||||
for i, kv := range t.request.GetQueryParams() {
|
||||
if kv.GetKey() == IgnoreGrowingKey {
|
||||
@ -358,6 +359,11 @@ func (t *queryTask) Execute(ctx context.Context) error {
|
||||
defer tr.CtxElapse(ctx, "done")
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
// Add user name into context if it's exists.
|
||||
if username, _ := GetCurUserFromContext(ctx); username != "" {
|
||||
ctx = contextutil.WithUserInGrpcMetadata(ctx, username)
|
||||
}
|
||||
|
||||
executeQuery := func() error {
|
||||
shards, err := globalMetaCache.GetShards(ctx, true, t.collectionName)
|
||||
if err != nil {
|
||||
|
||||
@ -23,6 +23,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/querynode"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/internal/util/contextutil"
|
||||
"github.com/milvus-io/milvus/internal/util/distance"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/retry"
|
||||
@ -255,7 +256,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
log.Ctx(ctx).Debug("translate output fields", zap.Int64("msgID", t.ID()),
|
||||
zap.Strings("output fields", t.request.GetOutputFields()))
|
||||
|
||||
//fetch search_growing from search param
|
||||
// fetch search_growing from search param
|
||||
var ignoreGrowing bool
|
||||
for i, kv := range t.request.GetSearchParams() {
|
||||
if kv.GetKey() == IgnoreGrowingKey {
|
||||
@ -380,6 +381,11 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
||||
defer tr.CtxElapse(ctx, "done")
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
// Add user name into context if it's exists.
|
||||
if username, _ := GetCurUserFromContext(ctx); username != "" {
|
||||
ctx = contextutil.WithUserInGrpcMetadata(ctx, username)
|
||||
}
|
||||
|
||||
executeSearch := func() error {
|
||||
shard2Leaders, err := globalMetaCache.GetShards(ctx, true, t.collectionName)
|
||||
if err != nil {
|
||||
@ -723,7 +729,7 @@ func reduceSearchResultData(ctx context.Context, subSearchResultData []*schemapb
|
||||
log.Ctx(ctx).Warn("invalid search results", zap.Error(err))
|
||||
return ret, err
|
||||
}
|
||||
//printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
|
||||
// printSearchResultData(sData, strconv.FormatInt(int64(i), 10))
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@ -711,7 +711,7 @@ func TestIsDefaultRole(t *testing.T) {
|
||||
assert.Equal(t, false, IsDefaultRole("manager"))
|
||||
}
|
||||
|
||||
func GetContext(ctx context.Context, originValue string) context.Context {
|
||||
func getContextWithAuthorization(ctx context.Context, originValue string) context.Context {
|
||||
authKey := strings.ToLower(util.HeaderAuthorize)
|
||||
authValue := crypto.Base64Encode(originValue)
|
||||
contextMap := map[string]string{
|
||||
@ -740,12 +740,12 @@ func TestGetCurUserFromContext(t *testing.T) {
|
||||
_, err = GetCurUserFromContext(metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{})))
|
||||
assert.NotNil(t, err)
|
||||
|
||||
_, err = GetCurUserFromContext(GetContext(context.Background(), "123456"))
|
||||
_, err = GetCurUserFromContext(getContextWithAuthorization(context.Background(), "123456"))
|
||||
assert.NotNil(t, err)
|
||||
|
||||
root := "root"
|
||||
password := "123456"
|
||||
username, err := GetCurUserFromContext(GetContext(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password)))
|
||||
username, err := GetCurUserFromContext(getContextWithAuthorization(context.Background(), fmt.Sprintf("%s%s%s", root, util.CredentialSeperator, password)))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "root", username)
|
||||
}
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/internal/util/contextutil"
|
||||
"github.com/milvus-io/milvus/internal/util/errorutil"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
@ -875,13 +876,19 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
|
||||
metrics.SearchLabel).Observe(float64(historicalTask.queueDur.Milliseconds()))
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
|
||||
metrics.SearchLabel).Observe(float64(historicalTask.reduceDur.Milliseconds()))
|
||||
// In queue latency per user.
|
||||
metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues(
|
||||
fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
|
||||
metrics.SearchLabel,
|
||||
contextutil.GetUserFromGrpcMetadata(historicalTask.Ctx()),
|
||||
).Observe(float64(historicalTask.queueDur.Milliseconds()))
|
||||
latency := tr.ElapseSpan()
|
||||
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel).Inc()
|
||||
return historicalTask.Ret, nil
|
||||
}
|
||||
|
||||
//from Proxy
|
||||
// from Proxy
|
||||
tr := timerecord.NewTimeRecorder("SearchShard")
|
||||
log.Ctx(ctx).Debug("start do search",
|
||||
zap.String("vChannel", dmlChannel),
|
||||
@ -900,6 +907,8 @@ func (node *QueryNode) searchWithDmlChannel(ctx context.Context, req *querypb.Se
|
||||
errCluster error
|
||||
)
|
||||
defer cancel()
|
||||
// Passthrough username across grpc.
|
||||
searchCtx = contextutil.PassthroughUserInGrpcMetadata(searchCtx)
|
||||
|
||||
// optimize search params
|
||||
if req.Req.GetSerializedExprPlan() != nil && node.queryHook != nil {
|
||||
@ -1046,6 +1055,12 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
|
||||
metrics.QueryLabel).Observe(float64(queryTask.queueDur.Milliseconds()))
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
|
||||
metrics.QueryLabel).Observe(float64(queryTask.reduceDur.Milliseconds()))
|
||||
// In queue latency per user.
|
||||
metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues(
|
||||
fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
|
||||
metrics.QueryLabel,
|
||||
contextutil.GetUserFromGrpcMetadata(queryTask.Ctx()),
|
||||
).Observe(float64(queryTask.queueDur.Milliseconds()))
|
||||
latency := tr.ElapseSpan()
|
||||
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
|
||||
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel).Inc()
|
||||
@ -1063,6 +1078,9 @@ func (node *QueryNode) queryWithDmlChannel(ctx context.Context, req *querypb.Que
|
||||
queryCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Passthrough username across grpc.
|
||||
queryCtx = contextutil.PassthroughUserInGrpcMetadata(queryCtx)
|
||||
|
||||
var results []*internalpb.RetrieveResults
|
||||
var errCluster error
|
||||
var withStreamingFunc queryWithStreaming
|
||||
|
||||
@ -2,23 +2,313 @@ package querynode
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"container/ring"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/contextutil"
|
||||
)
|
||||
|
||||
type scheduleReadTaskPolicy func(sqTasks *list.List, targetUsage int32, maxNum int32) ([]readTask, int32)
|
||||
const (
|
||||
scheduleReadPolicyNameFIFO = "fifo"
|
||||
scheduleReadPolicyNameUserTaskPolling = "user-task-polling"
|
||||
)
|
||||
|
||||
func defaultScheduleReadPolicy(sqTasks *list.List, targetUsage int32, maxNum int32) ([]readTask, int32) {
|
||||
var (
|
||||
_ scheduleReadPolicy = &fifoScheduleReadPolicy{}
|
||||
_ scheduleReadPolicy = &userTaskPollingScheduleReadPolicy{}
|
||||
)
|
||||
|
||||
type scheduleReadPolicy interface {
|
||||
// Read task count inside.
|
||||
len() int
|
||||
|
||||
// Add a new task into ready task.
|
||||
addTask(rt readTask)
|
||||
|
||||
// Merge a new task into exist ready task.
|
||||
// Return true if merge success.
|
||||
mergeTask(rt readTask) bool
|
||||
|
||||
// Schedule new task to run.
|
||||
schedule(targetCPUUsage int32, maxNum int32) ([]readTask, int32)
|
||||
}
|
||||
|
||||
// Create a new schedule policy.
|
||||
func newReadScheduleTaskPolicy(policyName string) scheduleReadPolicy {
|
||||
switch policyName {
|
||||
case "":
|
||||
fallthrough
|
||||
case scheduleReadPolicyNameFIFO:
|
||||
return newFIFOScheduleReadPolicy()
|
||||
case scheduleReadPolicyNameUserTaskPolling:
|
||||
return newUserTaskPollingScheduleReadPolicy(
|
||||
Params.QueryNodeCfg.ScheduleReadPolicy.UserTaskPolling.TaskQueueExpire)
|
||||
default:
|
||||
panic("invalid read schedule task policy")
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new user based task queue.
|
||||
func newUserBasedTaskQueue(username string) *userBasedTaskQueue {
|
||||
return &userBasedTaskQueue{
|
||||
cleanupTimestamp: time.Now(),
|
||||
queue: list.New(),
|
||||
username: username,
|
||||
}
|
||||
}
|
||||
|
||||
// User based task queue.
|
||||
type userBasedTaskQueue struct {
|
||||
cleanupTimestamp time.Time
|
||||
queue *list.List
|
||||
username string
|
||||
}
|
||||
|
||||
// Get length of user task
|
||||
func (q *userBasedTaskQueue) len() int {
|
||||
return q.queue.Len()
|
||||
}
|
||||
|
||||
// Add a new task to end of queue.
|
||||
func (q *userBasedTaskQueue) push(t readTask) {
|
||||
q.queue.PushBack(t)
|
||||
}
|
||||
|
||||
// Get the first task from queue.
|
||||
func (q *userBasedTaskQueue) front() readTask {
|
||||
element := q.queue.Front()
|
||||
if element == nil {
|
||||
return nil
|
||||
}
|
||||
return element.Value.(readTask)
|
||||
}
|
||||
|
||||
// Remove the first task from queue.
|
||||
func (q *userBasedTaskQueue) pop() {
|
||||
front := q.queue.Front()
|
||||
if front != nil {
|
||||
q.queue.Remove(front)
|
||||
}
|
||||
if q.queue.Len() == 0 {
|
||||
q.cleanupTimestamp = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// Return true if user based task is empty and empty for d time.
|
||||
func (q *userBasedTaskQueue) expire(d time.Duration) bool {
|
||||
if q.queue.Len() != 0 {
|
||||
return false
|
||||
}
|
||||
if time.Since(q.cleanupTimestamp) > d {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Merge a new task to task in queue.
|
||||
func (q *userBasedTaskQueue) mergeTask(rt readTask) bool {
|
||||
for element := q.queue.Back(); element != nil; element = element.Prev() {
|
||||
task := element.Value.(readTask)
|
||||
if task.CanMergeWith(rt) {
|
||||
task.Merge(rt)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Implement user based schedule read policy.
|
||||
type userTaskPollingScheduleReadPolicy struct {
|
||||
taskCount int // task count in the policy.
|
||||
route map[string]*ring.Ring // map username to node of task ring.
|
||||
checkpoint *ring.Ring // last not schedule ring node, ring.Ring[list.List]
|
||||
taskQueueExpire time.Duration
|
||||
}
|
||||
|
||||
// Create a new user-based schedule read policy.
|
||||
func newUserTaskPollingScheduleReadPolicy(taskQueueExpire time.Duration) scheduleReadPolicy {
|
||||
return &userTaskPollingScheduleReadPolicy{
|
||||
taskCount: 0,
|
||||
route: make(map[string]*ring.Ring),
|
||||
checkpoint: nil,
|
||||
taskQueueExpire: taskQueueExpire,
|
||||
}
|
||||
}
|
||||
|
||||
// Get length of task queue.
|
||||
func (p *userTaskPollingScheduleReadPolicy) len() int {
|
||||
return p.taskCount
|
||||
}
|
||||
|
||||
// Add a new task into ready task queue.
|
||||
func (p *userTaskPollingScheduleReadPolicy) addTask(rt readTask) {
|
||||
username := contextutil.GetUserFromGrpcMetadata(rt.Ctx()) // empty user will compete on single list.
|
||||
if r, ok := p.route[username]; ok {
|
||||
// Add new task to the back of queue if queue exist.
|
||||
r.Value.(*userBasedTaskQueue).push(rt)
|
||||
} else {
|
||||
// Create a new list, and add it to the route and queues.
|
||||
newQueue := newUserBasedTaskQueue(username)
|
||||
newQueue.push(rt)
|
||||
newRing := ring.New(1)
|
||||
newRing.Value = newQueue
|
||||
p.route[username] = newRing
|
||||
if p.checkpoint == nil {
|
||||
// Create new ring if not exist.
|
||||
p.checkpoint = newRing
|
||||
} else {
|
||||
// Add the new ring before the checkpoint.
|
||||
p.checkpoint.Prev().Link(newRing)
|
||||
}
|
||||
}
|
||||
p.taskCount++
|
||||
}
|
||||
|
||||
// Merge a new task into exist ready task.
|
||||
func (p *userTaskPollingScheduleReadPolicy) mergeTask(rt readTask) bool {
|
||||
if p.taskCount == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
username := contextutil.GetUserFromGrpcMetadata(rt.Ctx()) // empty user will compete on single list.
|
||||
// Applied to task with same user first.
|
||||
if r, ok := p.route[username]; ok {
|
||||
// Try to merge task into queue.
|
||||
if r.Value.(*userBasedTaskQueue).mergeTask(rt) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Try to merge task into other user queue before checkpoint.
|
||||
node := p.checkpoint.Prev()
|
||||
queuesLen := p.checkpoint.Len()
|
||||
for i := 0; i < queuesLen; i++ {
|
||||
prev := node.Prev()
|
||||
queue := node.Value.(*userBasedTaskQueue)
|
||||
if queue.len() == 0 || queue.username == username {
|
||||
continue
|
||||
}
|
||||
if queue.mergeTask(rt) {
|
||||
return true
|
||||
}
|
||||
node = prev
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *userTaskPollingScheduleReadPolicy) schedule(targetCPUUsage int32, maxNum int32) (result []readTask, usage int32) {
|
||||
// Return directly if there's no task ready.
|
||||
if p.taskCount == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
queuesLen := p.checkpoint.Len()
|
||||
checkpoint := p.checkpoint
|
||||
// TODO: infinite loop.
|
||||
L:
|
||||
for {
|
||||
readyCount := len(result)
|
||||
for i := 0; i < queuesLen; i++ {
|
||||
if len(result) >= int(maxNum) {
|
||||
break L
|
||||
}
|
||||
|
||||
next := checkpoint.Next()
|
||||
// Find task in this queue.
|
||||
taskQueue := checkpoint.Value.(*userBasedTaskQueue)
|
||||
|
||||
// empty task queue for this user.
|
||||
if taskQueue.len() == 0 {
|
||||
// expire the queue.
|
||||
if taskQueue.expire(p.taskQueueExpire) {
|
||||
delete(p.route, taskQueue.username)
|
||||
if checkpoint.Len() == 1 {
|
||||
checkpoint = nil
|
||||
break L
|
||||
} else {
|
||||
checkpoint.Prev().Unlink(1)
|
||||
}
|
||||
}
|
||||
checkpoint = next
|
||||
continue
|
||||
}
|
||||
|
||||
// Read first read task of queue and check if cpu is enough.
|
||||
task := taskQueue.front()
|
||||
tUsage := task.CPUUsage()
|
||||
if usage+tUsage > targetCPUUsage {
|
||||
break L
|
||||
}
|
||||
|
||||
// Pop the task and add to schedule list.
|
||||
usage += tUsage
|
||||
result = append(result, task)
|
||||
taskQueue.pop()
|
||||
p.taskCount--
|
||||
|
||||
checkpoint = next
|
||||
}
|
||||
|
||||
// Stop loop if no task is added.
|
||||
if readyCount == len(result) {
|
||||
break L
|
||||
}
|
||||
}
|
||||
|
||||
// Update checkpoint.
|
||||
p.checkpoint = checkpoint
|
||||
return result, usage
|
||||
}
|
||||
|
||||
// Implement default FIFO policy.
|
||||
type fifoScheduleReadPolicy struct {
|
||||
ready *list.List // Save ready to run task.
|
||||
}
|
||||
|
||||
// Create a new default schedule read policy
|
||||
func newFIFOScheduleReadPolicy() scheduleReadPolicy {
|
||||
return &fifoScheduleReadPolicy{
|
||||
ready: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Read task count inside.
|
||||
func (p *fifoScheduleReadPolicy) len() int {
|
||||
return p.ready.Len()
|
||||
}
|
||||
|
||||
// Add a new task into ready task.
|
||||
func (p *fifoScheduleReadPolicy) addTask(rt readTask) {
|
||||
p.ready.PushBack(rt)
|
||||
}
|
||||
|
||||
// Merge a new task into exist ready task.
|
||||
func (p *fifoScheduleReadPolicy) mergeTask(rt readTask) bool {
|
||||
// Reverse Iterate the task in the queue
|
||||
for task := p.ready.Back(); task != nil; task = task.Prev() {
|
||||
taskExist := task.Value.(readTask)
|
||||
if taskExist.CanMergeWith(rt) {
|
||||
taskExist.Merge(rt)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Schedule a new task.
|
||||
func (p *fifoScheduleReadPolicy) schedule(targetCPUUsage int32, maxNum int32) (result []readTask, usage int32) {
|
||||
var ret []readTask
|
||||
usage := int32(0)
|
||||
var next *list.Element
|
||||
for e := sqTasks.Front(); e != nil && maxNum > 0; e = next {
|
||||
for e := p.ready.Front(); e != nil && maxNum > 0; e = next {
|
||||
next = e.Next()
|
||||
t, _ := e.Value.(readTask)
|
||||
tUsage := t.CPUUsage()
|
||||
if usage+tUsage > targetUsage {
|
||||
if usage+tUsage > targetCPUUsage {
|
||||
break
|
||||
}
|
||||
usage += tUsage
|
||||
sqTasks.Remove(e)
|
||||
p.ready.Remove(e)
|
||||
rateCol.rtCounter.sub(t, readyQueueType)
|
||||
ret = append(ret, t)
|
||||
maxNum--
|
||||
|
||||
@ -1,60 +1,296 @@
|
||||
package querynode
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util"
|
||||
"github.com/milvus-io/milvus/internal/util/crypto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
func TestScheduler_newReadScheduleTaskPolicy(t *testing.T) {
|
||||
policy := newReadScheduleTaskPolicy(scheduleReadPolicyNameFIFO)
|
||||
assert.IsType(t, policy, &fifoScheduleReadPolicy{})
|
||||
policy = newReadScheduleTaskPolicy("")
|
||||
assert.IsType(t, policy, &fifoScheduleReadPolicy{})
|
||||
policy = newReadScheduleTaskPolicy(scheduleReadPolicyNameUserTaskPolling)
|
||||
assert.IsType(t, policy, &userTaskPollingScheduleReadPolicy{})
|
||||
assert.Panics(t, func() {
|
||||
newReadScheduleTaskPolicy("other")
|
||||
})
|
||||
}
|
||||
|
||||
func TestScheduler_defaultScheduleReadPolicy(t *testing.T) {
|
||||
readyReadTasks := list.New()
|
||||
for i := 1; i <= 10; i++ {
|
||||
policy := newFIFOScheduleReadPolicy()
|
||||
testBasicScheduleReadPolicy(t, policy)
|
||||
|
||||
for i := 1; i <= 100; i++ {
|
||||
t := mockReadTask{
|
||||
cpuUsage: int32(i * 10),
|
||||
}
|
||||
readyReadTasks.PushBack(&t)
|
||||
policy.addTask(&t)
|
||||
}
|
||||
|
||||
scheduleFunc := defaultScheduleReadPolicy
|
||||
|
||||
targetUsage := int32(100)
|
||||
maxNum := int32(2)
|
||||
|
||||
tasks, cur := scheduleFunc(readyReadTasks, targetUsage, maxNum)
|
||||
tasks, cur := policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, int32(30), cur)
|
||||
assert.Equal(t, int32(2), int32(len(tasks)))
|
||||
|
||||
targetUsage = 300
|
||||
maxNum = 0
|
||||
tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum)
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, int32(0), cur)
|
||||
assert.Equal(t, 0, len(tasks))
|
||||
|
||||
targetUsage = 0
|
||||
maxNum = 0
|
||||
tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum)
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, int32(0), cur)
|
||||
assert.Equal(t, 0, len(tasks))
|
||||
|
||||
targetUsage = 0
|
||||
maxNum = 300
|
||||
tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum)
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, int32(0), cur)
|
||||
assert.Equal(t, 0, len(tasks))
|
||||
|
||||
actual := int32(180) // sum(3..6) * 10 3 + 4 + 5 + 6
|
||||
targetUsage = int32(190) // > actual
|
||||
maxNum = math.MaxInt32
|
||||
tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum)
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, actual, cur)
|
||||
assert.Equal(t, 4, len(tasks))
|
||||
|
||||
actual = 340 // sum(7..10) * 10 , 7+ 8 + 9 + 10
|
||||
targetUsage = 340
|
||||
maxNum = 4
|
||||
tasks, cur = scheduleFunc(readyReadTasks, targetUsage, maxNum)
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, actual, cur)
|
||||
assert.Equal(t, 4, len(tasks))
|
||||
|
||||
actual = 4995 * 10 // sum(11..100)
|
||||
targetUsage = actual
|
||||
maxNum = 90
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, actual, cur)
|
||||
assert.Equal(t, 90, len(tasks))
|
||||
assert.Equal(t, 0, policy.len())
|
||||
}
|
||||
|
||||
func TestScheduler_userTaskPollingScheduleReadPolicy(t *testing.T) {
|
||||
policy := newUserTaskPollingScheduleReadPolicy(time.Minute)
|
||||
testBasicScheduleReadPolicy(t, policy)
|
||||
|
||||
for i := 1; i <= 100; i++ {
|
||||
policy.addTask(&mockReadTask{
|
||||
cpuUsage: int32(i * 10),
|
||||
mockTask: mockTask{
|
||||
baseTask: baseTask{
|
||||
ctx: getContextWithAuthorization(context.Background(), fmt.Sprintf("user%d:123456", i%10)),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
targetUsage := int32(100)
|
||||
maxNum := int32(2)
|
||||
|
||||
tasks, cur := policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, int32(30), cur)
|
||||
assert.Equal(t, int32(2), int32(len(tasks)))
|
||||
assert.Equal(t, 98, policy.len())
|
||||
|
||||
targetUsage = 300
|
||||
maxNum = 0
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, int32(0), cur)
|
||||
assert.Equal(t, 0, len(tasks))
|
||||
assert.Equal(t, 98, policy.len())
|
||||
|
||||
targetUsage = 0
|
||||
maxNum = 0
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, int32(0), cur)
|
||||
assert.Equal(t, 0, len(tasks))
|
||||
assert.Equal(t, 98, policy.len())
|
||||
|
||||
targetUsage = 0
|
||||
maxNum = 300
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, int32(0), cur)
|
||||
assert.Equal(t, 0, len(tasks))
|
||||
assert.Equal(t, 98, policy.len())
|
||||
|
||||
actual := int32(180) // sum(3..6) * 10 3 + 4 + 5 + 6
|
||||
targetUsage = int32(190) // > actual
|
||||
maxNum = math.MaxInt32
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, actual, cur)
|
||||
assert.Equal(t, 4, len(tasks))
|
||||
assert.Equal(t, 94, policy.len())
|
||||
|
||||
actual = 340 // sum(7..10) * 10 , 7+ 8 + 9 + 10
|
||||
targetUsage = 340
|
||||
maxNum = 4
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, actual, cur)
|
||||
assert.Equal(t, 4, len(tasks))
|
||||
assert.Equal(t, 90, policy.len())
|
||||
|
||||
actual = 4995 * 10 // sum(11..100)
|
||||
targetUsage = actual
|
||||
maxNum = 90
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, actual, cur)
|
||||
assert.Equal(t, 90, len(tasks))
|
||||
assert.Equal(t, 0, policy.len())
|
||||
|
||||
time.Sleep(time.Minute + time.Second)
|
||||
policy.addTask(&mockReadTask{
|
||||
cpuUsage: int32(1),
|
||||
mockTask: mockTask{
|
||||
baseTask: baseTask{
|
||||
ctx: getContextWithAuthorization(context.Background(), fmt.Sprintf("user%d:123456", 11)),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
tasks, cur = policy.schedule(targetUsage, maxNum)
|
||||
assert.Equal(t, int32(1), cur)
|
||||
assert.Equal(t, 1, len(tasks))
|
||||
assert.Equal(t, 0, policy.len())
|
||||
policyInner := policy.(*userTaskPollingScheduleReadPolicy)
|
||||
assert.Equal(t, 1, len(policyInner.route))
|
||||
assert.Equal(t, 1, policyInner.checkpoint.Len())
|
||||
}
|
||||
|
||||
func Test_userBasedTaskQueue(t *testing.T) {
|
||||
n := 50
|
||||
q := newUserBasedTaskQueue("test_user")
|
||||
for i := 1; i <= n; i++ {
|
||||
q.push(&mockReadTask{
|
||||
cpuUsage: int32(i * 10),
|
||||
mockTask: mockTask{
|
||||
baseTask: baseTask{
|
||||
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Equal(t, q.len(), i)
|
||||
assert.Equal(t, q.expire(time.Second), false)
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
q.pop()
|
||||
assert.Equal(t, q.len(), n-(i+1))
|
||||
assert.Equal(t, q.expire(time.Second), false)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
assert.Equal(t, q.expire(time.Second), true)
|
||||
}
|
||||
|
||||
func testBasicScheduleReadPolicy(t *testing.T, policy scheduleReadPolicy) {
|
||||
// test, push and schedule.
|
||||
for i := 1; i <= 50; i++ {
|
||||
cpuUsage := int32(i * 10)
|
||||
id := 1
|
||||
|
||||
policy.addTask(&mockReadTask{
|
||||
cpuUsage: cpuUsage,
|
||||
mockTask: mockTask{
|
||||
baseTask: baseTask{
|
||||
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
|
||||
id: UniqueID(id),
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Equal(t, policy.len(), 1)
|
||||
task, cost := policy.schedule(cpuUsage, 1)
|
||||
assert.Equal(t, cost, cpuUsage)
|
||||
assert.Equal(t, len(task), 1)
|
||||
assert.Equal(t, task[0].ID(), int64(id))
|
||||
}
|
||||
|
||||
// test, can not merge and schedule.
|
||||
cpuUsage := int32(100)
|
||||
notMergeTask := &mockReadTask{
|
||||
cpuUsage: cpuUsage,
|
||||
canMerge: false,
|
||||
mockTask: mockTask{
|
||||
baseTask: baseTask{
|
||||
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.False(t, policy.mergeTask(notMergeTask))
|
||||
policy.addTask(notMergeTask)
|
||||
assert.Equal(t, policy.len(), 1)
|
||||
task2 := &mockReadTask{
|
||||
cpuUsage: cpuUsage,
|
||||
canMerge: false,
|
||||
mockTask: mockTask{
|
||||
baseTask: baseTask{
|
||||
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.False(t, policy.mergeTask(task2))
|
||||
assert.Equal(t, policy.len(), 1)
|
||||
policy.addTask(notMergeTask)
|
||||
assert.Equal(t, policy.len(), 2)
|
||||
task, cost := policy.schedule(2*cpuUsage, 1)
|
||||
assert.Equal(t, cost, cpuUsage)
|
||||
assert.Equal(t, len(task), 1)
|
||||
assert.Equal(t, policy.len(), 1)
|
||||
assert.False(t, task[0].(*mockReadTask).merged)
|
||||
task, cost = policy.schedule(2*cpuUsage, 1)
|
||||
assert.Equal(t, cost, cpuUsage)
|
||||
assert.Equal(t, len(task), 1)
|
||||
assert.Equal(t, policy.len(), 0)
|
||||
assert.False(t, task[0].(*mockReadTask).merged)
|
||||
|
||||
// test, can merge and schedule.
|
||||
mergeTask := &mockReadTask{
|
||||
cpuUsage: cpuUsage,
|
||||
canMerge: true,
|
||||
mockTask: mockTask{
|
||||
baseTask: baseTask{
|
||||
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
|
||||
},
|
||||
},
|
||||
}
|
||||
policy.addTask(mergeTask)
|
||||
task2 = &mockReadTask{
|
||||
cpuUsage: cpuUsage,
|
||||
mockTask: mockTask{
|
||||
baseTask: baseTask{
|
||||
ctx: getContextWithAuthorization(context.Background(), "default:123456"),
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.True(t, policy.mergeTask(task2))
|
||||
assert.Equal(t, policy.len(), 1)
|
||||
task, cost = policy.schedule(cpuUsage, 1)
|
||||
assert.Equal(t, cost, cpuUsage)
|
||||
assert.Equal(t, len(task), 1)
|
||||
assert.Equal(t, policy.len(), 0)
|
||||
assert.True(t, task[0].(*mockReadTask).merged)
|
||||
}
|
||||
|
||||
func getContextWithAuthorization(ctx context.Context, originValue string) context.Context {
|
||||
authKey := strings.ToLower(util.HeaderAuthorize)
|
||||
authValue := crypto.Base64Encode(originValue)
|
||||
contextMap := map[string]string{
|
||||
authKey: authValue,
|
||||
}
|
||||
md := metadata.New(contextMap)
|
||||
return metadata.NewIncomingContext(ctx, md)
|
||||
}
|
||||
|
||||
@ -33,6 +33,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/contextutil"
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
"github.com/samber/lo"
|
||||
@ -127,9 +128,11 @@ type ShardSegmentDetector interface {
|
||||
type ShardNodeBuilder func(nodeID int64, addr string) shardQueryNode
|
||||
|
||||
// withStreaming function type to let search detects corresponding search streaming is done.
|
||||
type searchWithStreaming func(ctx context.Context) (error, *internalpb.SearchResults)
|
||||
type queryWithStreaming func(ctx context.Context) (error, *internalpb.RetrieveResults)
|
||||
type getStatisticsWithStreaming func(ctx context.Context) (error, *internalpb.GetStatisticsResponse)
|
||||
type (
|
||||
searchWithStreaming func(ctx context.Context) (error, *internalpb.SearchResults)
|
||||
queryWithStreaming func(ctx context.Context) (error, *internalpb.RetrieveResults)
|
||||
getStatisticsWithStreaming func(ctx context.Context) (error, *internalpb.GetStatisticsResponse)
|
||||
)
|
||||
|
||||
// ShardCluster maintains the ShardCluster information and perform shard level operations
|
||||
type ShardCluster struct {
|
||||
@ -156,7 +159,8 @@ type ShardCluster struct {
|
||||
|
||||
// NewShardCluster create a ShardCluster with provided information.
|
||||
func NewShardCluster(collectionID int64, replicaID int64, vchannelName string, version int64,
|
||||
nodeDetector ShardNodeDetector, segmentDetector ShardSegmentDetector, nodeBuilder ShardNodeBuilder) *ShardCluster {
|
||||
nodeDetector ShardNodeDetector, segmentDetector ShardSegmentDetector, nodeBuilder ShardNodeBuilder,
|
||||
) *ShardCluster {
|
||||
sc := &ShardCluster{
|
||||
state: atomic.NewInt32(int32(unavailable)),
|
||||
|
||||
@ -933,6 +937,13 @@ func getSearchWithStreamingFunc(searchCtx context.Context, req *querypb.SearchRe
|
||||
metrics.SearchLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
|
||||
metrics.SearchLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
|
||||
// In queue latency per user.
|
||||
metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues(
|
||||
fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
|
||||
metrics.SearchLabel,
|
||||
contextutil.GetUserFromGrpcMetadata(streamingTask.Ctx()),
|
||||
).Observe(float64(streamingTask.queueDur.Milliseconds()))
|
||||
|
||||
return nil, streamingTask.Ret
|
||||
}
|
||||
}
|
||||
@ -943,7 +954,6 @@ func getQueryWithStreamingFunc(queryCtx context.Context, req *querypb.QueryReque
|
||||
streamingTask.DataScope = querypb.DataScope_Streaming
|
||||
streamingTask.QS = qs
|
||||
err := node.scheduler.AddReadTask(queryCtx, streamingTask)
|
||||
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
@ -955,6 +965,13 @@ func getQueryWithStreamingFunc(queryCtx context.Context, req *querypb.QueryReque
|
||||
metrics.QueryLabel).Observe(float64(streamingTask.queueDur.Milliseconds()))
|
||||
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(nodeID),
|
||||
metrics.QueryLabel).Observe(float64(streamingTask.reduceDur.Milliseconds()))
|
||||
|
||||
// In queue latency per user.
|
||||
metrics.QueryNodeSQPerUserLatencyInQueue.WithLabelValues(
|
||||
fmt.Sprint(Params.QueryNodeCfg.GetNodeID()),
|
||||
metrics.QueryLabel,
|
||||
contextutil.GetUserFromGrpcMetadata(streamingTask.Ctx()),
|
||||
).Observe(float64(streamingTask.queueDur.Milliseconds()))
|
||||
return nil, streamingTask.Ret
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,7 +40,6 @@ type taskScheduler struct {
|
||||
|
||||
// for search and query start
|
||||
unsolvedReadTasks *list.List
|
||||
readyReadTasks *list.List
|
||||
|
||||
receiveReadTaskChan chan readTask
|
||||
executeReadTaskChan chan readTask
|
||||
@ -50,7 +49,8 @@ type taskScheduler struct {
|
||||
// tSafeReplica
|
||||
tSafeReplica TSafeReplicaInterface
|
||||
|
||||
schedule scheduleReadTaskPolicy
|
||||
// schedule policy for ready read task.
|
||||
readySchedulePolicy scheduleReadPolicy
|
||||
// for search and query end
|
||||
|
||||
cpuUsage int32 // 1200 means 1200% 12 cores
|
||||
@ -77,13 +77,12 @@ func newTaskScheduler(ctx context.Context, tSafeReplica TSafeReplicaInterface) *
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
unsolvedReadTasks: list.New(),
|
||||
readyReadTasks: list.New(),
|
||||
receiveReadTaskChan: make(chan readTask, Params.QueryNodeCfg.MaxReceiveChanSize),
|
||||
executeReadTaskChan: make(chan readTask, maxExecuteReadChanLen),
|
||||
notifyChan: make(chan struct{}, 1),
|
||||
tSafeReplica: tSafeReplica,
|
||||
maxCPUUsage: int32(getNumCPU() * 100),
|
||||
schedule: defaultScheduleReadPolicy,
|
||||
readySchedulePolicy: newReadScheduleTaskPolicy(Params.QueryNodeCfg.ScheduleReadPolicy.Name),
|
||||
}
|
||||
s.queue = newQueryNodeTaskQueue(s)
|
||||
return s
|
||||
@ -248,7 +247,7 @@ func (s *taskScheduler) AddReadTask(ctx context.Context, t readTask) error {
|
||||
func (s *taskScheduler) popAndAddToExecute() {
|
||||
readConcurrency := atomic.LoadInt32(&s.readConcurrency)
|
||||
metrics.QueryNodeReadTaskConcurrency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(readConcurrency))
|
||||
if s.readyReadTasks.Len() == 0 {
|
||||
if s.readySchedulePolicy.len() == 0 {
|
||||
return
|
||||
}
|
||||
curUsage := atomic.LoadInt32(&s.cpuUsage)
|
||||
@ -266,7 +265,7 @@ func (s *taskScheduler) popAndAddToExecute() {
|
||||
return
|
||||
}
|
||||
|
||||
tasks, deltaUsage := s.schedule(s.readyReadTasks, targetUsage, remain)
|
||||
tasks, deltaUsage := s.readySchedulePolicy.schedule(targetUsage, remain)
|
||||
atomic.AddInt32(&s.cpuUsage, deltaUsage)
|
||||
for _, t := range tasks {
|
||||
s.executeReadTaskChan <- t
|
||||
@ -366,23 +365,12 @@ func (s *taskScheduler) tryMergeReadTasks() {
|
||||
}
|
||||
if ready {
|
||||
if !Params.QueryNodeCfg.GroupEnabled {
|
||||
s.readyReadTasks.PushBack(t)
|
||||
s.readySchedulePolicy.addTask(t)
|
||||
rateCol.rtCounter.add(t, readyQueueType)
|
||||
} else {
|
||||
merged := false
|
||||
for m := s.readyReadTasks.Back(); m != nil; m = m.Prev() {
|
||||
mTask, ok := m.Value.(readTask)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if mTask.CanMergeWith(t) {
|
||||
mTask.Merge(t)
|
||||
merged = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !merged {
|
||||
s.readyReadTasks.PushBack(t)
|
||||
// Try to merge task first, otherwise add it.
|
||||
if !s.readySchedulePolicy.mergeTask(t) {
|
||||
s.readySchedulePolicy.addTask(t)
|
||||
rateCol.rtCounter.add(t, readyQueueType)
|
||||
}
|
||||
}
|
||||
@ -391,5 +379,5 @@ func (s *taskScheduler) tryMergeReadTasks() {
|
||||
}
|
||||
}
|
||||
metrics.QueryNodeReadTaskUnsolveLen.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(s.unsolvedReadTasks.Len()))
|
||||
metrics.QueryNodeReadTaskReadyLen.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(s.readyReadTasks.Len()))
|
||||
metrics.QueryNodeReadTaskReadyLen.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(s.readySchedulePolicy.len()))
|
||||
}
|
||||
|
||||
@ -71,6 +71,7 @@ type mockReadTask struct {
|
||||
timeoutError error
|
||||
step typeutil.TaskStep
|
||||
readyError error
|
||||
merged bool
|
||||
}
|
||||
|
||||
func (m *mockReadTask) GetCollectionID() UniqueID {
|
||||
@ -82,7 +83,7 @@ func (m *mockReadTask) Ready() (bool, error) {
|
||||
}
|
||||
|
||||
func (m *mockReadTask) Merge(o readTask) {
|
||||
|
||||
m.merged = true
|
||||
}
|
||||
|
||||
func (m *mockReadTask) CPUUsage() int32 {
|
||||
|
||||
@ -38,6 +38,8 @@ const (
|
||||
|
||||
HeaderAuthorize = "authorization"
|
||||
HeaderDBName = "dbName"
|
||||
// HeaderUser is the key of username in metadata of grpc.
|
||||
HeaderUser = "user"
|
||||
// HeaderSourceID identify requests from Milvus members and client requests
|
||||
HeaderSourceID = "sourceId"
|
||||
// MemberCredID id for Milvus members (data/index/query node/coord component)
|
||||
|
||||
@ -16,7 +16,14 @@
|
||||
|
||||
package contextutil
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util"
|
||||
"github.com/milvus-io/milvus/internal/util/crypto"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
type ctxTenantKey struct{}
|
||||
|
||||
@ -37,3 +44,47 @@ func TenantID(ctx context.Context) string {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Passthrough "user" field in grpc from incoming to outgoing.
|
||||
func PassthroughUserInGrpcMetadata(ctx context.Context) context.Context {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return ctx
|
||||
}
|
||||
user := md[strings.ToLower(util.HeaderUser)]
|
||||
if len(user) == 0 || len(user[0]) == 0 {
|
||||
return ctx
|
||||
}
|
||||
return metadata.AppendToOutgoingContext(ctx, util.HeaderUser, user[0])
|
||||
}
|
||||
|
||||
// Set "user" field in grpc outgoing metadata
|
||||
func WithUserInGrpcMetadata(ctx context.Context, user string) context.Context {
|
||||
return metadata.AppendToOutgoingContext(ctx, util.HeaderUser, crypto.Base64Encode(user))
|
||||
}
|
||||
|
||||
// Get "user" field from grpc metadata, empty string will returned if not set.
|
||||
func GetUserFromGrpcMetadata(ctx context.Context) string {
|
||||
if user := getUserFromGrpcMetadataAux(ctx, metadata.FromIncomingContext); user != "" {
|
||||
return user
|
||||
}
|
||||
return getUserFromGrpcMetadataAux(ctx, metadata.FromOutgoingContext)
|
||||
}
|
||||
|
||||
// Aux function for `GetUserFromGrpc`
|
||||
func getUserFromGrpcMetadataAux(ctx context.Context, mdGetter func(ctx context.Context) (metadata.MD, bool)) string {
|
||||
md, ok := mdGetter(ctx)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
user := md[strings.ToLower(util.HeaderUser)]
|
||||
if len(user) == 0 || len(user[0]) == 0 {
|
||||
return ""
|
||||
}
|
||||
// It may be duplicated in meta, but should be always same.
|
||||
rawuser, err := crypto.Base64Decode(user[0])
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return rawuser
|
||||
}
|
||||
|
||||
29
internal/util/contextutil/context_util_test.go
Normal file
29
internal/util/contextutil/context_util_test.go
Normal file
@ -0,0 +1,29 @@
|
||||
package contextutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util"
|
||||
"github.com/milvus-io/milvus/internal/util/crypto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// Test UserInGrpcMetadata related function
|
||||
func Test_UserInGrpcMetadata(t *testing.T) {
|
||||
testUser := "test_user_131"
|
||||
ctx := context.Background()
|
||||
assert.Equal(t, GetUserFromGrpcMetadata(ctx), "")
|
||||
|
||||
ctx = WithUserInGrpcMetadata(ctx, testUser)
|
||||
assert.Equal(t, GetUserFromGrpcMetadata(ctx), testUser)
|
||||
|
||||
md := metadata.Pairs(util.HeaderUser, crypto.Base64Encode(testUser))
|
||||
ctx = metadata.NewIncomingContext(context.Background(), md)
|
||||
assert.Equal(t, getUserFromGrpcMetadataAux(ctx, metadata.FromIncomingContext), testUser)
|
||||
|
||||
ctx = PassthroughUserInGrpcMetadata(ctx)
|
||||
assert.Equal(t, getUserFromGrpcMetadataAux(ctx, metadata.FromOutgoingContext), testUser)
|
||||
assert.Equal(t, GetUserFromGrpcMetadata(ctx), testUser)
|
||||
}
|
||||
@ -34,11 +34,11 @@ const (
|
||||
|
||||
// DefaultIndexSliceSize defines the default slice size of index file when serializing.
|
||||
DefaultIndexSliceSize = 16
|
||||
DefaultGracefulTime = 5000 //ms
|
||||
DefaultGracefulTime = 5000 // ms
|
||||
DefaultGracefulStopTimeout = 30 // s
|
||||
DefaultThreadCoreCoefficient = 10
|
||||
|
||||
DefaultSessionTTL = 20 //s
|
||||
DefaultSessionTTL = 20 // s
|
||||
DefaultSessionRetryTimes = 30
|
||||
|
||||
DefaultMaxDegree = 56
|
||||
@ -1117,6 +1117,9 @@ type queryNodeConfig struct {
|
||||
CPURatio float64
|
||||
MaxTimestampLag time.Duration
|
||||
|
||||
// schedule
|
||||
ScheduleReadPolicy queryNodeConfigScheduleReadPolicy
|
||||
|
||||
GCHelperEnabled bool
|
||||
MinimumGOGCConfig int
|
||||
MaximumGOGCConfig int
|
||||
@ -1124,6 +1127,15 @@ type queryNodeConfig struct {
|
||||
GracefulStopTimeout int64
|
||||
}
|
||||
|
||||
type queryNodeConfigScheduleReadPolicy struct {
|
||||
Name string
|
||||
UserTaskPolling queryNodeConfigScheduleUserTaskPolling
|
||||
}
|
||||
|
||||
type queryNodeConfigScheduleUserTaskPolling struct {
|
||||
TaskQueueExpire time.Duration
|
||||
}
|
||||
|
||||
func (p *queryNodeConfig) init(base *BaseTable) {
|
||||
p.Base = base
|
||||
p.NodeID.Store(UniqueID(0))
|
||||
@ -1153,6 +1165,9 @@ func (p *queryNodeConfig) init(base *BaseTable) {
|
||||
p.initDiskCapacity()
|
||||
p.initMaxDiskUsagePercentage()
|
||||
|
||||
// Initialize scheduler.
|
||||
p.initScheduleReadPolicy()
|
||||
|
||||
p.initGCTunerEnbaled()
|
||||
p.initMaximumGOGC()
|
||||
p.initMinimumGOGC()
|
||||
@ -1346,6 +1361,11 @@ func (p *queryNodeConfig) initDiskCapacity() {
|
||||
p.DiskCapacityLimit = diskSize * 1024 * 1024 * 1024
|
||||
}
|
||||
|
||||
func (p *queryNodeConfig) initScheduleReadPolicy() {
|
||||
p.ScheduleReadPolicy.Name = p.Base.LoadWithDefault("queryNode.scheduler.scheduleReadPolicy.name", "fifo")
|
||||
p.ScheduleReadPolicy.UserTaskPolling.TaskQueueExpire = time.Duration(p.Base.ParseInt64WithDefault("queryNode.scheduler.scheduleReadPolicy.taskQueueExpire", 60)) * time.Second
|
||||
}
|
||||
|
||||
func (p *queryNodeConfig) initGCTunerEnbaled() {
|
||||
p.GCHelperEnabled = p.Base.ParseBool("queryNode.gchelper.enabled", true)
|
||||
}
|
||||
@ -1514,7 +1534,6 @@ func (p *dataCoordConfig) initSegmentMinSizeFromIdleToSealed() {
|
||||
func (p *dataCoordConfig) initSegmentExpansionRate() {
|
||||
p.SegmentExpansionRate = p.Base.ParseFloatWithDefault("dataCoord.segment.expansionRate", 1.25)
|
||||
log.Info("init segment expansion rate", zap.Float64("value", p.SegmentExpansionRate))
|
||||
|
||||
}
|
||||
|
||||
func (p *dataCoordConfig) initSegmentMaxBinlogFileNumber() {
|
||||
@ -1647,7 +1666,7 @@ type dataNodeConfig struct {
|
||||
Base *BaseTable
|
||||
|
||||
// ID of the current node
|
||||
//NodeID atomic.Value
|
||||
// NodeID atomic.Value
|
||||
NodeID atomic.Value
|
||||
// IP of the current DataNode
|
||||
IP string
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user