enhance: storage sort can sort by multiple fields (#43994)

https://github.com/milvus-io/milvus/issues/44011
this is to support compaction that sorts records by partition key and pk
in the future

---------

Signed-off-by: sunby <sunbingyi1992@gmail.com>
This commit is contained in:
Bingyi Sun 2025-09-03 10:11:52 +08:00 committed by GitHub
parent d55bf49bf1
commit 6624011927
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 119 additions and 26 deletions

View File

@ -67,6 +67,7 @@ type sortCompactionTask struct {
tr *timerecord.TimeRecorder
compactionParams compaction.Params
sortByFieldIDs []int64
}
var _ Compactor = (*sortCompactionTask)(nil)
@ -76,6 +77,7 @@ func NewSortCompactionTask(
binlogIO io.BinlogIO,
plan *datapb.CompactionPlan,
compactionParams compaction.Params,
sortByFieldIDs []int64,
) *sortCompactionTask {
ctx1, cancel := context.WithCancel(ctx)
return &sortCompactionTask{
@ -87,6 +89,7 @@ func NewSortCompactionTask(
currentTime: time.Now(),
done: make(chan struct{}, 1),
compactionParams: compactionParams,
sortByFieldIDs: sortByFieldIDs,
}
}
@ -216,7 +219,7 @@ func (t *sortCompactionTask) sortSegment(ctx context.Context) (*datapb.Compactio
}
defer rr.Close()
rrs := []storage.RecordReader{rr}
numValidRows, err := storage.Sort(t.compactionParams.BinLogMaxSize, t.plan.GetSchema(), rrs, srw, predicate)
numValidRows, err := storage.Sort(t.compactionParams.BinLogMaxSize, t.plan.GetSchema(), rrs, srw, predicate, t.sortByFieldIDs)
if err != nil {
log.Warn("sort failed", zap.Error(err))
return nil, err

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestSortCompactionTaskSuite(t *testing.T) {
@ -84,7 +85,10 @@ func (s *SortCompactionTaskSuite) setupTest() {
CollectionTtl: time.Since(getMilvusBirthday().Add(-time.Hour)).Nanoseconds(),
}
s.task = NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams())
pk, err := typeutil.GetPrimaryFieldSchema(plan.GetSchema())
s.NoError(err)
s.task = NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams(), []int64{pk.GetFieldID()})
}
func (s *SortCompactionTaskSuite) SetupTest() {
@ -99,9 +103,13 @@ func (s *SortCompactionTaskSuite) TestNewSortCompactionTask() {
PlanID: 123,
Type: datapb.CompactionType_SortCompaction,
SlotUsage: 8,
Schema: s.meta.GetSchema(),
}
task := NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams())
pk, err := typeutil.GetPrimaryFieldSchema(plan.GetSchema())
s.NoError(err)
task := NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams(), []int64{pk.GetFieldID()})
s.NotNil(task)
s.Equal(plan.GetPlanID(), task.GetPlanID())
@ -242,7 +250,10 @@ func (s *SortCompactionTaskSuite) setupBM25Test() {
TotalRows: 3,
}
s.task = NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams())
pk, err := typeutil.GetPrimaryFieldSchema(plan.GetSchema())
s.NoError(err)
s.task = NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams(), []int64{pk.GetFieldID()})
}
func (s *SortCompactionTaskSuite) prepareSortCompactionWithBM25Task() {
@ -367,9 +378,13 @@ func TestSortCompactionTaskBasic(t *testing.T) {
SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{
{SegmentID: 100},
},
Schema: genTestCollectionMeta().GetSchema(),
}
task := NewSortCompactionTask(ctx, mockBinlogIO, plan, compaction.GenParams())
pk, err := typeutil.GetPrimaryFieldSchema(plan.GetSchema())
assert.NoError(t, err)
task := NewSortCompactionTask(ctx, mockBinlogIO, plan, compaction.GenParams(), []int64{pk.GetFieldID()})
assert.NotNil(t, task)
assert.Equal(t, int64(123), task.GetPlanID())

View File

@ -252,7 +252,7 @@ func (st *statsTask) sort(ctx context.Context) ([]*datapb.FieldBinlog, error) {
defer rr.Close()
rrs := []storage.RecordReader{rr}
numValidRows, err := storage.Sort(st.req.GetBinlogMaxSize(), st.req.GetSchema(), rrs, srw, predicate)
numValidRows, err := storage.Sort(st.req.GetBinlogMaxSize(), st.req.GetSchema(), rrs, srw, predicate, []int64{pkField.FieldID})
if err != nil {
log.Warn("sort failed", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err))
return nil, err

View File

@ -234,11 +234,16 @@ func (node *DataNode) CompactionV2(ctx context.Context, req *datapb.CompactionPl
if req.GetPreAllocatedSegmentIDs() == nil || req.GetPreAllocatedSegmentIDs().GetBegin() == 0 {
return merr.Status(merr.WrapErrParameterInvalidMsg("invalid pre-allocated segmentID range")), nil
}
pk, err := typeutil.GetPrimaryFieldSchema(req.GetSchema())
if err != nil {
return merr.Status(err), err
}
task = compactor.NewSortCompactionTask(
taskCtx,
binlogIO,
req,
compactionParams,
[]int64{pk.GetFieldID()},
)
default:
log.Warn("Unknown compaction type", zap.String("type", req.GetType().String()))

View File

@ -24,11 +24,12 @@ import (
"github.com/apache/arrow/go/v17/arrow/array"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func Sort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordReader,
rw RecordWriter, predicate func(r Record, ri, i int) bool,
rw RecordWriter, predicate func(r Record, ri, i int) bool, sortByFieldIDs []int64,
) (int, error) {
records := make([]Record, 0)
@ -69,24 +70,55 @@ func Sort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordReader
return 0, nil
}
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
if err != nil {
return 0, err
}
pkFieldId := pkField.FieldID
if len(sortByFieldIDs) > 0 {
type keyCmp func(x, y *index) int
comparators := make([]keyCmp, 0, len(sortByFieldIDs))
for _, fid := range sortByFieldIDs {
switch records[0].Column(fid).(type) {
case *array.Int64:
f := func(x, y *index) int {
xVal := records[x.ri].Column(fid).(*array.Int64).Value(x.i)
yVal := records[y.ri].Column(fid).(*array.Int64).Value(y.i)
if xVal < yVal {
return -1
}
if xVal > yVal {
return 1
}
return 0
}
comparators = append(comparators, f)
case *array.String:
f := func(x, y *index) int {
xVal := records[x.ri].Column(fid).(*array.String).Value(x.i)
yVal := records[y.ri].Column(fid).(*array.String).Value(y.i)
if xVal < yVal {
return -1
}
if xVal > yVal {
return 1
}
return 0
}
comparators = append(comparators, f)
default:
return 0, merr.WrapErrParameterInvalidMsg("unsupported type for sorting key")
}
}
switch records[0].Column(pkFieldId).(type) {
case *array.Int64:
sort.Slice(indices, func(i, j int) bool {
pki := records[indices[i].ri].Column(pkFieldId).(*array.Int64).Value(indices[i].i)
pkj := records[indices[j].ri].Column(pkFieldId).(*array.Int64).Value(indices[j].i)
return pki < pkj
})
case *array.String:
sort.Slice(indices, func(i, j int) bool {
pki := records[indices[i].ri].Column(pkFieldId).(*array.String).Value(indices[i].i)
pkj := records[indices[j].ri].Column(pkFieldId).(*array.String).Value(indices[j].i)
return pki < pkj
x := indices[i]
y := indices[j]
for _, cmp := range comparators {
c := cmp(x, y)
if c < 0 {
return true
}
if c > 0 {
return false
}
}
return false
})
}

View File

@ -59,7 +59,7 @@ func TestSort(t *testing.T) {
t.Run("sort", func(t *testing.T) {
gotNumRows, err := Sort(batchSize, generateTestSchema(), getReaders(), rw, func(r Record, ri, i int) bool {
return true
})
}, []int64{common.RowIDField})
assert.NoError(t, err)
assert.Equal(t, 6, gotNumRows)
err = rw.Close()
@ -70,7 +70,7 @@ func TestSort(t *testing.T) {
gotNumRows, err := Sort(batchSize, generateTestSchema(), getReaders(), rw, func(r Record, ri, i int) bool {
pk := r.Column(common.RowIDField).(*array.Int64).Value(i)
return pk >= 20
})
}, []int64{common.RowIDField})
assert.NoError(t, err)
assert.Equal(t, 3, gotNumRows)
err = rw.Close()
@ -163,7 +163,45 @@ func BenchmarkSort(b *testing.B) {
for i := 0; i < b.N; i++ {
Sort(batchSize, generateTestSchema(), rr, rw, func(r Record, ri, i int) bool {
return true
})
}, []int64{common.RowIDField})
}
})
}
func TestSortByMoreThanOneField(t *testing.T) {
const batchSize = 10000
sortByFieldIDs := []int64{common.RowIDField, common.TimeStampField}
blobs, err := generateTestDataWithSeed(10, batchSize)
assert.NoError(t, err)
reader10, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs))
assert.NoError(t, err)
blobs, err = generateTestDataWithSeed(20, batchSize)
assert.NoError(t, err)
reader20, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs))
assert.NoError(t, err)
rr := []RecordReader{reader20, reader10}
lastPK := int64(-1)
lastTS := int64(-1)
rw := &MockRecordWriter{
writefn: func(r Record) error {
pk := r.Column(common.RowIDField).(*array.Int64).Value(0)
ts := r.Column(common.TimeStampField).(*array.Int64).Value(0)
assert.True(t, pk > lastPK || (pk == lastPK && ts > lastTS))
lastPK = pk
return nil
},
closefn: func() error {
lastPK = int64(-1)
return nil
},
}
gotNumRows, err := Sort(batchSize, generateTestSchema(), rr, rw, func(r Record, ri, i int) bool {
return true
}, sortByFieldIDs)
assert.NoError(t, err)
assert.Equal(t, batchSize*2, gotNumRows)
assert.NoError(t, rw.Close())
}