From 3cc59a0d69958dd13cc92bd6efe7a19252e55b38 Mon Sep 17 00:00:00 2001 From: Gao Date: Tue, 30 Sep 2025 00:31:06 +0800 Subject: [PATCH] enhance: add storage usage for delete/upsert/restful (#44512) #44212 Also, record metrics only when storageUsageTracking is enabled. Use MB for scanned_remote counter and scanned_total counter metrics to avoid overflow. --------- Signed-off-by: chasingegg --- .../distributed/proxy/httpserver/constant.go | 4 + .../proxy/httpserver/handler_v2.go | 146 +++++++++++++++--- internal/proxy/impl.go | 98 +++++++----- internal/proxy/task_delete.go | 15 ++ internal/proxy/task_upsert.go | 22 +-- internal/proxy/task_upsert_test.go | 14 +- internal/proxy/util.go | 42 +++++ internal/proxy/util_test.go | 82 ++++++++++ pkg/metrics/proxy_metrics.go | 84 ++++++---- 9 files changed, 399 insertions(+), 108 deletions(-) diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index 8a07e6185b..34895ac293 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -132,6 +132,10 @@ const ( HTTPReturnHas = "has" + HTTPReturnScannedRemoteBytes = "scanned_remote_bytes" + HTTPReturnScannedTotalBytes = "scanned_total_bytes" + HTTPReturnCacheHitRatio = "cache_hit_ratio" + HTTPReturnFieldName = "name" HTTPReturnFieldID = "id" HTTPReturnFieldType = "type" diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index fd3dd26978..60eea4d422 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -860,11 +860,23 @@ func (h *HandlersV2) query(ctx context.Context, c *gin.Context, anyReq any, dbNa HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), }) } else { - HTTPReturnStream(c, http.StatusOK, gin.H{ - HTTPReturnCode: merr.Code(nil), - HTTPReturnData: outputData, - HTTPReturnCost: proxy.GetCostValue(queryResp.GetStatus()), - }) + scannedRemoteBytes, scannedTotalBytes, cacheHitRatio, isValid := proxy.GetStorageCost(queryResp.GetStatus()) + if proxy.Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() && isValid { + HTTPReturnStream(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: outputData, + HTTPReturnCost: proxy.GetCostValue(queryResp.GetStatus()), + HTTPReturnScannedRemoteBytes: scannedRemoteBytes, + HTTPReturnScannedTotalBytes: scannedTotalBytes, + HTTPReturnCacheHitRatio: cacheHitRatio, + }) + } else { + HTTPReturnStream(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: outputData, + HTTPReturnCost: proxy.GetCostValue(queryResp.GetStatus()), + }) + } } } return resp, err @@ -916,11 +928,23 @@ func (h *HandlersV2) get(ctx context.Context, c *gin.Context, anyReq any, dbName HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), }) } else { - HTTPReturnStream(c, http.StatusOK, gin.H{ - HTTPReturnCode: merr.Code(nil), - HTTPReturnData: outputData, - HTTPReturnCost: proxy.GetCostValue(queryResp.GetStatus()), - }) + scannedRemoteBytes, scannedTotalBytes, cacheHitRatio, isValid := proxy.GetStorageCost(queryResp.GetStatus()) + if proxy.Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() && isValid { + HTTPReturnStream(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: outputData, + HTTPReturnCost: proxy.GetCostValue(queryResp.GetStatus()), + HTTPReturnScannedRemoteBytes: scannedRemoteBytes, + HTTPReturnScannedTotalBytes: scannedTotalBytes, + HTTPReturnCacheHitRatio: cacheHitRatio, + }) + } else { + HTTPReturnStream(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: outputData, + HTTPReturnCost: proxy.GetCostValue(queryResp.GetStatus()), + }) + } } } return resp, err @@ -1082,28 +1106,62 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN if err == nil { upsertResp := resp.(*milvuspb.MutationResult) cost := proxy.GetCostValue(upsertResp.GetStatus()) + scannedRemoteBytes, scannedTotalBytes, cacheHitRatio, isValid := proxy.GetStorageCost(upsertResp.GetStatus()) switch upsertResp.IDs.GetIdField().(type) { case *schemapb.IDs_IntId: allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) if allowJS { + if proxy.Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() && isValid { + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}, + HTTPReturnCost: cost, + HTTPReturnScannedRemoteBytes: scannedRemoteBytes, + HTTPReturnScannedTotalBytes: scannedTotalBytes, + HTTPReturnCacheHitRatio: cacheHitRatio, + }) + } else { + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}, + HTTPReturnCost: cost, + }) + } + } else { + if proxy.Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() && isValid { + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": formatInt64(upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}, + HTTPReturnCost: cost, + HTTPReturnScannedRemoteBytes: scannedRemoteBytes, + HTTPReturnScannedTotalBytes: scannedTotalBytes, + HTTPReturnCacheHitRatio: cacheHitRatio, + }) + } else { + HTTPReturn(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": formatInt64(upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}, + HTTPReturnCost: cost, + }) + } + } + case *schemapb.IDs_StrId: + if proxy.Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() && isValid { HTTPReturn(c, http.StatusOK, gin.H{ - HTTPReturnCode: merr.Code(nil), - HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}, - HTTPReturnCost: cost, + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}, + HTTPReturnCost: cost, + HTTPReturnScannedRemoteBytes: scannedRemoteBytes, + HTTPReturnScannedTotalBytes: scannedTotalBytes, + HTTPReturnCacheHitRatio: cacheHitRatio, }) } else { HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(nil), - HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": formatInt64(upsertResp.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}, + HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}, HTTPReturnCost: cost, }) } - case *schemapb.IDs_StrId: - HTTPReturn(c, http.StatusOK, gin.H{ - HTTPReturnCode: merr.Code(nil), - HTTPReturnData: gin.H{"upsertCount": upsertResp.UpsertCnt, "upsertIds": upsertResp.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}, - HTTPReturnCost: cost, - }) default: HTTPReturn(c, http.StatusOK, gin.H{ HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), @@ -1241,6 +1299,7 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN if err == nil { searchResp := resp.(*milvuspb.SearchResults) cost := proxy.GetCostValue(searchResp.GetStatus()) + scannedRemoteBytes, scannedTotalBytes, cacheHitRatio, isValid := proxy.GetStorageCost(searchResp.GetStatus()) if searchResp.Results.TopK == int64(0) { HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: []interface{}{}, HTTPReturnCost: cost}) } else { @@ -1254,9 +1313,45 @@ func (h *HandlersV2) search(ctx context.Context, c *gin.Context, anyReq any, dbN }) } else { if len(searchResp.Results.Recalls) > 0 { - HTTPReturnStream(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: outputData, HTTPReturnCost: cost, HTTPReturnRecalls: searchResp.Results.Recalls, HTTPReturnTopks: searchResp.Results.Topks}) + if proxy.Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() && isValid { + HTTPReturnStream(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: outputData, + HTTPReturnCost: cost, + HTTPReturnRecalls: searchResp.Results.Recalls, + HTTPReturnTopks: searchResp.Results.Topks, + HTTPReturnScannedRemoteBytes: scannedRemoteBytes, + HTTPReturnScannedTotalBytes: scannedTotalBytes, + HTTPReturnCacheHitRatio: cacheHitRatio, + }) + } else { + HTTPReturnStream(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: outputData, + HTTPReturnCost: cost, + HTTPReturnRecalls: searchResp.Results.Recalls, + HTTPReturnTopks: searchResp.Results.Topks, + }) + } } else { - HTTPReturnStream(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: outputData, HTTPReturnCost: cost, HTTPReturnTopks: searchResp.Results.Topks}) + if proxy.Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() && isValid { + HTTPReturnStream(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: outputData, + HTTPReturnCost: cost, + HTTPReturnTopks: searchResp.Results.Topks, + HTTPReturnScannedRemoteBytes: scannedRemoteBytes, + HTTPReturnScannedTotalBytes: scannedTotalBytes, + HTTPReturnCacheHitRatio: cacheHitRatio, + }) + } else { + HTTPReturnStream(c, http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(nil), + HTTPReturnData: outputData, + HTTPReturnCost: cost, + HTTPReturnTopks: searchResp.Results.Topks, + }) + } } } } @@ -1355,6 +1450,7 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq if err == nil { searchResp := resp.(*milvuspb.SearchResults) cost := proxy.GetCostValue(searchResp.GetStatus()) + scannedRemoteBytes, scannedTotalBytes, cacheHitRatio, isValid := proxy.GetStorageCost(searchResp.GetStatus()) if searchResp.Results.TopK == int64(0) { HTTPReturn(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: []interface{}{}, HTTPReturnCost: cost}) } else { @@ -1367,7 +1463,11 @@ func (h *HandlersV2) advancedSearch(ctx context.Context, c *gin.Context, anyReq HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), }) } else { - HTTPReturnStream(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: outputData, HTTPReturnCost: cost, HTTPReturnTopks: searchResp.Results.Topks}) + if proxy.Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() && isValid { + HTTPReturnStream(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: outputData, HTTPReturnCost: cost, HTTPReturnTopks: searchResp.Results.Topks, HTTPReturnScannedRemoteBytes: scannedRemoteBytes, HTTPReturnScannedTotalBytes: scannedTotalBytes, HTTPReturnCacheHitRatio: cacheHitRatio}) + } else { + HTTPReturnStream(c, http.StatusOK, gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: outputData, HTTPReturnCost: cost, HTTPReturnTopks: searchResp.Results.Topks}) + } } } } diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 3b50bf3bc1..6f60eaee8e 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -2491,6 +2491,16 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest) }) SetReportValue(dr.result.GetStatus(), v) + if Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() { + metrics.ProxyScannedRemoteMB.WithLabelValues(nodeID, metrics.DeleteLabel, dbName, collectionName).Add(float64(dr.scannedRemoteBytes.Load()) / 1024 / 1024) + metrics.ProxyScannedTotalMB.WithLabelValues(nodeID, metrics.DeleteLabel, dbName, collectionName).Add(float64(dr.scannedTotalBytes.Load()) / 1024 / 1024) + } + + SetStorageCost(dr.result.GetStatus(), segcore.StorageCost{ + ScannedRemoteBytes: dr.scannedRemoteBytes.Load(), + ScannedTotalBytes: dr.scannedTotalBytes.Load(), + }) + if merr.Ok(dr.result.GetStatus()) { metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeDelete, dbName, username).Add(float64(v)) } @@ -2621,6 +2631,11 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) hookutil.FailCntKey: len(it.result.ErrIndex), }) SetReportValue(it.result.GetStatus(), v) + SetStorageCost(it.result.GetStatus(), it.storageCost) + if Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() { + metrics.ProxyScannedRemoteMB.WithLabelValues(nodeID, metrics.UpsertLabel, dbName, collectionName).Add(float64(it.storageCost.ScannedRemoteBytes) / 1024 / 1024) + metrics.ProxyScannedTotalMB.WithLabelValues(nodeID, metrics.UpsertLabel, dbName, collectionName).Add(float64(it.storageCost.ScannedTotalBytes) / 1024 / 1024) + } if merr.Ok(it.result.GetStatus()) { metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeUpsert, dbName, username).Add(float64(v)) } @@ -2881,19 +2896,21 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, collectionName, ).Observe(float64(searchDur)) - metrics.ProxyScannedRemoteBytes.WithLabelValues( - nodeID, - metrics.SearchLabel, - dbName, - collectionName, - ).Add(float64(qt.storageCost.ScannedRemoteBytes)) + if Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() { + metrics.ProxyScannedRemoteMB.WithLabelValues( + nodeID, + metrics.SearchLabel, + dbName, + collectionName, + ).Add(float64(qt.storageCost.ScannedRemoteBytes) / 1024 / 1024) - metrics.ProxyScannedTotalBytes.WithLabelValues( - nodeID, - metrics.SearchLabel, - dbName, - collectionName, - ).Add(float64(qt.storageCost.ScannedTotalBytes)) + metrics.ProxyScannedTotalMB.WithLabelValues( + nodeID, + metrics.SearchLabel, + dbName, + collectionName, + ).Add(float64(qt.storageCost.ScannedTotalBytes) / 1024 / 1024) + } if qt.result != nil { username := GetCurUserFromContextOrDefault(ctx) @@ -3103,19 +3120,21 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea collectionName, ).Observe(float64(searchDur)) - metrics.ProxyScannedRemoteBytes.WithLabelValues( - nodeID, - metrics.HybridSearchLabel, - dbName, - collectionName, - ).Add(float64(qt.storageCost.ScannedRemoteBytes)) + if Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() { + metrics.ProxyScannedRemoteMB.WithLabelValues( + nodeID, + metrics.HybridSearchLabel, + dbName, + collectionName, + ).Add(float64(qt.storageCost.ScannedRemoteBytes) / 1024 / 1024) - metrics.ProxyScannedTotalBytes.WithLabelValues( - nodeID, - metrics.HybridSearchLabel, - dbName, - collectionName, - ).Add(float64(qt.storageCost.ScannedTotalBytes)) + metrics.ProxyScannedTotalMB.WithLabelValues( + nodeID, + metrics.HybridSearchLabel, + dbName, + collectionName, + ).Add(float64(qt.storageCost.ScannedTotalBytes) / 1024 / 1024) + } if qt.result != nil { sentSize := proto.Size(qt.result) @@ -3358,20 +3377,6 @@ func (node *Proxy) query(ctx context.Context, qt *queryTask, sp trace.Span) (*mi request.DbName, request.CollectionName, ).Observe(float64(tr.ElapseSpan().Milliseconds())) - - metrics.ProxyScannedRemoteBytes.WithLabelValues( - strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.QueryLabel, - request.DbName, - request.CollectionName, - ).Add(float64(qt.storageCost.ScannedRemoteBytes)) - - metrics.ProxyScannedTotalBytes.WithLabelValues( - strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.QueryLabel, - request.DbName, - request.CollectionName, - ).Add(float64(qt.storageCost.ScannedTotalBytes)) } return qt.result, qt.storageCost, nil @@ -3422,6 +3427,23 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* method := "Query" res, storageCost, err := node.query(ctx, qt, sp) + + if Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() { + metrics.ProxyScannedRemoteMB.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.QueryLabel, + request.DbName, + request.CollectionName, + ).Add(float64(qt.storageCost.ScannedRemoteBytes) / 1024 / 1024) + + metrics.ProxyScannedTotalMB.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.QueryLabel, + request.DbName, + request.CollectionName, + ).Add(float64(qt.storageCost.ScannedTotalBytes) / 1024 / 1024) + } + if err != nil || !merr.Ok(res.Status) { return res, err } diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index d20d440c1c..759d64ec40 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -19,6 +19,7 @@ import ( "github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/exprutil" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" @@ -64,6 +65,7 @@ type deleteTask struct { // result count int64 allQueryCnt int64 + storageCost segcore.StorageCost sessionTS Timestamp } @@ -255,6 +257,9 @@ type deleteRunner struct { allQueryCnt atomic.Int64 sessionTS atomic.Uint64 + + scannedRemoteBytes atomic.Int64 + scannedTotalBytes atomic.Int64 } func (dr *deleteRunner) Init(ctx context.Context) error { @@ -459,6 +464,8 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe close(taskCh) }() var allQueryCnt int64 + var scannedRemoteBytes int64 + var scannedTotalBytes int64 // wait all task finish var sessionTS uint64 for task := range taskCh { @@ -468,6 +475,8 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe } dr.count.Add(task.count) allQueryCnt += task.allQueryCnt + scannedRemoteBytes += task.storageCost.ScannedRemoteBytes + scannedTotalBytes += task.storageCost.ScannedTotalBytes if sessionTS < task.sessionTS { sessionTS = task.sessionTS } @@ -479,6 +488,8 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe } dr.allQueryCnt.Add(allQueryCnt) dr.sessionTS.Store(sessionTS) + dr.scannedRemoteBytes.Add(scannedRemoteBytes) + dr.scannedTotalBytes.Add(scannedTotalBytes) return nil } } @@ -522,6 +533,10 @@ func (dr *deleteRunner) receiveQueryResult(ctx context.Context, client querypb.Q return err } task.allQueryCnt = result.GetAllRetrieveCount() + task.storageCost = segcore.StorageCost{ + ScannedRemoteBytes: result.GetScannedRemoteBytes(), + ScannedTotalBytes: result.GetScannedTotalBytes(), + } taskCh <- task } diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index b7c2b61bd3..431810b247 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" @@ -76,6 +77,8 @@ type upsertTask struct { deletePKs *schemapb.IDs insertFieldData []*schemapb.FieldData + + storageCost segcore.StorageCost } // TraceCtx returns upsertTask context @@ -155,7 +158,7 @@ func (it *upsertTask) OnEnqueue() error { return nil } -func retrieveByPKs(ctx context.Context, t *upsertTask, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) { +func retrieveByPKs(ctx context.Context, t *upsertTask, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, segcore.StorageCost, error) { log := log.Ctx(ctx).With(zap.String("collectionName", t.req.GetCollectionName())) var err error queryReq := &milvuspb.QueryRequest{ @@ -173,7 +176,7 @@ func retrieveByPKs(ctx context.Context, t *upsertTask, ids *schemapb.IDs, output } pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema) if err != nil { - return nil, err + return nil, segcore.StorageCost{}, err } var partitionIDs []int64 @@ -188,12 +191,12 @@ func retrieveByPKs(ctx context.Context, t *upsertTask, ids *schemapb.IDs, output partName := t.upsertMsg.DeleteMsg.PartitionName if err := validatePartitionTag(partName, true); err != nil { log.Warn("Invalid partition name", zap.String("partitionName", partName), zap.Error(err)) - return nil, err + return nil, segcore.StorageCost{}, err } partID, err := globalMetaCache.GetPartitionID(ctx, t.req.GetDbName(), t.req.GetCollectionName(), partName) if err != nil { log.Warn("Failed to get partition id", zap.String("partitionName", partName), zap.Error(err)) - return nil, err + return nil, segcore.StorageCost{}, err } partitionIDs = []int64{partID} queryReq.PartitionNames = []string{partName} @@ -222,12 +225,11 @@ func retrieveByPKs(ctx context.Context, t *upsertTask, ids *schemapb.IDs, output defer func() { sp.End() }() - // ignore storage cost? - queryResult, _, err := t.node.(*Proxy).query(ctx, qt, sp) + queryResult, storageCost, err := t.node.(*Proxy).query(ctx, qt, sp) if err := merr.CheckRPCCall(queryResult.GetStatus(), err); err != nil { - return nil, err + return nil, storageCost, err } - return queryResult, err + return queryResult, storageCost, err } func (it *upsertTask) queryPreExecute(ctx context.Context) error { @@ -261,12 +263,12 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error { tr := timerecord.NewTimeRecorder("Proxy-Upsert-retrieveByPKs") // retrieve by primary key to get original field data - resp, err := retrieveByPKs(ctx, it, upsertIDs, []string{"*"}) + resp, storageCost, err := retrieveByPKs(ctx, it, upsertIDs, []string{"*"}) if err != nil { log.Info("retrieve by primary key failed", zap.Error(err)) return err } - + it.storageCost = storageCost if len(resp.GetFieldsData()) == 0 { return merr.WrapErrParameterInvalidMsg("retrieve by primary key failed, no data found") } diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index ede82b06d0..a0b2bd17fa 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -695,7 +695,7 @@ func TestRetrieveByPKs_Success(t *testing.T) { }, } - result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"}) + result, _, err := retrieveByPKs(context.Background(), task, ids, []string{"*"}) // Verify results assert.NoError(t, err) @@ -717,7 +717,7 @@ func TestRetrieveByPKs_GetPrimaryFieldSchemaError(t *testing.T) { }, } - result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"}) + result, _, err := retrieveByPKs(context.Background(), task, ids, []string{"*"}) assert.Error(t, err) assert.Nil(t, result) @@ -750,7 +750,7 @@ func TestRetrieveByPKs_PartitionKeyMode(t *testing.T) { }, } - result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"}) + result, _, err := retrieveByPKs(context.Background(), task, ids, []string{"*"}) assert.NoError(t, err) assert.NotNil(t, result) @@ -829,7 +829,7 @@ func TestUpdateTask_queryPreExecute_Success(t *testing.T) { }, }, }, - }, nil).Build() + }, segcore.StorageCost{}, nil).Build() mockey.Mock(typeutil.NewIDsChecker).Return(&typeutil.IDsChecker{}, nil).Build() @@ -1148,7 +1148,7 @@ func TestUpsertTask_queryPreExecute_MixLogic(t *testing.T) { node: &Proxy{}, } - mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, nil).Build() + mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, segcore.StorageCost{}, nil).Build() defer mockRetrieve.UnPatch() err := task.queryPreExecute(context.Background()) @@ -1237,7 +1237,7 @@ func TestUpsertTask_queryPreExecute_PureInsert(t *testing.T) { node: &Proxy{}, } - mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, nil).Build() + mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, segcore.StorageCost{}, nil).Build() defer mockRetrieve.UnPatch() err := task.queryPreExecute(context.Background()) @@ -1325,7 +1325,7 @@ func TestUpsertTask_queryPreExecute_PureUpdate(t *testing.T) { node: &Proxy{}, } - mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, nil).Build() + mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, segcore.StorageCost{}, nil).Build() defer mockRetrieve.UnPatch() err := task.queryPreExecute(context.Background()) diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 862c403788..6fb43da272 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -2418,6 +2418,9 @@ func SetReportValue(status *commonpb.Status, value int) { } func SetStorageCost(status *commonpb.Status, storageCost segcore.StorageCost) { + if !Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() { + return + } if storageCost.ScannedTotalBytes <= 0 { return } @@ -2448,6 +2451,45 @@ func GetCostValue(status *commonpb.Status) int { return value } +// final return value means value is valid or not +func GetStorageCost(status *commonpb.Status) (int64, int64, float64, bool) { + if status == nil || status.ExtraInfo == nil { + return 0, 0, 0, false + } + var scannedRemoteBytes int64 + var scannedTotalBytes int64 + var cacheHitRatio float64 + var err error + if value, ok := status.ExtraInfo["scanned_remote_bytes"]; ok { + scannedRemoteBytes, err = strconv.ParseInt(value, 10, 64) + if err != nil { + log.Warn("scanned_remote_bytes is not a valid int64", zap.String("value", value), zap.Error(err)) + return 0, 0, 0, false + } + } else { + return 0, 0, 0, false + } + if value, ok := status.ExtraInfo["scanned_total_bytes"]; ok { + scannedTotalBytes, err = strconv.ParseInt(value, 10, 64) + if err != nil { + log.Warn("scanned_total_bytes is not a valid int64", zap.String("value", value), zap.Error(err)) + return 0, 0, 0, false + } + } else { + return 0, 0, 0, false + } + if value, ok := status.ExtraInfo["cache_hit_ratio"]; ok { + cacheHitRatio, err = strconv.ParseFloat(value, 64) + if err != nil { + log.Warn("cache_hit_ratio is not a valid float64", zap.String("value", value), zap.Error(err)) + return 0, 0, 0, false + } + } else { + return 0, 0, 0, false + } + return scannedRemoteBytes, scannedTotalBytes, cacheHitRatio, true +} + // GetRequestInfo returns collection name and rateType of request and return tokens needed. func GetRequestInfo(ctx context.Context, req proto.Message) (int64, map[int64][]int64, internalpb.RateType, int, error) { switch r := req.(type) { diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 19f8abee6f..15f3902a9f 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -4593,3 +4593,85 @@ func TestLackOfFieldsDataBySchema(t *testing.T) { }) } } + +func TestGetStorageCost(t *testing.T) { + // nil and empty cases + t.Run("nil or empty status", func(t *testing.T) { + { + remote, total, ratio, ok := GetStorageCost(nil) + assert.Equal(t, int64(0), remote) + assert.Equal(t, int64(0), total) + assert.Equal(t, 0.0, ratio) + assert.False(t, ok) + } + { + remote, total, ratio, ok := GetStorageCost(&commonpb.Status{}) + assert.Equal(t, int64(0), remote) + assert.Equal(t, int64(0), total) + assert.Equal(t, 0.0, ratio) + assert.False(t, ok) + } + }) + + // missing keys should result in zeros + t.Run("missing keys", func(t *testing.T) { + st := &commonpb.Status{ExtraInfo: map[string]string{ + "scanned_remote_bytes": "100", + }} + remote, total, ratio, ok := GetStorageCost(st) + assert.Equal(t, int64(0), remote) + assert.Equal(t, int64(0), total) + assert.Equal(t, 0.0, ratio) + assert.False(t, ok) + }) + + // invalid number formats should result in zeros + t.Run("invalid formats", func(t *testing.T) { + st := &commonpb.Status{ExtraInfo: map[string]string{ + "scanned_remote_bytes": "x", + "scanned_total_bytes": "200", + "cache_hit_ratio": "0.5", + }} + remote, total, ratio, ok := GetStorageCost(st) + assert.Equal(t, int64(0), remote) + assert.Equal(t, int64(0), total) + assert.Equal(t, 0.0, ratio) + assert.False(t, ok) + + st = &commonpb.Status{ExtraInfo: map[string]string{ + "scanned_remote_bytes": "100", + "scanned_total_bytes": "y", + "cache_hit_ratio": "0.5", + }} + remote, total, ratio, ok = GetStorageCost(st) + assert.Equal(t, int64(0), remote) + assert.Equal(t, int64(0), total) + assert.Equal(t, 0.0, ratio) + assert.False(t, ok) + + st = &commonpb.Status{ExtraInfo: map[string]string{ + "scanned_remote_bytes": "100", + "scanned_total_bytes": "200", + "cache_hit_ratio": "abc", + }} + remote, total, ratio, ok = GetStorageCost(st) + assert.Equal(t, int64(0), remote) + assert.Equal(t, int64(0), total) + assert.Equal(t, 0.0, ratio) + assert.False(t, ok) + }) + + // success case + t.Run("success", func(t *testing.T) { + st := &commonpb.Status{ExtraInfo: map[string]string{ + "scanned_remote_bytes": "123", + "scanned_total_bytes": "456", + "cache_hit_ratio": "0.27", + }} + remote, total, ratio, ok := GetStorageCost(st) + assert.Equal(t, int64(123), remote) + assert.Equal(t, int64(456), total) + assert.InDelta(t, 0.27, ratio, 1e-9) + assert.True(t, ok) + }) +} diff --git a/pkg/metrics/proxy_metrics.go b/pkg/metrics/proxy_metrics.go index 9eb7fe2420..a85a24e66a 100644 --- a/pkg/metrics/proxy_metrics.go +++ b/pkg/metrics/proxy_metrics.go @@ -465,21 +465,21 @@ var ( Buckets: buckets, }, []string{nodeIDLabelName, collectionName, functionTypeName, functionProvider, functionName}) - ProxyScannedRemoteBytes = prometheus.NewCounterVec( + ProxyScannedRemoteMB = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: milvusNamespace, Subsystem: typeutil.ProxyRole, - Name: "scanned_remote_bytes", - Help: "the scanned remote bytes", - }, []string{nodeIDLabelName, queryTypeLabelName, databaseLabelName, collectionName}) + Name: "scanned_remote_mb", + Help: "the scanned remote megabytes", + }, []string{nodeIDLabelName, msgTypeLabelName, databaseLabelName, collectionName}) - ProxyScannedTotalBytes = prometheus.NewCounterVec( + ProxyScannedTotalMB = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: milvusNamespace, Subsystem: typeutil.ProxyRole, - Name: "scanned_total_bytes", - Help: "the scanned total bytes", - }, []string{nodeIDLabelName, queryTypeLabelName, databaseLabelName, collectionName}) + Name: "scanned_total_mb", + Help: "the scanned total megabytes", + }, []string{nodeIDLabelName, msgTypeLabelName, databaseLabelName, collectionName}) ) // RegisterProxy registers Proxy metrics @@ -550,8 +550,8 @@ func RegisterProxy(registry *prometheus.Registry) { registry.MustRegister(ProxyFunctionlatency) - registry.MustRegister(ProxyScannedRemoteBytes) - registry.MustRegister(ProxyScannedTotalBytes) + registry.MustRegister(ProxyScannedRemoteMB) + registry.MustRegister(ProxyScannedTotalMB) RegisterStreamingServiceClient(registry) } @@ -714,28 +714,52 @@ func CleanupProxyCollectionMetrics(nodeID int64, dbName string, collection strin databaseLabelName: dbName, collectionName: collection, }) - ProxyScannedRemoteBytes.Delete(prometheus.Labels{ - nodeIDLabelName: strconv.FormatInt(nodeID, 10), - queryTypeLabelName: SearchLabel, - databaseLabelName: dbName, - collectionName: collection, + ProxyScannedRemoteMB.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: SearchLabel, + databaseLabelName: dbName, + collectionName: collection, }) - ProxyScannedRemoteBytes.Delete(prometheus.Labels{ - nodeIDLabelName: strconv.FormatInt(nodeID, 10), - queryTypeLabelName: QueryLabel, - databaseLabelName: dbName, - collectionName: collection, + ProxyScannedRemoteMB.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: QueryLabel, + databaseLabelName: dbName, + collectionName: collection, }) - ProxyScannedTotalBytes.Delete(prometheus.Labels{ - nodeIDLabelName: strconv.FormatInt(nodeID, 10), - queryTypeLabelName: SearchLabel, - databaseLabelName: dbName, - collectionName: collection, + ProxyScannedRemoteMB.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: UpsertLabel, + databaseLabelName: dbName, + collectionName: collection, }) - ProxyScannedTotalBytes.Delete(prometheus.Labels{ - nodeIDLabelName: strconv.FormatInt(nodeID, 10), - queryTypeLabelName: QueryLabel, - databaseLabelName: dbName, - collectionName: collection, + ProxyScannedRemoteMB.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: DeleteLabel, + databaseLabelName: dbName, + collectionName: collection, + }) + ProxyScannedTotalMB.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: SearchLabel, + databaseLabelName: dbName, + collectionName: collection, + }) + ProxyScannedTotalMB.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: QueryLabel, + databaseLabelName: dbName, + collectionName: collection, + }) + ProxyScannedTotalMB.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: UpsertLabel, + databaseLabelName: dbName, + collectionName: collection, + }) + ProxyScannedTotalMB.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + msgTypeLabelName: DeleteLabel, + databaseLabelName: dbName, + collectionName: collection, }) }