// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package numpy import ( "bytes" "encoding/binary" "fmt" "io" "unicode/utf8" "github.com/samber/lo" "github.com/sbinet/npyio" "github.com/sbinet/npyio/npy" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/util/importutilv2/common" pkgcommon "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/parameterutil" "github.com/milvus-io/milvus/pkg/v2/util/timestamptz" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) type FieldReader struct { reader io.Reader npyReader *npy.Reader order binary.ByteOrder dim int64 field *schemapb.FieldSchema // timezone is the collection's default timezone timezone string readPosition int } func NewFieldReader(reader io.Reader, field *schemapb.FieldSchema, timezone string) (*FieldReader, error) { r, err := npyio.NewReader(reader) if err != nil { return nil, err } var dim int64 = 1 dataType := field.GetDataType() if typeutil.IsVectorType(dataType) && !typeutil.IsSparseFloatVectorType(dataType) { dim, err = typeutil.GetDim(field) if err != nil { return nil, err } } err = validateHeader(r, field, int(dim)) if err != nil { return nil, err } cr := &FieldReader{ reader: reader, npyReader: r, dim: dim, field: field, timezone: timezone, } cr.setByteOrder() return cr, nil } func ReadN[T any](reader io.Reader, order binary.ByteOrder, n int64) ([]T, error) { data := make([]T, n) err := binary.Read(reader, order, &data) if err != nil { return nil, err } return data, nil } func (c *FieldReader) getCount(count int64) int64 { shape := c.npyReader.Header.Descr.Shape if len(shape) == 0 { return 0 } total := 1 for i := 0; i < len(shape); i++ { total *= shape[i] } if total == 0 { return 0 } switch c.field.GetDataType() { case schemapb.DataType_BinaryVector: count *= c.dim / 8 case schemapb.DataType_FloatVector: count *= c.dim case schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector: count *= c.dim * 2 case schemapb.DataType_Int8Vector: count *= c.dim } if int(count) > (total - c.readPosition) { return int64(total - c.readPosition) } return count } func (c *FieldReader) Next(count int64) (any, any, error) { readCount := c.getCount(count) if readCount == 0 { return nil, nil, nil } var ( data any validData []bool err error ) // numpy file cannot store null value, all the values must be non-null if the numpy file is accepted // construct a bool array with all-true if the field is nullable if c.field.GetNullable() || c.field.GetDefaultValue() != nil { validData = make([]bool, 0, readCount) for i := int64(0); i < readCount; i++ { validData = append(validData, true) } } dt := c.field.GetDataType() switch dt { case schemapb.DataType_Bool: data, err = ReadN[bool](c.reader, c.order, readCount) if err != nil { return nil, nil, err } c.readPosition += int(readCount) case schemapb.DataType_Int8: data, err = ReadN[int8](c.reader, c.order, readCount) if err != nil { return nil, nil, err } c.readPosition += int(readCount) case schemapb.DataType_Int16: data, err = ReadN[int16](c.reader, c.order, readCount) if err != nil { return nil, nil, err } c.readPosition += int(readCount) case schemapb.DataType_Int32: data, err = ReadN[int32](c.reader, c.order, readCount) if err != nil { return nil, nil, err } c.readPosition += int(readCount) case schemapb.DataType_Int64: data, err = ReadN[int64](c.reader, c.order, readCount) if err != nil { return nil, nil, err } c.readPosition += int(readCount) case schemapb.DataType_Float: data, err = ReadN[float32](c.reader, c.order, readCount) if err != nil { return nil, nil, err } c.readPosition += int(readCount) case schemapb.DataType_Double: data, err = ReadN[float64](c.reader, c.order, readCount) if err != nil { return nil, nil, err } c.readPosition += int(readCount) case schemapb.DataType_Timestamptz: var strs []string strs, err = c.ReadString(readCount) if err != nil { return nil, nil, err } int64Ts := make([]int64, 0, len(strs)) for _, strValue := range strs { tz, err := timestamptz.ValidateAndReturnUnixMicroTz(strValue, c.timezone) if err != nil { return nil, nil, err } int64Ts = append(int64Ts, tz) } data = int64Ts c.readPosition += int(readCount) case schemapb.DataType_VarChar: data, err = c.ReadString(readCount) c.readPosition += int(readCount) if err != nil { return nil, nil, err } case schemapb.DataType_Geometry: var strs []string strs, err = c.ReadString(readCount) if err != nil { return nil, nil, err } byteArr := make([][]byte, 0) for _, wktValue := range strs { wkbValue, err := pkgcommon.ConvertWKTToWKB(wktValue) if err != nil { return nil, nil, err } byteArr = append(byteArr, wkbValue) } data = byteArr c.readPosition += int(readCount) case schemapb.DataType_JSON: var strs []string strs, err = c.ReadString(readCount) if err != nil { return nil, nil, err } byteArr := make([][]byte, 0) for _, str := range strs { var dummy interface{} err = json.Unmarshal([]byte(str), &dummy) if err != nil { return nil, nil, merr.WrapErrImportFailed( fmt.Sprintf("failed to parse value '%v' for JSON field '%s', error: %v", str, c.field.GetName(), err)) } if c.field.GetIsDynamic() { var dummy2 map[string]interface{} err = json.Unmarshal([]byte(str), &dummy2) if err != nil { return nil, nil, merr.WrapErrImportFailed( fmt.Sprintf("failed to parse value '%v' for dynamic JSON field '%s', error: %v", str, c.field.GetName(), err)) } } byteArr = append(byteArr, []byte(str)) } data = byteArr c.readPosition += int(readCount) case schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector: data, err = ReadN[uint8](c.reader, c.order, readCount) if err != nil { return nil, nil, err } c.readPosition += int(readCount) case schemapb.DataType_Int8Vector: data, err = ReadN[int8](c.reader, c.order, readCount) if err != nil { return nil, nil, err } c.readPosition += int(readCount) case schemapb.DataType_FloatVector: var elementType schemapb.DataType elementType, err = convertNumpyType(c.npyReader.Header.Descr.Type) if err != nil { return nil, nil, err } switch elementType { case schemapb.DataType_Float: data, err = ReadN[float32](c.reader, c.order, readCount) if err != nil { return nil, nil, err } err = typeutil.VerifyFloats32(data.([]float32)) if err != nil { return nil, nil, err } case schemapb.DataType_Double: var data64 []float64 data64, err = ReadN[float64](c.reader, c.order, readCount) if err != nil { return nil, nil, err } err = typeutil.VerifyFloats64(data64) if err != nil { return nil, nil, err } data = lo.Map(data64, func(f float64, _ int) float32 { return float32(f) }) } c.readPosition += int(readCount) default: return nil, nil, merr.WrapErrImportFailed(fmt.Sprintf("unsupported data type: %s", dt.String())) } return data, validData, nil } // setByteOrder sets BigEndian/LittleEndian, the logic of this method is copied from npyio lib func (c *FieldReader) setByteOrder() { var nativeEndian binary.ByteOrder v := uint16(1) switch byte(v >> 8) { case 0: nativeEndian = binary.LittleEndian case 1: nativeEndian = binary.BigEndian } switch c.npyReader.Header.Descr.Type[0] { case '<': c.order = binary.LittleEndian case '>': c.order = binary.BigEndian default: c.order = nativeEndian } } func (c *FieldReader) ReadString(count int64) ([]string, error) { // varchar length, this is the max length, some item is shorter than this length, but they also occupy bytes of max length maxLen, utf, err := stringLen(c.npyReader.Header.Descr.Type) if err != nil || maxLen <= 0 { return nil, merr.WrapErrImportFailed( fmt.Sprintf("failed to get max length %d of varchar from numpy file header, error: %v", maxLen, err)) } maxLength, err := parameterutil.GetMaxLength(c.field) if c.field.DataType == schemapb.DataType_VarChar && err != nil { return nil, err } // read data data := make([]string, 0, count) for len(data) < int(count) { if utf { // in the numpy file with utf32 encoding, the dType could be like " 0 { buf = buf[:n] } str := string(buf) if err = common.CheckValidUTF8(str, c.field); err != nil { return nil, err } data = append(data, str) } } return data, nil }