diff --git a/internal/storage/payload_benchmark_test.go b/internal/storage/payload_benchmark_test.go index 4a89e025ed..2b2c4b1ee9 100644 --- a/internal/storage/payload_benchmark_test.go +++ b/internal/storage/payload_benchmark_test.go @@ -9,8 +9,8 @@ import ( // workload setting for benchmark const ( - numElements = 1000 - vectorDim = 8 + numElements = 10000 + vectorDim = 512 ) func BenchmarkPayloadReader_Bool(b *testing.B) { diff --git a/internal/storage/payload_reader.go b/internal/storage/payload_reader.go index a990285417..2c354ccafd 100644 --- a/internal/storage/payload_reader.go +++ b/internal/storage/payload_reader.go @@ -4,11 +4,11 @@ import ( "bytes" "errors" "fmt" - "reflect" "github.com/apache/arrow/go/v8/arrow" "github.com/apache/arrow/go/v8/parquet" "github.com/apache/arrow/go/v8/parquet/file" + "github.com/milvus-io/milvus/internal/proto/schemapb" ) @@ -80,20 +80,20 @@ func (r *PayloadReader) GetBoolFromPayload() ([]bool, error) { if r.colType != schemapb.DataType_Bool { return nil, fmt.Errorf("failed to get bool from datatype %v", r.colType.String()) } - dumper, err := r.createDumper() + reader, ok := r.reader.RowGroup(0).Column(0).(*file.BooleanColumnChunkReader) + if !ok { + return nil, fmt.Errorf("expect type *file.BooleanColumnChunkReader, but got %T", r.reader.RowGroup(0).Column(0)) + } + + values := make([]bool, r.numRows) + total, valuesRead, err := reader.ReadBatch(r.numRows, values, nil, nil) if err != nil { return nil, err } - ret := make([]bool, r.numRows) - var i int64 - for i = 0; i < r.numRows; i++ { - v, hasValue := dumper.Next() - if !hasValue { - return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i) - } - ret[i] = v.(bool) + if total != r.numRows || int64(valuesRead) != r.numRows { + return nil, fmt.Errorf("expect %d rows, but got total = %d and valuesRead = %d", r.numRows, total, valuesRead) } - return ret, nil + return values, nil } // GetByteFromPayload returns byte slice from payload @@ -101,18 +101,23 @@ func (r *PayloadReader) GetByteFromPayload() ([]byte, error) { if r.colType != schemapb.DataType_Int8 { return nil, fmt.Errorf("failed to get byte from datatype %v", r.colType.String()) } - dumper, err := r.createDumper() + reader, ok := r.reader.RowGroup(0).Column(0).(*file.Int32ColumnChunkReader) + if !ok { + return nil, fmt.Errorf("expect type *file.Int32ColumnChunkReader, but got %T", r.reader.RowGroup(0).Column(0)) + } + + values := make([]int32, r.numRows) + total, valuesRead, err := reader.ReadBatch(r.numRows, values, nil, nil) if err != nil { return nil, err } + if total != r.numRows || int64(valuesRead) != r.numRows { + return nil, fmt.Errorf("expect %d rows, but got total = %d and valuesRead = %d", r.numRows, total, valuesRead) + } + ret := make([]byte, r.numRows) - var i int64 - for i = 0; i < r.numRows; i++ { - v, hasValue := dumper.Next() - if !hasValue { - return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i) - } - ret[i] = byte(v.(int32)) + for i := int64(0); i < r.numRows; i++ { + ret[i] = byte(values[i]) } return ret, nil } @@ -122,19 +127,23 @@ func (r *PayloadReader) GetInt8FromPayload() ([]int8, error) { if r.colType != schemapb.DataType_Int8 { return nil, fmt.Errorf("failed to get int8 from datatype %v", r.colType.String()) } - dumper, err := r.createDumper() + + reader, ok := r.reader.RowGroup(0).Column(0).(*file.Int32ColumnChunkReader) + if !ok { + return nil, fmt.Errorf("expect type *file.Int32ColumnChunkReader, but got %T", r.reader.RowGroup(0).Column(0)) + } + values := make([]int32, r.numRows) + total, valuesRead, err := reader.ReadBatch(r.numRows, values, nil, nil) if err != nil { return nil, err } + if total != r.numRows || int64(valuesRead) != r.numRows { + return nil, fmt.Errorf("expect %d rows, but got total = %d and valuesRead = %d", r.numRows, total, valuesRead) + } + ret := make([]int8, r.numRows) - var i int64 - for i = 0; i < r.numRows; i++ { - v, hasValue := dumper.Next() - if !hasValue { - return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i) - } - // need to trasfer because parquet didn't support int8 - ret[i] = int8(v.(int32)) + for i := int64(0); i < r.numRows; i++ { + ret[i] = int8(values[i]) } return ret, nil } @@ -143,19 +152,23 @@ func (r *PayloadReader) GetInt16FromPayload() ([]int16, error) { if r.colType != schemapb.DataType_Int16 { return nil, fmt.Errorf("failed to get int16 from datatype %v", r.colType.String()) } - dumper, err := r.createDumper() + + reader, ok := r.reader.RowGroup(0).Column(0).(*file.Int32ColumnChunkReader) + if !ok { + return nil, fmt.Errorf("expect type *file.Int32ColumnChunkReader, but got %T", r.reader.RowGroup(0).Column(0)) + } + values := make([]int32, r.numRows) + total, valuesRead, err := reader.ReadBatch(r.numRows, values, nil, nil) if err != nil { return nil, err } + if total != r.numRows || int64(valuesRead) != r.numRows { + return nil, fmt.Errorf("expect %d rows, but got total = %d and valuesRead = %d", r.numRows, total, valuesRead) + } + ret := make([]int16, r.numRows) - var i int64 - for i = 0; i < r.numRows; i++ { - v, hasValue := dumper.Next() - if !hasValue { - return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i) - } - // need to trasfer because parquet didn't support int16 - ret[i] = int16(v.(int32)) + for i := int64(0); i < r.numRows; i++ { + ret[i] = int16(values[i]) } return ret, nil } @@ -164,98 +177,104 @@ func (r *PayloadReader) GetInt32FromPayload() ([]int32, error) { if r.colType != schemapb.DataType_Int32 { return nil, fmt.Errorf("failed to get int32 from datatype %v", r.colType.String()) } - dumper, err := r.createDumper() + + reader, ok := r.reader.RowGroup(0).Column(0).(*file.Int32ColumnChunkReader) + if !ok { + return nil, fmt.Errorf("expect type *file.Int32ColumnChunkReader, but got %T", r.reader.RowGroup(0).Column(0)) + } + values := make([]int32, r.numRows) + total, valuesRead, err := reader.ReadBatch(r.numRows, values, nil, nil) if err != nil { return nil, err } - ret := make([]int32, r.numRows) - var i int64 - for i = 0; i < r.numRows; i++ { - v, hasValue := dumper.Next() - if !hasValue { - return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i) - } - ret[i] = v.(int32) + if total != r.numRows || int64(valuesRead) != r.numRows { + return nil, fmt.Errorf("expect %d rows, but got total = %d and valuesRead = %d", r.numRows, total, valuesRead) } - return ret, nil + + return values, nil } func (r *PayloadReader) GetInt64FromPayload() ([]int64, error) { if r.colType != schemapb.DataType_Int64 { return nil, fmt.Errorf("failed to get int64 from datatype %v", r.colType.String()) } - dumper, err := r.createDumper() + reader, ok := r.reader.RowGroup(0).Column(0).(*file.Int64ColumnChunkReader) + if !ok { + return nil, fmt.Errorf("expect type *file.Int64ColumnChunkReader, but got %T", r.reader.RowGroup(0).Column(0)) + } + + values := make([]int64, r.numRows) + total, valuesRead, err := reader.ReadBatch(r.numRows, values, nil, nil) if err != nil { return nil, err } - ret := make([]int64, r.numRows) - var i int64 - for i = 0; i < r.numRows; i++ { - v, hasValue := dumper.Next() - if !hasValue { - return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i) - } - ret[i] = v.(int64) + if total != r.numRows || int64(valuesRead) != r.numRows { + return nil, fmt.Errorf("expect %d rows, but got total = %d and valuesRead = %d", r.numRows, total, valuesRead) } - return ret, nil + return values, nil } func (r *PayloadReader) GetFloatFromPayload() ([]float32, error) { if r.colType != schemapb.DataType_Float { return nil, fmt.Errorf("failed to get float32 from datatype %v", r.colType.String()) } - dumper, err := r.createDumper() + reader, ok := r.reader.RowGroup(0).Column(0).(*file.Float32ColumnChunkReader) + if !ok { + return nil, fmt.Errorf("expect type *file.Float32ColumnChunkReader, but got %T", r.reader.RowGroup(0).Column(0)) + } + + values := make([]float32, r.numRows) + total, valuesRead, err := reader.ReadBatch(r.numRows, values, nil, nil) if err != nil { return nil, err } - ret := make([]float32, r.numRows) - var i int64 - for i = 0; i < r.numRows; i++ { - v, hasValue := dumper.Next() - if !hasValue { - return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i) - } - ret[i] = v.(float32) + if total != r.numRows || int64(valuesRead) != r.numRows { + return nil, fmt.Errorf("expect %d rows, but got total = %d and valuesRead = %d", r.numRows, total, valuesRead) } - return ret, nil + return values, nil } func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) { if r.colType != schemapb.DataType_Double { - return nil, fmt.Errorf("failed to get double from datatype %v", r.colType.String()) + return nil, fmt.Errorf("failed to get float32 from datatype %v", r.colType.String()) } - dumper, err := r.createDumper() + reader, ok := r.reader.RowGroup(0).Column(0).(*file.Float64ColumnChunkReader) + if !ok { + return nil, fmt.Errorf("expect type *file.Float64ColumnChunkReader, but got %T", r.reader.RowGroup(0).Column(0)) + } + + values := make([]float64, r.numRows) + total, valuesRead, err := reader.ReadBatch(r.numRows, values, nil, nil) if err != nil { return nil, err } - ret := make([]float64, r.numRows) - var i int64 - for i = 0; i < r.numRows; i++ { - v, hasValue := dumper.Next() - if !hasValue { - return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i) - } - ret[i] = v.(float64) + if total != r.numRows || int64(valuesRead) != r.numRows { + return nil, fmt.Errorf("expect %d rows, but got total = %d and valuesRead = %d", r.numRows, total, valuesRead) } - return ret, nil + return values, nil } func (r *PayloadReader) GetStringFromPayload() ([]string, error) { if r.colType != schemapb.DataType_String && r.colType != schemapb.DataType_VarChar { return nil, fmt.Errorf("failed to get string from datatype %v", r.colType.String()) } - dumper, err := r.createDumper() + + reader, ok := r.reader.RowGroup(0).Column(0).(*file.ByteArrayColumnChunkReader) + if !ok { + return nil, fmt.Errorf("expect type *file.ByteArrayColumnChunkReader, but got %T", r.reader.RowGroup(0).Column(0)) + } + values := make([]parquet.ByteArray, r.numRows) + total, valuesRead, err := reader.ReadBatch(r.numRows, values, nil, nil) if err != nil { return nil, err } + if total != r.numRows || int64(valuesRead) != r.numRows { + return nil, fmt.Errorf("expect %d rows, but got total = %d and valuesRead = %d", r.numRows, total, valuesRead) + } + ret := make([]string, r.numRows) - var i int64 - for i = 0; i < r.numRows; i++ { - v, hasValue := dumper.Next() - if !hasValue { - return nil, fmt.Errorf("unmatched row number: expect %v, actual %v", r.numRows, i) - } - ret[i] = v.(parquet.ByteArray).String() + for i := 0; i < int(r.numRows); i++ { + ret[i] = values[i].String() } return ret, nil } @@ -265,20 +284,24 @@ func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) { if r.colType != schemapb.DataType_BinaryVector { return nil, -1, fmt.Errorf("failed to get binary vector from datatype %v", r.colType.String()) } - dumper, err := r.createDumper() - if err != nil { - return nil, -1, err + + reader, ok := r.reader.RowGroup(0).Column(0).(*file.FixedLenByteArrayColumnChunkReader) + if !ok { + return nil, -1, fmt.Errorf("expect type *file.FixedLenByteArrayColumnChunkReader, but got %T", r.reader.RowGroup(0).Column(0)) } dim := r.reader.RowGroup(0).Column(0).Descriptor().TypeLength() + values := make([]parquet.FixedLenByteArray, r.numRows) + total, valuesRead, err := reader.ReadBatch(r.numRows, values, nil, nil) + if err != nil { + return nil, -1, err + } + if total != r.numRows || int64(valuesRead) != r.numRows { + return nil, -1, fmt.Errorf("expect %d rows, but got total = %d and valuesRead = %d", r.numRows, total, valuesRead) + } ret := make([]byte, int64(dim)*r.numRows) for i := 0; i < int(r.numRows); i++ { - v, ok := dumper.Next() - if !ok { - return nil, -1, fmt.Errorf("unmatched row number: row %v, dim %v", r.numRows, dim) - } - parquetArray := v.(parquet.FixedLenByteArray) - copy(ret[i*dim:(i+1)*dim], parquetArray) + copy(ret[i*dim:(i+1)*dim], values[i]) } return ret, dim * 8, nil } @@ -288,20 +311,23 @@ func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) { if r.colType != schemapb.DataType_FloatVector { return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String()) } - dumper, err := r.createDumper() - if err != nil { - return nil, -1, err + reader, ok := r.reader.RowGroup(0).Column(0).(*file.FixedLenByteArrayColumnChunkReader) + if !ok { + return nil, -1, fmt.Errorf("expect type *file.FixedLenByteArrayColumnChunkReader, but got %T", r.reader.RowGroup(0).Column(0)) } dim := r.reader.RowGroup(0).Column(0).Descriptor().TypeLength() / 4 + values := make([]parquet.FixedLenByteArray, r.numRows) + total, valuesRead, err := reader.ReadBatch(r.numRows, values, nil, nil) + if err != nil { + return nil, -1, err + } + if total != r.numRows || int64(valuesRead) != r.numRows { + return nil, -1, fmt.Errorf("expect %d rows, but got total = %d and valuesRead = %d", r.numRows, total, valuesRead) + } ret := make([]float32, int64(dim)*r.numRows) for i := 0; i < int(r.numRows); i++ { - v, ok := dumper.Next() - if !ok { - return nil, -1, fmt.Errorf("unmatched row number: row %v, dim %v", r.numRows, dim) - } - parquetArray := v.(parquet.FixedLenByteArray) - copy(arrow.Float32Traits.CastToBytes(ret[i*dim:(i+1)*dim]), parquetArray) + copy(arrow.Float32Traits.CastToBytes(ret[i*dim:(i+1)*dim]), values[i]) } return ret, dim, nil } @@ -314,128 +340,3 @@ func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) { func (r *PayloadReader) Close() { r.reader.Close() } - -type Dumper struct { - reader file.ColumnChunkReader - batchSize int64 - valueOffset int - valuesBuffered int - - levelOffset int64 - levelsBuffered int64 - defLevels []int16 - repLevels []int16 - - valueBuffer interface{} -} - -func (r *PayloadReader) createDumper() (*Dumper, error) { - var valueBuffer interface{} - switch r.reader.RowGroup(0).Column(0).(type) { - case *file.BooleanColumnChunkReader: - if r.colType != schemapb.DataType_Bool { - return nil, errors.New("incorrect data type") - } - valueBuffer = make([]bool, r.numRows) - case *file.Int32ColumnChunkReader: - if r.colType != schemapb.DataType_Int32 && r.colType != schemapb.DataType_Int16 && r.colType != schemapb.DataType_Int8 { - return nil, fmt.Errorf("incorrect data type, expect int32/int16/int8 but find %v", r.colType.String()) - } - valueBuffer = make([]int32, r.numRows) - case *file.Int64ColumnChunkReader: - if r.colType != schemapb.DataType_Int64 { - return nil, fmt.Errorf("incorrect data type, expect int64 but find %v", r.colType.String()) - } - valueBuffer = make([]int64, r.numRows) - case *file.Float32ColumnChunkReader: - if r.colType != schemapb.DataType_Float { - return nil, fmt.Errorf("incorrect data type, expect float32 but find %v", r.colType.String()) - } - valueBuffer = make([]float32, r.numRows) - case *file.Float64ColumnChunkReader: - if r.colType != schemapb.DataType_Double { - return nil, fmt.Errorf("incorrect data type, expect float64 but find %v", r.colType.String()) - } - valueBuffer = make([]float64, r.numRows) - case *file.ByteArrayColumnChunkReader: - if r.colType != schemapb.DataType_String && r.colType != schemapb.DataType_VarChar { - return nil, fmt.Errorf("incorrect data type, expect string/varchar but find %v", r.colType.String()) - } - valueBuffer = make([]parquet.ByteArray, r.numRows) - case *file.FixedLenByteArrayColumnChunkReader: - if r.colType != schemapb.DataType_FloatVector && r.colType != schemapb.DataType_BinaryVector { - return nil, fmt.Errorf("incorrect data type, expect floavector/binaryvector but find %v", r.colType.String()) - } - valueBuffer = make([]parquet.FixedLenByteArray, r.numRows) - } - - return &Dumper{ - reader: r.reader.RowGroup(0).Column(0), - batchSize: r.numRows, - defLevels: make([]int16, r.numRows), - repLevels: make([]int16, r.numRows), - valueBuffer: valueBuffer, - }, nil -} - -func (dump *Dumper) readNextBatch() { - switch reader := dump.reader.(type) { - case *file.BooleanColumnChunkReader: - values := dump.valueBuffer.([]bool) - dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels) - case *file.Int32ColumnChunkReader: - values := dump.valueBuffer.([]int32) - dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels) - case *file.Int64ColumnChunkReader: - values := dump.valueBuffer.([]int64) - dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels) - case *file.Float32ColumnChunkReader: - values := dump.valueBuffer.([]float32) - dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels) - case *file.Float64ColumnChunkReader: - values := dump.valueBuffer.([]float64) - dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels) - case *file.Int96ColumnChunkReader: - values := dump.valueBuffer.([]parquet.Int96) - dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels) - case *file.ByteArrayColumnChunkReader: - values := dump.valueBuffer.([]parquet.ByteArray) - dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels) - case *file.FixedLenByteArrayColumnChunkReader: - values := dump.valueBuffer.([]parquet.FixedLenByteArray) - dump.levelsBuffered, dump.valuesBuffered, _ = reader.ReadBatch(dump.batchSize, values, dump.defLevels, dump.repLevels) - } - - dump.valueOffset = 0 - dump.levelOffset = 0 -} - -func (dump *Dumper) hasNext() bool { - return dump.levelOffset < dump.levelsBuffered || dump.reader.HasNext() -} - -func (dump *Dumper) Next() (interface{}, bool) { - if dump.levelOffset == dump.levelsBuffered { - if !dump.hasNext() { - return nil, false - } - dump.readNextBatch() - if dump.levelsBuffered == 0 { - return nil, false - } - } - - defLevel := dump.defLevels[int(dump.levelOffset)] - // repLevel := dump.repLevels[int(dump.levelOffset)] - dump.levelOffset++ - - if defLevel < dump.reader.Descriptor().MaxDefinitionLevel() { - return nil, true - } - - vb := reflect.ValueOf(dump.valueBuffer) - v := vb.Index(dump.valueOffset).Interface() - dump.valueOffset++ - - return v, true -} diff --git a/internal/storage/payload_test.go b/internal/storage/payload_test.go index 6e5f78c2b1..89615b82bd 100644 --- a/internal/storage/payload_test.go +++ b/internal/storage/payload_test.go @@ -20,12 +20,13 @@ import ( "fmt" "testing" - "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus/internal/proto/schemapb" ) -func TestPayload_ReaderandWriter(t *testing.T) { +func TestPayload_ReaderAndWriter(t *testing.T) { t.Run("TestBool", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) @@ -655,6 +656,27 @@ func TestPayload_ReaderandWriter(t *testing.T) { _, err = r.GetBoolFromPayload() assert.NotNil(t, err) }) + t.Run("TestGetBoolError2", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_Bool) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddBoolToPayload([]bool{true, false, true}) + assert.Nil(t, err) + + err = w.FinishPayloadWriter() + assert.Nil(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + assert.Nil(t, err) + + r, err := NewPayloadReader(schemapb.DataType_Bool, buffer) + assert.Nil(t, err) + + r.numRows = 99 + _, err = r.GetBoolFromPayload() + assert.NotNil(t, err) + }) t.Run("TestGetInt8Error", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) @@ -679,6 +701,27 @@ func TestPayload_ReaderandWriter(t *testing.T) { _, err = r.GetInt8FromPayload() assert.NotNil(t, err) }) + t.Run("TestGetInt8Error2", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_Int8) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddInt8ToPayload([]int8{1, 2, 3}) + assert.Nil(t, err) + + err = w.FinishPayloadWriter() + assert.Nil(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + assert.Nil(t, err) + + r, err := NewPayloadReader(schemapb.DataType_Int8, buffer) + assert.Nil(t, err) + + r.numRows = 99 + _, err = r.GetInt8FromPayload() + assert.NotNil(t, err) + }) t.Run("TestGetInt16Error", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) @@ -703,6 +746,27 @@ func TestPayload_ReaderandWriter(t *testing.T) { _, err = r.GetInt16FromPayload() assert.NotNil(t, err) }) + t.Run("TestGetInt16Error2", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_Int16) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddInt16ToPayload([]int16{1, 2, 3}) + assert.Nil(t, err) + + err = w.FinishPayloadWriter() + assert.Nil(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + assert.Nil(t, err) + + r, err := NewPayloadReader(schemapb.DataType_Int16, buffer) + assert.Nil(t, err) + + r.numRows = 99 + _, err = r.GetInt16FromPayload() + assert.NotNil(t, err) + }) t.Run("TestGetInt32Error", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) @@ -727,6 +791,27 @@ func TestPayload_ReaderandWriter(t *testing.T) { _, err = r.GetInt32FromPayload() assert.NotNil(t, err) }) + t.Run("TestGetInt32Error2", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_Int32) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddInt32ToPayload([]int32{1, 2, 3}) + assert.Nil(t, err) + + err = w.FinishPayloadWriter() + assert.Nil(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + assert.Nil(t, err) + + r, err := NewPayloadReader(schemapb.DataType_Int32, buffer) + assert.Nil(t, err) + + r.numRows = 99 + _, err = r.GetInt32FromPayload() + assert.NotNil(t, err) + }) t.Run("TestGetInt64Error", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) @@ -751,6 +836,27 @@ func TestPayload_ReaderandWriter(t *testing.T) { _, err = r.GetInt64FromPayload() assert.NotNil(t, err) }) + t.Run("TestGetInt64Error2", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_Int64) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddInt64ToPayload([]int64{1, 2, 3}) + assert.Nil(t, err) + + err = w.FinishPayloadWriter() + assert.Nil(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + assert.Nil(t, err) + + r, err := NewPayloadReader(schemapb.DataType_Int64, buffer) + assert.Nil(t, err) + + r.numRows = 99 + _, err = r.GetInt64FromPayload() + assert.NotNil(t, err) + }) t.Run("TestGetFloatError", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) @@ -775,6 +881,27 @@ func TestPayload_ReaderandWriter(t *testing.T) { _, err = r.GetFloatFromPayload() assert.NotNil(t, err) }) + t.Run("TestGetFloatError2", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_Float) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddFloatToPayload([]float32{1, 2, 3}) + assert.Nil(t, err) + + err = w.FinishPayloadWriter() + assert.Nil(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + assert.Nil(t, err) + + r, err := NewPayloadReader(schemapb.DataType_Float, buffer) + assert.Nil(t, err) + + r.numRows = 99 + _, err = r.GetFloatFromPayload() + assert.NotNil(t, err) + }) t.Run("TestGetDoubleError", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) @@ -799,6 +926,27 @@ func TestPayload_ReaderandWriter(t *testing.T) { _, err = r.GetDoubleFromPayload() assert.NotNil(t, err) }) + t.Run("TestGetDoubleError2", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_Double) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddDoubleToPayload([]float64{1, 2, 3}) + assert.Nil(t, err) + + err = w.FinishPayloadWriter() + assert.Nil(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + assert.Nil(t, err) + + r, err := NewPayloadReader(schemapb.DataType_Double, buffer) + assert.Nil(t, err) + + r.numRows = 99 + _, err = r.GetDoubleFromPayload() + assert.NotNil(t, err) + }) t.Run("TestGetStringError", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) @@ -823,6 +971,31 @@ func TestPayload_ReaderandWriter(t *testing.T) { _, err = r.GetStringFromPayload() assert.NotNil(t, err) }) + t.Run("TestGetStringError2", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_String) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddOneStringToPayload("hello0") + assert.Nil(t, err) + err = w.AddOneStringToPayload("hello1") + assert.Nil(t, err) + err = w.AddOneStringToPayload("hello2") + assert.Nil(t, err) + + err = w.FinishPayloadWriter() + assert.Nil(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + assert.Nil(t, err) + + r, err := NewPayloadReader(schemapb.DataType_String, buffer) + assert.Nil(t, err) + + r.numRows = 99 + _, err = r.GetStringFromPayload() + assert.NotNil(t, err) + }) t.Run("TestGetBinaryVectorError", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) @@ -847,6 +1020,27 @@ func TestPayload_ReaderandWriter(t *testing.T) { _, _, err = r.GetBinaryVectorFromPayload() assert.NotNil(t, err) }) + t.Run("TestGetBinaryVectorError2", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_BinaryVector) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8) + assert.Nil(t, err) + + err = w.FinishPayloadWriter() + assert.Nil(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + assert.Nil(t, err) + + r, err := NewPayloadReader(schemapb.DataType_BinaryVector, buffer) + assert.Nil(t, err) + + r.numRows = 99 + _, _, err = r.GetBinaryVectorFromPayload() + assert.NotNil(t, err) + }) t.Run("TestGetFloatVectorError", func(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_Bool) require.Nil(t, err) @@ -871,6 +1065,27 @@ func TestPayload_ReaderandWriter(t *testing.T) { _, _, err = r.GetFloatVectorFromPayload() assert.NotNil(t, err) }) + t.Run("TestGetFloatVectorError2", func(t *testing.T) { + w, err := NewPayloadWriter(schemapb.DataType_FloatVector) + require.Nil(t, err) + require.NotNil(t, w) + + err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8) + assert.Nil(t, err) + + err = w.FinishPayloadWriter() + assert.Nil(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + assert.Nil(t, err) + + r, err := NewPayloadReader(schemapb.DataType_FloatVector, buffer) + assert.Nil(t, err) + + r.numRows = 99 + _, _, err = r.GetFloatVectorFromPayload() + assert.NotNil(t, err) + }) t.Run("TestWriteLargeSizeData", func(t *testing.T) { t.Skip("Large data skip for online ut")