milvus/internal/agg/aggregate_reducer.go
Chun Han b7ee93fc52
feat: support query aggregtion(#36380) (#44394)
related: #36380

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
- Core invariant: aggregation is centralized and schema-aware — all
aggregate functions are created via the exec Aggregate registry
(milvus::exec::Aggregate) and validated by ValidateAggFieldType, use a
single in-memory accumulator layout (Accumulator/RowContainer) and
grouping primitives (GroupingSet, HashTable, VectorHasher), ensuring
consistent typing, null semantics and offsets across planner → exec →
reducer conversion paths (toAggregateInfo, Aggregate::create,
GroupingSet, AggResult converters).

- Removed / simplified logic: removed ad‑hoc count/group-by and reducer
code (CountNode/PhyCountNode, GroupByNode/PhyGroupByNode, cntReducer and
its tests) and consolidated into a unified AggregationNode →
PhyAggregationNode + GroupingSet + HashTable execution path and
centralized reducers (MilvusAggReducer, InternalAggReducer,
SegcoreAggReducer). AVG now implemented compositionally (SUM + COUNT)
rather than a bespoke operator, eliminating duplicate implementations.

- Why this does NOT cause data loss or regressions: existing data-access
and serialization paths are preserved and explicitly validated —
bulk_subscript / bulk_script_field_data and FieldData creation are used
for output materialization; converters (InternalResult2AggResult ↔
AggResult2internalResult, SegcoreResults2AggResult ↔
AggResult2segcoreResult) enforce shape/type/row-count validation; proxy
and plan-level checks (MatchAggregationExpression,
translateOutputFields, ValidateAggFieldType, translateGroupByFieldIds)
reject unsupported inputs (ARRAY/JSON, unsupported datatypes) early.
Empty-result generation and explicit error returns guard against silent
corruption.

- New capability and scope: end-to-end GROUP BY and aggregation support
added across the stack — proto (plan.proto, RetrieveRequest fields
group_by_field_ids/aggregates), planner nodes (AggregationNode,
ProjectNode, SearchGroupByNode), exec operators (PhyAggregationNode,
PhyProjectNode) and aggregation core (Aggregate implementations:
Sum/Count/Min/Max, SimpleNumericAggregate, RowContainer, GroupingSet,
HashTable) plus proxy/querynode reducers and tests — enabling grouped
and global aggregation (sum, count, min, max, avg via sum+count) with
schema-aware validation and reduction.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
2026-01-06 16:29:25 +08:00

460 lines
15 KiB
Go

package agg
import (
"context"
"fmt"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/proto/segcorepb"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type GroupAggReducer struct {
groupByFieldIds []int64
aggregates []*planpb.Aggregate
hashValsMap map[uint64]*Bucket
groupLimit int64
schema *schemapb.CollectionSchema
}
func NewGroupAggReducer(groupByFieldIds []int64, aggregates []*planpb.Aggregate, groupLimit int64, schema *schemapb.CollectionSchema) *GroupAggReducer {
return &GroupAggReducer{
groupByFieldIds: groupByFieldIds,
aggregates: aggregates,
hashValsMap: make(map[uint64]*Bucket), // Initialize hashValsMap
groupLimit: groupLimit,
schema: schema,
}
}
type AggregationResult struct {
fieldDatas []*schemapb.FieldData
allRetrieveCount int64
}
func NewAggregationResult(fieldDatas []*schemapb.FieldData, allRetrieveCount int64) *AggregationResult {
if fieldDatas == nil {
fieldDatas = make([]*schemapb.FieldData, 0)
}
return &AggregationResult{
fieldDatas: fieldDatas,
allRetrieveCount: allRetrieveCount,
}
}
// GetFieldDatas returns the fieldDatas slice
func (ar *AggregationResult) GetFieldDatas() []*schemapb.FieldData {
return ar.fieldDatas
}
func (ar *AggregationResult) GetAllRetrieveCount() int64 {
return ar.allRetrieveCount
}
func (reducer *GroupAggReducer) EmptyAggResult() (*AggregationResult, error) {
helper, err := typeutil.CreateSchemaHelper(reducer.schema)
if err != nil {
return nil, err
}
ret := NewAggregationResult(nil, 0)
appendEmptyField := func(fieldId int64) error {
field, err := helper.GetFieldFromID(fieldId)
if err != nil {
return err
}
emptyFieldData, err := typeutil.GenEmptyFieldData(field)
if err != nil {
return err
}
ret.fieldDatas = append(ret.fieldDatas, emptyFieldData)
return nil
}
for _, grpFid := range reducer.groupByFieldIds {
err := appendEmptyField(grpFid)
if err != nil {
return nil, err
}
}
for _, agg := range reducer.aggregates {
if agg.GetOp() == planpb.AggregateOp_count {
countField := genEmptyLongFieldData(schemapb.DataType_Int64, []int64{0})
ret.fieldDatas = append(ret.fieldDatas, countField)
} else {
field, err := helper.GetFieldFromID(agg.GetFieldId())
if err != nil {
return nil, fmt.Errorf("failed to get field schema for aggregate fieldID %d: %w", agg.GetFieldId(), err)
}
resultType, err := getAggregateResultType(agg.GetOp(), field.GetDataType())
if err != nil {
return nil, fmt.Errorf("failed to get result type for aggregate fieldID %d: %w", agg.GetFieldId(), err)
}
emptyFieldData, err := genEmptyFieldDataByType(resultType)
if err != nil {
return nil, fmt.Errorf("failed to generate empty field data for result type %s: %w", resultType.String(), err)
}
ret.fieldDatas = append(ret.fieldDatas, emptyFieldData)
}
}
return ret, nil
}
func genEmptyLongFieldData(dataType schemapb.DataType, data []int64) *schemapb.FieldData {
return &schemapb.FieldData{
Type: dataType,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: data}},
},
},
}
}
// genEmptyFieldDataByType generates empty field data based on the data type
func genEmptyFieldDataByType(dataType schemapb.DataType) (*schemapb.FieldData, error) {
switch dataType {
case schemapb.DataType_Int64:
return genEmptyLongFieldData(dataType, []int64{0}), nil
case schemapb.DataType_Double:
return &schemapb.FieldData{
Type: dataType,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{DoubleData: &schemapb.DoubleArray{Data: []float64{0}}},
},
},
}, nil
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
return &schemapb.FieldData{
Type: dataType,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{0}}},
},
},
}, nil
case schemapb.DataType_Float:
return &schemapb.FieldData{
Type: dataType,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_FloatData{FloatData: &schemapb.FloatArray{Data: []float32{0}}},
},
},
}, nil
case schemapb.DataType_VarChar, schemapb.DataType_String, schemapb.DataType_Text:
return &schemapb.FieldData{
Type: dataType,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{}}},
},
},
}, nil
case schemapb.DataType_Timestamptz:
return genEmptyLongFieldData(dataType, []int64{0}), nil
default:
// For other types, try to use the original field's GenEmptyFieldData
return nil, fmt.Errorf("unsupported data type for aggregate result: %s", dataType.String())
}
}
// getAggregateResultType returns the expected result type for an aggregate operation
// based on the aggregate operator type and the input field type.
func getAggregateResultType(op planpb.AggregateOp, inputType schemapb.DataType) (schemapb.DataType, error) {
switch op {
case planpb.AggregateOp_count:
// count aggregation always returns Int64
return schemapb.DataType_Int64, nil
case planpb.AggregateOp_avg:
// avg aggregation always returns Double
return schemapb.DataType_Double, nil
case planpb.AggregateOp_min, planpb.AggregateOp_max:
// min/max keep the original field type
return inputType, nil
case planpb.AggregateOp_sum:
// sum returns Int64 for integer types, Double for float types
switch inputType {
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64:
return schemapb.DataType_Int64, nil
case schemapb.DataType_Timestamptz:
return schemapb.DataType_Timestamptz, nil
case schemapb.DataType_Float, schemapb.DataType_Double:
return schemapb.DataType_Double, nil
default:
return schemapb.DataType_None, fmt.Errorf("unsupported input type %s for sum aggregation", inputType.String())
}
default:
return schemapb.DataType_None, fmt.Errorf("unknown aggregate operator: %d", op)
}
}
// validateAggregationResults validates the input AggregationResult slice
// It checks:
// 1. Each result's fieldDatas length equals numGroupingKeys + numAggs
// 2. No nil fieldData in any result
// 3. Each fieldData's Type matches the expected type from schema
func (reducer *GroupAggReducer) validateAggregationResults(results []*AggregationResult) error {
if reducer.schema == nil {
return fmt.Errorf("schema is nil, cannot validate field types")
}
helper, err := typeutil.CreateSchemaHelper(reducer.schema)
if err != nil {
return fmt.Errorf("failed to create schema helper: %w", err)
}
numGroupingKeys := len(reducer.groupByFieldIds)
numAggs := len(reducer.aggregates)
expectedColumnCount := numGroupingKeys + numAggs
// Build expected types for each column
expectedTypes := make([]schemapb.DataType, 0, expectedColumnCount)
// Add types for grouping keys
for _, fieldID := range reducer.groupByFieldIds {
field, err := helper.GetFieldFromID(fieldID)
if err != nil {
return fmt.Errorf("failed to get field schema for groupBy fieldID %d: %w", fieldID, err)
}
expectedTypes = append(expectedTypes, field.GetDataType())
}
// Add types for aggregates
for _, agg := range reducer.aggregates {
var expectedType schemapb.DataType
if agg.GetOp() == planpb.AggregateOp_count {
// count aggregation always returns Int64
expectedType = schemapb.DataType_Int64
} else {
field, err := helper.GetFieldFromID(agg.GetFieldId())
if err != nil {
return fmt.Errorf("failed to get field schema for aggregate fieldID %d: %w", agg.GetFieldId(), err)
}
expectedType, err = getAggregateResultType(agg.GetOp(), field.GetDataType())
if err != nil {
return fmt.Errorf("failed to get aggregate result type for aggregate fieldID %d: %w", agg.GetFieldId(), err)
}
}
expectedTypes = append(expectedTypes, expectedType)
}
// Validate each result
for resultIdx, result := range results {
if result == nil {
return fmt.Errorf("result at index %d is nil", resultIdx)
}
fieldDatas := result.GetFieldDatas()
// Check 1: fieldDatas length
if len(fieldDatas) != expectedColumnCount {
return fmt.Errorf("result at index %d has fieldDatas length %d, expected %d (numGroupingKeys=%d, numAggs=%d)",
resultIdx, len(fieldDatas), expectedColumnCount, numGroupingKeys, numAggs)
}
// Check 2: no nil fieldData and Check 3: type matching
for colIdx, fieldData := range fieldDatas {
if fieldData == nil {
return fmt.Errorf("result at index %d has nil fieldData at column %d", resultIdx, colIdx)
}
expectedType := expectedTypes[colIdx]
actualType := fieldData.GetType()
if actualType != expectedType {
return fmt.Errorf("result at index %d, column %d has type %s, expected %s",
resultIdx, colIdx, schemapb.DataType_name[int32(actualType)], schemapb.DataType_name[int32(expectedType)])
}
}
}
return nil
}
func (reducer *GroupAggReducer) Reduce(ctx context.Context, results []*AggregationResult) (*AggregationResult, error) {
if len(results) == 0 {
return reducer.EmptyAggResult()
}
// Validate input results before processing
if err := reducer.validateAggregationResults(results); err != nil {
return nil, err
}
if len(results) == 1 {
return results[0], nil
}
// 0. set up aggregates
aggs := make([]AggregateBase, len(reducer.aggregates))
for idx, aggPb := range reducer.aggregates {
agg, err := FromPB(aggPb)
if err != nil {
return nil, err
}
aggs[idx] = agg
}
// 1. set up hashers and accumulators
numGroupingKeys := len(reducer.groupByFieldIds)
numAggs := len(reducer.aggregates)
hashers := make([]FieldAccessor, numGroupingKeys)
accumulators := make([]FieldAccessor, numAggs)
firstFieldData := results[0].GetFieldDatas()
outputColumnCount := len(firstFieldData)
for idx, fieldData := range firstFieldData {
accessor, err := NewFieldAccessor(fieldData.GetType())
if err != nil {
return nil, err
}
if idx < numGroupingKeys {
hashers[idx] = accessor
} else {
accumulators[idx-numGroupingKeys] = accessor
}
}
reducedResult := NewAggregationResult(nil, 0)
isGlobal := numGroupingKeys == 0
if isGlobal {
reducedResult.fieldDatas = typeutil.PrepareResultFieldData(firstFieldData, 1)
rows := make([]*Row, len(results))
for idx, result := range results {
reducedResult.allRetrieveCount += result.GetAllRetrieveCount()
fieldValues := make([]*FieldValue, outputColumnCount)
for col := 0; col < outputColumnCount; col++ {
fieldData := result.GetFieldDatas()[col]
accumulators[col].SetVals(fieldData)
fieldValues[col] = NewFieldValue(accumulators[col].ValAt(0))
}
rows[idx] = NewRow(fieldValues)
}
for r := 1; r < len(rows); r++ {
for c := 0; c < outputColumnCount; c++ {
rows[0].UpdateFieldValue(rows[r], c, aggs[c])
}
}
AssembleSingleRow(outputColumnCount, rows[0], reducedResult.fieldDatas)
return reducedResult, nil
}
// 2. compute hash values for all rows in the result retrieved
var totalRowCount int64 = 0
processResults:
for _, result := range results {
// Check limit before processing each shard to avoid unnecessary work
if reducer.groupLimit != -1 && totalRowCount >= reducer.groupLimit {
break processResults
}
reducedResult.allRetrieveCount += result.GetAllRetrieveCount()
if result == nil {
return nil, fmt.Errorf("input result from any sources cannot be nil")
}
fieldDatas := result.GetFieldDatas()
if outputColumnCount != len(fieldDatas) {
return nil, fmt.Errorf("retrieved results from different segments have different size of columns")
}
if outputColumnCount == 0 {
return nil, fmt.Errorf("retrieved results have no column data")
}
rowCount := -1
for i := 0; i < outputColumnCount; i++ {
fieldData := fieldDatas[i]
if i < numGroupingKeys {
hashers[i].SetVals(fieldData)
} else {
accumulators[i-numGroupingKeys].SetVals(fieldData)
}
if rowCount == -1 {
rowCount = hashers[i].RowCount()
} else if i < numGroupingKeys {
if rowCount != hashers[i].RowCount() {
return nil, fmt.Errorf("field data:%d for different columns have different row count, %d vs %d, wrong state",
i, rowCount, hashers[i].RowCount())
}
} else if rowCount != accumulators[i-numGroupingKeys].RowCount() {
return nil, fmt.Errorf("field data:%d for different columns have different row count, %d vs %d, wrong state",
i, rowCount, accumulators[i-numGroupingKeys].RowCount())
}
}
for row := 0; row < rowCount; row++ {
// Check limit before processing each row to avoid unnecessary hashing and copying
if reducer.groupLimit != -1 && totalRowCount >= reducer.groupLimit {
break processResults
}
rowFieldValues := make([]*FieldValue, outputColumnCount)
var hashVal uint64
for col := 0; col < outputColumnCount; col++ {
if col < numGroupingKeys {
if col > 0 {
hashVal = typeutil2.HashMix(hashVal, hashers[col].Hash(row))
} else {
hashVal = hashers[col].Hash(row)
}
rowFieldValues[col] = NewFieldValue(hashers[col].ValAt(row))
} else {
rowFieldValues[col] = NewFieldValue(accumulators[col-numGroupingKeys].ValAt(row))
}
}
newRow := NewRow(rowFieldValues)
if bucket := reducer.hashValsMap[hashVal]; bucket == nil {
newBucket := NewBucket()
newBucket.AddRow(newRow)
totalRowCount++
reducer.hashValsMap[hashVal] = newBucket
} else {
if rowIdx := bucket.Find(newRow, numGroupingKeys); rowIdx == NONE {
bucket.AddRow(newRow)
totalRowCount++
} else {
bucket.Accumulate(newRow, rowIdx, numGroupingKeys, aggs)
}
}
// Don't guarantee specific groups to be returned before milvus support order by
}
}
// 3. assemble reduced buckets into retrievedResult
reducedResult.fieldDatas = typeutil.PrepareResultFieldData(firstFieldData, totalRowCount)
for _, bucket := range reducer.hashValsMap {
err := AssembleBucket(bucket, reducedResult.GetFieldDatas())
if err != nil {
return nil, err
}
}
return reducedResult, nil
}
func InternalResult2AggResult(results []*internalpb.RetrieveResults) []*AggregationResult {
aggResults := make([]*AggregationResult, len(results))
for i := 0; i < len(results); i++ {
aggResults[i] = NewAggregationResult(results[i].GetFieldsData(), results[i].GetAllRetrieveCount())
}
return aggResults
}
func AggResult2internalResult(aggRes *AggregationResult) *internalpb.RetrieveResults {
return &internalpb.RetrieveResults{FieldsData: aggRes.GetFieldDatas(), AllRetrieveCount: aggRes.GetAllRetrieveCount()}
}
func SegcoreResults2AggResult(results []*segcorepb.RetrieveResults) ([]*AggregationResult, error) {
aggResults := make([]*AggregationResult, len(results))
for i := 0; i < len(results); i++ {
if results[i] == nil {
return nil, fmt.Errorf("input segcore query results from any sources cannot be nil")
}
fieldsData := results[i].GetFieldsData()
allRetrieveCount := results[i].GetAllRetrieveCount()
aggResults[i] = NewAggregationResult(fieldsData, allRetrieveCount)
}
return aggResults, nil
}
func AggResult2segcoreResult(aggRes *AggregationResult) *segcorepb.RetrieveResults {
return &segcorepb.RetrieveResults{FieldsData: aggRes.GetFieldDatas(), AllRetrieveCount: aggRes.GetAllRetrieveCount()}
}