From d896e3119e95235e5d67e1b4baa070609e4d9e41 Mon Sep 17 00:00:00 2001 From: dragondriver Date: Tue, 13 Jul 2021 16:33:55 +0800 Subject: [PATCH] Fix 6419, check if num_rows is greater than zero in proxy (#6439) Signed-off-by: dragondriver --- internal/proxy/error.go | 48 +++++ internal/proxy/task.go | 143 ++++++++++++- internal/proxy/task_test.go | 388 +++++++++++++++++++++++++++++++++++ internal/proxy/test_utils.go | 197 ++++++++++++++++++ 4 files changed, 773 insertions(+), 3 deletions(-) create mode 100644 internal/proxy/error.go create mode 100644 internal/proxy/task_test.go create mode 100644 internal/proxy/test_utils.go diff --git a/internal/proxy/error.go b/internal/proxy/error.go new file mode 100644 index 0000000000..a15c786569 --- /dev/null +++ b/internal/proxy/error.go @@ -0,0 +1,48 @@ +package proxy + +import ( + "errors" + "fmt" + + "github.com/milvus-io/milvus/internal/proto/schemapb" +) + +// TODO(dragondriver): add more common error type + +func errInvalidNumRows(numRows uint32) error { + return fmt.Errorf("invalid num_rows: %d", numRows) +} + +func errNumRowsLessThanOrEqualToZero(numRows uint32) error { + return fmt.Errorf("num_rows(%d) should be greater than 0", numRows) +} + +func errNumRowsOfFieldDataMismatchPassed(idx int, fieldNumRows, passedNumRows uint32) error { + return fmt.Errorf("the num_rows(%d) of %dth field is not equal to passed NumRows(%d)", fieldNumRows, idx, passedNumRows) +} + +var errEmptyFieldData = errors.New("empty field data") + +func errFieldsLessThanNeeded(fieldsNum, needed int) error { + return fmt.Errorf("the length(%d) of passed fields is less than needed(%d)", fieldsNum, needed) +} + +func errUnsupportedDataType(dType schemapb.DataType) error { + return fmt.Errorf("%v is not supported now", dType) +} + +func errUnsupportedDType(dType string) error { + return fmt.Errorf("%s is not supported now", dType) +} + +func errInvalidDim(dim int) error { + return fmt.Errorf("invalid dim: %d", dim) +} + +func errDimLessThanOrEqualToZero(dim int) error { + return fmt.Errorf("dim(%d) should be greater than 0", dim) +} + +func errDimShouldDivide8(dim int) error { + return fmt.Errorf("dim(%d) should divide 8", dim) +} diff --git a/internal/proxy/task.go b/internal/proxy/task.go index d60ccb5862..c7be5f607a 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -197,6 +197,133 @@ func (it *InsertTask) OnEnqueue() error { return nil } +func getNumRowsOfScalarField(datas interface{}) uint32 { + realTypeDatas := reflect.ValueOf(datas) + return uint32(realTypeDatas.Len()) +} + +func getNumRowsOfFloatVectorField(fDatas []float32, dim int64) (uint32, error) { + if dim <= 0 { + return 0, errDimLessThanOrEqualToZero(int(dim)) + } + l := len(fDatas) + if int64(l)%dim != 0 { + return 0, fmt.Errorf("the length(%d) of float data should divide the dim(%d)", l, dim) + } + return uint32(int(int64(l) / dim)), nil +} + +func getNumRowsOfBinaryVectorField(bDatas []byte, dim int64) (uint32, error) { + if dim <= 0 { + return 0, errDimLessThanOrEqualToZero(int(dim)) + } + if dim%8 != 0 { + return 0, errDimShouldDivide8(int(dim)) + } + l := len(bDatas) + if (8*int64(l))%dim != 0 { + return 0, fmt.Errorf("the num(%d) of all bits should divide the dim(%d)", 8*l, dim) + } + return uint32(int((8 * int64(l)) / dim)), nil +} + +func (it *InsertTask) checkLengthOfFieldsData() error { + neededFieldsNum := 0 + for _, field := range it.schema.Fields { + if !field.AutoID { + neededFieldsNum++ + } + } + + if len(it.req.FieldsData) < neededFieldsNum { + return errFieldsLessThanNeeded(len(it.req.FieldsData), neededFieldsNum) + } + + return nil +} + +func (it *InsertTask) checkRowNums() error { + if it.req.NumRows <= 0 { + return errNumRowsLessThanOrEqualToZero(it.req.NumRows) + } + + if err := it.checkLengthOfFieldsData(); err != nil { + return err + } + + rowNums := it.req.NumRows + + for i, field := range it.req.FieldsData { + switch field.Field.(type) { + case *schemapb.FieldData_Scalars: + scalarField := field.GetScalars() + switch scalarField.Data.(type) { + case *schemapb.ScalarField_BoolData: + fieldNumRows := getNumRowsOfScalarField(scalarField.GetBoolData().Data) + if fieldNumRows != rowNums { + return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) + } + case *schemapb.ScalarField_IntData: + fieldNumRows := getNumRowsOfScalarField(scalarField.GetIntData().Data) + if fieldNumRows != rowNums { + return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) + } + case *schemapb.ScalarField_LongData: + fieldNumRows := getNumRowsOfScalarField(scalarField.GetLongData().Data) + if fieldNumRows != rowNums { + return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) + } + case *schemapb.ScalarField_FloatData: + fieldNumRows := getNumRowsOfScalarField(scalarField.GetFloatData().Data) + if fieldNumRows != rowNums { + return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) + } + case *schemapb.ScalarField_DoubleData: + fieldNumRows := getNumRowsOfScalarField(scalarField.GetDoubleData().Data) + if fieldNumRows != rowNums { + return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) + } + case *schemapb.ScalarField_BytesData: + return errUnsupportedDType("bytes") + case *schemapb.ScalarField_StringData: + return errUnsupportedDType("string") + case nil: + continue + default: + continue + } + case *schemapb.FieldData_Vectors: + vectorField := field.GetVectors() + switch vectorField.Data.(type) { + case *schemapb.VectorField_FloatVector: + dim := vectorField.GetDim() + fieldNumRows, err := getNumRowsOfFloatVectorField(vectorField.GetFloatVector().Data, dim) + if err != nil { + return err + } + if fieldNumRows != rowNums { + return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) + } + case *schemapb.VectorField_BinaryVector: + dim := vectorField.GetDim() + fieldNumRows, err := getNumRowsOfBinaryVectorField(vectorField.GetBinaryVector(), dim) + if err != nil { + return err + } + if fieldNumRows != rowNums { + return errNumRowsOfFieldDataMismatchPassed(i, fieldNumRows, rowNums) + } + case nil: + continue + default: + continue + } + } + } + + return nil +} + // TODO(dragondriver): ignore the order of fields in request, use the order of CollectionSchema to reorganize data func (it *InsertTask) transferColumnBasedRequestToRowBasedData() error { dTypes := make([]schemapb.DataType, 0, len(it.req.FieldsData)) @@ -441,11 +568,16 @@ func (it *InsertTask) transferColumnBasedRequestToRowBasedData() error { func (it *InsertTask) checkFieldAutoID() error { // TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields - rowNums := it.req.NumRows - if len(it.req.FieldsData) == 0 || rowNums == 0 { - return fmt.Errorf("do not contain any data") + if it.req.NumRows <= 0 { + return errNumRowsLessThanOrEqualToZero(it.req.NumRows) } + if err := it.checkLengthOfFieldsData(); err != nil { + return err + } + + rowNums := it.req.NumRows + primaryFieldName := "" autoIDFieldName := "" autoIDLoc := -1 @@ -611,6 +743,11 @@ func (it *InsertTask) PreExecute(ctx context.Context) error { } it.schema = collSchema + err = it.checkRowNums() + if err != nil { + return err + } + err = it.checkFieldAutoID() if err != nil { return err diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go new file mode 100644 index 0000000000..0bd7469638 --- /dev/null +++ b/internal/proxy/task_test.go @@ -0,0 +1,388 @@ +package proxy + +import ( + "testing" + + "github.com/milvus-io/milvus/internal/proto/milvuspb" + + "github.com/milvus-io/milvus/internal/proto/schemapb" + + "github.com/stretchr/testify/assert" +) + +func TestGetNumRowsOfScalarField(t *testing.T) { + cases := []struct { + datas interface{} + want uint32 + }{ + {[]bool{}, 0}, + {[]bool{true, false}, 2}, + {[]int32{}, 0}, + {[]int32{1, 2}, 2}, + {[]int64{}, 0}, + {[]int64{1, 2}, 2}, + {[]float32{}, 0}, + {[]float32{1.0, 2.0}, 2}, + {[]float64{}, 0}, + {[]float64{1.0, 2.0}, 2}, + } + + for _, test := range cases { + if got := getNumRowsOfScalarField(test.datas); got != test.want { + t.Errorf("getNumRowsOfScalarField(%v) = %v", test.datas, test.want) + } + } +} + +func TestGetNumRowsOfFloatVectorField(t *testing.T) { + cases := []struct { + fDatas []float32 + dim int64 + want uint32 + errIsNil bool + }{ + {[]float32{}, -1, 0, false}, // dim <= 0 + {[]float32{}, 0, 0, false}, // dim <= 0 + {[]float32{1.0}, 128, 0, false}, // length % dim != 0 + {[]float32{}, 128, 0, true}, + {[]float32{1.0, 2.0}, 2, 1, true}, + {[]float32{1.0, 2.0, 3.0, 4.0}, 2, 2, true}, + } + + for _, test := range cases { + got, err := getNumRowsOfFloatVectorField(test.fDatas, test.dim) + if test.errIsNil { + assert.Equal(t, nil, err) + if got != test.want { + t.Errorf("getNumRowsOfFloatVectorField(%v, %v) = %v, %v", test.fDatas, test.dim, test.want, nil) + } + } else { + assert.NotEqual(t, nil, err) + } + } +} + +func TestGetNumRowsOfBinaryVectorField(t *testing.T) { + cases := []struct { + bDatas []byte + dim int64 + want uint32 + errIsNil bool + }{ + {[]byte{}, -1, 0, false}, // dim <= 0 + {[]byte{}, 0, 0, false}, // dim <= 0 + {[]byte{1.0}, 128, 0, false}, // length % dim != 0 + {[]byte{}, 128, 0, true}, + {[]byte{1.0}, 1, 0, false}, // dim % 8 != 0 + {[]byte{1.0}, 4, 0, false}, // dim % 8 != 0 + {[]byte{1.0, 2.0}, 8, 2, true}, + {[]byte{1.0, 2.0}, 16, 1, true}, + {[]byte{1.0, 2.0, 3.0, 4.0}, 8, 4, true}, + {[]byte{1.0, 2.0, 3.0, 4.0}, 16, 2, true}, + {[]byte{1.0}, 128, 0, false}, // (8*l) % dim != 0 + } + + for _, test := range cases { + got, err := getNumRowsOfBinaryVectorField(test.bDatas, test.dim) + if test.errIsNil { + assert.Equal(t, nil, err) + if got != test.want { + t.Errorf("getNumRowsOfBinaryVectorField(%v, %v) = %v, %v", test.bDatas, test.dim, test.want, nil) + } + } else { + assert.NotEqual(t, nil, err) + } + } +} + +func TestInsertTask_checkLengthOfFieldsData(t *testing.T) { + var err error + + // schema is empty, though won't happened in system + case1 := InsertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkLengthOfFieldsData", + Description: "TestInsertTask_checkLengthOfFieldsData", + AutoID: false, + Fields: []*schemapb.FieldSchema{}, + }, + req: &milvuspb.InsertRequest{ + DbName: "TestInsertTask_checkLengthOfFieldsData", + CollectionName: "TestInsertTask_checkLengthOfFieldsData", + PartitionName: "TestInsertTask_checkLengthOfFieldsData", + FieldsData: nil, + }, + } + err = case1.checkLengthOfFieldsData() + assert.Equal(t, nil, err) + + // schema has two fields, neither of them are autoID + case2 := InsertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkLengthOfFieldsData", + Description: "TestInsertTask_checkLengthOfFieldsData", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + { + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + }, + }, + } + // passed fields is empty + case2.req = &milvuspb.InsertRequest{} + err = case2.checkLengthOfFieldsData() + assert.NotEqual(t, nil, err) + // the num of passed fields is less than needed + case2.req = &milvuspb.InsertRequest{ + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + }, + }, + } + err = case2.checkLengthOfFieldsData() + assert.NotEqual(t, nil, err) + // satisfied + case2.req = &milvuspb.InsertRequest{ + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + }, + { + Type: schemapb.DataType_Int64, + }, + }, + } + err = case2.checkLengthOfFieldsData() + assert.Equal(t, nil, err) + + // schema has two field, one of them are autoID + case3 := InsertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkLengthOfFieldsData", + Description: "TestInsertTask_checkLengthOfFieldsData", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + { + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + }, + }, + } + // passed fields is empty + case3.req = &milvuspb.InsertRequest{} + err = case3.checkLengthOfFieldsData() + assert.NotEqual(t, nil, err) + // satisfied + case3.req = &milvuspb.InsertRequest{ + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + }, + }, + } + err = case3.checkLengthOfFieldsData() + assert.Equal(t, nil, err) + + // schema has one field which is autoID + case4 := InsertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkLengthOfFieldsData", + Description: "TestInsertTask_checkLengthOfFieldsData", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + { + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + }, + }, + } + // passed fields is empty + // satisfied + case4.req = &milvuspb.InsertRequest{} + err = case4.checkLengthOfFieldsData() + assert.Equal(t, nil, err) +} + +func TestInsertTask_checkRowNums(t *testing.T) { + var err error + + // passed NumRows is less than 0 + case1 := InsertTask{ + req: &milvuspb.InsertRequest{ + NumRows: 0, + }, + } + err = case1.checkRowNums() + assert.NotEqual(t, nil, err) + + // checkLengthOfFieldsData was already checked by TestInsertTask_checkLengthOfFieldsData + + numRows := 20 + dim := 128 + case2 := InsertTask{ + schema: &schemapb.CollectionSchema{ + Name: "TestInsertTask_checkRowNums", + Description: "TestInsertTask_checkRowNums", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + {DataType: schemapb.DataType_Bool}, + {DataType: schemapb.DataType_Int8}, + {DataType: schemapb.DataType_Int16}, + {DataType: schemapb.DataType_Int32}, + {DataType: schemapb.DataType_Int64}, + {DataType: schemapb.DataType_Float}, + {DataType: schemapb.DataType_Double}, + {DataType: schemapb.DataType_FloatVector}, + {DataType: schemapb.DataType_BinaryVector}, + }, + }, + } + + // satisfied + case2.req = &milvuspb.InsertRequest{ + NumRows: uint32(numRows), + FieldsData: []*schemapb.FieldData{ + newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows), + newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows), + newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows), + newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows), + newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows), + newScalarFieldData(schemapb.DataType_Float, "Float", numRows), + newScalarFieldData(schemapb.DataType_Double, "Double", numRows), + newFloatVectorFieldData("FloatVector", numRows, dim), + newBinaryVectorFieldData("BinaryVector", numRows, dim), + }, + } + err = case2.checkRowNums() + assert.Equal(t, nil, err) + + // less bool data + case2.req.FieldsData[0] = newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows/2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // more bool data + case2.req.FieldsData[0] = newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows*2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // revert + case2.req.FieldsData[0] = newScalarFieldData(schemapb.DataType_Bool, "Bool", numRows) + err = case2.checkRowNums() + assert.Equal(t, nil, err) + + // less int8 data + case2.req.FieldsData[1] = newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows/2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // more int8 data + case2.req.FieldsData[1] = newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows*2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // revert + case2.req.FieldsData[1] = newScalarFieldData(schemapb.DataType_Int8, "Int8", numRows) + err = case2.checkRowNums() + assert.Equal(t, nil, err) + + // less int16 data + case2.req.FieldsData[2] = newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows/2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // more int16 data + case2.req.FieldsData[2] = newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows*2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // revert + case2.req.FieldsData[2] = newScalarFieldData(schemapb.DataType_Int16, "Int16", numRows) + err = case2.checkRowNums() + assert.Equal(t, nil, err) + + // less int32 data + case2.req.FieldsData[3] = newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows/2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // more int32 data + case2.req.FieldsData[3] = newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows*2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // revert + case2.req.FieldsData[3] = newScalarFieldData(schemapb.DataType_Int32, "Int32", numRows) + err = case2.checkRowNums() + assert.Equal(t, nil, err) + + // less int64 data + case2.req.FieldsData[4] = newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows/2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // more int64 data + case2.req.FieldsData[4] = newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows*2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // revert + case2.req.FieldsData[4] = newScalarFieldData(schemapb.DataType_Int64, "Int64", numRows) + err = case2.checkRowNums() + assert.Equal(t, nil, err) + + // less float data + case2.req.FieldsData[5] = newScalarFieldData(schemapb.DataType_Float, "Float", numRows/2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // more float data + case2.req.FieldsData[5] = newScalarFieldData(schemapb.DataType_Float, "Float", numRows*2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // revert + case2.req.FieldsData[5] = newScalarFieldData(schemapb.DataType_Float, "Float", numRows) + err = case2.checkRowNums() + assert.Equal(t, nil, err) + + // less double data + case2.req.FieldsData[6] = newScalarFieldData(schemapb.DataType_Double, "Double", numRows/2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // more double data + case2.req.FieldsData[6] = newScalarFieldData(schemapb.DataType_Double, "Double", numRows*2) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // revert + case2.req.FieldsData[6] = newScalarFieldData(schemapb.DataType_Double, "Double", numRows) + err = case2.checkRowNums() + assert.Equal(t, nil, err) + + // less float vectors + case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // more float vectors + case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // revert + case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim) + err = case2.checkRowNums() + assert.Equal(t, nil, err) + + // less binary vectors + case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // more binary vectors + case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim) + err = case2.checkRowNums() + assert.NotEqual(t, nil, err) + // revert + case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim) + err = case2.checkRowNums() + assert.Equal(t, nil, err) +} diff --git a/internal/proxy/test_utils.go b/internal/proxy/test_utils.go new file mode 100644 index 0000000000..bd360cd215 --- /dev/null +++ b/internal/proxy/test_utils.go @@ -0,0 +1,197 @@ +package proxy + +import ( + "math/rand" + + "github.com/milvus-io/milvus/internal/proto/schemapb" +) + +func generateBoolArray(numRows int) []bool { + ret := make([]bool, 0, numRows) + for i := 0; i < numRows; i++ { + ret = append(ret, rand.Int()%2 == 0) + } + return ret +} + +func generateInt8Array(numRows int) []int8 { + ret := make([]int8, 0, numRows) + for i := 0; i < numRows; i++ { + ret = append(ret, int8(rand.Int())) + } + return ret +} + +func generateInt16Array(numRows int) []int16 { + ret := make([]int16, 0, numRows) + for i := 0; i < numRows; i++ { + ret = append(ret, int16(rand.Int())) + } + return ret +} + +func generateInt32Array(numRows int) []int32 { + ret := make([]int32, 0, numRows) + for i := 0; i < numRows; i++ { + ret = append(ret, int32(rand.Int())) + } + return ret +} + +func generateInt64Array(numRows int) []int64 { + ret := make([]int64, 0, numRows) + for i := 0; i < numRows; i++ { + ret = append(ret, int64(rand.Int())) + } + return ret +} + +func generateFloat32Array(numRows int) []float32 { + ret := make([]float32, 0, numRows) + for i := 0; i < numRows; i++ { + ret = append(ret, rand.Float32()) + } + return ret +} + +func generateFloat64Array(numRows int) []float64 { + ret := make([]float64, 0, numRows) + for i := 0; i < numRows; i++ { + ret = append(ret, rand.Float64()) + } + return ret +} + +func generateFloatVectors(numRows, dim int) []float32 { + total := numRows * dim + ret := make([]float32, 0, total) + for i := 0; i < total; i++ { + ret = append(ret, rand.Float32()) + } + return ret +} + +func generateBinaryVectors(numRows, dim int) []byte { + total := (numRows * dim) / 8 + ret := make([]byte, total) + _, err := rand.Read(ret) + if err != nil { + panic(err) + } + return ret +} + +func newScalarFieldData(dType schemapb.DataType, fieldName string, numRows int) *schemapb.FieldData { + ret := &schemapb.FieldData{ + Type: dType, + FieldName: fieldName, + Field: nil, + } + + switch dType { + case schemapb.DataType_Bool: + ret.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: generateBoolArray(numRows), + }, + }, + }, + } + case schemapb.DataType_Int8: + ret.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: generateInt32Array(numRows), + }, + }, + }, + } + case schemapb.DataType_Int16: + ret.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: generateInt32Array(numRows), + }, + }, + }, + } + case schemapb.DataType_Int32: + ret.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: generateInt32Array(numRows), + }, + }, + }, + } + case schemapb.DataType_Int64: + ret.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: generateInt64Array(numRows), + }, + }, + }, + } + case schemapb.DataType_Float: + ret.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: generateFloat32Array(numRows), + }, + }, + }, + } + case schemapb.DataType_Double: + ret.Field = &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_DoubleData{ + DoubleData: &schemapb.DoubleArray{ + Data: generateFloat64Array(numRows), + }, + }, + }, + } + } + + return ret +} + +func newFloatVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: schemapb.DataType_FloatVector, + FieldName: fieldName, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: generateFloatVectors(numRows, dim), + }, + }, + }, + }, + } +} + +func newBinaryVectorFieldData(fieldName string, numRows, dim int) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: schemapb.DataType_BinaryVector, + FieldName: fieldName, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: generateBinaryVectors(numRows, dim), + }, + }, + }, + } +}