mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: merge sort support multiple fields (#44191)
issue: #44011 --------- Signed-off-by: sunby <sunbingyi1992@gmail.com>
This commit is contained in:
parent
2e98cb0103
commit
ef392bb1b2
@ -106,7 +106,7 @@ func mergeSortMultipleSegments(ctx context.Context,
|
||||
log.Warn("compaction only support int64 and varchar pk field")
|
||||
}
|
||||
|
||||
if _, err = storage.MergeSort(compactionParams.BinLogMaxSize, plan.GetSchema(), segmentReaders, writer, predicate); err != nil {
|
||||
if _, err = storage.MergeSort(compactionParams.BinLogMaxSize, plan.GetSchema(), segmentReaders, writer, predicate, []int64{pkField.FieldID}); err != nil {
|
||||
writer.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -25,7 +25,6 @@ import (
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func Sort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordReader,
|
||||
@ -202,7 +201,7 @@ func NewPriorityQueue[T any](less func(x, y *T) bool) *PriorityQueue[T] {
|
||||
}
|
||||
|
||||
func MergeSort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordReader,
|
||||
rw RecordWriter, predicate func(r Record, ri, i int) bool,
|
||||
rw RecordWriter, predicate func(r Record, ri, i int) bool, sortedByFieldIDs []int64,
|
||||
) (numRows int, err error) {
|
||||
// Fast path: no readers provided
|
||||
if len(rr) == 0 {
|
||||
@ -231,43 +230,50 @@ func MergeSort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordR
|
||||
}
|
||||
}
|
||||
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
comparators := make([]func(x, y *index) int, 0, len(sortedByFieldIDs))
|
||||
for _, fid := range sortedByFieldIDs {
|
||||
switch recs[0].Column(fid).(type) {
|
||||
case *array.Int64:
|
||||
comparators = append(comparators, func(x, y *index) int {
|
||||
xVal := recs[x.ri].Column(fid).(*array.Int64).Value(x.i)
|
||||
yVal := recs[y.ri].Column(fid).(*array.Int64).Value(y.i)
|
||||
if xVal < yVal {
|
||||
return -1
|
||||
}
|
||||
if xVal > yVal {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
case *array.String:
|
||||
comparators = append(comparators, func(x, y *index) int {
|
||||
xVal := recs[x.ri].Column(fid).(*array.String).Value(x.i)
|
||||
yVal := recs[y.ri].Column(fid).(*array.String).Value(y.i)
|
||||
if xVal < yVal {
|
||||
return -1
|
||||
}
|
||||
if xVal > yVal {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
default:
|
||||
return 0, merr.WrapErrParameterInvalidMsg("unsupported type for sorting key")
|
||||
}
|
||||
}
|
||||
pkFieldId := pkField.FieldID
|
||||
|
||||
var pq *PriorityQueue[index]
|
||||
switch recs[0].Column(pkFieldId).(type) {
|
||||
case *array.Int64:
|
||||
pq = NewPriorityQueue(func(x, y *index) bool {
|
||||
xVal := recs[x.ri].Column(pkFieldId).(*array.Int64).Value(x.i)
|
||||
yVal := recs[y.ri].Column(pkFieldId).(*array.Int64).Value(y.i)
|
||||
|
||||
if xVal != yVal {
|
||||
return xVal < yVal
|
||||
pq := NewPriorityQueue(func(x, y *index) bool {
|
||||
for _, cmp := range comparators {
|
||||
c := cmp(x, y)
|
||||
if c < 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if x.ri != y.ri {
|
||||
return x.ri < y.ri
|
||||
if c > 0 {
|
||||
return false
|
||||
}
|
||||
return x.i < y.i
|
||||
})
|
||||
case *array.String:
|
||||
pq = NewPriorityQueue(func(x, y *index) bool {
|
||||
xVal := recs[x.ri].Column(pkFieldId).(*array.String).Value(x.i)
|
||||
yVal := recs[y.ri].Column(pkFieldId).(*array.String).Value(y.i)
|
||||
|
||||
if xVal != yVal {
|
||||
return xVal < yVal
|
||||
}
|
||||
|
||||
if x.ri != y.ri {
|
||||
return x.ri < y.ri
|
||||
}
|
||||
return x.i < y.i
|
||||
})
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
endPositions := make([]int, len(recs))
|
||||
var enqueueAll func(ri int) error
|
||||
|
||||
@ -112,7 +112,7 @@ func TestMergeSort(t *testing.T) {
|
||||
t.Run("merge sort", func(t *testing.T) {
|
||||
gotNumRows, err := MergeSort(batchSize, generateTestSchema(), getReaders(), rw, func(r Record, ri, i int) bool {
|
||||
return true
|
||||
})
|
||||
}, []int64{common.RowIDField})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10000, gotNumRows)
|
||||
err = rw.Close()
|
||||
@ -125,7 +125,7 @@ func TestMergeSort(t *testing.T) {
|
||||
// cover a single record (1024 rows) that is deleted, or the last data in the record is deleted
|
||||
// index 1023 is deleted. records (1024-2048) and (5000-6023) are all deleted
|
||||
return pk < 2000 || (pk >= 3050 && pk < 5000) || pk >= 7000
|
||||
})
|
||||
}, []int64{common.RowIDField})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5950, gotNumRows)
|
||||
err = rw.Close()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user