diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index 8ce6c57fc8..3594f16792 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -3,6 +3,7 @@ package proxy import ( "fmt" "math" + "reflect" "go.uber.org/zap" @@ -364,6 +365,72 @@ func (v *validateUtil) checkIntegerFieldData(field *schemapb.FieldData, fieldSch return nil } +func (v *validateUtil) checkArrayElement(array *schemapb.ArrayArray, field *schemapb.FieldSchema) error { + switch field.GetElementType() { + case schemapb.DataType_Bool: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_BoolData)(nil)) { + return merr.WrapErrParameterInvalid("bool array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + } + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_IntData)(nil)) { + return merr.WrapErrParameterInvalid("int array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + if v.checkOverflow { + if field.GetElementType() == schemapb.DataType_Int8 { + if err := verifyOverflowByRange(row.GetIntData().GetData(), math.MinInt8, math.MaxInt8); err != nil { + return err + } + } + if field.GetElementType() == schemapb.DataType_Int16 { + if err := verifyOverflowByRange(row.GetIntData().GetData(), math.MinInt16, math.MaxInt16); err != nil { + return err + } + } + } + } + case schemapb.DataType_Int64: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_LongData)(nil)) { + return merr.WrapErrParameterInvalid("int64 array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + } + case schemapb.DataType_Float: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_FloatData)(nil)) { + return merr.WrapErrParameterInvalid("float array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + } + case schemapb.DataType_Double: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_DoubleData)(nil)) { + return merr.WrapErrParameterInvalid("double array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + } + case schemapb.DataType_VarChar, schemapb.DataType_String: + for _, row := range array.GetData() { + actualType := reflect.TypeOf(row.GetData()) + if actualType != reflect.TypeOf((*schemapb.ScalarField_StringData)(nil)) { + return merr.WrapErrParameterInvalid("string array", + fmt.Sprintf("%s array", actualType.String()), "insert data does not match") + } + } + } + return nil +} + func (v *validateUtil) checkArrayFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { data := field.GetScalars().GetArrayData() if data == nil { @@ -390,7 +457,7 @@ func (v *validateUtil) checkArrayFieldData(field *schemapb.FieldData, fieldSchem } } } - return nil + return v.checkArrayElement(data, fieldSchema) } func verifyLengthPerRow[E interface{ ~string | ~[]byte }](strArr []E, maxLength int64) error { diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/validate_util_test.go index 9b43d26de0..1bd2ed9a7d 100644 --- a/internal/proxy/validate_util_test.go +++ b/internal/proxy/validate_util_test.go @@ -1301,6 +1301,434 @@ func Test_validateUtil_Validate(t *testing.T) { assert.Error(t, err) }) + t.Run("element type not match", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: []bool{true, false}, + }, + }, + }, + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Bool, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + v := newValidateUtil(withMaxCapCheck()) + err := v.Validate(data, schema, 1) + assert.Error(t, err) + + data = []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int8, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + assert.Error(t, err) + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int16, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + assert.Error(t, err) + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + assert.Error(t, err) + + data = []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{1, 2, 3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Float, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + assert.Error(t, err) + + data = []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: []float64{1, 2, 3}, + }, + }, + }, + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Double, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + assert.Error(t, err) + + data = []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"a", "b", "c"}, + }, + }, + }, + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + assert.Error(t, err) + + data = []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + { + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int64, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + err = newValidateUtil(withMaxCapCheck()).Validate(data, schema, 1) + assert.Error(t, err) + }) + + t.Run("array element overflow", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3, 1 << 9}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int8, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + err := newValidateUtil(withMaxCapCheck(), withOverflowCheck()).Validate(data, schema, 1) + assert.Error(t, err) + + data = []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{1, 2, 3, 1 << 9, 1 << 17}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + schema = &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int16, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + }, + }, + } + + err = newValidateUtil(withMaxCapCheck(), withOverflowCheck()).Validate(data, schema, 1) + assert.Error(t, err) + }) + t.Run("normal case", func(t *testing.T) { data := []*schemapb.FieldData{ {