milvus/internal/storage/payload_writer.go
marcelo-cjl 3b599441fd
feat: Add nullable vector support for proxy and querynode (#46305)
related: #45993 

This commit extends nullable vector support to the proxy layer,
querynode,
and adds comprehensive validation, search reduce, and field data
handling
    for nullable vectors with sparse storage.
    
    Proxy layer changes:
- Update validate_util.go checkAligned() with getExpectedVectorRows()
helper
      to validate nullable vector field alignment using valid data count
- Update checkFloatVectorFieldData/checkSparseFloatVectorFieldData for
      nullable vector validation with proper row count expectations
- Add FieldDataIdxComputer in typeutil/schema.go for logical-to-physical
      index translation during search reduce operations
- Update search_reduce_util.go reduceSearchResultData to use
idxComputers
      for correct field data indexing with nullable vectors
- Update task.go, task_query.go, task_upsert.go for nullable vector
handling
    - Update msg_pack.go with nullable vector field data processing
    
    QueryNode layer changes:
    - Update segments/result.go for nullable vector result handling
- Update segments/search_reduce.go with nullable vector offset
translation
    
    Storage and index changes:
- Update data_codec.go and utils.go for nullable vector serialization
- Update indexcgowrapper/dataset.go and index.go for nullable vector
indexing
    
    Utility changes:
- Add FieldDataIdxComputer struct with Compute() method for efficient
      logical-to-physical index mapping across multiple field data
- Update EstimateEntitySize() and AppendFieldData() with fieldIdxs
parameter
    - Update funcutil.go with nullable vector support functions

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Full support for nullable vector fields (float, binary, float16,
bfloat16, int8, sparse) across ingest, storage, indexing, search and
retrieval; logical↔physical offset mapping preserves row semantics.
  * Client: compaction control and compaction-state APIs.

* **Bug Fixes**
* Improved validation for adding vector fields (nullable + dimension
checks) and corrected search/query behavior for nullable vectors.

* **Chores**
  * Persisted validity maps with indexes and on-disk formats.

* **Tests**
  * Extensive new and updated end-to-end nullable-vector tests.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: marcelo-cjl <marcelo.chen@zilliz.com>
2025-12-24 10:13:19 +08:00

1371 lines
39 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"bytes"
"encoding/binary"
"fmt"
"math"
"sync"
"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/arrow/go/v17/arrow/array"
"github.com/apache/arrow/go/v17/arrow/memory"
"github.com/apache/arrow/go/v17/parquet"
"github.com/apache/arrow/go/v17/parquet/compress"
"github.com/apache/arrow/go/v17/parquet/pqarrow"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
_ "github.com/milvus-io/milvus/internal/storage/compress" // register a custom zstd codec here, to avoid to much memory usage when serializing.
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
var _ PayloadWriterInterface = (*NativePayloadWriter)(nil)
type PayloadWriterOptions func(*NativePayloadWriter)
func WithNullable(nullable bool) PayloadWriterOptions {
return func(w *NativePayloadWriter) {
w.nullable = nullable
}
}
func WithWriterProps(writerProps *parquet.WriterProperties) PayloadWriterOptions {
return func(w *NativePayloadWriter) {
w.writerProps = writerProps
}
}
func WithDim(dim int) PayloadWriterOptions {
return func(w *NativePayloadWriter) {
w.dim = NewNullableInt(dim)
}
}
func WithElementType(elementType schemapb.DataType) PayloadWriterOptions {
return func(w *NativePayloadWriter) {
w.elementType = &elementType
}
}
type NativePayloadWriter struct {
dataType schemapb.DataType
elementType *schemapb.DataType
arrowType arrow.DataType
builder array.Builder
finished bool
flushedRows int
output *bytes.Buffer
releaseOnce sync.Once
dim *NullableInt
nullable bool
writerProps *parquet.WriterProperties
}
func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions) (PayloadWriterInterface, error) {
w := &NativePayloadWriter{
dataType: colType,
finished: false,
flushedRows: 0,
output: new(bytes.Buffer),
nullable: false,
writerProps: parquet.NewWriterProperties(
parquet.WithCompression(compress.Codecs.Zstd),
parquet.WithCompressionLevel(3),
),
dim: &NullableInt{},
}
for _, o := range options {
o(w)
}
// writer for sparse float vector doesn't require dim
if typeutil.IsVectorType(colType) && !typeutil.IsSparseFloatVectorType(colType) {
if w.dim.IsNull() {
return nil, merr.WrapErrParameterInvalidMsg("incorrect input numbers")
}
} else {
w.dim = NewNullableInt(1)
}
// Handle ArrayOfVector type with elementType
if colType == schemapb.DataType_ArrayOfVector {
if w.elementType == nil {
return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires elementType, use WithElementType option")
}
if w.dim == nil {
return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires dim to be specified")
}
elemType, err := VectorArrayToArrowType(*w.elementType, *w.dim.Value)
if err != nil {
return nil, err
}
w.arrowType = arrow.ListOf(elemType)
w.builder = array.NewListBuilder(memory.DefaultAllocator, elemType)
} else {
if w.nullable && typeutil.IsVectorType(colType) && !typeutil.IsSparseFloatVectorType(colType) {
w.arrowType = &arrow.BinaryType{}
w.builder = array.NewBinaryBuilder(memory.DefaultAllocator, arrow.BinaryTypes.Binary)
} else {
w.arrowType = MilvusDataTypeToArrowType(colType, *w.dim.Value)
w.builder = array.NewBuilder(memory.DefaultAllocator, w.arrowType)
}
}
return w, nil
}
func (w *NativePayloadWriter) AddDataToPayloadForUT(data interface{}, validData []bool) error {
switch w.dataType {
case schemapb.DataType_Bool:
val, ok := data.([]bool)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddBoolToPayload(val, validData)
case schemapb.DataType_Int8:
val, ok := data.([]int8)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddInt8ToPayload(val, validData)
case schemapb.DataType_Int16:
val, ok := data.([]int16)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddInt16ToPayload(val, validData)
case schemapb.DataType_Int32:
val, ok := data.([]int32)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddInt32ToPayload(val, validData)
case schemapb.DataType_Int64:
val, ok := data.([]int64)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddInt64ToPayload(val, validData)
case schemapb.DataType_Float:
val, ok := data.([]float32)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddFloatToPayload(val, validData)
case schemapb.DataType_Double:
val, ok := data.([]float64)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddDoubleToPayload(val, validData)
case schemapb.DataType_Timestamptz:
val, ok := data.([]int64)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddTimestamptzToPayload(val, validData)
case schemapb.DataType_String, schemapb.DataType_VarChar:
val, ok := data.(string)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
isValid := true
if len(validData) > 1 {
return merr.WrapErrParameterInvalidMsg("wrong input length when add data to payload")
}
if len(validData) == 0 && w.nullable {
return merr.WrapErrParameterInvalidMsg("need pass valid_data when nullable==true")
}
if len(validData) == 1 {
if !w.nullable {
return merr.WrapErrParameterInvalidMsg("no need pass valid_data when nullable==false")
}
isValid = validData[0]
}
return w.AddOneStringToPayload(val, isValid)
case schemapb.DataType_Array:
val, ok := data.(*schemapb.ScalarField)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
isValid := true
if len(validData) > 1 {
return merr.WrapErrParameterInvalidMsg("wrong input length when add data to payload")
}
if len(validData) == 0 && w.nullable {
return merr.WrapErrParameterInvalidMsg("need pass valid_data when nullable==true")
}
if len(validData) == 1 {
if !w.nullable {
return merr.WrapErrParameterInvalidMsg("no need pass valid_data when nullable==false")
}
isValid = validData[0]
}
return w.AddOneArrayToPayload(val, isValid)
case schemapb.DataType_JSON:
val, ok := data.([]byte)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
isValid := true
if len(validData) > 1 {
return merr.WrapErrParameterInvalidMsg("wrong input length when add data to payload")
}
if len(validData) == 0 && w.nullable {
return merr.WrapErrParameterInvalidMsg("need pass valid_data when nullable==true")
}
if len(validData) == 1 {
if !w.nullable {
return merr.WrapErrParameterInvalidMsg("no need pass valid_data when nullable==false")
}
isValid = validData[0]
}
return w.AddOneJSONToPayload(val, isValid)
case schemapb.DataType_Geometry:
val, ok := data.([]byte)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
isValid := true
if len(validData) > 1 {
return merr.WrapErrParameterInvalidMsg("wrong input length when add data to payload")
}
if len(validData) == 0 && w.nullable {
return merr.WrapErrParameterInvalidMsg("need pass valid_data when nullable==true")
}
if len(validData) == 1 {
if !w.nullable {
return merr.WrapErrParameterInvalidMsg("no need pass valid_data when nullable==false")
}
isValid = validData[0]
}
return w.AddOneGeometryToPayload(val, isValid)
case schemapb.DataType_BinaryVector:
val, ok := data.([]byte)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddBinaryVectorToPayload(val, w.dim.GetValue(), validData)
case schemapb.DataType_FloatVector:
val, ok := data.([]float32)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddFloatVectorToPayload(val, w.dim.GetValue(), validData)
case schemapb.DataType_Float16Vector:
val, ok := data.([]byte)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddFloat16VectorToPayload(val, w.dim.GetValue(), validData)
case schemapb.DataType_BFloat16Vector:
val, ok := data.([]byte)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddBFloat16VectorToPayload(val, w.dim.GetValue(), validData)
case schemapb.DataType_SparseFloatVector:
val, ok := data.(*SparseFloatVectorFieldData)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddSparseFloatVectorToPayload(val)
case schemapb.DataType_Int8Vector:
val, ok := data.([]int8)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddInt8VectorToPayload(val, w.dim.GetValue(), validData)
case schemapb.DataType_ArrayOfVector:
val, ok := data.(*VectorArrayFieldData)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type: expected *VectorArrayFieldData")
}
return w.AddVectorArrayFieldDataToPayload(val)
default:
return errors.New("unsupported datatype")
}
}
func (w *NativePayloadWriter) AddBoolToPayload(data []bool, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished bool payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into bool payload")
}
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
if w.nullable && len(data) != len(validData) {
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
return merr.WrapErrParameterInvalidMsg(msg)
}
builder, ok := w.builder.(*array.BooleanBuilder)
if !ok {
return errors.New("failed to cast ArrayBuilder")
}
builder.AppendValues(data, validData)
return nil
}
func (w *NativePayloadWriter) AddByteToPayload(data []byte, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished byte payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into byte payload")
}
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
if w.nullable && len(data) != len(validData) {
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
return merr.WrapErrParameterInvalidMsg(msg)
}
builder, ok := w.builder.(*array.Int8Builder)
if !ok {
return errors.New("failed to cast ByteBuilder")
}
builder.Reserve(len(data))
for i := range data {
builder.Append(int8(data[i]))
if w.nullable && !validData[i] {
builder.AppendNull()
}
}
return nil
}
func (w *NativePayloadWriter) AddInt8ToPayload(data []int8, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished int8 payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into int8 payload")
}
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
if w.nullable && len(data) != len(validData) {
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
return merr.WrapErrParameterInvalidMsg(msg)
}
builder, ok := w.builder.(*array.Int8Builder)
if !ok {
return errors.New("failed to cast Int8Builder")
}
builder.AppendValues(data, validData)
return nil
}
func (w *NativePayloadWriter) AddInt16ToPayload(data []int16, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished int16 payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into int16 payload")
}
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
if w.nullable && len(data) != len(validData) {
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
return merr.WrapErrParameterInvalidMsg(msg)
}
builder, ok := w.builder.(*array.Int16Builder)
if !ok {
return errors.New("failed to cast Int16Builder")
}
builder.AppendValues(data, validData)
return nil
}
func (w *NativePayloadWriter) AddInt32ToPayload(data []int32, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished int32 payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into int32 payload")
}
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
if w.nullable && len(data) != len(validData) {
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
return merr.WrapErrParameterInvalidMsg(msg)
}
builder, ok := w.builder.(*array.Int32Builder)
if !ok {
return errors.New("failed to cast Int32Builder")
}
builder.AppendValues(data, validData)
return nil
}
func (w *NativePayloadWriter) AddInt64ToPayload(data []int64, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished int64 payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into int64 payload")
}
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
if w.nullable && len(data) != len(validData) {
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
return merr.WrapErrParameterInvalidMsg(msg)
}
builder, ok := w.builder.(*array.Int64Builder)
if !ok {
return errors.New("failed to cast Int64Builder")
}
builder.AppendValues(data, validData)
return nil
}
func (w *NativePayloadWriter) AddFloatToPayload(data []float32, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished float payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into float payload")
}
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
if w.nullable && len(data) != len(validData) {
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
return merr.WrapErrParameterInvalidMsg(msg)
}
builder, ok := w.builder.(*array.Float32Builder)
if !ok {
return errors.New("failed to cast FloatBuilder")
}
builder.AppendValues(data, validData)
return nil
}
func (w *NativePayloadWriter) AddDoubleToPayload(data []float64, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished double payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into double payload")
}
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
if w.nullable && len(data) != len(validData) {
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
return merr.WrapErrParameterInvalidMsg(msg)
}
builder, ok := w.builder.(*array.Float64Builder)
if !ok {
return errors.New("failed to cast DoubleBuilder")
}
builder.AppendValues(data, validData)
return nil
}
func (w *NativePayloadWriter) AddTimestamptzToPayload(data []int64, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished int64 payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into int64 payload")
}
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
if w.nullable && len(data) != len(validData) {
msg := fmt.Sprintf("length of validData(%d) must equal to data(%d) when nullable", len(validData), len(data))
return merr.WrapErrParameterInvalidMsg(msg)
}
builder, ok := w.builder.(*array.Int64Builder)
if !ok {
return errors.New("failed to cast Int64Builder")
}
builder.AppendValues(data, validData)
return nil
}
func (w *NativePayloadWriter) AddOneStringToPayload(data string, isValid bool) error {
if w.finished {
return errors.New("can't append data to finished string payload")
}
if !w.nullable && !isValid {
return merr.WrapErrParameterInvalidMsg("not support null when nullable is false")
}
builder, ok := w.builder.(*array.StringBuilder)
if !ok {
return errors.New("failed to cast StringBuilder")
}
if !isValid {
builder.AppendNull()
} else {
builder.Append(data)
}
return nil
}
func (w *NativePayloadWriter) AddOneArrayToPayload(data *schemapb.ScalarField, isValid bool) error {
if w.finished {
return errors.New("can't append data to finished array payload")
}
if !w.nullable && !isValid {
return merr.WrapErrParameterInvalidMsg("not support null when nullable is false")
}
bytes, err := proto.Marshal(data)
if err != nil {
return errors.New("Marshal ListValue failed")
}
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast BinaryBuilder")
}
if !isValid {
builder.AppendNull()
} else {
builder.Append(bytes)
}
return nil
}
func (w *NativePayloadWriter) AddOneJSONToPayload(data []byte, isValid bool) error {
if w.finished {
return errors.New("can't append data to finished json payload")
}
if !w.nullable && !isValid {
return merr.WrapErrParameterInvalidMsg("not support null when nullable is false")
}
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast JsonBuilder")
}
if !isValid {
builder.AppendNull()
} else {
builder.Append(data)
}
return nil
}
func (w *NativePayloadWriter) AddOneGeometryToPayload(data []byte, isValid bool) error {
if w.finished {
return errors.New("can't append data to finished geometry payload")
}
if !w.nullable && !isValid {
return merr.WrapErrParameterInvalidMsg("not support null when nullable is false")
}
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast geometryBuilder")
}
if !isValid {
builder.AppendNull()
} else {
builder.Append(data)
}
return nil
}
func (w *NativePayloadWriter) AddBinaryVectorToPayload(data []byte, dim int, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished binary vector payload")
}
byteLength := dim / 8
var numRows int
if w.nullable && len(validData) > 0 {
numRows = len(validData)
validCount := 0
for _, valid := range validData {
if valid {
validCount++
}
}
expectedDataLen := validCount * byteLength
if len(data) != expectedDataLen {
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
if len(data) == 0 {
return errors.New("can't add empty msgs into binary vector payload")
}
numRows = len(data) / byteLength
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
if w.nullable {
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast to BinaryBuilder for nullable BinaryVector")
}
builder.Reserve(numRows)
dataIdx := 0
for i := 0; i < numRows; i++ {
if len(validData) > 0 && !validData[i] {
builder.AppendNull()
} else {
builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength])
dataIdx++
}
}
} else {
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable BinaryVector")
}
builder.Reserve(numRows)
for i := 0; i < numRows; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
}
}
return nil
}
func (w *NativePayloadWriter) AddFloatVectorToPayload(data []float32, dim int, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished float vector payload")
}
var numRows int
if w.nullable && len(validData) > 0 {
numRows = len(validData)
validCount := 0
for _, valid := range validData {
if valid {
validCount++
}
}
expectedDataLen := validCount * dim
if len(data) != expectedDataLen {
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * dim(%d) = %d", len(data), validCount, dim, expectedDataLen)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
if len(data) == 0 {
return errors.New("can't add empty msgs into float vector payload")
}
numRows = len(data) / dim
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
byteLength := dim * 4
if w.nullable {
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast to BinaryBuilder for nullable FloatVector")
}
builder.Reserve(numRows)
bytesData := make([]byte, byteLength)
dataIdx := 0
for i := 0; i < numRows; i++ {
if len(validData) > 0 && !validData[i] {
builder.AppendNull()
} else {
vec := data[dataIdx*dim : (dataIdx+1)*dim]
for j := range vec {
bytes := math.Float32bits(vec[j])
common.Endian.PutUint32(bytesData[j*4:], bytes)
}
builder.Append(bytesData)
dataIdx++
}
}
} else {
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable FloatVector")
}
builder.Reserve(numRows)
bytesData := make([]byte, byteLength)
for i := 0; i < numRows; i++ {
vec := data[i*dim : (i+1)*dim]
for j := range vec {
bytes := math.Float32bits(vec[j])
common.Endian.PutUint32(bytesData[j*4:], bytes)
}
builder.Append(bytesData)
}
}
return nil
}
func (w *NativePayloadWriter) AddFloat16VectorToPayload(data []byte, dim int, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished float16 payload")
}
byteLength := dim * 2
var numRows int
if w.nullable && len(validData) > 0 {
numRows = len(validData)
validCount := 0
for _, valid := range validData {
if valid {
validCount++
}
}
expectedDataLen := validCount * byteLength
if len(data) != expectedDataLen {
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
if len(data) == 0 {
return errors.New("can't add empty msgs into float16 payload")
}
numRows = len(data) / byteLength
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
if w.nullable {
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast to BinaryBuilder for nullable Float16Vector")
}
builder.Reserve(numRows)
dataIdx := 0
for i := 0; i < numRows; i++ {
if len(validData) > 0 && !validData[i] {
builder.AppendNull()
} else {
builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength])
dataIdx++
}
}
} else {
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable Float16Vector")
}
builder.Reserve(numRows)
for i := 0; i < numRows; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
}
}
return nil
}
func (w *NativePayloadWriter) AddBFloat16VectorToPayload(data []byte, dim int, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished BFloat16 payload")
}
byteLength := dim * 2
var numRows int
if w.nullable && len(validData) > 0 {
numRows = len(validData)
validCount := 0
for _, valid := range validData {
if valid {
validCount++
}
}
expectedDataLen := validCount * byteLength
if len(data) != expectedDataLen {
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
if len(data) == 0 {
return errors.New("can't add empty msgs into BFloat16 payload")
}
numRows = len(data) / byteLength
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
if w.nullable {
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast to BinaryBuilder for nullable BFloat16Vector")
}
builder.Reserve(numRows)
dataIdx := 0
for i := 0; i < numRows; i++ {
if len(validData) > 0 && !validData[i] {
builder.AppendNull()
} else {
builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength])
dataIdx++
}
}
} else {
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable BFloat16Vector")
}
builder.Reserve(numRows)
for i := 0; i < numRows; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
}
}
return nil
}
func (w *NativePayloadWriter) AddSparseFloatVectorToPayload(data *SparseFloatVectorFieldData) error {
if w.finished {
return errors.New("can't append data to finished sparse float vector payload")
}
var numRows int
if w.nullable && len(data.ValidData) > 0 {
numRows = len(data.ValidData)
validCount := 0
for _, valid := range data.ValidData {
if valid {
validCount++
}
}
if len(data.SparseFloatArray.Contents) != validCount {
msg := fmt.Sprintf("when nullable, Contents length(%d) must equal to valid count(%d)", len(data.SparseFloatArray.Contents), validCount)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
numRows = len(data.SparseFloatArray.Contents)
if !w.nullable && len(data.ValidData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(data.ValidData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast SparseFloatVectorBuilder")
}
builder.Reserve(numRows)
dataIdx := 0
for i := 0; i < numRows; i++ {
if w.nullable && len(data.ValidData) > 0 && !data.ValidData[i] {
builder.AppendNull()
} else {
builder.Append(data.SparseFloatArray.Contents[dataIdx])
dataIdx++
}
}
return nil
}
func (w *NativePayloadWriter) AddInt8VectorToPayload(data []int8, dim int, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished int8 vector payload")
}
var numRows int
if w.nullable && len(validData) > 0 {
numRows = len(validData)
validCount := 0
for _, valid := range validData {
if valid {
validCount++
}
}
expectedDataLen := validCount * dim
if len(data) != expectedDataLen {
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * dim(%d) = %d", len(data), validCount, dim, expectedDataLen)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
if len(data) == 0 {
return errors.New("can't add empty msgs into int8 vector payload")
}
numRows = len(data) / dim
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
if w.nullable {
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast to BinaryBuilder for nullable Int8Vector")
}
builder.Reserve(numRows)
dataIdx := 0
for i := 0; i < numRows; i++ {
if len(validData) > 0 && !validData[i] {
builder.AppendNull()
} else {
vec := data[dataIdx*dim : (dataIdx+1)*dim]
vecBytes := arrow.Int8Traits.CastToBytes(vec)
builder.Append(vecBytes)
dataIdx++
}
}
} else {
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable Int8Vector")
}
builder.Reserve(numRows)
for i := 0; i < numRows; i++ {
vec := data[i*dim : (i+1)*dim]
vecBytes := arrow.Int8Traits.CastToBytes(vec)
builder.Append(vecBytes)
}
}
return nil
}
func (w *NativePayloadWriter) FinishPayloadWriter() error {
if w.finished {
return errors.New("can't reuse a finished writer")
}
w.finished = true
// Prepare metadata for VectorArray type
var metadata arrow.Metadata
if w.dataType == schemapb.DataType_ArrayOfVector {
if w.elementType == nil {
return errors.New("element type for DataType_ArrayOfVector must be set")
}
metadata = arrow.NewMetadata(
[]string{"elementType", "dim"},
[]string{fmt.Sprintf("%d", int32(*w.elementType)), fmt.Sprintf("%d", w.dim.GetValue())},
)
} else if w.nullable && typeutil.IsVectorType(w.dataType) && !typeutil.IsSparseFloatVectorType(w.dataType) {
metadata = arrow.NewMetadata(
[]string{"dim"},
[]string{fmt.Sprintf("%d", w.dim.GetValue())},
)
}
field := arrow.Field{
Name: "val",
Type: w.arrowType,
Nullable: w.nullable,
Metadata: metadata,
}
schema := arrow.NewSchema([]arrow.Field{
field,
}, nil)
w.flushedRows += w.builder.Len()
data := w.builder.NewArray()
defer data.Release()
column := arrow.NewColumnFromArr(field, data)
defer column.Release()
table := array.NewTable(schema, []arrow.Column{column}, int64(column.Len()))
defer table.Release()
arrowWriterProps := pqarrow.DefaultWriterProps()
if w.dataType == schemapb.DataType_ArrayOfVector ||
(w.nullable && typeutil.IsVectorType(w.dataType) && !typeutil.IsSparseFloatVectorType(w.dataType)) {
// Store metadata in the Arrow writer properties
arrowWriterProps = pqarrow.NewArrowWriterProperties(
pqarrow.WithStoreSchema(),
)
}
return pqarrow.WriteTable(table,
w.output,
1024*1024*1024,
w.writerProps,
arrowWriterProps,
)
}
func (w *NativePayloadWriter) Reserve(size int) {
w.builder.Reserve(size)
}
func (w *NativePayloadWriter) GetPayloadBufferFromWriter() ([]byte, error) {
data := w.output.Bytes()
// The cpp version of payload writer handles the empty buffer as error
if len(data) == 0 {
return nil, errors.New("empty buffer")
}
return data, nil
}
func (w *NativePayloadWriter) GetPayloadLengthFromWriter() (int, error) {
return w.flushedRows + w.builder.Len(), nil
}
func (w *NativePayloadWriter) ReleasePayloadWriter() {
w.releaseOnce.Do(func() {
w.builder.Release()
})
}
func (w *NativePayloadWriter) Close() {
w.ReleasePayloadWriter()
}
func MilvusDataTypeToArrowType(dataType schemapb.DataType, dim int) arrow.DataType {
switch dataType {
case schemapb.DataType_Bool:
return &arrow.BooleanType{}
case schemapb.DataType_Int8:
return &arrow.Int8Type{}
case schemapb.DataType_Int16:
return &arrow.Int16Type{}
case schemapb.DataType_Int32:
return &arrow.Int32Type{}
case schemapb.DataType_Int64, schemapb.DataType_Timestamptz:
return &arrow.Int64Type{}
case schemapb.DataType_Float:
return &arrow.Float32Type{}
case schemapb.DataType_Double:
return &arrow.Float64Type{}
case schemapb.DataType_VarChar, schemapb.DataType_String, schemapb.DataType_Text:
return &arrow.StringType{}
case schemapb.DataType_Array:
return &arrow.BinaryType{}
case schemapb.DataType_JSON:
return &arrow.BinaryType{}
case schemapb.DataType_Geometry:
return &arrow.BinaryType{}
case schemapb.DataType_FloatVector:
return &arrow.FixedSizeBinaryType{
ByteWidth: dim * 4,
}
case schemapb.DataType_BinaryVector:
return &arrow.FixedSizeBinaryType{
ByteWidth: dim / 8,
}
case schemapb.DataType_Float16Vector:
return &arrow.FixedSizeBinaryType{
ByteWidth: dim * 2,
}
case schemapb.DataType_BFloat16Vector:
return &arrow.FixedSizeBinaryType{
ByteWidth: dim * 2,
}
case schemapb.DataType_SparseFloatVector:
return &arrow.BinaryType{}
case schemapb.DataType_Int8Vector:
return &arrow.FixedSizeBinaryType{
ByteWidth: dim,
}
case schemapb.DataType_ArrayOfVector:
// ArrayOfVector requires elementType, should use VectorArrayToArrowType instead
panic("ArrayOfVector type requires elementType information, use VectorArrayToArrowType")
default:
panic("unsupported data type")
}
}
// AddVectorArrayFieldDataToPayload adds VectorArrayFieldData to payload using Arrow ListArray
func (w *NativePayloadWriter) AddVectorArrayFieldDataToPayload(data *VectorArrayFieldData) error {
if w.finished {
return errors.New("can't append data to finished vector array payload")
}
if len(data.Data) == 0 {
return errors.New("can't add empty vector array field data")
}
builder, ok := w.builder.(*array.ListBuilder)
if !ok {
return errors.New("failed to cast to ListBuilder for VectorArray")
}
switch data.ElementType {
case schemapb.DataType_FloatVector:
return w.addFloatVectorArrayToPayload(builder, data)
case schemapb.DataType_BinaryVector:
return w.addBinaryVectorArrayToPayload(builder, data)
case schemapb.DataType_Float16Vector:
return w.addFloat16VectorArrayToPayload(builder, data)
case schemapb.DataType_BFloat16Vector:
return w.addBFloat16VectorArrayToPayload(builder, data)
case schemapb.DataType_Int8Vector:
return w.addInt8VectorArrayToPayload(builder, data)
default:
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", data.ElementType.String()))
}
}
// addFloatVectorArrayToPayload handles FloatVector elements in VectorArray
func (w *NativePayloadWriter) addFloatVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
if data.Dim <= 0 {
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
}
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
// Each element in data.Data represents one row of VectorArray
for _, vectorField := range data.Data {
if vectorField.GetFloatVector() == nil {
return merr.WrapErrParameterInvalidMsg("expected FloatVector but got different type")
}
// Start a new list for this row
builder.Append(true)
floatData := vectorField.GetFloatVector().GetData()
dim := vectorField.GetDim()
numVectors := len(floatData) / int(dim)
for i := 0; i < numVectors; i++ {
start := i * int(dim)
end := start + int(dim)
vectorSlice := floatData[start:end]
bytes := make([]byte, dim*4)
for j, f := range vectorSlice {
binary.LittleEndian.PutUint32(bytes[j*4:], math.Float32bits(f))
}
valueBuilder.Append(bytes)
}
}
return nil
}
// addBinaryVectorArrayToPayload handles BinaryVector elements in VectorArray
func (w *NativePayloadWriter) addBinaryVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
if data.Dim <= 0 {
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
}
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
// Each element in data.Data represents one row of VectorArray
for _, vectorField := range data.Data {
if vectorField.GetBinaryVector() == nil {
return merr.WrapErrParameterInvalidMsg("expected BinaryVector but got different type")
}
// Start a new list for this row
builder.Append(true)
binaryData := vectorField.GetBinaryVector()
byteWidth := (data.Dim + 7) / 8
numVectors := len(binaryData) / int(byteWidth)
for i := 0; i < numVectors; i++ {
start := i * int(byteWidth)
end := start + int(byteWidth)
valueBuilder.Append(binaryData[start:end])
}
}
return nil
}
// addFloat16VectorArrayToPayload handles Float16Vector elements in VectorArray
func (w *NativePayloadWriter) addFloat16VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
if data.Dim <= 0 {
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
}
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
// Each element in data.Data represents one row of VectorArray
for _, vectorField := range data.Data {
if vectorField.GetFloat16Vector() == nil {
return merr.WrapErrParameterInvalidMsg("expected Float16Vector but got different type")
}
// Start a new list for this row
builder.Append(true)
float16Data := vectorField.GetFloat16Vector()
byteWidth := data.Dim * 2
numVectors := len(float16Data) / int(byteWidth)
for i := 0; i < numVectors; i++ {
start := i * int(byteWidth)
end := start + int(byteWidth)
valueBuilder.Append(float16Data[start:end])
}
}
return nil
}
// addBFloat16VectorArrayToPayload handles BFloat16Vector elements in VectorArray
func (w *NativePayloadWriter) addBFloat16VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
if data.Dim <= 0 {
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
}
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
// Each element in data.Data represents one row of VectorArray
for _, vectorField := range data.Data {
if vectorField.GetBfloat16Vector() == nil {
return merr.WrapErrParameterInvalidMsg("expected BFloat16Vector but got different type")
}
// Start a new list for this row
builder.Append(true)
bfloat16Data := vectorField.GetBfloat16Vector()
byteWidth := data.Dim * 2
numVectors := len(bfloat16Data) / int(byteWidth)
for i := 0; i < numVectors; i++ {
start := i * int(byteWidth)
end := start + int(byteWidth)
valueBuilder.Append(bfloat16Data[start:end])
}
}
return nil
}
// addInt8VectorArrayToPayload handles Int8Vector elements in VectorArray
func (w *NativePayloadWriter) addInt8VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
if data.Dim <= 0 {
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
}
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
// Each element in data.Data represents one row of VectorArray
for _, vectorField := range data.Data {
if vectorField.GetInt8Vector() == nil {
return merr.WrapErrParameterInvalidMsg("expected Int8Vector but got different type")
}
// Start a new list for this row
builder.Append(true)
int8Data := vectorField.GetInt8Vector()
numVectors := len(int8Data) / int(data.Dim)
for i := 0; i < numVectors; i++ {
start := i * int(data.Dim)
end := start + int(data.Dim)
valueBuilder.Append(int8Data[start:end])
}
}
return nil
}