mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 17:48:29 +08:00
enhance: add metrics for counting number of nun-zeros/tokens of sparse/FTS search (#38329)
sparse vectors may have arbitrary number of non zeros and it is hard to optimize without knowing the actual distribution of nnz. this PR adds a metric for analyzing that. issue: https://github.com/milvus-io/milvus/issues/35853 comparing with https://github.com/milvus-io/milvus/pull/38328, this includes also metric for FTS in query node delegator also fixed a bug of sparse when searching by pk Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
This commit is contained in:
parent
b14a0c4bf5
commit
75e64b993f
@ -488,6 +488,7 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
metrics.ProxySearchSparseNumNonZeros.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), t.collectionName).Observe(float64(typeutil.EstimateSparseVectorNNZFromPlaceholderGroup(t.request.PlaceholderGroup, int(t.request.GetNq()))))
|
||||||
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup
|
||||||
t.SearchRequest.Topk = queryInfo.GetTopk()
|
t.SearchRequest.Topk = queryInfo.GetTopk()
|
||||||
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
t.SearchRequest.MetricType = queryInfo.GetMetricType()
|
||||||
|
|||||||
@ -1037,6 +1037,10 @@ func (sd *shardDelegator) buildBM25IDF(req *internalpb.SearchRequest) (float64,
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, idf := range idfSparseVector {
|
||||||
|
metrics.QueryNodeSearchFTSNumTokens.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(sd.collectionID)).Observe(float64(typeutil.SparseFloatRowElementCount(idf)))
|
||||||
|
}
|
||||||
|
|
||||||
err = SetBM25Params(req, avgdl)
|
err = SetBM25Params(req, avgdl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
|||||||
@ -417,6 +417,16 @@ var (
|
|||||||
Name: "recall_search_cnt",
|
Name: "recall_search_cnt",
|
||||||
Help: "counter of recall search",
|
Help: "counter of recall search",
|
||||||
}, []string{nodeIDLabelName, queryTypeLabelName, collectionName})
|
}, []string{nodeIDLabelName, queryTypeLabelName, collectionName})
|
||||||
|
|
||||||
|
// ProxySearchSparseNumNonZeros records the estimated number of non-zeros in each sparse search task
|
||||||
|
ProxySearchSparseNumNonZeros = prometheus.NewHistogramVec(
|
||||||
|
prometheus.HistogramOpts{
|
||||||
|
Namespace: milvusNamespace,
|
||||||
|
Subsystem: typeutil.ProxyRole,
|
||||||
|
Name: "search_sparse_num_non_zeros",
|
||||||
|
Help: "the number of non-zeros in each sparse search task",
|
||||||
|
Buckets: buckets,
|
||||||
|
}, []string{nodeIDLabelName, collectionName})
|
||||||
)
|
)
|
||||||
|
|
||||||
// RegisterProxy registers Proxy metrics
|
// RegisterProxy registers Proxy metrics
|
||||||
@ -479,6 +489,8 @@ func RegisterProxy(registry *prometheus.Registry) {
|
|||||||
registry.MustRegister(ProxyRetrySearchResultInsufficientCount)
|
registry.MustRegister(ProxyRetrySearchResultInsufficientCount)
|
||||||
registry.MustRegister(ProxyRecallSearchCount)
|
registry.MustRegister(ProxyRecallSearchCount)
|
||||||
|
|
||||||
|
registry.MustRegister(ProxySearchSparseNumNonZeros)
|
||||||
|
|
||||||
RegisterStreamingServiceClient(registry)
|
RegisterStreamingServiceClient(registry)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -352,6 +352,18 @@ var (
|
|||||||
nodeIDLabelName,
|
nodeIDLabelName,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
QueryNodeSearchFTSNumTokens = prometheus.NewHistogramVec(
|
||||||
|
prometheus.HistogramOpts{
|
||||||
|
Namespace: milvusNamespace,
|
||||||
|
Subsystem: typeutil.QueryNodeRole,
|
||||||
|
Name: "search_fts_num_tokens",
|
||||||
|
Help: "number of tokens in each Full Text Search search task",
|
||||||
|
Buckets: buckets,
|
||||||
|
}, []string{
|
||||||
|
nodeIDLabelName,
|
||||||
|
collectionIDLabelName,
|
||||||
|
})
|
||||||
|
|
||||||
QueryNodeSearchGroupSize = prometheus.NewHistogramVec(
|
QueryNodeSearchGroupSize = prometheus.NewHistogramVec(
|
||||||
prometheus.HistogramOpts{
|
prometheus.HistogramOpts{
|
||||||
Namespace: milvusNamespace,
|
Namespace: milvusNamespace,
|
||||||
@ -832,6 +844,7 @@ func RegisterQueryNode(registry *prometheus.Registry) {
|
|||||||
registry.MustRegister(QueryNodeEvictedReadReqCount)
|
registry.MustRegister(QueryNodeEvictedReadReqCount)
|
||||||
registry.MustRegister(QueryNodeSearchGroupTopK)
|
registry.MustRegister(QueryNodeSearchGroupTopK)
|
||||||
registry.MustRegister(QueryNodeSearchTopK)
|
registry.MustRegister(QueryNodeSearchTopK)
|
||||||
|
registry.MustRegister(QueryNodeSearchFTSNumTokens)
|
||||||
registry.MustRegister(QueryNodeNumFlowGraphs)
|
registry.MustRegister(QueryNodeNumFlowGraphs)
|
||||||
registry.MustRegister(QueryNodeNumEntities)
|
registry.MustRegister(QueryNodeNumEntities)
|
||||||
registry.MustRegister(QueryNodeEntitiesSize)
|
registry.MustRegister(QueryNodeEntitiesSize)
|
||||||
|
|||||||
@ -2,7 +2,6 @@ package funcutil
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
@ -97,14 +96,10 @@ func fieldDataToPlaceholderValue(fieldData *schemapb.FieldData) (*commonpb.Place
|
|||||||
return nil, errors.New("vector data is not schemapb.VectorField_SparseFloatVector")
|
return nil, errors.New("vector data is not schemapb.VectorField_SparseFloatVector")
|
||||||
}
|
}
|
||||||
vec := vectors.SparseFloatVector
|
vec := vectors.SparseFloatVector
|
||||||
bytes, err := proto.Marshal(vec)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to marshal schemapb.SparseFloatArray to bytes: %w", err)
|
|
||||||
}
|
|
||||||
placeholderValue := &commonpb.PlaceholderValue{
|
placeholderValue := &commonpb.PlaceholderValue{
|
||||||
Tag: "$0",
|
Tag: "$0",
|
||||||
Type: commonpb.PlaceholderType_SparseFloatVector,
|
Type: commonpb.PlaceholderType_SparseFloatVector,
|
||||||
Values: [][]byte{bytes},
|
Values: vec.Contents,
|
||||||
}
|
}
|
||||||
return placeholderValue, nil
|
return placeholderValue, nil
|
||||||
case schemapb.DataType_VarChar:
|
case schemapb.DataType_VarChar:
|
||||||
|
|||||||
@ -1919,3 +1919,11 @@ func SparseFloatRowDim(row []byte) int64 {
|
|||||||
}
|
}
|
||||||
return int64(SparseFloatRowIndexAt(row, SparseFloatRowElementCount(row)-1)) + 1
|
return int64(SparseFloatRowIndexAt(row, SparseFloatRowElementCount(row)-1)) + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// placeholderGroup is a serialized PlaceholderGroup, return estimated total
|
||||||
|
// number of non-zero elements of all the sparse vectors in the placeholderGroup
|
||||||
|
// This is a rough estimate, and should be used only for statistics.
|
||||||
|
func EstimateSparseVectorNNZFromPlaceholderGroup(placeholderGroup []byte, nq int) int {
|
||||||
|
overheadBytes := math.Max(10, float64(nq*3))
|
||||||
|
return (len(placeholderGroup) - int(overheadBytes)) / 8
|
||||||
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"math/rand"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@ -2714,3 +2715,67 @@ func TestParseJsonSparseFloatRowBytes(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test EstimateSparseVectorNNZFromPlaceholderGroup: given a PlaceholderGroup
|
||||||
|
// with various nq and averageNNZ, test if the estimated number of non-zero
|
||||||
|
// elements is close to the actual number.
|
||||||
|
func TestSparsePlaceholderGroupSize(t *testing.T) {
|
||||||
|
nqs := []int{1, 10, 100, 1000, 10000}
|
||||||
|
averageNNZs := []int{1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048}
|
||||||
|
numCases := 0
|
||||||
|
casesWithLargeError := 0
|
||||||
|
for _, nq := range nqs {
|
||||||
|
for _, averageNNZ := range averageNNZs {
|
||||||
|
variants := make([]int, 0)
|
||||||
|
for i := 1; i <= averageNNZ/2; i *= 2 {
|
||||||
|
variants = append(variants, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, variant := range variants {
|
||||||
|
numCases++
|
||||||
|
contents := make([][]byte, nq)
|
||||||
|
contentsSize := 0
|
||||||
|
totalNNZ := 0
|
||||||
|
for i := range contents {
|
||||||
|
// nnz of each row is in range [averageNNZ - variant/2, averageNNZ + variant/2] and at least 1.
|
||||||
|
nnz := averageNNZ + variant/2 + rand.Intn(variant)
|
||||||
|
if nnz < 1 {
|
||||||
|
nnz = 1
|
||||||
|
}
|
||||||
|
indices := make([]uint32, nnz)
|
||||||
|
values := make([]float32, nnz)
|
||||||
|
for j := 0; j < nnz; j++ {
|
||||||
|
indices[j] = uint32(i*averageNNZ + j)
|
||||||
|
values[j] = float32(i*averageNNZ + j)
|
||||||
|
}
|
||||||
|
contents[i] = CreateSparseFloatRow(indices, values)
|
||||||
|
contentsSize += len(contents[i])
|
||||||
|
totalNNZ += nnz
|
||||||
|
}
|
||||||
|
|
||||||
|
placeholderGroup := &commonpb.PlaceholderGroup{
|
||||||
|
Placeholders: []*commonpb.PlaceholderValue{
|
||||||
|
{
|
||||||
|
Tag: "$0",
|
||||||
|
Type: commonpb.PlaceholderType_SparseFloatVector,
|
||||||
|
Values: contents,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
bytes, _ := proto.Marshal(placeholderGroup)
|
||||||
|
estimatedNNZ := EstimateSparseVectorNNZFromPlaceholderGroup(bytes, nq)
|
||||||
|
errorRatio := (float64(totalNNZ-estimatedNNZ) / float64(totalNNZ)) * 100
|
||||||
|
assert.Less(t, errorRatio, 10.0)
|
||||||
|
if errorRatio > 5.0 {
|
||||||
|
casesWithLargeError++
|
||||||
|
}
|
||||||
|
// keep the logs for easy debugging.
|
||||||
|
// fmt.Printf("nq: %d, total nnz: %d, overhead bytes: %d, len of bytes: %d\n", nq, totalNNZ, len(bytes)-contentsSize, len(bytes))
|
||||||
|
// fmt.Printf("\tnq: %d, total nnz: %d, estimated nnz: %d, diff: %d, error ratio: %f%%\n", nq, totalNNZ, estimatedNNZ, totalNNZ-estimatedNNZ, errorRatio)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
largeErrorRatio := (float64(casesWithLargeError) / float64(numCases)) * 100
|
||||||
|
// no more than 2% cases have large error ratio.
|
||||||
|
assert.Less(t, largeErrorRatio, 2.0)
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user