mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
fix: Not allow to import null element of array field from parquet (#43964)
issue: https://github.com/milvus-io/milvus/issues/43819 Before this fix: null elements are converted to zero or empty strings After this fix: import job will return error "array element is not allowed to be null value for field xxx" Signed-off-by: yhmo <yihua.mo@zilliz.com>
This commit is contained in:
parent
575345ae7b
commit
ccb0db92e7
@ -149,9 +149,9 @@ func (c *FieldReader) Next(count int64) (any, any, error) {
|
||||
return data, nil, typeutil.VerifyFloats64(data.([]float64))
|
||||
case schemapb.DataType_VarChar, schemapb.DataType_String:
|
||||
if c.field.GetNullable() || c.field.GetDefaultValue() != nil {
|
||||
return ReadNullableVarcharData(c, count)
|
||||
return ReadNullableStringData(c, count, true)
|
||||
}
|
||||
data, err := ReadVarcharData(c, count)
|
||||
data, err := ReadStringData(c, count, true)
|
||||
return data, nil, err
|
||||
case schemapb.DataType_JSON:
|
||||
// json has not support default_value
|
||||
@ -220,12 +220,13 @@ func ReadBoolData(pcr *FieldReader, count int64) (any, error) {
|
||||
data := make([]bool, 0, count)
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
dataNums := chunk.Data().Len()
|
||||
boolReader, ok := chunk.(*array.Boolean)
|
||||
if boolReader.NullN() > 0 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("not nullable, but has null value")
|
||||
if chunk.NullN() > 0 {
|
||||
return nil, WrapNullRowErr(pcr.field)
|
||||
}
|
||||
boolReader, ok := chunk.(*array.Boolean)
|
||||
|
||||
if !ok {
|
||||
return nil, WrapTypeErr("bool", chunk.DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
for i := 0; i < dataNums; i++ {
|
||||
data = append(data, boolReader.Value(i))
|
||||
@ -275,7 +276,7 @@ func ReadNullableBoolData(pcr *FieldReader, count int64) (any, []bool, error) {
|
||||
// the chunk type may be *array.Null if the data in chunk is all null
|
||||
_, ok := chunk.(*array.Null)
|
||||
if !ok {
|
||||
return nil, nil, WrapTypeErr("bool|null", chunk.DataType().Name(), pcr.field)
|
||||
return nil, nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
validData = append(validData, make([]bool, dataNums)...)
|
||||
data = append(data, make([]bool, dataNums)...)
|
||||
@ -307,57 +308,42 @@ func ReadIntegerOrFloatData[T constraints.Integer | constraints.Float](pcr *Fiel
|
||||
data := make([]T, 0, count)
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
dataNums := chunk.Data().Len()
|
||||
if chunk.NullN() > 0 {
|
||||
return nil, WrapNullRowErr(pcr.field)
|
||||
}
|
||||
switch chunk.DataType().ID() {
|
||||
case arrow.INT8:
|
||||
int8Reader := chunk.(*array.Int8)
|
||||
if int8Reader.NullN() > 0 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("not nullable, but has null value")
|
||||
}
|
||||
for i := 0; i < dataNums; i++ {
|
||||
data = append(data, T(int8Reader.Value(i)))
|
||||
}
|
||||
case arrow.INT16:
|
||||
int16Reader := chunk.(*array.Int16)
|
||||
if int16Reader.NullN() > 0 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("not nullable, but has null value")
|
||||
}
|
||||
for i := 0; i < dataNums; i++ {
|
||||
data = append(data, T(int16Reader.Value(i)))
|
||||
}
|
||||
case arrow.INT32:
|
||||
int32Reader := chunk.(*array.Int32)
|
||||
if int32Reader.NullN() > 0 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("not nullable, but has null value")
|
||||
}
|
||||
for i := 0; i < dataNums; i++ {
|
||||
data = append(data, T(int32Reader.Value(i)))
|
||||
}
|
||||
case arrow.INT64:
|
||||
int64Reader := chunk.(*array.Int64)
|
||||
if int64Reader.NullN() > 0 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("not nullable, but has null value")
|
||||
}
|
||||
for i := 0; i < dataNums; i++ {
|
||||
data = append(data, T(int64Reader.Value(i)))
|
||||
}
|
||||
case arrow.FLOAT32:
|
||||
float32Reader := chunk.(*array.Float32)
|
||||
if float32Reader.NullN() > 0 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("not nullable, but has null value")
|
||||
}
|
||||
for i := 0; i < dataNums; i++ {
|
||||
data = append(data, T(float32Reader.Value(i)))
|
||||
}
|
||||
case arrow.FLOAT64:
|
||||
float64Reader := chunk.(*array.Float64)
|
||||
if float64Reader.NullN() > 0 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("not nullable, but has null value")
|
||||
}
|
||||
for i := 0; i < dataNums; i++ {
|
||||
data = append(data, T(float64Reader.Value(i)))
|
||||
}
|
||||
default:
|
||||
return nil, WrapTypeErr("integer|float", chunk.DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
}
|
||||
if len(data) == 0 {
|
||||
@ -417,7 +403,7 @@ func ReadNullableIntegerOrFloatData[T constraints.Integer | constraints.Float](p
|
||||
validData = append(validData, make([]bool, dataNums)...)
|
||||
data = append(data, make([]T, dataNums)...)
|
||||
default:
|
||||
return nil, nil, WrapTypeErr("integer|float|null", chunk.DataType().Name(), pcr.field)
|
||||
return nil, nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
}
|
||||
if len(data) != len(validData) {
|
||||
@ -426,6 +412,7 @@ func ReadNullableIntegerOrFloatData[T constraints.Integer | constraints.Float](p
|
||||
if len(data) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
if pcr.field.GetDefaultValue() != nil {
|
||||
defaultValue, err := nullutil.GetDefaultValue(pcr.field)
|
||||
if err != nil {
|
||||
@ -486,10 +473,10 @@ func ReadStructData(pcr *FieldReader, count int64) ([]map[string]arrow.Array, er
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
structReader, ok := chunk.(*array.Struct)
|
||||
if structReader.NullN() > 0 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("has null value, but struct doesn't support nullable yet")
|
||||
return nil, WrapNullRowErr(pcr.field)
|
||||
}
|
||||
if !ok {
|
||||
return nil, WrapTypeErr("struct", chunk.DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
|
||||
structType := structReader.DataType().(*arrow.StructType)
|
||||
@ -505,92 +492,34 @@ func ReadStructData(pcr *FieldReader, count int64) ([]map[string]arrow.Array, er
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func ReadStringData(pcr *FieldReader, count int64) (any, error) {
|
||||
func ReadStringData(pcr *FieldReader, count int64, isVarcharField bool) (any, error) {
|
||||
chunked, err := pcr.columnReader.NextBatch(count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]string, 0, count)
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
dataNums := chunk.Data().Len()
|
||||
stringReader, ok := chunk.(*array.String)
|
||||
if stringReader.NullN() > 0 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("not nullable, but has null value")
|
||||
var maxLength int64
|
||||
if isVarcharField {
|
||||
maxLength, err = parameterutil.GetMaxLength(pcr.field)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, WrapTypeErr("string", chunk.DataType().Name(), pcr.field)
|
||||
}
|
||||
for i := 0; i < dataNums; i++ {
|
||||
data = append(data, stringReader.Value(i))
|
||||
}
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func ReadNullableStringData(pcr *FieldReader, count int64) (any, []bool, error) {
|
||||
chunked, err := pcr.columnReader.NextBatch(count)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
data := make([]string, 0, count)
|
||||
validData := make([]bool, 0, count)
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
dataNums := chunk.Data().Len()
|
||||
stringReader, ok := chunk.(*array.String)
|
||||
if !ok {
|
||||
// the chunk type may be *array.Null if the data in chunk is all null
|
||||
_, ok := chunk.(*array.Null)
|
||||
if !ok {
|
||||
return nil, nil, WrapTypeErr("string|null", chunk.DataType().Name(), pcr.field)
|
||||
}
|
||||
validData = append(validData, make([]bool, dataNums)...)
|
||||
data = append(data, make([]string, dataNums)...)
|
||||
} else {
|
||||
validData = append(validData, bytesToValidData(dataNums, stringReader.NullBitmapBytes())...)
|
||||
for i := 0; i < dataNums; i++ {
|
||||
if stringReader.IsNull(i) {
|
||||
data = append(data, "")
|
||||
continue
|
||||
}
|
||||
data = append(data, stringReader.ValueStr(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(data) != len(validData) {
|
||||
return nil, nil, merr.WrapErrParameterInvalid(len(data), len(validData), "length of data is not equal to length of valid_data")
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
return data, validData, nil
|
||||
}
|
||||
|
||||
func ReadVarcharData(pcr *FieldReader, count int64) (any, error) {
|
||||
chunked, err := pcr.columnReader.NextBatch(count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make([]string, 0, count)
|
||||
maxLength, err := parameterutil.GetMaxLength(pcr.field)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
dataNums := chunk.Data().Len()
|
||||
stringReader, ok := chunk.(*array.String)
|
||||
if stringReader.NullN() > 0 {
|
||||
return nil, merr.WrapErrParameterInvalidMsg("not nullable, but has null value")
|
||||
if chunk.NullN() > 0 {
|
||||
return nil, WrapNullRowErr(pcr.field)
|
||||
}
|
||||
stringReader, ok := chunk.(*array.String)
|
||||
if !ok {
|
||||
return nil, WrapTypeErr("string", chunk.DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
for i := 0; i < dataNums; i++ {
|
||||
value := stringReader.Value(i)
|
||||
if err = common.CheckValidString(value, maxLength, pcr.field); err != nil {
|
||||
return nil, err
|
||||
if isVarcharField {
|
||||
if err = common.CheckValidString(value, maxLength, pcr.field); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
data = append(data, value)
|
||||
}
|
||||
@ -601,17 +530,20 @@ func ReadVarcharData(pcr *FieldReader, count int64) (any, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func ReadNullableVarcharData(pcr *FieldReader, count int64) (any, []bool, error) {
|
||||
func ReadNullableStringData(pcr *FieldReader, count int64, isVarcharField bool) (any, []bool, error) {
|
||||
chunked, err := pcr.columnReader.NextBatch(count)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
data := make([]string, 0, count)
|
||||
maxLength, err := parameterutil.GetMaxLength(pcr.field)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
validData := make([]bool, 0, count)
|
||||
var maxLength int64
|
||||
if isVarcharField {
|
||||
maxLength, err = parameterutil.GetMaxLength(pcr.field)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
dataNums := chunk.Data().Len()
|
||||
stringReader, ok := chunk.(*array.String)
|
||||
@ -619,7 +551,7 @@ func ReadNullableVarcharData(pcr *FieldReader, count int64) (any, []bool, error)
|
||||
// the chunk type may be *array.Null if the data in chunk is all null
|
||||
_, ok := chunk.(*array.Null)
|
||||
if !ok {
|
||||
return nil, nil, WrapTypeErr("string|null", chunk.DataType().Name(), pcr.field)
|
||||
return nil, nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
validData = append(validData, make([]bool, dataNums)...)
|
||||
data = append(data, make([]string, dataNums)...)
|
||||
@ -630,9 +562,11 @@ func ReadNullableVarcharData(pcr *FieldReader, count int64) (any, []bool, error)
|
||||
data = append(data, "")
|
||||
continue
|
||||
}
|
||||
value := stringReader.ValueStr(i)
|
||||
if err = common.CheckValidString(value, maxLength, pcr.field); err != nil {
|
||||
return nil, nil, err
|
||||
value := stringReader.Value(i)
|
||||
if isVarcharField {
|
||||
if err = common.CheckValidString(value, maxLength, pcr.field); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
data = append(data, value)
|
||||
}
|
||||
@ -644,7 +578,7 @@ func ReadNullableVarcharData(pcr *FieldReader, count int64) (any, []bool, error)
|
||||
if len(data) == 0 {
|
||||
return nil, nil, nil
|
||||
}
|
||||
if pcr.field.GetDefaultValue() != nil {
|
||||
if isVarcharField && pcr.field.GetDefaultValue() != nil {
|
||||
defaultValue := pcr.field.GetDefaultValue().GetStringData()
|
||||
return fillWithDefaultValueImpl(data, defaultValue, validData, pcr.field)
|
||||
}
|
||||
@ -653,7 +587,7 @@ func ReadNullableVarcharData(pcr *FieldReader, count int64) (any, []bool, error)
|
||||
|
||||
func ReadJSONData(pcr *FieldReader, count int64) (any, error) {
|
||||
// JSON field read data from string array Parquet
|
||||
data, err := ReadStringData(pcr, count)
|
||||
data, err := ReadStringData(pcr, count, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -681,7 +615,7 @@ func ReadJSONData(pcr *FieldReader, count int64) (any, error) {
|
||||
|
||||
func ReadNullableJSONData(pcr *FieldReader, count int64) (any, []bool, error) {
|
||||
// JSON field read data from string array Parquet
|
||||
data, validData, err := ReadNullableStringData(pcr, count)
|
||||
data, validData, err := ReadNullableStringData(pcr, count, false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -733,11 +667,11 @@ func ReadBinaryData(pcr *FieldReader, count int64) (any, error) {
|
||||
}
|
||||
uint8Reader, ok := listReader.ListValues().(*array.Uint8)
|
||||
if !ok {
|
||||
return nil, WrapTypeErr("binary", listReader.ListValues().DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, listReader.ListValues().DataType().Name())
|
||||
}
|
||||
data = append(data, uint8Reader.Uint8Values()...)
|
||||
default:
|
||||
return nil, WrapTypeErr("binary", chunk.DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
}
|
||||
if len(data) == 0 {
|
||||
@ -887,7 +821,7 @@ func parseSparseFloatVectorStructs(structs []map[string]arrow.Array) ([][]byte,
|
||||
func ReadSparseFloatVectorData(pcr *FieldReader, count int64) (any, error) {
|
||||
// read sparse vector from JSON-format string
|
||||
if pcr.sparseIsString {
|
||||
data, err := ReadStringData(pcr, count)
|
||||
data, err := ReadStringData(pcr, count, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -994,20 +928,31 @@ func ReadBoolArrayData(pcr *FieldReader, count int64) (any, error) {
|
||||
}
|
||||
data := make([][]bool, 0, count)
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
if chunk.NullN() > 0 {
|
||||
// Array field is not nullable, but some arrays are null
|
||||
return nil, WrapNullRowErr(pcr.field)
|
||||
}
|
||||
listReader, ok := chunk.(*array.List)
|
||||
if !ok {
|
||||
return nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
boolReader, ok := listReader.ListValues().(*array.Boolean)
|
||||
if !ok {
|
||||
return nil, WrapTypeErr("boolArray", chunk.DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
offsets := listReader.Offsets()
|
||||
getArrayData(offsets, func(i int) (bool, error) {
|
||||
err = getArrayData(offsets, func(i int) (bool, error) {
|
||||
if boolReader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return false, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return boolReader.Value(i), nil
|
||||
}, func(arr []bool, valid bool) {
|
||||
data = append(data, arr)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, nil
|
||||
@ -1028,7 +973,7 @@ func ReadNullableBoolArrayData(pcr *FieldReader, count int64) (any, []bool, erro
|
||||
// the chunk type may be *array.Null if the data in chunk is all null
|
||||
_, ok := chunk.(*array.Null)
|
||||
if !ok {
|
||||
return nil, nil, WrapTypeErr("list|null", chunk.DataType().Name(), pcr.field)
|
||||
return nil, nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
dataNums := chunk.Data().Len()
|
||||
validData = append(validData, make([]bool, dataNums)...)
|
||||
@ -1036,15 +981,21 @@ func ReadNullableBoolArrayData(pcr *FieldReader, count int64) (any, []bool, erro
|
||||
} else {
|
||||
boolReader, ok := listReader.ListValues().(*array.Boolean)
|
||||
if !ok {
|
||||
return nil, nil, WrapTypeErr("boolArray", chunk.DataType().Name(), pcr.field)
|
||||
return nil, nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
offsets := listReader.Offsets()
|
||||
getArrayData(offsets, func(i int) (bool, error) {
|
||||
err = getArrayData(offsets, func(i int) (bool, error) {
|
||||
if boolReader.IsNull(i) {
|
||||
return false, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return boolReader.Value(i), nil
|
||||
}, func(arr []bool, valid bool) {
|
||||
data = append(data, arr)
|
||||
validData = append(validData, valid)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(data) != len(validData) {
|
||||
@ -1064,9 +1015,13 @@ func ReadIntegerOrFloatArrayData[T constraints.Integer | constraints.Float](pcr
|
||||
data := make([][]T, 0, count)
|
||||
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
if chunk.NullN() > 0 {
|
||||
// Array field is not nullable, but some arrays are null
|
||||
return nil, WrapNullRowErr(pcr.field)
|
||||
}
|
||||
listReader, ok := chunk.(*array.List)
|
||||
if !ok {
|
||||
return nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
offsets := listReader.Offsets()
|
||||
dataType := pcr.field.GetDataType()
|
||||
@ -1079,48 +1034,90 @@ func ReadIntegerOrFloatArrayData[T constraints.Integer | constraints.Float](pcr
|
||||
switch valueReader.DataType().ID() {
|
||||
case arrow.INT8:
|
||||
int8Reader := valueReader.(*array.Int8)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if int8Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(int8Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case arrow.INT16:
|
||||
int16Reader := valueReader.(*array.Int16)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if int16Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(int16Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case arrow.INT32:
|
||||
int32Reader := valueReader.(*array.Int32)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if int32Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(int32Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case arrow.INT64:
|
||||
int64Reader := valueReader.(*array.Int64)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if int64Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(int64Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case arrow.FLOAT32:
|
||||
float32Reader := valueReader.(*array.Float32)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if float32Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0.0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(float32Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case arrow.FLOAT64:
|
||||
float64Reader := valueReader.(*array.Float64)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if float64Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0.0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(float64Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, WrapTypeErr("integerArray|floatArray", chunk.DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
}
|
||||
if len(data) == 0 {
|
||||
@ -1143,7 +1140,7 @@ func ReadNullableIntegerOrFloatArrayData[T constraints.Integer | constraints.Flo
|
||||
// the chunk type may be *array.Null if the data in chunk is all null
|
||||
_, ok := chunk.(*array.Null)
|
||||
if !ok {
|
||||
return nil, nil, WrapTypeErr("list|null", chunk.DataType().Name(), pcr.field)
|
||||
return nil, nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
dataNums := chunk.Data().Len()
|
||||
validData = append(validData, make([]bool, dataNums)...)
|
||||
@ -1160,54 +1157,96 @@ func ReadNullableIntegerOrFloatArrayData[T constraints.Integer | constraints.Flo
|
||||
switch valueReader.DataType().ID() {
|
||||
case arrow.INT8:
|
||||
int8Reader := valueReader.(*array.Int8)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if int8Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(int8Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
validData = append(validData, valid)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
case arrow.INT16:
|
||||
int16Reader := valueReader.(*array.Int16)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if int16Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(int16Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
validData = append(validData, valid)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
case arrow.INT32:
|
||||
int32Reader := valueReader.(*array.Int32)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if int32Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(int32Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
validData = append(validData, valid)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
case arrow.INT64:
|
||||
int64Reader := valueReader.(*array.Int64)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if int64Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(int64Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
validData = append(validData, valid)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
case arrow.FLOAT32:
|
||||
float32Reader := valueReader.(*array.Float32)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if float32Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0.0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(float32Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
validData = append(validData, valid)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
case arrow.FLOAT64:
|
||||
float64Reader := valueReader.(*array.Float64)
|
||||
getArrayData(offsets, func(i int) (T, error) {
|
||||
err = getArrayData(offsets, func(i int) (T, error) {
|
||||
if float64Reader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return 0.0, WrapNullElementErr(pcr.field)
|
||||
}
|
||||
return T(float64Reader.Value(i)), nil
|
||||
}, func(arr []T, valid bool) {
|
||||
data = append(data, arr)
|
||||
validData = append(validData, valid)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
default:
|
||||
return nil, nil, WrapTypeErr("integerArray|floatArray", chunk.DataType().Name(), pcr.field)
|
||||
return nil, nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1231,16 +1270,24 @@ func ReadStringArrayData(pcr *FieldReader, count int64) (any, error) {
|
||||
}
|
||||
data := make([][]string, 0, count)
|
||||
for _, chunk := range chunked.Chunks() {
|
||||
if chunk.NullN() > 0 {
|
||||
// Array field is not nullable, but some arrays are null
|
||||
return nil, WrapNullRowErr(pcr.field)
|
||||
}
|
||||
listReader, ok := chunk.(*array.List)
|
||||
if !ok {
|
||||
return nil, WrapTypeErr("list", chunk.DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
stringReader, ok := listReader.ListValues().(*array.String)
|
||||
if !ok {
|
||||
return nil, WrapTypeErr("stringArray", chunk.DataType().Name(), pcr.field)
|
||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
offsets := listReader.Offsets()
|
||||
err = getArrayData(offsets, func(i int) (string, error) {
|
||||
if stringReader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return "", WrapNullElementErr(pcr.field)
|
||||
}
|
||||
val := stringReader.Value(i)
|
||||
if err = common.CheckValidString(val, maxLength, pcr.field); err != nil {
|
||||
return val, err
|
||||
@ -1276,7 +1323,7 @@ func ReadNullableStringArrayData(pcr *FieldReader, count int64) (any, []bool, er
|
||||
// the chunk type may be *array.Null if the data in chunk is all null
|
||||
_, ok := chunk.(*array.Null)
|
||||
if !ok {
|
||||
return nil, nil, WrapTypeErr("list|null", chunk.DataType().Name(), pcr.field)
|
||||
return nil, nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
dataNums := chunk.Data().Len()
|
||||
validData = append(validData, make([]bool, dataNums)...)
|
||||
@ -1284,10 +1331,14 @@ func ReadNullableStringArrayData(pcr *FieldReader, count int64) (any, []bool, er
|
||||
} else {
|
||||
stringReader, ok := listReader.ListValues().(*array.String)
|
||||
if !ok {
|
||||
return nil, nil, WrapTypeErr("stringArray", chunk.DataType().Name(), pcr.field)
|
||||
return nil, nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||
}
|
||||
offsets := listReader.Offsets()
|
||||
err = getArrayData(offsets, func(i int) (string, error) {
|
||||
if stringReader.IsNull(i) {
|
||||
// array contains null values is not allowed
|
||||
return "", WrapNullElementErr(pcr.field)
|
||||
}
|
||||
val := stringReader.Value(i)
|
||||
if err = common.CheckValidString(val, maxLength, pcr.field); err != nil {
|
||||
return val, err
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"github.com/apache/arrow/go/v17/arrow/array"
|
||||
"github.com/apache/arrow/go/v17/arrow/memory"
|
||||
"github.com/apache/arrow/go/v17/parquet"
|
||||
"github.com/apache/arrow/go/v17/parquet/file"
|
||||
"github.com/apache/arrow/go/v17/parquet/pqarrow"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
@ -336,3 +337,503 @@ func TestParseSparseFloatVectorStructs(t *testing.T) {
|
||||
isValidFunc(genInt64ArrList(indices), genFloat32ArrList(values))
|
||||
isValidFunc(genInt64ArrList(indices), genFloat64ArrList(values))
|
||||
}
|
||||
|
||||
func TestReadFieldData(t *testing.T) {
|
||||
checkFunc := func(dataHasNull bool, readScehamIsNullable bool, dataType schemapb.DataType, elementType schemapb.DataType) {
|
||||
fieldName := dataType.String()
|
||||
if elementType != schemapb.DataType_None {
|
||||
fieldName = fieldName + "_" + elementType.String()
|
||||
}
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: fieldName,
|
||||
DataType: dataType,
|
||||
ElementType: elementType,
|
||||
Nullable: dataHasNull,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
{
|
||||
Key: "max_length",
|
||||
Value: "1000",
|
||||
},
|
||||
{
|
||||
Key: "max_capacity",
|
||||
Value: "50",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
arrDataType, err := convertToArrowDataType(schema.Fields[0], false)
|
||||
assert.NoError(t, err)
|
||||
arrFields := make([]arrow.Field, 0)
|
||||
arrFields = append(arrFields, arrow.Field{
|
||||
Name: schema.Fields[0].Name,
|
||||
Type: arrDataType,
|
||||
Nullable: true,
|
||||
Metadata: arrow.Metadata{},
|
||||
})
|
||||
pqSchema := arrow.NewSchema(arrFields, nil)
|
||||
|
||||
filePath := fmt.Sprintf("/tmp/test_%d_reader.parquet", rand.Int())
|
||||
defer os.Remove(filePath)
|
||||
wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fw, err := pqarrow.NewFileWriter(pqSchema, wf,
|
||||
parquet.NewWriterProperties(parquet.WithMaxRowGroupLength(100)), pqarrow.DefaultWriterProps())
|
||||
assert.NoError(t, err)
|
||||
|
||||
rowCount := 5
|
||||
nullPercent := 0
|
||||
if dataHasNull {
|
||||
nullPercent = 50
|
||||
}
|
||||
insertData, err := testutil.CreateInsertData(schema, rowCount, nullPercent)
|
||||
assert.NoError(t, err)
|
||||
columns, err := testutil.BuildArrayData(schema, insertData, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
recordBatch := array.NewRecord(pqSchema, columns, int64(rowCount))
|
||||
err = fw.Write(recordBatch)
|
||||
assert.NoError(t, err)
|
||||
fw.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
f := storage.NewChunkManagerFactory("local", objectstorage.RootPath(testOutputPath))
|
||||
cm, err := f.NewPersistentStorageChunkManager(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
schema.Fields[0].Nullable = readScehamIsNullable
|
||||
reader, err := NewReader(ctx, cm, schema, filePath, 64*1024*1024)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, reader)
|
||||
defer reader.Close()
|
||||
|
||||
_, err = reader.Read()
|
||||
if !readScehamIsNullable && dataHasNull {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
dataHasNull bool
|
||||
readScehamIsNullable bool
|
||||
dataType schemapb.DataType
|
||||
elementType schemapb.DataType
|
||||
}
|
||||
buildCaseFunc := func(dataHasNull bool, readScehamIsNullable bool, dataType schemapb.DataType, elementType schemapb.DataType) *testCase {
|
||||
name := fmt.Sprintf("dataHasNull='%v' schemaNullable='%v' dataType='%s' elementType='%s'",
|
||||
dataHasNull, readScehamIsNullable, dataType, elementType)
|
||||
return &testCase{
|
||||
name: name,
|
||||
dataHasNull: dataHasNull,
|
||||
readScehamIsNullable: readScehamIsNullable,
|
||||
dataType: dataType,
|
||||
elementType: elementType,
|
||||
}
|
||||
}
|
||||
cases := make([]*testCase, 0)
|
||||
|
||||
nullableDataTypes := []schemapb.DataType{
|
||||
schemapb.DataType_Bool,
|
||||
schemapb.DataType_Int8,
|
||||
schemapb.DataType_Int16,
|
||||
schemapb.DataType_Int32,
|
||||
schemapb.DataType_Int64,
|
||||
schemapb.DataType_Float,
|
||||
schemapb.DataType_Double,
|
||||
schemapb.DataType_VarChar,
|
||||
}
|
||||
for _, dataType := range nullableDataTypes {
|
||||
cases = append(cases, buildCaseFunc(true, true, dataType, schemapb.DataType_None))
|
||||
cases = append(cases, buildCaseFunc(true, false, dataType, schemapb.DataType_None))
|
||||
cases = append(cases, buildCaseFunc(false, true, dataType, schemapb.DataType_None))
|
||||
cases = append(cases, buildCaseFunc(false, true, dataType, schemapb.DataType_None))
|
||||
}
|
||||
|
||||
elementTypes := []schemapb.DataType{
|
||||
schemapb.DataType_Bool,
|
||||
schemapb.DataType_Int8,
|
||||
schemapb.DataType_Int16,
|
||||
schemapb.DataType_Int32,
|
||||
schemapb.DataType_Int64,
|
||||
schemapb.DataType_Float,
|
||||
schemapb.DataType_Double,
|
||||
schemapb.DataType_VarChar,
|
||||
}
|
||||
for _, elementType := range elementTypes {
|
||||
cases = append(cases, buildCaseFunc(true, true, schemapb.DataType_Array, elementType))
|
||||
cases = append(cases, buildCaseFunc(true, false, schemapb.DataType_Array, elementType))
|
||||
cases = append(cases, buildCaseFunc(false, true, schemapb.DataType_Array, elementType))
|
||||
cases = append(cases, buildCaseFunc(false, false, schemapb.DataType_Array, elementType))
|
||||
}
|
||||
|
||||
notNullableTypes := []schemapb.DataType{
|
||||
schemapb.DataType_JSON,
|
||||
schemapb.DataType_FloatVector,
|
||||
schemapb.DataType_BinaryVector,
|
||||
schemapb.DataType_SparseFloatVector,
|
||||
schemapb.DataType_Float16Vector,
|
||||
schemapb.DataType_BFloat16Vector,
|
||||
schemapb.DataType_Int8Vector,
|
||||
}
|
||||
for _, dataType := range notNullableTypes {
|
||||
cases = append(cases, buildCaseFunc(false, false, dataType, schemapb.DataType_None))
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
checkFunc(tt.dataHasNull, tt.readScehamIsNullable, tt.dataType, tt.elementType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTypeMismatch(t *testing.T) {
|
||||
checkFunc := func(srcDataType schemapb.DataType, srcElementType schemapb.DataType, dstDataType schemapb.DataType, dstElementType schemapb.DataType, nullalbe bool) {
|
||||
fieldName := "test_field"
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: fieldName,
|
||||
DataType: srcDataType,
|
||||
ElementType: srcElementType,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
{
|
||||
Key: "max_length",
|
||||
Value: "1000",
|
||||
},
|
||||
{
|
||||
Key: "max_capacity",
|
||||
Value: "50",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
arrDataType, err := convertToArrowDataType(schema.Fields[0], false)
|
||||
assert.NoError(t, err)
|
||||
arrFields := make([]arrow.Field, 0)
|
||||
arrFields = append(arrFields, arrow.Field{
|
||||
Name: schema.Fields[0].Name,
|
||||
Type: arrDataType,
|
||||
Nullable: true,
|
||||
Metadata: arrow.Metadata{},
|
||||
})
|
||||
pqSchema := arrow.NewSchema(arrFields, nil)
|
||||
|
||||
filePath := fmt.Sprintf("/tmp/test_%d_reader.parquet", rand.Int())
|
||||
defer os.Remove(filePath)
|
||||
wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fw, err := pqarrow.NewFileWriter(pqSchema, wf,
|
||||
parquet.NewWriterProperties(parquet.WithMaxRowGroupLength(100)), pqarrow.DefaultWriterProps())
|
||||
assert.NoError(t, err)
|
||||
|
||||
rowCount := 5
|
||||
insertData, err := testutil.CreateInsertData(schema, rowCount, 0)
|
||||
assert.NoError(t, err)
|
||||
columns, err := testutil.BuildArrayData(schema, insertData, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
recordBatch := array.NewRecord(pqSchema, columns, int64(rowCount))
|
||||
err = fw.Write(recordBatch)
|
||||
assert.NoError(t, err)
|
||||
fw.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
f := storage.NewChunkManagerFactory("local", objectstorage.RootPath(testOutputPath))
|
||||
cm, err := f.NewPersistentStorageChunkManager(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
schema.Fields[0].DataType = dstDataType
|
||||
schema.Fields[0].ElementType = dstElementType
|
||||
schema.Fields[0].Nullable = nullalbe
|
||||
cmReader, err := cm.Reader(ctx, filePath)
|
||||
assert.NoError(t, err)
|
||||
reader, err := file.NewParquetReader(cmReader, file.WithReadProps(&parquet.ReaderProperties{
|
||||
BufferSize: 65535,
|
||||
BufferedStreamEnabled: true,
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
|
||||
readProps := pqarrow.ArrowReadProperties{
|
||||
BatchSize: int64(rowCount),
|
||||
}
|
||||
fileReader, err := pqarrow.NewFileReader(reader, readProps, memory.DefaultAllocator)
|
||||
assert.NoError(t, err)
|
||||
columnReader, err := NewFieldReader(ctx, fileReader, 0, schema.Fields[0])
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _, err = columnReader.Next(int64(rowCount))
|
||||
if srcDataType != dstDataType || srcElementType != dstElementType {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
srcDataType schemapb.DataType
|
||||
srcElementType schemapb.DataType
|
||||
dstDataType schemapb.DataType
|
||||
dstElementType schemapb.DataType
|
||||
nullable bool
|
||||
}
|
||||
buildCaseFunc := func(srcDataType schemapb.DataType, srcElementType schemapb.DataType, dstDataType schemapb.DataType, dstElementType schemapb.DataType, nullable bool) *testCase {
|
||||
name := fmt.Sprintf("srcDataType='%s' srcElementType='%s' dstDataType='%s' dstElementType='%s' nullable='%v'",
|
||||
srcDataType, srcElementType, dstDataType, dstElementType, nullable)
|
||||
return &testCase{
|
||||
name: name,
|
||||
srcDataType: srcDataType,
|
||||
srcElementType: srcElementType,
|
||||
dstDataType: dstDataType,
|
||||
dstElementType: dstElementType,
|
||||
nullable: nullable,
|
||||
}
|
||||
}
|
||||
cases := make([]*testCase, 0)
|
||||
|
||||
scalarDataTypes := []schemapb.DataType{
|
||||
schemapb.DataType_Bool,
|
||||
schemapb.DataType_Int8,
|
||||
schemapb.DataType_Int16,
|
||||
schemapb.DataType_Int32,
|
||||
schemapb.DataType_Int64,
|
||||
schemapb.DataType_Float,
|
||||
schemapb.DataType_Double,
|
||||
schemapb.DataType_VarChar,
|
||||
}
|
||||
for _, dataType := range scalarDataTypes {
|
||||
srcDataType := schemapb.DataType_Bool
|
||||
if dataType == schemapb.DataType_Bool {
|
||||
srcDataType = schemapb.DataType_Int8
|
||||
}
|
||||
cases = append(cases, buildCaseFunc(srcDataType, schemapb.DataType_None, dataType, schemapb.DataType_None, true))
|
||||
cases = append(cases, buildCaseFunc(srcDataType, schemapb.DataType_None, dataType, schemapb.DataType_None, false))
|
||||
}
|
||||
|
||||
elementTypes := []schemapb.DataType{
|
||||
schemapb.DataType_Bool,
|
||||
schemapb.DataType_Int8,
|
||||
schemapb.DataType_Int16,
|
||||
schemapb.DataType_Int32,
|
||||
schemapb.DataType_Int64,
|
||||
schemapb.DataType_Float,
|
||||
schemapb.DataType_Double,
|
||||
schemapb.DataType_VarChar,
|
||||
}
|
||||
for _, elementType := range elementTypes {
|
||||
srcElementType := schemapb.DataType_Bool
|
||||
if elementType == schemapb.DataType_Bool {
|
||||
srcElementType = schemapb.DataType_Int8
|
||||
}
|
||||
// element type mismatch
|
||||
cases = append(cases, buildCaseFunc(schemapb.DataType_Array, srcElementType, schemapb.DataType_Array, elementType, true))
|
||||
cases = append(cases, buildCaseFunc(schemapb.DataType_Array, srcElementType, schemapb.DataType_Array, elementType, false))
|
||||
// not a list
|
||||
cases = append(cases, buildCaseFunc(schemapb.DataType_Bool, schemapb.DataType_None, schemapb.DataType_Array, elementType, true))
|
||||
cases = append(cases, buildCaseFunc(schemapb.DataType_Bool, schemapb.DataType_None, schemapb.DataType_Array, elementType, false))
|
||||
}
|
||||
|
||||
notNullableTypes := []schemapb.DataType{
|
||||
schemapb.DataType_JSON,
|
||||
schemapb.DataType_FloatVector,
|
||||
schemapb.DataType_BinaryVector,
|
||||
schemapb.DataType_SparseFloatVector,
|
||||
schemapb.DataType_Float16Vector,
|
||||
schemapb.DataType_BFloat16Vector,
|
||||
schemapb.DataType_Int8Vector,
|
||||
}
|
||||
for _, dataType := range notNullableTypes {
|
||||
srcDataType := schemapb.DataType_Bool
|
||||
if dataType == schemapb.DataType_Bool {
|
||||
srcDataType = schemapb.DataType_Int8
|
||||
}
|
||||
// not a list
|
||||
cases = append(cases, buildCaseFunc(srcDataType, schemapb.DataType_None, dataType, schemapb.DataType_None, false))
|
||||
// element type mismatch
|
||||
cases = append(cases, buildCaseFunc(schemapb.DataType_Array, schemapb.DataType_Bool, dataType, schemapb.DataType_None, false))
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
checkFunc(tt.srcDataType, tt.srcElementType, tt.dstDataType, tt.dstElementType, tt.nullable)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestArrayNullElement(t *testing.T) {
|
||||
checkFunc := func(dataType schemapb.DataType, elementType schemapb.DataType) {
|
||||
fieldName := "test_field"
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: fieldName,
|
||||
DataType: dataType,
|
||||
ElementType: elementType,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: "dim",
|
||||
Value: "16",
|
||||
},
|
||||
{
|
||||
Key: "max_length",
|
||||
Value: "1000",
|
||||
},
|
||||
{
|
||||
Key: "max_capacity",
|
||||
Value: "50",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
arrDataType, err := convertToArrowDataType(schema.Fields[0], false)
|
||||
assert.NoError(t, err)
|
||||
arrFields := make([]arrow.Field, 0)
|
||||
arrFields = append(arrFields, arrow.Field{
|
||||
Name: schema.Fields[0].Name,
|
||||
Type: arrDataType,
|
||||
Nullable: true,
|
||||
Metadata: arrow.Metadata{},
|
||||
})
|
||||
pqSchema := arrow.NewSchema(arrFields, nil)
|
||||
|
||||
filePath := fmt.Sprintf("/tmp/test_%d_reader.parquet", rand.Int())
|
||||
defer os.Remove(filePath)
|
||||
wf, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE, 0o666)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fw, err := pqarrow.NewFileWriter(pqSchema, wf,
|
||||
parquet.NewWriterProperties(parquet.WithMaxRowGroupLength(100)), pqarrow.DefaultWriterProps())
|
||||
assert.NoError(t, err)
|
||||
|
||||
mem := memory.NewGoAllocator()
|
||||
columns := make([]arrow.Array, 0, len(schema.Fields))
|
||||
switch elementType {
|
||||
case schemapb.DataType_Bool:
|
||||
builder := array.NewListBuilder(mem, &arrow.BooleanType{})
|
||||
valueBuilder := builder.ValueBuilder().(*array.BooleanBuilder)
|
||||
valueBuilder.AppendValues([]bool{true, false}, []bool{true, false})
|
||||
builder.AppendValues([]int32{0}, []bool{true})
|
||||
columns = append(columns, builder.NewListArray())
|
||||
case schemapb.DataType_Int8:
|
||||
builder := array.NewListBuilder(mem, &arrow.Int8Type{})
|
||||
valueBuilder := builder.ValueBuilder().(*array.Int8Builder)
|
||||
valueBuilder.AppendValues([]int8{1, 2}, []bool{true, false})
|
||||
builder.AppendValues([]int32{0}, []bool{true})
|
||||
columns = append(columns, builder.NewListArray())
|
||||
case schemapb.DataType_Int16:
|
||||
builder := array.NewListBuilder(mem, &arrow.Int16Type{})
|
||||
valueBuilder := builder.ValueBuilder().(*array.Int16Builder)
|
||||
valueBuilder.AppendValues([]int16{1, 2}, []bool{true, false})
|
||||
builder.AppendValues([]int32{0}, []bool{true})
|
||||
columns = append(columns, builder.NewListArray())
|
||||
case schemapb.DataType_Int32:
|
||||
builder := array.NewListBuilder(mem, &arrow.Int32Type{})
|
||||
valueBuilder := builder.ValueBuilder().(*array.Int32Builder)
|
||||
valueBuilder.AppendValues([]int32{1, 2}, []bool{true, false})
|
||||
builder.AppendValues([]int32{0}, []bool{true})
|
||||
columns = append(columns, builder.NewListArray())
|
||||
case schemapb.DataType_Int64:
|
||||
builder := array.NewListBuilder(mem, &arrow.Int64Type{})
|
||||
valueBuilder := builder.ValueBuilder().(*array.Int64Builder)
|
||||
valueBuilder.AppendValues([]int64{1, 2}, []bool{true, false})
|
||||
builder.AppendValues([]int32{0}, []bool{true})
|
||||
columns = append(columns, builder.NewListArray())
|
||||
case schemapb.DataType_Float:
|
||||
builder := array.NewListBuilder(mem, &arrow.Float32Type{})
|
||||
valueBuilder := builder.ValueBuilder().(*array.Float32Builder)
|
||||
valueBuilder.AppendValues([]float32{0.1, 0.2}, []bool{true, false})
|
||||
builder.AppendValues([]int32{0}, []bool{true})
|
||||
columns = append(columns, builder.NewListArray())
|
||||
case schemapb.DataType_Double:
|
||||
builder := array.NewListBuilder(mem, &arrow.Float64Type{})
|
||||
valueBuilder := builder.ValueBuilder().(*array.Float64Builder)
|
||||
valueBuilder.AppendValues([]float64{0.1, 0.2}, []bool{true, false})
|
||||
builder.AppendValues([]int32{0}, []bool{true})
|
||||
columns = append(columns, builder.NewListArray())
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
builder := array.NewListBuilder(mem, &arrow.StringType{})
|
||||
valueBuilder := builder.ValueBuilder().(*array.StringBuilder)
|
||||
valueBuilder.AppendValues([]string{"a", "b"}, []bool{true, false})
|
||||
builder.AppendValues([]int32{0}, []bool{true})
|
||||
columns = append(columns, builder.NewListArray())
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
recordBatch := array.NewRecord(pqSchema, columns, int64(1))
|
||||
err = fw.Write(recordBatch)
|
||||
assert.NoError(t, err)
|
||||
fw.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
f := storage.NewChunkManagerFactory("local", objectstorage.RootPath(testOutputPath))
|
||||
cm, err := f.NewPersistentStorageChunkManager(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
reader, err := NewReader(ctx, cm, schema, filePath, 64*1024*1024)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, reader)
|
||||
defer reader.Close()
|
||||
|
||||
_, err = reader.Read()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
dataType schemapb.DataType
|
||||
elementType schemapb.DataType
|
||||
}
|
||||
buildCaseFunc := func(dataType schemapb.DataType, elementType schemapb.DataType) *testCase {
|
||||
name := fmt.Sprintf("dataType='%s' elementType='%s'", dataType, elementType)
|
||||
return &testCase{
|
||||
name: name,
|
||||
dataType: dataType,
|
||||
elementType: elementType,
|
||||
}
|
||||
}
|
||||
cases := make([]*testCase, 0)
|
||||
|
||||
elementTypes := []schemapb.DataType{
|
||||
schemapb.DataType_Bool,
|
||||
schemapb.DataType_Int8,
|
||||
schemapb.DataType_Int16,
|
||||
schemapb.DataType_Int32,
|
||||
schemapb.DataType_Int64,
|
||||
schemapb.DataType_Float,
|
||||
schemapb.DataType_Double,
|
||||
schemapb.DataType_VarChar,
|
||||
}
|
||||
for _, elementType := range elementTypes {
|
||||
cases = append(cases, buildCaseFunc(schemapb.DataType_Array, elementType))
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
checkFunc(tt.dataType, tt.elementType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -36,10 +36,32 @@ const (
|
||||
sparseVectorValues = "values"
|
||||
)
|
||||
|
||||
func WrapTypeErr(expect string, actual string, field *schemapb.FieldSchema) error {
|
||||
func WrapTypeErr(expect *schemapb.FieldSchema, actual string) error {
|
||||
nullable := ""
|
||||
if expect.GetNullable() {
|
||||
nullable = "nullable"
|
||||
}
|
||||
elementType := ""
|
||||
if expect.GetDataType() == schemapb.DataType_Array {
|
||||
elementType = expect.GetElementType().String()
|
||||
}
|
||||
// error message examples:
|
||||
// "expect 'Int32' type for field 'xxx', but got 'bool' type"
|
||||
// "expect nullable 'Int32 Array' type for field 'xxx', but got 'bool' type"
|
||||
// "expect 'FloatVector' type for field 'xxx', but got 'bool' type"
|
||||
return merr.WrapErrImportFailed(
|
||||
fmt.Sprintf("expect '%s' type for field '%s', but got '%s' type",
|
||||
expect, field.GetName(), actual))
|
||||
fmt.Sprintf("expect %s '%s %s' type for field '%s', but got '%s' type",
|
||||
nullable, elementType, expect.GetDataType().String(), expect.GetName(), actual))
|
||||
}
|
||||
|
||||
func WrapNullRowErr(field *schemapb.FieldSchema) error {
|
||||
return merr.WrapErrImportFailed(
|
||||
fmt.Sprintf("the field '%s' is not nullable but the file contains null value", field.GetName()))
|
||||
}
|
||||
|
||||
func WrapNullElementErr(field *schemapb.FieldSchema) error {
|
||||
return merr.WrapErrImportFailed(
|
||||
fmt.Sprintf("array element is not allowed to be null value for field '%s'", field.GetName()))
|
||||
}
|
||||
|
||||
func CreateFieldReaders(ctx context.Context, fileReader *pqarrow.FileReader, schema *schemapb.CollectionSchema) (map[int64]*FieldReader, error) {
|
||||
@ -323,8 +345,8 @@ func isSchemaEqual(schema *schemapb.CollectionSchema, arrSchema *arrow.Schema) e
|
||||
return err
|
||||
}
|
||||
if !isArrowDataTypeConvertible(arrField.Type, toArrDataType, field) {
|
||||
return merr.WrapErrImportFailed(fmt.Sprintf("field '%s' type mis-match, milvus data type '%s', arrow data type get '%s'",
|
||||
field.Name, field.DataType.String(), arrField.Type.String()))
|
||||
return merr.WrapErrImportFailed(fmt.Sprintf("field '%s' type mis-match, expect arrow type '%s', get arrow data type '%s'",
|
||||
field.Name, toArrDataType.String(), arrField.Type.String()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user