diff --git a/client/column/columns.go b/client/column/columns.go index c3350e1732..2b1a885d9b 100644 --- a/client/column/columns.go +++ b/client/column/columns.go @@ -46,6 +46,7 @@ type Column interface { SetNullable(bool) ValidateNullable() error CompactNullableValues() + ValidCount() int } var errFieldDataTypeNotMatch = errors.New("FieldData type not matched") @@ -239,10 +240,39 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { } data := x.FloatVector.GetData() dim := int(vectors.GetDim()) + + if len(validData) > 0 { + if end < 0 { + end = len(validData) + } + vector := make([][]float32, 0, end-begin) + dataIdx := 0 + for i := 0; i < begin; i++ { + if validData[i] { + dataIdx++ + } + } + for i := begin; i < end; i++ { + if validData[i] { + v := make([]float32, dim) + copy(v, data[dataIdx*dim:(dataIdx+1)*dim]) + vector = append(vector, v) + dataIdx++ + } else { + vector = append(vector, nil) + } + } + col := NewColumnFloatVector(fd.GetFieldName(), dim, vector) + col.withValidData(validData[begin:end]) + col.nullable = true + col.sparseMode = true + return col, nil + } + if end < 0 { end = len(data) / dim } - vector := make([][]float32, 0, end-begin) // shall not have remanunt + vector := make([][]float32, 0, end-begin) for i := begin; i < end; i++ { v := make([]float32, dim) copy(v, data[i*dim:(i+1)*dim]) @@ -262,6 +292,35 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { } dim := int(vectors.GetDim()) blen := dim / 8 + + if len(validData) > 0 { + if end < 0 { + end = len(validData) + } + vector := make([][]byte, 0, end-begin) + dataIdx := 0 + for i := 0; i < begin; i++ { + if validData[i] { + dataIdx++ + } + } + for i := begin; i < end; i++ { + if validData[i] { + v := make([]byte, blen) + copy(v, data[dataIdx*blen:(dataIdx+1)*blen]) + vector = append(vector, v) + dataIdx++ + } else { + vector = append(vector, nil) + } + } + col := NewColumnBinaryVector(fd.GetFieldName(), dim, vector) + col.withValidData(validData[begin:end]) + col.nullable = true + col.sparseMode = true + return col, nil + } + if end < 0 { end = len(data) / blen } @@ -281,13 +340,43 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { } data := x.Float16Vector dim := int(vectors.GetDim()) + bytePerRow := dim * 2 + + if len(validData) > 0 { + if end < 0 { + end = len(validData) + } + vector := make([][]byte, 0, end-begin) + dataIdx := 0 + for i := 0; i < begin; i++ { + if validData[i] { + dataIdx++ + } + } + for i := begin; i < end; i++ { + if validData[i] { + v := make([]byte, bytePerRow) + copy(v, data[dataIdx*bytePerRow:(dataIdx+1)*bytePerRow]) + vector = append(vector, v) + dataIdx++ + } else { + vector = append(vector, nil) + } + } + col := NewColumnFloat16Vector(fd.GetFieldName(), dim, vector) + col.withValidData(validData[begin:end]) + col.nullable = true + col.sparseMode = true + return col, nil + } + if end < 0 { - end = len(data) / dim / 2 + end = len(data) / bytePerRow } vector := make([][]byte, 0, end-begin) for i := begin; i < end; i++ { - v := make([]byte, dim*2) - copy(v, data[i*dim*2:(i+1)*dim*2]) + v := make([]byte, bytePerRow) + copy(v, data[i*bytePerRow:(i+1)*bytePerRow]) vector = append(vector, v) } return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil @@ -300,13 +389,43 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { } data := x.Bfloat16Vector dim := int(vectors.GetDim()) - if end < 0 { - end = len(data) / dim / 2 + bytePerRow := dim * 2 + + if len(validData) > 0 { + if end < 0 { + end = len(validData) + } + vector := make([][]byte, 0, end-begin) + dataIdx := 0 + for i := 0; i < begin; i++ { + if validData[i] { + dataIdx++ + } + } + for i := begin; i < end; i++ { + if validData[i] { + v := make([]byte, bytePerRow) + copy(v, data[dataIdx*bytePerRow:(dataIdx+1)*bytePerRow]) + vector = append(vector, v) + dataIdx++ + } else { + vector = append(vector, nil) + } + } + col := NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector) + col.withValidData(validData[begin:end]) + col.nullable = true + col.sparseMode = true + return col, nil } - vector := make([][]byte, 0, end-begin) // shall not have remanunt + + if end < 0 { + end = len(data) / bytePerRow + } + vector := make([][]byte, 0, end-begin) for i := begin; i < end; i++ { - v := make([]byte, dim*2) - copy(v, data[i*dim*2:(i+1)*dim*2]) + v := make([]byte, bytePerRow) + copy(v, data[i*bytePerRow:(i+1)*bytePerRow]) vector = append(vector, v) } return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil @@ -317,6 +436,37 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { return nil, errFieldDataTypeNotMatch } data := sparseVectors.Contents + + if len(validData) > 0 { + if end < 0 { + end = len(validData) + } + vectors := make([]entity.SparseEmbedding, 0, end-begin) + dataIdx := 0 + for i := 0; i < begin; i++ { + if validData[i] { + dataIdx++ + } + } + for i := begin; i < end; i++ { + if validData[i] { + vector, err := entity.DeserializeSliceSparseEmbedding(data[dataIdx]) + if err != nil { + return nil, err + } + vectors = append(vectors, vector) + dataIdx++ + } else { + vectors = append(vectors, nil) + } + } + col := NewColumnSparseVectors(fd.GetFieldName(), vectors) + col.withValidData(validData[begin:end]) + col.nullable = true + col.sparseMode = true + return col, nil + } + if end < 0 { end = len(data) } @@ -339,11 +489,41 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) { } data := x.Int8Vector dim := int(vectors.GetDim()) + + if len(validData) > 0 { + if end < 0 { + end = len(validData) + } + vector := make([][]int8, 0, end-begin) + dataIdx := 0 + for i := 0; i < begin; i++ { + if validData[i] { + dataIdx++ + } + } + for i := begin; i < end; i++ { + if validData[i] { + v := make([]int8, dim) + for j := 0; j < dim; j++ { + v[j] = int8(data[dataIdx*dim+j]) + } + vector = append(vector, v) + dataIdx++ + } else { + vector = append(vector, nil) + } + } + col := NewColumnInt8Vector(fd.GetFieldName(), dim, vector) + col.withValidData(validData[begin:end]) + col.nullable = true + col.sparseMode = true + return col, nil + } + if end < 0 { end = len(data) / dim } - vector := make([][]int8, 0, end-begin) // shall not have remanunt - // TODO caiyd: has better way to convert []byte to []int8 ? + vector := make([][]int8, 0, end-begin) for i := begin; i < end; i++ { v := make([]int8, dim) for j := 0; j < dim; j++ { diff --git a/client/column/generic_base.go b/client/column/generic_base.go index b85f10b24f..5c1c6b46bc 100644 --- a/client/column/generic_base.go +++ b/client/column/generic_base.go @@ -301,6 +301,19 @@ func (c *genericColumnBase[T]) CompactNullableValues() { c.values = c.values[0:cnt] } +func (c *genericColumnBase[T]) ValidCount() int { + if !c.nullable || len(c.validData) == 0 { + return len(c.values) + } + count := 0 + for _, v := range c.validData { + if v { + count++ + } + } + return count +} + func (c *genericColumnBase[T]) withValidData(validData []bool) { if len(validData) > 0 { c.nullable = true diff --git a/client/column/nullable.go b/client/column/nullable.go index a2f02a1f0e..19d21061d6 100644 --- a/client/column/nullable.go +++ b/client/column/nullable.go @@ -16,6 +16,12 @@ package column +import ( + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/client/v2/entity" +) + var ( // scalars NewNullableColumnBool NullableColumnCreateFunc[bool, *ColumnBool] = NewNullableColumnCreator(NewColumnBool).New @@ -41,6 +47,76 @@ var ( NewNullableColumnDoubleArray NullableColumnCreateFunc[[]float64, *ColumnDoubleArray] = NewNullableColumnCreator(NewColumnDoubleArray).New ) +func NewNullableColumnFloatVector(fieldName string, dim int, values [][]float32, validData []bool) (*ColumnFloatVector, error) { + if len(values) != getValidCount(validData) { + return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData)) + } + col := NewColumnFloatVector(fieldName, dim, values) + col.withValidData(validData) + col.nullable = true + return col, nil +} + +func NewNullableColumnBinaryVector(fieldName string, dim int, values [][]byte, validData []bool) (*ColumnBinaryVector, error) { + if len(values) != getValidCount(validData) { + return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData)) + } + col := NewColumnBinaryVector(fieldName, dim, values) + col.withValidData(validData) + col.nullable = true + return col, nil +} + +func NewNullableColumnFloat16Vector(fieldName string, dim int, values [][]byte, validData []bool) (*ColumnFloat16Vector, error) { + if len(values) != getValidCount(validData) { + return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData)) + } + col := NewColumnFloat16Vector(fieldName, dim, values) + col.withValidData(validData) + col.nullable = true + return col, nil +} + +func NewNullableColumnBFloat16Vector(fieldName string, dim int, values [][]byte, validData []bool) (*ColumnBFloat16Vector, error) { + if len(values) != getValidCount(validData) { + return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData)) + } + col := NewColumnBFloat16Vector(fieldName, dim, values) + col.withValidData(validData) + col.nullable = true + return col, nil +} + +func NewNullableColumnInt8Vector(fieldName string, dim int, values [][]int8, validData []bool) (*ColumnInt8Vector, error) { + if len(values) != getValidCount(validData) { + return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData)) + } + col := NewColumnInt8Vector(fieldName, dim, values) + col.withValidData(validData) + col.nullable = true + return col, nil +} + +func NewNullableColumnSparseFloatVector(fieldName string, values []entity.SparseEmbedding, validData []bool) (*ColumnSparseFloatVector, error) { + if len(values) != getValidCount(validData) { + return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData)) + } + col := NewColumnSparseVectors(fieldName, values) + col.withValidData(validData) + col.nullable = true + return col, nil +} + +func getValidCount(validData []bool) int { + count := 0 + for _, v := range validData { + if v { + count++ + } + } + return count +} + type NullableColumnCreateFunc[T any, Col interface { Column Data() []T diff --git a/client/column/sparse.go b/client/column/sparse.go index 8b68ace541..ee152eec5c 100644 --- a/client/column/sparse.go +++ b/client/column/sparse.go @@ -38,11 +38,15 @@ func NewColumnSparseVectors(name string, values []entity.SparseEmbedding) *Colum func (c *ColumnSparseFloatVector) FieldData() *schemapb.FieldData { fd := c.vectorBase.FieldData() - max := lo.MaxBy(c.values, func(a, b entity.SparseEmbedding) bool { - return a.Dim() > b.Dim() - }) vectors := fd.GetVectors() - vectors.Dim = int64(max.Dim()) + if len(c.values) > 0 { + max := lo.MaxBy(c.values, func(a, b entity.SparseEmbedding) bool { + return a.Dim() > b.Dim() + }) + vectors.Dim = int64(max.Dim()) + } else { + vectors.Dim = 0 + } return fd } diff --git a/client/column/struct.go b/client/column/struct.go index 4b06886fc1..abc2c139f8 100644 --- a/client/column/struct.go +++ b/client/column/struct.go @@ -136,3 +136,7 @@ func (c *columnStructArray) CompactNullableValues() { field.CompactNullableValues() } } + +func (c *columnStructArray) ValidCount() int { + return c.Len() +} diff --git a/client/entity/field.go b/client/entity/field.go index b1795ef51e..66a092fef5 100644 --- a/client/entity/field.go +++ b/client/entity/field.go @@ -206,6 +206,17 @@ const ( FieldTypeStruct FieldType = 201 ) +// IsVectorType returns true if the field type is a vector type +func (t FieldType) IsVectorType() bool { + switch t { + case FieldTypeBinaryVector, FieldTypeFloatVector, FieldTypeFloat16Vector, + FieldTypeBFloat16Vector, FieldTypeSparseVector, FieldTypeInt8Vector: + return true + default: + return false + } +} + // Field represent field schema in milvus type Field struct { ID int64 // field id, generated when collection is created, input value is ignored diff --git a/client/milvusclient/collection.go b/client/milvusclient/collection.go index 5155ccfcc3..2f5e2431c2 100644 --- a/client/milvusclient/collection.go +++ b/client/milvusclient/collection.go @@ -185,6 +185,10 @@ func (c *Client) GetCollectionStats(ctx context.Context, opt GetCollectionOption // AddCollectionField adds a field to a collection. func (c *Client) AddCollectionField(ctx context.Context, opt AddCollectionFieldOption, callOpts ...grpc.CallOption) error { + if err := opt.Validate(); err != nil { + return err + } + req := opt.Request() err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { diff --git a/client/milvusclient/collection_options.go b/client/milvusclient/collection_options.go index ba994ace8e..49b0810996 100644 --- a/client/milvusclient/collection_options.go +++ b/client/milvusclient/collection_options.go @@ -405,6 +405,8 @@ func NewGetCollectionStatsOption(collectionName string) *getCollectionStatsOptio type AddCollectionFieldOption interface { Request() *milvuspb.AddCollectionFieldRequest + // Validate validates the option before sending request + Validate() error } type addCollectionFieldOption struct { @@ -420,6 +422,15 @@ func (c *addCollectionFieldOption) Request() *milvuspb.AddCollectionFieldRequest } } +// Validate validates the option before sending request +func (c *addCollectionFieldOption) Validate() error { + // Vector fields must be nullable when adding to existing collection + if c.fieldSch.DataType.IsVectorType() && !c.fieldSch.Nullable { + return fmt.Errorf("adding vector field to existing collection requires nullable=true, field name = %s", c.fieldSch.Name) + } + return nil +} + func NewAddCollectionFieldOption(collectionName string, field *entity.Field) *addCollectionFieldOption { return &addCollectionFieldOption{ collectionName: collectionName, diff --git a/client/milvusclient/collection_test.go b/client/milvusclient/collection_test.go index 18d21699c1..13058ca33f 100644 --- a/client/milvusclient/collection_test.go +++ b/client/milvusclient/collection_test.go @@ -441,6 +441,37 @@ func (s *CollectionSuite) TestAddCollectionField() { err := s.client.AddCollectionField(ctx, NewAddCollectionFieldOption(collName, field)) s.Error(err) }) + + s.Run("vector_field_without_nullable", func() { + collName := fmt.Sprintf("coll_%s", s.randString(6)) + fieldName := fmt.Sprintf("field_%s", s.randString(6)) + // no mock expected because validation should fail before RPC call + + field := entity.NewField().WithName(fieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(128) + + err := s.client.AddCollectionField(ctx, NewAddCollectionFieldOption(collName, field)) + s.Error(err) + s.Contains(err.Error(), "adding vector field to existing collection requires nullable=true") + }) + + s.Run("vector_field_with_nullable", func() { + collName := fmt.Sprintf("coll_%s", s.randString(6)) + fieldName := fmt.Sprintf("field_%s", s.randString(6)) + s.mock.EXPECT().AddCollectionField(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, acfr *milvuspb.AddCollectionFieldRequest) (*commonpb.Status, error) { + fieldProto := &schemapb.FieldSchema{} + err := proto.Unmarshal(acfr.GetSchema(), fieldProto) + s.Require().NoError(err) + s.Equal(fieldName, fieldProto.GetName()) + s.Equal(schemapb.DataType_FloatVector, fieldProto.GetDataType()) + s.True(fieldProto.GetNullable()) + return merr.Success(), nil + }).Once() + + field := entity.NewField().WithName(fieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(128).WithNullable(true) + + err := s.client.AddCollectionField(ctx, NewAddCollectionFieldOption(collName, field)) + s.NoError(err) + }) } func TestCollection(t *testing.T) { diff --git a/internal/core/src/common/Chunk.h b/internal/core/src/common/Chunk.h index 3fb44b1cce..6e30766011 100644 --- a/internal/core/src/common/Chunk.h +++ b/internal/core/src/common/Chunk.h @@ -126,6 +126,11 @@ class Chunk { return data_; } + FixedVector& + Valid() { + return valid_; + } + virtual bool isValid(int offset) const { if (nullable_) { @@ -559,17 +564,32 @@ class SparseFloatVectorChunk : public Chunk { bool nullable, std::shared_ptr chunk_mmap_guard) : Chunk(row_nums, data, size, nullable, chunk_mmap_guard) { - vec_.resize(row_nums); auto null_bitmap_bytes_num = nullable ? (row_nums + 7) / 8 : 0; auto offsets_ptr = reinterpret_cast(data + null_bitmap_bytes_num); - for (int i = 0; i < row_nums; i++) { - vec_[i] = {(offsets_ptr[i + 1] - offsets_ptr[i]) / - knowhere::sparse::SparseRow< - SparseValueType>::element_size(), - reinterpret_cast(data + offsets_ptr[i]), - false}; - dim_ = std::max(dim_, vec_[i].dim()); + + if (nullable_) { + for (int i = 0; i < row_nums; i++) { + if (isValid(i)) { + vec_.emplace_back( + (offsets_ptr[i + 1] - offsets_ptr[i]) / + knowhere::sparse::SparseRow< + SparseValueType>::element_size(), + reinterpret_cast(data + offsets_ptr[i]), + false); + dim_ = std::max(dim_, vec_.back().dim()); + } + } + } else { + vec_.resize(row_nums); + for (int i = 0; i < row_nums; i++) { + vec_[i] = {(offsets_ptr[i + 1] - offsets_ptr[i]) / + knowhere::sparse::SparseRow< + SparseValueType>::element_size(), + reinterpret_cast(data + offsets_ptr[i]), + false}; + dim_ = std::max(dim_, vec_[i].dim()); + } } } diff --git a/internal/core/src/common/ChunkWriter.cpp b/internal/core/src/common/ChunkWriter.cpp index 358e7ff625..2d73be88f5 100644 --- a/internal/core/src/common/ChunkWriter.cpp +++ b/internal/core/src/common/ChunkWriter.cpp @@ -435,8 +435,10 @@ SparseFloatVectorChunkWriter::calculate_size( for (const auto& data : array_vec) { auto array = std::dynamic_pointer_cast(data); for (int64_t i = 0; i < array->length(); ++i) { - auto str = array->GetView(i); - size += str.size(); + if (!nullable_ || !array->IsNull(i)) { + auto str = array->GetView(i); + size += str.size(); + } } row_nums_ += array->length(); } @@ -459,8 +461,10 @@ SparseFloatVectorChunkWriter::write_to_target( for (const auto& data : array_vec) { auto array = std::dynamic_pointer_cast(data); for (int64_t i = 0; i < array->length(); ++i) { - auto str = array->GetView(i); - strs.emplace_back(str); + if (!nullable_ || !array->IsNull(i)) { + auto str = array->GetView(i); + strs.emplace_back(str); + } } if (nullable_) { null_bitmaps.emplace_back( @@ -478,9 +482,23 @@ SparseFloatVectorChunkWriter::write_to_target( std::vector offsets; offsets.reserve(offset_num); - for (const auto& str : strs) { - offsets.push_back(offset_start_pos); - offset_start_pos += str.size(); + if (nullable_) { + size_t str_idx = 0; + for (const auto& data : array_vec) { + auto array = std::dynamic_pointer_cast(data); + for (int i = 0; i < array->length(); i++) { + offsets.push_back(offset_start_pos); + if (!array->IsNull(i)) { + offset_start_pos += strs[str_idx].size(); + str_idx++; + } + } + } + } else { + for (const auto& str : strs) { + offsets.push_back(offset_start_pos); + offset_start_pos += str.size(); + } } offsets.push_back(offset_start_pos); @@ -524,22 +542,43 @@ create_chunk_writer(const FieldMeta& field_meta) { return std::make_shared>( dim, nullable); case milvus::DataType::VECTOR_FLOAT: + if (nullable) { + return std::make_shared< + NullableVectorChunkWriter>(dim, nullable); + } return std::make_shared< ChunkWriter>( dim, nullable); case milvus::DataType::VECTOR_BINARY: + if (nullable) { + return std::make_shared< + NullableVectorChunkWriter>(dim / 8, + nullable); + } return std::make_shared< ChunkWriter>( dim / 8, nullable); case milvus::DataType::VECTOR_FLOAT16: + if (nullable) { + return std::make_shared< + NullableVectorChunkWriter>(dim, nullable); + } return std::make_shared< ChunkWriter>( dim, nullable); case milvus::DataType::VECTOR_BFLOAT16: + if (nullable) { + return std::make_shared< + NullableVectorChunkWriter>(dim, nullable); + } return std::make_shared< ChunkWriter>( dim, nullable); case milvus::DataType::VECTOR_INT8: + if (nullable) { + return std::make_shared< + NullableVectorChunkWriter>(dim, nullable); + } return std::make_shared< ChunkWriter>( dim, nullable); diff --git a/internal/core/src/common/ChunkWriter.h b/internal/core/src/common/ChunkWriter.h index 68c588844a..d7c3485bac 100644 --- a/internal/core/src/common/ChunkWriter.h +++ b/internal/core/src/common/ChunkWriter.h @@ -129,6 +129,57 @@ class ChunkWriter final : public ChunkWriterBase { const int64_t dim_; }; +template +class NullableVectorChunkWriter final : public ChunkWriterBase { + public: + NullableVectorChunkWriter(int64_t dim, bool nullable) + : ChunkWriterBase(nullable), dim_(dim) { + Assert(nullable && "NullableVectorChunkWriter requires nullable=true"); + } + + std::pair + calculate_size(const arrow::ArrayVector& array_vec) override { + size_t size = 0; + size_t row_nums = 0; + + for (const auto& data : array_vec) { + row_nums += data->length(); + auto binary_array = + std::static_pointer_cast(data); + int64_t valid_count = data->length() - binary_array->null_count(); + size += valid_count * dim_ * sizeof(T); + } + + // null bitmap size + size += (row_nums + 7) / 8; + row_nums_ = row_nums; + return {size, row_nums}; + } + + void + write_to_target(const arrow::ArrayVector& array_vec, + const std::shared_ptr& target) override { + std::vector> null_bitmaps; + for (const auto& data : array_vec) { + null_bitmaps.emplace_back( + data->null_bitmap_data(), data->length(), data->offset()); + } + write_null_bit_maps(null_bitmaps, target); + + for (const auto& data : array_vec) { + auto binary_array = + std::static_pointer_cast(data); + auto data_offset = binary_array->value_offset(0); + auto data_ptr = binary_array->value_data()->data() + data_offset; + int64_t valid_count = data->length() - binary_array->null_count(); + target->write(data_ptr, valid_count * dim_ * sizeof(T)); + } + } + + private: + const int64_t dim_; +}; + template <> inline void ChunkWriter::write_to_target( diff --git a/internal/core/src/common/FieldData.cpp b/internal/core/src/common/FieldData.cpp index 4ed29bb422..0251cc8add 100644 --- a/internal/core/src/common/FieldData.cpp +++ b/internal/core/src/common/FieldData.cpp @@ -20,6 +20,7 @@ #include "arrow/array/array_binary.h" #include "arrow/chunked_array.h" #include "bitset/detail/element_wise.h" +#include "bitset/detail/popcount.h" #include "common/Array.h" #include "common/EasyAssert.h" #include "common/Exception.h" @@ -310,6 +311,18 @@ FieldDataImpl::FillFieldData( case DataType::VECTOR_BFLOAT16: case DataType::VECTOR_INT8: case DataType::VECTOR_BINARY: { + if (nullable_) { + auto binary_array = + std::dynamic_pointer_cast(array); + AssertInfo(binary_array != nullptr, + "nullable vector must use BinaryArray"); + auto data_offset = binary_array->value_offset(0); + return FillFieldData( + binary_array->value_data()->data() + data_offset, + binary_array->null_bitmap_data(), + binary_array->length(), + binary_array->offset()); + } auto array_info = GetDataInfoFromArray( @@ -321,6 +334,20 @@ FieldDataImpl::FillFieldData( "inconsistent data type"); auto arr = std::dynamic_pointer_cast(array); std::vector> values; + + if (nullable_) { + for (int64_t i = 0; i < element_count; ++i) { + if (arr->IsValid(i)) { + auto view = arr->GetString(i); + values.push_back( + CopyAndWrapSparseRow(view.data(), view.size())); + } + } + return FillFieldData(values.data(), + arr->null_bitmap_data(), + arr->length(), + arr->offset()); + } for (size_t index = 0; index < element_count; ++index) { auto view = arr->GetString(index); values.push_back( @@ -572,6 +599,96 @@ template class FieldDataImpl, true>; template class FieldDataImpl; +template +void +FieldDataVectorImpl::FillFieldData( + const void* field_data, + const uint8_t* valid_data, + ssize_t total_element_count, + ssize_t offset) { + AssertInfo(this->nullable_, "requires nullable to be true"); + if (total_element_count == 0) { + return; + } + + int64_t valid_count = 0; + if (valid_data) { + int64_t bit_start = offset; + int64_t bit_end = offset + total_element_count; + + // Handle head: unaligned bits before first full byte + int64_t first_full_byte = (bit_start + 7) / 8; + int64_t last_full_byte = bit_end / 8; + + // Process unaligned head bits + for (int64_t bit_idx = bit_start; + bit_idx < std::min(first_full_byte * 8, bit_end); + ++bit_idx) { + if ((valid_data[bit_idx >> 3] >> (bit_idx & 7)) & 1) { + valid_count++; + } + } + + // Process aligned full bytes with popcount + for (int64_t byte_idx = first_full_byte; byte_idx < last_full_byte; + ++byte_idx) { + valid_count += bitset::detail::PopCountHelper::count( + valid_data[byte_idx]); + } + + // Process unaligned tail bits + for (int64_t bit_idx = + std::max(last_full_byte * 8, first_full_byte * 8); + bit_idx < bit_end; + ++bit_idx) { + if ((valid_data[bit_idx >> 3] >> (bit_idx & 7)) & 1) { + valid_count++; + } + } + } else { + valid_count = total_element_count; + } + + std::lock_guard lck(this->tell_mutex_); + resize_field_data(this->length_ + total_element_count, + this->valid_count_ + valid_count); + + if (valid_data) { + bitset::detail::ElementWiseBitsetPolicy::op_copy( + valid_data, + offset, + this->valid_data_.data(), + this->length_, + total_element_count); + } + + // update logical to physical offset mapping + l2p_mapping_.build(this->valid_data_.data(), + this->valid_count_, + this->length_, + total_element_count, + valid_count); + + if (valid_count > 0) { + std::copy_n(static_cast(field_data), + valid_count * this->dim_, + this->data_.data() + this->valid_count_ * this->dim_); + this->valid_count_ += valid_count; + } + + this->null_count_ = total_element_count - valid_count; + this->length_ += total_element_count; +} + +// explicit instantiations for FieldDataVectorImpl +template class FieldDataVectorImpl; +template class FieldDataVectorImpl; +template class FieldDataVectorImpl; +template class FieldDataVectorImpl; +template class FieldDataVectorImpl; +template class FieldDataVectorImpl, + true>; + FieldDataPtr InitScalarFieldData(const DataType& type, bool nullable, int64_t cap_rows) { switch (type) { diff --git a/internal/core/src/common/FieldData.h b/internal/core/src/common/FieldData.h index ddfe3fbb81..fe9c87e742 100644 --- a/internal/core/src/common/FieldData.h +++ b/internal/core/src/common/FieldData.h @@ -19,6 +19,8 @@ #include #include #include +#include +#include #include @@ -133,24 +135,251 @@ class FieldData : public FieldDataVectorArrayImpl { DataType element_type_; }; +template +class FieldDataVectorImpl : public FieldDataImpl { + private: + struct LogicalToPhysicalMapping { + bool mapping{false}; + std::unordered_map l2p_map; + std::vector l2p_vec; + + int64_t + get_physical_offset(int64_t logical_offset) const { + if (!mapping) { + return logical_offset; + } + if (!l2p_map.empty()) { + auto it = l2p_map.find(logical_offset); + if (it != l2p_map.end()) { + return it->second; + } + return -1; + } + if (logical_offset < static_cast(l2p_vec.size())) { + return l2p_vec[logical_offset]; + } + return -1; + } + + void + build(const uint8_t* valid_data, + int64_t start_physical, + int64_t start_logical, + int64_t total_count, + int64_t valid_count) { + if (total_count == 0) { + return; + } + + mapping = true; + + // use map when valid ratio < 10% + bool use_map = (valid_count * 10 < total_count); + + if (use_map) { + int64_t physical_idx = start_physical; + for (int64_t i = 0; i < total_count; ++i) { + int64_t bit_pos = start_logical + i; + if (valid_data == nullptr || + ((valid_data[bit_pos >> 3] >> (bit_pos & 0x07)) & 1)) { + l2p_map[start_logical + i] = physical_idx++; + } + } + } else { + // resize l2p_vec if needed + int64_t required_size = start_logical + total_count; + if (static_cast(l2p_vec.size()) < required_size) { + l2p_vec.resize(required_size, -1); + } + int64_t physical_idx = start_physical; + for (int64_t i = 0; i < total_count; ++i) { + int64_t bit_pos = start_logical + i; + if (valid_data == nullptr || + ((valid_data[bit_pos >> 3] >> (bit_pos & 0x07)) & 1)) { + l2p_vec[start_logical + i] = physical_idx++; + } else { + l2p_vec[start_logical + i] = -1; + } + } + } + } + }; + + void + resize_field_data(int64_t num_rows, int64_t valid_count) { + Assert(this->nullable_); + std::lock_guard lck(this->num_rows_mutex_); + if (num_rows > this->num_rows_) { + this->num_rows_ = num_rows; + this->valid_data_.resize((num_rows + 7) / 8, 0x00); + } + if (valid_count > this->valid_count_) { + this->data_.resize(valid_count * this->dim_); + } + } + + LogicalToPhysicalMapping l2p_mapping_; + + public: + using FieldDataImpl::FieldDataImpl; + using FieldDataImpl::resize_field_data; + + void + FillFieldData(const void* field_data, + const uint8_t* valid_data, + ssize_t element_count, + ssize_t offset) override; + + const void* + RawValue(ssize_t offset) const override { + auto physical_offset = l2p_mapping_.get_physical_offset(offset); + if (physical_offset == -1) { + return nullptr; + } + return &this->data_[physical_offset * this->dim_]; + } + + int64_t + DataSize() const override { + auto dim = this->dim_; + if (this->nullable_) { + return sizeof(Type) * this->valid_count_ * dim; + } + return sizeof(Type) * this->length_ * dim; + } + + int64_t + DataSize(ssize_t offset) const override { + auto dim = this->dim_; + AssertInfo(offset < this->get_num_rows(), + "field data subscript out of range"); + return sizeof(Type) * dim; + } + + int64_t + get_valid_rows() const override { + if (this->nullable_) { + return this->valid_count_; + } + return this->get_num_rows(); + } +}; + +class FieldDataSparseVectorImpl + : public FieldDataVectorImpl, + true> { + using Base = + FieldDataVectorImpl, true>; + + public: + // Bring base class FillFieldData overloads into scope (for nullable support) + using Base::FillFieldData; + + explicit FieldDataSparseVectorImpl(DataType data_type, + bool nullable = false, + int64_t total_num_rows = 0) + : FieldDataVectorImpl, + true>( + /*dim=*/1, data_type, nullable, total_num_rows), + vec_dim_(0) { + AssertInfo(data_type == DataType::VECTOR_SPARSE_U32_F32, + "invalid data type for sparse vector"); + } + + int64_t + DataSize() const override { + int64_t data_size = 0; + size_t count = nullable_ ? valid_count_ : length_; + for (size_t i = 0; i < count; ++i) { + data_size += data_[i].data_byte_size(); + } + return data_size; + } + + int64_t + DataSize(ssize_t offset) const override { + AssertInfo(offset < get_num_rows(), + "field data subscript out of range"); + size_t count = nullable_ ? valid_count_ : length_; + AssertInfo( + offset < count, + "subscript position don't has valid value offset={}, count={}", + offset, + count); + return data_[offset].data_byte_size(); + } + + void + FillFieldData(const void* source, ssize_t element_count) override { + if (element_count == 0) { + return; + } + + std::lock_guard lck(tell_mutex_); + if (length_ + element_count > get_num_rows()) { + FieldDataImpl::resize_field_data(length_ + element_count); + } + auto ptr = + static_cast*>( + source); + for (int64_t i = 0; i < element_count; ++i) { + auto& row = ptr[i]; + vec_dim_ = std::max(vec_dim_, row.dim()); + } + std::copy_n(ptr, element_count, data_.data() + length_); + length_ += element_count; + } + + void + FillFieldData(const std::shared_ptr& array) override { + auto n = array->length(); + if (n == 0) { + return; + } + + std::lock_guard lck(tell_mutex_); + if (length_ + n > get_num_rows()) { + FieldDataImpl::resize_field_data(length_ + n); + } + + for (int64_t i = 0; i < array->length(); ++i) { + auto view = array->GetView(i); + auto& row = data_[length_ + i]; + row = CopyAndWrapSparseRow(view.data(), view.size()); + vec_dim_ = std::max(vec_dim_, row.dim()); + } + length_ += n; + } + + int64_t + Dim() const { + return vec_dim_; + } + + private: + int64_t vec_dim_ = 0; +}; + template <> -class FieldData : public FieldDataImpl { +class FieldData : public FieldDataVectorImpl { public: explicit FieldData(int64_t dim, DataType data_type, + bool nullable, int64_t buffered_num_rows = 0) - : FieldDataImpl::FieldDataImpl( - dim, data_type, false, buffered_num_rows) { + : FieldDataVectorImpl::FieldDataVectorImpl( + dim, data_type, nullable, buffered_num_rows) { } }; template <> -class FieldData : public FieldDataImpl { +class FieldData : public FieldDataVectorImpl { public: explicit FieldData(int64_t dim, DataType data_type, + bool nullable, int64_t buffered_num_rows = 0) - : FieldDataImpl(dim / 8, data_type, false, buffered_num_rows), + : FieldDataVectorImpl(dim / 8, data_type, nullable, buffered_num_rows), binary_dim_(dim) { Assert(dim % 8 == 0); } @@ -165,43 +394,48 @@ class FieldData : public FieldDataImpl { }; template <> -class FieldData : public FieldDataImpl { +class FieldData : public FieldDataVectorImpl { public: explicit FieldData(int64_t dim, DataType data_type, + bool nullable, int64_t buffered_num_rows = 0) - : FieldDataImpl::FieldDataImpl( - dim, data_type, false, buffered_num_rows) { + : FieldDataVectorImpl::FieldDataVectorImpl( + dim, data_type, nullable, buffered_num_rows) { } }; template <> -class FieldData : public FieldDataImpl { +class FieldData : public FieldDataVectorImpl { public: explicit FieldData(int64_t dim, DataType data_type, + bool nullable, int64_t buffered_num_rows = 0) - : FieldDataImpl::FieldDataImpl( - dim, data_type, false, buffered_num_rows) { + : FieldDataVectorImpl::FieldDataVectorImpl( + dim, data_type, nullable, buffered_num_rows) { } }; template <> class FieldData : public FieldDataSparseVectorImpl { public: - explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0) - : FieldDataSparseVectorImpl(data_type, buffered_num_rows) { + explicit FieldData(DataType data_type, + bool nullable = false, + int64_t buffered_num_rows = 0) + : FieldDataSparseVectorImpl(data_type, nullable, buffered_num_rows) { } }; template <> -class FieldData : public FieldDataImpl { +class FieldData : public FieldDataVectorImpl { public: explicit FieldData(int64_t dim, DataType data_type, + bool nullable, int64_t buffered_num_rows = 0) - : FieldDataImpl::FieldDataImpl( - dim, data_type, false, buffered_num_rows) { + : FieldDataVectorImpl::FieldDataVectorImpl( + dim, data_type, nullable, buffered_num_rows) { } }; diff --git a/internal/core/src/common/FieldDataInterface.h b/internal/core/src/common/FieldDataInterface.h index 54886e7ab1..39deab12f9 100644 --- a/internal/core/src/common/FieldDataInterface.h +++ b/internal/core/src/common/FieldDataInterface.h @@ -127,6 +127,9 @@ class FieldDataBase { virtual bool is_valid(ssize_t offset) const = 0; + virtual int64_t + get_valid_rows() const = 0; + protected: const DataType data_type_; const bool nullable_; @@ -309,6 +312,12 @@ class FieldBitsetImpl : public FieldDataBase { "is_valid(ssize_t offset) not implemented for bitset"); } + int64_t + get_valid_rows() const override { + ThrowInfo(NotImplemented, + "get_valid_rows() not implemented for bitset"); + } + private: FixedVector data_{}; // capacity that data_ can store @@ -340,9 +349,6 @@ class FieldDataImpl : public FieldDataBase { dim_(is_type_entire_row ? 1 : dim) { data_.resize(num_rows_ * dim_); if (nullable) { - if (IsVectorDataType(data_type)) { - ThrowInfo(NotImplemented, "vector type not support null"); - } valid_data_.resize((num_rows_ + 7) / 8, 0xFF); } } @@ -492,6 +498,12 @@ class FieldDataImpl : public FieldDataBase { return num_rows_; } + int64_t + get_valid_rows() const override { + std::shared_lock lck(tell_mutex_); + return static_cast(length_) - null_count_; + } + void resize_field_data(int64_t num_rows) { std::lock_guard lck(num_rows_mutex_); @@ -540,13 +552,12 @@ class FieldDataImpl : public FieldDataBase { FixedVector valid_data_{}; // number of elements data_ can hold int64_t num_rows_; + size_t valid_count_{0}; mutable std::shared_mutex num_rows_mutex_; int64_t null_count_{0}; // number of actual elements in data_ size_t length_{}; mutable std::shared_mutex tell_mutex_; - - private: const ssize_t dim_; }; @@ -803,90 +814,6 @@ class FieldDataJsonImpl : public FieldDataImpl { } }; -class FieldDataSparseVectorImpl - : public FieldDataImpl, true> { - public: - explicit FieldDataSparseVectorImpl(DataType data_type, - int64_t total_num_rows = 0) - : FieldDataImpl, true>( - /*dim=*/1, data_type, false, total_num_rows), - vec_dim_(0) { - AssertInfo(data_type == DataType::VECTOR_SPARSE_U32_F32, - "invalid data type for sparse vector"); - } - - int64_t - DataSize() const override { - int64_t data_size = 0; - for (size_t i = 0; i < length(); ++i) { - data_size += data_[i].data_byte_size(); - } - return data_size; - } - - int64_t - DataSize(ssize_t offset) const override { - AssertInfo(offset < get_num_rows(), - "field data subscript out of range"); - AssertInfo(offset < length(), - "subscript position don't has valid value"); - return data_[offset].data_byte_size(); - } - - // source is a pointer to element_count of - // knowhere::sparse::SparseRow - void - FillFieldData(const void* source, ssize_t element_count) override { - if (element_count == 0) { - return; - } - - std::lock_guard lck(tell_mutex_); - if (length_ + element_count > get_num_rows()) { - resize_field_data(length_ + element_count); - } - auto ptr = - static_cast*>( - source); - for (int64_t i = 0; i < element_count; ++i) { - auto& row = ptr[i]; - vec_dim_ = std::max(vec_dim_, row.dim()); - } - std::copy_n(ptr, element_count, data_.data() + length_); - length_ += element_count; - } - - // each binary in array is a knowhere::sparse::SparseRow - void - FillFieldData(const std::shared_ptr& array) override { - auto n = array->length(); - if (n == 0) { - return; - } - - std::lock_guard lck(tell_mutex_); - if (length_ + n > get_num_rows()) { - resize_field_data(length_ + n); - } - - for (int64_t i = 0; i < array->length(); ++i) { - auto view = array->GetView(i); - auto& row = data_[length_ + i]; - row = CopyAndWrapSparseRow(view.data(), view.size()); - vec_dim_ = std::max(vec_dim_, row.dim()); - } - length_ += n; - } - - int64_t - Dim() const { - return vec_dim_; - } - - private: - int64_t vec_dim_ = 0; -}; - class FieldDataArrayImpl : public FieldDataImpl { public: explicit FieldDataArrayImpl(DataType data_type, diff --git a/internal/core/src/common/FieldMeta.cpp b/internal/core/src/common/FieldMeta.cpp index c1e96799a4..743859cccb 100644 --- a/internal/core/src/common/FieldMeta.cpp +++ b/internal/core/src/common/FieldMeta.cpp @@ -168,6 +168,8 @@ FieldMeta::ParseFrom(const milvus::proto::schema::FieldSchema& schema_proto) { } if (IsVectorDataType(data_type)) { + AssertInfo(!default_value.has_value(), + "vector fields do not support default values"); auto type_map = RepeatedKeyValToMap(schema_proto.type_params()); auto index_map = RepeatedKeyValToMap(schema_proto.index_params()); @@ -183,12 +185,17 @@ FieldMeta::ParseFrom(const milvus::proto::schema::FieldSchema& schema_proto) { data_type, dim, std::nullopt, - false, + nullable, default_value}; } auto metric_type = index_map.at("metric_type"); - return FieldMeta{ - name, field_id, data_type, dim, metric_type, false, default_value}; + return FieldMeta{name, + field_id, + data_type, + dim, + metric_type, + nullable, + default_value}; } if (IsStringDataType(data_type)) { diff --git a/internal/core/src/common/FieldMeta.h b/internal/core/src/common/FieldMeta.h index 62fa9fe7d5..0db6527b5e 100644 --- a/internal/core/src/common/FieldMeta.h +++ b/internal/core/src/common/FieldMeta.h @@ -125,7 +125,8 @@ class FieldMeta { vector_info_(VectorInfo{dim, std::move(metric_type)}), default_value_(std::move(default_value)) { Assert(IsVectorDataType(type_)); - Assert(!nullable); + Assert(!default_value_.has_value() && + "vector fields do not support default values"); } // array of vector type diff --git a/internal/core/src/common/OffsetMapping.cpp b/internal/core/src/common/OffsetMapping.cpp new file mode 100644 index 0000000000..46ece9a14a --- /dev/null +++ b/internal/core/src/common/OffsetMapping.cpp @@ -0,0 +1,194 @@ +#include "common/OffsetMapping.h" + +namespace milvus { + +void +OffsetMapping::Build(const bool* valid_data, + int64_t total_count, + int64_t start_logical, + int64_t start_physical) { + if (total_count == 0 || valid_data == nullptr) { + return; + } + + std::unique_lock lck(mutex_); + enabled_ = true; + total_count_ = start_logical + total_count; + + // Count valid elements first + int64_t valid_count = 0; + for (int64_t i = 0; i < total_count; ++i) { + if (valid_data[i]) { + valid_count++; + } + } + + // Auto-select storage mode: use map when valid ratio < 10% + use_map_ = (valid_count * 10 < total_count); + + if (use_map_) { + // Map mode: only store valid entries + int64_t physical_idx = start_physical; + for (int64_t i = 0; i < total_count; ++i) { + if (valid_data[i]) { + l2p_map_[start_logical + i] = physical_idx; + p2l_map_[physical_idx] = start_logical + i; + physical_idx++; + } + } + } else { + // Vec mode: store all entries + int64_t required_size = start_logical + total_count; + if (static_cast(l2p_vec_.size()) < required_size) { + l2p_vec_.resize(required_size, -1); + } + + int64_t physical_idx = start_physical; + for (int64_t i = 0; i < total_count; ++i) { + if (valid_data[i]) { + l2p_vec_[start_logical + i] = physical_idx; + if (physical_idx >= static_cast(p2l_vec_.size())) { + p2l_vec_.resize(physical_idx + 1, -1); + } + p2l_vec_[physical_idx] = start_logical + i; + physical_idx++; + } else { + l2p_vec_[start_logical + i] = -1; + } + } + } + + valid_count_ += valid_count; +} + +void +OffsetMapping::BuildIncremental(const bool* valid_data, + int64_t count, + int64_t start_logical, + int64_t start_physical) { + if (count == 0 || valid_data == nullptr) { + return; + } + + std::unique_lock lck(mutex_); + enabled_ = true; + total_count_ = start_logical + count; + + // Incremental builds always use vec mode + if (use_map_ && !l2p_map_.empty()) { + // Convert from map to vec if needed + int64_t max_logical = 0; + for (const auto& [logical, physical] : l2p_map_) { + if (logical > max_logical) { + max_logical = logical; + } + } + l2p_vec_.resize(max_logical + 1, -1); + for (const auto& [logical, physical] : l2p_map_) { + l2p_vec_[logical] = physical; + } + int64_t max_physical = 0; + for (const auto& [physical, logical] : p2l_map_) { + if (physical > max_physical) { + max_physical = physical; + } + } + p2l_vec_.resize(max_physical + 1, -1); + for (const auto& [physical, logical] : p2l_map_) { + p2l_vec_[physical] = logical; + } + l2p_map_.clear(); + p2l_map_.clear(); + use_map_ = false; + } + + // Resize l2p_vec if needed + int64_t required_size = start_logical + count; + if (static_cast(l2p_vec_.size()) < required_size) { + l2p_vec_.resize(required_size, -1); + } + + int64_t physical_idx = start_physical; + for (int64_t i = 0; i < count; ++i) { + if (valid_data[i]) { + l2p_vec_[start_logical + i] = physical_idx; + if (physical_idx >= static_cast(p2l_vec_.size())) { + p2l_vec_.resize(physical_idx + 1, -1); + } + p2l_vec_[physical_idx] = start_logical + i; + physical_idx++; + valid_count_++; + } else { + l2p_vec_[start_logical + i] = -1; + } + } +} + +int64_t +OffsetMapping::GetPhysicalOffset(int64_t logical_offset) const { + std::shared_lock lck(mutex_); + if (!enabled_) { + return logical_offset; + } + if (use_map_) { + auto it = l2p_map_.find(static_cast(logical_offset)); + if (it != l2p_map_.end()) { + return it->second; + } + return -1; + } + if (logical_offset < static_cast(l2p_vec_.size())) { + return l2p_vec_[logical_offset]; + } + return -1; +} + +int64_t +OffsetMapping::GetLogicalOffset(int64_t physical_offset) const { + std::shared_lock lck(mutex_); + if (!enabled_) { + return physical_offset; + } + if (use_map_) { + auto it = p2l_map_.find(static_cast(physical_offset)); + if (it != p2l_map_.end()) { + return it->second; + } + return -1; + } + if (physical_offset < static_cast(p2l_vec_.size())) { + return p2l_vec_[physical_offset]; + } + return -1; +} + +bool +OffsetMapping::IsValid(int64_t logical_offset) const { + return GetPhysicalOffset(logical_offset) >= 0; +} + +int64_t +OffsetMapping::GetValidCount() const { + std::shared_lock lck(mutex_); + return valid_count_; +} + +bool +OffsetMapping::IsEnabled() const { + std::shared_lock lck(mutex_); + return enabled_; +} + +int64_t +OffsetMapping::GetNextPhysicalOffset() const { + std::shared_lock lck(mutex_); + return valid_count_; +} + +int64_t +OffsetMapping::GetTotalCount() const { + std::shared_lock lck(mutex_); + return total_count_; +} + +} // namespace milvus diff --git a/internal/core/src/common/OffsetMapping.h b/internal/core/src/common/OffsetMapping.h new file mode 100644 index 0000000000..ed91af955a --- /dev/null +++ b/internal/core/src/common/OffsetMapping.h @@ -0,0 +1,96 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace milvus { + +// Bidirectional offset mapping for nullable vector storage +// Maps between logical offsets (with nulls) and physical offsets (only valid data) +// Supports two storage modes: +// - vec mode: uses vector for both L2P and P2L, efficient when valid ratio >= 10% +// - map mode: uses unordered_map for L2P, efficient when valid ratio < 10% +class OffsetMapping { + public: + OffsetMapping() = default; + + // Build mapping from valid_data (bool array format) + // If use_vec is not specified, auto-select based on valid ratio (< 10% uses map) + void + Build(const bool* valid_data, + int64_t total_count, + int64_t start_logical = 0, + int64_t start_physical = 0); + + // Build mapping incrementally (always uses vec mode for incremental builds) + void + BuildIncremental(const bool* valid_data, + int64_t count, + int64_t start_logical, + int64_t start_physical); + + // Get physical offset from logical offset. Returns -1 if null. + int64_t + GetPhysicalOffset(int64_t logical_offset) const; + + // Get logical offset from physical offset. Returns -1 if not found. + int64_t + GetLogicalOffset(int64_t physical_offset) const; + + // Check if a logical offset is valid (not null) + bool + IsValid(int64_t logical_offset) const; + + // Get count of valid (non-null) elements + int64_t + GetValidCount() const; + + // Check if mapping is enabled + bool + IsEnabled() const; + + // Get next physical offset (for incremental builds) + int64_t + GetNextPhysicalOffset() const; + + // Get total logical count (including nulls) + int64_t + GetTotalCount() const; + + private: + bool enabled_{false}; + bool use_map_{false}; // true: use map for L2P, false: use vec + + // Vec mode storage (uses int32_t to save memory) + std::vector l2p_vec_; // logical -> physical, -1 means null + std::vector p2l_vec_; // physical -> logical + + // Map mode storage (for sparse valid data) + std::unordered_map l2p_map_; // logical -> physical + std::unordered_map p2l_map_; // physical -> logical + + int64_t valid_count_{0}; + int64_t total_count_{0}; // total logical count (including nulls) + mutable std::shared_mutex mutex_; +}; + +} // namespace milvus diff --git a/internal/core/src/common/QueryResult.h b/internal/core/src/common/QueryResult.h index 0e49ce6f3f..96851ace81 100644 --- a/internal/core/src/common/QueryResult.h +++ b/internal/core/src/common/QueryResult.h @@ -29,6 +29,8 @@ #include "common/FieldMeta.h" #include "common/ArrayOffsets.h" +#include "common/OffsetMapping.h" +#include "query/Utils.h" #include "pb/schema.pb.h" #include "knowhere/index/index_node.h" @@ -156,9 +158,10 @@ class VectorIterator { class ChunkMergeIterator : public VectorIterator { public: ChunkMergeIterator(int chunk_count, + const milvus::OffsetMapping& offset_mapping, const std::vector& total_rows_until_chunk = {}, bool larger_is_closer = false) - : total_rows_until_chunk_(total_rows_until_chunk), + : offset_mapping_(&offset_mapping), larger_is_closer_(larger_is_closer), heap_(OffsetDisPairComparator(larger_is_closer)) { iterators_.reserve(chunk_count); @@ -180,7 +183,11 @@ class ChunkMergeIterator : public VectorIterator { origin_pair, top->GetIteratorIdx()); heap_.push(off_dis_pair); } - return top->GetOffDis(); + auto result = top->GetOffDis(); + if (offset_mapping_ != nullptr) { + result.first = offset_mapping_->GetLogicalOffset(result.first); + } + return result; } return std::nullopt; } @@ -231,6 +238,7 @@ class ChunkMergeIterator : public VectorIterator { OffsetDisPairComparator> heap_; bool sealed = false; + const milvus::OffsetMapping* offset_mapping_ = nullptr; std::vector total_rows_until_chunk_; bool larger_is_closer_ = false; //currently, ChunkMergeIterator is guaranteed to be used serially without concurrent problem, in the future @@ -258,6 +266,7 @@ struct SearchResult { int chunk_count, const std::vector& total_rows_until_chunk, const std::vector& kw_iterators, + const milvus::OffsetMapping& offset_mapping, bool larger_is_closer = false) { AssertInfo(kw_iterators.size() == nq * chunk_count, "kw_iterators count:{} is not equal to nq*chunk_count:{}, " @@ -269,8 +278,11 @@ struct SearchResult { for (int i = 0, vec_iter_idx = 0; i < kw_iterators.size(); i++) { vec_iter_idx = vec_iter_idx % nq; if (vector_iterators.size() < nq) { - auto chunk_merge_iter = std::make_shared( - chunk_count, total_rows_until_chunk, larger_is_closer); + auto chunk_merge_iter = + std::make_shared(chunk_count, + offset_mapping, + total_rows_until_chunk, + larger_is_closer); vector_iterators.emplace_back(chunk_merge_iter); } const auto& kw_iterator = kw_iterators[i]; diff --git a/internal/core/src/common/Schema.h b/internal/core/src/common/Schema.h index c3e1f0d962..33c8f2aa99 100644 --- a/internal/core/src/common/Schema.h +++ b/internal/core/src/common/Schema.h @@ -95,7 +95,8 @@ class Schema { AddDebugField(const std::string& name, DataType data_type, int64_t dim, - std::optional metric_type) { + std::optional metric_type, + bool nullable = false) { auto field_id = FieldId(debug_id); debug_id++; auto field_meta = FieldMeta(FieldName(name), @@ -103,7 +104,7 @@ class Schema { data_type, dim, metric_type, - false, + nullable, std::nullopt); this->AddField(std::move(field_meta)); return field_id; @@ -225,7 +226,7 @@ class Schema { std::optional metric_type, bool nullable) { auto field_meta = FieldMeta( - name, id, data_type, dim, metric_type, false, std::nullopt); + name, id, data_type, dim, metric_type, nullable, std::nullopt); this->AddField(std::move(field_meta)); } diff --git a/internal/core/src/common/Utils.h b/internal/core/src/common/Utils.h index 275618a0d1..a891740c3e 100644 --- a/internal/core/src/common/Utils.h +++ b/internal/core/src/common/Utils.h @@ -304,7 +304,9 @@ CopyAndWrapSparseRow(const void* data, template std::unique_ptr[]> SparseBytesToRows(const Iterable& rows, const bool validate = false) { - AssertInfo(rows.size() > 0, "at least 1 sparse row should be provided"); + if (rows.size() == 0) { + return nullptr; + } auto res = std::make_unique[]>( rows.size()); for (size_t i = 0; i < rows.size(); ++i) { diff --git a/internal/core/src/exec/operator/Utils.h b/internal/core/src/exec/operator/Utils.h index a840e2d180..30ba1606eb 100644 --- a/internal/core/src/exec/operator/Utils.h +++ b/internal/core/src/exec/operator/Utils.h @@ -57,7 +57,12 @@ PrepareVectorIteratorsFromIndex(const SearchInfo& search_info, bool larger_is_closer = PositivelyRelated(search_info.metric_type_); search_result.AssembleChunkVectorIterators( - nq, 1, {0}, iterators_val.value(), larger_is_closer); + nq, + 1, + {0}, + iterators_val.value(), + index.GetOffsetMapping(), + larger_is_closer); } else { std::string operator_type = ""; if (search_info.group_by_field_id_.has_value()) { diff --git a/internal/core/src/index/VectorDiskIndex.cpp b/internal/core/src/index/VectorDiskIndex.cpp index b2165d66bf..9f3cfabe37 100644 --- a/internal/core/src/index/VectorDiskIndex.cpp +++ b/internal/core/src/index/VectorDiskIndex.cpp @@ -120,6 +120,26 @@ VectorDiskAnnIndex::Load(milvus::tracer::TraceContext ctx, "failed to Deserialize index, " + KnowhereStatusString(stat)); span_load_engine->End(); + auto local_chunk_manager = + storage::LocalChunkManagerSingleton::GetInstance().GetChunkManager(); + auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix(); + + auto valid_data_path = local_index_path_prefix + "/" + VALID_DATA_KEY; + if (local_chunk_manager->Exist(valid_data_path)) { + size_t count; + local_chunk_manager->Read(valid_data_path, 0, &count, sizeof(size_t)); + size_t byte_size = (count + 7) / 8; + std::vector valid_bitmap(byte_size); + local_chunk_manager->Read( + valid_data_path, sizeof(size_t), valid_bitmap.data(), byte_size); + // Convert bitmap to bool array + std::unique_ptr valid_data(new bool[count]); + for (size_t i = 0; i < count; ++i) { + valid_data[i] = (valid_bitmap[i / 8] >> (i % 8)) & 1; + } + BuildValidData(valid_data.get(), count); + } + SetDim(index_.Dim()); } @@ -298,6 +318,23 @@ VectorDiskAnnIndex::BuildWithDataset(const DatasetPtr& dataset, if (stat != knowhere::Status::success) ThrowInfo(ErrorCode::IndexBuildError, "failed to build index, " + KnowhereStatusString(stat)); + + if (HasValidData()) { + auto valid_data_path = local_index_path_prefix + "/" + VALID_DATA_KEY; + size_t count = offset_mapping_.GetTotalCount(); + local_chunk_manager->Write(valid_data_path, 0, &count, sizeof(size_t)); + size_t byte_size = (count + 7) / 8; + std::vector packed_data(byte_size, 0); + for (size_t i = 0; i < count; ++i) { + if (offset_mapping_.IsValid(i)) { + packed_data[i / 8] |= (1 << (i % 8)); + } + } + local_chunk_manager->Write( + valid_data_path, sizeof(size_t), packed_data.data(), byte_size); + file_manager_->AddFile(valid_data_path); + } + local_chunk_manager->RemoveDir( storage::GetSegmentRawDataPathPrefix(local_chunk_manager, segment_id)); diff --git a/internal/core/src/index/VectorIndex.h b/internal/core/src/index/VectorIndex.h index 95e7cdb6af..5019f70698 100644 --- a/internal/core/src/index/VectorIndex.h +++ b/internal/core/src/index/VectorIndex.h @@ -27,6 +27,7 @@ #include "index/Index.h" #include "common/Types.h" #include "common/BitsetView.h" +#include "common/OffsetMapping.h" #include "common/QueryResult.h" #include "common/QueryInfo.h" #include "common/OpContext.h" @@ -34,6 +35,10 @@ namespace milvus::index { +// valid data keys for nullable vector index serialization +constexpr const char* VALID_DATA_KEY = "valid_data"; +constexpr const char* VALID_DATA_COUNT_KEY = "valid_data_count"; + class VectorIndex : public IndexBase { public: explicit VectorIndex(const IndexType& index_type, @@ -145,6 +150,56 @@ class VectorIndex : public IndexBase { return search_cfg; } + void + UpdateValidData(const bool* valid_data, int64_t count) { + offset_mapping_.BuildIncremental( + valid_data, + count, + offset_mapping_.GetTotalCount(), + offset_mapping_.GetNextPhysicalOffset()); + } + + void + BuildValidData(const bool* valid_data, int64_t total_count) { + offset_mapping_.Build(valid_data, total_count); + } + + bool + IsRowValid(int64_t logical_offset) const { + if (!offset_mapping_.IsEnabled()) { + return true; + } + return offset_mapping_.IsValid(logical_offset); + } + + bool + HasValidData() const { + return offset_mapping_.IsEnabled(); + } + + int64_t + GetValidCount() const { + return offset_mapping_.GetValidCount(); + } + + int64_t + GetPhysicalOffset(int64_t logical_offset) const { + return offset_mapping_.GetPhysicalOffset(logical_offset); + } + + int64_t + GetLogicalOffset(int64_t physical_offset) const { + return offset_mapping_.GetLogicalOffset(physical_offset); + } + + const milvus::OffsetMapping& + GetOffsetMapping() const { + return offset_mapping_; + } + + protected: + milvus::OffsetMapping offset_mapping_; + private: MetricType metric_type_; int64_t dim_; diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index c5822bc91c..2f00cc2e4b 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -146,6 +146,27 @@ VectorMemIndex::Serialize(const Config& config) { ThrowInfo(ErrorCode::UnexpectedError, "failed to serialize index: {}", KnowhereStatusString(stat)); + + // Serialize valid_data from offset_mapping if enabled + if (offset_mapping_.IsEnabled()) { + auto total_count = offset_mapping_.GetTotalCount(); + + std::shared_ptr count_buf(new uint8_t[sizeof(size_t)]); + size_t count = static_cast(total_count); + std::memcpy(count_buf.get(), &count, sizeof(size_t)); + ret.Append(VALID_DATA_COUNT_KEY, count_buf, sizeof(size_t)); + + size_t byte_size = (count + 7) / 8; + std::shared_ptr data(new uint8_t[byte_size]); + std::memset(data.get(), 0, byte_size); + for (size_t i = 0; i < count; ++i) { + if (offset_mapping_.IsValid(i)) { + data[i / 8] |= (1 << (i % 8)); + } + } + ret.Append(VALID_DATA_KEY, data, byte_size); + } + Disassemble(ret); return ret; @@ -160,6 +181,25 @@ VectorMemIndex::LoadWithoutAssemble(const BinarySet& binary_set, ThrowInfo(ErrorCode::UnexpectedError, "failed to Deserialize index: {}", KnowhereStatusString(stat)); + + // Deserialize valid_data bitmap and rebuild offset_mapping + if (binary_set.Contains(VALID_DATA_COUNT_KEY) && + binary_set.Contains(VALID_DATA_KEY)) { + knowhere::BinaryPtr ptr; + ptr = binary_set.GetByName(VALID_DATA_COUNT_KEY); + size_t count; + std::memcpy(&count, ptr->data.get(), sizeof(size_t)); + + ptr = binary_set.GetByName(VALID_DATA_KEY); + // Convert bitmap to bool array + std::unique_ptr valid_data(new bool[count]); + auto bitmap = ptr->data.get(); + for (size_t i = 0; i < count; ++i) { + valid_data[i] = (bitmap[i / 8] >> (i % 8)) & 1; + } + BuildValidData(valid_data.get(), count); + } + SetDim(index_.Dim()); } @@ -339,19 +379,48 @@ VectorMemIndex::Build(const Config& config) { build_config.update(config); build_config.erase(INSERT_FILES_KEY); build_config.erase(VEC_OPT_FIELDS); - if (!IndexIsSparse(GetIndexType())) { - int64_t total_size = 0; - int64_t total_num_rows = 0; - int64_t dim = 0; - for (auto data : field_datas) { - total_size += data->Size(); - total_num_rows += data->get_num_rows(); + bool nullable = false; + int64_t total_valid_rows = 0; + int64_t total_num_rows = 0; + for (auto data : field_datas) { + auto num_rows = data->get_num_rows(); + auto valid_rows = data->get_valid_rows(); + total_valid_rows += valid_rows; + total_num_rows += num_rows; + if (data->IsNullable()) { + nullable = true; + } + } + std::unique_ptr valid_data; + if (nullable) { + valid_data.reset(new bool[total_num_rows]); + int64_t chunk_offset = 0; + for (auto data : field_datas) { + auto rows = data->get_num_rows(); + // Copy valid data from FieldData (bitmap format to bool array) + auto src_bitmap = data->ValidData(); + for (int64_t i = 0; i < rows; ++i) { + valid_data[chunk_offset + i] = + (src_bitmap[i >> 3] >> (i & 7)) & 1; + } + chunk_offset += rows; + } + } + + if (!IndexIsSparse(GetIndexType())) { + int64_t dim = 0; + int64_t total_size = 0; + for (auto data : field_datas) { AssertInfo(dim == 0 || dim == data->get_dim(), "inconsistent dim value between field datas!"); dim = data->get_dim(); + if (elem_type_ == DataType::NONE) { + total_size += data->DataSize(); + } else { + total_size += data->Size(); + } } - auto buf = std::shared_ptr(new uint8_t[total_size]); size_t lim_offset = 0; @@ -362,8 +431,9 @@ VectorMemIndex::Build(const Config& config) { if (elem_type_ == DataType::NONE) { // TODO: avoid copying for (auto data : field_datas) { - std::memcpy(buf.get() + offset, data->Data(), data->Size()); - offset += data->Size(); + auto valid_size = data->DataSize(); + std::memcpy(buf.get() + offset, data->Data(), valid_size); + offset += valid_size; data.reset(); } } else { @@ -396,12 +466,12 @@ VectorMemIndex::Build(const Config& config) { data.reset(); } - total_num_rows = lim_offset; + total_valid_rows = lim_offset; } field_datas.clear(); - auto dataset = GenDataset(total_num_rows, dim, buf.get()); + auto dataset = GenDataset(total_valid_rows, dim, buf.get()); if (!scalar_info.empty()) { dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info)); } @@ -410,12 +480,13 @@ VectorMemIndex::Build(const Config& config) { const_cast(offsets.data())); } BuildWithDataset(dataset, build_config); + if (nullable) { + BuildValidData(valid_data.get(), total_num_rows); + } } else { // sparse - int64_t total_rows = 0; int64_t dim = 0; for (auto field_data : field_datas) { - total_rows += field_data->Length(); dim = std::max( dim, std::dynamic_pointer_cast>( @@ -423,28 +494,31 @@ VectorMemIndex::Build(const Config& config) { ->Dim()); } std::vector> vec( - total_rows); + total_valid_rows); int64_t offset = 0; for (auto field_data : field_datas) { auto ptr = static_cast< const knowhere::sparse::SparseRow*>( field_data->Data()); AssertInfo(ptr, "failed to cast field data to sparse rows"); - for (size_t i = 0; i < field_data->Length(); ++i) { + for (size_t i = 0; i < field_data->get_valid_rows(); ++i) { // this does a deep copy of field_data's data. // TODO: avoid copying by enforcing field data to give up // ownership. - AssertInfo(dim >= ptr[i].dim(), "bad dim"); + dim = std::max(dim, static_cast(ptr[i].dim())); vec[offset + i] = ptr[i]; } - offset += field_data->Length(); + offset += field_data->get_valid_rows(); } - auto dataset = GenDataset(total_rows, dim, vec.data()); + auto dataset = GenDataset(total_valid_rows, dim, vec.data()); dataset->SetIsSparse(true); if (!scalar_info.empty()) { dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info)); } BuildWithDataset(dataset, build_config); + if (nullable) { + BuildValidData(valid_data.get(), total_num_rows); + } } } @@ -572,6 +646,10 @@ VectorMemIndex::GetVector(const DatasetPtr dataset) const { template std::unique_ptr[]> VectorMemIndex::GetSparseVector(const DatasetPtr dataset) const { + if (dataset->GetRows() == 0) { + return nullptr; + } + auto res = index_.GetVectorByIds(dataset); if (!res.has_value()) { ThrowInfo(ErrorCode::UnexpectedError, @@ -646,6 +724,8 @@ void VectorMemIndex::LoadFromFile(const Config& config) { LOG_INFO("load with slice meta: {}", !slice_meta_filepath.empty()); std::chrono::duration load_duration_sum; std::chrono::duration write_disk_duration_sum; + std::unique_ptr valid_data_count_codec; + std::unique_ptr valid_data_codec; // load files in two parts: // 1. EMB_LIST_META: Written separately to embedding_list_meta_writer_ptr (if embedding list type) // 2. All other binaries: Merged and written to file_writer, forming a unified index file for knowhere @@ -683,6 +763,10 @@ void VectorMemIndex::LoadFromFile(const Config& config) { embedding_list_meta_writer_ptr) { embedding_list_meta_writer_ptr->Write( data->PayloadData(), data->PayloadSize()); + } else if (prefix == VALID_DATA_COUNT_KEY) { + valid_data_count_codec = std::move(data); + } else if (prefix == VALID_DATA_KEY) { + valid_data_codec = std::move(data); } else { file_writer.Write(data->PayloadData(), data->PayloadSize()); @@ -724,6 +808,10 @@ void VectorMemIndex::LoadFromFile(const Config& config) { embedding_list_meta_writer_ptr) { embedding_list_meta_writer_ptr->Write( index_data->PayloadData(), index_data->PayloadSize()); + } else if (prefix == VALID_DATA_COUNT_KEY) { + valid_data_count_codec = std::move(index_data); + } else if (prefix == VALID_DATA_KEY) { + valid_data_codec = std::move(index_data); } else { file_writer.Write(index_data->PayloadData(), index_data->PayloadSize()); @@ -768,6 +856,20 @@ void VectorMemIndex::LoadFromFile(const Config& config) { auto dim = index_.Dim(); this->SetDim(index_.Dim()); + // Restore valid_data for nullable vector support + if (valid_data_count_codec && valid_data_codec) { + size_t count; + std::memcpy( + &count, valid_data_count_codec->PayloadData(), sizeof(size_t)); + + std::unique_ptr valid_data(new bool[count]); + auto bitmap = valid_data_codec->PayloadData(); + for (size_t i = 0; i < count; ++i) { + valid_data[i] = (bitmap[i / 8] >> (i % 8)) & 1; + } + BuildValidData(valid_data.get(), count); + } + this->mmap_file_raii_ = std::make_unique(local_filepath.value()); LOG_INFO( diff --git a/internal/core/src/indexbuilder/IndexCreatorBase.h b/internal/core/src/indexbuilder/IndexCreatorBase.h index cfe74095eb..4476c65cfc 100644 --- a/internal/core/src/indexbuilder/IndexCreatorBase.h +++ b/internal/core/src/indexbuilder/IndexCreatorBase.h @@ -22,7 +22,9 @@ class IndexCreatorBase { virtual ~IndexCreatorBase() = default; virtual void - Build(const milvus::DatasetPtr& dataset) = 0; + Build(const milvus::DatasetPtr& dataset, + const bool* valid_data = nullptr, + const int64_t valid_data_len = 0) = 0; virtual void Build() = 0; diff --git a/internal/core/src/indexbuilder/ScalarIndexCreator.cpp b/internal/core/src/indexbuilder/ScalarIndexCreator.cpp index 553c203185..43778077a8 100644 --- a/internal/core/src/indexbuilder/ScalarIndexCreator.cpp +++ b/internal/core/src/indexbuilder/ScalarIndexCreator.cpp @@ -83,7 +83,11 @@ ScalarIndexCreator::ScalarIndexCreator( } void -ScalarIndexCreator::Build(const milvus::DatasetPtr& dataset) { +ScalarIndexCreator::Build(const milvus::DatasetPtr& dataset, + const bool* valid_data, + const int64_t valid_data_len) { + (void)valid_data; + (void)valid_data_len; auto size = dataset->GetRows(); auto data = dataset->GetTensor(); index_->BuildWithRawDataForUT(size, data); diff --git a/internal/core/src/indexbuilder/ScalarIndexCreator.h b/internal/core/src/indexbuilder/ScalarIndexCreator.h index 3d32fc78e3..12513b9af4 100644 --- a/internal/core/src/indexbuilder/ScalarIndexCreator.h +++ b/internal/core/src/indexbuilder/ScalarIndexCreator.h @@ -27,7 +27,9 @@ class ScalarIndexCreator : public IndexCreatorBase { const storage::FileManagerContext& file_manager_context); void - Build(const milvus::DatasetPtr& dataset) override; + Build(const milvus::DatasetPtr& dataset, + const bool* valid_data = nullptr, + const int64_t valid_data_len = 0) override; void Build() override; diff --git a/internal/core/src/indexbuilder/VecIndexCreator.cpp b/internal/core/src/indexbuilder/VecIndexCreator.cpp index 79ce9049dd..d3aa11622d 100644 --- a/internal/core/src/indexbuilder/VecIndexCreator.cpp +++ b/internal/core/src/indexbuilder/VecIndexCreator.cpp @@ -65,8 +65,15 @@ VecIndexCreator::dim() { } void -VecIndexCreator::Build(const milvus::DatasetPtr& dataset) { +VecIndexCreator::Build(const milvus::DatasetPtr& dataset, + const bool* valid_data, + const int64_t valid_data_len) { index_->BuildWithDataset(dataset, config_); + if (valid_data && valid_data_len > 0) { + auto vec_index = dynamic_cast(index_.get()); + AssertInfo(vec_index != nullptr, "failed to cast index to VectorIndex"); + vec_index->BuildValidData(valid_data, valid_data_len); + } } void diff --git a/internal/core/src/indexbuilder/VecIndexCreator.h b/internal/core/src/indexbuilder/VecIndexCreator.h index 6f89c273a3..b349133ea4 100644 --- a/internal/core/src/indexbuilder/VecIndexCreator.h +++ b/internal/core/src/indexbuilder/VecIndexCreator.h @@ -39,7 +39,9 @@ class VecIndexCreator : public IndexCreatorBase { const storage::FileManagerContext& file_manager_context); void - Build(const milvus::DatasetPtr& dataset) override; + Build(const milvus::DatasetPtr& dataset, + const bool* valid_data = nullptr, + const int64_t valid_data_len = 0) override; void Build() override; diff --git a/internal/core/src/indexbuilder/index_c.cpp b/internal/core/src/indexbuilder/index_c.cpp index 6739f8fbf9..8bbe3bfac0 100644 --- a/internal/core/src/indexbuilder/index_c.cpp +++ b/internal/core/src/indexbuilder/index_c.cpp @@ -564,6 +564,35 @@ BuildFloatVecIndex(CIndex index, return status; } +CStatus +BuildFloatVecIndexWithValidData(CIndex index, + int64_t float_value_num, + const float* vectors, + const bool* valid_data, + int64_t valid_data_len) { + SCOPE_CGO_CALL_METRIC(); + + auto status = CStatus(); + try { + AssertInfo(index, + "failed to build float vector index, passed index was null"); + auto real_index = + reinterpret_cast(index); + auto cIndex = + dynamic_cast(real_index); + auto dim = cIndex->dim(); + auto row_nums = float_value_num / dim; + auto ds = knowhere::GenDataSet(row_nums, dim, vectors); + cIndex->Build(ds, valid_data, valid_data_len); + status.error_code = Success; + status.error_msg = ""; + } catch (std::exception& e) { + status.error_code = UnexpectedError; + status.error_msg = strdup(e.what()); + } + return status; +} + CStatus BuildFloat16VecIndex(CIndex index, int64_t float16_value_num, @@ -592,6 +621,36 @@ BuildFloat16VecIndex(CIndex index, return status; } +CStatus +BuildFloat16VecIndexWithValidData(CIndex index, + int64_t float16_value_num, + const uint8_t* vectors, + const bool* valid_data, + int64_t valid_data_len) { + SCOPE_CGO_CALL_METRIC(); + + auto status = CStatus(); + try { + AssertInfo( + index, + "failed to build float16 vector index, passed index was null"); + auto real_index = + reinterpret_cast(index); + auto cIndex = + dynamic_cast(real_index); + auto dim = cIndex->dim(); + auto row_nums = float16_value_num / dim / 2; + auto ds = knowhere::GenDataSet(row_nums, dim, vectors); + cIndex->Build(ds, valid_data, valid_data_len); + status.error_code = Success; + status.error_msg = ""; + } catch (std::exception& e) { + status.error_code = UnexpectedError; + status.error_msg = strdup(e.what()); + } + return status; +} + CStatus BuildBFloat16VecIndex(CIndex index, int64_t bfloat16_value_num, @@ -620,6 +679,36 @@ BuildBFloat16VecIndex(CIndex index, return status; } +CStatus +BuildBFloat16VecIndexWithValidData(CIndex index, + int64_t bfloat16_value_num, + const uint8_t* vectors, + const bool* valid_data, + int64_t valid_data_len) { + SCOPE_CGO_CALL_METRIC(); + + auto status = CStatus(); + try { + AssertInfo( + index, + "failed to build bfloat16 vector index, passed index was null"); + auto real_index = + reinterpret_cast(index); + auto cIndex = + dynamic_cast(real_index); + auto dim = cIndex->dim(); + auto row_nums = bfloat16_value_num / dim / 2; + auto ds = knowhere::GenDataSet(row_nums, dim, vectors); + cIndex->Build(ds, valid_data, valid_data_len); + status.error_code = Success; + status.error_msg = ""; + } catch (std::exception& e) { + status.error_code = UnexpectedError; + status.error_msg = strdup(e.what()); + } + return status; +} + CStatus BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors) { SCOPE_CGO_CALL_METRIC(); @@ -646,6 +735,36 @@ BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors) { return status; } +CStatus +BuildBinaryVecIndexWithValidData(CIndex index, + int64_t data_size, + const uint8_t* vectors, + const bool* valid_data, + int64_t valid_data_len) { + SCOPE_CGO_CALL_METRIC(); + + auto status = CStatus(); + try { + AssertInfo( + index, + "failed to build binary vector index, passed index was null"); + auto real_index = + reinterpret_cast(index); + auto cIndex = + dynamic_cast(real_index); + auto dim = cIndex->dim(); + auto row_nums = (data_size * 8) / dim; + auto ds = knowhere::GenDataSet(row_nums, dim, vectors); + cIndex->Build(ds, valid_data, valid_data_len); + status.error_code = Success; + status.error_msg = ""; + } catch (std::exception& e) { + status.error_code = UnexpectedError; + status.error_msg = strdup(e.what()); + } + return status; +} + CStatus BuildSparseFloatVecIndex(CIndex index, int64_t row_num, @@ -674,6 +793,36 @@ BuildSparseFloatVecIndex(CIndex index, return status; } +CStatus +BuildSparseFloatVecIndexWithValidData(CIndex index, + int64_t row_num, + int64_t dim, + const uint8_t* vectors, + const bool* valid_data, + int64_t valid_data_len) { + SCOPE_CGO_CALL_METRIC(); + + auto status = CStatus(); + try { + AssertInfo( + index, + "failed to build sparse float vector index, passed index was null"); + auto real_index = + reinterpret_cast(index); + auto cIndex = + dynamic_cast(real_index); + auto ds = knowhere::GenDataSet(row_num, dim, vectors); + ds->SetIsSparse(true); + cIndex->Build(ds, valid_data, valid_data_len); + status.error_code = Success; + status.error_msg = ""; + } catch (std::exception& e) { + status.error_code = UnexpectedError; + status.error_msg = strdup(e.what()); + } + return status; +} + CStatus BuildInt8VecIndex(CIndex index, int64_t int8_value_num, const int8_t* vectors) { SCOPE_CGO_CALL_METRIC(); @@ -699,6 +848,35 @@ BuildInt8VecIndex(CIndex index, int64_t int8_value_num, const int8_t* vectors) { return status; } +CStatus +BuildInt8VecIndexWithValidData(CIndex index, + int64_t int8_value_num, + const int8_t* vectors, + const bool* valid_data, + int64_t valid_data_len) { + SCOPE_CGO_CALL_METRIC(); + + auto status = CStatus(); + try { + AssertInfo(index, + "failed to build int8 vector index, passed index was null"); + auto real_index = + reinterpret_cast(index); + auto cIndex = + dynamic_cast(real_index); + auto dim = cIndex->dim(); + auto row_nums = int8_value_num / dim; + auto ds = knowhere::GenDataSet(row_nums, dim, vectors); + cIndex->Build(ds, valid_data, valid_data_len); + status.error_code = Success; + status.error_msg = ""; + } catch (std::exception& e) { + status.error_code = UnexpectedError; + status.error_msg = strdup(e.what()); + } + return status; +} + // field_data: // 1, serialized proto::schema::BoolArray, if type is bool; // 2, serialized proto::schema::StringArray, if type is string; diff --git a/internal/core/src/indexbuilder/index_c.h b/internal/core/src/indexbuilder/index_c.h index aefe7b5085..ff1f47e0cc 100644 --- a/internal/core/src/indexbuilder/index_c.h +++ b/internal/core/src/indexbuilder/index_c.h @@ -55,24 +55,67 @@ CreateIndexForUT(enum CDataType dtype, CStatus BuildFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors); +CStatus +BuildFloatVecIndexWithValidData(CIndex index, + int64_t float_value_num, + const float* vectors, + const bool* valid_data, + int64_t valid_data_len); + CStatus BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors); +CStatus +BuildBinaryVecIndexWithValidData(CIndex index, + int64_t data_size, + const uint8_t* vectors, + const bool* valid_data, + int64_t valid_data_len); + CStatus BuildFloat16VecIndex(CIndex index, int64_t data_size, const uint8_t* vectors); +CStatus +BuildFloat16VecIndexWithValidData(CIndex index, + int64_t data_size, + const uint8_t* vectors, + const bool* valid_data, + int64_t valid_data_len); + CStatus BuildBFloat16VecIndex(CIndex index, int64_t data_size, const uint8_t* vectors); +CStatus +BuildBFloat16VecIndexWithValidData(CIndex index, + int64_t data_size, + const uint8_t* vectors, + const bool* valid_data, + int64_t valid_data_len); + CStatus BuildSparseFloatVecIndex(CIndex index, int64_t row_num, int64_t dim, const uint8_t* vectors); +CStatus +BuildSparseFloatVecIndexWithValidData(CIndex index, + int64_t row_num, + int64_t dim, + const uint8_t* vectors, + const bool* valid_data, + int64_t valid_data_len); + CStatus BuildInt8VecIndex(CIndex index, int64_t data_size, const int8_t* vectors); +CStatus +BuildInt8VecIndexWithValidData(CIndex index, + int64_t data_size, + const int8_t* vectors, + const bool* valid_data, + int64_t valid_data_len); + // field_data: // 1, serialized proto::schema::BoolArray, if type is bool; // 2, serialized proto::schema::StringArray, if type is string; diff --git a/internal/core/src/mmap/ChunkedColumn.h b/internal/core/src/mmap/ChunkedColumn.h index 17a8dd0c59..491e168552 100644 --- a/internal/core/src/mmap/ChunkedColumn.h +++ b/internal/core/src/mmap/ChunkedColumn.h @@ -368,6 +368,33 @@ class ChunkedColumnBase : public ChunkedColumnInterface { return meta->num_rows_until_chunk_; } + void + BuildValidRowIds(milvus::OpContext* op_ctx) override { + if (!nullable_) { + return; + } + auto ca = SemiInlineGet(slot_->PinAllCells(op_ctx)); + int64_t logical_offset = 0; + valid_data_.resize(num_rows_); + valid_count_per_chunk_.resize(num_chunks_); + for (size_t i = 0; i < num_chunks_; i++) { + auto chunk = ca->get_cell_of(i); + auto rows = chunk_row_nums(i); + int64_t valid_count = 0; + for (int64_t j = 0; j < rows; j++) { + if (chunk->isValid(j)) { + valid_data_[logical_offset + j] = true; + valid_count++; + } else { + valid_data_[logical_offset + j] = false; + } + } + valid_count_per_chunk_[i] = valid_count; + logical_offset += rows; + } + BuildOffsetMapping(); + } + protected: bool nullable_{false}; DataType data_type_{DataType::NONE}; diff --git a/internal/core/src/mmap/ChunkedColumnGroup.h b/internal/core/src/mmap/ChunkedColumnGroup.h index 08c56498d2..0586ae1ffc 100644 --- a/internal/core/src/mmap/ChunkedColumnGroup.h +++ b/internal/core/src/mmap/ChunkedColumnGroup.h @@ -667,6 +667,36 @@ class ProxyChunkColumn : public ChunkedColumnInterface { } } + void + BuildValidRowIds(milvus::OpContext* op_ctx) override { + if (!field_meta_.is_nullable()) { + return; + } + auto total_rows = NumRows(); + auto total_chunks = num_chunks(); + valid_data_.resize(total_rows); + valid_count_per_chunk_.resize(total_chunks); + + int64_t logical_offset = 0; + for (int64_t i = 0; i < total_chunks; i++) { + auto group_chunk = group_->GetGroupChunk(op_ctx, i); + auto chunk = group_chunk.get()->GetChunk(field_id_); + auto rows = chunk->RowNums(); + int64_t valid_count = 0; + for (int64_t j = 0; j < rows; j++) { + if (chunk->isValid(j)) { + valid_data_[logical_offset + j] = true; + valid_count++; + } else { + valid_data_[logical_offset + j] = false; + } + } + valid_count_per_chunk_[i] = valid_count; + logical_offset += rows; + } + BuildOffsetMapping(); + } + private: std::shared_ptr group_; FieldId field_id_; diff --git a/internal/core/src/mmap/ChunkedColumnInterface.h b/internal/core/src/mmap/ChunkedColumnInterface.h index f9c79d4b70..1ebc19ca5f 100644 --- a/internal/core/src/mmap/ChunkedColumnInterface.h +++ b/internal/core/src/mmap/ChunkedColumnInterface.h @@ -17,6 +17,7 @@ #include "cachinglayer/CacheSlot.h" #include "common/Chunk.h" +#include "common/OffsetMapping.h" #include "common/bson_view.h" namespace milvus { @@ -131,6 +132,35 @@ class ChunkedColumnInterface { virtual const std::vector& GetNumRowsUntilChunk() const = 0; + const FixedVector& + GetValidData() const { + return valid_data_; + } + + const std::vector& + GetValidCountPerChunk() const { + return valid_count_per_chunk_; + } + + const OffsetMapping& + GetOffsetMapping() const { + return offset_mapping_; + } + + virtual void + BuildValidRowIds(milvus::OpContext* op_ctx) { + ThrowInfo(ErrorCode::Unsupported, + "BuildValidRowIds not supported for this column type"); + } + + // Build offset mapping from valid_data + void + BuildOffsetMapping() { + if (!valid_data_.empty()) { + offset_mapping_.Build(valid_data_.data(), valid_data_.size()); + } + } + virtual void BulkValueAt(milvus::OpContext* op_ctx, std::function fn, @@ -237,6 +267,10 @@ class ChunkedColumnInterface { } protected: + FixedVector valid_data_; + std::vector valid_count_per_chunk_; + OffsetMapping offset_mapping_; + std::pair, std::vector> ToChunkIdAndOffset(const int64_t* offsets, int64_t count) const { AssertInfo(offsets != nullptr, "Offsets cannot be nullptr"); diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index b306542c5a..f797b3dba5 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -22,6 +22,7 @@ #include "query/CachedSearchIterator.h" #include "query/SearchBruteForce.h" #include "query/SearchOnIndex.h" +#include "query/Utils.h" #include "exec/operator/Utils.h" namespace milvus::query { @@ -82,8 +83,6 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, SearchResult& search_result) { auto& schema = segment.get_schema(); auto& record = segment.get_insert_record(); - auto active_row_count = - std::min(int64_t(bitset.size()), segment.get_active_count(timestamp)); // step 1.1: get meta // step 1.2: get which vector field to search @@ -155,6 +154,19 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, // step 3: brute force search where small indexing is unavailable auto vec_ptr = record.get_data_base(vecfield_id); + const auto& offset_mapping = vec_ptr->get_offset_mapping(); + + TargetBitmap transformed_bitset; + BitsetView search_bitset = bitset; + if (offset_mapping.IsEnabled()) { + transformed_bitset = TransformBitset(bitset, offset_mapping); + search_bitset = BitsetView(transformed_bitset); + } + + auto active_count = offset_mapping.IsEnabled() + ? offset_mapping.GetValidCount() + : std::min(int64_t(bitset.size()), + segment.get_active_count(timestamp)); if (info.iterator_v2_info_.has_value()) { AssertInfo(data_type != DataType::VECTOR_ARRAY, @@ -163,17 +175,20 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, CachedSearchIterator cached_iter(search_dataset, vec_ptr, - active_row_count, + active_count, info, index_info, - bitset, + search_bitset, data_type); cached_iter.NextBatch(info, search_result); + if (offset_mapping.IsEnabled()) { + TransformOffset(search_result.seg_offsets_, offset_mapping); + } return; } auto vec_size_per_chunk = vec_ptr->get_size_per_chunk(); - auto max_chunk = upper_div(active_row_count, vec_size_per_chunk); + auto max_chunk = upper_div(active_count, vec_size_per_chunk); // embedding search embedding on embedding list bool embedding_search = false; @@ -188,7 +203,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, auto row_begin = chunk_id * vec_size_per_chunk; auto row_end = - std::min(active_row_count, (chunk_id + 1) * vec_size_per_chunk); + std::min(active_count, (chunk_id + 1) * vec_size_per_chunk); auto size_per_chunk = row_end - row_begin; query::dataset::RawDataset sub_data; @@ -260,7 +275,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, sub_data, info, index_info, - bitset, + search_bitset, vector_type); final_qr.merge(sub_qr); } else { @@ -268,7 +283,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, sub_data, info, index_info, - bitset, + search_bitset, vector_type, element_type, op_context); @@ -286,6 +301,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, max_chunk, chunk_rows, final_qr.chunk_iterators(), + offset_mapping, larger_is_closer); } else { if (info.array_offsets_ != nullptr) { @@ -300,6 +316,9 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, std::move(final_qr.mutable_offsets()); } search_result.distances_ = std::move(final_qr.mutable_distances()); + if (offset_mapping.IsEnabled()) { + TransformOffset(search_result.seg_offsets_, offset_mapping); + } } search_result.unity_topK_ = topk; search_result.total_nq_ = num_queries; diff --git a/internal/core/src/query/SearchOnIndex.cpp b/internal/core/src/query/SearchOnIndex.cpp index fe70ef292b..4b9c675ac3 100644 --- a/internal/core/src/query/SearchOnIndex.cpp +++ b/internal/core/src/query/SearchOnIndex.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #include "SearchOnIndex.h" +#include "Utils.h" #include "exec/operator/Utils.h" #include "CachedSearchIterator.h" @@ -28,23 +29,39 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset, auto dataset = knowhere::GenDataSet(num_queries, dim, search_dataset.query_data); dataset->SetIsSparse(is_sparse); + + const auto& offset_mapping = indexing.GetOffsetMapping(); + TargetBitmap transformed_bitset; + BitsetView search_bitset = bitset; + if (offset_mapping.IsEnabled()) { + transformed_bitset = TransformBitset(bitset, offset_mapping); + search_bitset = BitsetView(transformed_bitset); + } + if (milvus::exec::PrepareVectorIteratorsFromIndex(search_conf, num_queries, dataset, search_result, - bitset, + search_bitset, indexing)) { return; } if (search_conf.iterator_v2_info_.has_value()) { auto iter = - CachedSearchIterator(indexing, dataset, search_conf, bitset); + CachedSearchIterator(indexing, dataset, search_conf, search_bitset); iter.NextBatch(search_conf, search_result); + if (offset_mapping.IsEnabled()) { + TransformOffset(search_result.seg_offsets_, offset_mapping); + } return; } - indexing.Query(dataset, search_conf, bitset, op_context, search_result); + indexing.Query( + dataset, search_conf, search_bitset, op_context, search_result); + if (offset_mapping.IsEnabled()) { + TransformOffset(search_result.seg_offsets_, offset_mapping); + } } } // namespace milvus::query diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index cfde506887..c12215acdc 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -16,12 +16,14 @@ #include "bitset/detail/element_wise.h" #include "cachinglayer/Utils.h" #include "common/BitsetView.h" +#include "common/Consts.h" #include "common/QueryInfo.h" #include "common/Types.h" #include "common/Utils.h" #include "query/CachedSearchIterator.h" #include "query/SearchBruteForce.h" #include "query/SearchOnSealed.h" +#include "query/Utils.h" #include "query/helper.h" #include "exec/operator/Utils.h" @@ -73,21 +75,40 @@ SearchOnSealedIndex(const Schema& schema, auto vec_index = dynamic_cast(accessor->get_cell_of(0)); + const auto& offset_mapping = vec_index->GetOffsetMapping(); + TargetBitmap transformed_bitset; + BitsetView search_bitset = bitset; + if (offset_mapping.IsEnabled()) { + transformed_bitset = TransformBitset(bitset, offset_mapping); + search_bitset = BitsetView(transformed_bitset); + if (offset_mapping.GetValidCount() == 0) { + auto total_num = num_queries * topK; + search_result.seg_offsets_.resize(total_num, INVALID_SEG_OFFSET); + search_result.distances_.resize(total_num, 0.0f); + search_result.total_nq_ = num_queries; + search_result.unity_topK_ = topK; + return; + } + } + if (search_info.iterator_v2_info_.has_value()) { CachedSearchIterator cached_iter( - *vec_index, dataset, search_info, bitset); + *vec_index, dataset, search_info, search_bitset); cached_iter.NextBatch(search_info, search_result); + TransformOffset(search_result.seg_offsets_, offset_mapping); return; } - if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_info, - num_queries, - dataset, - search_result, - bitset, - *vec_index)) { + bool use_iterator = + milvus::exec::PrepareVectorIteratorsFromIndex(search_info, + num_queries, + dataset, + search_result, + search_bitset, + *vec_index); + if (!use_iterator) { vec_index->Query( - dataset, search_info, bitset, op_context, search_result); + dataset, search_info, search_bitset, op_context, search_result); float* distances = search_result.distances_.data(); auto total_num = num_queries * topK; if (round_decimal != -1) { @@ -120,6 +141,7 @@ SearchOnSealedIndex(const Schema& schema, search_result.element_level_ = true; } } + TransformOffset(search_result.seg_offsets_, offset_mapping); search_result.total_nq_ = num_queries; search_result.unity_topK_ = topK; } @@ -185,12 +207,30 @@ SearchOnSealedColumn(const Schema& schema, } auto offset = 0; - + const auto& offset_mapping = column->GetOffsetMapping(); + TargetBitmap transformed_bitset; + BitsetView search_bitview = bitview; + if (offset_mapping.IsEnabled()) { + transformed_bitset = TransformBitset(bitview, offset_mapping); + search_bitview = BitsetView(transformed_bitset); + if (offset_mapping.GetValidCount() == 0) { + auto total_num = num_queries * search_info.topk_; + result.seg_offsets_.resize(total_num, INVALID_SEG_OFFSET); + result.distances_.resize(total_num, 0.0f); + result.total_nq_ = num_queries; + result.unity_topK_ = search_info.topk_; + return; + } + } auto vector_chunks = column->GetAllChunks(op_context); + const auto& valid_count_per_chunk = column->GetValidCountPerChunk(); for (int i = 0; i < num_chunk; ++i) { auto pw = vector_chunks[i]; auto vec_data = pw.get()->Data(); auto chunk_size = column->chunk_row_nums(i); + if (offset_mapping.IsEnabled() && !valid_count_per_chunk.empty()) { + chunk_size = valid_count_per_chunk[i]; + } // For element-level search, get element count from VectorArrayOffsets if (is_element_level_search) { @@ -221,7 +261,7 @@ SearchOnSealedColumn(const Schema& schema, raw_dataset, search_info, index_info, - bitview, + search_bitview, data_type); final_qr.merge(sub_qr); } else { @@ -229,7 +269,7 @@ SearchOnSealedColumn(const Schema& schema, raw_dataset, search_info, index_info, - bitview, + search_bitview, data_type, element_type, op_context); @@ -243,6 +283,7 @@ SearchOnSealedColumn(const Schema& schema, num_chunk, column->GetNumRowsUntilChunk(), final_qr.chunk_iterators(), + offset_mapping, larger_is_closer); } else { if (search_info.array_offsets_ != nullptr) { @@ -256,6 +297,9 @@ SearchOnSealedColumn(const Schema& schema, result.seg_offsets_ = std::move(final_qr.mutable_offsets()); } result.distances_ = std::move(final_qr.mutable_distances()); + if (offset_mapping.IsEnabled()) { + TransformOffset(result.seg_offsets_, offset_mapping); + } } result.unity_topK_ = query_dataset.topk; result.total_nq_ = query_dataset.num_queries; diff --git a/internal/core/src/query/Utils.h b/internal/core/src/query/Utils.h index 8e5c637772..0e51f12586 100644 --- a/internal/core/src/query/Utils.h +++ b/internal/core/src/query/Utils.h @@ -14,9 +14,37 @@ #include #include +#include "common/BitsetView.h" +#include "common/OffsetMapping.h" +#include "common/Types.h" #include "common/Utils.h" namespace milvus::query { +inline TargetBitmap +TransformBitset(const BitsetView& bitset, + const milvus::OffsetMapping& mapping) { + TargetBitmap result; + auto count = mapping.GetValidCount(); + result.resize(count); + for (int64_t physical_idx = 0; physical_idx < count; physical_idx++) { + auto logical_idx = mapping.GetLogicalOffset(physical_idx); + if (logical_idx >= 0 && + logical_idx < static_cast(bitset.size())) { + result[physical_idx] = bitset.test(logical_idx); + } + } + return result; +} + +inline void +TransformOffset(std::vector& seg_offsets, + const milvus::OffsetMapping& mapping) { + for (auto& seg_offset : seg_offsets) { + if (seg_offset >= 0) { + seg_offset = mapping.GetLogicalOffset(seg_offset); + } + } +} template inline bool diff --git a/internal/core/src/segcore/ChunkedSegmentSealedBinlogIndexTest.cpp b/internal/core/src/segcore/ChunkedSegmentSealedBinlogIndexTest.cpp index 3417cda369..3bedf6a68a 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedBinlogIndexTest.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedBinlogIndexTest.cpp @@ -14,14 +14,18 @@ #include #include "index/IndexFactory.h" +#include "index/VectorIndex.h" #include "pb/plan.pb.h" #include "query/Plan.h" #include "segcore/segcore_init_c.h" #include "segcore/SegmentSealed.h" +#include "expr/ITypeExpr.h" +#include "plan/PlanNode.h" #include "test_utils/cachinglayer_test_utils.h" #include "test_utils/DataGen.h" #include "test_utils/storage_test_utils.h" +#include "test_utils/GenExprProto.h" using namespace milvus; using namespace milvus::segcore; @@ -35,6 +39,135 @@ GenRandomFloatVecData(int rows, int dim, int seed = 42) { return vecs; } +std::unique_ptr +GenRandomFloat16VecData(int rows, int dim, int seed = 42) { + auto vecs = std::make_unique(rows * dim); + std::mt19937 rng(seed); + std::normal_distribution<> distrib(0.0, 1.0); + for (int i = 0; i < rows * dim; ++i) + vecs[i] = milvus::float16(distrib(rng)); + return vecs; +} + +std::unique_ptr +GenRandomBFloat16VecData(int rows, int dim, int seed = 42) { + auto vecs = std::make_unique(rows * dim); + std::mt19937 rng(seed); + std::normal_distribution<> distrib(0.0, 1.0); + for (int i = 0; i < rows * dim; ++i) + vecs[i] = milvus::bfloat16(distrib(rng)); + return vecs; +} + +std::unique_ptr +GenRandomBinaryVecData(int rows, int dim, int seed = 42) { + assert(dim % 8 == 0); + auto byte_dim = dim / 8; + auto vecs = std::make_unique(rows * byte_dim); + std::mt19937 rng(seed); + for (int i = 0; i < rows * byte_dim; ++i) + vecs[i] = static_cast(rng()); + return vecs; +} + +std::unique_ptr +GenRandomInt8VecData(int rows, int dim, int seed = 42) { + auto vecs = std::make_unique(rows * dim); + std::mt19937 rng(seed); + for (int i = 0; i < rows * dim; ++i) + vecs[i] = static_cast(rng() % 256 - 128); + return vecs; +} + +milvus::proto::plan::VectorType +DataTypeToVectorType(DataType data_type) { + switch (data_type) { + case DataType::VECTOR_FLOAT: + return milvus::proto::plan::VectorType::FloatVector; + case DataType::VECTOR_FLOAT16: + return milvus::proto::plan::VectorType::Float16Vector; + case DataType::VECTOR_BFLOAT16: + return milvus::proto::plan::VectorType::BFloat16Vector; + case DataType::VECTOR_BINARY: + return milvus::proto::plan::VectorType::BinaryVector; + case DataType::VECTOR_INT8: + return milvus::proto::plan::VectorType::Int8Vector; + case DataType::VECTOR_SPARSE_U32_F32: + return milvus::proto::plan::VectorType::SparseFloatVector; + default: + throw std::runtime_error("unsupported vector type"); + } +} + +milvus::proto::common::PlaceholderGroup +CreatePlaceholderGroupForVectorType(DataType data_type, + int64_t num_queries, + int64_t dim, + const void* data) { + namespace ser = milvus::proto::common; + ser::PlaceholderGroup raw_group; + auto value = raw_group.add_placeholders(); + value->set_tag("$0"); + + switch (data_type) { + case DataType::VECTOR_FLOAT: { + value->set_type(ser::PlaceholderType::FloatVector); + auto ptr = static_cast(data); + for (int i = 0; i < num_queries; ++i) { + value->add_values(ptr + i * dim, dim * sizeof(float)); + } + break; + } + case DataType::VECTOR_FLOAT16: { + value->set_type(ser::PlaceholderType::Float16Vector); + auto ptr = static_cast(data); + for (int i = 0; i < num_queries; ++i) { + value->add_values(ptr + i * dim, dim * sizeof(milvus::float16)); + } + break; + } + case DataType::VECTOR_BFLOAT16: { + value->set_type(ser::PlaceholderType::BFloat16Vector); + auto ptr = static_cast(data); + for (int i = 0; i < num_queries; ++i) { + value->add_values(ptr + i * dim, + dim * sizeof(milvus::bfloat16)); + } + break; + } + case DataType::VECTOR_BINARY: { + value->set_type(ser::PlaceholderType::BinaryVector); + auto byte_dim = dim / 8; + auto ptr = static_cast(data); + for (int i = 0; i < num_queries; ++i) { + value->add_values(ptr + i * byte_dim, byte_dim); + } + break; + } + case DataType::VECTOR_INT8: { + value->set_type(ser::PlaceholderType::Int8Vector); + auto ptr = static_cast(data); + for (int i = 0; i < num_queries; ++i) { + value->add_values(ptr + i * dim, dim * sizeof(int8_t)); + } + break; + } + case DataType::VECTOR_SPARSE_U32_F32: { + value->set_type(ser::PlaceholderType::SparseFloatVector); + auto ptr = static_cast< + const knowhere::sparse::SparseRow*>( + data); + for (int i = 0; i < num_queries; ++i) { + value->add_values(ptr[i].data(), ptr[i].data_byte_size()); + } + break; + } + default: + throw std::runtime_error("unsupported vector type for placeholder"); + } + return raw_group; +} + inline float GetKnnSearchRecall( size_t nq, int64_t* gt_ids, size_t gt_k, int64_t* res_ids, size_t res_k) { @@ -62,29 +195,58 @@ using Param = std::tuple>; + /* DenseVectorInterminIndexType*/ std::optional, + /* Nullable */ bool, + /* NullPercent */ int>; class BinlogIndexTest : public ::testing::TestWithParam { void SetUp() override { - std::tie( - data_type, metric_type, index_type, dense_vec_intermin_index_type) = - GetParam(); + std::tie(data_type, + metric_type, + index_type, + dense_vec_intermin_index_type, + nullable, + null_percent) = GetParam(); schema = std::make_shared(); - vec_field_id = - schema->AddDebugField("fakevec", data_type, data_d, metric_type); + valid_count = 0; + vec_field_id = schema->AddDebugField( + "fakevec", data_type, data_d, metric_type, nullable); auto i64_fid = schema->AddDebugField("counter", DataType::INT64); schema->set_primary_field_id(i64_fid); - vec_field_data = - storage::CreateFieldData(data_type, DataType::NONE, false, data_d); + vec_field_data = storage::CreateFieldData( + data_type, DataType::NONE, nullable, data_d); + + if (nullable) { + valid_data.resize((data_n + 7) / 8, 0); + for (int i = 0; i < data_n; ++i) { + bool is_valid = (i % 100) >= null_percent; + if (is_valid) { + valid_data[i >> 3] |= (1 << (i & 0x07)); + valid_count++; + row_ids.push_back(i); + } + } + } else { + valid_count = data_n; + } if (data_type == DataType::VECTOR_FLOAT) { - auto vec_data = GenRandomFloatVecData(data_n, data_d); - vec_field_data->FillFieldData(vec_data.get(), data_n); - raw_dataset = knowhere::GenDataSet(data_n, data_d, vec_data.get()); - raw_dataset->SetIsOwner(true); - vec_data.release(); + auto vec_data = GenRandomFloatVecData(valid_count, data_d); + if (nullable) { + auto vec_field_data_impl = std::dynamic_pointer_cast< + milvus::FieldData>(vec_field_data); + vec_field_data_impl->FillFieldData( + vec_data.get(), valid_data.data(), data_n, 0); + } else { + vec_field_data->FillFieldData(vec_data.get(), data_n); + } + + raw_dataset = + knowhere::GenDataSet(valid_count, data_d, vec_data.get()); + raw_dataset->SetIsOwner(false); + raw_float_data = std::move(vec_data); if (dense_vec_intermin_index_type.has_value() && dense_vec_intermin_index_type.value() == knowhere::IndexEnum::INDEX_FAISS_SCANN_DVR) { @@ -92,18 +254,90 @@ class BinlogIndexTest : public ::testing::TestWithParam { } else { intermin_index_has_raw_data = true; } + } else if (data_type == DataType::VECTOR_FLOAT16) { + auto vec_data = GenRandomFloat16VecData(valid_count, data_d); + if (nullable) { + auto vec_field_data_impl = std::dynamic_pointer_cast< + milvus::FieldData>(vec_field_data); + vec_field_data_impl->FillFieldData( + vec_data.get(), valid_data.data(), data_n, 0); + } else { + vec_field_data->FillFieldData(vec_data.get(), data_n); + } + + raw_dataset = + knowhere::GenDataSet(valid_count, data_d, vec_data.get()); + raw_dataset->SetIsOwner(false); + raw_float16_data = std::move(vec_data); + intermin_index_has_raw_data = true; + } else if (data_type == DataType::VECTOR_BFLOAT16) { + auto vec_data = GenRandomBFloat16VecData(valid_count, data_d); + if (nullable) { + auto vec_field_data_impl = std::dynamic_pointer_cast< + milvus::FieldData>(vec_field_data); + vec_field_data_impl->FillFieldData( + vec_data.get(), valid_data.data(), data_n, 0); + } else { + vec_field_data->FillFieldData(vec_data.get(), data_n); + } + + raw_dataset = + knowhere::GenDataSet(valid_count, data_d, vec_data.get()); + raw_dataset->SetIsOwner(false); + raw_bfloat16_data = std::move(vec_data); + intermin_index_has_raw_data = true; + } else if (data_type == DataType::VECTOR_BINARY) { + auto vec_data = GenRandomBinaryVecData(valid_count, data_d); + if (nullable) { + auto vec_field_data_impl = std::dynamic_pointer_cast< + milvus::FieldData>(vec_field_data); + vec_field_data_impl->FillFieldData( + vec_data.get(), valid_data.data(), data_n, 0); + } else { + vec_field_data->FillFieldData(vec_data.get(), data_n); + } + + raw_dataset = + knowhere::GenDataSet(valid_count, data_d / 8, vec_data.get()); + raw_dataset->SetIsOwner(false); + raw_binary_data = std::move(vec_data); + intermin_index_has_raw_data = true; + } else if (data_type == DataType::VECTOR_INT8) { + auto vec_data = GenRandomInt8VecData(valid_count, data_d); + if (nullable) { + auto vec_field_data_impl = std::dynamic_pointer_cast< + milvus::FieldData>(vec_field_data); + vec_field_data_impl->FillFieldData( + vec_data.get(), valid_data.data(), data_n, 0); + } else { + vec_field_data->FillFieldData(vec_data.get(), data_n); + } + + raw_dataset = + knowhere::GenDataSet(valid_count, data_d, vec_data.get()); + raw_dataset->SetIsOwner(false); + raw_int8_data = std::move(vec_data); + intermin_index_has_raw_data = true; } else if (data_type == DataType::VECTOR_SPARSE_U32_F32) { - auto sparse_vecs = GenerateRandomSparseFloatVector(data_n); - vec_field_data->FillFieldData(sparse_vecs.get(), data_n); + auto sparse_vecs = GenerateRandomSparseFloatVector(valid_count); + if (nullable) { + auto vec_field_data_impl = std::dynamic_pointer_cast< + milvus::FieldData>( + vec_field_data); + vec_field_data_impl->FillFieldData( + sparse_vecs.get(), valid_data.data(), data_n, 0); + } else { + vec_field_data->FillFieldData(sparse_vecs.get(), data_n); + } data_d = std::dynamic_pointer_cast< milvus::FieldData>( vec_field_data) ->Dim(); raw_dataset = - knowhere::GenDataSet(data_n, data_d, sparse_vecs.get()); - raw_dataset->SetIsOwner(true); + knowhere::GenDataSet(valid_count, data_d, sparse_vecs.get()); + raw_dataset->SetIsOwner(false); raw_dataset->SetIsSparse(true); - sparse_vecs.release(); + raw_sparse_data = std::move(sparse_vecs); intermin_index_has_raw_data = false; } else { throw std::runtime_error("not implemented"); @@ -123,6 +357,7 @@ class BinlogIndexTest : public ::testing::TestWithParam { auto& config = SegcoreConfig::default_config(); config.set_chunk_rows(1024); config.set_enable_interim_segment_index(true); + config.set_nlist(16); std::map filedMap = { {vec_field_id, fieldIndexMeta}}; IndexMetaPtr metaPtr = @@ -159,48 +394,241 @@ class BinlogIndexTest : public ::testing::TestWithParam { segment->LoadFieldData(load_info); } + const void* + GetQueryData(int num_queries) { + // Generate random query vectors for search + switch (data_type) { + case DataType::VECTOR_FLOAT: + query_float_data = + GenRandomFloatVecData(num_queries, data_d, 999); + return query_float_data.get(); + case DataType::VECTOR_FLOAT16: + query_float16_data = + GenRandomFloat16VecData(num_queries, data_d, 999); + return query_float16_data.get(); + case DataType::VECTOR_BFLOAT16: + query_bfloat16_data = + GenRandomBFloat16VecData(num_queries, data_d, 999); + return query_bfloat16_data.get(); + case DataType::VECTOR_BINARY: + query_binary_data = + GenRandomBinaryVecData(num_queries, data_d, 999); + return query_binary_data.get(); + case DataType::VECTOR_INT8: + query_int8_data = + GenRandomInt8VecData(num_queries, data_d, 999); + return query_int8_data.get(); + case DataType::VECTOR_SPARSE_U32_F32: + query_sparse_data = GenerateRandomSparseFloatVector( + num_queries, kTestSparseDim, kTestSparseVectorDensity, 999); + return query_sparse_data.get(); + default: + throw std::runtime_error("unsupported vector type"); + } + } + + void + VerifyQueryResults(const std::vector& seg_offsets) { + if (seg_offsets.empty()) { + return; + } + + std::vector valid_offsets; + for (auto offset : seg_offsets) { + if (offset >= 0 && offset < static_cast(data_n)) { + valid_offsets.push_back(offset); + } + } + if (valid_offsets.empty()) { + return; + } + + std::sort(valid_offsets.begin(), valid_offsets.end()); + valid_offsets.erase( + std::unique(valid_offsets.begin(), valid_offsets.end()), + valid_offsets.end()); + + auto i64_fid = schema->get_primary_field_id().value(); + + std::vector values; + for (auto offset : valid_offsets) { + proto::plan::GenericValue val; + val.set_int64_val(offset); + values.push_back(val); + } + + auto term_expr = std::make_shared( + milvus::expr::ColumnInfo( + i64_fid, DataType::INT64, std::vector()), + values); + + auto plan = std::make_unique(schema); + plan->plan_node_ = std::make_unique(); + plan->plan_node_->plannodes_ = + milvus::test::CreateRetrievePlanByExpr(term_expr); + std::vector target_fields{vec_field_id}; + plan->field_ids_ = target_fields; + + auto retrieve_results = segment->Retrieve( + nullptr, plan.get(), MAX_TIMESTAMP, DEFAULT_MAX_OUTPUT_SIZE, false); + + ASSERT_TRUE(retrieve_results != nullptr); + EXPECT_EQ(retrieve_results->fields_data_size(), 1); + + auto& field_data = retrieve_results->fields_data(0); + EXPECT_TRUE(field_data.has_vectors()); + + // Verify the number of returned vectors matches the number of valid offsets we queried + size_t returned_count = 0; + switch (data_type) { + case DataType::VECTOR_FLOAT: + returned_count = + field_data.vectors().float_vector().data_size() / data_d; + break; + case DataType::VECTOR_FLOAT16: + returned_count = field_data.vectors().float16_vector().size() / + (data_d * sizeof(milvus::float16)); + break; + case DataType::VECTOR_BFLOAT16: + returned_count = field_data.vectors().bfloat16_vector().size() / + (data_d * sizeof(milvus::bfloat16)); + break; + case DataType::VECTOR_BINARY: + returned_count = + field_data.vectors().binary_vector().size() / (data_d / 8); + break; + case DataType::VECTOR_INT8: + returned_count = + field_data.vectors().int8_vector().size() / data_d; + break; + case DataType::VECTOR_SPARSE_U32_F32: + returned_count = + field_data.vectors().sparse_float_vector().contents_size(); + break; + default: + break; + } + + if (!nullable) { + EXPECT_EQ(returned_count, valid_offsets.size()) + << "Query returned " << returned_count << " vectors, expected " + << valid_offsets.size(); + } + + EXPECT_GT(returned_count, 0) + << "Query should return at least some vectors"; + } + protected: milvus::SchemaPtr schema; knowhere::MetricType metric_type; DataType data_type; std::optional dense_vec_intermin_index_type = std::nullopt; std::string index_type; - size_t data_n = 5000; - size_t data_d = 4; + bool nullable = false; + int null_percent = 0; + size_t data_n = 1000; + size_t data_d = 8; size_t topk = 10; + size_t valid_count = 0; milvus::FieldDataPtr vec_field_data = nullptr; milvus::segcore::SegmentSealedUPtr segment = nullptr; milvus::FieldId vec_field_id; knowhere::DataSetPtr raw_dataset; bool intermin_index_has_raw_data; + std::vector valid_data; + std::vector row_ids; + + std::unique_ptr raw_float_data; + std::unique_ptr raw_float16_data; + std::unique_ptr raw_bfloat16_data; + std::unique_ptr raw_binary_data; + std::unique_ptr raw_int8_data; + std::unique_ptr[]> + raw_sparse_data; + + // Query data (generated randomly for each search) + mutable std::unique_ptr query_float_data; + mutable std::unique_ptr query_float16_data; + mutable std::unique_ptr query_bfloat16_data; + mutable std::unique_ptr query_binary_data; + mutable std::unique_ptr query_int8_data; + mutable std::unique_ptr< + knowhere::sparse::SparseRow[]> + query_sparse_data; }; -INSTANTIATE_TEST_SUITE_P( - MetricTypeParameters, - BinlogIndexTest, - ::testing::Values( - std::make_tuple(DataType::VECTOR_FLOAT, - knowhere::metric::L2, - knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, - knowhere::IndexEnum:: - INDEX_FAISS_IVFFLAT_CC), // intermin index has data - std::make_tuple( - DataType::VECTOR_FLOAT, - knowhere::metric::L2, - knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, - knowhere::IndexEnum:: - INDEX_FAISS_SCANN_DVR), // intermin index not has data - std::make_tuple( - DataType::VECTOR_SPARSE_U32_F32, - knowhere::metric::IP, - knowhere::IndexEnum:: - INDEX_SPARSE_INVERTED_INDEX, //intermin index not has data - std::nullopt), - std::make_tuple(DataType::VECTOR_SPARSE_U32_F32, - knowhere::metric::IP, - knowhere::IndexEnum:: - INDEX_SPARSE_WAND, // intermin index not has data - std::nullopt))); +static std::vector +GenerateTestParams() { + std::vector params; + + std::vector>> + base_configs = { + {DataType::VECTOR_FLOAT, + knowhere::metric::L2, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, + knowhere::IndexEnum:: + INDEX_FAISS_IVFFLAT_CC}, // intermin index has data + {DataType::VECTOR_FLOAT, + knowhere::metric::L2, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, + knowhere::IndexEnum:: + INDEX_FAISS_SCANN_DVR}, // intermin index not has data + {DataType::VECTOR_FLOAT16, + knowhere::metric::L2, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC}, + {DataType::VECTOR_BFLOAT16, + knowhere::metric::L2, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC}, + {DataType::VECTOR_BINARY, + knowhere::metric::HAMMING, + knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, + std::nullopt}, + {DataType::VECTOR_INT8, + knowhere::metric::L2, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC}, + {DataType::VECTOR_SPARSE_U32_F32, + knowhere::metric::IP, + knowhere::IndexEnum:: + INDEX_SPARSE_INVERTED_INDEX, //intermin index not has data + std::nullopt}, + {DataType::VECTOR_SPARSE_U32_F32, + knowhere::metric::IP, + knowhere::IndexEnum:: + INDEX_SPARSE_WAND, // intermin index not has data + std::nullopt}, + }; + + std::vector> null_configs = { + {false, 0}, // non-nullable with 0% null + {true, 0}, // nullable with 0% null + {true, 20}, // nullable with 20% null + {true, 100} // nullable with 100% null + }; + + for (const auto& [data_type, metric, index_type, interim_index] : + base_configs) { + for (const auto& [nullable, null_percent] : null_configs) { + params.push_back(std::make_tuple(data_type, + metric, + index_type, + interim_index, + nullable, + null_percent)); + } + } + return params; +} + +INSTANTIATE_TEST_SUITE_P(MetricTypeParameters, + BinlogIndexTest, + ::testing::ValuesIn(GenerateTestParams())); TEST_P(BinlogIndexTest, AccuracyWithLoadFieldData) { IndexMetaPtr collection_index_meta = GetCollectionIndexMeta(index_type); @@ -219,17 +647,28 @@ TEST_P(BinlogIndexTest, AccuracyWithLoadFieldData) { LoadVectorField(); //assert segment has been built binlog index - EXPECT_TRUE(segment->HasIndex(vec_field_id)); + bool supports_interim_index = + (data_type == DataType::VECTOR_FLOAT || + data_type == DataType::VECTOR_FLOAT16 || + data_type == DataType::VECTOR_BFLOAT16 || + data_type == DataType::VECTOR_SPARSE_U32_F32); + int64_t valid_row_count = nullable ? valid_count : data_n; + int64_t threshold = segcore_config.get_nlist() * 39; + bool should_have_index = + supports_interim_index && (valid_row_count >= threshold); + if (should_have_index) { + EXPECT_TRUE(segment->HasIndex(vec_field_id)); + } else { + EXPECT_FALSE(segment->HasIndex(vec_field_id)); + } EXPECT_EQ(segment->get_row_count(), data_n); - EXPECT_TRUE(segment->HasFieldData(vec_field_id)); // 2. search binlog index auto num_queries = 10; - milvus::proto::plan::PlanNode plan_node; auto vector_anns = plan_node.mutable_vector_anns(); - vector_anns->set_vector_type(milvus::proto::plan::VectorType::FloatVector); + vector_anns->set_vector_type(DataTypeToVectorType(data_type)); vector_anns->set_placeholder_tag("$0"); vector_anns->set_field_id(vec_field_id.get()); auto query_info = vector_anns->mutable_query_info(); @@ -239,13 +678,8 @@ TEST_P(BinlogIndexTest, AccuracyWithLoadFieldData) { query_info->set_search_params(R"({"nprobe": 16})"); auto plan_str = plan_node.SerializeAsString(); - auto ph_group_raw = - data_type == DataType::VECTOR_FLOAT - ? CreatePlaceholderGroupFromBlob( - num_queries, - data_d, - GenRandomFloatVecData(num_queries, data_d).get()) - : CreateSparseFloatPlaceholderGroup(num_queries); + auto ph_group_raw = CreatePlaceholderGroupForVectorType( + data_type, num_queries, data_d, GetQueryData(num_queries)); auto plan = milvus::query::CreateSearchPlanByExpr( schema, plan_str.data(), plan_str.size()); @@ -262,40 +696,206 @@ TEST_P(BinlogIndexTest, AccuracyWithLoadFieldData) { EXPECT_EQ(binlog_index_sr->distances_.size(), num_queries * topk); EXPECT_EQ(binlog_index_sr->seg_offsets_.size(), num_queries * topk); - // 3. update vector index + for (int q = 0; q < num_queries; ++q) { + for (size_t k = 0; k < topk; ++k) { + int64_t seg_offset = binlog_index_sr->seg_offsets_[q * topk + k]; + if (seg_offset == -1) { + continue; // No result for this position + } + ASSERT_GE(seg_offset, 0); + ASSERT_LT(seg_offset, static_cast(data_n)); + + if (nullable) { + bool is_valid = + (valid_data[seg_offset >> 3] >> (seg_offset & 0x07)) & 1; + EXPECT_TRUE(is_valid) + << "Search returned invalid (null) row at seg_offset=" + << seg_offset; + } + } + } + + VerifyQueryResults(binlog_index_sr->seg_offsets_); + { - milvus::index::CreateIndexInfo create_index_info; - create_index_info.field_type = data_type; - create_index_info.metric_type = metric_type; - create_index_info.index_type = index_type; - create_index_info.index_engine_version = - knowhere::Version::GetCurrentVersion().VersionNumber(); - auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, milvus::storage::FileManagerContext()); + milvus::proto::plan::PlanNode filtered_plan_node; + auto filtered_vector_anns = filtered_plan_node.mutable_vector_anns(); + filtered_vector_anns->set_vector_type(DataTypeToVectorType(data_type)); + filtered_vector_anns->set_placeholder_tag("$0"); + filtered_vector_anns->set_field_id(vec_field_id.get()); + auto filtered_query_info = filtered_vector_anns->mutable_query_info(); + filtered_query_info->set_topk(topk); + filtered_query_info->set_round_decimal(3); + filtered_query_info->set_metric_type(metric_type); + filtered_query_info->set_search_params(R"({"nprobe": 16})"); - auto build_conf = - knowhere::Json{{knowhere::meta::METRIC_TYPE, metric_type}, - {knowhere::meta::DIM, std::to_string(data_d)}, - {knowhere::indexparam::NLIST, "64"}}; - indexing->BuildWithDataset(raw_dataset, build_conf); + auto i64_fid = schema->get_primary_field_id().value(); + auto* predicate = filtered_vector_anns->mutable_predicates(); + auto* unary_range = predicate->mutable_unary_range_expr(); + auto* col_info = unary_range->mutable_column_info(); + col_info->set_field_id(i64_fid.get()); + col_info->set_data_type(milvus::proto::schema::DataType::Int64); + unary_range->set_op(milvus::proto::plan::OpType::GreaterEqual); + unary_range->mutable_value()->set_int64_val(data_n / 2); - LoadIndexInfo load_info; - load_info.field_id = vec_field_id.get(); - load_info.index_params = GenIndexParams(indexing.get()); - load_info.cache_index = - CreateTestCacheIndex("test", std::move(indexing)); - load_info.index_params["metric_type"] = metric_type; - ASSERT_NO_THROW(segment->LoadIndex(load_info)); - EXPECT_TRUE(segment->HasIndex(vec_field_id)); - EXPECT_EQ(segment->get_row_count(), data_n); - auto ivf_sr = - segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); - auto similary = GetKnnSearchRecall(num_queries, - binlog_index_sr->seg_offsets_.data(), - topk, - ivf_sr->seg_offsets_.data(), - topk); - ASSERT_GT(similary, 0.45); + auto filtered_plan_str = filtered_plan_node.SerializeAsString(); + auto filtered_plan = milvus::query::CreateSearchPlanByExpr( + schema, filtered_plan_str.data(), filtered_plan_str.size()); + auto filtered_ph_group = ParsePlaceholderGroup( + filtered_plan.get(), ph_group_raw.SerializeAsString()); + + auto filtered_sr = segment->Search( + filtered_plan.get(), filtered_ph_group.get(), MAX_TIMESTAMP); + + ASSERT_EQ(filtered_sr->total_nq_, num_queries); + EXPECT_EQ(filtered_sr->unity_topK_, topk); + + for (size_t i = 0; i < filtered_sr->seg_offsets_.size(); ++i) { + int64_t seg_offset = filtered_sr->seg_offsets_[i]; + if (seg_offset != -1) { + EXPECT_GE(seg_offset, data_n / 2) + << "Filtered search returned row " << seg_offset + << " which should have been filtered (pk < " << data_n / 2 + << ")"; + + if (nullable) { + bool is_valid = + (valid_data[seg_offset >> 3] >> (seg_offset & 0x07)) & + 1; + EXPECT_TRUE(is_valid) << "Filtered search returned invalid " + "(null) row at seg_offset=" + << seg_offset; + } + } + } + } + + if (null_percent != 100 && supports_interim_index) { + { + milvus::index::CreateIndexInfo create_index_info; + create_index_info.field_type = data_type; + create_index_info.metric_type = metric_type; + create_index_info.index_type = index_type; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); + auto indexing = + milvus::index::IndexFactory::GetInstance().CreateIndex( + create_index_info, milvus::storage::FileManagerContext()); + + auto build_conf = + knowhere::Json{{knowhere::meta::METRIC_TYPE, metric_type}, + {knowhere::meta::DIM, std::to_string(data_d)}, + {knowhere::indexparam::NLIST, "64"}}; + + indexing->BuildWithDataset(raw_dataset, build_conf); + + if (nullable) { + auto vec_indexing = + dynamic_cast(indexing.get()); + ASSERT_NE(vec_indexing, nullptr); + std::unique_ptr valid_data_bool(new bool[data_n]); + for (int64_t i = 0; i < data_n; ++i) { + valid_data_bool[i] = (valid_data[i >> 3] >> (i & 0x07)) & 1; + } + vec_indexing->UpdateValidData(valid_data_bool.get(), data_n); + } + + LoadIndexInfo load_info; + load_info.field_id = vec_field_id.get(); + load_info.index_params = GenIndexParams(indexing.get()); + load_info.cache_index = + CreateTestCacheIndex("test", std::move(indexing)); + load_info.index_params["metric_type"] = metric_type; + + ASSERT_NO_THROW(segment->LoadIndex(load_info)); + + EXPECT_TRUE(segment->HasIndex(vec_field_id)); + EXPECT_EQ(segment->get_row_count(), data_n); + + std::unique_ptr ivf_sr; + try { + ivf_sr = + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); + } catch (const std::exception& e) { + throw; + } + + ASSERT_EQ(ivf_sr->total_nq_, num_queries); + EXPECT_EQ(ivf_sr->unity_topK_, topk); + EXPECT_EQ(ivf_sr->distances_.size(), num_queries * topk); + EXPECT_EQ(ivf_sr->seg_offsets_.size(), num_queries * topk); + + auto similary = + GetKnnSearchRecall(num_queries, + binlog_index_sr->seg_offsets_.data(), + topk, + ivf_sr->seg_offsets_.data(), + topk); + ASSERT_GT(similary, 0.45); + + VerifyQueryResults(ivf_sr->seg_offsets_); + + { + milvus::proto::plan::PlanNode ivf_filtered_plan_node; + auto ivf_filtered_anns = + ivf_filtered_plan_node.mutable_vector_anns(); + ivf_filtered_anns->set_vector_type( + DataTypeToVectorType(data_type)); + ivf_filtered_anns->set_placeholder_tag("$0"); + ivf_filtered_anns->set_field_id(vec_field_id.get()); + auto ivf_filtered_info = + ivf_filtered_anns->mutable_query_info(); + ivf_filtered_info->set_topk(topk); + ivf_filtered_info->set_round_decimal(3); + ivf_filtered_info->set_metric_type(metric_type); + ivf_filtered_info->set_search_params(R"({"nprobe": 16})"); + + // Add filter: pk >= data_n/2 + auto i64_fid = schema->get_primary_field_id().value(); + auto* predicate = ivf_filtered_anns->mutable_predicates(); + auto* unary_range = predicate->mutable_unary_range_expr(); + auto* col_info = unary_range->mutable_column_info(); + col_info->set_field_id(i64_fid.get()); + col_info->set_data_type(milvus::proto::schema::DataType::Int64); + unary_range->set_op(milvus::proto::plan::OpType::GreaterEqual); + unary_range->mutable_value()->set_int64_val(data_n / 2); + + auto ivf_filtered_str = + ivf_filtered_plan_node.SerializeAsString(); + auto ivf_filtered_plan = milvus::query::CreateSearchPlanByExpr( + schema, ivf_filtered_str.data(), ivf_filtered_str.size()); + auto ivf_filtered_ph = ParsePlaceholderGroup( + ivf_filtered_plan.get(), ph_group_raw.SerializeAsString()); + + auto ivf_filtered_sr = segment->Search(ivf_filtered_plan.get(), + ivf_filtered_ph.get(), + MAX_TIMESTAMP); + + ASSERT_EQ(ivf_filtered_sr->total_nq_, num_queries); + EXPECT_EQ(ivf_filtered_sr->unity_topK_, topk); + + // Verify all returned offsets are >= data_n/2 + for (size_t i = 0; i < ivf_filtered_sr->seg_offsets_.size(); + ++i) { + int64_t seg_offset = ivf_filtered_sr->seg_offsets_[i]; + if (seg_offset != -1) { + EXPECT_GE(seg_offset, data_n / 2) + << "IVF filtered search returned row " << seg_offset + << " which should have been filtered"; + + if (nullable) { + bool is_valid = (valid_data[seg_offset >> 3] >> + (seg_offset & 0x07)) & + 1; + EXPECT_TRUE(is_valid) + << "IVF filtered search returned invalid row " + "at seg_offset=" + << seg_offset; + } + } + } + } + } } } @@ -316,16 +916,32 @@ TEST_P(BinlogIndexTest, AccuracyWithMapFieldData) { LoadVectorField("./data/mmap-test"); //assert segment has been built binlog index - EXPECT_TRUE(segment->HasIndex(vec_field_id)); + bool supports_interim_index = + (data_type == DataType::VECTOR_FLOAT || + data_type == DataType::VECTOR_FLOAT16 || + data_type == DataType::VECTOR_BFLOAT16 || + data_type == DataType::VECTOR_SPARSE_U32_F32); + int64_t valid_row_count = nullable ? valid_count : data_n; + int64_t threshold = segcore_config.get_nlist() * 39; + bool should_have_index = + supports_interim_index && (valid_row_count >= threshold); + if (should_have_index) { + EXPECT_TRUE(segment->HasIndex(vec_field_id)); + } else { + EXPECT_FALSE(segment->HasIndex(vec_field_id)); + } EXPECT_EQ(segment->get_row_count(), data_n); EXPECT_TRUE(segment->HasFieldData(vec_field_id)); // 2. search binlog index - auto num_queries = 10; + auto num_queries = std::min(10, (int)valid_count); + if (num_queries == 0) { + return; + } milvus::proto::plan::PlanNode plan_node; auto vector_anns = plan_node.mutable_vector_anns(); - vector_anns->set_vector_type(milvus::proto::plan::VectorType::FloatVector); + vector_anns->set_vector_type(DataTypeToVectorType(data_type)); vector_anns->set_placeholder_tag("$0"); vector_anns->set_field_id(vec_field_id.get()); @@ -336,13 +952,9 @@ TEST_P(BinlogIndexTest, AccuracyWithMapFieldData) { query_info->set_search_params(R"({"nprobe": 16})"); auto plan_str = plan_node.SerializeAsString(); - auto ph_group_raw = - data_type == DataType::VECTOR_FLOAT - ? CreatePlaceholderGroupFromBlob( - num_queries, - data_d, - GenRandomFloatVecData(num_queries, data_d).get()) - : CreateSparseFloatPlaceholderGroup(num_queries); + // Use the first num_queries vectors from raw data as queries + auto ph_group_raw = CreatePlaceholderGroupForVectorType( + data_type, num_queries, data_d, GetQueryData(num_queries)); auto plan = milvus::query::CreateSearchPlanByExpr( schema, plan_str.data(), plan_str.size()); @@ -359,40 +971,196 @@ TEST_P(BinlogIndexTest, AccuracyWithMapFieldData) { EXPECT_EQ(binlog_index_sr->distances_.size(), num_queries * topk); EXPECT_EQ(binlog_index_sr->seg_offsets_.size(), num_queries * topk); - // 3. update vector index + for (int q = 0; q < num_queries; ++q) { + for (size_t k = 0; k < topk; ++k) { + int64_t seg_offset = binlog_index_sr->seg_offsets_[q * topk + k]; + if (seg_offset == -1) { + continue; + } + ASSERT_GE(seg_offset, 0); + ASSERT_LT(seg_offset, static_cast(data_n)); + + if (nullable) { + bool is_valid = + (valid_data[seg_offset >> 3] >> (seg_offset & 0x07)) & 1; + EXPECT_TRUE(is_valid) + << "Search returned invalid (null) row at seg_offset=" + << seg_offset; + } + } + } + + VerifyQueryResults(binlog_index_sr->seg_offsets_); + { - milvus::index::CreateIndexInfo create_index_info; - create_index_info.field_type = data_type; - create_index_info.metric_type = metric_type; - create_index_info.index_type = index_type; - create_index_info.index_engine_version = - knowhere::Version::GetCurrentVersion().VersionNumber(); - auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, milvus::storage::FileManagerContext()); + milvus::proto::plan::PlanNode filtered_plan_node; + auto filtered_vector_anns = filtered_plan_node.mutable_vector_anns(); + filtered_vector_anns->set_vector_type(DataTypeToVectorType(data_type)); + filtered_vector_anns->set_placeholder_tag("$0"); + filtered_vector_anns->set_field_id(vec_field_id.get()); + auto filtered_query_info = filtered_vector_anns->mutable_query_info(); + filtered_query_info->set_topk(topk); + filtered_query_info->set_round_decimal(3); + filtered_query_info->set_metric_type(metric_type); + filtered_query_info->set_search_params(R"({"nprobe": 16})"); - auto build_conf = - knowhere::Json{{knowhere::meta::METRIC_TYPE, metric_type}, - {knowhere::meta::DIM, std::to_string(data_d)}, - {knowhere::indexparam::NLIST, "64"}}; - indexing->BuildWithDataset(raw_dataset, build_conf); + auto i64_fid = schema->get_primary_field_id().value(); + auto* predicate = filtered_vector_anns->mutable_predicates(); + auto* unary_range = predicate->mutable_unary_range_expr(); + auto* col_info = unary_range->mutable_column_info(); + col_info->set_field_id(i64_fid.get()); + col_info->set_data_type(milvus::proto::schema::DataType::Int64); + unary_range->set_op(milvus::proto::plan::OpType::GreaterEqual); + unary_range->mutable_value()->set_int64_val(data_n / 2); - LoadIndexInfo load_info; - load_info.field_id = vec_field_id.get(); - load_info.index_params = GenIndexParams(indexing.get()); - load_info.cache_index = - CreateTestCacheIndex("test", std::move(indexing)); - load_info.index_params["metric_type"] = metric_type; - ASSERT_NO_THROW(segment->LoadIndex(load_info)); - EXPECT_TRUE(segment->HasIndex(vec_field_id)); - EXPECT_EQ(segment->get_row_count(), data_n); - auto ivf_sr = - segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); - auto similary = GetKnnSearchRecall(num_queries, - binlog_index_sr->seg_offsets_.data(), - topk, - ivf_sr->seg_offsets_.data(), - topk); - ASSERT_GT(similary, 0.45); + auto filtered_plan_str = filtered_plan_node.SerializeAsString(); + auto filtered_plan = milvus::query::CreateSearchPlanByExpr( + schema, filtered_plan_str.data(), filtered_plan_str.size()); + auto filtered_ph_group = ParsePlaceholderGroup( + filtered_plan.get(), ph_group_raw.SerializeAsString()); + + auto filtered_sr = segment->Search( + filtered_plan.get(), filtered_ph_group.get(), MAX_TIMESTAMP); + + ASSERT_EQ(filtered_sr->total_nq_, num_queries); + EXPECT_EQ(filtered_sr->unity_topK_, topk); + + for (size_t i = 0; i < filtered_sr->seg_offsets_.size(); ++i) { + int64_t seg_offset = filtered_sr->seg_offsets_[i]; + if (seg_offset != -1) { + EXPECT_GE(seg_offset, data_n / 2) + << "Filtered search returned row " << seg_offset + << " which should have been filtered (pk < " << data_n / 2 + << ")"; + + if (nullable) { + bool is_valid = + (valid_data[seg_offset >> 3] >> (seg_offset & 0x07)) & + 1; + EXPECT_TRUE(is_valid) << "Filtered search returned invalid " + "(null) row at seg_offset=" + << seg_offset; + } + } + } + } + + if (null_percent != 100 && supports_interim_index) { + // 3. update vector index + { + milvus::index::CreateIndexInfo create_index_info; + create_index_info.field_type = data_type; + create_index_info.metric_type = metric_type; + create_index_info.index_type = index_type; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); + auto indexing = + milvus::index::IndexFactory::GetInstance().CreateIndex( + create_index_info, milvus::storage::FileManagerContext()); + + auto build_conf = + knowhere::Json{{knowhere::meta::METRIC_TYPE, metric_type}, + {knowhere::meta::DIM, std::to_string(data_d)}, + {knowhere::indexparam::NLIST, "64"}}; + + indexing->BuildWithDataset(raw_dataset, build_conf); + + if (nullable) { + auto vec_indexing = + dynamic_cast(indexing.get()); + ASSERT_NE(vec_indexing, nullptr); + std::unique_ptr valid_data_bool(new bool[data_n]); + for (int64_t i = 0; i < data_n; ++i) { + valid_data_bool[i] = (valid_data[i >> 3] >> (i & 0x07)) & 1; + } + vec_indexing->UpdateValidData(valid_data_bool.get(), data_n); + } + + LoadIndexInfo load_info; + load_info.field_id = vec_field_id.get(); + load_info.index_params = GenIndexParams(indexing.get()); + load_info.cache_index = + CreateTestCacheIndex("test", std::move(indexing)); + load_info.index_params["metric_type"] = metric_type; + ASSERT_NO_THROW(segment->LoadIndex(load_info)); + EXPECT_TRUE(segment->HasIndex(vec_field_id)); + EXPECT_EQ(segment->get_row_count(), data_n); + auto ivf_sr = + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); + ASSERT_EQ(ivf_sr->total_nq_, num_queries); + EXPECT_EQ(ivf_sr->unity_topK_, topk); + EXPECT_EQ(ivf_sr->distances_.size(), num_queries * topk); + EXPECT_EQ(ivf_sr->seg_offsets_.size(), num_queries * topk); + + auto similary = + GetKnnSearchRecall(num_queries, + binlog_index_sr->seg_offsets_.data(), + topk, + ivf_sr->seg_offsets_.data(), + topk); + ASSERT_GT(similary, 0.45); + + VerifyQueryResults(ivf_sr->seg_offsets_); + + { + milvus::proto::plan::PlanNode ivf_filtered_plan_node; + auto ivf_filtered_anns = + ivf_filtered_plan_node.mutable_vector_anns(); + ivf_filtered_anns->set_vector_type( + DataTypeToVectorType(data_type)); + ivf_filtered_anns->set_placeholder_tag("$0"); + ivf_filtered_anns->set_field_id(vec_field_id.get()); + auto ivf_filtered_info = + ivf_filtered_anns->mutable_query_info(); + ivf_filtered_info->set_topk(topk); + ivf_filtered_info->set_round_decimal(3); + ivf_filtered_info->set_metric_type(metric_type); + ivf_filtered_info->set_search_params(R"({"nprobe": 16})"); + + auto i64_fid = schema->get_primary_field_id().value(); + auto* predicate = ivf_filtered_anns->mutable_predicates(); + auto* unary_range = predicate->mutable_unary_range_expr(); + auto* col_info = unary_range->mutable_column_info(); + col_info->set_field_id(i64_fid.get()); + col_info->set_data_type(milvus::proto::schema::DataType::Int64); + unary_range->set_op(milvus::proto::plan::OpType::GreaterEqual); + unary_range->mutable_value()->set_int64_val(data_n / 2); + + auto ivf_filtered_str = + ivf_filtered_plan_node.SerializeAsString(); + auto ivf_filtered_plan = milvus::query::CreateSearchPlanByExpr( + schema, ivf_filtered_str.data(), ivf_filtered_str.size()); + auto ivf_filtered_ph = ParsePlaceholderGroup( + ivf_filtered_plan.get(), ph_group_raw.SerializeAsString()); + + auto ivf_filtered_sr = segment->Search(ivf_filtered_plan.get(), + ivf_filtered_ph.get(), + MAX_TIMESTAMP); + + ASSERT_EQ(ivf_filtered_sr->total_nq_, num_queries); + EXPECT_EQ(ivf_filtered_sr->unity_topK_, topk); + + for (size_t i = 0; i < ivf_filtered_sr->seg_offsets_.size(); + ++i) { + int64_t seg_offset = ivf_filtered_sr->seg_offsets_[i]; + if (seg_offset != -1) { + EXPECT_GE(seg_offset, data_n / 2) + << "IVF filtered search returned row " << seg_offset + << " which should have been filtered"; + + if (nullable) { + bool is_valid = (valid_data[seg_offset >> 3] >> + (seg_offset & 0x07)) & + 1; + EXPECT_TRUE(is_valid) + << "IVF filtered search returned invalid row " + "at seg_offset=" + << seg_offset; + } + } + } + } + } } } @@ -408,32 +1176,98 @@ TEST_P(BinlogIndexTest, DisableInterimIndex) { EXPECT_FALSE(segment->HasIndex(vec_field_id)); EXPECT_EQ(segment->get_row_count(), data_n); EXPECT_TRUE(segment->HasFieldData(vec_field_id)); - // load vector index - milvus::index::CreateIndexInfo create_index_info; - create_index_info.field_type = data_type; - create_index_info.metric_type = metric_type; - create_index_info.index_type = index_type; - create_index_info.index_engine_version = - knowhere::Version::GetCurrentVersion().VersionNumber(); - auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex( - create_index_info, milvus::storage::FileManagerContext()); - auto build_conf = - knowhere::Json{{knowhere::meta::METRIC_TYPE, metric_type}, - {knowhere::meta::DIM, std::to_string(data_d)}, - {knowhere::indexparam::NLIST, "64"}}; + bool supports_final_index = (data_type == DataType::VECTOR_FLOAT || + data_type == DataType::VECTOR_FLOAT16 || + data_type == DataType::VECTOR_BFLOAT16 || + data_type == DataType::VECTOR_SPARSE_U32_F32); - indexing->BuildWithDataset(raw_dataset, build_conf); + if (null_percent != 100 && supports_final_index) { + // load vector index + milvus::index::CreateIndexInfo create_index_info; + create_index_info.field_type = data_type; + create_index_info.metric_type = metric_type; + create_index_info.index_type = index_type; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); + auto indexing = milvus::index::IndexFactory::GetInstance().CreateIndex( + create_index_info, milvus::storage::FileManagerContext()); - LoadIndexInfo load_info; - load_info.field_id = vec_field_id.get(); - load_info.index_params = GenIndexParams(indexing.get()); - load_info.cache_index = CreateTestCacheIndex("test", std::move(indexing)); - load_info.index_params["metric_type"] = metric_type; + auto build_conf = + knowhere::Json{{knowhere::meta::METRIC_TYPE, metric_type}, + {knowhere::meta::DIM, std::to_string(data_d)}, + {knowhere::indexparam::NLIST, "64"}}; - ASSERT_NO_THROW(segment->LoadIndex(load_info)); - EXPECT_TRUE(segment->HasIndex(vec_field_id)); - EXPECT_EQ(segment->get_row_count(), data_n); + indexing->BuildWithDataset(raw_dataset, build_conf); + + if (nullable) { + auto vec_indexing = + dynamic_cast(indexing.get()); + ASSERT_NE(vec_indexing, nullptr); + std::unique_ptr valid_data_bool(new bool[data_n]); + for (int64_t i = 0; i < data_n; ++i) { + valid_data_bool[i] = (valid_data[i >> 3] >> (i & 0x07)) & 1; + } + vec_indexing->UpdateValidData(valid_data_bool.get(), data_n); + } + + LoadIndexInfo load_info; + load_info.field_id = vec_field_id.get(); + load_info.index_params = GenIndexParams(indexing.get()); + load_info.cache_index = + CreateTestCacheIndex("test", std::move(indexing)); + load_info.index_params["metric_type"] = metric_type; + + ASSERT_NO_THROW(segment->LoadIndex(load_info)); + EXPECT_TRUE(segment->HasIndex(vec_field_id)); + EXPECT_EQ(segment->get_row_count(), data_n); + + auto num_queries = std::min(10, (int)valid_count); + if (num_queries > 0) { + milvus::proto::plan::PlanNode plan_node; + auto vector_anns = plan_node.mutable_vector_anns(); + vector_anns->set_vector_type(DataTypeToVectorType(data_type)); + vector_anns->set_placeholder_tag("$0"); + vector_anns->set_field_id(vec_field_id.get()); + auto query_info = vector_anns->mutable_query_info(); + query_info->set_topk(topk); + query_info->set_round_decimal(3); + query_info->set_metric_type(metric_type); + query_info->set_search_params(R"({"nprobe": 16})"); + auto plan_str = plan_node.SerializeAsString(); + + auto ph_group_raw = CreatePlaceholderGroupForVectorType( + data_type, num_queries, data_d, GetQueryData(num_queries)); + + auto plan = milvus::query::CreateSearchPlanByExpr( + schema, plan_str.data(), plan_str.size()); + auto ph_group = ParsePlaceholderGroup( + plan.get(), ph_group_raw.SerializeAsString()); + + auto sr = + segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); + ASSERT_EQ(sr->total_nq_, num_queries); + EXPECT_EQ(sr->unity_topK_, topk); + + for (size_t i = 0; i < sr->seg_offsets_.size(); ++i) { + int64_t seg_offset = sr->seg_offsets_[i]; + if (seg_offset != -1) { + ASSERT_GE(seg_offset, 0); + ASSERT_LT(seg_offset, static_cast(data_n)); + if (nullable) { + bool is_valid = (valid_data[seg_offset >> 3] >> + (seg_offset & 0x07)) & + 1; + EXPECT_TRUE(is_valid) + << "Search returned invalid row at seg_offset=" + << seg_offset; + } + } + } + + VerifyQueryResults(sr->seg_offsets_); + } + } } TEST_P(BinlogIndexTest, LoadBingLogWihIDMAP) { diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp index 5a4b9a22eb..820671f125 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp @@ -918,6 +918,61 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, } } +ChunkedSegmentSealedImpl::ValidResult +ChunkedSegmentSealedImpl::FilterVectorValidOffsets(milvus::OpContext* op_ctx, + FieldId field_id, + const int64_t* seg_offsets, + int64_t count) const { + ValidResult result; + result.valid_count = count; + + if (vector_indexings_.is_ready(field_id)) { + auto field_indexing = vector_indexings_.get_field_indexing(field_id); + auto cache_index = field_indexing->indexing_; + auto ca = SemiInlineGet(cache_index->PinCells(op_ctx, {0})); + auto vec_index = dynamic_cast(ca->get_cell_of(0)); + + if (vec_index != nullptr && vec_index->HasValidData()) { + result.valid_data = std::make_unique(count); + result.valid_offsets.reserve(count); + + for (int64_t i = 0; i < count; ++i) { + bool is_valid = vec_index->IsRowValid(seg_offsets[i]); + result.valid_data[i] = is_valid; + if (is_valid) { + int64_t physical_offset = + vec_index->GetPhysicalOffset(seg_offsets[i]); + if (physical_offset >= 0) { + result.valid_offsets.push_back(physical_offset); + } + } + } + result.valid_count = result.valid_offsets.size(); + } + } else { + auto column = get_column(field_id); + if (column != nullptr && column->IsNullable()) { + result.valid_data = std::make_unique(count); + result.valid_offsets.reserve(count); + + const auto& offset_mapping = column->GetOffsetMapping(); + for (int64_t i = 0; i < count; ++i) { + bool is_valid = offset_mapping.IsValid(seg_offsets[i]); + result.valid_data[i] = is_valid; + if (is_valid) { + int64_t physical_offset = + offset_mapping.GetPhysicalOffset(seg_offsets[i]); + if (physical_offset >= 0) { + result.valid_offsets.push_back(physical_offset); + } + } + } + result.valid_count = result.valid_offsets.size(); + } + } + return result; +} + std::unique_ptr ChunkedSegmentSealedImpl::get_vector(milvus::OpContext* op_ctx, FieldId field_id, @@ -945,16 +1000,29 @@ ChunkedSegmentSealedImpl::get_vector(milvus::OpContext* op_ctx, if (has_raw_data) { // If index has raw data, get vector from memory. - auto ids_ds = GenIdsDataset(count, ids); + ValidResult filter_result; + knowhere::DataSetPtr ids_ds; + int64_t valid_count = count; + const bool* valid_data = nullptr; + if (field_meta.is_nullable()) { + filter_result = + FilterVectorValidOffsets(op_ctx, field_id, ids, count); + ids_ds = GenIdsDataset(filter_result.valid_count, + filter_result.valid_offsets.data()); + valid_count = filter_result.valid_count; + valid_data = filter_result.valid_data.get(); + } else { + ids_ds = GenIdsDataset(count, ids); + } if (field_meta.get_data_type() == DataType::VECTOR_SPARSE_U32_F32) { auto res = vec_index->GetSparseVector(ids_ds); return segcore::CreateVectorDataArrayFrom( - res.get(), count, field_meta); + res.get(), valid_data, count, valid_count, field_meta); } else { // dense vector: auto vector = vec_index->GetVector(ids_ds); return segcore::CreateVectorDataArrayFrom( - vector.data(), count, field_meta); + vector.data(), valid_data, count, valid_count, field_meta); } } @@ -1529,10 +1597,13 @@ ChunkedSegmentSealedImpl::ClearData() { std::unique_ptr ChunkedSegmentSealedImpl::fill_with_empty(FieldId field_id, - int64_t count) const { + int64_t count, + int64_t valid_count, + const void* valid_data) const { auto& field_meta = schema_->operator[](field_id); if (IsVectorDataType(field_meta.get_data_type())) { - return CreateEmptyVectorDataArray(count, field_meta); + return CreateEmptyVectorDataArray( + count, valid_count, valid_data, field_meta); } return CreateEmptyScalarDataArray(count, field_meta); } @@ -1682,8 +1753,22 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx, AssertInfo(column != nullptr, "field {} must exist when getting raw data", field_id.get()); - auto ret = fill_with_empty(field_id, count); - if (column->IsNullable()) { + + int64_t valid_count = count; + const bool* valid_data = nullptr; + const int64_t* valid_offsets = seg_offsets; + ValidResult filter_result; + + if (field_meta.is_vector() && field_meta.is_nullable()) { + filter_result = + FilterVectorValidOffsets(op_ctx, field_id, seg_offsets, count); + valid_count = filter_result.valid_count; + valid_data = filter_result.valid_data.get(); + valid_offsets = filter_result.valid_offsets.data(); + } + auto ret = fill_with_empty(field_id, count, valid_count, valid_data); + + if (!field_meta.is_vector() && column->IsNullable()) { auto dst = ret->mutable_valid_data()->mutable_data(); column->BulkIsValid( op_ctx, @@ -1691,6 +1776,7 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx, seg_offsets, count); } + switch (field_meta.get_data_type()) { case DataType::VARCHAR: case DataType::STRING: @@ -1827,8 +1913,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx, bulk_subscript_impl(op_ctx, field_meta.get_sizeof(), column.get(), - seg_offsets, - count, + valid_offsets, + valid_count, ret->mutable_vectors() ->mutable_float_vector() ->mutable_data() @@ -1840,8 +1926,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx, op_ctx, field_meta.get_sizeof(), column.get(), - seg_offsets, - count, + valid_offsets, + valid_count, ret->mutable_vectors()->mutable_float16_vector()->data()); break; } @@ -1850,8 +1936,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx, op_ctx, field_meta.get_sizeof(), column.get(), - seg_offsets, - count, + valid_offsets, + valid_count, ret->mutable_vectors()->mutable_bfloat16_vector()->data()); break; } @@ -1860,8 +1946,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx, op_ctx, field_meta.get_sizeof(), column.get(), - seg_offsets, - count, + valid_offsets, + valid_count, ret->mutable_vectors()->mutable_binary_vector()->data()); break; } @@ -1870,8 +1956,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx, op_ctx, field_meta.get_sizeof(), column.get(), - seg_offsets, - count, + valid_offsets, + valid_count, ret->mutable_vectors()->mutable_int8_vector()->data()); break; } @@ -1881,7 +1967,7 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx, column->BulkValueAt( op_ctx, [&](const char* value, size_t i) mutable { - auto offset = seg_offsets[i]; + auto offset = valid_offsets[i]; auto row = offset != INVALID_SEG_OFFSET ? static_castdim()); dst->add_contents(row->data(), row->data_byte_size()); }, - seg_offsets, - count); + valid_offsets, + valid_count); dst->set_dim(max_dim); ret->mutable_vectors()->set_dim(dst->dim()); break; @@ -1905,8 +1991,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx, bulk_subscript_vector_array_impl( op_ctx, column.get(), - seg_offsets, - count, + valid_offsets, + valid_count, ret->mutable_vectors()->mutable_vector_array()->mutable_data()); break; } @@ -2267,7 +2353,12 @@ ChunkedSegmentSealedImpl::generate_interim_index(const FieldId field_id, return false; } try { - int64_t row_count = num_rows; + std::shared_ptr vec_data = get_column(field_id); + AssertInfo( + vec_data != nullptr, "vector field {} not loaded", field_id.get()); + int64_t row_count = field_meta.is_nullable() + ? vec_data->GetOffsetMapping().GetValidCount() + : num_rows; // generate index params auto field_binlog_config = std::unique_ptr( @@ -2279,9 +2370,6 @@ ChunkedSegmentSealedImpl::generate_interim_index(const FieldId field_id, if (row_count < field_binlog_config->GetBuildThreshold()) { return false; } - std::shared_ptr vec_data = get_column(field_id); - AssertInfo( - vec_data != nullptr, "vector field {} not loaded", field_id.get()); auto dim = is_sparse ? std::numeric_limits::max() : field_meta.get_dim(); auto interim_index_type = field_binlog_config->GetIndexType(); @@ -2397,6 +2485,10 @@ ChunkedSegmentSealedImpl::load_field_data_common( return; } + if (column->IsNullable() && IsVectorDataType(data_type)) { + column->BuildValidRowIds(nullptr); + } + if (!enable_mmap) { if (!is_proxy_column || is_proxy_column && @@ -2517,12 +2609,7 @@ ChunkedSegmentSealedImpl::Reopen(SchemaPtr sch) { auto absent_fields = sch->AbsentFields(*schema_); for (const auto& field_meta : *absent_fields) { - // vector field is not supported to be "added field", thus if a vector - // field is absent, it means for some reason we want to skip loading this - // field. - if (!IsVectorDataType(field_meta.get_data_type())) { - fill_empty_field(field_meta); - } + fill_empty_field(field_meta); } schema_ = sch; @@ -2597,10 +2684,6 @@ ChunkedSegmentSealedImpl::FinishLoad() { // no filling fields that index already loaded and has raw data continue; } - if (IsVectorDataType(field_meta.get_data_type())) { - // no filling vector fields - continue; - } fill_empty_field(field_meta); } } @@ -2608,10 +2691,11 @@ ChunkedSegmentSealedImpl::FinishLoad() { void ChunkedSegmentSealedImpl::fill_empty_field(const FieldMeta& field_meta) { auto field_id = field_meta.get_id(); + auto data_type = field_meta.get_data_type(); LOG_INFO( "start fill empty field {} (data type {}) for sealed segment " "{}", - field_meta.get_data_type(), + data_type, field_id.get(), id_); int64_t size = num_rows_.value(); @@ -2620,40 +2704,11 @@ ChunkedSegmentSealedImpl::fill_empty_field(const FieldMeta& field_meta) { std::unique_ptr> translator = std::make_unique( get_segment_id(), field_meta, field_data_info, false); - std::shared_ptr column{}; - switch (field_meta.get_data_type()) { - case milvus::DataType::STRING: - case milvus::DataType::VARCHAR: - case milvus::DataType::TEXT: { - column = std::make_shared>( - std::move(translator), field_meta); - break; - } - case milvus::DataType::JSON: { - column = std::make_shared>( - std::move(translator), field_meta); - break; - } - case milvus::DataType::GEOMETRY: { - column = std::make_shared>( - std::move(translator), field_meta); - break; - } - case milvus::DataType::ARRAY: { - column = std::make_shared(std::move(translator), - field_meta); - break; - } - case milvus::DataType::VECTOR_ARRAY: { - column = std::make_shared( - std::move(translator), field_meta); - break; - } - default: { - column = std::make_shared(std::move(translator), - field_meta); - break; - } + auto column = + MakeChunkedColumnBase(data_type, std::move(translator), field_meta); + + if (column->IsNullable() && IsVectorDataType(data_type)) { + column->BuildValidRowIds(nullptr); } fields_.wlock()->emplace(field_id, column); diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.h b/internal/core/src/segcore/ChunkedSegmentSealedImpl.h index 5087483037..4b311045f0 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.h +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.h @@ -880,7 +880,10 @@ class ChunkedSegmentSealedImpl : public SegmentSealed { google::protobuf::RepeatedPtrField* dst); std::unique_ptr - fill_with_empty(FieldId field_id, int64_t count) const; + fill_with_empty(FieldId field_id, + int64_t count, + int64_t valid_count = 0, + const void* valid_data = nullptr) const; std::unique_ptr get_raw_data(milvus::OpContext* op_ctx, @@ -889,6 +892,18 @@ class ChunkedSegmentSealedImpl : public SegmentSealed { const int64_t* seg_offsets, int64_t count) const; + struct ValidResult { + int64_t valid_count = 0; + std::unique_ptr valid_data; + std::vector valid_offsets; + }; + + ValidResult + FilterVectorValidOffsets(milvus::OpContext* op_ctx, + FieldId field_id, + const int64_t* seg_offsets, + int64_t count) const; + void update_row_count(int64_t row_count) { num_rows_ = row_count; diff --git a/internal/core/src/segcore/ConcurrentVector.h b/internal/core/src/segcore/ConcurrentVector.h index 2eaa463d95..58dc42c0d4 100644 --- a/internal/core/src/segcore/ConcurrentVector.h +++ b/internal/core/src/segcore/ConcurrentVector.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +30,7 @@ #include "common/FieldMeta.h" #include "common/FieldData.h" #include "common/Json.h" +#include "common/OffsetMapping.h" #include "common/Span.h" #include "common/Types.h" #include "common/Utils.h" @@ -79,19 +81,30 @@ class ThreadSafeValidData { } bool - is_valid(size_t offset) { + is_valid(size_t offset) const { std::shared_lock lck(mutex_); - Assert(offset < length_); + AssertInfo(offset < length_, + "offset out of range, offset={}, length_={}", + offset, + length_); return data_[offset]; } bool* get_chunk_data(size_t offset) { std::shared_lock lck(mutex_); - Assert(offset < length_); + AssertInfo(offset < length_, + "offset out of range, offset={}, length_={}", + offset, + length_); return &data_[offset]; } + const FixedVector& + get_data() const { + return data_; + } + private: mutable std::shared_mutex mutex_{}; FixedVector data_; @@ -155,6 +168,40 @@ class VectorBase { virtual void clear() = 0; + virtual bool + is_mapping_storage() const { + return false; + } + + // Get physical offset from logical offset. Returns -1 if not found. + virtual int64_t + get_physical_offset(int64_t logical_offset) const { + return logical_offset; // default: no mapping + } + + // Get logical offset from physical offset. Returns -1 if not found. + virtual int64_t + get_logical_offset(int64_t physical_offset) const { + return physical_offset; // default: no mapping + } + + virtual int64_t + get_valid_count() const { + return 0; + } + + virtual const FixedVector& + get_valid_data() const { + static const FixedVector empty; + return empty; + } + + virtual const OffsetMapping& + get_offset_mapping() const { + static const OffsetMapping empty; + return empty; + } + protected: const int64_t size_per_chunk_; }; @@ -191,10 +238,12 @@ class ConcurrentVectorImpl : public VectorBase { ssize_t elements_per_row, int64_t size_per_chunk, storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr, - ThreadSafeValidDataPtr valid_data_ptr = nullptr) + ThreadSafeValidDataPtr valid_data_ptr = nullptr, + bool use_mapping_storage = false) : VectorBase(size_per_chunk), elements_per_row_(is_type_entire_row ? 1 : elements_per_row), - valid_data_ptr_(valid_data_ptr) { + valid_data_ptr_(valid_data_ptr), + use_mapping_storage_(use_mapping_storage) { chunks_ptr_ = SelectChunkVectorPtr(mmap_descriptor); } @@ -221,19 +270,7 @@ class ConcurrentVectorImpl : public VectorBase { void fill_chunk_data(const std::vector& datas) override { AssertInfo(chunks_ptr_->size() == 0, "non empty concurrent vector"); - - int64_t element_count = 0; - for (auto& field_data : datas) { - element_count += field_data->get_num_rows(); - } - chunks_ptr_->emplace_to_at_least(1, elements_per_row_ * element_count); - int64_t offset = 0; - for (auto& field_data : datas) { - auto num_rows = field_data->get_num_rows(); - set_data( - offset, static_cast(field_data->Data()), num_rows); - offset += num_rows; - } + set_data_raw(0, datas); } void @@ -250,16 +287,39 @@ class ConcurrentVectorImpl : public VectorBase { set_data_raw(ssize_t element_offset, const void* source, ssize_t element_count) override { - if (element_count == 0) { - return; + ssize_t valid_count = 0; + ssize_t storage_offset = 0; + if (use_mapping_storage_) { + if constexpr (!std::is_same_v) { + storage_offset = offset_mapping_.GetNextPhysicalOffset(); + // Build valid_data array for offset mapping + std::unique_ptr valid_data(new bool[element_count]); + for (ssize_t i = 0; i < element_count; ++i) { + bool is_valid = + valid_data_ptr_->is_valid(element_offset + i); + valid_data[i] = is_valid; + if (is_valid) { + valid_count++; + } + } + offset_mapping_.BuildIncremental(valid_data.get(), + element_count, + element_offset, + storage_offset); + } + } else { + valid_count = element_count; + storage_offset = element_offset; + } + if (valid_count > 0) { + auto size = size_per_chunk_ == MAX_ROW_COUNT ? valid_count + : size_per_chunk_; + chunks_ptr_->emplace_to_at_least( + upper_div(storage_offset + valid_count, size), + elements_per_row_ * size); + set_data( + storage_offset, static_cast(source), valid_count); } - auto size = - size_per_chunk_ == MAX_ROW_COUNT ? element_count : size_per_chunk_; - chunks_ptr_->emplace_to_at_least( - upper_div(element_offset + element_count, size), - elements_per_row_ * size); - set_data( - element_offset, static_cast(source), element_count); } const void* @@ -297,8 +357,24 @@ class ConcurrentVectorImpl : public VectorBase { // just for fun, don't use it directly const Type* get_element(ssize_t element_index) const { - auto chunk_id = element_index / size_per_chunk_; - auto chunk_offset = element_index % size_per_chunk_; + auto physical_index = offset_mapping_.GetPhysicalOffset(element_index); + if (physical_index == -1) { + return nullptr; + } + auto chunk_id = physical_index / size_per_chunk_; + auto chunk_offset = physical_index % size_per_chunk_; + auto data = + static_cast(chunks_ptr_->get_chunk_data(chunk_id)); + return data + chunk_offset * elements_per_row_; + } + + const Type* + get_physical_element(ssize_t physical_index) const { + if (physical_index == -1) { + return nullptr; + } + auto chunk_id = physical_index / size_per_chunk_; + auto chunk_offset = physical_index % size_per_chunk_; auto data = static_cast(chunks_ptr_->get_chunk_data(chunk_id)); return data + chunk_offset * elements_per_row_; @@ -344,6 +420,40 @@ class ConcurrentVectorImpl : public VectorBase { return chunks_ptr_->is_mmap(); } + bool + is_mapping_storage() const override { + return use_mapping_storage_; + } + + int64_t + get_physical_offset(int64_t logical_offset) const override { + return offset_mapping_.GetPhysicalOffset(logical_offset); + } + + int64_t + get_logical_offset(int64_t physical_offset) const override { + return offset_mapping_.GetLogicalOffset(physical_offset); + } + + int64_t + get_valid_count() const override { + return offset_mapping_.GetValidCount(); + } + + const milvus::OffsetMapping& + get_offset_mapping() const override { + return offset_mapping_; + } + + const FixedVector& + get_valid_data() const override { + if (valid_data_ptr_ != nullptr) { + return valid_data_ptr_->get_data(); + } + static const FixedVector empty; + return empty; + } + private: void set_data(ssize_t element_offset, @@ -395,9 +505,10 @@ class ConcurrentVectorImpl : public VectorBase { fmt::format("chunk_id out of chunk num, chunk_id={}, chunk_num={}", chunk_id, chunk_num)); - size_t chunk_id_offset = chunk_id * size_per_chunk_ * elements_per_row_; std::optional check_data_valid = std::nullopt; - if (valid_data_ptr_ != nullptr) { + if (valid_data_ptr_ != nullptr && !use_mapping_storage_) { + size_t chunk_id_offset = + chunk_id * size_per_chunk_ * elements_per_row_; check_data_valid = [valid_data_ptr = valid_data_ptr_, beg_id = chunk_id_offset](size_t offset) { return valid_data_ptr->is_valid(beg_id + offset); @@ -414,6 +525,9 @@ class ConcurrentVectorImpl : public VectorBase { const ssize_t elements_per_row_; ChunkVectorPtr chunks_ptr_ = nullptr; ThreadSafeValidDataPtr valid_data_ptr_ = nullptr; + + const bool use_mapping_storage_; + milvus::OffsetMapping offset_mapping_; }; template @@ -496,9 +610,14 @@ class ConcurrentVector int64_t dim /* not use it*/, int64_t size_per_chunk, storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr, - ThreadSafeValidDataPtr valid_data_ptr = nullptr) + ThreadSafeValidDataPtr valid_data_ptr = nullptr, + bool use_mapping_storage = false) : ConcurrentVectorImpl::ConcurrentVectorImpl( - 1, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) { + 1, + size_per_chunk, + std::move(mmap_descriptor), + valid_data_ptr, + use_mapping_storage) { } }; @@ -510,13 +629,15 @@ class ConcurrentVector explicit ConcurrentVector( int64_t size_per_chunk, storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr, - ThreadSafeValidDataPtr valid_data_ptr = nullptr) + ThreadSafeValidDataPtr valid_data_ptr = nullptr, + bool use_mapping_storage = false) : ConcurrentVectorImpl, true>::ConcurrentVectorImpl(1, size_per_chunk, std::move( mmap_descriptor), - valid_data_ptr), + valid_data_ptr, + use_mapping_storage), dim_(0) { } @@ -527,7 +648,16 @@ class ConcurrentVector auto* src = static_cast*>( source); - for (int i = 0; i < element_count; ++i) { + ssize_t source_count = element_count; + if (this->use_mapping_storage_) { + source_count = 0; + for (ssize_t i = 0; i < element_count; ++i) { + if (this->valid_data_ptr_->is_valid(element_offset + i)) { + source_count++; + } + } + } + for (ssize_t i = 0; i < source_count; ++i) { dim_ = std::max(dim_, src[i].dim()); } ConcurrentVectorImpl, @@ -552,9 +682,14 @@ class ConcurrentVector ConcurrentVector(int64_t dim, int64_t size_per_chunk, storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr, - ThreadSafeValidDataPtr valid_data_ptr = nullptr) + ThreadSafeValidDataPtr valid_data_ptr = nullptr, + bool use_mapping_storage = false) : ConcurrentVectorImpl::ConcurrentVectorImpl( - dim, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) { + dim, + size_per_chunk, + std::move(mmap_descriptor), + valid_data_ptr, + use_mapping_storage) { } }; @@ -566,11 +701,13 @@ class ConcurrentVector int64_t dim, int64_t size_per_chunk, storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr, - ThreadSafeValidDataPtr valid_data_ptr = nullptr) + ThreadSafeValidDataPtr valid_data_ptr = nullptr, + bool use_mapping_storage = false) : ConcurrentVectorImpl(dim / 8, size_per_chunk, std::move(mmap_descriptor), - valid_data_ptr) { + valid_data_ptr, + use_mapping_storage) { AssertInfo(dim % 8 == 0, fmt::format("dim is not a multiple of 8, dim={}", dim)); } @@ -583,9 +720,14 @@ class ConcurrentVector ConcurrentVector(int64_t dim, int64_t size_per_chunk, storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr, - ThreadSafeValidDataPtr valid_data_ptr = nullptr) + ThreadSafeValidDataPtr valid_data_ptr = nullptr, + bool use_mapping_storage = false) : ConcurrentVectorImpl::ConcurrentVectorImpl( - dim, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) { + dim, + size_per_chunk, + std::move(mmap_descriptor), + valid_data_ptr, + use_mapping_storage) { } }; @@ -596,9 +738,14 @@ class ConcurrentVector ConcurrentVector(int64_t dim, int64_t size_per_chunk, storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr, - ThreadSafeValidDataPtr valid_data_ptr = nullptr) + ThreadSafeValidDataPtr valid_data_ptr = nullptr, + bool use_mapping_storage = false) : ConcurrentVectorImpl::ConcurrentVectorImpl( - dim, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) { + dim, + size_per_chunk, + std::move(mmap_descriptor), + valid_data_ptr, + use_mapping_storage) { } }; @@ -608,9 +755,14 @@ class ConcurrentVector : public ConcurrentVectorImpl { ConcurrentVector(int64_t dim, int64_t size_per_chunk, storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr, - ThreadSafeValidDataPtr valid_data_ptr = nullptr) + ThreadSafeValidDataPtr valid_data_ptr = nullptr, + bool use_mapping_storage = false) : ConcurrentVectorImpl::ConcurrentVectorImpl( - dim, size_per_chunk, std::move(mmap_descriptor)) { + dim, + size_per_chunk, + std::move(mmap_descriptor), + valid_data_ptr, + use_mapping_storage) { } }; diff --git a/internal/core/src/segcore/FieldIndexing.cpp b/internal/core/src/segcore/FieldIndexing.cpp index b0a303444e..fd229fd087 100644 --- a/internal/core/src/segcore/FieldIndexing.cpp +++ b/internal/core/src/segcore/FieldIndexing.cpp @@ -9,8 +9,10 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License +#include #include #include +#include #include "common/EasyAssert.h" #include "common/Types.h" @@ -19,6 +21,7 @@ #include "index/StringIndexMarisa.h" #include "common/SystemProperty.h" +#include "segcore/ConcurrentVector.h" #include "segcore/FieldIndexing.h" #include "index/VectorMemIndex.h" #include "IndexConfigGenerator.h" @@ -29,6 +32,104 @@ namespace milvus::segcore { using std::unique_ptr; +void +IndexingRecord::AppendingIndex(int64_t reserved_offset, + int64_t size, + FieldId fieldId, + const DataArray* stream_data, + const InsertRecord& record) { + if (!is_in(fieldId)) { + return; + } + auto& indexing = field_indexings_.at(fieldId); + auto type = indexing->get_data_type(); + auto field_raw_data = record.get_data_base(fieldId); + auto field_meta = schema_.get_fields().at(fieldId); + int64_t valid_count = reserved_offset + size; + if (field_meta.is_nullable() && field_raw_data->is_mapping_storage()) { + valid_count = field_raw_data->get_valid_count(); + } + if (type == DataType::VECTOR_FLOAT && + valid_count >= indexing->get_build_threshold()) { + indexing->AppendSegmentIndexDense( + reserved_offset, + size, + field_raw_data, + stream_data->vectors().float_vector().data().data()); + } else if (type == DataType::VECTOR_FLOAT16 && + valid_count >= indexing->get_build_threshold()) { + indexing->AppendSegmentIndexDense( + reserved_offset, + size, + field_raw_data, + stream_data->vectors().float16_vector().data()); + } else if (type == DataType::VECTOR_BFLOAT16 && + valid_count >= indexing->get_build_threshold()) { + indexing->AppendSegmentIndexDense( + reserved_offset, + size, + field_raw_data, + stream_data->vectors().bfloat16_vector().data()); + } else if (type == DataType::VECTOR_SPARSE_U32_F32 && + valid_count >= indexing->get_build_threshold()) { + auto data = SparseBytesToRows( + stream_data->vectors().sparse_float_vector().contents()); + indexing->AppendSegmentIndexSparse( + reserved_offset, + size, + stream_data->vectors().sparse_float_vector().dim(), + field_raw_data, + data.get()); + } else if (type == DataType::GEOMETRY) { + // For geometry fields, append data incrementally to RTree index + indexing->AppendSegmentIndex( + reserved_offset, size, field_raw_data, stream_data); + } +} + +// concurrent, reentrant +void +IndexingRecord::AppendingIndex(int64_t reserved_offset, + int64_t size, + FieldId fieldId, + const FieldDataPtr data, + const InsertRecord& record) { + if (!is_in(fieldId)) { + return; + } + auto& indexing = field_indexings_.at(fieldId); + auto type = indexing->get_data_type(); + const void* p = data->Data(); + auto vec_base = record.get_data_base(fieldId); + auto field_meta = schema_.get_fields().at(fieldId); + int64_t valid_count = reserved_offset + size; + if (field_meta.is_nullable() && vec_base->is_mapping_storage()) { + valid_count = vec_base->get_valid_count(); + } + + if ((type == DataType::VECTOR_FLOAT || type == DataType::VECTOR_FLOAT16 || + type == DataType::VECTOR_BFLOAT16) && + valid_count >= indexing->get_build_threshold()) { + auto vec_base = record.get_data_base(fieldId); + indexing->AppendSegmentIndexDense( + reserved_offset, size, vec_base, data->Data()); + } else if (type == DataType::VECTOR_SPARSE_U32_F32 && + valid_count >= indexing->get_build_threshold()) { + auto vec_base = record.get_data_base(fieldId); + indexing->AppendSegmentIndexSparse( + reserved_offset, + size, + std::dynamic_pointer_cast>(data) + ->Dim(), + vec_base, + p); + } else if (type == DataType::GEOMETRY) { + // For geometry fields, append data incrementally to RTree index + auto vec_base = record.get_data_base(fieldId); + indexing->AppendSegmentIndex(reserved_offset, size, vec_base, data); + } +} + VectorFieldIndexing::VectorFieldIndexing(const FieldMeta& field_meta, const FieldIndexMeta& field_index_meta, int64_t segment_max_row_count, @@ -140,54 +241,133 @@ VectorFieldIndexing::AppendSegmentIndexSparse(int64_t reserved_offset, int64_t new_data_dim, const VectorBase* field_raw_data, const void* data_source) { + using value_type = knowhere::sparse::SparseRow; + AssertInfo(get_data_type() == DataType::VECTOR_SPARSE_U32_F32, + "Data type of vector field is not VECTOR_SPARSE_U32_F32"); + auto conf = get_build_params(get_data_type()); - auto source = dynamic_cast*>( - field_raw_data); - AssertInfo(source, + auto field_source = + dynamic_cast*>( + field_raw_data); + AssertInfo(field_source, "field_raw_data can't cast to " "ConcurrentVector type"); - AssertInfo(size > 0, "append 0 sparse rows to index is not allowed"); - if (!built_) { - AssertInfo(!sync_with_index_, "index marked synced before built"); - idx_t total_rows = reserved_offset + size; - idx_t chunk_id = 0; - auto dim = source->Dim(); + auto source = static_cast(data_source); - while (total_rows > 0) { - auto mat = static_cast< - const knowhere::sparse::SparseRow*>( - source->get_chunk_data(chunk_id)); - auto rows = std::min(source->get_size_per_chunk(), total_rows); - auto dataset = knowhere::GenDataSet(rows, dim, mat); - dataset->SetIsSparse(true); - try { - if (chunk_id == 0) { - index_->BuildWithDataset(dataset, conf); - } else { - index_->AddWithDataset(dataset, conf); + auto dim = new_data_dim; + auto size_per_chunk = field_raw_data->get_size_per_chunk(); + auto build_threshold = get_build_threshold(); + bool is_mapping_storage = field_raw_data->is_mapping_storage(); + auto& valid_data = field_raw_data->get_valid_data(); + + if (!built_) { + const void* data_ptr = nullptr; + std::vector data_buf; + + int64_t start_chunk = 0; + int64_t end_chunk = (build_threshold - 1) / size_per_chunk; + + if (start_chunk == end_chunk) { + data_ptr = field_raw_data->get_chunk_data(start_chunk); + } else { + data_buf.resize(build_threshold); + int64_t actual_copy_count = 0; + for (int64_t chunk_id = start_chunk; chunk_id <= end_chunk; + ++chunk_id) { + int64_t copy_start = + std::max((int64_t)0, chunk_id * size_per_chunk); + int64_t copy_end = + std::min(build_threshold, (chunk_id + 1) * size_per_chunk); + int64_t copy_count = copy_end - copy_start; + // For mapping storage, chunk data is already compactly stored, + // so we can copy directly from chunk + auto chunk_data = static_cast( + field_raw_data->get_chunk_data(chunk_id)); + int64_t chunk_offset = copy_start - chunk_id * size_per_chunk; + for (int64_t i = 0; i < copy_count; ++i) { + data_buf[actual_copy_count + i] = + chunk_data[chunk_offset + i]; } - } catch (SegcoreError& error) { - LOG_ERROR("growing sparse index build error: {}", error.what()); - recreate_index(get_data_type(), nullptr); - index_cur_ = 0; - return; + actual_copy_count += copy_count; } - index_cur_.fetch_add(rows); - total_rows -= rows; - chunk_id++; + data_ptr = data_buf.data(); + } + + auto dataset = knowhere::GenDataSet(build_threshold, dim, data_ptr); + dataset->SetIsSparse(true); + try { + index_->BuildWithDataset(dataset, conf); + if (is_mapping_storage) { + auto logical_offset = + field_raw_data->get_logical_offset(build_threshold - 1); + auto update_count = logical_offset + 1; + index_->UpdateValidData(valid_data.data(), update_count); + } + built_ = true; + index_cur_.fetch_add(build_threshold); + } catch (SegcoreError& error) { + LOG_ERROR("growing sparse index build error: {}", error.what()); + recreate_index(get_data_type(), field_raw_data); + return; } - built_ = true; - sync_with_index_ = true; - // if not built_, new rows in data_source have already been added to - // source(ConcurrentVector) and thus added to the - // index, thus no need to add again. - return; } - auto dataset = knowhere::GenDataSet(size, new_data_dim, data_source); - dataset->SetIsSparse(true); - index_->AddWithDataset(dataset, conf); - index_cur_.fetch_add(size); + // Append rest data when index has been built + int64_t add_count = 0; + int64_t total_count = 0; + if (valid_data.empty()) { + // Non-nullable case: add all rows + add_count = reserved_offset + size - index_cur_.load(); + total_count = size; + if (add_count <= 0) { + sync_with_index_.store(true); + return; + } + auto data_ptr = source + (total_count - add_count); + auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr); + dataset->SetIsSparse(true); + try { + index_->AddWithDataset(dataset, conf); + index_cur_.fetch_add(add_count); + sync_with_index_.store(true); + } catch (SegcoreError& error) { + LOG_ERROR("growing sparse index add error: {}", error.what()); + recreate_index(get_data_type(), field_raw_data); + } + } else { + // Nullable case: only add valid rows (matching dense vector approach) + auto index_total_count = index_->GetOffsetMapping().GetTotalCount(); + auto add_valid_data_count = reserved_offset + size - index_total_count; + for (auto i = reserved_offset; i < reserved_offset + size; i++) { + if (valid_data[i]) { + total_count++; + if (i >= index_total_count) { + add_count++; + } + } + } + if (add_count <= 0 && add_valid_data_count <= 0) { + sync_with_index_.store(true); + return; + } + if (add_count > 0) { + auto data_ptr = source + (total_count - add_count); + auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr); + dataset->SetIsSparse(true); + try { + index_->AddWithDataset(dataset, conf); + } catch (SegcoreError& error) { + LOG_ERROR("growing sparse index add error: {}", error.what()); + recreate_index(get_data_type(), field_raw_data); + } + } + if (add_valid_data_count > 0) { + index_->UpdateValidData(valid_data.data() + index_total_count, + add_valid_data_count); + } + index_cur_.fetch_add(add_count); + sync_with_index_.store(true); + } } void @@ -203,8 +383,10 @@ VectorFieldIndexing::AppendSegmentIndexDense(int64_t reserved_offset, auto dim = get_dim(); auto conf = get_build_params(get_data_type()); auto size_per_chunk = field_raw_data->get_size_per_chunk(); - //append vector [vector_id_beg, vector_id_end] into index - //build index [vector_id_beg, build_threshold) when index not exist + auto build_threshold = get_build_threshold(); + bool is_mapping_storage = field_raw_data->is_mapping_storage(); + auto& valid_data = field_raw_data->get_valid_data(); + AssertInfo(ConcurrentDenseVectorCheck(field_raw_data, get_data_type()), "vec_base can't cast to ConcurrentVector type"); size_t vec_length; @@ -216,88 +398,112 @@ VectorFieldIndexing::AppendSegmentIndexDense(int64_t reserved_offset, vec_length = dim * sizeof(bfloat16); } if (!built_) { - idx_t vector_id_beg = index_cur_.load(); - Assert(vector_id_beg == 0); - idx_t vector_id_end = get_build_threshold() - 1; - auto chunk_id_beg = vector_id_beg / size_per_chunk; - auto chunk_id_end = vector_id_end / size_per_chunk; + const void* data_ptr; + std::unique_ptr data_buf; + // Chunk data stores valid vectors compactly for both nullable and non-nullable + int64_t start_chunk = 0; + int64_t end_chunk = (build_threshold - 1) / size_per_chunk; - int64_t vec_num = vector_id_end - vector_id_beg + 1; - // for train index - const void* data_addr; - unique_ptr vec_data; - //all train data in one chunk - if (chunk_id_beg == chunk_id_end) { - data_addr = field_raw_data->get_chunk_data(chunk_id_beg); + if (start_chunk == end_chunk) { + auto chunk_data = static_cast( + field_raw_data->get_chunk_data(start_chunk)); + data_ptr = chunk_data; } else { - //merge data from multiple chunks together - vec_data = std::make_unique(vec_num * vec_length); - int64_t offset = 0; - //copy vector data [vector_id_beg, vector_id_end] - for (int chunk_id = chunk_id_beg; chunk_id <= chunk_id_end; - chunk_id++) { - int chunk_offset = 0; - int chunk_copysz = - chunk_id == chunk_id_end - ? vector_id_end - chunk_id * size_per_chunk + 1 - : size_per_chunk; - std::memcpy( - (void*)((const char*)vec_data.get() + offset * vec_length), - (void*)((const char*)field_raw_data->get_chunk_data( - chunk_id) + - chunk_offset * vec_length), - chunk_copysz * vec_length); - offset += chunk_copysz; + data_buf = std::make_unique(build_threshold * vec_length); + int64_t actual_copy_count = 0; + for (int64_t chunk_id = start_chunk; chunk_id <= end_chunk; + ++chunk_id) { + auto chunk_data = static_cast( + field_raw_data->get_chunk_data(chunk_id)); + int64_t copy_start = + std::max((int64_t)0, chunk_id * size_per_chunk); + int64_t copy_end = + std::min(build_threshold, (chunk_id + 1) * size_per_chunk); + int64_t copy_count = copy_end - copy_start; + auto src = + chunk_data + + (copy_start - chunk_id * size_per_chunk) * vec_length; + std::memcpy(data_buf.get() + actual_copy_count * vec_length, + src, + copy_count * vec_length); + actual_copy_count += copy_count; } - data_addr = vec_data.get(); + data_ptr = data_buf.get(); } - auto dataset = knowhere::GenDataSet(vec_num, dim, data_addr); - dataset->SetIsOwner(false); + + auto dataset = knowhere::GenDataSet(build_threshold, dim, data_ptr); try { index_->BuildWithDataset(dataset, conf); + if (is_mapping_storage) { + auto logical_offset = + field_raw_data->get_logical_offset(build_threshold - 1); + auto update_count = logical_offset + 1; + index_->UpdateValidData(valid_data.data(), update_count); + } + built_ = true; + index_cur_.fetch_add(build_threshold); } catch (SegcoreError& error) { LOG_ERROR("growing index build error: {}", error.what()); recreate_index(get_data_type(), field_raw_data); return; } - index_cur_.fetch_add(vec_num); - built_ = true; } //append rest data when index has built - idx_t vector_id_beg = index_cur_.load(); - idx_t vector_id_end = reserved_offset + size - 1; - auto chunk_id_beg = vector_id_beg / size_per_chunk; - auto chunk_id_end = vector_id_end / size_per_chunk; - int64_t vec_num = vector_id_end - vector_id_beg + 1; - - if (vec_num <= 0) { - sync_with_index_.store(true); - return; - } - - if (sync_with_index_.load()) { - Assert(size == vec_num); - auto dataset = knowhere::GenDataSet(vec_num, dim, data_source); - index_->AddWithDataset(dataset, conf); - index_cur_.fetch_add(vec_num); - } else { - for (int chunk_id = chunk_id_beg; chunk_id <= chunk_id_end; - chunk_id++) { - int chunk_offset = chunk_id == chunk_id_beg - ? index_cur_ - chunk_id * size_per_chunk - : 0; - int chunk_sz = - chunk_id == chunk_id_end - ? vector_id_end % size_per_chunk - chunk_offset + 1 - : size_per_chunk - chunk_offset; - auto dataset = knowhere::GenDataSet( - chunk_sz, - dim, - (const char*)field_raw_data->get_chunk_data(chunk_id) + - chunk_offset * vec_length); - index_->AddWithDataset(dataset, conf); - index_cur_.fetch_add(chunk_sz); + int64_t add_count = 0; + int64_t total_count = 0; + if (valid_data.empty()) { + add_count = reserved_offset + size - index_cur_.load(); + total_count = size; + if (add_count <= 0) { + sync_with_index_.store(true); + return; } + auto data_ptr = static_cast(data_source) + + (total_count - add_count) * vec_length; + auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr); + try { + index_->AddWithDataset(dataset, conf); + index_cur_.fetch_add(add_count); + sync_with_index_.store(true); + } catch (SegcoreError& error) { + LOG_ERROR("growing index add error: {}", error.what()); + recreate_index(get_data_type(), field_raw_data); + } + } else { + // Nullable dense vectors: data_source (proto) contains valid vectors compactly + auto index_total_count = index_->GetOffsetMapping().GetTotalCount(); + auto add_valid_data_count = reserved_offset + size - index_total_count; + auto index_cur_val = index_cur_.load(); + // Count valid vectors in this batch range + for (auto i = reserved_offset; i < reserved_offset + size; i++) { + if (valid_data[i]) { + total_count++; + if (i >= index_total_count) { + add_count++; + } + } + } + if (add_count <= 0 && add_valid_data_count <= 0) { + sync_with_index_.store(true); + return; + } + if (add_count > 0) { + // data_source contains valid vectors compactly, skip already indexed ones + auto data_ptr = static_cast(data_source) + + (total_count - add_count) * vec_length; + auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr); + try { + index_->AddWithDataset(dataset, conf); + } catch (SegcoreError& error) { + LOG_ERROR("growing index add error: {}", error.what()); + recreate_index(get_data_type(), field_raw_data); + } + } + if (add_valid_data_count > 0) { + index_->UpdateValidData(valid_data.data() + index_total_count, + add_valid_data_count); + } + index_cur_.fetch_add(add_count); sync_with_index_.store(true); } } diff --git a/internal/core/src/segcore/FieldIndexing.h b/internal/core/src/segcore/FieldIndexing.h index cda0cd1bcb..65123449ab 100644 --- a/internal/core/src/segcore/FieldIndexing.h +++ b/internal/core/src/segcore/FieldIndexing.h @@ -434,93 +434,19 @@ class IndexingRecord { assert(offset_id == schema_.size()); } - // concurrent, reentrant void AppendingIndex(int64_t reserved_offset, int64_t size, FieldId fieldId, const DataArray* stream_data, - const InsertRecord& record) { - if (!is_in(fieldId)) { - return; - } - auto& indexing = field_indexings_.at(fieldId); - auto type = indexing->get_data_type(); - auto field_raw_data = record.get_data_base(fieldId); - if (type == DataType::VECTOR_FLOAT && - reserved_offset + size >= indexing->get_build_threshold()) { - indexing->AppendSegmentIndexDense( - reserved_offset, - size, - field_raw_data, - stream_data->vectors().float_vector().data().data()); - } else if (type == DataType::VECTOR_FLOAT16 && - reserved_offset + size >= indexing->get_build_threshold()) { - indexing->AppendSegmentIndexDense( - reserved_offset, - size, - field_raw_data, - stream_data->vectors().float16_vector().data()); - } else if (type == DataType::VECTOR_BFLOAT16 && - reserved_offset + size >= indexing->get_build_threshold()) { - indexing->AppendSegmentIndexDense( - reserved_offset, - size, - field_raw_data, - stream_data->vectors().bfloat16_vector().data()); - } else if (type == DataType::VECTOR_SPARSE_U32_F32) { - auto data = SparseBytesToRows( - stream_data->vectors().sparse_float_vector().contents()); - indexing->AppendSegmentIndexSparse( - reserved_offset, - size, - stream_data->vectors().sparse_float_vector().dim(), - field_raw_data, - data.get()); - } else if (type == DataType::GEOMETRY) { - // For geometry fields, append data incrementally to RTree index - indexing->AppendSegmentIndex( - reserved_offset, size, field_raw_data, stream_data); - } - } + const InsertRecord& record); - // concurrent, reentrant void AppendingIndex(int64_t reserved_offset, int64_t size, FieldId fieldId, const FieldDataPtr data, - const InsertRecord& record) { - if (!is_in(fieldId)) { - return; - } - auto& indexing = field_indexings_.at(fieldId); - auto type = indexing->get_data_type(); - const void* p = data->Data(); - - if ((type == DataType::VECTOR_FLOAT || - type == DataType::VECTOR_FLOAT16 || - type == DataType::VECTOR_BFLOAT16) && - reserved_offset + size >= indexing->get_build_threshold()) { - auto vec_base = record.get_data_base(fieldId); - indexing->AppendSegmentIndexDense( - reserved_offset, size, vec_base, data->Data()); - } else if (type == DataType::VECTOR_SPARSE_U32_F32) { - auto vec_base = record.get_data_base(fieldId); - indexing->AppendSegmentIndexSparse( - reserved_offset, - size, - std::dynamic_pointer_cast>( - data) - ->Dim(), - vec_base, - p); - } else if (type == DataType::GEOMETRY) { - // For geometry fields, append data incrementally to RTree index - auto vec_base = record.get_data_base(fieldId); - indexing->AppendSegmentIndex(reserved_offset, size, vec_base, data); - } - } + const InsertRecord& record); // for sparse float vector: // * element_size is not used diff --git a/internal/core/src/segcore/IndexConfigGenerator.cpp b/internal/core/src/segcore/IndexConfigGenerator.cpp index 8ee7416cb6..426e069854 100644 --- a/internal/core/src/segcore/IndexConfigGenerator.cpp +++ b/internal/core/src/segcore/IndexConfigGenerator.cpp @@ -87,12 +87,6 @@ VecIndexConfig::VecIndexConfig(const int64_t max_index_row_cout, int64_t VecIndexConfig::GetBuildThreshold() const noexcept { - // For sparse, do not impose a threshold and start using index with any - // number of rows. Unlike dense vector index, growing sparse vector index - // does not require a minimum number of rows to train. - if (is_sparse_) { - return 0; - } auto ratio = config_.get_build_ratio(); assert(ratio >= 0.0 && ratio < 1.0); return std::max(int64_t(max_index_row_count_ * ratio), diff --git a/internal/core/src/segcore/InsertRecord.h b/internal/core/src/segcore/InsertRecord.h index dbb7b6c127..20f055ddc3 100644 --- a/internal/core/src/segcore/InsertRecord.h +++ b/internal/core/src/segcore/InsertRecord.h @@ -1025,7 +1025,6 @@ class InsertRecordGrowing { } // append a column of vector type - // vector not support nullable, not pass valid data ptr template void append_data(FieldId field_id, @@ -1033,9 +1032,15 @@ class InsertRecordGrowing { int64_t size_per_chunk, const storage::MmapChunkDescriptorPtr mmap_descriptor) { static_assert(std::is_base_of_v); - data_.emplace(field_id, - std::make_unique>( - dim, size_per_chunk, mmap_descriptor)); + bool use_mapping_storage = is_valid_data_exist(field_id); + data_.emplace( + field_id, + std::make_unique>( + dim, + size_per_chunk, + mmap_descriptor, + use_mapping_storage ? get_valid_data(field_id) : nullptr, + use_mapping_storage)); } // append a column of scalar or sparse float vector type @@ -1045,13 +1050,23 @@ class InsertRecordGrowing { int64_t size_per_chunk, const storage::MmapChunkDescriptorPtr mmap_descriptor) { static_assert(IsScalar || IsSparse); - data_.emplace( - field_id, - std::make_unique>( - size_per_chunk, - mmap_descriptor, - is_valid_data_exist(field_id) ? get_valid_data(field_id) - : nullptr)); + bool use_mapping_storage = is_valid_data_exist(field_id); + if constexpr (IsSparse) { + data_.emplace( + field_id, + std::make_unique>( + size_per_chunk, + mmap_descriptor, + use_mapping_storage ? get_valid_data(field_id) : nullptr, + use_mapping_storage)); + } else { + data_.emplace( + field_id, + std::make_unique>( + size_per_chunk, + mmap_descriptor, + use_mapping_storage ? get_valid_data(field_id) : nullptr)); + } } void diff --git a/internal/core/src/segcore/SegmentGrowingImpl.cpp b/internal/core/src/segcore/SegmentGrowingImpl.cpp index 811b1e1630..38f71440fc 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.cpp +++ b/internal/core/src/segcore/SegmentGrowingImpl.cpp @@ -312,13 +312,13 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset, AssertInfo(field_id_to_offset.count(field_id), fmt::format("can't find field {}", field_id.get())); auto data_offset = field_id_to_offset[field_id]; + if (field_meta.is_nullable()) { + insert_record_.get_valid_data(field_id)->set_data_raw( + num_rows, + &insert_record_proto->fields_data(data_offset), + field_meta); + } if (!indexing_record_.HasRawData(field_id)) { - if (field_meta.is_nullable()) { - insert_record_.get_valid_data(field_id)->set_data_raw( - num_rows, - &insert_record_proto->fields_data(data_offset), - field_meta); - } insert_record_.get_data_base(field_id)->set_data_raw( reserved_offset, num_rows, @@ -937,14 +937,31 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx, auto& field_meta = schema_->operator[](field_id); auto vec_ptr = insert_record_.get_data_base(field_id); if (field_meta.is_vector()) { - auto result = CreateEmptyVectorDataArray(count, field_meta); + int64_t valid_count = count; + const bool* valid_data = nullptr; + const int64_t* valid_offsets = seg_offsets; + ValidResult filter_result; + + if (field_meta.is_nullable()) { + filter_result = + FilterVectorValidOffsets(op_ctx, field_id, seg_offsets, count); + valid_count = filter_result.valid_count; + valid_data = filter_result.valid_data.get(); + valid_offsets = filter_result.valid_offsets.data(); + } + + auto result = CreateEmptyVectorDataArray( + count, valid_count, valid_data, field_meta); + if (valid_count == 0) { + return result; + } if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) { bulk_subscript_impl(op_ctx, field_id, field_meta.get_sizeof(), vec_ptr, - seg_offsets, - count, + valid_offsets, + valid_count, result->mutable_vectors() ->mutable_float_vector() ->mutable_data() @@ -955,8 +972,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx, field_id, field_meta.get_sizeof(), vec_ptr, - seg_offsets, - count, + valid_offsets, + valid_count, result->mutable_vectors()->mutable_binary_vector()->data()); } else if (field_meta.get_data_type() == DataType::VECTOR_FLOAT16) { bulk_subscript_impl( @@ -964,8 +981,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx, field_id, field_meta.get_sizeof(), vec_ptr, - seg_offsets, - count, + valid_offsets, + valid_count, result->mutable_vectors()->mutable_float16_vector()->data()); } else if (field_meta.get_data_type() == DataType::VECTOR_BFLOAT16) { bulk_subscript_impl( @@ -973,8 +990,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx, field_id, field_meta.get_sizeof(), vec_ptr, - seg_offsets, - count, + valid_offsets, + valid_count, result->mutable_vectors()->mutable_bfloat16_vector()->data()); } else if (field_meta.get_data_type() == DataType::VECTOR_SPARSE_U32_F32) { @@ -982,8 +999,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx, op_ctx, field_id, (const ConcurrentVector*)vec_ptr, - seg_offsets, - count, + valid_offsets, + valid_count, result->mutable_vectors()->mutable_sparse_float_vector()); result->mutable_vectors()->set_dim( result->vectors().sparse_float_vector().dim()); @@ -993,8 +1010,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx, field_id, field_meta.get_sizeof(), vec_ptr, - seg_offsets, - count, + valid_offsets, + valid_count, result->mutable_vectors()->mutable_int8_vector()->data()); } else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { bulk_subscript_vector_array_impl(op_ctx, @@ -1190,7 +1207,7 @@ SegmentGrowingImpl::bulk_subscript_sparse_float_vector_impl( [&](size_t i) { auto offset = seg_offsets[i]; return offset != INVALID_SEG_OFFSET - ? vec_raw->get_element(offset) + ? vec_raw->get_physical_element(offset) : nullptr; }, count, @@ -1257,12 +1274,8 @@ SegmentGrowingImpl::bulk_subscript_impl(milvus::OpContext* op_ctx, for (int i = 0; i < count; ++i) { auto dst = output_base + i * element_sizeof; auto offset = seg_offsets[i]; - if (offset == INVALID_SEG_OFFSET) { - memset(dst, 0, element_sizeof); - } else { - auto src = (const uint8_t*)vec.get_element(offset); - memcpy(dst, src, element_sizeof); - } + auto src = (const uint8_t*)vec.get_physical_element(offset); + memcpy(dst, src, element_sizeof); } return; } @@ -1860,4 +1873,68 @@ SegmentGrowingImpl::BuildGeometryCacheForLoad( } } +SegmentGrowingImpl::ValidResult +SegmentGrowingImpl::FilterVectorValidOffsets(milvus::OpContext* op_ctx, + FieldId field_id, + const int64_t* seg_offsets, + int64_t count) const { + ValidResult result; + result.valid_count = count; + + if (indexing_record_.SyncDataWithIndex(field_id)) { + const auto& field_indexing = + indexing_record_.get_vec_field_indexing(field_id); + auto indexing = field_indexing.get_segment_indexing(); + auto vec_index = dynamic_cast(indexing.get()); + + if (vec_index != nullptr && vec_index->HasValidData()) { + result.valid_data = std::make_unique(count); + result.valid_offsets.reserve(count); + + for (int64_t i = 0; i < count; ++i) { + bool is_valid = vec_index->IsRowValid(seg_offsets[i]); + result.valid_data[i] = is_valid; + if (is_valid) { + int64_t physical_offset = + vec_index->GetPhysicalOffset(seg_offsets[i]); + if (physical_offset >= 0) { + result.valid_offsets.push_back(physical_offset); + } + } + } + result.valid_count = result.valid_offsets.size(); + } + } else { + auto vec_base = insert_record_.get_data_base(field_id); + if (vec_base != nullptr) { + const auto& valid_data_vec = vec_base->get_valid_data(); + bool is_mapping_storage = vec_base->is_mapping_storage(); + if (!valid_data_vec.empty()) { + result.valid_data = std::make_unique(count); + result.valid_offsets.reserve(count); + + for (int64_t i = 0; i < count; ++i) { + auto offset = seg_offsets[i]; + bool is_valid = + offset >= 0 && + offset < static_cast(valid_data_vec.size()) && + valid_data_vec[offset]; + result.valid_data[i] = is_valid; + if (is_valid) { + if (is_mapping_storage) { + int64_t physical_offset = + vec_base->get_physical_offset(offset); + if (physical_offset >= 0) { + result.valid_offsets.push_back(physical_offset); + } + } + } + } + result.valid_count = result.valid_offsets.size(); + } + } + } + return result; +} + } // namespace milvus::segcore diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index 79ad1c41fd..7a28a99ea9 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -504,6 +504,17 @@ class SegmentGrowingImpl : public SegmentGrowing { } return nullptr; } + struct ValidResult { + int64_t valid_count = 0; + std::unique_ptr valid_data; + std::vector valid_offsets; + }; + + ValidResult + FilterVectorValidOffsets(milvus::OpContext* op_ctx, + FieldId field_id, + const int64_t* seg_offsets, + int64_t count) const; protected: int64_t diff --git a/internal/core/src/segcore/SegmentGrowingIndexTest.cpp b/internal/core/src/segcore/SegmentGrowingIndexTest.cpp index f531f0fe77..5010690fd2 100644 --- a/internal/core/src/segcore/SegmentGrowingIndexTest.cpp +++ b/internal/core/src/segcore/SegmentGrowingIndexTest.cpp @@ -280,11 +280,10 @@ TEST_P(GrowingIndexTest, Correctness) { auto inserted = (i + 1) * per_batch; // once index built, chunk data will be removed. // growing index will only be built when num rows reached - // get_build_threshold(). This value for sparse is 0, thus sparse index - // will be built since the first chunk. Dense segment buffers the first + // get_build_threshold(). Both sparse and dense segment buffer the first // 2 chunks before building an index in this test case. - if ((!is_sparse && i < 2) || !intermin_index_with_raw_data) { + if (i < 2 || !intermin_index_with_raw_data) { EXPECT_EQ(field_data->num_chunk(), upper_div(inserted, field_data->get_size_per_chunk())); } else { diff --git a/internal/core/src/segcore/SegmentGrowingTest.cpp b/internal/core/src/segcore/SegmentGrowingTest.cpp index 903a1bb5e9..081fa61a6b 100644 --- a/internal/core/src/segcore/SegmentGrowingTest.cpp +++ b/internal/core/src/segcore/SegmentGrowingTest.cpp @@ -12,12 +12,18 @@ #include #include "common/Types.h" +#include "common/IndexMeta.h" #include "knowhere/comp/index_param.h" #include "segcore/SegmentGrowing.h" #include "segcore/SegmentGrowingImpl.h" #include "pb/schema.pb.h" +#include "pb/plan.pb.h" +#include "query/Plan.h" +#include "expr/ITypeExpr.h" +#include "plan/PlanNode.h" #include "test_utils/DataGen.h" #include "test_utils/storage_test_utils.h" +#include "test_utils/GenExprProto.h" using namespace milvus::segcore; using namespace milvus; @@ -435,6 +441,325 @@ TEST(Growing, FillNullableData) { } } +class GrowingNullableTest : public ::testing::TestWithParam< + std::tuple> { + public: + void + SetUp() override { + std::tie(data_type, + metric_type, + index_type, + null_percent, + enable_interim_index) = GetParam(); + } + + DataType data_type; + knowhere::MetricType metric_type; + std::string index_type; + int null_percent; + bool enable_interim_index; +}; + +static std::vector< + std::tuple> +GenerateGrowingNullableTestParams() { + std::vector< + std::tuple> + params; + + // Dense float vectors with IVF_FLAT + std::vector> + base_configs = { + {DataType::VECTOR_FLOAT, + knowhere::metric::L2, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT}, + {DataType::VECTOR_FLOAT, + knowhere::metric::IP, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT}, + {DataType::VECTOR_FLOAT, + knowhere::metric::COSINE, + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT}, + {DataType::VECTOR_SPARSE_U32_F32, + knowhere::metric::IP, + knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX}, + }; + + std::vector null_percents = {0, 20, 100}; + + std::vector interim_index_configs = {true, false}; + + for (const auto& [dtype, metric, idx_type] : base_configs) { + for (int null_pct : null_percents) { + for (bool enable_interim : interim_index_configs) { + params.push_back( + {dtype, metric, idx_type, null_pct, enable_interim}); + } + } + } + return params; +} + +INSTANTIATE_TEST_SUITE_P( + NullableVectorParameters, + GrowingNullableTest, + ::testing::ValuesIn(GenerateGrowingNullableTestParams())); + +TEST_P(GrowingNullableTest, SearchAndQueryNullableVectors) { + using namespace milvus::query; + + bool nullable = true; + + auto schema = std::make_shared(); + auto int64_field = schema->AddDebugField("int64", DataType::INT64); + int64_t dim = 8; + auto vec = schema->AddDebugField( + "embeddings", data_type, dim, metric_type, nullable); + schema->set_primary_field_id(int64_field); + + std::map index_params; + std::map type_params; + if (data_type == DataType::VECTOR_SPARSE_U32_F32) { + index_params = {{"index_type", index_type}, + {"metric_type", metric_type}}; + type_params = {}; + } else { + index_params = {{"index_type", index_type}, + {"metric_type", metric_type}, + {"nlist", "128"}}; + type_params = {{"dim", std::to_string(dim)}}; + } + FieldIndexMeta fieldIndexMeta( + vec, std::move(index_params), std::move(type_params)); + auto config = SegcoreConfig::default_config(); + config.set_chunk_rows(1024); + config.set_enable_interim_segment_index(enable_interim_index); + // Explicitly set interim index type to avoid contamination from other tests + config.set_dense_vector_intermin_index_type( + knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC); + std::map filedMap = {{vec, fieldIndexMeta}}; + IndexMetaPtr metaPtr = + std::make_shared(100000, std::move(filedMap)); + auto segment_growing = CreateGrowingSegment(schema, metaPtr, 1, config); + auto segment = dynamic_cast(segment_growing.get()); + + int64_t batch_size = 2000; + int64_t num_rounds = 10; + int64_t topk = 5; + int64_t num_queries = 2; + Timestamp timestamp = 10000000; + + // Prepare search plan + std::string search_params_fmt; + if (data_type == DataType::VECTOR_SPARSE_U32_F32) { + search_params_fmt = R"( + vector_anns:< + field_id: {} + query_info:< + topk: {} + round_decimal: 3 + metric_type: "{}" + search_params: "{{\"drop_ratio_search\": 0.1}}" + > + placeholder_tag: "$0" + > + )"; + } else { + search_params_fmt = R"( + vector_anns:< + field_id: {} + query_info:< + topk: {} + round_decimal: 3 + metric_type: "{}" + search_params: "{{\"nprobe\": 10}}" + > + placeholder_tag: "$0" + > + )"; + } + + auto raw_plan = + fmt::format(search_params_fmt, vec.get(), topk, metric_type); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); + auto plan = + CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size()); + + // Create query vectors + proto::common::PlaceholderGroup ph_group_raw; + if (data_type == DataType::VECTOR_SPARSE_U32_F32) { + ph_group_raw = CreateSparseFloatPlaceholderGroup(num_queries, 42); + } else { + auto query_data = generate_float_vector(num_queries, dim); + ph_group_raw = + CreatePlaceholderGroupFromBlob(num_queries, dim, query_data.data()); + } + + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + + // Store all inserted data for verification + // For nullable vectors, data is stored sparsely (only valid vectors) + // We need a mapping from logical offset to physical offset + std::vector all_float_vectors; // Physical storage (only valid) + std::vector> all_sparse_vectors; + std::vector all_valid_data; // Logical storage (all rows) + std::vector + logical_to_physical; // Maps logical offset to physical + + // Insert data in multiple rounds and test after each round + for (int64_t round = 0; round < num_rounds; round++) { + int64_t total_rows = (round + 1) * batch_size; + int64_t expected_valid_count = + total_rows - (total_rows * null_percent / 100); + + auto dataset = DataGen(schema, + batch_size, + 42 + round, + 0, + 1, + 10, + 1, + false, + true, + false, + null_percent); + + // Build logical to physical mapping for this batch + int64_t base_physical = all_float_vectors.size() / dim; + if (data_type == DataType::VECTOR_SPARSE_U32_F32) { + base_physical = all_sparse_vectors.size(); + } + + auto valid_data_from_dataset = dataset.get_col_valid(vec); + int64_t physical_idx = base_physical; + for (size_t i = 0; i < valid_data_from_dataset.size(); i++) { + if (valid_data_from_dataset[i]) { + logical_to_physical.push_back(physical_idx); + physical_idx++; + } else { + logical_to_physical.push_back(-1); // null + } + } + + // Get original data directly from proto (sparse storage for nullable) + // Data is stored sparsely - only valid vectors are in the proto + if (data_type == DataType::VECTOR_FLOAT) { + auto field_data = dataset.get_col(vec); + auto& float_data = field_data->vectors().float_vector().data(); + all_float_vectors.insert( + all_float_vectors.end(), float_data.begin(), float_data.end()); + } else if (data_type == DataType::VECTOR_SPARSE_U32_F32) { + auto field_data = dataset.get_col(vec); + auto& sparse_array = field_data->vectors().sparse_float_vector(); + for (int i = 0; i < sparse_array.contents_size(); i++) { + auto& content = sparse_array.contents(i); + auto row = CopyAndWrapSparseRow(content.data(), content.size()); + all_sparse_vectors.push_back(std::move(row)); + } + } + all_valid_data.insert(all_valid_data.end(), + valid_data_from_dataset.begin(), + valid_data_from_dataset.end()); + + auto offset = segment->PreInsert(batch_size); + segment->Insert(offset, + batch_size, + dataset.row_ids_.data(), + dataset.timestamps_.data(), + dataset.raw_); + + auto& insert_record = segment->get_insert_record(); + ASSERT_TRUE(insert_record.is_valid_data_exist(vec)); + + auto valid_data_ptr = insert_record.get_data_base(vec); + const auto& valid_data = valid_data_ptr->get_valid_data(); + + // Test search + auto sr = + segment_growing->Search(plan.get(), ph_group.get(), timestamp); + + ASSERT_EQ(sr->total_nq_, num_queries); + ASSERT_EQ(sr->unity_topK_, topk); + + if (expected_valid_count == 0) { + auto total_results = sr->get_total_result_count(); + EXPECT_EQ(total_results, 0) + << "Round " << round + << ": 100% null should return 0 results, but got " + << total_results; + } else { + // Verify search results don't contain null vectors + for (size_t i = 0; i < sr->seg_offsets_.size(); i++) { + auto seg_offset = sr->seg_offsets_[i]; + if (seg_offset < 0) { + continue; + } + ASSERT_TRUE(valid_data[seg_offset]) + << "Round " << round + << ": Search returned null vector at offset " << seg_offset; + } + } + + auto vec_result = segment->bulk_subscript( + nullptr, vec, sr->seg_offsets_.data(), sr->seg_offsets_.size()); + ASSERT_TRUE(vec_result != nullptr); + + if (data_type == DataType::VECTOR_FLOAT) { + auto& float_data = vec_result->vectors().float_vector(); + size_t valid_idx = 0; + for (size_t i = 0; i < sr->seg_offsets_.size(); i++) { + auto offset = sr->seg_offsets_[i]; + if (offset < 0) { + continue; // Skip invalid offsets + } + auto physical_idx = logical_to_physical[offset]; + for (int d = 0; d < dim; d++) { + float expected_val = + all_float_vectors[physical_idx * dim + d]; + float actual_val = float_data.data(valid_idx * dim + d); + ASSERT_FLOAT_EQ(expected_val, actual_val) + << "Round " << round << ": Mismatch at logical offset " + << offset << " dim " << d; + } + valid_idx++; + } + } else if (data_type == DataType::VECTOR_SPARSE_U32_F32) { + auto& sparse_data = vec_result->vectors().sparse_float_vector(); + size_t valid_idx = 0; + for (size_t i = 0; i < sr->seg_offsets_.size(); i++) { + auto offset = sr->seg_offsets_[i]; + if (offset < 0) { + continue; // Skip invalid offsets + } + auto physical_idx = logical_to_physical[offset]; + auto& content = sparse_data.contents(valid_idx); + auto retrieved_row = + CopyAndWrapSparseRow(content.data(), content.size()); + const auto& expected_row = all_sparse_vectors[physical_idx]; + ASSERT_EQ(retrieved_row.size(), expected_row.size()) + << "Round " << round + << ": Sparse vector size mismatch at logical offset " + << offset; + for (size_t j = 0; j < retrieved_row.size(); j++) { + ASSERT_EQ(retrieved_row[j].id, expected_row[j].id) + << "Round " << round + << ": Sparse vector id mismatch at logical offset " + << offset << " element " << j; + ASSERT_FLOAT_EQ(retrieved_row[j].val, expected_row[j].val) + << "Round " << round + << ": Sparse vector val mismatch at logical offset " + << offset << " element " << j; + } + valid_idx++; + } + } + } +} + TEST_P(GrowingTest, FillVectorArrayData) { auto schema = std::make_shared(); auto int64_field = schema->AddDebugField("int64", DataType::INT64); diff --git a/internal/core/src/segcore/SegmentInterface.cpp b/internal/core/src/segcore/SegmentInterface.cpp index 491c0c0bfd..96bed94ea9 100644 --- a/internal/core/src/segcore/SegmentInterface.cpp +++ b/internal/core/src/segcore/SegmentInterface.cpp @@ -500,10 +500,18 @@ SegmentInternalInterface::bulk_subscript_not_exist_field( const milvus::FieldMeta& field_meta, int64_t count) const { auto data_type = field_meta.get_data_type(); if (IsVectorDataType(data_type)) { - ThrowInfo(DataTypeInvalid, - fmt::format("unsupported added field type {}", - field_meta.get_data_type())); + AssertInfo(field_meta.is_nullable(), + "Non-nullable vector field should not reach here"); + + auto result = CreateEmptyVectorDataArray(0, field_meta); + + auto valid_data = result->mutable_valid_data(); + for (int64_t i = 0; i < count; ++i) { + valid_data->Add(false); + } + return result; } + auto result = CreateEmptyScalarDataArray(count, field_meta); if (field_meta.default_value().has_value()) { auto res = result->mutable_valid_data()->mutable_data(); diff --git a/internal/core/src/segcore/Utils.cpp b/internal/core/src/segcore/Utils.cpp index dbf45ed316..3c44e04aef 100644 --- a/internal/core/src/segcore/Utils.cpp +++ b/internal/core/src/segcore/Utils.cpp @@ -434,6 +434,23 @@ CreateEmptyVectorDataArray(int64_t count, const FieldMeta& field_meta) { return data_array; } +std::unique_ptr +CreateEmptyVectorDataArray(int64_t count, + int64_t valid_count, + const void* valid_data, + const FieldMeta& field_meta) { + int64_t data_count = (field_meta.is_nullable() && valid_data != nullptr) + ? valid_count + : count; + auto data_array = CreateEmptyVectorDataArray(data_count, field_meta); + if (field_meta.is_nullable() && valid_data != nullptr) { + auto obj = data_array->mutable_valid_data(); + auto valid_data_bool = reinterpret_cast(valid_data); + obj->Add(valid_data_bool, valid_data_bool + count); + } + return data_array; +} + std::unique_ptr CreateScalarDataArrayFrom(const void* data_raw, const void* valid_data, @@ -444,7 +461,7 @@ CreateScalarDataArrayFrom(const void* data_raw, data_array->set_field_id(field_meta.get_id().get()); data_array->set_type(static_cast( field_meta.get_data_type())); - if (field_meta.is_nullable()) { + if (field_meta.is_nullable() && valid_data != nullptr) { auto valid_data_ = reinterpret_cast(valid_data); auto obj = data_array->mutable_valid_data(); obj->Add(valid_data_, valid_data_ + count); @@ -659,6 +676,22 @@ CreateVectorDataArrayFrom(const void* data_raw, return data_array; } +std::unique_ptr +CreateVectorDataArrayFrom(const void* data_raw, + const void* valid_data, + int64_t count, + int64_t valid_count, + const FieldMeta& field_meta) { + auto data_array = + CreateVectorDataArrayFrom(data_raw, valid_count, field_meta); + if (field_meta.is_nullable() && valid_data != nullptr) { + auto obj = data_array->mutable_valid_data(); + auto valid_data_bool = reinterpret_cast(valid_data); + obj->Add(valid_data_bool, valid_data_bool + count); + } + return data_array; +} + std::unique_ptr CreateDataArrayFrom(const void* data_raw, const void* valid_data, @@ -691,6 +724,21 @@ MergeDataArray(std::vector& merge_bases, AssertInfo(data_type == DataType(src_field_data->type()), "merge field data type not consistent"); if (field_meta.is_vector()) { + bool is_valid = true; + if (nullable) { + auto data = src_field_data->valid_data().data(); + auto obj = data_array->mutable_valid_data(); + is_valid = data[src_offset]; + *(obj->Add()) = is_valid; + } + + if (!is_valid) { + continue; + } + + int64_t physical_offset = + merge_base.getValidDataOffset(field_meta.get_id()); + auto vector_array = data_array->mutable_vectors(); auto dim = 0; if (!IsSparseFloatVectorDataType(data_type)) { @@ -700,17 +748,19 @@ MergeDataArray(std::vector& merge_bases, if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) { auto data = VEC_FIELD_DATA(src_field_data, float).data(); auto obj = vector_array->mutable_float_vector(); - obj->mutable_data()->Add(data + src_offset * dim, - data + (src_offset + 1) * dim); + obj->mutable_data()->Add(data + physical_offset * dim, + data + (physical_offset + 1) * dim); } else if (field_meta.get_data_type() == DataType::VECTOR_FLOAT16) { auto data = VEC_FIELD_DATA(src_field_data, float16); auto obj = vector_array->mutable_float16_vector(); - obj->assign(data, dim * sizeof(float16)); + obj->assign(data + physical_offset * dim * sizeof(float16), + dim * sizeof(float16)); } else if (field_meta.get_data_type() == DataType::VECTOR_BFLOAT16) { auto data = VEC_FIELD_DATA(src_field_data, bfloat16); auto obj = vector_array->mutable_bfloat16_vector(); - obj->assign(data, dim * sizeof(bfloat16)); + obj->assign(data + physical_offset * dim * sizeof(bfloat16), + dim * sizeof(bfloat16)); } else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) { AssertInfo( dim % 8 == 0, @@ -718,26 +768,28 @@ MergeDataArray(std::vector& merge_bases, auto num_bytes = dim / 8; auto data = VEC_FIELD_DATA(src_field_data, binary); auto obj = vector_array->mutable_binary_vector(); - obj->assign(data + src_offset * num_bytes, num_bytes); + obj->assign(data + physical_offset * num_bytes, num_bytes); } else if (field_meta.get_data_type() == DataType::VECTOR_SPARSE_U32_F32) { - auto src = src_field_data->vectors().sparse_float_vector(); + auto& src_vec = src_field_data->vectors().sparse_float_vector(); auto dst = vector_array->mutable_sparse_float_vector(); - if (src.dim() > dst->dim()) { - dst->set_dim(src.dim()); + if (src_vec.dim() > dst->dim()) { + dst->set_dim(src_vec.dim()); } vector_array->set_dim(dst->dim()); - *dst->mutable_contents() = src.contents(); + auto& src_contents = src_vec.contents(physical_offset); + *(dst->mutable_contents()->Add()) = src_contents; } else if (field_meta.get_data_type() == DataType::VECTOR_INT8) { auto data = VEC_FIELD_DATA(src_field_data, int8); auto obj = vector_array->mutable_int8_vector(); - obj->assign(data, dim * sizeof(int8)); + obj->assign(data + physical_offset * dim * sizeof(int8), + dim * sizeof(int8)); } else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { - auto data = src_field_data->vectors().vector_array(); + auto& data = src_field_data->vectors().vector_array(); auto obj = vector_array->mutable_vector_array(); obj->set_element_type( proto::schema::DataType(field_meta.get_element_type())); - obj->CopyFrom(data); + *(obj->mutable_data()->Add()) = data.data(physical_offset); } else { ThrowInfo(DataTypeInvalid, fmt::format("unsupported datatype {}", data_type)); diff --git a/internal/core/src/segcore/Utils.h b/internal/core/src/segcore/Utils.h index 91170b8448..bdd0fa3c46 100644 --- a/internal/core/src/segcore/Utils.h +++ b/internal/core/src/segcore/Utils.h @@ -55,6 +55,12 @@ CreateEmptyScalarDataArray(int64_t count, const FieldMeta& field_meta); std::unique_ptr CreateEmptyVectorDataArray(int64_t count, const FieldMeta& field_meta); +std::unique_ptr +CreateEmptyVectorDataArray(int64_t count, + int64_t valid_count, + const void* valid_data, + const FieldMeta& field_meta); + std::unique_ptr CreateScalarDataArrayFrom(const void* data_raw, const void* valid_data, @@ -66,6 +72,13 @@ CreateVectorDataArrayFrom(const void* data_raw, int64_t count, const FieldMeta& field_meta); +std::unique_ptr +CreateVectorDataArrayFrom(const void* data_raw, + const void* valid_data, + int64_t count, + int64_t valid_count, + const FieldMeta& field_meta); + std::unique_ptr CreateDataArrayFrom(const void* data_raw, const void* valid_data, @@ -77,6 +90,7 @@ struct MergeBase { private: std::map>* output_fields_data_; size_t offset_; + std::map valid_data_offsets_; public: MergeBase() { @@ -93,6 +107,20 @@ struct MergeBase { return offset_; } + void + setValidDataOffset(FieldId fieldId, size_t valid_offset) { + valid_data_offsets_[fieldId] = valid_offset; + } + + size_t + getValidDataOffset(FieldId fieldId) const { + auto it = valid_data_offsets_.find(fieldId); + if (it != valid_data_offsets_.end()) { + return it->second; + } + return offset_; + } + milvus::DataArray* get_field_data(FieldId fieldId) const { return (*output_fields_data_)[fieldId].get(); diff --git a/internal/core/src/segcore/UtilsTest.cpp b/internal/core/src/segcore/UtilsTest.cpp index 9d64c3b645..15754ff46f 100644 --- a/internal/core/src/segcore/UtilsTest.cpp +++ b/internal/core/src/segcore/UtilsTest.cpp @@ -94,3 +94,96 @@ TEST(Util_Segcore, GetDeleteBitmap) { delete_record.Query(res_view, insert_barrier, query_timestamp); ASSERT_EQ(res_view.count(), 0); } + +TEST(Util_Segcore, CreateVectorDataArrayFromNullableVectors) { + using namespace milvus; + using namespace milvus::segcore; + + auto schema = std::make_shared(); + auto vec = schema->AddDebugField( + "embeddings", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2, true); + auto& field_meta = (*schema)[vec]; + + int64_t dim = 16; + int64_t total_count = 10; + int64_t valid_count = 5; + + std::vector data(valid_count * dim); + for (size_t i = 0; i < data.size(); ++i) { + data[i] = static_cast(i); + } + + std::unique_ptr valid_flags = std::make_unique(total_count); + for (int64_t i = 0; i < total_count; ++i) { + if (i % 2 == 0) { + valid_flags[i] = true; + } else { + valid_flags[i] = false; + } + } + + auto result = CreateVectorDataArrayFrom( + data.data(), valid_flags.get(), total_count, valid_count, field_meta); + + ASSERT_TRUE(result->valid_data().size() > 0); + ASSERT_EQ(result->valid_data().size(), total_count); + ASSERT_EQ(result->vectors().float_vector().data_size(), valid_count * dim); +} + +TEST(Util_Segcore, MergeDataArrayWithNullableVectors) { + using namespace milvus; + using namespace milvus::segcore; + + auto schema = std::make_shared(); + auto vec = schema->AddDebugField( + "embeddings", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2, true); + auto& field_meta = (*schema)[vec]; + + int64_t dim = 16; + int64_t total_count = 10; + int64_t valid_count = 5; + + std::vector data(valid_count * dim); + for (size_t i = 0; i < data.size(); ++i) { + data[i] = static_cast(i); + } + + std::unique_ptr valid_flags = std::make_unique(total_count); + for (int64_t i = 0; i < total_count; ++i) { + if (i % 2 == 0) { + valid_flags[i] = true; + } else { + valid_flags[i] = false; + } + } + + auto data_array = CreateVectorDataArrayFrom( + data.data(), valid_flags.get(), total_count, valid_count, field_meta); + + std::map> output_fields_data; + output_fields_data[vec] = std::move(data_array); + + std::vector merge_bases; + merge_bases.emplace_back(&output_fields_data, 0); + merge_bases.back().setValidDataOffset(vec, 0); + merge_bases.emplace_back(&output_fields_data, 2); + merge_bases.back().setValidDataOffset(vec, 1); + merge_bases.emplace_back(&output_fields_data, 4); + merge_bases.back().setValidDataOffset(vec, 2); + merge_bases.emplace_back(&output_fields_data, 6); + merge_bases.back().setValidDataOffset(vec, 3); + merge_bases.emplace_back(&output_fields_data, 8); + merge_bases.back().setValidDataOffset(vec, 4); + + auto merged_result = MergeDataArray(merge_bases, field_meta); + + ASSERT_TRUE(merged_result->valid_data().size() > 0); + ASSERT_EQ(merged_result->valid_data().size(), 5); + ASSERT_EQ(merged_result->vectors().float_vector().data_size(), 5 * dim); + + ASSERT_TRUE(merged_result->valid_data(0)); + ASSERT_TRUE(merged_result->valid_data(1)); + ASSERT_TRUE(merged_result->valid_data(2)); + ASSERT_TRUE(merged_result->valid_data(3)); + ASSERT_TRUE(merged_result->valid_data(4)); +} diff --git a/internal/core/src/segcore/reduce/Reduce.cpp b/internal/core/src/segcore/reduce/Reduce.cpp index 87363d70e0..e354e66fb2 100644 --- a/internal/core/src/segcore/reduce/Reduce.cpp +++ b/internal/core/src/segcore/reduce/Reduce.cpp @@ -541,6 +541,27 @@ ReduceHelper::GetSearchResultDataSlice(const int slice_index, // set result offset to fill output fields data result_pairs[loc] = {&search_result->output_fields_data_, ki}; + + for (auto field_id : plan_->target_entries_) { + auto& field_meta = plan_->schema_->operator[](field_id); + if (field_meta.is_vector() && field_meta.is_nullable()) { + auto it = + search_result->output_fields_data_.find(field_id); + if (it != search_result->output_fields_data_.end()) { + auto& field_data = it->second; + if (field_data->valid_data_size() > 0) { + int64_t valid_idx = 0; + for (int64_t i = 0; i < ki; ++i) { + if (field_data->valid_data(i)) { + valid_idx++; + } + } + result_pairs[loc].setValidDataOffset(field_id, + valid_idx); + } + } + } + } } } diff --git a/internal/core/src/segcore/reduce/StreamReduce.cpp b/internal/core/src/segcore/reduce/StreamReduce.cpp index b830cbb163..edc75abd91 100644 --- a/internal/core/src/segcore/reduce/StreamReduce.cpp +++ b/internal/core/src/segcore/reduce/StreamReduce.cpp @@ -18,6 +18,32 @@ namespace milvus::segcore { +void +StreamReducerHelper::SetNullableVectorValidDataOffsets( + const std::map>& + output_fields_data, + int64_t ki, + MergeBase& merge_base) { + for (auto field_id : plan_->target_entries_) { + auto& field_meta = plan_->schema_->operator[](field_id); + if (field_meta.is_vector() && field_meta.is_nullable()) { + auto it = output_fields_data.find(field_id); + if (it != output_fields_data.end()) { + auto& field_data = it->second; + if (field_data->valid_data_size() > 0) { + int64_t physical_offset = 0; + for (int64_t j = 0; j < ki; ++j) { + if (field_data->valid_data(j)) { + physical_offset++; + } + } + merge_base.setValidDataOffset(field_id, physical_offset); + } + } + } + } +} + void StreamReducerHelper::FillEntryData() { for (auto search_result : search_results_to_merge_) { @@ -98,6 +124,10 @@ StreamReducerHelper::AssembleMergedResult() { } merge_output_data_bases[nq_base_offset + loc] = { &search_result->output_fields_data_, ki}; + SetNullableVectorValidDataOffsets( + search_result->output_fields_data_, + ki, + merge_output_data_bases[nq_base_offset + loc]); new_result_offsets[nq_base_offset + loc] = loc; real_topKs[qi]++; } @@ -127,6 +157,10 @@ StreamReducerHelper::AssembleMergedResult() { } merge_output_data_bases[nq_base_offset + loc] = { &merged_search_result->output_fields_data_, ki}; + SetNullableVectorValidDataOffsets( + merged_search_result->output_fields_data_, + ki, + merge_output_data_bases[nq_base_offset + loc]); new_result_offsets[nq_base_offset + loc] = loc; real_topKs[qi]++; } diff --git a/internal/core/src/segcore/reduce/StreamReduce.h b/internal/core/src/segcore/reduce/StreamReduce.h index f25d831413..f3166a0c6f 100644 --- a/internal/core/src/segcore/reduce/StreamReduce.h +++ b/internal/core/src/segcore/reduce/StreamReduce.h @@ -19,6 +19,7 @@ #include "query/PlanImpl.h" #include "common/QueryResult.h" #include "segcore/ReduceStructure.h" +#include "segcore/Utils.h" #include "common/EasyAssert.h" namespace milvus::segcore { @@ -207,6 +208,13 @@ class StreamReducerHelper { void CleanReduceStatus(); + void + SetNullableVectorValidDataOffsets( + const std::map>& + output_fields_data, + int64_t ki, + MergeBase& merge_base); + std::unique_ptr merged_search_result; milvus::query::Plan* plan_; std::vector slice_nqs_; diff --git a/internal/core/src/segcore/storagev1translator/DefaultValueChunkTranslator.cpp b/internal/core/src/segcore/storagev1translator/DefaultValueChunkTranslator.cpp index 5bb5831b1c..49e87e5dcb 100644 --- a/internal/core/src/segcore/storagev1translator/DefaultValueChunkTranslator.cpp +++ b/internal/core/src/segcore/storagev1translator/DefaultValueChunkTranslator.cpp @@ -107,6 +107,16 @@ DefaultValueChunkTranslator::estimated_byte_size_of_cell( case milvus::DataType::ARRAY: value_size = sizeof(Array); break; + case milvus::DataType::VECTOR_FLOAT: + case milvus::DataType::VECTOR_BINARY: + case milvus::DataType::VECTOR_FLOAT16: + case milvus::DataType::VECTOR_BFLOAT16: + case milvus::DataType::VECTOR_INT8: + case milvus::DataType::VECTOR_SPARSE_U32_F32: + AssertInfo(field_meta_.is_nullable(), + "only nullable vector fields can be dynamically added"); + value_size = 0; + break; default: ThrowInfo(DataTypeInvalid, "unsupported default value data type {}", @@ -128,8 +138,15 @@ DefaultValueChunkTranslator::get_cells( AssertInfo(cids.size() == 1 && cids[0] == 0, "DefaultValueChunkTranslator only supports one cell"); auto num_rows = meta_.num_rows_until_chunk_[1]; - auto builder = - milvus::storage::CreateArrowBuilder(field_meta_.get_data_type()); + auto data_type = field_meta_.get_data_type(); + std::shared_ptr builder; + if (IsVectorDataType(data_type)) { + AssertInfo(field_meta_.is_nullable(), + "only nullable vector fields can be dynamically added"); + builder = std::make_shared(); + } else { + builder = milvus::storage::CreateArrowBuilder(data_type); + } arrow::Status ast; if (field_meta_.default_value().has_value()) { ast = builder->Reserve(num_rows); diff --git a/internal/core/src/segcore/storagev1translator/InterimSealedIndexTranslator.cpp b/internal/core/src/segcore/storagev1translator/InterimSealedIndexTranslator.cpp index e871f72a9d..72567d873f 100644 --- a/internal/core/src/segcore/storagev1translator/InterimSealedIndexTranslator.cpp +++ b/internal/core/src/segcore/storagev1translator/InterimSealedIndexTranslator.cpp @@ -143,20 +143,54 @@ InterimSealedIndexTranslator::get_cells( } auto num_chunk = vec_data_->num_chunks(); + const auto& offset_mapping = vec_data_->GetOffsetMapping(); + bool nullable = offset_mapping.IsEnabled(); + const auto& valid_count_per_chunk = + nullable ? vec_data_->GetValidCountPerChunk() : std::vector{}; + + int64_t total_valid_count = + nullable ? offset_mapping.GetValidCount() : vec_data_->NumRows(); + + if (total_valid_count == 0) { + if (nullable) { + const auto& valid_data = vec_data_->GetValidData(); + vec_index->BuildValidData(valid_data.data(), valid_data.size()); + } + std::vector>> + result; + result.emplace_back(std::make_pair(0, std::move(vec_index))); + return result; + } + + bool first_build = true; for (int i = 0; i < num_chunk; ++i) { auto pw = vec_data_->GetChunk(nullptr, i); auto chunk = pw.get(); - auto dataset = knowhere::GenDataSet( - vec_data_->chunk_row_nums(i), dim_, chunk->Data()); + + int64_t actual_row_count = + nullable ? valid_count_per_chunk[i] : vec_data_->chunk_row_nums(i); + + if (actual_row_count == 0) { + continue; + } + + auto dataset = + knowhere::GenDataSet(actual_row_count, dim_, chunk->Data()); dataset->SetIsOwner(false); dataset->SetIsSparse(is_sparse_); - if (i == 0) { + if (first_build) { vec_index->BuildWithDataset(dataset, build_config_); + first_build = false; } else { vec_index->AddWithDataset(dataset, build_config_); } } + + if (nullable) { + const auto& valid_data = vec_data_->GetValidData(); + vec_index->BuildValidData(valid_data.data(), valid_data.size()); + } std::vector>> result; result.emplace_back(std::make_pair(0, std::move(vec_index))); diff --git a/internal/core/src/storage/DataCodecTest.cpp b/internal/core/src/storage/DataCodecTest.cpp index 5dacf6401c..14770e8fbd 100644 --- a/internal/core/src/storage/DataCodecTest.cpp +++ b/internal/core/src/storage/DataCodecTest.cpp @@ -763,7 +763,80 @@ TEST(storage, InsertDataFloatVector) { ASSERT_EQ(data, new_data); } -TEST(storage, InsertDataSparseFloat) { +TEST(storage, InsertDataFloatVectorNullable) { + int DIM = 4; + int num_rows = 100; + + for (int null_percent : {0, 20, 100}) { + int valid_count = num_rows * (100 - null_percent) / 100; + bool is_nullable = true; + + std::vector data(valid_count * DIM); + for (int i = 0; i < valid_count * DIM; ++i) { + data[i] = static_cast(i) * 0.5f; + } + + FieldDataPtr field_data; + std::vector valid_data((num_rows + 7) / 8, 0); + for (int i = 0; i < valid_count; ++i) { + valid_data[i >> 3] |= (1 << (i & 0x07)); + } + + field_data = milvus::storage::CreateFieldData( + storage::DataType::VECTOR_FLOAT, DataType::NONE, true, DIM); + auto field_data_impl = + std::dynamic_pointer_cast>( + field_data); + field_data_impl->FillFieldData( + data.data(), valid_data.data(), num_rows, 0); + + ASSERT_EQ(field_data->get_num_rows(), num_rows); + ASSERT_EQ(field_data->get_valid_rows(), valid_count); + ASSERT_EQ(field_data->get_null_count(), num_rows - valid_count); + ASSERT_EQ(field_data->IsNullable(), is_nullable); + + auto payload_reader = + std::make_shared(field_data); + storage::InsertData insert_data(payload_reader); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = + insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + ASSERT_EQ(new_insert_data->GetTimeRage(), + std::make_pair(Timestamp(0), Timestamp(100))); + + auto new_payload = new_insert_data->GetFieldData(); + + ASSERT_EQ(new_payload->get_data_type(), + storage::DataType::VECTOR_FLOAT); + ASSERT_EQ(new_payload->get_num_rows(), num_rows); + ASSERT_EQ(new_payload->get_valid_rows(), valid_count); + ASSERT_EQ(new_payload->get_null_count(), num_rows - valid_count); + ASSERT_EQ(new_payload->IsNullable(), is_nullable); + + int valid_idx = 0; + for (int i = 0; i < num_rows; ++i) { + if (new_payload->is_valid(i)) { + // RawValue takes logical offset, internally converts to physical + auto vec_ptr = + static_cast(new_payload->RawValue(i)); + for (int j = 0; j < DIM; ++j) { + ASSERT_FLOAT_EQ(vec_ptr[j], data[valid_idx * DIM + j]); + } + valid_idx++; + } + } + } +} + +TEST(storage, InsertDataSparseFloatVector) { auto n_rows = 100; auto vecs = milvus::segcore::GenerateRandomSparseFloatVector( n_rows, kTestSparseDim, kTestSparseVectorDensity); @@ -810,6 +883,75 @@ TEST(storage, InsertDataSparseFloat) { } } +TEST(storage, InsertDataSparseFloatVectorNullable) { + int num_rows = 100; + + for (int null_percent : {0, 20, 100}) { + int valid_count = num_rows * (100 - null_percent) / 100; + bool is_nullable = true; + auto vecs = milvus::segcore::GenerateRandomSparseFloatVector( + valid_count, kTestSparseDim, kTestSparseVectorDensity); + + FieldDataPtr field_data; + std::vector valid_data((num_rows + 7) / 8, 0); + for (int i = 0; i < valid_count; ++i) { + valid_data[i >> 3] |= (1 << (i & 0x07)); + } + + field_data = + milvus::storage::CreateFieldData(DataType::VECTOR_SPARSE_U32_F32, + DataType::NONE, + true, + kTestSparseDim, + num_rows); + + auto field_data_impl = std::dynamic_pointer_cast< + milvus::FieldData>(field_data); + field_data_impl->FillFieldData( + vecs.get(), valid_data.data(), num_rows, 0); + + ASSERT_EQ(field_data->get_num_rows(), num_rows); + ASSERT_EQ(field_data->get_valid_rows(), valid_count); + ASSERT_EQ(field_data->IsNullable(), is_nullable); + + auto payload_reader = + std::make_shared(field_data); + storage::InsertData insert_data(payload_reader); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = + insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_TRUE(new_payload->get_data_type() == + storage::DataType::VECTOR_SPARSE_U32_F32); + ASSERT_EQ(new_payload->get_num_rows(), num_rows); + ASSERT_EQ(new_payload->IsNullable(), is_nullable); + + int valid_idx = 0; + for (int i = 0; i < num_rows; ++i) { + if (new_payload->is_valid(i)) { + auto& original = vecs[valid_idx]; + auto new_vec = static_cast*>(new_payload->RawValue(i)); + ASSERT_EQ(original.size(), new_vec->size()); + for (size_t j = 0; j < original.size(); ++j) { + ASSERT_EQ(original[j].id, (*new_vec)[j].id); + ASSERT_EQ(original[j].val, (*new_vec)[j].val); + } + valid_idx++; + } + } + } +} + TEST(storage, InsertDataBinaryVector) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; int DIM = 16; @@ -841,6 +983,72 @@ TEST(storage, InsertDataBinaryVector) { ASSERT_EQ(data, new_data); } +TEST(storage, InsertDataBinaryVectorNullable) { + int DIM = 128; + int num_rows = 100; + + for (int null_percent : {0, 20, 100}) { + int valid_count = num_rows * (100 - null_percent) / 100; + bool is_nullable = true; + + std::vector data(valid_count * DIM / 8); + for (size_t i = 0; i < data.size(); ++i) { + data[i] = static_cast(i % 256); + } + + FieldDataPtr field_data; + std::vector valid_data((num_rows + 7) / 8, 0); + for (int i = 0; i < valid_count; ++i) { + valid_data[i >> 3] |= (1 << (i & 0x07)); + } + + field_data = milvus::storage::CreateFieldData( + storage::DataType::VECTOR_BINARY, DataType::NONE, true, DIM); + auto field_data_impl = + std::dynamic_pointer_cast>( + field_data); + field_data_impl->FillFieldData( + data.data(), valid_data.data(), num_rows, 0); + + ASSERT_EQ(field_data->get_num_rows(), num_rows); + ASSERT_EQ(field_data->get_valid_rows(), valid_count); + ASSERT_EQ(field_data->IsNullable(), is_nullable); + + auto payload_reader = + std::make_shared(field_data); + storage::InsertData insert_data(payload_reader); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = + insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), + storage::DataType::VECTOR_BINARY); + ASSERT_EQ(new_payload->get_num_rows(), num_rows); + ASSERT_EQ(new_payload->IsNullable(), is_nullable); + + int valid_idx = 0; + for (int i = 0; i < num_rows; ++i) { + if (new_payload->is_valid(i)) { + auto vec_ptr = + static_cast(new_payload->RawValue(i)); + for (int j = 0; j < DIM / 8; ++j) { + ASSERT_EQ(vec_ptr[j], data[valid_idx * DIM / 8 + j]); + } + valid_idx++; + } + } + } +} + TEST(storage, InsertDataFloat16Vector) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; int DIM = 2; @@ -874,6 +1082,72 @@ TEST(storage, InsertDataFloat16Vector) { ASSERT_EQ(data, new_data); } +TEST(storage, InsertDataFloat16VectorNullable) { + int DIM = 4; + int num_rows = 100; + + for (int null_percent : {0, 20, 100}) { + int valid_count = num_rows * (100 - null_percent) / 100; + bool is_nullable = true; + + std::vector data(valid_count * DIM); + for (int i = 0; i < valid_count * DIM; ++i) { + data[i] = static_cast(i * 0.5f); + } + + FieldDataPtr field_data; + std::vector valid_data((num_rows + 7) / 8, 0); + for (int i = 0; i < valid_count; ++i) { + valid_data[i >> 3] |= (1 << (i & 0x07)); + } + + field_data = milvus::storage::CreateFieldData( + storage::DataType::VECTOR_FLOAT16, DataType::NONE, true, DIM); + auto field_data_impl = + std::dynamic_pointer_cast>( + field_data); + field_data_impl->FillFieldData( + data.data(), valid_data.data(), num_rows, 0); + + ASSERT_EQ(field_data->get_num_rows(), num_rows); + ASSERT_EQ(field_data->get_valid_rows(), valid_count); + ASSERT_EQ(field_data->IsNullable(), is_nullable); + + auto payload_reader = + std::make_shared(field_data); + storage::InsertData insert_data(payload_reader); + storage::FieldDataMeta field_data_meta{100, 101, 102, 103}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = + insert_data.Serialize(storage::StorageType::Remote); + std::shared_ptr serialized_data_ptr(serialized_bytes.data(), + [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, serialized_bytes.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType); + + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_EQ(new_payload->get_data_type(), + storage::DataType::VECTOR_FLOAT16); + ASSERT_EQ(new_payload->get_num_rows(), num_rows); + ASSERT_EQ(new_payload->IsNullable(), is_nullable); + + int valid_idx = 0; + for (int i = 0; i < num_rows; ++i) { + if (new_payload->is_valid(i)) { + auto vec_ptr = + static_cast(new_payload->RawValue(i)); + for (int j = 0; j < DIM; ++j) { + ASSERT_EQ(vec_ptr[j], data[valid_idx * DIM + j]); + } + valid_idx++; + } + } + } +} + TEST(storage, IndexData) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; storage::IndexData index_data(data.data(), data.size()); diff --git a/internal/core/src/storage/DiskFileManagerImpl.cpp b/internal/core/src/storage/DiskFileManagerImpl.cpp index c6dbfad4de..6cbf741cee 100644 --- a/internal/core/src/storage/DiskFileManagerImpl.cpp +++ b/internal/core/src/storage/DiskFileManagerImpl.cpp @@ -536,7 +536,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_internal(const Config& config) { int batch_size = batch_files.size(); for (int i = 0; i < batch_size; i++) { auto field_data = field_datas[i].get()->GetFieldData(); - num_rows += uint32_t(field_data->get_num_rows()); + num_rows += uint32_t(field_data->get_valid_rows()); cache_raw_data_to_disk_common( field_data, local_chunk_manager, @@ -634,7 +634,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_common( auto sparse_rows = static_cast*>( field_data->Data()); - for (size_t i = 0; i < field_data->Length(); ++i) { + for (size_t i = 0; i < field_data->get_valid_rows(); ++i) { auto row = sparse_rows[i]; auto row_byte_size = row.data_byte_size(); uint32_t nnz = row.size(); @@ -689,7 +689,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_common( } else { dim = field_data->get_dim(); auto data_size = - field_data->get_num_rows() * milvus::GetVecRowSize(dim); + field_data->get_valid_rows() * milvus::GetVecRowSize(dim); local_chunk_manager->Write(local_data_path, write_offset, const_cast(field_data->Data()), @@ -761,7 +761,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_storage_v2(const Config& config) { fs_); } for (auto& field_data : field_datas) { - num_rows += uint32_t(field_data->get_num_rows()); + num_rows += uint32_t(field_data->get_valid_rows()); cache_raw_data_to_disk_common(field_data, local_chunk_manager, local_data_path, diff --git a/internal/core/src/storage/DiskFileManagerTest.cpp b/internal/core/src/storage/DiskFileManagerTest.cpp index 0ad9d5022e..f05b22c232 100644 --- a/internal/core/src/storage/DiskFileManagerTest.cpp +++ b/internal/core/src/storage/DiskFileManagerTest.cpp @@ -523,6 +523,262 @@ TEST_F(DiskAnnFileManagerTest, CacheOptFieldToDiskOnlyOneCategory) { } } +TEST_F(DiskAnnFileManagerTest, CacheRawDataToDiskNullableVector) { + const int64_t collection_id = 1; + const int64_t partition_id = 2; + const int64_t segment_id = 3; + const int64_t field_id = 100; + const int64_t dim = 128; + const int64_t num_rows = 1000; + + struct VectorTypeInfo { + DataType data_type; + std::string type_name; + size_t element_size; + bool is_sparse; + }; + + std::vector vector_types = { + {DataType::VECTOR_FLOAT, "FLOAT", sizeof(float), false}, + {DataType::VECTOR_FLOAT16, "FLOAT16", sizeof(knowhere::fp16), false}, + {DataType::VECTOR_BFLOAT16, "BFLOAT16", sizeof(knowhere::bf16), false}, + {DataType::VECTOR_INT8, "INT8", sizeof(int8_t), false}, + {DataType::VECTOR_BINARY, "BINARY", dim / 8, false}, + {DataType::VECTOR_SPARSE_U32_F32, "SPARSE", 0, true}}; + + for (const auto& vec_type : vector_types) { + for (int null_percent : {0, 20, 100}) { + int64_t valid_count = num_rows * (100 - null_percent) / 100; + + std::vector valid_data((num_rows + 7) / 8, 0); + for (int64_t i = 0; i < valid_count; ++i) { + valid_data[i >> 3] |= (1 << (i & 0x07)); + } + + FieldDataPtr field_data; + std::vector vec_data; + std::unique_ptr[]> sparse_vecs; + + if (vec_type.is_sparse) { + const int64_t sparse_dim = 1000; + const float sparse_density = 0.1; + sparse_vecs = milvus::segcore::GenerateRandomSparseFloatVector( + valid_count, sparse_dim, sparse_density); + + field_data = + storage::CreateFieldData(DataType::VECTOR_SPARSE_U32_F32, + DataType::NONE, + true, + sparse_dim, + num_rows); + auto field_data_impl = std::dynamic_pointer_cast< + milvus::FieldData>(field_data); + field_data_impl->FillFieldData( + sparse_vecs.get(), valid_data.data(), num_rows, 0); + } else { + if (vec_type.data_type == DataType::VECTOR_BINARY) { + vec_data.resize(valid_count * dim / 8); + } else { + vec_data.resize(valid_count * dim * vec_type.element_size); + } + for (size_t i = 0; i < vec_data.size(); ++i) { + vec_data[i] = static_cast(i % 256); + } + + field_data = storage::CreateFieldData( + vec_type.data_type, DataType::NONE, true, dim); + + if (vec_type.data_type == DataType::VECTOR_FLOAT) { + auto impl = std::dynamic_pointer_cast< + milvus::FieldData>(field_data); + impl->FillFieldData( + vec_data.data(), valid_data.data(), num_rows, 0); + } else if (vec_type.data_type == DataType::VECTOR_FLOAT16) { + auto impl = std::dynamic_pointer_cast< + milvus::FieldData>(field_data); + impl->FillFieldData( + vec_data.data(), valid_data.data(), num_rows, 0); + } else if (vec_type.data_type == DataType::VECTOR_BFLOAT16) { + auto impl = std::dynamic_pointer_cast< + milvus::FieldData>(field_data); + impl->FillFieldData( + vec_data.data(), valid_data.data(), num_rows, 0); + } else if (vec_type.data_type == DataType::VECTOR_INT8) { + auto impl = std::dynamic_pointer_cast< + milvus::FieldData>(field_data); + impl->FillFieldData( + vec_data.data(), valid_data.data(), num_rows, 0); + } else if (vec_type.data_type == DataType::VECTOR_BINARY) { + auto impl = std::dynamic_pointer_cast< + milvus::FieldData>(field_data); + impl->FillFieldData( + vec_data.data(), valid_data.data(), num_rows, 0); + } + } + + ASSERT_EQ(field_data->get_num_rows(), num_rows); + ASSERT_EQ(field_data->get_valid_rows(), valid_count); + + auto payload_reader = + std::make_shared(field_data); + storage::InsertData insert_data(payload_reader); + FieldDataMeta field_data_meta = { + collection_id, partition_id, segment_id, field_id}; + insert_data.SetFieldDataMeta(field_data_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_data = + insert_data.Serialize(storage::StorageType::Remote); + + std::string insert_file_path = "/tmp/diskann/nullable_" + + vec_type.type_name + "_" + + std::to_string(null_percent); + boost::filesystem::remove_all(insert_file_path); + cm_->Write(insert_file_path, + serialized_data.data(), + serialized_data.size()); + + if (vec_type.is_sparse) { + int64_t file_size = cm_->Size(insert_file_path); + std::vector buffer(file_size); + cm_->Read(insert_file_path, buffer.data(), file_size); + + std::shared_ptr serialized_data_ptr( + buffer.data(), [&](uint8_t*) {}); + auto new_insert_data = storage::DeserializeFileData( + serialized_data_ptr, buffer.size()); + ASSERT_EQ(new_insert_data->GetCodecType(), + storage::InsertDataType); + + auto new_payload = new_insert_data->GetFieldData(); + ASSERT_TRUE(new_payload->get_data_type() == + DataType::VECTOR_SPARSE_U32_F32); + ASSERT_EQ(new_payload->get_num_rows(), num_rows) + << "num_rows mismatch for " << vec_type.type_name + << " with null_percent=" << null_percent; + ASSERT_EQ(new_payload->get_valid_rows(), valid_count) + << "valid_rows mismatch for " << vec_type.type_name + << " with null_percent=" << null_percent; + ASSERT_TRUE(new_payload->IsNullable()); + + for (int i = 0; i < num_rows; ++i) { + if (i < valid_count) { + ASSERT_TRUE(new_payload->is_valid(i)) + << "Row " << i + << " should be valid for null_percent=" + << null_percent; + + auto original = &sparse_vecs[i]; + auto new_vec = + static_cast*>( + new_payload->RawValue(i)); + ASSERT_EQ(original->size(), new_vec->size()) + << "Size mismatch at row " << i + << " for null_percent=" << null_percent; + + for (size_t j = 0; j < original->size(); ++j) { + ASSERT_EQ((*original)[j].id, (*new_vec)[j].id) + << "ID mismatch at row " << i << ", element " + << j << " for null_percent=" << null_percent; + ASSERT_EQ((*original)[j].val, (*new_vec)[j].val) + << "Value mismatch at row " << i << ", element " + << j << " for null_percent=" << null_percent; + } + } else { + ASSERT_FALSE(new_payload->is_valid(i)) + << "Row " << i + << " should be null for null_percent=" + << null_percent; + } + } + } else { + IndexMeta index_meta = {segment_id, + field_id, + 1000, + 1, + "test", + "vec_field", + vec_type.data_type, + dim}; + auto file_manager = std::make_shared( + storage::FileManagerContext( + field_data_meta, index_meta, cm_, fs_)); + + milvus::Config config; + config[INSERT_FILES_KEY] = + std::vector{insert_file_path}; + + std::string local_data_path; + if (vec_type.data_type == DataType::VECTOR_FLOAT) { + local_data_path = + file_manager->CacheRawDataToDisk(config); + } else if (vec_type.data_type == DataType::VECTOR_INT8) { + local_data_path = + file_manager->CacheRawDataToDisk(config); + } else if (vec_type.data_type == DataType::VECTOR_FLOAT16) { + local_data_path = + file_manager->CacheRawDataToDisk( + config); + } else if (vec_type.data_type == DataType::VECTOR_BFLOAT16) { + local_data_path = + file_manager->CacheRawDataToDisk( + config); + } else if (vec_type.data_type == DataType::VECTOR_BINARY) { + local_data_path = + file_manager->CacheRawDataToDisk(config); + } + + ASSERT_FALSE(local_data_path.empty()) + << "Failed for " << vec_type.type_name + << " with null_percent=" << null_percent; + + auto local_chunk_manager = + LocalChunkManagerSingleton::GetInstance().GetChunkManager(); + uint32_t read_num_rows = 0; + uint32_t read_dim = 0; + local_chunk_manager->Read( + local_data_path, 0, &read_num_rows, sizeof(read_num_rows)); + local_chunk_manager->Read(local_data_path, + sizeof(read_num_rows), + &read_dim, + sizeof(read_dim)); + + EXPECT_EQ(read_num_rows, valid_count) + << "Mismatch for " << vec_type.type_name + << " with null_percent=" << null_percent; + EXPECT_EQ(read_dim, dim); + + size_t bytes_per_vector = + (vec_type.data_type == DataType::VECTOR_BINARY) + ? (dim / 8) + : (dim * vec_type.element_size); + auto data_size = read_num_rows * bytes_per_vector; + std::vector buffer(data_size); + local_chunk_manager->Read( + local_data_path, + sizeof(read_num_rows) + sizeof(read_dim), + buffer.data(), + data_size); + + EXPECT_EQ(buffer.size(), vec_data.size()) + << "Data size mismatch for " << vec_type.type_name; + for (size_t i = 0; i < std::min(buffer.size(), vec_data.size()); + ++i) { + EXPECT_EQ(buffer[i], vec_data[i]) + << "Data mismatch at byte " << i << " for " + << vec_type.type_name + << " with null_percent=" << null_percent; + } + + local_chunk_manager->Remove(local_data_path); + } + + cm_->Remove(insert_file_path); + } + } +} + TEST_F(DiskAnnFileManagerTest, FileCleanup) { std::string local_index_file_path; std::string local_text_index_file_path; diff --git a/internal/core/src/storage/Event.cpp b/internal/core/src/storage/Event.cpp index f78bb311a3..7d49c3aba0 100644 --- a/internal/core/src/storage/Event.cpp +++ b/internal/core/src/storage/Event.cpp @@ -336,9 +336,13 @@ BaseEventData::Serialize() { auto row = static_cast< const knowhere::sparse::SparseRow*>( field_data->RawValue(offset)); - payload_writer->add_one_binary_payload( - static_cast(row->data()), - row->data_byte_size()); + if (row) { + payload_writer->add_one_binary_payload( + static_cast(row->data()), + row->data_byte_size()); + } else { + payload_writer->add_one_binary_payload(nullptr, -1); + } } break; } diff --git a/internal/core/src/storage/PayloadReader.cpp b/internal/core/src/storage/PayloadReader.cpp index bd10f825f8..c5bec9dc0f 100644 --- a/internal/core/src/storage/PayloadReader.cpp +++ b/internal/core/src/storage/PayloadReader.cpp @@ -71,13 +71,44 @@ PayloadReader::init(const uint8_t* data, int length, bool is_field_data) { auto file_meta = arrow_reader->parquet_reader()->metadata(); // dim is unused for sparse float vector - dim_ = - (IsVectorDataType(column_type_) && - !IsVectorArrayDataType(column_type_) && - !IsSparseFloatVectorDataType(column_type_)) - ? GetDimensionFromFileMetaData( - file_meta->schema()->Column(column_index), column_type_) - : 1; + // For nullable vectors, dim is stored in Arrow schema metadata + if (IsVectorDataType(column_type_) && + !IsVectorArrayDataType(column_type_) && + !IsSparseFloatVectorDataType(column_type_)) { + if (nullable_) { + std::shared_ptr arrow_schema; + auto st = arrow_reader->GetSchema(&arrow_schema); + AssertInfo(st.ok(), "Failed to get arrow schema"); + AssertInfo(arrow_schema->num_fields() == 1, + "Vector field should have exactly 1 field, got {}", + arrow_schema->num_fields()); + + auto field = arrow_schema->field(0); + if (field->HasMetadata()) { + auto metadata = field->metadata(); + if (metadata->Contains(DIM_KEY)) { + auto dim_str = metadata->Get(DIM_KEY).ValueOrDie(); + dim_ = std::stoi(dim_str); + AssertInfo( + dim_ > 0, + "nullable vector dim must be positive, got {}", + dim_); + } else { + ThrowInfo(DataTypeInvalid, + "nullable vector field metadata missing " + "required 'dim' field"); + } + } else { + ThrowInfo(DataTypeInvalid, + "nullable vector field is missing metadata"); + } + } else { + dim_ = GetDimensionFromFileMetaData( + file_meta->schema()->Column(column_index), column_type_); + } + } else { + dim_ = 1; + } // For VectorArray, get element type and dim from Arrow schema metadata auto element_type = DataType::NONE; @@ -133,8 +164,10 @@ PayloadReader::init(const uint8_t* data, int length, bool is_field_data) { field_data_->FillFieldData(array); } - AssertInfo(field_data_->IsFull(), - "field data hasn't been filled done"); + if (!nullable_ || !IsVectorDataType(column_type_)) { + AssertInfo(field_data_->IsFull(), + "field data hasn't been filled done"); + } } else { arrow_reader_ = std::move(arrow_reader); record_batch_reader_ = std::move(rb_reader); diff --git a/internal/core/src/storage/PayloadWriter.cpp b/internal/core/src/storage/PayloadWriter.cpp index e6ce1a0cf3..200b195609 100644 --- a/internal/core/src/storage/PayloadWriter.cpp +++ b/internal/core/src/storage/PayloadWriter.cpp @@ -35,7 +35,6 @@ PayloadWriter::PayloadWriter(const DataType column_type, int dim, bool nullable) AssertInfo(column_type != DataType::VECTOR_SPARSE_U32_F32, "PayloadWriter for Sparse Float Vector should be created " "using the constructor without dimension"); - AssertInfo(nullable == false, "only scalcar type support null now"); init_dimension(dim); } @@ -63,7 +62,7 @@ PayloadWriter::init_dimension(int dim) { } dimension_ = dim; - builder_ = CreateArrowBuilder(column_type_, element_type_, dim); + builder_ = CreateArrowBuilder(column_type_, element_type_, dim, nullable_); schema_ = CreateArrowSchema(column_type_, dim, nullable_); } @@ -112,8 +111,10 @@ PayloadWriter::finish() { std::shared_ptr arrow_properties = parquet::default_arrow_writer_properties(); - if (column_type_ == DataType::VECTOR_ARRAY) { - // For VectorArray, we need to store schema metadata + if (column_type_ == DataType::VECTOR_ARRAY || + (nullable_ && IsVectorDataType(column_type_) && + !IsSparseFloatVectorDataType(column_type_))) { + // For VectorArray and nullable vectors, we need to store schema metadata parquet::ArrowWriterProperties::Builder arrow_props_builder; arrow_props_builder.store_schema(); arrow_properties = arrow_props_builder.build(); diff --git a/internal/core/src/storage/Util.cpp b/internal/core/src/storage/Util.cpp index e2cd5165c4..8d83e36df3 100644 --- a/internal/core/src/storage/Util.cpp +++ b/internal/core/src/storage/Util.cpp @@ -128,13 +128,40 @@ ReadMediumType(BinlogReaderPtr reader) { void add_vector_payload(std::shared_ptr builder, uint8_t* values, - int length) { + const uint8_t* valid_data, + bool nullable, + int length, + int byte_width) { AssertInfo(builder != nullptr, "empty arrow builder"); - auto binary_builder = - std::dynamic_pointer_cast(builder); - auto ast = binary_builder->AppendValues(values, length); - AssertInfo( - ast.ok(), "append value to arrow builder failed: {}", ast.ToString()); + AssertInfo((nullable && valid_data) || !nullable, + "valid_data is required for nullable vectors"); + arrow::Status ast; + + if (nullable) { + auto binary_builder = + std::dynamic_pointer_cast(builder); + int valid_index = 0; + for (int i = 0; i < length; ++i) { + auto bit = (valid_data[i >> 3] >> (i & 0x07)) & 1; + if (bit) { + ast = binary_builder->Append(values + valid_index * byte_width, + byte_width); + valid_index++; + } else { + ast = binary_builder->AppendNull(); + } + AssertInfo(ast.ok(), + "append value to arrow builder failed: {}", + ast.ToString()); + } + } else { + auto binary_builder = + std::dynamic_pointer_cast(builder); + ast = binary_builder->AppendValues(values, length); + AssertInfo(ast.ok(), + "append value to arrow builder failed: {}", + ast.ToString()); + } } // append values for numeric data @@ -223,12 +250,64 @@ AddPayloadToArrowBuilder(std::shared_ptr builder, break; } - case DataType::VECTOR_FLOAT16: - case DataType::VECTOR_BFLOAT16: - case DataType::VECTOR_BINARY: - case DataType::VECTOR_INT8: case DataType::VECTOR_FLOAT: { - add_vector_payload(builder, const_cast(raw_data), length); + AssertInfo(payload.dimension.has_value(), + "dimension is required for VECTOR_FLOAT"); + int byte_width = payload.dimension.value() * sizeof(float); + add_vector_payload(builder, + const_cast(raw_data), + payload.valid_data, + nullable, + length, + byte_width); + break; + } + case DataType::VECTOR_BINARY: { + AssertInfo(payload.dimension.has_value(), + "dimension is required for VECTOR_BINARY"); + int byte_width = (payload.dimension.value() + 7) / 8; + add_vector_payload(builder, + const_cast(raw_data), + payload.valid_data, + nullable, + length, + byte_width); + break; + } + case DataType::VECTOR_FLOAT16: { + AssertInfo(payload.dimension.has_value(), + "dimension is required for VECTOR_FLOAT16"); + int byte_width = payload.dimension.value() * 2; + add_vector_payload(builder, + const_cast(raw_data), + payload.valid_data, + nullable, + length, + byte_width); + break; + } + case DataType::VECTOR_BFLOAT16: { + AssertInfo(payload.dimension.has_value(), + "dimension is required for VECTOR_BFLOAT16"); + int byte_width = payload.dimension.value() * 2; + add_vector_payload(builder, + const_cast(raw_data), + payload.valid_data, + nullable, + length, + byte_width); + break; + } + case DataType::VECTOR_INT8: { + AssertInfo(payload.dimension.has_value(), + "dimension is required for VECTOR_INT8"); + int byte_width = payload.dimension.value() * sizeof(int8_t); + add_vector_payload(builder, + const_cast(raw_data), + payload.valid_data, + nullable, + length, + byte_width); break; } case DataType::VECTOR_SPARSE_U32_F32: { @@ -380,30 +459,48 @@ CreateArrowBuilder(DataType data_type) { } std::shared_ptr -CreateArrowBuilder(DataType data_type, DataType element_type, int dim) { +CreateArrowBuilder(DataType data_type, + DataType element_type, + int dim, + bool nullable) { switch (static_cast(data_type)) { case DataType::VECTOR_FLOAT: { AssertInfo(dim > 0, "invalid dim value: {}", dim); + if (nullable) { + return std::make_shared(); + } return std::make_shared( arrow::fixed_size_binary(dim * sizeof(float))); } case DataType::VECTOR_BINARY: { AssertInfo(dim % 8 == 0 && dim > 0, "invalid dim value: {}", dim); + if (nullable) { + return std::make_shared(); + } return std::make_shared( arrow::fixed_size_binary(dim / 8)); } case DataType::VECTOR_FLOAT16: { AssertInfo(dim > 0, "invalid dim value: {}", dim); + if (nullable) { + return std::make_shared(); + } return std::make_shared( arrow::fixed_size_binary(dim * sizeof(float16))); } case DataType::VECTOR_BFLOAT16: { AssertInfo(dim > 0, "invalid dim value"); + if (nullable) { + return std::make_shared(); + } return std::make_shared( arrow::fixed_size_binary(dim * sizeof(bfloat16))); } case DataType::VECTOR_INT8: { AssertInfo(dim > 0, "invalid dim value"); + if (nullable) { + return std::make_shared(); + } return std::make_shared( arrow::fixed_size_binary(dim * sizeof(int8))); } @@ -576,6 +673,13 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) { switch (static_cast(data_type)) { case DataType::VECTOR_FLOAT: { AssertInfo(dim > 0, "invalid dim value: {}", dim); + if (nullable) { + auto metadata = std::shared_ptr( + new arrow::KeyValueMetadata()); + metadata->Append(DIM_KEY, std::to_string(dim)); + return arrow::schema( + {arrow::field("val", arrow::binary(), nullable, metadata)}); + } return arrow::schema( {arrow::field("val", arrow::fixed_size_binary(dim * sizeof(float)), @@ -583,11 +687,25 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) { } case DataType::VECTOR_BINARY: { AssertInfo(dim % 8 == 0 && dim > 0, "invalid dim value: {}", dim); + if (nullable) { + auto metadata = std::shared_ptr( + new arrow::KeyValueMetadata()); + metadata->Append(DIM_KEY, std::to_string(dim)); + return arrow::schema( + {arrow::field("val", arrow::binary(), nullable, metadata)}); + } return arrow::schema({arrow::field( "val", arrow::fixed_size_binary(dim / 8), nullable)}); } case DataType::VECTOR_FLOAT16: { AssertInfo(dim > 0, "invalid dim value: {}", dim); + if (nullable) { + auto metadata = std::shared_ptr( + new arrow::KeyValueMetadata()); + metadata->Append(DIM_KEY, std::to_string(dim)); + return arrow::schema( + {arrow::field("val", arrow::binary(), nullable, metadata)}); + } return arrow::schema( {arrow::field("val", arrow::fixed_size_binary(dim * sizeof(float16)), @@ -595,6 +713,13 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) { } case DataType::VECTOR_BFLOAT16: { AssertInfo(dim > 0, "invalid dim value"); + if (nullable) { + auto metadata = std::shared_ptr( + new arrow::KeyValueMetadata()); + metadata->Append(DIM_KEY, std::to_string(dim)); + return arrow::schema( + {arrow::field("val", arrow::binary(), nullable, metadata)}); + } return arrow::schema( {arrow::field("val", arrow::fixed_size_binary(dim * sizeof(bfloat16)), @@ -606,6 +731,13 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) { } case DataType::VECTOR_INT8: { AssertInfo(dim > 0, "invalid dim value"); + if (nullable) { + auto metadata = std::shared_ptr( + new arrow::KeyValueMetadata()); + metadata->Append(DIM_KEY, std::to_string(dim)); + return arrow::schema( + {arrow::field("val", arrow::binary(), nullable, metadata)}); + } return arrow::schema( {arrow::field("val", arrow::fixed_size_binary(dim * sizeof(int8)), @@ -1103,22 +1235,22 @@ CreateFieldData(const DataType& type, type, nullable, total_num_rows); case DataType::VECTOR_FLOAT: return std::make_shared>( - dim, type, total_num_rows); + dim, type, nullable, total_num_rows); case DataType::VECTOR_BINARY: return std::make_shared>( - dim, type, total_num_rows); + dim, type, nullable, total_num_rows); case DataType::VECTOR_FLOAT16: return std::make_shared>( - dim, type, total_num_rows); + dim, type, nullable, total_num_rows); case DataType::VECTOR_BFLOAT16: return std::make_shared>( - dim, type, total_num_rows); + dim, type, nullable, total_num_rows); case DataType::VECTOR_SPARSE_U32_F32: return std::make_shared>( - type, total_num_rows); + type, nullable, total_num_rows); case DataType::VECTOR_INT8: return std::make_shared>( - dim, type, total_num_rows); + dim, type, nullable, total_num_rows); case DataType::VECTOR_ARRAY: return std::make_shared>( dim, element_type, total_num_rows); diff --git a/internal/core/src/storage/Util.h b/internal/core/src/storage/Util.h index 0a74612db3..b67ca4a665 100644 --- a/internal/core/src/storage/Util.h +++ b/internal/core/src/storage/Util.h @@ -59,7 +59,10 @@ std::shared_ptr CreateArrowBuilder(DataType data_type); std::shared_ptr -CreateArrowBuilder(DataType data_type, DataType element_type, int dim); +CreateArrowBuilder(DataType data_type, + DataType element_type, + int dim, + bool nullable = false); /// \brief Utility function to create arrow:Scalar from FieldMeta.default_value /// diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index 7ec8187a71..46a7d8e423 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -326,6 +326,10 @@ GenerateRandomSparseFloatVector(size_t rows, size_t cols = kTestSparseDim, float density = kTestSparseVectorDensity, int seed = 42) { + if (rows == 0) { + return std::make_unique< + knowhere::sparse::SparseRow[]>(0); + } int32_t num_elements = static_cast(rows * cols * density); std::mt19937 rng(seed); @@ -542,7 +546,8 @@ DataGen(SchemaPtr schema, int group_count = 1, bool random_pk = false, bool random_val = true, - bool random_valid = false) { + bool random_valid = false, + int null_percent = 50) { using std::vector; std::default_random_engine random(seed); std::normal_distribution<> distr(0, 1); @@ -635,41 +640,138 @@ DataGen(SchemaPtr schema, return data; }; + auto generate_valid_data = [&](const FieldMeta& field_meta, int64_t N) { + struct Result { + int64_t valid_count; + FixedVector valid_data; + }; + + Result result; + result.valid_data.resize(N); + result.valid_count = 0; + + bool is_nullable = field_meta.is_nullable(); + if (is_nullable) { + for (int i = 0; i < N; ++i) { + if (random_valid) { + int x = rand(); + result.valid_data[i] = x % 2 == 0 ? true : false; + } else { + result.valid_data[i] = (i % 100) >= null_percent; + } + if (result.valid_data[i]) { + result.valid_count++; + } + } + } else { + result.valid_count = N; + } + + return result; + }; + for (auto field_id : schema->get_field_ids()) { auto field_meta = schema->operator[](field_id); switch (field_meta.get_data_type()) { case DataType::VECTOR_FLOAT: { - auto data = generate_float_vector(field_meta, N); - insert_cols(data, N, field_meta, random_valid); + auto [valid_count, valid_data] = + generate_valid_data(field_meta, N); + bool is_nullable = field_meta.is_nullable(); + + auto data = generate_float_vector(field_meta, valid_count); + auto array = milvus::segcore::CreateVectorDataArrayFrom( + data.data(), + is_nullable ? valid_data.data() : nullptr, + N, + valid_count, + field_meta); + insert_data->mutable_fields_data()->AddAllocated( + array.release()); break; } case DataType::VECTOR_BINARY: { - auto data = generate_binary_vector(field_meta, N); - insert_cols(data, N, field_meta, random_valid); + auto [valid_count, valid_data] = + generate_valid_data(field_meta, N); + bool is_nullable = field_meta.is_nullable(); + + auto data = generate_binary_vector(field_meta, valid_count); + auto array = milvus::segcore::CreateVectorDataArrayFrom( + data.data(), + is_nullable ? valid_data.data() : nullptr, + N, + valid_count, + field_meta); + insert_data->mutable_fields_data()->AddAllocated( + array.release()); break; } case DataType::VECTOR_FLOAT16: { - auto data = generate_float16_vector(field_meta, N); - insert_cols(data, N, field_meta, random_valid); + auto [valid_count, valid_data] = + generate_valid_data(field_meta, N); + bool is_nullable = field_meta.is_nullable(); + + auto data = generate_float16_vector(field_meta, valid_count); + auto array = milvus::segcore::CreateVectorDataArrayFrom( + data.data(), + is_nullable ? valid_data.data() : nullptr, + N, + valid_count, + field_meta); + insert_data->mutable_fields_data()->AddAllocated( + array.release()); break; } case DataType::VECTOR_BFLOAT16: { - auto data = generate_bfloat16_vector(field_meta, N); - insert_cols(data, N, field_meta, random_valid); + auto [valid_count, valid_data] = + generate_valid_data(field_meta, N); + bool is_nullable = field_meta.is_nullable(); + + auto data = generate_bfloat16_vector(field_meta, valid_count); + auto array = milvus::segcore::CreateVectorDataArrayFrom( + data.data(), + is_nullable ? valid_data.data() : nullptr, + N, + valid_count, + field_meta); + insert_data->mutable_fields_data()->AddAllocated( + array.release()); break; } case DataType::VECTOR_SPARSE_U32_F32: { - auto res = GenerateRandomSparseFloatVector( - N, kTestSparseDim, kTestSparseVectorDensity, seed); - auto array = milvus::segcore::CreateDataArrayFrom( - res.get(), nullptr, N, field_meta); + auto [valid_count, valid_data] = + generate_valid_data(field_meta, N); + bool is_nullable = field_meta.is_nullable(); + + auto res = + GenerateRandomSparseFloatVector(valid_count, + kTestSparseDim, + kTestSparseVectorDensity, + seed); + + auto array = milvus::segcore::CreateVectorDataArrayFrom( + res.get(), + is_nullable ? valid_data.data() : nullptr, + N, + valid_count, + field_meta); insert_data->mutable_fields_data()->AddAllocated( array.release()); break; } case DataType::VECTOR_INT8: { - auto data = generate_int8_vector(field_meta, N); - insert_cols(data, N, field_meta, random_valid); + auto [valid_count, valid_data] = + generate_valid_data(field_meta, N); + bool is_nullable = field_meta.is_nullable(); + + auto data = generate_int8_vector(field_meta, valid_count); + auto array = milvus::segcore::CreateVectorDataArrayFrom( + data.data(), + is_nullable ? valid_data.data() : nullptr, + N, + valid_count, + field_meta); + insert_data->mutable_fields_data()->AddAllocated( + array.release()); break; } diff --git a/internal/core/unittest/test_utils/storage_test_utils.h b/internal/core/unittest/test_utils/storage_test_utils.h index 99718c79ad..e4fbf3f3b7 100644 --- a/internal/core/unittest/test_utils/storage_test_utils.h +++ b/internal/core/unittest/test_utils/storage_test_utils.h @@ -157,13 +157,13 @@ PrepareSingleFieldInsertBinlog(int64_t collection_id, int64_t row_count = 0; for (auto i = 0; i < field_datas.size(); ++i) { auto& field_data = field_datas[i]; - row_count += field_data->Length(); + row_count += field_data->get_num_rows(); auto file = "./data/test/" + std::to_string(collection_id) + "/" + std::to_string(partition_id) + "/" + std::to_string(segment_id) + "/" + std::to_string(field_id) + "/" + std::to_string(i); files.push_back(file); - row_counts.push_back(field_data->Length()); + row_counts.push_back(field_data->get_num_rows()); auto payload_reader = std::make_shared(field_data); auto insert_data = std::make_shared(payload_reader); diff --git a/internal/flushcommon/writebuffer/insert_buffer_test.go b/internal/flushcommon/writebuffer/insert_buffer_test.go index a68d994231..6e514f6cec 100644 --- a/internal/flushcommon/writebuffer/insert_buffer_test.go +++ b/internal/flushcommon/writebuffer/insert_buffer_test.go @@ -141,7 +141,7 @@ func (s *InsertBufferSuite) TestBuffer() { memSize := insertBuffer.Buffer(groups[0], &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) s.EqualValues(100, insertBuffer.MinTimestamp()) - s.EqualValues(5367, memSize) + s.EqualValues(5376, memSize) } func (s *InsertBufferSuite) TestYield() { diff --git a/internal/flushcommon/writebuffer/l0_write_buffer_test.go b/internal/flushcommon/writebuffer/l0_write_buffer_test.go index a470430c47..99c2c43c42 100644 --- a/internal/flushcommon/writebuffer/l0_write_buffer_test.go +++ b/internal/flushcommon/writebuffer/l0_write_buffer_test.go @@ -195,12 +195,12 @@ func (s *L0WriteBufferSuite) TestBufferData() { value, err := metrics.DataNodeFlowGraphBufferDataSize.GetMetricWithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(s.metacache.Collection())) s.NoError(err) - s.MetricsEqual(value, 5607) + s.MetricsEqual(value, 5616) delMsg = s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewInt64PrimaryKey(id) })) err = wb.BufferData([]*InsertData{}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200}) s.NoError(err) - s.MetricsEqual(value, 5847) + s.MetricsEqual(value, 5856) }) } diff --git a/internal/proxy/msg_pack.go b/internal/proxy/msg_pack.go index c28d1e6cce..9811948e86 100644 --- a/internal/proxy/msg_pack.go +++ b/internal/proxy/msg_pack.go @@ -65,11 +65,15 @@ func genInsertMsgsByPartition(ctx context.Context, return msg } + fieldsData := insertMsg.GetFieldsData() + idxComputer := typeutil.NewFieldDataIdxComputer(fieldsData) + repackedMsgs := make([]msgstream.TsMsg, 0) requestSize := 0 msg := createInsertMsg(segmentID, channelName) for _, offset := range rowOffsets { - curRowMessageSize, err := typeutil.EstimateEntitySize(insertMsg.GetFieldsData(), offset) + fieldIdxs := idxComputer.Compute(int64(offset)) + curRowMessageSize, err := typeutil.EstimateEntitySize(fieldsData, offset, fieldIdxs...) if err != nil { return nil, err } @@ -81,7 +85,7 @@ func genInsertMsgsByPartition(ctx context.Context, requestSize = 0 } - typeutil.AppendFieldData(msg.FieldsData, insertMsg.GetFieldsData(), int64(offset)) + typeutil.AppendFieldData(msg.FieldsData, fieldsData, int64(offset), fieldIdxs...) msg.HashValues = append(msg.HashValues, insertMsg.HashValues[offset]) msg.Timestamps = append(msg.Timestamps, insertMsg.Timestamps[offset]) msg.RowIDs = append(msg.RowIDs, insertMsg.RowIDs[offset]) diff --git a/internal/proxy/search_reduce_util.go b/internal/proxy/search_reduce_util.go index 5517e61a1c..0d435acc44 100644 --- a/internal/proxy/search_reduce_util.go +++ b/internal/proxy/search_reduce_util.go @@ -258,6 +258,11 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error()) } + idxComputers := make([]*typeutil.FieldDataIdxComputer, subSearchNum) + for i, srd := range subSearchResultData { + idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData) + } + var realTopK int64 = -1 var retSize int64 @@ -316,7 +321,8 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData for _, groupEntity := range groupEntities { subResData := subSearchResultData[groupEntity.subSearchIdx] if len(ret.Results.FieldsData) > 0 { - retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx) + fieldIdxs := idxComputers[groupEntity.subSearchIdx].Compute(groupEntity.resultIdx) + retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx, fieldIdxs...) } typeutil.AppendPKs(ret.Results.Ids, groupEntity.id) ret.Results.Scores = append(ret.Results.Scores, groupEntity.score) @@ -424,6 +430,12 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData [] subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1] } } + + idxComputers := make([]*typeutil.FieldDataIdxComputer, subSearchNum) + for i, srd := range subSearchResultData { + idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData) + } + maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() // reducing nq * topk results for i := int64(0); i < nq; i++ { @@ -456,7 +468,9 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData [] score := subSearchResultData[subSearchIdx].Scores[resultDataIdx] if len(ret.Results.FieldsData) > 0 { - retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx) + fieldsData := subSearchResultData[subSearchIdx].FieldsData + fieldIdxs := idxComputers[subSearchIdx].Compute(resultDataIdx) + retSize += typeutil.AppendFieldData(ret.Results.FieldsData, fieldsData, resultDataIdx, fieldIdxs...) } typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx)) ret.Results.Scores = append(ret.Results.Scores, score) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index b82d1f4c2b..d511f629ba 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -563,9 +563,6 @@ func (t *addCollectionFieldTask) PreExecute(ctx context.Context) error { return merr.WrapErrParameterInvalid("valid field", fmt.Sprintf("field data type: %s is not supported", t.fieldSchema.GetDataType())) } - if typeutil.IsVectorType(t.fieldSchema.DataType) { - return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("not support to add vector field, field name = %s", t.fieldSchema.Name)) - } if funcutil.SliceContain([]string{common.RowIDFieldName, common.TimeStampFieldName, common.MetaFieldName, common.NamespaceFieldName}, t.fieldSchema.GetName()) { return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("not support to add system field, field name = %s", t.fieldSchema.Name)) } @@ -575,6 +572,17 @@ func (t *addCollectionFieldTask) PreExecute(ctx context.Context) error { if !t.fieldSchema.Nullable { return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("added field must be nullable, please check it, field name = %s", t.fieldSchema.Name)) } + if typeutil.IsVectorType(t.fieldSchema.DataType) && t.fieldSchema.Nullable { + if t.fieldSchema.DataType == schemapb.DataType_FloatVector || + t.fieldSchema.DataType == schemapb.DataType_Float16Vector || + t.fieldSchema.DataType == schemapb.DataType_BFloat16Vector || + t.fieldSchema.DataType == schemapb.DataType_BinaryVector || + t.fieldSchema.DataType == schemapb.DataType_Int8Vector { + if len(t.fieldSchema.TypeParams) == 0 { + return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("vector field must have dimension specified, field name = %s", t.fieldSchema.Name)) + } + } + } if t.fieldSchema.AutoID { return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("only primary field can speficy AutoID with true, field name = %s", t.fieldSchema.Name)) } diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index d3ed5b6548..a0c91bdab8 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -790,6 +790,10 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re } cursors := make([]int64, len(validRetrieveResults)) + idxComputers := make([]*typeutil.FieldDataIdxComputer, len(validRetrieveResults)) + for i, vr := range validRetrieveResults { + idxComputers[i] = typeutil.NewFieldDataIdxComputer(vr.GetFieldsData()) + } if queryParams != nil && queryParams.limit != typeutil.Unlimited { // IReduceInOrderForBest will try to get as many results as possible @@ -819,7 +823,8 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) { break } - retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) + fieldIdxs := idxComputers[sel].Compute(cursors[sel]) + retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel], fieldIdxs...) // limit retrieve result to avoid oom if retSize > maxOutputSize { diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index f445865c48..3e3ff2054b 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -1009,16 +1009,30 @@ func TestAddFieldTask(t *testing.T) { assert.Error(t, err) assert.ErrorIs(t, err, merr.ErrParameterInvalid) - // not support vector field fSchema = &schemapb.FieldSchema{ + Name: "vec_field", DataType: schemapb.DataType_FloatVector, + Nullable: true, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "128"}, + }, } bytes, err = proto.Marshal(fSchema) assert.NoError(t, err) task.Schema = bytes err = task.PreExecute(ctx) - assert.Error(t, err) - assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.NoError(t, err) + + fSchema = &schemapb.FieldSchema{ + Name: "sparse_vec", + DataType: schemapb.DataType_SparseFloatVector, + Nullable: true, + } + bytes, err = proto.Marshal(fSchema) + assert.NoError(t, err) + task.Schema = bytes + err = task.PreExecute(ctx) + assert.NoError(t, err) // not support system field fSchema = &schemapb.FieldSchema{ @@ -2595,11 +2609,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { }) t.Run("upsert", func(t *testing.T) { - // upsert require pk unique in same batch - hash := make([]uint32, nb) - for i := 0; i < nb; i++ { - hash[i] = uint32(i) - } + hash := testutils.GenerateHashKeys(nb) task := &upsertTask{ upsertMsg: &msgstream.UpsertMsg{ InsertMsg: &BaseInsertTask{ diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 3af8e1fe72..a748e9f3a7 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -392,6 +392,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error { } baseIdx := 0 + idxComputer := typeutil.NewFieldDataIdxComputer(existFieldData) for _, idx := range updateIdxInUpsert { typeutil.AppendIDs(it.deletePKs, upsertIDs, idx) oldPK := typeutil.GetPK(upsertIDs, int64(idx)) @@ -399,7 +400,8 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error { if !ok { return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping") } - typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex)) + fieldIdxs := idxComputer.Compute(int64(existIndex)) + typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex), fieldIdxs...) err := typeutil.UpdateFieldData(it.insertFieldData, it.upsertMsg.InsertMsg.GetFieldsData(), int64(baseIdx), int64(idx)) baseIdx += 1 if err != nil { @@ -438,8 +440,32 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error { insertWithNullField = append(insertWithNullField, fieldData) } } - for _, idx := range insertIdxInUpsert { - typeutil.AppendFieldData(it.insertFieldData, insertWithNullField, int64(idx)) + vectorIdxMap := make([][]int64, len(insertIdxInUpsert)) + for rowIdx, offset := range insertIdxInUpsert { + vectorIdxMap[rowIdx] = make([]int64, len(insertWithNullField)) + for fieldIdx := range insertWithNullField { + vectorIdxMap[rowIdx][fieldIdx] = int64(offset) + } + } + for fieldIdx, fieldData := range insertWithNullField { + validData := fieldData.GetValidData() + if len(validData) > 0 && typeutil.IsVectorType(fieldData.Type) { + dataIdx := int64(0) + rowIdx := 0 + for i := 0; i < len(validData) && rowIdx < len(insertIdxInUpsert); i++ { + if i == insertIdxInUpsert[rowIdx] { + vectorIdxMap[rowIdx][fieldIdx] = dataIdx + rowIdx++ + } + if validData[i] { + dataIdx++ + } + } + } + } + + for rowIdx, idx := range insertIdxInUpsert { + typeutil.AppendFieldData(it.insertFieldData, insertWithNullField, int64(idx), vectorIdxMap[rowIdx]...) } } @@ -620,6 +646,10 @@ func ToCompressedFormatNullable(field *schemapb.FieldData) error { return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String())) } + case *schemapb.FieldData_Vectors: + // Vector data is already in compressed format, skip + return nil + default: return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String())) } @@ -1077,7 +1107,7 @@ func (it *upsertTask) PreExecute(ctx context.Context) error { } } - // deduplicate upsert data to handle duplicate primary keys in the same batch + // check for duplicate primary keys in the same batch primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema) if err != nil { log.Warn("fail to get primary field schema", zap.Error(err)) diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index b6c9df5c06..c2d55acddd 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -141,7 +141,7 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, helper *typeutil.Sch return err } case schemapb.DataType_SparseFloatVector: - if err := v.checkSparseFloatFieldData(field, fieldSchema); err != nil { + if err := v.checkSparseFloatVectorFieldData(field, fieldSchema); err != nil { return err } case schemapb.DataType_Int8Vector: @@ -219,6 +219,13 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil msg := fmt.Sprintf("the dim (%d) of field data(%s) is not equal to schema dim (%d)", dataDim, fieldName, schemaDim) return merr.WrapErrParameterInvalid(schemaDim, dataDim, msg) } + getExpectedVectorRows := func(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) uint64 { + validData := field.GetValidData() + if fieldSchema.GetNullable() && len(validData) > 0 { + return uint64(getValidNumber(validData)) + } + return numRows + } for _, field := range data { switch field.GetType() { case schemapb.DataType_FloatVector: @@ -241,7 +248,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return errDimMismatch(field.GetFieldName(), dataDim, dim) } - if n != numRows { + expectedRows := getExpectedVectorRows(field, f) + if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } @@ -265,7 +273,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return err } - if n != numRows { + expectedRows := getExpectedVectorRows(field, f) + if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } @@ -289,7 +298,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return err } - if n != numRows { + expectedRows := getExpectedVectorRows(field, f) + if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } @@ -313,13 +323,19 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return err } - if n != numRows { + expectedRows := getExpectedVectorRows(field, f) + if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } case schemapb.DataType_SparseFloatVector: + f, err := schema.GetFieldFromName(field.GetFieldName()) + if err != nil { + return err + } n := uint64(len(field.GetVectors().GetSparseFloatVector().Contents)) - if n != numRows { + expectedRows := getExpectedVectorRows(field, f) + if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } @@ -343,7 +359,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return errDimMismatch(field.GetFieldName(), dataDim, dim) } - if n != numRows { + expectedRows := getExpectedVectorRows(field, f) + if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } @@ -728,7 +745,7 @@ func getValidNumber(validData []bool) int { func (v *validateUtil) checkFloatVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { floatArray := field.GetVectors().GetFloatVector().GetData() - if floatArray == nil { + if floatArray == nil && !fieldSchema.GetNullable() { msg := fmt.Sprintf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName()) return merr.WrapErrParameterInvalid("need float vector", "got nil", msg) } @@ -743,8 +760,11 @@ func (v *validateUtil) checkFloatVectorFieldData(field *schemapb.FieldData, fiel func (v *validateUtil) checkFloat16VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { float16VecArray := field.GetVectors().GetFloat16Vector() if float16VecArray == nil { - msg := fmt.Sprintf("float16 float field '%v' is illegal, nil Vector_Float16 type", field.GetFieldName()) - return merr.WrapErrParameterInvalid("need vector_float16 array", "got nil", msg) + if !fieldSchema.GetNullable() { + msg := fmt.Sprintf("float16 vector field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need float16 vector", "got nil", msg) + } + return nil } if v.checkNAN { return typeutil.VerifyFloats16(float16VecArray) @@ -755,8 +775,11 @@ func (v *validateUtil) checkFloat16VectorFieldData(field *schemapb.FieldData, fi func (v *validateUtil) checkBFloat16VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { bfloat16VecArray := field.GetVectors().GetBfloat16Vector() if bfloat16VecArray == nil { - msg := fmt.Sprintf("bfloat16 float field '%v' is illegal, nil Vector_BFloat16 type", field.GetFieldName()) - return merr.WrapErrParameterInvalid("need vector_bfloat16 array", "got nil", msg) + if !fieldSchema.GetNullable() { + msg := fmt.Sprintf("bfloat16 vector field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need bfloat16 vector", "got nil", msg) + } + return nil } if v.checkNAN { return typeutil.VerifyBFloats16(bfloat16VecArray) @@ -766,31 +789,33 @@ func (v *validateUtil) checkBFloat16VectorFieldData(field *schemapb.FieldData, f func (v *validateUtil) checkBinaryVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { bVecArray := field.GetVectors().GetBinaryVector() - if bVecArray == nil { - msg := fmt.Sprintf("binary float vector field '%v' is illegal, array type mismatch", field.GetFieldName()) - return merr.WrapErrParameterInvalid("need bytes array", "got nil", msg) + if bVecArray == nil && !fieldSchema.GetNullable() { + msg := fmt.Sprintf("binary vector field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need binary vector", "got nil", msg) } return nil } -func (v *validateUtil) checkSparseFloatFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { +func (v *validateUtil) checkSparseFloatVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { if field.GetVectors() == nil || field.GetVectors().GetSparseFloatVector() == nil { - msg := fmt.Sprintf("sparse float field '%v' is illegal, nil SparseFloatVector", field.GetFieldName()) - return merr.WrapErrParameterInvalid("need sparse float array", "got nil", msg) + if !fieldSchema.GetNullable() { + msg := fmt.Sprintf("sparse float vector field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need sparse float vector", "got nil", msg) + } + return nil } sparseRows := field.GetVectors().GetSparseFloatVector().GetContents() - if sparseRows == nil { - msg := fmt.Sprintf("sparse float field '%v' is illegal, array type mismatch", field.GetFieldName()) - return merr.WrapErrParameterInvalid("need sparse float array", "got nil", msg) - } return typeutil.ValidateSparseFloatRows(sparseRows...) } func (v *validateUtil) checkInt8VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { int8VecArray := field.GetVectors().GetInt8Vector() if int8VecArray == nil { - msg := fmt.Sprintf("int8 vector field '%v' is illegal, nil Vector_Int8 type", field.GetFieldName()) - return merr.WrapErrParameterInvalid("need vector_int8 array", "got nil", msg) + if !fieldSchema.GetNullable() { + msg := fmt.Sprintf("int8 vector field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need int8 vector", "got nil", msg) + } + return nil } return nil } diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/validate_util_test.go index e70e6ff50c..4d015926d4 100644 --- a/internal/proxy/validate_util_test.go +++ b/internal/proxy/validate_util_test.go @@ -310,23 +310,77 @@ func Test_validateUtil_checkTextFieldData(t *testing.T) { } func Test_validateUtil_checkBinaryVectorFieldData(t *testing.T) { - v := newValidateUtil() - assert.Error(t, v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)) - assert.NoError(t, v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: 128, - Data: &schemapb.VectorField_BinaryVector{ - BinaryVector: []byte(strings.Repeat("1", 128)), + t.Run("not binary vector", func(t *testing.T) { + v := newValidateUtil() + err := v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + v := newValidateUtil() + err := v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 128, + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: []byte(strings.Repeat("1", 128)), + }, }, - }, - }}, nil)) + }}, nil) + assert.NoError(t, err) + }) + + t.Run("nil vector not nullable", func(t *testing.T) { + data := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_BinaryVector, + Nullable: false, + } + v := newValidateUtil() + err := v.checkBinaryVectorFieldData(data, schema) + assert.Error(t, err) + }) + + t.Run("nil vector nullable", func(t *testing.T) { + data := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_BinaryVector, + Nullable: true, + } + v := newValidateUtil() + err := v.checkBinaryVectorFieldData(data, schema) + assert.NoError(t, err) + }) } func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) { + nb := 5 + dim := int64(8) + data := testutils.GenerateFloatVectors(nb, int(dim)) + invalidData := testutils.GenerateFloatVectorsWithInvalidData(nb, int(dim)) + t.Run("not float vector", func(t *testing.T) { - f := &schemapb.FieldData{} v := newValidateUtil() - err := v.checkFloatVectorFieldData(f, nil) + err := v.checkFloatVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil) assert.Error(t, err) }) @@ -336,7 +390,7 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) { Vectors: &schemapb.VectorField{ Data: &schemapb.VectorField_FloatVector{ FloatVector: &schemapb.FloatArray{ - Data: []float32{1.1, 2.2}, + Data: invalidData, }, }, }, @@ -354,7 +408,7 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) { Vectors: &schemapb.VectorField{ Data: &schemapb.VectorField_FloatVector{ FloatVector: &schemapb.FloatArray{ - Data: []float32{float32(math.NaN())}, + Data: invalidData, }, }, }, @@ -371,7 +425,7 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) { Vectors: &schemapb.VectorField{ Data: &schemapb.VectorField_FloatVector{ FloatVector: &schemapb.FloatArray{ - Data: []float32{1.1, 2.2}, + Data: data, }, }, }, @@ -409,6 +463,49 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) { err = v.fillWithValue(data, h, 1) assert.Error(t, err) }) + + t.Run("nil vector not nullable", func(t *testing.T) { + data := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_FloatVector{ + FloatVector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_FloatVector, + Nullable: false, + } + v := newValidateUtil() + err := v.checkFloatVectorFieldData(data, schema) + assert.Error(t, err) + }) + + t.Run("nil vector nullable", func(t *testing.T) { + data := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_FloatVector{ + FloatVector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_FloatVector, + Nullable: true, + } + + v := newValidateUtil() + err := v.checkFloatVectorFieldData(data, schema) + assert.NoError(t, err) + }) } func Test_validateUtil_checkFloat16VectorFieldData(t *testing.T) { @@ -418,9 +515,8 @@ func Test_validateUtil_checkFloat16VectorFieldData(t *testing.T) { invalidData := testutils.GenerateFloat16VectorsWithInvalidData(nb, int(dim)) t.Run("not float16 vector", func(t *testing.T) { - f := &schemapb.FieldData{} v := newValidateUtil() - err := v.checkFloat16VectorFieldData(f, nil) + err := v.checkFloat16VectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil) assert.Error(t, err) }) @@ -500,17 +596,60 @@ func Test_validateUtil_checkFloat16VectorFieldData(t *testing.T) { err = v.fillWithValue(data, h, 1) assert.Error(t, err) }) + + t.Run("nil vector not nullable", func(t *testing.T) { + data := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_Float16Vector, + Nullable: false, + } + v := newValidateUtil() + err := v.checkFloat16VectorFieldData(data, schema) + assert.Error(t, err) + }) + + t.Run("nil vector nullable", func(t *testing.T) { + data := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_Float16Vector, + Nullable: true, + } + + v := newValidateUtil() + err := v.checkFloat16VectorFieldData(data, schema) + assert.NoError(t, err) + }) } -func Test_validateUtil_checkBfloatVectorFieldData(t *testing.T) { +func Test_validateUtil_checkBFloat16VectorFieldData(t *testing.T) { nb := 5 dim := int64(8) - data := testutils.GenerateFloat16Vectors(nb, int(dim)) + data := testutils.GenerateBFloat16Vectors(nb, int(dim)) invalidData := testutils.GenerateBFloat16VectorsWithInvalidData(nb, int(dim)) - t.Run("not float vector", func(t *testing.T) { - f := &schemapb.FieldData{} + + t.Run("not bfloat16 vector", func(t *testing.T) { v := newValidateUtil() - err := v.checkBFloat16VectorFieldData(f, nil) + err := v.checkBFloat16VectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil) assert.Error(t, err) }) @@ -590,6 +729,203 @@ func Test_validateUtil_checkBfloatVectorFieldData(t *testing.T) { err = v.fillWithValue(data, h, 1) assert.Error(t, err) }) + + t.Run("nil vector not nullable", func(t *testing.T) { + data := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Bfloat16Vector{ + Bfloat16Vector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_BFloat16Vector, + Nullable: false, + } + v := newValidateUtil() + err := v.checkBFloat16VectorFieldData(data, schema) + assert.Error(t, err) + }) + + t.Run("nil vector nullable", func(t *testing.T) { + data := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Bfloat16Vector{ + Bfloat16Vector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_BFloat16Vector, + Nullable: true, + } + + v := newValidateUtil() + err := v.checkBFloat16VectorFieldData(data, schema) + assert.NoError(t, err) + }) +} + +func Test_validateUtil_checkSparseFloatVectorFieldData(t *testing.T) { + nb := 5 + sparseContents, dim := testutils.GenerateSparseFloatVectorsData(nb) + + t.Run("not sparse float vector", func(t *testing.T) { + v := newValidateUtil() + err := v.checkSparseFloatVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + fieldData := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_SparseFloatVector{ + SparseFloatVector: &schemapb.SparseFloatArray{ + Contents: sparseContents, + Dim: dim, + }, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_SparseFloatVector, + } + v := newValidateUtil() + err := v.checkSparseFloatVectorFieldData(fieldData, schema) + assert.NoError(t, err) + }) + + t.Run("nil vector not nullable", func(t *testing.T) { + data := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_SparseFloatVector{ + SparseFloatVector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_SparseFloatVector, + Nullable: false, + } + v := newValidateUtil() + err := v.checkSparseFloatVectorFieldData(data, schema) + assert.Error(t, err) + }) + + t.Run("nil vector nullable", func(t *testing.T) { + data := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_SparseFloatVector{ + SparseFloatVector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_SparseFloatVector, + Nullable: true, + } + + v := newValidateUtil() + err := v.checkSparseFloatVectorFieldData(data, schema) + assert.NoError(t, err) + }) +} + +func Test_validateUtil_checkInt8VectorFieldData(t *testing.T) { + nb := 5 + dim := int64(8) + data := typeutil.Int8ArrayToBytes(testutils.GenerateInt8Vectors(nb, int(dim))) + + t.Run("not int8 vector", func(t *testing.T) { + v := newValidateUtil() + err := v.checkInt8VectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + fieldData := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 128, + Data: &schemapb.VectorField_Int8Vector{ + Int8Vector: data, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_Int8Vector, + } + v := newValidateUtil() + err := v.checkInt8VectorFieldData(fieldData, schema) + assert.NoError(t, err) + }) + + t.Run("nil vector not nullable", func(t *testing.T) { + fieldData := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Int8Vector{ + Int8Vector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_Int8Vector, + Nullable: false, + } + v := newValidateUtil() + err := v.checkInt8VectorFieldData(fieldData, schema) + assert.Error(t, err) + }) + + t.Run("nil vector nullable", func(t *testing.T) { + fieldData := &schemapb.FieldData{ + FieldName: "vec", + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_Int8Vector{ + Int8Vector: nil, + }, + }, + }, + } + schema := &schemapb.FieldSchema{ + Name: "vec", + DataType: schemapb.DataType_Int8Vector, + Nullable: true, + } + + v := newValidateUtil() + err := v.checkInt8VectorFieldData(fieldData, schema) + assert.NoError(t, err) + }) } func Test_validateUtil_checkAligned(t *testing.T) { diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index faed762cdb..02ff760d94 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -325,6 +325,11 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna idTsMap := make(map[interface{}]int64) cursors := make([]int64, len(validRetrieveResults)) + idxComputers := make([]*typeutil.FieldDataIdxComputer, len(validRetrieveResults)) + for i, vr := range validRetrieveResults { + idxComputers[i] = typeutil.NewFieldDataIdxComputer(vr.Result.GetFieldsData()) + } + var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; { @@ -335,9 +340,11 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel]) ts := validRetrieveResults[sel].Timestamps[cursors[sel]] + fieldsData := validRetrieveResults[sel].Result.GetFieldsData() + fieldIdxs := idxComputers[sel].Compute(cursors[sel]) if _, ok := idTsMap[pk]; !ok { typeutil.AppendPKs(ret.Ids, pk) - retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].Result.GetFieldsData(), cursors[sel]) + retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, cursors[sel], fieldIdxs...) idTsMap[pk] = ts j++ } else { @@ -346,7 +353,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna if ts != 0 && ts > idTsMap[pk] { idTsMap[pk] = ts typeutil.DeleteFieldData(ret.FieldsData) - retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].Result.GetFieldsData(), cursors[sel]) + retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, cursors[sel], fieldIdxs...) } } @@ -511,10 +518,17 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore _, span2 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData") defer span2.End() ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].Result.GetFieldsData(), int64(len(selections))) - // cursors = make([]int64, len(validRetrieveResults)) + + idxComputers := make([]*typeutil.FieldDataIdxComputer, len(validRetrieveResults)) + for i, vr := range validRetrieveResults { + idxComputers[i] = typeutil.NewFieldDataIdxComputer(vr.Result.GetFieldsData()) + } + for _, selection := range selections { // cannot use `cursors[sel]` directly, since some of them may be skipped. - retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[selection.batchIndex].Result.GetFieldsData(), selection.resultIndex) + fieldsData := validRetrieveResults[selection.batchIndex].Result.GetFieldsData() + fieldIdxs := idxComputers[selection.batchIndex].Compute(selection.resultIndex) + retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, selection.resultIndex, fieldIdxs...) // limit retrieve result to avoid oom if retSize > maxOutputSize { @@ -564,10 +578,18 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore _, span3 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData") defer span3.End() + + idxComputers := make([]*typeutil.FieldDataIdxComputer, len(segmentResults)) + for i, r := range segmentResults { + idxComputers[i] = typeutil.NewFieldDataIdxComputer(r.GetFieldsData()) + } + // retrieve result is compacted, use 0,1,2...end segmentResOffset := make([]int64, len(segmentResults)) for _, selection := range selections { - retSize += typeutil.AppendFieldData(ret.FieldsData, segmentResults[selection.batchIndex].GetFieldsData(), segmentResOffset[selection.batchIndex]) + fieldsData := segmentResults[selection.batchIndex].GetFieldsData() + fieldIdxs := idxComputers[selection.batchIndex].Compute(segmentResOffset[selection.batchIndex]) + retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, segmentResOffset[selection.batchIndex], fieldIdxs...) segmentResOffset[selection.batchIndex]++ // limit retrieve result to avoid oom if retSize > maxOutputSize { diff --git a/internal/querynodev2/segments/search_reduce.go b/internal/querynodev2/segments/search_reduce.go index 430c6c3ff6..3b68a4fc19 100644 --- a/internal/querynodev2/segments/search_reduce.go +++ b/internal/querynodev2/segments/search_reduce.go @@ -68,6 +68,11 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc ret.AllSearchCount += searchResultData[i].GetAllSearchCount() } + idxComputers := make([]*typeutil.FieldDataIdxComputer, len(searchResultData)) + for i, srd := range searchResultData { + idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData) + } + var skipDupCnt int64 var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() @@ -87,7 +92,9 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc // remove duplicates if _, ok := idSet[id]; !ok { - retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx) + fieldsData := searchResultData[sel].FieldsData + fieldIdxs := idxComputers[sel].Compute(idx) + retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, idx, fieldIdxs...) typeutil.AppendPKs(ret.Ids, id) ret.Scores = append(ret.Scores, score) if searchResultData[sel].ElementIndices != nil && ret.ElementIndices != nil { @@ -173,6 +180,11 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error()) } + idxComputers := make([]*typeutil.FieldDataIdxComputer, len(searchResultData)) + for i, srd := range searchResultData { + idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData) + } + var filteredCount int64 var retSize int64 maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() @@ -208,7 +220,9 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear // exceed the limit for each group, filter this entity filteredCount++ } else { - retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx) + fieldsData := searchResultData[sel].FieldsData + fieldIdxs := idxComputers[sel].Compute(idx) + retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, idx, fieldIdxs...) typeutil.AppendPKs(ret.Ids, id) ret.Scores = append(ret.Scores, score) if searchResultData[sel].ElementIndices != nil && ret.ElementIndices != nil { diff --git a/internal/rootcoord/create_collection_task_test.go b/internal/rootcoord/create_collection_task_test.go index b20a10602a..de250c8586 100644 --- a/internal/rootcoord/create_collection_task_test.go +++ b/internal/rootcoord/create_collection_task_test.go @@ -667,14 +667,14 @@ func Test_createCollectionTask_validateSchema(t *testing.T) { DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector, Nullable: true, + TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}, }, }, }, }, } err := task.validateSchema(context.TODO(), schema) - assert.Error(t, err) - assert.Contains(t, err.Error(), "vector type not support null") + assert.NoError(t, err) }) t.Run("struct array field - field with default value", func(t *testing.T) { @@ -980,7 +980,7 @@ func Test_createCollectionTask_prepareSchema(t *testing.T) { assert.Error(t, err) }) - t.Run("vector type not support null", func(t *testing.T) { + t.Run("vector type with nullable", func(t *testing.T) { collectionName := funcutil.GenRandomStr() field1 := funcutil.GenRandomStr() schema := &schemapb.CollectionSchema{ @@ -989,9 +989,17 @@ func Test_createCollectionTask_prepareSchema(t *testing.T) { AutoID: false, Fields: []*schemapb.FieldSchema{ { - Name: field1, - DataType: 101, - Nullable: true, + FieldID: 100, + Name: "pk", + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + }, + { + FieldID: 101, + Name: field1, + DataType: schemapb.DataType_FloatVector, + Nullable: true, + TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}, }, }, } @@ -1005,7 +1013,7 @@ func Test_createCollectionTask_prepareSchema(t *testing.T) { }, } err := task.prepareSchema(context.TODO()) - assert.Error(t, err) + assert.NoError(t, err) }) } diff --git a/internal/rootcoord/util.go b/internal/rootcoord/util.go index 96cd9a8f1b..b340ffad33 100644 --- a/internal/rootcoord/util.go +++ b/internal/rootcoord/util.go @@ -374,10 +374,6 @@ func checkFieldSchema(fieldSchemas []*schemapb.FieldSchema) error { msg := fmt.Sprintf("ArrayOfVector is only supported in struct array field, type:%s, name:%s", fieldSchema.GetDataType().String(), fieldSchema.GetName()) return merr.WrapErrParameterInvalidMsg(msg) } - if fieldSchema.GetNullable() && typeutil.IsVectorType(fieldSchema.GetDataType()) { - msg := fmt.Sprintf("vector type not support null, type:%s, name:%s", fieldSchema.GetDataType().String(), fieldSchema.GetName()) - return merr.WrapErrParameterInvalidMsg(msg) - } if fieldSchema.GetNullable() && fieldSchema.IsPrimaryKey { msg := fmt.Sprintf("primary field not support null, type:%s, name:%s", fieldSchema.GetDataType().String(), fieldSchema.GetName()) return merr.WrapErrParameterInvalidMsg(msg) @@ -502,11 +498,6 @@ func checkStructArrayFieldSchema(schemas []*schemapb.StructArrayFieldSchema) err field.DataType.String(), field.ElementType.String(), field.Name) return merr.WrapErrParameterInvalidMsg(msg) } - if field.GetNullable() && typeutil.IsVectorType(field.ElementType) { - msg := fmt.Sprintf("vector type not support null, data type:%s, element type:%s, name:%s", - field.DataType.String(), field.ElementType.String(), field.Name) - return merr.WrapErrParameterInvalidMsg(msg) - } if field.GetDefaultValue() != nil { msg := fmt.Sprintf("fields in struct array field not support default_value, data type:%s, element type:%s, name:%s", field.DataType.String(), field.ElementType.String(), field.Name) diff --git a/internal/storage/arrow_util.go b/internal/storage/arrow_util.go index 81ad0325cd..b50d8d65af 100644 --- a/internal/storage/arrow_util.go +++ b/internal/storage/arrow_util.go @@ -391,8 +391,12 @@ func NewRecordBuilder(schema *schemapb.CollectionSchema) *RecordBuilder { if field.DataType == schemapb.DataType_ArrayOfVector { elementType = field.GetElementType() } - arrowType := serdeMap[field.DataType].arrowType(int(dim), elementType) - builders[i] = array.NewBuilder(memory.DefaultAllocator, arrowType) + if field.GetNullable() && typeutil.IsVectorType(field.DataType) && !typeutil.IsSparseFloatVectorType(field.DataType) { + builders[i] = array.NewBinaryBuilder(memory.DefaultAllocator, arrow.BinaryTypes.Binary) + } else { + arrowType := serdeMap[field.DataType].arrowType(int(dim), elementType) + builders[i] = array.NewBuilder(memory.DefaultAllocator, arrowType) + } } return &RecordBuilder{ diff --git a/internal/storage/data_codec.go b/internal/storage/data_codec.go index 6e94131485..9b8a069afe 100644 --- a/internal/storage/data_codec.go +++ b/internal/storage/data_codec.go @@ -448,19 +448,19 @@ func AddFieldDataToPayload(eventWriter *insertEventWriter, dataType schemapb.Dat } } case schemapb.DataType_BinaryVector: - if err = eventWriter.AddBinaryVectorToPayload(singleData.(*BinaryVectorFieldData).Data, singleData.(*BinaryVectorFieldData).Dim); err != nil { + if err = eventWriter.AddBinaryVectorToPayload(singleData.(*BinaryVectorFieldData).Data, singleData.(*BinaryVectorFieldData).Dim, singleData.(*BinaryVectorFieldData).ValidData); err != nil { return err } case schemapb.DataType_FloatVector: - if err = eventWriter.AddFloatVectorToPayload(singleData.(*FloatVectorFieldData).Data, singleData.(*FloatVectorFieldData).Dim); err != nil { + if err = eventWriter.AddFloatVectorToPayload(singleData.(*FloatVectorFieldData).Data, singleData.(*FloatVectorFieldData).Dim, singleData.(*FloatVectorFieldData).ValidData); err != nil { return err } case schemapb.DataType_Float16Vector: - if err = eventWriter.AddFloat16VectorToPayload(singleData.(*Float16VectorFieldData).Data, singleData.(*Float16VectorFieldData).Dim); err != nil { + if err = eventWriter.AddFloat16VectorToPayload(singleData.(*Float16VectorFieldData).Data, singleData.(*Float16VectorFieldData).Dim, singleData.(*Float16VectorFieldData).ValidData); err != nil { return err } case schemapb.DataType_BFloat16Vector: - if err = eventWriter.AddBFloat16VectorToPayload(singleData.(*BFloat16VectorFieldData).Data, singleData.(*BFloat16VectorFieldData).Dim); err != nil { + if err = eventWriter.AddBFloat16VectorToPayload(singleData.(*BFloat16VectorFieldData).Data, singleData.(*BFloat16VectorFieldData).Dim, singleData.(*BFloat16VectorFieldData).ValidData); err != nil { return err } case schemapb.DataType_SparseFloatVector: @@ -468,7 +468,7 @@ func AddFieldDataToPayload(eventWriter *insertEventWriter, dataType schemapb.Dat return err } case schemapb.DataType_Int8Vector: - if err = eventWriter.AddInt8VectorToPayload(singleData.(*Int8VectorFieldData).Data, singleData.(*Int8VectorFieldData).Dim); err != nil { + if err = eventWriter.AddInt8VectorToPayload(singleData.(*Int8VectorFieldData).Data, singleData.(*Int8VectorFieldData).Dim, singleData.(*Int8VectorFieldData).ValidData); err != nil { return err } case schemapb.DataType_ArrayOfVector: @@ -747,6 +747,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins return length, err } binaryVectorFieldData.Dim = dim + if validData != nil && len(validData) > 0 { + startLogical := len(binaryVectorFieldData.ValidData) + if binaryVectorFieldData.ValidData == nil { + binaryVectorFieldData.ValidData = make([]bool, 0, rowNum) + } + binaryVectorFieldData.ValidData = append(binaryVectorFieldData.ValidData, validData...) + binaryVectorFieldData.Nullable = true + binaryVectorFieldData.L2PMapping.Build(validData, startLogical, len(validData)) + } insertData.Data[fieldID] = binaryVectorFieldData return length, nil @@ -763,6 +772,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins return length, err } float16VectorFieldData.Dim = dim + if validData != nil && len(validData) > 0 { + startLogical := len(float16VectorFieldData.ValidData) + if float16VectorFieldData.ValidData == nil { + float16VectorFieldData.ValidData = make([]bool, 0, rowNum) + } + float16VectorFieldData.ValidData = append(float16VectorFieldData.ValidData, validData...) + float16VectorFieldData.Nullable = true + float16VectorFieldData.L2PMapping.Build(validData, startLogical, len(validData)) + } insertData.Data[fieldID] = float16VectorFieldData return length, nil @@ -779,6 +797,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins return length, err } bfloat16VectorFieldData.Dim = dim + if validData != nil && len(validData) > 0 { + startLogical := len(bfloat16VectorFieldData.ValidData) + if bfloat16VectorFieldData.ValidData == nil { + bfloat16VectorFieldData.ValidData = make([]bool, 0, rowNum) + } + bfloat16VectorFieldData.ValidData = append(bfloat16VectorFieldData.ValidData, validData...) + bfloat16VectorFieldData.Nullable = true + bfloat16VectorFieldData.L2PMapping.Build(validData, startLogical, len(validData)) + } insertData.Data[fieldID] = bfloat16VectorFieldData return length, nil @@ -795,6 +822,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins return 0, err } floatVectorFieldData.Dim = dim + if validData != nil && len(validData) > 0 { + startLogical := len(floatVectorFieldData.ValidData) + if floatVectorFieldData.ValidData == nil { + floatVectorFieldData.ValidData = make([]bool, 0, rowNum) + } + floatVectorFieldData.ValidData = append(floatVectorFieldData.ValidData, validData...) + floatVectorFieldData.Nullable = true + floatVectorFieldData.L2PMapping.Build(validData, startLogical, len(validData)) + } insertData.Data[fieldID] = floatVectorFieldData return length, nil @@ -805,6 +841,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins } vec := fieldData.(*SparseFloatVectorFieldData) vec.AppendAllRows(singleData) + if validData != nil && len(validData) > 0 { + startLogical := len(vec.ValidData) + if vec.ValidData == nil { + vec.ValidData = make([]bool, 0, rowNum) + } + vec.ValidData = append(vec.ValidData, validData...) + vec.Nullable = true + vec.L2PMapping.Build(validData, startLogical, len(validData)) + } insertData.Data[fieldID] = vec return singleData.RowNum(), nil @@ -821,6 +866,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins return 0, err } int8VectorFieldData.Dim = dim + if validData != nil && len(validData) > 0 { + startLogical := len(int8VectorFieldData.ValidData) + if int8VectorFieldData.ValidData == nil { + int8VectorFieldData.ValidData = make([]bool, 0, rowNum) + } + int8VectorFieldData.ValidData = append(int8VectorFieldData.ValidData, validData...) + int8VectorFieldData.Nullable = true + int8VectorFieldData.L2PMapping.Build(validData, startLogical, len(validData)) + } insertData.Data[fieldID] = int8VectorFieldData return length, nil diff --git a/internal/storage/data_codec_test.go b/internal/storage/data_codec_test.go index d2daa9f238..2364c3be78 100644 --- a/internal/storage/data_codec_test.go +++ b/internal/storage/data_codec_test.go @@ -35,30 +35,36 @@ import ( ) const ( - CollectionID = 1 - PartitionID = 1 - SegmentID = 1 - RowIDField = 0 - TimestampField = 1 - BoolField = 100 - Int8Field = 101 - Int16Field = 102 - Int32Field = 103 - Int64Field = 104 - FloatField = 105 - DoubleField = 106 - StringField = 107 - BinaryVectorField = 108 - FloatVectorField = 109 - ArrayField = 110 - JSONField = 111 - Float16VectorField = 112 - BFloat16VectorField = 113 - SparseFloatVectorField = 114 - Int8VectorField = 115 - StructField = 116 - StructSubInt32Field = 117 - StructSubFloatVectorField = 118 + CollectionID = 1 + PartitionID = 1 + SegmentID = 1 + RowIDField = 0 + TimestampField = 1 + BoolField = 100 + Int8Field = 101 + Int16Field = 102 + Int32Field = 103 + Int64Field = 104 + FloatField = 105 + DoubleField = 106 + StringField = 107 + BinaryVectorField = 108 + FloatVectorField = 109 + ArrayField = 110 + JSONField = 111 + Float16VectorField = 112 + BFloat16VectorField = 113 + SparseFloatVectorField = 114 + Int8VectorField = 115 + StructField = 116 + StructSubInt32Field = 117 + StructSubFloatVectorField = 118 + NullableFloatVectorField = 119 + NullableBinaryVectorField = 120 + NullableFloat16VectorField = 121 + NullableBFloat16VectorField = 122 + NullableInt8VectorField = 123 + NullableSparseFloatVectorField = 124 ) func assertTestData(t *testing.T, i int, value *Value) { @@ -284,26 +290,35 @@ func generateTestDataWithSeed(seed, num int) ([]*Blob, error) { 19: &JSONFieldData{Data: field19}, 101: &Int32FieldData{Data: field101}, 102: &FloatVectorFieldData{ - Data: field102, - Dim: 8, + Data: field102, + ValidData: nil, + Dim: 8, + Nullable: false, }, 103: &BinaryVectorFieldData{ - Data: field103, - Dim: 8, + Data: field103, + ValidData: nil, + Dim: 8, + Nullable: false, }, 104: &Float16VectorFieldData{ - Data: field104, - Dim: 8, + Data: field104, + ValidData: nil, + Dim: 8, + Nullable: false, }, 105: &BFloat16VectorFieldData{ - Data: field105, - Dim: 8, + Data: field105, + ValidData: nil, + Dim: 8, + Nullable: false, }, 106: &SparseFloatVectorFieldData{ SparseFloatArray: schemapb.SparseFloatArray{ Dim: 28433, Contents: field106, }, + Nullable: false, }, }} @@ -616,6 +631,79 @@ func genTestCollectionMeta() *etcdpb.CollectionMeta { }, }, }, + { + FieldID: NullableFloatVectorField, + Name: "field_nullable_float_vector", + Description: "nullable_float_vector", + DataType: schemapb.DataType_FloatVector, + Nullable: true, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "4", + }, + }, + }, + { + FieldID: NullableBinaryVectorField, + Name: "field_nullable_binary_vector", + Description: "nullable_binary_vector", + DataType: schemapb.DataType_BinaryVector, + Nullable: true, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + }, + { + FieldID: NullableFloat16VectorField, + Name: "field_nullable_float16_vector", + Description: "nullable_float16_vector", + DataType: schemapb.DataType_Float16Vector, + Nullable: true, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "4", + }, + }, + }, + { + FieldID: NullableBFloat16VectorField, + Name: "field_nullable_bfloat16_vector", + Description: "nullable_bfloat16_vector", + DataType: schemapb.DataType_BFloat16Vector, + Nullable: true, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "4", + }, + }, + }, + { + FieldID: NullableInt8VectorField, + Name: "field_nullable_int8_vector", + Description: "nullable_int8_vector", + DataType: schemapb.DataType_Int8Vector, + Nullable: true, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "4", + }, + }, + }, + { + FieldID: NullableSparseFloatVectorField, + Name: "field_nullable_sparse_float_vector", + Description: "nullable_sparse_float_vector", + DataType: schemapb.DataType_SparseFloatVector, + Nullable: true, + TypeParams: []*commonpb.KeyValuePair{}, + }, }, StructArrayFields: []*schemapb.StructArrayFieldSchema{ { @@ -649,63 +737,6 @@ func genTestCollectionMeta() *etcdpb.CollectionMeta { } } -func TestInsertCodecFailed(t *testing.T) { - t.Run("vector field not support null", func(t *testing.T) { - tests := []struct { - description string - dataType schemapb.DataType - }{ - {"nullable FloatVector field", schemapb.DataType_FloatVector}, - {"nullable Float16Vector field", schemapb.DataType_Float16Vector}, - {"nullable BinaryVector field", schemapb.DataType_BinaryVector}, - {"nullable BFloat16Vector field", schemapb.DataType_BFloat16Vector}, - {"nullable SparseFloatVector field", schemapb.DataType_SparseFloatVector}, - {"nullable Int8Vector field", schemapb.DataType_Int8Vector}, - } - - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { - schema := &etcdpb.CollectionMeta{ - ID: CollectionID, - CreateTime: 1, - SegmentIDs: []int64{SegmentID}, - PartitionTags: []string{"partition_0", "partition_1"}, - Schema: &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - Fields: []*schemapb.FieldSchema{ - { - FieldID: RowIDField, - Name: "row_id", - Description: "row_id", - DataType: schemapb.DataType_Int64, - }, - { - FieldID: TimestampField, - Name: "Timestamp", - Description: "Timestamp", - DataType: schemapb.DataType_Int64, - }, - { - DataType: test.dataType, - }, - }, - }, - } - insertCodec := NewInsertCodecWithSchema(schema) - insertDataEmpty := &InsertData{ - Data: map[int64]FieldData{ - RowIDField: &Int64FieldData{[]int64{}, nil, false}, - TimestampField: &Int64FieldData{[]int64{}, nil, false}, - }, - } - _, err := insertCodec.Serialize(PartitionID, SegmentID, insertDataEmpty) - assert.Error(t, err) - }) - } - }) -} - func TestInsertCodec(t *testing.T) { schema := genTestCollectionMeta() insertCodec := NewInsertCodecWithSchema(schema) @@ -742,12 +773,16 @@ func TestInsertCodec(t *testing.T) { Data: []string{"3", "4"}, }, BinaryVectorField: &BinaryVectorFieldData{ - Data: []byte{0, 255}, - Dim: 8, + Data: []byte{0, 255}, + ValidData: nil, + Dim: 8, + Nullable: false, }, FloatVectorField: &FloatVectorFieldData{ - Data: []float32{4, 5, 6, 7, 4, 5, 6, 7}, - Dim: 4, + Data: []float32{4, 5, 6, 7, 4, 5, 6, 7}, + ValidData: nil, + Dim: 4, + Nullable: false, }, ArrayField: &ArrayFieldData{ ElementType: schemapb.DataType_Int32, @@ -772,13 +807,17 @@ func TestInsertCodec(t *testing.T) { }, Float16VectorField: &Float16VectorFieldData{ // length = 2 * Dim * numRows(2) = 16 - Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}, - Dim: 4, + Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}, + ValidData: nil, + Dim: 4, + Nullable: false, }, BFloat16VectorField: &BFloat16VectorFieldData{ // length = 2 * Dim * numRows(2) = 16 - Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}, - Dim: 4, + Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}, + ValidData: nil, + Dim: 4, + Nullable: false, }, SparseFloatVectorField: &SparseFloatVectorFieldData{ SparseFloatArray: schemapb.SparseFloatArray{ @@ -789,10 +828,13 @@ func TestInsertCodec(t *testing.T) { typeutil.CreateSparseFloatRow([]uint32{100, 200, 599}, []float32{3.1, 3.2, 3.3}), }, }, + Nullable: false, }, Int8VectorField: &Int8VectorFieldData{ - Data: []int8{-4, -5, -6, -7, -4, -5, -6, -7}, - Dim: 4, + Data: []int8{-4, -5, -6, -7, -4, -5, -6, -7}, + ValidData: nil, + Dim: 4, + Nullable: false, }, StructSubInt32Field: &ArrayFieldData{ ElementType: schemapb.DataType_Int32, @@ -827,6 +869,36 @@ func TestInsertCodec(t *testing.T) { }, }, }, + NullableFloatVectorField: &FloatVectorFieldData{ + Data: []float32{4.0, 5.0, 6.0, 7.0}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, + NullableBinaryVectorField: &BinaryVectorFieldData{ + Data: []byte{255}, + ValidData: []bool{true, false}, + Dim: 8, + Nullable: true, + }, + NullableFloat16VectorField: &Float16VectorFieldData{ + Data: []byte{255, 0, 255, 0, 255, 0, 255, 0}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, + NullableBFloat16VectorField: &BFloat16VectorFieldData{ + Data: []byte{255, 0, 255, 0, 255, 0, 255, 0}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, + NullableInt8VectorField: &Int8VectorFieldData{ + Data: []int8{-4, -5, -6, -7}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, }, } @@ -863,22 +935,30 @@ func TestInsertCodec(t *testing.T) { Data: []string{"1", "2"}, }, BinaryVectorField: &BinaryVectorFieldData{ - Data: []byte{0, 255}, - Dim: 8, + Data: []byte{0, 255}, + ValidData: nil, + Dim: 8, + Nullable: false, }, FloatVectorField: &FloatVectorFieldData{ - Data: []float32{0, 1, 2, 3, 0, 1, 2, 3}, - Dim: 4, + Data: []float32{0, 1, 2, 3, 0, 1, 2, 3}, + ValidData: nil, + Dim: 4, + Nullable: false, }, Float16VectorField: &Float16VectorFieldData{ // length = 2 * Dim * numRows(2) = 16 - Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}, - Dim: 4, + Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}, + ValidData: nil, + Dim: 4, + Nullable: false, }, BFloat16VectorField: &BFloat16VectorFieldData{ // length = 2 * Dim * numRows(2) = 16 - Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}, - Dim: 4, + Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255}, + ValidData: nil, + Dim: 4, + Nullable: false, }, SparseFloatVectorField: &SparseFloatVectorFieldData{ SparseFloatArray: schemapb.SparseFloatArray{ @@ -889,10 +969,13 @@ func TestInsertCodec(t *testing.T) { typeutil.CreateSparseFloatRow([]uint32{105, 207, 299}, []float32{3.1, 3.2, 3.3}), }, }, + Nullable: false, }, Int8VectorField: &Int8VectorFieldData{ - Data: []int8{0, 1, 2, 3, 0, 1, 2, 3}, - Dim: 4, + Data: []int8{0, 1, 2, 3, 0, 1, 2, 3}, + ValidData: nil, + Dim: 4, + Nullable: false, }, StructSubInt32Field: &ArrayFieldData{ ElementType: schemapb.DataType_Int32, @@ -927,6 +1010,46 @@ func TestInsertCodec(t *testing.T) { }, }, }, + NullableFloatVectorField: &FloatVectorFieldData{ + Data: []float32{0.0, 1.0, 2.0, 3.0}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, + NullableBinaryVectorField: &BinaryVectorFieldData{ + Data: []byte{0}, + ValidData: []bool{true, false}, + Dim: 8, + Nullable: true, + }, + NullableFloat16VectorField: &Float16VectorFieldData{ + Data: []byte{0, 255, 0, 255, 0, 255, 0, 255}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, + NullableBFloat16VectorField: &BFloat16VectorFieldData{ + Data: []byte{0, 255, 0, 255, 0, 255, 0, 255}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, + NullableInt8VectorField: &Int8VectorFieldData{ + Data: []int8{0, 1, 2, 3}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, + NullableSparseFloatVectorField: &SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Dim: 300, + Contents: [][]byte{ + typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}), + }, + }, + ValidData: []bool{true, false}, + Nullable: true, + }, ArrayField: &ArrayFieldData{ ElementType: schemapb.DataType_Int32, Data: []*schemapb.ScalarField{ @@ -953,27 +1076,84 @@ func TestInsertCodec(t *testing.T) { insertDataEmpty := &InsertData{ Data: map[int64]FieldData{ - RowIDField: &Int64FieldData{[]int64{}, nil, false}, - TimestampField: &Int64FieldData{[]int64{}, nil, false}, - BoolField: &BoolFieldData{[]bool{}, nil, false}, - Int8Field: &Int8FieldData{[]int8{}, nil, false}, - Int16Field: &Int16FieldData{[]int16{}, nil, false}, - Int32Field: &Int32FieldData{[]int32{}, nil, false}, - Int64Field: &Int64FieldData{[]int64{}, nil, false}, - FloatField: &FloatFieldData{[]float32{}, nil, false}, - DoubleField: &DoubleFieldData{[]float64{}, nil, false}, - StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false}, - BinaryVectorField: &BinaryVectorFieldData{[]byte{}, 8}, - FloatVectorField: &FloatVectorFieldData{[]float32{}, 4}, - Float16VectorField: &Float16VectorFieldData{[]byte{}, 4}, - BFloat16VectorField: &BFloat16VectorFieldData{[]byte{}, 4}, + RowIDField: &Int64FieldData{[]int64{}, nil, false}, + TimestampField: &Int64FieldData{[]int64{}, nil, false}, + BoolField: &BoolFieldData{[]bool{}, nil, false}, + Int8Field: &Int8FieldData{[]int8{}, nil, false}, + Int16Field: &Int16FieldData{[]int16{}, nil, false}, + Int32Field: &Int32FieldData{[]int32{}, nil, false}, + Int64Field: &Int64FieldData{[]int64{}, nil, false}, + FloatField: &FloatFieldData{[]float32{}, nil, false}, + DoubleField: &DoubleFieldData{[]float64{}, nil, false}, + StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false}, + BinaryVectorField: &BinaryVectorFieldData{ + Data: []byte{}, + ValidData: nil, + Dim: 8, + Nullable: false, + }, + FloatVectorField: &FloatVectorFieldData{ + Data: []float32{}, + ValidData: nil, + Dim: 4, + Nullable: false, + }, + Float16VectorField: &Float16VectorFieldData{ + Data: []byte{}, + ValidData: nil, + Dim: 4, + Nullable: false, + }, + BFloat16VectorField: &BFloat16VectorFieldData{ + Data: []byte{}, + ValidData: nil, + Dim: 4, + Nullable: false, + }, SparseFloatVectorField: &SparseFloatVectorFieldData{ SparseFloatArray: schemapb.SparseFloatArray{ Dim: 0, Contents: [][]byte{}, }, + ValidData: nil, + Nullable: false, + }, + Int8VectorField: &Int8VectorFieldData{ + Data: []int8{}, + ValidData: nil, + Dim: 4, + Nullable: false, + }, + NullableFloatVectorField: &FloatVectorFieldData{ + Data: []float32{}, + ValidData: []bool{}, + Dim: 4, + Nullable: true, + }, + NullableBinaryVectorField: &BinaryVectorFieldData{ + Data: []byte{}, + ValidData: []bool{}, + Dim: 8, + Nullable: true, + }, + NullableFloat16VectorField: &Float16VectorFieldData{ + Data: []byte{}, + ValidData: []bool{}, + Dim: 4, + Nullable: true, + }, + NullableBFloat16VectorField: &BFloat16VectorFieldData{ + Data: []byte{}, + ValidData: []bool{}, + Dim: 4, + Nullable: true, + }, + NullableInt8VectorField: &Int8VectorFieldData{ + Data: []int8{}, + ValidData: []bool{}, + Dim: 4, + Nullable: true, }, - Int8VectorField: &Int8VectorFieldData{[]int8{}, 4}, StructSubInt32Field: &ArrayFieldData{schemapb.DataType_Int32, []*schemapb.ScalarField{}, nil, false}, ArrayField: &ArrayFieldData{schemapb.DataType_Int32, []*schemapb.ScalarField{}, nil, false}, JSONField: &JSONFieldData{[][]byte{}, nil, false}, @@ -1321,24 +1501,74 @@ func TestMemorySize(t *testing.T) { Data: []string{"3"}, }, BinaryVectorField: &BinaryVectorFieldData{ - Data: []byte{0}, - Dim: 8, + Data: []byte{0}, + ValidData: nil, + Dim: 8, + Nullable: false, }, FloatVectorField: &FloatVectorFieldData{ - Data: []float32{4, 5, 6, 7}, - Dim: 4, + Data: []float32{4, 5, 6, 7}, + ValidData: nil, + Dim: 4, + Nullable: false, }, Float16VectorField: &Float16VectorFieldData{ - Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7}, - Dim: 4, + Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7}, + ValidData: nil, + Dim: 4, + Nullable: false, }, BFloat16VectorField: &BFloat16VectorFieldData{ - Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7}, - Dim: 4, + Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7}, + ValidData: nil, + Dim: 4, + Nullable: false, }, Int8VectorField: &Int8VectorFieldData{ - Data: []int8{4, 5, 6, 7}, - Dim: 4, + Data: []int8{4, 5, 6, 7}, + ValidData: nil, + Dim: 4, + Nullable: false, + }, + NullableFloatVectorField: &FloatVectorFieldData{ + Data: []float32{4.0, 5.0, 6.0, 7.0}, + ValidData: []bool{true}, + Dim: 4, + Nullable: true, + }, + NullableBinaryVectorField: &BinaryVectorFieldData{ + Data: []byte{255}, + ValidData: []bool{true}, + Dim: 8, + Nullable: true, + }, + NullableFloat16VectorField: &Float16VectorFieldData{ + Data: []byte{0xff, 0x0, 0xff, 0x0, 0xff, 0x0, 0xff, 0x0}, + ValidData: []bool{true}, + Dim: 4, + Nullable: true, + }, + NullableBFloat16VectorField: &BFloat16VectorFieldData{ + Data: []byte{0xff, 0x0, 0xff, 0x0, 0xff, 0x0, 0xff, 0x0}, + ValidData: []bool{true}, + Dim: 4, + Nullable: true, + }, + NullableInt8VectorField: &Int8VectorFieldData{ + Data: []int8{4, 5, 6, 7}, + ValidData: []bool{true}, + Dim: 4, + Nullable: true, + }, + NullableSparseFloatVectorField: &SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Dim: 300, + Contents: [][]byte{ + typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}), + }, + }, + ValidData: []bool{true}, + Nullable: true, }, ArrayField: &ArrayFieldData{ ElementType: schemapb.DataType_Int32, @@ -1389,15 +1619,21 @@ func TestMemorySize(t *testing.T) { assert.Equal(t, insertData1.Data[FloatField].GetMemorySize(), 5) assert.Equal(t, insertData1.Data[DoubleField].GetMemorySize(), 9) assert.Equal(t, insertData1.Data[StringField].GetMemorySize(), 18) - assert.Equal(t, insertData1.Data[BinaryVectorField].GetMemorySize(), 5) - assert.Equal(t, insertData1.Data[FloatVectorField].GetMemorySize(), 20) - assert.Equal(t, insertData1.Data[Float16VectorField].GetMemorySize(), 12) - assert.Equal(t, insertData1.Data[BFloat16VectorField].GetMemorySize(), 12) - assert.Equal(t, insertData1.Data[Int8VectorField].GetMemorySize(), 8) - assert.Equal(t, insertData1.Data[ArrayField].GetMemorySize(), 3*4+1) - assert.Equal(t, insertData1.Data[JSONField].GetMemorySize(), len([]byte(`{"batch":1}`))+16+1) - assert.Equal(t, insertData1.Data[StructSubInt32Field].GetMemorySize(), 4*4+1) - assert.Equal(t, insertData1.Data[StructSubFloatVectorField].GetMemorySize(), 4*4+4) + assert.Equal(t, insertData1.Data[BinaryVectorField].GetMemorySize(), 14) + assert.Equal(t, insertData1.Data[FloatVectorField].GetMemorySize(), 29) + assert.Equal(t, insertData1.Data[Float16VectorField].GetMemorySize(), 21) + assert.Equal(t, insertData1.Data[BFloat16VectorField].GetMemorySize(), 21) + assert.Equal(t, insertData1.Data[Int8VectorField].GetMemorySize(), 17) + assert.Equal(t, insertData1.Data[NullableFloatVectorField].GetMemorySize(), 30) + assert.Equal(t, insertData1.Data[NullableBinaryVectorField].GetMemorySize(), 15) + assert.Equal(t, insertData1.Data[NullableFloat16VectorField].GetMemorySize(), 22) + assert.Equal(t, insertData1.Data[NullableBFloat16VectorField].GetMemorySize(), 22) + assert.Equal(t, insertData1.Data[NullableInt8VectorField].GetMemorySize(), 18) + assert.Equal(t, insertData1.Data[NullableSparseFloatVectorField].GetMemorySize(), 39) + assert.Equal(t, insertData1.Data[ArrayField].GetMemorySize(), 13) + assert.Equal(t, insertData1.Data[JSONField].GetMemorySize(), 28) + assert.Equal(t, insertData1.Data[StructSubInt32Field].GetMemorySize(), 17) + assert.Equal(t, insertData1.Data[StructSubFloatVectorField].GetMemorySize(), 20) insertData2 := &InsertData{ Data: map[int64]FieldData{ @@ -1432,24 +1668,84 @@ func TestMemorySize(t *testing.T) { Data: []string{"1", "23"}, }, BinaryVectorField: &BinaryVectorFieldData{ - Data: []byte{0, 255}, - Dim: 8, + Data: []byte{0, 255}, + ValidData: nil, + Dim: 8, + Nullable: false, }, FloatVectorField: &FloatVectorFieldData{ - Data: []float32{0, 1, 2, 3, 0, 1, 2, 3}, - Dim: 4, + Data: []float32{0, 1, 2, 3, 0, 1, 2, 3}, + ValidData: nil, + Dim: 4, + Nullable: false, }, Float16VectorField: &Float16VectorFieldData{ - Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}, - Dim: 4, + Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}, + ValidData: nil, + Dim: 4, + Nullable: false, }, BFloat16VectorField: &BFloat16VectorFieldData{ - Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}, - Dim: 4, + Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7}, + ValidData: nil, + Dim: 4, + Nullable: false, }, Int8VectorField: &Int8VectorFieldData{ - Data: []int8{0, 1, 2, 3, 0, 1, 2, 3}, - Dim: 4, + Data: []int8{0, 1, 2, 3, 0, 1, 2, 3}, + ValidData: nil, + Dim: 4, + Nullable: false, + }, + SparseFloatVectorField: &SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Dim: 300, + Contents: [][]byte{ + typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{1.1, 1.2, 1.3}), + typeutil.CreateSparseFloatRow([]uint32{10, 20, 30}, []float32{2.1, 2.2, 2.3}), + }, + }, + Nullable: false, + }, + NullableFloatVectorField: &FloatVectorFieldData{ + Data: []float32{0.0, 1.0, 2.0, 3.0}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, + NullableBinaryVectorField: &BinaryVectorFieldData{ + Data: []byte{0}, + ValidData: []bool{true, false}, + Dim: 8, + Nullable: true, + }, + NullableFloat16VectorField: &Float16VectorFieldData{ + Data: []byte{0, 1, 2, 3, 4, 5, 6, 7}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, + NullableBFloat16VectorField: &BFloat16VectorFieldData{ + Data: []byte{0, 1, 2, 3, 4, 5, 6, 7}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, + NullableInt8VectorField: &Int8VectorFieldData{ + Data: []int8{0, 1, 2, 3}, + ValidData: []bool{true, false}, + Dim: 4, + Nullable: true, + }, + NullableSparseFloatVectorField: &SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Dim: 300, + Contents: [][]byte{ + typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}), + }, + }, + ValidData: []bool{true, false}, + Nullable: true, }, }, } @@ -1464,29 +1760,107 @@ func TestMemorySize(t *testing.T) { assert.Equal(t, insertData2.Data[FloatField].GetMemorySize(), 9) assert.Equal(t, insertData2.Data[DoubleField].GetMemorySize(), 17) assert.Equal(t, insertData2.Data[StringField].GetMemorySize(), 36) - assert.Equal(t, insertData2.Data[BinaryVectorField].GetMemorySize(), 6) - assert.Equal(t, insertData2.Data[FloatVectorField].GetMemorySize(), 36) - assert.Equal(t, insertData2.Data[Float16VectorField].GetMemorySize(), 20) - assert.Equal(t, insertData2.Data[BFloat16VectorField].GetMemorySize(), 20) - assert.Equal(t, insertData2.Data[Int8VectorField].GetMemorySize(), 12) + assert.Equal(t, insertData2.Data[BinaryVectorField].GetMemorySize(), 15) + assert.Equal(t, insertData2.Data[FloatVectorField].GetMemorySize(), 45) + assert.Equal(t, insertData2.Data[Float16VectorField].GetMemorySize(), 29) + assert.Equal(t, insertData2.Data[BFloat16VectorField].GetMemorySize(), 29) + assert.Equal(t, insertData2.Data[Int8VectorField].GetMemorySize(), 21) + assert.Equal(t, insertData2.Data[SparseFloatVectorField].GetMemorySize(), 64) + assert.Equal(t, insertData2.Data[NullableBinaryVectorField].GetMemorySize(), 16) + assert.Equal(t, insertData2.Data[NullableFloatVectorField].GetMemorySize(), 31) + assert.Equal(t, insertData2.Data[NullableFloat16VectorField].GetMemorySize(), 23) + assert.Equal(t, insertData2.Data[NullableBFloat16VectorField].GetMemorySize(), 23) + assert.Equal(t, insertData2.Data[NullableInt8VectorField].GetMemorySize(), 19) + assert.Equal(t, insertData2.Data[NullableSparseFloatVectorField].GetMemorySize(), 40) insertDataEmpty := &InsertData{ Data: map[int64]FieldData{ - RowIDField: &Int64FieldData{[]int64{}, nil, false}, - TimestampField: &Int64FieldData{[]int64{}, nil, false}, - BoolField: &BoolFieldData{[]bool{}, nil, false}, - Int8Field: &Int8FieldData{[]int8{}, nil, false}, - Int16Field: &Int16FieldData{[]int16{}, nil, false}, - Int32Field: &Int32FieldData{[]int32{}, nil, false}, - Int64Field: &Int64FieldData{[]int64{}, nil, false}, - FloatField: &FloatFieldData{[]float32{}, nil, false}, - DoubleField: &DoubleFieldData{[]float64{}, nil, false}, - StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false}, - BinaryVectorField: &BinaryVectorFieldData{[]byte{}, 8}, - FloatVectorField: &FloatVectorFieldData{[]float32{}, 4}, - Float16VectorField: &Float16VectorFieldData{[]byte{}, 4}, - BFloat16VectorField: &BFloat16VectorFieldData{[]byte{}, 4}, - Int8VectorField: &Int8VectorFieldData{[]int8{}, 4}, + RowIDField: &Int64FieldData{[]int64{}, nil, false}, + TimestampField: &Int64FieldData{[]int64{}, nil, false}, + BoolField: &BoolFieldData{[]bool{}, nil, false}, + Int8Field: &Int8FieldData{[]int8{}, nil, false}, + Int16Field: &Int16FieldData{[]int16{}, nil, false}, + Int32Field: &Int32FieldData{[]int32{}, nil, false}, + Int64Field: &Int64FieldData{[]int64{}, nil, false}, + FloatField: &FloatFieldData{[]float32{}, nil, false}, + DoubleField: &DoubleFieldData{[]float64{}, nil, false}, + StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false}, + BinaryVectorField: &BinaryVectorFieldData{ + Data: []byte{}, + ValidData: nil, + Dim: 8, + Nullable: false, + }, + FloatVectorField: &FloatVectorFieldData{ + Data: []float32{}, + ValidData: nil, + Dim: 4, + Nullable: false, + }, + Float16VectorField: &Float16VectorFieldData{ + Data: []byte{}, + ValidData: nil, + Dim: 4, + Nullable: false, + }, + BFloat16VectorField: &BFloat16VectorFieldData{ + Data: []byte{}, + ValidData: nil, + Dim: 4, + Nullable: false, + }, + Int8VectorField: &Int8VectorFieldData{ + Data: []int8{}, + ValidData: nil, + Dim: 4, + Nullable: false, + }, + SparseFloatVectorField: &SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Dim: 0, + Contents: [][]byte{}, + }, + ValidData: nil, + Nullable: false, + }, + NullableFloatVectorField: &FloatVectorFieldData{ + Data: []float32{}, + ValidData: []bool{}, + Dim: 4, + Nullable: true, + }, + NullableBinaryVectorField: &BinaryVectorFieldData{ + Data: []byte{}, + ValidData: []bool{}, + Dim: 8, + Nullable: true, + }, + NullableFloat16VectorField: &Float16VectorFieldData{ + Data: []byte{}, + ValidData: []bool{}, + Dim: 4, + Nullable: true, + }, + NullableBFloat16VectorField: &BFloat16VectorFieldData{ + Data: []byte{}, + ValidData: []bool{}, + Dim: 4, + Nullable: true, + }, + NullableInt8VectorField: &Int8VectorFieldData{ + Data: []int8{}, + ValidData: []bool{}, + Dim: 4, + Nullable: true, + }, + NullableSparseFloatVectorField: &SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Dim: 0, + Contents: [][]byte{}, + }, + ValidData: []bool{}, + Nullable: true, + }, StructSubFloatVectorField: &VectorArrayFieldData{ Dim: 2, ElementType: schemapb.DataType_FloatVector, @@ -1505,11 +1879,18 @@ func TestMemorySize(t *testing.T) { assert.Equal(t, insertDataEmpty.Data[FloatField].GetMemorySize(), 1) assert.Equal(t, insertDataEmpty.Data[DoubleField].GetMemorySize(), 1) assert.Equal(t, insertDataEmpty.Data[StringField].GetMemorySize(), 1) - assert.Equal(t, insertDataEmpty.Data[BinaryVectorField].GetMemorySize(), 4) - assert.Equal(t, insertDataEmpty.Data[FloatVectorField].GetMemorySize(), 4) - assert.Equal(t, insertDataEmpty.Data[Float16VectorField].GetMemorySize(), 4) - assert.Equal(t, insertDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 4) - assert.Equal(t, insertDataEmpty.Data[Int8VectorField].GetMemorySize(), 4) + assert.Equal(t, insertDataEmpty.Data[BinaryVectorField].GetMemorySize(), 13) + assert.Equal(t, insertDataEmpty.Data[FloatVectorField].GetMemorySize(), 13) + assert.Equal(t, insertDataEmpty.Data[Float16VectorField].GetMemorySize(), 13) + assert.Equal(t, insertDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 13) + assert.Equal(t, insertDataEmpty.Data[Int8VectorField].GetMemorySize(), 13) + assert.Equal(t, insertDataEmpty.Data[SparseFloatVectorField].GetMemorySize(), 9) + assert.Equal(t, insertDataEmpty.Data[NullableFloatVectorField].GetMemorySize(), 13) + assert.Equal(t, insertDataEmpty.Data[NullableBinaryVectorField].GetMemorySize(), 13) + assert.Equal(t, insertDataEmpty.Data[NullableFloat16VectorField].GetMemorySize(), 13) + assert.Equal(t, insertDataEmpty.Data[NullableBFloat16VectorField].GetMemorySize(), 13) + assert.Equal(t, insertDataEmpty.Data[NullableInt8VectorField].GetMemorySize(), 13) + assert.Equal(t, insertDataEmpty.Data[NullableSparseFloatVectorField].GetMemorySize(), 9) assert.Equal(t, insertDataEmpty.Data[StructSubFloatVectorField].GetMemorySize(), 0) } @@ -1589,22 +1970,49 @@ func TestAddFieldDataToPayload(t *testing.T) { assert.Error(t, err) err = AddFieldDataToPayload(e, schemapb.DataType_JSON, &JSONFieldData{[][]byte{[]byte(`"batch":2}`)}, nil, false}) assert.Error(t, err) - err = AddFieldDataToPayload(e, schemapb.DataType_BinaryVector, &BinaryVectorFieldData{[]byte{}, 8}) + err = AddFieldDataToPayload(e, schemapb.DataType_BinaryVector, &BinaryVectorFieldData{ + Data: []byte{}, + ValidData: nil, + Dim: 8, + Nullable: false, + }) assert.Error(t, err) - err = AddFieldDataToPayload(e, schemapb.DataType_FloatVector, &FloatVectorFieldData{[]float32{}, 4}) + err = AddFieldDataToPayload(e, schemapb.DataType_FloatVector, &FloatVectorFieldData{ + Data: []float32{}, + ValidData: nil, + Dim: 4, + Nullable: false, + }) assert.Error(t, err) - err = AddFieldDataToPayload(e, schemapb.DataType_Float16Vector, &Float16VectorFieldData{[]byte{}, 4}) + err = AddFieldDataToPayload(e, schemapb.DataType_Float16Vector, &Float16VectorFieldData{ + Data: []byte{}, + ValidData: nil, + Dim: 4, + Nullable: false, + }) assert.Error(t, err) - err = AddFieldDataToPayload(e, schemapb.DataType_BFloat16Vector, &BFloat16VectorFieldData{[]byte{}, 8}) + err = AddFieldDataToPayload(e, schemapb.DataType_BFloat16Vector, &BFloat16VectorFieldData{ + Data: []byte{}, + ValidData: nil, + Dim: 8, + Nullable: false, + }) assert.Error(t, err) err = AddFieldDataToPayload(e, schemapb.DataType_SparseFloatVector, &SparseFloatVectorFieldData{ SparseFloatArray: schemapb.SparseFloatArray{ Dim: 0, Contents: [][]byte{}, }, + ValidData: nil, + Nullable: false, }) assert.Error(t, err) - err = AddFieldDataToPayload(e, schemapb.DataType_Int8Vector, &Int8VectorFieldData{[]int8{}, 4}) + err = AddFieldDataToPayload(e, schemapb.DataType_Int8Vector, &Int8VectorFieldData{ + Data: []int8{}, + ValidData: nil, + Dim: 4, + Nullable: false, + }) assert.Error(t, err) err = AddFieldDataToPayload(e, schemapb.DataType_ArrayOfVector, &VectorArrayFieldData{ Dim: 2, diff --git a/internal/storage/insert_data.go b/internal/storage/insert_data.go index e026f2c255..654f169a79 100644 --- a/internal/storage/insert_data.go +++ b/internal/storage/insert_data.go @@ -195,70 +195,83 @@ func NewFieldData(dataType schemapb.DataType, fieldSchema *schemapb.FieldSchema, typeParams := fieldSchema.GetTypeParams() switch dataType { case schemapb.DataType_Float16Vector: - if fieldSchema.GetNullable() { - return nil, merr.WrapErrParameterInvalidMsg("vector not support null") - } dim, err := GetDimFromParams(typeParams) if err != nil { return nil, err } - return &Float16VectorFieldData{ - Data: make([]byte, 0, cap), - Dim: dim, - }, nil + data := &Float16VectorFieldData{ + Data: make([]byte, 0, cap), + Dim: dim, + Nullable: fieldSchema.GetNullable(), + } + if fieldSchema.GetNullable() { + data.ValidData = make([]bool, 0, cap) + } + return data, nil case schemapb.DataType_BFloat16Vector: - if fieldSchema.GetNullable() { - return nil, merr.WrapErrParameterInvalidMsg("vector not support null") - } dim, err := GetDimFromParams(typeParams) if err != nil { return nil, err } - return &BFloat16VectorFieldData{ - Data: make([]byte, 0, cap), - Dim: dim, - }, nil + data := &BFloat16VectorFieldData{ + Data: make([]byte, 0, cap), + Dim: dim, + Nullable: fieldSchema.GetNullable(), + } + if fieldSchema.GetNullable() { + data.ValidData = make([]bool, 0, cap) + } + return data, nil case schemapb.DataType_FloatVector: - if fieldSchema.GetNullable() { - return nil, merr.WrapErrParameterInvalidMsg("vector not support null") - } dim, err := GetDimFromParams(typeParams) if err != nil { return nil, err } - return &FloatVectorFieldData{ - Data: make([]float32, 0, cap), - Dim: dim, - }, nil + data := &FloatVectorFieldData{ + Data: make([]float32, 0, cap), + Dim: dim, + Nullable: fieldSchema.GetNullable(), + } + if fieldSchema.GetNullable() { + data.ValidData = make([]bool, 0, cap) + } + return data, nil case schemapb.DataType_BinaryVector: - if fieldSchema.GetNullable() { - return nil, merr.WrapErrParameterInvalidMsg("vector not support null") - } dim, err := GetDimFromParams(typeParams) if err != nil { return nil, err } - return &BinaryVectorFieldData{ - Data: make([]byte, 0, cap), - Dim: dim, - }, nil + data := &BinaryVectorFieldData{ + Data: make([]byte, 0, cap), + Dim: dim, + Nullable: fieldSchema.GetNullable(), + } + if fieldSchema.GetNullable() { + data.ValidData = make([]bool, 0, cap) + } + return data, nil case schemapb.DataType_SparseFloatVector: - if fieldSchema.GetNullable() { - return nil, merr.WrapErrParameterInvalidMsg("vector not support null") + data := &SparseFloatVectorFieldData{ + Nullable: fieldSchema.GetNullable(), } - return &SparseFloatVectorFieldData{}, nil + if fieldSchema.GetNullable() { + data.ValidData = make([]bool, 0, cap) + } + return data, nil case schemapb.DataType_Int8Vector: - if fieldSchema.GetNullable() { - return nil, merr.WrapErrParameterInvalidMsg("vector not support null") - } dim, err := GetDimFromParams(typeParams) if err != nil { return nil, err } - return &Int8VectorFieldData{ - Data: make([]int8, 0, cap), - Dim: dim, - }, nil + data := &Int8VectorFieldData{ + Data: make([]int8, 0, cap), + Dim: dim, + Nullable: fieldSchema.GetNullable(), + } + if fieldSchema.GetNullable() { + data.ValidData = make([]bool, 0, cap) + } + return data, nil case schemapb.DataType_Bool: data := &BoolFieldData{ Data: make([]bool, 0, cap), @@ -453,30 +466,97 @@ type GeometryFieldData struct { ValidData []bool Nullable bool } + +// LogicalToPhysicalMapping maps logical offset to physical offset for nullable vector +type LogicalToPhysicalMapping struct { + validCount int + l2pMap map[int]int +} + +func (m *LogicalToPhysicalMapping) GetPhysicalOffset(logicalOffset int) int { + if m.l2pMap == nil { + return logicalOffset + } + if physicalOffset, ok := m.l2pMap[logicalOffset]; ok { + return physicalOffset + } + return -1 +} + +func (m *LogicalToPhysicalMapping) GetMemorySize() int { + size := 8 // validCount int + size += len(m.l2pMap) * 16 // map[int]int, roughly 16 bytes per entry + return size +} + +func (m *LogicalToPhysicalMapping) GetValidCount() int { + return m.validCount +} + +func (m *LogicalToPhysicalMapping) Build(validData []bool, startLogical, totalCount int) { + if totalCount == 0 { + return + } + if len(validData) < totalCount { + return + } + + if m.l2pMap == nil { + m.l2pMap = make(map[int]int) + } + + physicalIdx := m.validCount + for i := 0; i < totalCount; i++ { + if validData[i] { + m.l2pMap[startLogical+i] = physicalIdx + physicalIdx++ + } + } + m.validCount = physicalIdx +} + type BinaryVectorFieldData struct { - Data []byte - Dim int + Data []byte + ValidData []bool + Dim int + Nullable bool + L2PMapping LogicalToPhysicalMapping } type FloatVectorFieldData struct { - Data []float32 - Dim int + Data []float32 + ValidData []bool + Dim int + Nullable bool + L2PMapping LogicalToPhysicalMapping } type Float16VectorFieldData struct { - Data []byte - Dim int + Data []byte + ValidData []bool + Dim int + Nullable bool + L2PMapping LogicalToPhysicalMapping } type BFloat16VectorFieldData struct { - Data []byte - Dim int + Data []byte + ValidData []bool + Dim int + Nullable bool + L2PMapping LogicalToPhysicalMapping } type SparseFloatVectorFieldData struct { schemapb.SparseFloatArray + ValidData []bool + Nullable bool + L2PMapping LogicalToPhysicalMapping } type Int8VectorFieldData struct { - Data []int8 - Dim int + Data []int8 + ValidData []bool + Dim int + Nullable bool + L2PMapping LogicalToPhysicalMapping } type VectorArrayFieldData struct { @@ -493,29 +573,71 @@ func (dst *SparseFloatVectorFieldData) AppendAllRows(src *SparseFloatVectorField dst.Dim = src.Dim } dst.Contents = append(dst.Contents, src.Contents...) + if src.Nullable { + if dst.ValidData == nil { + dst.ValidData = make([]bool, 0, len(src.ValidData)) + } + dst.L2PMapping.Build(src.ValidData, len(dst.ValidData), len(src.ValidData)) + dst.ValidData = append(dst.ValidData, src.ValidData...) + dst.Nullable = true + } } // RowNum implements FieldData.RowNum -func (data *BoolFieldData) RowNum() int { return len(data.Data) } -func (data *Int8FieldData) RowNum() int { return len(data.Data) } -func (data *Int16FieldData) RowNum() int { return len(data.Data) } -func (data *Int32FieldData) RowNum() int { return len(data.Data) } -func (data *Int64FieldData) RowNum() int { return len(data.Data) } -func (data *FloatFieldData) RowNum() int { return len(data.Data) } -func (data *DoubleFieldData) RowNum() int { return len(data.Data) } -func (data *TimestamptzFieldData) RowNum() int { return len(data.Data) } -func (data *StringFieldData) RowNum() int { return len(data.Data) } -func (data *ArrayFieldData) RowNum() int { return len(data.Data) } -func (data *JSONFieldData) RowNum() int { return len(data.Data) } -func (data *GeometryFieldData) RowNum() int { return len(data.Data) } -func (data *BinaryVectorFieldData) RowNum() int { return len(data.Data) * 8 / data.Dim } -func (data *FloatVectorFieldData) RowNum() int { return len(data.Data) / data.Dim } -func (data *Float16VectorFieldData) RowNum() int { return len(data.Data) / 2 / data.Dim } -func (data *BFloat16VectorFieldData) RowNum() int { +func (data *BoolFieldData) RowNum() int { return len(data.Data) } +func (data *Int8FieldData) RowNum() int { return len(data.Data) } +func (data *Int16FieldData) RowNum() int { return len(data.Data) } +func (data *Int32FieldData) RowNum() int { return len(data.Data) } +func (data *Int64FieldData) RowNum() int { return len(data.Data) } +func (data *FloatFieldData) RowNum() int { return len(data.Data) } +func (data *DoubleFieldData) RowNum() int { return len(data.Data) } +func (data *TimestamptzFieldData) RowNum() int { return len(data.Data) } +func (data *StringFieldData) RowNum() int { return len(data.Data) } +func (data *ArrayFieldData) RowNum() int { return len(data.Data) } +func (data *JSONFieldData) RowNum() int { return len(data.Data) } +func (data *GeometryFieldData) RowNum() int { return len(data.Data) } +func (data *BinaryVectorFieldData) RowNum() int { + if data.Nullable { + return len(data.ValidData) + } + return len(data.Data) * 8 / data.Dim +} + +func (data *FloatVectorFieldData) RowNum() int { + if data.Nullable { + return len(data.ValidData) + } + return len(data.Data) / data.Dim +} + +func (data *Float16VectorFieldData) RowNum() int { + if data.Nullable { + return len(data.ValidData) + } return len(data.Data) / 2 / data.Dim } -func (data *SparseFloatVectorFieldData) RowNum() int { return len(data.Contents) } -func (data *Int8VectorFieldData) RowNum() int { return len(data.Data) / data.Dim } + +func (data *BFloat16VectorFieldData) RowNum() int { + if data.Nullable { + return len(data.ValidData) + } + return len(data.Data) / 2 / data.Dim +} + +func (data *SparseFloatVectorFieldData) RowNum() int { + if data.Nullable { + return len(data.ValidData) + } + return len(data.Contents) +} + +func (data *Int8VectorFieldData) RowNum() int { + if data.Nullable { + return len(data.ValidData) + } + return len(data.Data) / data.Dim +} + func (data *VectorArrayFieldData) RowNum() int { return len(data.Data) } @@ -606,27 +728,51 @@ func (data *GeometryFieldData) GetRow(i int) any { } func (data *BinaryVectorFieldData) GetRow(i int) any { - return data.Data[i*data.Dim/8 : (i+1)*data.Dim/8] + if data.GetNullable() && !data.ValidData[i] { + return nil + } + physicalIdx := data.L2PMapping.GetPhysicalOffset(i) + return data.Data[physicalIdx*data.Dim/8 : (physicalIdx+1)*data.Dim/8] } func (data *SparseFloatVectorFieldData) GetRow(i int) interface{} { - return data.Contents[i] + if data.GetNullable() && !data.ValidData[i] { + return nil + } + physicalIdx := data.L2PMapping.GetPhysicalOffset(i) + return data.Contents[physicalIdx] } func (data *FloatVectorFieldData) GetRow(i int) interface{} { - return data.Data[i*data.Dim : (i+1)*data.Dim] + if data.GetNullable() && !data.ValidData[i] { + return nil + } + physicalIdx := data.L2PMapping.GetPhysicalOffset(i) + return data.Data[physicalIdx*data.Dim : (physicalIdx+1)*data.Dim] } func (data *Float16VectorFieldData) GetRow(i int) interface{} { - return data.Data[i*data.Dim*2 : (i+1)*data.Dim*2] + if data.GetNullable() && !data.ValidData[i] { + return nil + } + physicalIdx := data.L2PMapping.GetPhysicalOffset(i) + return data.Data[physicalIdx*data.Dim*2 : (physicalIdx+1)*data.Dim*2] } func (data *BFloat16VectorFieldData) GetRow(i int) interface{} { - return data.Data[i*data.Dim*2 : (i+1)*data.Dim*2] + if data.GetNullable() && !data.ValidData[i] { + return nil + } + physicalIdx := data.L2PMapping.GetPhysicalOffset(i) + return data.Data[physicalIdx*data.Dim*2 : (physicalIdx+1)*data.Dim*2] } func (data *Int8VectorFieldData) GetRow(i int) interface{} { - return data.Data[i*data.Dim : (i+1)*data.Dim] + if data.GetNullable() && !data.ValidData[i] { + return nil + } + physicalIdx := data.L2PMapping.GetPhysicalOffset(i) + return data.Data[physicalIdx*data.Dim : (physicalIdx+1)*data.Dim] } func (data *VectorArrayFieldData) GetRow(i int) interface{} { @@ -862,42 +1008,83 @@ func (data *GeometryFieldData) AppendRow(row interface{}) error { } func (data *BinaryVectorFieldData) AppendRow(row interface{}) error { + if data.GetNullable() && row == nil { + data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, false) + return nil + } v, ok := row.([]byte) if !ok || len(v) != data.Dim/8 { return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type") } data.Data = append(data.Data, v...) + if data.GetNullable() { + data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, true) + } return nil } func (data *FloatVectorFieldData) AppendRow(row interface{}) error { + if data.GetNullable() && row == nil { + data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, false) + return nil + } v, ok := row.([]float32) if !ok || len(v) != data.Dim { return merr.WrapErrParameterInvalid("[]float32", row, "Wrong row type") } data.Data = append(data.Data, v...) + if data.GetNullable() { + data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, true) + } return nil } func (data *Float16VectorFieldData) AppendRow(row interface{}) error { + if data.GetNullable() && row == nil { + data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, false) + return nil + } v, ok := row.([]byte) if !ok || len(v) != data.Dim*2 { return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type") } data.Data = append(data.Data, v...) + if data.GetNullable() { + data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, true) + } return nil } func (data *BFloat16VectorFieldData) AppendRow(row interface{}) error { + if data.GetNullable() && row == nil { + data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, false) + return nil + } v, ok := row.([]byte) if !ok || len(v) != data.Dim*2 { return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type") } data.Data = append(data.Data, v...) + if data.GetNullable() { + data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, true) + } return nil } func (data *SparseFloatVectorFieldData) AppendRow(row interface{}) error { + if data.GetNullable() && row == nil { + data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, false) + return nil + } v, ok := row.([]byte) if !ok { return merr.WrapErrParameterInvalid("SparseFloatVectorRowData", row, "Wrong row type") @@ -910,15 +1097,28 @@ func (data *SparseFloatVectorFieldData) AppendRow(row interface{}) error { data.Dim = rowDim } data.Contents = append(data.Contents, v) + if data.GetNullable() { + data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, true) + } return nil } func (data *Int8VectorFieldData) AppendRow(row interface{}) error { + if data.GetNullable() && row == nil { + data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, false) + return nil + } v, ok := row.([]int8) if !ok || len(v) != data.Dim { return merr.WrapErrParameterInvalid("[]int8", row, "Wrong row type") } data.Data = append(data.Data, v...) + if data.GetNullable() { + data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1) + data.ValidData = append(data.ValidData, true) + } return nil } @@ -1431,15 +1631,15 @@ func (data *GeometryFieldData) AppendValidDataRows(rows interface{}) error { // AppendValidDataRows appends FLATTEN vectors to field data. func (data *BinaryVectorFieldData) AppendValidDataRows(rows interface{}) error { - if rows != nil { - v, ok := rows.([]bool) - if !ok { - return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") - } - if len(v) != 0 { - return merr.WrapErrParameterInvalidMsg("not support Nullable in vector") - } + if rows == nil { + return nil } + v, ok := rows.([]bool) + if !ok { + return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") + } + data.L2PMapping.Build(v, len(data.ValidData), len(v)) + data.ValidData = append(data.ValidData, v...) return nil } @@ -1458,69 +1658,69 @@ func (data *VectorArrayFieldData) AppendValidDataRows(rows interface{}) error { // AppendValidDataRows appends FLATTEN vectors to field data. func (data *FloatVectorFieldData) AppendValidDataRows(rows interface{}) error { - if rows != nil { - v, ok := rows.([]bool) - if !ok { - return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") - } - if len(v) != 0 { - return merr.WrapErrParameterInvalidMsg("not support Nullable in vector") - } + if rows == nil { + return nil } + v, ok := rows.([]bool) + if !ok { + return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") + } + data.L2PMapping.Build(v, len(data.ValidData), len(v)) + data.ValidData = append(data.ValidData, v...) return nil } // AppendValidDataRows appends FLATTEN vectors to field data. func (data *Float16VectorFieldData) AppendValidDataRows(rows interface{}) error { - if rows != nil { - v, ok := rows.([]bool) - if !ok { - return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") - } - if len(v) != 0 { - return merr.WrapErrParameterInvalidMsg("not support Nullable in vector") - } + if rows == nil { + return nil } + v, ok := rows.([]bool) + if !ok { + return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") + } + data.L2PMapping.Build(v, len(data.ValidData), len(v)) + data.ValidData = append(data.ValidData, v...) return nil } // AppendValidDataRows appends FLATTEN vectors to field data. func (data *BFloat16VectorFieldData) AppendValidDataRows(rows interface{}) error { - if rows != nil { - v, ok := rows.([]bool) - if !ok { - return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") - } - if len(v) != 0 { - return merr.WrapErrParameterInvalidMsg("not support Nullable in vector") - } + if rows == nil { + return nil } + v, ok := rows.([]bool) + if !ok { + return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") + } + data.L2PMapping.Build(v, len(data.ValidData), len(v)) + data.ValidData = append(data.ValidData, v...) return nil } func (data *SparseFloatVectorFieldData) AppendValidDataRows(rows interface{}) error { - if rows != nil { - v, ok := rows.([]bool) - if !ok { - return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") - } - if len(v) != 0 { - return merr.WrapErrParameterInvalidMsg("not support Nullable in vector") - } + if rows == nil { + return nil } + v, ok := rows.([]bool) + if !ok { + return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") + } + data.L2PMapping.Build(v, len(data.ValidData), len(v)) + data.ValidData = append(data.ValidData, v...) return nil } func (data *Int8VectorFieldData) AppendValidDataRows(rows interface{}) error { - if rows != nil { - v, ok := rows.([]bool) - if !ok { - return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") - } - if len(v) != 0 { - return merr.WrapErrParameterInvalidMsg("not support Nullable in vector") - } + if rows == nil { + return nil } + v, ok := rows.([]bool) + if !ok { + return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type") + } + data.L2PMapping.Build(v, len(data.ValidData), len(v)) + data.ValidData = append(data.ValidData, v...) return nil } @@ -1557,17 +1757,32 @@ func (data *TimestamptzFieldData) GetMemorySize() int { return binary.Size(data.Data) + binary.Size(data.ValidData) + binary.Size(data.Nullable) } -func (data *BinaryVectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 } -func (data *FloatVectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 } -func (data *Float16VectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 } -func (data *BFloat16VectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 } +func (data *BinaryVectorFieldData) GetMemorySize() int { + // Data + ValidData + Dim(4) + Nullable(1) + L2PMapping + return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize() +} + +func (data *FloatVectorFieldData) GetMemorySize() int { + return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize() +} + +func (data *Float16VectorFieldData) GetMemorySize() int { + return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize() +} + +func (data *BFloat16VectorFieldData) GetMemorySize() int { + return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize() +} func (data *SparseFloatVectorFieldData) GetMemorySize() int { // TODO(SPARSE): should this be the memory size of serialzied size? - return proto.Size(&data.SparseFloatArray) + // SparseFloatArray + ValidData + Nullable(1) + L2PMapping + return proto.Size(&data.SparseFloatArray) + binary.Size(data.ValidData) + 1 + data.L2PMapping.GetMemorySize() } -func (data *Int8VectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 } +func (data *Int8VectorFieldData) GetMemorySize() int { + return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize() +} func GetVectorSize(vector *schemapb.VectorField, vectorType schemapb.DataType) int { size := 0 @@ -1698,22 +1913,51 @@ func (data *GeometryFieldData) GetMemorySize() int { return size + binary.Size(data.ValidData) + binary.Size(data.Nullable) } -func (data *BoolFieldData) GetRowSize(i int) int { return 1 } -func (data *Int8FieldData) GetRowSize(i int) int { return 1 } -func (data *Int16FieldData) GetRowSize(i int) int { return 2 } -func (data *Int32FieldData) GetRowSize(i int) int { return 4 } -func (data *Int64FieldData) GetRowSize(i int) int { return 8 } -func (data *FloatFieldData) GetRowSize(i int) int { return 4 } -func (data *DoubleFieldData) GetRowSize(i int) int { return 8 } -func (data *TimestamptzFieldData) GetRowSize(i int) int { return 8 } -func (data *BinaryVectorFieldData) GetRowSize(i int) int { return data.Dim / 8 } -func (data *FloatVectorFieldData) GetRowSize(i int) int { return data.Dim * 4 } -func (data *Float16VectorFieldData) GetRowSize(i int) int { return data.Dim * 2 } -func (data *BFloat16VectorFieldData) GetRowSize(i int) int { return data.Dim * 2 } -func (data *Int8VectorFieldData) GetRowSize(i int) int { return data.Dim } -func (data *StringFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 } -func (data *JSONFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 } -func (data *GeometryFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 } +func (data *BoolFieldData) GetRowSize(i int) int { return 1 } +func (data *Int8FieldData) GetRowSize(i int) int { return 1 } +func (data *Int16FieldData) GetRowSize(i int) int { return 2 } +func (data *Int32FieldData) GetRowSize(i int) int { return 4 } +func (data *Int64FieldData) GetRowSize(i int) int { return 8 } +func (data *FloatFieldData) GetRowSize(i int) int { return 4 } +func (data *DoubleFieldData) GetRowSize(i int) int { return 8 } +func (data *TimestamptzFieldData) GetRowSize(i int) int { return 8 } +func (data *BinaryVectorFieldData) GetRowSize(i int) int { + if data.GetNullable() && !data.ValidData[i] { + return 0 + } + return data.Dim / 8 +} + +func (data *FloatVectorFieldData) GetRowSize(i int) int { + if data.GetNullable() && !data.ValidData[i] { + return 0 + } + return data.Dim * 4 +} + +func (data *Float16VectorFieldData) GetRowSize(i int) int { + if data.GetNullable() && !data.ValidData[i] { + return 0 + } + return data.Dim * 2 +} + +func (data *BFloat16VectorFieldData) GetRowSize(i int) int { + if data.GetNullable() && !data.ValidData[i] { + return 0 + } + return data.Dim * 2 +} + +func (data *Int8VectorFieldData) GetRowSize(i int) int { + if data.GetNullable() && !data.ValidData[i] { + return 0 + } + return data.Dim +} +func (data *StringFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 } +func (data *JSONFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 } +func (data *GeometryFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 } func (data *ArrayFieldData) GetRowSize(i int) int { switch data.ElementType { case schemapb.DataType_Bool: @@ -1737,7 +1981,11 @@ func (data *ArrayFieldData) GetRowSize(i int) int { } func (data *SparseFloatVectorFieldData) GetRowSize(i int) int { - return len(data.Contents[i]) + if data.GetNullable() && !data.ValidData[i] { + return 0 + } + physicalIdx := data.L2PMapping.GetPhysicalOffset(i) + return len(data.Contents[physicalIdx]) } func (data *VectorArrayFieldData) GetRowSize(i int) int { @@ -1777,27 +2025,27 @@ func (data *TimestamptzFieldData) GetNullable() bool { } func (data *BFloat16VectorFieldData) GetNullable() bool { - return false + return data.Nullable } func (data *BinaryVectorFieldData) GetNullable() bool { - return false + return data.Nullable } func (data *FloatVectorFieldData) GetNullable() bool { - return false + return data.Nullable } func (data *SparseFloatVectorFieldData) GetNullable() bool { - return false + return data.Nullable } func (data *Float16VectorFieldData) GetNullable() bool { - return false + return data.Nullable } func (data *Int8VectorFieldData) GetNullable() bool { - return false + return data.Nullable } func (data *StringFieldData) GetNullable() bool { @@ -1819,3 +2067,23 @@ func (data *VectorArrayFieldData) GetNullable() bool { func (data *GeometryFieldData) GetNullable() bool { return data.Nullable } + +func (data *BoolFieldData) GetValidData() []bool { return data.ValidData } +func (data *Int8FieldData) GetValidData() []bool { return data.ValidData } +func (data *Int16FieldData) GetValidData() []bool { return data.ValidData } +func (data *Int32FieldData) GetValidData() []bool { return data.ValidData } +func (data *Int64FieldData) GetValidData() []bool { return data.ValidData } +func (data *FloatFieldData) GetValidData() []bool { return data.ValidData } +func (data *DoubleFieldData) GetValidData() []bool { return data.ValidData } +func (data *TimestamptzFieldData) GetValidData() []bool { return data.ValidData } +func (data *StringFieldData) GetValidData() []bool { return data.ValidData } +func (data *ArrayFieldData) GetValidData() []bool { return data.ValidData } +func (data *JSONFieldData) GetValidData() []bool { return data.ValidData } +func (data *GeometryFieldData) GetValidData() []bool { return data.ValidData } +func (data *BinaryVectorFieldData) GetValidData() []bool { return data.ValidData } +func (data *FloatVectorFieldData) GetValidData() []bool { return data.ValidData } +func (data *Float16VectorFieldData) GetValidData() []bool { return data.ValidData } +func (data *BFloat16VectorFieldData) GetValidData() []bool { return data.ValidData } +func (data *SparseFloatVectorFieldData) GetValidData() []bool { return data.ValidData } +func (data *Int8VectorFieldData) GetValidData() []bool { return data.ValidData } +func (data *VectorArrayFieldData) GetValidData() []bool { return nil } diff --git a/internal/storage/insert_data_test.go b/internal/storage/insert_data_test.go index 0b2a75b97c..cb0de86d69 100644 --- a/internal/storage/insert_data_test.go +++ b/internal/storage/insert_data_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/suite" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -45,17 +46,31 @@ func (s *InsertDataSuite) TestInsertData() { tests := []struct { description string dataType schemapb.DataType + typeParams []*commonpb.KeyValuePair + nullable bool }{ - {"nullable bool field", schemapb.DataType_Bool}, - {"nullable int8 field", schemapb.DataType_Int8}, - {"nullable int16 field", schemapb.DataType_Int16}, - {"nullable int32 field", schemapb.DataType_Int32}, - {"nullable int64 field", schemapb.DataType_Int64}, - {"nullable float field", schemapb.DataType_Float}, - {"nullable double field", schemapb.DataType_Double}, - {"nullable json field", schemapb.DataType_JSON}, - {"nullable array field", schemapb.DataType_Array}, - {"nullable string/varchar field", schemapb.DataType_String}, + {"nullable bool field", schemapb.DataType_Bool, nil, true}, + {"nullable int8 field", schemapb.DataType_Int8, nil, true}, + {"nullable int16 field", schemapb.DataType_Int16, nil, true}, + {"nullable int32 field", schemapb.DataType_Int32, nil, true}, + {"nullable int64 field", schemapb.DataType_Int64, nil, true}, + {"nullable float field", schemapb.DataType_Float, nil, true}, + {"nullable double field", schemapb.DataType_Double, nil, true}, + {"nullable json field", schemapb.DataType_JSON, nil, true}, + {"nullable array field", schemapb.DataType_Array, nil, true}, + {"nullable string/varchar field", schemapb.DataType_String, nil, true}, + {"nullable binary vector field", schemapb.DataType_BinaryVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "8"}}, true}, + {"nullable float vector field", schemapb.DataType_FloatVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true}, + {"nullable float16 vector field", schemapb.DataType_Float16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true}, + {"nullable bfloat16 vector field", schemapb.DataType_BFloat16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true}, + {"nullable sparse float vector field", schemapb.DataType_SparseFloatVector, nil, true}, + {"nullable int8 vector field", schemapb.DataType_Int8Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true}, + {"non-nullable binary vector field", schemapb.DataType_BinaryVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "8"}}, false}, + {"non-nullable float vector field", schemapb.DataType_FloatVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false}, + {"non-nullable float16 vector field", schemapb.DataType_Float16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false}, + {"non-nullable bfloat16 vector field", schemapb.DataType_BFloat16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false}, + {"non-nullable sparse float vector field", schemapb.DataType_SparseFloatVector, nil, false}, + {"non-nullable int8 vector field", schemapb.DataType_Int8Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false}, } for _, test := range tests { @@ -63,8 +78,9 @@ func (s *InsertDataSuite) TestInsertData() { schema := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ { - DataType: test.dataType, - Nullable: true, + DataType: test.dataType, + Nullable: test.nullable, + TypeParams: test.typeParams, }, }, } @@ -115,15 +131,15 @@ func (s *InsertDataSuite) TestInsertData() { s.Run("init by New", func() { s.True(s.iDataEmpty.IsEmpty()) s.Equal(0, s.iDataEmpty.GetRowNum()) - s.Equal(33, s.iDataEmpty.GetMemorySize()) + s.Equal(161, s.iDataEmpty.GetMemorySize()) s.False(s.iDataOneRow.IsEmpty()) s.Equal(1, s.iDataOneRow.GetRowNum()) - s.Equal(240, s.iDataOneRow.GetMemorySize()) + s.Equal(535, s.iDataOneRow.GetMemorySize()) s.False(s.iDataTwoRows.IsEmpty()) s.Equal(2, s.iDataTwoRows.GetRowNum()) - s.Equal(433, s.iDataTwoRows.GetMemorySize()) + s.Equal(734, s.iDataTwoRows.GetMemorySize()) for _, field := range s.iDataTwoRows.Data { s.Equal(2, field.RowNum()) @@ -147,12 +163,13 @@ func (s *InsertDataSuite) TestMemorySize() { s.Equal(s.iDataEmpty.Data[DoubleField].GetMemorySize(), 1) s.Equal(s.iDataEmpty.Data[StringField].GetMemorySize(), 1) s.Equal(s.iDataEmpty.Data[ArrayField].GetMemorySize(), 1) - s.Equal(s.iDataEmpty.Data[BinaryVectorField].GetMemorySize(), 4) - s.Equal(s.iDataEmpty.Data[FloatVectorField].GetMemorySize(), 4) - s.Equal(s.iDataEmpty.Data[Float16VectorField].GetMemorySize(), 4) - s.Equal(s.iDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 4) - s.Equal(s.iDataEmpty.Data[SparseFloatVectorField].GetMemorySize(), 0) - s.Equal(s.iDataEmpty.Data[Int8VectorField].GetMemorySize(), 4) + // +9 bytes: Nullable(1) + L2PMapping.GetMemorySize()(8) + s.Equal(s.iDataEmpty.Data[BinaryVectorField].GetMemorySize(), 4+9) + s.Equal(s.iDataEmpty.Data[FloatVectorField].GetMemorySize(), 4+9) + s.Equal(s.iDataEmpty.Data[Float16VectorField].GetMemorySize(), 4+9) + s.Equal(s.iDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 4+9) + s.Equal(s.iDataEmpty.Data[SparseFloatVectorField].GetMemorySize(), 0+9) + s.Equal(s.iDataEmpty.Data[Int8VectorField].GetMemorySize(), 4+9) s.Equal(s.iDataEmpty.Data[StructSubInt32Field].GetMemorySize(), 1) s.Equal(s.iDataEmpty.Data[StructSubFloatVectorField].GetMemorySize(), 0) @@ -168,12 +185,13 @@ func (s *InsertDataSuite) TestMemorySize() { s.Equal(s.iDataOneRow.Data[StringField].GetMemorySize(), 20) s.Equal(s.iDataOneRow.Data[JSONField].GetMemorySize(), len([]byte(`{"batch":1}`))+16+1) s.Equal(s.iDataOneRow.Data[ArrayField].GetMemorySize(), 3*4+1) - s.Equal(s.iDataOneRow.Data[BinaryVectorField].GetMemorySize(), 5) - s.Equal(s.iDataOneRow.Data[FloatVectorField].GetMemorySize(), 20) - s.Equal(s.iDataOneRow.Data[Float16VectorField].GetMemorySize(), 12) - s.Equal(s.iDataOneRow.Data[BFloat16VectorField].GetMemorySize(), 12) - s.Equal(s.iDataOneRow.Data[SparseFloatVectorField].GetMemorySize(), 28) - s.Equal(s.iDataOneRow.Data[Int8VectorField].GetMemorySize(), 8) + // +9 bytes: Nullable(1) + L2PMapping.GetMemorySize()(8) + s.Equal(s.iDataOneRow.Data[BinaryVectorField].GetMemorySize(), 5+9) + s.Equal(s.iDataOneRow.Data[FloatVectorField].GetMemorySize(), 20+9) + s.Equal(s.iDataOneRow.Data[Float16VectorField].GetMemorySize(), 12+9) + s.Equal(s.iDataOneRow.Data[BFloat16VectorField].GetMemorySize(), 12+9) + s.Equal(s.iDataOneRow.Data[SparseFloatVectorField].GetMemorySize(), 28+9) + s.Equal(s.iDataOneRow.Data[Int8VectorField].GetMemorySize(), 8+9) s.Equal(s.iDataOneRow.Data[StructSubInt32Field].GetMemorySize(), 3*4+1) s.Equal(s.iDataOneRow.Data[StructSubFloatVectorField].GetMemorySize(), 3*4*2+4) @@ -188,12 +206,13 @@ func (s *InsertDataSuite) TestMemorySize() { s.Equal(s.iDataTwoRows.Data[DoubleField].GetMemorySize(), 17) s.Equal(s.iDataTwoRows.Data[StringField].GetMemorySize(), 39) s.Equal(s.iDataTwoRows.Data[ArrayField].GetMemorySize(), 25) - s.Equal(s.iDataTwoRows.Data[BinaryVectorField].GetMemorySize(), 6) - s.Equal(s.iDataTwoRows.Data[FloatVectorField].GetMemorySize(), 36) - s.Equal(s.iDataTwoRows.Data[Float16VectorField].GetMemorySize(), 20) - s.Equal(s.iDataTwoRows.Data[BFloat16VectorField].GetMemorySize(), 20) - s.Equal(s.iDataTwoRows.Data[SparseFloatVectorField].GetMemorySize(), 54) - s.Equal(s.iDataTwoRows.Data[Int8VectorField].GetMemorySize(), 12) + // +9 bytes: Nullable(1) + L2PMapping.GetMemorySize()(8) + s.Equal(s.iDataTwoRows.Data[BinaryVectorField].GetMemorySize(), 6+9) + s.Equal(s.iDataTwoRows.Data[FloatVectorField].GetMemorySize(), 36+9) + s.Equal(s.iDataTwoRows.Data[Float16VectorField].GetMemorySize(), 20+9) + s.Equal(s.iDataTwoRows.Data[BFloat16VectorField].GetMemorySize(), 20+9) + s.Equal(s.iDataTwoRows.Data[SparseFloatVectorField].GetMemorySize(), 54+9) + s.Equal(s.iDataTwoRows.Data[Int8VectorField].GetMemorySize(), 12+9) s.Equal(s.iDataTwoRows.Data[StructSubInt32Field].GetMemorySize(), 3*4+2*4+1) s.Equal(s.iDataTwoRows.Data[StructSubFloatVectorField].GetMemorySize(), 3*4*2+4+2*4*2+4) } @@ -252,25 +271,31 @@ func (s *InsertDataSuite) SetupTest() { s.Require().NoError(err) s.True(s.iDataEmpty.IsEmpty()) s.Equal(0, s.iDataEmpty.GetRowNum()) - s.Equal(33, s.iDataEmpty.GetMemorySize()) + s.Equal(161, s.iDataEmpty.GetMemorySize()) row1 := map[FieldID]interface{}{ - RowIDField: int64(3), - TimestampField: int64(3), - BoolField: true, - Int8Field: int8(3), - Int16Field: int16(3), - Int32Field: int32(3), - Int64Field: int64(3), - FloatField: float32(3), - DoubleField: float64(3), - StringField: "str", - BinaryVectorField: []byte{0}, - FloatVectorField: []float32{4, 5, 6, 7}, - Float16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255}, - BFloat16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255}, - SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}), - Int8VectorField: []int8{-4, -5, 6, 7}, + RowIDField: int64(3), + TimestampField: int64(3), + BoolField: true, + Int8Field: int8(3), + Int16Field: int16(3), + Int32Field: int32(3), + Int64Field: int64(3), + FloatField: float32(3), + DoubleField: float64(3), + StringField: "str", + BinaryVectorField: []byte{0}, + FloatVectorField: []float32{4, 5, 6, 7}, + Float16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255}, + BFloat16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255}, + SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}), + Int8VectorField: []int8{-4, -5, 6, 7}, + NullableFloatVectorField: []float32{1.0, 2.0, 3.0, 4.0}, + NullableBinaryVectorField: []byte{1}, + NullableFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + NullableBFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + NullableInt8VectorField: []int8{1, 2, 3, 4}, + NullableSparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}), ArrayField: &schemapb.ScalarField{ Data: &schemapb.ScalarField_IntData{ IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}}, @@ -300,22 +325,28 @@ func (s *InsertDataSuite) SetupTest() { } row2 := map[FieldID]interface{}{ - RowIDField: int64(1), - TimestampField: int64(1), - BoolField: false, - Int8Field: int8(1), - Int16Field: int16(1), - Int32Field: int32(1), - Int64Field: int64(1), - FloatField: float32(1), - DoubleField: float64(1), - StringField: string("str"), - BinaryVectorField: []byte{0}, - FloatVectorField: []float32{4, 5, 6, 7}, - Float16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8}, - BFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8}, - SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{2, 3, 4}, []float32{4, 5, 6}), - Int8VectorField: []int8{-128, -5, 6, 127}, + RowIDField: int64(1), + TimestampField: int64(1), + BoolField: false, + Int8Field: int8(1), + Int16Field: int16(1), + Int32Field: int32(1), + Int64Field: int64(1), + FloatField: float32(1), + DoubleField: float64(1), + StringField: string("str"), + BinaryVectorField: []byte{0}, + FloatVectorField: []float32{4, 5, 6, 7}, + Float16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + BFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{2, 3, 4}, []float32{4, 5, 6}), + Int8VectorField: []int8{-128, -5, 6, 127}, + NullableFloatVectorField: nil, + NullableBinaryVectorField: nil, + NullableFloat16VectorField: nil, + NullableBFloat16VectorField: nil, + NullableInt8VectorField: nil, + NullableSparseFloatVectorField: nil, ArrayField: &schemapb.ScalarField{ Data: &schemapb.ScalarField_IntData{ IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}}, diff --git a/internal/storage/payload.go b/internal/storage/payload.go index a5cb7004d2..e7aaaa10dc 100644 --- a/internal/storage/payload.go +++ b/internal/storage/payload.go @@ -40,12 +40,12 @@ type PayloadWriterInterface interface { AddOneArrayToPayload(*schemapb.ScalarField, bool) error AddOneJSONToPayload([]byte, bool) error AddOneGeometryToPayload(msg []byte, isValid bool) error - AddBinaryVectorToPayload([]byte, int) error - AddFloatVectorToPayload([]float32, int) error - AddFloat16VectorToPayload([]byte, int) error - AddBFloat16VectorToPayload([]byte, int) error + AddBinaryVectorToPayload(data []byte, dim int, validData []bool) error + AddFloatVectorToPayload(data []float32, dim int, validData []bool) error + AddFloat16VectorToPayload(data []byte, dim int, validData []bool) error + AddBFloat16VectorToPayload(data []byte, dim int, validData []bool) error AddSparseFloatVectorToPayload(*SparseFloatVectorFieldData) error - AddInt8VectorToPayload([]int8, int) error + AddInt8VectorToPayload(data []int8, dim int, validData []bool) error AddVectorArrayFieldDataToPayload(*VectorArrayFieldData) error FinishPayloadWriter() error GetPayloadBufferFromWriter() ([]byte, error) @@ -72,12 +72,12 @@ type PayloadReaderInterface interface { GetVectorArrayFromPayload() ([]*schemapb.VectorField, error) GetJSONFromPayload() ([][]byte, []bool, error) GetGeometryFromPayload() ([][]byte, []bool, error) - GetBinaryVectorFromPayload() ([]byte, int, error) - GetFloat16VectorFromPayload() ([]byte, int, error) - GetBFloat16VectorFromPayload() ([]byte, int, error) - GetFloatVectorFromPayload() ([]float32, int, error) - GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, error) - GetInt8VectorFromPayload() ([]int8, int, error) + GetBinaryVectorFromPayload() ([]byte, int, []bool, int, error) + GetFloat16VectorFromPayload() ([]byte, int, []bool, int, error) + GetBFloat16VectorFromPayload() ([]byte, int, []bool, int, error) + GetFloatVectorFromPayload() ([]float32, int, []bool, int, error) + GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, []bool, error) + GetInt8VectorFromPayload() ([]int8, int, []bool, int, error) GetPayloadLengthFromReader() (int, error) GetByteArrayDataSet() (*DataSet[parquet.ByteArray, *file.ByteArrayColumnChunkReader], error) diff --git a/internal/storage/payload_reader.go b/internal/storage/payload_reader.go index 90ba26ddf7..73088cb3e9 100644 --- a/internal/storage/payload_reader.go +++ b/internal/storage/payload_reader.go @@ -149,23 +149,23 @@ func (r *PayloadReader) GetDataFromPayload() (interface{}, []bool, int, error) { val, validData, err := r.GetTimestamptzFromPayload() return val, validData, 0, err case schemapb.DataType_BinaryVector: - val, dim, err := r.GetBinaryVectorFromPayload() - return val, nil, dim, err + val, dim, validData, _, err := r.GetBinaryVectorFromPayload() + return val, validData, dim, err case schemapb.DataType_FloatVector: - val, dim, err := r.GetFloatVectorFromPayload() - return val, nil, dim, err + val, dim, validData, _, err := r.GetFloatVectorFromPayload() + return val, validData, dim, err case schemapb.DataType_Float16Vector: - val, dim, err := r.GetFloat16VectorFromPayload() - return val, nil, dim, err + val, dim, validData, _, err := r.GetFloat16VectorFromPayload() + return val, validData, dim, err case schemapb.DataType_BFloat16Vector: - val, dim, err := r.GetBFloat16VectorFromPayload() - return val, nil, dim, err + val, dim, validData, _, err := r.GetBFloat16VectorFromPayload() + return val, validData, dim, err case schemapb.DataType_SparseFloatVector: - val, dim, err := r.GetSparseFloatVectorFromPayload() - return val, nil, dim, err + val, dim, validData, err := r.GetSparseFloatVectorFromPayload() + return val, validData, dim, err case schemapb.DataType_Int8Vector: - val, dim, err := r.GetInt8VectorFromPayload() - return val, nil, dim, err + val, dim, validData, _, err := r.GetInt8VectorFromPayload() + return val, validData, dim, err case schemapb.DataType_String, schemapb.DataType_VarChar: val, validData, err := r.GetStringFromPayload() return val, validData, 0, err @@ -681,96 +681,434 @@ func readByteAndConvert[T any](r *PayloadReader, convert func(parquet.ByteArray) return ret, nil } -// GetBinaryVectorFromPayload returns vector, dimension, error -func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) { +// GetBinaryVectorFromPayload returns vector, dimension, validData, numRows, error +func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, []bool, int, error) { if r.colType != schemapb.DataType_BinaryVector { - return nil, -1, fmt.Errorf("failed to get binary vector from datatype %v", r.colType.String()) + return nil, -1, nil, 0, fmt.Errorf("failed to get binary vector from datatype %v", r.colType.String()) + } + + if r.nullable { + fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator) + if err != nil { + return nil, -1, nil, 0, err + } + + arrowSchema, err := fileReader.Schema() + if err != nil { + return nil, -1, nil, 0, err + } + + if arrowSchema.NumFields() != 1 { + return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields()) + } + + field := arrowSchema.Field(0) + var dim int + + if field.Type.ID() == arrow.BINARY { + if !field.HasMetadata() { + return nil, -1, nil, 0, fmt.Errorf("nullable binary vector field is missing metadata") + } + metadata := field.Metadata + dimStr, ok := metadata.GetValue("dim") + if !ok { + return nil, -1, nil, 0, fmt.Errorf("nullable binary vector metadata missing required 'dim' field") + } + var err error + dim, err = strconv.Atoi(dimStr) + if err != nil { + return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err) + } + dim = dim / 8 + } else { + col, err := r.reader.RowGroup(0).Column(0) + if err != nil { + return nil, -1, nil, 0, err + } + dim = col.Descriptor().TypeLength() + } + + table, err := fileReader.ReadTable(context.Background()) + if err != nil { + return nil, -1, nil, 0, err + } + defer table.Release() + + if table.NumCols() != 1 { + return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols()) + } + + column := table.Column(0) + validCount := 0 + for _, chunk := range column.Data().Chunks() { + for i := 0; i < chunk.Len(); i++ { + if chunk.IsValid(i) { + validCount++ + } + } + } + + ret := make([]byte, validCount*dim) + validData := make([]bool, r.numRows) + offset := 0 + dataIdx := 0 + for _, chunk := range column.Data().Chunks() { + binaryArray, ok := chunk.(*array.Binary) + if !ok { + return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk) + } + for i := 0; i < binaryArray.Len(); i++ { + if binaryArray.IsValid(i) { + validData[offset+i] = true + bytes := binaryArray.Value(i) + copy(ret[dataIdx*dim:(dataIdx+1)*dim], bytes) + dataIdx++ + } + } + offset += binaryArray.Len() + } + + return ret, dim * 8, validData, int(r.numRows), nil } col, err := r.reader.RowGroup(0).Column(0) if err != nil { - return nil, -1, err + return nil, -1, nil, 0, err } dim := col.Descriptor().TypeLength() + values := make([]parquet.FixedLenByteArray, r.numRows) valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows) if err != nil { - return nil, -1, err + return nil, -1, nil, 0, err } if valuesRead != r.numRows { - return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead) + return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead) } ret := make([]byte, int64(dim)*r.numRows) for i := 0; i < int(r.numRows); i++ { copy(ret[i*dim:(i+1)*dim], values[i]) } - return ret, dim * 8, nil + return ret, dim * 8, nil, int(r.numRows), nil } -// GetFloat16VectorFromPayload returns vector, dimension, error -func (r *PayloadReader) GetFloat16VectorFromPayload() ([]byte, int, error) { +// GetFloat16VectorFromPayload returns vector, dimension, validData, numRows, error +func (r *PayloadReader) GetFloat16VectorFromPayload() ([]byte, int, []bool, int, error) { if r.colType != schemapb.DataType_Float16Vector { - return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String()) + return nil, -1, nil, 0, fmt.Errorf("failed to get float16 vector from datatype %v", r.colType.String()) } + if r.nullable { + fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator) + if err != nil { + return nil, -1, nil, 0, err + } + + arrowSchema, err := fileReader.Schema() + if err != nil { + return nil, -1, nil, 0, err + } + + if arrowSchema.NumFields() != 1 { + return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields()) + } + + field := arrowSchema.Field(0) + var dim int + + if field.Type.ID() == arrow.BINARY { + if !field.HasMetadata() { + return nil, -1, nil, 0, fmt.Errorf("nullable float16 vector field is missing metadata") + } + metadata := field.Metadata + dimStr, ok := metadata.GetValue("dim") + if !ok { + return nil, -1, nil, 0, fmt.Errorf("nullable float16 vector metadata missing required 'dim' field") + } + var err error + dim, err = strconv.Atoi(dimStr) + if err != nil { + return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err) + } + } else { + col, err := r.reader.RowGroup(0).Column(0) + if err != nil { + return nil, -1, nil, 0, err + } + dim = col.Descriptor().TypeLength() / 2 + } + + table, err := fileReader.ReadTable(context.Background()) + if err != nil { + return nil, -1, nil, 0, err + } + defer table.Release() + + if table.NumCols() != 1 { + return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols()) + } + + column := table.Column(0) + validCount := 0 + for _, chunk := range column.Data().Chunks() { + for i := 0; i < chunk.Len(); i++ { + if chunk.IsValid(i) { + validCount++ + } + } + } + + ret := make([]byte, validCount*dim*2) + validData := make([]bool, r.numRows) + offset := 0 + dataIdx := 0 + for _, chunk := range column.Data().Chunks() { + binaryArray, ok := chunk.(*array.Binary) + if !ok { + return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk) + } + for i := 0; i < binaryArray.Len(); i++ { + if binaryArray.IsValid(i) { + validData[offset+i] = true + bytes := binaryArray.Value(i) + copy(ret[dataIdx*dim*2:(dataIdx+1)*dim*2], bytes) + dataIdx++ + } + } + offset += binaryArray.Len() + } + + return ret, dim, validData, int(r.numRows), nil + } + col, err := r.reader.RowGroup(0).Column(0) if err != nil { - return nil, -1, err + return nil, -1, nil, 0, err } + dim := col.Descriptor().TypeLength() / 2 + values := make([]parquet.FixedLenByteArray, r.numRows) valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows) if err != nil { - return nil, -1, err + return nil, -1, nil, 0, err } if valuesRead != r.numRows { - return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead) + return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead) } ret := make([]byte, int64(dim*2)*r.numRows) for i := 0; i < int(r.numRows); i++ { copy(ret[i*dim*2:(i+1)*dim*2], values[i]) } - return ret, dim, nil + return ret, dim, nil, int(r.numRows), nil } -// GetBFloat16VectorFromPayload returns vector, dimension, error -func (r *PayloadReader) GetBFloat16VectorFromPayload() ([]byte, int, error) { +// GetBFloat16VectorFromPayload returns vector, dimension, validData, numRows, error +func (r *PayloadReader) GetBFloat16VectorFromPayload() ([]byte, int, []bool, int, error) { if r.colType != schemapb.DataType_BFloat16Vector { - return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String()) + return nil, -1, nil, 0, fmt.Errorf("failed to get bfloat16 vector from datatype %v", r.colType.String()) } + if r.nullable { + fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator) + if err != nil { + return nil, -1, nil, 0, err + } + + arrowSchema, err := fileReader.Schema() + if err != nil { + return nil, -1, nil, 0, err + } + + if arrowSchema.NumFields() != 1 { + return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields()) + } + + field := arrowSchema.Field(0) + var dim int + + if field.Type.ID() == arrow.BINARY { + if !field.HasMetadata() { + return nil, -1, nil, 0, fmt.Errorf("nullable bfloat16 vector field is missing metadata") + } + metadata := field.Metadata + dimStr, ok := metadata.GetValue("dim") + if !ok { + return nil, -1, nil, 0, fmt.Errorf("nullable bfloat16 vector metadata missing required 'dim' field") + } + var err error + dim, err = strconv.Atoi(dimStr) + if err != nil { + return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err) + } + } else { + col, err := r.reader.RowGroup(0).Column(0) + if err != nil { + return nil, -1, nil, 0, err + } + dim = col.Descriptor().TypeLength() / 2 + } + + table, err := fileReader.ReadTable(context.Background()) + if err != nil { + return nil, -1, nil, 0, err + } + defer table.Release() + + if table.NumCols() != 1 { + return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols()) + } + + column := table.Column(0) + validCount := 0 + for _, chunk := range column.Data().Chunks() { + for i := 0; i < chunk.Len(); i++ { + if chunk.IsValid(i) { + validCount++ + } + } + } + + ret := make([]byte, validCount*dim*2) + validData := make([]bool, r.numRows) + offset := 0 + dataIdx := 0 + for _, chunk := range column.Data().Chunks() { + binaryArray, ok := chunk.(*array.Binary) + if !ok { + return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk) + } + for i := 0; i < binaryArray.Len(); i++ { + if binaryArray.IsValid(i) { + validData[offset+i] = true + bytes := binaryArray.Value(i) + copy(ret[dataIdx*dim*2:(dataIdx+1)*dim*2], bytes) + dataIdx++ + } + } + offset += binaryArray.Len() + } + + return ret, dim, validData, int(r.numRows), nil + } + col, err := r.reader.RowGroup(0).Column(0) if err != nil { - return nil, -1, err + return nil, -1, nil, 0, err } + dim := col.Descriptor().TypeLength() / 2 + values := make([]parquet.FixedLenByteArray, r.numRows) valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows) if err != nil { - return nil, -1, err + return nil, -1, nil, 0, err } if valuesRead != r.numRows { - return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead) + return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead) } ret := make([]byte, int64(dim*2)*r.numRows) for i := 0; i < int(r.numRows); i++ { copy(ret[i*dim*2:(i+1)*dim*2], values[i]) } - return ret, dim, nil + return ret, dim, nil, int(r.numRows), nil } -// GetFloatVectorFromPayload returns vector, dimension, error -func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) { +// GetFloatVectorFromPayload returns vector, dimension, validData, numRows, error +func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, []bool, int, error) { if r.colType != schemapb.DataType_FloatVector { - return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String()) + return nil, -1, nil, 0, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String()) } + if r.nullable { + fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator) + if err != nil { + return nil, -1, nil, 0, err + } + + arrowSchema, err := fileReader.Schema() + if err != nil { + return nil, -1, nil, 0, err + } + + if arrowSchema.NumFields() != 1 { + return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields()) + } + + field := arrowSchema.Field(0) + var dim int + + if field.Type.ID() == arrow.BINARY { + if !field.HasMetadata() { + return nil, -1, nil, 0, fmt.Errorf("nullable float vector field is missing metadata") + } + metadata := field.Metadata + dimStr, ok := metadata.GetValue("dim") + if !ok { + return nil, -1, nil, 0, fmt.Errorf("nullable float vector metadata missing required 'dim' field") + } + var err error + dim, err = strconv.Atoi(dimStr) + if err != nil { + return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err) + } + } else { + col, err := r.reader.RowGroup(0).Column(0) + if err != nil { + return nil, -1, nil, 0, err + } + dim = col.Descriptor().TypeLength() / 4 + } + + table, err := fileReader.ReadTable(context.Background()) + if err != nil { + return nil, -1, nil, 0, err + } + defer table.Release() + + if table.NumCols() != 1 { + return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols()) + } + + column := table.Column(0) + validCount := 0 + for _, chunk := range column.Data().Chunks() { + for i := 0; i < chunk.Len(); i++ { + if chunk.IsValid(i) { + validCount++ + } + } + } + + ret := make([]float32, validCount*dim) + validData := make([]bool, r.numRows) + offset := 0 + dataIdx := 0 + for _, chunk := range column.Data().Chunks() { + binaryArray, ok := chunk.(*array.Binary) + if !ok { + return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk) + } + for i := 0; i < binaryArray.Len(); i++ { + if binaryArray.IsValid(i) { + validData[offset+i] = true + bytes := binaryArray.Value(i) + copy(arrow.Float32Traits.CastToBytes(ret[dataIdx*dim:(dataIdx+1)*dim]), bytes) + dataIdx++ + } + } + offset += binaryArray.Len() + } + + return ret, dim, validData, int(r.numRows), nil + } + col, err := r.reader.RowGroup(0).Column(0) if err != nil { - return nil, -1, err + return nil, -1, nil, 0, err } dim := col.Descriptor().TypeLength() / 4 @@ -778,38 +1116,89 @@ func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) { values := make([]parquet.FixedLenByteArray, r.numRows) valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows) if err != nil { - return nil, -1, err + return nil, -1, nil, 0, err } if valuesRead != r.numRows { - return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead) + return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead) } ret := make([]float32, int64(dim)*r.numRows) for i := 0; i < int(r.numRows); i++ { copy(arrow.Float32Traits.CastToBytes(ret[i*dim:(i+1)*dim]), values[i]) } - return ret, dim, nil + return ret, dim, nil, int(r.numRows), nil } -func (r *PayloadReader) GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, error) { +// GetSparseFloatVectorFromPayload returns fieldData, dimension, validData, error +func (r *PayloadReader) GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, []bool, error) { if !typeutil.IsSparseFloatVectorType(r.colType) { - return nil, -1, fmt.Errorf("failed to get sparse float vector from datatype %v", r.colType.String()) + return nil, -1, nil, fmt.Errorf("failed to get sparse float vector from datatype %v", r.colType.String()) } + + if r.nullable { + fieldData := &SparseFloatVectorFieldData{} + validData := make([]bool, r.numRows) + + fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator) + if err != nil { + return nil, -1, nil, err + } + + table, err := fileReader.ReadTable(context.Background()) + if err != nil { + return nil, -1, nil, err + } + defer table.Release() + + if table.NumCols() != 1 { + return nil, -1, nil, fmt.Errorf("expected 1 column, got %d", table.NumCols()) + } + + column := table.Column(0) + offset := 0 + for _, chunk := range column.Data().Chunks() { + binaryArray, ok := chunk.(*array.Binary) + if !ok { + return nil, -1, nil, fmt.Errorf("expected Binary array, got %T", chunk) + } + + for i := 0; i < binaryArray.Len(); i++ { + validData[offset+i] = binaryArray.IsValid(i) + if validData[offset+i] { + value := binaryArray.Value(i) + if len(value)%8 != 0 { + return nil, -1, nil, errors.New("invalid bytesData length") + } + fieldData.Contents = append(fieldData.Contents, value) + rowDim := typeutil.SparseFloatRowDim(value) + if rowDim > fieldData.Dim { + fieldData.Dim = rowDim + } + } else { + fieldData.Contents = append(fieldData.Contents, nil) + } + } + offset += binaryArray.Len() + } + + return fieldData, int(fieldData.Dim), validData, nil + } + values := make([]parquet.ByteArray, r.numRows) valuesRead, err := ReadDataFromAllRowGroups[parquet.ByteArray, *file.ByteArrayColumnChunkReader](r.reader, values, 0, r.numRows) if err != nil { - return nil, -1, err + return nil, -1, nil, err } if valuesRead != r.numRows { - return nil, -1, fmt.Errorf("expect %d binary, but got = %d", r.numRows, valuesRead) + return nil, -1, nil, fmt.Errorf("expect %d binary, but got = %d", r.numRows, valuesRead) } fieldData := &SparseFloatVectorFieldData{} for _, value := range values { if len(value)%8 != 0 { - return nil, -1, errors.New("invalid bytesData length") + return nil, -1, nil, errors.New("invalid bytesData length") } fieldData.Contents = append(fieldData.Contents, value) @@ -819,17 +1208,101 @@ func (r *PayloadReader) GetSparseFloatVectorFromPayload() (*SparseFloatVectorFie } } - return fieldData, int(fieldData.Dim), nil + return fieldData, int(fieldData.Dim), nil, nil } -// GetInt8VectorFromPayload returns vector, dimension, error -func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, error) { +// GetInt8VectorFromPayload returns vector, dimension, validData, numRows, error +func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, []bool, int, error) { if r.colType != schemapb.DataType_Int8Vector { - return nil, -1, fmt.Errorf("failed to get int8 vector from datatype %v", r.colType.String()) + return nil, -1, nil, 0, fmt.Errorf("failed to get int8 vector from datatype %v", r.colType.String()) } + if r.nullable { + fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator) + if err != nil { + return nil, -1, nil, 0, err + } + + arrowSchema, err := fileReader.Schema() + if err != nil { + return nil, -1, nil, 0, err + } + + if arrowSchema.NumFields() != 1 { + return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields()) + } + + field := arrowSchema.Field(0) + var dim int + + if field.Type.ID() == arrow.BINARY { + if !field.HasMetadata() { + return nil, -1, nil, 0, fmt.Errorf("nullable int8 vector field is missing metadata") + } + metadata := field.Metadata + dimStr, ok := metadata.GetValue("dim") + if !ok { + return nil, -1, nil, 0, fmt.Errorf("nullable int8 vector metadata missing required 'dim' field") + } + var err error + dim, err = strconv.Atoi(dimStr) + if err != nil { + return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err) + } + } else { + col, err := r.reader.RowGroup(0).Column(0) + if err != nil { + return nil, -1, nil, 0, err + } + dim = col.Descriptor().TypeLength() + } + + table, err := fileReader.ReadTable(context.Background()) + if err != nil { + return nil, -1, nil, 0, err + } + defer table.Release() + + if table.NumCols() != 1 { + return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols()) + } + + column := table.Column(0) + validCount := 0 + for _, chunk := range column.Data().Chunks() { + for i := 0; i < chunk.Len(); i++ { + if chunk.IsValid(i) { + validCount++ + } + } + } + + ret := make([]int8, validCount*dim) + validData := make([]bool, r.numRows) + offset := 0 + dataIdx := 0 + for _, chunk := range column.Data().Chunks() { + binaryArray, ok := chunk.(*array.Binary) + if !ok { + return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk) + } + for i := 0; i < binaryArray.Len(); i++ { + if binaryArray.IsValid(i) { + validData[offset+i] = true + bytes := binaryArray.Value(i) + int8Vals := arrow.Int8Traits.CastFromBytes(bytes) + copy(ret[dataIdx*dim:(dataIdx+1)*dim], int8Vals) + dataIdx++ + } + } + offset += binaryArray.Len() + } + + return ret, dim, validData, int(r.numRows), nil + } + col, err := r.reader.RowGroup(0).Column(0) if err != nil { - return nil, -1, err + return nil, -1, nil, 0, err } dim := col.Descriptor().TypeLength() @@ -837,11 +1310,11 @@ func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, error) { values := make([]parquet.FixedLenByteArray, r.numRows) valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows) if err != nil { - return nil, -1, err + return nil, -1, nil, 0, err } if valuesRead != r.numRows { - return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead) + return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead) } ret := make([]int8, int64(dim)*r.numRows) @@ -849,7 +1322,7 @@ func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, error) { int8Vals := arrow.Int8Traits.CastFromBytes(values[i]) copy(ret[i*dim:(i+1)*dim], int8Vals) } - return ret, dim, nil + return ret, dim, nil, int(r.numRows), nil } func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) { diff --git a/internal/storage/payload_test.go b/internal/storage/payload_test.go index 5d848a0bb9..a5978eb8d0 100644 --- a/internal/storage/payload_test.go +++ b/internal/storage/payload_test.go @@ -484,7 +484,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { in2[i] = 1 } - err = w.AddBinaryVectorToPayload(in, 8) + err = w.AddBinaryVectorToPayload(in, 8, nil) assert.NoError(t, err) err = w.AddDataToPayloadForUT(in2, nil) assert.NoError(t, err) @@ -505,7 +505,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.NoError(t, err) assert.Equal(t, length, 24) - binVecs, dim, err := r.GetBinaryVectorFromPayload() + binVecs, dim, _, _, err := r.GetBinaryVectorFromPayload() assert.NoError(t, err) assert.Equal(t, 8, dim) assert.Equal(t, 24, len(binVecs)) @@ -524,7 +524,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { require.Nil(t, err) require.NotNil(t, w) - err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1) + err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1, nil) assert.NoError(t, err) err = w.AddDataToPayloadForUT([]float32{3.0, 4.0}, nil) assert.NoError(t, err) @@ -545,7 +545,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.NoError(t, err) assert.Equal(t, length, 4) - floatVecs, dim, err := r.GetFloatVectorFromPayload() + floatVecs, dim, _, _, err := r.GetFloatVectorFromPayload() assert.NoError(t, err) assert.Equal(t, 1, dim) assert.Equal(t, 4, len(floatVecs)) @@ -566,7 +566,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { require.Nil(t, err) require.NotNil(t, w) - err = w.AddFloat16VectorToPayload([]byte{1, 2}, 1) + err = w.AddFloat16VectorToPayload([]byte{1, 2}, 1, nil) assert.NoError(t, err) err = w.AddDataToPayloadForUT([]byte{3, 4}, nil) assert.NoError(t, err) @@ -587,7 +587,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.NoError(t, err) assert.Equal(t, length, 2) - float16Vecs, dim, err := r.GetFloat16VectorFromPayload() + float16Vecs, dim, _, _, err := r.GetFloat16VectorFromPayload() assert.NoError(t, err) assert.Equal(t, 1, dim) assert.Equal(t, 4, len(float16Vecs)) @@ -608,7 +608,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { require.Nil(t, err) require.NotNil(t, w) - err = w.AddBFloat16VectorToPayload([]byte{1, 2}, 1) + err = w.AddBFloat16VectorToPayload([]byte{1, 2}, 1, nil) assert.NoError(t, err) err = w.AddDataToPayloadForUT([]byte{3, 4}, nil) assert.NoError(t, err) @@ -629,7 +629,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.NoError(t, err) assert.Equal(t, length, 2) - bfloat16Vecs, dim, err := r.GetBFloat16VectorFromPayload() + bfloat16Vecs, dim, _, _, err := r.GetBFloat16VectorFromPayload() assert.NoError(t, err) assert.Equal(t, 1, dim) assert.Equal(t, 4, len(bfloat16Vecs)) @@ -689,7 +689,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.NoError(t, err) assert.Equal(t, length, 6) - floatVecs, dim, err := r.GetSparseFloatVectorFromPayload() + floatVecs, dim, _, err := r.GetSparseFloatVectorFromPayload() assert.NoError(t, err) assert.Equal(t, 600, dim) assert.Equal(t, 6, len(floatVecs.Contents)) @@ -743,7 +743,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.NoError(t, err) assert.Equal(t, length, 3) - floatVecs, dim, err := r.GetSparseFloatVectorFromPayload() + floatVecs, dim, _, err := r.GetSparseFloatVectorFromPayload() assert.NoError(t, err) assert.Equal(t, actualDim, dim) assert.Equal(t, 3, len(floatVecs.Contents)) @@ -951,16 +951,16 @@ func TestPayload_ReaderAndWriter(t *testing.T) { err = w.FinishPayloadWriter() assert.NoError(t, err) - err = w.AddBinaryVectorToPayload([]byte{}, 8) + err = w.AddBinaryVectorToPayload([]byte{}, 8, nil) assert.Error(t, err) - err = w.AddBinaryVectorToPayload([]byte{1}, 0) + err = w.AddBinaryVectorToPayload([]byte{1}, 0, nil) assert.Error(t, err) - err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8) + err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil) assert.Error(t, err) err = w.FinishPayloadWriter() assert.Error(t, err) - err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8) + err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil) assert.Error(t, err) }) t.Run("TestAddFloatVectorAfterFinish", func(t *testing.T) { @@ -972,16 +972,16 @@ func TestPayload_ReaderAndWriter(t *testing.T) { err = w.FinishPayloadWriter() assert.NoError(t, err) - err = w.AddFloatVectorToPayload([]float32{}, 8) + err = w.AddFloatVectorToPayload([]float32{}, 8, nil) assert.Error(t, err) - err = w.AddFloatVectorToPayload([]float32{1.0}, 0) + err = w.AddFloatVectorToPayload([]float32{1.0}, 0, nil) assert.Error(t, err) - err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8) + err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8, nil) assert.Error(t, err) err = w.FinishPayloadWriter() assert.Error(t, err) - err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8) + err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8, nil) assert.Error(t, err) }) t.Run("TestAddFloat16VectorAfterFinish", func(t *testing.T) { @@ -990,22 +990,22 @@ func TestPayload_ReaderAndWriter(t *testing.T) { require.NotNil(t, w) defer w.Close() - err = w.AddFloat16VectorToPayload([]byte{}, 8) + err = w.AddFloat16VectorToPayload([]byte{}, 8, nil) assert.Error(t, err) err = w.FinishPayloadWriter() assert.NoError(t, err) - err = w.AddFloat16VectorToPayload([]byte{}, 8) + err = w.AddFloat16VectorToPayload([]byte{}, 8, nil) assert.Error(t, err) - err = w.AddFloat16VectorToPayload([]byte{1}, 0) + err = w.AddFloat16VectorToPayload([]byte{1}, 0, nil) assert.Error(t, err) - err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8) + err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil) assert.Error(t, err) err = w.FinishPayloadWriter() assert.Error(t, err) - err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8) + err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil) assert.Error(t, err) }) t.Run("TestAddBFloat16VectorAfterFinish", func(t *testing.T) { @@ -1014,22 +1014,22 @@ func TestPayload_ReaderAndWriter(t *testing.T) { require.NotNil(t, w) defer w.Close() - err = w.AddBFloat16VectorToPayload([]byte{}, 8) + err = w.AddBFloat16VectorToPayload([]byte{}, 8, nil) assert.Error(t, err) err = w.FinishPayloadWriter() assert.NoError(t, err) - err = w.AddBFloat16VectorToPayload([]byte{}, 8) + err = w.AddBFloat16VectorToPayload([]byte{}, 8, nil) assert.Error(t, err) - err = w.AddBFloat16VectorToPayload([]byte{1}, 0) + err = w.AddBFloat16VectorToPayload([]byte{1}, 0, nil) assert.Error(t, err) - err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8) + err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil) assert.Error(t, err) err = w.FinishPayloadWriter() assert.Error(t, err) - err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8) + err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil) assert.Error(t, err) }) t.Run("TestAddSparseFloatVectorAfterFinish", func(t *testing.T) { @@ -1481,11 +1481,11 @@ func TestPayload_ReaderAndWriter(t *testing.T) { r, err := NewPayloadReader(schemapb.DataType_BinaryVector, buffer, false) assert.NoError(t, err) - _, _, err = r.GetBinaryVectorFromPayload() + _, _, _, _, err = r.GetBinaryVectorFromPayload() assert.Error(t, err) r.colType = 999 - _, _, err = r.GetBinaryVectorFromPayload() + _, _, _, _, err = r.GetBinaryVectorFromPayload() assert.Error(t, err) }) t.Run("TestGetBinaryVectorError2", func(t *testing.T) { @@ -1493,7 +1493,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { require.Nil(t, err) require.NotNil(t, w) - err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8) + err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil) assert.NoError(t, err) err = w.FinishPayloadWriter() @@ -1506,7 +1506,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.NoError(t, err) r.numRows = 99 - _, _, err = r.GetBinaryVectorFromPayload() + _, _, _, _, err = r.GetBinaryVectorFromPayload() assert.Error(t, err) }) t.Run("TestGetFloatVectorError", func(t *testing.T) { @@ -1526,11 +1526,11 @@ func TestPayload_ReaderAndWriter(t *testing.T) { r, err := NewPayloadReader(schemapb.DataType_FloatVector, buffer, false) assert.NoError(t, err) - _, _, err = r.GetFloatVectorFromPayload() + _, _, _, _, err = r.GetFloatVectorFromPayload() assert.Error(t, err) r.colType = 999 - _, _, err = r.GetFloatVectorFromPayload() + _, _, _, _, err = r.GetFloatVectorFromPayload() assert.Error(t, err) }) t.Run("TestGetFloatVectorError2", func(t *testing.T) { @@ -1538,7 +1538,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { require.Nil(t, err) require.NotNil(t, w) - err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8) + err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8, nil) assert.NoError(t, err) err = w.FinishPayloadWriter() @@ -1551,7 +1551,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { assert.NoError(t, err) r.numRows = 99 - _, _, err = r.GetFloatVectorFromPayload() + _, _, _, _, err = r.GetFloatVectorFromPayload() assert.Error(t, err) }) @@ -1599,7 +1599,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { w, err := NewPayloadWriter(schemapb.DataType_FloatVector) assert.NoError(t, err) - err = w.AddFloatVectorToPayload(vec, 128) + err = w.AddFloatVectorToPayload(vec, 128, nil) assert.NoError(t, err) err = w.FinishPayloadWriter() @@ -2234,19 +2234,548 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) { w.ReleasePayloadWriter() }) - t.Run("TestBinaryVector", func(t *testing.T) { - _, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithNullable(true), WithDim(8)) - assert.ErrorIs(t, err, merr.ErrParameterInvalid) + t.Run("TestFloatVector", func(t *testing.T) { + dim := 128 + numRows := 100 + + type testCase struct { + name string + validDataSetup func([]bool) int + } + + testCases := []testCase{ + { + name: "half null", + validDataSetup: func(validData []bool) int { + validCount := 0 + for i := 0; i < numRows; i++ { + if i%2 == 0 { + validData[i] = true + validCount++ + } + } + return validCount + }, + }, + { + name: "all valid", + validDataSetup: func(validData []bool) int { + for i := 0; i < numRows; i++ { + validData[i] = true + } + return numRows + }, + }, + { + name: "all null", + validDataSetup: func(validData []bool) int { + return 0 + }, + }, + } + + for _, tc := range testCases { + w, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithDim(dim), WithNullable(true)) + require.NoError(t, err) + + validData := make([]bool, numRows) + validCount := tc.validDataSetup(validData) + + data := make([]float32, validCount*dim) + dataIdx := 0 + for i := 0; i < numRows; i++ { + if validData[i] { + for j := 0; j < dim; j++ { + data[dataIdx*dim+j] = float32(i*100 + j) + } + dataIdx++ + } + } + + err = w.AddFloatVectorToPayload(data, dim, validData) + require.NoError(t, err) + + err = w.FinishPayloadWriter() + require.NoError(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + require.NoError(t, err) + + r, err := NewPayloadReader(schemapb.DataType_FloatVector, buffer, true) + require.NoError(t, err) + + readData, readDim, readValid, readNumRows, err := r.GetFloatVectorFromPayload() + require.NoError(t, err) + require.Equal(t, dim, readDim) + require.Equal(t, numRows, readNumRows) + require.Equal(t, numRows, len(readValid)) + + dataIdx = 0 + for i := 0; i < numRows; i++ { + require.Equal(t, validData[i], readValid[i]) + if validData[i] { + pos := dataIdx + for j := 0; j < dim; j++ { + require.Equal(t, data[dataIdx*dim+j], readData[pos*dim+j]) + } + dataIdx++ + } + } + } }) - t.Run("TestFloatVector", func(t *testing.T) { - _, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithNullable(true), WithDim(1)) - assert.ErrorIs(t, err, merr.ErrParameterInvalid) + t.Run("TestBinaryVector", func(t *testing.T) { + dim := 128 + numRows := 100 + + type testCase struct { + name string + validDataSetup func([]bool) int + } + + testCases := []testCase{ + { + name: "partial null", + validDataSetup: func(validData []bool) int { + validCount := 0 + for i := 0; i < numRows; i++ { + if i%3 == 0 { + validData[i] = true + validCount++ + } + } + return validCount + }, + }, + { + name: "all valid", + validDataSetup: func(validData []bool) int { + for i := 0; i < numRows; i++ { + validData[i] = true + } + return numRows + }, + }, + { + name: "all null", + validDataSetup: func(validData []bool) int { + return 0 + }, + }, + } + + for _, tc := range testCases { + w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithDim(dim), WithNullable(true)) + require.NoError(t, err) + + validData := make([]bool, numRows) + validCount := tc.validDataSetup(validData) + + data := make([]byte, validCount*dim/8) + dataIdx := 0 + for i := 0; i < numRows; i++ { + if validData[i] { + for j := 0; j < dim/8; j++ { + data[dataIdx*dim/8+j] = byte(i + j) + } + dataIdx++ + } + } + + err = w.AddBinaryVectorToPayload(data, dim, validData) + require.NoError(t, err) + + err = w.FinishPayloadWriter() + require.NoError(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + require.NoError(t, err) + + r, err := NewPayloadReader(schemapb.DataType_BinaryVector, buffer, true) + require.NoError(t, err) + + readData, readDim, readValid, readNumRows, err := r.GetBinaryVectorFromPayload() + require.NoError(t, err) + require.Equal(t, dim, readDim) + require.Equal(t, numRows, readNumRows) + require.Equal(t, numRows, len(readValid)) + + dataIdx = 0 + for i := 0; i < numRows; i++ { + require.Equal(t, validData[i], readValid[i]) + if validData[i] { + pos := dataIdx + for j := 0; j < dim/8; j++ { + require.Equal(t, data[dataIdx*dim/8+j], readData[pos*dim/8+j]) + } + dataIdx++ + } + } + } }) t.Run("TestFloat16Vector", func(t *testing.T) { - _, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithNullable(true), WithDim(1)) - assert.ErrorIs(t, err, merr.ErrParameterInvalid) + dim := 128 + numRows := 100 + + type testCase struct { + name string + validDataSetup func([]bool) int + } + + testCases := []testCase{ + { + name: "partial null", + validDataSetup: func(validData []bool) int { + validCount := 0 + for i := 0; i < numRows; i++ { + if i%2 == 1 { + validData[i] = true + validCount++ + } + } + return validCount + }, + }, + { + name: "all valid", + validDataSetup: func(validData []bool) int { + for i := 0; i < numRows; i++ { + validData[i] = true + } + return numRows + }, + }, + { + name: "all null", + validDataSetup: func(validData []bool) int { + return 0 + }, + }, + } + + for _, tc := range testCases { + w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithDim(dim), WithNullable(true)) + require.NoError(t, err) + + validData := make([]bool, numRows) + validCount := tc.validDataSetup(validData) + + data := make([]byte, validCount*dim*2) + dataIdx := 0 + for i := 0; i < numRows; i++ { + if validData[i] { + for j := 0; j < dim*2; j++ { + data[dataIdx*dim*2+j] = byte((i*10 + j) % 256) + } + dataIdx++ + } + } + + err = w.AddFloat16VectorToPayload(data, dim, validData) + require.NoError(t, err) + + err = w.FinishPayloadWriter() + require.NoError(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + require.NoError(t, err) + + r, err := NewPayloadReader(schemapb.DataType_Float16Vector, buffer, true) + require.NoError(t, err) + + readData, readDim, readValid, readNumRows, err := r.GetFloat16VectorFromPayload() + require.NoError(t, err) + require.Equal(t, dim, readDim) + require.Equal(t, numRows, readNumRows) + require.Equal(t, numRows, len(readValid)) + + dataIdx = 0 + for i := 0; i < numRows; i++ { + require.Equal(t, validData[i], readValid[i]) + if validData[i] { + pos := dataIdx + for j := 0; j < dim*2; j++ { + require.Equal(t, data[dataIdx*dim*2+j], readData[pos*dim*2+j]) + } + dataIdx++ + } + } + } + }) + + t.Run("TestBFloat16Vector", func(t *testing.T) { + dim := 128 + numRows := 100 + + type testCase struct { + name string + validDataSetup func([]bool) int + } + + testCases := []testCase{ + { + name: "partial null", + validDataSetup: func(validData []bool) int { + validCount := 0 + for i := 0; i < numRows; i++ { + if (i+1)%3 != 0 { + validData[i] = true + validCount++ + } + } + return validCount + }, + }, + { + name: "all valid", + validDataSetup: func(validData []bool) int { + for i := 0; i < numRows; i++ { + validData[i] = true + } + return numRows + }, + }, + { + name: "all null", + validDataSetup: func(validData []bool) int { + return 0 + }, + }, + } + + for _, tc := range testCases { + w, err := NewPayloadWriter(schemapb.DataType_BFloat16Vector, WithDim(dim), WithNullable(true)) + require.NoError(t, err) + + validData := make([]bool, numRows) + validCount := tc.validDataSetup(validData) + + data := make([]byte, validCount*dim*2) + dataIdx := 0 + for i := 0; i < numRows; i++ { + if validData[i] { + for j := 0; j < dim*2; j++ { + data[dataIdx*dim*2+j] = byte((i*20 + j) % 256) + } + dataIdx++ + } + } + + err = w.AddBFloat16VectorToPayload(data, dim, validData) + require.NoError(t, err) + + err = w.FinishPayloadWriter() + require.NoError(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + require.NoError(t, err) + + r, err := NewPayloadReader(schemapb.DataType_BFloat16Vector, buffer, true) + require.NoError(t, err) + + readData, readDim, readValid, readNumRows, err := r.GetBFloat16VectorFromPayload() + require.NoError(t, err) + require.Equal(t, dim, readDim) + require.Equal(t, numRows, readNumRows) + require.Equal(t, numRows, len(readValid)) + + dataIdx = 0 + for i := 0; i < numRows; i++ { + require.Equal(t, validData[i], readValid[i]) + if validData[i] { + pos := dataIdx + for j := 0; j < dim*2; j++ { + require.Equal(t, data[dataIdx*dim*2+j], readData[pos*dim*2+j]) + } + dataIdx++ + } + } + } + }) + + t.Run("TestInt8Vector", func(t *testing.T) { + dim := 128 + numRows := 100 + + type testCase struct { + name string + validDataSetup func([]bool) int + } + + testCases := []testCase{ + { + name: "partial null", + validDataSetup: func(validData []bool) int { + validCount := 0 + for i := 0; i < numRows; i++ { + if i < numRows/2 { + validData[i] = true + validCount++ + } + } + return validCount + }, + }, + { + name: "all valid", + validDataSetup: func(validData []bool) int { + for i := 0; i < numRows; i++ { + validData[i] = true + } + return numRows + }, + }, + { + name: "all null", + validDataSetup: func(validData []bool) int { + return 0 + }, + }, + } + + for _, tc := range testCases { + w, err := NewPayloadWriter(schemapb.DataType_Int8Vector, WithDim(dim), WithNullable(true)) + require.NoError(t, err) + + validData := make([]bool, numRows) + validCount := tc.validDataSetup(validData) + + data := make([]int8, validCount*dim) + dataIdx := 0 + for i := 0; i < numRows; i++ { + if validData[i] { + for j := 0; j < dim; j++ { + data[dataIdx*dim+j] = int8((i*10 + j) % 128) + } + dataIdx++ + } + } + + err = w.AddInt8VectorToPayload(data, dim, validData) + require.NoError(t, err) + + err = w.FinishPayloadWriter() + require.NoError(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + require.NoError(t, err) + + r, err := NewPayloadReader(schemapb.DataType_Int8Vector, buffer, true) + require.NoError(t, err) + + readData, readDim, readValid, readNumRows, err := r.GetInt8VectorFromPayload() + require.NoError(t, err) + require.Equal(t, dim, readDim) + require.Equal(t, numRows, readNumRows) + require.Equal(t, numRows, len(readValid)) + + dataIdx = 0 + for i := 0; i < numRows; i++ { + require.Equal(t, validData[i], readValid[i]) + if validData[i] { + pos := dataIdx + for j := 0; j < dim; j++ { + require.Equal(t, data[dataIdx*dim+j], readData[pos*dim+j]) + } + dataIdx++ + } + } + } + }) + + t.Run("TestSparseFloatVector", func(t *testing.T) { + numRows := 100 + + type testCase struct { + name string + validDataSetup func([]bool) int + } + + testCases := []testCase{ + { + name: "half null", + validDataSetup: func(validData []bool) int { + validCount := 0 + for i := 0; i < numRows; i++ { + if i%2 == 0 { + validData[i] = true + validCount++ + } + } + return validCount + }, + }, + { + name: "all valid", + validDataSetup: func(validData []bool) int { + for i := 0; i < numRows; i++ { + validData[i] = true + } + return numRows + }, + }, + { + name: "all null", + validDataSetup: func(validData []bool) int { + return 0 + }, + }, + } + + for _, tc := range testCases { + w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector, WithNullable(true)) + require.NoError(t, err) + + validData := make([]bool, numRows) + tc.validDataSetup(validData) + + data := &SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Dim: 100, + }, + ValidData: validData, + } + for i := 0; i < numRows; i++ { + if validData[i] { + sparseVec := make([]byte, 16) + for j := 0; j < 16; j++ { + sparseVec[j] = byte((i*10 + j) % 256) + } + data.SparseFloatArray.Contents = append(data.SparseFloatArray.Contents, sparseVec) + } + } + + err = w.AddSparseFloatVectorToPayload(data) + require.NoError(t, err) + + err = w.FinishPayloadWriter() + require.NoError(t, err) + + buffer, err := w.GetPayloadBufferFromWriter() + require.NoError(t, err) + + r, err := NewPayloadReader(schemapb.DataType_SparseFloatVector, buffer, true) + require.NoError(t, err) + + readData, _, readValid, err := r.GetSparseFloatVectorFromPayload() + require.NoError(t, err) + require.Equal(t, numRows, len(readValid)) + require.Equal(t, numRows, len(readData.Contents)) + + for i := 0; i < numRows; i++ { + require.Equal(t, validData[i], readValid[i]) + if validData[i] { + require.NotNil(t, readData.Contents[i]) + require.Equal(t, 16, len(readData.Contents[i])) + for j := 0; j < 16; j++ { + require.Equal(t, byte((i*10+j)%256), readData.Contents[i][j]) + } + } else { + require.Nil(t, readData.Contents[i]) + } + } + } }) t.Run("TestAddBool with wrong valids", func(t *testing.T) { diff --git a/internal/storage/payload_writer.go b/internal/storage/payload_writer.go index e06a8a5a8e..d9e56db416 100644 --- a/internal/storage/payload_writer.go +++ b/internal/storage/payload_writer.go @@ -103,9 +103,6 @@ func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions if w.dim.IsNull() { return nil, merr.WrapErrParameterInvalidMsg("incorrect input numbers") } - if w.nullable { - return nil, merr.WrapErrParameterInvalidMsg("vector type does not support nullable") - } } else { w.dim = NewNullableInt(1) } @@ -125,8 +122,13 @@ func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions w.arrowType = arrow.ListOf(elemType) w.builder = array.NewListBuilder(memory.DefaultAllocator, elemType) } else { - w.arrowType = MilvusDataTypeToArrowType(colType, *w.dim.Value) - w.builder = array.NewBuilder(memory.DefaultAllocator, w.arrowType) + if w.nullable && typeutil.IsVectorType(colType) && !typeutil.IsSparseFloatVectorType(colType) { + w.arrowType = &arrow.BinaryType{} + w.builder = array.NewBinaryBuilder(memory.DefaultAllocator, arrow.BinaryTypes.Binary) + } else { + w.arrowType = MilvusDataTypeToArrowType(colType, *w.dim.Value) + w.builder = array.NewBuilder(memory.DefaultAllocator, w.arrowType) + } } return w, nil } @@ -262,25 +264,25 @@ func (w *NativePayloadWriter) AddDataToPayloadForUT(data interface{}, validData if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddBinaryVectorToPayload(val, w.dim.GetValue()) + return w.AddBinaryVectorToPayload(val, w.dim.GetValue(), validData) case schemapb.DataType_FloatVector: val, ok := data.([]float32) if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddFloatVectorToPayload(val, w.dim.GetValue()) + return w.AddFloatVectorToPayload(val, w.dim.GetValue(), validData) case schemapb.DataType_Float16Vector: val, ok := data.([]byte) if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddFloat16VectorToPayload(val, w.dim.GetValue()) + return w.AddFloat16VectorToPayload(val, w.dim.GetValue(), validData) case schemapb.DataType_BFloat16Vector: val, ok := data.([]byte) if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddBFloat16VectorToPayload(val, w.dim.GetValue()) + return w.AddBFloat16VectorToPayload(val, w.dim.GetValue(), validData) case schemapb.DataType_SparseFloatVector: val, ok := data.(*SparseFloatVectorFieldData) if !ok { @@ -292,7 +294,7 @@ func (w *NativePayloadWriter) AddDataToPayloadForUT(data interface{}, validData if !ok { return merr.WrapErrParameterInvalidMsg("incorrect data type") } - return w.AddInt8VectorToPayload(val, w.dim.GetValue()) + return w.AddInt8VectorToPayload(val, w.dim.GetValue(), validData) case schemapb.DataType_ArrayOfVector: val, ok := data.(*VectorArrayFieldData) if !ok { @@ -660,106 +662,262 @@ func (w *NativePayloadWriter) AddOneGeometryToPayload(data []byte, isValid bool) return nil } -func (w *NativePayloadWriter) AddBinaryVectorToPayload(data []byte, dim int) error { +func (w *NativePayloadWriter) AddBinaryVectorToPayload(data []byte, dim int, validData []bool) error { if w.finished { return errors.New("can't append data to finished binary vector payload") } - if len(data) == 0 { - return errors.New("can't add empty msgs into binary vector payload") - } - - builder, ok := w.builder.(*array.FixedSizeBinaryBuilder) - if !ok { - return errors.New("failed to cast BinaryVectorBuilder") - } - byteLength := dim / 8 - length := len(data) / byteLength - builder.Reserve(length) - for i := 0; i < length; i++ { - builder.Append(data[i*byteLength : (i+1)*byteLength]) + var numRows int + if w.nullable && len(validData) > 0 { + numRows = len(validData) + validCount := 0 + for _, valid := range validData { + if valid { + validCount++ + } + } + expectedDataLen := validCount * byteLength + if len(data) != expectedDataLen { + msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen) + return merr.WrapErrParameterInvalidMsg(msg) + } + } else { + if len(data) == 0 { + return errors.New("can't add empty msgs into binary vector payload") + } + numRows = len(data) / byteLength + if !w.nullable && len(validData) != 0 { + msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData)) + return merr.WrapErrParameterInvalidMsg(msg) + } + } + + if w.nullable { + builder, ok := w.builder.(*array.BinaryBuilder) + if !ok { + return errors.New("failed to cast to BinaryBuilder for nullable BinaryVector") + } + + builder.Reserve(numRows) + dataIdx := 0 + for i := 0; i < numRows; i++ { + if len(validData) > 0 && !validData[i] { + builder.AppendNull() + } else { + builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength]) + dataIdx++ + } + } + } else { + builder, ok := w.builder.(*array.FixedSizeBinaryBuilder) + if !ok { + return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable BinaryVector") + } + + builder.Reserve(numRows) + for i := 0; i < numRows; i++ { + builder.Append(data[i*byteLength : (i+1)*byteLength]) + } } return nil } -func (w *NativePayloadWriter) AddFloatVectorToPayload(data []float32, dim int) error { +func (w *NativePayloadWriter) AddFloatVectorToPayload(data []float32, dim int, validData []bool) error { if w.finished { return errors.New("can't append data to finished float vector payload") } - if len(data) == 0 { - return errors.New("can't add empty msgs into float vector payload") - } - - builder, ok := w.builder.(*array.FixedSizeBinaryBuilder) - if !ok { - return errors.New("failed to cast FloatVectorBuilder") + var numRows int + if w.nullable && len(validData) > 0 { + numRows = len(validData) + validCount := 0 + for _, valid := range validData { + if valid { + validCount++ + } + } + expectedDataLen := validCount * dim + if len(data) != expectedDataLen { + msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * dim(%d) = %d", len(data), validCount, dim, expectedDataLen) + return merr.WrapErrParameterInvalidMsg(msg) + } + } else { + if len(data) == 0 { + return errors.New("can't add empty msgs into float vector payload") + } + numRows = len(data) / dim + if !w.nullable && len(validData) != 0 { + msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData)) + return merr.WrapErrParameterInvalidMsg(msg) + } } byteLength := dim * 4 - length := len(data) / dim - builder.Reserve(length) - bytesData := make([]byte, byteLength) - for i := 0; i < length; i++ { - vec := data[i*dim : (i+1)*dim] - for j := range vec { - bytes := math.Float32bits(vec[j]) - common.Endian.PutUint32(bytesData[j*4:], bytes) + if w.nullable { + builder, ok := w.builder.(*array.BinaryBuilder) + if !ok { + return errors.New("failed to cast to BinaryBuilder for nullable FloatVector") + } + + builder.Reserve(numRows) + bytesData := make([]byte, byteLength) + dataIdx := 0 + for i := 0; i < numRows; i++ { + if len(validData) > 0 && !validData[i] { + builder.AppendNull() + } else { + vec := data[dataIdx*dim : (dataIdx+1)*dim] + for j := range vec { + bytes := math.Float32bits(vec[j]) + common.Endian.PutUint32(bytesData[j*4:], bytes) + } + builder.Append(bytesData) + dataIdx++ + } + } + } else { + builder, ok := w.builder.(*array.FixedSizeBinaryBuilder) + if !ok { + return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable FloatVector") + } + + builder.Reserve(numRows) + bytesData := make([]byte, byteLength) + for i := 0; i < numRows; i++ { + vec := data[i*dim : (i+1)*dim] + for j := range vec { + bytes := math.Float32bits(vec[j]) + common.Endian.PutUint32(bytesData[j*4:], bytes) + } + builder.Append(bytesData) } - builder.Append(bytesData) } return nil } -func (w *NativePayloadWriter) AddFloat16VectorToPayload(data []byte, dim int) error { +func (w *NativePayloadWriter) AddFloat16VectorToPayload(data []byte, dim int, validData []bool) error { if w.finished { return errors.New("can't append data to finished float16 payload") } - if len(data) == 0 { - return errors.New("can't add empty msgs into float16 payload") - } - - builder, ok := w.builder.(*array.FixedSizeBinaryBuilder) - if !ok { - return errors.New("failed to cast Float16Builder") - } - byteLength := dim * 2 - length := len(data) / byteLength + var numRows int + if w.nullable && len(validData) > 0 { + numRows = len(validData) + validCount := 0 + for _, valid := range validData { + if valid { + validCount++ + } + } + expectedDataLen := validCount * byteLength + if len(data) != expectedDataLen { + msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen) + return merr.WrapErrParameterInvalidMsg(msg) + } + } else { + if len(data) == 0 { + return errors.New("can't add empty msgs into float16 payload") + } + numRows = len(data) / byteLength + if !w.nullable && len(validData) != 0 { + msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData)) + return merr.WrapErrParameterInvalidMsg(msg) + } + } - builder.Reserve(length) - for i := 0; i < length; i++ { - builder.Append(data[i*byteLength : (i+1)*byteLength]) + if w.nullable { + builder, ok := w.builder.(*array.BinaryBuilder) + if !ok { + return errors.New("failed to cast to BinaryBuilder for nullable Float16Vector") + } + + builder.Reserve(numRows) + dataIdx := 0 + for i := 0; i < numRows; i++ { + if len(validData) > 0 && !validData[i] { + builder.AppendNull() + } else { + builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength]) + dataIdx++ + } + } + } else { + builder, ok := w.builder.(*array.FixedSizeBinaryBuilder) + if !ok { + return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable Float16Vector") + } + + builder.Reserve(numRows) + for i := 0; i < numRows; i++ { + builder.Append(data[i*byteLength : (i+1)*byteLength]) + } } return nil } -func (w *NativePayloadWriter) AddBFloat16VectorToPayload(data []byte, dim int) error { +func (w *NativePayloadWriter) AddBFloat16VectorToPayload(data []byte, dim int, validData []bool) error { if w.finished { return errors.New("can't append data to finished BFloat16 payload") } - if len(data) == 0 { - return errors.New("can't add empty msgs into BFloat16 payload") - } - - builder, ok := w.builder.(*array.FixedSizeBinaryBuilder) - if !ok { - return errors.New("failed to cast BFloat16Builder") - } - byteLength := dim * 2 - length := len(data) / byteLength + var numRows int + if w.nullable && len(validData) > 0 { + numRows = len(validData) + validCount := 0 + for _, valid := range validData { + if valid { + validCount++ + } + } + expectedDataLen := validCount * byteLength + if len(data) != expectedDataLen { + msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen) + return merr.WrapErrParameterInvalidMsg(msg) + } + } else { + if len(data) == 0 { + return errors.New("can't add empty msgs into BFloat16 payload") + } + numRows = len(data) / byteLength + if !w.nullable && len(validData) != 0 { + msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData)) + return merr.WrapErrParameterInvalidMsg(msg) + } + } - builder.Reserve(length) - for i := 0; i < length; i++ { - builder.Append(data[i*byteLength : (i+1)*byteLength]) + if w.nullable { + builder, ok := w.builder.(*array.BinaryBuilder) + if !ok { + return errors.New("failed to cast to BinaryBuilder for nullable BFloat16Vector") + } + + builder.Reserve(numRows) + dataIdx := 0 + for i := 0; i < numRows; i++ { + if len(validData) > 0 && !validData[i] { + builder.AppendNull() + } else { + builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength]) + dataIdx++ + } + } + } else { + builder, ok := w.builder.(*array.FixedSizeBinaryBuilder) + if !ok { + return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable BFloat16Vector") + } + + builder.Reserve(numRows) + for i := 0; i < numRows; i++ { + builder.Append(data[i*byteLength : (i+1)*byteLength]) + } } return nil @@ -769,41 +927,107 @@ func (w *NativePayloadWriter) AddSparseFloatVectorToPayload(data *SparseFloatVec if w.finished { return errors.New("can't append data to finished sparse float vector payload") } + + var numRows int + if w.nullable && len(data.ValidData) > 0 { + numRows = len(data.ValidData) + validCount := 0 + for _, valid := range data.ValidData { + if valid { + validCount++ + } + } + if len(data.SparseFloatArray.Contents) != validCount { + msg := fmt.Sprintf("when nullable, Contents length(%d) must equal to valid count(%d)", len(data.SparseFloatArray.Contents), validCount) + return merr.WrapErrParameterInvalidMsg(msg) + } + } else { + numRows = len(data.SparseFloatArray.Contents) + if !w.nullable && len(data.ValidData) != 0 { + msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(data.ValidData)) + return merr.WrapErrParameterInvalidMsg(msg) + } + } + builder, ok := w.builder.(*array.BinaryBuilder) if !ok { return errors.New("failed to cast SparseFloatVectorBuilder") } - length := len(data.SparseFloatArray.Contents) - builder.Reserve(length) - for i := 0; i < length; i++ { - builder.Append(data.SparseFloatArray.Contents[i]) + + builder.Reserve(numRows) + dataIdx := 0 + for i := 0; i < numRows; i++ { + if w.nullable && len(data.ValidData) > 0 && !data.ValidData[i] { + builder.AppendNull() + } else { + builder.Append(data.SparseFloatArray.Contents[dataIdx]) + dataIdx++ + } } return nil } -func (w *NativePayloadWriter) AddInt8VectorToPayload(data []int8, dim int) error { +func (w *NativePayloadWriter) AddInt8VectorToPayload(data []int8, dim int, validData []bool) error { if w.finished { return errors.New("can't append data to finished int8 vector payload") } - if len(data) == 0 { - return errors.New("can't add empty msgs into int8 vector payload") + var numRows int + if w.nullable && len(validData) > 0 { + numRows = len(validData) + validCount := 0 + for _, valid := range validData { + if valid { + validCount++ + } + } + expectedDataLen := validCount * dim + if len(data) != expectedDataLen { + msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * dim(%d) = %d", len(data), validCount, dim, expectedDataLen) + return merr.WrapErrParameterInvalidMsg(msg) + } + } else { + if len(data) == 0 { + return errors.New("can't add empty msgs into int8 vector payload") + } + numRows = len(data) / dim + if !w.nullable && len(validData) != 0 { + msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData)) + return merr.WrapErrParameterInvalidMsg(msg) + } } - builder, ok := w.builder.(*array.FixedSizeBinaryBuilder) - if !ok { - return errors.New("failed to cast Int8VectorBuilder") - } + if w.nullable { + builder, ok := w.builder.(*array.BinaryBuilder) + if !ok { + return errors.New("failed to cast to BinaryBuilder for nullable Int8Vector") + } - byteLength := dim - length := len(data) / byteLength + builder.Reserve(numRows) + dataIdx := 0 + for i := 0; i < numRows; i++ { + if len(validData) > 0 && !validData[i] { + builder.AppendNull() + } else { + vec := data[dataIdx*dim : (dataIdx+1)*dim] + vecBytes := arrow.Int8Traits.CastToBytes(vec) + builder.Append(vecBytes) + dataIdx++ + } + } + } else { + builder, ok := w.builder.(*array.FixedSizeBinaryBuilder) + if !ok { + return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable Int8Vector") + } - builder.Reserve(length) - for i := 0; i < length; i++ { - vec := data[i*dim : (i+1)*dim] - vecBytes := arrow.Int8Traits.CastToBytes(vec) - builder.Append(vecBytes) + builder.Reserve(numRows) + for i := 0; i < numRows; i++ { + vec := data[i*dim : (i+1)*dim] + vecBytes := arrow.Int8Traits.CastToBytes(vec) + builder.Append(vecBytes) + } } return nil @@ -827,6 +1051,11 @@ func (w *NativePayloadWriter) FinishPayloadWriter() error { []string{"elementType", "dim"}, []string{fmt.Sprintf("%d", int32(*w.elementType)), fmt.Sprintf("%d", w.dim.GetValue())}, ) + } else if w.nullable && typeutil.IsVectorType(w.dataType) && !typeutil.IsSparseFloatVectorType(w.dataType) { + metadata = arrow.NewMetadata( + []string{"dim"}, + []string{fmt.Sprintf("%d", w.dim.GetValue())}, + ) } field := arrow.Field{ @@ -849,7 +1078,8 @@ func (w *NativePayloadWriter) FinishPayloadWriter() error { defer table.Release() arrowWriterProps := pqarrow.DefaultWriterProps() - if w.dataType == schemapb.DataType_ArrayOfVector { + if w.dataType == schemapb.DataType_ArrayOfVector || + (w.nullable && typeutil.IsVectorType(w.dataType) && !typeutil.IsSparseFloatVectorType(w.dataType)) { // Store metadata in the Arrow writer properties arrowWriterProps = pqarrow.NewArrowWriterProperties( pqarrow.WithStoreSchema(), diff --git a/internal/storage/payload_writer_test.go b/internal/storage/payload_writer_test.go index 8c3c798eb1..f26cb3fa78 100644 --- a/internal/storage/payload_writer_test.go +++ b/internal/storage/payload_writer_test.go @@ -260,14 +260,14 @@ func TestPayloadWriter_Failed(t *testing.T) { err = w.FinishPayloadWriter() require.NoError(t, err) - err = w.AddBinaryVectorToPayload(data, 8) + err = w.AddBinaryVectorToPayload(data, 8, nil) require.Error(t, err) w, err = NewPayloadWriter(schemapb.DataType_Int64) require.Nil(t, err) require.NotNil(t, w) - err = w.AddBinaryVectorToPayload(data, 8) + err = w.AddBinaryVectorToPayload(data, 8, nil) require.Error(t, err) }) diff --git a/internal/storage/print_binlog.go b/internal/storage/print_binlog.go index c046f32e68..133a0d6844 100644 --- a/internal/storage/print_binlog.go +++ b/internal/storage/print_binlog.go @@ -303,7 +303,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface fmt.Printf("\t\t%d : %s\n", i, val[i]) } case schemapb.DataType_BinaryVector: - val, dim, err := reader.GetBinaryVectorFromPayload() + val, dim, _, _, err := reader.GetBinaryVectorFromPayload() if err != nil { return err } @@ -318,7 +318,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface fmt.Println() } case schemapb.DataType_Float16Vector: - val, dim, err := reader.GetFloat16VectorFromPayload() + val, dim, _, _, err := reader.GetFloat16VectorFromPayload() if err != nil { return err } @@ -333,7 +333,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface fmt.Println() } case schemapb.DataType_BFloat16Vector: - val, dim, err := reader.GetBFloat16VectorFromPayload() + val, dim, _, _, err := reader.GetBFloat16VectorFromPayload() if err != nil { return err } @@ -349,7 +349,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface } case schemapb.DataType_FloatVector: - val, dim, err := reader.GetFloatVectorFromPayload() + val, dim, _, _, err := reader.GetFloatVectorFromPayload() if err != nil { return err } @@ -362,6 +362,20 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface } fmt.Println() } + case schemapb.DataType_Int8Vector: + val, dim, _, _, err := reader.GetInt8VectorFromPayload() + if err != nil { + return err + } + length := len(val) / dim + for i := 0; i < length; i++ { + fmt.Printf("\t\t%d :", i) + for j := 0; j < dim; j++ { + idx := i*dim + j + fmt.Printf(" %d", val[idx]) + } + fmt.Println() + } case schemapb.DataType_JSON: rows, err := reader.GetPayloadLengthFromReader() @@ -397,7 +411,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface fmt.Printf("\t\t%d : %v\n", i, v) } case schemapb.DataType_SparseFloatVector: - sparseData, _, err := reader.GetSparseFloatVectorFromPayload() + sparseData, _, _, err := reader.GetSparseFloatVectorFromPayload() if err != nil { return err } diff --git a/internal/storage/print_binlog_test.go b/internal/storage/print_binlog_test.go index 2dea59de37..b7b47ddc99 100644 --- a/internal/storage/print_binlog_test.go +++ b/internal/storage/print_binlog_test.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/proto/etcdpb" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" "github.com/milvus-io/milvus/pkg/v2/util/uniquegenerator" ) @@ -215,6 +216,23 @@ func TestPrintBinlogFiles(t *testing.T) { Description: "description_15", DataType: schemapb.DataType_Geometry, }, + { + FieldID: 114, + Name: "field_int8_vector", + IsPrimaryKey: false, + Description: "description_16", + DataType: schemapb.DataType_Int8Vector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "4"}, + }, + }, + { + FieldID: 115, + Name: "field_sparse_float_vector", + IsPrimaryKey: false, + Description: "description_17", + DataType: schemapb.DataType_SparseFloatVector, + }, }, }, } @@ -280,6 +298,19 @@ func TestPrintBinlogFiles(t *testing.T) { {0x01, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A, 0x0D, 0x1B, 0x4F, 0x4F, 0x9A, 0x3D, 0x40, 0x03, 0xA6, 0xB4, 0xA6, 0xA4, 0xD2, 0xC5, 0xC0, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A}, }, }, + 114: &Int8VectorFieldData{ + Data: []int8{1, 2, 3, 4, 5, 6, 7, 8}, + Dim: 4, + }, + 115: &SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Dim: 100, + Contents: [][]byte{ + typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{1.1, 1.2, 1.3}), + typeutil.CreateSparseFloatRow([]uint32{10, 20, 30}, []float32{2.1, 2.2, 2.3}), + }, + }, + }, }, } @@ -344,6 +375,19 @@ func TestPrintBinlogFiles(t *testing.T) { {0x01, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A, 0x0D, 0x1B, 0x4F, 0x4F, 0x9A, 0x3D, 0x40, 0x03, 0xA6, 0xB4, 0xA6, 0xA4, 0xD2, 0xC5, 0xC0, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A}, }, }, + 114: &Int8VectorFieldData{ + Data: []int8{11, 12, 13, 14, 15, 16, 17, 18}, + Dim: 4, + }, + 115: &SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Dim: 100, + Contents: [][]byte{ + typeutil.CreateSparseFloatRow([]uint32{5, 6, 7}, []float32{3.1, 3.2, 3.3}), + typeutil.CreateSparseFloatRow([]uint32{15, 25, 35}, []float32{4.1, 4.2, 4.3}), + }, + }, + }, }, } firstBlobs, err := insertCodec.Serialize(1, 1, insertDataFirst) diff --git a/internal/storage/schema.go b/internal/storage/schema.go index ddeb748e17..60a6f440ef 100644 --- a/internal/storage/schema.go +++ b/internal/storage/schema.go @@ -38,8 +38,28 @@ func ConvertToArrowSchema(schema *schemapb.CollectionSchema, useFieldID bool) (* } arrowType := serdeMap[field.DataType].arrowType(dim, elementType) + + if field.GetNullable() { + switch field.DataType { + case schemapb.DataType_BinaryVector, schemapb.DataType_FloatVector, + schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_Int8Vector: + arrowType = arrow.BinaryTypes.Binary + } + } + arrowField := ConvertToArrowField(field, arrowType, useFieldID) + if field.GetNullable() { + switch field.DataType { + case schemapb.DataType_BinaryVector, schemapb.DataType_FloatVector, + schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_Int8Vector: + arrowField.Metadata = arrow.NewMetadata( + []string{packed.ArrowFieldIdMetadataKey, "dim"}, + []string{strconv.Itoa(int(field.GetFieldID())), strconv.Itoa(dim)}, + ) + } + } + // Add extra metadata for ArrayOfVector if field.DataType == schemapb.DataType_ArrayOfVector { arrowField.Metadata = arrow.NewMetadata( diff --git a/internal/storage/schema_test.go b/internal/storage/schema_test.go index 0b72e8aeb7..6784a67ead 100644 --- a/internal/storage/schema_test.go +++ b/internal/storage/schema_test.go @@ -43,12 +43,18 @@ func TestConvertArrowSchema(t *testing.T) { {FieldID: 14, Name: "field13", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, {FieldID: 15, Name: "field14", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, {FieldID: 16, Name: "field15", DataType: schemapb.DataType_Int8Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, + {FieldID: 17, Name: "field16", DataType: schemapb.DataType_BinaryVector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, + {FieldID: 18, Name: "field17", DataType: schemapb.DataType_FloatVector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, + {FieldID: 19, Name: "field18", DataType: schemapb.DataType_Float16Vector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, + {FieldID: 20, Name: "field19", DataType: schemapb.DataType_BFloat16Vector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, + {FieldID: 21, Name: "field20", DataType: schemapb.DataType_Int8Vector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, + {FieldID: 22, Name: "field21", DataType: schemapb.DataType_SparseFloatVector, Nullable: true}, } StructArrayFieldSchemas := []*schemapb.StructArrayFieldSchema{ - {FieldID: 17, Name: "struct_field0", Fields: []*schemapb.FieldSchema{ - {FieldID: 18, Name: "field16", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}, - {FieldID: 19, Name: "field17", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}, + {FieldID: 23, Name: "struct_field0", Fields: []*schemapb.FieldSchema{ + {FieldID: 24, Name: "field22", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64}, + {FieldID: 25, Name: "field23", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float}, }}, } @@ -59,6 +65,14 @@ func TestConvertArrowSchema(t *testing.T) { arrowSchema, err := ConvertToArrowSchema(schema, false) assert.NoError(t, err) assert.Equal(t, len(fieldSchemas)+len(StructArrayFieldSchemas[0].Fields), len(arrowSchema.Fields())) + + for i, field := range arrowSchema.Fields() { + if i >= 16 && i <= 20 { + dimVal, ok := field.Metadata.GetValue("dim") + assert.True(t, ok, "nullable vector field should have dim metadata") + assert.Equal(t, "128", dimVal) + } + } } func TestConvertArrowSchemaWithoutDim(t *testing.T) { diff --git a/internal/storage/serde.go b/internal/storage/serde.go index 0e720f90c7..f11d9f811a 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -553,8 +553,12 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { b.AppendNull() return true } - if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok { - if v, ok := v.([]byte); ok { + if v, ok := v.([]byte); ok { + if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok { + builder.Append(v) + return true + } + if builder, ok := b.(*array.BinaryBuilder); ok { builder.Append(v) return true } @@ -607,14 +611,21 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { b.AppendNull() return true } + var bytesData []byte + if vv, ok := v.([]byte); ok { + bytesData = vv + } else if vv, ok := v.([]int8); ok { + bytesData = arrow.Int8Traits.CastToBytes(vv) + } else { + return false + } if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok { - if vv, ok := v.([]byte); ok { - builder.Append(vv) - return true - } else if vv, ok := v.([]int8); ok { - builder.Append(arrow.Int8Traits.CastToBytes(vv)) - return true - } + builder.Append(bytesData) + return true + } + if builder, ok := b.(*array.BinaryBuilder); ok { + builder.Append(bytesData) + return true } return false }, @@ -643,15 +654,19 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { b.AppendNull() return true } - if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok { - if vv, ok := v.([]float32); ok { - dim := len(vv) - byteLength := dim * 4 - bytesData := make([]byte, byteLength) - for i, vec := range vv { - bytes := math.Float32bits(vec) - common.Endian.PutUint32(bytesData[i*4:], bytes) - } + if vv, ok := v.([]float32); ok { + dim := len(vv) + byteLength := dim * 4 + bytesData := make([]byte, byteLength) + for i, vec := range vv { + bytes := math.Float32bits(vec) + common.Endian.PutUint32(bytesData[i*4:], bytes) + } + if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok { + builder.Append(bytesData) + return true + } + if builder, ok := b.(*array.BinaryBuilder); ok { builder.Append(bytesData) return true } @@ -987,7 +1002,16 @@ func newSingleFieldRecordWriter(field *schemapb.FieldSchema, writer io.Writer, o []string{fmt.Sprintf("%d", int32(elementType)), fmt.Sprintf("%d", dim)}, ) } - arrowType = serdeMap[field.DataType].arrowType(int(dim), elementType) + + if field.GetNullable() && typeutil.IsVectorType(field.DataType) && !typeutil.IsSparseFloatVectorType(field.DataType) { + arrowType = arrow.BinaryTypes.Binary + fieldMetadata = arrow.NewMetadata( + []string{"dim"}, + []string{fmt.Sprintf("%d", dim)}, + ) + } else { + arrowType = serdeMap[field.DataType].arrowType(int(dim), elementType) + } w := &singleFieldRecordWriter{ fieldId: field.FieldID, @@ -1199,10 +1223,40 @@ func BuildRecord(b *array.RecordBuilder, data *InsertData, schema *schemapb.Coll elementType = field.GetElementType() } - for j := 0; j < fieldData.RowNum(); j++ { - ok = typeEntry.serialize(fBuilder, fieldData.GetRow(j), elementType) - if !ok { - return merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", field.DataType.String())) + if field.GetNullable() && typeutil.IsVectorType(field.DataType) { + var validData []bool + switch fd := fieldData.(type) { + case *FloatVectorFieldData: + validData = fd.ValidData + case *BinaryVectorFieldData: + validData = fd.ValidData + case *Float16VectorFieldData: + validData = fd.ValidData + case *BFloat16VectorFieldData: + validData = fd.ValidData + case *SparseFloatVectorFieldData: + validData = fd.ValidData + case *Int8VectorFieldData: + validData = fd.ValidData + } + // Use len(validData) as logical row count, GetRow takes logical index + for j := 0; j < len(validData); j++ { + if !validData[j] { + fBuilder.(*array.BinaryBuilder).AppendNull() + } else { + rowData := fieldData.GetRow(j) + ok = typeEntry.serialize(fBuilder, rowData, elementType) + if !ok { + return merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", field.DataType.String())) + } + } + } + } else { + for j := 0; j < fieldData.RowNum(); j++ { + ok = typeEntry.serialize(fBuilder, fieldData.GetRow(j), elementType) + if !ok { + return merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", field.DataType.String())) + } } } return nil diff --git a/internal/storage/serde_test.go b/internal/storage/serde_test.go index 718b8e77bc..a6973e11a8 100644 --- a/internal/storage/serde_test.go +++ b/internal/storage/serde_test.go @@ -103,6 +103,9 @@ func TestSerDe(t *testing.T) { {"test bfloat16 vector null", args{dt: schemapb.DataType_BFloat16Vector, v: nil}, nil, true}, {"test bfloat16 vector negative", args{dt: schemapb.DataType_BFloat16Vector, v: -1}, nil, false}, {"test int8 vector", args{dt: schemapb.DataType_Int8Vector, v: []int8{10}}, []int8{10}, true}, + {"test sparse float vector", args{dt: schemapb.DataType_SparseFloatVector, v: []byte{1, 2, 3, 4}}, []byte{1, 2, 3, 4}, true}, + {"test sparse float vector null", args{dt: schemapb.DataType_SparseFloatVector, v: nil}, nil, true}, + {"test sparse float vector negative", args{dt: schemapb.DataType_SparseFloatVector, v: -1}, nil, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/storage/utils.go b/internal/storage/utils.go index 303fc7d772..d8343da31e 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -568,10 +568,17 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche } srcData := srcField.GetVectors().GetFloatVector().GetData() - fieldData = &FloatVectorFieldData{ - Data: srcData, - Dim: dim, + validData := srcField.GetValidData() + fd := &FloatVectorFieldData{ + Data: srcData, + Dim: dim, + ValidData: validData, + Nullable: field.GetNullable(), } + if len(validData) > 0 { + fd.L2PMapping.Build(validData, 0, len(validData)) + } + fieldData = fd case schemapb.DataType_BinaryVector: dim, err := GetDimFromParams(field.TypeParams) @@ -581,11 +588,17 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche } srcData := srcField.GetVectors().GetBinaryVector() - - fieldData = &BinaryVectorFieldData{ - Data: srcData, - Dim: dim, + validData := srcField.GetValidData() + fd := &BinaryVectorFieldData{ + Data: srcData, + Dim: dim, + ValidData: validData, + Nullable: field.GetNullable(), } + if len(validData) > 0 { + fd.L2PMapping.Build(validData, 0, len(validData)) + } + fieldData = fd case schemapb.DataType_Float16Vector: dim, err := GetDimFromParams(field.TypeParams) @@ -595,11 +608,17 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche } srcData := srcField.GetVectors().GetFloat16Vector() - - fieldData = &Float16VectorFieldData{ - Data: srcData, - Dim: dim, + validData := srcField.GetValidData() + fd := &Float16VectorFieldData{ + Data: srcData, + Dim: dim, + ValidData: validData, + Nullable: field.GetNullable(), } + if len(validData) > 0 { + fd.L2PMapping.Build(validData, 0, len(validData)) + } + fieldData = fd case schemapb.DataType_BFloat16Vector: dim, err := GetDimFromParams(field.TypeParams) @@ -609,16 +628,39 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche } srcData := srcField.GetVectors().GetBfloat16Vector() - - fieldData = &BFloat16VectorFieldData{ - Data: srcData, - Dim: dim, + validData := srcField.GetValidData() + fd := &BFloat16VectorFieldData{ + Data: srcData, + Dim: dim, + ValidData: validData, + Nullable: field.GetNullable(), } + if len(validData) > 0 { + fd.L2PMapping.Build(validData, 0, len(validData)) + } + fieldData = fd case schemapb.DataType_SparseFloatVector: - fieldData = &SparseFloatVectorFieldData{ - SparseFloatArray: *srcFields[field.FieldID].GetVectors().GetSparseFloatVector(), + sparseArray := srcFields[field.FieldID].GetVectors().GetSparseFloatVector() + validData := srcField.GetValidData() + var contents [][]byte + var dim int64 + if sparseArray != nil { + contents = sparseArray.GetContents() + dim = sparseArray.GetDim() } + fd := &SparseFloatVectorFieldData{ + SparseFloatArray: schemapb.SparseFloatArray{ + Contents: contents, + Dim: dim, + }, + ValidData: validData, + Nullable: field.GetNullable(), + } + if len(validData) > 0 { + fd.L2PMapping.Build(validData, 0, len(validData)) + } + fieldData = fd case schemapb.DataType_Int8Vector: dim, err := GetDimFromParams(field.TypeParams) @@ -628,10 +670,17 @@ func ColumnBasedInsertMsgToInsertData(msg *msgstream.InsertMsg, collSchema *sche } srcData := srcField.GetVectors().GetInt8Vector() - fieldData = &Int8VectorFieldData{ - Data: lo.Map(srcData, func(v byte, _ int) int8 { return int8(v) }), - Dim: dim, + validData := srcField.GetValidData() + fd := &Int8VectorFieldData{ + Data: lo.Map(srcData, func(v byte, _ int) int8 { return int8(v) }), + Dim: dim, + ValidData: validData, + Nullable: field.GetNullable(), } + if len(validData) > 0 { + fd.L2PMapping.Build(validData, 0, len(validData)) + } + fieldData = fd case schemapb.DataType_Bool: srcData := srcField.GetScalars().GetBoolData().GetData() @@ -987,54 +1036,80 @@ func mergeJSONField(data *InsertData, fid FieldID, field *JSONFieldData) { func mergeBinaryVectorField(data *InsertData, fid FieldID, field *BinaryVectorFieldData) { if _, ok := data.Data[fid]; !ok { fieldData := &BinaryVectorFieldData{ - Data: nil, - Dim: field.Dim, + Data: nil, + Dim: field.Dim, + ValidData: nil, + Nullable: field.Nullable, } data.Data[fid] = fieldData } fieldData := data.Data[fid].(*BinaryVectorFieldData) fieldData.Data = append(fieldData.Data, field.Data...) + if len(field.ValidData) > 0 { + fieldData.L2PMapping.Build(field.ValidData, len(fieldData.ValidData), len(field.ValidData)) + } + fieldData.ValidData = append(fieldData.ValidData, field.ValidData...) } func mergeFloatVectorField(data *InsertData, fid FieldID, field *FloatVectorFieldData) { if _, ok := data.Data[fid]; !ok { fieldData := &FloatVectorFieldData{ - Data: nil, - Dim: field.Dim, + Data: nil, + Dim: field.Dim, + ValidData: nil, + Nullable: field.Nullable, } data.Data[fid] = fieldData } fieldData := data.Data[fid].(*FloatVectorFieldData) fieldData.Data = append(fieldData.Data, field.Data...) + if len(field.ValidData) > 0 { + fieldData.L2PMapping.Build(field.ValidData, len(fieldData.ValidData), len(field.ValidData)) + } + fieldData.ValidData = append(fieldData.ValidData, field.ValidData...) } func mergeFloat16VectorField(data *InsertData, fid FieldID, field *Float16VectorFieldData) { if _, ok := data.Data[fid]; !ok { fieldData := &Float16VectorFieldData{ - Data: nil, - Dim: field.Dim, + Data: nil, + Dim: field.Dim, + ValidData: nil, + Nullable: field.Nullable, } data.Data[fid] = fieldData } fieldData := data.Data[fid].(*Float16VectorFieldData) fieldData.Data = append(fieldData.Data, field.Data...) + if len(field.ValidData) > 0 { + fieldData.L2PMapping.Build(field.ValidData, len(fieldData.ValidData), len(field.ValidData)) + } + fieldData.ValidData = append(fieldData.ValidData, field.ValidData...) } func mergeBFloat16VectorField(data *InsertData, fid FieldID, field *BFloat16VectorFieldData) { if _, ok := data.Data[fid]; !ok { fieldData := &BFloat16VectorFieldData{ - Data: nil, - Dim: field.Dim, + Data: nil, + Dim: field.Dim, + ValidData: nil, + Nullable: field.Nullable, } data.Data[fid] = fieldData } fieldData := data.Data[fid].(*BFloat16VectorFieldData) fieldData.Data = append(fieldData.Data, field.Data...) + if len(field.ValidData) > 0 { + fieldData.L2PMapping.Build(field.ValidData, len(fieldData.ValidData), len(field.ValidData)) + } + fieldData.ValidData = append(fieldData.ValidData, field.ValidData...) } func mergeSparseFloatVectorField(data *InsertData, fid FieldID, field *SparseFloatVectorFieldData) { if _, ok := data.Data[fid]; !ok { - data.Data[fid] = &SparseFloatVectorFieldData{} + data.Data[fid] = &SparseFloatVectorFieldData{ + Nullable: field.Nullable, + } } fieldData := data.Data[fid].(*SparseFloatVectorFieldData) fieldData.AppendAllRows(field) @@ -1043,13 +1118,19 @@ func mergeSparseFloatVectorField(data *InsertData, fid FieldID, field *SparseFlo func mergeInt8VectorField(data *InsertData, fid FieldID, field *Int8VectorFieldData) { if _, ok := data.Data[fid]; !ok { fieldData := &Int8VectorFieldData{ - Data: nil, - Dim: field.Dim, + Data: nil, + Dim: field.Dim, + ValidData: nil, + Nullable: field.Nullable, } data.Data[fid] = fieldData } fieldData := data.Data[fid].(*Int8VectorFieldData) fieldData.Data = append(fieldData.Data, field.Data...) + if len(field.ValidData) > 0 { + fieldData.L2PMapping.Build(field.ValidData, len(fieldData.ValidData), len(field.ValidData)) + } + fieldData.ValidData = append(fieldData.ValidData, field.ValidData...) } // MergeFieldData merge field into data. @@ -1405,6 +1486,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert Dim: int64(rawData.Dim), }, }, + ValidData: rawData.ValidData, } case *BinaryVectorFieldData: fieldData = &schemapb.FieldData{ @@ -1418,6 +1500,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert Dim: int64(rawData.Dim), }, }, + ValidData: rawData.ValidData, } case *Float16VectorFieldData: fieldData = &schemapb.FieldData{ @@ -1431,6 +1514,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert Dim: int64(rawData.Dim), }, }, + ValidData: rawData.ValidData, } case *BFloat16VectorFieldData: fieldData = &schemapb.FieldData{ @@ -1444,6 +1528,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert Dim: int64(rawData.Dim), }, }, + ValidData: rawData.ValidData, } case *SparseFloatVectorFieldData: fieldData = &schemapb.FieldData{ @@ -1456,6 +1541,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert }, }, }, + ValidData: rawData.ValidData, } case *Int8VectorFieldData: dataBytes := arrow.Int8Traits.CastToBytes(rawData.Data) @@ -1470,6 +1556,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert Dim: int64(rawData.Dim), }, }, + ValidData: rawData.ValidData, } case *VectorArrayFieldData: fieldData = &schemapb.FieldData{ diff --git a/internal/util/importutilv2/binlog/reader_test.go b/internal/util/importutilv2/binlog/reader_test.go index dde872823f..4355076bf4 100644 --- a/internal/util/importutilv2/binlog/reader_test.go +++ b/internal/util/importutilv2/binlog/reader_test.go @@ -168,19 +168,19 @@ func createBinlogBuf(t *testing.T, field *schemapb.FieldSchema, data storage.Fie } case schemapb.DataType_BinaryVector: vectors := data.(*storage.BinaryVectorFieldData).Data - err = evt.AddBinaryVectorToPayload(vectors, int(dim)) + err = evt.AddBinaryVectorToPayload(vectors, int(dim), nil) assert.NoError(t, err) case schemapb.DataType_FloatVector: vectors := data.(*storage.FloatVectorFieldData).Data - err = evt.AddFloatVectorToPayload(vectors, int(dim)) + err = evt.AddFloatVectorToPayload(vectors, int(dim), nil) assert.NoError(t, err) case schemapb.DataType_Float16Vector: vectors := data.(*storage.Float16VectorFieldData).Data - err = evt.AddFloat16VectorToPayload(vectors, int(dim)) + err = evt.AddFloat16VectorToPayload(vectors, int(dim), nil) assert.NoError(t, err) case schemapb.DataType_BFloat16Vector: vectors := data.(*storage.BFloat16VectorFieldData).Data - err = evt.AddBFloat16VectorToPayload(vectors, int(dim)) + err = evt.AddBFloat16VectorToPayload(vectors, int(dim), nil) assert.NoError(t, err) case schemapb.DataType_SparseFloatVector: vectors := data.(*storage.SparseFloatVectorFieldData) @@ -188,7 +188,7 @@ func createBinlogBuf(t *testing.T, field *schemapb.FieldSchema, data storage.Fie assert.NoError(t, err) case schemapb.DataType_Int8Vector: vectors := data.(*storage.Int8VectorFieldData).Data - err = evt.AddInt8VectorToPayload(vectors, int(dim)) + err = evt.AddInt8VectorToPayload(vectors, int(dim), nil) assert.NoError(t, err) case schemapb.DataType_ArrayOfVector: elementType := field.GetElementType() diff --git a/internal/util/indexcgowrapper/dataset.go b/internal/util/indexcgowrapper/dataset.go index dc2c0ace9c..a948cc6e46 100644 --- a/internal/util/indexcgowrapper/dataset.go +++ b/internal/util/indexcgowrapper/dataset.go @@ -6,7 +6,8 @@ import ( ) const ( - keyRawArr = "key_raw_arr" + keyRawArr = "key_raw_arr" + keyValidArr = "key_valid_arr" ) type Dataset struct { @@ -48,6 +49,7 @@ func GenSparseFloatVecDataset(data *storage.SparseFloatVectorFieldData) *Dataset // wrapper. Such tests are skipping sparse vector for now. return &Dataset{ DType: schemapb.DataType_SparseFloatVector, + Data: make(map[string]interface{}), } } @@ -72,73 +74,154 @@ func GenInt8VecDataset(vectors []int8) *Dataset { func GenDataset(data storage.FieldData) *Dataset { switch f := data.(type) { case *storage.BoolFieldData: - return &Dataset{ + ds := &Dataset{ DType: schemapb.DataType_Bool, Data: map[string]interface{}{ keyRawArr: f.Data, }, } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.Int8FieldData: - return &Dataset{ + ds := &Dataset{ DType: schemapb.DataType_Int8, Data: map[string]interface{}{ keyRawArr: f.Data, }, } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.Int16FieldData: - return &Dataset{ + ds := &Dataset{ DType: schemapb.DataType_Int16, Data: map[string]interface{}{ keyRawArr: f.Data, }, } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.Int32FieldData: - return &Dataset{ + ds := &Dataset{ DType: schemapb.DataType_Int32, Data: map[string]interface{}{ keyRawArr: f.Data, }, } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.Int64FieldData: - return &Dataset{ + ds := &Dataset{ DType: schemapb.DataType_Int64, Data: map[string]interface{}{ keyRawArr: f.Data, }, } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.FloatFieldData: - return &Dataset{ + ds := &Dataset{ DType: schemapb.DataType_Float, Data: map[string]interface{}{ keyRawArr: f.Data, }, } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.DoubleFieldData: - return &Dataset{ + ds := &Dataset{ DType: schemapb.DataType_Double, Data: map[string]interface{}{ keyRawArr: f.Data, }, } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.StringFieldData: - return &Dataset{ + ds := &Dataset{ DType: schemapb.DataType_VarChar, Data: map[string]interface{}{ keyRawArr: f.Data, }, } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.BinaryVectorFieldData: - return GenBinaryVecDataset(f.Data) + ds := &Dataset{ + DType: schemapb.DataType_BinaryVector, + Data: map[string]interface{}{ + keyRawArr: f.Data, + }, + } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.FloatVectorFieldData: - return GenFloatVecDataset(f.Data) + ds := &Dataset{ + DType: schemapb.DataType_FloatVector, + Data: map[string]interface{}{ + keyRawArr: f.Data, + }, + } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.Float16VectorFieldData: - return GenFloat16VecDataset(f.Data) + ds := &Dataset{ + DType: schemapb.DataType_Float16Vector, + Data: map[string]interface{}{ + keyRawArr: f.Data, + }, + } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.BFloat16VectorFieldData: - return GenBFloat16VecDataset(f.Data) + ds := &Dataset{ + DType: schemapb.DataType_BFloat16Vector, + Data: map[string]interface{}{ + keyRawArr: f.Data, + }, + } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.SparseFloatVectorFieldData: - return GenSparseFloatVecDataset(f) + ds := GenSparseFloatVecDataset(f) + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds case *storage.Int8VectorFieldData: - return GenInt8VecDataset(f.Data) + ds := &Dataset{ + DType: schemapb.DataType_Int8Vector, + Data: map[string]interface{}{ + keyRawArr: f.Data, + }, + } + if f.Nullable && len(f.ValidData) > 0 { + ds.Data[keyValidArr] = f.ValidData + } + return ds default: return &Dataset{ DType: schemapb.DataType_None, diff --git a/internal/util/indexcgowrapper/index.go b/internal/util/indexcgowrapper/index.go index 03da27b346..edaffcb6a1 100644 --- a/internal/util/indexcgowrapper/index.go +++ b/internal/util/indexcgowrapper/index.go @@ -238,36 +238,91 @@ func (index *CgoIndex) Build(dataset *Dataset) error { func (index *CgoIndex) buildFloatVecIndex(dataset *Dataset) error { vectors := dataset.Data[keyRawArr].([]float32) + if validData, ok := dataset.Data[keyValidArr].([]bool); ok && len(validData) > 0 { + status := C.BuildFloatVecIndexWithValidData( + index.indexPtr, + (C.int64_t)(len(vectors)), + (*C.float)(&vectors[0]), + (*C.bool)(&validData[0]), + (C.int64_t)(len(validData))) + return HandleCStatus(&status, "failed to build float vector index with valid data") + } status := C.BuildFloatVecIndex(index.indexPtr, (C.int64_t)(len(vectors)), (*C.float)(&vectors[0])) return HandleCStatus(&status, "failed to build float vector index") } func (index *CgoIndex) buildFloat16VecIndex(dataset *Dataset) error { vectors := dataset.Data[keyRawArr].([]byte) + if validData, ok := dataset.Data[keyValidArr].([]bool); ok && len(validData) > 0 { + status := C.BuildFloat16VecIndexWithValidData( + index.indexPtr, + (C.int64_t)(len(vectors)), + (*C.uint8_t)(&vectors[0]), + (*C.bool)(&validData[0]), + (C.int64_t)(len(validData))) + return HandleCStatus(&status, "failed to build float16 vector index with valid data") + } status := C.BuildFloat16VecIndex(index.indexPtr, (C.int64_t)(len(vectors)), (*C.uint8_t)(&vectors[0])) return HandleCStatus(&status, "failed to build float16 vector index") } func (index *CgoIndex) buildBFloat16VecIndex(dataset *Dataset) error { vectors := dataset.Data[keyRawArr].([]byte) + if validData, ok := dataset.Data[keyValidArr].([]bool); ok && len(validData) > 0 { + status := C.BuildBFloat16VecIndexWithValidData( + index.indexPtr, + (C.int64_t)(len(vectors)), + (*C.uint8_t)(&vectors[0]), + (*C.bool)(&validData[0]), + (C.int64_t)(len(validData))) + return HandleCStatus(&status, "failed to build bfloat16 vector index with valid data") + } status := C.BuildBFloat16VecIndex(index.indexPtr, (C.int64_t)(len(vectors)), (*C.uint8_t)(&vectors[0])) return HandleCStatus(&status, "failed to build bfloat16 vector index") } func (index *CgoIndex) buildSparseFloatVecIndex(dataset *Dataset) error { vectors := dataset.Data[keyRawArr].([]byte) + if validData, ok := dataset.Data[keyValidArr].([]bool); ok && len(validData) > 0 { + status := C.BuildSparseFloatVecIndexWithValidData( + index.indexPtr, + (C.int64_t)(len(validData)), + (C.int64_t)(0), + (*C.uint8_t)(&vectors[0]), + (*C.bool)(&validData[0]), + (C.int64_t)(len(validData))) + return HandleCStatus(&status, "failed to build sparse float vector index with valid data") + } status := C.BuildSparseFloatVecIndex(index.indexPtr, (C.int64_t)(len(vectors)), (C.int64_t)(0), (*C.uint8_t)(&vectors[0])) return HandleCStatus(&status, "failed to build sparse float vector index") } func (index *CgoIndex) buildBinaryVecIndex(dataset *Dataset) error { vectors := dataset.Data[keyRawArr].([]byte) + if validData, ok := dataset.Data[keyValidArr].([]bool); ok && len(validData) > 0 { + status := C.BuildBinaryVecIndexWithValidData( + index.indexPtr, + (C.int64_t)(len(vectors)), + (*C.uint8_t)(&vectors[0]), + (*C.bool)(&validData[0]), + (C.int64_t)(len(validData))) + return HandleCStatus(&status, "failed to build binary vector index with valid data") + } status := C.BuildBinaryVecIndex(index.indexPtr, (C.int64_t)(len(vectors)), (*C.uint8_t)(&vectors[0])) return HandleCStatus(&status, "failed to build binary vector index") } func (index *CgoIndex) buildInt8VecIndex(dataset *Dataset) error { vectors := dataset.Data[keyRawArr].([]int8) + if validData, ok := dataset.Data[keyValidArr].([]bool); ok && len(validData) > 0 { + status := C.BuildInt8VecIndexWithValidData( + index.indexPtr, + (C.int64_t)(len(vectors)), + (*C.int8_t)(&vectors[0]), + (*C.bool)(&validData[0]), + (C.int64_t)(len(validData))) + return HandleCStatus(&status, "failed to build int8 vector index with valid data") + } status := C.BuildInt8VecIndex(index.indexPtr, (C.int64_t)(len(vectors)), (*C.int8_t)(&vectors[0])) return HandleCStatus(&status, "failed to build int8 vector index") } diff --git a/pkg/mq/msgstream/msg.go b/pkg/mq/msgstream/msg.go index c53228d7b7..16d3a2198e 100644 --- a/pkg/mq/msgstream/msg.go +++ b/pkg/mq/msgstream/msg.go @@ -283,9 +283,12 @@ func (it *InsertMsg) rowBasedIndexRequest(index int) *msgpb.InsertRequest { } func (it *InsertMsg) columnBasedIndexRequest(index int) *msgpb.InsertRequest { - colNum := len(it.GetFieldsData()) - fieldsData := make([]*schemapb.FieldData, colNum) - typeutil.AppendFieldData(fieldsData, it.GetFieldsData(), int64(index)) + srcFieldsData := it.GetFieldsData() + fieldsData := make([]*schemapb.FieldData, len(srcFieldsData)) + idxComputer := typeutil.NewFieldDataIdxComputer(srcFieldsData) + vectorIdx := idxComputer.Compute(int64(index)) + + typeutil.AppendFieldData(fieldsData, srcFieldsData, int64(index), vectorIdx...) return &msgpb.InsertRequest{ Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_Insert), diff --git a/pkg/util/funcutil/func.go b/pkg/util/funcutil/func.go index fde131e4aa..27168f0436 100644 --- a/pkg/util/funcutil/func.go +++ b/pkg/util/funcutil/func.go @@ -453,39 +453,67 @@ func GetNumRowOfFieldDataWithSchema(fieldData *schemapb.FieldData, helper *typeu fieldNumRows = getNumRowsOfScalarField(fieldData.GetScalars().GetGeometryWktData().GetData()) } case schemapb.DataType_FloatVector: - dim := fieldData.GetVectors().GetDim() - fieldNumRows, err = GetNumRowsOfFloatVectorField(fieldData.GetVectors().GetFloatVector().GetData(), dim) - if err != nil { - return 0, err + if len(fieldData.GetValidData()) > 0 { + fieldNumRows = uint64(len(fieldData.GetValidData())) + } else { + dim := fieldData.GetVectors().GetDim() + fieldNumRows, err = GetNumRowsOfFloatVectorField(fieldData.GetVectors().GetFloatVector().GetData(), dim) + if err != nil { + return 0, err + } } case schemapb.DataType_BinaryVector: - dim := fieldData.GetVectors().GetDim() - fieldNumRows, err = GetNumRowsOfBinaryVectorField(fieldData.GetVectors().GetBinaryVector(), dim) - if err != nil { - return 0, err + if len(fieldData.GetValidData()) > 0 { + fieldNumRows = uint64(len(fieldData.GetValidData())) + } else { + dim := fieldData.GetVectors().GetDim() + fieldNumRows, err = GetNumRowsOfBinaryVectorField(fieldData.GetVectors().GetBinaryVector(), dim) + if err != nil { + return 0, err + } } case schemapb.DataType_Float16Vector: - dim := fieldData.GetVectors().GetDim() - fieldNumRows, err = GetNumRowsOfFloat16VectorField(fieldData.GetVectors().GetFloat16Vector(), dim) - if err != nil { - return 0, err + if len(fieldData.GetValidData()) > 0 { + fieldNumRows = uint64(len(fieldData.GetValidData())) + } else { + dim := fieldData.GetVectors().GetDim() + fieldNumRows, err = GetNumRowsOfFloat16VectorField(fieldData.GetVectors().GetFloat16Vector(), dim) + if err != nil { + return 0, err + } } case schemapb.DataType_BFloat16Vector: - dim := fieldData.GetVectors().GetDim() - fieldNumRows, err = GetNumRowsOfBFloat16VectorField(fieldData.GetVectors().GetBfloat16Vector(), dim) - if err != nil { - return 0, err + if len(fieldData.GetValidData()) > 0 { + fieldNumRows = uint64(len(fieldData.GetValidData())) + } else { + dim := fieldData.GetVectors().GetDim() + fieldNumRows, err = GetNumRowsOfBFloat16VectorField(fieldData.GetVectors().GetBfloat16Vector(), dim) + if err != nil { + return 0, err + } } case schemapb.DataType_SparseFloatVector: - fieldNumRows = uint64(len(fieldData.GetVectors().GetSparseFloatVector().GetContents())) + if len(fieldData.GetValidData()) > 0 { + fieldNumRows = uint64(len(fieldData.GetValidData())) + } else { + fieldNumRows = uint64(len(fieldData.GetVectors().GetSparseFloatVector().GetContents())) + } case schemapb.DataType_Int8Vector: - dim := fieldData.GetVectors().GetDim() - fieldNumRows, err = GetNumRowsOfInt8VectorField(fieldData.GetVectors().GetInt8Vector(), dim) - if err != nil { - return 0, err + if len(fieldData.GetValidData()) > 0 { + fieldNumRows = uint64(len(fieldData.GetValidData())) + } else { + dim := fieldData.GetVectors().GetDim() + fieldNumRows, err = GetNumRowsOfInt8VectorField(fieldData.GetVectors().GetInt8Vector(), dim) + if err != nil { + return 0, err + } } case schemapb.DataType_ArrayOfVector: - fieldNumRows = getNumRowsOfArrayVectorField(fieldData.GetVectors().GetVectorArray().GetData()) + if len(fieldData.GetValidData()) > 0 { + fieldNumRows = uint64(len(fieldData.GetValidData())) + } else { + fieldNumRows = getNumRowsOfArrayVectorField(fieldData.GetVectors().GetVectorArray().GetData()) + } default: return 0, fmt.Errorf("%s is not supported now", fieldSchema.GetDataType()) } @@ -525,6 +553,9 @@ func GetNumRowOfFieldData(fieldData *schemapb.FieldData) (uint64, error) { return 0, fmt.Errorf("%s is not supported now", scalarType) } case *schemapb.FieldData_Vectors: + if len(fieldData.GetValidData()) > 0 { + return uint64(len(fieldData.GetValidData())), nil + } vectorField := fieldData.GetVectors() switch vectorFieldType := vectorField.Data.(type) { case *schemapb.VectorField_FloatVector: diff --git a/pkg/util/testutils/gen_data.go b/pkg/util/testutils/gen_data.go index 946462eeb5..c8fd47b9ec 100644 --- a/pkg/util/testutils/gen_data.go +++ b/pkg/util/testutils/gen_data.go @@ -367,6 +367,21 @@ func GenerateInt8Vectors(numRows, dim int) []int8 { return ret } +func GenerateFloatVectorsWithInvalidData(numRows, dim int) []float32 { + total := numRows * dim + ret := make([]float32, 0, total) + for i := 0; i < total; i++ { + var f float32 + if i%2 == 0 { + f = float32(math.NaN()) + } else { + f = float32(math.Inf(1)) + } + ret = append(ret, f) + } + return ret +} + func GenerateBFloat16VectorsWithInvalidData(numRows, dim int) []byte { total := numRows * dim ret16 := make([]uint16, 0, total) diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 0e77e37e48..bab34c9895 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -268,9 +268,13 @@ func calcVectorSize(column *schemapb.VectorField, vectorType schemapb.DataType) return res } -func EstimateEntitySize(fieldsData []*schemapb.FieldData, rowOffset int) (int, error) { +func EstimateEntitySize(fieldsData []*schemapb.FieldData, rowOffset int, fieldIdxs ...int64) (int, error) { res := 0 - for _, fs := range fieldsData { + for i, fs := range fieldsData { + fieldIdx := int64(rowOffset) + if i < len(fieldIdxs) { + fieldIdx = fieldIdxs[i] + } switch fs.GetType() { case schemapb.DataType_Bool, schemapb.DataType_Int8: res++ @@ -304,27 +308,40 @@ func EstimateEntitySize(fieldsData []*schemapb.FieldData, rowOffset int) (int, e return 0, fmt.Errorf("offset out range of field datas") } res += len(fs.GetScalars().GetGeometryData().GetData()[rowOffset]) - case schemapb.DataType_BinaryVector: - res += int(fs.GetVectors().GetDim()) - case schemapb.DataType_FloatVector: - res += int(fs.GetVectors().GetDim() * 4) - case schemapb.DataType_Float16Vector: - res += int(fs.GetVectors().GetDim() * 2) - case schemapb.DataType_BFloat16Vector: - res += int(fs.GetVectors().GetDim() * 2) - case schemapb.DataType_SparseFloatVector: - vec := fs.GetVectors().GetSparseFloatVector() - // counting only the size of the vector data, ignoring other - // bytes used in proto. - res += len(vec.Contents[rowOffset]) - case schemapb.DataType_Int8Vector: - res += int(fs.GetVectors().GetDim()) + case schemapb.DataType_BinaryVector, + schemapb.DataType_FloatVector, + schemapb.DataType_Float16Vector, + schemapb.DataType_BFloat16Vector, + schemapb.DataType_Int8Vector, + schemapb.DataType_SparseFloatVector: + validData := fs.GetValidData() + isNullRow := len(validData) > 0 && rowOffset < len(validData) && !validData[rowOffset] + if isNullRow { + continue + } + switch fs.GetType() { + case schemapb.DataType_BinaryVector: + res += int(fs.GetVectors().GetDim() / 8) + case schemapb.DataType_FloatVector: + res += int(fs.GetVectors().GetDim() * 4) + case schemapb.DataType_Float16Vector: + res += int(fs.GetVectors().GetDim() * 2) + case schemapb.DataType_BFloat16Vector: + res += int(fs.GetVectors().GetDim() * 2) + case schemapb.DataType_SparseFloatVector: + vec := fs.GetVectors().GetSparseFloatVector() + if int(fieldIdx) < len(vec.Contents) { + res += len(vec.Contents[fieldIdx]) + } + case schemapb.DataType_Int8Vector: + res += int(fs.GetVectors().GetDim()) + } case schemapb.DataType_ArrayOfVector: arrayVector := fs.GetVectors().GetVectorArray() - if rowOffset >= len(arrayVector.GetData()) { + if int(fieldIdx) >= len(arrayVector.GetData()) { return 0, errors.New("offset out range of field datas") } - res += calcVectorSize(arrayVector.GetData()[rowOffset], arrayVector.GetElementType()) + res += calcVectorSize(arrayVector.GetData()[fieldIdx], arrayVector.GetElementType()) default: panic("Unknown data type:" + fs.GetType().String()) } @@ -820,7 +837,56 @@ func PrepareResultFieldData(sample []*schemapb.FieldData, topK int64) []*schemap return result } -func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int64) { +type FieldDataIdxComputer struct { + fieldsData []*schemapb.FieldData + lastRowIdx int64 + dataIndices []int64 + isVector []bool + resultBuffer []int64 +} + +func NewFieldDataIdxComputer(fieldsData []*schemapb.FieldData) *FieldDataIdxComputer { + c := &FieldDataIdxComputer{ + fieldsData: fieldsData, + lastRowIdx: 0, + dataIndices: make([]int64, len(fieldsData)), + isVector: make([]bool, len(fieldsData)), + resultBuffer: make([]int64, len(fieldsData)), + } + for i, fieldData := range fieldsData { + validData := fieldData.GetValidData() + c.isVector[i] = len(validData) > 0 && IsVectorType(fieldData.Type) + } + return c +} + +func (c *FieldDataIdxComputer) Compute(rowIdx int64) []int64 { + if rowIdx < c.lastRowIdx { + c.lastRowIdx = 0 + for i := range c.dataIndices { + c.dataIndices[i] = 0 + } + } + + for i, fieldData := range c.fieldsData { + if c.isVector[i] { + validData := fieldData.GetValidData() + for j := c.lastRowIdx; j < rowIdx && j < int64(len(validData)); j++ { + if validData[j] { + c.dataIndices[i]++ + } + } + c.resultBuffer[i] = c.dataIndices[i] + } else { + c.resultBuffer[i] = rowIdx + } + } + + c.lastRowIdx = rowIdx + return c.resultBuffer +} + +func AppendFieldData(dst, src []*schemapb.FieldData, idx int64, fieldIdxs ...int64) (appendSize int64) { dstMap := make(map[int64]*schemapb.FieldData) for _, fieldData := range dst { if fieldData != nil { @@ -828,6 +894,10 @@ func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int6 } } for i, fieldData := range src { + fieldIdx := idx + if i < len(fieldIdxs) { + fieldIdx = fieldIdxs[i] + } dstFieldData, ok := dstMap[fieldData.FieldId] if !ok { dstFieldData = &schemapb.FieldData{ @@ -997,96 +1067,112 @@ func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int6 } } dstVector := dstFieldData.GetVectors() + isNullRow := len(fieldData.GetValidData()) > 0 && !fieldData.GetValidData()[idx] + switch srcVector := fieldType.Vectors.Data.(type) { case *schemapb.VectorField_BinaryVector: - if dstVector.GetBinaryVector() == nil { - srcToCopy := srcVector.BinaryVector[idx*(dim/8) : (idx+1)*(dim/8)] - dstVector.Data = &schemapb.VectorField_BinaryVector{ - BinaryVector: make([]byte, len(srcToCopy)), + if !isNullRow { + if dstVector.GetBinaryVector() == nil { + srcToCopy := srcVector.BinaryVector[fieldIdx*(dim/8) : (fieldIdx+1)*(dim/8)] + dstVector.Data = &schemapb.VectorField_BinaryVector{ + BinaryVector: make([]byte, len(srcToCopy)), + } + copy(dstVector.Data.(*schemapb.VectorField_BinaryVector).BinaryVector, srcToCopy) + } else { + dstBinaryVector := dstVector.Data.(*schemapb.VectorField_BinaryVector) + dstBinaryVector.BinaryVector = append(dstBinaryVector.BinaryVector, srcVector.BinaryVector[fieldIdx*(dim/8):(fieldIdx+1)*(dim/8)]...) } - copy(dstVector.Data.(*schemapb.VectorField_BinaryVector).BinaryVector, srcToCopy) - } else { - dstBinaryVector := dstVector.Data.(*schemapb.VectorField_BinaryVector) - dstBinaryVector.BinaryVector = append(dstBinaryVector.BinaryVector, srcVector.BinaryVector[idx*(dim/8):(idx+1)*(dim/8)]...) + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcVector.BinaryVector[fieldIdx*(dim/8) : (fieldIdx+1)*(dim/8)])) } - /* #nosec G103 */ - appendSize += int64(unsafe.Sizeof(srcVector.BinaryVector[idx*(dim/8) : (idx+1)*(dim/8)])) case *schemapb.VectorField_FloatVector: - if dstVector.GetFloatVector() == nil { - srcToCopy := srcVector.FloatVector.Data[idx*dim : (idx+1)*dim] - dstVector.Data = &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: make([]float32, len(srcToCopy)), - }, + if !isNullRow { + if dstVector.GetFloatVector() == nil { + srcToCopy := srcVector.FloatVector.Data[fieldIdx*dim : (fieldIdx+1)*dim] + dstVector.Data = &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: make([]float32, len(srcToCopy)), + }, + } + copy(dstVector.Data.(*schemapb.VectorField_FloatVector).FloatVector.Data, srcToCopy) + } else { + dstVector.GetFloatVector().Data = append(dstVector.GetFloatVector().Data, srcVector.FloatVector.Data[fieldIdx*dim:(fieldIdx+1)*dim]...) } - copy(dstVector.Data.(*schemapb.VectorField_FloatVector).FloatVector.Data, srcToCopy) - } else { - dstVector.GetFloatVector().Data = append(dstVector.GetFloatVector().Data, srcVector.FloatVector.Data[idx*dim:(idx+1)*dim]...) + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcVector.FloatVector.Data[fieldIdx*dim : (fieldIdx+1)*dim])) } - /* #nosec G103 */ - appendSize += int64(unsafe.Sizeof(srcVector.FloatVector.Data[idx*dim : (idx+1)*dim])) case *schemapb.VectorField_Float16Vector: - if dstVector.GetFloat16Vector() == nil { - srcToCopy := srcVector.Float16Vector[idx*(dim*2) : (idx+1)*(dim*2)] - dstVector.Data = &schemapb.VectorField_Float16Vector{ - Float16Vector: make([]byte, len(srcToCopy)), + if !isNullRow { + if dstVector.GetFloat16Vector() == nil { + srcToCopy := srcVector.Float16Vector[fieldIdx*(dim*2) : (fieldIdx+1)*(dim*2)] + dstVector.Data = &schemapb.VectorField_Float16Vector{ + Float16Vector: make([]byte, len(srcToCopy)), + } + copy(dstVector.Data.(*schemapb.VectorField_Float16Vector).Float16Vector, srcToCopy) + } else { + dstFloat16Vector := dstVector.Data.(*schemapb.VectorField_Float16Vector) + dstFloat16Vector.Float16Vector = append(dstFloat16Vector.Float16Vector, srcVector.Float16Vector[fieldIdx*(dim*2):(fieldIdx+1)*(dim*2)]...) } - copy(dstVector.Data.(*schemapb.VectorField_Float16Vector).Float16Vector, srcToCopy) - } else { - dstFloat16Vector := dstVector.Data.(*schemapb.VectorField_Float16Vector) - dstFloat16Vector.Float16Vector = append(dstFloat16Vector.Float16Vector, srcVector.Float16Vector[idx*(dim*2):(idx+1)*(dim*2)]...) + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcVector.Float16Vector[fieldIdx*(dim*2) : (fieldIdx+1)*(dim*2)])) } - /* #nosec G103 */ - appendSize += int64(unsafe.Sizeof(srcVector.Float16Vector[idx*(dim*2) : (idx+1)*(dim*2)])) case *schemapb.VectorField_Bfloat16Vector: - if dstVector.GetBfloat16Vector() == nil { - srcToCopy := srcVector.Bfloat16Vector[idx*(dim*2) : (idx+1)*(dim*2)] - dstVector.Data = &schemapb.VectorField_Bfloat16Vector{ - Bfloat16Vector: make([]byte, len(srcToCopy)), + if !isNullRow { + if dstVector.GetBfloat16Vector() == nil { + srcToCopy := srcVector.Bfloat16Vector[fieldIdx*(dim*2) : (fieldIdx+1)*(dim*2)] + dstVector.Data = &schemapb.VectorField_Bfloat16Vector{ + Bfloat16Vector: make([]byte, len(srcToCopy)), + } + copy(dstVector.Data.(*schemapb.VectorField_Bfloat16Vector).Bfloat16Vector, srcToCopy) + } else { + dstBfloat16Vector := dstVector.Data.(*schemapb.VectorField_Bfloat16Vector) + dstBfloat16Vector.Bfloat16Vector = append(dstBfloat16Vector.Bfloat16Vector, srcVector.Bfloat16Vector[fieldIdx*(dim*2):(fieldIdx+1)*(dim*2)]...) } - copy(dstVector.Data.(*schemapb.VectorField_Bfloat16Vector).Bfloat16Vector, srcToCopy) - } else { - dstBfloat16Vector := dstVector.Data.(*schemapb.VectorField_Bfloat16Vector) - dstBfloat16Vector.Bfloat16Vector = append(dstBfloat16Vector.Bfloat16Vector, srcVector.Bfloat16Vector[idx*(dim*2):(idx+1)*(dim*2)]...) + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcVector.Bfloat16Vector[fieldIdx*(dim*2) : (fieldIdx+1)*(dim*2)])) } - /* #nosec G103 */ - appendSize += int64(unsafe.Sizeof(srcVector.Bfloat16Vector[idx*(dim*2) : (idx+1)*(dim*2)])) case *schemapb.VectorField_SparseFloatVector: - if dstVector.GetSparseFloatVector() == nil { - dstVector.Data = &schemapb.VectorField_SparseFloatVector{ - SparseFloatVector: &schemapb.SparseFloatArray{ - Dim: 0, - Contents: make([][]byte, 0), - }, + if !isNullRow { + if dstVector.GetSparseFloatVector() == nil { + dstVector.Data = &schemapb.VectorField_SparseFloatVector{ + SparseFloatVector: &schemapb.SparseFloatArray{ + Dim: 0, + Contents: make([][]byte, 0), + }, + } + dstVector.Dim = srcVector.SparseFloatVector.Dim } - dstVector.Dim = srcVector.SparseFloatVector.Dim + vec := dstVector.Data.(*schemapb.VectorField_SparseFloatVector).SparseFloatVector + appendSize += appendSparseFloatArraySingleRow(vec, srcVector.SparseFloatVector, fieldIdx) } - vec := dstVector.Data.(*schemapb.VectorField_SparseFloatVector).SparseFloatVector - appendSize += appendSparseFloatArraySingleRow(vec, srcVector.SparseFloatVector, idx) case *schemapb.VectorField_Int8Vector: - if dstVector.GetInt8Vector() == nil { - srcToCopy := srcVector.Int8Vector[idx*dim : (idx+1)*dim] - dstVector.Data = &schemapb.VectorField_Int8Vector{ - Int8Vector: make([]byte, len(srcToCopy)), + if !isNullRow { + if dstVector.GetInt8Vector() == nil { + srcToCopy := srcVector.Int8Vector[fieldIdx*dim : (fieldIdx+1)*dim] + dstVector.Data = &schemapb.VectorField_Int8Vector{ + Int8Vector: make([]byte, len(srcToCopy)), + } + copy(dstVector.Data.(*schemapb.VectorField_Int8Vector).Int8Vector, srcToCopy) + } else { + dstInt8Vector := dstVector.Data.(*schemapb.VectorField_Int8Vector) + dstInt8Vector.Int8Vector = append(dstInt8Vector.Int8Vector, srcVector.Int8Vector[fieldIdx*dim:(fieldIdx+1)*dim]...) } - copy(dstVector.Data.(*schemapb.VectorField_Int8Vector).Int8Vector, srcToCopy) - } else { - dstInt8Vector := dstVector.Data.(*schemapb.VectorField_Int8Vector) - dstInt8Vector.Int8Vector = append(dstInt8Vector.Int8Vector, srcVector.Int8Vector[idx*dim:(idx+1)*dim]...) + /* #nosec G103 */ + appendSize += int64(unsafe.Sizeof(srcVector.Int8Vector[fieldIdx*dim : (fieldIdx+1)*dim])) } - /* #nosec G103 */ - appendSize += int64(unsafe.Sizeof(srcVector.Int8Vector[idx*dim : (idx+1)*dim])) case *schemapb.VectorField_VectorArray: - if dstVector.GetVectorArray() == nil { - dstVector.Data = &schemapb.VectorField_VectorArray{ - VectorArray: &schemapb.VectorArray{ - Data: []*schemapb.VectorField{srcVector.VectorArray.Data[idx]}, - Dim: srcVector.VectorArray.Dim, - ElementType: srcVector.VectorArray.ElementType, - }, + if !isNullRow { + if dstVector.GetVectorArray() == nil { + dstVector.Data = &schemapb.VectorField_VectorArray{ + VectorArray: &schemapb.VectorArray{ + Data: []*schemapb.VectorField{srcVector.VectorArray.Data[fieldIdx]}, + Dim: srcVector.VectorArray.Dim, + ElementType: srcVector.VectorArray.ElementType, + }, + } + } else { + dstVector.GetVectorArray().Data = append(dstVector.GetVectorArray().Data, srcVector.VectorArray.Data[fieldIdx]) } - } else { - dstVector.GetVectorArray().Data = append(dstVector.GetVectorArray().Data, srcVector.VectorArray.Data[idx]) } } } @@ -1512,7 +1598,10 @@ func MergeFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData) error if _, ok := fieldID2Data[srcFieldData.FieldId]; !ok { return errors.New("fields in src but not in dst: " + srcFieldData.Type.String()) } - dstVector := fieldID2Data[srcFieldData.FieldId].GetVectors() + fieldData := fieldID2Data[srcFieldData.FieldId] + // Merge ValidData for nullable vectors + fieldData.ValidData = append(fieldData.ValidData, srcFieldData.GetValidData()...) + dstVector := fieldData.GetVectors() switch srcVector := fieldType.Vectors.Data.(type) { case *schemapb.VectorField_BinaryVector: if dstVector.GetBinaryVector() == nil { diff --git a/tests/go_client/base/milvus_client.go b/tests/go_client/base/milvus_client.go index 72df0ea812..d7903afc2f 100644 --- a/tests/go_client/base/milvus_client.go +++ b/tests/go_client/base/milvus_client.go @@ -9,6 +9,7 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" + "github.com/milvus-io/milvus/client/v2/entity" client "github.com/milvus-io/milvus/client/v2/milvusclient" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" @@ -81,3 +82,13 @@ func (mc *MilvusClient) Close(ctx context.Context) error { err := mc.Client.Close(ctx) return err } + +func (mc *MilvusClient) Compact(ctx context.Context, option client.CompactOption, callOptions ...grpc.CallOption) (int64, error) { + compactID, err := mc.Client.Compact(ctx, option, callOptions...) + return compactID, err +} + +func (mc *MilvusClient) GetCompactionState(ctx context.Context, option client.GetCompactionStateOption, callOptions ...grpc.CallOption) (entity.CompactionState, error) { + state, err := mc.Client.GetCompactionState(ctx, option, callOptions...) + return state, err +} diff --git a/tests/go_client/testcases/add_field_test.go b/tests/go_client/testcases/add_field_test.go index efd429b1cc..094aa2837e 100644 --- a/tests/go_client/testcases/add_field_test.go +++ b/tests/go_client/testcases/add_field_test.go @@ -96,14 +96,14 @@ func TestAddCollectionFieldInvalid(t *testing.T) { expectedError: "type param(max_length) should be specified for the field(" + common.DefaultNewField + ") of collection", }, { - name: "addVectorField", + name: "addVectorFieldWithoutNullable", setupCollection: func(collName string) error { return mc.CreateCollection(ctx, client.SimpleCreateCollectionOptions(collName, common.DefaultDim)) }, fieldBuilder: func() *entity.Field { - return entity.NewField().WithName(common.DefaultNewField).WithDataType(entity.FieldTypeFloatVector).WithNullable(true) + return entity.NewField().WithName(common.DefaultNewField).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim) }, - expectedError: "not support to add vector field, field name = " + common.DefaultNewField + ": invalid parameter", + expectedError: "adding vector field to existing collection requires nullable=true, field name = " + common.DefaultNewField, }, { name: "addFieldAsPrimary", diff --git a/tests/go_client/testcases/nullable_default_value_test.go b/tests/go_client/testcases/nullable_default_value_test.go index cdc6ecefa1..01e53f0334 100644 --- a/tests/go_client/testcases/nullable_default_value_test.go +++ b/tests/go_client/testcases/nullable_default_value_test.go @@ -3,6 +3,7 @@ package testcases import ( "fmt" "math" + "strings" "testing" "time" @@ -18,6 +19,356 @@ import ( hp "github.com/milvus-io/milvus/tests/go_client/testcases/helper" ) +func int64SliceToString(ids []int64) string { + strs := make([]string, len(ids)) + for i, id := range ids { + strs[i] = fmt.Sprintf("%d", id) + } + return strings.Join(strs, ", ") +} + +type NullableVectorType struct { + Name string + FieldType entity.FieldType +} + +func GetVectorTypes() []NullableVectorType { + return []NullableVectorType{ + {"FloatVector", entity.FieldTypeFloatVector}, + {"BinaryVector", entity.FieldTypeBinaryVector}, + {"Float16Vector", entity.FieldTypeFloat16Vector}, + {"BFloat16Vector", entity.FieldTypeBFloat16Vector}, + {"Int8Vector", entity.FieldTypeInt8Vector}, + {"SparseVector", entity.FieldTypeSparseVector}, + } +} + +func GetNullPercents() []int { + return []int{0, 30} +} + +type NullableVectorTestData struct { + ValidData []bool + ValidCount int + PkToVecIdx map[int64]int + OriginalVectors interface{} + VecColumn column.Column + SearchVec entity.Vector +} + +func GenerateNullableVectorTestData(t *testing.T, vt NullableVectorType, nb int, nullPercent int, fieldName string) *NullableVectorTestData { + data := &NullableVectorTestData{ + ValidData: make([]bool, nb), + PkToVecIdx: make(map[int64]int), + } + + for i := range nb { + data.ValidData[i] = (i % 100) >= nullPercent + if data.ValidData[i] { + data.ValidCount++ + } + } + + vecIdx := 0 + for i := range nb { + if data.ValidData[i] { + data.PkToVecIdx[int64(i)] = vecIdx + vecIdx++ + } + } + + var err error + switch vt.FieldType { + case entity.FieldTypeFloatVector: + vectors := make([][]float32, data.ValidCount) + for i := range data.ValidCount { + vec := make([]float32, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = float32(i*common.DefaultDim+j) / 10000.0 + } + vectors[i] = vec + } + data.OriginalVectors = vectors + data.VecColumn, err = column.NewNullableColumnFloatVector(fieldName, common.DefaultDim, vectors, data.ValidData) + if data.ValidCount > 0 { + data.SearchVec = entity.FloatVector(vectors[0]) + } + + case entity.FieldTypeBinaryVector: + byteDim := common.DefaultDim / 8 + vectors := make([][]byte, data.ValidCount) + for i := range data.ValidCount { + vec := make([]byte, byteDim) + for j := range byteDim { + vec[j] = byte((i + j) % 256) + } + vectors[i] = vec + } + data.OriginalVectors = vectors + data.VecColumn, err = column.NewNullableColumnBinaryVector(fieldName, common.DefaultDim, vectors, data.ValidData) + if data.ValidCount > 0 { + data.SearchVec = entity.BinaryVector(vectors[0]) + } + + case entity.FieldTypeFloat16Vector: + vectors := make([][]byte, data.ValidCount) + for i := range data.ValidCount { + vectors[i] = common.GenFloat16Vector(common.DefaultDim) + } + data.OriginalVectors = vectors + data.VecColumn, err = column.NewNullableColumnFloat16Vector(fieldName, common.DefaultDim, vectors, data.ValidData) + if data.ValidCount > 0 { + data.SearchVec = entity.Float16Vector(vectors[0]) + } + + case entity.FieldTypeBFloat16Vector: + vectors := make([][]byte, data.ValidCount) + for i := range data.ValidCount { + vectors[i] = common.GenBFloat16Vector(common.DefaultDim) + } + data.OriginalVectors = vectors + data.VecColumn, err = column.NewNullableColumnBFloat16Vector(fieldName, common.DefaultDim, vectors, data.ValidData) + if data.ValidCount > 0 { + data.SearchVec = entity.BFloat16Vector(vectors[0]) + } + + case entity.FieldTypeInt8Vector: + vectors := make([][]int8, data.ValidCount) + for i := range data.ValidCount { + vec := make([]int8, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = int8((i + j) % 127) + } + vectors[i] = vec + } + data.OriginalVectors = vectors + data.VecColumn, err = column.NewNullableColumnInt8Vector(fieldName, common.DefaultDim, vectors, data.ValidData) + if data.ValidCount > 0 { + data.SearchVec = entity.Int8Vector(vectors[0]) + } + + case entity.FieldTypeSparseVector: + vectors := make([]entity.SparseEmbedding, data.ValidCount) + for i := range data.ValidCount { + positions := []uint32{0, uint32(i + 1), uint32(i + 1000)} + values := []float32{1.0, float32(i+1) / 1000.0, 0.1} + vectors[i], err = entity.NewSliceSparseEmbedding(positions, values) + common.CheckErr(t, err, true) + } + data.OriginalVectors = vectors + data.VecColumn, err = column.NewNullableColumnSparseFloatVector(fieldName, vectors, data.ValidData) + if data.ValidCount > 0 { + data.SearchVec = vectors[0] + } + } + common.CheckErr(t, err, true) + + return data +} + +type IndexConfig struct { + Name string + IndexType string + MetricType entity.MetricType + Params map[string]string +} + +func GetIndexesForVectorType(fieldType entity.FieldType) []IndexConfig { + switch fieldType { + case entity.FieldTypeFloatVector: + return []IndexConfig{ + {"FLAT", "FLAT", entity.L2, nil}, + {"IVF_FLAT", "IVF_FLAT", entity.L2, map[string]string{"nlist": "128"}}, + {"IVF_SQ8", "IVF_SQ8", entity.L2, map[string]string{"nlist": "128"}}, + {"IVF_PQ", "IVF_PQ", entity.L2, map[string]string{"nlist": "128", "m": "8", "nbits": "8"}}, + {"HNSW", "HNSW", entity.L2, map[string]string{"M": "16", "efConstruction": "200"}}, + {"SCANN", "SCANN", entity.L2, map[string]string{"nlist": "128", "with_raw_data": "true"}}, + // {"DISKANN", "DISKANN", entity.L2, nil}, // Skip DISKANN for now + } + case entity.FieldTypeBinaryVector: + return []IndexConfig{ + {"BIN_FLAT", "BIN_FLAT", entity.JACCARD, nil}, + {"BIN_IVF_FLAT", "BIN_IVF_FLAT", entity.JACCARD, map[string]string{"nlist": "128"}}, + } + case entity.FieldTypeFloat16Vector, entity.FieldTypeBFloat16Vector: + return []IndexConfig{ + {"FLAT", "FLAT", entity.L2, nil}, + {"IVF_FLAT", "IVF_FLAT", entity.L2, map[string]string{"nlist": "128"}}, + {"IVF_SQ8", "IVF_SQ8", entity.L2, map[string]string{"nlist": "128"}}, + {"HNSW", "HNSW", entity.L2, map[string]string{"M": "16", "efConstruction": "200"}}, + } + case entity.FieldTypeInt8Vector: + return []IndexConfig{ + {"HNSW", "HNSW", entity.COSINE, map[string]string{"M": "16", "efConstruction": "200"}}, + } + case entity.FieldTypeSparseVector: + return []IndexConfig{ + {"SPARSE_INVERTED_INDEX", "SPARSE_INVERTED_INDEX", entity.IP, map[string]string{"drop_ratio_build": "0.1"}}, + {"SPARSE_WAND", "SPARSE_WAND", entity.IP, map[string]string{"drop_ratio_build": "0.1"}}, + } + default: + return []IndexConfig{ + {"FLAT", "FLAT", entity.L2, nil}, + } + } +} + +func CreateIndexFromConfig(fieldName string, cfg IndexConfig) index.Index { + params := map[string]string{ + index.MetricTypeKey: string(cfg.MetricType), + index.IndexTypeKey: cfg.IndexType, + } + for k, v := range cfg.Params { + params[k] = v + } + return index.NewGenericIndex(fieldName, params) +} + +func CreateNullableVectorIndex(vt NullableVectorType) index.Index { + return CreateNullableVectorIndexWithFieldName(vt, "vector") +} + +func CreateNullableVectorIndexWithFieldName(vt NullableVectorType, fieldName string) index.Index { + indexes := GetIndexesForVectorType(vt.FieldType) + if len(indexes) > 0 { + return CreateIndexFromConfig(fieldName, indexes[0]) + } + return index.NewGenericIndex(fieldName, map[string]string{ + index.MetricTypeKey: string(entity.L2), + index.IndexTypeKey: "FLAT", + }) +} + +func VerifyNullableVectorData(t *testing.T, vt NullableVectorType, queryResult client.ResultSet, pkToVecIdx map[int64]int, originalVectors interface{}, context string) { + pkCol := queryResult.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64) + vecCol := queryResult.GetColumn("vector") + for i := 0; i < queryResult.ResultCount; i++ { + pk, _ := pkCol.GetAsInt64(i) + isNull, _ := vecCol.IsNull(i) + + if origIdx, ok := pkToVecIdx[pk]; ok { + require.False(t, isNull, "%s: vector should not be null for pk %d", context, pk) + vecData, _ := vecCol.Get(i) + + switch vt.FieldType { + case entity.FieldTypeFloatVector: + vectors := originalVectors.([][]float32) + queriedVec := []float32(vecData.(entity.FloatVector)) + require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk) + origVec := vectors[origIdx] + for j := range origVec { + require.InDelta(t, origVec[j], queriedVec[j], 1e-6, "%s: vector element %d should match for pk %d", context, j, pk) + } + case entity.FieldTypeInt8Vector: + vectors := originalVectors.([][]int8) + queriedVec := []int8(vecData.(entity.Int8Vector)) + require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk) + origVec := vectors[origIdx] + for j := range origVec { + require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector element %d should match for pk %d", context, j, pk) + } + case entity.FieldTypeBinaryVector: + vectors := originalVectors.([][]byte) + queriedVec := []byte(vecData.(entity.BinaryVector)) + byteDim := common.DefaultDim / 8 + require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk) + origVec := vectors[origIdx] + for j := range origVec { + require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector byte %d should match for pk %d", context, j, pk) + } + case entity.FieldTypeFloat16Vector: + queriedVec := []byte(vecData.(entity.Float16Vector)) + byteDim := common.DefaultDim * 2 + require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk) + case entity.FieldTypeBFloat16Vector: + queriedVec := []byte(vecData.(entity.BFloat16Vector)) + byteDim := common.DefaultDim * 2 + require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk) + case entity.FieldTypeSparseVector: + vectors := originalVectors.([]entity.SparseEmbedding) + queriedVec := vecData.(entity.SparseEmbedding) + origVec := vectors[origIdx] + require.EqualValues(t, origVec.Len(), queriedVec.Len(), "%s: sparse vector length should match for pk %d", context, pk) + for j := 0; j < origVec.Len(); j++ { + origPos, origVal, _ := origVec.Get(j) + queriedPos, queriedVal, _ := queriedVec.Get(j) + require.EqualValues(t, origPos, queriedPos, "%s: sparse vector position %d should match for pk %d", context, j, pk) + require.InDelta(t, origVal, queriedVal, 1e-6, "%s: sparse vector value %d should match for pk %d", context, j, pk) + } + } + } else { + require.True(t, isNull, "%s: vector should be null for pk %d", context, pk) + vecData, _ := vecCol.Get(i) + require.Nil(t, vecData, "%s: null vector data should be nil for pk %d", context, pk) + } + } +} + +func VerifyNullableVectorDataWithFieldName(t *testing.T, vt NullableVectorType, queryResult client.ResultSet, pkToVecIdx map[int64]int, originalVectors interface{}, fieldName string, context string) { + pkCol := queryResult.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64) + vecCol := queryResult.GetColumn(fieldName) + for i := 0; i < queryResult.ResultCount; i++ { + pk, _ := pkCol.GetAsInt64(i) + isNull, _ := vecCol.IsNull(i) + + if origIdx, ok := pkToVecIdx[pk]; ok { + require.False(t, isNull, "%s: vector should not be null for pk %d", context, pk) + vecData, _ := vecCol.Get(i) + + switch vt.FieldType { + case entity.FieldTypeFloatVector: + vectors := originalVectors.([][]float32) + queriedVec := []float32(vecData.(entity.FloatVector)) + require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk) + origVec := vectors[origIdx] + for j := range origVec { + require.InDelta(t, origVec[j], queriedVec[j], 1e-6, "%s: vector element %d should match for pk %d", context, j, pk) + } + case entity.FieldTypeInt8Vector: + vectors := originalVectors.([][]int8) + queriedVec := []int8(vecData.(entity.Int8Vector)) + require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk) + origVec := vectors[origIdx] + for j := range origVec { + require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector element %d should match for pk %d", context, j, pk) + } + case entity.FieldTypeBinaryVector: + vectors := originalVectors.([][]byte) + queriedVec := []byte(vecData.(entity.BinaryVector)) + byteDim := common.DefaultDim / 8 + require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk) + origVec := vectors[origIdx] + for j := range origVec { + require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector byte %d should match for pk %d", context, j, pk) + } + case entity.FieldTypeFloat16Vector: + queriedVec := []byte(vecData.(entity.Float16Vector)) + byteDim := common.DefaultDim * 2 + require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk) + case entity.FieldTypeBFloat16Vector: + queriedVec := []byte(vecData.(entity.BFloat16Vector)) + byteDim := common.DefaultDim * 2 + require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk) + case entity.FieldTypeSparseVector: + vectors := originalVectors.([]entity.SparseEmbedding) + queriedVec := vecData.(entity.SparseEmbedding) + origVec := vectors[origIdx] + require.EqualValues(t, origVec.Len(), queriedVec.Len(), "%s: sparse vector length should match for pk %d", context, pk) + for j := 0; j < origVec.Len(); j++ { + origPos, origVal, _ := origVec.Get(j) + queriedPos, queriedVal, _ := queriedVec.Get(j) + require.EqualValues(t, origPos, queriedPos, "%s: sparse vector position %d should match for pk %d", context, j, pk) + require.InDelta(t, origVal, queriedVal, 1e-6, "%s: sparse vector value %d should match for pk %d", context, j, pk) + } + } + } else { + require.True(t, isNull, "%s: vector should be null for pk %d", context, pk) + vecData, _ := vecCol.Get(i) + require.Nil(t, vecData, "%s: null vector data should be nil for pk %d", context, pk) + } + } +} + // create collection with nullable fields and insert with column / nullableColumn func TestNullableDefault(t *testing.T) { ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) @@ -189,16 +540,16 @@ func TestNullableInvalid(t *testing.T) { err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema)) common.CheckErr(t, err, false, "primary field not support null") - // vector type not support null - notSupportedNullableDataTypes := []entity.FieldType{entity.FieldTypeFloatVector, entity.FieldTypeBinaryVector, entity.FieldTypeFloat16Vector, entity.FieldTypeBFloat16Vector, entity.FieldTypeSparseVector, entity.FieldTypeInt8Vector} - for _, fieldType := range notSupportedNullableDataTypes { + supportedNullableVectorTypes := []entity.FieldType{entity.FieldTypeFloatVector, entity.FieldTypeBinaryVector, entity.FieldTypeFloat16Vector, entity.FieldTypeBFloat16Vector, entity.FieldTypeSparseVector, entity.FieldTypeInt8Vector} + for _, fieldType := range supportedNullableVectorTypes { nullableVectorField := entity.NewField().WithName(common.GenRandomString("null", 3)).WithDataType(fieldType).WithNullable(true) if fieldType != entity.FieldTypeSparseVector { nullableVectorField.WithDim(128) } - schema := entity.NewSchema().WithName(common.GenRandomString("nullable_invalid_field", 5)).WithField(pkField).WithField(nullableVectorField) + schema := entity.NewSchema().WithName(common.GenRandomString("nullable_vector", 5)).WithField(pkField).WithField(nullableVectorField) err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(schema.CollectionName, schema)) - common.CheckErr(t, err, false, "vector type not support null") + common.CheckErr(t, err, true) + mc.DropCollection(ctx, client.NewDropCollectionOption(schema.CollectionName)) } // partition-key field not support null @@ -891,3 +1242,2041 @@ func TestNullableRows(t *testing.T) { count, _ := countRes.Fields[0].GetAsInt64(0) require.EqualValues(t, common.DefaultNb/2, count) } + +func TestNullableVectorAllTypes(t *testing.T) { + vectorTypes := GetVectorTypes() + nullPercents := GetNullPercents() + + for _, vt := range vectorTypes { + for _, nullPercent := range nullPercents { + testName := fmt.Sprintf("%s_%d%%null", vt.Name, nullPercent) + t.Run(testName, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create collection + collName := common.GenRandomString("nullable_vec", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true) + if vt.FieldType != entity.FieldTypeSparseVector { + vecField = vecField.WithDim(common.DefaultDim) + } + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + nb := 500 + validData := make([]bool, nb) + validCount := 0 + for i := range nb { + validData[i] = (i % 100) >= nullPercent + if validData[i] { + validCount++ + } + } + + pkData := make([]int64, nb) + for i := range nb { + pkData[i] = int64(i) + } + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) + + pkToVecIdx := make(map[int64]int) + vecIdx := 0 + for i := range nb { + if validData[i] { + pkToVecIdx[int64(i)] = vecIdx + vecIdx++ + } + } + + var vecColumn column.Column + var searchVec entity.Vector + var originalVectors interface{} + + switch vt.FieldType { + case entity.FieldTypeFloatVector: + vectors := make([][]float32, validCount) + for i := range validCount { + vec := make([]float32, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = float32(i*common.DefaultDim+j) / 10000.0 + } + vectors[i] = vec + } + originalVectors = vectors + vecColumn, err = column.NewNullableColumnFloatVector("vector", common.DefaultDim, vectors, validData) + if validCount > 0 { + searchVec = entity.FloatVector(vectors[0]) + } + + case entity.FieldTypeBinaryVector: + vectors := make([][]byte, validCount) + byteDim := common.DefaultDim / 8 + for i := range validCount { + vec := make([]byte, byteDim) + for j := range byteDim { + vec[j] = byte((i + j) % 256) + } + vectors[i] = vec + } + originalVectors = vectors + vecColumn, err = column.NewNullableColumnBinaryVector("vector", common.DefaultDim, vectors, validData) + if validCount > 0 { + searchVec = entity.BinaryVector(vectors[0]) + } + + case entity.FieldTypeFloat16Vector: + vectors := make([][]byte, validCount) + for i := range validCount { + vectors[i] = common.GenFloat16Vector(common.DefaultDim) + } + originalVectors = vectors + vecColumn, err = column.NewNullableColumnFloat16Vector("vector", common.DefaultDim, vectors, validData) + if validCount > 0 { + searchVec = entity.Float16Vector(vectors[0]) + } + + case entity.FieldTypeBFloat16Vector: + vectors := make([][]byte, validCount) + for i := range validCount { + vectors[i] = common.GenBFloat16Vector(common.DefaultDim) + } + originalVectors = vectors + vecColumn, err = column.NewNullableColumnBFloat16Vector("vector", common.DefaultDim, vectors, validData) + if validCount > 0 { + searchVec = entity.BFloat16Vector(vectors[0]) + } + + case entity.FieldTypeInt8Vector: + vectors := make([][]int8, validCount) + for i := range validCount { + vec := make([]int8, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = int8((i + j) % 127) + } + vectors[i] = vec + } + originalVectors = vectors + vecColumn, err = column.NewNullableColumnInt8Vector("vector", common.DefaultDim, vectors, validData) + if validCount > 0 { + searchVec = entity.Int8Vector(vectors[0]) + } + + case entity.FieldTypeSparseVector: + vectors := make([]entity.SparseEmbedding, validCount) + for i := range validCount { + positions := []uint32{0, uint32(i + 1), uint32(i + 1000)} + values := []float32{1.0, float32(i+1) / 1000.0, 0.1} + vectors[i], err = entity.NewSliceSparseEmbedding(positions, values) + common.CheckErr(t, err, true) + } + originalVectors = vectors + vecColumn, err = column.NewNullableColumnSparseFloatVector("vector", vectors, validData) + if validCount > 0 { + searchVec = vectors[0] + } + } + common.CheckErr(t, err, true) + _ = originalVectors + + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, vecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + if validCount > 0 { + var vecIndex index.Index + switch vt.FieldType { + case entity.FieldTypeBinaryVector: + vecIndex = index.NewGenericIndex("vector", map[string]string{ + index.MetricTypeKey: string(entity.JACCARD), + index.IndexTypeKey: "BIN_FLAT", + }) + case entity.FieldTypeInt8Vector: + vecIndex = index.NewGenericIndex("vector", map[string]string{ + index.MetricTypeKey: string(entity.COSINE), + index.IndexTypeKey: "HNSW", + "M": "16", + "efConstruction": "200", + }) + case entity.FieldTypeSparseVector: + vecIndex = index.NewSparseInvertedIndex(entity.IP, 0.1) + default: + vecIndex = index.NewFlatIndex(entity.L2) + } + indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask.Await(ctx) + common.CheckErr(t, err, true) + + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}).WithANNSField("vector")) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes)) + searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data() + require.True(t, len(searchIDs) > 0, "search should return results") + + expectedTopK := 10 + if validCount < expectedTopK { + expectedTopK = validCount + } + require.EqualValues(t, expectedTopK, len(searchIDs), "search should return expected number of results") + + for _, id := range searchIDs { + require.True(t, id >= 0 && id < int64(nb), "search result ID %d should be in range [0, %d)", id, nb) + } + + verifyVectorData := func(queryResult client.ResultSet, context string) { + pkCol := queryResult.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64) + vecCol := queryResult.GetColumn("vector") + for i := 0; i < queryResult.ResultCount; i++ { + pk, _ := pkCol.GetAsInt64(i) + isNull, _ := vecCol.IsNull(i) + + if origIdx, ok := pkToVecIdx[pk]; ok { + require.False(t, isNull, "%s: vector should not be null for pk %d", context, pk) + vecData, _ := vecCol.Get(i) + + switch vt.FieldType { + case entity.FieldTypeFloatVector: + vectors := originalVectors.([][]float32) + queriedVec := []float32(vecData.(entity.FloatVector)) + require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk) + origVec := vectors[origIdx] + for j := range origVec { + require.InDelta(t, origVec[j], queriedVec[j], 1e-6, "%s: vector element %d should match for pk %d", context, j, pk) + } + case entity.FieldTypeInt8Vector: + vectors := originalVectors.([][]int8) + queriedVec := []int8(vecData.(entity.Int8Vector)) + require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk) + origVec := vectors[origIdx] + for j := range origVec { + require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector element %d should match for pk %d", context, j, pk) + } + case entity.FieldTypeBinaryVector: + vectors := originalVectors.([][]byte) + queriedVec := []byte(vecData.(entity.BinaryVector)) + byteDim := common.DefaultDim / 8 + require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk) + origVec := vectors[origIdx] + for j := range origVec { + require.EqualValues(t, origVec[j], queriedVec[j], "%s: vector byte %d should match for pk %d", context, j, pk) + } + case entity.FieldTypeFloat16Vector: + queriedVec := []byte(vecData.(entity.Float16Vector)) + byteDim := common.DefaultDim * 2 + require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk) + case entity.FieldTypeBFloat16Vector: + queriedVec := []byte(vecData.(entity.BFloat16Vector)) + byteDim := common.DefaultDim * 2 + require.EqualValues(t, byteDim, len(queriedVec), "%s: vector byte dimension should match for pk %d", context, pk) + case entity.FieldTypeSparseVector: + vectors := originalVectors.([]entity.SparseEmbedding) + queriedVec := vecData.(entity.SparseEmbedding) + origVec := vectors[origIdx] + require.EqualValues(t, origVec.Len(), queriedVec.Len(), "%s: sparse vector length should match for pk %d", context, pk) + for j := 0; j < origVec.Len(); j++ { + origPos, origVal, _ := origVec.Get(j) + queriedPos, queriedVal, _ := queriedVec.Get(j) + require.EqualValues(t, origPos, queriedPos, "%s: sparse vector position %d should match for pk %d", context, j, pk) + require.InDelta(t, origVal, queriedVal, 1e-6, "%s: sparse vector value %d should match for pk %d", context, j, pk) + } + } + } else { + require.True(t, isNull, "%s: vector should be null for pk %d", context, pk) + vecData, _ := vecCol.Get(i) + require.Nil(t, vecData, "%s: null vector data should be nil for pk %d", context, pk) + } + } + } + + if len(searchIDs) > 0 { + searchQueryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(searchIDs))).WithOutputFields("vector")) + common.CheckErr(t, err, true) + require.EqualValues(t, len(searchIDs), searchQueryRes.ResultCount, "query by search IDs should return all IDs") + verifyVectorData(searchQueryRes, "Search results") + } + + queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 < 10").WithOutputFields("vector")) + common.CheckErr(t, err, true) + expectedQueryCount := 10 + require.EqualValues(t, expectedQueryCount, queryRes.ResultCount, "query should return expected count") + verifyVectorData(queryRes, "Query int64 < 10") + + searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}). + WithANNSField("vector").WithFilter("int64 < 100").WithOutputFields("vector")) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes)) + filteredIDs := searchRes[0].IDs.(*column.ColumnInt64).Data() + for _, id := range filteredIDs { + require.True(t, id < 100, "filtered search should only return IDs < 100, got %d", id) + } + hybridSearchQueryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(filteredIDs))).WithOutputFields("vector")) + common.CheckErr(t, err, true) + verifyVectorData(hybridSearchQueryRes, "Hybrid search results") + + countRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + totalCount, err := countRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, totalCount, "total count should equal inserted rows") + } + + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } + } +} + +func TestNullableVectorWithScalarFilter(t *testing.T) { + vectorTypes := GetVectorTypes() + nullPercents := GetNullPercents() + + for _, vt := range vectorTypes { + for _, nullPercent := range nullPercents { + testName := fmt.Sprintf("%s_%d%%null", vt.Name, nullPercent) + t.Run(testName, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + collName := common.GenRandomString("nullable_vec_filter", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true) + if vt.FieldType != entity.FieldTypeSparseVector { + vecField = vecField.WithDim(common.DefaultDim) + } + tagField := entity.NewField().WithName("tag").WithDataType(entity.FieldTypeVarChar).WithMaxLength(100).WithNullable(true) + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField).WithField(tagField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + nb := 500 + testData := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector") + + // tag field: 50% null (even rows are valid) + tagValidData := make([]bool, nb) + tagValidCount := 0 + for i := range nb { + tagValidData[i] = i%2 == 0 + if tagValidData[i] { + tagValidCount++ + } + } + + pkData := make([]int64, nb) + for i := range nb { + pkData[i] = int64(i) + } + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) + + tagData := make([]string, tagValidCount) + for i := range tagValidCount { + tagData[i] = fmt.Sprintf("tag_%d", i) + } + tagColumn, err := column.NewNullableColumnVarChar("tag", tagData, tagValidData) + common.CheckErr(t, err, true) + + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn, tagColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + if testData.ValidCount > 0 { + vecIndex := CreateNullableVectorIndex(vt) + indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask.Await(ctx) + common.CheckErr(t, err, true) + + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + // Query with scalar filter: tag is not null and int64 < 50 + // int64 < 50 => 50 rows (pk 0-49) + // tag is not null => even rows only (pk 0, 2, 4, ..., 48) => 25 rows + queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("tag is not null and int64 < 50").WithOutputFields("vector", "tag")) + common.CheckErr(t, err, true) + require.EqualValues(t, 25, queryRes.ResultCount, "query should return 25 rows with tag not null and int64 < 50") + VerifyNullableVectorData(t, vt, queryRes, testData.PkToVecIdx, testData.OriginalVectors, "Query with tag filter") + } + + // clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } + } +} + +func TestNullableVectorDelete(t *testing.T) { + vectorTypes := GetVectorTypes() + nullPercents := GetNullPercents() + + for _, vt := range vectorTypes { + for _, nullPercent := range nullPercents { + testName := fmt.Sprintf("%s_%d%%null", vt.Name, nullPercent) + t.Run(testName, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + collName := common.GenRandomString("nullable_vec_del", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true) + if vt.FieldType != entity.FieldTypeSparseVector { + vecField = vecField.WithDim(common.DefaultDim) + } + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + nb := 100 + testData := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector") + + pkData := make([]int64, nb) + for i := range nb { + pkData[i] = int64(i) + } + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) + + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + if testData.ValidCount > 0 { + vecIndex := CreateNullableVectorIndex(vt) + indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask.Await(ctx) + common.CheckErr(t, err, true) + + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + // Delete first 25 rows and last 25 rows + delRes, err := mc.Delete(ctx, client.NewDeleteOption(collName).WithExpr("int64 < 25")) + common.CheckErr(t, err, true) + require.EqualValues(t, 25, delRes.DeleteCount) + + delRes, err = mc.Delete(ctx, client.NewDeleteOption(collName).WithExpr("int64 >= 75")) + common.CheckErr(t, err, true) + require.EqualValues(t, 25, delRes.DeleteCount) + + // Verify remaining count + queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + count, err := queryRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, 50, count, "remaining count should be 100 - 25 - 25 = 50") + + // Verify deleted rows don't exist + queryDeletedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 < 25").WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + deletedCount, err := queryDeletedRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, 0, deletedCount, "deleted rows should not exist") + + queryDeletedValidRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 75").WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + deletedValidCount, err := queryDeletedValidRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, 0, deletedValidCount, "deleted valid vector rows should not exist") + + // Verify remaining rows with vector data + queryValidRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 25 and int64 < 75").WithOutputFields("vector")) + common.CheckErr(t, err, true) + require.EqualValues(t, 50, queryValidRes.ResultCount, "should have 50 remaining rows") + VerifyNullableVectorData(t, vt, queryValidRes, testData.PkToVecIdx, testData.OriginalVectors, "Remaining vector rows") + } + + // clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } + } +} + +func TestNullableVectorUpsert(t *testing.T) { + autoIDOptions := []bool{false, true} + + for _, autoID := range autoIDOptions { + testName := fmt.Sprintf("AutoID=%v", autoID) + t.Run(testName, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + collName := common.GenRandomString("nullable_vec_ups", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + if autoID { + pkField.AutoID = true + } + vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim).WithNullable(true) + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + // insert initial data with 50% null + nb := 100 + nullPercent := 50 + validData := make([]bool, nb) + validCount := 0 + for i := range nb { + validData[i] = (i % 100) >= nullPercent + if validData[i] { + validCount++ + } + } + + pkData := make([]int64, nb) + for i := range nb { + pkData[i] = int64(i) + } + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) + + pkToVecIdx := make(map[int64]int) + vecIdx := 0 + for i := range nb { + if validData[i] { + pkToVecIdx[int64(i)] = vecIdx + vecIdx++ + } + } + + vectors := make([][]float32, validCount) + for i := range validCount { + vec := make([]float32, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = float32(i*common.DefaultDim+j) / 10000.0 + } + vectors[i] = vec + } + vecColumn, err := column.NewNullableColumnFloatVector(common.DefaultFloatVecFieldName, common.DefaultDim, vectors, validData) + common.CheckErr(t, err, true) + + var insertRes client.InsertResult + if autoID { + insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(vecColumn)) + } else { + insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(pkColumn, vecColumn)) + } + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + var actualPkData []int64 + if autoID { + insertedIDs := insertRes.IDs.(*column.ColumnInt64) + actualPkData = insertedIDs.Data() + require.EqualValues(t, nb, len(actualPkData), "inserted PK count should match") + } else { + actualPkData = pkData + } + + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, common.DefaultFloatVecFieldName, index.NewFlatIndex(entity.L2))) + common.CheckErr(t, err, true) + err = indexTask.Await(ctx) + common.CheckErr(t, err, true) + + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + // upsert: change first 25 rows (originally null) to valid, change rows 50-74 (originally valid) to null + upsertNb := 50 + upsertValidData := make([]bool, upsertNb) + for i := range upsertNb { + upsertValidData[i] = i < 25 + } + + upsertPkData := make([]int64, upsertNb) + for i := range upsertNb { + if i < 25 { + upsertPkData[i] = actualPkData[i] + } else { + upsertPkData[i] = actualPkData[i+25] + } + } + upsertPkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, upsertPkData) + + upsertVectors := make([][]float32, 25) + for i := range 25 { + vec := make([]float32, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = float32((i+100)*common.DefaultDim+j) / 10000.0 + } + upsertVectors[i] = vec + } + upsertVecColumn, err := column.NewNullableColumnFloatVector(common.DefaultFloatVecFieldName, common.DefaultDim, upsertVectors, upsertValidData) + common.CheckErr(t, err, true) + + upsertRes, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName, upsertPkColumn, upsertVecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, upsertNb, upsertRes.UpsertCount) + + var upsertedPks []int64 + if autoID { + upsertedIDs := upsertRes.IDs.(*column.ColumnInt64) + upsertedPks = upsertedIDs.Data() + require.EqualValues(t, upsertNb, len(upsertedPks), "upserted PK count should match") + } else { + upsertedPks = upsertPkData + } + + expectedVectorMap := make(map[int64][]float32) + for i := 0; i < 25; i++ { + expectedVectorMap[upsertedPks[i]] = upsertVectors[i] + } + for i := 25; i < 50; i++ { + expectedVectorMap[upsertedPks[i]] = nil + } + for i := 25; i < 50; i++ { + expectedVectorMap[actualPkData[i]] = nil + } + for i := 75; i < 100; i++ { + vecIdx := i - 50 + expectedVectorMap[actualPkData[i]] = vectors[vecIdx] + } + + time.Sleep(10 * time.Second) + flushTask, err = mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + err = mc.ReleaseCollection(ctx, client.NewReleaseCollectionOption(collName)) + common.CheckErr(t, err, true) + + loadTask, err = mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + expectedValidCount := 50 + searchVec := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{searchVec}).WithANNSField(common.DefaultFloatVecFieldName)) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes)) + searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data() + require.EqualValues(t, expectedValidCount, len(searchIDs), "search should return all 50 valid vectors") + + verifyVectorData := func(queryResult client.ResultSet, context string) { + pkCol := queryResult.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64) + vecCol := queryResult.GetColumn(common.DefaultFloatVecFieldName).(*column.ColumnFloatVector) + for i := 0; i < queryResult.ResultCount; i++ { + pk, _ := pkCol.GetAsInt64(i) + isNull, _ := vecCol.IsNull(i) + + expectedVec, exists := expectedVectorMap[pk] + require.True(t, exists, "%s: unexpected PK %d in query results", context, pk) + + if expectedVec != nil { + require.False(t, isNull, "%s: vector should not be null for pk %d", context, pk) + vecData, _ := vecCol.Get(i) + queriedVec := []float32(vecData.(entity.FloatVector)) + require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk) + for j := range expectedVec { + require.InDelta(t, expectedVec[j], queriedVec[j], 1e-6, "%s: vector element %d should match for pk %d", context, j, pk) + } + } else { + require.True(t, isNull, "%s: vector should be null for pk %d", context, pk) + vecData, _ := vecCol.Get(i) + require.Nil(t, vecData, "%s: null vector data should be nil for pk %d", context, pk) + } + } + } + + searchQueryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(searchIDs))).WithOutputFields(common.DefaultFloatVecFieldName)) + common.CheckErr(t, err, true) + verifyVectorData(searchQueryRes, "All valid vectors after upsert") + + upsertedToValidPKs := upsertedPks[0:25] + queryUpsertedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(upsertedToValidPKs))).WithOutputFields(common.DefaultFloatVecFieldName)) + common.CheckErr(t, err, true) + require.EqualValues(t, 25, queryUpsertedRes.ResultCount, "should have 25 rows for upserted to valid") + verifyVectorData(queryUpsertedRes, "Upserted valid rows") + + countRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + totalCount, err := countRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, totalCount, "total count after upsert should still be %d", nb) + + // clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } +} + +func TestNullableVectorAllNull(t *testing.T) { + vectorTypes := GetVectorTypes() + + for _, vt := range vectorTypes { + testName := fmt.Sprintf("%s_100%%null", vt.Name) + t.Run(testName, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + collName := common.GenRandomString("nullable_vec_all", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true) + if vt.FieldType != entity.FieldTypeSparseVector { + vecField = vecField.WithDim(common.DefaultDim) + } + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + // Generate test data with 100% null (nullPercent = 100) + nb := 100 + testData := GenerateNullableVectorTestData(t, vt, nb, 100, "vector") + + pkData := make([]int64, nb) + for i := range nb { + pkData[i] = int64(i) + } + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) + + // insert + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + // flush + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + // create index and load + vecIndex := CreateNullableVectorIndex(vt) + indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask.Await(ctx) + common.CheckErr(t, err, true) + + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + // Generate a search vector (won't match anything since all are null) + var searchVec entity.Vector + switch vt.FieldType { + case entity.FieldTypeFloatVector: + searchVec = entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + case entity.FieldTypeBinaryVector: + searchVec = entity.BinaryVector(make([]byte, common.DefaultDim/8)) + case entity.FieldTypeFloat16Vector: + searchVec = entity.Float16Vector(common.GenFloat16Vector(common.DefaultDim)) + case entity.FieldTypeBFloat16Vector: + searchVec = entity.BFloat16Vector(common.GenBFloat16Vector(common.DefaultDim)) + case entity.FieldTypeInt8Vector: + vec := make([]int8, common.DefaultDim) + searchVec = entity.Int8Vector(vec) + case entity.FieldTypeSparseVector: + searchVec, _ = entity.NewSliceSparseEmbedding([]uint32{0}, []float32{1.0}) + } + + // search should return empty results since all vectors are null (not searchable) + searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}).WithANNSField("vector")) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes)) + searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data() + require.EqualValues(t, 0, len(searchIDs), "search should return empty results for all-null vectors") + + // query should return all rows + queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + count, err := queryRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, count, "query should return all %d rows even with 100%% null vectors", nb) + + // clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } +} + +func TestNullableVectorMultiFields(t *testing.T) { + vectorTypes := GetVectorTypes() + + for _, vt := range vectorTypes { + testName := vt.Name + t.Run(testName, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + collName := common.GenRandomString("nullable_vec_multi", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField1 := entity.NewField().WithName("vec1").WithDataType(vt.FieldType).WithNullable(true) + vecField2 := entity.NewField().WithName("vec2").WithDataType(vt.FieldType).WithNullable(true) + if vt.FieldType != entity.FieldTypeSparseVector { + vecField1 = vecField1.WithDim(common.DefaultDim) + vecField2 = vecField2.WithDim(common.DefaultDim) + } + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField1).WithField(vecField2) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + // generate data: vec1 has 70% valid (first 30 per 100 are invalid), vec2 has 30% valid (first 70 per 100 are invalid) + nb := 100 + nullPercent1 := 30 // vec1: 30% null + nullPercent2 := 70 // vec2: 70% null + + // Generate test data for both vector fields + testData1 := GenerateNullableVectorTestData(t, vt, nb, nullPercent1, "vec1") + testData2 := GenerateNullableVectorTestData(t, vt, nb, nullPercent2, "vec2") + + pkData := make([]int64, nb) + for i := range nb { + pkData[i] = int64(i) + } + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) + + // insert + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData1.VecColumn, testData2.VecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + // flush + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + // create indexes for both vector fields + vecIndex := CreateNullableVectorIndex(vt) + indexTask1, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vec1", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask1.Await(ctx) + common.CheckErr(t, err, true) + + indexTask2, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vec2", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask2.Await(ctx) + common.CheckErr(t, err, true) + + // load + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + // search on vec1 + searchRes1, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{testData1.SearchVec}).WithANNSField("vec1").WithOutputFields("vec1")) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes1)) + searchIDs1 := searchRes1[0].IDs.(*column.ColumnInt64).Data() + require.EqualValues(t, 10, len(searchIDs1), "search on vec1 should return 10 results") + for _, id := range searchIDs1 { + _, ok := testData1.PkToVecIdx[id] + require.True(t, ok, "search on vec1 should only return rows where vec1 is valid, got pk %d", id) + } + + // search on vec2 + searchRes2, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{testData2.SearchVec}).WithANNSField("vec2").WithOutputFields("vec2")) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes2)) + searchIDs2 := searchRes2[0].IDs.(*column.ColumnInt64).Data() + require.EqualValues(t, 10, len(searchIDs2), "search on vec2 should return 10 results") + for _, id := range searchIDs2 { + _, ok := testData2.PkToVecIdx[id] + require.True(t, ok, "search on vec2 should only return rows where vec2 is valid, got pk %d", id) + } + + // query and verify - rows 0-29 both null + queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 < 30").WithOutputFields("vec1", "vec2")) + common.CheckErr(t, err, true) + VerifyNullableVectorDataWithFieldName(t, vt, queryRes, testData1.PkToVecIdx, testData1.OriginalVectors, "vec1", "query0-29 vec1") + VerifyNullableVectorDataWithFieldName(t, vt, queryRes, testData2.PkToVecIdx, testData2.OriginalVectors, "vec2", "query0-29 vec2") + + // query rows 30-69: vec1 valid, vec2 null + queryMixedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 30 AND int64 < 70").WithOutputFields("vec1", "vec2")) + common.CheckErr(t, err, true) + VerifyNullableVectorDataWithFieldName(t, vt, queryMixedRes, testData1.PkToVecIdx, testData1.OriginalVectors, "vec1", "query30-69 vec1") + VerifyNullableVectorDataWithFieldName(t, vt, queryMixedRes, testData2.PkToVecIdx, testData2.OriginalVectors, "vec2", "query30-69 vec2") + + // query rows 70-99: both valid + queryBothValidRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 70").WithOutputFields("vec1", "vec2")) + common.CheckErr(t, err, true) + VerifyNullableVectorDataWithFieldName(t, vt, queryBothValidRes, testData1.PkToVecIdx, testData1.OriginalVectors, "vec1", "query70-99 vec1") + VerifyNullableVectorDataWithFieldName(t, vt, queryBothValidRes, testData2.PkToVecIdx, testData2.OriginalVectors, "vec2", "query70-99 vec2") + + // clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } +} + +func TestNullableVectorPaginatedQuery(t *testing.T) { + vectorTypes := GetVectorTypes() + nullPercents := GetNullPercents() + + for _, vt := range vectorTypes { + for _, nullPercent := range nullPercents { + testName := fmt.Sprintf("%s_%d%%null", vt.Name, nullPercent) + t.Run(testName, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + collName := common.GenRandomString("nullable_vec_page", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true) + if vt.FieldType != entity.FieldTypeSparseVector { + vecField = vecField.WithDim(common.DefaultDim) + } + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + nb := 200 + testData := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector") + + pkData := make([]int64, nb) + for i := range nb { + pkData[i] = int64(i) + } + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) + + // insert + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + if testData.ValidCount > 0 { + // create index and load + vecIndex := CreateNullableVectorIndex(vt) + indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask.Await(ctx) + common.CheckErr(t, err, true) + + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + // Test pagination: page1 offset=0, limit=50 + page1Res, err := mc.Query(ctx, client.NewQueryOption(collName). + WithFilter("").WithOutputFields("vector").WithOffset(0).WithLimit(50)) + common.CheckErr(t, err, true) + require.EqualValues(t, 50, page1Res.ResultCount, "page 1 should return 50 rows") + VerifyNullableVectorData(t, vt, page1Res, testData.PkToVecIdx, testData.OriginalVectors, "page1") + + // Test pagination: page2 offset=50, limit=50 + page2Res, err := mc.Query(ctx, client.NewQueryOption(collName). + WithFilter("").WithOutputFields("vector").WithOffset(50).WithLimit(50)) + common.CheckErr(t, err, true) + require.EqualValues(t, 50, page2Res.ResultCount, "page 2 should return 50 rows") + VerifyNullableVectorData(t, vt, page2Res, testData.PkToVecIdx, testData.OriginalVectors, "page2") + + // Test pagination: page3 offset=100, limit=50 + page3Res, err := mc.Query(ctx, client.NewQueryOption(collName). + WithFilter("").WithOutputFields("vector").WithOffset(100).WithLimit(50)) + common.CheckErr(t, err, true) + require.EqualValues(t, 50, page3Res.ResultCount, "page 3 should return 50 rows") + VerifyNullableVectorData(t, vt, page3Res, testData.PkToVecIdx, testData.OriginalVectors, "page3") + + // Test mixed query with filter + mixedPageRes, err := mc.Query(ctx, client.NewQueryOption(collName). + WithFilter("int64 >= 40 and int64 < 60"). + WithOutputFields("vector")) + common.CheckErr(t, err, true) + require.EqualValues(t, 20, mixedPageRes.ResultCount, "mixed query should return 20 rows") + VerifyNullableVectorData(t, vt, mixedPageRes, testData.PkToVecIdx, testData.OriginalVectors, "mixed query") + } + + // clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } + } +} + +func TestNullableVectorMultiPartitions(t *testing.T) { + vectorTypes := GetVectorTypes() + + for _, vt := range vectorTypes { + testName := vt.Name + t.Run(testName, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + collName := common.GenRandomString("nullable_vec_part", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true) + if vt.FieldType != entity.FieldTypeSparseVector { + vecField = vecField.WithDim(common.DefaultDim) + } + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + // create partitions + partitions := []string{"partition_a", "partition_b", "partition_c"} + for _, p := range partitions { + err = mc.CreatePartition(ctx, client.NewCreatePartitionOption(collName, p)) + common.CheckErr(t, err, true) + } + + // insert data into each partition with different null ratios + nbPerPartition := 100 + nullRatios := []int{0, 30, 50} // 0%, 30%, 50% null for each partition + + // Store all test data and mappings for verification + allPkToVecIdx := make(map[int64]int) + var allOriginalVectors interface{} + var firstSearchVec entity.Vector + globalVecIdx := 0 + + for i, partition := range partitions { + nullRatio := nullRatios[i] + testData := GenerateNullableVectorTestData(t, vt, nbPerPartition, nullRatio, "vector") + + // pk column with unique ids per partition + pkData := make([]int64, nbPerPartition) + for j := range nbPerPartition { + pkData[j] = int64(i*nbPerPartition + j) + } + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) + + for j := range nbPerPartition { + if testData.ValidData[j] { + allPkToVecIdx[pkData[j]] = globalVecIdx + globalVecIdx++ + } + } + + // Accumulate original vectors for verification + switch vt.FieldType { + case entity.FieldTypeFloatVector: + if allOriginalVectors == nil { + allOriginalVectors = make([][]float32, 0) + } + allOriginalVectors = append(allOriginalVectors.([][]float32), testData.OriginalVectors.([][]float32)...) + case entity.FieldTypeBinaryVector: + if allOriginalVectors == nil { + allOriginalVectors = make([][]byte, 0) + } + allOriginalVectors = append(allOriginalVectors.([][]byte), testData.OriginalVectors.([][]byte)...) + case entity.FieldTypeFloat16Vector, entity.FieldTypeBFloat16Vector: + if allOriginalVectors == nil { + allOriginalVectors = make([][]byte, 0) + } + allOriginalVectors = append(allOriginalVectors.([][]byte), testData.OriginalVectors.([][]byte)...) + case entity.FieldTypeInt8Vector: + if allOriginalVectors == nil { + allOriginalVectors = make([][]int8, 0) + } + allOriginalVectors = append(allOriginalVectors.([][]int8), testData.OriginalVectors.([][]int8)...) + case entity.FieldTypeSparseVector: + if allOriginalVectors == nil { + allOriginalVectors = make([]entity.SparseEmbedding, 0) + } + allOriginalVectors = append(allOriginalVectors.([]entity.SparseEmbedding), testData.OriginalVectors.([]entity.SparseEmbedding)...) + } + + // Save first search vector (from partition_a with 0% null) + if i == 0 && testData.SearchVec != nil { + firstSearchVec = testData.SearchVec + } + + // insert into partition + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn).WithPartition(partition)) + common.CheckErr(t, err, true) + require.EqualValues(t, nbPerPartition, insertRes.InsertCount) + } + + // flush + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + // create index and load + vecIndex := CreateNullableVectorIndex(vt) + indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask.Await(ctx) + common.CheckErr(t, err, true) + + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + // search in specific partition - verify all results from partition_a + searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{firstSearchVec}). + WithANNSField("vector"). + WithPartitions("partition_a")) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes)) + searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data() + require.EqualValues(t, 10, len(searchIDs), "search in partition_a should return 10 results") + // partition_a has 0% null, so all 100 vectors are valid, IDs should be 0-99 + for _, id := range searchIDs { + require.True(t, id >= 0 && id < int64(nbPerPartition), "partition_a IDs should be in range [0, %d), got %d", nbPerPartition, id) + // Verify all search results have valid vectors + _, ok := allPkToVecIdx[id] + require.True(t, ok, "search result pk %d should have a valid vector", id) + } + + // search across all partitions - should return results from any partition + searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{firstSearchVec}). + WithANNSField("vector")) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes)) + allSearchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data() + require.EqualValues(t, 50, len(allSearchIDs), "search across all partitions should return 50 results") + // Verify all search results have valid vectors + for _, id := range allSearchIDs { + _, ok := allPkToVecIdx[id] + require.True(t, ok, "all partitions search result pk %d should have a valid vector", id) + } + + // query each partition to verify counts + expectedCounts := []int64{100, 100, 100} // total rows in each partition + for i, partition := range partitions { + queryRes, err := mc.Query(ctx, client.NewQueryOption(collName). + WithFilter(""). + WithOutputFields("count(*)"). + WithPartitions(partition)) + common.CheckErr(t, err, true) + count, err := queryRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, expectedCounts[i], count, "partition %s should have %d rows", partition, expectedCounts[i]) + } + + // query with vector output from specific partition - partition_a (0% null) + queryVecRes, err := mc.Query(ctx, client.NewQueryOption(collName). + WithFilter("int64 < 10"). + WithOutputFields("vector"). + WithPartitions("partition_a")) + common.CheckErr(t, err, true) + require.EqualValues(t, 10, queryVecRes.ResultCount, "query partition_a with int64 < 10 should return 10 rows") + VerifyNullableVectorData(t, vt, queryVecRes, allPkToVecIdx, allOriginalVectors, "query partition_a int64 < 10") + + // query partition_b which has 30% null (rows 100-129 are null, 130-199 are valid) + queryPartBRes, err := mc.Query(ctx, client.NewQueryOption(collName). + WithFilter("int64 >= 100 AND int64 < 150"). + WithOutputFields("vector"). + WithPartitions("partition_b")) + common.CheckErr(t, err, true) + require.EqualValues(t, 50, queryPartBRes.ResultCount, "query partition_b with 100 <= int64 < 150 should return 50 rows") + VerifyNullableVectorData(t, vt, queryPartBRes, allPkToVecIdx, allOriginalVectors, "query partition_b int64 100-149") + + // query partition_c which has 50% null (rows 200-249 are null, 250-299 are valid) + queryPartCRes, err := mc.Query(ctx, client.NewQueryOption(collName). + WithFilter("int64 >= 200 AND int64 < 260"). + WithOutputFields("vector"). + WithPartitions("partition_c")) + common.CheckErr(t, err, true) + require.EqualValues(t, 60, queryPartCRes.ResultCount, "query partition_c with 200 <= int64 < 260 should return 60 rows") + VerifyNullableVectorData(t, vt, queryPartCRes, allPkToVecIdx, allOriginalVectors, "query partition_c int64 200-259") + + // verify total count across all partitions + totalCountRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + totalCount, err := totalCountRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, nbPerPartition*3, totalCount, "total count should be %d", nbPerPartition*3) + + // clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } +} + +func TestNullableVectorCompaction(t *testing.T) { + vectorTypes := GetVectorTypes() + + for _, vt := range vectorTypes { + testName := vt.Name + t.Run(testName, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*2) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + collName := common.GenRandomString("nullable_vec_comp", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true) + if vt.FieldType != entity.FieldTypeSparseVector { + vecField = vecField.WithDim(common.DefaultDim) + } + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + // insert data in multiple batches to create multiple segments + nb := 200 + nullPercent := 30 + + // Store all vectors and mappings for verification + allPkToVecIdx := make(map[int64]int) + var allOriginalVectors interface{} + var searchVec entity.Vector + globalVecIdx := 0 + + // batch 1: generate test data + testData1 := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector") + + pkData1 := make([]int64, nb) + for i := range nb { + pkData1[i] = int64(i) + } + pkColumn1 := column.NewColumnInt64(common.DefaultInt64FieldName, pkData1) + + for i := range nb { + if testData1.ValidData[i] { + allPkToVecIdx[pkData1[i]] = globalVecIdx + globalVecIdx++ + } + } + + // Store original vectors + allOriginalVectors = testData1.OriginalVectors + searchVec = testData1.SearchVec + + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn1, testData1.VecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + // flush to create segment + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + // wait for rate limiter reset before next flush (rate=0.1 means 1 flush per 10s) + time.Sleep(10 * time.Second) + + testData2 := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector") + + pkData2 := make([]int64, nb) + for i := range nb { + pkData2[i] = int64(nb + i) + } + pkColumn2 := column.NewColumnInt64(common.DefaultInt64FieldName, pkData2) + + for i := range nb { + if testData2.ValidData[i] { + allPkToVecIdx[pkData2[i]] = globalVecIdx + globalVecIdx++ + } + } + + // Accumulate original vectors for verification + switch vt.FieldType { + case entity.FieldTypeFloatVector: + allOriginalVectors = append(allOriginalVectors.([][]float32), testData2.OriginalVectors.([][]float32)...) + case entity.FieldTypeBinaryVector: + allOriginalVectors = append(allOriginalVectors.([][]byte), testData2.OriginalVectors.([][]byte)...) + case entity.FieldTypeFloat16Vector, entity.FieldTypeBFloat16Vector: + allOriginalVectors = append(allOriginalVectors.([][]byte), testData2.OriginalVectors.([][]byte)...) + case entity.FieldTypeInt8Vector: + allOriginalVectors = append(allOriginalVectors.([][]int8), testData2.OriginalVectors.([][]int8)...) + case entity.FieldTypeSparseVector: + allOriginalVectors = append(allOriginalVectors.([]entity.SparseEmbedding), testData2.OriginalVectors.([]entity.SparseEmbedding)...) + } + + insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn2, testData2.VecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + // flush to create another segment + flushTask, err = mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + // create index and load + vecIndex := CreateNullableVectorIndex(vt) + indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask.Await(ctx) + common.CheckErr(t, err, true) + + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + // delete some data (mix of valid and null vectors) - first 50 rows from batch 1 + delRes, err := mc.Delete(ctx, client.NewDeleteOption(collName).WithExpr("int64 < 50")) + common.CheckErr(t, err, true) + require.EqualValues(t, 50, delRes.DeleteCount, "should delete 50 rows") + + // trigger manual compaction + compactID, err := mc.Compact(ctx, client.NewCompactOption(collName)) + common.CheckErr(t, err, true) + t.Logf("Compaction started with ID: %d", compactID) + + // wait for compaction to complete + for i := 0; i < 60; i++ { + state, err := mc.GetCompactionState(ctx, client.NewGetCompactionStateOption(compactID)) + common.CheckErr(t, err, true) + if state == entity.CompactionStateCompleted { + t.Log("Compaction completed") + break + } + time.Sleep(time.Second) + } + + // verify remaining count: 400 total - 50 deleted = 350 + queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + count, err := queryRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, nb*2-50, count, "remaining count should be 400 - 50 = 350") + + // verify deleted rows are gone + queryDeletedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 < 50").WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + deletedCount, err := queryDeletedRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, 0, deletedCount, "deleted rows should not exist") + + // search should still work - verify returns results + searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}).WithANNSField("vector")) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes)) + searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data() + require.EqualValues(t, 10, len(searchIDs), "search should return 10 results") + // All search results should have IDs >= 50 (since we deleted pk < 50) and have valid vectors + for _, id := range searchIDs { + require.True(t, id >= 50, "search results should not include deleted IDs, got %d", id) + _, ok := allPkToVecIdx[id] + require.True(t, ok, "search result pk %d should have a valid vector", id) + } + + // query with output vector field - verify remaining valid vectors in batch 1 + queryRes, err = mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 50 and int64 < 100").WithOutputFields("vector")) + common.CheckErr(t, err, true) + require.EqualValues(t, 50, queryRes.ResultCount, "should have 50 rows in range [50, 100)") + VerifyNullableVectorData(t, vt, queryRes, allPkToVecIdx, allOriginalVectors, "query batch1 remaining 50-99") + + queryMixedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 200 and int64 < 250").WithOutputFields("vector")) + common.CheckErr(t, err, true) + require.EqualValues(t, 50, queryMixedRes.ResultCount, "should have 50 rows in range [200, 250)") + VerifyNullableVectorData(t, vt, queryMixedRes, allPkToVecIdx, allOriginalVectors, "query batch2 200-249") + + queryBatch2CountRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 200").WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + batch2Count, err := queryBatch2CountRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, batch2Count, "batch 2 should have all %d rows intact", nb) + + // clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } +} + +func TestNullableVectorAddField(t *testing.T) { + vectorTypes := GetVectorTypes() + + for _, vt := range vectorTypes { + t.Run(vt.Name, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + collName := common.GenRandomString("nullable_vec_add", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + origVecField := entity.NewField().WithName("orig_vec").WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim) + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(origVecField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + nb := 100 + pkData1 := make([]int64, nb) + for i := range nb { + pkData1[i] = int64(i) + } + pkColumn1 := column.NewColumnInt64(common.DefaultInt64FieldName, pkData1) + + origVecData1 := make([][]float32, nb) + for i := range nb { + vec := make([]float32, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = float32(i*common.DefaultDim+j) / 10000.0 + } + origVecData1[i] = vec + } + origVecColumn1 := column.NewColumnFloatVector("orig_vec", common.DefaultDim, origVecData1) + + insertRes1, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn1, origVecColumn1)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes1.InsertCount) + + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + // wait for rate limiter reset before next flush (rate=0.1 means 1 flush per 10s) + time.Sleep(10 * time.Second) + + // SparseVector does not need dim, but other vectors do + newVecField := entity.NewField().WithName("new_vec").WithDataType(vt.FieldType).WithNullable(true) + if vt.FieldType != entity.FieldTypeSparseVector { + newVecField = newVecField.WithDim(common.DefaultDim) + } + err = mc.AddCollectionField(ctx, client.NewAddCollectionFieldOption(collName, newVecField)) + common.CheckErr(t, err, true) + + // verify schema updated + coll, err := mc.DescribeCollection(ctx, client.NewDescribeCollectionOption(collName)) + common.CheckErr(t, err, true) + require.EqualValues(t, 3, len(coll.Schema.Fields), "should have 3 fields after adding new vector field") + + nullPercent := 30 // 30% null + testData := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "new_vec") + + pkData2 := make([]int64, nb) + for i := range nb { + pkData2[i] = int64(nb + i) // pk starts from nb + } + pkColumn2 := column.NewColumnInt64(common.DefaultInt64FieldName, pkData2) + + origVecData2 := make([][]float32, nb) + for i := range nb { + vec := make([]float32, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = float32((nb+i)*common.DefaultDim+j) / 10000.0 + } + origVecData2[i] = vec + } + origVecColumn2 := column.NewColumnFloatVector("orig_vec", common.DefaultDim, origVecData2) + + insertRes2, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn2, origVecColumn2, testData.VecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes2.InsertCount) + + flushTask2, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask2.Await(ctx) + common.CheckErr(t, err, true) + + // create indexes + origVecIndex := index.NewFlatIndex(entity.L2) + indexTask1, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "orig_vec", origVecIndex)) + common.CheckErr(t, err, true) + err = indexTask1.Await(ctx) + common.CheckErr(t, err, true) + + newVecIndex := CreateNullableVectorIndexWithFieldName(vt, "new_vec") + indexTask2, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "new_vec", newVecIndex)) + common.CheckErr(t, err, true) + err = indexTask2.Await(ctx) + common.CheckErr(t, err, true) + + // load collection + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + // verify total count + countRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + totalCount, err := countRes.Fields[0].GetAsInt64(0) + common.CheckErr(t, err, true) + require.EqualValues(t, nb*2, totalCount, "total count should be %d", nb*2) + + searchVec := entity.FloatVector(origVecData1[0]) + searchRes1, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}).WithANNSField("orig_vec").WithOutputFields("new_vec")) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes1)) + require.EqualValues(t, 10, len(searchRes1[0].IDs.(*column.ColumnInt64).Data()), "search on orig_vec should return 10 results") + + if testData.SearchVec != nil { + searchRes2, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{testData.SearchVec}).WithANNSField("new_vec").WithOutputFields("new_vec")) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes2)) + searchIDs2 := searchRes2[0].IDs.(*column.ColumnInt64).Data() + require.EqualValues(t, 10, len(searchIDs2), "search on new_vec should return 10 results") + for _, id := range searchIDs2 { + require.True(t, id >= int64(nb), "search on new_vec should only return batch 2 rows, got pk %d", id) + _, ok := testData.PkToVecIdx[id-int64(nb)] + require.True(t, ok, "search result pk %d should have valid new_vec", id) + } + } + + queryRes1, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 < 100").WithOutputFields("new_vec")) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, queryRes1.ResultCount, "should have %d rows in batch 1", nb) + newVecCol1 := queryRes1.GetColumn("new_vec") + for i := 0; i < queryRes1.ResultCount; i++ { + isNull, _ := newVecCol1.IsNull(i) + require.True(t, isNull, "batch 1 rows should have null new_vec") + } + + queryRes2, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("int64 >= 100").WithOutputFields("new_vec")) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, queryRes2.ResultCount, "should have %d rows in batch 2", nb) + + pkToVecIdx2 := make(map[int64]int) + for pk, idx := range testData.PkToVecIdx { + // original PkToVecIdx uses pk 0..nb-1, need to map to nb..2*nb-1 + pkToVecIdx2[pk+int64(nb)] = idx + } + VerifyNullableVectorDataWithFieldName(t, vt, queryRes2, pkToVecIdx2, testData.OriginalVectors, "new_vec", "query batch 2") + + // clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } +} + +func TestNullableVectorRangeSearch(t *testing.T) { + vectorTypes := GetVectorTypes() + + for _, vt := range vectorTypes { + t.Run(vt.Name, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create collection + collName := common.GenRandomString("nullable_vec_range", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true) + if vt.FieldType != entity.FieldTypeSparseVector { + vecField = vecField.WithDim(common.DefaultDim) + } + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + // generate data with 30% null + nb := 500 + nullPercent := 30 + testData := GenerateNullableVectorTestData(t, vt, nb, nullPercent, "vector") + + // pk column + pkData := make([]int64, nb) + for i := range nb { + pkData[i] = int64(i) + } + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) + + // insert + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, testData.VecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + // flush + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + // create index with appropriate metric type + var vecIndex index.Index + switch vt.FieldType { + case entity.FieldTypeSparseVector: + vecIndex = index.NewSparseInvertedIndex(entity.IP, 0.1) + case entity.FieldTypeBinaryVector: + // BinaryVector uses Hamming distance + vecIndex = index.NewBinFlatIndex(entity.HAMMING) + case entity.FieldTypeInt8Vector: + // Int8Vector uses COSINE metric + vecIndex = index.NewHNSWIndex(entity.COSINE, 8, 96) + default: + // FloatVector, Float16Vector, BFloat16Vector use L2 metric + vecIndex = index.NewHNSWIndex(entity.L2, 8, 96) + } + indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask.Await(ctx) + common.CheckErr(t, err, true) + + // load + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + if testData.SearchVec != nil { + var searchRes []client.ResultSet + switch vt.FieldType { + case entity.FieldTypeSparseVector: + // For sparse vector, use IP metric with radius and range_filter + // IP metric: higher is better, range is [radius, range_filter] + annParams := index.NewSparseAnnParam() + annParams.WithRadius(0) + annParams.WithRangeFilter(100) + annParams.WithDropRatio(0.2) + searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{testData.SearchVec}). + WithANNSField("vector").WithAnnParam(annParams).WithOutputFields("vector")) + case entity.FieldTypeBinaryVector: + // For binary vector, use Hamming distance + // Hamming distance: smaller is better (number of different bits), range is [range_filter, radius] + // With dim=128, max Hamming distance is 128 + searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{testData.SearchVec}). + WithANNSField("vector").WithSearchParam("radius", "128").WithSearchParam("range_filter", "0").WithOutputFields("vector")) + case entity.FieldTypeInt8Vector: + // For int8 vector, use COSINE metric + // COSINE distance: range is [0, 2], smaller is better, range is [range_filter, radius] + searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{testData.SearchVec}). + WithANNSField("vector").WithSearchParam("radius", "2").WithSearchParam("range_filter", "0").WithOutputFields("vector")) + default: + // For dense vectors (FloatVector, Float16Vector, BFloat16Vector), use L2 metric + // L2 distance: smaller is better, so radius is upper bound, range_filter is lower bound + searchRes, err = mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{testData.SearchVec}). + WithANNSField("vector").WithSearchParam("radius", "100").WithSearchParam("range_filter", "0").WithOutputFields("vector")) + } + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes)) + + // Verify all results have valid vectors (not null) + searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data() + require.Greater(t, len(searchIDs), 0, "range search should return results") + for _, id := range searchIDs { + _, ok := testData.PkToVecIdx[id] + require.True(t, ok, "range search result pk %d should have valid vector", id) + } + + // Verify scores are within range based on metric type + scores := searchRes[0].Scores + for i, score := range scores { + switch vt.FieldType { + case entity.FieldTypeSparseVector: + // IP metric: higher is better, range is [radius, range_filter] = [0, 100] + require.GreaterOrEqual(t, score, float32(0), "sparse vector score should be >= radius(0), got %f for pk %d", score, searchIDs[i]) + require.LessOrEqual(t, score, float32(100), "sparse vector score should be <= range_filter(100), got %f for pk %d", score, searchIDs[i]) + case entity.FieldTypeBinaryVector: + // Hamming distance: range is [range_filter, radius] = [0, 128] + require.GreaterOrEqual(t, score, float32(0), "Hamming score should be >= range_filter(0), got %f for pk %d", score, searchIDs[i]) + require.LessOrEqual(t, score, float32(128), "Hamming score should be <= radius(128), got %f for pk %d", score, searchIDs[i]) + case entity.FieldTypeInt8Vector: + // COSINE distance: range is [range_filter, radius] = [0, 2] + require.GreaterOrEqual(t, score, float32(0), "COSINE score should be >= range_filter(0), got %f for pk %d", score, searchIDs[i]) + require.LessOrEqual(t, score, float32(2), "COSINE score should be <= radius(2), got %f for pk %d", score, searchIDs[i]) + default: + // L2 metric: lower is better, range is [range_filter, radius] = [0, 100] + require.GreaterOrEqual(t, score, float32(0), "L2 score should be >= range_filter(0), got %f for pk %d", score, searchIDs[i]) + require.LessOrEqual(t, score, float32(100), "L2 score should be <= radius(100), got %f for pk %d", score, searchIDs[i]) + } + } + } + + // clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } +} + +// index building on both SegmentGrowingImpl and ChunkedSegmentSealedImpl +func TestNullableVectorDifferentIndexTypes(t *testing.T) { + vectorTypes := GetVectorTypes() + nullPercents := GetNullPercents() + + segmentTypes := []string{"growing", "sealed"} + + for _, vt := range vectorTypes { + indexConfigs := GetIndexesForVectorType(vt.FieldType) + for _, nullPercent := range nullPercents { + for _, segmentType := range segmentTypes { + // For growing segment, only test once with default index (interim index IVF_FLAT_CC is always used) + // For sealed segment, iterate through all user-specified index types + var testIndexConfigs []IndexConfig + if segmentType == "growing" { + // Only use first (default) index config for growing segment + testIndexConfigs = []IndexConfig{indexConfigs[0]} + } else { + // Test all index types for sealed segment + testIndexConfigs = indexConfigs + } + + for _, idxCfg := range testIndexConfigs { + testName := fmt.Sprintf("%s_%s_%d%%null_%s", vt.Name, idxCfg.Name, nullPercent, segmentType) + idxCfgCopy := idxCfg // capture loop variable + t.Run(testName, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout*10) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // Create collection with nullable vector + collName := common.GenRandomString("nullable_vec_large", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true) + if vt.FieldType != entity.FieldTypeSparseVector { + vecField = vecField.WithDim(common.DefaultDim) + } + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + nb := 10000 + validData := make([]bool, nb) + validCount := 0 + for i := range nb { + validData[i] = (i % 100) >= nullPercent + if validData[i] { + validCount++ + } + } + + pkToVecIdx := make(map[int64]int) + vecIdx := 0 + for i := range nb { + if validData[i] { + pkToVecIdx[int64(i)] = vecIdx + vecIdx++ + } + } + + // Generate pk column + pkData := make([]int64, nb) + for i := range nb { + pkData[i] = int64(i) + } + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) + + // Generate vector column based on type + var vecColumn column.Column + var searchVec entity.Vector + var originalVectors interface{} + + switch vt.FieldType { + case entity.FieldTypeFloatVector: + vectors := make([][]float32, validCount) + for i := range validCount { + vec := make([]float32, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = float32(i*common.DefaultDim+j) / float32(validCount*common.DefaultDim) + } + vectors[i] = vec + } + vecColumn, err = column.NewNullableColumnFloatVector("vector", common.DefaultDim, vectors, validData) + searchVec = entity.FloatVector(vectors[0]) + originalVectors = vectors + + case entity.FieldTypeBinaryVector: + vectors := make([][]byte, validCount) + byteDim := common.DefaultDim / 8 + for i := range validCount { + vec := make([]byte, byteDim) + for j := range byteDim { + vec[j] = byte((i + j) % 256) + } + vectors[i] = vec + } + vecColumn, err = column.NewNullableColumnBinaryVector("vector", common.DefaultDim, vectors, validData) + searchVec = entity.BinaryVector(vectors[0]) + originalVectors = vectors + + case entity.FieldTypeFloat16Vector: + vectors := make([][]byte, validCount) + for i := range validCount { + vectors[i] = common.GenFloat16Vector(common.DefaultDim) + } + vecColumn, err = column.NewNullableColumnFloat16Vector("vector", common.DefaultDim, vectors, validData) + searchVec = entity.Float16Vector(vectors[0]) + originalVectors = vectors + + case entity.FieldTypeBFloat16Vector: + vectors := make([][]byte, validCount) + for i := range validCount { + vectors[i] = common.GenBFloat16Vector(common.DefaultDim) + } + vecColumn, err = column.NewNullableColumnBFloat16Vector("vector", common.DefaultDim, vectors, validData) + searchVec = entity.BFloat16Vector(vectors[0]) + originalVectors = vectors + + case entity.FieldTypeInt8Vector: + vectors := make([][]int8, validCount) + for i := range validCount { + vec := make([]int8, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = int8((i + j) % 127) + } + vectors[i] = vec + } + vecColumn, err = column.NewNullableColumnInt8Vector("vector", common.DefaultDim, vectors, validData) + searchVec = entity.Int8Vector(vectors[0]) + originalVectors = vectors + + case entity.FieldTypeSparseVector: + vectors := make([]entity.SparseEmbedding, validCount) + for i := range validCount { + positions := []uint32{0, uint32(i%1000 + 1), uint32(i%10000 + 1000)} + values := []float32{1.0, float32(i+1) / 1000.0, 0.1} + vectors[i], err = entity.NewSliceSparseEmbedding(positions, values) + common.CheckErr(t, err, true) + } + vecColumn, err = column.NewNullableColumnSparseFloatVector("vector", vectors, validData) + searchVec = vectors[0] + originalVectors = vectors + } + common.CheckErr(t, err, true) + + // Insert data + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, vecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + // For sealed segment, flush before creating index to convert growing to sealed + if segmentType == "sealed" { + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + } + + // Create index using the config for this test iteration + vecIndex := CreateIndexFromConfig("vector", idxCfgCopy) + indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask.Await(ctx) + common.CheckErr(t, err, true) + + // Load collection - specify load fields to potentially skip loading vector raw data + // When vector has index and is specified in LoadFields, system may use index instead of field data + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName). + WithLoadFields(common.DefaultInt64FieldName, "vector")) // Load pk and vector (via index) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + // Search + searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}). + WithOutputFields("*"). + WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes)) + require.GreaterOrEqual(t, searchRes[0].ResultCount, 1) + + // Verify search results + VerifyNullableVectorData(t, vt, searchRes[0], pkToVecIdx, originalVectors, "search") + + // Query to count rows + queryRes, err := mc.Query(ctx, client.NewQueryOption(collName). + WithFilter(fmt.Sprintf("%s >= 0", common.DefaultInt64FieldName)). + WithOutputFields("count(*)")) + common.CheckErr(t, err, true) + countCol := queryRes.GetColumn("count(*)") + count, _ := countCol.GetAsInt64(0) + require.EqualValues(t, nb, count) + + // Query with vector output to verify data + queryVecRes, err := mc.Query(ctx, client.NewQueryOption(collName). + WithFilter(fmt.Sprintf("%s < 100", common.DefaultInt64FieldName)). + WithOutputFields("*")) + common.CheckErr(t, err, true) + + // Verify query results + VerifyNullableVectorData(t, vt, queryVecRes, pkToVecIdx, originalVectors, "query") + + // Clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } + } + } + } +} + +func TestNullableVectorGroupBy(t *testing.T) { + groupByVectorTypes := []NullableVectorType{ + {"FloatVector", entity.FieldTypeFloatVector}, + {"Float16Vector", entity.FieldTypeFloat16Vector}, + {"BFloat16Vector", entity.FieldTypeBFloat16Vector}, + } + nullPercents := GetNullPercents() + + for _, vt := range groupByVectorTypes { + for _, nullPercent := range nullPercents { + testName := fmt.Sprintf("%s_%d%%null", vt.Name, nullPercent) + t.Run(testName, func(t *testing.T) { + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + collName := common.GenRandomString("nullable_vec_groupby", 5) + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName("vector").WithDataType(vt.FieldType).WithNullable(true).WithDim(common.DefaultDim) + groupField := entity.NewField().WithName("group_id").WithDataType(entity.FieldTypeInt64) + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField).WithField(groupField) + + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + nb := 500 + numGroups := 50 + rowsPerGroup := nb / numGroups + + validData := make([]bool, nb) + validCount := 0 + for i := range nb { + validData[i] = (i % 100) >= nullPercent + if validData[i] { + validCount++ + } + } + + pkData := make([]int64, nb) + groupData := make([]int64, nb) + for i := range nb { + pkData[i] = int64(i) + groupData[i] = int64(i / rowsPerGroup) // 0-9 -> group 0, 10-19 -> group 1, etc. + } + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) + groupColumn := column.NewColumnInt64("group_id", groupData) + + var vecColumn column.Column + var searchVec entity.Vector + + switch vt.FieldType { + case entity.FieldTypeFloatVector: + vectors := make([][]float32, validCount) + for i := range validCount { + vec := make([]float32, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = float32(i*common.DefaultDim+j) / 10000.0 + } + vectors[i] = vec + } + vecColumn, err = column.NewNullableColumnFloatVector("vector", common.DefaultDim, vectors, validData) + if validCount > 0 { + searchVec = entity.FloatVector(vectors[0]) + } + + case entity.FieldTypeFloat16Vector: + vectors := make([][]byte, validCount) + for i := range validCount { + vectors[i] = common.GenFloat16Vector(common.DefaultDim) + } + vecColumn, err = column.NewNullableColumnFloat16Vector("vector", common.DefaultDim, vectors, validData) + if validCount > 0 { + searchVec = entity.Float16Vector(vectors[0]) + } + + case entity.FieldTypeBFloat16Vector: + vectors := make([][]byte, validCount) + for i := range validCount { + vectors[i] = common.GenBFloat16Vector(common.DefaultDim) + } + vecColumn, err = column.NewNullableColumnBFloat16Vector("vector", common.DefaultDim, vectors, validData) + if validCount > 0 { + searchVec = entity.BFloat16Vector(vectors[0]) + } + } + common.CheckErr(t, err, true) + + // Insert + insertRes, err := mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, vecColumn, groupColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, insertRes.InsertCount) + + // Flush + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) + + if validCount > 0 { + vecIndex := index.NewFlatIndex(entity.L2) + indexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "vector", vecIndex)) + common.CheckErr(t, err, true) + err = indexTask.Await(ctx) + common.CheckErr(t, err, true) + + scalarIndexTask, err := mc.CreateIndex(ctx, client.NewCreateIndexOption(collName, "group_id", index.NewAutoIndex(entity.L2))) + common.CheckErr(t, err, true) + err = scalarIndexTask.Await(ctx) + common.CheckErr(t, err, true) + + // Load + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) + + searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 10, []entity.Vector{searchVec}). + WithANNSField("vector"). + WithGroupByField("group_id"). + WithOutputFields(common.DefaultInt64FieldName, "group_id")) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes)) + + // 1. Result count should be <= limit (number of unique groups) + // 2. Each result should have a unique group_id + // 3. All returned PKs should have valid vectors (not null) + resultCount := searchRes[0].ResultCount + require.LessOrEqual(t, resultCount, 10, "result count should be <= limit") + + // Check unique group_ids + seenGroups := make(map[int64]bool) + for i := 0; i < resultCount; i++ { + groupByValue, err := searchRes[0].GroupByValue.Get(i) + require.NoError(t, err) + groupID := groupByValue.(int64) + require.False(t, seenGroups[groupID], "group_id should be unique in GroupBy results") + seenGroups[groupID] = true + + // Verify the returned PK has a valid vector + pkValue, _ := searchRes[0].IDs.GetAsInt64(i) + require.True(t, validData[pkValue], "returned pk %d should have valid vector", pkValue) + } + } + + // Clean up + err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + } + } +} diff --git a/tests/python_client/common/common_func.py b/tests/python_client/common/common_func.py index 0c512f8ca4..27f04bb86f 100644 --- a/tests/python_client/common/common_func.py +++ b/tests/python_client/common/common_func.py @@ -2361,7 +2361,9 @@ def gen_data_by_collection_field(field, nb=None, start=0, random_pk=False): if nullable is False: return gen_vectors(nb, dim, vector_data_type=data_type) else: - raise MilvusException(message=f"gen data failed, vector field does not support nullable") + # gen 20% none data for nullable vector field + vectors = gen_vectors(nb, dim, vector_data_type=data_type) + return [None if i % 2 == 0 and random.random() < 0.4 else vectors[i] for i in range(nb)] elif data_type == DataType.ARRAY: if isinstance(field, dict): max_capacity = field.get('params')['max_capacity'] diff --git a/tests/python_client/milvus_client/test_milvus_client_collection.py b/tests/python_client/milvus_client/test_milvus_client_collection.py index 6a5640ac8c..3b671ccdb1 100644 --- a/tests/python_client/milvus_client/test_milvus_client_collection.py +++ b/tests/python_client/milvus_client/test_milvus_client_collection.py @@ -3628,26 +3628,6 @@ class TestMilvusClientCollectionNullInvalid(TestMilvusClientV2Base): error = {ct.err_code: 1100, ct.err_msg: "primary field not support null"} self.create_collection(client, collection_name, schema=schema, check_task=CheckTasks.err_res, check_items=error) - @pytest.mark.tags(CaseLabel.L1) - @pytest.mark.parametrize("vector_type", ct.all_float_vector_dtypes) - def test_milvus_client_collection_set_nullable_on_vector_field(self, vector_type): - """ - target: test create collection with nullable=True on vector field - method: create collection schema with vector field set as nullable - expected: raise exception - """ - client = self._client() - collection_name = cf.gen_collection_name_by_testcase_name() - # Create schema with nullable vector field - schema = self.create_schema(client, enable_dynamic_field=False)[0] - schema.add_field("id", DataType.INT64, is_primary=True, auto_id=False) - if vector_type == DataType.SPARSE_FLOAT_VECTOR: - schema.add_field("vector", vector_type, nullable=True) - else: - schema.add_field("vector", vector_type, dim=default_dim, nullable=True) - error = {ct.err_code: 1100, ct.err_msg: "vector type not support null"} - self.create_collection(client, collection_name, schema=schema, check_task=CheckTasks.err_res, check_items=error) - @pytest.mark.tags(CaseLabel.L1) def test_milvus_client_collection_set_nullable_on_partition_key_field(self): """ diff --git a/tests/restful_client_v2/testcases/test_collection_operations.py b/tests/restful_client_v2/testcases/test_collection_operations.py index ab63809a32..428f43428a 100644 --- a/tests/restful_client_v2/testcases/test_collection_operations.py +++ b/tests/restful_client_v2/testcases/test_collection_operations.py @@ -793,30 +793,6 @@ class TestCreateCollectionNegative(TestBase): assert rsp['code'] == 1100 assert "partition key field not support nullable" in rsp['message'] - def test_create_collections_with_vector_nullable(self): - """ - vector field not support nullable - """ - name = gen_collection_name() - dim = 128 - client = self.collection_client - payload = { - "collectionName": name, - "schema": { - "fields": [ - {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, - {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, - {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, - {"fieldName": "book_intro", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}, - "nullable": True} - ] - } - } - logging.info(f"create collection {name} with payload: {payload}") - rsp = client.collection_create(payload) - assert rsp['code'] == 1100 - assert "vector type not support null" in rsp['message'] - def test_create_collections_with_primary_default(self): """ primary key field not support defaultValue