From 3d07b6682cedbb36146cfa0bbd384395c42e883b Mon Sep 17 00:00:00 2001 From: "yihao.dai" Date: Mon, 8 Jan 2024 19:42:49 +0800 Subject: [PATCH] feat: Add import reader for numpy (#29253) This PR implements a new numpy reader for import. issue: https://github.com/milvus-io/milvus/issues/28521 --------- Signed-off-by: bigsheeper --- internal/storage/insert_data.go | 15 + .../util/importutilv2/numpy/field_reader.go | 294 +++++++++++++++++ internal/util/importutilv2/numpy/reader.go | 82 +++++ .../util/importutilv2/numpy/reader_test.go | 308 ++++++++++++++++++ internal/util/importutilv2/numpy/util.go | 242 ++++++++++++++ 5 files changed, 941 insertions(+) create mode 100644 internal/util/importutilv2/numpy/field_reader.go create mode 100644 internal/util/importutilv2/numpy/reader.go create mode 100644 internal/util/importutilv2/numpy/reader_test.go create mode 100644 internal/util/importutilv2/numpy/util.go diff --git a/internal/storage/insert_data.go b/internal/storage/insert_data.go index 4fafbf160c..fcfdb8d192 100644 --- a/internal/storage/insert_data.go +++ b/internal/storage/insert_data.go @@ -127,6 +127,7 @@ type FieldData interface { GetMemorySize() int RowNum() int GetRow(i int) any + GetRows() any AppendRow(row interface{}) error AppendRows(rows interface{}) error GetDataType() schemapb.DataType @@ -298,6 +299,20 @@ func (data *Float16VectorFieldData) GetRow(i int) interface{} { return data.Data[i*data.Dim*2 : (i+1)*data.Dim*2] } +func (data *BoolFieldData) GetRows() any { return data.Data } +func (data *Int8FieldData) GetRows() any { return data.Data } +func (data *Int16FieldData) GetRows() any { return data.Data } +func (data *Int32FieldData) GetRows() any { return data.Data } +func (data *Int64FieldData) GetRows() any { return data.Data } +func (data *FloatFieldData) GetRows() any { return data.Data } +func (data *DoubleFieldData) GetRows() any { return data.Data } +func (data *StringFieldData) GetRows() any { return data.Data } +func (data *ArrayFieldData) GetRows() any { return data.Data } +func (data *JSONFieldData) GetRows() any { return data.Data } +func (data *BinaryVectorFieldData) GetRows() any { return data.Data } +func (data *FloatVectorFieldData) GetRows() any { return data.Data } +func (data *Float16VectorFieldData) GetRows() any { return data.Data } + // AppendRow implements FieldData.AppendRow func (data *BoolFieldData) AppendRow(row interface{}) error { v, ok := row.(bool) diff --git a/internal/util/importutilv2/numpy/field_reader.go b/internal/util/importutilv2/numpy/field_reader.go new file mode 100644 index 0000000000..8c4405b116 --- /dev/null +++ b/internal/util/importutilv2/numpy/field_reader.go @@ -0,0 +1,294 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package numpy + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "unicode/utf8" + + "github.com/samber/lo" + "github.com/sbinet/npyio" + "github.com/sbinet/npyio/npy" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type FieldReader struct { + reader io.Reader + npyReader *npy.Reader + order binary.ByteOrder + + dim int64 + field *schemapb.FieldSchema + + readPosition int +} + +func NewFieldReader(reader io.Reader, field *schemapb.FieldSchema) (*FieldReader, error) { + r, err := npyio.NewReader(reader) + if err != nil { + return nil, err + } + + var dim int64 = 1 + if typeutil.IsVectorType(field.GetDataType()) { + dim, err = typeutil.GetDim(field) + if err != nil { + return nil, err + } + } + + err = validateHeader(r, field, int(dim)) + if err != nil { + return nil, err + } + + cr := &FieldReader{ + reader: reader, + npyReader: r, + dim: dim, + field: field, + } + cr.setByteOrder() + return cr, nil +} + +func ReadN[T any](reader io.Reader, order binary.ByteOrder, n int64) ([]T, error) { + data := make([]T, n) + err := binary.Read(reader, order, &data) + if err != nil { + return nil, err + } + return data, nil +} + +func (c *FieldReader) getCount(count int64) int64 { + shape := c.npyReader.Header.Descr.Shape + if len(shape) == 0 { + return 0 + } + total := 1 + for i := 0; i < len(shape); i++ { + total *= shape[i] + } + if total == 0 { + return 0 + } + if c.field.GetDataType() == schemapb.DataType_BinaryVector { + count *= c.dim / 8 + } else if c.field.GetDataType() == schemapb.DataType_FloatVector { + count *= c.dim + } + if int(count) > (total - c.readPosition) { + return int64(total - c.readPosition) + } + return count +} + +func (c *FieldReader) Next(count int64) (any, error) { + readCount := c.getCount(count) + if readCount == 0 { + return nil, nil + } + var ( + data any + err error + ) + dt := c.field.GetDataType() + switch dt { + case schemapb.DataType_Bool: + data, err = ReadN[bool](c.reader, c.order, readCount) + if err != nil { + return nil, err + } + c.readPosition += int(readCount) + case schemapb.DataType_Int8: + data, err = ReadN[int8](c.reader, c.order, readCount) + if err != nil { + return nil, err + } + c.readPosition += int(readCount) + case schemapb.DataType_Int16: + data, err = ReadN[int16](c.reader, c.order, readCount) + if err != nil { + return nil, err + } + c.readPosition += int(readCount) + case schemapb.DataType_Int32: + data, err = ReadN[int32](c.reader, c.order, readCount) + if err != nil { + return nil, err + } + c.readPosition += int(readCount) + case schemapb.DataType_Int64: + data, err = ReadN[int64](c.reader, c.order, readCount) + if err != nil { + return nil, err + } + c.readPosition += int(readCount) + case schemapb.DataType_Float: + data, err = ReadN[float32](c.reader, c.order, readCount) + if err != nil { + return nil, err + } + c.readPosition += int(readCount) + case schemapb.DataType_Double: + data, err = ReadN[float64](c.reader, c.order, readCount) + if err != nil { + return nil, err + } + c.readPosition += int(readCount) + case schemapb.DataType_VarChar: + data, err = c.ReadString(readCount) + c.readPosition += int(readCount) + if err != nil { + return nil, err + } + case schemapb.DataType_JSON: + var strs []string + strs, err = c.ReadString(readCount) + if err != nil { + return nil, err + } + byteArr := make([][]byte, 0) + for _, str := range strs { + var dummy interface{} + err = json.Unmarshal([]byte(str), &dummy) + if err != nil { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("failed to parse value '%v' for JSON field '%s', error: %v", str, c.field.GetName(), err)) + } + byteArr = append(byteArr, []byte(str)) + } + data = byteArr + c.readPosition += int(readCount) + case schemapb.DataType_BinaryVector: + data, err = ReadN[uint8](c.reader, c.order, readCount) + if err != nil { + return nil, err + } + c.readPosition += int(readCount) + case schemapb.DataType_FloatVector: + var elementType schemapb.DataType + elementType, err = convertNumpyType(c.npyReader.Header.Descr.Type) + if err != nil { + return nil, err + } + switch elementType { + case schemapb.DataType_Float: + data, err = ReadN[float32](c.reader, c.order, readCount) + if err != nil { + return nil, err + } + err = typeutil.VerifyFloats32(data.([]float32)) + if err != nil { + return nil, nil + } + case schemapb.DataType_Double: + var data64 []float64 + data64, err = ReadN[float64](c.reader, c.order, readCount) + if err != nil { + return nil, err + } + err = typeutil.VerifyFloats64(data64) + if err != nil { + return nil, err + } + data = lo.Map(data64, func(f float64, _ int) float32 { + return float32(f) + }) + } + c.readPosition += int(readCount) + default: + return nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type: %s", dt.String())) + } + return data, nil +} + +func (c *FieldReader) Close() {} + +// setByteOrder sets BigEndian/LittleEndian, the logic of this method is copied from npyio lib +func (c *FieldReader) setByteOrder() { + var nativeEndian binary.ByteOrder + v := uint16(1) + switch byte(v >> 8) { + case 0: + nativeEndian = binary.LittleEndian + case 1: + nativeEndian = binary.BigEndian + } + + switch c.npyReader.Header.Descr.Type[0] { + case '<': + c.order = binary.LittleEndian + case '>': + c.order = binary.BigEndian + default: + c.order = nativeEndian + } +} + +func (c *FieldReader) ReadString(count int64) ([]string, error) { + // varchar length, this is the max length, some item is shorter than this length, but they also occupy bytes of max length + maxLen, utf, err := stringLen(c.npyReader.Header.Descr.Type) + if err != nil || maxLen <= 0 { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("failed to get max length %d of varchar from numpy file header, error: %v", maxLen, err)) + } + + // read data + data := make([]string, 0, count) + for len(data) < int(count) { + if utf { + // in the numpy file with utf32 encoding, the dType could be like " 0 { + buf = buf[:n] + } + data = append(data, string(buf)) + } + } + return data, nil +} diff --git a/internal/util/importutilv2/numpy/reader.go b/internal/util/importutilv2/numpy/reader.go new file mode 100644 index 0000000000..5606449bb1 --- /dev/null +++ b/internal/util/importutilv2/numpy/reader.go @@ -0,0 +1,82 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package numpy + +import ( + "io" + + "github.com/samber/lo" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" +) + +type Reader struct { + schema *schemapb.CollectionSchema + count int64 + frs map[int64]*FieldReader // fieldID -> FieldReader +} + +func NewReader(schema *schemapb.CollectionSchema, readers map[int64]io.Reader, bufferSize int) (*Reader, error) { + fields := lo.KeyBy(schema.GetFields(), func(field *schemapb.FieldSchema) int64 { + return field.GetFieldID() + }) + count, err := calcRowCount(bufferSize, schema) + if err != nil { + return nil, err + } + crs := make(map[int64]*FieldReader) + for fieldID, r := range readers { + cr, err := NewFieldReader(r, fields[fieldID]) + if err != nil { + return nil, err + } + crs[fieldID] = cr + } + return &Reader{ + schema: schema, + count: count, + frs: crs, + }, nil +} + +func (r *Reader) Read() (*storage.InsertData, error) { + insertData, err := storage.NewInsertData(r.schema) + if err != nil { + return nil, err + } + for fieldID, cr := range r.frs { + data, err := cr.Next(r.count) + if err != nil { + return nil, err + } + if data == nil { + return nil, nil + } + err = insertData.Data[fieldID].AppendRows(data) + if err != nil { + return nil, err + } + } + return insertData, nil +} + +func (r *Reader) Close() { + for _, cr := range r.frs { + cr.Close() + } +} diff --git a/internal/util/importutilv2/numpy/reader_test.go b/internal/util/importutilv2/numpy/reader_test.go new file mode 100644 index 0000000000..42dae31da0 --- /dev/null +++ b/internal/util/importutilv2/numpy/reader_test.go @@ -0,0 +1,308 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package numpy + +import ( + "bytes" + rand2 "crypto/rand" + "fmt" + "io" + "math" + "math/rand" + "strconv" + "strings" + "testing" + + "github.com/samber/lo" + "github.com/sbinet/npyio" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "golang.org/x/exp/slices" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type ReaderSuite struct { + suite.Suite + + numRows int + pkDataType schemapb.DataType + vecDataType schemapb.DataType +} + +func (suite *ReaderSuite) SetupSuite() { + paramtable.Get().Init(paramtable.NewBaseTable()) +} + +func (suite *ReaderSuite) SetupTest() { + // default suite params + suite.numRows = 100 + suite.pkDataType = schemapb.DataType_Int64 + suite.vecDataType = schemapb.DataType_FloatVector +} + +func createInsertData(t *testing.T, schema *schemapb.CollectionSchema, rowCount int) *storage.InsertData { + insertData, err := storage.NewInsertData(schema) + assert.NoError(t, err) + for _, field := range schema.GetFields() { + switch field.GetDataType() { + case schemapb.DataType_Bool: + boolData := make([]bool, 0) + for i := 0; i < rowCount; i++ { + boolData = append(boolData, i%3 != 0) + } + insertData.Data[field.GetFieldID()] = &storage.BoolFieldData{Data: boolData} + case schemapb.DataType_Float: + floatData := make([]float32, 0) + for i := 0; i < rowCount; i++ { + floatData = append(floatData, float32(i/2)) + } + insertData.Data[field.GetFieldID()] = &storage.FloatFieldData{Data: floatData} + case schemapb.DataType_Double: + doubleData := make([]float64, 0) + for i := 0; i < rowCount; i++ { + doubleData = append(doubleData, float64(i/5)) + } + insertData.Data[field.GetFieldID()] = &storage.DoubleFieldData{Data: doubleData} + case schemapb.DataType_Int8: + int8Data := make([]int8, 0) + for i := 0; i < rowCount; i++ { + int8Data = append(int8Data, int8(i%256)) + } + insertData.Data[field.GetFieldID()] = &storage.Int8FieldData{Data: int8Data} + case schemapb.DataType_Int16: + int16Data := make([]int16, 0) + for i := 0; i < rowCount; i++ { + int16Data = append(int16Data, int16(i%65536)) + } + insertData.Data[field.GetFieldID()] = &storage.Int16FieldData{Data: int16Data} + case schemapb.DataType_Int32: + int32Data := make([]int32, 0) + for i := 0; i < rowCount; i++ { + int32Data = append(int32Data, int32(i%1000)) + } + insertData.Data[field.GetFieldID()] = &storage.Int32FieldData{Data: int32Data} + case schemapb.DataType_Int64: + int64Data := make([]int64, 0) + for i := 0; i < rowCount; i++ { + int64Data = append(int64Data, int64(i)) + } + insertData.Data[field.GetFieldID()] = &storage.Int64FieldData{Data: int64Data} + case schemapb.DataType_BinaryVector: + dim, err := typeutil.GetDim(field) + assert.NoError(t, err) + binVecData := make([]byte, 0) + total := rowCount * int(dim) / 8 + for i := 0; i < total; i++ { + binVecData = append(binVecData, byte(i%256)) + } + insertData.Data[field.GetFieldID()] = &storage.BinaryVectorFieldData{Data: binVecData, Dim: int(dim)} + case schemapb.DataType_FloatVector: + dim, err := typeutil.GetDim(field) + assert.NoError(t, err) + floatVecData := make([]float32, 0) + total := rowCount * int(dim) + for i := 0; i < total; i++ { + floatVecData = append(floatVecData, rand.Float32()) + } + insertData.Data[field.GetFieldID()] = &storage.FloatVectorFieldData{Data: floatVecData, Dim: int(dim)} + case schemapb.DataType_Float16Vector: + dim, err := typeutil.GetDim(field) + assert.NoError(t, err) + total := int64(rowCount) * dim * 2 + float16VecData := make([]byte, total) + _, err = rand2.Read(float16VecData) + assert.NoError(t, err) + insertData.Data[field.GetFieldID()] = &storage.Float16VectorFieldData{Data: float16VecData, Dim: int(dim)} + case schemapb.DataType_String, schemapb.DataType_VarChar: + varcharData := make([]string, 0) + for i := 0; i < rowCount; i++ { + varcharData = append(varcharData, strconv.Itoa(i)) + } + insertData.Data[field.GetFieldID()] = &storage.StringFieldData{Data: varcharData} + case schemapb.DataType_JSON: + jsonData := make([][]byte, 0) + for i := 0; i < rowCount; i++ { + jsonData = append(jsonData, []byte(fmt.Sprintf("{\"y\": %d}", i))) + } + insertData.Data[field.GetFieldID()] = &storage.JSONFieldData{Data: jsonData} + case schemapb.DataType_Array: + arrayData := make([]*schemapb.ScalarField, 0) + for i := 0; i < rowCount; i++ { + arrayData = append(arrayData, &schemapb.ScalarField{ + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{ + Data: []int32{int32(i), int32(i + 1), int32(i + 2)}, + }, + }, + }) + } + insertData.Data[field.GetFieldID()] = &storage.ArrayFieldData{Data: arrayData} + default: + panic(fmt.Sprintf("unexpected data type: %s", field.GetDataType().String())) + } + } + return insertData +} + +func CreateReader(data interface{}) (io.Reader, error) { + buf := new(bytes.Buffer) + err := npyio.Write(buf, data) + if err != nil { + return nil, err + } + return strings.NewReader(buf.String()), nil +} + +func (suite *ReaderSuite) run(dt schemapb.DataType) { + const dim = 8 + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + IsPrimaryKey: true, + DataType: suite.pkDataType, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "max_length", + Value: "256", + }, + }, + }, + { + FieldID: 101, + Name: "vec", + DataType: suite.vecDataType, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: fmt.Sprintf("%d", dim), + }, + }, + }, + { + FieldID: 102, + Name: dt.String(), + DataType: dt, + ElementType: schemapb.DataType_Int32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "max_length", + Value: "256", + }, + }, + }, + }, + } + insertData := createInsertData(suite.T(), schema, suite.numRows) + fieldIDToField := lo.KeyBy(schema.GetFields(), func(field *schemapb.FieldSchema) int64 { + return field.GetFieldID() + }) + + readers := make(map[int64]io.Reader) + for fieldID, fieldData := range insertData.Data { + dataType := fieldIDToField[fieldID].GetDataType() + if dataType == schemapb.DataType_JSON { + jsonStrs := make([]string, 0, fieldData.RowNum()) + for i := 0; i < fieldData.RowNum(); i++ { + row := fieldData.GetRow(i) + jsonStrs = append(jsonStrs, string(row.([]byte))) + } + reader, err := CreateReader(jsonStrs) + suite.NoError(err) + readers[fieldID] = reader + } else if dataType == schemapb.DataType_FloatVector { + chunked := lo.Chunk(insertData.Data[fieldID].GetRows().([]float32), dim) + chunkedRows := make([][dim]float32, len(chunked)) + for i, innerSlice := range chunked { + copy(chunkedRows[i][:], innerSlice[:]) + } + reader, err := CreateReader(chunkedRows) + suite.NoError(err) + readers[fieldID] = reader + } else if dataType == schemapb.DataType_BinaryVector { + chunked := lo.Chunk(insertData.Data[fieldID].GetRows().([]byte), dim/8) + chunkedRows := make([][dim / 8]byte, len(chunked)) + for i, innerSlice := range chunked { + copy(chunkedRows[i][:], innerSlice[:]) + } + reader, err := CreateReader(chunkedRows) + suite.NoError(err) + readers[fieldID] = reader + } else { + reader, err := CreateReader(insertData.Data[fieldID].GetRows()) + suite.NoError(err) + readers[fieldID] = reader + } + } + + reader, err := NewReader(schema, readers, math.MaxInt) + suite.NoError(err) + + checkFn := func(actualInsertData *storage.InsertData, offsetBegin, expectRows int) { + expectInsertData := insertData + for fieldID, data := range actualInsertData.Data { + suite.Equal(expectRows, data.RowNum()) + fieldDataType := typeutil.GetField(schema, fieldID).GetDataType() + for i := 0; i < expectRows; i++ { + expect := expectInsertData.Data[fieldID].GetRow(i + offsetBegin) + actual := data.GetRow(i) + if fieldDataType == schemapb.DataType_Array { + suite.True(slices.Equal(expect.(*schemapb.ScalarField).GetIntData().GetData(), actual.(*schemapb.ScalarField).GetIntData().GetData())) + } else { + suite.Equal(expect, actual) + } + } + } + } + + res, err := reader.Read() + suite.NoError(err) + checkFn(res, 0, suite.numRows) +} + +func (suite *ReaderSuite) TestReadScalarFields() { + suite.run(schemapb.DataType_Bool) + suite.run(schemapb.DataType_Int8) + suite.run(schemapb.DataType_Int16) + suite.run(schemapb.DataType_Int32) + suite.run(schemapb.DataType_Int64) + suite.run(schemapb.DataType_Float) + suite.run(schemapb.DataType_Double) + suite.run(schemapb.DataType_VarChar) + suite.run(schemapb.DataType_JSON) +} + +func (suite *ReaderSuite) TestStringPK() { + suite.pkDataType = schemapb.DataType_VarChar + suite.run(schemapb.DataType_Int32) +} + +func (suite *ReaderSuite) TestBinaryVector() { + suite.vecDataType = schemapb.DataType_BinaryVector + suite.run(schemapb.DataType_Int32) +} + +func TestUtil(t *testing.T) { + suite.Run(t, new(ReaderSuite)) +} diff --git a/internal/util/importutilv2/numpy/util.go b/internal/util/importutilv2/numpy/util.go new file mode 100644 index 0000000000..b0b4bf5586 --- /dev/null +++ b/internal/util/importutilv2/numpy/util.go @@ -0,0 +1,242 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package numpy + +import ( + "encoding/binary" + "fmt" + "reflect" + "regexp" + "strconv" + "unicode/utf8" + + "github.com/sbinet/npyio" + "github.com/sbinet/npyio/npy" + "golang.org/x/text/encoding/unicode" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +var ( + reStrPre = regexp.MustCompile(`^[|]*?(\d.*)[Sa]$`) + reStrPost = regexp.MustCompile(`^[|]*?[Sa](\d.*)$`) + reUniPre = regexp.MustCompile(`^[<|>]*?(\d.*)U$`) + reUniPost = regexp.MustCompile(`^[<|>]*?U(\d.*)$`) +) + +func stringLen(dtype string) (int, bool, error) { + var utf bool + switch { + case reStrPre.MatchString(dtype), reStrPost.MatchString(dtype): + utf = false + case reUniPre.MatchString(dtype), reUniPost.MatchString(dtype): + utf = true + } + + if m := reStrPre.FindStringSubmatch(dtype); m != nil { + v, err := strconv.Atoi(m[1]) + if err != nil { + return 0, false, err + } + return v, utf, nil + } + if m := reStrPost.FindStringSubmatch(dtype); m != nil { + v, err := strconv.Atoi(m[1]) + if err != nil { + return 0, false, err + } + return v, utf, nil + } + if m := reUniPre.FindStringSubmatch(dtype); m != nil { + v, err := strconv.Atoi(m[1]) + if err != nil { + return 0, false, err + } + return v, utf, nil + } + if m := reUniPost.FindStringSubmatch(dtype); m != nil { + v, err := strconv.Atoi(m[1]) + if err != nil { + return 0, false, err + } + return v, utf, nil + } + + return 0, false, merr.WrapErrImportFailed(fmt.Sprintf("dtype '%s' of numpy file is not varchar data type", dtype)) +} + +func decodeUtf32(src []byte, order binary.ByteOrder) (string, error) { + if len(src)%4 != 0 { + return "", merr.WrapErrImportFailed(fmt.Sprintf("invalid utf32 bytes length %d, the byte array length should be multiple of 4", len(src))) + } + + var str string + for len(src) > 0 { + // check the high bytes, if high bytes are 0, the UNICODE is less than U+FFFF, we can use unicode.UTF16 to decode + isUtf16 := false + var lowbytesPosition int + uOrder := unicode.LittleEndian + if order == binary.LittleEndian { + if src[2] == 0 && src[3] == 0 { + isUtf16 = true + } + lowbytesPosition = 0 + } else { + if src[0] == 0 && src[1] == 0 { + isUtf16 = true + } + lowbytesPosition = 2 + uOrder = unicode.BigEndian + } + + if isUtf16 { + // use unicode.UTF16 to decode the low bytes to utf8 + // utf32 and utf16 is same if the unicode code is less than 65535 + if src[lowbytesPosition] != 0 || src[lowbytesPosition+1] != 0 { + decoder := unicode.UTF16(uOrder, unicode.IgnoreBOM).NewDecoder() + res, err := decoder.Bytes(src[lowbytesPosition : lowbytesPosition+2]) + if err != nil { + return "", merr.WrapErrImportFailed(fmt.Sprintf("failed to decode utf32 binary bytes, error: %v", err)) + } + str += string(res) + } + } else { + // convert the 4 bytes to a unicode and encode to utf8 + // Golang strongly opposes utf32 coding, this kind of encoding has been excluded from standard lib + var x uint32 + if order == binary.LittleEndian { + x = uint32(src[3])<<24 | uint32(src[2])<<16 | uint32(src[1])<<8 | uint32(src[0]) + } else { + x = uint32(src[0])<<24 | uint32(src[1])<<16 | uint32(src[2])<<8 | uint32(src[3]) + } + r := rune(x) + utf8Code := make([]byte, 4) + utf8.EncodeRune(utf8Code, r) + if r == utf8.RuneError { + return "", merr.WrapErrImportFailed(fmt.Sprintf("failed to convert 4 bytes unicode %d to utf8 rune", x)) + } + str += string(utf8Code) + } + + src = src[4:] + } + return str, nil +} + +// convertNumpyType gets data type converted from numpy header description, +// for vector field, the type is int8(binary vector) or float32(float vector) +func convertNumpyType(typeStr string) (schemapb.DataType, error) { + switch typeStr { + case "b1", "i1", "int8": + return schemapb.DataType_Int8, nil + case "i2", "i2", "int16": + return schemapb.DataType_Int16, nil + case "i4", "i4", "int32": + return schemapb.DataType_Int32, nil + case "i8", "i8", "int64": + return schemapb.DataType_Int64, nil + case "f4", "f4", "float32": + return schemapb.DataType_Float, nil + case "f8", "f8", "float64": + return schemapb.DataType_Double, nil + default: + rt := npyio.TypeFrom(typeStr) + if rt == reflect.TypeOf((*string)(nil)).Elem() { + // Note: JSON field and VARCHAR field are using string type numpy + return schemapb.DataType_VarChar, nil + } + return schemapb.DataType_None, merr.WrapErrImportFailed( + fmt.Sprintf("the numpy file dtype '%s' is not supported", typeStr)) + } +} + +func wrapElementTypeError(eleType schemapb.DataType, field *schemapb.FieldSchema) error { + return merr.WrapErrImportFailed(fmt.Sprintf("expected element type '%s' for field '%s', got type '%T'", + field.GetDataType().String(), field.GetName(), eleType)) +} + +func wrapDimError(actualDim int, expectDim int, field *schemapb.FieldSchema) error { + return merr.WrapErrImportFailed(fmt.Sprintf("expected dim '%d' for %s field '%s', got dim '%d'", + expectDim, field.GetDataType().String(), field.GetName(), actualDim)) +} + +func wrapShapeError(actualShape int, expectShape int, field *schemapb.FieldSchema) error { + return merr.WrapErrImportFailed(fmt.Sprintf("expected shape '%d' for %s field '%s', got shape '%d'", + expectShape, field.GetDataType().String(), field.GetName(), actualShape)) +} + +func validateHeader(npyReader *npy.Reader, field *schemapb.FieldSchema, dim int) error { + elementType, err := convertNumpyType(npyReader.Header.Descr.Type) + if err != nil { + return err + } + shape := npyReader.Header.Descr.Shape + + switch field.GetDataType() { + case schemapb.DataType_FloatVector: + if elementType != schemapb.DataType_Float && elementType != schemapb.DataType_Double { + return wrapElementTypeError(elementType, field) + } + if len(shape) != 2 { + return wrapShapeError(len(shape), 2, field) + } + if shape[1] != dim { + return wrapDimError(shape[1], dim, field) + } + case schemapb.DataType_BinaryVector: + if elementType != schemapb.DataType_BinaryVector { + return wrapElementTypeError(elementType, field) + } + if len(shape) != 2 { + return wrapShapeError(len(shape), 2, field) + } + if shape[1] != dim/8 { + return wrapDimError(shape[1]*8, dim, field) + } + case schemapb.DataType_VarChar, schemapb.DataType_JSON: + if len(shape) != 1 { + return wrapShapeError(len(shape), 1, field) + } + case schemapb.DataType_None, schemapb.DataType_Array, + schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector: + return merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type: %s", field.GetDataType().String())) + + default: + if elementType != field.GetDataType() { + return wrapElementTypeError(elementType, field) + } + if len(shape) != 1 { + return wrapShapeError(len(shape), 1, field) + } + } + return nil +} + +func calcRowCount(bufferSize int, schema *schemapb.CollectionSchema) (int64, error) { + sizePerRecord, err := typeutil.EstimateSizePerRecord(schema) + if err != nil { + return 0, err + } + rowCount := int64(bufferSize) / int64(sizePerRecord) + return rowCount, nil +}