mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-02-04 11:18:44 +08:00
enhance: storage sort can sort by multiple fields (#43994)
https://github.com/milvus-io/milvus/issues/44011 this is to support compaction that sorts records by partition key and pk in the future --------- Signed-off-by: sunby <sunbingyi1992@gmail.com>
This commit is contained in:
parent
d55bf49bf1
commit
6624011927
@ -67,6 +67,7 @@ type sortCompactionTask struct {
|
||||
tr *timerecord.TimeRecorder
|
||||
|
||||
compactionParams compaction.Params
|
||||
sortByFieldIDs []int64
|
||||
}
|
||||
|
||||
var _ Compactor = (*sortCompactionTask)(nil)
|
||||
@ -76,6 +77,7 @@ func NewSortCompactionTask(
|
||||
binlogIO io.BinlogIO,
|
||||
plan *datapb.CompactionPlan,
|
||||
compactionParams compaction.Params,
|
||||
sortByFieldIDs []int64,
|
||||
) *sortCompactionTask {
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
return &sortCompactionTask{
|
||||
@ -87,6 +89,7 @@ func NewSortCompactionTask(
|
||||
currentTime: time.Now(),
|
||||
done: make(chan struct{}, 1),
|
||||
compactionParams: compactionParams,
|
||||
sortByFieldIDs: sortByFieldIDs,
|
||||
}
|
||||
}
|
||||
|
||||
@ -216,7 +219,7 @@ func (t *sortCompactionTask) sortSegment(ctx context.Context) (*datapb.Compactio
|
||||
}
|
||||
defer rr.Close()
|
||||
rrs := []storage.RecordReader{rr}
|
||||
numValidRows, err := storage.Sort(t.compactionParams.BinLogMaxSize, t.plan.GetSchema(), rrs, srw, predicate)
|
||||
numValidRows, err := storage.Sort(t.compactionParams.BinLogMaxSize, t.plan.GetSchema(), rrs, srw, predicate, t.sortByFieldIDs)
|
||||
if err != nil {
|
||||
log.Warn("sort failed", zap.Error(err))
|
||||
return nil, err
|
||||
|
||||
@ -35,6 +35,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func TestSortCompactionTaskSuite(t *testing.T) {
|
||||
@ -84,7 +85,10 @@ func (s *SortCompactionTaskSuite) setupTest() {
|
||||
CollectionTtl: time.Since(getMilvusBirthday().Add(-time.Hour)).Nanoseconds(),
|
||||
}
|
||||
|
||||
s.task = NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams())
|
||||
pk, err := typeutil.GetPrimaryFieldSchema(plan.GetSchema())
|
||||
s.NoError(err)
|
||||
|
||||
s.task = NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams(), []int64{pk.GetFieldID()})
|
||||
}
|
||||
|
||||
func (s *SortCompactionTaskSuite) SetupTest() {
|
||||
@ -99,9 +103,13 @@ func (s *SortCompactionTaskSuite) TestNewSortCompactionTask() {
|
||||
PlanID: 123,
|
||||
Type: datapb.CompactionType_SortCompaction,
|
||||
SlotUsage: 8,
|
||||
Schema: s.meta.GetSchema(),
|
||||
}
|
||||
|
||||
task := NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams())
|
||||
pk, err := typeutil.GetPrimaryFieldSchema(plan.GetSchema())
|
||||
s.NoError(err)
|
||||
|
||||
task := NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams(), []int64{pk.GetFieldID()})
|
||||
|
||||
s.NotNil(task)
|
||||
s.Equal(plan.GetPlanID(), task.GetPlanID())
|
||||
@ -242,7 +250,10 @@ func (s *SortCompactionTaskSuite) setupBM25Test() {
|
||||
TotalRows: 3,
|
||||
}
|
||||
|
||||
s.task = NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams())
|
||||
pk, err := typeutil.GetPrimaryFieldSchema(plan.GetSchema())
|
||||
s.NoError(err)
|
||||
|
||||
s.task = NewSortCompactionTask(context.Background(), s.mockBinlogIO, plan, compaction.GenParams(), []int64{pk.GetFieldID()})
|
||||
}
|
||||
|
||||
func (s *SortCompactionTaskSuite) prepareSortCompactionWithBM25Task() {
|
||||
@ -367,9 +378,13 @@ func TestSortCompactionTaskBasic(t *testing.T) {
|
||||
SegmentBinlogs: []*datapb.CompactionSegmentBinlogs{
|
||||
{SegmentID: 100},
|
||||
},
|
||||
Schema: genTestCollectionMeta().GetSchema(),
|
||||
}
|
||||
|
||||
task := NewSortCompactionTask(ctx, mockBinlogIO, plan, compaction.GenParams())
|
||||
pk, err := typeutil.GetPrimaryFieldSchema(plan.GetSchema())
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := NewSortCompactionTask(ctx, mockBinlogIO, plan, compaction.GenParams(), []int64{pk.GetFieldID()})
|
||||
|
||||
assert.NotNil(t, task)
|
||||
assert.Equal(t, int64(123), task.GetPlanID())
|
||||
|
||||
@ -252,7 +252,7 @@ func (st *statsTask) sort(ctx context.Context) ([]*datapb.FieldBinlog, error) {
|
||||
defer rr.Close()
|
||||
|
||||
rrs := []storage.RecordReader{rr}
|
||||
numValidRows, err := storage.Sort(st.req.GetBinlogMaxSize(), st.req.GetSchema(), rrs, srw, predicate)
|
||||
numValidRows, err := storage.Sort(st.req.GetBinlogMaxSize(), st.req.GetSchema(), rrs, srw, predicate, []int64{pkField.FieldID})
|
||||
if err != nil {
|
||||
log.Warn("sort failed", zap.Int64("taskID", st.req.GetTaskID()), zap.Error(err))
|
||||
return nil, err
|
||||
|
||||
@ -234,11 +234,16 @@ func (node *DataNode) CompactionV2(ctx context.Context, req *datapb.CompactionPl
|
||||
if req.GetPreAllocatedSegmentIDs() == nil || req.GetPreAllocatedSegmentIDs().GetBegin() == 0 {
|
||||
return merr.Status(merr.WrapErrParameterInvalidMsg("invalid pre-allocated segmentID range")), nil
|
||||
}
|
||||
pk, err := typeutil.GetPrimaryFieldSchema(req.GetSchema())
|
||||
if err != nil {
|
||||
return merr.Status(err), err
|
||||
}
|
||||
task = compactor.NewSortCompactionTask(
|
||||
taskCtx,
|
||||
binlogIO,
|
||||
req,
|
||||
compactionParams,
|
||||
[]int64{pk.GetFieldID()},
|
||||
)
|
||||
default:
|
||||
log.Warn("Unknown compaction type", zap.String("type", req.GetType().String()))
|
||||
|
||||
@ -24,11 +24,12 @@ import (
|
||||
"github.com/apache/arrow/go/v17/arrow/array"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func Sort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordReader,
|
||||
rw RecordWriter, predicate func(r Record, ri, i int) bool,
|
||||
rw RecordWriter, predicate func(r Record, ri, i int) bool, sortByFieldIDs []int64,
|
||||
) (int, error) {
|
||||
records := make([]Record, 0)
|
||||
|
||||
@ -69,24 +70,55 @@ func Sort(batchSize uint64, schema *schemapb.CollectionSchema, rr []RecordReader
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
pkField, err := typeutil.GetPrimaryFieldSchema(schema)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
pkFieldId := pkField.FieldID
|
||||
if len(sortByFieldIDs) > 0 {
|
||||
type keyCmp func(x, y *index) int
|
||||
comparators := make([]keyCmp, 0, len(sortByFieldIDs))
|
||||
for _, fid := range sortByFieldIDs {
|
||||
switch records[0].Column(fid).(type) {
|
||||
case *array.Int64:
|
||||
f := func(x, y *index) int {
|
||||
xVal := records[x.ri].Column(fid).(*array.Int64).Value(x.i)
|
||||
yVal := records[y.ri].Column(fid).(*array.Int64).Value(y.i)
|
||||
if xVal < yVal {
|
||||
return -1
|
||||
}
|
||||
if xVal > yVal {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
comparators = append(comparators, f)
|
||||
case *array.String:
|
||||
f := func(x, y *index) int {
|
||||
xVal := records[x.ri].Column(fid).(*array.String).Value(x.i)
|
||||
yVal := records[y.ri].Column(fid).(*array.String).Value(y.i)
|
||||
if xVal < yVal {
|
||||
return -1
|
||||
}
|
||||
if xVal > yVal {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
comparators = append(comparators, f)
|
||||
default:
|
||||
return 0, merr.WrapErrParameterInvalidMsg("unsupported type for sorting key")
|
||||
}
|
||||
}
|
||||
|
||||
switch records[0].Column(pkFieldId).(type) {
|
||||
case *array.Int64:
|
||||
sort.Slice(indices, func(i, j int) bool {
|
||||
pki := records[indices[i].ri].Column(pkFieldId).(*array.Int64).Value(indices[i].i)
|
||||
pkj := records[indices[j].ri].Column(pkFieldId).(*array.Int64).Value(indices[j].i)
|
||||
return pki < pkj
|
||||
})
|
||||
case *array.String:
|
||||
sort.Slice(indices, func(i, j int) bool {
|
||||
pki := records[indices[i].ri].Column(pkFieldId).(*array.String).Value(indices[i].i)
|
||||
pkj := records[indices[j].ri].Column(pkFieldId).(*array.String).Value(indices[j].i)
|
||||
return pki < pkj
|
||||
x := indices[i]
|
||||
y := indices[j]
|
||||
for _, cmp := range comparators {
|
||||
c := cmp(x, y)
|
||||
if c < 0 {
|
||||
return true
|
||||
}
|
||||
if c > 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -59,7 +59,7 @@ func TestSort(t *testing.T) {
|
||||
t.Run("sort", func(t *testing.T) {
|
||||
gotNumRows, err := Sort(batchSize, generateTestSchema(), getReaders(), rw, func(r Record, ri, i int) bool {
|
||||
return true
|
||||
})
|
||||
}, []int64{common.RowIDField})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 6, gotNumRows)
|
||||
err = rw.Close()
|
||||
@ -70,7 +70,7 @@ func TestSort(t *testing.T) {
|
||||
gotNumRows, err := Sort(batchSize, generateTestSchema(), getReaders(), rw, func(r Record, ri, i int) bool {
|
||||
pk := r.Column(common.RowIDField).(*array.Int64).Value(i)
|
||||
return pk >= 20
|
||||
})
|
||||
}, []int64{common.RowIDField})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, gotNumRows)
|
||||
err = rw.Close()
|
||||
@ -163,7 +163,45 @@ func BenchmarkSort(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
Sort(batchSize, generateTestSchema(), rr, rw, func(r Record, ri, i int) bool {
|
||||
return true
|
||||
})
|
||||
}, []int64{common.RowIDField})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSortByMoreThanOneField(t *testing.T) {
|
||||
const batchSize = 10000
|
||||
sortByFieldIDs := []int64{common.RowIDField, common.TimeStampField}
|
||||
|
||||
blobs, err := generateTestDataWithSeed(10, batchSize)
|
||||
assert.NoError(t, err)
|
||||
reader10, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs))
|
||||
assert.NoError(t, err)
|
||||
blobs, err = generateTestDataWithSeed(20, batchSize)
|
||||
assert.NoError(t, err)
|
||||
reader20, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs))
|
||||
assert.NoError(t, err)
|
||||
rr := []RecordReader{reader20, reader10}
|
||||
|
||||
lastPK := int64(-1)
|
||||
lastTS := int64(-1)
|
||||
rw := &MockRecordWriter{
|
||||
writefn: func(r Record) error {
|
||||
pk := r.Column(common.RowIDField).(*array.Int64).Value(0)
|
||||
ts := r.Column(common.TimeStampField).(*array.Int64).Value(0)
|
||||
assert.True(t, pk > lastPK || (pk == lastPK && ts > lastTS))
|
||||
lastPK = pk
|
||||
return nil
|
||||
},
|
||||
|
||||
closefn: func() error {
|
||||
lastPK = int64(-1)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
gotNumRows, err := Sort(batchSize, generateTestSchema(), rr, rw, func(r Record, ri, i int) bool {
|
||||
return true
|
||||
}, sortByFieldIDs)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, batchSize*2, gotNumRows)
|
||||
assert.NoError(t, rw.Close())
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user