mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
ref https://github.com/milvus-io/milvus/issues/42148 --------- Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
1141 lines
32 KiB
Go
1141 lines
32 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package storage
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"math"
|
|
"sync"
|
|
|
|
"github.com/apache/arrow/go/v17/arrow"
|
|
"github.com/apache/arrow/go/v17/arrow/array"
|
|
"github.com/apache/arrow/go/v17/arrow/memory"
|
|
"github.com/apache/arrow/go/v17/parquet"
|
|
"github.com/apache/arrow/go/v17/parquet/compress"
|
|
"github.com/apache/arrow/go/v17/parquet/pqarrow"
|
|
"github.com/cockroachdb/errors"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
_ "github.com/milvus-io/milvus/internal/storage/compress" // register a custom zstd codec here, to avoid to much memory usage when serializing.
|
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
|
)
|
|
|
|
var _ PayloadWriterInterface = (*NativePayloadWriter)(nil)
|
|
|
|
type PayloadWriterOptions func(*NativePayloadWriter)
|
|
|
|
func WithNullable(nullable bool) PayloadWriterOptions {
|
|
return func(w *NativePayloadWriter) {
|
|
w.nullable = nullable
|
|
}
|
|
}
|
|
|
|
func WithWriterProps(writerProps *parquet.WriterProperties) PayloadWriterOptions {
|
|
return func(w *NativePayloadWriter) {
|
|
w.writerProps = writerProps
|
|
}
|
|
}
|
|
|
|
func WithDim(dim int) PayloadWriterOptions {
|
|
return func(w *NativePayloadWriter) {
|
|
w.dim = NewNullableInt(dim)
|
|
}
|
|
}
|
|
|
|
func WithElementType(elementType schemapb.DataType) PayloadWriterOptions {
|
|
return func(w *NativePayloadWriter) {
|
|
w.elementType = &elementType
|
|
}
|
|
}
|
|
|
|
type NativePayloadWriter struct {
|
|
dataType schemapb.DataType
|
|
elementType *schemapb.DataType
|
|
arrowType arrow.DataType
|
|
builder array.Builder
|
|
finished bool
|
|
flushedRows int
|
|
output *bytes.Buffer
|
|
releaseOnce sync.Once
|
|
dim *NullableInt
|
|
nullable bool
|
|
writerProps *parquet.WriterProperties
|
|
}
|
|
|
|
func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions) (PayloadWriterInterface, error) {
|
|
w := &NativePayloadWriter{
|
|
dataType: colType,
|
|
finished: false,
|
|
flushedRows: 0,
|
|
output: new(bytes.Buffer),
|
|
nullable: false,
|
|
writerProps: parquet.NewWriterProperties(
|
|
parquet.WithCompression(compress.Codecs.Zstd),
|
|
parquet.WithCompressionLevel(3),
|
|
),
|
|
dim: &NullableInt{},
|
|
}
|
|
for _, o := range options {
|
|
o(w)
|
|
}
|
|
|
|
// writer for sparse float vector doesn't require dim
|
|
if typeutil.IsVectorType(colType) && !typeutil.IsSparseFloatVectorType(colType) {
|
|
if w.dim.IsNull() {
|
|
return nil, merr.WrapErrParameterInvalidMsg("incorrect input numbers")
|
|
}
|
|
if w.nullable {
|
|
return nil, merr.WrapErrParameterInvalidMsg("vector type does not support nullable")
|
|
}
|
|
} else {
|
|
w.dim = NewNullableInt(1)
|
|
}
|
|
|
|
// Handle ArrayOfVector type with elementType
|
|
if colType == schemapb.DataType_ArrayOfVector {
|
|
if w.elementType == nil {
|
|
return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires elementType, use WithElementType option")
|
|
}
|
|
if w.dim == nil {
|
|
return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires dim to be specified")
|
|
}
|
|
elemType, err := VectorArrayToArrowType(*w.elementType, *w.dim.Value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
w.arrowType = arrow.ListOf(elemType)
|
|
w.builder = array.NewListBuilder(memory.DefaultAllocator, elemType)
|
|
} else {
|
|
w.arrowType = MilvusDataTypeToArrowType(colType, *w.dim.Value)
|
|
w.builder = array.NewBuilder(memory.DefaultAllocator, w.arrowType)
|
|
}
|
|
return w, nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddDataToPayloadForUT(data interface{}, validData []bool) error {
|
|
switch w.dataType {
|
|
case schemapb.DataType_Bool:
|
|
val, ok := data.([]bool)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddBoolToPayload(val, validData)
|
|
case schemapb.DataType_Int8:
|
|
val, ok := data.([]int8)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddInt8ToPayload(val, validData)
|
|
case schemapb.DataType_Int16:
|
|
val, ok := data.([]int16)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddInt16ToPayload(val, validData)
|
|
case schemapb.DataType_Int32:
|
|
val, ok := data.([]int32)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddInt32ToPayload(val, validData)
|
|
case schemapb.DataType_Int64:
|
|
val, ok := data.([]int64)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddInt64ToPayload(val, validData)
|
|
case schemapb.DataType_Float:
|
|
val, ok := data.([]float32)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddFloatToPayload(val, validData)
|
|
case schemapb.DataType_Double:
|
|
val, ok := data.([]float64)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddDoubleToPayload(val, validData)
|
|
case schemapb.DataType_Timestamptz:
|
|
val, ok := data.([]int64)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddTimestamptzToPayload(val, validData)
|
|
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
|
val, ok := data.(string)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
isValid := true
|
|
if len(validData) > 1 {
|
|
return merr.WrapErrParameterInvalidMsg("wrong input length when add data to payload")
|
|
}
|
|
if len(validData) == 0 && w.nullable {
|
|
return merr.WrapErrParameterInvalidMsg("need pass valid_data when nullable==true")
|
|
}
|
|
if len(validData) == 1 {
|
|
if !w.nullable {
|
|
return merr.WrapErrParameterInvalidMsg("no need pass valid_data when nullable==false")
|
|
}
|
|
isValid = validData[0]
|
|
}
|
|
return w.AddOneStringToPayload(val, isValid)
|
|
case schemapb.DataType_Array:
|
|
val, ok := data.(*schemapb.ScalarField)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
isValid := true
|
|
if len(validData) > 1 {
|
|
return merr.WrapErrParameterInvalidMsg("wrong input length when add data to payload")
|
|
}
|
|
if len(validData) == 0 && w.nullable {
|
|
return merr.WrapErrParameterInvalidMsg("need pass valid_data when nullable==true")
|
|
}
|
|
if len(validData) == 1 {
|
|
if !w.nullable {
|
|
return merr.WrapErrParameterInvalidMsg("no need pass valid_data when nullable==false")
|
|
}
|
|
isValid = validData[0]
|
|
}
|
|
return w.AddOneArrayToPayload(val, isValid)
|
|
case schemapb.DataType_JSON:
|
|
val, ok := data.([]byte)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
isValid := true
|
|
if len(validData) > 1 {
|
|
return merr.WrapErrParameterInvalidMsg("wrong input length when add data to payload")
|
|
}
|
|
if len(validData) == 0 && w.nullable {
|
|
return merr.WrapErrParameterInvalidMsg("need pass valid_data when nullable==true")
|
|
}
|
|
if len(validData) == 1 {
|
|
if !w.nullable {
|
|
return merr.WrapErrParameterInvalidMsg("no need pass valid_data when nullable==false")
|
|
}
|
|
isValid = validData[0]
|
|
}
|
|
return w.AddOneJSONToPayload(val, isValid)
|
|
case schemapb.DataType_Geometry:
|
|
val, ok := data.([]byte)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
isValid := true
|
|
if len(validData) > 1 {
|
|
return merr.WrapErrParameterInvalidMsg("wrong input length when add data to payload")
|
|
}
|
|
if len(validData) == 0 && w.nullable {
|
|
return merr.WrapErrParameterInvalidMsg("need pass valid_data when nullable==true")
|
|
}
|
|
if len(validData) == 1 {
|
|
if !w.nullable {
|
|
return merr.WrapErrParameterInvalidMsg("no need pass valid_data when nullable==false")
|
|
}
|
|
isValid = validData[0]
|
|
}
|
|
return w.AddOneGeometryToPayload(val, isValid)
|
|
case schemapb.DataType_BinaryVector:
|
|
val, ok := data.([]byte)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddBinaryVectorToPayload(val, w.dim.GetValue())
|
|
case schemapb.DataType_FloatVector:
|
|
val, ok := data.([]float32)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddFloatVectorToPayload(val, w.dim.GetValue())
|
|
case schemapb.DataType_Float16Vector:
|
|
val, ok := data.([]byte)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddFloat16VectorToPayload(val, w.dim.GetValue())
|
|
case schemapb.DataType_BFloat16Vector:
|
|
val, ok := data.([]byte)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddBFloat16VectorToPayload(val, w.dim.GetValue())
|
|
case schemapb.DataType_SparseFloatVector:
|
|
val, ok := data.(*SparseFloatVectorFieldData)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddSparseFloatVectorToPayload(val)
|
|
case schemapb.DataType_Int8Vector:
|
|
val, ok := data.([]int8)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type")
|
|
}
|
|
return w.AddInt8VectorToPayload(val, w.dim.GetValue())
|
|
case schemapb.DataType_ArrayOfVector:
|
|
val, ok := data.(*VectorArrayFieldData)
|
|
if !ok {
|
|
return merr.WrapErrParameterInvalidMsg("incorrect data type: expected *VectorArrayFieldData")
|
|
}
|
|
return w.AddVectorArrayFieldDataToPayload(val)
|
|
default:
|
|
return errors.New("unsupported datatype")
|
|
}
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddBoolToPayload(data []bool, validData []bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished bool payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into bool payload")
|
|
}
|
|
|
|
if !w.nullable && len(validData) != 0 {
|
|
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
if w.nullable && len(data) != len(validData) {
|
|
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.BooleanBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast ArrayBuilder")
|
|
}
|
|
builder.AppendValues(data, validData)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddByteToPayload(data []byte, validData []bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished byte payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into byte payload")
|
|
}
|
|
|
|
if !w.nullable && len(validData) != 0 {
|
|
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
if w.nullable && len(data) != len(validData) {
|
|
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.Int8Builder)
|
|
if !ok {
|
|
return errors.New("failed to cast ByteBuilder")
|
|
}
|
|
|
|
builder.Reserve(len(data))
|
|
for i := range data {
|
|
builder.Append(int8(data[i]))
|
|
if w.nullable && !validData[i] {
|
|
builder.AppendNull()
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddInt8ToPayload(data []int8, validData []bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished int8 payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into int8 payload")
|
|
}
|
|
|
|
if !w.nullable && len(validData) != 0 {
|
|
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
if w.nullable && len(data) != len(validData) {
|
|
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.Int8Builder)
|
|
if !ok {
|
|
return errors.New("failed to cast Int8Builder")
|
|
}
|
|
builder.AppendValues(data, validData)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddInt16ToPayload(data []int16, validData []bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished int16 payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into int16 payload")
|
|
}
|
|
|
|
if !w.nullable && len(validData) != 0 {
|
|
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
if w.nullable && len(data) != len(validData) {
|
|
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.Int16Builder)
|
|
if !ok {
|
|
return errors.New("failed to cast Int16Builder")
|
|
}
|
|
builder.AppendValues(data, validData)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddInt32ToPayload(data []int32, validData []bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished int32 payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into int32 payload")
|
|
}
|
|
|
|
if !w.nullable && len(validData) != 0 {
|
|
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
if w.nullable && len(data) != len(validData) {
|
|
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.Int32Builder)
|
|
if !ok {
|
|
return errors.New("failed to cast Int32Builder")
|
|
}
|
|
builder.AppendValues(data, validData)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddInt64ToPayload(data []int64, validData []bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished int64 payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into int64 payload")
|
|
}
|
|
|
|
if !w.nullable && len(validData) != 0 {
|
|
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
if w.nullable && len(data) != len(validData) {
|
|
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.Int64Builder)
|
|
if !ok {
|
|
return errors.New("failed to cast Int64Builder")
|
|
}
|
|
builder.AppendValues(data, validData)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddFloatToPayload(data []float32, validData []bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished float payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into float payload")
|
|
}
|
|
|
|
if !w.nullable && len(validData) != 0 {
|
|
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
if w.nullable && len(data) != len(validData) {
|
|
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.Float32Builder)
|
|
if !ok {
|
|
return errors.New("failed to cast FloatBuilder")
|
|
}
|
|
builder.AppendValues(data, validData)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddDoubleToPayload(data []float64, validData []bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished double payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into double payload")
|
|
}
|
|
|
|
if !w.nullable && len(validData) != 0 {
|
|
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
if w.nullable && len(data) != len(validData) {
|
|
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.Float64Builder)
|
|
if !ok {
|
|
return errors.New("failed to cast DoubleBuilder")
|
|
}
|
|
builder.AppendValues(data, validData)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddTimestamptzToPayload(data []int64, validData []bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished int64 payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into int64 payload")
|
|
}
|
|
|
|
if !w.nullable && len(validData) != 0 {
|
|
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
if w.nullable && len(data) != len(validData) {
|
|
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
|
|
return merr.WrapErrParameterInvalidMsg(msg)
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.Int64Builder)
|
|
if !ok {
|
|
return errors.New("failed to cast Int64Builder")
|
|
}
|
|
builder.AppendValues(data, validData)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddOneStringToPayload(data string, isValid bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished string payload")
|
|
}
|
|
|
|
if !w.nullable && !isValid {
|
|
return merr.WrapErrParameterInvalidMsg("not support null when nullable is false")
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.StringBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast StringBuilder")
|
|
}
|
|
|
|
if !isValid {
|
|
builder.AppendNull()
|
|
} else {
|
|
builder.Append(data)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddOneArrayToPayload(data *schemapb.ScalarField, isValid bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished array payload")
|
|
}
|
|
|
|
if !w.nullable && !isValid {
|
|
return merr.WrapErrParameterInvalidMsg("not support null when nullable is false")
|
|
}
|
|
|
|
bytes, err := proto.Marshal(data)
|
|
if err != nil {
|
|
return errors.New("Marshal ListValue failed")
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.BinaryBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast BinaryBuilder")
|
|
}
|
|
|
|
if !isValid {
|
|
builder.AppendNull()
|
|
} else {
|
|
builder.Append(bytes)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddOneJSONToPayload(data []byte, isValid bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished json payload")
|
|
}
|
|
|
|
if !w.nullable && !isValid {
|
|
return merr.WrapErrParameterInvalidMsg("not support null when nullable is false")
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.BinaryBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast JsonBuilder")
|
|
}
|
|
|
|
if !isValid {
|
|
builder.AppendNull()
|
|
} else {
|
|
builder.Append(data)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddOneGeometryToPayload(data []byte, isValid bool) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished geometry payload")
|
|
}
|
|
|
|
if !w.nullable && !isValid {
|
|
return merr.WrapErrParameterInvalidMsg("not support null when nullable is false")
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.BinaryBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast geometryBuilder")
|
|
}
|
|
|
|
if !isValid {
|
|
builder.AppendNull()
|
|
} else {
|
|
builder.Append(data)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddBinaryVectorToPayload(data []byte, dim int) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished binary vector payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into binary vector payload")
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast BinaryVectorBuilder")
|
|
}
|
|
|
|
byteLength := dim / 8
|
|
length := len(data) / byteLength
|
|
builder.Reserve(length)
|
|
for i := 0; i < length; i++ {
|
|
builder.Append(data[i*byteLength : (i+1)*byteLength])
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddFloatVectorToPayload(data []float32, dim int) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished float vector payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into float vector payload")
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast FloatVectorBuilder")
|
|
}
|
|
|
|
byteLength := dim * 4
|
|
length := len(data) / dim
|
|
|
|
builder.Reserve(length)
|
|
bytesData := make([]byte, byteLength)
|
|
for i := 0; i < length; i++ {
|
|
vec := data[i*dim : (i+1)*dim]
|
|
for j := range vec {
|
|
bytes := math.Float32bits(vec[j])
|
|
common.Endian.PutUint32(bytesData[j*4:], bytes)
|
|
}
|
|
builder.Append(bytesData)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddFloat16VectorToPayload(data []byte, dim int) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished float16 payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into float16 payload")
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast Float16Builder")
|
|
}
|
|
|
|
byteLength := dim * 2
|
|
length := len(data) / byteLength
|
|
|
|
builder.Reserve(length)
|
|
for i := 0; i < length; i++ {
|
|
builder.Append(data[i*byteLength : (i+1)*byteLength])
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddBFloat16VectorToPayload(data []byte, dim int) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished BFloat16 payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into BFloat16 payload")
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast BFloat16Builder")
|
|
}
|
|
|
|
byteLength := dim * 2
|
|
length := len(data) / byteLength
|
|
|
|
builder.Reserve(length)
|
|
for i := 0; i < length; i++ {
|
|
builder.Append(data[i*byteLength : (i+1)*byteLength])
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddSparseFloatVectorToPayload(data *SparseFloatVectorFieldData) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished sparse float vector payload")
|
|
}
|
|
builder, ok := w.builder.(*array.BinaryBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast SparseFloatVectorBuilder")
|
|
}
|
|
length := len(data.SparseFloatArray.Contents)
|
|
builder.Reserve(length)
|
|
for i := 0; i < length; i++ {
|
|
builder.Append(data.SparseFloatArray.Contents[i])
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) AddInt8VectorToPayload(data []int8, dim int) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished int8 vector payload")
|
|
}
|
|
|
|
if len(data) == 0 {
|
|
return errors.New("can't add empty msgs into int8 vector payload")
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast Int8VectorBuilder")
|
|
}
|
|
|
|
byteLength := dim
|
|
length := len(data) / byteLength
|
|
|
|
builder.Reserve(length)
|
|
for i := 0; i < length; i++ {
|
|
vec := data[i*dim : (i+1)*dim]
|
|
vecBytes := arrow.Int8Traits.CastToBytes(vec)
|
|
builder.Append(vecBytes)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) FinishPayloadWriter() error {
|
|
if w.finished {
|
|
return errors.New("can't reuse a finished writer")
|
|
}
|
|
|
|
w.finished = true
|
|
|
|
// Prepare metadata for VectorArray type
|
|
var metadata arrow.Metadata
|
|
if w.dataType == schemapb.DataType_ArrayOfVector {
|
|
if w.elementType == nil {
|
|
return errors.New("element type for DataType_ArrayOfVector must be set")
|
|
}
|
|
|
|
metadata = arrow.NewMetadata(
|
|
[]string{"elementType", "dim"},
|
|
[]string{fmt.Sprintf("%d", int32(*w.elementType)), fmt.Sprintf("%d", w.dim.GetValue())},
|
|
)
|
|
}
|
|
|
|
field := arrow.Field{
|
|
Name: "val",
|
|
Type: w.arrowType,
|
|
Nullable: w.nullable,
|
|
Metadata: metadata,
|
|
}
|
|
schema := arrow.NewSchema([]arrow.Field{
|
|
field,
|
|
}, nil)
|
|
|
|
w.flushedRows += w.builder.Len()
|
|
data := w.builder.NewArray()
|
|
defer data.Release()
|
|
column := arrow.NewColumnFromArr(field, data)
|
|
defer column.Release()
|
|
|
|
table := array.NewTable(schema, []arrow.Column{column}, int64(column.Len()))
|
|
defer table.Release()
|
|
|
|
arrowWriterProps := pqarrow.DefaultWriterProps()
|
|
if w.dataType == schemapb.DataType_ArrayOfVector {
|
|
// Store metadata in the Arrow writer properties
|
|
arrowWriterProps = pqarrow.NewArrowWriterProperties(
|
|
pqarrow.WithStoreSchema(),
|
|
)
|
|
}
|
|
|
|
return pqarrow.WriteTable(table,
|
|
w.output,
|
|
1024*1024*1024,
|
|
w.writerProps,
|
|
arrowWriterProps,
|
|
)
|
|
}
|
|
|
|
func (w *NativePayloadWriter) Reserve(size int) {
|
|
w.builder.Reserve(size)
|
|
}
|
|
|
|
func (w *NativePayloadWriter) GetPayloadBufferFromWriter() ([]byte, error) {
|
|
data := w.output.Bytes()
|
|
|
|
// The cpp version of payload writer handles the empty buffer as error
|
|
if len(data) == 0 {
|
|
return nil, errors.New("empty buffer")
|
|
}
|
|
|
|
return data, nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) GetPayloadLengthFromWriter() (int, error) {
|
|
return w.flushedRows + w.builder.Len(), nil
|
|
}
|
|
|
|
func (w *NativePayloadWriter) ReleasePayloadWriter() {
|
|
w.releaseOnce.Do(func() {
|
|
w.builder.Release()
|
|
})
|
|
}
|
|
|
|
func (w *NativePayloadWriter) Close() {
|
|
w.ReleasePayloadWriter()
|
|
}
|
|
|
|
func MilvusDataTypeToArrowType(dataType schemapb.DataType, dim int) arrow.DataType {
|
|
switch dataType {
|
|
case schemapb.DataType_Bool:
|
|
return &arrow.BooleanType{}
|
|
case schemapb.DataType_Int8:
|
|
return &arrow.Int8Type{}
|
|
case schemapb.DataType_Int16:
|
|
return &arrow.Int16Type{}
|
|
case schemapb.DataType_Int32:
|
|
return &arrow.Int32Type{}
|
|
case schemapb.DataType_Int64, schemapb.DataType_Timestamptz:
|
|
return &arrow.Int64Type{}
|
|
case schemapb.DataType_Float:
|
|
return &arrow.Float32Type{}
|
|
case schemapb.DataType_Double:
|
|
return &arrow.Float64Type{}
|
|
case schemapb.DataType_VarChar, schemapb.DataType_String, schemapb.DataType_Text:
|
|
return &arrow.StringType{}
|
|
case schemapb.DataType_Array:
|
|
return &arrow.BinaryType{}
|
|
case schemapb.DataType_JSON:
|
|
return &arrow.BinaryType{}
|
|
case schemapb.DataType_Geometry:
|
|
return &arrow.BinaryType{}
|
|
case schemapb.DataType_FloatVector:
|
|
return &arrow.FixedSizeBinaryType{
|
|
ByteWidth: dim * 4,
|
|
}
|
|
case schemapb.DataType_BinaryVector:
|
|
return &arrow.FixedSizeBinaryType{
|
|
ByteWidth: dim / 8,
|
|
}
|
|
case schemapb.DataType_Float16Vector:
|
|
return &arrow.FixedSizeBinaryType{
|
|
ByteWidth: dim * 2,
|
|
}
|
|
case schemapb.DataType_BFloat16Vector:
|
|
return &arrow.FixedSizeBinaryType{
|
|
ByteWidth: dim * 2,
|
|
}
|
|
case schemapb.DataType_SparseFloatVector:
|
|
return &arrow.BinaryType{}
|
|
case schemapb.DataType_Int8Vector:
|
|
return &arrow.FixedSizeBinaryType{
|
|
ByteWidth: dim,
|
|
}
|
|
case schemapb.DataType_ArrayOfVector:
|
|
// ArrayOfVector requires elementType, should use VectorArrayToArrowType instead
|
|
panic("ArrayOfVector type requires elementType information, use VectorArrayToArrowType")
|
|
default:
|
|
panic("unsupported data type")
|
|
}
|
|
}
|
|
|
|
// AddVectorArrayFieldDataToPayload adds VectorArrayFieldData to payload using Arrow ListArray
|
|
func (w *NativePayloadWriter) AddVectorArrayFieldDataToPayload(data *VectorArrayFieldData) error {
|
|
if w.finished {
|
|
return errors.New("can't append data to finished vector array payload")
|
|
}
|
|
|
|
if len(data.Data) == 0 {
|
|
return errors.New("can't add empty vector array field data")
|
|
}
|
|
|
|
builder, ok := w.builder.(*array.ListBuilder)
|
|
if !ok {
|
|
return errors.New("failed to cast to ListBuilder for VectorArray")
|
|
}
|
|
|
|
switch data.ElementType {
|
|
case schemapb.DataType_FloatVector:
|
|
return w.addFloatVectorArrayToPayload(builder, data)
|
|
case schemapb.DataType_BinaryVector:
|
|
return w.addBinaryVectorArrayToPayload(builder, data)
|
|
case schemapb.DataType_Float16Vector:
|
|
return w.addFloat16VectorArrayToPayload(builder, data)
|
|
case schemapb.DataType_BFloat16Vector:
|
|
return w.addBFloat16VectorArrayToPayload(builder, data)
|
|
case schemapb.DataType_Int8Vector:
|
|
return w.addInt8VectorArrayToPayload(builder, data)
|
|
default:
|
|
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", data.ElementType.String()))
|
|
}
|
|
}
|
|
|
|
// addFloatVectorArrayToPayload handles FloatVector elements in VectorArray
|
|
func (w *NativePayloadWriter) addFloatVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
|
|
if data.Dim <= 0 {
|
|
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
|
|
}
|
|
|
|
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
|
|
|
// Each element in data.Data represents one row of VectorArray
|
|
for _, vectorField := range data.Data {
|
|
if vectorField.GetFloatVector() == nil {
|
|
return merr.WrapErrParameterInvalidMsg("expected FloatVector but got different type")
|
|
}
|
|
|
|
// Start a new list for this row
|
|
builder.Append(true)
|
|
|
|
floatData := vectorField.GetFloatVector().GetData()
|
|
|
|
dim := vectorField.GetDim()
|
|
numVectors := len(floatData) / int(dim)
|
|
for i := 0; i < numVectors; i++ {
|
|
start := i * int(dim)
|
|
end := start + int(dim)
|
|
vectorSlice := floatData[start:end]
|
|
|
|
bytes := make([]byte, dim*4)
|
|
for j, f := range vectorSlice {
|
|
binary.LittleEndian.PutUint32(bytes[j*4:], math.Float32bits(f))
|
|
}
|
|
|
|
valueBuilder.Append(bytes)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addBinaryVectorArrayToPayload handles BinaryVector elements in VectorArray
|
|
func (w *NativePayloadWriter) addBinaryVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
|
|
if data.Dim <= 0 {
|
|
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
|
|
}
|
|
|
|
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
|
|
|
// Each element in data.Data represents one row of VectorArray
|
|
for _, vectorField := range data.Data {
|
|
if vectorField.GetBinaryVector() == nil {
|
|
return merr.WrapErrParameterInvalidMsg("expected BinaryVector but got different type")
|
|
}
|
|
|
|
// Start a new list for this row
|
|
builder.Append(true)
|
|
|
|
binaryData := vectorField.GetBinaryVector()
|
|
byteWidth := (data.Dim + 7) / 8
|
|
numVectors := len(binaryData) / int(byteWidth)
|
|
|
|
for i := 0; i < numVectors; i++ {
|
|
start := i * int(byteWidth)
|
|
end := start + int(byteWidth)
|
|
valueBuilder.Append(binaryData[start:end])
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addFloat16VectorArrayToPayload handles Float16Vector elements in VectorArray
|
|
func (w *NativePayloadWriter) addFloat16VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
|
|
if data.Dim <= 0 {
|
|
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
|
|
}
|
|
|
|
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
|
|
|
// Each element in data.Data represents one row of VectorArray
|
|
for _, vectorField := range data.Data {
|
|
if vectorField.GetFloat16Vector() == nil {
|
|
return merr.WrapErrParameterInvalidMsg("expected Float16Vector but got different type")
|
|
}
|
|
|
|
// Start a new list for this row
|
|
builder.Append(true)
|
|
|
|
float16Data := vectorField.GetFloat16Vector()
|
|
byteWidth := data.Dim * 2
|
|
numVectors := len(float16Data) / int(byteWidth)
|
|
|
|
for i := 0; i < numVectors; i++ {
|
|
start := i * int(byteWidth)
|
|
end := start + int(byteWidth)
|
|
valueBuilder.Append(float16Data[start:end])
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addBFloat16VectorArrayToPayload handles BFloat16Vector elements in VectorArray
|
|
func (w *NativePayloadWriter) addBFloat16VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
|
|
if data.Dim <= 0 {
|
|
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
|
|
}
|
|
|
|
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
|
|
|
// Each element in data.Data represents one row of VectorArray
|
|
for _, vectorField := range data.Data {
|
|
if vectorField.GetBfloat16Vector() == nil {
|
|
return merr.WrapErrParameterInvalidMsg("expected BFloat16Vector but got different type")
|
|
}
|
|
|
|
// Start a new list for this row
|
|
builder.Append(true)
|
|
|
|
bfloat16Data := vectorField.GetBfloat16Vector()
|
|
byteWidth := data.Dim * 2
|
|
numVectors := len(bfloat16Data) / int(byteWidth)
|
|
|
|
for i := 0; i < numVectors; i++ {
|
|
start := i * int(byteWidth)
|
|
end := start + int(byteWidth)
|
|
valueBuilder.Append(bfloat16Data[start:end])
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addInt8VectorArrayToPayload handles Int8Vector elements in VectorArray
|
|
func (w *NativePayloadWriter) addInt8VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
|
|
if data.Dim <= 0 {
|
|
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
|
|
}
|
|
|
|
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
|
|
|
// Each element in data.Data represents one row of VectorArray
|
|
for _, vectorField := range data.Data {
|
|
if vectorField.GetInt8Vector() == nil {
|
|
return merr.WrapErrParameterInvalidMsg("expected Int8Vector but got different type")
|
|
}
|
|
|
|
// Start a new list for this row
|
|
builder.Append(true)
|
|
|
|
int8Data := vectorField.GetInt8Vector()
|
|
numVectors := len(int8Data) / int(data.Dim)
|
|
|
|
for i := 0; i < numVectors; i++ {
|
|
start := i * int(data.Dim)
|
|
end := start + int(data.Dim)
|
|
valueBuilder.Append(int8Data[start:end])
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|