From ef392bb1b2aee2bfac71fde2e1b66285d7d4b87b Mon Sep 17 00:00:00 2001 From: Bingyi Sun Date: Thu, 4 Sep 2025 19:53:53 +0800 Subject: [PATCH] enhance: merge sort support multiple fields (#44191) issue: #44011 --------- Signed-off-by: sunby --- internal/datanode/compactor/merge_sort.go | 2 +- internal/storage/sort.go | 76 ++++++++++++----------- internal/storage/sort_test.go | 4 +- 3 files changed, 44 insertions(+), 38 deletions(-) diff --git a/internal/datanode/compactor/merge_sort.go b/internal/datanode/compactor/merge_sort.go index 14d2825cc5..3e01a010ff 100644 --- a/internal/datanode/compactor/merge_sort.go +++ b/internal/datanode/compactor/merge_sort.go @@ -106,7 +106,7 @@ func mergeSortMultipleSegments(ctx context.Context, log.Warn("compaction only support int64 and varchar pk field") } - if _, err = storage.MergeSort(compactionParams.BinLogMaxSize, plan.GetSchema(), segmentReaders, writer, predicate); err != nil { + if _, err = storage.MergeSort(compactionParams.BinLogMaxSize, plan.GetSchema(), segmentReaders, writer, predicate, []int64{pkField.FieldID}); err != nil { writer.Close() return nil, err } diff --git a/internal/storage/sort.go b/internal/storage/sort.go index dee25adc0f..436c6558c9 100644 --- a/internal/storage/sort.go +++ b/internal/storage/sort.go @@ -25,7 +25,6 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/v2/util/merr" - "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) func Sort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordReader, @@ -202,7 +201,7 @@ func NewPriorityQueue[T any](less func(x, y *T) bool) *PriorityQueue[T] { } func MergeSort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordReader, - rw RecordWriter, predicate func(r Record, ri, i int) bool, + rw RecordWriter, predicate func(r Record, ri, i int) bool, sortedByFieldIDs []int64, ) (numRows int, err error) { // Fast path: no readers provided if len(rr) == 0 { @@ -231,43 +230,50 @@ func MergeSort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordR } } - pkField, err := typeutil.GetPrimaryFieldSchema(schema) - if err != nil { - return 0, err + comparators := make([]func(x, y *index) int, 0, len(sortedByFieldIDs)) + for _, fid := range sortedByFieldIDs { + switch recs[0].Column(fid).(type) { + case *array.Int64: + comparators = append(comparators, func(x, y *index) int { + xVal := recs[x.ri].Column(fid).(*array.Int64).Value(x.i) + yVal := recs[y.ri].Column(fid).(*array.Int64).Value(y.i) + if xVal < yVal { + return -1 + } + if xVal > yVal { + return 1 + } + return 0 + }) + case *array.String: + comparators = append(comparators, func(x, y *index) int { + xVal := recs[x.ri].Column(fid).(*array.String).Value(x.i) + yVal := recs[y.ri].Column(fid).(*array.String).Value(y.i) + if xVal < yVal { + return -1 + } + if xVal > yVal { + return 1 + } + return 0 + }) + default: + return 0, merr.WrapErrParameterInvalidMsg("unsupported type for sorting key") + } } - pkFieldId := pkField.FieldID - var pq *PriorityQueue[index] - switch recs[0].Column(pkFieldId).(type) { - case *array.Int64: - pq = NewPriorityQueue(func(x, y *index) bool { - xVal := recs[x.ri].Column(pkFieldId).(*array.Int64).Value(x.i) - yVal := recs[y.ri].Column(pkFieldId).(*array.Int64).Value(y.i) - - if xVal != yVal { - return xVal < yVal + pq := NewPriorityQueue(func(x, y *index) bool { + for _, cmp := range comparators { + c := cmp(x, y) + if c < 0 { + return true } - - if x.ri != y.ri { - return x.ri < y.ri + if c > 0 { + return false } - return x.i < y.i - }) - case *array.String: - pq = NewPriorityQueue(func(x, y *index) bool { - xVal := recs[x.ri].Column(pkFieldId).(*array.String).Value(x.i) - yVal := recs[y.ri].Column(pkFieldId).(*array.String).Value(y.i) - - if xVal != yVal { - return xVal < yVal - } - - if x.ri != y.ri { - return x.ri < y.ri - } - return x.i < y.i - }) - } + } + return false + }) endPositions := make([]int, len(recs)) var enqueueAll func(ri int) error diff --git a/internal/storage/sort_test.go b/internal/storage/sort_test.go index 5e587883d3..82dd73d9f5 100644 --- a/internal/storage/sort_test.go +++ b/internal/storage/sort_test.go @@ -112,7 +112,7 @@ func TestMergeSort(t *testing.T) { t.Run("merge sort", func(t *testing.T) { gotNumRows, err := MergeSort(batchSize, generateTestSchema(), getReaders(), rw, func(r Record, ri, i int) bool { return true - }) + }, []int64{common.RowIDField}) assert.NoError(t, err) assert.Equal(t, 10000, gotNumRows) err = rw.Close() @@ -125,7 +125,7 @@ func TestMergeSort(t *testing.T) { // cover a single record (1024 rows) that is deleted, or the last data in the record is deleted // index 1023 is deleted. records (1024-2048) and (5000-6023) are all deleted return pk < 2000 || (pk >= 3050 && pk < 5000) || pk >= 7000 - }) + }, []int64{common.RowIDField}) assert.NoError(t, err) assert.Equal(t, 5950, gotNumRows) err = rw.Close()