mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
related: #36380 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> - Core invariant: aggregation is centralized and schema-aware — all aggregate functions are created via the exec Aggregate registry (milvus::exec::Aggregate) and validated by ValidateAggFieldType, use a single in-memory accumulator layout (Accumulator/RowContainer) and grouping primitives (GroupingSet, HashTable, VectorHasher), ensuring consistent typing, null semantics and offsets across planner → exec → reducer conversion paths (toAggregateInfo, Aggregate::create, GroupingSet, AggResult converters). - Removed / simplified logic: removed ad‑hoc count/group-by and reducer code (CountNode/PhyCountNode, GroupByNode/PhyGroupByNode, cntReducer and its tests) and consolidated into a unified AggregationNode → PhyAggregationNode + GroupingSet + HashTable execution path and centralized reducers (MilvusAggReducer, InternalAggReducer, SegcoreAggReducer). AVG now implemented compositionally (SUM + COUNT) rather than a bespoke operator, eliminating duplicate implementations. - Why this does NOT cause data loss or regressions: existing data-access and serialization paths are preserved and explicitly validated — bulk_subscript / bulk_script_field_data and FieldData creation are used for output materialization; converters (InternalResult2AggResult ↔ AggResult2internalResult, SegcoreResults2AggResult ↔ AggResult2segcoreResult) enforce shape/type/row-count validation; proxy and plan-level checks (MatchAggregationExpression, translateOutputFields, ValidateAggFieldType, translateGroupByFieldIds) reject unsupported inputs (ARRAY/JSON, unsupported datatypes) early. Empty-result generation and explicit error returns guard against silent corruption. - New capability and scope: end-to-end GROUP BY and aggregation support added across the stack — proto (plan.proto, RetrieveRequest fields group_by_field_ids/aggregates), planner nodes (AggregationNode, ProjectNode, SearchGroupByNode), exec operators (PhyAggregationNode, PhyProjectNode) and aggregation core (Aggregate implementations: Sum/Count/Min/Max, SimpleNumericAggregate, RowContainer, GroupingSet, HashTable) plus proxy/querynode reducers and tests — enabling grouped and global aggregation (sum, count, min, max, avg via sum+count) with schema-aware validation and reduction. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com>
241 lines
6.6 KiB
Go
241 lines
6.6 KiB
Go
package agg
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
|
)
|
|
|
|
const (
|
|
kSum = "sum"
|
|
kCount = "count"
|
|
kAvg = "avg"
|
|
kMin = "min"
|
|
kMax = "max"
|
|
)
|
|
|
|
var (
|
|
// Define the regular expression pattern once to avoid repeated concatenation.
|
|
aggregationTypes = kSum + `|` + kCount + `|` + kAvg + `|` + kMin + `|` + kMax
|
|
aggregationPattern = regexp.MustCompile(`(?i)^(` + aggregationTypes + `)\s*\(\s*([\w\*]*)\s*\)$`)
|
|
)
|
|
|
|
// MatchAggregationExpression return isAgg, operator name, operator parameter
|
|
func MatchAggregationExpression(expression string) (bool, string, string) {
|
|
// FindStringSubmatch returns the full match and submatches.
|
|
matches := aggregationPattern.FindStringSubmatch(expression)
|
|
if len(matches) > 0 {
|
|
// Return true, the operator, and the captured parameter.
|
|
return true, strings.ToLower(matches[1]), strings.TrimSpace(matches[2])
|
|
}
|
|
return false, "", ""
|
|
}
|
|
|
|
type AggregateBase interface {
|
|
Name() string
|
|
Update(target *FieldValue, new *FieldValue) error
|
|
ToPB() *planpb.Aggregate
|
|
FieldID() int64
|
|
OriginalName() string
|
|
}
|
|
|
|
func isSupportedAggregateName(aggregateName string) bool {
|
|
switch aggregateName {
|
|
case kCount, kSum, kAvg, kMin, kMax:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func NewAggregate(aggregateName string, aggFieldID int64, originalName string, fieldType schemapb.DataType) ([]AggregateBase, error) {
|
|
if !isSupportedAggregateName(aggregateName) {
|
|
return nil, fmt.Errorf("invalid Aggregation operator %s", aggregateName)
|
|
}
|
|
|
|
if err := ValidateAggFieldType(aggregateName, fieldType); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch aggregateName {
|
|
case kCount:
|
|
return []AggregateBase{&CountAggregate{fieldID: aggFieldID, originalName: originalName, isAvg: false}}, nil
|
|
case kSum:
|
|
return []AggregateBase{&SumAggregate{fieldID: aggFieldID, originalName: originalName, isAvg: false}}, nil
|
|
case kAvg:
|
|
// avg is implemented as sum and count, which will be computed as sum/count later
|
|
return []AggregateBase{
|
|
&SumAggregate{fieldID: aggFieldID, originalName: originalName, isAvg: true},
|
|
&CountAggregate{fieldID: aggFieldID, originalName: originalName, isAvg: true},
|
|
}, nil
|
|
case kMin:
|
|
return []AggregateBase{&MinAggregate{fieldID: aggFieldID, originalName: originalName}}, nil
|
|
case kMax:
|
|
return []AggregateBase{&MaxAggregate{fieldID: aggFieldID, originalName: originalName}}, nil
|
|
default:
|
|
// should never happen due to isSupportedAggregateName check
|
|
return nil, fmt.Errorf("invalid Aggregation operator %s", aggregateName)
|
|
}
|
|
}
|
|
|
|
func FromPB(pb *planpb.Aggregate) (AggregateBase, error) {
|
|
switch pb.Op {
|
|
case planpb.AggregateOp_count:
|
|
return &CountAggregate{fieldID: pb.GetFieldId(), isAvg: false}, nil
|
|
case planpb.AggregateOp_sum:
|
|
return &SumAggregate{fieldID: pb.GetFieldId(), isAvg: false}, nil
|
|
case planpb.AggregateOp_min:
|
|
return &MinAggregate{fieldID: pb.GetFieldId()}, nil
|
|
case planpb.AggregateOp_max:
|
|
return &MaxAggregate{fieldID: pb.GetFieldId()}, nil
|
|
default:
|
|
return nil, fmt.Errorf("invalid Aggregation operator %d", pb.Op)
|
|
}
|
|
}
|
|
|
|
func AccumulateFieldValue(target *FieldValue, new *FieldValue) error {
|
|
if target == nil || new == nil {
|
|
return fmt.Errorf("target or new field value is nil")
|
|
}
|
|
|
|
// Handle nil `val` for initialization
|
|
if target.val == nil {
|
|
target.val = new.val
|
|
return nil
|
|
}
|
|
// Verify types match
|
|
if reflect.TypeOf(target.val) != reflect.TypeOf(new.val) {
|
|
return fmt.Errorf("type mismatch: target is %T, new is %T", target.val, new.val)
|
|
}
|
|
|
|
switch target.val.(type) {
|
|
case int:
|
|
target.val = target.val.(int) + new.val.(int)
|
|
case int32:
|
|
target.val = target.val.(int32) + new.val.(int32)
|
|
case int64:
|
|
target.val = target.val.(int64) + new.val.(int64)
|
|
case float32:
|
|
target.val = target.val.(float32) + new.val.(float32)
|
|
case float64:
|
|
target.val = target.val.(float64) + new.val.(float64)
|
|
default:
|
|
return fmt.Errorf("unsupported type: %T", target.val)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func AggregatesToPB(aggregates []AggregateBase) []*planpb.Aggregate {
|
|
ret := make([]*planpb.Aggregate, len(aggregates))
|
|
for idx, agg := range aggregates {
|
|
ret[idx] = agg.ToPB()
|
|
}
|
|
return ret
|
|
}
|
|
|
|
type FieldValue struct {
|
|
val interface{}
|
|
}
|
|
|
|
func NewFieldValue(v interface{}) *FieldValue {
|
|
return &FieldValue{val: v}
|
|
}
|
|
|
|
type Row struct {
|
|
fieldValues []*FieldValue
|
|
}
|
|
|
|
func (r *Row) ColumnCount() int {
|
|
return len(r.fieldValues)
|
|
}
|
|
|
|
func (r *Row) ValAt(col int) interface{} {
|
|
return r.fieldValues[col].val
|
|
}
|
|
|
|
func (r *Row) Equal(other *Row, keyCount int) bool {
|
|
// Ensure both rows have enough fields for key comparison
|
|
if len(r.fieldValues) < keyCount || len(other.fieldValues) < keyCount || len(r.fieldValues) != len(other.fieldValues) {
|
|
return false
|
|
}
|
|
// Compare each field value for equality
|
|
for i := 0; i < keyCount; i++ {
|
|
if r.fieldValues[i].val != other.fieldValues[i].val {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (r *Row) UpdateFieldValue(newRow *Row, col int, agg AggregateBase) {
|
|
agg.Update(r.fieldValues[col], newRow.fieldValues[col])
|
|
}
|
|
|
|
func (r *Row) ToString() string {
|
|
var builder strings.Builder
|
|
builder.WriteString("agg-row:")
|
|
for _, fv := range r.fieldValues {
|
|
builder.WriteString(fmt.Sprintf("%v,", fv.val))
|
|
}
|
|
return builder.String()
|
|
}
|
|
|
|
func NewRow(fieldValues []*FieldValue) *Row {
|
|
return &Row{fieldValues: fieldValues}
|
|
}
|
|
|
|
type Bucket struct {
|
|
rows []*Row
|
|
}
|
|
|
|
func (bucket *Bucket) AddRow(row *Row) {
|
|
bucket.rows = append(bucket.rows, row)
|
|
}
|
|
|
|
func (bucket *Bucket) RowAt(idx int) *Row {
|
|
return bucket.rows[idx]
|
|
}
|
|
|
|
func (bucket *Bucket) RowCount() int {
|
|
return len(bucket.rows)
|
|
}
|
|
|
|
func (bucket *Bucket) Accumulate(row *Row, idx int, keyCount int, aggs []AggregateBase) error {
|
|
if idx >= len(bucket.rows) || idx < 0 {
|
|
return fmt.Errorf("wrong idx:%d for bucket", idx)
|
|
}
|
|
targetRow := bucket.rows[idx]
|
|
if targetRow == nil {
|
|
return fmt.Errorf("nil row at the target idx:%d, cannot accumulate the row", idx)
|
|
}
|
|
if row.ColumnCount() != targetRow.ColumnCount() {
|
|
return fmt.Errorf("column count:%d in the row must be equal to the target row:%d", row.ColumnCount(), bucket.rows[idx].ColumnCount())
|
|
}
|
|
if row.ColumnCount() != keyCount+len(aggs) {
|
|
return fmt.Errorf("column count:%d in the row must be sum of keyCount:%d and the number of aggs:%d", row.ColumnCount(), keyCount, len(aggs))
|
|
}
|
|
for col := keyCount; col < row.ColumnCount(); col++ {
|
|
targetRow.UpdateFieldValue(row, col, aggs[col-keyCount])
|
|
}
|
|
return nil
|
|
}
|
|
|
|
const NONE int = -1
|
|
|
|
func (bucket *Bucket) Find(row *Row, keyCount int) int {
|
|
for idx, existingRow := range bucket.rows {
|
|
if existingRow.Equal(row, keyCount) {
|
|
return idx
|
|
}
|
|
}
|
|
return NONE
|
|
}
|
|
|
|
func NewBucket() *Bucket {
|
|
return &Bucket{rows: make([]*Row, 0, 1)}
|
|
}
|