enhance: merge sort support multiple fields (#44191)

issue: #44011

---------

Signed-off-by: sunby <sunbingyi1992@gmail.com>
This commit is contained in:
Bingyi Sun 2025-09-04 19:53:53 +08:00 committed by GitHub
parent 2e98cb0103
commit ef392bb1b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 38 deletions

View File

@ -106,7 +106,7 @@ func mergeSortMultipleSegments(ctx context.Context,
log.Warn("compaction only support int64 and varchar pk field")
}
if _, err = storage.MergeSort(compactionParams.BinLogMaxSize, plan.GetSchema(), segmentReaders, writer, predicate); err != nil {
if _, err = storage.MergeSort(compactionParams.BinLogMaxSize, plan.GetSchema(), segmentReaders, writer, predicate, []int64{pkField.FieldID}); err != nil {
writer.Close()
return nil, err
}

View File

@ -25,7 +25,6 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func Sort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordReader,
@ -202,7 +201,7 @@ func NewPriorityQueue[T any](less func(x, y *T) bool) *PriorityQueue[T] {
}
func MergeSort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordReader,
rw RecordWriter, predicate func(r Record, ri, i int) bool,
rw RecordWriter, predicate func(r Record, ri, i int) bool, sortedByFieldIDs []int64,
) (numRows int, err error) {
// Fast path: no readers provided
if len(rr) == 0 {
@ -231,43 +230,50 @@ func MergeSort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordR
}
}
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
if err != nil {
return 0, err
comparators := make([]func(x, y *index) int, 0, len(sortedByFieldIDs))
for _, fid := range sortedByFieldIDs {
switch recs[0].Column(fid).(type) {
case *array.Int64:
comparators = append(comparators, func(x, y *index) int {
xVal := recs[x.ri].Column(fid).(*array.Int64).Value(x.i)
yVal := recs[y.ri].Column(fid).(*array.Int64).Value(y.i)
if xVal < yVal {
return -1
}
if xVal > yVal {
return 1
}
return 0
})
case *array.String:
comparators = append(comparators, func(x, y *index) int {
xVal := recs[x.ri].Column(fid).(*array.String).Value(x.i)
yVal := recs[y.ri].Column(fid).(*array.String).Value(y.i)
if xVal < yVal {
return -1
}
if xVal > yVal {
return 1
}
return 0
})
default:
return 0, merr.WrapErrParameterInvalidMsg("unsupported type for sorting key")
}
}
pkFieldId := pkField.FieldID
var pq *PriorityQueue[index]
switch recs[0].Column(pkFieldId).(type) {
case *array.Int64:
pq = NewPriorityQueue(func(x, y *index) bool {
xVal := recs[x.ri].Column(pkFieldId).(*array.Int64).Value(x.i)
yVal := recs[y.ri].Column(pkFieldId).(*array.Int64).Value(y.i)
if xVal != yVal {
return xVal < yVal
pq := NewPriorityQueue(func(x, y *index) bool {
for _, cmp := range comparators {
c := cmp(x, y)
if c < 0 {
return true
}
if x.ri != y.ri {
return x.ri < y.ri
if c > 0 {
return false
}
return x.i < y.i
})
case *array.String:
pq = NewPriorityQueue(func(x, y *index) bool {
xVal := recs[x.ri].Column(pkFieldId).(*array.String).Value(x.i)
yVal := recs[y.ri].Column(pkFieldId).(*array.String).Value(y.i)
if xVal != yVal {
return xVal < yVal
}
if x.ri != y.ri {
return x.ri < y.ri
}
return x.i < y.i
})
}
}
return false
})
endPositions := make([]int, len(recs))
var enqueueAll func(ri int) error

View File

@ -112,7 +112,7 @@ func TestMergeSort(t *testing.T) {
t.Run("merge sort", func(t *testing.T) {
gotNumRows, err := MergeSort(batchSize, generateTestSchema(), getReaders(), rw, func(r Record, ri, i int) bool {
return true
})
}, []int64{common.RowIDField})
assert.NoError(t, err)
assert.Equal(t, 10000, gotNumRows)
err = rw.Close()
@ -125,7 +125,7 @@ func TestMergeSort(t *testing.T) {
// cover a single record (1024 rows) that is deleted, or the last data in the record is deleted
// index 1023 is deleted. records (1024-2048) and (5000-6023) are all deleted
return pk < 2000 || (pk >= 3050 && pk < 5000) || pk >= 7000
})
}, []int64{common.RowIDField})
assert.NoError(t, err)
assert.Equal(t, 5950, gotNumRows)
err = rw.Close()