mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
fix: fix parquet import bug in STRUCT (#45028)
issue: https://github.com/milvus-io/milvus/issues/45006 ref: https://github.com/milvus-io/milvus/issues/42148 Previsouly, the parquet import is implemented based on that the STRUCT in the parquet files is hanlded in the way that each field in struct is stored in a single column. However, in the user's perspective, the array of STRUCT contains data is something like STRUCT_A: for one row, [struct{field1_1, field2_1, field3_1}, struct{field1_2, field2_2, field3_2}, ...], rather than {[field1_1, field1_2, ...], [field2_1, field2_2, ...], [field3_1, field3_2, field3_3, ...]}. This PR fixes this. --------- Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
This commit is contained in:
parent
7c627260f3
commit
ce2862d325
@ -47,6 +47,9 @@ type FieldReader struct {
|
||||
dim int
|
||||
field *schemapb.FieldSchema
|
||||
sparseIsString bool
|
||||
|
||||
// structReader is non-nil when Struct Array field exists
|
||||
structReader *StructFieldReader
|
||||
}
|
||||
|
||||
func NewFieldReader(ctx context.Context, reader *pqarrow.FileReader, columnIndex int, field *schemapb.FieldSchema) (*FieldReader, error) {
|
||||
@ -81,6 +84,11 @@ func NewFieldReader(ctx context.Context, reader *pqarrow.FileReader, columnIndex
|
||||
}
|
||||
|
||||
func (c *FieldReader) Next(count int64) (any, any, error) {
|
||||
// Check if this FieldReader wraps a StructFieldReader
|
||||
if c.structReader != nil {
|
||||
return c.structReader.Next(count)
|
||||
}
|
||||
|
||||
switch c.field.GetDataType() {
|
||||
case schemapb.DataType_Bool:
|
||||
if c.field.GetNullable() || c.field.GetDefaultValue() != nil {
|
||||
|
||||
@ -681,7 +681,7 @@ func TestParquetReaderWithStructArray(t *testing.T) {
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 201,
|
||||
Name: "int_array",
|
||||
Name: "struct_array[int_array]",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
@ -690,7 +690,7 @@ func TestParquetReaderWithStructArray(t *testing.T) {
|
||||
},
|
||||
{
|
||||
FieldID: 202,
|
||||
Name: "float_array",
|
||||
Name: "struct_array[float_array]",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Float,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
@ -699,7 +699,7 @@ func TestParquetReaderWithStructArray(t *testing.T) {
|
||||
},
|
||||
{
|
||||
FieldID: 203,
|
||||
Name: "vector_array",
|
||||
Name: "struct_array[vector_array]",
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
@ -716,7 +716,7 @@ func TestParquetReaderWithStructArray(t *testing.T) {
|
||||
filePath := fmt.Sprintf("/tmp/test_struct_array_%d.parquet", rand.Int())
|
||||
defer os.Remove(filePath)
|
||||
|
||||
numRows := 10
|
||||
numRows := 50
|
||||
f, err := os.Create(filePath)
|
||||
assert.NoError(t, err)
|
||||
|
||||
|
||||
292
internal/util/importutilv2/parquet/struct_field_reader.go
Normal file
292
internal/util/importutilv2/parquet/struct_field_reader.go
Normal file
@ -0,0 +1,292 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package parquet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/apache/arrow/go/v17/arrow"
|
||||
"github.com/apache/arrow/go/v17/arrow/array"
|
||||
"github.com/apache/arrow/go/v17/parquet/pqarrow"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
// StructFieldReader reads a specific field from a list<struct> column
|
||||
type StructFieldReader struct {
|
||||
columnReader *pqarrow.ColumnReader
|
||||
field *schemapb.FieldSchema
|
||||
fieldIndex int
|
||||
dim int
|
||||
}
|
||||
|
||||
// NewStructFieldReader creates a reader for extracting a field from nested struct
|
||||
func NewStructFieldReader(ctx context.Context, fileReader *pqarrow.FileReader, columnIndex int,
|
||||
fieldIndex int, field *schemapb.FieldSchema,
|
||||
) (*FieldReader, error) {
|
||||
columnReader, err := fileReader.GetColumn(ctx, columnIndex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dim := 0
|
||||
if typeutil.IsVectorType(field.GetDataType()) && !typeutil.IsSparseFloatVectorType(field.GetDataType()) {
|
||||
d, err := typeutil.GetDim(field)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dim = int(d)
|
||||
} else if field.GetDataType() == schemapb.DataType_ArrayOfVector {
|
||||
// For ArrayOfVector, get the dimension from the element type
|
||||
d, err := typeutil.GetDim(field)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dim = int(d)
|
||||
}
|
||||
|
||||
sfr := &StructFieldReader{
|
||||
columnReader: columnReader,
|
||||
field: field,
|
||||
fieldIndex: fieldIndex,
|
||||
dim: dim,
|
||||
}
|
||||
|
||||
fr := &FieldReader{
|
||||
columnIndex: columnIndex,
|
||||
columnReader: columnReader,
|
||||
field: field,
|
||||
dim: dim,
|
||||
structReader: sfr,
|
||||
}
|
||||
|
||||
return fr, nil
|
||||
}
|
||||
|
||||
// Next extracts the specific field from struct array
|
||||
func (r *StructFieldReader) Next(count int64) (any, any, error) {
|
||||
chunked, err := r.columnReader.NextBatch(count)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// If no more data, return nil to signal EOF
|
||||
if chunked.Len() == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
switch r.field.GetDataType() {
|
||||
case schemapb.DataType_Array:
|
||||
return r.readArrayField(chunked)
|
||||
case schemapb.DataType_ArrayOfVector:
|
||||
return r.readArrayOfVectorField(chunked)
|
||||
default:
|
||||
return nil, nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type for struct field: %v", r.field.GetDataType()))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *StructFieldReader) toScalarField(data []interface{}) *schemapb.ScalarField {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch r.field.GetElementType() {
|
||||
case schemapb.DataType_Int32:
|
||||
intData := make([]int32, len(data))
|
||||
for i, v := range data {
|
||||
if val, ok := v.(int32); ok {
|
||||
intData[i] = val
|
||||
}
|
||||
}
|
||||
return &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{Data: intData},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
floatData := make([]float32, len(data))
|
||||
for i, v := range data {
|
||||
if val, ok := v.(float32); ok {
|
||||
floatData[i] = val
|
||||
}
|
||||
}
|
||||
return &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{Data: floatData},
|
||||
},
|
||||
}
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
strData := make([]string, len(data))
|
||||
for i, v := range data {
|
||||
if val, ok := v.(string); ok {
|
||||
strData[i] = val
|
||||
}
|
||||
}
|
||||
return &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{Data: strData},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *StructFieldReader) readArrayField(chunked *arrow.Chunked) (any, any, error) {
|
||||
result := make([]*schemapb.ScalarField, 0)
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
listArray, ok := chunk.(*array.List)
|
||||
if !ok {
|
||||
return nil, nil, merr.WrapErrImportFailed("expected list array for struct field")
|
||||
}
|
||||
|
||||
structArray, ok := listArray.ListValues().(*array.Struct)
|
||||
if !ok {
|
||||
return nil, nil, merr.WrapErrImportFailed("expected struct in list")
|
||||
}
|
||||
|
||||
fieldArray := structArray.Field(r.fieldIndex)
|
||||
offsets := listArray.Offsets()
|
||||
|
||||
for i := 0; i < len(offsets)-1; i++ {
|
||||
startIdx := offsets[i]
|
||||
endIdx := offsets[i+1]
|
||||
|
||||
var combinedData []interface{}
|
||||
for structIdx := startIdx; structIdx < endIdx; structIdx++ {
|
||||
switch field := fieldArray.(type) {
|
||||
case *array.Boolean:
|
||||
if !field.IsNull(int(structIdx)) {
|
||||
combinedData = append(combinedData, field.Value(int(structIdx)))
|
||||
}
|
||||
case *array.Int8:
|
||||
if !field.IsNull(int(structIdx)) {
|
||||
combinedData = append(combinedData, field.Value(int(structIdx)))
|
||||
}
|
||||
case *array.Int16:
|
||||
if !field.IsNull(int(structIdx)) {
|
||||
combinedData = append(combinedData, field.Value(int(structIdx)))
|
||||
}
|
||||
case *array.Int32:
|
||||
if !field.IsNull(int(structIdx)) {
|
||||
combinedData = append(combinedData, field.Value(int(structIdx)))
|
||||
}
|
||||
case *array.Int64:
|
||||
if !field.IsNull(int(structIdx)) {
|
||||
combinedData = append(combinedData, field.Value(int(structIdx)))
|
||||
}
|
||||
case *array.Float32:
|
||||
if !field.IsNull(int(structIdx)) {
|
||||
combinedData = append(combinedData, field.Value(int(structIdx)))
|
||||
}
|
||||
case *array.Float64:
|
||||
if !field.IsNull(int(structIdx)) {
|
||||
combinedData = append(combinedData, field.Value(int(structIdx)))
|
||||
}
|
||||
case *array.String:
|
||||
if !field.IsNull(int(structIdx)) {
|
||||
combinedData = append(combinedData, field.Value(int(structIdx)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a single ScalarField for this row
|
||||
scalarField := r.toScalarField(combinedData)
|
||||
if scalarField != nil {
|
||||
result = append(result, scalarField)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil, nil
|
||||
}
|
||||
|
||||
func (r *StructFieldReader) readArrayOfVectorField(chunked *arrow.Chunked) (any, any, error) {
|
||||
var result []*schemapb.VectorField
|
||||
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
listArray, ok := chunk.(*array.List)
|
||||
if !ok {
|
||||
return nil, nil, merr.WrapErrImportFailed("expected list array for struct field")
|
||||
}
|
||||
|
||||
structArray, ok := listArray.ListValues().(*array.Struct)
|
||||
if !ok {
|
||||
return nil, nil, merr.WrapErrImportFailed("expected struct in list")
|
||||
}
|
||||
|
||||
// Get the field array - it should be a list<primitives> (one vector per struct)
|
||||
fieldArray, ok := structArray.Field(r.fieldIndex).(*array.List)
|
||||
if !ok {
|
||||
return nil, nil, merr.WrapErrImportFailed("expected list array for vector field")
|
||||
}
|
||||
|
||||
offsets := listArray.Offsets()
|
||||
|
||||
// Process each row
|
||||
for i := 0; i < len(offsets)-1; i++ {
|
||||
startIdx := offsets[i]
|
||||
endIdx := offsets[i+1]
|
||||
|
||||
// Extract vectors based on element type
|
||||
switch r.field.GetElementType() {
|
||||
case schemapb.DataType_FloatVector:
|
||||
var allVectors []float32
|
||||
for structIdx := startIdx; structIdx < endIdx; structIdx++ {
|
||||
vecStart, vecEnd := fieldArray.ValueOffsets(int(structIdx))
|
||||
if floatArr, ok := fieldArray.ListValues().(*array.Float32); ok {
|
||||
for j := vecStart; j < vecEnd; j++ {
|
||||
allVectors = append(allVectors, floatArr.Value(int(j)))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(allVectors) > 0 {
|
||||
vectorField := &schemapb.VectorField{
|
||||
Dim: int64(r.dim),
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{Data: allVectors},
|
||||
},
|
||||
}
|
||||
result = append(result, vectorField)
|
||||
}
|
||||
|
||||
case schemapb.DataType_BinaryVector:
|
||||
return nil, nil, merr.WrapErrImportFailed("ArrayOfVector with BinaryVector element type is not implemented yet")
|
||||
|
||||
case schemapb.DataType_Float16Vector:
|
||||
return nil, nil, merr.WrapErrImportFailed("ArrayOfVector with Float16Vector element type is not implemented yet")
|
||||
|
||||
case schemapb.DataType_BFloat16Vector:
|
||||
return nil, nil, merr.WrapErrImportFailed("ArrayOfVector with BFloat16Vector element type is not implemented yet")
|
||||
|
||||
case schemapb.DataType_Int8Vector:
|
||||
return nil, nil, merr.WrapErrImportFailed("ArrayOfVector with Int8Vector element type is not implemented yet")
|
||||
|
||||
case schemapb.DataType_SparseFloatVector:
|
||||
return nil, nil, merr.WrapErrImportFailed("ArrayOfVector with SparseFloatVector element type is not implemented yet")
|
||||
|
||||
default:
|
||||
return nil, nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported ArrayOfVector element type: %v", r.field.GetElementType()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil, nil
|
||||
}
|
||||
@ -26,7 +26,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/storage"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
@ -78,6 +77,43 @@ func CreateFieldReaders(ctx context.Context, fileReader *pqarrow.FileReader, sch
|
||||
return nil, merr.WrapErrImportFailed(fmt.Sprintf("get parquet schema failed, err=%v", err))
|
||||
}
|
||||
|
||||
// Check if we have nested struct format
|
||||
nestedStructs := make(map[string]int) // struct name -> column index
|
||||
for _, structField := range schema.StructArrayFields {
|
||||
for i, pqField := range pqSchema.Fields() {
|
||||
if pqField.Name != structField.Name {
|
||||
continue
|
||||
}
|
||||
listType, ok := pqField.Type.(*arrow.ListType)
|
||||
if !ok {
|
||||
return nil, merr.WrapErrImportFailed(fmt.Sprintf("struct field is not a list of structs: %s", structField.Name))
|
||||
}
|
||||
structType, ok := listType.Elem().(*arrow.StructType)
|
||||
if !ok {
|
||||
return nil, merr.WrapErrImportFailed(fmt.Sprintf("struct field is not a list of structs: %s", structField.Name))
|
||||
}
|
||||
nestedStructs[structField.Name] = i
|
||||
// Verify struct fields match
|
||||
for _, subField := range structField.Fields {
|
||||
fieldName, err := typeutil.ExtractStructFieldName(subField.Name)
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrImportFailed(err.Error())
|
||||
}
|
||||
found := false
|
||||
for _, f := range structType.Fields() {
|
||||
if f.Name == fieldName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, merr.WrapErrImportFailed(fmt.Sprintf("field not found in struct: %s", fieldName))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Original flat format handling
|
||||
err = isSchemaEqual(schema, pqSchema)
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrImportFailed(fmt.Sprintf("schema not equal, err=%v", err))
|
||||
@ -88,6 +124,11 @@ func CreateFieldReaders(ctx context.Context, fileReader *pqarrow.FileReader, sch
|
||||
crs := make(map[int64]*FieldReader)
|
||||
allowInsertAutoID, _ := common.IsAllowInsertAutoID(schema.GetProperties()...)
|
||||
for i, pqField := range pqSchema.Fields() {
|
||||
// Skip if it's a struct column
|
||||
if _, isStruct := nestedStructs[pqField.Name]; isStruct {
|
||||
continue
|
||||
}
|
||||
|
||||
field, ok := nameToField[pqField.Name]
|
||||
if !ok {
|
||||
// redundant fields, ignore. only accepts a special field "$meta" to store dynamic data
|
||||
@ -117,6 +158,46 @@ func CreateFieldReaders(ctx context.Context, fileReader *pqarrow.FileReader, sch
|
||||
readFields[field.GetName()] = field.GetFieldID()
|
||||
}
|
||||
|
||||
for _, structField := range schema.StructArrayFields {
|
||||
columnIndex, ok := nestedStructs[structField.Name]
|
||||
if !ok {
|
||||
return nil, merr.WrapErrImportFailed(fmt.Sprintf("struct field not found in parquet schema: %s", structField.Name))
|
||||
}
|
||||
|
||||
listType := pqSchema.Field(columnIndex).Type.(*arrow.ListType)
|
||||
structType := listType.Elem().(*arrow.StructType)
|
||||
|
||||
// Create reader for each sub-field
|
||||
for _, subField := range structField.Fields {
|
||||
// Find field index in struct
|
||||
fieldName, err := typeutil.ExtractStructFieldName(subField.Name)
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrImportFailed(err.Error())
|
||||
}
|
||||
|
||||
fieldIndex := -1
|
||||
for i, f := range structType.Fields() {
|
||||
if f.Name == fieldName {
|
||||
fieldIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if fieldIndex == -1 {
|
||||
return nil, merr.WrapErrImportFailed(fmt.Sprintf("field not found in struct: %s", fieldName))
|
||||
}
|
||||
|
||||
// Create struct field reader
|
||||
reader, err := NewStructFieldReader(ctx, fileReader, columnIndex, fieldIndex, subField)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
crs[subField.FieldID] = reader
|
||||
readFields[subField.Name] = subField.FieldID
|
||||
}
|
||||
}
|
||||
|
||||
// this loop is for "are there any fields not provided in the parquet file?"
|
||||
for _, field := range nameToField {
|
||||
// auto-id field, function output field already checked
|
||||
@ -234,6 +315,29 @@ func IsValidSparseVectorSchema(arrowType arrow.DataType) (bool, bool) {
|
||||
return arrowID == arrow.STRING, true
|
||||
}
|
||||
|
||||
// For ArrayOfVector, use natural user format (list of list of primitives)
|
||||
// instead of internal fixed_size_binary format
|
||||
func convertElementTypeOfVectorArrayToArrowType(field *schemapb.FieldSchema) (arrow.DataType, error) {
|
||||
if field.GetDataType() != schemapb.DataType_ArrayOfVector {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("field is not a vector array: %v", field.GetDataType().String())
|
||||
}
|
||||
|
||||
var elemType arrow.DataType
|
||||
switch field.GetElementType() {
|
||||
case schemapb.DataType_FloatVector:
|
||||
elemType = arrow.ListOf(arrow.PrimitiveTypes.Float32)
|
||||
case schemapb.DataType_BinaryVector:
|
||||
elemType = arrow.ListOf(arrow.PrimitiveTypes.Uint8)
|
||||
case schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
|
||||
elemType = arrow.ListOf(arrow.PrimitiveTypes.Float32)
|
||||
case schemapb.DataType_Int8Vector:
|
||||
elemType = arrow.ListOf(arrow.PrimitiveTypes.Int8)
|
||||
default:
|
||||
return nil, merr.WrapErrParameterInvalidMsg("unsupported element type for ArrayOfVector: %v", field.GetElementType().String())
|
||||
}
|
||||
return elemType, nil
|
||||
}
|
||||
|
||||
func convertToArrowDataType(field *schemapb.FieldSchema, isArray bool) (arrow.DataType, error) {
|
||||
dataType := field.GetDataType()
|
||||
if isArray {
|
||||
@ -295,11 +399,7 @@ func convertToArrowDataType(field *schemapb.FieldSchema, isArray bool) (arrow.Da
|
||||
Metadata: arrow.Metadata{},
|
||||
}), nil
|
||||
case schemapb.DataType_ArrayOfVector:
|
||||
dim, err := typeutil.GetDim(field)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
elemType, err := storage.VectorArrayToArrowType(field.GetElementType(), int(dim))
|
||||
elemType, err := convertElementTypeOfVectorArrayToArrowType(field)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -317,11 +417,9 @@ func convertToArrowDataType(field *schemapb.FieldSchema, isArray bool) (arrow.Da
|
||||
// This method is used only by import util and related tests. Returned arrow.Schema
|
||||
// doesn't include function output fields.
|
||||
func ConvertToArrowSchemaForUT(schema *schemapb.CollectionSchema, useNullType bool) (*arrow.Schema, error) {
|
||||
// Get all fields including struct sub-fields
|
||||
allFields := typeutil.GetAllFieldSchemas(schema)
|
||||
arrFields := make([]arrow.Field, 0, len(allFields))
|
||||
arrFields := make([]arrow.Field, 0, 10)
|
||||
|
||||
for _, field := range allFields {
|
||||
for _, field := range schema.Fields {
|
||||
if typeutil.IsAutoPKField(field) || field.GetIsFunctionOutput() {
|
||||
continue
|
||||
}
|
||||
@ -344,19 +442,66 @@ func ConvertToArrowSchemaForUT(schema *schemapb.CollectionSchema, useNullType bo
|
||||
Metadata: arrow.Metadata{},
|
||||
})
|
||||
}
|
||||
|
||||
for _, structField := range schema.StructArrayFields {
|
||||
// Build struct fields for row-wise format
|
||||
structFields := make([]arrow.Field, 0, len(structField.Fields))
|
||||
for _, subField := range structField.Fields {
|
||||
fieldName, err := typeutil.ExtractStructFieldName(subField.Name)
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrImportFailed(err.Error())
|
||||
}
|
||||
|
||||
var arrDataType arrow.DataType
|
||||
switch subField.DataType {
|
||||
case schemapb.DataType_Array:
|
||||
arrDataType, err = convertToArrowDataType(subField, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
case schemapb.DataType_ArrayOfVector:
|
||||
arrDataType, err = convertElementTypeOfVectorArrayToArrowType(subField)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
err = merr.WrapErrParameterInvalidMsg("unsupported data type in struct: %v", subField.DataType.String())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
structFields = append(structFields, arrow.Field{
|
||||
Name: fieldName,
|
||||
Type: arrDataType,
|
||||
Nullable: subField.GetNullable(),
|
||||
})
|
||||
}
|
||||
|
||||
// Create list<struct> type
|
||||
structType := arrow.StructOf(structFields...)
|
||||
listType := arrow.ListOf(structType)
|
||||
|
||||
arrFields = append(arrFields, arrow.Field{
|
||||
Name: structField.Name,
|
||||
Type: listType,
|
||||
Nullable: false,
|
||||
})
|
||||
}
|
||||
|
||||
return arrow.NewSchema(arrFields, nil), nil
|
||||
}
|
||||
|
||||
func isSchemaEqual(schema *schemapb.CollectionSchema, arrSchema *arrow.Schema) error {
|
||||
// Get all fields including struct sub-fields
|
||||
allFields := typeutil.GetAllFieldSchemas(schema)
|
||||
|
||||
arrNameToField := lo.KeyBy(arrSchema.Fields(), func(field arrow.Field) string {
|
||||
return field.Name
|
||||
})
|
||||
|
||||
// Check all fields (including struct sub-fields which are stored as separate columns)
|
||||
for _, field := range allFields {
|
||||
for _, field := range schema.Fields {
|
||||
// ignore autoPKField and functionOutputField
|
||||
if typeutil.IsAutoPKField(field) || field.GetIsFunctionOutput() {
|
||||
continue
|
||||
@ -381,6 +526,77 @@ func isSchemaEqual(schema *schemapb.CollectionSchema, arrSchema *arrow.Schema) e
|
||||
field.Name, toArrDataType.String(), arrField.Type.String()))
|
||||
}
|
||||
}
|
||||
|
||||
for _, structField := range schema.StructArrayFields {
|
||||
arrStructField, ok := arrNameToField[structField.Name]
|
||||
if !ok {
|
||||
return merr.WrapErrImportFailed(fmt.Sprintf("struct field not found in arrow schema: %s", structField.Name))
|
||||
}
|
||||
|
||||
// Verify the arrow field is list<struct> type
|
||||
listType, ok := arrStructField.Type.(*arrow.ListType)
|
||||
if !ok {
|
||||
return merr.WrapErrImportFailed(fmt.Sprintf("struct field '%s' should be list type in arrow schema, but got '%s'",
|
||||
structField.Name, arrStructField.Type.String()))
|
||||
}
|
||||
|
||||
structType, ok := listType.Elem().(*arrow.StructType)
|
||||
if !ok {
|
||||
return merr.WrapErrImportFailed(fmt.Sprintf("struct field '%s' should contain struct elements in arrow schema, but got '%s'",
|
||||
structField.Name, listType.Elem().String()))
|
||||
}
|
||||
|
||||
// Create a map of struct field names to arrow.Field for quick lookup
|
||||
structFieldMap := make(map[string]arrow.Field)
|
||||
for _, arrowField := range structType.Fields() {
|
||||
structFieldMap[arrowField.Name] = arrowField
|
||||
}
|
||||
|
||||
// Verify each sub-field in the struct
|
||||
for _, subField := range structField.Fields {
|
||||
// Extract actual field name (remove structName[] prefix if present)
|
||||
fieldName, err := typeutil.ExtractStructFieldName(subField.Name)
|
||||
if err != nil {
|
||||
return merr.WrapErrImportFailed(err.Error())
|
||||
}
|
||||
|
||||
arrowSubField, ok := structFieldMap[fieldName]
|
||||
if !ok {
|
||||
return merr.WrapErrImportFailed(fmt.Sprintf("sub-field '%s' not found in struct '%s' of arrow schema",
|
||||
fieldName, structField.Name))
|
||||
}
|
||||
|
||||
// Convert Milvus field type to expected Arrow type
|
||||
var expectedArrowType arrow.DataType
|
||||
|
||||
switch subField.DataType {
|
||||
case schemapb.DataType_Array:
|
||||
// For Array type, need to convert based on element type
|
||||
expectedArrowType, err = convertToArrowDataType(subField, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case schemapb.DataType_ArrayOfVector:
|
||||
expectedArrowType, err = convertElementTypeOfVectorArrayToArrowType(subField)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type in struct field: %v", subField.DataType))
|
||||
}
|
||||
|
||||
// Check if the arrow type is convertible to the expected type
|
||||
if !isArrowDataTypeConvertible(arrowSubField.Type, expectedArrowType, subField) {
|
||||
return merr.WrapErrImportFailed(fmt.Sprintf("sub-field '%s' in struct '%s' type mis-match, expect arrow type '%s', got '%s'",
|
||||
fieldName, structField.Name, expectedArrowType.String(), arrowSubField.Type.String()))
|
||||
}
|
||||
}
|
||||
|
||||
if len(structFieldMap) != len(structField.Fields) {
|
||||
return merr.WrapErrImportFailed(fmt.Sprintf("struct field number dismatch: %s, expect %d, got %d", structField.Name, len(structField.Fields), len(structFieldMap)))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -418,14 +418,18 @@ func BuildSparseVectorData(mem *memory.GoAllocator, contents [][]byte, arrowType
|
||||
|
||||
func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.InsertData, useNullType bool) ([]arrow.Array, error) {
|
||||
mem := memory.NewGoAllocator()
|
||||
// Get all fields including struct sub-fields
|
||||
allFields := typeutil.GetAllFieldSchemas(schema)
|
||||
// Filter out auto-generated and function output fields
|
||||
fields := lo.Filter(allFields, func(field *schemapb.FieldSchema, _ int) bool {
|
||||
return !(field.GetIsPrimaryKey() && field.GetAutoID()) && !field.GetIsFunctionOutput()
|
||||
columns := make([]arrow.Array, 0)
|
||||
|
||||
// Filter out auto-generated, function output, and nested struct sub-fields
|
||||
fields := lo.Filter(schema.Fields, func(field *schemapb.FieldSchema, _ int) bool {
|
||||
// Skip auto PK, function output, and struct sub-fields (if using nested format)
|
||||
if (field.GetIsPrimaryKey() && field.GetAutoID()) || field.GetIsFunctionOutput() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
columns := make([]arrow.Array, 0, len(fields))
|
||||
// Build regular field columns
|
||||
for _, field := range fields {
|
||||
fieldID := field.GetFieldID()
|
||||
dataType := field.GetDataType()
|
||||
@ -844,6 +848,192 @@ func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.Inser
|
||||
columns = append(columns, listBuilder.NewListArray())
|
||||
}
|
||||
}
|
||||
|
||||
// Process StructArrayFields as nested list<struct> format
|
||||
for _, structField := range schema.StructArrayFields {
|
||||
// Build arrow fields for the struct
|
||||
structFields := make([]arrow.Field, 0, len(structField.Fields))
|
||||
for _, subField := range structField.Fields {
|
||||
// Extract actual field name (remove structName[] prefix)
|
||||
fieldName := subField.Name
|
||||
if len(structField.Name) > 0 && len(subField.Name) > len(structField.Name)+2 {
|
||||
fieldName = subField.Name[len(structField.Name)+1 : len(subField.Name)-1]
|
||||
}
|
||||
|
||||
// Determine arrow type for the field
|
||||
var arrType arrow.DataType
|
||||
switch subField.DataType {
|
||||
case schemapb.DataType_Array:
|
||||
switch subField.ElementType {
|
||||
case schemapb.DataType_Bool:
|
||||
arrType = arrow.FixedWidthTypes.Boolean
|
||||
case schemapb.DataType_Int8:
|
||||
arrType = arrow.PrimitiveTypes.Int8
|
||||
case schemapb.DataType_Int16:
|
||||
arrType = arrow.PrimitiveTypes.Int16
|
||||
case schemapb.DataType_Int32:
|
||||
arrType = arrow.PrimitiveTypes.Int32
|
||||
case schemapb.DataType_Int64:
|
||||
arrType = arrow.PrimitiveTypes.Int64
|
||||
case schemapb.DataType_Float:
|
||||
arrType = arrow.PrimitiveTypes.Float32
|
||||
case schemapb.DataType_Double:
|
||||
arrType = arrow.PrimitiveTypes.Float64
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
arrType = arrow.BinaryTypes.String
|
||||
default:
|
||||
// Default to string for unknown element types
|
||||
arrType = arrow.BinaryTypes.String
|
||||
}
|
||||
case schemapb.DataType_ArrayOfVector:
|
||||
// For user data, use list<float> format for vectors
|
||||
switch subField.ElementType {
|
||||
case schemapb.DataType_FloatVector:
|
||||
arrType = arrow.ListOf(arrow.PrimitiveTypes.Float32)
|
||||
case schemapb.DataType_BinaryVector:
|
||||
arrType = arrow.ListOf(arrow.PrimitiveTypes.Uint8)
|
||||
case schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector:
|
||||
arrType = arrow.ListOf(arrow.PrimitiveTypes.Float32)
|
||||
case schemapb.DataType_Int8Vector:
|
||||
arrType = arrow.ListOf(arrow.PrimitiveTypes.Int8)
|
||||
default:
|
||||
panic("unimplemented element type for ArrayOfVector")
|
||||
}
|
||||
default:
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
structFields = append(structFields, arrow.Field{
|
||||
Name: fieldName,
|
||||
Type: arrType,
|
||||
Nullable: subField.GetNullable(),
|
||||
})
|
||||
}
|
||||
|
||||
// Build list<struct> column
|
||||
listBuilder := array.NewListBuilder(mem, arrow.StructOf(structFields...))
|
||||
structBuilder := listBuilder.ValueBuilder().(*array.StructBuilder)
|
||||
|
||||
// Get row count from first sub-field
|
||||
var rowCount int
|
||||
for _, subField := range structField.Fields {
|
||||
if data, ok := insertData.Data[subField.FieldID]; ok {
|
||||
rowCount = data.RowNum()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// row to column
|
||||
for i := 0; i < rowCount; i++ {
|
||||
var arrayLen int
|
||||
subField := structField.Fields[0]
|
||||
data := insertData.Data[subField.FieldID]
|
||||
if data == nil {
|
||||
panic(fmt.Sprintf("data for struct sub-field %s (ID: %d) is nil", subField.Name, subField.FieldID))
|
||||
}
|
||||
rowData := data.GetRow(i)
|
||||
switch subField.DataType {
|
||||
case schemapb.DataType_Array:
|
||||
scalarField := rowData.(*schemapb.ScalarField)
|
||||
switch subField.ElementType {
|
||||
case schemapb.DataType_Bool:
|
||||
arrayLen = len(scalarField.GetBoolData().GetData())
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||
arrayLen = len(scalarField.GetIntData().GetData())
|
||||
case schemapb.DataType_Int64:
|
||||
arrayLen = len(scalarField.GetLongData().GetData())
|
||||
case schemapb.DataType_Float:
|
||||
arrayLen = len(scalarField.GetFloatData().GetData())
|
||||
case schemapb.DataType_Double:
|
||||
arrayLen = len(scalarField.GetDoubleData().GetData())
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
arrayLen = len(scalarField.GetStringData().GetData())
|
||||
}
|
||||
case schemapb.DataType_ArrayOfVector:
|
||||
vectorField := rowData.(*schemapb.VectorField)
|
||||
if vectorField.GetFloatVector() != nil {
|
||||
dim, _ := typeutil.GetDim(subField)
|
||||
arrayLen = len(vectorField.GetFloatVector().Data) / int(dim)
|
||||
}
|
||||
}
|
||||
|
||||
listBuilder.Append(true)
|
||||
// generate a struct for each array element
|
||||
for j := 0; j < arrayLen; j++ {
|
||||
// add data for each field at this position
|
||||
for fieldIdx, subField := range structField.Fields {
|
||||
data := insertData.Data[subField.FieldID]
|
||||
fieldBuilder := structBuilder.FieldBuilder(fieldIdx)
|
||||
|
||||
rowData := data.GetRow(i)
|
||||
switch subField.DataType {
|
||||
case schemapb.DataType_Array:
|
||||
scalarField := rowData.(*schemapb.ScalarField)
|
||||
switch subField.ElementType {
|
||||
case schemapb.DataType_Bool:
|
||||
if boolData := scalarField.GetBoolData(); boolData != nil && j < len(boolData.GetData()) {
|
||||
fieldBuilder.(*array.BooleanBuilder).Append(boolData.GetData()[j])
|
||||
} else {
|
||||
fieldBuilder.(*array.BooleanBuilder).AppendNull()
|
||||
}
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||
if intData := scalarField.GetIntData(); intData != nil && j < len(intData.GetData()) {
|
||||
fieldBuilder.(*array.Int32Builder).Append(intData.GetData()[j])
|
||||
} else {
|
||||
fieldBuilder.(*array.Int32Builder).AppendNull()
|
||||
}
|
||||
case schemapb.DataType_Int64:
|
||||
if longData := scalarField.GetLongData(); longData != nil && j < len(longData.GetData()) {
|
||||
fieldBuilder.(*array.Int64Builder).Append(longData.GetData()[j])
|
||||
} else {
|
||||
fieldBuilder.(*array.Int64Builder).AppendNull()
|
||||
}
|
||||
case schemapb.DataType_Float:
|
||||
if floatData := scalarField.GetFloatData(); floatData != nil && j < len(floatData.GetData()) {
|
||||
fieldBuilder.(*array.Float32Builder).Append(floatData.GetData()[j])
|
||||
} else {
|
||||
fieldBuilder.(*array.Float32Builder).AppendNull()
|
||||
}
|
||||
case schemapb.DataType_Double:
|
||||
if doubleData := scalarField.GetDoubleData(); doubleData != nil && j < len(doubleData.GetData()) {
|
||||
fieldBuilder.(*array.Float64Builder).Append(doubleData.GetData()[j])
|
||||
} else {
|
||||
fieldBuilder.(*array.Float64Builder).AppendNull()
|
||||
}
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
if stringData := scalarField.GetStringData(); stringData != nil && j < len(stringData.GetData()) {
|
||||
fieldBuilder.(*array.StringBuilder).Append(stringData.GetData()[j])
|
||||
} else {
|
||||
fieldBuilder.(*array.StringBuilder).AppendNull()
|
||||
}
|
||||
}
|
||||
|
||||
case schemapb.DataType_ArrayOfVector:
|
||||
vectorField := rowData.(*schemapb.VectorField)
|
||||
listBuilder := fieldBuilder.(*array.ListBuilder)
|
||||
listBuilder.Append(true)
|
||||
|
||||
if floatVectors := vectorField.GetFloatVector(); floatVectors != nil {
|
||||
dim, _ := typeutil.GetDim(subField)
|
||||
floatBuilder := listBuilder.ValueBuilder().(*array.Float32Builder)
|
||||
start := j * int(dim)
|
||||
end := start + int(dim)
|
||||
if end <= len(floatVectors.Data) {
|
||||
for k := start; k < end; k++ {
|
||||
floatBuilder.Append(floatVectors.Data[k])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
structBuilder.Append(true)
|
||||
}
|
||||
}
|
||||
|
||||
columns = append(columns, listBuilder.NewArray())
|
||||
}
|
||||
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
@ -2462,3 +2463,14 @@ func IsBm25FunctionInputField(coll *schemapb.CollectionSchema, field *schemapb.F
|
||||
func ConcatStructFieldName(structName string, fieldName string) string {
|
||||
return fmt.Sprintf("%s[%s]", structName, fieldName)
|
||||
}
|
||||
|
||||
func ExtractStructFieldName(fieldName string) (string, error) {
|
||||
parts := strings.Split(fieldName, "[")
|
||||
if len(parts) == 1 {
|
||||
return fieldName, nil
|
||||
} else if len(parts) == 2 {
|
||||
return parts[1][:len(parts[1])-1], nil
|
||||
} else {
|
||||
return "", fmt.Errorf("invalid struct field name: %s, more than one [ found", fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
@ -301,7 +301,7 @@ func (s *BulkInsertSuite) runForStructArray() {
|
||||
}
|
||||
|
||||
func (s *BulkInsertSuite) TestImportWithVectorArray() {
|
||||
fileTypeArr := []importutilv2.FileType{importutilv2.CSV, importutilv2.JSON}
|
||||
fileTypeArr := []importutilv2.FileType{importutilv2.CSV, importutilv2.JSON, importutilv2.Parquet}
|
||||
for _, fileType := range fileTypeArr {
|
||||
s.fileType = fileType
|
||||
s.vecType = schemapb.DataType_FloatVector
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user