mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
enhance: bulkinsert supports parsing sparse vector form parquet struct (#40874)
issue: https://github.com/milvus-io/milvus/issues/40777 pr: https://github.com/milvus-io/milvus/pull/40927 Signed-off-by: yhmo <yihua.mo@zilliz.com>
This commit is contained in:
parent
0ba389e434
commit
712d1644d8
@ -40,8 +40,9 @@ type FieldReader struct {
|
||||
columnIndex int
|
||||
columnReader *pqarrow.ColumnReader
|
||||
|
||||
dim int
|
||||
field *schemapb.FieldSchema
|
||||
dim int
|
||||
field *schemapb.FieldSchema
|
||||
sparseIsString bool
|
||||
}
|
||||
|
||||
func NewFieldReader(ctx context.Context, reader *pqarrow.FileReader, columnIndex int, field *schemapb.FieldSchema) (*FieldReader, error) {
|
||||
@ -58,11 +59,19 @@ func NewFieldReader(ctx context.Context, reader *pqarrow.FileReader, columnIndex
|
||||
}
|
||||
}
|
||||
|
||||
// set a flag here to know whether a sparse vector is stored as JSON-format string or parquet struct
|
||||
// because we don't intend to check it every time the Next() is called
|
||||
sparseIsString := true
|
||||
if field.GetDataType() == schemapb.DataType_SparseFloatVector {
|
||||
_, sparseIsString = IsValidSparseVectorSchema(columnReader.Field().Type)
|
||||
}
|
||||
|
||||
cr := &FieldReader{
|
||||
columnIndex: columnIndex,
|
||||
columnReader: columnReader,
|
||||
dim: int(dim),
|
||||
field: field,
|
||||
columnIndex: columnIndex,
|
||||
columnReader: columnReader,
|
||||
dim: int(dim),
|
||||
field: field,
|
||||
sparseIsString: sparseIsString,
|
||||
}
|
||||
return cr, nil
|
||||
}
|
||||
@ -416,6 +425,74 @@ func ReadNullableIntegerOrFloatData[T constraints.Integer | constraints.Float](p
|
||||
return data, validData, nil
|
||||
}
|
||||
|
||||
// This method returns a []map[string]arrow.Array
|
||||
// map[string]arrow.Array represents a struct
|
||||
// For example 1:
|
||||
//
|
||||
// struct {
|
||||
// name string
|
||||
// age int
|
||||
// }
|
||||
//
|
||||
// The ReadStructData() will return a list like:
|
||||
//
|
||||
// [
|
||||
// {"name": ["a", "b", "c"], "age": [4, 5, 6]},
|
||||
// {"name": ["e", "f"], "age": [7, 8]}
|
||||
// ]
|
||||
//
|
||||
// Value type of "name" is array.String, value type of "age" is array.Int32
|
||||
// The length of the list is equal to the length of chunked.Chunks()
|
||||
//
|
||||
// For sparse vector, the map[string]arrow.Array is like {"indices": array.List, "values": array.List}
|
||||
// For example 2:
|
||||
//
|
||||
// struct {
|
||||
// indices []uint32
|
||||
// values []float32
|
||||
// }
|
||||
//
|
||||
// The ReadStructData() will return a list like:
|
||||
//
|
||||
// [
|
||||
// {"indices": [[1, 2, 3], [4, 5], [6, 7]], "values": [[0.1, 0.2, 0.3], [0.4, 0.5], [0.6, 0.7]]},
|
||||
// {"indices": [[8], [9, 10]], "values": [[0.8], [0.9, 1.0]]}
|
||||
// ]
|
||||
//
|
||||
// Value type of "indices" is array.List, element type is array.Uint32
|
||||
// Value type of "values" is array.List, element type is array.Float32
|
||||
// The length of the list is equal to the length of chunked.Chunks()
|
||||
//
|
||||
// Note: now the ReadStructData() is used by SparseVector type and SparseVector is not nullable,
|
||||
// create a new method ReadNullableStructData() if we have nullable struct type in future.
|
||||
func ReadStructData(pcr *FieldReader, count int64) ([]map[string]arrow.Array, error) {
|
||||
chunked, err := pcr.columnReader.NextBatch(count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]map[string]arrow.Array, 0, count)
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
structReader, ok := chunk.(*array.Struct)
|
||||
if structReader.NullN() > 0 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("has null value, but struct doesn't support nullable yet")
|
||||
}
|
||||
if !ok {
|
||||
return nil, WrapTypeErr("struct", chunk.DataType().Name(), pcr.field)
|
||||
}
|
||||
|
||||
structType := structReader.DataType().(*arrow.StructType)
|
||||
st := make(map[string]arrow.Array)
|
||||
for k, field := range structType.Fields() {
|
||||
st[field.Name] = structReader.Field(k)
|
||||
}
|
||||
data = append(data, st)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func ReadStringData(pcr *FieldReader, count int64) (any, error) {
|
||||
chunked, err := pcr.columnReader.NextBatch(count)
|
||||
if err != nil {
|
||||
@ -670,8 +747,165 @@ func parseSparseFloatRowVector(str string) ([]byte, uint32, error) {
|
||||
return rowVec, maxIdx, nil
|
||||
}
|
||||
|
||||
// This method accepts input from ReadStructData()
|
||||
// For sparse vector, the map[string]arrow.Array is like {"indices": array.List, "values": array.List}
|
||||
// Although "indices" and "values" is two-dim list, the array.List provides ListValues() and ValueOffsets()
|
||||
// to return one-dim list. We use the start/end position of ValueOffsets() to get the correct sparse vector
|
||||
// from ListValues().
|
||||
// Note that arrow.Uint32.Value(int i) accepts an int32 value, the max length of indices/values is max value of int32
|
||||
func parseSparseFloatVectorStructs(structs []map[string]arrow.Array) ([][]byte, uint32, error) {
|
||||
byteArr := make([][]byte, 0)
|
||||
maxDim := uint32(0)
|
||||
for _, st := range structs {
|
||||
indices, ok1 := st[sparseVectorIndice]
|
||||
values, ok2 := st[sparseVectorValues]
|
||||
if !ok1 || !ok2 {
|
||||
return nil, 0, merr.WrapErrImportFailed("Invalid parquet struct for SparseFloatVector: 'indices' or 'values' missed")
|
||||
}
|
||||
|
||||
indicesList, ok1 := indices.(*array.List)
|
||||
valuesList, ok2 := values.(*array.List)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, 0, merr.WrapErrImportFailed("Invalid parquet struct for SparseFloatVector: 'indices' or 'values' is not list")
|
||||
}
|
||||
|
||||
// Len() is the number of rows in this row group
|
||||
if indices.Len() != values.Len() {
|
||||
msg := fmt.Sprintf("Invalid parquet struct for SparseFloatVector: number of rows of 'indices' and 'values' mismatched, '%d' vs '%d'", indices.Len(), values.Len())
|
||||
return nil, 0, merr.WrapErrImportFailed(msg)
|
||||
}
|
||||
|
||||
// technically, DataType() of array.List must be arrow.ListType, but we still check here to ensure safety
|
||||
indicesListType, ok1 := indicesList.DataType().(*arrow.ListType)
|
||||
valuesListType, ok2 := valuesList.DataType().(*arrow.ListType)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, 0, merr.WrapErrImportFailed("Invalid parquet struct for SparseFloatVector: incorrect arrow type of 'indices' or 'values'")
|
||||
}
|
||||
|
||||
indexDataType := indicesListType.Elem().ID()
|
||||
valueDataType := valuesListType.Elem().ID()
|
||||
|
||||
// The array.Uint32/array.Int64/array.Float32/array.Float64 are derived from arrow.Array
|
||||
// The ListValues() returns arrow.Array interface, but the arrow.Array doesn't have Value(int) interface
|
||||
// To call array.Uint32.Value(int), we need to explicitly cast the ListValues() to array.Uint32
|
||||
// So, we declare two methods here to avoid type casting in the "for" loop
|
||||
type GetIndex func(position int) uint32
|
||||
type GetValue func(position int) float32
|
||||
|
||||
var getIndexFunc GetIndex
|
||||
switch indexDataType {
|
||||
case arrow.INT32:
|
||||
indicesList := indicesList.ListValues().(*array.Int32)
|
||||
getIndexFunc = func(position int) uint32 {
|
||||
return (uint32)(indicesList.Value(position))
|
||||
}
|
||||
case arrow.UINT32:
|
||||
indicesList := indicesList.ListValues().(*array.Uint32)
|
||||
getIndexFunc = func(position int) uint32 {
|
||||
return indicesList.Value(position)
|
||||
}
|
||||
case arrow.INT64:
|
||||
indicesList := indicesList.ListValues().(*array.Int64)
|
||||
getIndexFunc = func(position int) uint32 {
|
||||
return (uint32)(indicesList.Value(position))
|
||||
}
|
||||
case arrow.UINT64:
|
||||
indicesList := indicesList.ListValues().(*array.Uint64)
|
||||
getIndexFunc = func(position int) uint32 {
|
||||
return (uint32)(indicesList.Value(position))
|
||||
}
|
||||
default:
|
||||
msg := fmt.Sprintf("Invalid parquet struct for SparseFloatVector: index type must be uint32/int32/uint64/int64 but actual type is '%s'", indicesListType.Elem().Name())
|
||||
return nil, 0, merr.WrapErrImportFailed(msg)
|
||||
}
|
||||
|
||||
var getValueFunc GetValue
|
||||
switch valueDataType {
|
||||
case arrow.FLOAT32:
|
||||
valuesList := valuesList.ListValues().(*array.Float32)
|
||||
getValueFunc = func(position int) float32 {
|
||||
return valuesList.Value(position)
|
||||
}
|
||||
case arrow.FLOAT64:
|
||||
valuesList := valuesList.ListValues().(*array.Float64)
|
||||
getValueFunc = func(position int) float32 {
|
||||
return (float32)(valuesList.Value(position))
|
||||
}
|
||||
default:
|
||||
msg := fmt.Sprintf("Invalid parquet struct for SparseFloatVector: value type must be float32 or float64 but actual type is '%s'", valuesListType.Elem().Name())
|
||||
return nil, 0, merr.WrapErrImportFailed(msg)
|
||||
}
|
||||
|
||||
for i := 0; i < indicesList.Len(); i++ {
|
||||
start, end := indicesList.ValueOffsets(i)
|
||||
start2, end2 := valuesList.ValueOffsets(i)
|
||||
rowLen := (int)(end - start)
|
||||
rowLenValues := (int)(end2 - start2)
|
||||
if rowLenValues != rowLen {
|
||||
msg := fmt.Sprintf("Invalid parquet struct for SparseFloatVector: number of elements of 'indices' and 'values' mismatched, '%d' vs '%d'", rowLen, rowLenValues)
|
||||
return nil, 0, merr.WrapErrImportFailed(msg)
|
||||
}
|
||||
|
||||
rowIndices := make([]uint32, rowLen)
|
||||
rowValues := make([]float32, rowLen)
|
||||
for i := start; i < end; i++ {
|
||||
rowIndices[i-start] = getIndexFunc((int)(i))
|
||||
rowValues[i-start] = getValueFunc((int)(i))
|
||||
}
|
||||
|
||||
// ensure the indices is sorted
|
||||
sortedIndices, sortedValues := typeutil.SortSparseFloatRow(rowIndices, rowValues)
|
||||
rowVec := typeutil.CreateSparseFloatRow(sortedIndices, sortedValues)
|
||||
if err := typeutil.ValidateSparseFloatRows(rowVec); err != nil {
|
||||
return byteArr, maxDim, err
|
||||
}
|
||||
|
||||
// set the maxDim as the last value of sortedIndices since it has been sorted
|
||||
if len(sortedIndices) > 0 && sortedIndices[len(sortedIndices)-1] > maxDim {
|
||||
maxDim = sortedIndices[len(sortedIndices)-1]
|
||||
}
|
||||
byteArr = append(byteArr, rowVec) // rowVec could be an empty sparse
|
||||
}
|
||||
}
|
||||
return byteArr, maxDim, nil
|
||||
}
|
||||
|
||||
func ReadSparseFloatVectorData(pcr *FieldReader, count int64) (any, error) {
|
||||
data, err := ReadStringData(pcr, count)
|
||||
// read sparse vector from JSON-format string
|
||||
if pcr.sparseIsString {
|
||||
data, err := ReadStringData(pcr, count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if data == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
byteArr := make([][]byte, 0, count)
|
||||
maxDim := uint32(0)
|
||||
|
||||
for _, str := range data.([]string) {
|
||||
rowVec, rowMaxIdx, err := parseSparseFloatRowVector(str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
byteArr = append(byteArr, rowVec)
|
||||
if rowMaxIdx > maxDim {
|
||||
maxDim = rowMaxIdx
|
||||
}
|
||||
}
|
||||
|
||||
return &storage.SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: int64(maxDim),
|
||||
Contents: byteArr,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// read sparse vector from parquet struct
|
||||
data, err := ReadStructData(pcr, count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -679,19 +913,9 @@ func ReadSparseFloatVectorData(pcr *FieldReader, count int64) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
byteArr := make([][]byte, 0, count)
|
||||
maxDim := uint32(0)
|
||||
|
||||
for _, str := range data.([]string) {
|
||||
rowVec, rowMaxIdx, err := parseSparseFloatRowVector(str)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
byteArr = append(byteArr, rowVec)
|
||||
if rowMaxIdx > maxDim {
|
||||
maxDim = rowMaxIdx
|
||||
}
|
||||
byteArr, maxDim, err := parseSparseFloatVectorStructs(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &storage.SparseFloatVectorFieldData{
|
||||
|
||||
@ -3,6 +3,9 @@ package parquet
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/apache/arrow/go/v12/arrow"
|
||||
"github.com/apache/arrow/go/v12/arrow/array"
|
||||
"github.com/apache/arrow/go/v12/arrow/memory"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
@ -90,3 +93,175 @@ func TestParseSparseFloatRowVector(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSparseFloatVectorStructs(t *testing.T) {
|
||||
mem := memory.NewGoAllocator()
|
||||
|
||||
checkFunc := func(indices arrow.Array, values arrow.Array, expectSucceed bool) ([][]byte, uint32) {
|
||||
st := make(map[string]arrow.Array)
|
||||
if indices != nil {
|
||||
st[sparseVectorIndice] = indices
|
||||
}
|
||||
if values != nil {
|
||||
st[sparseVectorValues] = values
|
||||
}
|
||||
|
||||
structs := make([]map[string]arrow.Array, 0)
|
||||
structs = append(structs, st)
|
||||
|
||||
byteArr, maxDim, err := parseSparseFloatVectorStructs(structs)
|
||||
if expectSucceed {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
return byteArr, maxDim
|
||||
}
|
||||
|
||||
genInt32Arr := func(len int) *array.Int32 {
|
||||
builder := array.NewInt32Builder(mem)
|
||||
data := make([]int32, 0)
|
||||
validData := make([]bool, 0)
|
||||
for i := 0; i < len; i++ {
|
||||
data = append(data, (int32)(i))
|
||||
validData = append(validData, i%2 == 0)
|
||||
}
|
||||
|
||||
builder.AppendValues(data, validData)
|
||||
return builder.NewInt32Array()
|
||||
}
|
||||
|
||||
genFloat32Arr := func(len int) *array.Float32 {
|
||||
builder := array.NewFloat32Builder(mem)
|
||||
data := make([]float32, 0)
|
||||
validData := make([]bool, 0)
|
||||
for i := 0; i < len; i++ {
|
||||
data = append(data, (float32)(i))
|
||||
validData = append(validData, i%2 == 0)
|
||||
}
|
||||
builder.AppendValues(data, validData)
|
||||
return builder.NewFloat32Array()
|
||||
}
|
||||
|
||||
genInt32ArrList := func(arr []uint32) *array.List {
|
||||
builder := array.NewListBuilder(mem, &arrow.Int32Type{})
|
||||
builder.Append(true)
|
||||
for _, v := range arr {
|
||||
builder.ValueBuilder().(*array.Int32Builder).Append((int32)(v))
|
||||
}
|
||||
return builder.NewListArray()
|
||||
}
|
||||
|
||||
genUint32ArrList := func(arr []uint32) *array.List {
|
||||
builder := array.NewListBuilder(mem, &arrow.Uint32Type{})
|
||||
if arr != nil {
|
||||
builder.Append(true)
|
||||
for _, v := range arr {
|
||||
builder.ValueBuilder().(*array.Uint32Builder).Append(v)
|
||||
}
|
||||
}
|
||||
return builder.NewListArray()
|
||||
}
|
||||
|
||||
genInt64ArrList := func(arr []uint32) *array.List {
|
||||
builder := array.NewListBuilder(mem, &arrow.Int64Type{})
|
||||
if arr != nil {
|
||||
builder.Append(true)
|
||||
for _, v := range arr {
|
||||
builder.ValueBuilder().(*array.Int64Builder).Append((int64)(v))
|
||||
}
|
||||
}
|
||||
return builder.NewListArray()
|
||||
}
|
||||
|
||||
genUint64ArrList := func(arr []uint32) *array.List {
|
||||
builder := array.NewListBuilder(mem, &arrow.Uint64Type{})
|
||||
if arr != nil {
|
||||
builder.Append(true)
|
||||
for _, v := range arr {
|
||||
builder.ValueBuilder().(*array.Uint64Builder).Append((uint64)(v))
|
||||
}
|
||||
}
|
||||
return builder.NewListArray()
|
||||
}
|
||||
|
||||
genFloat32ArrList := func(arr []float32) *array.List {
|
||||
builder := array.NewListBuilder(mem, &arrow.Float32Type{})
|
||||
if arr != nil {
|
||||
builder.Append(true)
|
||||
for _, v := range arr {
|
||||
builder.ValueBuilder().(*array.Float32Builder).Append(v)
|
||||
}
|
||||
}
|
||||
return builder.NewListArray()
|
||||
}
|
||||
|
||||
genFloat64ArrList := func(arr []float32) *array.List {
|
||||
builder := array.NewListBuilder(mem, &arrow.Float64Type{})
|
||||
if arr != nil {
|
||||
builder.Append(true)
|
||||
for _, v := range arr {
|
||||
builder.ValueBuilder().(*array.Float64Builder).Append((float64)(v))
|
||||
}
|
||||
}
|
||||
return builder.NewListArray()
|
||||
}
|
||||
|
||||
// idices field missed
|
||||
checkFunc(nil, genFloat32ArrList([]float32{0.1}), false)
|
||||
|
||||
// values field missed
|
||||
checkFunc(genUint32ArrList([]uint32{1, 2}), nil, false)
|
||||
|
||||
// indices is not array.List
|
||||
checkFunc(genInt32Arr(2), genFloat32ArrList([]float32{0.1, 0.2}), false)
|
||||
|
||||
// values is not array.List
|
||||
checkFunc(genUint32ArrList([]uint32{1, 2}), genFloat32Arr(2), false)
|
||||
|
||||
// indices is not list of int32/uint32/int64/uint64 array
|
||||
checkFunc(genFloat32ArrList([]float32{0.1, 0.2, 0.3}), genFloat32ArrList([]float32{0.1, 0.2, 0.3}), false)
|
||||
|
||||
// values is not list of float32/float64 array
|
||||
checkFunc(genUint32ArrList([]uint32{1, 2, 3}), genUint32ArrList([]uint32{1, 2, 3}), false)
|
||||
|
||||
// row number of indices and values are different
|
||||
checkFunc(genUint32ArrList([]uint32{1, 2}), genFloat32ArrList(nil), false)
|
||||
|
||||
// element number of indices and values are different
|
||||
checkFunc(genUint32ArrList([]uint32{1, 2}), genFloat32ArrList([]float32{0.1}), false)
|
||||
|
||||
// duplicated indices
|
||||
checkFunc(genUint32ArrList([]uint32{4, 5, 4}), genFloat32ArrList([]float32{0.11, 0.22, 0.23}), false)
|
||||
|
||||
// check result is correct
|
||||
// can handle empty indices/values
|
||||
byteArr, maxDim := checkFunc(genUint32ArrList([]uint32{}), genFloat32ArrList([]float32{}), true)
|
||||
assert.Equal(t, uint32(0), maxDim)
|
||||
assert.Equal(t, 1, len(byteArr))
|
||||
assert.Equal(t, 0, len(byteArr[0]))
|
||||
|
||||
// note that the input indices is not sorted, the parseSparseFloatVectorStructs
|
||||
// returns correct maxDim and byteArr
|
||||
indices := []uint32{25, 78, 56}
|
||||
values := []float32{0.11, 0.22, 0.23}
|
||||
sortedIndices, sortedValues := typeutil.SortSparseFloatRow(indices, values)
|
||||
rowBytes := typeutil.CreateSparseFloatRow(sortedIndices, sortedValues)
|
||||
|
||||
isValidFunc := func(indices arrow.Array, values arrow.Array) {
|
||||
byteArr, maxDim := checkFunc(indices, values, true)
|
||||
assert.Equal(t, uint32(78), maxDim)
|
||||
assert.Equal(t, 1, len(byteArr))
|
||||
assert.Equal(t, rowBytes, byteArr[0])
|
||||
}
|
||||
|
||||
// ensure all supported types are correct
|
||||
isValidFunc(genUint32ArrList(indices), genFloat32ArrList(values))
|
||||
isValidFunc(genUint32ArrList(indices), genFloat64ArrList(values))
|
||||
isValidFunc(genInt32ArrList(indices), genFloat32ArrList(values))
|
||||
isValidFunc(genInt32ArrList(indices), genFloat64ArrList(values))
|
||||
isValidFunc(genUint64ArrList(indices), genFloat32ArrList(values))
|
||||
isValidFunc(genUint64ArrList(indices), genFloat64ArrList(values))
|
||||
isValidFunc(genInt64ArrList(indices), genFloat32ArrList(values))
|
||||
isValidFunc(genInt64ArrList(indices), genFloat64ArrList(values))
|
||||
}
|
||||
|
||||
@ -24,7 +24,9 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/apache/arrow/go/v12/arrow"
|
||||
"github.com/apache/arrow/go/v12/arrow/array"
|
||||
"github.com/apache/arrow/go/v12/arrow/memory"
|
||||
"github.com/apache/arrow/go/v12/parquet"
|
||||
"github.com/apache/arrow/go/v12/parquet/pqarrow"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -357,6 +359,107 @@ func (s *ReaderSuite) runWithDefaultValue(dataType schemapb.DataType, elemType s
|
||||
checkFn(res, 0, s.numRows)
|
||||
}
|
||||
|
||||
func (s *ReaderSuite) runWithSparseVector(indicesType arrow.DataType, valuesType arrow.DataType) {
|
||||
// milvus schema
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "pk",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
AutoID: false,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "sparse",
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// arrow schema
|
||||
arrowFields := make([]arrow.Field, 0)
|
||||
arrowFields = append(arrowFields, arrow.Field{
|
||||
Name: "pk",
|
||||
Type: &arrow.Int64Type{},
|
||||
Nullable: false,
|
||||
Metadata: arrow.Metadata{},
|
||||
})
|
||||
|
||||
sparseFields := []arrow.Field{
|
||||
{Name: sparseVectorIndice, Type: arrow.ListOf(indicesType)},
|
||||
{Name: sparseVectorValues, Type: arrow.ListOf(valuesType)},
|
||||
}
|
||||
arrowFields = append(arrowFields, arrow.Field{
|
||||
Name: "sparse",
|
||||
Type: arrow.StructOf(sparseFields...),
|
||||
Nullable: false,
|
||||
Metadata: arrow.Metadata{},
|
||||
})
|
||||
pqSchema := arrow.NewSchema(arrowFields, nil)
|
||||
|
||||
// parquet writer
|
||||
filePath := fmt.Sprintf("/tmp/test_%d_sparse_reader.parquet", rand.Int())
|
||||
defer os.Remove(filePath)
|
||||
|
||||
// prepare milvus data
|
||||
insertData, err := testutil.CreateInsertData(schema, s.numRows, 0)
|
||||
assert.NoError(s.T(), err)
|
||||
|
||||
// use a function here because the fw.Close() must be called before we read the parquet file
|
||||
func() {
|
||||
wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666)
|
||||
assert.NoError(s.T(), err)
|
||||
fw, err := pqarrow.NewFileWriter(pqSchema, wf, parquet.NewWriterProperties(parquet.WithMaxRowGroupLength(int64(s.numRows))), pqarrow.DefaultWriterProps())
|
||||
assert.NoError(s.T(), err)
|
||||
defer fw.Close()
|
||||
|
||||
// prepare parquet data
|
||||
arrowColumns := make([]arrow.Array, 0, len(schema.Fields))
|
||||
mem := memory.NewGoAllocator()
|
||||
builder := array.NewInt64Builder(mem)
|
||||
int64Data := insertData.Data[schema.Fields[0].FieldID].(*storage.Int64FieldData).Data
|
||||
validData := insertData.Data[schema.Fields[0].FieldID].(*storage.Int64FieldData).ValidData
|
||||
builder.AppendValues(int64Data, validData)
|
||||
arrowColumns = append(arrowColumns, builder.NewInt64Array())
|
||||
|
||||
contents := insertData.Data[schema.Fields[1].FieldID].(*storage.SparseFloatVectorFieldData).GetContents()
|
||||
arr, err := testutil.BuildSparseVectorData(mem, contents, arrowFields[1].Type)
|
||||
assert.NoError(s.T(), err)
|
||||
arrowColumns = append(arrowColumns, arr)
|
||||
|
||||
// write parquet
|
||||
recordBatch := array.NewRecord(pqSchema, arrowColumns, int64(s.numRows))
|
||||
err = fw.Write(recordBatch)
|
||||
assert.NoError(s.T(), err)
|
||||
}()
|
||||
|
||||
// read parquet
|
||||
ctx := context.Background()
|
||||
f := storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus_test/test_parquet_reader/"))
|
||||
cm, err := f.NewPersistentStorageChunkManager(ctx)
|
||||
assert.NoError(s.T(), err)
|
||||
reader, err := NewReader(ctx, cm, schema, filePath, 64*1024*1024)
|
||||
assert.NoError(s.T(), err)
|
||||
|
||||
checkFn := func(actualInsertData *storage.InsertData, offsetBegin, expectRows int) {
|
||||
expectInsertData := insertData
|
||||
for fieldID, data := range actualInsertData.Data {
|
||||
s.Equal(expectRows, data.RowNum())
|
||||
for i := 0; i < expectRows; i++ {
|
||||
expect := expectInsertData.Data[fieldID].GetRow(i + offsetBegin)
|
||||
actual := data.GetRow(i)
|
||||
s.Equal(expect, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res, err := reader.Read()
|
||||
assert.NoError(s.T(), err)
|
||||
checkFn(res, 0, s.numRows)
|
||||
}
|
||||
|
||||
func (s *ReaderSuite) TestReadScalarFieldsWithDefaultValue() {
|
||||
s.runWithDefaultValue(schemapb.DataType_Bool, schemapb.DataType_None, true, 0)
|
||||
s.runWithDefaultValue(schemapb.DataType_Int8, schemapb.DataType_None, true, 0)
|
||||
@ -493,10 +596,22 @@ func (s *ReaderSuite) TestVector() {
|
||||
s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0)
|
||||
s.vecDataType = schemapb.DataType_BFloat16Vector
|
||||
s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0)
|
||||
// this test case only test parsing sparse vector from JSON-format string
|
||||
s.vecDataType = schemapb.DataType_SparseFloatVector
|
||||
s.run(schemapb.DataType_Int32, schemapb.DataType_None, false, 0)
|
||||
}
|
||||
|
||||
func (s *ReaderSuite) TestSparseVector() {
|
||||
s.runWithSparseVector(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float32)
|
||||
s.runWithSparseVector(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float64)
|
||||
s.runWithSparseVector(arrow.PrimitiveTypes.Uint32, arrow.PrimitiveTypes.Float32)
|
||||
s.runWithSparseVector(arrow.PrimitiveTypes.Uint32, arrow.PrimitiveTypes.Float64)
|
||||
s.runWithSparseVector(arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Float32)
|
||||
s.runWithSparseVector(arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Float64)
|
||||
s.runWithSparseVector(arrow.PrimitiveTypes.Uint64, arrow.PrimitiveTypes.Float32)
|
||||
s.runWithSparseVector(arrow.PrimitiveTypes.Uint64, arrow.PrimitiveTypes.Float64)
|
||||
}
|
||||
|
||||
func TestUtil(t *testing.T) {
|
||||
suite.Run(t, new(ReaderSuite))
|
||||
}
|
||||
|
||||
@ -29,6 +29,11 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
const (
|
||||
sparseVectorIndice = "indices"
|
||||
sparseVectorValues = "values"
|
||||
)
|
||||
|
||||
func WrapTypeErr(expect string, actual string, field *schemapb.FieldSchema) error {
|
||||
return merr.WrapErrImportFailed(
|
||||
fmt.Sprintf("expect '%s' type for field '%s', but got '%s' type",
|
||||
@ -146,11 +151,49 @@ func isArrowDataTypeConvertible(src arrow.DataType, dst arrow.DataType, field *s
|
||||
case arrow.NULL:
|
||||
// if nullable==true or has set default_value, can use null type
|
||||
return field.GetNullable() || field.GetDefaultValue() != nil
|
||||
case arrow.STRUCT:
|
||||
if field.GetDataType() == schemapb.DataType_SparseFloatVector {
|
||||
valid, _ := IsValidSparseVectorSchema(src)
|
||||
return valid
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// This method returns two booleans
|
||||
// The first boolean value means the arrowType is a valid sparse vector schema
|
||||
// The second boolean value: true means the sparse vector is stored as JSON-format string,
|
||||
// false means the sparse vector is stored as parquet struct
|
||||
func IsValidSparseVectorSchema(arrowType arrow.DataType) (bool, bool) {
|
||||
arrowID := arrowType.ID()
|
||||
if arrowID == arrow.STRUCT {
|
||||
arrType := arrowType.(*arrow.StructType)
|
||||
indicesType, ok1 := arrType.FieldByName(sparseVectorIndice)
|
||||
valuesType, ok2 := arrType.FieldByName(sparseVectorValues)
|
||||
if !ok1 || !ok2 {
|
||||
return false, false
|
||||
}
|
||||
|
||||
// indices can be uint32 list or int64 list
|
||||
// values can be float32 list or float64 list
|
||||
isValidType := func(finger string, expectedType arrow.DataType) bool {
|
||||
return finger == arrow.ListOf(expectedType).Fingerprint()
|
||||
}
|
||||
indicesFinger := indicesType.Type.Fingerprint()
|
||||
valuesFinger := valuesType.Type.Fingerprint()
|
||||
indicesTypeIsOK := (isValidType(indicesFinger, arrow.PrimitiveTypes.Int32) ||
|
||||
isValidType(indicesFinger, arrow.PrimitiveTypes.Uint32) ||
|
||||
isValidType(indicesFinger, arrow.PrimitiveTypes.Int64) ||
|
||||
isValidType(indicesFinger, arrow.PrimitiveTypes.Uint64))
|
||||
valuesTypeIsOK := (isValidType(valuesFinger, arrow.PrimitiveTypes.Float32) ||
|
||||
isValidType(valuesFinger, arrow.PrimitiveTypes.Float64))
|
||||
return indicesTypeIsOK && valuesTypeIsOK, false
|
||||
}
|
||||
return arrowID == arrow.STRING, true
|
||||
}
|
||||
|
||||
func convertToArrowDataType(field *schemapb.FieldSchema, isArray bool) (arrow.DataType, error) {
|
||||
dataType := field.GetDataType()
|
||||
if isArray {
|
||||
|
||||
@ -2,6 +2,7 @@ package testutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
|
||||
@ -277,6 +278,118 @@ func CreateFieldWithDefaultValue(dataType schemapb.DataType, id int64, nullable
|
||||
return field, nil
|
||||
}
|
||||
|
||||
func BuildSparseVectorData(mem *memory.GoAllocator, contents [][]byte, arrowType arrow.DataType) (arrow.Array, error) {
|
||||
if arrowType == nil || arrowType.ID() == arrow.STRING {
|
||||
// build sparse vector as JSON-format string
|
||||
builder := array.NewStringBuilder(mem)
|
||||
rows := len(contents)
|
||||
jsonBytesData := make([][]byte, 0)
|
||||
for i := 0; i < rows; i++ {
|
||||
rowVecData := contents[i]
|
||||
mapData := typeutil.SparseFloatBytesToMap(rowVecData)
|
||||
// convert to JSON format
|
||||
jsonBytes, err := json.Marshal(mapData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsonBytesData = append(jsonBytesData, jsonBytes)
|
||||
}
|
||||
builder.AppendValues(lo.Map(jsonBytesData, func(bs []byte, _ int) string {
|
||||
return string(bs)
|
||||
}), nil)
|
||||
return builder.NewStringArray(), nil
|
||||
} else if arrowType.ID() == arrow.STRUCT {
|
||||
// build sparse vector as parquet struct
|
||||
stType, _ := arrowType.(*arrow.StructType)
|
||||
indicesField, ok1 := stType.FieldByName("indices")
|
||||
valuesField, ok2 := stType.FieldByName("values")
|
||||
if !ok1 || !ok2 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("Indices type or values type is missed for sparse vector")
|
||||
}
|
||||
|
||||
indicesList, ok1 := indicesField.Type.(*arrow.ListType)
|
||||
valuesList, ok2 := valuesField.Type.(*arrow.ListType)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("Indices type and values type of sparse vector should be list")
|
||||
}
|
||||
indexType := indicesList.Elem().ID()
|
||||
valueType := valuesList.Elem().ID()
|
||||
|
||||
fields := []arrow.Field{indicesField, valuesField}
|
||||
structType := arrow.StructOf(fields...)
|
||||
builder := array.NewStructBuilder(mem, structType)
|
||||
indicesBuilder := builder.FieldBuilder(0).(*array.ListBuilder)
|
||||
valuesBuilder := builder.FieldBuilder(1).(*array.ListBuilder)
|
||||
|
||||
// The array.Uint32Builder/array.Int64Builder/array.Float32Builder/array.Float64Builder
|
||||
// are derived from array.Builder, but array.Builder doesn't have Append() interface
|
||||
// To call array.Uint32Builder.Value(uint32), we need to explicitly cast the indicesBuilder.ValueBuilder()
|
||||
// to array.Uint32Builder
|
||||
// So, we declare two methods here to avoid type casting in the "for" loop
|
||||
type AppendIndex func(index uint32)
|
||||
type AppendValue func(value float32)
|
||||
|
||||
var appendIndexFunc AppendIndex
|
||||
switch indexType {
|
||||
case arrow.INT32:
|
||||
indicesArrayBuilder := indicesBuilder.ValueBuilder().(*array.Int32Builder)
|
||||
appendIndexFunc = func(index uint32) {
|
||||
indicesArrayBuilder.Append((int32)(index))
|
||||
}
|
||||
case arrow.UINT32:
|
||||
indicesArrayBuilder := indicesBuilder.ValueBuilder().(*array.Uint32Builder)
|
||||
appendIndexFunc = func(index uint32) {
|
||||
indicesArrayBuilder.Append(index)
|
||||
}
|
||||
case arrow.INT64:
|
||||
indicesArrayBuilder := indicesBuilder.ValueBuilder().(*array.Int64Builder)
|
||||
appendIndexFunc = func(index uint32) {
|
||||
indicesArrayBuilder.Append((int64)(index))
|
||||
}
|
||||
case arrow.UINT64:
|
||||
indicesArrayBuilder := indicesBuilder.ValueBuilder().(*array.Uint64Builder)
|
||||
appendIndexFunc = func(index uint32) {
|
||||
indicesArrayBuilder.Append((uint64)(index))
|
||||
}
|
||||
default:
|
||||
msg := fmt.Sprintf("Not able to write this type (%s) for sparse vector index", indexType.String())
|
||||
return nil, merr.WrapErrImportFailed(msg)
|
||||
}
|
||||
|
||||
var appendValueFunc AppendValue
|
||||
switch valueType {
|
||||
case arrow.FLOAT32:
|
||||
valuesArrayBuilder := valuesBuilder.ValueBuilder().(*array.Float32Builder)
|
||||
appendValueFunc = func(value float32) {
|
||||
valuesArrayBuilder.Append(value)
|
||||
}
|
||||
case arrow.FLOAT64:
|
||||
valuesArrayBuilder := valuesBuilder.ValueBuilder().(*array.Float64Builder)
|
||||
appendValueFunc = func(value float32) {
|
||||
valuesArrayBuilder.Append((float64)(value))
|
||||
}
|
||||
default:
|
||||
msg := fmt.Sprintf("Not able to write this type (%s) for sparse vector index", indexType.String())
|
||||
return nil, merr.WrapErrImportFailed(msg)
|
||||
}
|
||||
|
||||
for i := 0; i < len(contents); i++ {
|
||||
builder.Append(true)
|
||||
indicesBuilder.Append(true)
|
||||
valuesBuilder.Append(true)
|
||||
rowVecData := contents[i]
|
||||
elemCount := len(rowVecData) / 8
|
||||
for j := 0; j < elemCount; j++ {
|
||||
appendIndexFunc(common.Endian.Uint32(rowVecData[j*8:]))
|
||||
appendValueFunc(math.Float32frombits(common.Endian.Uint32(rowVecData[j*8+4:])))
|
||||
}
|
||||
}
|
||||
return builder.NewStructArray(), nil
|
||||
}
|
||||
|
||||
return nil, merr.WrapErrParameterInvalidMsg("Invalid arrow data type for sparse vector")
|
||||
}
|
||||
|
||||
func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.InsertData, useNullType bool) ([]arrow.Array, error) {
|
||||
mem := memory.NewGoAllocator()
|
||||
columns := make([]arrow.Array, 0, len(schema.Fields))
|
||||
@ -401,24 +514,12 @@ func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.Inser
|
||||
builder.AppendValues(offsets, valid)
|
||||
columns = append(columns, builder.NewListArray())
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
builder := array.NewStringBuilder(mem)
|
||||
contents := insertData.Data[fieldID].(*storage.SparseFloatVectorFieldData).GetContents()
|
||||
rows := len(contents)
|
||||
jsonBytesData := make([][]byte, 0)
|
||||
for i := 0; i < rows; i++ {
|
||||
rowVecData := contents[i]
|
||||
mapData := typeutil.SparseFloatBytesToMap(rowVecData)
|
||||
// convert to JSON format
|
||||
jsonBytes, err := json.Marshal(mapData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsonBytesData = append(jsonBytesData, jsonBytes)
|
||||
arr, err := BuildSparseVectorData(mem, contents, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
builder.AppendValues(lo.Map(jsonBytesData, func(bs []byte, _ int) string {
|
||||
return string(bs)
|
||||
}), nil)
|
||||
columns = append(columns, builder.NewStringArray())
|
||||
columns = append(columns, arr)
|
||||
case schemapb.DataType_JSON:
|
||||
builder := array.NewStringBuilder(mem)
|
||||
jsonData := insertData.Data[fieldID].(*storage.JSONFieldData).Data
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user