Chun Han b7ee93fc52
feat: support query aggregtion(#36380) (#44394)
related: #36380

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
- Core invariant: aggregation is centralized and schema-aware — all
aggregate functions are created via the exec Aggregate registry
(milvus::exec::Aggregate) and validated by ValidateAggFieldType, use a
single in-memory accumulator layout (Accumulator/RowContainer) and
grouping primitives (GroupingSet, HashTable, VectorHasher), ensuring
consistent typing, null semantics and offsets across planner → exec →
reducer conversion paths (toAggregateInfo, Aggregate::create,
GroupingSet, AggResult converters).

- Removed / simplified logic: removed ad‑hoc count/group-by and reducer
code (CountNode/PhyCountNode, GroupByNode/PhyGroupByNode, cntReducer and
its tests) and consolidated into a unified AggregationNode →
PhyAggregationNode + GroupingSet + HashTable execution path and
centralized reducers (MilvusAggReducer, InternalAggReducer,
SegcoreAggReducer). AVG now implemented compositionally (SUM + COUNT)
rather than a bespoke operator, eliminating duplicate implementations.

- Why this does NOT cause data loss or regressions: existing data-access
and serialization paths are preserved and explicitly validated —
bulk_subscript / bulk_script_field_data and FieldData creation are used
for output materialization; converters (InternalResult2AggResult ↔
AggResult2internalResult, SegcoreResults2AggResult ↔
AggResult2segcoreResult) enforce shape/type/row-count validation; proxy
and plan-level checks (MatchAggregationExpression,
translateOutputFields, ValidateAggFieldType, translateGroupByFieldIds)
reject unsupported inputs (ARRAY/JSON, unsupported datatypes) early.
Empty-result generation and explicit error returns guard against silent
corruption.

- New capability and scope: end-to-end GROUP BY and aggregation support
added across the stack — proto (plan.proto, RetrieveRequest fields
group_by_field_ids/aggregates), planner nodes (AggregationNode,
ProjectNode, SearchGroupByNode), exec operators (PhyAggregationNode,
PhyProjectNode) and aggregation core (Aggregate implementations:
Sum/Count/Min/Max, SimpleNumericAggregate, RowContainer, GroupingSet,
HashTable) plus proxy/querynode reducers and tests — enabling grouped
and global aggregation (sum, count, min, max, avg via sum+count) with
schema-aware validation and reduction.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
2026-01-06 16:29:25 +08:00

325 lines
9.6 KiB
Go

package segments
/*
#cgo pkg-config: milvus_core
#include "segcore/collection_c.h"
#include "segcore/segment_c.h"
#include "segcore/segcore_init_c.h"
#include "common/init_c.h"
*/
import "C"
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"strconv"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querynodev2/segments/metricsutil"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
var errLazyLoadTimeout = merr.WrapErrServiceInternal("lazy load time out")
func GetPkField(schema *schemapb.CollectionSchema) *schemapb.FieldSchema {
for _, field := range schema.GetFields() {
if field.GetIsPrimaryKey() {
return field
}
}
return nil
}
// TODO: remove this function to proper file
// GetPrimaryKeys would get primary keys by insert messages
func GetPrimaryKeys(msg *msgstream.InsertMsg, schema *schemapb.CollectionSchema) ([]storage.PrimaryKey, error) {
if msg.IsRowBased() {
return getPKsFromRowBasedInsertMsg(msg, schema)
}
return getPKsFromColumnBasedInsertMsg(msg, schema)
}
func getPKsFromRowBasedInsertMsg(msg *msgstream.InsertMsg, schema *schemapb.CollectionSchema) ([]storage.PrimaryKey, error) {
offset := 0
for _, field := range schema.Fields {
if field.IsPrimaryKey {
break
}
switch field.DataType {
case schemapb.DataType_Bool:
offset++
case schemapb.DataType_Int8:
offset++
case schemapb.DataType_Int16:
offset += 2
case schemapb.DataType_Int32:
offset += 4
case schemapb.DataType_Timestamptz, schemapb.DataType_Int64:
offset += 8
case schemapb.DataType_Float:
offset += 4
case schemapb.DataType_Double:
offset += 8
case schemapb.DataType_FloatVector:
for _, t := range field.TypeParams {
if t.Key == common.DimKey {
dim, err := strconv.Atoi(t.Value)
if err != nil {
return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err)
}
offset += dim * 4
break
}
}
case schemapb.DataType_BinaryVector:
for _, t := range field.TypeParams {
if t.Key == common.DimKey {
dim, err := strconv.Atoi(t.Value)
if err != nil {
return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err)
}
offset += dim / 8
break
}
}
case schemapb.DataType_Float16Vector:
for _, t := range field.TypeParams {
if t.Key == common.DimKey {
dim, err := strconv.Atoi(t.Value)
if err != nil {
return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err)
}
offset += dim * 2
break
}
}
case schemapb.DataType_BFloat16Vector:
for _, t := range field.TypeParams {
if t.Key == common.DimKey {
dim, err := strconv.Atoi(t.Value)
if err != nil {
return nil, fmt.Errorf("strconv wrong on get dim, err = %s", err)
}
offset += dim * 2
break
}
}
case schemapb.DataType_SparseFloatVector:
return nil, errors.New("SparseFloatVector not support in row based message")
}
}
log.Info(strconv.FormatInt(int64(offset), 10))
blobReaders := make([]io.Reader, len(msg.RowData))
for i, blob := range msg.RowData {
blobReaders[i] = bytes.NewReader(blob.GetValue()[offset : offset+8])
}
pks := make([]storage.PrimaryKey, len(blobReaders))
for i, reader := range blobReaders {
var int64PkValue int64
err := binary.Read(reader, common.Endian, &int64PkValue)
if err != nil {
log.Warn("binary read blob value failed", zap.Error(err))
return nil, err
}
pks[i] = storage.NewInt64PrimaryKey(int64PkValue)
}
return pks, nil
}
func getPKsFromColumnBasedInsertMsg(msg *msgstream.InsertMsg, schema *schemapb.CollectionSchema) ([]storage.PrimaryKey, error) {
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema)
if err != nil {
return nil, err
}
primaryFieldData, err := typeutil.GetPrimaryFieldData(msg.GetFieldsData(), primaryFieldSchema)
if err != nil {
return nil, err
}
pks, err := storage.ParseFieldData2PrimaryKeys(primaryFieldData)
if err != nil {
return nil, err
}
return pks, nil
}
// mergeRequestCost merge the costs of request, the cost may come from different worker in same channel
// or different channel in same collection, for now we just choose the part with the highest response time
func mergeRequestCost(requestCosts []*internalpb.CostAggregation) *internalpb.CostAggregation {
var result *internalpb.CostAggregation
for _, cost := range requestCosts {
if result == nil || result.ResponseTime < cost.ResponseTime {
result = cost
}
}
return result
}
func getIndexEngineVersion() (minimal, current int32) {
GetDynamicPool().Submit(func() (any, error) {
cMinimal, cCurrent := C.GetMinimalIndexVersion(), C.GetCurrentIndexVersion()
minimal, current = int32(cMinimal), int32(cCurrent)
return nil, nil
}).Await()
return minimal, current
}
// getSegmentMetricLabel returns the label for segment metrics.
func getSegmentMetricLabel(segment Segment) metricsutil.SegmentLabel {
return metricsutil.SegmentLabel{
DatabaseName: segment.DatabaseName(),
ResourceGroup: segment.ResourceGroup(),
}
}
func FilterZeroValuesFromSlice(intVals []int64) []int64 {
var result []int64
for _, value := range intVals {
if value != 0 {
result = append(result, value)
}
}
return result
}
// withLazyLoadTimeoutContext returns a new context with lazy load timeout.
func withLazyLoadTimeoutContext(ctx context.Context) (context.Context, context.CancelFunc) {
lazyLoadTimeout := paramtable.Get().QueryNodeCfg.LazyLoadWaitTimeout.GetAsDuration(time.Millisecond)
// TODO: use context.WithTimeoutCause instead of contextutil.WithTimeoutCause in go1.21
return contextutil.WithTimeoutCause(ctx, lazyLoadTimeout, errLazyLoadTimeout)
}
func GetSegmentRelatedDataSize(segment Segment) int64 {
if segment.Type() == SegmentTypeSealed {
return calculateSegmentLogSize(segment.LoadInfo())
}
return segment.MemSize()
}
func calculateSegmentLogSize(segmentLoadInfo *querypb.SegmentLoadInfo) int64 {
segmentSize := int64(0)
for _, fieldBinlog := range segmentLoadInfo.BinlogPaths {
segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog)
}
// Get size of state data
for _, fieldBinlog := range segmentLoadInfo.Statslogs {
segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog)
}
// Get size of delete data
for _, fieldBinlog := range segmentLoadInfo.Deltalogs {
segmentSize += getFieldSizeFromFieldBinlog(fieldBinlog)
}
return segmentSize
}
func calculateSegmentMemorySize(segmentLoadInfo *querypb.SegmentLoadInfo) int64 {
segmentSize := int64(0)
for _, fieldBinlog := range segmentLoadInfo.BinlogPaths {
segmentSize += getBinlogDataMemorySize(fieldBinlog)
}
for _, fieldBinlog := range segmentLoadInfo.Statslogs {
segmentSize += getBinlogDataMemorySize(fieldBinlog)
}
for _, fieldBinlog := range segmentLoadInfo.Deltalogs {
segmentSize += getBinlogDataMemorySize(fieldBinlog)
}
return segmentSize
}
func getFieldSizeFromFieldBinlog(fieldBinlog *datapb.FieldBinlog) int64 {
fieldSize := int64(0)
for _, binlog := range fieldBinlog.Binlogs {
fieldSize += binlog.LogSize
}
return fieldSize
}
func getFieldSchema(schema *schemapb.CollectionSchema, fieldID int64) (*schemapb.FieldSchema, error) {
for _, field := range schema.Fields {
if field.FieldID == fieldID {
return field, nil
}
}
for _, structArrayField := range schema.StructArrayFields {
for _, subField := range structArrayField.Fields {
if subField.FieldID == fieldID {
return subField, nil
}
}
}
return nil, fmt.Errorf("field %d not found in schema", fieldID)
}
func isIndexMmapEnable(fieldSchema *schemapb.FieldSchema, indexInfo *querypb.FieldIndexInfo) bool {
enableMmap, exist := common.IsMmapIndexEnabled(indexInfo.IndexParams...)
// fast path for returning disabled, need to perform index type check for enabled case
if exist && !enableMmap {
return enableMmap
}
indexType := common.GetIndexType(indexInfo.IndexParams)
var indexSupportMmap bool
// var defaultEnableMmap bool
if typeutil.IsVectorType(fieldSchema.GetDataType()) {
indexSupportMmap = vecindexmgr.GetVecIndexMgrInstance().IsMMapSupported(indexType)
enableMmap = params.Params.QueryNodeCfg.MmapVectorIndex.GetAsBool() || enableMmap
} else {
indexSupportMmap = indexparamcheck.IsScalarMmapIndex(indexType)
enableMmap = params.Params.QueryNodeCfg.MmapScalarIndex.GetAsBool() || enableMmap
}
return indexSupportMmap && enableMmap
}
// Except accepting whether the raw data is loaded in mmap or not, it also affects the stats index such as
// text match index and json key stats index.
func isDataMmapEnable(fieldSchema *schemapb.FieldSchema) bool {
enableMmap, exist := common.IsMmapDataEnabled(fieldSchema.GetTypeParams()...)
if exist {
return enableMmap
}
if typeutil.IsVectorType(fieldSchema.GetDataType()) {
return params.Params.QueryNodeCfg.MmapVectorField.GetAsBool()
}
return params.Params.QueryNodeCfg.MmapScalarField.GetAsBool()
}
func isGrowingMmapEnable() bool {
return params.Params.QueryNodeCfg.GrowingMmapEnabled.GetAsBool()
}