diff --git a/internal/flushcommon/syncmgr/pack_writer.go b/internal/flushcommon/syncmgr/pack_writer.go index 9959838d97..e70d40972b 100644 --- a/internal/flushcommon/syncmgr/pack_writer.go +++ b/internal/flushcommon/syncmgr/pack_writer.go @@ -18,8 +18,12 @@ package syncmgr import ( "context" + "fmt" "path" + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" "github.com/samber/lo" "go.uber.org/zap" @@ -306,22 +310,100 @@ func (bw *BulkPackWriter) writeDelta(ctx context.Context, pack *SyncPack) (*data if pack.deltaData == nil { return &datapb.FieldBinlog{}, nil } - s, err := NewStorageSerializer(bw.metaCache, bw.schema) - if err != nil { - return nil, err + + pkField := func() *schemapb.FieldSchema { + for _, field := range bw.schema.Fields { + if field.IsPrimaryKey { + return field + } + } + return nil + }() + if pkField == nil { + return nil, fmt.Errorf("primary key field not found") } - deltaBlob, err := s.serializeDeltalog(pack) + + logID := bw.nextID() + k := metautil.JoinIDPath(pack.collectionID, pack.partitionID, pack.segmentID, logID) + path := path.Join(bw.chunkManager.RootPath(), common.SegmentDeltaLogPath, k) + writer, err := storage.NewDeltalogWriter( + ctx, pack.collectionID, pack.partitionID, pack.segmentID, logID, pkField.DataType, path, + storage.WithUploader(func(ctx context.Context, kvs map[string][]byte) error { + // Get the only blob in the map + if len(kvs) != 1 { + return fmt.Errorf("expected 1 blob, got %d", len(kvs)) + } + for _, blob := range kvs { + return bw.chunkManager.Write(ctx, path, blob) + } + return nil + }), + ) if err != nil { return nil, err } - k := metautil.JoinIDPath(pack.collectionID, pack.partitionID, pack.segmentID, bw.nextID()) - deltalog, err := bw.writeLog(ctx, deltaBlob, common.SegmentDeltaLogPath, k, pack) + pkType := func() arrow.DataType { + switch pkField.DataType { + case schemapb.DataType_Int64: + return arrow.PrimitiveTypes.Int64 + case schemapb.DataType_VarChar: + return arrow.BinaryTypes.String + default: + return nil + } + }() + if pkType == nil { + return nil, fmt.Errorf("unexpected pk type %v", pkField.DataType) + } + + pkBuilder := array.NewBuilder(memory.DefaultAllocator, pkType) + tsBuilder := array.NewBuilder(memory.DefaultAllocator, arrow.PrimitiveTypes.Int64) + defer pkBuilder.Release() + defer tsBuilder.Release() + + for i := int64(0); i < pack.deltaData.RowCount; i++ { + switch pkField.DataType { + case schemapb.DataType_Int64: + pkBuilder.(*array.Int64Builder).Append(pack.deltaData.Pks[i].GetValue().(int64)) + case schemapb.DataType_VarChar: + pkBuilder.(*array.StringBuilder).Append(pack.deltaData.Pks[i].GetValue().(string)) + default: + return nil, fmt.Errorf("unexpected pk type %v", pkField.DataType) + } + tsBuilder.(*array.Int64Builder).Append(int64(pack.deltaData.Tss[i])) + } + + pkArray := pkBuilder.NewArray() + tsArray := tsBuilder.NewArray() + record := storage.NewSimpleArrowRecord(array.NewRecord(arrow.NewSchema([]arrow.Field{ + {Name: "pk", Type: pkType}, + {Name: "ts", Type: arrow.PrimitiveTypes.Int64}, + }, nil), []arrow.Array{pkArray, tsArray}, pack.deltaData.RowCount), map[storage.FieldID]int{ + common.RowIDField: 0, + common.TimeStampField: 1, + }) + err = writer.Write(record) if err != nil { return nil, err } + err = writer.Close() + if err != nil { + return nil, err + } + + deltalog := &datapb.Binlog{ + EntriesNum: pack.deltaData.RowCount, + TimestampFrom: pack.tsFrom, + TimestampTo: pack.tsTo, + LogPath: path, + LogSize: pack.deltaData.Size() / 4, // Not used + MemorySize: pack.deltaData.Size(), + } + bw.sizeWritten += deltalog.LogSize + return &datapb.FieldBinlog{ - FieldID: s.pkField.GetFieldID(), + FieldID: pkField.GetFieldID(), Binlogs: []*datapb.Binlog{deltalog}, }, nil } diff --git a/internal/flushcommon/syncmgr/pack_writer_test.go b/internal/flushcommon/syncmgr/pack_writer_test.go index 8eaf364728..1935456e39 100644 --- a/internal/flushcommon/syncmgr/pack_writer_test.go +++ b/internal/flushcommon/syncmgr/pack_writer_test.go @@ -151,14 +151,14 @@ func TestBulkPackWriter_Write(t *testing.T) { { EntriesNum: 10, LogPath: "files/delta_log/123/456/789/10000", - LogSize: 592, - MemorySize: 327, + LogSize: 60, + MemorySize: 240, }, }, }, wantStats: map[int64]*datapb.FieldBinlog{}, wantBm25Stats: map[int64]*datapb.FieldBinlog{}, - wantSize: 592, + wantSize: 60, wantErr: nil, }, } diff --git a/internal/flushcommon/syncmgr/pack_writer_v2.go b/internal/flushcommon/syncmgr/pack_writer_v2.go index fd9eaa4121..fbed49f70f 100644 --- a/internal/flushcommon/syncmgr/pack_writer_v2.go +++ b/internal/flushcommon/syncmgr/pack_writer_v2.go @@ -205,7 +205,7 @@ func (bw *BulkPackWriterV2) writeInserts(ctx context.Context, pack *SyncPack) (m return logs, nil } -func (bw *BulkPackWriterV2) serializeBinlog(ctx context.Context, pack *SyncPack) (storage.Record, error) { +func (bw *BulkPackWriterV2) serializeBinlog(_ context.Context, pack *SyncPack) (storage.Record, error) { if len(pack.insertData) == 0 { return nil, nil } diff --git a/internal/flushcommon/syncmgr/storage_serializer.go b/internal/flushcommon/syncmgr/storage_serializer.go index dec1ab9a6c..8e6790ac8e 100644 --- a/internal/flushcommon/syncmgr/storage_serializer.go +++ b/internal/flushcommon/syncmgr/storage_serializer.go @@ -18,7 +18,6 @@ package syncmgr import ( "context" - "fmt" "strconv" "github.com/samber/lo" @@ -181,31 +180,6 @@ func (s *storageV1Serializer) serializeMergedBM25Stats(pack *SyncPack) (map[int6 return blobs, nil } -func (s *storageV1Serializer) serializeDeltalog(pack *SyncPack) (*storage.Blob, error) { - if len(pack.deltaData.Pks) == 0 { - return &storage.Blob{}, nil - } - - writer, finalizer, err := storage.CreateDeltalogWriter(pack.collectionID, pack.partitionID, pack.segmentID, pack.deltaData.Pks[0].Type(), 1024) - if err != nil { - return nil, err - } - - if len(pack.deltaData.Pks) != len(pack.deltaData.Tss) { - return nil, fmt.Errorf("pk and ts should have same length in delta log, but get %d and %d", len(pack.deltaData.Pks), len(pack.deltaData.Tss)) - } - - for i := 0; i < len(pack.deltaData.Pks); i++ { - deleteLog := storage.NewDeleteLog(pack.deltaData.Pks[i], pack.deltaData.Tss[i]) - err = writer.WriteValue(deleteLog) - if err != nil { - return nil, err - } - } - writer.Close() - return finalizer() -} - func hasBM25Function(schema *schemapb.CollectionSchema) bool { for _, function := range schema.GetFunctions() { if function.GetType() == schemapb.FunctionType_BM25 { diff --git a/internal/flushcommon/syncmgr/storage_serializer_test.go b/internal/flushcommon/syncmgr/storage_serializer_test.go index 4da6bf128c..0c57d64275 100644 --- a/internal/flushcommon/syncmgr/storage_serializer_test.go +++ b/internal/flushcommon/syncmgr/storage_serializer_test.go @@ -241,18 +241,6 @@ func (s *StorageV1SerializerSuite) TestSerializeInsert() { }) } -func (s *StorageV1SerializerSuite) TestSerializeDelete() { - s.Run("serialize_normal", func() { - pack := s.getBasicPack() - pack.WithDeleteData(s.getDeleteBuffer()) - pack.WithTimeRange(50, 100) - - blob, err := s.serializer.serializeDeltalog(pack) - s.NoError(err) - s.NotNil(blob) - }) -} - func (s *StorageV1SerializerSuite) TestBadSchema() { mockCache := metacache.NewMockMetaCache(s.T()) _, err := NewStorageSerializer(mockCache, &schemapb.CollectionSchema{}) diff --git a/internal/storage/rw.go b/internal/storage/rw.go index c80fd85e10..4fb698fad7 100644 --- a/internal/storage/rw.go +++ b/internal/storage/rw.go @@ -70,9 +70,6 @@ type rwOptions struct { } func (o *rwOptions) validate() error { - if o.storageConfig == nil { - return merr.WrapErrServiceInternal("storage config is nil") - } if o.collectionID == 0 { log.Warn("storage config collection id is empty when init BinlogReader") // return merr.WrapErrServiceInternal("storage config collection id is empty") @@ -86,6 +83,9 @@ func (o *rwOptions) validate() error { return merr.WrapErrServiceInternal("downloader is nil for v1 reader") } case StorageV2: + if o.storageConfig == nil { + return merr.WrapErrServiceInternal("storage config is nil") + } default: return merr.WrapErrServiceInternal(fmt.Sprintf("unsupported storage version %d", o.version)) } @@ -266,7 +266,7 @@ func NewBinlogRecordReader(ctx context.Context, binlogs []*datapb.FieldBinlog, s if err != nil { return nil, err } - rr, err = newCompositeBinlogRecordReader(schema, blobsReader, binlogReaderOpts...) + rr = newIterativeCompositeBinlogRecordReader(schema, rwOptions.neededFields, blobsReader, binlogReaderOpts...) case StorageV2: if len(binlogs) <= 0 { return nil, sio.EOF @@ -288,16 +288,14 @@ func NewBinlogRecordReader(ctx context.Context, binlogs []*datapb.FieldBinlog, s paths[j] = append(paths[j], logPath) } } - rr, err = newPackedRecordReader(paths, schema, rwOptions.bufferSize, rwOptions.storageConfig, pluginContext) + // FIXME: add needed fields support + rr = newIterativePackedRecordReader(paths, schema, rwOptions.bufferSize, rwOptions.storageConfig, pluginContext) default: return nil, merr.WrapErrServiceInternal(fmt.Sprintf("unsupported storage version %d", rwOptions.version)) } if err != nil { return nil, err } - if rwOptions.neededFields != nil { - rr.SetNeededFields(rwOptions.neededFields) - } return rr, nil } @@ -361,3 +359,36 @@ func NewBinlogRecordWriter(ctx context.Context, collectionID, partitionID, segme } return nil, merr.WrapErrServiceInternal(fmt.Sprintf("unsupported storage version %d", rwOptions.version)) } + +func NewDeltalogWriter( + ctx context.Context, + collectionID, partitionID, segmentID, logID UniqueID, + pkType schemapb.DataType, + path string, + option ...RwOption, +) (RecordWriter, error) { + rwOptions := DefaultWriterOptions() + for _, opt := range option { + opt(rwOptions) + } + if err := rwOptions.validate(); err != nil { + return nil, err + } + return NewLegacyDeltalogWriter(collectionID, partitionID, segmentID, logID, pkType, rwOptions.uploader, path) +} + +func NewDeltalogReader( + pkField *schemapb.FieldSchema, + paths []string, + option ...RwOption, +) (RecordReader, error) { + rwOptions := DefaultReaderOptions() + for _, opt := range option { + opt(rwOptions) + } + if err := rwOptions.validate(); err != nil { + return nil, err + } + + return NewLegacyDeltalogReader(pkField, rwOptions.downloader, paths) +} diff --git a/internal/storage/serde.go b/internal/storage/serde.go index 445a9567ab..a35c474926 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -46,7 +46,6 @@ type Record interface { type RecordReader interface { Next() (Record, error) - SetNeededFields(fields typeutil.Set[int64]) Close() error } diff --git a/internal/storage/serde_delta.go b/internal/storage/serde_delta.go new file mode 100644 index 0000000000..1c88bea1e5 --- /dev/null +++ b/internal/storage/serde_delta.go @@ -0,0 +1,627 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "strconv" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/common" + "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +// newDeltalogOneFieldReader creates a reader for the old single-field deltalog format +func newDeltalogOneFieldReader(blobs []*Blob) (*DeserializeReaderImpl[*DeleteLog], error) { + reader := newIterativeCompositeBinlogRecordReader( + &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + DataType: schemapb.DataType_VarChar, + }, + }, + }, + nil, + MakeBlobsReader(blobs)) + return NewDeserializeReader(reader, func(r Record, v []*DeleteLog) error { + for i := 0; i < r.Len(); i++ { + if v[i] == nil { + v[i] = &DeleteLog{} + } + // retrieve the only field + a := r.(*compositeRecord).recs[0].(*array.String) + strVal := a.Value(i) + if err := v[i].Parse(strVal); err != nil { + return err + } + } + return nil + }), nil +} + +// DeltalogStreamWriter writes deltalog in the old JSON format +type DeltalogStreamWriter struct { + collectionID UniqueID + partitionID UniqueID + segmentID UniqueID + fieldSchema *schemapb.FieldSchema + + buf bytes.Buffer + rw *singleFieldRecordWriter +} + +func (dsw *DeltalogStreamWriter) GetRecordWriter() (RecordWriter, error) { + if dsw.rw != nil { + return dsw.rw, nil + } + rw, err := newSingleFieldRecordWriter(dsw.fieldSchema, &dsw.buf, WithRecordWriterProps(getFieldWriterProps(dsw.fieldSchema))) + if err != nil { + return nil, err + } + dsw.rw = rw + return rw, nil +} + +func (dsw *DeltalogStreamWriter) Finalize() (*Blob, error) { + if dsw.rw == nil { + return nil, io.ErrUnexpectedEOF + } + dsw.rw.Close() + + var b bytes.Buffer + if err := dsw.writeDeltalogHeaders(&b); err != nil { + return nil, err + } + if _, err := b.Write(dsw.buf.Bytes()); err != nil { + return nil, err + } + return &Blob{ + Value: b.Bytes(), + RowNum: int64(dsw.rw.numRows), + MemorySize: int64(dsw.rw.writtenUncompressed), + }, nil +} + +func (dsw *DeltalogStreamWriter) writeDeltalogHeaders(w io.Writer) error { + // Write magic number + if err := binary.Write(w, common.Endian, MagicNumber); err != nil { + return err + } + // Write descriptor + de := NewBaseDescriptorEvent(dsw.collectionID, dsw.partitionID, dsw.segmentID) + de.PayloadDataType = dsw.fieldSchema.DataType + de.descriptorEventData.AddExtra(originalSizeKey, strconv.Itoa(int(dsw.rw.writtenUncompressed))) + if err := de.Write(w); err != nil { + return err + } + // Write event header + eh := newEventHeader(DeleteEventType) + // Write event data + ev := newDeleteEventData() + ev.StartTimestamp = 1 + ev.EndTimestamp = 1 + eh.EventLength = int32(dsw.buf.Len()) + eh.GetMemoryUsageInBytes() + int32(binary.Size(ev)) + // eh.NextPosition = eh.EventLength + w.Offset() + if err := eh.Write(w); err != nil { + return err + } + if err := ev.WriteEventData(w); err != nil { + return err + } + return nil +} + +func newDeltalogStreamWriter(collectionID, partitionID, segmentID UniqueID) *DeltalogStreamWriter { + return &DeltalogStreamWriter{ + collectionID: collectionID, + partitionID: partitionID, + segmentID: segmentID, + fieldSchema: &schemapb.FieldSchema{ + FieldID: common.RowIDField, + Name: "delta", + DataType: schemapb.DataType_String, + }, + } +} + +func newDeltalogSerializeWriter(eventWriter *DeltalogStreamWriter, batchSize int) (*SerializeWriterImpl[*DeleteLog], error) { + rws := make(map[FieldID]RecordWriter, 1) + rw, err := eventWriter.GetRecordWriter() + if err != nil { + return nil, err + } + rws[0] = rw + compositeRecordWriter := NewCompositeRecordWriter(rws) + return NewSerializeRecordWriter(compositeRecordWriter, func(v []*DeleteLog) (Record, error) { + builder := array.NewBuilder(memory.DefaultAllocator, arrow.BinaryTypes.String) + + for _, vv := range v { + strVal, err := json.Marshal(vv) + if err != nil { + return nil, err + } + + builder.AppendValueFromString(string(strVal)) + } + arr := []arrow.Array{builder.NewArray()} + field := []arrow.Field{{ + Name: "delta", + Type: arrow.BinaryTypes.String, + Nullable: false, + }} + field2Col := map[FieldID]int{ + 0: 0, + } + return NewSimpleArrowRecord(array.NewRecord(arrow.NewSchema(field, nil), arr, int64(len(v))), field2Col), nil + }, batchSize), nil +} + +var _ RecordReader = (*simpleArrowRecordReader)(nil) + +// simpleArrowRecordReader reads simple arrow records from blobs +type simpleArrowRecordReader struct { + blobs []*Blob + + blobPos int + rr array.RecordReader + closer func() + + r simpleArrowRecord +} + +func (crr *simpleArrowRecordReader) iterateNextBatch() error { + if crr.closer != nil { + crr.closer() + } + + crr.blobPos++ + if crr.blobPos >= len(crr.blobs) { + return io.EOF + } + + reader, err := NewBinlogReader(crr.blobs[crr.blobPos].Value) + if err != nil { + return err + } + + er, err := reader.NextEventReader() + if err != nil { + return err + } + rr, err := er.GetArrowRecordReader() + if err != nil { + return err + } + crr.rr = rr + crr.closer = func() { + crr.rr.Release() + er.Close() + reader.Close() + } + + return nil +} + +func (crr *simpleArrowRecordReader) Next() (Record, error) { + if crr.rr == nil { + if len(crr.blobs) == 0 { + return nil, io.EOF + } + crr.blobPos = -1 + crr.r = simpleArrowRecord{ + field2Col: make(map[FieldID]int), + } + if err := crr.iterateNextBatch(); err != nil { + return nil, err + } + } + + composeRecord := func() bool { + if ok := crr.rr.Next(); !ok { + return false + } + record := crr.rr.Record() + for i := range record.Schema().Fields() { + crr.r.field2Col[FieldID(i)] = i + } + crr.r.r = record + return true + } + + if ok := composeRecord(); !ok { + if err := crr.iterateNextBatch(); err != nil { + return nil, err + } + if ok := composeRecord(); !ok { + return nil, io.EOF + } + } + return &crr.r, nil +} + +func (crr *simpleArrowRecordReader) SetNeededFields(_ typeutil.Set[int64]) { + // no-op for simple arrow record reader +} + +func (crr *simpleArrowRecordReader) Close() error { + if crr.closer != nil { + crr.closer() + } + return nil +} + +func newSimpleArrowRecordReader(blobs []*Blob) (*simpleArrowRecordReader, error) { + return &simpleArrowRecordReader{ + blobs: blobs, + }, nil +} + +// MultiFieldDeltalogStreamWriter writes deltalog in the new multi-field parquet format +type MultiFieldDeltalogStreamWriter struct { + collectionID UniqueID + partitionID UniqueID + segmentID UniqueID + pkType schemapb.DataType + + buf bytes.Buffer + rw *multiFieldRecordWriter +} + +func newMultiFieldDeltalogStreamWriter(collectionID, partitionID, segmentID UniqueID, pkType schemapb.DataType) *MultiFieldDeltalogStreamWriter { + return &MultiFieldDeltalogStreamWriter{ + collectionID: collectionID, + partitionID: partitionID, + segmentID: segmentID, + pkType: pkType, + } +} + +func (dsw *MultiFieldDeltalogStreamWriter) GetRecordWriter() (RecordWriter, error) { + if dsw.rw != nil { + return dsw.rw, nil + } + + fieldIDs := []FieldID{common.RowIDField, common.TimeStampField} // Not used. + fields := []arrow.Field{ + { + Name: "pk", + Type: serdeMap[dsw.pkType].arrowType(0, schemapb.DataType_None), + Nullable: false, + }, + { + Name: "ts", + Type: arrow.PrimitiveTypes.Int64, + Nullable: false, + }, + } + + rw, err := newMultiFieldRecordWriter(fieldIDs, fields, &dsw.buf) + if err != nil { + return nil, err + } + dsw.rw = rw + return rw, nil +} + +func (dsw *MultiFieldDeltalogStreamWriter) Finalize() (*Blob, error) { + if dsw.rw == nil { + return nil, io.ErrUnexpectedEOF + } + dsw.rw.Close() + + var b bytes.Buffer + if err := dsw.writeDeltalogHeaders(&b); err != nil { + return nil, err + } + if _, err := b.Write(dsw.buf.Bytes()); err != nil { + return nil, err + } + return &Blob{ + Value: b.Bytes(), + RowNum: int64(dsw.rw.numRows), + MemorySize: int64(dsw.rw.writtenUncompressed), + }, nil +} + +func (dsw *MultiFieldDeltalogStreamWriter) writeDeltalogHeaders(w io.Writer) error { + // Write magic number + if err := binary.Write(w, common.Endian, MagicNumber); err != nil { + return err + } + // Write descriptor + de := NewBaseDescriptorEvent(dsw.collectionID, dsw.partitionID, dsw.segmentID) + de.PayloadDataType = schemapb.DataType_Int64 + de.descriptorEventData.AddExtra(originalSizeKey, strconv.Itoa(int(dsw.rw.writtenUncompressed))) + de.descriptorEventData.AddExtra(version, MultiField) + if err := de.Write(w); err != nil { + return err + } + // Write event header + eh := newEventHeader(DeleteEventType) + // Write event data + ev := newDeleteEventData() + ev.StartTimestamp = 1 + ev.EndTimestamp = 1 + eh.EventLength = int32(dsw.buf.Len()) + eh.GetMemoryUsageInBytes() + int32(binary.Size(ev)) + // eh.NextPosition = eh.EventLength + w.Offset() + if err := eh.Write(w); err != nil { + return err + } + if err := ev.WriteEventData(w); err != nil { + return err + } + return nil +} + +func newDeltalogMultiFieldWriter(eventWriter *MultiFieldDeltalogStreamWriter, batchSize int) (*SerializeWriterImpl[*DeleteLog], error) { + rw, err := eventWriter.GetRecordWriter() + if err != nil { + return nil, err + } + return NewSerializeRecordWriter[*DeleteLog](rw, func(v []*DeleteLog) (Record, error) { + fields := []arrow.Field{ + { + Name: "pk", + Type: serdeMap[schemapb.DataType(v[0].PkType)].arrowType(0, schemapb.DataType_None), + Nullable: false, + }, + { + Name: "ts", + Type: arrow.PrimitiveTypes.Int64, + Nullable: false, + }, + } + arrowSchema := arrow.NewSchema(fields, nil) + builder := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema) + defer builder.Release() + + pkType := schemapb.DataType(v[0].PkType) + switch pkType { + case schemapb.DataType_Int64: + pb := builder.Field(0).(*array.Int64Builder) + for _, vv := range v { + pk := vv.Pk.GetValue().(int64) + pb.Append(pk) + } + case schemapb.DataType_VarChar: + pb := builder.Field(0).(*array.StringBuilder) + for _, vv := range v { + pk := vv.Pk.GetValue().(string) + pb.Append(pk) + } + default: + return nil, fmt.Errorf("unexpected pk type %v", v[0].PkType) + } + + for _, vv := range v { + builder.Field(1).(*array.Int64Builder).Append(int64(vv.Ts)) + } + + arr := []arrow.Array{builder.Field(0).NewArray(), builder.Field(1).NewArray()} + + field2Col := map[FieldID]int{ + common.RowIDField: 0, + common.TimeStampField: 1, + } + return NewSimpleArrowRecord(array.NewRecord(arrowSchema, arr, int64(len(v))), field2Col), nil + }, batchSize), nil +} + +func newDeltalogMultiFieldReader(blobs []*Blob) (*DeserializeReaderImpl[*DeleteLog], error) { + reader, err := newSimpleArrowRecordReader(blobs) + if err != nil { + return nil, err + } + return NewDeserializeReader(reader, func(r Record, v []*DeleteLog) error { + rec, ok := r.(*simpleArrowRecord) + if !ok { + return errors.New("can not cast to simple arrow record") + } + fields := rec.r.Schema().Fields() + switch fields[0].Type.ID() { + case arrow.INT64: + arr := r.Column(0).(*array.Int64) + for j := 0; j < r.Len(); j++ { + if v[j] == nil { + v[j] = &DeleteLog{} + } + v[j].Pk = NewInt64PrimaryKey(arr.Value(j)) + } + case arrow.STRING: + arr := r.Column(0).(*array.String) + for j := 0; j < r.Len(); j++ { + if v[j] == nil { + v[j] = &DeleteLog{} + } + v[j].Pk = NewVarCharPrimaryKey(arr.Value(j)) + } + default: + return fmt.Errorf("unexpected delta log pkType %v", fields[0].Type.Name()) + } + + arr := r.Column(1).(*array.Int64) + for j := 0; j < r.Len(); j++ { + v[j].Ts = uint64(arr.Value(j)) + } + return nil + }), nil +} + +// newDeltalogDeserializeReader is the entry point for the delta log reader. +// It includes newDeltalogOneFieldReader, which uses the existing log format with only one column in a log file, +// and newDeltalogMultiFieldReader, which uses the new format and supports multiple fields in a log file. +func newDeltalogDeserializeReader(blobs []*Blob) (*DeserializeReaderImpl[*DeleteLog], error) { + if supportMultiFieldFormat(blobs) { + return newDeltalogMultiFieldReader(blobs) + } + return newDeltalogOneFieldReader(blobs) +} + +// supportMultiFieldFormat checks delta log description data to see if it is the format with +// pk and ts column separately +func supportMultiFieldFormat(blobs []*Blob) bool { + if len(blobs) > 0 { + reader, err := NewBinlogReader(blobs[0].Value) + if err != nil { + return false + } + defer reader.Close() + version := reader.descriptorEventData.Extras[version] + return version != nil && version.(string) == MultiField + } + return false +} + +// CreateDeltalogReader creates a deltalog reader based on the format version +func CreateDeltalogReader(blobs []*Blob) (*DeserializeReaderImpl[*DeleteLog], error) { + return newDeltalogDeserializeReader(blobs) +} + +// createDeltalogWriter creates a deltalog writer based on the configured format +func createDeltalogWriter(collectionID, partitionID, segmentID UniqueID, pkType schemapb.DataType, batchSize int, +) (*SerializeWriterImpl[*DeleteLog], func() (*Blob, error), error) { + format := paramtable.Get().DataNodeCfg.DeltalogFormat.GetValue() + switch format { + case "json": + eventWriter := newDeltalogStreamWriter(collectionID, partitionID, segmentID) + writer, err := newDeltalogSerializeWriter(eventWriter, batchSize) + return writer, eventWriter.Finalize, err + case "parquet": + eventWriter := newMultiFieldDeltalogStreamWriter(collectionID, partitionID, segmentID, pkType) + writer, err := newDeltalogMultiFieldWriter(eventWriter, batchSize) + return writer, eventWriter.Finalize, err + default: + return nil, nil, merr.WrapErrParameterInvalid("unsupported deltalog format %s", format) + } +} + +type LegacyDeltalogWriter struct { + path string + pkType schemapb.DataType + writer *SerializeWriterImpl[*DeleteLog] + finalizer func() (*Blob, error) + writtenUncompressed uint64 + + uploader uploaderFn +} + +var _ RecordWriter = (*LegacyDeltalogWriter)(nil) + +func NewLegacyDeltalogWriter( + collectionID, partitionID, segmentID, logID UniqueID, pkType schemapb.DataType, uploader uploaderFn, path string, +) (*LegacyDeltalogWriter, error) { + writer, finalizer, err := createDeltalogWriter(collectionID, partitionID, segmentID, pkType, 4096) + if err != nil { + return nil, err + } + + return &LegacyDeltalogWriter{ + path: path, + pkType: pkType, + writer: writer, + finalizer: finalizer, + uploader: uploader, + }, nil +} + +func (w *LegacyDeltalogWriter) Write(rec Record) error { + newDeleteLog := func(i int) (*DeleteLog, error) { + ts := Timestamp(rec.Column(1).(*array.Int64).Value(i)) + switch w.pkType { + case schemapb.DataType_Int64: + pk := NewInt64PrimaryKey(rec.Column(0).(*array.Int64).Value(i)) + return NewDeleteLog(pk, ts), nil + case schemapb.DataType_VarChar: + pk := NewVarCharPrimaryKey(rec.Column(0).(*array.String).Value(i)) + return NewDeleteLog(pk, ts), nil + default: + return nil, fmt.Errorf("unexpected pk type %v", w.pkType) + } + } + + for i := range rec.Len() { + deleteLog, err := newDeleteLog(i) + if err != nil { + return err + } + err = w.writer.WriteValue(deleteLog) + if err != nil { + return err + } + } + w.writtenUncompressed += (rec.Column(0).Data().SizeInBytes() + rec.Column(1).Data().SizeInBytes()) + return nil +} + +func (w *LegacyDeltalogWriter) Close() error { + err := w.writer.Close() + if err != nil { + return err + } + blob, err := w.finalizer() + if err != nil { + return err + } + + return w.uploader(context.Background(), map[string][]byte{blob.Key: blob.Value}) +} + +func (w *LegacyDeltalogWriter) GetWrittenUncompressed() uint64 { + return w.writtenUncompressed +} + +func NewLegacyDeltalogReader(pkField *schemapb.FieldSchema, downloader downloaderFn, paths []string) (RecordReader, error) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + pkField, + { + FieldID: common.TimeStampField, + DataType: schemapb.DataType_Int64, + }, + }, + } + + chunkPos := 0 + blobsReader := func() ([]*Blob, error) { + path := paths[chunkPos] + chunkPos++ + blobs, err := downloader(context.Background(), []string{path}) + if err != nil { + return nil, err + } + return []*Blob{{Key: path, Value: blobs[0]}}, nil + } + + return newIterativeCompositeBinlogRecordReader( + schema, + nil, + blobsReader, + nil, + ), nil +} diff --git a/internal/storage/serde_delta_test.go b/internal/storage/serde_delta_test.go new file mode 100644 index 0000000000..1498532039 --- /dev/null +++ b/internal/storage/serde_delta_test.go @@ -0,0 +1,155 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" +) + +func TestDeltalogReaderWriter(t *testing.T) { + const ( + testCollectionID = int64(1) + testPartitionID = int64(2) + testSegmentID = int64(3) + testBatchSize = 1024 + testNumLogs = 100 + ) + + type deleteLogGenerator func(i int) *DeleteLog + + tests := []struct { + name string + format string + pkType schemapb.DataType + logGenerator deleteLogGenerator + wantErr bool + }{ + { + name: "Int64 PK - JSON format", + format: "json", + pkType: schemapb.DataType_Int64, + logGenerator: func(i int) *DeleteLog { + return NewDeleteLog(NewInt64PrimaryKey(int64(i)), uint64(100+i)) + }, + wantErr: false, + }, + { + name: "VarChar PK - JSON format", + format: "json", + pkType: schemapb.DataType_VarChar, + logGenerator: func(i int) *DeleteLog { + return NewDeleteLog(NewVarCharPrimaryKey("key_"+string(rune(i))), uint64(100+i)) + }, + wantErr: false, + }, + { + name: "Int64 PK - Parquet format", + format: "parquet", + pkType: schemapb.DataType_Int64, + logGenerator: func(i int) *DeleteLog { + return NewDeleteLog(NewInt64PrimaryKey(int64(i)), uint64(100+i)) + }, + wantErr: false, + }, + { + name: "VarChar PK - Parquet format", + format: "parquet", + pkType: schemapb.DataType_VarChar, + logGenerator: func(i int) *DeleteLog { + return NewDeleteLog(NewVarCharPrimaryKey("key_"+string(rune(i))), uint64(100+i)) + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set deltalog format + originalFormat := paramtable.Get().DataNodeCfg.DeltalogFormat.GetValue() + paramtable.Get().Save(paramtable.Get().DataNodeCfg.DeltalogFormat.Key, tt.format) + defer paramtable.Get().Save(paramtable.Get().DataNodeCfg.DeltalogFormat.Key, originalFormat) + + writer, finalizer, err := createDeltalogWriter(testCollectionID, testPartitionID, testSegmentID, tt.pkType, testBatchSize) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.NotNil(t, writer) + assert.NotNil(t, finalizer) + + // Write delete logs + expectedLogs := make([]*DeleteLog, 0, testNumLogs) + for i := 0; i < testNumLogs; i++ { + deleteLog := tt.logGenerator(i) + expectedLogs = append(expectedLogs, deleteLog) + err = writer.WriteValue(deleteLog) + require.NoError(t, err) + } + + err = writer.Close() + require.NoError(t, err) + + blob, err := finalizer() + require.NoError(t, err) + assert.NotNil(t, blob) + assert.Greater(t, len(blob.Value), 0) + + // Test round trip + reader, err := CreateDeltalogReader([]*Blob{blob}) + require.NoError(t, err) + require.NotNil(t, reader) + + // Read and verify contents + readLogs := make([]*DeleteLog, 0) + for { + log, err := reader.NextValue() + if err != nil { + break + } + if log != nil { + readLogs = append(readLogs, *log) + } + } + + assert.Equal(t, len(expectedLogs), len(readLogs)) + for i := 0; i < len(expectedLogs); i++ { + assert.Equal(t, expectedLogs[i].Ts, readLogs[i].Ts) + assert.Equal(t, expectedLogs[i].Pk.GetValue(), readLogs[i].Pk.GetValue()) + } + + err = reader.Close() + assert.NoError(t, err) + }) + } +} + +func TestDeltalogStreamWriter_NoRecordWriter(t *testing.T) { + writer := newDeltalogStreamWriter(1, 2, 3) + assert.NotNil(t, writer) + + // Finalize without getting record writer should return error + blob, err := writer.Finalize() + assert.Error(t, err) + assert.Nil(t, blob) +} diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index 511fa639cf..2dc2adb931 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -28,172 +28,104 @@ import ( "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" "github.com/apache/arrow/go/v17/arrow/memory" - "github.com/cockroachdb/errors" "github.com/samber/lo" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/hook" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" - "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" - "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/metautil" - "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) -var _ RecordReader = (*CompositeBinlogRecordReader)(nil) +type IterativeRecordReader struct { + cur RecordReader + iterate func() (RecordReader, error) +} + +// Close implements RecordReader. +func (ir *IterativeRecordReader) Close() error { + if ir.cur != nil { + return ir.cur.Close() + } + return nil +} + +var _ RecordReader = (*IterativeRecordReader)(nil) + +func (ir *IterativeRecordReader) Next() (Record, error) { + if ir.cur == nil { + r, err := ir.iterate() + if err != nil { + return nil, err + } + ir.cur = r + } + rec, err := ir.cur.Next() + if err == io.EOF { + closeErr := ir.cur.Close() + if closeErr != nil { + return nil, closeErr + } + ir.cur, err = ir.iterate() + if err != nil { + return nil, err + } + rec, err = ir.cur.Next() + } + return rec, err +} // ChunkedBlobsReader returns a chunk composed of blobs, or io.EOF if no more data type ChunkedBlobsReader func() ([]*Blob, error) type CompositeBinlogRecordReader struct { - BlobsReader ChunkedBlobsReader - schema *schemapb.CollectionSchema - index map[FieldID]int16 - + fields map[FieldID]*schemapb.FieldSchema + index map[FieldID]int16 brs []*BinlogReader - bropts []BinlogReaderOption rrs []array.RecordReader } -func (crr *CompositeBinlogRecordReader) iterateNextBatch() error { - if crr.brs != nil { - for _, er := range crr.brs { - if er != nil { - er.Close() - } - } - } - if crr.rrs != nil { - for _, rr := range crr.rrs { - if rr != nil { - rr.Release() - } - } - } - - blobs, err := crr.BlobsReader() - if err != nil { - return err - } - - fieldNum := len(crr.schema.Fields) - for _, f := range crr.schema.StructArrayFields { - fieldNum += len(f.Fields) - } - - crr.rrs = make([]array.RecordReader, fieldNum) - crr.brs = make([]*BinlogReader, fieldNum) - - for _, b := range blobs { - reader, err := NewBinlogReader(b.Value, crr.bropts...) - if err != nil { - return err - } - - er, err := reader.NextEventReader() - if err != nil { - return err - } - i := crr.index[reader.FieldID] - rr, err := er.GetArrowRecordReader() - if err != nil { - return err - } - crr.rrs[i] = rr - crr.brs[i] = reader - } - return nil -} +var _ RecordReader = (*CompositeBinlogRecordReader)(nil) func (crr *CompositeBinlogRecordReader) Next() (Record, error) { - if crr.rrs == nil { - if err := crr.iterateNextBatch(); err != nil { + recs := make([]arrow.Array, len(crr.fields)) + nonExistingFields := make([]*schemapb.FieldSchema, 0) + nRows := 0 + for _, f := range crr.fields { + idx := crr.index[f.FieldID] + if crr.rrs[idx] != nil { + if ok := crr.rrs[idx].Next(); !ok { + return nil, io.EOF + } + r := crr.rrs[idx].Record() + recs[idx] = r.Column(0) + if nRows == 0 { + nRows = int(r.NumRows()) + } + if nRows != int(r.NumRows()) { + return nil, merr.WrapErrServiceInternal(fmt.Sprintf("number of rows mismatch for field %d", f.FieldID)) + } + } else { + nonExistingFields = append(nonExistingFields, f) + } + } + for _, f := range nonExistingFields { + // If the field is not in the current batch, fill with null array + arr, err := GenerateEmptyArrayFromSchema(f, nRows) + if err != nil { return nil, err } + recs[crr.index[f.FieldID]] = arr } - - composeRecord := func() (Record, error) { - fieldNum := len(crr.schema.Fields) - for _, f := range crr.schema.StructArrayFields { - fieldNum += len(f.Fields) - } - recs := make([]arrow.Array, fieldNum) - - appendFieldRecord := func(f *schemapb.FieldSchema) error { - idx := crr.index[f.FieldID] - if crr.rrs[idx] != nil { - if ok := crr.rrs[idx].Next(); !ok { - return io.EOF - } - recs[idx] = crr.rrs[idx].Record().Column(0) - } else { - // If the field is not in the current batch, fill with null array - // Note that we're intentionally not filling default value here, because the - // deserializer will fill them later. - numRows := int(crr.rrs[0].Record().NumRows()) - arr, err := GenerateEmptyArrayFromSchema(f, numRows) - if err != nil { - return err - } - recs[idx] = arr - } - return nil - } - - for _, f := range crr.schema.Fields { - if err := appendFieldRecord(f); err != nil { - return nil, err - } - } - for _, f := range crr.schema.StructArrayFields { - for _, sf := range f.Fields { - if err := appendFieldRecord(sf); err != nil { - return nil, err - } - } - } - return &compositeRecord{ - index: crr.index, - recs: recs, - }, nil - } - - // Try compose records - r, err := composeRecord() - if err == io.EOF { - // if EOF, try iterate next batch (blob) - if err := crr.iterateNextBatch(); err != nil { - return nil, err - } - r, err = composeRecord() // try compose again - } - if err != nil { - return nil, err - } - return r, nil -} - -func (crr *CompositeBinlogRecordReader) SetNeededFields(neededFields typeutil.Set[int64]) { - if neededFields == nil { - return - } - - crr.schema = &schemapb.CollectionSchema{ - Fields: lo.Filter(crr.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool { - return neededFields.Contain(field.GetFieldID()) - }), - } - index := make(map[FieldID]int16) - for i, f := range crr.schema.Fields { - index[f.FieldID] = int16(i) - } - crr.index = index + return &compositeRecord{ + index: crr.index, + recs: recs, + }, nil } func (crr *CompositeBinlogRecordReader) Close() error { @@ -256,28 +188,74 @@ func MakeBlobsReader(blobs []*Blob) ChunkedBlobsReader { } } -func newCompositeBinlogRecordReader(schema *schemapb.CollectionSchema, blobsReader ChunkedBlobsReader, opts ...BinlogReaderOption) (*CompositeBinlogRecordReader, error) { +func newCompositeBinlogRecordReader( + schema *schemapb.CollectionSchema, + neededFields typeutil.Set[int64], + blobs []*Blob, + opts ...BinlogReaderOption, +) (*CompositeBinlogRecordReader, error) { + allFields := typeutil.GetAllFieldSchemas(schema) + if neededFields != nil { + allFields = lo.Filter(allFields, func(field *schemapb.FieldSchema, _ int) bool { + return neededFields.Contain(field.GetFieldID()) + }) + } + idx := 0 index := make(map[FieldID]int16) - for _, f := range schema.Fields { + fields := make(map[FieldID]*schemapb.FieldSchema) + for _, f := range allFields { index[f.FieldID] = int16(idx) + fields[f.FieldID] = f idx++ } - for _, f := range schema.StructArrayFields { - for _, sf := range f.Fields { - index[sf.FieldID] = int16(idx) - idx++ + + rrs := make([]array.RecordReader, len(allFields)) + brs := make([]*BinlogReader, len(allFields)) + for _, b := range blobs { + reader, err := NewBinlogReader(b.Value, opts...) + if err != nil { + return nil, err } + + er, err := reader.NextEventReader() + if err != nil { + return nil, err + } + i := index[reader.FieldID] + rr, err := er.GetArrowRecordReader() + if err != nil { + return nil, err + } + rrs[i] = rr + brs[i] = reader } return &CompositeBinlogRecordReader{ - schema: schema, - BlobsReader: blobsReader, - index: index, - bropts: opts, + fields: fields, + index: index, + rrs: rrs, + brs: brs, }, nil } +func newIterativeCompositeBinlogRecordReader( + schema *schemapb.CollectionSchema, + neededFields typeutil.Set[int64], + chunkedBlobs ChunkedBlobsReader, + opts ...BinlogReaderOption, +) *IterativeRecordReader { + return &IterativeRecordReader{ + iterate: func() (RecordReader, error) { + blobs, err := chunkedBlobs() + if err != nil { + return nil, err + } + return newCompositeBinlogRecordReader(schema, neededFields, blobs, opts...) + }, + } +} + func ValueDeserializerWithSelectedFields(r Record, v []*Value, fieldSchema []*schemapb.FieldSchema, shouldCopy bool) error { return valueDeserializer(r, v, fieldSchema, shouldCopy) } @@ -358,46 +336,14 @@ func valueDeserializer(r Record, v []*Value, fields []*schemapb.FieldSchema, sho return nil } -func NewBinlogDeserializeReader(schema *schemapb.CollectionSchema, blobsReader ChunkedBlobsReader, shouldCopy bool) (*DeserializeReaderImpl[*Value], error) { - reader, err := newCompositeBinlogRecordReader(schema, blobsReader) - if err != nil { - return nil, err - } - +func NewBinlogDeserializeReader(schema *schemapb.CollectionSchema, blobsReader ChunkedBlobsReader, shouldCopy bool, +) (*DeserializeReaderImpl[*Value], error) { + reader := newIterativeCompositeBinlogRecordReader(schema, nil, blobsReader) return NewDeserializeReader(reader, func(r Record, v []*Value) error { return ValueDeserializerWithSchema(r, v, schema, shouldCopy) }), nil } -func newDeltalogOneFieldReader(blobs []*Blob) (*DeserializeReaderImpl[*DeleteLog], error) { - reader, err := newCompositeBinlogRecordReader( - &schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ - { - DataType: schemapb.DataType_VarChar, - }, - }, - }, - MakeBlobsReader(blobs)) - if err != nil { - return nil, err - } - return NewDeserializeReader(reader, func(r Record, v []*DeleteLog) error { - for i := 0; i < r.Len(); i++ { - if v[i] == nil { - v[i] = &DeleteLog{} - } - // retrieve the only field - a := r.(*compositeRecord).recs[0].(*array.String) - strVal := a.Value(i) - if err := v[i].Parse(strVal); err != nil { - return err - } - } - return nil - }), nil -} - type HeaderExtraWriterOption func(header *descriptorEvent) func WithEncryptionKey(ezID int64, edek []byte) HeaderExtraWriterOption { @@ -663,15 +609,15 @@ type CompositeBinlogRecordWriter struct { chunkSize uint64 rootPath string maxRowNum int64 - pkstats *PrimaryKeyStats - bm25Stats map[int64]*BM25Stats // writers and stats generated at runtime - fieldWriters map[FieldID]*BinlogStreamWriter - rw RecordWriter - tsFrom typeutil.Timestamp - tsTo typeutil.Timestamp - rowNum int64 + fieldWriters map[FieldID]*BinlogStreamWriter + rw RecordWriter + pkCollector *PkStatsCollector + bm25Collector *Bm25StatsCollector + tsFrom typeutil.Timestamp + tsTo typeutil.Timestamp + rowNum int64 // results fieldBinlogs map[FieldID]*datapb.FieldBinlog @@ -689,6 +635,7 @@ func (c *CompositeBinlogRecordWriter) Write(r Record) error { return err } + // Track timestamps tsArray := r.Column(common.TimeStampField).(*array.Int64) rows := r.Len() for i := 0; i < rows; i++ { @@ -699,37 +646,20 @@ func (c *CompositeBinlogRecordWriter) Write(r Record) error { if ts > c.tsTo { c.tsTo = ts } + } - switch schemapb.DataType(c.pkstats.PkType) { - case schemapb.DataType_Int64: - pkArray := r.Column(c.pkstats.FieldID).(*array.Int64) - pk := &Int64PrimaryKey{ - Value: pkArray.Value(i), - } - c.pkstats.Update(pk) - case schemapb.DataType_VarChar: - pkArray := r.Column(c.pkstats.FieldID).(*array.String) - pk := &VarCharPrimaryKey{ - Value: pkArray.Value(i), - } - c.pkstats.Update(pk) - default: - panic("invalid data type") - } - - for fieldID, stats := range c.bm25Stats { - field, ok := r.Column(fieldID).(*array.Binary) - if !ok { - return errors.New("bm25 field value not found") - } - stats.AppendBytes(field.Value(i)) - } + // Collect statistics + if err := c.pkCollector.Collect(r); err != nil { + return err + } + if err := c.bm25Collector.Collect(r); err != nil { + return err } if err := c.rw.Write(r); err != nil { return err } - c.rowNum += int64(rows) + c.rowNum += int64(r.Len()) // flush if size exceeds chunk size if c.rw.GetWrittenUncompressed() >= c.chunkSize { @@ -763,18 +693,15 @@ func (c *CompositeBinlogRecordWriter) resetWriters() { } func (c *CompositeBinlogRecordWriter) Close() error { - if err := c.writeStats(); err != nil { - return err - } - if err := c.writeBm25Stats(); err != nil { - return err - } if c.rw != nil { // if rw is not nil, it means there is data to be flushed if err := c.FlushChunk(); err != nil { return err } } + if err := c.writeStats(); err != nil { + return err + } return nil } @@ -846,89 +773,39 @@ func (c *CompositeBinlogRecordWriter) Schema() *schemapb.CollectionSchema { } func (c *CompositeBinlogRecordWriter) writeStats() error { - if c.pkstats == nil { - return nil - } - - id, err := c.allocator.AllocOne() + // Write PK stats + pkStatsMap, err := c.pkCollector.Digest( + c.collectionID, + c.partitionID, + c.segmentID, + c.rootPath, + c.rowNum, + c.allocator, + c.BlobsWriter, + ) if err != nil { return err } + // Extract single PK stats from map + for _, statsLog := range pkStatsMap { + c.statsLog = statsLog + break + } - codec := NewInsertCodecWithSchema(&etcdpb.CollectionMeta{ - ID: c.collectionID, - Schema: c.schema, - }) - sblob, err := codec.SerializePkStats(c.pkstats, c.rowNum) + // Write BM25 stats + bm25StatsLog, err := c.bm25Collector.Digest( + c.collectionID, + c.partitionID, + c.segmentID, + c.rootPath, + c.rowNum, + c.allocator, + c.BlobsWriter, + ) if err != nil { return err } - - sblob.Key = metautil.BuildStatsLogPath(c.rootPath, - c.collectionID, c.partitionID, c.segmentID, c.pkstats.FieldID, id) - - if err := c.BlobsWriter([]*Blob{sblob}); err != nil { - return err - } - - c.statsLog = &datapb.FieldBinlog{ - FieldID: c.pkstats.FieldID, - Binlogs: []*datapb.Binlog{ - { - LogSize: int64(len(sblob.GetValue())), - MemorySize: int64(len(sblob.GetValue())), - LogPath: sblob.Key, - EntriesNum: c.rowNum, - }, - }, - } - return nil -} - -func (c *CompositeBinlogRecordWriter) writeBm25Stats() error { - if len(c.bm25Stats) == 0 { - return nil - } - id, _, err := c.allocator.Alloc(uint32(len(c.bm25Stats))) - if err != nil { - return err - } - - if c.bm25StatsLog == nil { - c.bm25StatsLog = make(map[FieldID]*datapb.FieldBinlog) - } - for fid, stats := range c.bm25Stats { - bytes, err := stats.Serialize() - if err != nil { - return err - } - key := metautil.BuildBm25LogPath(c.rootPath, - c.collectionID, c.partitionID, c.segmentID, fid, id) - blob := &Blob{ - Key: key, - Value: bytes, - RowNum: stats.NumRow(), - MemorySize: int64(len(bytes)), - } - if err := c.BlobsWriter([]*Blob{blob}); err != nil { - return err - } - - fieldLog := &datapb.FieldBinlog{ - FieldID: fid, - Binlogs: []*datapb.Binlog{ - { - LogSize: int64(len(blob.GetValue())), - MemorySize: int64(len(blob.GetValue())), - LogPath: key, - EntriesNum: c.rowNum, - }, - }, - } - - c.bm25StatsLog[fid] = fieldLog - id++ - } + c.bm25StatsLog = bm25StatsLog return nil } @@ -949,26 +826,7 @@ func newCompositeBinlogRecordWriter(collectionID, partitionID, segmentID UniqueI blobsWriter ChunkedBlobsWriter, allocator allocator.Interface, chunkSize uint64, rootPath string, maxRowNum int64, options ...StreamWriterOption, ) (*CompositeBinlogRecordWriter, error) { - pkField, err := typeutil.GetPrimaryFieldSchema(schema) - if err != nil { - return nil, err - } - stats, err := NewPrimaryKeyStats(pkField.GetFieldID(), int64(pkField.GetDataType()), maxRowNum) - if err != nil { - return nil, err - } - bm25FieldIDs := lo.FilterMap(schema.GetFunctions(), func(function *schemapb.FunctionSchema, _ int) (int64, bool) { - if function.GetType() == schemapb.FunctionType_BM25 { - return function.GetOutputFieldIds()[0], true - } - return 0, false - }) - bm25Stats := make(map[int64]*BM25Stats, len(bm25FieldIDs)) - for _, fid := range bm25FieldIDs { - bm25Stats[fid] = NewBM25Stats() - } - - return &CompositeBinlogRecordWriter{ + writer := &CompositeBinlogRecordWriter{ collectionID: collectionID, partitionID: partitionID, segmentID: segmentID, @@ -978,10 +836,25 @@ func newCompositeBinlogRecordWriter(collectionID, partitionID, segmentID UniqueI chunkSize: chunkSize, rootPath: rootPath, maxRowNum: maxRowNum, - pkstats: stats, - bm25Stats: bm25Stats, options: options, - }, nil + tsFrom: math.MaxUint64, + tsTo: 0, + } + + // Create stats collectors + var err error + writer.pkCollector, err = NewPkStatsCollector( + collectionID, + schema, + maxRowNum, + ) + if err != nil { + return nil, err + } + + writer.bm25Collector = NewBm25StatsCollector(schema) + + return writer, nil } // BinlogValueWriter is a BinlogRecordWriter with SerializeWriter[*Value] mixin. @@ -1034,452 +907,3 @@ func NewBinlogSerializeWriter(schema *schemapb.CollectionSchema, partitionID, se }, batchSize), }, nil } - -type DeltalogStreamWriter struct { - collectionID UniqueID - partitionID UniqueID - segmentID UniqueID - fieldSchema *schemapb.FieldSchema - - buf bytes.Buffer - rw *singleFieldRecordWriter -} - -func (dsw *DeltalogStreamWriter) GetRecordWriter() (RecordWriter, error) { - if dsw.rw != nil { - return dsw.rw, nil - } - rw, err := newSingleFieldRecordWriter(dsw.fieldSchema, &dsw.buf, WithRecordWriterProps(getFieldWriterProps(dsw.fieldSchema))) - if err != nil { - return nil, err - } - dsw.rw = rw - return rw, nil -} - -func (dsw *DeltalogStreamWriter) Finalize() (*Blob, error) { - if dsw.rw == nil { - return nil, io.ErrUnexpectedEOF - } - dsw.rw.Close() - - var b bytes.Buffer - if err := dsw.writeDeltalogHeaders(&b); err != nil { - return nil, err - } - if _, err := b.Write(dsw.buf.Bytes()); err != nil { - return nil, err - } - return &Blob{ - Value: b.Bytes(), - RowNum: int64(dsw.rw.numRows), - MemorySize: int64(dsw.rw.writtenUncompressed), - }, nil -} - -func (dsw *DeltalogStreamWriter) writeDeltalogHeaders(w io.Writer) error { - // Write magic number - if err := binary.Write(w, common.Endian, MagicNumber); err != nil { - return err - } - // Write descriptor - de := NewBaseDescriptorEvent(dsw.collectionID, dsw.partitionID, dsw.segmentID) - de.PayloadDataType = dsw.fieldSchema.DataType - de.descriptorEventData.AddExtra(originalSizeKey, strconv.Itoa(int(dsw.rw.writtenUncompressed))) - if err := de.Write(w); err != nil { - return err - } - // Write event header - eh := newEventHeader(DeleteEventType) - // Write event data - ev := newDeleteEventData() - ev.StartTimestamp = 1 - ev.EndTimestamp = 1 - eh.EventLength = int32(dsw.buf.Len()) + eh.GetMemoryUsageInBytes() + int32(binary.Size(ev)) - // eh.NextPosition = eh.EventLength + w.Offset() - if err := eh.Write(w); err != nil { - return err - } - if err := ev.WriteEventData(w); err != nil { - return err - } - return nil -} - -func newDeltalogStreamWriter(collectionID, partitionID, segmentID UniqueID) *DeltalogStreamWriter { - return &DeltalogStreamWriter{ - collectionID: collectionID, - partitionID: partitionID, - segmentID: segmentID, - fieldSchema: &schemapb.FieldSchema{ - FieldID: common.RowIDField, - Name: "delta", - DataType: schemapb.DataType_String, - }, - } -} - -func newDeltalogSerializeWriter(eventWriter *DeltalogStreamWriter, batchSize int) (*SerializeWriterImpl[*DeleteLog], error) { - rws := make(map[FieldID]RecordWriter, 1) - rw, err := eventWriter.GetRecordWriter() - if err != nil { - return nil, err - } - rws[0] = rw - compositeRecordWriter := NewCompositeRecordWriter(rws) - return NewSerializeRecordWriter(compositeRecordWriter, func(v []*DeleteLog) (Record, error) { - builder := array.NewBuilder(memory.DefaultAllocator, arrow.BinaryTypes.String) - - for _, vv := range v { - strVal, err := json.Marshal(vv) - if err != nil { - return nil, err - } - - builder.AppendValueFromString(string(strVal)) - } - arr := []arrow.Array{builder.NewArray()} - field := []arrow.Field{{ - Name: "delta", - Type: arrow.BinaryTypes.String, - Nullable: false, - }} - field2Col := map[FieldID]int{ - 0: 0, - } - return NewSimpleArrowRecord(array.NewRecord(arrow.NewSchema(field, nil), arr, int64(len(v))), field2Col), nil - }, batchSize), nil -} - -var _ RecordReader = (*simpleArrowRecordReader)(nil) - -type simpleArrowRecordReader struct { - blobs []*Blob - - blobPos int - rr array.RecordReader - closer func() - - r simpleArrowRecord -} - -func (crr *simpleArrowRecordReader) iterateNextBatch() error { - if crr.closer != nil { - crr.closer() - } - - crr.blobPos++ - if crr.blobPos >= len(crr.blobs) { - return io.EOF - } - - reader, err := NewBinlogReader(crr.blobs[crr.blobPos].Value) - if err != nil { - return err - } - - er, err := reader.NextEventReader() - if err != nil { - return err - } - rr, err := er.GetArrowRecordReader() - if err != nil { - return err - } - crr.rr = rr - crr.closer = func() { - crr.rr.Release() - er.Close() - reader.Close() - } - - return nil -} - -func (crr *simpleArrowRecordReader) Next() (Record, error) { - if crr.rr == nil { - if len(crr.blobs) == 0 { - return nil, io.EOF - } - crr.blobPos = -1 - crr.r = simpleArrowRecord{ - field2Col: make(map[FieldID]int), - } - if err := crr.iterateNextBatch(); err != nil { - return nil, err - } - } - - composeRecord := func() bool { - if ok := crr.rr.Next(); !ok { - return false - } - record := crr.rr.Record() - for i := range record.Schema().Fields() { - crr.r.field2Col[FieldID(i)] = i - } - crr.r.r = record - return true - } - - if ok := composeRecord(); !ok { - if err := crr.iterateNextBatch(); err != nil { - return nil, err - } - if ok := composeRecord(); !ok { - return nil, io.EOF - } - } - return &crr.r, nil -} - -func (crr *simpleArrowRecordReader) SetNeededFields(_ typeutil.Set[int64]) { - // no-op for simple arrow record reader -} - -func (crr *simpleArrowRecordReader) Close() error { - if crr.closer != nil { - crr.closer() - } - return nil -} - -func newSimpleArrowRecordReader(blobs []*Blob) (*simpleArrowRecordReader, error) { - return &simpleArrowRecordReader{ - blobs: blobs, - }, nil -} - -func newMultiFieldDeltalogStreamWriter(collectionID, partitionID, segmentID UniqueID, pkType schemapb.DataType) *MultiFieldDeltalogStreamWriter { - return &MultiFieldDeltalogStreamWriter{ - collectionID: collectionID, - partitionID: partitionID, - segmentID: segmentID, - pkType: pkType, - } -} - -type MultiFieldDeltalogStreamWriter struct { - collectionID UniqueID - partitionID UniqueID - segmentID UniqueID - pkType schemapb.DataType - - buf bytes.Buffer - rw *multiFieldRecordWriter -} - -func (dsw *MultiFieldDeltalogStreamWriter) GetRecordWriter() (RecordWriter, error) { - if dsw.rw != nil { - return dsw.rw, nil - } - - fieldIDs := []FieldID{common.RowIDField, common.TimeStampField} // Not used. - fields := []arrow.Field{ - { - Name: "pk", - Type: serdeMap[dsw.pkType].arrowType(0, schemapb.DataType_None), - Nullable: false, - }, - { - Name: "ts", - Type: arrow.PrimitiveTypes.Int64, - Nullable: false, - }, - } - - rw, err := newMultiFieldRecordWriter(fieldIDs, fields, &dsw.buf) - if err != nil { - return nil, err - } - dsw.rw = rw - return rw, nil -} - -func (dsw *MultiFieldDeltalogStreamWriter) Finalize() (*Blob, error) { - if dsw.rw == nil { - return nil, io.ErrUnexpectedEOF - } - dsw.rw.Close() - - var b bytes.Buffer - if err := dsw.writeDeltalogHeaders(&b); err != nil { - return nil, err - } - if _, err := b.Write(dsw.buf.Bytes()); err != nil { - return nil, err - } - return &Blob{ - Value: b.Bytes(), - RowNum: int64(dsw.rw.numRows), - MemorySize: int64(dsw.rw.writtenUncompressed), - }, nil -} - -func (dsw *MultiFieldDeltalogStreamWriter) writeDeltalogHeaders(w io.Writer) error { - // Write magic number - if err := binary.Write(w, common.Endian, MagicNumber); err != nil { - return err - } - // Write descriptor - de := NewBaseDescriptorEvent(dsw.collectionID, dsw.partitionID, dsw.segmentID) - de.PayloadDataType = schemapb.DataType_Int64 - de.descriptorEventData.AddExtra(originalSizeKey, strconv.Itoa(int(dsw.rw.writtenUncompressed))) - de.descriptorEventData.AddExtra(version, MultiField) - if err := de.Write(w); err != nil { - return err - } - // Write event header - eh := newEventHeader(DeleteEventType) - // Write event data - ev := newDeleteEventData() - ev.StartTimestamp = 1 - ev.EndTimestamp = 1 - eh.EventLength = int32(dsw.buf.Len()) + eh.GetMemoryUsageInBytes() + int32(binary.Size(ev)) - // eh.NextPosition = eh.EventLength + w.Offset() - if err := eh.Write(w); err != nil { - return err - } - if err := ev.WriteEventData(w); err != nil { - return err - } - return nil -} - -func newDeltalogMultiFieldWriter(eventWriter *MultiFieldDeltalogStreamWriter, batchSize int) (*SerializeWriterImpl[*DeleteLog], error) { - rw, err := eventWriter.GetRecordWriter() - if err != nil { - return nil, err - } - return NewSerializeRecordWriter[*DeleteLog](rw, func(v []*DeleteLog) (Record, error) { - fields := []arrow.Field{ - { - Name: "pk", - Type: serdeMap[schemapb.DataType(v[0].PkType)].arrowType(0, schemapb.DataType_None), - Nullable: false, - }, - { - Name: "ts", - Type: arrow.PrimitiveTypes.Int64, - Nullable: false, - }, - } - arrowSchema := arrow.NewSchema(fields, nil) - builder := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema) - defer builder.Release() - - pkType := schemapb.DataType(v[0].PkType) - switch pkType { - case schemapb.DataType_Int64: - pb := builder.Field(0).(*array.Int64Builder) - for _, vv := range v { - pk := vv.Pk.GetValue().(int64) - pb.Append(pk) - } - case schemapb.DataType_VarChar: - pb := builder.Field(0).(*array.StringBuilder) - for _, vv := range v { - pk := vv.Pk.GetValue().(string) - pb.Append(pk) - } - default: - return nil, fmt.Errorf("unexpected pk type %v", v[0].PkType) - } - - for _, vv := range v { - builder.Field(1).(*array.Int64Builder).Append(int64(vv.Ts)) - } - - arr := []arrow.Array{builder.Field(0).NewArray(), builder.Field(1).NewArray()} - - field2Col := map[FieldID]int{ - common.RowIDField: 0, - common.TimeStampField: 1, - } - return NewSimpleArrowRecord(array.NewRecord(arrowSchema, arr, int64(len(v))), field2Col), nil - }, batchSize), nil -} - -func newDeltalogMultiFieldReader(blobs []*Blob) (*DeserializeReaderImpl[*DeleteLog], error) { - reader, err := newSimpleArrowRecordReader(blobs) - if err != nil { - return nil, err - } - return NewDeserializeReader(reader, func(r Record, v []*DeleteLog) error { - rec, ok := r.(*simpleArrowRecord) - if !ok { - return errors.New("can not cast to simple arrow record") - } - fields := rec.r.Schema().Fields() - switch fields[0].Type.ID() { - case arrow.INT64: - arr := r.Column(0).(*array.Int64) - for j := 0; j < r.Len(); j++ { - if v[j] == nil { - v[j] = &DeleteLog{} - } - v[j].Pk = NewInt64PrimaryKey(arr.Value(j)) - } - case arrow.STRING: - arr := r.Column(0).(*array.String) - for j := 0; j < r.Len(); j++ { - if v[j] == nil { - v[j] = &DeleteLog{} - } - v[j].Pk = NewVarCharPrimaryKey(arr.Value(j)) - } - default: - return fmt.Errorf("unexpected delta log pkType %v", fields[0].Type.Name()) - } - - arr := r.Column(1).(*array.Int64) - for j := 0; j < r.Len(); j++ { - v[j].Ts = uint64(arr.Value(j)) - } - return nil - }), nil -} - -// NewDeltalogDeserializeReader is the entry point for the delta log reader. -// It includes NewDeltalogOneFieldReader, which uses the existing log format with only one column in a log file, -// and NewDeltalogMultiFieldReader, which uses the new format and supports multiple fields in a log file. -func newDeltalogDeserializeReader(blobs []*Blob) (*DeserializeReaderImpl[*DeleteLog], error) { - if supportMultiFieldFormat(blobs) { - return newDeltalogMultiFieldReader(blobs) - } - return newDeltalogOneFieldReader(blobs) -} - -// check delta log description data to see if it is the format with -// pk and ts column separately -func supportMultiFieldFormat(blobs []*Blob) bool { - if len(blobs) > 0 { - reader, err := NewBinlogReader(blobs[0].Value) - if err != nil { - return false - } - defer reader.Close() - version := reader.descriptorEventData.Extras[version] - return version != nil && version.(string) == MultiField - } - return false -} - -func CreateDeltalogReader(blobs []*Blob) (*DeserializeReaderImpl[*DeleteLog], error) { - return newDeltalogDeserializeReader(blobs) -} - -func CreateDeltalogWriter(collectionID, partitionID, segmentID UniqueID, pkType schemapb.DataType, batchSize int, -) (*SerializeWriterImpl[*DeleteLog], func() (*Blob, error), error) { - format := paramtable.Get().DataNodeCfg.DeltalogFormat.GetValue() - if format == "json" { - eventWriter := newDeltalogStreamWriter(collectionID, partitionID, segmentID) - writer, err := newDeltalogSerializeWriter(eventWriter, batchSize) - return writer, eventWriter.Finalize, err - } else if format == "parquet" { - eventWriter := newMultiFieldDeltalogStreamWriter(collectionID, partitionID, segmentID, pkType) - writer, err := newDeltalogMultiFieldWriter(eventWriter, batchSize) - return writer, eventWriter.Finalize, err - } - return nil, nil, merr.WrapErrParameterInvalid("unsupported deltalog format %s", format) -} diff --git a/internal/storage/serde_events_v2.go b/internal/storage/serde_events_v2.go index 404ad3481a..699c944f81 100644 --- a/internal/storage/serde_events_v2.go +++ b/internal/storage/serde_events_v2.go @@ -23,7 +23,6 @@ import ( "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" - "github.com/cockroachdb/errors" "github.com/samber/lo" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -31,9 +30,7 @@ import ( "github.com/milvus-io/milvus/internal/storagecommon" "github.com/milvus-io/milvus/internal/storagev2/packed" "github.com/milvus-io/milvus/pkg/v2/common" - "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" - "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb" "github.com/milvus-io/milvus/pkg/v2/proto/indexcgopb" "github.com/milvus-io/milvus/pkg/v2/proto/indexpb" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -43,63 +40,18 @@ import ( ) type packedRecordReader struct { - paths [][]string - chunk int - reader *packed.PackedReader - - bufferSize int64 - arrowSchema *arrow.Schema - field2Col map[FieldID]int - storageConfig *indexpb.StorageConfig - storagePluginContext *indexcgopb.StoragePluginContext + reader *packed.PackedReader + field2Col map[FieldID]int } var _ RecordReader = (*packedRecordReader)(nil) -func (pr *packedRecordReader) iterateNextBatch() error { - if pr.reader != nil { - if err := pr.reader.Close(); err != nil { - return err - } - } - - if pr.chunk >= len(pr.paths) { - return io.EOF - } - - reader, err := packed.NewPackedReader(pr.paths[pr.chunk], pr.arrowSchema, pr.bufferSize, pr.storageConfig, pr.storagePluginContext) - pr.chunk++ - if err != nil { - return errors.Newf("New binlog record packed reader error: %w", err) - } - pr.reader = reader - return nil -} - func (pr *packedRecordReader) Next() (Record, error) { - if pr.reader == nil { - if err := pr.iterateNextBatch(); err != nil { - return nil, err - } + rec, err := pr.reader.ReadNext() + if err != nil { + return nil, err } - - for { - rec, err := pr.reader.ReadNext() - if err == io.EOF { - if err := pr.iterateNextBatch(); err != nil { - return nil, err - } - continue - } else if err != nil { - return nil, err - } - return NewSimpleArrowRecord(rec, pr.field2Col), nil - } -} - -func (pr *packedRecordReader) SetNeededFields(fields typeutil.Set[int64]) { - // TODO, push down SetNeededFields to packedReader after implemented - // no-op for now + return NewSimpleArrowRecord(rec, pr.field2Col), nil } func (pr *packedRecordReader) Close() error { @@ -109,7 +61,12 @@ func (pr *packedRecordReader) Close() error { return nil } -func newPackedRecordReader(paths [][]string, schema *schemapb.CollectionSchema, bufferSize int64, storageConfig *indexpb.StorageConfig, storagePluginContext *indexcgopb.StoragePluginContext, +func newPackedRecordReader( + paths []string, + schema *schemapb.CollectionSchema, + bufferSize int64, + storageConfig *indexpb.StorageConfig, + storagePluginContext *indexcgopb.StoragePluginContext, ) (*packedRecordReader, error) { arrowSchema, err := ConvertToArrowSchema(schema) if err != nil { @@ -120,27 +77,34 @@ func newPackedRecordReader(paths [][]string, schema *schemapb.CollectionSchema, for i, field := range allFields { field2Col[field.FieldID] = i } - return &packedRecordReader{ - paths: paths, - bufferSize: bufferSize, - arrowSchema: arrowSchema, - field2Col: field2Col, - storageConfig: storageConfig, - storagePluginContext: storagePluginContext, - }, nil -} - -// Deprecated -func NewPackedDeserializeReader(paths [][]string, schema *schemapb.CollectionSchema, - bufferSize int64, shouldCopy bool, -) (*DeserializeReaderImpl[*Value], error) { - reader, err := newPackedRecordReader(paths, schema, bufferSize, nil, nil) + reader, err := packed.NewPackedReader(paths, arrowSchema, bufferSize, storageConfig, storagePluginContext) if err != nil { return nil, err } - return NewDeserializeReader(reader, func(r Record, v []*Value) error { - return ValueDeserializerWithSchema(r, v, schema, shouldCopy) - }), nil + return &packedRecordReader{ + reader: reader, + field2Col: field2Col, + }, nil +} + +func newIterativePackedRecordReader( + paths [][]string, + schema *schemapb.CollectionSchema, + bufferSize int64, + storageConfig *indexpb.StorageConfig, + storagePluginContext *indexcgopb.StoragePluginContext, +) *IterativeRecordReader { + chunk := 0 + return &IterativeRecordReader{ + iterate: func() (RecordReader, error) { + if chunk >= len(paths) { + return nil, io.EOF + } + currentPaths := paths[chunk] + chunk++ + return newPackedRecordReader(currentPaths, schema, bufferSize, storageConfig, storagePluginContext) + }, + } } var _ RecordWriter = (*packedRecordWriter)(nil) @@ -236,7 +200,22 @@ func (pw *packedRecordWriter) Close() error { return nil } -func NewPackedRecordWriter(bucketName string, paths []string, schema *schemapb.CollectionSchema, bufferSize int64, multiPartUploadSize int64, columnGroups []storagecommon.ColumnGroup, storageConfig *indexpb.StorageConfig, storagePluginContext *indexcgopb.StoragePluginContext) (*packedRecordWriter, error) { +func NewPackedRecordWriter( + bucketName string, + paths []string, + schema *schemapb.CollectionSchema, + bufferSize int64, + multiPartUploadSize int64, + columnGroups []storagecommon.ColumnGroup, + storageConfig *indexpb.StorageConfig, + storagePluginContext *indexcgopb.StoragePluginContext, +) (*packedRecordWriter, error) { + // Validate PK field exists before proceeding + _, err := typeutil.GetPrimaryFieldSchema(schema) + if err != nil { + return nil, err + } + arrowSchema, err := ConvertToArrowSchema(schema) if err != nil { return nil, merr.WrapErrServiceInternal( @@ -320,8 +299,8 @@ type PackedBinlogRecordWriter struct { // writer and stats generated at runtime writer *packedRecordWriter - pkstats *PrimaryKeyStats - bm25Stats map[int64]*BM25Stats + pkCollector *PkStatsCollector + bm25Collector *Bm25StatsCollector tsFrom typeutil.Timestamp tsTo typeutil.Timestamp rowNum int64 @@ -338,6 +317,7 @@ func (pw *PackedBinlogRecordWriter) Write(r Record) error { return err } + // Track timestamps tsArray := r.Column(common.TimeStampField).(*array.Int64) rows := r.Len() for i := 0; i < rows; i++ { @@ -348,31 +328,14 @@ func (pw *PackedBinlogRecordWriter) Write(r Record) error { if ts > pw.tsTo { pw.tsTo = ts } + } - switch schemapb.DataType(pw.pkstats.PkType) { - case schemapb.DataType_Int64: - pkArray := r.Column(pw.pkstats.FieldID).(*array.Int64) - pk := &Int64PrimaryKey{ - Value: pkArray.Value(i), - } - pw.pkstats.Update(pk) - case schemapb.DataType_VarChar: - pkArray := r.Column(pw.pkstats.FieldID).(*array.String) - pk := &VarCharPrimaryKey{ - Value: pkArray.Value(i), - } - pw.pkstats.Update(pk) - default: - panic("invalid data type") - } - - for fieldID, stats := range pw.bm25Stats { - field, ok := r.Column(fieldID).(*array.Binary) - if !ok { - return errors.New("bm25 field value not found") - } - stats.AppendBytes(field.Value(i)) - } + // Collect statistics + if err := pw.pkCollector.Collect(r); err != nil { + return err + } + if err := pw.bm25Collector.Collect(r); err != nil { + return err } err := pw.writer.Write(r) @@ -433,9 +396,6 @@ func (pw *PackedBinlogRecordWriter) Close() error { if err := pw.writeStats(); err != nil { return err } - if err := pw.writeBm25Stats(); err != nil { - return err - } return nil } @@ -467,89 +427,39 @@ func (pw *PackedBinlogRecordWriter) finalizeBinlogs() { } func (pw *PackedBinlogRecordWriter) writeStats() error { - if pw.pkstats == nil { - return nil - } - - id, err := pw.allocator.AllocOne() + // Write PK stats + pkStatsMap, err := pw.pkCollector.Digest( + pw.collectionID, + pw.partitionID, + pw.segmentID, + pw.storageConfig.GetRootPath(), + pw.rowNum, + pw.allocator, + pw.BlobsWriter, + ) if err != nil { return err } + // Extract single PK stats from map + for _, statsLog := range pkStatsMap { + pw.statsLog = statsLog + break + } - codec := NewInsertCodecWithSchema(&etcdpb.CollectionMeta{ - ID: pw.collectionID, - Schema: pw.schema, - }) - sblob, err := codec.SerializePkStats(pw.pkstats, pw.rowNum) + // Write BM25 stats + bm25StatsLog, err := pw.bm25Collector.Digest( + pw.collectionID, + pw.partitionID, + pw.segmentID, + pw.storageConfig.GetRootPath(), + pw.rowNum, + pw.allocator, + pw.BlobsWriter, + ) if err != nil { return err } - - sblob.Key = metautil.BuildStatsLogPath(pw.storageConfig.GetRootPath(), - pw.collectionID, pw.partitionID, pw.segmentID, pw.pkstats.FieldID, id) - - if err := pw.BlobsWriter([]*Blob{sblob}); err != nil { - return err - } - - pw.statsLog = &datapb.FieldBinlog{ - FieldID: pw.pkstats.FieldID, - Binlogs: []*datapb.Binlog{ - { - LogSize: int64(len(sblob.GetValue())), - MemorySize: int64(len(sblob.GetValue())), - LogPath: sblob.Key, - EntriesNum: pw.rowNum, - }, - }, - } - return nil -} - -func (pw *PackedBinlogRecordWriter) writeBm25Stats() error { - if len(pw.bm25Stats) == 0 { - return nil - } - id, _, err := pw.allocator.Alloc(uint32(len(pw.bm25Stats))) - if err != nil { - return err - } - - if pw.bm25StatsLog == nil { - pw.bm25StatsLog = make(map[FieldID]*datapb.FieldBinlog) - } - for fid, stats := range pw.bm25Stats { - bytes, err := stats.Serialize() - if err != nil { - return err - } - key := metautil.BuildBm25LogPath(pw.storageConfig.GetRootPath(), - pw.collectionID, pw.partitionID, pw.segmentID, fid, id) - blob := &Blob{ - Key: key, - Value: bytes, - RowNum: stats.NumRow(), - MemorySize: int64(len(bytes)), - } - if err := pw.BlobsWriter([]*Blob{blob}); err != nil { - return err - } - - fieldLog := &datapb.FieldBinlog{ - FieldID: fid, - Binlogs: []*datapb.Binlog{ - { - LogSize: int64(len(blob.GetValue())), - MemorySize: int64(len(blob.GetValue())), - LogPath: key, - EntriesNum: pw.rowNum, - }, - }, - } - - pw.bm25StatsLog[fid] = fieldLog - id++ - } + pw.bm25StatsLog = bm25StatsLog return nil } @@ -587,27 +497,8 @@ func newPackedBinlogRecordWriter(collectionID, partitionID, segmentID UniqueID, if err != nil { return nil, merr.WrapErrParameterInvalid("convert collection schema [%s] to arrow schema error: %s", schema.Name, err.Error()) } - pkField, err := typeutil.GetPrimaryFieldSchema(schema) - if err != nil { - log.Warn("failed to get pk field from schema") - return nil, err - } - stats, err := NewPrimaryKeyStats(pkField.GetFieldID(), int64(pkField.GetDataType()), maxRowNum) - if err != nil { - return nil, err - } - bm25FieldIDs := lo.FilterMap(schema.GetFunctions(), func(function *schemapb.FunctionSchema, _ int) (int64, bool) { - if function.GetType() == schemapb.FunctionType_BM25 { - return function.GetOutputFieldIds()[0], true - } - return 0, false - }) - bm25Stats := make(map[int64]*BM25Stats, len(bm25FieldIDs)) - for _, fid := range bm25FieldIDs { - bm25Stats[fid] = NewBM25Stats() - } - return &PackedBinlogRecordWriter{ + writer := &PackedBinlogRecordWriter{ collectionID: collectionID, partitionID: partitionID, segmentID: segmentID, @@ -619,12 +510,23 @@ func newPackedBinlogRecordWriter(collectionID, partitionID, segmentID UniqueID, bufferSize: bufferSize, multiPartUploadSize: multiPartUploadSize, columnGroups: columnGroups, - pkstats: stats, - bm25Stats: bm25Stats, storageConfig: storageConfig, storagePluginContext: storagePluginContext, + tsFrom: typeutil.MaxTimestamp, + tsTo: 0, + } - tsFrom: typeutil.MaxTimestamp, - tsTo: 0, - }, nil + // Create stats collectors + writer.pkCollector, err = NewPkStatsCollector( + collectionID, + schema, + maxRowNum, + ) + if err != nil { + return nil, err + } + + writer.bm25Collector = NewBm25StatsCollector(schema) + + return writer, nil } diff --git a/internal/storage/serde_events_v2_test.go b/internal/storage/serde_events_v2_test.go index becaaf5768..2bd288dba0 100644 --- a/internal/storage/serde_events_v2_test.go +++ b/internal/storage/serde_events_v2_test.go @@ -24,14 +24,15 @@ import ( "github.com/milvus-io/milvus/internal/storagecommon" "github.com/milvus-io/milvus/internal/util/initcore" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) func TestPackedSerde(t *testing.T) { t.Run("test binlog packed serde v2", func(t *testing.T) { - t.Skip("storage v2 cgo not ready yet") + paramtable.Get().Save(paramtable.Get().CommonCfg.StorageType.Key, "local") initcore.InitLocalArrowFileSystem("/tmp") size := 10 - bucketName := "a-bucket" + bucketName := "" paths := [][]string{{"/tmp/0"}, {"/tmp/1"}} bufferSize := int64(10 * 1024 * 1024) // 10MB schema := generateTestSchema() @@ -70,16 +71,18 @@ func TestPackedSerde(t *testing.T) { prepareChunkData(chunkPaths, size) } - reader, err := NewPackedDeserializeReader(paths, schema, bufferSize, false) - assert.NoError(t, err) + reader := newIterativePackedRecordReader(paths, schema, bufferSize, nil, nil) defer reader.Close() - for i := 0; i < size*len(paths); i++ { - value, err := reader.NextValue() + nRows := 0 + for { + rec, err := reader.Next() + if err == io.EOF { + break + } assert.NoError(t, err) - assertTestData(t, i%10+1, *value) + nRows += rec.Len() } - _, err = reader.NextValue() - assert.Equal(t, err, io.EOF) + assert.Equal(t, size*len(paths), nRows) }) } diff --git a/internal/storage/sort_test.go b/internal/storage/sort_test.go index 82dd73d9f5..22ecc50422 100644 --- a/internal/storage/sort_test.go +++ b/internal/storage/sort_test.go @@ -31,12 +31,10 @@ func TestSort(t *testing.T) { getReaders := func() []RecordReader { blobs, err := generateTestDataWithSeed(10, 3) assert.NoError(t, err) - reader10, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs)) - assert.NoError(t, err) + reader10 := newIterativeCompositeBinlogRecordReader(generateTestSchema(), nil, MakeBlobsReader(blobs)) blobs, err = generateTestDataWithSeed(20, 3) assert.NoError(t, err) - reader20, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs)) - assert.NoError(t, err) + reader20 := newIterativeCompositeBinlogRecordReader(generateTestSchema(), nil, MakeBlobsReader(blobs)) rr := []RecordReader{reader20, reader10} return rr } @@ -82,12 +80,10 @@ func TestMergeSort(t *testing.T) { getReaders := func() []RecordReader { blobs, err := generateTestDataWithSeed(1000, 5000) assert.NoError(t, err) - reader10, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs)) - assert.NoError(t, err) + reader10 := newIterativeCompositeBinlogRecordReader(generateTestSchema(), nil, MakeBlobsReader(blobs)) blobs, err = generateTestDataWithSeed(4000, 5000) assert.NoError(t, err) - reader20, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs)) - assert.NoError(t, err) + reader20 := newIterativeCompositeBinlogRecordReader(generateTestSchema(), nil, MakeBlobsReader(blobs)) rr := []RecordReader{reader20, reader10} return rr } @@ -138,12 +134,10 @@ func BenchmarkSort(b *testing.B) { batch := 500000 blobs, err := generateTestDataWithSeed(batch, batch) assert.NoError(b, err) - reader10, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs)) - assert.NoError(b, err) + reader10 := newIterativeCompositeBinlogRecordReader(generateTestSchema(), nil, MakeBlobsReader(blobs)) blobs, err = generateTestDataWithSeed(batch*2+1, batch) assert.NoError(b, err) - reader20, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs)) - assert.NoError(b, err) + reader20 := newIterativeCompositeBinlogRecordReader(generateTestSchema(), nil, MakeBlobsReader(blobs)) rr := []RecordReader{reader20, reader10} rw := &MockRecordWriter{ @@ -174,12 +168,10 @@ func TestSortByMoreThanOneField(t *testing.T) { blobs, err := generateTestDataWithSeed(10, batchSize) assert.NoError(t, err) - reader10, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs)) - assert.NoError(t, err) + reader10 := newIterativeCompositeBinlogRecordReader(generateTestSchema(), nil, MakeBlobsReader(blobs)) blobs, err = generateTestDataWithSeed(20, batchSize) assert.NoError(t, err) - reader20, err := newCompositeBinlogRecordReader(generateTestSchema(), MakeBlobsReader(blobs)) - assert.NoError(t, err) + reader20 := newIterativeCompositeBinlogRecordReader(generateTestSchema(), nil, MakeBlobsReader(blobs)) rr := []RecordReader{reader20, reader10} lastPK := int64(-1) diff --git a/internal/storage/stats_collector.go b/internal/storage/stats_collector.go new file mode 100644 index 0000000000..df4b310c99 --- /dev/null +++ b/internal/storage/stats_collector.go @@ -0,0 +1,278 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "strconv" + + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/cockroachdb/errors" + "github.com/samber/lo" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/pkg/v2/proto/datapb" + "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb" + "github.com/milvus-io/milvus/pkg/v2/util/metautil" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +// StatsCollector collects statistics from records +type StatsCollector interface { + // Collect collects statistics from a record + Collect(r Record) error + // Digest serializes the collected statistics, writes them to storage, + // and returns the field binlog metadata + Digest( + collectionID, partitionID, segmentID UniqueID, + rootPath string, + rowNum int64, + allocator allocator.Interface, + blobsWriter ChunkedBlobsWriter, + ) (map[FieldID]*datapb.FieldBinlog, error) +} + +// PkStatsCollector collects primary key statistics +type PkStatsCollector struct { + pkstats *PrimaryKeyStats + collectionID UniqueID // needed for initializing codecs, TODO: remove this + schema *schemapb.CollectionSchema +} + +// Collect collects primary key stats from the record +func (c *PkStatsCollector) Collect(r Record) error { + if c.pkstats == nil { + return nil + } + + rows := r.Len() + for i := 0; i < rows; i++ { + switch schemapb.DataType(c.pkstats.PkType) { + case schemapb.DataType_Int64: + pkArray := r.Column(c.pkstats.FieldID).(*array.Int64) + pk := &Int64PrimaryKey{ + Value: pkArray.Value(i), + } + c.pkstats.Update(pk) + case schemapb.DataType_VarChar: + pkArray := r.Column(c.pkstats.FieldID).(*array.String) + pk := &VarCharPrimaryKey{ + Value: pkArray.Value(i), + } + c.pkstats.Update(pk) + default: + panic("invalid data type") + } + } + return nil +} + +// Digest serializes the collected primary key statistics, writes them to storage, +// and returns the field binlog metadata +func (c *PkStatsCollector) Digest( + collectionID, partitionID, segmentID UniqueID, + rootPath string, + rowNum int64, + allocator allocator.Interface, + blobsWriter ChunkedBlobsWriter, +) (map[FieldID]*datapb.FieldBinlog, error) { + if c.pkstats == nil { + return nil, nil + } + + // Serialize PK stats + codec := NewInsertCodecWithSchema(&etcdpb.CollectionMeta{ + ID: c.collectionID, + Schema: c.schema, + }) + sblob, err := codec.SerializePkStats(c.pkstats, rowNum) + if err != nil { + return nil, err + } + + // Get pk field ID + pkField, err := typeutil.GetPrimaryFieldSchema(c.schema) + if err != nil { + return nil, err + } + + // Allocate ID for stats blob + id, err := allocator.AllocOne() + if err != nil { + return nil, err + } + + // Assign proper path to the blob + fieldID := pkField.GetFieldID() + sblob.Key = metautil.BuildStatsLogPath(rootPath, + c.collectionID, partitionID, segmentID, fieldID, id) + + // Write the blob + if err := blobsWriter([]*Blob{sblob}); err != nil { + return nil, err + } + + // Return as map for interface consistency + return map[FieldID]*datapb.FieldBinlog{ + fieldID: { + FieldID: fieldID, + Binlogs: []*datapb.Binlog{ + { + LogSize: int64(len(sblob.GetValue())), + MemorySize: int64(len(sblob.GetValue())), + LogPath: sblob.Key, + EntriesNum: rowNum, + }, + }, + }, + }, nil +} + +// NewPkStatsCollector creates a new primary key stats collector +func NewPkStatsCollector( + collectionID UniqueID, + schema *schemapb.CollectionSchema, + maxRowNum int64, +) (*PkStatsCollector, error) { + pkField, err := typeutil.GetPrimaryFieldSchema(schema) + if err != nil { + return nil, err + } + stats, err := NewPrimaryKeyStats(pkField.GetFieldID(), int64(pkField.GetDataType()), maxRowNum) + if err != nil { + return nil, err + } + + return &PkStatsCollector{ + pkstats: stats, + collectionID: collectionID, + schema: schema, + }, nil +} + +// Bm25StatsCollector collects BM25 statistics +type Bm25StatsCollector struct { + bm25Stats map[int64]*BM25Stats +} + +// Collect collects BM25 statistics from the record +func (c *Bm25StatsCollector) Collect(r Record) error { + if len(c.bm25Stats) == 0 { + return nil + } + + rows := r.Len() + for fieldID, stats := range c.bm25Stats { + field, ok := r.Column(fieldID).(*array.Binary) + if !ok { + return errors.New("bm25 field value not found") + } + for i := 0; i < rows; i++ { + stats.AppendBytes(field.Value(i)) + } + } + return nil +} + +// Digest serializes the collected BM25 statistics, writes them to storage, +// and returns the field binlog metadata +func (c *Bm25StatsCollector) Digest( + collectionID, partitionID, segmentID UniqueID, + rootPath string, + rowNum int64, + allocator allocator.Interface, + blobsWriter ChunkedBlobsWriter, +) (map[FieldID]*datapb.FieldBinlog, error) { + if len(c.bm25Stats) == 0 { + return nil, nil + } + + // Serialize BM25 stats into blobs + blobs := make([]*Blob, 0, len(c.bm25Stats)) + for fid, stats := range c.bm25Stats { + bytes, err := stats.Serialize() + if err != nil { + return nil, err + } + blob := &Blob{ + Key: strconv.FormatInt(fid, 10), // temporary key, will be replaced below + Value: bytes, + RowNum: stats.NumRow(), + MemorySize: int64(len(bytes)), + } + blobs = append(blobs, blob) + } + + // Allocate IDs for stats blobs + id, _, err := allocator.Alloc(uint32(len(blobs))) + if err != nil { + return nil, err + } + + result := make(map[FieldID]*datapb.FieldBinlog) + + // Process each blob and assign proper paths + for _, blob := range blobs { + // Parse the field ID from the temporary key + fieldID, parseErr := strconv.ParseInt(blob.Key, 10, 64) + if parseErr != nil { + // This should not happen for BM25 blobs + continue + } + + blob.Key = metautil.BuildBm25LogPath(rootPath, + collectionID, partitionID, segmentID, fieldID, id) + + result[fieldID] = &datapb.FieldBinlog{ + FieldID: fieldID, + Binlogs: []*datapb.Binlog{ + { + LogSize: int64(len(blob.GetValue())), + MemorySize: int64(len(blob.GetValue())), + LogPath: blob.Key, + EntriesNum: rowNum, + }, + }, + } + id++ + } + + // Write all blobs + if err := blobsWriter(blobs); err != nil { + return nil, err + } + + return result, nil +} + +// NewBm25StatsCollector creates a new BM25 stats collector +func NewBm25StatsCollector(schema *schemapb.CollectionSchema) *Bm25StatsCollector { + bm25FieldIDs := lo.FilterMap(schema.GetFunctions(), func(function *schemapb.FunctionSchema, _ int) (int64, bool) { + if function.GetType() == schemapb.FunctionType_BM25 { + return function.GetOutputFieldIds()[0], true + } + return 0, false + }) + bm25Stats := make(map[int64]*BM25Stats, len(bm25FieldIDs)) + for _, fid := range bm25FieldIDs { + bm25Stats[fid] = NewBM25Stats() + } + + return &Bm25StatsCollector{ + bm25Stats: bm25Stats, + } +} diff --git a/internal/storage/stats_collector_test.go b/internal/storage/stats_collector_test.go new file mode 100644 index 0000000000..8862302029 --- /dev/null +++ b/internal/storage/stats_collector_test.go @@ -0,0 +1,261 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "fmt" + "testing" + + "github.com/apache/arrow/go/v17/arrow" + "github.com/apache/arrow/go/v17/arrow/array" + "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/pkg/v2/common" +) + +func TestPkStatsCollector(t *testing.T) { + collectionID := int64(1) + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: common.RowIDField, DataType: schemapb.DataType_Int64}, + {FieldID: common.TimeStampField, DataType: schemapb.DataType_Int64}, + { + FieldID: 100, + Name: "pk", + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + }, + }, + } + + t.Run("collect and digest int64 pk", func(t *testing.T) { + collector, err := NewPkStatsCollector(collectionID, schema, 100) + require.NoError(t, err) + require.NotNil(t, collector) + + // Create test record + fields := []arrow.Field{ + {Name: "pk", Type: arrow.PrimitiveTypes.Int64}, + } + arrowSchema := arrow.NewSchema(fields, nil) + builder := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema) + defer builder.Release() + + pkBuilder := builder.Field(0).(*array.Int64Builder) + for i := 0; i < 10; i++ { + pkBuilder.Append(int64(i)) + } + + rec := builder.NewRecord() + field2Col := map[FieldID]int{100: 0} + record := NewSimpleArrowRecord(rec, field2Col) + + // Collect stats + err = collector.Collect(record) + assert.NoError(t, err) + + // Digest stats + alloc := allocator.NewLocalAllocator(1, 100) + writer := func(blobs []*Blob) error { return nil } + + resultMap, err := collector.Digest(collectionID, 1, 2, "/tmp", 10, alloc, writer) + assert.NoError(t, err) + assert.NotNil(t, resultMap) + assert.Len(t, resultMap, 1) + }) + + t.Run("varchar pk", func(t *testing.T) { + varcharSchema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: common.RowIDField, DataType: schemapb.DataType_Int64}, + {FieldID: common.TimeStampField, DataType: schemapb.DataType_Int64}, + { + FieldID: 100, + Name: "pk", + DataType: schemapb.DataType_VarChar, + IsPrimaryKey: true, + }, + }, + } + + collector, err := NewPkStatsCollector(collectionID, varcharSchema, 100) + require.NoError(t, err) + + // Create test record with varchar pk + fields := []arrow.Field{ + {Name: "pk", Type: arrow.BinaryTypes.String}, + } + arrowSchema := arrow.NewSchema(fields, nil) + builder := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema) + defer builder.Release() + + pkBuilder := builder.Field(0).(*array.StringBuilder) + for i := 0; i < 10; i++ { + pkBuilder.Append(fmt.Sprintf("key_%d", i)) + } + + rec := builder.NewRecord() + field2Col := map[FieldID]int{100: 0} + record := NewSimpleArrowRecord(rec, field2Col) + + err = collector.Collect(record) + assert.NoError(t, err) + }) +} + +func TestBm25StatsCollector(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: common.RowIDField, DataType: schemapb.DataType_Int64}, + {FieldID: common.TimeStampField, DataType: schemapb.DataType_Int64}, + { + FieldID: 100, + Name: "text", + DataType: schemapb.DataType_VarChar, + }, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "bm25_function", + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{100}, + OutputFieldIds: []int64{101}, + OutputFieldNames: []string{"bm25_field"}, + }, + }, + } + + t.Run("collect bm25 stats", func(t *testing.T) { + collector := NewBm25StatsCollector(schema) + assert.NotNil(t, collector) + assert.NotNil(t, collector.bm25Stats) + }) + + t.Run("digest with empty stats", func(t *testing.T) { + collector := NewBm25StatsCollector(schema) + + alloc := allocator.NewLocalAllocator(1, 100) + writer := func(blobs []*Blob) error { return nil } + + _, err := collector.Digest(1, 1, 2, "/tmp", 10, alloc, writer) + assert.NoError(t, err) + }) +} + +func TestNewPkStatsCollector_NoPkField(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: common.RowIDField, DataType: schemapb.DataType_Int64}, + {FieldID: common.TimeStampField, DataType: schemapb.DataType_Int64}, + }, + } + + collector, err := NewPkStatsCollector(1, schema, 100) + assert.Error(t, err) + assert.Nil(t, collector) +} + +func TestPkStatsCollector_DigestEndToEnd(t *testing.T) { + collectionID := int64(1) + partitionID := int64(2) + segmentID := int64(3) + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + }, + }, + } + + collector, err := NewPkStatsCollector(collectionID, schema, 100) + require.NoError(t, err) + + // Create test record + fields := []arrow.Field{ + {Name: "pk", Type: arrow.PrimitiveTypes.Int64}, + } + arrowSchema := arrow.NewSchema(fields, nil) + builder := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema) + defer builder.Release() + + pkBuilder := builder.Field(0).(*array.Int64Builder) + for i := 0; i < 10; i++ { + pkBuilder.Append(int64(i)) + } + + rec := builder.NewRecord() + field2Col := map[FieldID]int{100: 0} + record := NewSimpleArrowRecord(rec, field2Col) + + err = collector.Collect(record) + require.NoError(t, err) + + alloc := allocator.NewLocalAllocator(1, 100) + var writtenBlobs []*Blob + writer := func(blobs []*Blob) error { + writtenBlobs = blobs + return nil + } + + // Test Digest which includes writing + binlogMap, err := collector.Digest(collectionID, partitionID, segmentID, + "/tmp", 10, alloc, writer) + assert.NoError(t, err) + assert.NotNil(t, binlogMap) + assert.Len(t, binlogMap, 1) + + binlog := binlogMap[100] + assert.NotNil(t, binlog) + assert.Equal(t, int64(100), binlog.FieldID) + assert.Len(t, binlog.Binlogs, 1) + assert.Contains(t, binlog.Binlogs[0].LogPath, "stats_log") + assert.NotNil(t, writtenBlobs) + assert.Len(t, writtenBlobs, 1) +} + +func TestBm25StatsCollector_DigestEndToEnd(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "text", DataType: schemapb.DataType_VarChar}, + }, + Functions: []*schemapb.FunctionSchema{ + { + Name: "bm25_function", + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{100}, + OutputFieldIds: []int64{101}, + OutputFieldNames: []string{"bm25_field"}, + }, + }, + } + + collector := NewBm25StatsCollector(schema) + + alloc := allocator.NewLocalAllocator(1, 100) + writer := func(blobs []*Blob) error { return nil } + + // Test with empty stats + _, err := collector.Digest(1, 2, 3, "/tmp", 10, alloc, writer) + assert.NoError(t, err) +}