From 712d1644d8b55c3e2d00740b04dd2e896165db03 Mon Sep 17 00:00:00 2001 From: groot Date: Mon, 31 Mar 2025 14:20:31 +0800 Subject: [PATCH] enhance: bulkinsert supports parsing sparse vector form parquet struct (#40874) issue: https://github.com/milvus-io/milvus/issues/40777 pr: https://github.com/milvus-io/milvus/pull/40927 Signed-off-by: yhmo --- .../util/importutilv2/parquet/field_reader.go | 264 ++++++++++++++++-- .../importutilv2/parquet/file_reader_test.go | 175 ++++++++++++ .../util/importutilv2/parquet/reader_test.go | 115 ++++++++ internal/util/importutilv2/parquet/util.go | 43 +++ internal/util/testutil/test_util.go | 133 +++++++-- 5 files changed, 694 insertions(+), 36 deletions(-) diff --git a/internal/util/importutilv2/parquet/field_reader.go b/internal/util/importutilv2/parquet/field_reader.go index c3cecd5dd1..eeda07e31d 100644 --- a/internal/util/importutilv2/parquet/field_reader.go +++ b/internal/util/importutilv2/parquet/field_reader.go @@ -40,8 +40,9 @@ type FieldReader struct { columnIndex int columnReader *pqarrow.ColumnReader - dim int - field *schemapb.FieldSchema + dim int + field *schemapb.FieldSchema + sparseIsString bool } func NewFieldReader(ctx context.Context, reader *pqarrow.FileReader, columnIndex int, field *schemapb.FieldSchema) (*FieldReader, error) { @@ -58,11 +59,19 @@ func NewFieldReader(ctx context.Context, reader *pqarrow.FileReader, columnIndex } } + // set a flag here to know whether a sparse vector is stored as JSON-format string or parquet struct + // because we don't intend to check it every time the Next() is called + sparseIsString := true + if field.GetDataType() == schemapb.DataType_SparseFloatVector { + _, sparseIsString = IsValidSparseVectorSchema(columnReader.Field().Type) + } + cr := &FieldReader{ - columnIndex: columnIndex, - columnReader: columnReader, - dim: int(dim), - field: field, + columnIndex: columnIndex, + columnReader: columnReader, + dim: int(dim), + field: field, + sparseIsString: sparseIsString, } return cr, nil } @@ -416,6 +425,74 @@ func ReadNullableIntegerOrFloatData[T constraints.Integer | constraints.Float](p return data, validData, nil } +// This method returns a []map[string]arrow.Array +// map[string]arrow.Array represents a struct +// For example 1: +// +// struct { +// name string +// age int +// } +// +// The ReadStructData() will return a list like: +// +// [ +// {"name": ["a", "b", "c"], "age": [4, 5, 6]}, +// {"name": ["e", "f"], "age": [7, 8]} +// ] +// +// Value type of "name" is array.String, value type of "age" is array.Int32 +// The length of the list is equal to the length of chunked.Chunks() +// +// For sparse vector, the map[string]arrow.Array is like {"indices": array.List, "values": array.List} +// For example 2: +// +// struct { +// indices []uint32 +// values []float32 +// } +// +// The ReadStructData() will return a list like: +// +// [ +// {"indices": [[1, 2, 3], [4, 5], [6, 7]], "values": [[0.1, 0.2, 0.3], [0.4, 0.5], [0.6, 0.7]]}, +// {"indices": [[8], [9, 10]], "values": [[0.8], [0.9, 1.0]]} +// ] +// +// Value type of "indices" is array.List, element type is array.Uint32 +// Value type of "values" is array.List, element type is array.Float32 +// The length of the list is equal to the length of chunked.Chunks() +// +// Note: now the ReadStructData() is used by SparseVector type and SparseVector is not nullable, +// create a new method ReadNullableStructData() if we have nullable struct type in future. +func ReadStructData(pcr *FieldReader, count int64) ([]map[string]arrow.Array, error) { + chunked, err := pcr.columnReader.NextBatch(count) + if err != nil { + return nil, err + } + data := make([]map[string]arrow.Array, 0, count) + for _, chunk := range chunked.Chunks() { + structReader, ok := chunk.(*array.Struct) + if structReader.NullN() > 0 { + return nil, merr.WrapErrParameterInvalidMsg("has null value, but struct doesn't support nullable yet") + } + if !ok { + return nil, WrapTypeErr("struct", chunk.DataType().Name(), pcr.field) + } + + structType := structReader.DataType().(*arrow.StructType) + st := make(map[string]arrow.Array) + for k, field := range structType.Fields() { + st[field.Name] = structReader.Field(k) + } + data = append(data, st) + } + if len(data) == 0 { + return nil, nil + } + return data, nil +} + func ReadStringData(pcr *FieldReader, count int64) (any, error) { chunked, err := pcr.columnReader.NextBatch(count) if err != nil { @@ -670,8 +747,165 @@ func parseSparseFloatRowVector(str string) ([]byte, uint32, error) { return rowVec, maxIdx, nil } +// This method accepts input from ReadStructData() +// For sparse vector, the map[string]arrow.Array is like {"indices": array.List, "values": array.List} +// Although "indices" and "values" is two-dim list, the array.List provides ListValues() and ValueOffsets() +// to return one-dim list. We use the start/end position of ValueOffsets() to get the correct sparse vector +// from ListValues(). +// Note that arrow.Uint32.Value(int i) accepts an int32 value, the max length of indices/values is max value of int32 +func parseSparseFloatVectorStructs(structs []map[string]arrow.Array) ([][]byte, uint32, error) { + byteArr := make([][]byte, 0) + maxDim := uint32(0) + for _, st := range structs { + indices, ok1 := st[sparseVectorIndice] + values, ok2 := st[sparseVectorValues] + if !ok1 || !ok2 { + return nil, 0, merr.WrapErrImportFailed("Invalid parquet struct for SparseFloatVector: 'indices' or 'values' missed") + } + + indicesList, ok1 := indices.(*array.List) + valuesList, ok2 := values.(*array.List) + if !ok1 || !ok2 { + return nil, 0, merr.WrapErrImportFailed("Invalid parquet struct for SparseFloatVector: 'indices' or 'values' is not list") + } + + // Len() is the number of rows in this row group + if indices.Len() != values.Len() { + msg := fmt.Sprintf("Invalid parquet struct for SparseFloatVector: number of rows of 'indices' and 'values' mismatched, '%d' vs '%d'", indices.Len(), values.Len()) + return nil, 0, merr.WrapErrImportFailed(msg) + } + + // technically, DataType() of array.List must be arrow.ListType, but we still check here to ensure safety + indicesListType, ok1 := indicesList.DataType().(*arrow.ListType) + valuesListType, ok2 := valuesList.DataType().(*arrow.ListType) + if !ok1 || !ok2 { + return nil, 0, merr.WrapErrImportFailed("Invalid parquet struct for SparseFloatVector: incorrect arrow type of 'indices' or 'values'") + } + + indexDataType := indicesListType.Elem().ID() + valueDataType := valuesListType.Elem().ID() + + // The array.Uint32/array.Int64/array.Float32/array.Float64 are derived from arrow.Array + // The ListValues() returns arrow.Array interface, but the arrow.Array doesn't have Value(int) interface + // To call array.Uint32.Value(int), we need to explicitly cast the ListValues() to array.Uint32 + // So, we declare two methods here to avoid type casting in the "for" loop + type GetIndex func(position int) uint32 + type GetValue func(position int) float32 + + var getIndexFunc GetIndex + switch indexDataType { + case arrow.INT32: + indicesList := indicesList.ListValues().(*array.Int32) + getIndexFunc = func(position int) uint32 { + return (uint32)(indicesList.Value(position)) + } + case arrow.UINT32: + indicesList := indicesList.ListValues().(*array.Uint32) + getIndexFunc = func(position int) uint32 { + return indicesList.Value(position) + } + case arrow.INT64: + indicesList := indicesList.ListValues().(*array.Int64) + getIndexFunc = func(position int) uint32 { + return (uint32)(indicesList.Value(position)) + } + case arrow.UINT64: + indicesList := indicesList.ListValues().(*array.Uint64) + getIndexFunc = func(position int) uint32 { + return (uint32)(indicesList.Value(position)) + } + default: + msg := fmt.Sprintf("Invalid parquet struct for SparseFloatVector: index type must be uint32/int32/uint64/int64 but actual type is '%s'", indicesListType.Elem().Name()) + return nil, 0, merr.WrapErrImportFailed(msg) + } + + var getValueFunc GetValue + switch valueDataType { + case arrow.FLOAT32: + valuesList := valuesList.ListValues().(*array.Float32) + getValueFunc = func(position int) float32 { + return valuesList.Value(position) + } + case arrow.FLOAT64: + valuesList := valuesList.ListValues().(*array.Float64) + getValueFunc = func(position int) float32 { + return (float32)(valuesList.Value(position)) + } + default: + msg := fmt.Sprintf("Invalid parquet struct for SparseFloatVector: value type must be float32 or float64 but actual type is '%s'", valuesListType.Elem().Name()) + return nil, 0, merr.WrapErrImportFailed(msg) + } + + for i := 0; i < indicesList.Len(); i++ { + start, end := indicesList.ValueOffsets(i) + start2, end2 := valuesList.ValueOffsets(i) + rowLen := (int)(end - start) + rowLenValues := (int)(end2 - start2) + if rowLenValues != rowLen { + msg := fmt.Sprintf("Invalid parquet struct for SparseFloatVector: number of elements of 'indices' and 'values' mismatched, '%d' vs '%d'", rowLen, rowLenValues) + return nil, 0, merr.WrapErrImportFailed(msg) + } + + rowIndices := make([]uint32, rowLen) + rowValues := make([]float32, rowLen) + for i := start; i < end; i++ { + rowIndices[i-start] = getIndexFunc((int)(i)) + rowValues[i-start] = getValueFunc((int)(i)) + } + + // ensure the indices is sorted + sortedIndices, sortedValues := typeutil.SortSparseFloatRow(rowIndices, rowValues) + rowVec := typeutil.CreateSparseFloatRow(sortedIndices, sortedValues) + if err := typeutil.ValidateSparseFloatRows(rowVec); err != nil { + return byteArr, maxDim, err + } + + // set the maxDim as the last value of sortedIndices since it has been sorted + if len(sortedIndices) > 0 && sortedIndices[len(sortedIndices)-1] > maxDim { + maxDim = sortedIndices[len(sortedIndices)-1] + } + byteArr = append(byteArr, rowVec) // rowVec could be an empty sparse + } + } + return byteArr, maxDim, nil +} + func ReadSparseFloatVectorData(pcr *FieldReader, count int64) (any, error) { - data, err := ReadStringData(pcr, count) + // read sparse vector from JSON-format string + if pcr.sparseIsString { + data, err := ReadStringData(pcr, count) + if err != nil { + return nil, err + } + if data == nil { + return nil, nil + } + + byteArr := make([][]byte, 0, count) + maxDim := uint32(0) + + for _, str := range data.([]string) { + rowVec, rowMaxIdx, err := parseSparseFloatRowVector(str) + if err != nil { + return nil, err + } + + byteArr = append(byteArr, rowVec) + if rowMaxIdx > maxDim { + maxDim = rowMaxIdx + } + } + + return &storage.SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Dim: int64(maxDim), + Contents: byteArr, + }, + }, nil + } + + // read sparse vector from parquet struct + data, err := ReadStructData(pcr, count) if err != nil { return nil, err } @@ -679,19 +913,9 @@ func ReadSparseFloatVectorData(pcr *FieldReader, count int64) (any, error) { return nil, nil } - byteArr := make([][]byte, 0, count) - maxDim := uint32(0) - - for _, str := range data.([]string) { - rowVec, rowMaxIdx, err := parseSparseFloatRowVector(str) - if err != nil { - return nil, err - } - - byteArr = append(byteArr, rowVec) - if rowMaxIdx > maxDim { - maxDim = rowMaxIdx - } + byteArr, maxDim, err := parseSparseFloatVectorStructs(data) + if err != nil { + return nil, err } return &storage.SparseFloatVectorFieldData{ diff --git a/internal/util/importutilv2/parquet/file_reader_test.go b/internal/util/importutilv2/parquet/file_reader_test.go index 5ea091c68e..4804b8992c 100644 --- a/internal/util/importutilv2/parquet/file_reader_test.go +++ b/internal/util/importutilv2/parquet/file_reader_test.go @@ -3,6 +3,9 @@ package parquet import ( "testing" + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/memory" "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" @@ -90,3 +93,175 @@ func TestParseSparseFloatRowVector(t *testing.T) { }) } } + +func TestParseSparseFloatVectorStructs(t *testing.T) { + mem := memory.NewGoAllocator() + + checkFunc := func(indices arrow.Array, values arrow.Array, expectSucceed bool) ([][]byte, uint32) { + st := make(map[string]arrow.Array) + if indices != nil { + st[sparseVectorIndice] = indices + } + if values != nil { + st[sparseVectorValues] = values + } + + structs := make([]map[string]arrow.Array, 0) + structs = append(structs, st) + + byteArr, maxDim, err := parseSparseFloatVectorStructs(structs) + if expectSucceed { + assert.NoError(t, err) + } else { + assert.Error(t, err) + } + return byteArr, maxDim + } + + genInt32Arr := func(len int) *array.Int32 { + builder := array.NewInt32Builder(mem) + data := make([]int32, 0) + validData := make([]bool, 0) + for i := 0; i < len; i++ { + data = append(data, (int32)(i)) + validData = append(validData, i%2 == 0) + } + + builder.AppendValues(data, validData) + return builder.NewInt32Array() + } + + genFloat32Arr := func(len int) *array.Float32 { + builder := array.NewFloat32Builder(mem) + data := make([]float32, 0) + validData := make([]bool, 0) + for i := 0; i < len; i++ { + data = append(data, (float32)(i)) + validData = append(validData, i%2 == 0) + } + builder.AppendValues(data, validData) + return builder.NewFloat32Array() + } + + genInt32ArrList := func(arr []uint32) *array.List { + builder := array.NewListBuilder(mem, &arrow.Int32Type{}) + builder.Append(true) + for _, v := range arr { + builder.ValueBuilder().(*array.Int32Builder).Append((int32)(v)) + } + return builder.NewListArray() + } + + genUint32ArrList := func(arr []uint32) *array.List { + builder := array.NewListBuilder(mem, &arrow.Uint32Type{}) + if arr != nil { + builder.Append(true) + for _, v := range arr { + builder.ValueBuilder().(*array.Uint32Builder).Append(v) + } + } + return builder.NewListArray() + } + + genInt64ArrList := func(arr []uint32) *array.List { + builder := array.NewListBuilder(mem, &arrow.Int64Type{}) + if arr != nil { + builder.Append(true) + for _, v := range arr { + builder.ValueBuilder().(*array.Int64Builder).Append((int64)(v)) + } + } + return builder.NewListArray() + } + + genUint64ArrList := func(arr []uint32) *array.List { + builder := array.NewListBuilder(mem, &arrow.Uint64Type{}) + if arr != nil { + builder.Append(true) + for _, v := range arr { + builder.ValueBuilder().(*array.Uint64Builder).Append((uint64)(v)) + } + } + return builder.NewListArray() + } + + genFloat32ArrList := func(arr []float32) *array.List { + builder := array.NewListBuilder(mem, &arrow.Float32Type{}) + if arr != nil { + builder.Append(true) + for _, v := range arr { + builder.ValueBuilder().(*array.Float32Builder).Append(v) + } + } + return builder.NewListArray() + } + + genFloat64ArrList := func(arr []float32) *array.List { + builder := array.NewListBuilder(mem, &arrow.Float64Type{}) + if arr != nil { + builder.Append(true) + for _, v := range arr { + builder.ValueBuilder().(*array.Float64Builder).Append((float64)(v)) + } + } + return builder.NewListArray() + } + + // idices field missed + checkFunc(nil, genFloat32ArrList([]float32{0.1}), false) + + // values field missed + checkFunc(genUint32ArrList([]uint32{1, 2}), nil, false) + + // indices is not array.List + checkFunc(genInt32Arr(2), genFloat32ArrList([]float32{0.1, 0.2}), false) + + // values is not array.List + checkFunc(genUint32ArrList([]uint32{1, 2}), genFloat32Arr(2), false) + + // indices is not list of int32/uint32/int64/uint64 array + checkFunc(genFloat32ArrList([]float32{0.1, 0.2, 0.3}), genFloat32ArrList([]float32{0.1, 0.2, 0.3}), false) + + // values is not list of float32/float64 array + checkFunc(genUint32ArrList([]uint32{1, 2, 3}), genUint32ArrList([]uint32{1, 2, 3}), false) + + // row number of indices and values are different + checkFunc(genUint32ArrList([]uint32{1, 2}), genFloat32ArrList(nil), false) + + // element number of indices and values are different + checkFunc(genUint32ArrList([]uint32{1, 2}), genFloat32ArrList([]float32{0.1}), false) + + // duplicated indices + checkFunc(genUint32ArrList([]uint32{4, 5, 4}), genFloat32ArrList([]float32{0.11, 0.22, 0.23}), false) + + // check result is correct + // can handle empty indices/values + byteArr, maxDim := checkFunc(genUint32ArrList([]uint32{}), genFloat32ArrList([]float32{}), true) + assert.Equal(t, uint32(0), maxDim) + assert.Equal(t, 1, len(byteArr)) + assert.Equal(t, 0, len(byteArr[0])) + + // note that the input indices is not sorted, the parseSparseFloatVectorStructs + // returns correct maxDim and byteArr + indices := []uint32{25, 78, 56} + values := []float32{0.11, 0.22, 0.23} + sortedIndices, sortedValues := typeutil.SortSparseFloatRow(indices, values) + rowBytes := typeutil.CreateSparseFloatRow(sortedIndices, sortedValues) + + isValidFunc := func(indices arrow.Array, values arrow.Array) { + byteArr, maxDim := checkFunc(indices, values, true) + assert.Equal(t, uint32(78), maxDim) + assert.Equal(t, 1, len(byteArr)) + assert.Equal(t, rowBytes, byteArr[0]) + } + + // ensure all supported types are correct + isValidFunc(genUint32ArrList(indices), genFloat32ArrList(values)) + isValidFunc(genUint32ArrList(indices), genFloat64ArrList(values)) + isValidFunc(genInt32ArrList(indices), genFloat32ArrList(values)) + isValidFunc(genInt32ArrList(indices), genFloat64ArrList(values)) + isValidFunc(genUint64ArrList(indices), genFloat32ArrList(values)) + isValidFunc(genUint64ArrList(indices), genFloat64ArrList(values)) + isValidFunc(genInt64ArrList(indices), genFloat32ArrList(values)) + isValidFunc(genInt64ArrList(indices), genFloat64ArrList(values)) +} diff --git a/internal/util/importutilv2/parquet/reader_test.go b/internal/util/importutilv2/parquet/reader_test.go index 1f6a8e0923..a45cd099ce 100644 --- a/internal/util/importutilv2/parquet/reader_test.go +++ b/internal/util/importutilv2/parquet/reader_test.go @@ -24,7 +24,9 @@ import ( "os" "testing" + "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/memory" "github.com/apache/arrow/go/v12/parquet" "github.com/apache/arrow/go/v12/parquet/pqarrow" "github.com/stretchr/testify/assert" @@ -357,6 +359,107 @@ func (s *ReaderSuite) runWithDefaultValue(dataType schemapb.DataType, elemType s checkFn(res, 0, s.numRows) } +func (s *ReaderSuite) runWithSparseVector(indicesType arrow.DataType, valuesType arrow.DataType) { + // milvus schema + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + AutoID: false, + }, + { + FieldID: 101, + Name: "sparse", + DataType: schemapb.DataType_SparseFloatVector, + }, + }, + } + + // arrow schema + arrowFields := make([]arrow.Field, 0) + arrowFields = append(arrowFields, arrow.Field{ + Name: "pk", + Type: &arrow.Int64Type{}, + Nullable: false, + Metadata: arrow.Metadata{}, + }) + + sparseFields := []arrow.Field{ + {Name: sparseVectorIndice, Type: arrow.ListOf(indicesType)}, + {Name: sparseVectorValues, Type: arrow.ListOf(valuesType)}, + } + arrowFields = append(arrowFields, arrow.Field{ + Name: "sparse", + Type: arrow.StructOf(sparseFields...), + Nullable: false, + Metadata: arrow.Metadata{}, + }) + pqSchema := arrow.NewSchema(arrowFields, nil) + + // parquet writer + filePath := fmt.Sprintf("/tmp/test_%d_sparse_reader.parquet", rand.Int()) + defer os.Remove(filePath) + + // prepare milvus data + insertData, err := testutil.CreateInsertData(schema, s.numRows, 0) + assert.NoError(s.T(), err) + + // use a function here because the fw.Close() must be called before we read the parquet file + func() { + wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666) + assert.NoError(s.T(), err) + fw, err := pqarrow.NewFileWriter(pqSchema, wf, parquet.NewWriterProperties(parquet.WithMaxRowGroupLength(int64(s.numRows))), pqarrow.DefaultWriterProps()) + assert.NoError(s.T(), err) + defer fw.Close() + + // prepare parquet data + arrowColumns := make([]arrow.Array, 0, len(schema.Fields)) + mem := memory.NewGoAllocator() + builder := array.NewInt64Builder(mem) + int64Data := insertData.Data[schema.Fields[0].FieldID].(*storage.Int64FieldData).Data + validData := insertData.Data[schema.Fields[0].FieldID].(*storage.Int64FieldData).ValidData + builder.AppendValues(int64Data, validData) + arrowColumns = append(arrowColumns, builder.NewInt64Array()) + + contents := insertData.Data[schema.Fields[1].FieldID].(*storage.SparseFloatVectorFieldData).GetContents() + arr, err := testutil.BuildSparseVectorData(mem, contents, arrowFields[1].Type) + assert.NoError(s.T(), err) + arrowColumns = append(arrowColumns, arr) + + // write parquet + recordBatch := array.NewRecord(pqSchema, arrowColumns, int64(s.numRows)) + err = fw.Write(recordBatch) + assert.NoError(s.T(), err) + }() + + // read parquet + ctx := context.Background() + f := storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus_test/test_parquet_reader/")) + cm, err := f.NewPersistentStorageChunkManager(ctx) + assert.NoError(s.T(), err) + reader, err := NewReader(ctx, cm, schema, filePath, 64*1024*1024) + assert.NoError(s.T(), err) + + checkFn := func(actualInsertData *storage.InsertData, offsetBegin, expectRows int) { + expectInsertData := insertData + for fieldID, data := range actualInsertData.Data { + s.Equal(expectRows, data.RowNum()) + for i := 0; i < expectRows; i++ { + expect := expectInsertData.Data[fieldID].GetRow(i + offsetBegin) + actual := data.GetRow(i) + s.Equal(expect, actual) + } + } + } + + res, err := reader.Read() + assert.NoError(s.T(), err) + checkFn(res, 0, s.numRows) +} + func (s *ReaderSuite) TestReadScalarFieldsWithDefaultValue() { s.runWithDefaultValue(schemapb.DataType_Bool, schemapb.DataType_None, true, 0) s.runWithDefaultValue(schemapb.DataType_Int8, schemapb.DataType_None, true, 0) @@ -493,10 +596,22 @@ func (s *ReaderSuite) TestVector() { s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0) s.vecDataType = schemapb.DataType_BFloat16Vector s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0) + // this test case only test parsing sparse vector from JSON-format string s.vecDataType = schemapb.DataType_SparseFloatVector s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0) } +func (s *ReaderSuite) TestSparseVector() { + s.runWithSparseVector(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float32) + s.runWithSparseVector(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float64) + s.runWithSparseVector(arrow.PrimitiveTypes.Uint32, arrow.PrimitiveTypes.Float32) + s.runWithSparseVector(arrow.PrimitiveTypes.Uint32, arrow.PrimitiveTypes.Float64) + s.runWithSparseVector(arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Float32) + s.runWithSparseVector(arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Float64) + s.runWithSparseVector(arrow.PrimitiveTypes.Uint64, arrow.PrimitiveTypes.Float32) + s.runWithSparseVector(arrow.PrimitiveTypes.Uint64, arrow.PrimitiveTypes.Float64) +} + func TestUtil(t *testing.T) { suite.Run(t, new(ReaderSuite)) } diff --git a/internal/util/importutilv2/parquet/util.go b/internal/util/importutilv2/parquet/util.go index 4b001014bf..5f0b10ff1e 100644 --- a/internal/util/importutilv2/parquet/util.go +++ b/internal/util/importutilv2/parquet/util.go @@ -29,6 +29,11 @@ import ( "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) +const ( + sparseVectorIndice = "indices" + sparseVectorValues = "values" +) + func WrapTypeErr(expect string, actual string, field *schemapb.FieldSchema) error { return merr.WrapErrImportFailed( fmt.Sprintf("expect '%s' type for field '%s', but got '%s' type", @@ -146,11 +151,49 @@ func isArrowDataTypeConvertible(src arrow.DataType, dst arrow.DataType, field *s case arrow.NULL: // if nullable==true or has set default_value, can use null type return field.GetNullable() || field.GetDefaultValue() != nil + case arrow.STRUCT: + if field.GetDataType() == schemapb.DataType_SparseFloatVector { + valid, _ := IsValidSparseVectorSchema(src) + return valid + } + return false default: return false } } +// This method returns two booleans +// The first boolean value means the arrowType is a valid sparse vector schema +// The second boolean value: true means the sparse vector is stored as JSON-format string, +// false means the sparse vector is stored as parquet struct +func IsValidSparseVectorSchema(arrowType arrow.DataType) (bool, bool) { + arrowID := arrowType.ID() + if arrowID == arrow.STRUCT { + arrType := arrowType.(*arrow.StructType) + indicesType, ok1 := arrType.FieldByName(sparseVectorIndice) + valuesType, ok2 := arrType.FieldByName(sparseVectorValues) + if !ok1 || !ok2 { + return false, false + } + + // indices can be uint32 list or int64 list + // values can be float32 list or float64 list + isValidType := func(finger string, expectedType arrow.DataType) bool { + return finger == arrow.ListOf(expectedType).Fingerprint() + } + indicesFinger := indicesType.Type.Fingerprint() + valuesFinger := valuesType.Type.Fingerprint() + indicesTypeIsOK := (isValidType(indicesFinger, arrow.PrimitiveTypes.Int32) || + isValidType(indicesFinger, arrow.PrimitiveTypes.Uint32) || + isValidType(indicesFinger, arrow.PrimitiveTypes.Int64) || + isValidType(indicesFinger, arrow.PrimitiveTypes.Uint64)) + valuesTypeIsOK := (isValidType(valuesFinger, arrow.PrimitiveTypes.Float32) || + isValidType(valuesFinger, arrow.PrimitiveTypes.Float64)) + return indicesTypeIsOK && valuesTypeIsOK, false + } + return arrowID == arrow.STRING, true +} + func convertToArrowDataType(field *schemapb.FieldSchema, isArray bool) (arrow.DataType, error) { dataType := field.GetDataType() if isArray { diff --git a/internal/util/testutil/test_util.go b/internal/util/testutil/test_util.go index 3d8bb637ea..322eb3bf1c 100644 --- a/internal/util/testutil/test_util.go +++ b/internal/util/testutil/test_util.go @@ -2,6 +2,7 @@ package testutil import ( "fmt" + "math" "math/rand" "strconv" @@ -277,6 +278,118 @@ func CreateFieldWithDefaultValue(dataType schemapb.DataType, id int64, nullable return field, nil } +func BuildSparseVectorData(mem *memory.GoAllocator, contents [][]byte, arrowType arrow.DataType) (arrow.Array, error) { + if arrowType == nil || arrowType.ID() == arrow.STRING { + // build sparse vector as JSON-format string + builder := array.NewStringBuilder(mem) + rows := len(contents) + jsonBytesData := make([][]byte, 0) + for i := 0; i < rows; i++ { + rowVecData := contents[i] + mapData := typeutil.SparseFloatBytesToMap(rowVecData) + // convert to JSON format + jsonBytes, err := json.Marshal(mapData) + if err != nil { + return nil, err + } + jsonBytesData = append(jsonBytesData, jsonBytes) + } + builder.AppendValues(lo.Map(jsonBytesData, func(bs []byte, _ int) string { + return string(bs) + }), nil) + return builder.NewStringArray(), nil + } else if arrowType.ID() == arrow.STRUCT { + // build sparse vector as parquet struct + stType, _ := arrowType.(*arrow.StructType) + indicesField, ok1 := stType.FieldByName("indices") + valuesField, ok2 := stType.FieldByName("values") + if !ok1 || !ok2 { + return nil, merr.WrapErrParameterInvalidMsg("Indices type or values type is missed for sparse vector") + } + + indicesList, ok1 := indicesField.Type.(*arrow.ListType) + valuesList, ok2 := valuesField.Type.(*arrow.ListType) + if !ok1 || !ok2 { + return nil, merr.WrapErrParameterInvalidMsg("Indices type and values type of sparse vector should be list") + } + indexType := indicesList.Elem().ID() + valueType := valuesList.Elem().ID() + + fields := []arrow.Field{indicesField, valuesField} + structType := arrow.StructOf(fields...) + builder := array.NewStructBuilder(mem, structType) + indicesBuilder := builder.FieldBuilder(0).(*array.ListBuilder) + valuesBuilder := builder.FieldBuilder(1).(*array.ListBuilder) + + // The array.Uint32Builder/array.Int64Builder/array.Float32Builder/array.Float64Builder + // are derived from array.Builder, but array.Builder doesn't have Append() interface + // To call array.Uint32Builder.Value(uint32), we need to explicitly cast the indicesBuilder.ValueBuilder() + // to array.Uint32Builder + // So, we declare two methods here to avoid type casting in the "for" loop + type AppendIndex func(index uint32) + type AppendValue func(value float32) + + var appendIndexFunc AppendIndex + switch indexType { + case arrow.INT32: + indicesArrayBuilder := indicesBuilder.ValueBuilder().(*array.Int32Builder) + appendIndexFunc = func(index uint32) { + indicesArrayBuilder.Append((int32)(index)) + } + case arrow.UINT32: + indicesArrayBuilder := indicesBuilder.ValueBuilder().(*array.Uint32Builder) + appendIndexFunc = func(index uint32) { + indicesArrayBuilder.Append(index) + } + case arrow.INT64: + indicesArrayBuilder := indicesBuilder.ValueBuilder().(*array.Int64Builder) + appendIndexFunc = func(index uint32) { + indicesArrayBuilder.Append((int64)(index)) + } + case arrow.UINT64: + indicesArrayBuilder := indicesBuilder.ValueBuilder().(*array.Uint64Builder) + appendIndexFunc = func(index uint32) { + indicesArrayBuilder.Append((uint64)(index)) + } + default: + msg := fmt.Sprintf("Not able to write this type (%s) for sparse vector index", indexType.String()) + return nil, merr.WrapErrImportFailed(msg) + } + + var appendValueFunc AppendValue + switch valueType { + case arrow.FLOAT32: + valuesArrayBuilder := valuesBuilder.ValueBuilder().(*array.Float32Builder) + appendValueFunc = func(value float32) { + valuesArrayBuilder.Append(value) + } + case arrow.FLOAT64: + valuesArrayBuilder := valuesBuilder.ValueBuilder().(*array.Float64Builder) + appendValueFunc = func(value float32) { + valuesArrayBuilder.Append((float64)(value)) + } + default: + msg := fmt.Sprintf("Not able to write this type (%s) for sparse vector index", indexType.String()) + return nil, merr.WrapErrImportFailed(msg) + } + + for i := 0; i < len(contents); i++ { + builder.Append(true) + indicesBuilder.Append(true) + valuesBuilder.Append(true) + rowVecData := contents[i] + elemCount := len(rowVecData) / 8 + for j := 0; j < elemCount; j++ { + appendIndexFunc(common.Endian.Uint32(rowVecData[j*8:])) + appendValueFunc(math.Float32frombits(common.Endian.Uint32(rowVecData[j*8+4:]))) + } + } + return builder.NewStructArray(), nil + } + + return nil, merr.WrapErrParameterInvalidMsg("Invalid arrow data type for sparse vector") +} + func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.InsertData, useNullType bool) ([]arrow.Array, error) { mem := memory.NewGoAllocator() columns := make([]arrow.Array, 0, len(schema.Fields)) @@ -401,24 +514,12 @@ func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.Inser builder.AppendValues(offsets, valid) columns = append(columns, builder.NewListArray()) case schemapb.DataType_SparseFloatVector: - builder := array.NewStringBuilder(mem) contents := insertData.Data[fieldID].(*storage.SparseFloatVectorFieldData).GetContents() - rows := len(contents) - jsonBytesData := make([][]byte, 0) - for i := 0; i < rows; i++ { - rowVecData := contents[i] - mapData := typeutil.SparseFloatBytesToMap(rowVecData) - // convert to JSON format - jsonBytes, err := json.Marshal(mapData) - if err != nil { - return nil, err - } - jsonBytesData = append(jsonBytesData, jsonBytes) + arr, err := BuildSparseVectorData(mem, contents, nil) + if err != nil { + return nil, err } - builder.AppendValues(lo.Map(jsonBytesData, func(bs []byte, _ int) string { - return string(bs) - }), nil) - columns = append(columns, builder.NewStringArray()) + columns = append(columns, arr) case schemapb.DataType_JSON: builder := array.NewStringBuilder(mem) jsonData := insertData.Data[fieldID].(*storage.JSONFieldData).Data