milvus/internal/agg/aggregate_util.go
Chun Han b7ee93fc52
feat: support query aggregtion(#36380) (#44394)
related: #36380

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
- Core invariant: aggregation is centralized and schema-aware — all
aggregate functions are created via the exec Aggregate registry
(milvus::exec::Aggregate) and validated by ValidateAggFieldType, use a
single in-memory accumulator layout (Accumulator/RowContainer) and
grouping primitives (GroupingSet, HashTable, VectorHasher), ensuring
consistent typing, null semantics and offsets across planner → exec →
reducer conversion paths (toAggregateInfo, Aggregate::create,
GroupingSet, AggResult converters).

- Removed / simplified logic: removed ad‑hoc count/group-by and reducer
code (CountNode/PhyCountNode, GroupByNode/PhyGroupByNode, cntReducer and
its tests) and consolidated into a unified AggregationNode →
PhyAggregationNode + GroupingSet + HashTable execution path and
centralized reducers (MilvusAggReducer, InternalAggReducer,
SegcoreAggReducer). AVG now implemented compositionally (SUM + COUNT)
rather than a bespoke operator, eliminating duplicate implementations.

- Why this does NOT cause data loss or regressions: existing data-access
and serialization paths are preserved and explicitly validated —
bulk_subscript / bulk_script_field_data and FieldData creation are used
for output materialization; converters (InternalResult2AggResult ↔
AggResult2internalResult, SegcoreResults2AggResult ↔
AggResult2segcoreResult) enforce shape/type/row-count validation; proxy
and plan-level checks (MatchAggregationExpression,
translateOutputFields, ValidateAggFieldType, translateGroupByFieldIds)
reject unsupported inputs (ARRAY/JSON, unsupported datatypes) early.
Empty-result generation and explicit error returns guard against silent
corruption.

- New capability and scope: end-to-end GROUP BY and aggregation support
added across the stack — proto (plan.proto, RetrieveRequest fields
group_by_field_ids/aggregates), planner nodes (AggregationNode,
ProjectNode, SearchGroupByNode), exec operators (PhyAggregationNode,
PhyProjectNode) and aggregation core (Aggregate implementations:
Sum/Count/Min/Max, SimpleNumericAggregate, RowContainer, GroupingSet,
HashTable) plus proxy/querynode reducers and tests — enabling grouped
and global aggregation (sum, count, min, max, avg via sum+count) with
schema-aware validation and reduction.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
2026-01-06 16:29:25 +08:00

495 lines
15 KiB
Go

package agg
import (
"encoding/binary"
"fmt"
"hash"
"hash/fnv"
"math"
"unsafe"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
)
func NewFieldAccessor(fieldType schemapb.DataType) (FieldAccessor, error) {
switch fieldType {
case schemapb.DataType_Bool:
return newBoolFieldAccessor(), nil
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
return newInt32FieldAccessor(), nil
case schemapb.DataType_Int64:
return newInt64FieldAccessor(), nil
case schemapb.DataType_Timestamptz:
return newTimestamptzFieldAccessor(), nil
case schemapb.DataType_VarChar, schemapb.DataType_String:
return newStringFieldAccessor(), nil
case schemapb.DataType_Float:
return newFloat32FieldAccessor(), nil
case schemapb.DataType_Double:
return newFloat64FieldAccessor(), nil
default:
return nil, fmt.Errorf("unsupported data type for hasher")
}
}
type FieldAccessor interface {
Hash(idx int) uint64
ValAt(idx int) interface{}
SetVals(fieldData *schemapb.FieldData)
RowCount() int
}
type Int32FieldAccessor struct {
vals []int32
hasher hash.Hash64
buffer []byte
}
func (i32Field *Int32FieldAccessor) Hash(idx int) uint64 {
if idx < 0 || idx >= len(i32Field.vals) {
panic(fmt.Sprintf("Int32FieldAccessor.Hash: index %d out of range [0,%d)", idx, len(i32Field.vals)))
}
i32Field.hasher.Reset()
val := i32Field.vals[idx]
binary.LittleEndian.PutUint32(i32Field.buffer, uint32(val))
i32Field.hasher.Write(i32Field.buffer)
ret := i32Field.hasher.Sum64()
return ret
}
func (i32Field *Int32FieldAccessor) SetVals(fieldData *schemapb.FieldData) {
i32Field.vals = fieldData.GetScalars().GetIntData().GetData()
}
func (i32Field *Int32FieldAccessor) RowCount() int {
return len(i32Field.vals)
}
func (i32Field *Int32FieldAccessor) ValAt(idx int) interface{} {
return i32Field.vals[idx]
}
func newInt32FieldAccessor() FieldAccessor {
return &Int32FieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 4)}
}
type Int64FieldAccessor struct {
vals []int64
hasher hash.Hash64
buffer []byte
}
func (i64Field *Int64FieldAccessor) Hash(idx int) uint64 {
if idx < 0 || idx >= len(i64Field.vals) {
panic(fmt.Sprintf("Int64FieldAccessor.Hash: index %d out of range [0,%d)", idx, len(i64Field.vals)))
}
i64Field.hasher.Reset()
val := i64Field.vals[idx]
binary.LittleEndian.PutUint64(i64Field.buffer, uint64(val))
i64Field.hasher.Write(i64Field.buffer)
return i64Field.hasher.Sum64()
}
func (i64Field *Int64FieldAccessor) SetVals(fieldData *schemapb.FieldData) {
i64Field.vals = fieldData.GetScalars().GetLongData().GetData()
}
func (i64Field *Int64FieldAccessor) RowCount() int {
return len(i64Field.vals)
}
func (i64Field *Int64FieldAccessor) ValAt(idx int) interface{} {
return i64Field.vals[idx]
}
func newInt64FieldAccessor() FieldAccessor {
return &Int64FieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 8)}
}
type TimestamptzFieldAccessor struct {
vals []int64
hasher hash.Hash64
buffer []byte
}
func (tzField *TimestamptzFieldAccessor) Hash(idx int) uint64 {
if idx < 0 || idx >= len(tzField.vals) {
panic(fmt.Sprintf("TimestamptzFieldAccessor.Hash: index %d out of range [0,%d)", idx, len(tzField.vals)))
}
tzField.hasher.Reset()
val := tzField.vals[idx]
binary.LittleEndian.PutUint64(tzField.buffer, uint64(val))
tzField.hasher.Write(tzField.buffer)
return tzField.hasher.Sum64()
}
func (tzField *TimestamptzFieldAccessor) SetVals(fieldData *schemapb.FieldData) {
tzField.vals = fieldData.GetScalars().GetTimestamptzData().GetData()
}
func (tzField *TimestamptzFieldAccessor) RowCount() int {
return len(tzField.vals)
}
func (tzField *TimestamptzFieldAccessor) ValAt(idx int) interface{} {
return tzField.vals[idx]
}
func newTimestamptzFieldAccessor() FieldAccessor {
return &TimestamptzFieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 8)}
}
// BoolFieldAccessor
type BoolFieldAccessor struct {
vals []bool
hasher hash.Hash64
buffer []byte
}
func (boolField *BoolFieldAccessor) Hash(idx int) uint64 {
if idx < 0 || idx >= len(boolField.vals) {
panic(fmt.Sprintf("BoolFieldAccessor.Hash: index %d out of range [0,%d)", idx, len(boolField.vals)))
}
boolField.hasher.Reset()
val := boolField.vals[idx]
if val {
boolField.buffer[0] = 1
} else {
boolField.buffer[0] = 0
}
boolField.hasher.Write(boolField.buffer[:1])
return boolField.hasher.Sum64()
}
func (boolField *BoolFieldAccessor) SetVals(fieldData *schemapb.FieldData) {
boolField.vals = fieldData.GetScalars().GetBoolData().GetData()
}
func (boolField *BoolFieldAccessor) RowCount() int {
return len(boolField.vals)
}
func (boolField *BoolFieldAccessor) ValAt(idx int) interface{} {
return boolField.vals[idx]
}
func newBoolFieldAccessor() FieldAccessor {
return &BoolFieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 1)}
}
// Float32FieldAccessor
type Float32FieldAccessor struct {
vals []float32
hasher hash.Hash64
buffer []byte
}
func (f32FieldAccessor *Float32FieldAccessor) Hash(idx int) uint64 {
if idx < 0 || idx >= len(f32FieldAccessor.vals) {
panic(fmt.Sprintf("Float32FieldAccessor.Hash: index %d out of range [0,%d)", idx, len(f32FieldAccessor.vals)))
}
f32FieldAccessor.hasher.Reset()
val := f32FieldAccessor.vals[idx]
binary.LittleEndian.PutUint32(f32FieldAccessor.buffer, math.Float32bits(val))
f32FieldAccessor.hasher.Write(f32FieldAccessor.buffer[:4])
return f32FieldAccessor.hasher.Sum64()
}
func (f32FieldAccessor *Float32FieldAccessor) SetVals(fieldData *schemapb.FieldData) {
f32FieldAccessor.vals = fieldData.GetScalars().GetFloatData().GetData()
}
func (f32FieldAccessor *Float32FieldAccessor) RowCount() int {
return len(f32FieldAccessor.vals)
}
func (f32FieldAccessor *Float32FieldAccessor) ValAt(idx int) interface{} {
return f32FieldAccessor.vals[idx]
}
func newFloat32FieldAccessor() FieldAccessor {
return &Float32FieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 4)}
}
// Float64FieldAccessor
type Float64FieldAccessor struct {
vals []float64
hasher hash.Hash64
buffer []byte
}
func (f64Field *Float64FieldAccessor) Hash(idx int) uint64 {
if idx < 0 || idx >= len(f64Field.vals) {
panic(fmt.Sprintf("Float64FieldAccessor.Hash: index %d out of range [0,%d)", idx, len(f64Field.vals)))
}
f64Field.hasher.Reset()
val := f64Field.vals[idx]
binary.LittleEndian.PutUint64(f64Field.buffer, math.Float64bits(val))
f64Field.hasher.Write(f64Field.buffer)
return f64Field.hasher.Sum64()
}
func (f64Field *Float64FieldAccessor) SetVals(fieldData *schemapb.FieldData) {
f64Field.vals = fieldData.GetScalars().GetDoubleData().GetData()
}
func (f64Field *Float64FieldAccessor) RowCount() int {
return len(f64Field.vals)
}
func (f64Field *Float64FieldAccessor) ValAt(idx int) interface{} {
return f64Field.vals[idx]
}
func newFloat64FieldAccessor() FieldAccessor {
return &Float64FieldAccessor{hasher: fnv.New64a(), buffer: make([]byte, 8)}
}
// StringFieldAccessor
type StringFieldAccessor struct {
vals []string
hasher hash.Hash64
}
func (stringField *StringFieldAccessor) Hash(idx int) uint64 {
if idx < 0 || idx >= len(stringField.vals) {
panic(fmt.Sprintf("StringFieldAccessor.Hash: index %d out of range [0,%d)", idx, len(stringField.vals)))
}
stringField.hasher.Reset()
val := stringField.vals[idx]
b := unsafe.Slice(unsafe.StringData(val), len(val))
stringField.hasher.Write(b)
return stringField.hasher.Sum64()
}
func (stringField *StringFieldAccessor) SetVals(fieldData *schemapb.FieldData) {
stringField.vals = fieldData.GetScalars().GetStringData().GetData()
}
func (stringField *StringFieldAccessor) RowCount() int {
return len(stringField.vals)
}
func (stringField *StringFieldAccessor) ValAt(idx int) interface{} {
return stringField.vals[idx]
}
func newStringFieldAccessor() FieldAccessor {
return &StringFieldAccessor{hasher: fnv.New64a()}
}
func AssembleBucket(bucket *Bucket, fieldDatas []*schemapb.FieldData) error {
colCount := len(fieldDatas)
for r := 0; r < bucket.RowCount(); r++ {
row := bucket.RowAt(r)
if err := AssembleSingleRow(colCount, row, fieldDatas); err != nil {
return err
}
}
return nil
}
func AssembleSingleRow(colCount int, row *Row, fieldDatas []*schemapb.FieldData) error {
for c := 0; c < colCount; c++ {
err := AssembleSingleValue(row.ValAt(c), fieldDatas[c])
if err != nil {
return err
}
}
return nil
}
func AssembleSingleValue(val interface{}, fieldData *schemapb.FieldData) error {
switch fieldData.GetType() {
case schemapb.DataType_Bool:
boolVal, ok := val.(bool)
if !ok {
return fmt.Errorf("type assertion failed: expected bool, got %T", val)
}
fieldData.GetScalars().GetBoolData().Data = append(fieldData.GetScalars().GetBoolData().GetData(), boolVal)
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
intVal, ok := val.(int32)
if !ok {
return fmt.Errorf("type assertion failed: expected int32, got %T", val)
}
fieldData.GetScalars().GetIntData().Data = append(fieldData.GetScalars().GetIntData().GetData(), intVal)
case schemapb.DataType_Int64:
int64Val, ok := val.(int64)
if !ok {
return fmt.Errorf("type assertion failed: expected int64, got %T", val)
}
fieldData.GetScalars().GetLongData().Data = append(fieldData.GetScalars().GetLongData().GetData(), int64Val)
case schemapb.DataType_Timestamptz:
timestampVal, ok := val.(int64)
if !ok {
return fmt.Errorf("type assertion failed: expected int64 for Timestamptz, got %T", val)
}
fieldData.GetScalars().GetTimestamptzData().Data = append(fieldData.GetScalars().GetTimestamptzData().GetData(), timestampVal)
case schemapb.DataType_Float:
floatVal, ok := val.(float32)
if !ok {
return fmt.Errorf("type assertion failed: expected float32, got %T", val)
}
fieldData.GetScalars().GetFloatData().Data = append(fieldData.GetScalars().GetFloatData().GetData(), floatVal)
case schemapb.DataType_Double:
doubleVal, ok := val.(float64)
if !ok {
return fmt.Errorf("type assertion failed: expected float64, got %T", val)
}
fieldData.GetScalars().GetDoubleData().Data = append(fieldData.GetScalars().GetDoubleData().GetData(), doubleVal)
case schemapb.DataType_VarChar, schemapb.DataType_String:
stringVal, ok := val.(string)
if !ok {
return fmt.Errorf("type assertion failed: expected string, got %T", val)
}
fieldData.GetScalars().GetStringData().Data = append(fieldData.GetScalars().GetStringData().GetData(), stringVal)
default:
return fmt.Errorf("unsupported DataType:%d", fieldData.GetType())
}
return nil
}
type AggregationFieldMap struct {
userOriginalOutputFields []string
userOriginalOutputFieldIdxes [][]int // Each user output field can map to multiple field indices (e.g., avg maps to sum and count)
}
func (aggMap *AggregationFieldMap) Count() int {
return len(aggMap.userOriginalOutputFields)
}
// IndexAt returns the first index for the given user output field index.
// For avg aggregation, this returns the sum index.
// For backward compatibility, this method is kept.
func (aggMap *AggregationFieldMap) IndexAt(idx int) int {
if len(aggMap.userOriginalOutputFieldIdxes[idx]) > 0 {
return aggMap.userOriginalOutputFieldIdxes[idx][0]
}
return -1
}
// IndexesAt returns all indices for the given user output field index.
// For avg aggregation, this returns both sum and count indices.
// For other aggregations, this returns a slice with a single index.
func (aggMap *AggregationFieldMap) IndexesAt(idx int) []int {
return aggMap.userOriginalOutputFieldIdxes[idx]
}
func (aggMap *AggregationFieldMap) NameAt(idx int) string {
return aggMap.userOriginalOutputFields[idx]
}
func NewAggregationFieldMap(originalUserOutputFields []string, groupByFields []string, aggs []AggregateBase) *AggregationFieldMap {
numGroupingKeys := len(groupByFields)
groupByFieldMap := make(map[string]int, len(groupByFields))
for i, field := range groupByFields {
groupByFieldMap[field] = i
}
// Build a map from originalName to all indices (for avg, this will include both sum and count indices)
aggFieldMap := make(map[string][]int, len(aggs))
for i, agg := range aggs {
originalName := agg.OriginalName()
idx := i + numGroupingKeys
// Check if this aggregate is part of an avg aggregation
var isAvg bool
switch a := agg.(type) {
case *SumAggregate:
isAvg = a.isAvg
case *CountAggregate:
isAvg = a.isAvg
}
if isAvg {
// For avg aggregates, both sum and count share the same originalName
// Add this index to the list for this originalName
aggFieldMap[originalName] = append(aggFieldMap[originalName], idx)
} else {
// For non-avg aggregates, each originalName maps to a single index
aggFieldMap[originalName] = []int{idx}
}
}
userOriginalOutputFieldIdxes := make([][]int, len(originalUserOutputFields))
for i, outputField := range originalUserOutputFields {
if idx, exist := groupByFieldMap[outputField]; exist {
// Group by field maps to a single index
userOriginalOutputFieldIdxes[i] = []int{idx}
} else if indices, exist := aggFieldMap[outputField]; exist {
// Aggregate field may map to multiple indices (for avg: sum and count)
userOriginalOutputFieldIdxes[i] = indices
} else {
// Field not found, set empty slice
userOriginalOutputFieldIdxes[i] = []int{}
}
}
return &AggregationFieldMap{originalUserOutputFields, userOriginalOutputFieldIdxes}
}
// ComputeAvgFromSumAndCount computes average from sum and count field data.
// It takes sumFieldData and countFieldData, computes avg = sum / count for each row,
// and returns a new Double FieldData containing the average values.
func ComputeAvgFromSumAndCount(sumFieldData *schemapb.FieldData, countFieldData *schemapb.FieldData) (*schemapb.FieldData, error) {
if sumFieldData == nil || countFieldData == nil {
return nil, fmt.Errorf("sumFieldData and countFieldData cannot be nil")
}
sumType := sumFieldData.GetType()
countType := countFieldData.GetType()
if countType != schemapb.DataType_Int64 {
return nil, fmt.Errorf("count field must be Int64 type, got %s", countType.String())
}
countData := countFieldData.GetScalars().GetLongData().GetData()
rowCount := len(countData)
// Create result FieldData with Double type
result := &schemapb.FieldData{
Type: schemapb.DataType_Double,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_DoubleData{
DoubleData: &schemapb.DoubleArray{Data: make([]float64, 0, rowCount)},
},
},
},
}
resultData := make([]float64, 0, rowCount)
// Compute avg = sum / count for each row
switch sumType {
case schemapb.DataType_Int64:
sumData := sumFieldData.GetScalars().GetLongData().GetData()
if len(sumData) != rowCount {
return nil, fmt.Errorf("sum and count field data must have the same length, got sum:%d, count:%d", len(sumData), rowCount)
}
for i := 0; i < rowCount; i++ {
if countData[i] == 0 {
return nil, fmt.Errorf("division by zero: count is 0 at row %d", i)
}
resultData = append(resultData, float64(sumData[i])/float64(countData[i]))
}
case schemapb.DataType_Double:
sumData := sumFieldData.GetScalars().GetDoubleData().GetData()
if len(sumData) != rowCount {
return nil, fmt.Errorf("sum and count field data must have the same length, got sum:%d, count:%d", len(sumData), rowCount)
}
for i := 0; i < rowCount; i++ {
if countData[i] == 0 {
return nil, fmt.Errorf("division by zero: count is 0 at row %d", i)
}
resultData = append(resultData, sumData[i]/float64(countData[i]))
}
default:
return nil, fmt.Errorf("unsupported sum field type for avg computation: %s", sumType.String())
}
result.GetScalars().GetDoubleData().Data = resultData
return result, nil
}