feat: Add nullable vector support for proxy and querynode (#46305)

related: #45993 

This commit extends nullable vector support to the proxy layer,
querynode,
and adds comprehensive validation, search reduce, and field data
handling
    for nullable vectors with sparse storage.
    
    Proxy layer changes:
- Update validate_util.go checkAligned() with getExpectedVectorRows()
helper
      to validate nullable vector field alignment using valid data count
- Update checkFloatVectorFieldData/checkSparseFloatVectorFieldData for
      nullable vector validation with proper row count expectations
- Add FieldDataIdxComputer in typeutil/schema.go for logical-to-physical
      index translation during search reduce operations
- Update search_reduce_util.go reduceSearchResultData to use
idxComputers
      for correct field data indexing with nullable vectors
- Update task.go, task_query.go, task_upsert.go for nullable vector
handling
    - Update msg_pack.go with nullable vector field data processing
    
    QueryNode layer changes:
    - Update segments/result.go for nullable vector result handling
- Update segments/search_reduce.go with nullable vector offset
translation
    
    Storage and index changes:
- Update data_codec.go and utils.go for nullable vector serialization
- Update indexcgowrapper/dataset.go and index.go for nullable vector
indexing
    
    Utility changes:
- Add FieldDataIdxComputer struct with Compute() method for efficient
      logical-to-physical index mapping across multiple field data
- Update EstimateEntitySize() and AppendFieldData() with fieldIdxs
parameter
    - Update funcutil.go with nullable vector support functions

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Full support for nullable vector fields (float, binary, float16,
bfloat16, int8, sparse) across ingest, storage, indexing, search and
retrieval; logical↔physical offset mapping preserves row semantics.
  * Client: compaction control and compaction-state APIs.

* **Bug Fixes**
* Improved validation for adding vector fields (nullable + dimension
checks) and corrected search/query behavior for nullable vectors.

* **Chores**
  * Persisted validity maps with indexes and on-disk formats.

* **Tests**
  * Extensive new and updated end-to-end nullable-vector tests.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: marcelo-cjl <marcelo.chen@zilliz.com>
This commit is contained in:
marcelo-cjl 2025-12-24 10:13:19 +08:00 committed by GitHub
parent e4b0f48bc0
commit 3b599441fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
115 changed files with 11431 additions and 1717 deletions

View File

@ -46,6 +46,7 @@ type Column interface {
SetNullable(bool)
ValidateNullable() error
CompactNullableValues()
ValidCount() int
}
var errFieldDataTypeNotMatch = errors.New("FieldData type not matched")
@ -239,10 +240,39 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
}
data := x.FloatVector.GetData()
dim := int(vectors.GetDim())
if len(validData) > 0 {
if end < 0 {
end = len(validData)
}
vector := make([][]float32, 0, end-begin)
dataIdx := 0
for i := 0; i < begin; i++ {
if validData[i] {
dataIdx++
}
}
for i := begin; i < end; i++ {
if validData[i] {
v := make([]float32, dim)
copy(v, data[dataIdx*dim:(dataIdx+1)*dim])
vector = append(vector, v)
dataIdx++
} else {
vector = append(vector, nil)
}
}
col := NewColumnFloatVector(fd.GetFieldName(), dim, vector)
col.withValidData(validData[begin:end])
col.nullable = true
col.sparseMode = true
return col, nil
}
if end < 0 {
end = len(data) / dim
}
vector := make([][]float32, 0, end-begin) // shall not have remanunt
vector := make([][]float32, 0, end-begin)
for i := begin; i < end; i++ {
v := make([]float32, dim)
copy(v, data[i*dim:(i+1)*dim])
@ -262,6 +292,35 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
}
dim := int(vectors.GetDim())
blen := dim / 8
if len(validData) > 0 {
if end < 0 {
end = len(validData)
}
vector := make([][]byte, 0, end-begin)
dataIdx := 0
for i := 0; i < begin; i++ {
if validData[i] {
dataIdx++
}
}
for i := begin; i < end; i++ {
if validData[i] {
v := make([]byte, blen)
copy(v, data[dataIdx*blen:(dataIdx+1)*blen])
vector = append(vector, v)
dataIdx++
} else {
vector = append(vector, nil)
}
}
col := NewColumnBinaryVector(fd.GetFieldName(), dim, vector)
col.withValidData(validData[begin:end])
col.nullable = true
col.sparseMode = true
return col, nil
}
if end < 0 {
end = len(data) / blen
}
@ -281,13 +340,43 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
}
data := x.Float16Vector
dim := int(vectors.GetDim())
bytePerRow := dim * 2
if len(validData) > 0 {
if end < 0 {
end = len(validData)
}
vector := make([][]byte, 0, end-begin)
dataIdx := 0
for i := 0; i < begin; i++ {
if validData[i] {
dataIdx++
}
}
for i := begin; i < end; i++ {
if validData[i] {
v := make([]byte, bytePerRow)
copy(v, data[dataIdx*bytePerRow:(dataIdx+1)*bytePerRow])
vector = append(vector, v)
dataIdx++
} else {
vector = append(vector, nil)
}
}
col := NewColumnFloat16Vector(fd.GetFieldName(), dim, vector)
col.withValidData(validData[begin:end])
col.nullable = true
col.sparseMode = true
return col, nil
}
if end < 0 {
end = len(data) / dim / 2
end = len(data) / bytePerRow
}
vector := make([][]byte, 0, end-begin)
for i := begin; i < end; i++ {
v := make([]byte, dim*2)
copy(v, data[i*dim*2:(i+1)*dim*2])
v := make([]byte, bytePerRow)
copy(v, data[i*bytePerRow:(i+1)*bytePerRow])
vector = append(vector, v)
}
return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil
@ -300,13 +389,43 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
}
data := x.Bfloat16Vector
dim := int(vectors.GetDim())
if end < 0 {
end = len(data) / dim / 2
bytePerRow := dim * 2
if len(validData) > 0 {
if end < 0 {
end = len(validData)
}
vector := make([][]byte, 0, end-begin)
dataIdx := 0
for i := 0; i < begin; i++ {
if validData[i] {
dataIdx++
}
}
for i := begin; i < end; i++ {
if validData[i] {
v := make([]byte, bytePerRow)
copy(v, data[dataIdx*bytePerRow:(dataIdx+1)*bytePerRow])
vector = append(vector, v)
dataIdx++
} else {
vector = append(vector, nil)
}
}
col := NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector)
col.withValidData(validData[begin:end])
col.nullable = true
col.sparseMode = true
return col, nil
}
vector := make([][]byte, 0, end-begin) // shall not have remanunt
if end < 0 {
end = len(data) / bytePerRow
}
vector := make([][]byte, 0, end-begin)
for i := begin; i < end; i++ {
v := make([]byte, dim*2)
copy(v, data[i*dim*2:(i+1)*dim*2])
v := make([]byte, bytePerRow)
copy(v, data[i*bytePerRow:(i+1)*bytePerRow])
vector = append(vector, v)
}
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil
@ -317,6 +436,37 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
return nil, errFieldDataTypeNotMatch
}
data := sparseVectors.Contents
if len(validData) > 0 {
if end < 0 {
end = len(validData)
}
vectors := make([]entity.SparseEmbedding, 0, end-begin)
dataIdx := 0
for i := 0; i < begin; i++ {
if validData[i] {
dataIdx++
}
}
for i := begin; i < end; i++ {
if validData[i] {
vector, err := entity.DeserializeSliceSparseEmbedding(data[dataIdx])
if err != nil {
return nil, err
}
vectors = append(vectors, vector)
dataIdx++
} else {
vectors = append(vectors, nil)
}
}
col := NewColumnSparseVectors(fd.GetFieldName(), vectors)
col.withValidData(validData[begin:end])
col.nullable = true
col.sparseMode = true
return col, nil
}
if end < 0 {
end = len(data)
}
@ -339,11 +489,41 @@ func FieldDataColumn(fd *schemapb.FieldData, begin, end int) (Column, error) {
}
data := x.Int8Vector
dim := int(vectors.GetDim())
if len(validData) > 0 {
if end < 0 {
end = len(validData)
}
vector := make([][]int8, 0, end-begin)
dataIdx := 0
for i := 0; i < begin; i++ {
if validData[i] {
dataIdx++
}
}
for i := begin; i < end; i++ {
if validData[i] {
v := make([]int8, dim)
for j := 0; j < dim; j++ {
v[j] = int8(data[dataIdx*dim+j])
}
vector = append(vector, v)
dataIdx++
} else {
vector = append(vector, nil)
}
}
col := NewColumnInt8Vector(fd.GetFieldName(), dim, vector)
col.withValidData(validData[begin:end])
col.nullable = true
col.sparseMode = true
return col, nil
}
if end < 0 {
end = len(data) / dim
}
vector := make([][]int8, 0, end-begin) // shall not have remanunt
// TODO caiyd: has better way to convert []byte to []int8 ?
vector := make([][]int8, 0, end-begin)
for i := begin; i < end; i++ {
v := make([]int8, dim)
for j := 0; j < dim; j++ {

View File

@ -301,6 +301,19 @@ func (c *genericColumnBase[T]) CompactNullableValues() {
c.values = c.values[0:cnt]
}
func (c *genericColumnBase[T]) ValidCount() int {
if !c.nullable || len(c.validData) == 0 {
return len(c.values)
}
count := 0
for _, v := range c.validData {
if v {
count++
}
}
return count
}
func (c *genericColumnBase[T]) withValidData(validData []bool) {
if len(validData) > 0 {
c.nullable = true

View File

@ -16,6 +16,12 @@
package column
import (
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/client/v2/entity"
)
var (
// scalars
NewNullableColumnBool NullableColumnCreateFunc[bool, *ColumnBool] = NewNullableColumnCreator(NewColumnBool).New
@ -41,6 +47,76 @@ var (
NewNullableColumnDoubleArray NullableColumnCreateFunc[[]float64, *ColumnDoubleArray] = NewNullableColumnCreator(NewColumnDoubleArray).New
)
func NewNullableColumnFloatVector(fieldName string, dim int, values [][]float32, validData []bool) (*ColumnFloatVector, error) {
if len(values) != getValidCount(validData) {
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
}
col := NewColumnFloatVector(fieldName, dim, values)
col.withValidData(validData)
col.nullable = true
return col, nil
}
func NewNullableColumnBinaryVector(fieldName string, dim int, values [][]byte, validData []bool) (*ColumnBinaryVector, error) {
if len(values) != getValidCount(validData) {
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
}
col := NewColumnBinaryVector(fieldName, dim, values)
col.withValidData(validData)
col.nullable = true
return col, nil
}
func NewNullableColumnFloat16Vector(fieldName string, dim int, values [][]byte, validData []bool) (*ColumnFloat16Vector, error) {
if len(values) != getValidCount(validData) {
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
}
col := NewColumnFloat16Vector(fieldName, dim, values)
col.withValidData(validData)
col.nullable = true
return col, nil
}
func NewNullableColumnBFloat16Vector(fieldName string, dim int, values [][]byte, validData []bool) (*ColumnBFloat16Vector, error) {
if len(values) != getValidCount(validData) {
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
}
col := NewColumnBFloat16Vector(fieldName, dim, values)
col.withValidData(validData)
col.nullable = true
return col, nil
}
func NewNullableColumnInt8Vector(fieldName string, dim int, values [][]int8, validData []bool) (*ColumnInt8Vector, error) {
if len(values) != getValidCount(validData) {
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
}
col := NewColumnInt8Vector(fieldName, dim, values)
col.withValidData(validData)
col.nullable = true
return col, nil
}
func NewNullableColumnSparseFloatVector(fieldName string, values []entity.SparseEmbedding, validData []bool) (*ColumnSparseFloatVector, error) {
if len(values) != getValidCount(validData) {
return nil, errors.Newf("values length (%d) must equal valid count (%d) in validData", len(values), getValidCount(validData))
}
col := NewColumnSparseVectors(fieldName, values)
col.withValidData(validData)
col.nullable = true
return col, nil
}
func getValidCount(validData []bool) int {
count := 0
for _, v := range validData {
if v {
count++
}
}
return count
}
type NullableColumnCreateFunc[T any, Col interface {
Column
Data() []T

View File

@ -38,11 +38,15 @@ func NewColumnSparseVectors(name string, values []entity.SparseEmbedding) *Colum
func (c *ColumnSparseFloatVector) FieldData() *schemapb.FieldData {
fd := c.vectorBase.FieldData()
max := lo.MaxBy(c.values, func(a, b entity.SparseEmbedding) bool {
return a.Dim() > b.Dim()
})
vectors := fd.GetVectors()
vectors.Dim = int64(max.Dim())
if len(c.values) > 0 {
max := lo.MaxBy(c.values, func(a, b entity.SparseEmbedding) bool {
return a.Dim() > b.Dim()
})
vectors.Dim = int64(max.Dim())
} else {
vectors.Dim = 0
}
return fd
}

View File

@ -136,3 +136,7 @@ func (c *columnStructArray) CompactNullableValues() {
field.CompactNullableValues()
}
}
func (c *columnStructArray) ValidCount() int {
return c.Len()
}

View File

@ -206,6 +206,17 @@ const (
FieldTypeStruct FieldType = 201
)
// IsVectorType returns true if the field type is a vector type
func (t FieldType) IsVectorType() bool {
switch t {
case FieldTypeBinaryVector, FieldTypeFloatVector, FieldTypeFloat16Vector,
FieldTypeBFloat16Vector, FieldTypeSparseVector, FieldTypeInt8Vector:
return true
default:
return false
}
}
// Field represent field schema in milvus
type Field struct {
ID int64 // field id, generated when collection is created, input value is ignored

View File

@ -185,6 +185,10 @@ func (c *Client) GetCollectionStats(ctx context.Context, opt GetCollectionOption
// AddCollectionField adds a field to a collection.
func (c *Client) AddCollectionField(ctx context.Context, opt AddCollectionFieldOption, callOpts ...grpc.CallOption) error {
if err := opt.Validate(); err != nil {
return err
}
req := opt.Request()
err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error {

View File

@ -405,6 +405,8 @@ func NewGetCollectionStatsOption(collectionName string) *getCollectionStatsOptio
type AddCollectionFieldOption interface {
Request() *milvuspb.AddCollectionFieldRequest
// Validate validates the option before sending request
Validate() error
}
type addCollectionFieldOption struct {
@ -420,6 +422,15 @@ func (c *addCollectionFieldOption) Request() *milvuspb.AddCollectionFieldRequest
}
}
// Validate validates the option before sending request
func (c *addCollectionFieldOption) Validate() error {
// Vector fields must be nullable when adding to existing collection
if c.fieldSch.DataType.IsVectorType() && !c.fieldSch.Nullable {
return fmt.Errorf("adding vector field to existing collection requires nullable=true, field name = %s", c.fieldSch.Name)
}
return nil
}
func NewAddCollectionFieldOption(collectionName string, field *entity.Field) *addCollectionFieldOption {
return &addCollectionFieldOption{
collectionName: collectionName,

View File

@ -441,6 +441,37 @@ func (s *CollectionSuite) TestAddCollectionField() {
err := s.client.AddCollectionField(ctx, NewAddCollectionFieldOption(collName, field))
s.Error(err)
})
s.Run("vector_field_without_nullable", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
fieldName := fmt.Sprintf("field_%s", s.randString(6))
// no mock expected because validation should fail before RPC call
field := entity.NewField().WithName(fieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(128)
err := s.client.AddCollectionField(ctx, NewAddCollectionFieldOption(collName, field))
s.Error(err)
s.Contains(err.Error(), "adding vector field to existing collection requires nullable=true")
})
s.Run("vector_field_with_nullable", func() {
collName := fmt.Sprintf("coll_%s", s.randString(6))
fieldName := fmt.Sprintf("field_%s", s.randString(6))
s.mock.EXPECT().AddCollectionField(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, acfr *milvuspb.AddCollectionFieldRequest) (*commonpb.Status, error) {
fieldProto := &schemapb.FieldSchema{}
err := proto.Unmarshal(acfr.GetSchema(), fieldProto)
s.Require().NoError(err)
s.Equal(fieldName, fieldProto.GetName())
s.Equal(schemapb.DataType_FloatVector, fieldProto.GetDataType())
s.True(fieldProto.GetNullable())
return merr.Success(), nil
}).Once()
field := entity.NewField().WithName(fieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(128).WithNullable(true)
err := s.client.AddCollectionField(ctx, NewAddCollectionFieldOption(collName, field))
s.NoError(err)
})
}
func TestCollection(t *testing.T) {

View File

@ -126,6 +126,11 @@ class Chunk {
return data_;
}
FixedVector<bool>&
Valid() {
return valid_;
}
virtual bool
isValid(int offset) const {
if (nullable_) {
@ -559,17 +564,32 @@ class SparseFloatVectorChunk : public Chunk {
bool nullable,
std::shared_ptr<ChunkMmapGuard> chunk_mmap_guard)
: Chunk(row_nums, data, size, nullable, chunk_mmap_guard) {
vec_.resize(row_nums);
auto null_bitmap_bytes_num = nullable ? (row_nums + 7) / 8 : 0;
auto offsets_ptr =
reinterpret_cast<uint64_t*>(data + null_bitmap_bytes_num);
for (int i = 0; i < row_nums; i++) {
vec_[i] = {(offsets_ptr[i + 1] - offsets_ptr[i]) /
knowhere::sparse::SparseRow<
SparseValueType>::element_size(),
reinterpret_cast<uint8_t*>(data + offsets_ptr[i]),
false};
dim_ = std::max(dim_, vec_[i].dim());
if (nullable_) {
for (int i = 0; i < row_nums; i++) {
if (isValid(i)) {
vec_.emplace_back(
(offsets_ptr[i + 1] - offsets_ptr[i]) /
knowhere::sparse::SparseRow<
SparseValueType>::element_size(),
reinterpret_cast<uint8_t*>(data + offsets_ptr[i]),
false);
dim_ = std::max(dim_, vec_.back().dim());
}
}
} else {
vec_.resize(row_nums);
for (int i = 0; i < row_nums; i++) {
vec_[i] = {(offsets_ptr[i + 1] - offsets_ptr[i]) /
knowhere::sparse::SparseRow<
SparseValueType>::element_size(),
reinterpret_cast<uint8_t*>(data + offsets_ptr[i]),
false};
dim_ = std::max(dim_, vec_[i].dim());
}
}
}

View File

@ -435,8 +435,10 @@ SparseFloatVectorChunkWriter::calculate_size(
for (const auto& data : array_vec) {
auto array = std::dynamic_pointer_cast<arrow::BinaryArray>(data);
for (int64_t i = 0; i < array->length(); ++i) {
auto str = array->GetView(i);
size += str.size();
if (!nullable_ || !array->IsNull(i)) {
auto str = array->GetView(i);
size += str.size();
}
}
row_nums_ += array->length();
}
@ -459,8 +461,10 @@ SparseFloatVectorChunkWriter::write_to_target(
for (const auto& data : array_vec) {
auto array = std::dynamic_pointer_cast<arrow::BinaryArray>(data);
for (int64_t i = 0; i < array->length(); ++i) {
auto str = array->GetView(i);
strs.emplace_back(str);
if (!nullable_ || !array->IsNull(i)) {
auto str = array->GetView(i);
strs.emplace_back(str);
}
}
if (nullable_) {
null_bitmaps.emplace_back(
@ -478,9 +482,23 @@ SparseFloatVectorChunkWriter::write_to_target(
std::vector<uint64_t> offsets;
offsets.reserve(offset_num);
for (const auto& str : strs) {
offsets.push_back(offset_start_pos);
offset_start_pos += str.size();
if (nullable_) {
size_t str_idx = 0;
for (const auto& data : array_vec) {
auto array = std::dynamic_pointer_cast<arrow::BinaryArray>(data);
for (int i = 0; i < array->length(); i++) {
offsets.push_back(offset_start_pos);
if (!array->IsNull(i)) {
offset_start_pos += strs[str_idx].size();
str_idx++;
}
}
}
} else {
for (const auto& str : strs) {
offsets.push_back(offset_start_pos);
offset_start_pos += str.size();
}
}
offsets.push_back(offset_start_pos);
@ -524,22 +542,43 @@ create_chunk_writer(const FieldMeta& field_meta) {
return std::make_shared<ChunkWriter<arrow::Int64Array, int64_t>>(
dim, nullable);
case milvus::DataType::VECTOR_FLOAT:
if (nullable) {
return std::make_shared<
NullableVectorChunkWriter<knowhere::fp32>>(dim, nullable);
}
return std::make_shared<
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::fp32>>(
dim, nullable);
case milvus::DataType::VECTOR_BINARY:
if (nullable) {
return std::make_shared<
NullableVectorChunkWriter<knowhere::bin1>>(dim / 8,
nullable);
}
return std::make_shared<
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::bin1>>(
dim / 8, nullable);
case milvus::DataType::VECTOR_FLOAT16:
if (nullable) {
return std::make_shared<
NullableVectorChunkWriter<knowhere::fp16>>(dim, nullable);
}
return std::make_shared<
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::fp16>>(
dim, nullable);
case milvus::DataType::VECTOR_BFLOAT16:
if (nullable) {
return std::make_shared<
NullableVectorChunkWriter<knowhere::bf16>>(dim, nullable);
}
return std::make_shared<
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::bf16>>(
dim, nullable);
case milvus::DataType::VECTOR_INT8:
if (nullable) {
return std::make_shared<
NullableVectorChunkWriter<knowhere::int8>>(dim, nullable);
}
return std::make_shared<
ChunkWriter<arrow::FixedSizeBinaryArray, knowhere::int8>>(
dim, nullable);

View File

@ -129,6 +129,57 @@ class ChunkWriter final : public ChunkWriterBase {
const int64_t dim_;
};
template <typename T>
class NullableVectorChunkWriter final : public ChunkWriterBase {
public:
NullableVectorChunkWriter(int64_t dim, bool nullable)
: ChunkWriterBase(nullable), dim_(dim) {
Assert(nullable && "NullableVectorChunkWriter requires nullable=true");
}
std::pair<size_t, size_t>
calculate_size(const arrow::ArrayVector& array_vec) override {
size_t size = 0;
size_t row_nums = 0;
for (const auto& data : array_vec) {
row_nums += data->length();
auto binary_array =
std::static_pointer_cast<arrow::BinaryArray>(data);
int64_t valid_count = data->length() - binary_array->null_count();
size += valid_count * dim_ * sizeof(T);
}
// null bitmap size
size += (row_nums + 7) / 8;
row_nums_ = row_nums;
return {size, row_nums};
}
void
write_to_target(const arrow::ArrayVector& array_vec,
const std::shared_ptr<ChunkTarget>& target) override {
std::vector<std::tuple<const uint8_t*, int64_t, int64_t>> null_bitmaps;
for (const auto& data : array_vec) {
null_bitmaps.emplace_back(
data->null_bitmap_data(), data->length(), data->offset());
}
write_null_bit_maps(null_bitmaps, target);
for (const auto& data : array_vec) {
auto binary_array =
std::static_pointer_cast<arrow::BinaryArray>(data);
auto data_offset = binary_array->value_offset(0);
auto data_ptr = binary_array->value_data()->data() + data_offset;
int64_t valid_count = data->length() - binary_array->null_count();
target->write(data_ptr, valid_count * dim_ * sizeof(T));
}
}
private:
const int64_t dim_;
};
template <>
inline void
ChunkWriter<arrow::BooleanArray, bool>::write_to_target(

View File

@ -20,6 +20,7 @@
#include "arrow/array/array_binary.h"
#include "arrow/chunked_array.h"
#include "bitset/detail/element_wise.h"
#include "bitset/detail/popcount.h"
#include "common/Array.h"
#include "common/EasyAssert.h"
#include "common/Exception.h"
@ -310,6 +311,18 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
case DataType::VECTOR_BFLOAT16:
case DataType::VECTOR_INT8:
case DataType::VECTOR_BINARY: {
if (nullable_) {
auto binary_array =
std::dynamic_pointer_cast<arrow::BinaryArray>(array);
AssertInfo(binary_array != nullptr,
"nullable vector must use BinaryArray");
auto data_offset = binary_array->value_offset(0);
return FillFieldData(
binary_array->value_data()->data() + data_offset,
binary_array->null_bitmap_data(),
binary_array->length(),
binary_array->offset());
}
auto array_info =
GetDataInfoFromArray<arrow::FixedSizeBinaryArray,
arrow::Type::type::FIXED_SIZE_BINARY>(
@ -321,6 +334,20 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
"inconsistent data type");
auto arr = std::dynamic_pointer_cast<arrow::BinaryArray>(array);
std::vector<knowhere::sparse::SparseRow<SparseValueType>> values;
if (nullable_) {
for (int64_t i = 0; i < element_count; ++i) {
if (arr->IsValid(i)) {
auto view = arr->GetString(i);
values.push_back(
CopyAndWrapSparseRow(view.data(), view.size()));
}
}
return FillFieldData(values.data(),
arr->null_bitmap_data(),
arr->length(),
arr->offset());
}
for (size_t index = 0; index < element_count; ++index) {
auto view = arr->GetString(index);
values.push_back(
@ -572,6 +599,96 @@ template class FieldDataImpl<knowhere::sparse::SparseRow<SparseValueType>,
true>;
template class FieldDataImpl<VectorArray, true>;
template <typename Type, bool is_type_entire_row>
void
FieldDataVectorImpl<Type, is_type_entire_row>::FillFieldData(
const void* field_data,
const uint8_t* valid_data,
ssize_t total_element_count,
ssize_t offset) {
AssertInfo(this->nullable_, "requires nullable to be true");
if (total_element_count == 0) {
return;
}
int64_t valid_count = 0;
if (valid_data) {
int64_t bit_start = offset;
int64_t bit_end = offset + total_element_count;
// Handle head: unaligned bits before first full byte
int64_t first_full_byte = (bit_start + 7) / 8;
int64_t last_full_byte = bit_end / 8;
// Process unaligned head bits
for (int64_t bit_idx = bit_start;
bit_idx < std::min(first_full_byte * 8, bit_end);
++bit_idx) {
if ((valid_data[bit_idx >> 3] >> (bit_idx & 7)) & 1) {
valid_count++;
}
}
// Process aligned full bytes with popcount
for (int64_t byte_idx = first_full_byte; byte_idx < last_full_byte;
++byte_idx) {
valid_count += bitset::detail::PopCountHelper<uint8_t>::count(
valid_data[byte_idx]);
}
// Process unaligned tail bits
for (int64_t bit_idx =
std::max(last_full_byte * 8, first_full_byte * 8);
bit_idx < bit_end;
++bit_idx) {
if ((valid_data[bit_idx >> 3] >> (bit_idx & 7)) & 1) {
valid_count++;
}
}
} else {
valid_count = total_element_count;
}
std::lock_guard lck(this->tell_mutex_);
resize_field_data(this->length_ + total_element_count,
this->valid_count_ + valid_count);
if (valid_data) {
bitset::detail::ElementWiseBitsetPolicy<uint8_t>::op_copy(
valid_data,
offset,
this->valid_data_.data(),
this->length_,
total_element_count);
}
// update logical to physical offset mapping
l2p_mapping_.build(this->valid_data_.data(),
this->valid_count_,
this->length_,
total_element_count,
valid_count);
if (valid_count > 0) {
std::copy_n(static_cast<const Type*>(field_data),
valid_count * this->dim_,
this->data_.data() + this->valid_count_ * this->dim_);
this->valid_count_ += valid_count;
}
this->null_count_ = total_element_count - valid_count;
this->length_ += total_element_count;
}
// explicit instantiations for FieldDataVectorImpl
template class FieldDataVectorImpl<uint8_t, false>;
template class FieldDataVectorImpl<int8_t, false>;
template class FieldDataVectorImpl<float, false>;
template class FieldDataVectorImpl<float16, false>;
template class FieldDataVectorImpl<bfloat16, false>;
template class FieldDataVectorImpl<knowhere::sparse::SparseRow<SparseValueType>,
true>;
FieldDataPtr
InitScalarFieldData(const DataType& type, bool nullable, int64_t cap_rows) {
switch (type) {

View File

@ -19,6 +19,8 @@
#include <string>
#include <memory>
#include <utility>
#include <unordered_map>
#include <vector>
#include <oneapi/tbb/concurrent_queue.h>
@ -133,24 +135,251 @@ class FieldData<VectorArray> : public FieldDataVectorArrayImpl {
DataType element_type_;
};
template <typename Type, bool is_type_entire_row = false>
class FieldDataVectorImpl : public FieldDataImpl<Type, is_type_entire_row> {
private:
struct LogicalToPhysicalMapping {
bool mapping{false};
std::unordered_map<int64_t, int64_t> l2p_map;
std::vector<int64_t> l2p_vec;
int64_t
get_physical_offset(int64_t logical_offset) const {
if (!mapping) {
return logical_offset;
}
if (!l2p_map.empty()) {
auto it = l2p_map.find(logical_offset);
if (it != l2p_map.end()) {
return it->second;
}
return -1;
}
if (logical_offset < static_cast<int64_t>(l2p_vec.size())) {
return l2p_vec[logical_offset];
}
return -1;
}
void
build(const uint8_t* valid_data,
int64_t start_physical,
int64_t start_logical,
int64_t total_count,
int64_t valid_count) {
if (total_count == 0) {
return;
}
mapping = true;
// use map when valid ratio < 10%
bool use_map = (valid_count * 10 < total_count);
if (use_map) {
int64_t physical_idx = start_physical;
for (int64_t i = 0; i < total_count; ++i) {
int64_t bit_pos = start_logical + i;
if (valid_data == nullptr ||
((valid_data[bit_pos >> 3] >> (bit_pos & 0x07)) & 1)) {
l2p_map[start_logical + i] = physical_idx++;
}
}
} else {
// resize l2p_vec if needed
int64_t required_size = start_logical + total_count;
if (static_cast<int64_t>(l2p_vec.size()) < required_size) {
l2p_vec.resize(required_size, -1);
}
int64_t physical_idx = start_physical;
for (int64_t i = 0; i < total_count; ++i) {
int64_t bit_pos = start_logical + i;
if (valid_data == nullptr ||
((valid_data[bit_pos >> 3] >> (bit_pos & 0x07)) & 1)) {
l2p_vec[start_logical + i] = physical_idx++;
} else {
l2p_vec[start_logical + i] = -1;
}
}
}
}
};
void
resize_field_data(int64_t num_rows, int64_t valid_count) {
Assert(this->nullable_);
std::lock_guard lck(this->num_rows_mutex_);
if (num_rows > this->num_rows_) {
this->num_rows_ = num_rows;
this->valid_data_.resize((num_rows + 7) / 8, 0x00);
}
if (valid_count > this->valid_count_) {
this->data_.resize(valid_count * this->dim_);
}
}
LogicalToPhysicalMapping l2p_mapping_;
public:
using FieldDataImpl<Type, is_type_entire_row>::FieldDataImpl;
using FieldDataImpl<Type, is_type_entire_row>::resize_field_data;
void
FillFieldData(const void* field_data,
const uint8_t* valid_data,
ssize_t element_count,
ssize_t offset) override;
const void*
RawValue(ssize_t offset) const override {
auto physical_offset = l2p_mapping_.get_physical_offset(offset);
if (physical_offset == -1) {
return nullptr;
}
return &this->data_[physical_offset * this->dim_];
}
int64_t
DataSize() const override {
auto dim = this->dim_;
if (this->nullable_) {
return sizeof(Type) * this->valid_count_ * dim;
}
return sizeof(Type) * this->length_ * dim;
}
int64_t
DataSize(ssize_t offset) const override {
auto dim = this->dim_;
AssertInfo(offset < this->get_num_rows(),
"field data subscript out of range");
return sizeof(Type) * dim;
}
int64_t
get_valid_rows() const override {
if (this->nullable_) {
return this->valid_count_;
}
return this->get_num_rows();
}
};
class FieldDataSparseVectorImpl
: public FieldDataVectorImpl<knowhere::sparse::SparseRow<SparseValueType>,
true> {
using Base =
FieldDataVectorImpl<knowhere::sparse::SparseRow<SparseValueType>, true>;
public:
// Bring base class FillFieldData overloads into scope (for nullable support)
using Base::FillFieldData;
explicit FieldDataSparseVectorImpl(DataType data_type,
bool nullable = false,
int64_t total_num_rows = 0)
: FieldDataVectorImpl<knowhere::sparse::SparseRow<SparseValueType>,
true>(
/*dim=*/1, data_type, nullable, total_num_rows),
vec_dim_(0) {
AssertInfo(data_type == DataType::VECTOR_SPARSE_U32_F32,
"invalid data type for sparse vector");
}
int64_t
DataSize() const override {
int64_t data_size = 0;
size_t count = nullable_ ? valid_count_ : length_;
for (size_t i = 0; i < count; ++i) {
data_size += data_[i].data_byte_size();
}
return data_size;
}
int64_t
DataSize(ssize_t offset) const override {
AssertInfo(offset < get_num_rows(),
"field data subscript out of range");
size_t count = nullable_ ? valid_count_ : length_;
AssertInfo(
offset < count,
"subscript position don't has valid value offset={}, count={}",
offset,
count);
return data_[offset].data_byte_size();
}
void
FillFieldData(const void* source, ssize_t element_count) override {
if (element_count == 0) {
return;
}
std::lock_guard lck(tell_mutex_);
if (length_ + element_count > get_num_rows()) {
FieldDataImpl::resize_field_data(length_ + element_count);
}
auto ptr =
static_cast<const knowhere::sparse::SparseRow<SparseValueType>*>(
source);
for (int64_t i = 0; i < element_count; ++i) {
auto& row = ptr[i];
vec_dim_ = std::max(vec_dim_, row.dim());
}
std::copy_n(ptr, element_count, data_.data() + length_);
length_ += element_count;
}
void
FillFieldData(const std::shared_ptr<arrow::BinaryArray>& array) override {
auto n = array->length();
if (n == 0) {
return;
}
std::lock_guard lck(tell_mutex_);
if (length_ + n > get_num_rows()) {
FieldDataImpl::resize_field_data(length_ + n);
}
for (int64_t i = 0; i < array->length(); ++i) {
auto view = array->GetView(i);
auto& row = data_[length_ + i];
row = CopyAndWrapSparseRow(view.data(), view.size());
vec_dim_ = std::max(vec_dim_, row.dim());
}
length_ += n;
}
int64_t
Dim() const {
return vec_dim_;
}
private:
int64_t vec_dim_ = 0;
};
template <>
class FieldData<FloatVector> : public FieldDataImpl<float, false> {
class FieldData<FloatVector> : public FieldDataVectorImpl<float, false> {
public:
explicit FieldData(int64_t dim,
DataType data_type,
bool nullable,
int64_t buffered_num_rows = 0)
: FieldDataImpl<float, false>::FieldDataImpl(
dim, data_type, false, buffered_num_rows) {
: FieldDataVectorImpl<float, false>::FieldDataVectorImpl(
dim, data_type, nullable, buffered_num_rows) {
}
};
template <>
class FieldData<BinaryVector> : public FieldDataImpl<uint8_t, false> {
class FieldData<BinaryVector> : public FieldDataVectorImpl<uint8_t, false> {
public:
explicit FieldData(int64_t dim,
DataType data_type,
bool nullable,
int64_t buffered_num_rows = 0)
: FieldDataImpl(dim / 8, data_type, false, buffered_num_rows),
: FieldDataVectorImpl(dim / 8, data_type, nullable, buffered_num_rows),
binary_dim_(dim) {
Assert(dim % 8 == 0);
}
@ -165,43 +394,48 @@ class FieldData<BinaryVector> : public FieldDataImpl<uint8_t, false> {
};
template <>
class FieldData<Float16Vector> : public FieldDataImpl<float16, false> {
class FieldData<Float16Vector> : public FieldDataVectorImpl<float16, false> {
public:
explicit FieldData(int64_t dim,
DataType data_type,
bool nullable,
int64_t buffered_num_rows = 0)
: FieldDataImpl<float16, false>::FieldDataImpl(
dim, data_type, false, buffered_num_rows) {
: FieldDataVectorImpl<float16, false>::FieldDataVectorImpl(
dim, data_type, nullable, buffered_num_rows) {
}
};
template <>
class FieldData<BFloat16Vector> : public FieldDataImpl<bfloat16, false> {
class FieldData<BFloat16Vector> : public FieldDataVectorImpl<bfloat16, false> {
public:
explicit FieldData(int64_t dim,
DataType data_type,
bool nullable,
int64_t buffered_num_rows = 0)
: FieldDataImpl<bfloat16, false>::FieldDataImpl(
dim, data_type, false, buffered_num_rows) {
: FieldDataVectorImpl<bfloat16, false>::FieldDataVectorImpl(
dim, data_type, nullable, buffered_num_rows) {
}
};
template <>
class FieldData<SparseFloatVector> : public FieldDataSparseVectorImpl {
public:
explicit FieldData(DataType data_type, int64_t buffered_num_rows = 0)
: FieldDataSparseVectorImpl(data_type, buffered_num_rows) {
explicit FieldData(DataType data_type,
bool nullable = false,
int64_t buffered_num_rows = 0)
: FieldDataSparseVectorImpl(data_type, nullable, buffered_num_rows) {
}
};
template <>
class FieldData<Int8Vector> : public FieldDataImpl<int8, false> {
class FieldData<Int8Vector> : public FieldDataVectorImpl<int8, false> {
public:
explicit FieldData(int64_t dim,
DataType data_type,
bool nullable,
int64_t buffered_num_rows = 0)
: FieldDataImpl<int8, false>::FieldDataImpl(
dim, data_type, false, buffered_num_rows) {
: FieldDataVectorImpl<int8, false>::FieldDataVectorImpl(
dim, data_type, nullable, buffered_num_rows) {
}
};

View File

@ -127,6 +127,9 @@ class FieldDataBase {
virtual bool
is_valid(ssize_t offset) const = 0;
virtual int64_t
get_valid_rows() const = 0;
protected:
const DataType data_type_;
const bool nullable_;
@ -309,6 +312,12 @@ class FieldBitsetImpl : public FieldDataBase {
"is_valid(ssize_t offset) not implemented for bitset");
}
int64_t
get_valid_rows() const override {
ThrowInfo(NotImplemented,
"get_valid_rows() not implemented for bitset");
}
private:
FixedVector<Type> data_{};
// capacity that data_ can store
@ -340,9 +349,6 @@ class FieldDataImpl : public FieldDataBase {
dim_(is_type_entire_row ? 1 : dim) {
data_.resize(num_rows_ * dim_);
if (nullable) {
if (IsVectorDataType(data_type)) {
ThrowInfo(NotImplemented, "vector type not support null");
}
valid_data_.resize((num_rows_ + 7) / 8, 0xFF);
}
}
@ -492,6 +498,12 @@ class FieldDataImpl : public FieldDataBase {
return num_rows_;
}
int64_t
get_valid_rows() const override {
std::shared_lock lck(tell_mutex_);
return static_cast<int64_t>(length_) - null_count_;
}
void
resize_field_data(int64_t num_rows) {
std::lock_guard lck(num_rows_mutex_);
@ -540,13 +552,12 @@ class FieldDataImpl : public FieldDataBase {
FixedVector<uint8_t> valid_data_{};
// number of elements data_ can hold
int64_t num_rows_;
size_t valid_count_{0};
mutable std::shared_mutex num_rows_mutex_;
int64_t null_count_{0};
// number of actual elements in data_
size_t length_{};
mutable std::shared_mutex tell_mutex_;
private:
const ssize_t dim_;
};
@ -803,90 +814,6 @@ class FieldDataJsonImpl : public FieldDataImpl<Json, true> {
}
};
class FieldDataSparseVectorImpl
: public FieldDataImpl<knowhere::sparse::SparseRow<SparseValueType>, true> {
public:
explicit FieldDataSparseVectorImpl(DataType data_type,
int64_t total_num_rows = 0)
: FieldDataImpl<knowhere::sparse::SparseRow<SparseValueType>, true>(
/*dim=*/1, data_type, false, total_num_rows),
vec_dim_(0) {
AssertInfo(data_type == DataType::VECTOR_SPARSE_U32_F32,
"invalid data type for sparse vector");
}
int64_t
DataSize() const override {
int64_t data_size = 0;
for (size_t i = 0; i < length(); ++i) {
data_size += data_[i].data_byte_size();
}
return data_size;
}
int64_t
DataSize(ssize_t offset) const override {
AssertInfo(offset < get_num_rows(),
"field data subscript out of range");
AssertInfo(offset < length(),
"subscript position don't has valid value");
return data_[offset].data_byte_size();
}
// source is a pointer to element_count of
// knowhere::sparse::SparseRow<SparseValueType>
void
FillFieldData(const void* source, ssize_t element_count) override {
if (element_count == 0) {
return;
}
std::lock_guard lck(tell_mutex_);
if (length_ + element_count > get_num_rows()) {
resize_field_data(length_ + element_count);
}
auto ptr =
static_cast<const knowhere::sparse::SparseRow<SparseValueType>*>(
source);
for (int64_t i = 0; i < element_count; ++i) {
auto& row = ptr[i];
vec_dim_ = std::max(vec_dim_, row.dim());
}
std::copy_n(ptr, element_count, data_.data() + length_);
length_ += element_count;
}
// each binary in array is a knowhere::sparse::SparseRow<SparseValueType>
void
FillFieldData(const std::shared_ptr<arrow::BinaryArray>& array) override {
auto n = array->length();
if (n == 0) {
return;
}
std::lock_guard lck(tell_mutex_);
if (length_ + n > get_num_rows()) {
resize_field_data(length_ + n);
}
for (int64_t i = 0; i < array->length(); ++i) {
auto view = array->GetView(i);
auto& row = data_[length_ + i];
row = CopyAndWrapSparseRow(view.data(), view.size());
vec_dim_ = std::max(vec_dim_, row.dim());
}
length_ += n;
}
int64_t
Dim() const {
return vec_dim_;
}
private:
int64_t vec_dim_ = 0;
};
class FieldDataArrayImpl : public FieldDataImpl<Array, true> {
public:
explicit FieldDataArrayImpl(DataType data_type,

View File

@ -168,6 +168,8 @@ FieldMeta::ParseFrom(const milvus::proto::schema::FieldSchema& schema_proto) {
}
if (IsVectorDataType(data_type)) {
AssertInfo(!default_value.has_value(),
"vector fields do not support default values");
auto type_map = RepeatedKeyValToMap(schema_proto.type_params());
auto index_map = RepeatedKeyValToMap(schema_proto.index_params());
@ -183,12 +185,17 @@ FieldMeta::ParseFrom(const milvus::proto::schema::FieldSchema& schema_proto) {
data_type,
dim,
std::nullopt,
false,
nullable,
default_value};
}
auto metric_type = index_map.at("metric_type");
return FieldMeta{
name, field_id, data_type, dim, metric_type, false, default_value};
return FieldMeta{name,
field_id,
data_type,
dim,
metric_type,
nullable,
default_value};
}
if (IsStringDataType(data_type)) {

View File

@ -125,7 +125,8 @@ class FieldMeta {
vector_info_(VectorInfo{dim, std::move(metric_type)}),
default_value_(std::move(default_value)) {
Assert(IsVectorDataType(type_));
Assert(!nullable);
Assert(!default_value_.has_value() &&
"vector fields do not support default values");
}
// array of vector type

View File

@ -0,0 +1,194 @@
#include "common/OffsetMapping.h"
namespace milvus {
void
OffsetMapping::Build(const bool* valid_data,
int64_t total_count,
int64_t start_logical,
int64_t start_physical) {
if (total_count == 0 || valid_data == nullptr) {
return;
}
std::unique_lock<std::shared_mutex> lck(mutex_);
enabled_ = true;
total_count_ = start_logical + total_count;
// Count valid elements first
int64_t valid_count = 0;
for (int64_t i = 0; i < total_count; ++i) {
if (valid_data[i]) {
valid_count++;
}
}
// Auto-select storage mode: use map when valid ratio < 10%
use_map_ = (valid_count * 10 < total_count);
if (use_map_) {
// Map mode: only store valid entries
int64_t physical_idx = start_physical;
for (int64_t i = 0; i < total_count; ++i) {
if (valid_data[i]) {
l2p_map_[start_logical + i] = physical_idx;
p2l_map_[physical_idx] = start_logical + i;
physical_idx++;
}
}
} else {
// Vec mode: store all entries
int64_t required_size = start_logical + total_count;
if (static_cast<int64_t>(l2p_vec_.size()) < required_size) {
l2p_vec_.resize(required_size, -1);
}
int64_t physical_idx = start_physical;
for (int64_t i = 0; i < total_count; ++i) {
if (valid_data[i]) {
l2p_vec_[start_logical + i] = physical_idx;
if (physical_idx >= static_cast<int64_t>(p2l_vec_.size())) {
p2l_vec_.resize(physical_idx + 1, -1);
}
p2l_vec_[physical_idx] = start_logical + i;
physical_idx++;
} else {
l2p_vec_[start_logical + i] = -1;
}
}
}
valid_count_ += valid_count;
}
void
OffsetMapping::BuildIncremental(const bool* valid_data,
int64_t count,
int64_t start_logical,
int64_t start_physical) {
if (count == 0 || valid_data == nullptr) {
return;
}
std::unique_lock<std::shared_mutex> lck(mutex_);
enabled_ = true;
total_count_ = start_logical + count;
// Incremental builds always use vec mode
if (use_map_ && !l2p_map_.empty()) {
// Convert from map to vec if needed
int64_t max_logical = 0;
for (const auto& [logical, physical] : l2p_map_) {
if (logical > max_logical) {
max_logical = logical;
}
}
l2p_vec_.resize(max_logical + 1, -1);
for (const auto& [logical, physical] : l2p_map_) {
l2p_vec_[logical] = physical;
}
int64_t max_physical = 0;
for (const auto& [physical, logical] : p2l_map_) {
if (physical > max_physical) {
max_physical = physical;
}
}
p2l_vec_.resize(max_physical + 1, -1);
for (const auto& [physical, logical] : p2l_map_) {
p2l_vec_[physical] = logical;
}
l2p_map_.clear();
p2l_map_.clear();
use_map_ = false;
}
// Resize l2p_vec if needed
int64_t required_size = start_logical + count;
if (static_cast<int64_t>(l2p_vec_.size()) < required_size) {
l2p_vec_.resize(required_size, -1);
}
int64_t physical_idx = start_physical;
for (int64_t i = 0; i < count; ++i) {
if (valid_data[i]) {
l2p_vec_[start_logical + i] = physical_idx;
if (physical_idx >= static_cast<int64_t>(p2l_vec_.size())) {
p2l_vec_.resize(physical_idx + 1, -1);
}
p2l_vec_[physical_idx] = start_logical + i;
physical_idx++;
valid_count_++;
} else {
l2p_vec_[start_logical + i] = -1;
}
}
}
int64_t
OffsetMapping::GetPhysicalOffset(int64_t logical_offset) const {
std::shared_lock<std::shared_mutex> lck(mutex_);
if (!enabled_) {
return logical_offset;
}
if (use_map_) {
auto it = l2p_map_.find(static_cast<int32_t>(logical_offset));
if (it != l2p_map_.end()) {
return it->second;
}
return -1;
}
if (logical_offset < static_cast<int64_t>(l2p_vec_.size())) {
return l2p_vec_[logical_offset];
}
return -1;
}
int64_t
OffsetMapping::GetLogicalOffset(int64_t physical_offset) const {
std::shared_lock<std::shared_mutex> lck(mutex_);
if (!enabled_) {
return physical_offset;
}
if (use_map_) {
auto it = p2l_map_.find(static_cast<int32_t>(physical_offset));
if (it != p2l_map_.end()) {
return it->second;
}
return -1;
}
if (physical_offset < static_cast<int64_t>(p2l_vec_.size())) {
return p2l_vec_[physical_offset];
}
return -1;
}
bool
OffsetMapping::IsValid(int64_t logical_offset) const {
return GetPhysicalOffset(logical_offset) >= 0;
}
int64_t
OffsetMapping::GetValidCount() const {
std::shared_lock<std::shared_mutex> lck(mutex_);
return valid_count_;
}
bool
OffsetMapping::IsEnabled() const {
std::shared_lock<std::shared_mutex> lck(mutex_);
return enabled_;
}
int64_t
OffsetMapping::GetNextPhysicalOffset() const {
std::shared_lock<std::shared_mutex> lck(mutex_);
return valid_count_;
}
int64_t
OffsetMapping::GetTotalCount() const {
std::shared_lock<std::shared_mutex> lck(mutex_);
return total_count_;
}
} // namespace milvus

View File

@ -0,0 +1,96 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include <mutex>
#include <shared_mutex>
#include <unordered_map>
#include <vector>
namespace milvus {
// Bidirectional offset mapping for nullable vector storage
// Maps between logical offsets (with nulls) and physical offsets (only valid data)
// Supports two storage modes:
// - vec mode: uses vector for both L2P and P2L, efficient when valid ratio >= 10%
// - map mode: uses unordered_map for L2P, efficient when valid ratio < 10%
class OffsetMapping {
public:
OffsetMapping() = default;
// Build mapping from valid_data (bool array format)
// If use_vec is not specified, auto-select based on valid ratio (< 10% uses map)
void
Build(const bool* valid_data,
int64_t total_count,
int64_t start_logical = 0,
int64_t start_physical = 0);
// Build mapping incrementally (always uses vec mode for incremental builds)
void
BuildIncremental(const bool* valid_data,
int64_t count,
int64_t start_logical,
int64_t start_physical);
// Get physical offset from logical offset. Returns -1 if null.
int64_t
GetPhysicalOffset(int64_t logical_offset) const;
// Get logical offset from physical offset. Returns -1 if not found.
int64_t
GetLogicalOffset(int64_t physical_offset) const;
// Check if a logical offset is valid (not null)
bool
IsValid(int64_t logical_offset) const;
// Get count of valid (non-null) elements
int64_t
GetValidCount() const;
// Check if mapping is enabled
bool
IsEnabled() const;
// Get next physical offset (for incremental builds)
int64_t
GetNextPhysicalOffset() const;
// Get total logical count (including nulls)
int64_t
GetTotalCount() const;
private:
bool enabled_{false};
bool use_map_{false}; // true: use map for L2P, false: use vec
// Vec mode storage (uses int32_t to save memory)
std::vector<int32_t> l2p_vec_; // logical -> physical, -1 means null
std::vector<int32_t> p2l_vec_; // physical -> logical
// Map mode storage (for sparse valid data)
std::unordered_map<int32_t, int32_t> l2p_map_; // logical -> physical
std::unordered_map<int32_t, int32_t> p2l_map_; // physical -> logical
int64_t valid_count_{0};
int64_t total_count_{0}; // total logical count (including nulls)
mutable std::shared_mutex mutex_;
};
} // namespace milvus

View File

@ -29,6 +29,8 @@
#include "common/FieldMeta.h"
#include "common/ArrayOffsets.h"
#include "common/OffsetMapping.h"
#include "query/Utils.h"
#include "pb/schema.pb.h"
#include "knowhere/index/index_node.h"
@ -156,9 +158,10 @@ class VectorIterator {
class ChunkMergeIterator : public VectorIterator {
public:
ChunkMergeIterator(int chunk_count,
const milvus::OffsetMapping& offset_mapping,
const std::vector<int64_t>& total_rows_until_chunk = {},
bool larger_is_closer = false)
: total_rows_until_chunk_(total_rows_until_chunk),
: offset_mapping_(&offset_mapping),
larger_is_closer_(larger_is_closer),
heap_(OffsetDisPairComparator(larger_is_closer)) {
iterators_.reserve(chunk_count);
@ -180,7 +183,11 @@ class ChunkMergeIterator : public VectorIterator {
origin_pair, top->GetIteratorIdx());
heap_.push(off_dis_pair);
}
return top->GetOffDis();
auto result = top->GetOffDis();
if (offset_mapping_ != nullptr) {
result.first = offset_mapping_->GetLogicalOffset(result.first);
}
return result;
}
return std::nullopt;
}
@ -231,6 +238,7 @@ class ChunkMergeIterator : public VectorIterator {
OffsetDisPairComparator>
heap_;
bool sealed = false;
const milvus::OffsetMapping* offset_mapping_ = nullptr;
std::vector<int64_t> total_rows_until_chunk_;
bool larger_is_closer_ = false;
//currently, ChunkMergeIterator is guaranteed to be used serially without concurrent problem, in the future
@ -258,6 +266,7 @@ struct SearchResult {
int chunk_count,
const std::vector<int64_t>& total_rows_until_chunk,
const std::vector<knowhere::IndexNode::IteratorPtr>& kw_iterators,
const milvus::OffsetMapping& offset_mapping,
bool larger_is_closer = false) {
AssertInfo(kw_iterators.size() == nq * chunk_count,
"kw_iterators count:{} is not equal to nq*chunk_count:{}, "
@ -269,8 +278,11 @@ struct SearchResult {
for (int i = 0, vec_iter_idx = 0; i < kw_iterators.size(); i++) {
vec_iter_idx = vec_iter_idx % nq;
if (vector_iterators.size() < nq) {
auto chunk_merge_iter = std::make_shared<ChunkMergeIterator>(
chunk_count, total_rows_until_chunk, larger_is_closer);
auto chunk_merge_iter =
std::make_shared<ChunkMergeIterator>(chunk_count,
offset_mapping,
total_rows_until_chunk,
larger_is_closer);
vector_iterators.emplace_back(chunk_merge_iter);
}
const auto& kw_iterator = kw_iterators[i];

View File

@ -95,7 +95,8 @@ class Schema {
AddDebugField(const std::string& name,
DataType data_type,
int64_t dim,
std::optional<knowhere::MetricType> metric_type) {
std::optional<knowhere::MetricType> metric_type,
bool nullable = false) {
auto field_id = FieldId(debug_id);
debug_id++;
auto field_meta = FieldMeta(FieldName(name),
@ -103,7 +104,7 @@ class Schema {
data_type,
dim,
metric_type,
false,
nullable,
std::nullopt);
this->AddField(std::move(field_meta));
return field_id;
@ -225,7 +226,7 @@ class Schema {
std::optional<knowhere::MetricType> metric_type,
bool nullable) {
auto field_meta = FieldMeta(
name, id, data_type, dim, metric_type, false, std::nullopt);
name, id, data_type, dim, metric_type, nullable, std::nullopt);
this->AddField(std::move(field_meta));
}

View File

@ -304,7 +304,9 @@ CopyAndWrapSparseRow(const void* data,
template <typename Iterable>
std::unique_ptr<knowhere::sparse::SparseRow<SparseValueType>[]>
SparseBytesToRows(const Iterable& rows, const bool validate = false) {
AssertInfo(rows.size() > 0, "at least 1 sparse row should be provided");
if (rows.size() == 0) {
return nullptr;
}
auto res = std::make_unique<knowhere::sparse::SparseRow<SparseValueType>[]>(
rows.size());
for (size_t i = 0; i < rows.size(); ++i) {

View File

@ -57,7 +57,12 @@ PrepareVectorIteratorsFromIndex(const SearchInfo& search_info,
bool larger_is_closer =
PositivelyRelated(search_info.metric_type_);
search_result.AssembleChunkVectorIterators(
nq, 1, {0}, iterators_val.value(), larger_is_closer);
nq,
1,
{0},
iterators_val.value(),
index.GetOffsetMapping(),
larger_is_closer);
} else {
std::string operator_type = "";
if (search_info.group_by_field_id_.has_value()) {

View File

@ -120,6 +120,26 @@ VectorDiskAnnIndex<T>::Load(milvus::tracer::TraceContext ctx,
"failed to Deserialize index, " + KnowhereStatusString(stat));
span_load_engine->End();
auto local_chunk_manager =
storage::LocalChunkManagerSingleton::GetInstance().GetChunkManager();
auto local_index_path_prefix = file_manager_->GetLocalIndexObjectPrefix();
auto valid_data_path = local_index_path_prefix + "/" + VALID_DATA_KEY;
if (local_chunk_manager->Exist(valid_data_path)) {
size_t count;
local_chunk_manager->Read(valid_data_path, 0, &count, sizeof(size_t));
size_t byte_size = (count + 7) / 8;
std::vector<uint8_t> valid_bitmap(byte_size);
local_chunk_manager->Read(
valid_data_path, sizeof(size_t), valid_bitmap.data(), byte_size);
// Convert bitmap to bool array
std::unique_ptr<bool[]> valid_data(new bool[count]);
for (size_t i = 0; i < count; ++i) {
valid_data[i] = (valid_bitmap[i / 8] >> (i % 8)) & 1;
}
BuildValidData(valid_data.get(), count);
}
SetDim(index_.Dim());
}
@ -298,6 +318,23 @@ VectorDiskAnnIndex<T>::BuildWithDataset(const DatasetPtr& dataset,
if (stat != knowhere::Status::success)
ThrowInfo(ErrorCode::IndexBuildError,
"failed to build index, " + KnowhereStatusString(stat));
if (HasValidData()) {
auto valid_data_path = local_index_path_prefix + "/" + VALID_DATA_KEY;
size_t count = offset_mapping_.GetTotalCount();
local_chunk_manager->Write(valid_data_path, 0, &count, sizeof(size_t));
size_t byte_size = (count + 7) / 8;
std::vector<uint8_t> packed_data(byte_size, 0);
for (size_t i = 0; i < count; ++i) {
if (offset_mapping_.IsValid(i)) {
packed_data[i / 8] |= (1 << (i % 8));
}
}
local_chunk_manager->Write(
valid_data_path, sizeof(size_t), packed_data.data(), byte_size);
file_manager_->AddFile(valid_data_path);
}
local_chunk_manager->RemoveDir(
storage::GetSegmentRawDataPathPrefix(local_chunk_manager, segment_id));

View File

@ -27,6 +27,7 @@
#include "index/Index.h"
#include "common/Types.h"
#include "common/BitsetView.h"
#include "common/OffsetMapping.h"
#include "common/QueryResult.h"
#include "common/QueryInfo.h"
#include "common/OpContext.h"
@ -34,6 +35,10 @@
namespace milvus::index {
// valid data keys for nullable vector index serialization
constexpr const char* VALID_DATA_KEY = "valid_data";
constexpr const char* VALID_DATA_COUNT_KEY = "valid_data_count";
class VectorIndex : public IndexBase {
public:
explicit VectorIndex(const IndexType& index_type,
@ -145,6 +150,56 @@ class VectorIndex : public IndexBase {
return search_cfg;
}
void
UpdateValidData(const bool* valid_data, int64_t count) {
offset_mapping_.BuildIncremental(
valid_data,
count,
offset_mapping_.GetTotalCount(),
offset_mapping_.GetNextPhysicalOffset());
}
void
BuildValidData(const bool* valid_data, int64_t total_count) {
offset_mapping_.Build(valid_data, total_count);
}
bool
IsRowValid(int64_t logical_offset) const {
if (!offset_mapping_.IsEnabled()) {
return true;
}
return offset_mapping_.IsValid(logical_offset);
}
bool
HasValidData() const {
return offset_mapping_.IsEnabled();
}
int64_t
GetValidCount() const {
return offset_mapping_.GetValidCount();
}
int64_t
GetPhysicalOffset(int64_t logical_offset) const {
return offset_mapping_.GetPhysicalOffset(logical_offset);
}
int64_t
GetLogicalOffset(int64_t physical_offset) const {
return offset_mapping_.GetLogicalOffset(physical_offset);
}
const milvus::OffsetMapping&
GetOffsetMapping() const {
return offset_mapping_;
}
protected:
milvus::OffsetMapping offset_mapping_;
private:
MetricType metric_type_;
int64_t dim_;

View File

@ -146,6 +146,27 @@ VectorMemIndex<T>::Serialize(const Config& config) {
ThrowInfo(ErrorCode::UnexpectedError,
"failed to serialize index: {}",
KnowhereStatusString(stat));
// Serialize valid_data from offset_mapping if enabled
if (offset_mapping_.IsEnabled()) {
auto total_count = offset_mapping_.GetTotalCount();
std::shared_ptr<uint8_t[]> count_buf(new uint8_t[sizeof(size_t)]);
size_t count = static_cast<size_t>(total_count);
std::memcpy(count_buf.get(), &count, sizeof(size_t));
ret.Append(VALID_DATA_COUNT_KEY, count_buf, sizeof(size_t));
size_t byte_size = (count + 7) / 8;
std::shared_ptr<uint8_t[]> data(new uint8_t[byte_size]);
std::memset(data.get(), 0, byte_size);
for (size_t i = 0; i < count; ++i) {
if (offset_mapping_.IsValid(i)) {
data[i / 8] |= (1 << (i % 8));
}
}
ret.Append(VALID_DATA_KEY, data, byte_size);
}
Disassemble(ret);
return ret;
@ -160,6 +181,25 @@ VectorMemIndex<T>::LoadWithoutAssemble(const BinarySet& binary_set,
ThrowInfo(ErrorCode::UnexpectedError,
"failed to Deserialize index: {}",
KnowhereStatusString(stat));
// Deserialize valid_data bitmap and rebuild offset_mapping
if (binary_set.Contains(VALID_DATA_COUNT_KEY) &&
binary_set.Contains(VALID_DATA_KEY)) {
knowhere::BinaryPtr ptr;
ptr = binary_set.GetByName(VALID_DATA_COUNT_KEY);
size_t count;
std::memcpy(&count, ptr->data.get(), sizeof(size_t));
ptr = binary_set.GetByName(VALID_DATA_KEY);
// Convert bitmap to bool array
std::unique_ptr<bool[]> valid_data(new bool[count]);
auto bitmap = ptr->data.get();
for (size_t i = 0; i < count; ++i) {
valid_data[i] = (bitmap[i / 8] >> (i % 8)) & 1;
}
BuildValidData(valid_data.get(), count);
}
SetDim(index_.Dim());
}
@ -339,19 +379,48 @@ VectorMemIndex<T>::Build(const Config& config) {
build_config.update(config);
build_config.erase(INSERT_FILES_KEY);
build_config.erase(VEC_OPT_FIELDS);
if (!IndexIsSparse(GetIndexType())) {
int64_t total_size = 0;
int64_t total_num_rows = 0;
int64_t dim = 0;
for (auto data : field_datas) {
total_size += data->Size();
total_num_rows += data->get_num_rows();
bool nullable = false;
int64_t total_valid_rows = 0;
int64_t total_num_rows = 0;
for (auto data : field_datas) {
auto num_rows = data->get_num_rows();
auto valid_rows = data->get_valid_rows();
total_valid_rows += valid_rows;
total_num_rows += num_rows;
if (data->IsNullable()) {
nullable = true;
}
}
std::unique_ptr<bool[]> valid_data;
if (nullable) {
valid_data.reset(new bool[total_num_rows]);
int64_t chunk_offset = 0;
for (auto data : field_datas) {
auto rows = data->get_num_rows();
// Copy valid data from FieldData (bitmap format to bool array)
auto src_bitmap = data->ValidData();
for (int64_t i = 0; i < rows; ++i) {
valid_data[chunk_offset + i] =
(src_bitmap[i >> 3] >> (i & 7)) & 1;
}
chunk_offset += rows;
}
}
if (!IndexIsSparse(GetIndexType())) {
int64_t dim = 0;
int64_t total_size = 0;
for (auto data : field_datas) {
AssertInfo(dim == 0 || dim == data->get_dim(),
"inconsistent dim value between field datas!");
dim = data->get_dim();
if (elem_type_ == DataType::NONE) {
total_size += data->DataSize();
} else {
total_size += data->Size();
}
}
auto buf = std::shared_ptr<uint8_t[]>(new uint8_t[total_size]);
size_t lim_offset = 0;
@ -362,8 +431,9 @@ VectorMemIndex<T>::Build(const Config& config) {
if (elem_type_ == DataType::NONE) {
// TODO: avoid copying
for (auto data : field_datas) {
std::memcpy(buf.get() + offset, data->Data(), data->Size());
offset += data->Size();
auto valid_size = data->DataSize();
std::memcpy(buf.get() + offset, data->Data(), valid_size);
offset += valid_size;
data.reset();
}
} else {
@ -396,12 +466,12 @@ VectorMemIndex<T>::Build(const Config& config) {
data.reset();
}
total_num_rows = lim_offset;
total_valid_rows = lim_offset;
}
field_datas.clear();
auto dataset = GenDataset(total_num_rows, dim, buf.get());
auto dataset = GenDataset(total_valid_rows, dim, buf.get());
if (!scalar_info.empty()) {
dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info));
}
@ -410,12 +480,13 @@ VectorMemIndex<T>::Build(const Config& config) {
const_cast<const size_t*>(offsets.data()));
}
BuildWithDataset(dataset, build_config);
if (nullable) {
BuildValidData(valid_data.get(), total_num_rows);
}
} else {
// sparse
int64_t total_rows = 0;
int64_t dim = 0;
for (auto field_data : field_datas) {
total_rows += field_data->Length();
dim = std::max(
dim,
std::dynamic_pointer_cast<FieldData<SparseFloatVector>>(
@ -423,28 +494,31 @@ VectorMemIndex<T>::Build(const Config& config) {
->Dim());
}
std::vector<knowhere::sparse::SparseRow<SparseValueType>> vec(
total_rows);
total_valid_rows);
int64_t offset = 0;
for (auto field_data : field_datas) {
auto ptr = static_cast<
const knowhere::sparse::SparseRow<SparseValueType>*>(
field_data->Data());
AssertInfo(ptr, "failed to cast field data to sparse rows");
for (size_t i = 0; i < field_data->Length(); ++i) {
for (size_t i = 0; i < field_data->get_valid_rows(); ++i) {
// this does a deep copy of field_data's data.
// TODO: avoid copying by enforcing field data to give up
// ownership.
AssertInfo(dim >= ptr[i].dim(), "bad dim");
dim = std::max(dim, static_cast<int64_t>(ptr[i].dim()));
vec[offset + i] = ptr[i];
}
offset += field_data->Length();
offset += field_data->get_valid_rows();
}
auto dataset = GenDataset(total_rows, dim, vec.data());
auto dataset = GenDataset(total_valid_rows, dim, vec.data());
dataset->SetIsSparse(true);
if (!scalar_info.empty()) {
dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info));
}
BuildWithDataset(dataset, build_config);
if (nullable) {
BuildValidData(valid_data.get(), total_num_rows);
}
}
}
@ -572,6 +646,10 @@ VectorMemIndex<T>::GetVector(const DatasetPtr dataset) const {
template <typename T>
std::unique_ptr<const knowhere::sparse::SparseRow<SparseValueType>[]>
VectorMemIndex<T>::GetSparseVector(const DatasetPtr dataset) const {
if (dataset->GetRows() == 0) {
return nullptr;
}
auto res = index_.GetVectorByIds(dataset);
if (!res.has_value()) {
ThrowInfo(ErrorCode::UnexpectedError,
@ -646,6 +724,8 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
LOG_INFO("load with slice meta: {}", !slice_meta_filepath.empty());
std::chrono::duration<double> load_duration_sum;
std::chrono::duration<double> write_disk_duration_sum;
std::unique_ptr<storage::DataCodec> valid_data_count_codec;
std::unique_ptr<storage::DataCodec> valid_data_codec;
// load files in two parts:
// 1. EMB_LIST_META: Written separately to embedding_list_meta_writer_ptr (if embedding list type)
// 2. All other binaries: Merged and written to file_writer, forming a unified index file for knowhere
@ -683,6 +763,10 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
embedding_list_meta_writer_ptr) {
embedding_list_meta_writer_ptr->Write(
data->PayloadData(), data->PayloadSize());
} else if (prefix == VALID_DATA_COUNT_KEY) {
valid_data_count_codec = std::move(data);
} else if (prefix == VALID_DATA_KEY) {
valid_data_codec = std::move(data);
} else {
file_writer.Write(data->PayloadData(),
data->PayloadSize());
@ -724,6 +808,10 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
embedding_list_meta_writer_ptr) {
embedding_list_meta_writer_ptr->Write(
index_data->PayloadData(), index_data->PayloadSize());
} else if (prefix == VALID_DATA_COUNT_KEY) {
valid_data_count_codec = std::move(index_data);
} else if (prefix == VALID_DATA_KEY) {
valid_data_codec = std::move(index_data);
} else {
file_writer.Write(index_data->PayloadData(),
index_data->PayloadSize());
@ -768,6 +856,20 @@ void VectorMemIndex<T>::LoadFromFile(const Config& config) {
auto dim = index_.Dim();
this->SetDim(index_.Dim());
// Restore valid_data for nullable vector support
if (valid_data_count_codec && valid_data_codec) {
size_t count;
std::memcpy(
&count, valid_data_count_codec->PayloadData(), sizeof(size_t));
std::unique_ptr<bool[]> valid_data(new bool[count]);
auto bitmap = valid_data_codec->PayloadData();
for (size_t i = 0; i < count; ++i) {
valid_data[i] = (bitmap[i / 8] >> (i % 8)) & 1;
}
BuildValidData(valid_data.get(), count);
}
this->mmap_file_raii_ =
std::make_unique<MmapFileRAII>(local_filepath.value());
LOG_INFO(

View File

@ -22,7 +22,9 @@ class IndexCreatorBase {
virtual ~IndexCreatorBase() = default;
virtual void
Build(const milvus::DatasetPtr& dataset) = 0;
Build(const milvus::DatasetPtr& dataset,
const bool* valid_data = nullptr,
const int64_t valid_data_len = 0) = 0;
virtual void
Build() = 0;

View File

@ -83,7 +83,11 @@ ScalarIndexCreator::ScalarIndexCreator(
}
void
ScalarIndexCreator::Build(const milvus::DatasetPtr& dataset) {
ScalarIndexCreator::Build(const milvus::DatasetPtr& dataset,
const bool* valid_data,
const int64_t valid_data_len) {
(void)valid_data;
(void)valid_data_len;
auto size = dataset->GetRows();
auto data = dataset->GetTensor();
index_->BuildWithRawDataForUT(size, data);

View File

@ -27,7 +27,9 @@ class ScalarIndexCreator : public IndexCreatorBase {
const storage::FileManagerContext& file_manager_context);
void
Build(const milvus::DatasetPtr& dataset) override;
Build(const milvus::DatasetPtr& dataset,
const bool* valid_data = nullptr,
const int64_t valid_data_len = 0) override;
void
Build() override;

View File

@ -65,8 +65,15 @@ VecIndexCreator::dim() {
}
void
VecIndexCreator::Build(const milvus::DatasetPtr& dataset) {
VecIndexCreator::Build(const milvus::DatasetPtr& dataset,
const bool* valid_data,
const int64_t valid_data_len) {
index_->BuildWithDataset(dataset, config_);
if (valid_data && valid_data_len > 0) {
auto vec_index = dynamic_cast<index::VectorIndex*>(index_.get());
AssertInfo(vec_index != nullptr, "failed to cast index to VectorIndex");
vec_index->BuildValidData(valid_data, valid_data_len);
}
}
void

View File

@ -39,7 +39,9 @@ class VecIndexCreator : public IndexCreatorBase {
const storage::FileManagerContext& file_manager_context);
void
Build(const milvus::DatasetPtr& dataset) override;
Build(const milvus::DatasetPtr& dataset,
const bool* valid_data = nullptr,
const int64_t valid_data_len = 0) override;
void
Build() override;

View File

@ -564,6 +564,35 @@ BuildFloatVecIndex(CIndex index,
return status;
}
CStatus
BuildFloatVecIndexWithValidData(CIndex index,
int64_t float_value_num,
const float* vectors,
const bool* valid_data,
int64_t valid_data_len) {
SCOPE_CGO_CALL_METRIC();
auto status = CStatus();
try {
AssertInfo(index,
"failed to build float vector index, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex =
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
auto dim = cIndex->dim();
auto row_nums = float_value_num / dim;
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
cIndex->Build(ds, valid_data, valid_data_len);
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
}
return status;
}
CStatus
BuildFloat16VecIndex(CIndex index,
int64_t float16_value_num,
@ -592,6 +621,36 @@ BuildFloat16VecIndex(CIndex index,
return status;
}
CStatus
BuildFloat16VecIndexWithValidData(CIndex index,
int64_t float16_value_num,
const uint8_t* vectors,
const bool* valid_data,
int64_t valid_data_len) {
SCOPE_CGO_CALL_METRIC();
auto status = CStatus();
try {
AssertInfo(
index,
"failed to build float16 vector index, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex =
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
auto dim = cIndex->dim();
auto row_nums = float16_value_num / dim / 2;
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
cIndex->Build(ds, valid_data, valid_data_len);
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
}
return status;
}
CStatus
BuildBFloat16VecIndex(CIndex index,
int64_t bfloat16_value_num,
@ -620,6 +679,36 @@ BuildBFloat16VecIndex(CIndex index,
return status;
}
CStatus
BuildBFloat16VecIndexWithValidData(CIndex index,
int64_t bfloat16_value_num,
const uint8_t* vectors,
const bool* valid_data,
int64_t valid_data_len) {
SCOPE_CGO_CALL_METRIC();
auto status = CStatus();
try {
AssertInfo(
index,
"failed to build bfloat16 vector index, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex =
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
auto dim = cIndex->dim();
auto row_nums = bfloat16_value_num / dim / 2;
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
cIndex->Build(ds, valid_data, valid_data_len);
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
}
return status;
}
CStatus
BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors) {
SCOPE_CGO_CALL_METRIC();
@ -646,6 +735,36 @@ BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors) {
return status;
}
CStatus
BuildBinaryVecIndexWithValidData(CIndex index,
int64_t data_size,
const uint8_t* vectors,
const bool* valid_data,
int64_t valid_data_len) {
SCOPE_CGO_CALL_METRIC();
auto status = CStatus();
try {
AssertInfo(
index,
"failed to build binary vector index, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex =
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
auto dim = cIndex->dim();
auto row_nums = (data_size * 8) / dim;
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
cIndex->Build(ds, valid_data, valid_data_len);
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
}
return status;
}
CStatus
BuildSparseFloatVecIndex(CIndex index,
int64_t row_num,
@ -674,6 +793,36 @@ BuildSparseFloatVecIndex(CIndex index,
return status;
}
CStatus
BuildSparseFloatVecIndexWithValidData(CIndex index,
int64_t row_num,
int64_t dim,
const uint8_t* vectors,
const bool* valid_data,
int64_t valid_data_len) {
SCOPE_CGO_CALL_METRIC();
auto status = CStatus();
try {
AssertInfo(
index,
"failed to build sparse float vector index, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex =
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
auto ds = knowhere::GenDataSet(row_num, dim, vectors);
ds->SetIsSparse(true);
cIndex->Build(ds, valid_data, valid_data_len);
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
}
return status;
}
CStatus
BuildInt8VecIndex(CIndex index, int64_t int8_value_num, const int8_t* vectors) {
SCOPE_CGO_CALL_METRIC();
@ -699,6 +848,35 @@ BuildInt8VecIndex(CIndex index, int64_t int8_value_num, const int8_t* vectors) {
return status;
}
CStatus
BuildInt8VecIndexWithValidData(CIndex index,
int64_t int8_value_num,
const int8_t* vectors,
const bool* valid_data,
int64_t valid_data_len) {
SCOPE_CGO_CALL_METRIC();
auto status = CStatus();
try {
AssertInfo(index,
"failed to build int8 vector index, passed index was null");
auto real_index =
reinterpret_cast<milvus::indexbuilder::IndexCreatorBase*>(index);
auto cIndex =
dynamic_cast<milvus::indexbuilder::VecIndexCreator*>(real_index);
auto dim = cIndex->dim();
auto row_nums = int8_value_num / dim;
auto ds = knowhere::GenDataSet(row_nums, dim, vectors);
cIndex->Build(ds, valid_data, valid_data_len);
status.error_code = Success;
status.error_msg = "";
} catch (std::exception& e) {
status.error_code = UnexpectedError;
status.error_msg = strdup(e.what());
}
return status;
}
// field_data:
// 1, serialized proto::schema::BoolArray, if type is bool;
// 2, serialized proto::schema::StringArray, if type is string;

View File

@ -55,24 +55,67 @@ CreateIndexForUT(enum CDataType dtype,
CStatus
BuildFloatVecIndex(CIndex index, int64_t float_value_num, const float* vectors);
CStatus
BuildFloatVecIndexWithValidData(CIndex index,
int64_t float_value_num,
const float* vectors,
const bool* valid_data,
int64_t valid_data_len);
CStatus
BuildBinaryVecIndex(CIndex index, int64_t data_size, const uint8_t* vectors);
CStatus
BuildBinaryVecIndexWithValidData(CIndex index,
int64_t data_size,
const uint8_t* vectors,
const bool* valid_data,
int64_t valid_data_len);
CStatus
BuildFloat16VecIndex(CIndex index, int64_t data_size, const uint8_t* vectors);
CStatus
BuildFloat16VecIndexWithValidData(CIndex index,
int64_t data_size,
const uint8_t* vectors,
const bool* valid_data,
int64_t valid_data_len);
CStatus
BuildBFloat16VecIndex(CIndex index, int64_t data_size, const uint8_t* vectors);
CStatus
BuildBFloat16VecIndexWithValidData(CIndex index,
int64_t data_size,
const uint8_t* vectors,
const bool* valid_data,
int64_t valid_data_len);
CStatus
BuildSparseFloatVecIndex(CIndex index,
int64_t row_num,
int64_t dim,
const uint8_t* vectors);
CStatus
BuildSparseFloatVecIndexWithValidData(CIndex index,
int64_t row_num,
int64_t dim,
const uint8_t* vectors,
const bool* valid_data,
int64_t valid_data_len);
CStatus
BuildInt8VecIndex(CIndex index, int64_t data_size, const int8_t* vectors);
CStatus
BuildInt8VecIndexWithValidData(CIndex index,
int64_t data_size,
const int8_t* vectors,
const bool* valid_data,
int64_t valid_data_len);
// field_data:
// 1, serialized proto::schema::BoolArray, if type is bool;
// 2, serialized proto::schema::StringArray, if type is string;

View File

@ -368,6 +368,33 @@ class ChunkedColumnBase : public ChunkedColumnInterface {
return meta->num_rows_until_chunk_;
}
void
BuildValidRowIds(milvus::OpContext* op_ctx) override {
if (!nullable_) {
return;
}
auto ca = SemiInlineGet(slot_->PinAllCells(op_ctx));
int64_t logical_offset = 0;
valid_data_.resize(num_rows_);
valid_count_per_chunk_.resize(num_chunks_);
for (size_t i = 0; i < num_chunks_; i++) {
auto chunk = ca->get_cell_of(i);
auto rows = chunk_row_nums(i);
int64_t valid_count = 0;
for (int64_t j = 0; j < rows; j++) {
if (chunk->isValid(j)) {
valid_data_[logical_offset + j] = true;
valid_count++;
} else {
valid_data_[logical_offset + j] = false;
}
}
valid_count_per_chunk_[i] = valid_count;
logical_offset += rows;
}
BuildOffsetMapping();
}
protected:
bool nullable_{false};
DataType data_type_{DataType::NONE};

View File

@ -667,6 +667,36 @@ class ProxyChunkColumn : public ChunkedColumnInterface {
}
}
void
BuildValidRowIds(milvus::OpContext* op_ctx) override {
if (!field_meta_.is_nullable()) {
return;
}
auto total_rows = NumRows();
auto total_chunks = num_chunks();
valid_data_.resize(total_rows);
valid_count_per_chunk_.resize(total_chunks);
int64_t logical_offset = 0;
for (int64_t i = 0; i < total_chunks; i++) {
auto group_chunk = group_->GetGroupChunk(op_ctx, i);
auto chunk = group_chunk.get()->GetChunk(field_id_);
auto rows = chunk->RowNums();
int64_t valid_count = 0;
for (int64_t j = 0; j < rows; j++) {
if (chunk->isValid(j)) {
valid_data_[logical_offset + j] = true;
valid_count++;
} else {
valid_data_[logical_offset + j] = false;
}
}
valid_count_per_chunk_[i] = valid_count;
logical_offset += rows;
}
BuildOffsetMapping();
}
private:
std::shared_ptr<ChunkedColumnGroup> group_;
FieldId field_id_;

View File

@ -17,6 +17,7 @@
#include "cachinglayer/CacheSlot.h"
#include "common/Chunk.h"
#include "common/OffsetMapping.h"
#include "common/bson_view.h"
namespace milvus {
@ -131,6 +132,35 @@ class ChunkedColumnInterface {
virtual const std::vector<int64_t>&
GetNumRowsUntilChunk() const = 0;
const FixedVector<bool>&
GetValidData() const {
return valid_data_;
}
const std::vector<int64_t>&
GetValidCountPerChunk() const {
return valid_count_per_chunk_;
}
const OffsetMapping&
GetOffsetMapping() const {
return offset_mapping_;
}
virtual void
BuildValidRowIds(milvus::OpContext* op_ctx) {
ThrowInfo(ErrorCode::Unsupported,
"BuildValidRowIds not supported for this column type");
}
// Build offset mapping from valid_data
void
BuildOffsetMapping() {
if (!valid_data_.empty()) {
offset_mapping_.Build(valid_data_.data(), valid_data_.size());
}
}
virtual void
BulkValueAt(milvus::OpContext* op_ctx,
std::function<void(const char*, size_t)> fn,
@ -237,6 +267,10 @@ class ChunkedColumnInterface {
}
protected:
FixedVector<bool> valid_data_;
std::vector<int64_t> valid_count_per_chunk_;
OffsetMapping offset_mapping_;
std::pair<std::vector<milvus::cachinglayer::cid_t>, std::vector<int64_t>>
ToChunkIdAndOffset(const int64_t* offsets, int64_t count) const {
AssertInfo(offsets != nullptr, "Offsets cannot be nullptr");

View File

@ -22,6 +22,7 @@
#include "query/CachedSearchIterator.h"
#include "query/SearchBruteForce.h"
#include "query/SearchOnIndex.h"
#include "query/Utils.h"
#include "exec/operator/Utils.h"
namespace milvus::query {
@ -82,8 +83,6 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
SearchResult& search_result) {
auto& schema = segment.get_schema();
auto& record = segment.get_insert_record();
auto active_row_count =
std::min(int64_t(bitset.size()), segment.get_active_count(timestamp));
// step 1.1: get meta
// step 1.2: get which vector field to search
@ -155,6 +154,19 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
// step 3: brute force search where small indexing is unavailable
auto vec_ptr = record.get_data_base(vecfield_id);
const auto& offset_mapping = vec_ptr->get_offset_mapping();
TargetBitmap transformed_bitset;
BitsetView search_bitset = bitset;
if (offset_mapping.IsEnabled()) {
transformed_bitset = TransformBitset(bitset, offset_mapping);
search_bitset = BitsetView(transformed_bitset);
}
auto active_count = offset_mapping.IsEnabled()
? offset_mapping.GetValidCount()
: std::min(int64_t(bitset.size()),
segment.get_active_count(timestamp));
if (info.iterator_v2_info_.has_value()) {
AssertInfo(data_type != DataType::VECTOR_ARRAY,
@ -163,17 +175,20 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
CachedSearchIterator cached_iter(search_dataset,
vec_ptr,
active_row_count,
active_count,
info,
index_info,
bitset,
search_bitset,
data_type);
cached_iter.NextBatch(info, search_result);
if (offset_mapping.IsEnabled()) {
TransformOffset(search_result.seg_offsets_, offset_mapping);
}
return;
}
auto vec_size_per_chunk = vec_ptr->get_size_per_chunk();
auto max_chunk = upper_div(active_row_count, vec_size_per_chunk);
auto max_chunk = upper_div(active_count, vec_size_per_chunk);
// embedding search embedding on embedding list
bool embedding_search = false;
@ -188,7 +203,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
auto row_begin = chunk_id * vec_size_per_chunk;
auto row_end =
std::min(active_row_count, (chunk_id + 1) * vec_size_per_chunk);
std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
auto size_per_chunk = row_end - row_begin;
query::dataset::RawDataset sub_data;
@ -260,7 +275,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
sub_data,
info,
index_info,
bitset,
search_bitset,
vector_type);
final_qr.merge(sub_qr);
} else {
@ -268,7 +283,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
sub_data,
info,
index_info,
bitset,
search_bitset,
vector_type,
element_type,
op_context);
@ -286,6 +301,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
max_chunk,
chunk_rows,
final_qr.chunk_iterators(),
offset_mapping,
larger_is_closer);
} else {
if (info.array_offsets_ != nullptr) {
@ -300,6 +316,9 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
std::move(final_qr.mutable_offsets());
}
search_result.distances_ = std::move(final_qr.mutable_distances());
if (offset_mapping.IsEnabled()) {
TransformOffset(search_result.seg_offsets_, offset_mapping);
}
}
search_result.unity_topK_ = topk;
search_result.total_nq_ = num_queries;

View File

@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "SearchOnIndex.h"
#include "Utils.h"
#include "exec/operator/Utils.h"
#include "CachedSearchIterator.h"
@ -28,23 +29,39 @@ SearchOnIndex(const dataset::SearchDataset& search_dataset,
auto dataset =
knowhere::GenDataSet(num_queries, dim, search_dataset.query_data);
dataset->SetIsSparse(is_sparse);
const auto& offset_mapping = indexing.GetOffsetMapping();
TargetBitmap transformed_bitset;
BitsetView search_bitset = bitset;
if (offset_mapping.IsEnabled()) {
transformed_bitset = TransformBitset(bitset, offset_mapping);
search_bitset = BitsetView(transformed_bitset);
}
if (milvus::exec::PrepareVectorIteratorsFromIndex(search_conf,
num_queries,
dataset,
search_result,
bitset,
search_bitset,
indexing)) {
return;
}
if (search_conf.iterator_v2_info_.has_value()) {
auto iter =
CachedSearchIterator(indexing, dataset, search_conf, bitset);
CachedSearchIterator(indexing, dataset, search_conf, search_bitset);
iter.NextBatch(search_conf, search_result);
if (offset_mapping.IsEnabled()) {
TransformOffset(search_result.seg_offsets_, offset_mapping);
}
return;
}
indexing.Query(dataset, search_conf, bitset, op_context, search_result);
indexing.Query(
dataset, search_conf, search_bitset, op_context, search_result);
if (offset_mapping.IsEnabled()) {
TransformOffset(search_result.seg_offsets_, offset_mapping);
}
}
} // namespace milvus::query

View File

@ -16,12 +16,14 @@
#include "bitset/detail/element_wise.h"
#include "cachinglayer/Utils.h"
#include "common/BitsetView.h"
#include "common/Consts.h"
#include "common/QueryInfo.h"
#include "common/Types.h"
#include "common/Utils.h"
#include "query/CachedSearchIterator.h"
#include "query/SearchBruteForce.h"
#include "query/SearchOnSealed.h"
#include "query/Utils.h"
#include "query/helper.h"
#include "exec/operator/Utils.h"
@ -73,21 +75,40 @@ SearchOnSealedIndex(const Schema& schema,
auto vec_index =
dynamic_cast<index::VectorIndex*>(accessor->get_cell_of(0));
const auto& offset_mapping = vec_index->GetOffsetMapping();
TargetBitmap transformed_bitset;
BitsetView search_bitset = bitset;
if (offset_mapping.IsEnabled()) {
transformed_bitset = TransformBitset(bitset, offset_mapping);
search_bitset = BitsetView(transformed_bitset);
if (offset_mapping.GetValidCount() == 0) {
auto total_num = num_queries * topK;
search_result.seg_offsets_.resize(total_num, INVALID_SEG_OFFSET);
search_result.distances_.resize(total_num, 0.0f);
search_result.total_nq_ = num_queries;
search_result.unity_topK_ = topK;
return;
}
}
if (search_info.iterator_v2_info_.has_value()) {
CachedSearchIterator cached_iter(
*vec_index, dataset, search_info, bitset);
*vec_index, dataset, search_info, search_bitset);
cached_iter.NextBatch(search_info, search_result);
TransformOffset(search_result.seg_offsets_, offset_mapping);
return;
}
if (!milvus::exec::PrepareVectorIteratorsFromIndex(search_info,
num_queries,
dataset,
search_result,
bitset,
*vec_index)) {
bool use_iterator =
milvus::exec::PrepareVectorIteratorsFromIndex(search_info,
num_queries,
dataset,
search_result,
search_bitset,
*vec_index);
if (!use_iterator) {
vec_index->Query(
dataset, search_info, bitset, op_context, search_result);
dataset, search_info, search_bitset, op_context, search_result);
float* distances = search_result.distances_.data();
auto total_num = num_queries * topK;
if (round_decimal != -1) {
@ -120,6 +141,7 @@ SearchOnSealedIndex(const Schema& schema,
search_result.element_level_ = true;
}
}
TransformOffset(search_result.seg_offsets_, offset_mapping);
search_result.total_nq_ = num_queries;
search_result.unity_topK_ = topK;
}
@ -185,12 +207,30 @@ SearchOnSealedColumn(const Schema& schema,
}
auto offset = 0;
const auto& offset_mapping = column->GetOffsetMapping();
TargetBitmap transformed_bitset;
BitsetView search_bitview = bitview;
if (offset_mapping.IsEnabled()) {
transformed_bitset = TransformBitset(bitview, offset_mapping);
search_bitview = BitsetView(transformed_bitset);
if (offset_mapping.GetValidCount() == 0) {
auto total_num = num_queries * search_info.topk_;
result.seg_offsets_.resize(total_num, INVALID_SEG_OFFSET);
result.distances_.resize(total_num, 0.0f);
result.total_nq_ = num_queries;
result.unity_topK_ = search_info.topk_;
return;
}
}
auto vector_chunks = column->GetAllChunks(op_context);
const auto& valid_count_per_chunk = column->GetValidCountPerChunk();
for (int i = 0; i < num_chunk; ++i) {
auto pw = vector_chunks[i];
auto vec_data = pw.get()->Data();
auto chunk_size = column->chunk_row_nums(i);
if (offset_mapping.IsEnabled() && !valid_count_per_chunk.empty()) {
chunk_size = valid_count_per_chunk[i];
}
// For element-level search, get element count from VectorArrayOffsets
if (is_element_level_search) {
@ -221,7 +261,7 @@ SearchOnSealedColumn(const Schema& schema,
raw_dataset,
search_info,
index_info,
bitview,
search_bitview,
data_type);
final_qr.merge(sub_qr);
} else {
@ -229,7 +269,7 @@ SearchOnSealedColumn(const Schema& schema,
raw_dataset,
search_info,
index_info,
bitview,
search_bitview,
data_type,
element_type,
op_context);
@ -243,6 +283,7 @@ SearchOnSealedColumn(const Schema& schema,
num_chunk,
column->GetNumRowsUntilChunk(),
final_qr.chunk_iterators(),
offset_mapping,
larger_is_closer);
} else {
if (search_info.array_offsets_ != nullptr) {
@ -256,6 +297,9 @@ SearchOnSealedColumn(const Schema& schema,
result.seg_offsets_ = std::move(final_qr.mutable_offsets());
}
result.distances_ = std::move(final_qr.mutable_distances());
if (offset_mapping.IsEnabled()) {
TransformOffset(result.seg_offsets_, offset_mapping);
}
}
result.unity_topK_ = query_dataset.topk;
result.total_nq_ = query_dataset.num_queries;

View File

@ -14,9 +14,37 @@
#include <limits>
#include <string>
#include "common/BitsetView.h"
#include "common/OffsetMapping.h"
#include "common/Types.h"
#include "common/Utils.h"
namespace milvus::query {
inline TargetBitmap
TransformBitset(const BitsetView& bitset,
const milvus::OffsetMapping& mapping) {
TargetBitmap result;
auto count = mapping.GetValidCount();
result.resize(count);
for (int64_t physical_idx = 0; physical_idx < count; physical_idx++) {
auto logical_idx = mapping.GetLogicalOffset(physical_idx);
if (logical_idx >= 0 &&
logical_idx < static_cast<int64_t>(bitset.size())) {
result[physical_idx] = bitset.test(logical_idx);
}
}
return result;
}
inline void
TransformOffset(std::vector<int64_t>& seg_offsets,
const milvus::OffsetMapping& mapping) {
for (auto& seg_offset : seg_offsets) {
if (seg_offset >= 0) {
seg_offset = mapping.GetLogicalOffset(seg_offset);
}
}
}
template <typename T, typename U>
inline bool

View File

@ -918,6 +918,61 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
}
}
ChunkedSegmentSealedImpl::ValidResult
ChunkedSegmentSealedImpl::FilterVectorValidOffsets(milvus::OpContext* op_ctx,
FieldId field_id,
const int64_t* seg_offsets,
int64_t count) const {
ValidResult result;
result.valid_count = count;
if (vector_indexings_.is_ready(field_id)) {
auto field_indexing = vector_indexings_.get_field_indexing(field_id);
auto cache_index = field_indexing->indexing_;
auto ca = SemiInlineGet(cache_index->PinCells(op_ctx, {0}));
auto vec_index = dynamic_cast<index::VectorIndex*>(ca->get_cell_of(0));
if (vec_index != nullptr && vec_index->HasValidData()) {
result.valid_data = std::make_unique<bool[]>(count);
result.valid_offsets.reserve(count);
for (int64_t i = 0; i < count; ++i) {
bool is_valid = vec_index->IsRowValid(seg_offsets[i]);
result.valid_data[i] = is_valid;
if (is_valid) {
int64_t physical_offset =
vec_index->GetPhysicalOffset(seg_offsets[i]);
if (physical_offset >= 0) {
result.valid_offsets.push_back(physical_offset);
}
}
}
result.valid_count = result.valid_offsets.size();
}
} else {
auto column = get_column(field_id);
if (column != nullptr && column->IsNullable()) {
result.valid_data = std::make_unique<bool[]>(count);
result.valid_offsets.reserve(count);
const auto& offset_mapping = column->GetOffsetMapping();
for (int64_t i = 0; i < count; ++i) {
bool is_valid = offset_mapping.IsValid(seg_offsets[i]);
result.valid_data[i] = is_valid;
if (is_valid) {
int64_t physical_offset =
offset_mapping.GetPhysicalOffset(seg_offsets[i]);
if (physical_offset >= 0) {
result.valid_offsets.push_back(physical_offset);
}
}
}
result.valid_count = result.valid_offsets.size();
}
}
return result;
}
std::unique_ptr<DataArray>
ChunkedSegmentSealedImpl::get_vector(milvus::OpContext* op_ctx,
FieldId field_id,
@ -945,16 +1000,29 @@ ChunkedSegmentSealedImpl::get_vector(milvus::OpContext* op_ctx,
if (has_raw_data) {
// If index has raw data, get vector from memory.
auto ids_ds = GenIdsDataset(count, ids);
ValidResult filter_result;
knowhere::DataSetPtr ids_ds;
int64_t valid_count = count;
const bool* valid_data = nullptr;
if (field_meta.is_nullable()) {
filter_result =
FilterVectorValidOffsets(op_ctx, field_id, ids, count);
ids_ds = GenIdsDataset(filter_result.valid_count,
filter_result.valid_offsets.data());
valid_count = filter_result.valid_count;
valid_data = filter_result.valid_data.get();
} else {
ids_ds = GenIdsDataset(count, ids);
}
if (field_meta.get_data_type() == DataType::VECTOR_SPARSE_U32_F32) {
auto res = vec_index->GetSparseVector(ids_ds);
return segcore::CreateVectorDataArrayFrom(
res.get(), count, field_meta);
res.get(), valid_data, count, valid_count, field_meta);
} else {
// dense vector:
auto vector = vec_index->GetVector(ids_ds);
return segcore::CreateVectorDataArrayFrom(
vector.data(), count, field_meta);
vector.data(), valid_data, count, valid_count, field_meta);
}
}
@ -1529,10 +1597,13 @@ ChunkedSegmentSealedImpl::ClearData() {
std::unique_ptr<DataArray>
ChunkedSegmentSealedImpl::fill_with_empty(FieldId field_id,
int64_t count) const {
int64_t count,
int64_t valid_count,
const void* valid_data) const {
auto& field_meta = schema_->operator[](field_id);
if (IsVectorDataType(field_meta.get_data_type())) {
return CreateEmptyVectorDataArray(count, field_meta);
return CreateEmptyVectorDataArray(
count, valid_count, valid_data, field_meta);
}
return CreateEmptyScalarDataArray(count, field_meta);
}
@ -1682,8 +1753,22 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
AssertInfo(column != nullptr,
"field {} must exist when getting raw data",
field_id.get());
auto ret = fill_with_empty(field_id, count);
if (column->IsNullable()) {
int64_t valid_count = count;
const bool* valid_data = nullptr;
const int64_t* valid_offsets = seg_offsets;
ValidResult filter_result;
if (field_meta.is_vector() && field_meta.is_nullable()) {
filter_result =
FilterVectorValidOffsets(op_ctx, field_id, seg_offsets, count);
valid_count = filter_result.valid_count;
valid_data = filter_result.valid_data.get();
valid_offsets = filter_result.valid_offsets.data();
}
auto ret = fill_with_empty(field_id, count, valid_count, valid_data);
if (!field_meta.is_vector() && column->IsNullable()) {
auto dst = ret->mutable_valid_data()->mutable_data();
column->BulkIsValid(
op_ctx,
@ -1691,6 +1776,7 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
seg_offsets,
count);
}
switch (field_meta.get_data_type()) {
case DataType::VARCHAR:
case DataType::STRING:
@ -1827,8 +1913,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
bulk_subscript_impl(op_ctx,
field_meta.get_sizeof(),
column.get(),
seg_offsets,
count,
valid_offsets,
valid_count,
ret->mutable_vectors()
->mutable_float_vector()
->mutable_data()
@ -1840,8 +1926,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
op_ctx,
field_meta.get_sizeof(),
column.get(),
seg_offsets,
count,
valid_offsets,
valid_count,
ret->mutable_vectors()->mutable_float16_vector()->data());
break;
}
@ -1850,8 +1936,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
op_ctx,
field_meta.get_sizeof(),
column.get(),
seg_offsets,
count,
valid_offsets,
valid_count,
ret->mutable_vectors()->mutable_bfloat16_vector()->data());
break;
}
@ -1860,8 +1946,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
op_ctx,
field_meta.get_sizeof(),
column.get(),
seg_offsets,
count,
valid_offsets,
valid_count,
ret->mutable_vectors()->mutable_binary_vector()->data());
break;
}
@ -1870,8 +1956,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
op_ctx,
field_meta.get_sizeof(),
column.get(),
seg_offsets,
count,
valid_offsets,
valid_count,
ret->mutable_vectors()->mutable_int8_vector()->data());
break;
}
@ -1881,7 +1967,7 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
column->BulkValueAt(
op_ctx,
[&](const char* value, size_t i) mutable {
auto offset = seg_offsets[i];
auto offset = valid_offsets[i];
auto row =
offset != INVALID_SEG_OFFSET
? static_cast<const knowhere::sparse::SparseRow<
@ -1895,8 +1981,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
max_dim = std::max(max_dim, row->dim());
dst->add_contents(row->data(), row->data_byte_size());
},
seg_offsets,
count);
valid_offsets,
valid_count);
dst->set_dim(max_dim);
ret->mutable_vectors()->set_dim(dst->dim());
break;
@ -1905,8 +1991,8 @@ ChunkedSegmentSealedImpl::get_raw_data(milvus::OpContext* op_ctx,
bulk_subscript_vector_array_impl(
op_ctx,
column.get(),
seg_offsets,
count,
valid_offsets,
valid_count,
ret->mutable_vectors()->mutable_vector_array()->mutable_data());
break;
}
@ -2267,7 +2353,12 @@ ChunkedSegmentSealedImpl::generate_interim_index(const FieldId field_id,
return false;
}
try {
int64_t row_count = num_rows;
std::shared_ptr<ChunkedColumnInterface> vec_data = get_column(field_id);
AssertInfo(
vec_data != nullptr, "vector field {} not loaded", field_id.get());
int64_t row_count = field_meta.is_nullable()
? vec_data->GetOffsetMapping().GetValidCount()
: num_rows;
// generate index params
auto field_binlog_config = std::unique_ptr<VecIndexConfig>(
@ -2279,9 +2370,6 @@ ChunkedSegmentSealedImpl::generate_interim_index(const FieldId field_id,
if (row_count < field_binlog_config->GetBuildThreshold()) {
return false;
}
std::shared_ptr<ChunkedColumnInterface> vec_data = get_column(field_id);
AssertInfo(
vec_data != nullptr, "vector field {} not loaded", field_id.get());
auto dim = is_sparse ? std::numeric_limits<uint32_t>::max()
: field_meta.get_dim();
auto interim_index_type = field_binlog_config->GetIndexType();
@ -2397,6 +2485,10 @@ ChunkedSegmentSealedImpl::load_field_data_common(
return;
}
if (column->IsNullable() && IsVectorDataType(data_type)) {
column->BuildValidRowIds(nullptr);
}
if (!enable_mmap) {
if (!is_proxy_column ||
is_proxy_column &&
@ -2517,12 +2609,7 @@ ChunkedSegmentSealedImpl::Reopen(SchemaPtr sch) {
auto absent_fields = sch->AbsentFields(*schema_);
for (const auto& field_meta : *absent_fields) {
// vector field is not supported to be "added field", thus if a vector
// field is absent, it means for some reason we want to skip loading this
// field.
if (!IsVectorDataType(field_meta.get_data_type())) {
fill_empty_field(field_meta);
}
fill_empty_field(field_meta);
}
schema_ = sch;
@ -2597,10 +2684,6 @@ ChunkedSegmentSealedImpl::FinishLoad() {
// no filling fields that index already loaded and has raw data
continue;
}
if (IsVectorDataType(field_meta.get_data_type())) {
// no filling vector fields
continue;
}
fill_empty_field(field_meta);
}
}
@ -2608,10 +2691,11 @@ ChunkedSegmentSealedImpl::FinishLoad() {
void
ChunkedSegmentSealedImpl::fill_empty_field(const FieldMeta& field_meta) {
auto field_id = field_meta.get_id();
auto data_type = field_meta.get_data_type();
LOG_INFO(
"start fill empty field {} (data type {}) for sealed segment "
"{}",
field_meta.get_data_type(),
data_type,
field_id.get(),
id_);
int64_t size = num_rows_.value();
@ -2620,40 +2704,11 @@ ChunkedSegmentSealedImpl::fill_empty_field(const FieldMeta& field_meta) {
std::unique_ptr<Translator<milvus::Chunk>> translator =
std::make_unique<storagev1translator::DefaultValueChunkTranslator>(
get_segment_id(), field_meta, field_data_info, false);
std::shared_ptr<milvus::ChunkedColumnBase> column{};
switch (field_meta.get_data_type()) {
case milvus::DataType::STRING:
case milvus::DataType::VARCHAR:
case milvus::DataType::TEXT: {
column = std::make_shared<ChunkedVariableColumn<std::string>>(
std::move(translator), field_meta);
break;
}
case milvus::DataType::JSON: {
column = std::make_shared<ChunkedVariableColumn<milvus::Json>>(
std::move(translator), field_meta);
break;
}
case milvus::DataType::GEOMETRY: {
column = std::make_shared<ChunkedVariableColumn<std::string>>(
std::move(translator), field_meta);
break;
}
case milvus::DataType::ARRAY: {
column = std::make_shared<ChunkedArrayColumn>(std::move(translator),
field_meta);
break;
}
case milvus::DataType::VECTOR_ARRAY: {
column = std::make_shared<ChunkedVectorArrayColumn>(
std::move(translator), field_meta);
break;
}
default: {
column = std::make_shared<ChunkedColumn>(std::move(translator),
field_meta);
break;
}
auto column =
MakeChunkedColumnBase(data_type, std::move(translator), field_meta);
if (column->IsNullable() && IsVectorDataType(data_type)) {
column->BuildValidRowIds(nullptr);
}
fields_.wlock()->emplace(field_id, column);

View File

@ -880,7 +880,10 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
google::protobuf::RepeatedPtrField<T>* dst);
std::unique_ptr<DataArray>
fill_with_empty(FieldId field_id, int64_t count) const;
fill_with_empty(FieldId field_id,
int64_t count,
int64_t valid_count = 0,
const void* valid_data = nullptr) const;
std::unique_ptr<DataArray>
get_raw_data(milvus::OpContext* op_ctx,
@ -889,6 +892,18 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
const int64_t* seg_offsets,
int64_t count) const;
struct ValidResult {
int64_t valid_count = 0;
std::unique_ptr<bool[]> valid_data;
std::vector<int64_t> valid_offsets;
};
ValidResult
FilterVectorValidOffsets(milvus::OpContext* op_ctx,
FieldId field_id,
const int64_t* seg_offsets,
int64_t count) const;
void
update_row_count(int64_t row_count) {
num_rows_ = row_count;

View File

@ -17,6 +17,7 @@
#include <atomic>
#include <cassert>
#include <deque>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <string>
@ -29,6 +30,7 @@
#include "common/FieldMeta.h"
#include "common/FieldData.h"
#include "common/Json.h"
#include "common/OffsetMapping.h"
#include "common/Span.h"
#include "common/Types.h"
#include "common/Utils.h"
@ -79,19 +81,30 @@ class ThreadSafeValidData {
}
bool
is_valid(size_t offset) {
is_valid(size_t offset) const {
std::shared_lock<std::shared_mutex> lck(mutex_);
Assert(offset < length_);
AssertInfo(offset < length_,
"offset out of range, offset={}, length_={}",
offset,
length_);
return data_[offset];
}
bool*
get_chunk_data(size_t offset) {
std::shared_lock<std::shared_mutex> lck(mutex_);
Assert(offset < length_);
AssertInfo(offset < length_,
"offset out of range, offset={}, length_={}",
offset,
length_);
return &data_[offset];
}
const FixedVector<bool>&
get_data() const {
return data_;
}
private:
mutable std::shared_mutex mutex_{};
FixedVector<bool> data_;
@ -155,6 +168,40 @@ class VectorBase {
virtual void
clear() = 0;
virtual bool
is_mapping_storage() const {
return false;
}
// Get physical offset from logical offset. Returns -1 if not found.
virtual int64_t
get_physical_offset(int64_t logical_offset) const {
return logical_offset; // default: no mapping
}
// Get logical offset from physical offset. Returns -1 if not found.
virtual int64_t
get_logical_offset(int64_t physical_offset) const {
return physical_offset; // default: no mapping
}
virtual int64_t
get_valid_count() const {
return 0;
}
virtual const FixedVector<bool>&
get_valid_data() const {
static const FixedVector<bool> empty;
return empty;
}
virtual const OffsetMapping&
get_offset_mapping() const {
static const OffsetMapping empty;
return empty;
}
protected:
const int64_t size_per_chunk_;
};
@ -191,10 +238,12 @@ class ConcurrentVectorImpl : public VectorBase {
ssize_t elements_per_row,
int64_t size_per_chunk,
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
bool use_mapping_storage = false)
: VectorBase(size_per_chunk),
elements_per_row_(is_type_entire_row ? 1 : elements_per_row),
valid_data_ptr_(valid_data_ptr) {
valid_data_ptr_(valid_data_ptr),
use_mapping_storage_(use_mapping_storage) {
chunks_ptr_ = SelectChunkVectorPtr<Type>(mmap_descriptor);
}
@ -221,19 +270,7 @@ class ConcurrentVectorImpl : public VectorBase {
void
fill_chunk_data(const std::vector<FieldDataPtr>& datas) override {
AssertInfo(chunks_ptr_->size() == 0, "non empty concurrent vector");
int64_t element_count = 0;
for (auto& field_data : datas) {
element_count += field_data->get_num_rows();
}
chunks_ptr_->emplace_to_at_least(1, elements_per_row_ * element_count);
int64_t offset = 0;
for (auto& field_data : datas) {
auto num_rows = field_data->get_num_rows();
set_data(
offset, static_cast<const Type*>(field_data->Data()), num_rows);
offset += num_rows;
}
set_data_raw(0, datas);
}
void
@ -250,16 +287,39 @@ class ConcurrentVectorImpl : public VectorBase {
set_data_raw(ssize_t element_offset,
const void* source,
ssize_t element_count) override {
if (element_count == 0) {
return;
ssize_t valid_count = 0;
ssize_t storage_offset = 0;
if (use_mapping_storage_) {
if constexpr (!std::is_same_v<Type, bool>) {
storage_offset = offset_mapping_.GetNextPhysicalOffset();
// Build valid_data array for offset mapping
std::unique_ptr<bool[]> valid_data(new bool[element_count]);
for (ssize_t i = 0; i < element_count; ++i) {
bool is_valid =
valid_data_ptr_->is_valid(element_offset + i);
valid_data[i] = is_valid;
if (is_valid) {
valid_count++;
}
}
offset_mapping_.BuildIncremental(valid_data.get(),
element_count,
element_offset,
storage_offset);
}
} else {
valid_count = element_count;
storage_offset = element_offset;
}
if (valid_count > 0) {
auto size = size_per_chunk_ == MAX_ROW_COUNT ? valid_count
: size_per_chunk_;
chunks_ptr_->emplace_to_at_least(
upper_div(storage_offset + valid_count, size),
elements_per_row_ * size);
set_data(
storage_offset, static_cast<const Type*>(source), valid_count);
}
auto size =
size_per_chunk_ == MAX_ROW_COUNT ? element_count : size_per_chunk_;
chunks_ptr_->emplace_to_at_least(
upper_div(element_offset + element_count, size),
elements_per_row_ * size);
set_data(
element_offset, static_cast<const Type*>(source), element_count);
}
const void*
@ -297,8 +357,24 @@ class ConcurrentVectorImpl : public VectorBase {
// just for fun, don't use it directly
const Type*
get_element(ssize_t element_index) const {
auto chunk_id = element_index / size_per_chunk_;
auto chunk_offset = element_index % size_per_chunk_;
auto physical_index = offset_mapping_.GetPhysicalOffset(element_index);
if (physical_index == -1) {
return nullptr;
}
auto chunk_id = physical_index / size_per_chunk_;
auto chunk_offset = physical_index % size_per_chunk_;
auto data =
static_cast<const Type*>(chunks_ptr_->get_chunk_data(chunk_id));
return data + chunk_offset * elements_per_row_;
}
const Type*
get_physical_element(ssize_t physical_index) const {
if (physical_index == -1) {
return nullptr;
}
auto chunk_id = physical_index / size_per_chunk_;
auto chunk_offset = physical_index % size_per_chunk_;
auto data =
static_cast<const Type*>(chunks_ptr_->get_chunk_data(chunk_id));
return data + chunk_offset * elements_per_row_;
@ -344,6 +420,40 @@ class ConcurrentVectorImpl : public VectorBase {
return chunks_ptr_->is_mmap();
}
bool
is_mapping_storage() const override {
return use_mapping_storage_;
}
int64_t
get_physical_offset(int64_t logical_offset) const override {
return offset_mapping_.GetPhysicalOffset(logical_offset);
}
int64_t
get_logical_offset(int64_t physical_offset) const override {
return offset_mapping_.GetLogicalOffset(physical_offset);
}
int64_t
get_valid_count() const override {
return offset_mapping_.GetValidCount();
}
const milvus::OffsetMapping&
get_offset_mapping() const override {
return offset_mapping_;
}
const FixedVector<bool>&
get_valid_data() const override {
if (valid_data_ptr_ != nullptr) {
return valid_data_ptr_->get_data();
}
static const FixedVector<bool> empty;
return empty;
}
private:
void
set_data(ssize_t element_offset,
@ -395,9 +505,10 @@ class ConcurrentVectorImpl : public VectorBase {
fmt::format("chunk_id out of chunk num, chunk_id={}, chunk_num={}",
chunk_id,
chunk_num));
size_t chunk_id_offset = chunk_id * size_per_chunk_ * elements_per_row_;
std::optional<CheckDataValid> check_data_valid = std::nullopt;
if (valid_data_ptr_ != nullptr) {
if (valid_data_ptr_ != nullptr && !use_mapping_storage_) {
size_t chunk_id_offset =
chunk_id * size_per_chunk_ * elements_per_row_;
check_data_valid = [valid_data_ptr = valid_data_ptr_,
beg_id = chunk_id_offset](size_t offset) {
return valid_data_ptr->is_valid(beg_id + offset);
@ -414,6 +525,9 @@ class ConcurrentVectorImpl : public VectorBase {
const ssize_t elements_per_row_;
ChunkVectorPtr<Type> chunks_ptr_ = nullptr;
ThreadSafeValidDataPtr valid_data_ptr_ = nullptr;
const bool use_mapping_storage_;
milvus::OffsetMapping offset_mapping_;
};
template <typename Type>
@ -496,9 +610,14 @@ class ConcurrentVector<VectorArray>
int64_t dim /* not use it*/,
int64_t size_per_chunk,
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
bool use_mapping_storage = false)
: ConcurrentVectorImpl<VectorArray, true>::ConcurrentVectorImpl(
1, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) {
1,
size_per_chunk,
std::move(mmap_descriptor),
valid_data_ptr,
use_mapping_storage) {
}
};
@ -510,13 +629,15 @@ class ConcurrentVector<SparseFloatVector>
explicit ConcurrentVector(
int64_t size_per_chunk,
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
bool use_mapping_storage = false)
: ConcurrentVectorImpl<knowhere::sparse::SparseRow<SparseValueType>,
true>::ConcurrentVectorImpl(1,
size_per_chunk,
std::move(
mmap_descriptor),
valid_data_ptr),
valid_data_ptr,
use_mapping_storage),
dim_(0) {
}
@ -527,7 +648,16 @@ class ConcurrentVector<SparseFloatVector>
auto* src =
static_cast<const knowhere::sparse::SparseRow<SparseValueType>*>(
source);
for (int i = 0; i < element_count; ++i) {
ssize_t source_count = element_count;
if (this->use_mapping_storage_) {
source_count = 0;
for (ssize_t i = 0; i < element_count; ++i) {
if (this->valid_data_ptr_->is_valid(element_offset + i)) {
source_count++;
}
}
}
for (ssize_t i = 0; i < source_count; ++i) {
dim_ = std::max(dim_, src[i].dim());
}
ConcurrentVectorImpl<knowhere::sparse::SparseRow<SparseValueType>,
@ -552,9 +682,14 @@ class ConcurrentVector<FloatVector>
ConcurrentVector(int64_t dim,
int64_t size_per_chunk,
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
bool use_mapping_storage = false)
: ConcurrentVectorImpl<float, false>::ConcurrentVectorImpl(
dim, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) {
dim,
size_per_chunk,
std::move(mmap_descriptor),
valid_data_ptr,
use_mapping_storage) {
}
};
@ -566,11 +701,13 @@ class ConcurrentVector<BinaryVector>
int64_t dim,
int64_t size_per_chunk,
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
bool use_mapping_storage = false)
: ConcurrentVectorImpl(dim / 8,
size_per_chunk,
std::move(mmap_descriptor),
valid_data_ptr) {
valid_data_ptr,
use_mapping_storage) {
AssertInfo(dim % 8 == 0,
fmt::format("dim is not a multiple of 8, dim={}", dim));
}
@ -583,9 +720,14 @@ class ConcurrentVector<Float16Vector>
ConcurrentVector(int64_t dim,
int64_t size_per_chunk,
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
bool use_mapping_storage = false)
: ConcurrentVectorImpl<float16, false>::ConcurrentVectorImpl(
dim, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) {
dim,
size_per_chunk,
std::move(mmap_descriptor),
valid_data_ptr,
use_mapping_storage) {
}
};
@ -596,9 +738,14 @@ class ConcurrentVector<BFloat16Vector>
ConcurrentVector(int64_t dim,
int64_t size_per_chunk,
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
bool use_mapping_storage = false)
: ConcurrentVectorImpl<bfloat16, false>::ConcurrentVectorImpl(
dim, size_per_chunk, std::move(mmap_descriptor), valid_data_ptr) {
dim,
size_per_chunk,
std::move(mmap_descriptor),
valid_data_ptr,
use_mapping_storage) {
}
};
@ -608,9 +755,14 @@ class ConcurrentVector<Int8Vector> : public ConcurrentVectorImpl<int8, false> {
ConcurrentVector(int64_t dim,
int64_t size_per_chunk,
storage::MmapChunkDescriptorPtr mmap_descriptor = nullptr,
ThreadSafeValidDataPtr valid_data_ptr = nullptr)
ThreadSafeValidDataPtr valid_data_ptr = nullptr,
bool use_mapping_storage = false)
: ConcurrentVectorImpl<int8, false>::ConcurrentVectorImpl(
dim, size_per_chunk, std::move(mmap_descriptor)) {
dim,
size_per_chunk,
std::move(mmap_descriptor),
valid_data_ptr,
use_mapping_storage) {
}
};

View File

@ -9,8 +9,10 @@
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include <map>
#include <string>
#include <thread>
#include <unordered_map>
#include "common/EasyAssert.h"
#include "common/Types.h"
@ -19,6 +21,7 @@
#include "index/StringIndexMarisa.h"
#include "common/SystemProperty.h"
#include "segcore/ConcurrentVector.h"
#include "segcore/FieldIndexing.h"
#include "index/VectorMemIndex.h"
#include "IndexConfigGenerator.h"
@ -29,6 +32,104 @@
namespace milvus::segcore {
using std::unique_ptr;
void
IndexingRecord::AppendingIndex(int64_t reserved_offset,
int64_t size,
FieldId fieldId,
const DataArray* stream_data,
const InsertRecord<false>& record) {
if (!is_in(fieldId)) {
return;
}
auto& indexing = field_indexings_.at(fieldId);
auto type = indexing->get_data_type();
auto field_raw_data = record.get_data_base(fieldId);
auto field_meta = schema_.get_fields().at(fieldId);
int64_t valid_count = reserved_offset + size;
if (field_meta.is_nullable() && field_raw_data->is_mapping_storage()) {
valid_count = field_raw_data->get_valid_count();
}
if (type == DataType::VECTOR_FLOAT &&
valid_count >= indexing->get_build_threshold()) {
indexing->AppendSegmentIndexDense(
reserved_offset,
size,
field_raw_data,
stream_data->vectors().float_vector().data().data());
} else if (type == DataType::VECTOR_FLOAT16 &&
valid_count >= indexing->get_build_threshold()) {
indexing->AppendSegmentIndexDense(
reserved_offset,
size,
field_raw_data,
stream_data->vectors().float16_vector().data());
} else if (type == DataType::VECTOR_BFLOAT16 &&
valid_count >= indexing->get_build_threshold()) {
indexing->AppendSegmentIndexDense(
reserved_offset,
size,
field_raw_data,
stream_data->vectors().bfloat16_vector().data());
} else if (type == DataType::VECTOR_SPARSE_U32_F32 &&
valid_count >= indexing->get_build_threshold()) {
auto data = SparseBytesToRows(
stream_data->vectors().sparse_float_vector().contents());
indexing->AppendSegmentIndexSparse(
reserved_offset,
size,
stream_data->vectors().sparse_float_vector().dim(),
field_raw_data,
data.get());
} else if (type == DataType::GEOMETRY) {
// For geometry fields, append data incrementally to RTree index
indexing->AppendSegmentIndex(
reserved_offset, size, field_raw_data, stream_data);
}
}
// concurrent, reentrant
void
IndexingRecord::AppendingIndex(int64_t reserved_offset,
int64_t size,
FieldId fieldId,
const FieldDataPtr data,
const InsertRecord<false>& record) {
if (!is_in(fieldId)) {
return;
}
auto& indexing = field_indexings_.at(fieldId);
auto type = indexing->get_data_type();
const void* p = data->Data();
auto vec_base = record.get_data_base(fieldId);
auto field_meta = schema_.get_fields().at(fieldId);
int64_t valid_count = reserved_offset + size;
if (field_meta.is_nullable() && vec_base->is_mapping_storage()) {
valid_count = vec_base->get_valid_count();
}
if ((type == DataType::VECTOR_FLOAT || type == DataType::VECTOR_FLOAT16 ||
type == DataType::VECTOR_BFLOAT16) &&
valid_count >= indexing->get_build_threshold()) {
auto vec_base = record.get_data_base(fieldId);
indexing->AppendSegmentIndexDense(
reserved_offset, size, vec_base, data->Data());
} else if (type == DataType::VECTOR_SPARSE_U32_F32 &&
valid_count >= indexing->get_build_threshold()) {
auto vec_base = record.get_data_base(fieldId);
indexing->AppendSegmentIndexSparse(
reserved_offset,
size,
std::dynamic_pointer_cast<const FieldData<SparseFloatVector>>(data)
->Dim(),
vec_base,
p);
} else if (type == DataType::GEOMETRY) {
// For geometry fields, append data incrementally to RTree index
auto vec_base = record.get_data_base(fieldId);
indexing->AppendSegmentIndex(reserved_offset, size, vec_base, data);
}
}
VectorFieldIndexing::VectorFieldIndexing(const FieldMeta& field_meta,
const FieldIndexMeta& field_index_meta,
int64_t segment_max_row_count,
@ -140,54 +241,133 @@ VectorFieldIndexing::AppendSegmentIndexSparse(int64_t reserved_offset,
int64_t new_data_dim,
const VectorBase* field_raw_data,
const void* data_source) {
using value_type = knowhere::sparse::SparseRow<SparseValueType>;
AssertInfo(get_data_type() == DataType::VECTOR_SPARSE_U32_F32,
"Data type of vector field is not VECTOR_SPARSE_U32_F32");
auto conf = get_build_params(get_data_type());
auto source = dynamic_cast<const ConcurrentVector<SparseFloatVector>*>(
field_raw_data);
AssertInfo(source,
auto field_source =
dynamic_cast<const ConcurrentVector<SparseFloatVector>*>(
field_raw_data);
AssertInfo(field_source,
"field_raw_data can't cast to "
"ConcurrentVector<SparseFloatVector> type");
AssertInfo(size > 0, "append 0 sparse rows to index is not allowed");
if (!built_) {
AssertInfo(!sync_with_index_, "index marked synced before built");
idx_t total_rows = reserved_offset + size;
idx_t chunk_id = 0;
auto dim = source->Dim();
auto source = static_cast<const value_type*>(data_source);
while (total_rows > 0) {
auto mat = static_cast<
const knowhere::sparse::SparseRow<SparseValueType>*>(
source->get_chunk_data(chunk_id));
auto rows = std::min(source->get_size_per_chunk(), total_rows);
auto dataset = knowhere::GenDataSet(rows, dim, mat);
dataset->SetIsSparse(true);
try {
if (chunk_id == 0) {
index_->BuildWithDataset(dataset, conf);
} else {
index_->AddWithDataset(dataset, conf);
auto dim = new_data_dim;
auto size_per_chunk = field_raw_data->get_size_per_chunk();
auto build_threshold = get_build_threshold();
bool is_mapping_storage = field_raw_data->is_mapping_storage();
auto& valid_data = field_raw_data->get_valid_data();
if (!built_) {
const void* data_ptr = nullptr;
std::vector<value_type> data_buf;
int64_t start_chunk = 0;
int64_t end_chunk = (build_threshold - 1) / size_per_chunk;
if (start_chunk == end_chunk) {
data_ptr = field_raw_data->get_chunk_data(start_chunk);
} else {
data_buf.resize(build_threshold);
int64_t actual_copy_count = 0;
for (int64_t chunk_id = start_chunk; chunk_id <= end_chunk;
++chunk_id) {
int64_t copy_start =
std::max((int64_t)0, chunk_id * size_per_chunk);
int64_t copy_end =
std::min(build_threshold, (chunk_id + 1) * size_per_chunk);
int64_t copy_count = copy_end - copy_start;
// For mapping storage, chunk data is already compactly stored,
// so we can copy directly from chunk
auto chunk_data = static_cast<const value_type*>(
field_raw_data->get_chunk_data(chunk_id));
int64_t chunk_offset = copy_start - chunk_id * size_per_chunk;
for (int64_t i = 0; i < copy_count; ++i) {
data_buf[actual_copy_count + i] =
chunk_data[chunk_offset + i];
}
} catch (SegcoreError& error) {
LOG_ERROR("growing sparse index build error: {}", error.what());
recreate_index(get_data_type(), nullptr);
index_cur_ = 0;
return;
actual_copy_count += copy_count;
}
index_cur_.fetch_add(rows);
total_rows -= rows;
chunk_id++;
data_ptr = data_buf.data();
}
auto dataset = knowhere::GenDataSet(build_threshold, dim, data_ptr);
dataset->SetIsSparse(true);
try {
index_->BuildWithDataset(dataset, conf);
if (is_mapping_storage) {
auto logical_offset =
field_raw_data->get_logical_offset(build_threshold - 1);
auto update_count = logical_offset + 1;
index_->UpdateValidData(valid_data.data(), update_count);
}
built_ = true;
index_cur_.fetch_add(build_threshold);
} catch (SegcoreError& error) {
LOG_ERROR("growing sparse index build error: {}", error.what());
recreate_index(get_data_type(), field_raw_data);
return;
}
built_ = true;
sync_with_index_ = true;
// if not built_, new rows in data_source have already been added to
// source(ConcurrentVector<SparseFloatVector>) and thus added to the
// index, thus no need to add again.
return;
}
auto dataset = knowhere::GenDataSet(size, new_data_dim, data_source);
dataset->SetIsSparse(true);
index_->AddWithDataset(dataset, conf);
index_cur_.fetch_add(size);
// Append rest data when index has been built
int64_t add_count = 0;
int64_t total_count = 0;
if (valid_data.empty()) {
// Non-nullable case: add all rows
add_count = reserved_offset + size - index_cur_.load();
total_count = size;
if (add_count <= 0) {
sync_with_index_.store(true);
return;
}
auto data_ptr = source + (total_count - add_count);
auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr);
dataset->SetIsSparse(true);
try {
index_->AddWithDataset(dataset, conf);
index_cur_.fetch_add(add_count);
sync_with_index_.store(true);
} catch (SegcoreError& error) {
LOG_ERROR("growing sparse index add error: {}", error.what());
recreate_index(get_data_type(), field_raw_data);
}
} else {
// Nullable case: only add valid rows (matching dense vector approach)
auto index_total_count = index_->GetOffsetMapping().GetTotalCount();
auto add_valid_data_count = reserved_offset + size - index_total_count;
for (auto i = reserved_offset; i < reserved_offset + size; i++) {
if (valid_data[i]) {
total_count++;
if (i >= index_total_count) {
add_count++;
}
}
}
if (add_count <= 0 && add_valid_data_count <= 0) {
sync_with_index_.store(true);
return;
}
if (add_count > 0) {
auto data_ptr = source + (total_count - add_count);
auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr);
dataset->SetIsSparse(true);
try {
index_->AddWithDataset(dataset, conf);
} catch (SegcoreError& error) {
LOG_ERROR("growing sparse index add error: {}", error.what());
recreate_index(get_data_type(), field_raw_data);
}
}
if (add_valid_data_count > 0) {
index_->UpdateValidData(valid_data.data() + index_total_count,
add_valid_data_count);
}
index_cur_.fetch_add(add_count);
sync_with_index_.store(true);
}
}
void
@ -203,8 +383,10 @@ VectorFieldIndexing::AppendSegmentIndexDense(int64_t reserved_offset,
auto dim = get_dim();
auto conf = get_build_params(get_data_type());
auto size_per_chunk = field_raw_data->get_size_per_chunk();
//append vector [vector_id_beg, vector_id_end] into index
//build index [vector_id_beg, build_threshold) when index not exist
auto build_threshold = get_build_threshold();
bool is_mapping_storage = field_raw_data->is_mapping_storage();
auto& valid_data = field_raw_data->get_valid_data();
AssertInfo(ConcurrentDenseVectorCheck(field_raw_data, get_data_type()),
"vec_base can't cast to ConcurrentVector type");
size_t vec_length;
@ -216,88 +398,112 @@ VectorFieldIndexing::AppendSegmentIndexDense(int64_t reserved_offset,
vec_length = dim * sizeof(bfloat16);
}
if (!built_) {
idx_t vector_id_beg = index_cur_.load();
Assert(vector_id_beg == 0);
idx_t vector_id_end = get_build_threshold() - 1;
auto chunk_id_beg = vector_id_beg / size_per_chunk;
auto chunk_id_end = vector_id_end / size_per_chunk;
const void* data_ptr;
std::unique_ptr<char[]> data_buf;
// Chunk data stores valid vectors compactly for both nullable and non-nullable
int64_t start_chunk = 0;
int64_t end_chunk = (build_threshold - 1) / size_per_chunk;
int64_t vec_num = vector_id_end - vector_id_beg + 1;
// for train index
const void* data_addr;
unique_ptr<char[]> vec_data;
//all train data in one chunk
if (chunk_id_beg == chunk_id_end) {
data_addr = field_raw_data->get_chunk_data(chunk_id_beg);
if (start_chunk == end_chunk) {
auto chunk_data = static_cast<const char*>(
field_raw_data->get_chunk_data(start_chunk));
data_ptr = chunk_data;
} else {
//merge data from multiple chunks together
vec_data = std::make_unique<char[]>(vec_num * vec_length);
int64_t offset = 0;
//copy vector data [vector_id_beg, vector_id_end]
for (int chunk_id = chunk_id_beg; chunk_id <= chunk_id_end;
chunk_id++) {
int chunk_offset = 0;
int chunk_copysz =
chunk_id == chunk_id_end
? vector_id_end - chunk_id * size_per_chunk + 1
: size_per_chunk;
std::memcpy(
(void*)((const char*)vec_data.get() + offset * vec_length),
(void*)((const char*)field_raw_data->get_chunk_data(
chunk_id) +
chunk_offset * vec_length),
chunk_copysz * vec_length);
offset += chunk_copysz;
data_buf = std::make_unique<char[]>(build_threshold * vec_length);
int64_t actual_copy_count = 0;
for (int64_t chunk_id = start_chunk; chunk_id <= end_chunk;
++chunk_id) {
auto chunk_data = static_cast<const char*>(
field_raw_data->get_chunk_data(chunk_id));
int64_t copy_start =
std::max((int64_t)0, chunk_id * size_per_chunk);
int64_t copy_end =
std::min(build_threshold, (chunk_id + 1) * size_per_chunk);
int64_t copy_count = copy_end - copy_start;
auto src =
chunk_data +
(copy_start - chunk_id * size_per_chunk) * vec_length;
std::memcpy(data_buf.get() + actual_copy_count * vec_length,
src,
copy_count * vec_length);
actual_copy_count += copy_count;
}
data_addr = vec_data.get();
data_ptr = data_buf.get();
}
auto dataset = knowhere::GenDataSet(vec_num, dim, data_addr);
dataset->SetIsOwner(false);
auto dataset = knowhere::GenDataSet(build_threshold, dim, data_ptr);
try {
index_->BuildWithDataset(dataset, conf);
if (is_mapping_storage) {
auto logical_offset =
field_raw_data->get_logical_offset(build_threshold - 1);
auto update_count = logical_offset + 1;
index_->UpdateValidData(valid_data.data(), update_count);
}
built_ = true;
index_cur_.fetch_add(build_threshold);
} catch (SegcoreError& error) {
LOG_ERROR("growing index build error: {}", error.what());
recreate_index(get_data_type(), field_raw_data);
return;
}
index_cur_.fetch_add(vec_num);
built_ = true;
}
//append rest data when index has built
idx_t vector_id_beg = index_cur_.load();
idx_t vector_id_end = reserved_offset + size - 1;
auto chunk_id_beg = vector_id_beg / size_per_chunk;
auto chunk_id_end = vector_id_end / size_per_chunk;
int64_t vec_num = vector_id_end - vector_id_beg + 1;
if (vec_num <= 0) {
sync_with_index_.store(true);
return;
}
if (sync_with_index_.load()) {
Assert(size == vec_num);
auto dataset = knowhere::GenDataSet(vec_num, dim, data_source);
index_->AddWithDataset(dataset, conf);
index_cur_.fetch_add(vec_num);
} else {
for (int chunk_id = chunk_id_beg; chunk_id <= chunk_id_end;
chunk_id++) {
int chunk_offset = chunk_id == chunk_id_beg
? index_cur_ - chunk_id * size_per_chunk
: 0;
int chunk_sz =
chunk_id == chunk_id_end
? vector_id_end % size_per_chunk - chunk_offset + 1
: size_per_chunk - chunk_offset;
auto dataset = knowhere::GenDataSet(
chunk_sz,
dim,
(const char*)field_raw_data->get_chunk_data(chunk_id) +
chunk_offset * vec_length);
index_->AddWithDataset(dataset, conf);
index_cur_.fetch_add(chunk_sz);
int64_t add_count = 0;
int64_t total_count = 0;
if (valid_data.empty()) {
add_count = reserved_offset + size - index_cur_.load();
total_count = size;
if (add_count <= 0) {
sync_with_index_.store(true);
return;
}
auto data_ptr = static_cast<const char*>(data_source) +
(total_count - add_count) * vec_length;
auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr);
try {
index_->AddWithDataset(dataset, conf);
index_cur_.fetch_add(add_count);
sync_with_index_.store(true);
} catch (SegcoreError& error) {
LOG_ERROR("growing index add error: {}", error.what());
recreate_index(get_data_type(), field_raw_data);
}
} else {
// Nullable dense vectors: data_source (proto) contains valid vectors compactly
auto index_total_count = index_->GetOffsetMapping().GetTotalCount();
auto add_valid_data_count = reserved_offset + size - index_total_count;
auto index_cur_val = index_cur_.load();
// Count valid vectors in this batch range
for (auto i = reserved_offset; i < reserved_offset + size; i++) {
if (valid_data[i]) {
total_count++;
if (i >= index_total_count) {
add_count++;
}
}
}
if (add_count <= 0 && add_valid_data_count <= 0) {
sync_with_index_.store(true);
return;
}
if (add_count > 0) {
// data_source contains valid vectors compactly, skip already indexed ones
auto data_ptr = static_cast<const char*>(data_source) +
(total_count - add_count) * vec_length;
auto dataset = knowhere::GenDataSet(add_count, dim, data_ptr);
try {
index_->AddWithDataset(dataset, conf);
} catch (SegcoreError& error) {
LOG_ERROR("growing index add error: {}", error.what());
recreate_index(get_data_type(), field_raw_data);
}
}
if (add_valid_data_count > 0) {
index_->UpdateValidData(valid_data.data() + index_total_count,
add_valid_data_count);
}
index_cur_.fetch_add(add_count);
sync_with_index_.store(true);
}
}

View File

@ -434,93 +434,19 @@ class IndexingRecord {
assert(offset_id == schema_.size());
}
// concurrent, reentrant
void
AppendingIndex(int64_t reserved_offset,
int64_t size,
FieldId fieldId,
const DataArray* stream_data,
const InsertRecord<false>& record) {
if (!is_in(fieldId)) {
return;
}
auto& indexing = field_indexings_.at(fieldId);
auto type = indexing->get_data_type();
auto field_raw_data = record.get_data_base(fieldId);
if (type == DataType::VECTOR_FLOAT &&
reserved_offset + size >= indexing->get_build_threshold()) {
indexing->AppendSegmentIndexDense(
reserved_offset,
size,
field_raw_data,
stream_data->vectors().float_vector().data().data());
} else if (type == DataType::VECTOR_FLOAT16 &&
reserved_offset + size >= indexing->get_build_threshold()) {
indexing->AppendSegmentIndexDense(
reserved_offset,
size,
field_raw_data,
stream_data->vectors().float16_vector().data());
} else if (type == DataType::VECTOR_BFLOAT16 &&
reserved_offset + size >= indexing->get_build_threshold()) {
indexing->AppendSegmentIndexDense(
reserved_offset,
size,
field_raw_data,
stream_data->vectors().bfloat16_vector().data());
} else if (type == DataType::VECTOR_SPARSE_U32_F32) {
auto data = SparseBytesToRows(
stream_data->vectors().sparse_float_vector().contents());
indexing->AppendSegmentIndexSparse(
reserved_offset,
size,
stream_data->vectors().sparse_float_vector().dim(),
field_raw_data,
data.get());
} else if (type == DataType::GEOMETRY) {
// For geometry fields, append data incrementally to RTree index
indexing->AppendSegmentIndex(
reserved_offset, size, field_raw_data, stream_data);
}
}
const InsertRecord<false>& record);
// concurrent, reentrant
void
AppendingIndex(int64_t reserved_offset,
int64_t size,
FieldId fieldId,
const FieldDataPtr data,
const InsertRecord<false>& record) {
if (!is_in(fieldId)) {
return;
}
auto& indexing = field_indexings_.at(fieldId);
auto type = indexing->get_data_type();
const void* p = data->Data();
if ((type == DataType::VECTOR_FLOAT ||
type == DataType::VECTOR_FLOAT16 ||
type == DataType::VECTOR_BFLOAT16) &&
reserved_offset + size >= indexing->get_build_threshold()) {
auto vec_base = record.get_data_base(fieldId);
indexing->AppendSegmentIndexDense(
reserved_offset, size, vec_base, data->Data());
} else if (type == DataType::VECTOR_SPARSE_U32_F32) {
auto vec_base = record.get_data_base(fieldId);
indexing->AppendSegmentIndexSparse(
reserved_offset,
size,
std::dynamic_pointer_cast<const FieldData<SparseFloatVector>>(
data)
->Dim(),
vec_base,
p);
} else if (type == DataType::GEOMETRY) {
// For geometry fields, append data incrementally to RTree index
auto vec_base = record.get_data_base(fieldId);
indexing->AppendSegmentIndex(reserved_offset, size, vec_base, data);
}
}
const InsertRecord<false>& record);
// for sparse float vector:
// * element_size is not used

View File

@ -87,12 +87,6 @@ VecIndexConfig::VecIndexConfig(const int64_t max_index_row_cout,
int64_t
VecIndexConfig::GetBuildThreshold() const noexcept {
// For sparse, do not impose a threshold and start using index with any
// number of rows. Unlike dense vector index, growing sparse vector index
// does not require a minimum number of rows to train.
if (is_sparse_) {
return 0;
}
auto ratio = config_.get_build_ratio();
assert(ratio >= 0.0 && ratio < 1.0);
return std::max(int64_t(max_index_row_count_ * ratio),

View File

@ -1025,7 +1025,6 @@ class InsertRecordGrowing {
}
// append a column of vector type
// vector not support nullable, not pass valid data ptr
template <typename VectorType>
void
append_data(FieldId field_id,
@ -1033,9 +1032,15 @@ class InsertRecordGrowing {
int64_t size_per_chunk,
const storage::MmapChunkDescriptorPtr mmap_descriptor) {
static_assert(std::is_base_of_v<VectorTrait, VectorType>);
data_.emplace(field_id,
std::make_unique<ConcurrentVector<VectorType>>(
dim, size_per_chunk, mmap_descriptor));
bool use_mapping_storage = is_valid_data_exist(field_id);
data_.emplace(
field_id,
std::make_unique<ConcurrentVector<VectorType>>(
dim,
size_per_chunk,
mmap_descriptor,
use_mapping_storage ? get_valid_data(field_id) : nullptr,
use_mapping_storage));
}
// append a column of scalar or sparse float vector type
@ -1045,13 +1050,23 @@ class InsertRecordGrowing {
int64_t size_per_chunk,
const storage::MmapChunkDescriptorPtr mmap_descriptor) {
static_assert(IsScalar<Type> || IsSparse<Type>);
data_.emplace(
field_id,
std::make_unique<ConcurrentVector<Type>>(
size_per_chunk,
mmap_descriptor,
is_valid_data_exist(field_id) ? get_valid_data(field_id)
: nullptr));
bool use_mapping_storage = is_valid_data_exist(field_id);
if constexpr (IsSparse<Type>) {
data_.emplace(
field_id,
std::make_unique<ConcurrentVector<Type>>(
size_per_chunk,
mmap_descriptor,
use_mapping_storage ? get_valid_data(field_id) : nullptr,
use_mapping_storage));
} else {
data_.emplace(
field_id,
std::make_unique<ConcurrentVector<Type>>(
size_per_chunk,
mmap_descriptor,
use_mapping_storage ? get_valid_data(field_id) : nullptr));
}
}
void

View File

@ -312,13 +312,13 @@ SegmentGrowingImpl::Insert(int64_t reserved_offset,
AssertInfo(field_id_to_offset.count(field_id),
fmt::format("can't find field {}", field_id.get()));
auto data_offset = field_id_to_offset[field_id];
if (field_meta.is_nullable()) {
insert_record_.get_valid_data(field_id)->set_data_raw(
num_rows,
&insert_record_proto->fields_data(data_offset),
field_meta);
}
if (!indexing_record_.HasRawData(field_id)) {
if (field_meta.is_nullable()) {
insert_record_.get_valid_data(field_id)->set_data_raw(
num_rows,
&insert_record_proto->fields_data(data_offset),
field_meta);
}
insert_record_.get_data_base(field_id)->set_data_raw(
reserved_offset,
num_rows,
@ -937,14 +937,31 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
auto& field_meta = schema_->operator[](field_id);
auto vec_ptr = insert_record_.get_data_base(field_id);
if (field_meta.is_vector()) {
auto result = CreateEmptyVectorDataArray(count, field_meta);
int64_t valid_count = count;
const bool* valid_data = nullptr;
const int64_t* valid_offsets = seg_offsets;
ValidResult filter_result;
if (field_meta.is_nullable()) {
filter_result =
FilterVectorValidOffsets(op_ctx, field_id, seg_offsets, count);
valid_count = filter_result.valid_count;
valid_data = filter_result.valid_data.get();
valid_offsets = filter_result.valid_offsets.data();
}
auto result = CreateEmptyVectorDataArray(
count, valid_count, valid_data, field_meta);
if (valid_count == 0) {
return result;
}
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
bulk_subscript_impl<FloatVector>(op_ctx,
field_id,
field_meta.get_sizeof(),
vec_ptr,
seg_offsets,
count,
valid_offsets,
valid_count,
result->mutable_vectors()
->mutable_float_vector()
->mutable_data()
@ -955,8 +972,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
field_id,
field_meta.get_sizeof(),
vec_ptr,
seg_offsets,
count,
valid_offsets,
valid_count,
result->mutable_vectors()->mutable_binary_vector()->data());
} else if (field_meta.get_data_type() == DataType::VECTOR_FLOAT16) {
bulk_subscript_impl<Float16Vector>(
@ -964,8 +981,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
field_id,
field_meta.get_sizeof(),
vec_ptr,
seg_offsets,
count,
valid_offsets,
valid_count,
result->mutable_vectors()->mutable_float16_vector()->data());
} else if (field_meta.get_data_type() == DataType::VECTOR_BFLOAT16) {
bulk_subscript_impl<BFloat16Vector>(
@ -973,8 +990,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
field_id,
field_meta.get_sizeof(),
vec_ptr,
seg_offsets,
count,
valid_offsets,
valid_count,
result->mutable_vectors()->mutable_bfloat16_vector()->data());
} else if (field_meta.get_data_type() ==
DataType::VECTOR_SPARSE_U32_F32) {
@ -982,8 +999,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
op_ctx,
field_id,
(const ConcurrentVector<SparseFloatVector>*)vec_ptr,
seg_offsets,
count,
valid_offsets,
valid_count,
result->mutable_vectors()->mutable_sparse_float_vector());
result->mutable_vectors()->set_dim(
result->vectors().sparse_float_vector().dim());
@ -993,8 +1010,8 @@ SegmentGrowingImpl::bulk_subscript(milvus::OpContext* op_ctx,
field_id,
field_meta.get_sizeof(),
vec_ptr,
seg_offsets,
count,
valid_offsets,
valid_count,
result->mutable_vectors()->mutable_int8_vector()->data());
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
bulk_subscript_vector_array_impl(op_ctx,
@ -1190,7 +1207,7 @@ SegmentGrowingImpl::bulk_subscript_sparse_float_vector_impl(
[&](size_t i) {
auto offset = seg_offsets[i];
return offset != INVALID_SEG_OFFSET
? vec_raw->get_element(offset)
? vec_raw->get_physical_element(offset)
: nullptr;
},
count,
@ -1257,12 +1274,8 @@ SegmentGrowingImpl::bulk_subscript_impl(milvus::OpContext* op_ctx,
for (int i = 0; i < count; ++i) {
auto dst = output_base + i * element_sizeof;
auto offset = seg_offsets[i];
if (offset == INVALID_SEG_OFFSET) {
memset(dst, 0, element_sizeof);
} else {
auto src = (const uint8_t*)vec.get_element(offset);
memcpy(dst, src, element_sizeof);
}
auto src = (const uint8_t*)vec.get_physical_element(offset);
memcpy(dst, src, element_sizeof);
}
return;
}
@ -1860,4 +1873,68 @@ SegmentGrowingImpl::BuildGeometryCacheForLoad(
}
}
SegmentGrowingImpl::ValidResult
SegmentGrowingImpl::FilterVectorValidOffsets(milvus::OpContext* op_ctx,
FieldId field_id,
const int64_t* seg_offsets,
int64_t count) const {
ValidResult result;
result.valid_count = count;
if (indexing_record_.SyncDataWithIndex(field_id)) {
const auto& field_indexing =
indexing_record_.get_vec_field_indexing(field_id);
auto indexing = field_indexing.get_segment_indexing();
auto vec_index = dynamic_cast<index::VectorIndex*>(indexing.get());
if (vec_index != nullptr && vec_index->HasValidData()) {
result.valid_data = std::make_unique<bool[]>(count);
result.valid_offsets.reserve(count);
for (int64_t i = 0; i < count; ++i) {
bool is_valid = vec_index->IsRowValid(seg_offsets[i]);
result.valid_data[i] = is_valid;
if (is_valid) {
int64_t physical_offset =
vec_index->GetPhysicalOffset(seg_offsets[i]);
if (physical_offset >= 0) {
result.valid_offsets.push_back(physical_offset);
}
}
}
result.valid_count = result.valid_offsets.size();
}
} else {
auto vec_base = insert_record_.get_data_base(field_id);
if (vec_base != nullptr) {
const auto& valid_data_vec = vec_base->get_valid_data();
bool is_mapping_storage = vec_base->is_mapping_storage();
if (!valid_data_vec.empty()) {
result.valid_data = std::make_unique<bool[]>(count);
result.valid_offsets.reserve(count);
for (int64_t i = 0; i < count; ++i) {
auto offset = seg_offsets[i];
bool is_valid =
offset >= 0 &&
offset < static_cast<int64_t>(valid_data_vec.size()) &&
valid_data_vec[offset];
result.valid_data[i] = is_valid;
if (is_valid) {
if (is_mapping_storage) {
int64_t physical_offset =
vec_base->get_physical_offset(offset);
if (physical_offset >= 0) {
result.valid_offsets.push_back(physical_offset);
}
}
}
}
result.valid_count = result.valid_offsets.size();
}
}
}
return result;
}
} // namespace milvus::segcore

View File

@ -504,6 +504,17 @@ class SegmentGrowingImpl : public SegmentGrowing {
}
return nullptr;
}
struct ValidResult {
int64_t valid_count = 0;
std::unique_ptr<bool[]> valid_data;
std::vector<int64_t> valid_offsets;
};
ValidResult
FilterVectorValidOffsets(milvus::OpContext* op_ctx,
FieldId field_id,
const int64_t* seg_offsets,
int64_t count) const;
protected:
int64_t

View File

@ -280,11 +280,10 @@ TEST_P(GrowingIndexTest, Correctness) {
auto inserted = (i + 1) * per_batch;
// once index built, chunk data will be removed.
// growing index will only be built when num rows reached
// get_build_threshold(). This value for sparse is 0, thus sparse index
// will be built since the first chunk. Dense segment buffers the first
// get_build_threshold(). Both sparse and dense segment buffer the first
// 2 chunks before building an index in this test case.
if ((!is_sparse && i < 2) || !intermin_index_with_raw_data) {
if (i < 2 || !intermin_index_with_raw_data) {
EXPECT_EQ(field_data->num_chunk(),
upper_div(inserted, field_data->get_size_per_chunk()));
} else {

View File

@ -12,12 +12,18 @@
#include <gtest/gtest.h>
#include "common/Types.h"
#include "common/IndexMeta.h"
#include "knowhere/comp/index_param.h"
#include "segcore/SegmentGrowing.h"
#include "segcore/SegmentGrowingImpl.h"
#include "pb/schema.pb.h"
#include "pb/plan.pb.h"
#include "query/Plan.h"
#include "expr/ITypeExpr.h"
#include "plan/PlanNode.h"
#include "test_utils/DataGen.h"
#include "test_utils/storage_test_utils.h"
#include "test_utils/GenExprProto.h"
using namespace milvus::segcore;
using namespace milvus;
@ -435,6 +441,325 @@ TEST(Growing, FillNullableData) {
}
}
class GrowingNullableTest : public ::testing::TestWithParam<
std::tuple</*data_type*/ DataType,
/*metric_type*/ knowhere::MetricType,
/*index_type*/ std::string,
/*null_percent*/ int,
/*enable_interim_index*/ bool>> {
public:
void
SetUp() override {
std::tie(data_type,
metric_type,
index_type,
null_percent,
enable_interim_index) = GetParam();
}
DataType data_type;
knowhere::MetricType metric_type;
std::string index_type;
int null_percent;
bool enable_interim_index;
};
static std::vector<
std::tuple<DataType, knowhere::MetricType, std::string, int, bool>>
GenerateGrowingNullableTestParams() {
std::vector<
std::tuple<DataType, knowhere::MetricType, std::string, int, bool>>
params;
// Dense float vectors with IVF_FLAT
std::vector<std::tuple<DataType, knowhere::MetricType, std::string>>
base_configs = {
{DataType::VECTOR_FLOAT,
knowhere::metric::L2,
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT},
{DataType::VECTOR_FLOAT,
knowhere::metric::IP,
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT},
{DataType::VECTOR_FLOAT,
knowhere::metric::COSINE,
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT},
{DataType::VECTOR_SPARSE_U32_F32,
knowhere::metric::IP,
knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX},
};
std::vector<int> null_percents = {0, 20, 100};
std::vector<bool> interim_index_configs = {true, false};
for (const auto& [dtype, metric, idx_type] : base_configs) {
for (int null_pct : null_percents) {
for (bool enable_interim : interim_index_configs) {
params.push_back(
{dtype, metric, idx_type, null_pct, enable_interim});
}
}
}
return params;
}
INSTANTIATE_TEST_SUITE_P(
NullableVectorParameters,
GrowingNullableTest,
::testing::ValuesIn(GenerateGrowingNullableTestParams()));
TEST_P(GrowingNullableTest, SearchAndQueryNullableVectors) {
using namespace milvus::query;
bool nullable = true;
auto schema = std::make_shared<Schema>();
auto int64_field = schema->AddDebugField("int64", DataType::INT64);
int64_t dim = 8;
auto vec = schema->AddDebugField(
"embeddings", data_type, dim, metric_type, nullable);
schema->set_primary_field_id(int64_field);
std::map<std::string, std::string> index_params;
std::map<std::string, std::string> type_params;
if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
index_params = {{"index_type", index_type},
{"metric_type", metric_type}};
type_params = {};
} else {
index_params = {{"index_type", index_type},
{"metric_type", metric_type},
{"nlist", "128"}};
type_params = {{"dim", std::to_string(dim)}};
}
FieldIndexMeta fieldIndexMeta(
vec, std::move(index_params), std::move(type_params));
auto config = SegcoreConfig::default_config();
config.set_chunk_rows(1024);
config.set_enable_interim_segment_index(enable_interim_index);
// Explicitly set interim index type to avoid contamination from other tests
config.set_dense_vector_intermin_index_type(
knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC);
std::map<FieldId, FieldIndexMeta> filedMap = {{vec, fieldIndexMeta}};
IndexMetaPtr metaPtr =
std::make_shared<CollectionIndexMeta>(100000, std::move(filedMap));
auto segment_growing = CreateGrowingSegment(schema, metaPtr, 1, config);
auto segment = dynamic_cast<SegmentGrowingImpl*>(segment_growing.get());
int64_t batch_size = 2000;
int64_t num_rounds = 10;
int64_t topk = 5;
int64_t num_queries = 2;
Timestamp timestamp = 10000000;
// Prepare search plan
std::string search_params_fmt;
if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
search_params_fmt = R"(
vector_anns:<
field_id: {}
query_info:<
topk: {}
round_decimal: 3
metric_type: "{}"
search_params: "{{\"drop_ratio_search\": 0.1}}"
>
placeholder_tag: "$0"
>
)";
} else {
search_params_fmt = R"(
vector_anns:<
field_id: {}
query_info:<
topk: {}
round_decimal: 3
metric_type: "{}"
search_params: "{{\"nprobe\": 10}}"
>
placeholder_tag: "$0"
>
)";
}
auto raw_plan =
fmt::format(search_params_fmt, vec.get(), topk, metric_type);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
auto plan =
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
// Create query vectors
proto::common::PlaceholderGroup ph_group_raw;
if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
ph_group_raw = CreateSparseFloatPlaceholderGroup(num_queries, 42);
} else {
auto query_data = generate_float_vector(num_queries, dim);
ph_group_raw =
CreatePlaceholderGroupFromBlob(num_queries, dim, query_data.data());
}
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
// Store all inserted data for verification
// For nullable vectors, data is stored sparsely (only valid vectors)
// We need a mapping from logical offset to physical offset
std::vector<float> all_float_vectors; // Physical storage (only valid)
std::vector<knowhere::sparse::SparseRow<float>> all_sparse_vectors;
std::vector<bool> all_valid_data; // Logical storage (all rows)
std::vector<int64_t>
logical_to_physical; // Maps logical offset to physical
// Insert data in multiple rounds and test after each round
for (int64_t round = 0; round < num_rounds; round++) {
int64_t total_rows = (round + 1) * batch_size;
int64_t expected_valid_count =
total_rows - (total_rows * null_percent / 100);
auto dataset = DataGen(schema,
batch_size,
42 + round,
0,
1,
10,
1,
false,
true,
false,
null_percent);
// Build logical to physical mapping for this batch
int64_t base_physical = all_float_vectors.size() / dim;
if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
base_physical = all_sparse_vectors.size();
}
auto valid_data_from_dataset = dataset.get_col_valid(vec);
int64_t physical_idx = base_physical;
for (size_t i = 0; i < valid_data_from_dataset.size(); i++) {
if (valid_data_from_dataset[i]) {
logical_to_physical.push_back(physical_idx);
physical_idx++;
} else {
logical_to_physical.push_back(-1); // null
}
}
// Get original data directly from proto (sparse storage for nullable)
// Data is stored sparsely - only valid vectors are in the proto
if (data_type == DataType::VECTOR_FLOAT) {
auto field_data = dataset.get_col(vec);
auto& float_data = field_data->vectors().float_vector().data();
all_float_vectors.insert(
all_float_vectors.end(), float_data.begin(), float_data.end());
} else if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
auto field_data = dataset.get_col(vec);
auto& sparse_array = field_data->vectors().sparse_float_vector();
for (int i = 0; i < sparse_array.contents_size(); i++) {
auto& content = sparse_array.contents(i);
auto row = CopyAndWrapSparseRow(content.data(), content.size());
all_sparse_vectors.push_back(std::move(row));
}
}
all_valid_data.insert(all_valid_data.end(),
valid_data_from_dataset.begin(),
valid_data_from_dataset.end());
auto offset = segment->PreInsert(batch_size);
segment->Insert(offset,
batch_size,
dataset.row_ids_.data(),
dataset.timestamps_.data(),
dataset.raw_);
auto& insert_record = segment->get_insert_record();
ASSERT_TRUE(insert_record.is_valid_data_exist(vec));
auto valid_data_ptr = insert_record.get_data_base(vec);
const auto& valid_data = valid_data_ptr->get_valid_data();
// Test search
auto sr =
segment_growing->Search(plan.get(), ph_group.get(), timestamp);
ASSERT_EQ(sr->total_nq_, num_queries);
ASSERT_EQ(sr->unity_topK_, topk);
if (expected_valid_count == 0) {
auto total_results = sr->get_total_result_count();
EXPECT_EQ(total_results, 0)
<< "Round " << round
<< ": 100% null should return 0 results, but got "
<< total_results;
} else {
// Verify search results don't contain null vectors
for (size_t i = 0; i < sr->seg_offsets_.size(); i++) {
auto seg_offset = sr->seg_offsets_[i];
if (seg_offset < 0) {
continue;
}
ASSERT_TRUE(valid_data[seg_offset])
<< "Round " << round
<< ": Search returned null vector at offset " << seg_offset;
}
}
auto vec_result = segment->bulk_subscript(
nullptr, vec, sr->seg_offsets_.data(), sr->seg_offsets_.size());
ASSERT_TRUE(vec_result != nullptr);
if (data_type == DataType::VECTOR_FLOAT) {
auto& float_data = vec_result->vectors().float_vector();
size_t valid_idx = 0;
for (size_t i = 0; i < sr->seg_offsets_.size(); i++) {
auto offset = sr->seg_offsets_[i];
if (offset < 0) {
continue; // Skip invalid offsets
}
auto physical_idx = logical_to_physical[offset];
for (int d = 0; d < dim; d++) {
float expected_val =
all_float_vectors[physical_idx * dim + d];
float actual_val = float_data.data(valid_idx * dim + d);
ASSERT_FLOAT_EQ(expected_val, actual_val)
<< "Round " << round << ": Mismatch at logical offset "
<< offset << " dim " << d;
}
valid_idx++;
}
} else if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
auto& sparse_data = vec_result->vectors().sparse_float_vector();
size_t valid_idx = 0;
for (size_t i = 0; i < sr->seg_offsets_.size(); i++) {
auto offset = sr->seg_offsets_[i];
if (offset < 0) {
continue; // Skip invalid offsets
}
auto physical_idx = logical_to_physical[offset];
auto& content = sparse_data.contents(valid_idx);
auto retrieved_row =
CopyAndWrapSparseRow(content.data(), content.size());
const auto& expected_row = all_sparse_vectors[physical_idx];
ASSERT_EQ(retrieved_row.size(), expected_row.size())
<< "Round " << round
<< ": Sparse vector size mismatch at logical offset "
<< offset;
for (size_t j = 0; j < retrieved_row.size(); j++) {
ASSERT_EQ(retrieved_row[j].id, expected_row[j].id)
<< "Round " << round
<< ": Sparse vector id mismatch at logical offset "
<< offset << " element " << j;
ASSERT_FLOAT_EQ(retrieved_row[j].val, expected_row[j].val)
<< "Round " << round
<< ": Sparse vector val mismatch at logical offset "
<< offset << " element " << j;
}
valid_idx++;
}
}
}
}
TEST_P(GrowingTest, FillVectorArrayData) {
auto schema = std::make_shared<Schema>();
auto int64_field = schema->AddDebugField("int64", DataType::INT64);

View File

@ -500,10 +500,18 @@ SegmentInternalInterface::bulk_subscript_not_exist_field(
const milvus::FieldMeta& field_meta, int64_t count) const {
auto data_type = field_meta.get_data_type();
if (IsVectorDataType(data_type)) {
ThrowInfo(DataTypeInvalid,
fmt::format("unsupported added field type {}",
field_meta.get_data_type()));
AssertInfo(field_meta.is_nullable(),
"Non-nullable vector field should not reach here");
auto result = CreateEmptyVectorDataArray(0, field_meta);
auto valid_data = result->mutable_valid_data();
for (int64_t i = 0; i < count; ++i) {
valid_data->Add(false);
}
return result;
}
auto result = CreateEmptyScalarDataArray(count, field_meta);
if (field_meta.default_value().has_value()) {
auto res = result->mutable_valid_data()->mutable_data();

View File

@ -434,6 +434,23 @@ CreateEmptyVectorDataArray(int64_t count, const FieldMeta& field_meta) {
return data_array;
}
std::unique_ptr<DataArray>
CreateEmptyVectorDataArray(int64_t count,
int64_t valid_count,
const void* valid_data,
const FieldMeta& field_meta) {
int64_t data_count = (field_meta.is_nullable() && valid_data != nullptr)
? valid_count
: count;
auto data_array = CreateEmptyVectorDataArray(data_count, field_meta);
if (field_meta.is_nullable() && valid_data != nullptr) {
auto obj = data_array->mutable_valid_data();
auto valid_data_bool = reinterpret_cast<const bool*>(valid_data);
obj->Add(valid_data_bool, valid_data_bool + count);
}
return data_array;
}
std::unique_ptr<DataArray>
CreateScalarDataArrayFrom(const void* data_raw,
const void* valid_data,
@ -444,7 +461,7 @@ CreateScalarDataArrayFrom(const void* data_raw,
data_array->set_field_id(field_meta.get_id().get());
data_array->set_type(static_cast<milvus::proto::schema::DataType>(
field_meta.get_data_type()));
if (field_meta.is_nullable()) {
if (field_meta.is_nullable() && valid_data != nullptr) {
auto valid_data_ = reinterpret_cast<const bool*>(valid_data);
auto obj = data_array->mutable_valid_data();
obj->Add(valid_data_, valid_data_ + count);
@ -659,6 +676,22 @@ CreateVectorDataArrayFrom(const void* data_raw,
return data_array;
}
std::unique_ptr<DataArray>
CreateVectorDataArrayFrom(const void* data_raw,
const void* valid_data,
int64_t count,
int64_t valid_count,
const FieldMeta& field_meta) {
auto data_array =
CreateVectorDataArrayFrom(data_raw, valid_count, field_meta);
if (field_meta.is_nullable() && valid_data != nullptr) {
auto obj = data_array->mutable_valid_data();
auto valid_data_bool = reinterpret_cast<const bool*>(valid_data);
obj->Add(valid_data_bool, valid_data_bool + count);
}
return data_array;
}
std::unique_ptr<DataArray>
CreateDataArrayFrom(const void* data_raw,
const void* valid_data,
@ -691,6 +724,21 @@ MergeDataArray(std::vector<MergeBase>& merge_bases,
AssertInfo(data_type == DataType(src_field_data->type()),
"merge field data type not consistent");
if (field_meta.is_vector()) {
bool is_valid = true;
if (nullable) {
auto data = src_field_data->valid_data().data();
auto obj = data_array->mutable_valid_data();
is_valid = data[src_offset];
*(obj->Add()) = is_valid;
}
if (!is_valid) {
continue;
}
int64_t physical_offset =
merge_base.getValidDataOffset(field_meta.get_id());
auto vector_array = data_array->mutable_vectors();
auto dim = 0;
if (!IsSparseFloatVectorDataType(data_type)) {
@ -700,17 +748,19 @@ MergeDataArray(std::vector<MergeBase>& merge_bases,
if (field_meta.get_data_type() == DataType::VECTOR_FLOAT) {
auto data = VEC_FIELD_DATA(src_field_data, float).data();
auto obj = vector_array->mutable_float_vector();
obj->mutable_data()->Add(data + src_offset * dim,
data + (src_offset + 1) * dim);
obj->mutable_data()->Add(data + physical_offset * dim,
data + (physical_offset + 1) * dim);
} else if (field_meta.get_data_type() == DataType::VECTOR_FLOAT16) {
auto data = VEC_FIELD_DATA(src_field_data, float16);
auto obj = vector_array->mutable_float16_vector();
obj->assign(data, dim * sizeof(float16));
obj->assign(data + physical_offset * dim * sizeof(float16),
dim * sizeof(float16));
} else if (field_meta.get_data_type() ==
DataType::VECTOR_BFLOAT16) {
auto data = VEC_FIELD_DATA(src_field_data, bfloat16);
auto obj = vector_array->mutable_bfloat16_vector();
obj->assign(data, dim * sizeof(bfloat16));
obj->assign(data + physical_offset * dim * sizeof(bfloat16),
dim * sizeof(bfloat16));
} else if (field_meta.get_data_type() == DataType::VECTOR_BINARY) {
AssertInfo(
dim % 8 == 0,
@ -718,26 +768,28 @@ MergeDataArray(std::vector<MergeBase>& merge_bases,
auto num_bytes = dim / 8;
auto data = VEC_FIELD_DATA(src_field_data, binary);
auto obj = vector_array->mutable_binary_vector();
obj->assign(data + src_offset * num_bytes, num_bytes);
obj->assign(data + physical_offset * num_bytes, num_bytes);
} else if (field_meta.get_data_type() ==
DataType::VECTOR_SPARSE_U32_F32) {
auto src = src_field_data->vectors().sparse_float_vector();
auto& src_vec = src_field_data->vectors().sparse_float_vector();
auto dst = vector_array->mutable_sparse_float_vector();
if (src.dim() > dst->dim()) {
dst->set_dim(src.dim());
if (src_vec.dim() > dst->dim()) {
dst->set_dim(src_vec.dim());
}
vector_array->set_dim(dst->dim());
*dst->mutable_contents() = src.contents();
auto& src_contents = src_vec.contents(physical_offset);
*(dst->mutable_contents()->Add()) = src_contents;
} else if (field_meta.get_data_type() == DataType::VECTOR_INT8) {
auto data = VEC_FIELD_DATA(src_field_data, int8);
auto obj = vector_array->mutable_int8_vector();
obj->assign(data, dim * sizeof(int8));
obj->assign(data + physical_offset * dim * sizeof(int8),
dim * sizeof(int8));
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
auto data = src_field_data->vectors().vector_array();
auto& data = src_field_data->vectors().vector_array();
auto obj = vector_array->mutable_vector_array();
obj->set_element_type(
proto::schema::DataType(field_meta.get_element_type()));
obj->CopyFrom(data);
*(obj->mutable_data()->Add()) = data.data(physical_offset);
} else {
ThrowInfo(DataTypeInvalid,
fmt::format("unsupported datatype {}", data_type));

View File

@ -55,6 +55,12 @@ CreateEmptyScalarDataArray(int64_t count, const FieldMeta& field_meta);
std::unique_ptr<DataArray>
CreateEmptyVectorDataArray(int64_t count, const FieldMeta& field_meta);
std::unique_ptr<DataArray>
CreateEmptyVectorDataArray(int64_t count,
int64_t valid_count,
const void* valid_data,
const FieldMeta& field_meta);
std::unique_ptr<DataArray>
CreateScalarDataArrayFrom(const void* data_raw,
const void* valid_data,
@ -66,6 +72,13 @@ CreateVectorDataArrayFrom(const void* data_raw,
int64_t count,
const FieldMeta& field_meta);
std::unique_ptr<DataArray>
CreateVectorDataArrayFrom(const void* data_raw,
const void* valid_data,
int64_t count,
int64_t valid_count,
const FieldMeta& field_meta);
std::unique_ptr<DataArray>
CreateDataArrayFrom(const void* data_raw,
const void* valid_data,
@ -77,6 +90,7 @@ struct MergeBase {
private:
std::map<FieldId, std::unique_ptr<milvus::DataArray>>* output_fields_data_;
size_t offset_;
std::map<FieldId, size_t> valid_data_offsets_;
public:
MergeBase() {
@ -93,6 +107,20 @@ struct MergeBase {
return offset_;
}
void
setValidDataOffset(FieldId fieldId, size_t valid_offset) {
valid_data_offsets_[fieldId] = valid_offset;
}
size_t
getValidDataOffset(FieldId fieldId) const {
auto it = valid_data_offsets_.find(fieldId);
if (it != valid_data_offsets_.end()) {
return it->second;
}
return offset_;
}
milvus::DataArray*
get_field_data(FieldId fieldId) const {
return (*output_fields_data_)[fieldId].get();

View File

@ -94,3 +94,96 @@ TEST(Util_Segcore, GetDeleteBitmap) {
delete_record.Query(res_view, insert_barrier, query_timestamp);
ASSERT_EQ(res_view.count(), 0);
}
TEST(Util_Segcore, CreateVectorDataArrayFromNullableVectors) {
using namespace milvus;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto vec = schema->AddDebugField(
"embeddings", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2, true);
auto& field_meta = (*schema)[vec];
int64_t dim = 16;
int64_t total_count = 10;
int64_t valid_count = 5;
std::vector<float> data(valid_count * dim);
for (size_t i = 0; i < data.size(); ++i) {
data[i] = static_cast<float>(i);
}
std::unique_ptr<bool[]> valid_flags = std::make_unique<bool[]>(total_count);
for (int64_t i = 0; i < total_count; ++i) {
if (i % 2 == 0) {
valid_flags[i] = true;
} else {
valid_flags[i] = false;
}
}
auto result = CreateVectorDataArrayFrom(
data.data(), valid_flags.get(), total_count, valid_count, field_meta);
ASSERT_TRUE(result->valid_data().size() > 0);
ASSERT_EQ(result->valid_data().size(), total_count);
ASSERT_EQ(result->vectors().float_vector().data_size(), valid_count * dim);
}
TEST(Util_Segcore, MergeDataArrayWithNullableVectors) {
using namespace milvus;
using namespace milvus::segcore;
auto schema = std::make_shared<Schema>();
auto vec = schema->AddDebugField(
"embeddings", DataType::VECTOR_FLOAT, 16, knowhere::metric::L2, true);
auto& field_meta = (*schema)[vec];
int64_t dim = 16;
int64_t total_count = 10;
int64_t valid_count = 5;
std::vector<float> data(valid_count * dim);
for (size_t i = 0; i < data.size(); ++i) {
data[i] = static_cast<float>(i);
}
std::unique_ptr<bool[]> valid_flags = std::make_unique<bool[]>(total_count);
for (int64_t i = 0; i < total_count; ++i) {
if (i % 2 == 0) {
valid_flags[i] = true;
} else {
valid_flags[i] = false;
}
}
auto data_array = CreateVectorDataArrayFrom(
data.data(), valid_flags.get(), total_count, valid_count, field_meta);
std::map<FieldId, std::unique_ptr<milvus::DataArray>> output_fields_data;
output_fields_data[vec] = std::move(data_array);
std::vector<MergeBase> merge_bases;
merge_bases.emplace_back(&output_fields_data, 0);
merge_bases.back().setValidDataOffset(vec, 0);
merge_bases.emplace_back(&output_fields_data, 2);
merge_bases.back().setValidDataOffset(vec, 1);
merge_bases.emplace_back(&output_fields_data, 4);
merge_bases.back().setValidDataOffset(vec, 2);
merge_bases.emplace_back(&output_fields_data, 6);
merge_bases.back().setValidDataOffset(vec, 3);
merge_bases.emplace_back(&output_fields_data, 8);
merge_bases.back().setValidDataOffset(vec, 4);
auto merged_result = MergeDataArray(merge_bases, field_meta);
ASSERT_TRUE(merged_result->valid_data().size() > 0);
ASSERT_EQ(merged_result->valid_data().size(), 5);
ASSERT_EQ(merged_result->vectors().float_vector().data_size(), 5 * dim);
ASSERT_TRUE(merged_result->valid_data(0));
ASSERT_TRUE(merged_result->valid_data(1));
ASSERT_TRUE(merged_result->valid_data(2));
ASSERT_TRUE(merged_result->valid_data(3));
ASSERT_TRUE(merged_result->valid_data(4));
}

View File

@ -541,6 +541,27 @@ ReduceHelper::GetSearchResultDataSlice(const int slice_index,
// set result offset to fill output fields data
result_pairs[loc] = {&search_result->output_fields_data_, ki};
for (auto field_id : plan_->target_entries_) {
auto& field_meta = plan_->schema_->operator[](field_id);
if (field_meta.is_vector() && field_meta.is_nullable()) {
auto it =
search_result->output_fields_data_.find(field_id);
if (it != search_result->output_fields_data_.end()) {
auto& field_data = it->second;
if (field_data->valid_data_size() > 0) {
int64_t valid_idx = 0;
for (int64_t i = 0; i < ki; ++i) {
if (field_data->valid_data(i)) {
valid_idx++;
}
}
result_pairs[loc].setValidDataOffset(field_id,
valid_idx);
}
}
}
}
}
}

View File

@ -18,6 +18,32 @@
namespace milvus::segcore {
void
StreamReducerHelper::SetNullableVectorValidDataOffsets(
const std::map<FieldId, std::unique_ptr<milvus::DataArray>>&
output_fields_data,
int64_t ki,
MergeBase& merge_base) {
for (auto field_id : plan_->target_entries_) {
auto& field_meta = plan_->schema_->operator[](field_id);
if (field_meta.is_vector() && field_meta.is_nullable()) {
auto it = output_fields_data.find(field_id);
if (it != output_fields_data.end()) {
auto& field_data = it->second;
if (field_data->valid_data_size() > 0) {
int64_t physical_offset = 0;
for (int64_t j = 0; j < ki; ++j) {
if (field_data->valid_data(j)) {
physical_offset++;
}
}
merge_base.setValidDataOffset(field_id, physical_offset);
}
}
}
}
}
void
StreamReducerHelper::FillEntryData() {
for (auto search_result : search_results_to_merge_) {
@ -98,6 +124,10 @@ StreamReducerHelper::AssembleMergedResult() {
}
merge_output_data_bases[nq_base_offset + loc] = {
&search_result->output_fields_data_, ki};
SetNullableVectorValidDataOffsets(
search_result->output_fields_data_,
ki,
merge_output_data_bases[nq_base_offset + loc]);
new_result_offsets[nq_base_offset + loc] = loc;
real_topKs[qi]++;
}
@ -127,6 +157,10 @@ StreamReducerHelper::AssembleMergedResult() {
}
merge_output_data_bases[nq_base_offset + loc] = {
&merged_search_result->output_fields_data_, ki};
SetNullableVectorValidDataOffsets(
merged_search_result->output_fields_data_,
ki,
merge_output_data_bases[nq_base_offset + loc]);
new_result_offsets[nq_base_offset + loc] = loc;
real_topKs[qi]++;
}

View File

@ -19,6 +19,7 @@
#include "query/PlanImpl.h"
#include "common/QueryResult.h"
#include "segcore/ReduceStructure.h"
#include "segcore/Utils.h"
#include "common/EasyAssert.h"
namespace milvus::segcore {
@ -207,6 +208,13 @@ class StreamReducerHelper {
void
CleanReduceStatus();
void
SetNullableVectorValidDataOffsets(
const std::map<FieldId, std::unique_ptr<milvus::DataArray>>&
output_fields_data,
int64_t ki,
MergeBase& merge_base);
std::unique_ptr<MergedSearchResult> merged_search_result;
milvus::query::Plan* plan_;
std::vector<int64_t> slice_nqs_;

View File

@ -107,6 +107,16 @@ DefaultValueChunkTranslator::estimated_byte_size_of_cell(
case milvus::DataType::ARRAY:
value_size = sizeof(Array);
break;
case milvus::DataType::VECTOR_FLOAT:
case milvus::DataType::VECTOR_BINARY:
case milvus::DataType::VECTOR_FLOAT16:
case milvus::DataType::VECTOR_BFLOAT16:
case milvus::DataType::VECTOR_INT8:
case milvus::DataType::VECTOR_SPARSE_U32_F32:
AssertInfo(field_meta_.is_nullable(),
"only nullable vector fields can be dynamically added");
value_size = 0;
break;
default:
ThrowInfo(DataTypeInvalid,
"unsupported default value data type {}",
@ -128,8 +138,15 @@ DefaultValueChunkTranslator::get_cells(
AssertInfo(cids.size() == 1 && cids[0] == 0,
"DefaultValueChunkTranslator only supports one cell");
auto num_rows = meta_.num_rows_until_chunk_[1];
auto builder =
milvus::storage::CreateArrowBuilder(field_meta_.get_data_type());
auto data_type = field_meta_.get_data_type();
std::shared_ptr<arrow::ArrayBuilder> builder;
if (IsVectorDataType(data_type)) {
AssertInfo(field_meta_.is_nullable(),
"only nullable vector fields can be dynamically added");
builder = std::make_shared<arrow::BinaryBuilder>();
} else {
builder = milvus::storage::CreateArrowBuilder(data_type);
}
arrow::Status ast;
if (field_meta_.default_value().has_value()) {
ast = builder->Reserve(num_rows);

View File

@ -143,20 +143,54 @@ InterimSealedIndexTranslator::get_cells(
}
auto num_chunk = vec_data_->num_chunks();
const auto& offset_mapping = vec_data_->GetOffsetMapping();
bool nullable = offset_mapping.IsEnabled();
const auto& valid_count_per_chunk =
nullable ? vec_data_->GetValidCountPerChunk() : std::vector<int64_t>{};
int64_t total_valid_count =
nullable ? offset_mapping.GetValidCount() : vec_data_->NumRows();
if (total_valid_count == 0) {
if (nullable) {
const auto& valid_data = vec_data_->GetValidData();
vec_index->BuildValidData(valid_data.data(), valid_data.size());
}
std::vector<std::pair<cid_t, std::unique_ptr<milvus::index::IndexBase>>>
result;
result.emplace_back(std::make_pair(0, std::move(vec_index)));
return result;
}
bool first_build = true;
for (int i = 0; i < num_chunk; ++i) {
auto pw = vec_data_->GetChunk(nullptr, i);
auto chunk = pw.get();
auto dataset = knowhere::GenDataSet(
vec_data_->chunk_row_nums(i), dim_, chunk->Data());
int64_t actual_row_count =
nullable ? valid_count_per_chunk[i] : vec_data_->chunk_row_nums(i);
if (actual_row_count == 0) {
continue;
}
auto dataset =
knowhere::GenDataSet(actual_row_count, dim_, chunk->Data());
dataset->SetIsOwner(false);
dataset->SetIsSparse(is_sparse_);
if (i == 0) {
if (first_build) {
vec_index->BuildWithDataset(dataset, build_config_);
first_build = false;
} else {
vec_index->AddWithDataset(dataset, build_config_);
}
}
if (nullable) {
const auto& valid_data = vec_data_->GetValidData();
vec_index->BuildValidData(valid_data.data(), valid_data.size());
}
std::vector<std::pair<cid_t, std::unique_ptr<milvus::index::IndexBase>>>
result;
result.emplace_back(std::make_pair(0, std::move(vec_index)));

View File

@ -763,7 +763,80 @@ TEST(storage, InsertDataFloatVector) {
ASSERT_EQ(data, new_data);
}
TEST(storage, InsertDataSparseFloat) {
TEST(storage, InsertDataFloatVectorNullable) {
int DIM = 4;
int num_rows = 100;
for (int null_percent : {0, 20, 100}) {
int valid_count = num_rows * (100 - null_percent) / 100;
bool is_nullable = true;
std::vector<float> data(valid_count * DIM);
for (int i = 0; i < valid_count * DIM; ++i) {
data[i] = static_cast<float>(i) * 0.5f;
}
FieldDataPtr field_data;
std::vector<uint8_t> valid_data((num_rows + 7) / 8, 0);
for (int i = 0; i < valid_count; ++i) {
valid_data[i >> 3] |= (1 << (i & 0x07));
}
field_data = milvus::storage::CreateFieldData(
storage::DataType::VECTOR_FLOAT, DataType::NONE, true, DIM);
auto field_data_impl =
std::dynamic_pointer_cast<milvus::FieldData<milvus::FloatVector>>(
field_data);
field_data_impl->FillFieldData(
data.data(), valid_data.data(), num_rows, 0);
ASSERT_EQ(field_data->get_num_rows(), num_rows);
ASSERT_EQ(field_data->get_valid_rows(), valid_count);
ASSERT_EQ(field_data->get_null_count(), num_rows - valid_count);
ASSERT_EQ(field_data->IsNullable(), is_nullable);
auto payload_reader =
std::make_shared<milvus::storage::PayloadReader>(field_data);
storage::InsertData insert_data(payload_reader);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes =
insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
ASSERT_EQ(new_insert_data->GetTimeRage(),
std::make_pair(Timestamp(0), Timestamp(100)));
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(),
storage::DataType::VECTOR_FLOAT);
ASSERT_EQ(new_payload->get_num_rows(), num_rows);
ASSERT_EQ(new_payload->get_valid_rows(), valid_count);
ASSERT_EQ(new_payload->get_null_count(), num_rows - valid_count);
ASSERT_EQ(new_payload->IsNullable(), is_nullable);
int valid_idx = 0;
for (int i = 0; i < num_rows; ++i) {
if (new_payload->is_valid(i)) {
// RawValue takes logical offset, internally converts to physical
auto vec_ptr =
static_cast<const float*>(new_payload->RawValue(i));
for (int j = 0; j < DIM; ++j) {
ASSERT_FLOAT_EQ(vec_ptr[j], data[valid_idx * DIM + j]);
}
valid_idx++;
}
}
}
}
TEST(storage, InsertDataSparseFloatVector) {
auto n_rows = 100;
auto vecs = milvus::segcore::GenerateRandomSparseFloatVector(
n_rows, kTestSparseDim, kTestSparseVectorDensity);
@ -810,6 +883,75 @@ TEST(storage, InsertDataSparseFloat) {
}
}
TEST(storage, InsertDataSparseFloatVectorNullable) {
int num_rows = 100;
for (int null_percent : {0, 20, 100}) {
int valid_count = num_rows * (100 - null_percent) / 100;
bool is_nullable = true;
auto vecs = milvus::segcore::GenerateRandomSparseFloatVector(
valid_count, kTestSparseDim, kTestSparseVectorDensity);
FieldDataPtr field_data;
std::vector<uint8_t> valid_data((num_rows + 7) / 8, 0);
for (int i = 0; i < valid_count; ++i) {
valid_data[i >> 3] |= (1 << (i & 0x07));
}
field_data =
milvus::storage::CreateFieldData(DataType::VECTOR_SPARSE_U32_F32,
DataType::NONE,
true,
kTestSparseDim,
num_rows);
auto field_data_impl = std::dynamic_pointer_cast<
milvus::FieldData<milvus::SparseFloatVector>>(field_data);
field_data_impl->FillFieldData(
vecs.get(), valid_data.data(), num_rows, 0);
ASSERT_EQ(field_data->get_num_rows(), num_rows);
ASSERT_EQ(field_data->get_valid_rows(), valid_count);
ASSERT_EQ(field_data->IsNullable(), is_nullable);
auto payload_reader =
std::make_shared<milvus::storage::PayloadReader>(field_data);
storage::InsertData insert_data(payload_reader);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes =
insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
auto new_payload = new_insert_data->GetFieldData();
ASSERT_TRUE(new_payload->get_data_type() ==
storage::DataType::VECTOR_SPARSE_U32_F32);
ASSERT_EQ(new_payload->get_num_rows(), num_rows);
ASSERT_EQ(new_payload->IsNullable(), is_nullable);
int valid_idx = 0;
for (int i = 0; i < num_rows; ++i) {
if (new_payload->is_valid(i)) {
auto& original = vecs[valid_idx];
auto new_vec = static_cast<const knowhere::sparse::SparseRow<
milvus::SparseValueType>*>(new_payload->RawValue(i));
ASSERT_EQ(original.size(), new_vec->size());
for (size_t j = 0; j < original.size(); ++j) {
ASSERT_EQ(original[j].id, (*new_vec)[j].id);
ASSERT_EQ(original[j].val, (*new_vec)[j].val);
}
valid_idx++;
}
}
}
}
TEST(storage, InsertDataBinaryVector) {
std::vector<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8};
int DIM = 16;
@ -841,6 +983,72 @@ TEST(storage, InsertDataBinaryVector) {
ASSERT_EQ(data, new_data);
}
TEST(storage, InsertDataBinaryVectorNullable) {
int DIM = 128;
int num_rows = 100;
for (int null_percent : {0, 20, 100}) {
int valid_count = num_rows * (100 - null_percent) / 100;
bool is_nullable = true;
std::vector<uint8_t> data(valid_count * DIM / 8);
for (size_t i = 0; i < data.size(); ++i) {
data[i] = static_cast<uint8_t>(i % 256);
}
FieldDataPtr field_data;
std::vector<uint8_t> valid_data((num_rows + 7) / 8, 0);
for (int i = 0; i < valid_count; ++i) {
valid_data[i >> 3] |= (1 << (i & 0x07));
}
field_data = milvus::storage::CreateFieldData(
storage::DataType::VECTOR_BINARY, DataType::NONE, true, DIM);
auto field_data_impl =
std::dynamic_pointer_cast<milvus::FieldData<milvus::BinaryVector>>(
field_data);
field_data_impl->FillFieldData(
data.data(), valid_data.data(), num_rows, 0);
ASSERT_EQ(field_data->get_num_rows(), num_rows);
ASSERT_EQ(field_data->get_valid_rows(), valid_count);
ASSERT_EQ(field_data->IsNullable(), is_nullable);
auto payload_reader =
std::make_shared<milvus::storage::PayloadReader>(field_data);
storage::InsertData insert_data(payload_reader);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes =
insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(),
storage::DataType::VECTOR_BINARY);
ASSERT_EQ(new_payload->get_num_rows(), num_rows);
ASSERT_EQ(new_payload->IsNullable(), is_nullable);
int valid_idx = 0;
for (int i = 0; i < num_rows; ++i) {
if (new_payload->is_valid(i)) {
auto vec_ptr =
static_cast<const uint8_t*>(new_payload->RawValue(i));
for (int j = 0; j < DIM / 8; ++j) {
ASSERT_EQ(vec_ptr[j], data[valid_idx * DIM / 8 + j]);
}
valid_idx++;
}
}
}
}
TEST(storage, InsertDataFloat16Vector) {
std::vector<float16> data = {1, 2, 3, 4, 5, 6, 7, 8};
int DIM = 2;
@ -874,6 +1082,72 @@ TEST(storage, InsertDataFloat16Vector) {
ASSERT_EQ(data, new_data);
}
TEST(storage, InsertDataFloat16VectorNullable) {
int DIM = 4;
int num_rows = 100;
for (int null_percent : {0, 20, 100}) {
int valid_count = num_rows * (100 - null_percent) / 100;
bool is_nullable = true;
std::vector<float16> data(valid_count * DIM);
for (int i = 0; i < valid_count * DIM; ++i) {
data[i] = static_cast<float16>(i * 0.5f);
}
FieldDataPtr field_data;
std::vector<uint8_t> valid_data((num_rows + 7) / 8, 0);
for (int i = 0; i < valid_count; ++i) {
valid_data[i >> 3] |= (1 << (i & 0x07));
}
field_data = milvus::storage::CreateFieldData(
storage::DataType::VECTOR_FLOAT16, DataType::NONE, true, DIM);
auto field_data_impl =
std::dynamic_pointer_cast<milvus::FieldData<milvus::Float16Vector>>(
field_data);
field_data_impl->FillFieldData(
data.data(), valid_data.data(), num_rows, 0);
ASSERT_EQ(field_data->get_num_rows(), num_rows);
ASSERT_EQ(field_data->get_valid_rows(), valid_count);
ASSERT_EQ(field_data->IsNullable(), is_nullable);
auto payload_reader =
std::make_shared<milvus::storage::PayloadReader>(field_data);
storage::InsertData insert_data(payload_reader);
storage::FieldDataMeta field_data_meta{100, 101, 102, 103};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes =
insert_data.Serialize(storage::StorageType::Remote);
std::shared_ptr<uint8_t[]> serialized_data_ptr(serialized_bytes.data(),
[&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, serialized_bytes.size());
ASSERT_EQ(new_insert_data->GetCodecType(), storage::InsertDataType);
auto new_payload = new_insert_data->GetFieldData();
ASSERT_EQ(new_payload->get_data_type(),
storage::DataType::VECTOR_FLOAT16);
ASSERT_EQ(new_payload->get_num_rows(), num_rows);
ASSERT_EQ(new_payload->IsNullable(), is_nullable);
int valid_idx = 0;
for (int i = 0; i < num_rows; ++i) {
if (new_payload->is_valid(i)) {
auto vec_ptr =
static_cast<const float16*>(new_payload->RawValue(i));
for (int j = 0; j < DIM; ++j) {
ASSERT_EQ(vec_ptr[j], data[valid_idx * DIM + j]);
}
valid_idx++;
}
}
}
}
TEST(storage, IndexData) {
std::vector<uint8_t> data = {1, 2, 3, 4, 5, 6, 7, 8};
storage::IndexData index_data(data.data(), data.size());

View File

@ -536,7 +536,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_internal(const Config& config) {
int batch_size = batch_files.size();
for (int i = 0; i < batch_size; i++) {
auto field_data = field_datas[i].get()->GetFieldData();
num_rows += uint32_t(field_data->get_num_rows());
num_rows += uint32_t(field_data->get_valid_rows());
cache_raw_data_to_disk_common<DataType>(
field_data,
local_chunk_manager,
@ -634,7 +634,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_common(
auto sparse_rows =
static_cast<const knowhere::sparse::SparseRow<SparseValueType>*>(
field_data->Data());
for (size_t i = 0; i < field_data->Length(); ++i) {
for (size_t i = 0; i < field_data->get_valid_rows(); ++i) {
auto row = sparse_rows[i];
auto row_byte_size = row.data_byte_size();
uint32_t nnz = row.size();
@ -689,7 +689,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_common(
} else {
dim = field_data->get_dim();
auto data_size =
field_data->get_num_rows() * milvus::GetVecRowSize<DataType>(dim);
field_data->get_valid_rows() * milvus::GetVecRowSize<DataType>(dim);
local_chunk_manager->Write(local_data_path,
write_offset,
const_cast<void*>(field_data->Data()),
@ -761,7 +761,7 @@ DiskFileManagerImpl::cache_raw_data_to_disk_storage_v2(const Config& config) {
fs_);
}
for (auto& field_data : field_datas) {
num_rows += uint32_t(field_data->get_num_rows());
num_rows += uint32_t(field_data->get_valid_rows());
cache_raw_data_to_disk_common<T>(field_data,
local_chunk_manager,
local_data_path,

View File

@ -523,6 +523,262 @@ TEST_F(DiskAnnFileManagerTest, CacheOptFieldToDiskOnlyOneCategory) {
}
}
TEST_F(DiskAnnFileManagerTest, CacheRawDataToDiskNullableVector) {
const int64_t collection_id = 1;
const int64_t partition_id = 2;
const int64_t segment_id = 3;
const int64_t field_id = 100;
const int64_t dim = 128;
const int64_t num_rows = 1000;
struct VectorTypeInfo {
DataType data_type;
std::string type_name;
size_t element_size;
bool is_sparse;
};
std::vector<VectorTypeInfo> vector_types = {
{DataType::VECTOR_FLOAT, "FLOAT", sizeof(float), false},
{DataType::VECTOR_FLOAT16, "FLOAT16", sizeof(knowhere::fp16), false},
{DataType::VECTOR_BFLOAT16, "BFLOAT16", sizeof(knowhere::bf16), false},
{DataType::VECTOR_INT8, "INT8", sizeof(int8_t), false},
{DataType::VECTOR_BINARY, "BINARY", dim / 8, false},
{DataType::VECTOR_SPARSE_U32_F32, "SPARSE", 0, true}};
for (const auto& vec_type : vector_types) {
for (int null_percent : {0, 20, 100}) {
int64_t valid_count = num_rows * (100 - null_percent) / 100;
std::vector<uint8_t> valid_data((num_rows + 7) / 8, 0);
for (int64_t i = 0; i < valid_count; ++i) {
valid_data[i >> 3] |= (1 << (i & 0x07));
}
FieldDataPtr field_data;
std::vector<uint8_t> vec_data;
std::unique_ptr<knowhere::sparse::SparseRow<float>[]> sparse_vecs;
if (vec_type.is_sparse) {
const int64_t sparse_dim = 1000;
const float sparse_density = 0.1;
sparse_vecs = milvus::segcore::GenerateRandomSparseFloatVector(
valid_count, sparse_dim, sparse_density);
field_data =
storage::CreateFieldData(DataType::VECTOR_SPARSE_U32_F32,
DataType::NONE,
true,
sparse_dim,
num_rows);
auto field_data_impl = std::dynamic_pointer_cast<
milvus::FieldData<milvus::SparseFloatVector>>(field_data);
field_data_impl->FillFieldData(
sparse_vecs.get(), valid_data.data(), num_rows, 0);
} else {
if (vec_type.data_type == DataType::VECTOR_BINARY) {
vec_data.resize(valid_count * dim / 8);
} else {
vec_data.resize(valid_count * dim * vec_type.element_size);
}
for (size_t i = 0; i < vec_data.size(); ++i) {
vec_data[i] = static_cast<uint8_t>(i % 256);
}
field_data = storage::CreateFieldData(
vec_type.data_type, DataType::NONE, true, dim);
if (vec_type.data_type == DataType::VECTOR_FLOAT) {
auto impl = std::dynamic_pointer_cast<
milvus::FieldData<milvus::FloatVector>>(field_data);
impl->FillFieldData(
vec_data.data(), valid_data.data(), num_rows, 0);
} else if (vec_type.data_type == DataType::VECTOR_FLOAT16) {
auto impl = std::dynamic_pointer_cast<
milvus::FieldData<milvus::Float16Vector>>(field_data);
impl->FillFieldData(
vec_data.data(), valid_data.data(), num_rows, 0);
} else if (vec_type.data_type == DataType::VECTOR_BFLOAT16) {
auto impl = std::dynamic_pointer_cast<
milvus::FieldData<milvus::BFloat16Vector>>(field_data);
impl->FillFieldData(
vec_data.data(), valid_data.data(), num_rows, 0);
} else if (vec_type.data_type == DataType::VECTOR_INT8) {
auto impl = std::dynamic_pointer_cast<
milvus::FieldData<milvus::Int8Vector>>(field_data);
impl->FillFieldData(
vec_data.data(), valid_data.data(), num_rows, 0);
} else if (vec_type.data_type == DataType::VECTOR_BINARY) {
auto impl = std::dynamic_pointer_cast<
milvus::FieldData<milvus::BinaryVector>>(field_data);
impl->FillFieldData(
vec_data.data(), valid_data.data(), num_rows, 0);
}
}
ASSERT_EQ(field_data->get_num_rows(), num_rows);
ASSERT_EQ(field_data->get_valid_rows(), valid_count);
auto payload_reader =
std::make_shared<milvus::storage::PayloadReader>(field_data);
storage::InsertData insert_data(payload_reader);
FieldDataMeta field_data_meta = {
collection_id, partition_id, segment_id, field_id};
insert_data.SetFieldDataMeta(field_data_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_data =
insert_data.Serialize(storage::StorageType::Remote);
std::string insert_file_path = "/tmp/diskann/nullable_" +
vec_type.type_name + "_" +
std::to_string(null_percent);
boost::filesystem::remove_all(insert_file_path);
cm_->Write(insert_file_path,
serialized_data.data(),
serialized_data.size());
if (vec_type.is_sparse) {
int64_t file_size = cm_->Size(insert_file_path);
std::vector<uint8_t> buffer(file_size);
cm_->Read(insert_file_path, buffer.data(), file_size);
std::shared_ptr<uint8_t[]> serialized_data_ptr(
buffer.data(), [&](uint8_t*) {});
auto new_insert_data = storage::DeserializeFileData(
serialized_data_ptr, buffer.size());
ASSERT_EQ(new_insert_data->GetCodecType(),
storage::InsertDataType);
auto new_payload = new_insert_data->GetFieldData();
ASSERT_TRUE(new_payload->get_data_type() ==
DataType::VECTOR_SPARSE_U32_F32);
ASSERT_EQ(new_payload->get_num_rows(), num_rows)
<< "num_rows mismatch for " << vec_type.type_name
<< " with null_percent=" << null_percent;
ASSERT_EQ(new_payload->get_valid_rows(), valid_count)
<< "valid_rows mismatch for " << vec_type.type_name
<< " with null_percent=" << null_percent;
ASSERT_TRUE(new_payload->IsNullable());
for (int i = 0; i < num_rows; ++i) {
if (i < valid_count) {
ASSERT_TRUE(new_payload->is_valid(i))
<< "Row " << i
<< " should be valid for null_percent="
<< null_percent;
auto original = &sparse_vecs[i];
auto new_vec =
static_cast<const knowhere::sparse::SparseRow<
milvus::SparseValueType>*>(
new_payload->RawValue(i));
ASSERT_EQ(original->size(), new_vec->size())
<< "Size mismatch at row " << i
<< " for null_percent=" << null_percent;
for (size_t j = 0; j < original->size(); ++j) {
ASSERT_EQ((*original)[j].id, (*new_vec)[j].id)
<< "ID mismatch at row " << i << ", element "
<< j << " for null_percent=" << null_percent;
ASSERT_EQ((*original)[j].val, (*new_vec)[j].val)
<< "Value mismatch at row " << i << ", element "
<< j << " for null_percent=" << null_percent;
}
} else {
ASSERT_FALSE(new_payload->is_valid(i))
<< "Row " << i
<< " should be null for null_percent="
<< null_percent;
}
}
} else {
IndexMeta index_meta = {segment_id,
field_id,
1000,
1,
"test",
"vec_field",
vec_type.data_type,
dim};
auto file_manager = std::make_shared<DiskFileManagerImpl>(
storage::FileManagerContext(
field_data_meta, index_meta, cm_, fs_));
milvus::Config config;
config[INSERT_FILES_KEY] =
std::vector<std::string>{insert_file_path};
std::string local_data_path;
if (vec_type.data_type == DataType::VECTOR_FLOAT) {
local_data_path =
file_manager->CacheRawDataToDisk<float>(config);
} else if (vec_type.data_type == DataType::VECTOR_INT8) {
local_data_path =
file_manager->CacheRawDataToDisk<int8_t>(config);
} else if (vec_type.data_type == DataType::VECTOR_FLOAT16) {
local_data_path =
file_manager->CacheRawDataToDisk<knowhere::fp16>(
config);
} else if (vec_type.data_type == DataType::VECTOR_BFLOAT16) {
local_data_path =
file_manager->CacheRawDataToDisk<knowhere::bf16>(
config);
} else if (vec_type.data_type == DataType::VECTOR_BINARY) {
local_data_path =
file_manager->CacheRawDataToDisk<uint8_t>(config);
}
ASSERT_FALSE(local_data_path.empty())
<< "Failed for " << vec_type.type_name
<< " with null_percent=" << null_percent;
auto local_chunk_manager =
LocalChunkManagerSingleton::GetInstance().GetChunkManager();
uint32_t read_num_rows = 0;
uint32_t read_dim = 0;
local_chunk_manager->Read(
local_data_path, 0, &read_num_rows, sizeof(read_num_rows));
local_chunk_manager->Read(local_data_path,
sizeof(read_num_rows),
&read_dim,
sizeof(read_dim));
EXPECT_EQ(read_num_rows, valid_count)
<< "Mismatch for " << vec_type.type_name
<< " with null_percent=" << null_percent;
EXPECT_EQ(read_dim, dim);
size_t bytes_per_vector =
(vec_type.data_type == DataType::VECTOR_BINARY)
? (dim / 8)
: (dim * vec_type.element_size);
auto data_size = read_num_rows * bytes_per_vector;
std::vector<uint8_t> buffer(data_size);
local_chunk_manager->Read(
local_data_path,
sizeof(read_num_rows) + sizeof(read_dim),
buffer.data(),
data_size);
EXPECT_EQ(buffer.size(), vec_data.size())
<< "Data size mismatch for " << vec_type.type_name;
for (size_t i = 0; i < std::min(buffer.size(), vec_data.size());
++i) {
EXPECT_EQ(buffer[i], vec_data[i])
<< "Data mismatch at byte " << i << " for "
<< vec_type.type_name
<< " with null_percent=" << null_percent;
}
local_chunk_manager->Remove(local_data_path);
}
cm_->Remove(insert_file_path);
}
}
}
TEST_F(DiskAnnFileManagerTest, FileCleanup) {
std::string local_index_file_path;
std::string local_text_index_file_path;

View File

@ -336,9 +336,13 @@ BaseEventData::Serialize() {
auto row = static_cast<
const knowhere::sparse::SparseRow<SparseValueType>*>(
field_data->RawValue(offset));
payload_writer->add_one_binary_payload(
static_cast<const uint8_t*>(row->data()),
row->data_byte_size());
if (row) {
payload_writer->add_one_binary_payload(
static_cast<const uint8_t*>(row->data()),
row->data_byte_size());
} else {
payload_writer->add_one_binary_payload(nullptr, -1);
}
}
break;
}

View File

@ -71,13 +71,44 @@ PayloadReader::init(const uint8_t* data, int length, bool is_field_data) {
auto file_meta = arrow_reader->parquet_reader()->metadata();
// dim is unused for sparse float vector
dim_ =
(IsVectorDataType(column_type_) &&
!IsVectorArrayDataType(column_type_) &&
!IsSparseFloatVectorDataType(column_type_))
? GetDimensionFromFileMetaData(
file_meta->schema()->Column(column_index), column_type_)
: 1;
// For nullable vectors, dim is stored in Arrow schema metadata
if (IsVectorDataType(column_type_) &&
!IsVectorArrayDataType(column_type_) &&
!IsSparseFloatVectorDataType(column_type_)) {
if (nullable_) {
std::shared_ptr<arrow::Schema> arrow_schema;
auto st = arrow_reader->GetSchema(&arrow_schema);
AssertInfo(st.ok(), "Failed to get arrow schema");
AssertInfo(arrow_schema->num_fields() == 1,
"Vector field should have exactly 1 field, got {}",
arrow_schema->num_fields());
auto field = arrow_schema->field(0);
if (field->HasMetadata()) {
auto metadata = field->metadata();
if (metadata->Contains(DIM_KEY)) {
auto dim_str = metadata->Get(DIM_KEY).ValueOrDie();
dim_ = std::stoi(dim_str);
AssertInfo(
dim_ > 0,
"nullable vector dim must be positive, got {}",
dim_);
} else {
ThrowInfo(DataTypeInvalid,
"nullable vector field metadata missing "
"required 'dim' field");
}
} else {
ThrowInfo(DataTypeInvalid,
"nullable vector field is missing metadata");
}
} else {
dim_ = GetDimensionFromFileMetaData(
file_meta->schema()->Column(column_index), column_type_);
}
} else {
dim_ = 1;
}
// For VectorArray, get element type and dim from Arrow schema metadata
auto element_type = DataType::NONE;
@ -133,8 +164,10 @@ PayloadReader::init(const uint8_t* data, int length, bool is_field_data) {
field_data_->FillFieldData(array);
}
AssertInfo(field_data_->IsFull(),
"field data hasn't been filled done");
if (!nullable_ || !IsVectorDataType(column_type_)) {
AssertInfo(field_data_->IsFull(),
"field data hasn't been filled done");
}
} else {
arrow_reader_ = std::move(arrow_reader);
record_batch_reader_ = std::move(rb_reader);

View File

@ -35,7 +35,6 @@ PayloadWriter::PayloadWriter(const DataType column_type, int dim, bool nullable)
AssertInfo(column_type != DataType::VECTOR_SPARSE_U32_F32,
"PayloadWriter for Sparse Float Vector should be created "
"using the constructor without dimension");
AssertInfo(nullable == false, "only scalcar type support null now");
init_dimension(dim);
}
@ -63,7 +62,7 @@ PayloadWriter::init_dimension(int dim) {
}
dimension_ = dim;
builder_ = CreateArrowBuilder(column_type_, element_type_, dim);
builder_ = CreateArrowBuilder(column_type_, element_type_, dim, nullable_);
schema_ = CreateArrowSchema(column_type_, dim, nullable_);
}
@ -112,8 +111,10 @@ PayloadWriter::finish() {
std::shared_ptr<parquet::ArrowWriterProperties> arrow_properties =
parquet::default_arrow_writer_properties();
if (column_type_ == DataType::VECTOR_ARRAY) {
// For VectorArray, we need to store schema metadata
if (column_type_ == DataType::VECTOR_ARRAY ||
(nullable_ && IsVectorDataType(column_type_) &&
!IsSparseFloatVectorDataType(column_type_))) {
// For VectorArray and nullable vectors, we need to store schema metadata
parquet::ArrowWriterProperties::Builder arrow_props_builder;
arrow_props_builder.store_schema();
arrow_properties = arrow_props_builder.build();

View File

@ -128,13 +128,40 @@ ReadMediumType(BinlogReaderPtr reader) {
void
add_vector_payload(std::shared_ptr<arrow::ArrayBuilder> builder,
uint8_t* values,
int length) {
const uint8_t* valid_data,
bool nullable,
int length,
int byte_width) {
AssertInfo(builder != nullptr, "empty arrow builder");
auto binary_builder =
std::dynamic_pointer_cast<arrow::FixedSizeBinaryBuilder>(builder);
auto ast = binary_builder->AppendValues(values, length);
AssertInfo(
ast.ok(), "append value to arrow builder failed: {}", ast.ToString());
AssertInfo((nullable && valid_data) || !nullable,
"valid_data is required for nullable vectors");
arrow::Status ast;
if (nullable) {
auto binary_builder =
std::dynamic_pointer_cast<arrow::BinaryBuilder>(builder);
int valid_index = 0;
for (int i = 0; i < length; ++i) {
auto bit = (valid_data[i >> 3] >> (i & 0x07)) & 1;
if (bit) {
ast = binary_builder->Append(values + valid_index * byte_width,
byte_width);
valid_index++;
} else {
ast = binary_builder->AppendNull();
}
AssertInfo(ast.ok(),
"append value to arrow builder failed: {}",
ast.ToString());
}
} else {
auto binary_builder =
std::dynamic_pointer_cast<arrow::FixedSizeBinaryBuilder>(builder);
ast = binary_builder->AppendValues(values, length);
AssertInfo(ast.ok(),
"append value to arrow builder failed: {}",
ast.ToString());
}
}
// append values for numeric data
@ -223,12 +250,64 @@ AddPayloadToArrowBuilder(std::shared_ptr<arrow::ArrayBuilder> builder,
break;
}
case DataType::VECTOR_FLOAT16:
case DataType::VECTOR_BFLOAT16:
case DataType::VECTOR_BINARY:
case DataType::VECTOR_INT8:
case DataType::VECTOR_FLOAT: {
add_vector_payload(builder, const_cast<uint8_t*>(raw_data), length);
AssertInfo(payload.dimension.has_value(),
"dimension is required for VECTOR_FLOAT");
int byte_width = payload.dimension.value() * sizeof(float);
add_vector_payload(builder,
const_cast<uint8_t*>(raw_data),
payload.valid_data,
nullable,
length,
byte_width);
break;
}
case DataType::VECTOR_BINARY: {
AssertInfo(payload.dimension.has_value(),
"dimension is required for VECTOR_BINARY");
int byte_width = (payload.dimension.value() + 7) / 8;
add_vector_payload(builder,
const_cast<uint8_t*>(raw_data),
payload.valid_data,
nullable,
length,
byte_width);
break;
}
case DataType::VECTOR_FLOAT16: {
AssertInfo(payload.dimension.has_value(),
"dimension is required for VECTOR_FLOAT16");
int byte_width = payload.dimension.value() * 2;
add_vector_payload(builder,
const_cast<uint8_t*>(raw_data),
payload.valid_data,
nullable,
length,
byte_width);
break;
}
case DataType::VECTOR_BFLOAT16: {
AssertInfo(payload.dimension.has_value(),
"dimension is required for VECTOR_BFLOAT16");
int byte_width = payload.dimension.value() * 2;
add_vector_payload(builder,
const_cast<uint8_t*>(raw_data),
payload.valid_data,
nullable,
length,
byte_width);
break;
}
case DataType::VECTOR_INT8: {
AssertInfo(payload.dimension.has_value(),
"dimension is required for VECTOR_INT8");
int byte_width = payload.dimension.value() * sizeof(int8_t);
add_vector_payload(builder,
const_cast<uint8_t*>(raw_data),
payload.valid_data,
nullable,
length,
byte_width);
break;
}
case DataType::VECTOR_SPARSE_U32_F32: {
@ -380,30 +459,48 @@ CreateArrowBuilder(DataType data_type) {
}
std::shared_ptr<arrow::ArrayBuilder>
CreateArrowBuilder(DataType data_type, DataType element_type, int dim) {
CreateArrowBuilder(DataType data_type,
DataType element_type,
int dim,
bool nullable) {
switch (static_cast<DataType>(data_type)) {
case DataType::VECTOR_FLOAT: {
AssertInfo(dim > 0, "invalid dim value: {}", dim);
if (nullable) {
return std::make_shared<arrow::BinaryBuilder>();
}
return std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(dim * sizeof(float)));
}
case DataType::VECTOR_BINARY: {
AssertInfo(dim % 8 == 0 && dim > 0, "invalid dim value: {}", dim);
if (nullable) {
return std::make_shared<arrow::BinaryBuilder>();
}
return std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(dim / 8));
}
case DataType::VECTOR_FLOAT16: {
AssertInfo(dim > 0, "invalid dim value: {}", dim);
if (nullable) {
return std::make_shared<arrow::BinaryBuilder>();
}
return std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(dim * sizeof(float16)));
}
case DataType::VECTOR_BFLOAT16: {
AssertInfo(dim > 0, "invalid dim value");
if (nullable) {
return std::make_shared<arrow::BinaryBuilder>();
}
return std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(dim * sizeof(bfloat16)));
}
case DataType::VECTOR_INT8: {
AssertInfo(dim > 0, "invalid dim value");
if (nullable) {
return std::make_shared<arrow::BinaryBuilder>();
}
return std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(dim * sizeof(int8)));
}
@ -576,6 +673,13 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) {
switch (static_cast<DataType>(data_type)) {
case DataType::VECTOR_FLOAT: {
AssertInfo(dim > 0, "invalid dim value: {}", dim);
if (nullable) {
auto metadata = std::shared_ptr<arrow::KeyValueMetadata>(
new arrow::KeyValueMetadata());
metadata->Append(DIM_KEY, std::to_string(dim));
return arrow::schema(
{arrow::field("val", arrow::binary(), nullable, metadata)});
}
return arrow::schema(
{arrow::field("val",
arrow::fixed_size_binary(dim * sizeof(float)),
@ -583,11 +687,25 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) {
}
case DataType::VECTOR_BINARY: {
AssertInfo(dim % 8 == 0 && dim > 0, "invalid dim value: {}", dim);
if (nullable) {
auto metadata = std::shared_ptr<arrow::KeyValueMetadata>(
new arrow::KeyValueMetadata());
metadata->Append(DIM_KEY, std::to_string(dim));
return arrow::schema(
{arrow::field("val", arrow::binary(), nullable, metadata)});
}
return arrow::schema({arrow::field(
"val", arrow::fixed_size_binary(dim / 8), nullable)});
}
case DataType::VECTOR_FLOAT16: {
AssertInfo(dim > 0, "invalid dim value: {}", dim);
if (nullable) {
auto metadata = std::shared_ptr<arrow::KeyValueMetadata>(
new arrow::KeyValueMetadata());
metadata->Append(DIM_KEY, std::to_string(dim));
return arrow::schema(
{arrow::field("val", arrow::binary(), nullable, metadata)});
}
return arrow::schema(
{arrow::field("val",
arrow::fixed_size_binary(dim * sizeof(float16)),
@ -595,6 +713,13 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) {
}
case DataType::VECTOR_BFLOAT16: {
AssertInfo(dim > 0, "invalid dim value");
if (nullable) {
auto metadata = std::shared_ptr<arrow::KeyValueMetadata>(
new arrow::KeyValueMetadata());
metadata->Append(DIM_KEY, std::to_string(dim));
return arrow::schema(
{arrow::field("val", arrow::binary(), nullable, metadata)});
}
return arrow::schema(
{arrow::field("val",
arrow::fixed_size_binary(dim * sizeof(bfloat16)),
@ -606,6 +731,13 @@ CreateArrowSchema(DataType data_type, int dim, bool nullable) {
}
case DataType::VECTOR_INT8: {
AssertInfo(dim > 0, "invalid dim value");
if (nullable) {
auto metadata = std::shared_ptr<arrow::KeyValueMetadata>(
new arrow::KeyValueMetadata());
metadata->Append(DIM_KEY, std::to_string(dim));
return arrow::schema(
{arrow::field("val", arrow::binary(), nullable, metadata)});
}
return arrow::schema(
{arrow::field("val",
arrow::fixed_size_binary(dim * sizeof(int8)),
@ -1103,22 +1235,22 @@ CreateFieldData(const DataType& type,
type, nullable, total_num_rows);
case DataType::VECTOR_FLOAT:
return std::make_shared<FieldData<FloatVector>>(
dim, type, total_num_rows);
dim, type, nullable, total_num_rows);
case DataType::VECTOR_BINARY:
return std::make_shared<FieldData<BinaryVector>>(
dim, type, total_num_rows);
dim, type, nullable, total_num_rows);
case DataType::VECTOR_FLOAT16:
return std::make_shared<FieldData<Float16Vector>>(
dim, type, total_num_rows);
dim, type, nullable, total_num_rows);
case DataType::VECTOR_BFLOAT16:
return std::make_shared<FieldData<BFloat16Vector>>(
dim, type, total_num_rows);
dim, type, nullable, total_num_rows);
case DataType::VECTOR_SPARSE_U32_F32:
return std::make_shared<FieldData<SparseFloatVector>>(
type, total_num_rows);
type, nullable, total_num_rows);
case DataType::VECTOR_INT8:
return std::make_shared<FieldData<Int8Vector>>(
dim, type, total_num_rows);
dim, type, nullable, total_num_rows);
case DataType::VECTOR_ARRAY:
return std::make_shared<FieldData<VectorArray>>(
dim, element_type, total_num_rows);

View File

@ -59,7 +59,10 @@ std::shared_ptr<arrow::ArrayBuilder>
CreateArrowBuilder(DataType data_type);
std::shared_ptr<arrow::ArrayBuilder>
CreateArrowBuilder(DataType data_type, DataType element_type, int dim);
CreateArrowBuilder(DataType data_type,
DataType element_type,
int dim,
bool nullable = false);
/// \brief Utility function to create arrow:Scalar from FieldMeta.default_value
///

View File

@ -326,6 +326,10 @@ GenerateRandomSparseFloatVector(size_t rows,
size_t cols = kTestSparseDim,
float density = kTestSparseVectorDensity,
int seed = 42) {
if (rows == 0) {
return std::make_unique<
knowhere::sparse::SparseRow<milvus::SparseValueType>[]>(0);
}
int32_t num_elements = static_cast<int32_t>(rows * cols * density);
std::mt19937 rng(seed);
@ -542,7 +546,8 @@ DataGen(SchemaPtr schema,
int group_count = 1,
bool random_pk = false,
bool random_val = true,
bool random_valid = false) {
bool random_valid = false,
int null_percent = 50) {
using std::vector;
std::default_random_engine random(seed);
std::normal_distribution<> distr(0, 1);
@ -635,41 +640,138 @@ DataGen(SchemaPtr schema,
return data;
};
auto generate_valid_data = [&](const FieldMeta& field_meta, int64_t N) {
struct Result {
int64_t valid_count;
FixedVector<bool> valid_data;
};
Result result;
result.valid_data.resize(N);
result.valid_count = 0;
bool is_nullable = field_meta.is_nullable();
if (is_nullable) {
for (int i = 0; i < N; ++i) {
if (random_valid) {
int x = rand();
result.valid_data[i] = x % 2 == 0 ? true : false;
} else {
result.valid_data[i] = (i % 100) >= null_percent;
}
if (result.valid_data[i]) {
result.valid_count++;
}
}
} else {
result.valid_count = N;
}
return result;
};
for (auto field_id : schema->get_field_ids()) {
auto field_meta = schema->operator[](field_id);
switch (field_meta.get_data_type()) {
case DataType::VECTOR_FLOAT: {
auto data = generate_float_vector(field_meta, N);
insert_cols(data, N, field_meta, random_valid);
auto [valid_count, valid_data] =
generate_valid_data(field_meta, N);
bool is_nullable = field_meta.is_nullable();
auto data = generate_float_vector(field_meta, valid_count);
auto array = milvus::segcore::CreateVectorDataArrayFrom(
data.data(),
is_nullable ? valid_data.data() : nullptr,
N,
valid_count,
field_meta);
insert_data->mutable_fields_data()->AddAllocated(
array.release());
break;
}
case DataType::VECTOR_BINARY: {
auto data = generate_binary_vector(field_meta, N);
insert_cols(data, N, field_meta, random_valid);
auto [valid_count, valid_data] =
generate_valid_data(field_meta, N);
bool is_nullable = field_meta.is_nullable();
auto data = generate_binary_vector(field_meta, valid_count);
auto array = milvus::segcore::CreateVectorDataArrayFrom(
data.data(),
is_nullable ? valid_data.data() : nullptr,
N,
valid_count,
field_meta);
insert_data->mutable_fields_data()->AddAllocated(
array.release());
break;
}
case DataType::VECTOR_FLOAT16: {
auto data = generate_float16_vector(field_meta, N);
insert_cols(data, N, field_meta, random_valid);
auto [valid_count, valid_data] =
generate_valid_data(field_meta, N);
bool is_nullable = field_meta.is_nullable();
auto data = generate_float16_vector(field_meta, valid_count);
auto array = milvus::segcore::CreateVectorDataArrayFrom(
data.data(),
is_nullable ? valid_data.data() : nullptr,
N,
valid_count,
field_meta);
insert_data->mutable_fields_data()->AddAllocated(
array.release());
break;
}
case DataType::VECTOR_BFLOAT16: {
auto data = generate_bfloat16_vector(field_meta, N);
insert_cols(data, N, field_meta, random_valid);
auto [valid_count, valid_data] =
generate_valid_data(field_meta, N);
bool is_nullable = field_meta.is_nullable();
auto data = generate_bfloat16_vector(field_meta, valid_count);
auto array = milvus::segcore::CreateVectorDataArrayFrom(
data.data(),
is_nullable ? valid_data.data() : nullptr,
N,
valid_count,
field_meta);
insert_data->mutable_fields_data()->AddAllocated(
array.release());
break;
}
case DataType::VECTOR_SPARSE_U32_F32: {
auto res = GenerateRandomSparseFloatVector(
N, kTestSparseDim, kTestSparseVectorDensity, seed);
auto array = milvus::segcore::CreateDataArrayFrom(
res.get(), nullptr, N, field_meta);
auto [valid_count, valid_data] =
generate_valid_data(field_meta, N);
bool is_nullable = field_meta.is_nullable();
auto res =
GenerateRandomSparseFloatVector(valid_count,
kTestSparseDim,
kTestSparseVectorDensity,
seed);
auto array = milvus::segcore::CreateVectorDataArrayFrom(
res.get(),
is_nullable ? valid_data.data() : nullptr,
N,
valid_count,
field_meta);
insert_data->mutable_fields_data()->AddAllocated(
array.release());
break;
}
case DataType::VECTOR_INT8: {
auto data = generate_int8_vector(field_meta, N);
insert_cols(data, N, field_meta, random_valid);
auto [valid_count, valid_data] =
generate_valid_data(field_meta, N);
bool is_nullable = field_meta.is_nullable();
auto data = generate_int8_vector(field_meta, valid_count);
auto array = milvus::segcore::CreateVectorDataArrayFrom(
data.data(),
is_nullable ? valid_data.data() : nullptr,
N,
valid_count,
field_meta);
insert_data->mutable_fields_data()->AddAllocated(
array.release());
break;
}

View File

@ -157,13 +157,13 @@ PrepareSingleFieldInsertBinlog(int64_t collection_id,
int64_t row_count = 0;
for (auto i = 0; i < field_datas.size(); ++i) {
auto& field_data = field_datas[i];
row_count += field_data->Length();
row_count += field_data->get_num_rows();
auto file = "./data/test/" + std::to_string(collection_id) + "/" +
std::to_string(partition_id) + "/" +
std::to_string(segment_id) + "/" +
std::to_string(field_id) + "/" + std::to_string(i);
files.push_back(file);
row_counts.push_back(field_data->Length());
row_counts.push_back(field_data->get_num_rows());
auto payload_reader =
std::make_shared<milvus::storage::PayloadReader>(field_data);
auto insert_data = std::make_shared<InsertData>(payload_reader);

View File

@ -141,7 +141,7 @@ func (s *InsertBufferSuite) TestBuffer() {
memSize := insertBuffer.Buffer(groups[0], &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200})
s.EqualValues(100, insertBuffer.MinTimestamp())
s.EqualValues(5367, memSize)
s.EqualValues(5376, memSize)
}
func (s *InsertBufferSuite) TestYield() {

View File

@ -195,12 +195,12 @@ func (s *L0WriteBufferSuite) TestBufferData() {
value, err := metrics.DataNodeFlowGraphBufferDataSize.GetMetricWithLabelValues(fmt.Sprint(paramtable.GetNodeID()), fmt.Sprint(s.metacache.Collection()))
s.NoError(err)
s.MetricsEqual(value, 5607)
s.MetricsEqual(value, 5616)
delMsg = s.composeDeleteMsg(lo.Map(pks, func(id int64, _ int) storage.PrimaryKey { return storage.NewInt64PrimaryKey(id) }))
err = wb.BufferData([]*InsertData{}, []*msgstream.DeleteMsg{delMsg}, &msgpb.MsgPosition{Timestamp: 100}, &msgpb.MsgPosition{Timestamp: 200})
s.NoError(err)
s.MetricsEqual(value, 5847)
s.MetricsEqual(value, 5856)
})
}

View File

@ -65,11 +65,15 @@ func genInsertMsgsByPartition(ctx context.Context,
return msg
}
fieldsData := insertMsg.GetFieldsData()
idxComputer := typeutil.NewFieldDataIdxComputer(fieldsData)
repackedMsgs := make([]msgstream.TsMsg, 0)
requestSize := 0
msg := createInsertMsg(segmentID, channelName)
for _, offset := range rowOffsets {
curRowMessageSize, err := typeutil.EstimateEntitySize(insertMsg.GetFieldsData(), offset)
fieldIdxs := idxComputer.Compute(int64(offset))
curRowMessageSize, err := typeutil.EstimateEntitySize(fieldsData, offset, fieldIdxs...)
if err != nil {
return nil, err
}
@ -81,7 +85,7 @@ func genInsertMsgsByPartition(ctx context.Context,
requestSize = 0
}
typeutil.AppendFieldData(msg.FieldsData, insertMsg.GetFieldsData(), int64(offset))
typeutil.AppendFieldData(msg.FieldsData, fieldsData, int64(offset), fieldIdxs...)
msg.HashValues = append(msg.HashValues, insertMsg.HashValues[offset])
msg.Timestamps = append(msg.Timestamps, insertMsg.Timestamps[offset])
msg.RowIDs = append(msg.RowIDs, insertMsg.RowIDs[offset])

View File

@ -258,6 +258,11 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error())
}
idxComputers := make([]*typeutil.FieldDataIdxComputer, subSearchNum)
for i, srd := range subSearchResultData {
idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData)
}
var realTopK int64 = -1
var retSize int64
@ -316,7 +321,8 @@ func reduceSearchResultDataWithGroupBy(ctx context.Context, subSearchResultData
for _, groupEntity := range groupEntities {
subResData := subSearchResultData[groupEntity.subSearchIdx]
if len(ret.Results.FieldsData) > 0 {
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx)
fieldIdxs := idxComputers[groupEntity.subSearchIdx].Compute(groupEntity.resultIdx)
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subResData.FieldsData, groupEntity.resultIdx, fieldIdxs...)
}
typeutil.AppendPKs(ret.Results.Ids, groupEntity.id)
ret.Results.Scores = append(ret.Results.Scores, groupEntity.score)
@ -424,6 +430,12 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
subSearchNqOffset[i][j] = subSearchNqOffset[i][j-1] + subSearchResultData[i].Topks[j-1]
}
}
idxComputers := make([]*typeutil.FieldDataIdxComputer, subSearchNum)
for i, srd := range subSearchResultData {
idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData)
}
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
// reducing nq * topk results
for i := int64(0); i < nq; i++ {
@ -456,7 +468,9 @@ func reduceSearchResultDataNoGroupBy(ctx context.Context, subSearchResultData []
score := subSearchResultData[subSearchIdx].Scores[resultDataIdx]
if len(ret.Results.FieldsData) > 0 {
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, subSearchResultData[subSearchIdx].FieldsData, resultDataIdx)
fieldsData := subSearchResultData[subSearchIdx].FieldsData
fieldIdxs := idxComputers[subSearchIdx].Compute(resultDataIdx)
retSize += typeutil.AppendFieldData(ret.Results.FieldsData, fieldsData, resultDataIdx, fieldIdxs...)
}
typeutil.CopyPk(ret.Results.Ids, subSearchResultData[subSearchIdx].GetIds(), int(resultDataIdx))
ret.Results.Scores = append(ret.Results.Scores, score)

View File

@ -563,9 +563,6 @@ func (t *addCollectionFieldTask) PreExecute(ctx context.Context) error {
return merr.WrapErrParameterInvalid("valid field", fmt.Sprintf("field data type: %s is not supported", t.fieldSchema.GetDataType()))
}
if typeutil.IsVectorType(t.fieldSchema.DataType) {
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("not support to add vector field, field name = %s", t.fieldSchema.Name))
}
if funcutil.SliceContain([]string{common.RowIDFieldName, common.TimeStampFieldName, common.MetaFieldName, common.NamespaceFieldName}, t.fieldSchema.GetName()) {
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("not support to add system field, field name = %s", t.fieldSchema.Name))
}
@ -575,6 +572,17 @@ func (t *addCollectionFieldTask) PreExecute(ctx context.Context) error {
if !t.fieldSchema.Nullable {
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("added field must be nullable, please check it, field name = %s", t.fieldSchema.Name))
}
if typeutil.IsVectorType(t.fieldSchema.DataType) && t.fieldSchema.Nullable {
if t.fieldSchema.DataType == schemapb.DataType_FloatVector ||
t.fieldSchema.DataType == schemapb.DataType_Float16Vector ||
t.fieldSchema.DataType == schemapb.DataType_BFloat16Vector ||
t.fieldSchema.DataType == schemapb.DataType_BinaryVector ||
t.fieldSchema.DataType == schemapb.DataType_Int8Vector {
if len(t.fieldSchema.TypeParams) == 0 {
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("vector field must have dimension specified, field name = %s", t.fieldSchema.Name))
}
}
}
if t.fieldSchema.AutoID {
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("only primary field can speficy AutoID with true, field name = %s", t.fieldSchema.Name))
}

View File

@ -790,6 +790,10 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
}
cursors := make([]int64, len(validRetrieveResults))
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(validRetrieveResults))
for i, vr := range validRetrieveResults {
idxComputers[i] = typeutil.NewFieldDataIdxComputer(vr.GetFieldsData())
}
if queryParams != nil && queryParams.limit != typeutil.Unlimited {
// IReduceInOrderForBest will try to get as many results as possible
@ -819,7 +823,8 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re
if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) {
break
}
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
fieldIdxs := idxComputers[sel].Compute(cursors[sel])
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel], fieldIdxs...)
// limit retrieve result to avoid oom
if retSize > maxOutputSize {

View File

@ -1009,16 +1009,30 @@ func TestAddFieldTask(t *testing.T) {
assert.Error(t, err)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
// not support vector field
fSchema = &schemapb.FieldSchema{
Name: "vec_field",
DataType: schemapb.DataType_FloatVector,
Nullable: true,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "128"},
},
}
bytes, err = proto.Marshal(fSchema)
assert.NoError(t, err)
task.Schema = bytes
err = task.PreExecute(ctx)
assert.Error(t, err)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
assert.NoError(t, err)
fSchema = &schemapb.FieldSchema{
Name: "sparse_vec",
DataType: schemapb.DataType_SparseFloatVector,
Nullable: true,
}
bytes, err = proto.Marshal(fSchema)
assert.NoError(t, err)
task.Schema = bytes
err = task.PreExecute(ctx)
assert.NoError(t, err)
// not support system field
fSchema = &schemapb.FieldSchema{
@ -2595,11 +2609,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
})
t.Run("upsert", func(t *testing.T) {
// upsert require pk unique in same batch
hash := make([]uint32, nb)
for i := 0; i < nb; i++ {
hash[i] = uint32(i)
}
hash := testutils.GenerateHashKeys(nb)
task := &upsertTask{
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &BaseInsertTask{

View File

@ -392,6 +392,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
}
baseIdx := 0
idxComputer := typeutil.NewFieldDataIdxComputer(existFieldData)
for _, idx := range updateIdxInUpsert {
typeutil.AppendIDs(it.deletePKs, upsertIDs, idx)
oldPK := typeutil.GetPK(upsertIDs, int64(idx))
@ -399,7 +400,8 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
if !ok {
return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping")
}
typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex))
fieldIdxs := idxComputer.Compute(int64(existIndex))
typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex), fieldIdxs...)
err := typeutil.UpdateFieldData(it.insertFieldData, it.upsertMsg.InsertMsg.GetFieldsData(), int64(baseIdx), int64(idx))
baseIdx += 1
if err != nil {
@ -438,8 +440,32 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
insertWithNullField = append(insertWithNullField, fieldData)
}
}
for _, idx := range insertIdxInUpsert {
typeutil.AppendFieldData(it.insertFieldData, insertWithNullField, int64(idx))
vectorIdxMap := make([][]int64, len(insertIdxInUpsert))
for rowIdx, offset := range insertIdxInUpsert {
vectorIdxMap[rowIdx] = make([]int64, len(insertWithNullField))
for fieldIdx := range insertWithNullField {
vectorIdxMap[rowIdx][fieldIdx] = int64(offset)
}
}
for fieldIdx, fieldData := range insertWithNullField {
validData := fieldData.GetValidData()
if len(validData) > 0 && typeutil.IsVectorType(fieldData.Type) {
dataIdx := int64(0)
rowIdx := 0
for i := 0; i < len(validData) && rowIdx < len(insertIdxInUpsert); i++ {
if i == insertIdxInUpsert[rowIdx] {
vectorIdxMap[rowIdx][fieldIdx] = dataIdx
rowIdx++
}
if validData[i] {
dataIdx++
}
}
}
}
for rowIdx, idx := range insertIdxInUpsert {
typeutil.AppendFieldData(it.insertFieldData, insertWithNullField, int64(idx), vectorIdxMap[rowIdx]...)
}
}
@ -620,6 +646,10 @@ func ToCompressedFormatNullable(field *schemapb.FieldData) error {
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
}
case *schemapb.FieldData_Vectors:
// Vector data is already in compressed format, skip
return nil
default:
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
}
@ -1077,7 +1107,7 @@ func (it *upsertTask) PreExecute(ctx context.Context) error {
}
}
// deduplicate upsert data to handle duplicate primary keys in the same batch
// check for duplicate primary keys in the same batch
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
if err != nil {
log.Warn("fail to get primary field schema", zap.Error(err))

View File

@ -141,7 +141,7 @@ func (v *validateUtil) Validate(data []*schemapb.FieldData, helper *typeutil.Sch
return err
}
case schemapb.DataType_SparseFloatVector:
if err := v.checkSparseFloatFieldData(field, fieldSchema); err != nil {
if err := v.checkSparseFloatVectorFieldData(field, fieldSchema); err != nil {
return err
}
case schemapb.DataType_Int8Vector:
@ -219,6 +219,13 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
msg := fmt.Sprintf("the dim (%d) of field data(%s) is not equal to schema dim (%d)", dataDim, fieldName, schemaDim)
return merr.WrapErrParameterInvalid(schemaDim, dataDim, msg)
}
getExpectedVectorRows := func(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) uint64 {
validData := field.GetValidData()
if fieldSchema.GetNullable() && len(validData) > 0 {
return uint64(getValidNumber(validData))
}
return numRows
}
for _, field := range data {
switch field.GetType() {
case schemapb.DataType_FloatVector:
@ -241,7 +248,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
return errDimMismatch(field.GetFieldName(), dataDim, dim)
}
if n != numRows {
expectedRows := getExpectedVectorRows(field, f)
if n != expectedRows {
return errNumRowsMismatch(field.GetFieldName(), n)
}
@ -265,7 +273,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
return err
}
if n != numRows {
expectedRows := getExpectedVectorRows(field, f)
if n != expectedRows {
return errNumRowsMismatch(field.GetFieldName(), n)
}
@ -289,7 +298,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
return err
}
if n != numRows {
expectedRows := getExpectedVectorRows(field, f)
if n != expectedRows {
return errNumRowsMismatch(field.GetFieldName(), n)
}
@ -313,13 +323,19 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
return err
}
if n != numRows {
expectedRows := getExpectedVectorRows(field, f)
if n != expectedRows {
return errNumRowsMismatch(field.GetFieldName(), n)
}
case schemapb.DataType_SparseFloatVector:
f, err := schema.GetFieldFromName(field.GetFieldName())
if err != nil {
return err
}
n := uint64(len(field.GetVectors().GetSparseFloatVector().Contents))
if n != numRows {
expectedRows := getExpectedVectorRows(field, f)
if n != expectedRows {
return errNumRowsMismatch(field.GetFieldName(), n)
}
@ -343,7 +359,8 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
return errDimMismatch(field.GetFieldName(), dataDim, dim)
}
if n != numRows {
expectedRows := getExpectedVectorRows(field, f)
if n != expectedRows {
return errNumRowsMismatch(field.GetFieldName(), n)
}
@ -728,7 +745,7 @@ func getValidNumber(validData []bool) int {
func (v *validateUtil) checkFloatVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
floatArray := field.GetVectors().GetFloatVector().GetData()
if floatArray == nil {
if floatArray == nil && !fieldSchema.GetNullable() {
msg := fmt.Sprintf("float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need float vector", "got nil", msg)
}
@ -743,8 +760,11 @@ func (v *validateUtil) checkFloatVectorFieldData(field *schemapb.FieldData, fiel
func (v *validateUtil) checkFloat16VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
float16VecArray := field.GetVectors().GetFloat16Vector()
if float16VecArray == nil {
msg := fmt.Sprintf("float16 float field '%v' is illegal, nil Vector_Float16 type", field.GetFieldName())
return merr.WrapErrParameterInvalid("need vector_float16 array", "got nil", msg)
if !fieldSchema.GetNullable() {
msg := fmt.Sprintf("float16 vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need float16 vector", "got nil", msg)
}
return nil
}
if v.checkNAN {
return typeutil.VerifyFloats16(float16VecArray)
@ -755,8 +775,11 @@ func (v *validateUtil) checkFloat16VectorFieldData(field *schemapb.FieldData, fi
func (v *validateUtil) checkBFloat16VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
bfloat16VecArray := field.GetVectors().GetBfloat16Vector()
if bfloat16VecArray == nil {
msg := fmt.Sprintf("bfloat16 float field '%v' is illegal, nil Vector_BFloat16 type", field.GetFieldName())
return merr.WrapErrParameterInvalid("need vector_bfloat16 array", "got nil", msg)
if !fieldSchema.GetNullable() {
msg := fmt.Sprintf("bfloat16 vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need bfloat16 vector", "got nil", msg)
}
return nil
}
if v.checkNAN {
return typeutil.VerifyBFloats16(bfloat16VecArray)
@ -766,31 +789,33 @@ func (v *validateUtil) checkBFloat16VectorFieldData(field *schemapb.FieldData, f
func (v *validateUtil) checkBinaryVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
bVecArray := field.GetVectors().GetBinaryVector()
if bVecArray == nil {
msg := fmt.Sprintf("binary float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need bytes array", "got nil", msg)
if bVecArray == nil && !fieldSchema.GetNullable() {
msg := fmt.Sprintf("binary vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need binary vector", "got nil", msg)
}
return nil
}
func (v *validateUtil) checkSparseFloatFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
func (v *validateUtil) checkSparseFloatVectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
if field.GetVectors() == nil || field.GetVectors().GetSparseFloatVector() == nil {
msg := fmt.Sprintf("sparse float field '%v' is illegal, nil SparseFloatVector", field.GetFieldName())
return merr.WrapErrParameterInvalid("need sparse float array", "got nil", msg)
if !fieldSchema.GetNullable() {
msg := fmt.Sprintf("sparse float vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need sparse float vector", "got nil", msg)
}
return nil
}
sparseRows := field.GetVectors().GetSparseFloatVector().GetContents()
if sparseRows == nil {
msg := fmt.Sprintf("sparse float field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need sparse float array", "got nil", msg)
}
return typeutil.ValidateSparseFloatRows(sparseRows...)
}
func (v *validateUtil) checkInt8VectorFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error {
int8VecArray := field.GetVectors().GetInt8Vector()
if int8VecArray == nil {
msg := fmt.Sprintf("int8 vector field '%v' is illegal, nil Vector_Int8 type", field.GetFieldName())
return merr.WrapErrParameterInvalid("need vector_int8 array", "got nil", msg)
if !fieldSchema.GetNullable() {
msg := fmt.Sprintf("int8 vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need int8 vector", "got nil", msg)
}
return nil
}
return nil
}

View File

@ -310,23 +310,77 @@ func Test_validateUtil_checkTextFieldData(t *testing.T) {
}
func Test_validateUtil_checkBinaryVectorFieldData(t *testing.T) {
v := newValidateUtil()
assert.Error(t, v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil))
assert.NoError(t, v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 128,
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: []byte(strings.Repeat("1", 128)),
t.Run("not binary vector", func(t *testing.T) {
v := newValidateUtil()
err := v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
v := newValidateUtil()
err := v.checkBinaryVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 128,
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: []byte(strings.Repeat("1", 128)),
},
},
},
}}, nil))
}}, nil)
assert.NoError(t, err)
})
t.Run("nil vector not nullable", func(t *testing.T) {
data := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_BinaryVector,
Nullable: false,
}
v := newValidateUtil()
err := v.checkBinaryVectorFieldData(data, schema)
assert.Error(t, err)
})
t.Run("nil vector nullable", func(t *testing.T) {
data := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_BinaryVector,
Nullable: true,
}
v := newValidateUtil()
err := v.checkBinaryVectorFieldData(data, schema)
assert.NoError(t, err)
})
}
func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
nb := 5
dim := int64(8)
data := testutils.GenerateFloatVectors(nb, int(dim))
invalidData := testutils.GenerateFloatVectorsWithInvalidData(nb, int(dim))
t.Run("not float vector", func(t *testing.T) {
f := &schemapb.FieldData{}
v := newValidateUtil()
err := v.checkFloatVectorFieldData(f, nil)
err := v.checkFloatVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
assert.Error(t, err)
})
@ -336,7 +390,7 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{1.1, 2.2},
Data: invalidData,
},
},
},
@ -354,7 +408,7 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{float32(math.NaN())},
Data: invalidData,
},
},
},
@ -371,7 +425,7 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: []float32{1.1, 2.2},
Data: data,
},
},
},
@ -409,6 +463,49 @@ func Test_validateUtil_checkFloatVectorFieldData(t *testing.T) {
err = v.fillWithValue(data, h, 1)
assert.Error(t, err)
})
t.Run("nil vector not nullable", func(t *testing.T) {
data := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_FloatVector,
Nullable: false,
}
v := newValidateUtil()
err := v.checkFloatVectorFieldData(data, schema)
assert.Error(t, err)
})
t.Run("nil vector nullable", func(t *testing.T) {
data := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_FloatVector{
FloatVector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_FloatVector,
Nullable: true,
}
v := newValidateUtil()
err := v.checkFloatVectorFieldData(data, schema)
assert.NoError(t, err)
})
}
func Test_validateUtil_checkFloat16VectorFieldData(t *testing.T) {
@ -418,9 +515,8 @@ func Test_validateUtil_checkFloat16VectorFieldData(t *testing.T) {
invalidData := testutils.GenerateFloat16VectorsWithInvalidData(nb, int(dim))
t.Run("not float16 vector", func(t *testing.T) {
f := &schemapb.FieldData{}
v := newValidateUtil()
err := v.checkFloat16VectorFieldData(f, nil)
err := v.checkFloat16VectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
assert.Error(t, err)
})
@ -500,17 +596,60 @@ func Test_validateUtil_checkFloat16VectorFieldData(t *testing.T) {
err = v.fillWithValue(data, h, 1)
assert.Error(t, err)
})
t.Run("nil vector not nullable", func(t *testing.T) {
data := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_Float16Vector{
Float16Vector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_Float16Vector,
Nullable: false,
}
v := newValidateUtil()
err := v.checkFloat16VectorFieldData(data, schema)
assert.Error(t, err)
})
t.Run("nil vector nullable", func(t *testing.T) {
data := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_Float16Vector{
Float16Vector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_Float16Vector,
Nullable: true,
}
v := newValidateUtil()
err := v.checkFloat16VectorFieldData(data, schema)
assert.NoError(t, err)
})
}
func Test_validateUtil_checkBfloatVectorFieldData(t *testing.T) {
func Test_validateUtil_checkBFloat16VectorFieldData(t *testing.T) {
nb := 5
dim := int64(8)
data := testutils.GenerateFloat16Vectors(nb, int(dim))
data := testutils.GenerateBFloat16Vectors(nb, int(dim))
invalidData := testutils.GenerateBFloat16VectorsWithInvalidData(nb, int(dim))
t.Run("not float vector", func(t *testing.T) {
f := &schemapb.FieldData{}
t.Run("not bfloat16 vector", func(t *testing.T) {
v := newValidateUtil()
err := v.checkBFloat16VectorFieldData(f, nil)
err := v.checkBFloat16VectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
assert.Error(t, err)
})
@ -590,6 +729,203 @@ func Test_validateUtil_checkBfloatVectorFieldData(t *testing.T) {
err = v.fillWithValue(data, h, 1)
assert.Error(t, err)
})
t.Run("nil vector not nullable", func(t *testing.T) {
data := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_BFloat16Vector,
Nullable: false,
}
v := newValidateUtil()
err := v.checkBFloat16VectorFieldData(data, schema)
assert.Error(t, err)
})
t.Run("nil vector nullable", func(t *testing.T) {
data := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_BFloat16Vector,
Nullable: true,
}
v := newValidateUtil()
err := v.checkBFloat16VectorFieldData(data, schema)
assert.NoError(t, err)
})
}
func Test_validateUtil_checkSparseFloatVectorFieldData(t *testing.T) {
nb := 5
sparseContents, dim := testutils.GenerateSparseFloatVectorsData(nb)
t.Run("not sparse float vector", func(t *testing.T) {
v := newValidateUtil()
err := v.checkSparseFloatVectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
fieldData := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_SparseFloatVector{
SparseFloatVector: &schemapb.SparseFloatArray{
Contents: sparseContents,
Dim: dim,
},
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_SparseFloatVector,
}
v := newValidateUtil()
err := v.checkSparseFloatVectorFieldData(fieldData, schema)
assert.NoError(t, err)
})
t.Run("nil vector not nullable", func(t *testing.T) {
data := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_SparseFloatVector{
SparseFloatVector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_SparseFloatVector,
Nullable: false,
}
v := newValidateUtil()
err := v.checkSparseFloatVectorFieldData(data, schema)
assert.Error(t, err)
})
t.Run("nil vector nullable", func(t *testing.T) {
data := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_SparseFloatVector{
SparseFloatVector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_SparseFloatVector,
Nullable: true,
}
v := newValidateUtil()
err := v.checkSparseFloatVectorFieldData(data, schema)
assert.NoError(t, err)
})
}
func Test_validateUtil_checkInt8VectorFieldData(t *testing.T) {
nb := 5
dim := int64(8)
data := typeutil.Int8ArrayToBytes(testutils.GenerateInt8Vectors(nb, int(dim)))
t.Run("not int8 vector", func(t *testing.T) {
v := newValidateUtil()
err := v.checkInt8VectorFieldData(&schemapb.FieldData{Field: &schemapb.FieldData_Scalars{}}, nil)
assert.Error(t, err)
})
t.Run("normal case", func(t *testing.T) {
fieldData := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 128,
Data: &schemapb.VectorField_Int8Vector{
Int8Vector: data,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_Int8Vector,
}
v := newValidateUtil()
err := v.checkInt8VectorFieldData(fieldData, schema)
assert.NoError(t, err)
})
t.Run("nil vector not nullable", func(t *testing.T) {
fieldData := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_Int8Vector{
Int8Vector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_Int8Vector,
Nullable: false,
}
v := newValidateUtil()
err := v.checkInt8VectorFieldData(fieldData, schema)
assert.Error(t, err)
})
t.Run("nil vector nullable", func(t *testing.T) {
fieldData := &schemapb.FieldData{
FieldName: "vec",
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Data: &schemapb.VectorField_Int8Vector{
Int8Vector: nil,
},
},
},
}
schema := &schemapb.FieldSchema{
Name: "vec",
DataType: schemapb.DataType_Int8Vector,
Nullable: true,
}
v := newValidateUtil()
err := v.checkInt8VectorFieldData(fieldData, schema)
assert.NoError(t, err)
})
}
func Test_validateUtil_checkAligned(t *testing.T) {

View File

@ -325,6 +325,11 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
idTsMap := make(map[interface{}]int64)
cursors := make([]int64, len(validRetrieveResults))
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(validRetrieveResults))
for i, vr := range validRetrieveResults {
idxComputers[i] = typeutil.NewFieldDataIdxComputer(vr.Result.GetFieldsData())
}
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
for j := 0; j < loopEnd; {
@ -335,9 +340,11 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
pk := typeutil.GetPK(validRetrieveResults[sel].GetIds(), cursors[sel])
ts := validRetrieveResults[sel].Timestamps[cursors[sel]]
fieldsData := validRetrieveResults[sel].Result.GetFieldsData()
fieldIdxs := idxComputers[sel].Compute(cursors[sel])
if _, ok := idTsMap[pk]; !ok {
typeutil.AppendPKs(ret.Ids, pk)
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].Result.GetFieldsData(), cursors[sel])
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, cursors[sel], fieldIdxs...)
idTsMap[pk] = ts
j++
} else {
@ -346,7 +353,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna
if ts != 0 && ts > idTsMap[pk] {
idTsMap[pk] = ts
typeutil.DeleteFieldData(ret.FieldsData)
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].Result.GetFieldsData(), cursors[sel])
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, cursors[sel], fieldIdxs...)
}
}
@ -511,10 +518,17 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
_, span2 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData")
defer span2.End()
ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].Result.GetFieldsData(), int64(len(selections)))
// cursors = make([]int64, len(validRetrieveResults))
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(validRetrieveResults))
for i, vr := range validRetrieveResults {
idxComputers[i] = typeutil.NewFieldDataIdxComputer(vr.Result.GetFieldsData())
}
for _, selection := range selections {
// cannot use `cursors[sel]` directly, since some of them may be skipped.
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[selection.batchIndex].Result.GetFieldsData(), selection.resultIndex)
fieldsData := validRetrieveResults[selection.batchIndex].Result.GetFieldsData()
fieldIdxs := idxComputers[selection.batchIndex].Compute(selection.resultIndex)
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, selection.resultIndex, fieldIdxs...)
// limit retrieve result to avoid oom
if retSize > maxOutputSize {
@ -564,10 +578,18 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
_, span3 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData")
defer span3.End()
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(segmentResults))
for i, r := range segmentResults {
idxComputers[i] = typeutil.NewFieldDataIdxComputer(r.GetFieldsData())
}
// retrieve result is compacted, use 0,1,2...end
segmentResOffset := make([]int64, len(segmentResults))
for _, selection := range selections {
retSize += typeutil.AppendFieldData(ret.FieldsData, segmentResults[selection.batchIndex].GetFieldsData(), segmentResOffset[selection.batchIndex])
fieldsData := segmentResults[selection.batchIndex].GetFieldsData()
fieldIdxs := idxComputers[selection.batchIndex].Compute(segmentResOffset[selection.batchIndex])
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, segmentResOffset[selection.batchIndex], fieldIdxs...)
segmentResOffset[selection.batchIndex]++
// limit retrieve result to avoid oom
if retSize > maxOutputSize {

View File

@ -68,6 +68,11 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
ret.AllSearchCount += searchResultData[i].GetAllSearchCount()
}
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(searchResultData))
for i, srd := range searchResultData {
idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData)
}
var skipDupCnt int64
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
@ -87,7 +92,9 @@ func (scr *SearchCommonReduce) ReduceSearchResultData(ctx context.Context, searc
// remove duplicates
if _, ok := idSet[id]; !ok {
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
fieldsData := searchResultData[sel].FieldsData
fieldIdxs := idxComputers[sel].Compute(idx)
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, idx, fieldIdxs...)
typeutil.AppendPKs(ret.Ids, id)
ret.Scores = append(ret.Scores, score)
if searchResultData[sel].ElementIndices != nil && ret.ElementIndices != nil {
@ -173,6 +180,11 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
return ret, merr.WrapErrServiceInternal("failed to construct group by field data builder, this is abnormal as segcore should always set up a group by field, no matter data status, check code on qn", err.Error())
}
idxComputers := make([]*typeutil.FieldDataIdxComputer, len(searchResultData))
for i, srd := range searchResultData {
idxComputers[i] = typeutil.NewFieldDataIdxComputer(srd.FieldsData)
}
var filteredCount int64
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
@ -208,7 +220,9 @@ func (sbr *SearchGroupByReduce) ReduceSearchResultData(ctx context.Context, sear
// exceed the limit for each group, filter this entity
filteredCount++
} else {
retSize += typeutil.AppendFieldData(ret.FieldsData, searchResultData[sel].FieldsData, idx)
fieldsData := searchResultData[sel].FieldsData
fieldIdxs := idxComputers[sel].Compute(idx)
retSize += typeutil.AppendFieldData(ret.FieldsData, fieldsData, idx, fieldIdxs...)
typeutil.AppendPKs(ret.Ids, id)
ret.Scores = append(ret.Scores, score)
if searchResultData[sel].ElementIndices != nil && ret.ElementIndices != nil {

View File

@ -667,14 +667,14 @@ func Test_createCollectionTask_validateSchema(t *testing.T) {
DataType: schemapb.DataType_ArrayOfVector,
ElementType: schemapb.DataType_FloatVector,
Nullable: true,
TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}},
},
},
},
},
}
err := task.validateSchema(context.TODO(), schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "vector type not support null")
assert.NoError(t, err)
})
t.Run("struct array field - field with default value", func(t *testing.T) {
@ -980,7 +980,7 @@ func Test_createCollectionTask_prepareSchema(t *testing.T) {
assert.Error(t, err)
})
t.Run("vector type not support null", func(t *testing.T) {
t.Run("vector type with nullable", func(t *testing.T) {
collectionName := funcutil.GenRandomStr()
field1 := funcutil.GenRandomStr()
schema := &schemapb.CollectionSchema{
@ -989,9 +989,17 @@ func Test_createCollectionTask_prepareSchema(t *testing.T) {
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: field1,
DataType: 101,
Nullable: true,
FieldID: 100,
Name: "pk",
DataType: schemapb.DataType_Int64,
IsPrimaryKey: true,
},
{
FieldID: 101,
Name: field1,
DataType: schemapb.DataType_FloatVector,
Nullable: true,
TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}},
},
},
}
@ -1005,7 +1013,7 @@ func Test_createCollectionTask_prepareSchema(t *testing.T) {
},
}
err := task.prepareSchema(context.TODO())
assert.Error(t, err)
assert.NoError(t, err)
})
}

View File

@ -374,10 +374,6 @@ func checkFieldSchema(fieldSchemas []*schemapb.FieldSchema) error {
msg := fmt.Sprintf("ArrayOfVector is only supported in struct array field, type:%s, name:%s", fieldSchema.GetDataType().String(), fieldSchema.GetName())
return merr.WrapErrParameterInvalidMsg(msg)
}
if fieldSchema.GetNullable() && typeutil.IsVectorType(fieldSchema.GetDataType()) {
msg := fmt.Sprintf("vector type not support null, type:%s, name:%s", fieldSchema.GetDataType().String(), fieldSchema.GetName())
return merr.WrapErrParameterInvalidMsg(msg)
}
if fieldSchema.GetNullable() && fieldSchema.IsPrimaryKey {
msg := fmt.Sprintf("primary field not support null, type:%s, name:%s", fieldSchema.GetDataType().String(), fieldSchema.GetName())
return merr.WrapErrParameterInvalidMsg(msg)
@ -502,11 +498,6 @@ func checkStructArrayFieldSchema(schemas []*schemapb.StructArrayFieldSchema) err
field.DataType.String(), field.ElementType.String(), field.Name)
return merr.WrapErrParameterInvalidMsg(msg)
}
if field.GetNullable() && typeutil.IsVectorType(field.ElementType) {
msg := fmt.Sprintf("vector type not support null, data type:%s, element type:%s, name:%s",
field.DataType.String(), field.ElementType.String(), field.Name)
return merr.WrapErrParameterInvalidMsg(msg)
}
if field.GetDefaultValue() != nil {
msg := fmt.Sprintf("fields in struct array field not support default_value, data type:%s, element type:%s, name:%s",
field.DataType.String(), field.ElementType.String(), field.Name)

View File

@ -391,8 +391,12 @@ func NewRecordBuilder(schema *schemapb.CollectionSchema) *RecordBuilder {
if field.DataType == schemapb.DataType_ArrayOfVector {
elementType = field.GetElementType()
}
arrowType := serdeMap[field.DataType].arrowType(int(dim), elementType)
builders[i] = array.NewBuilder(memory.DefaultAllocator, arrowType)
if field.GetNullable() && typeutil.IsVectorType(field.DataType) && !typeutil.IsSparseFloatVectorType(field.DataType) {
builders[i] = array.NewBinaryBuilder(memory.DefaultAllocator, arrow.BinaryTypes.Binary)
} else {
arrowType := serdeMap[field.DataType].arrowType(int(dim), elementType)
builders[i] = array.NewBuilder(memory.DefaultAllocator, arrowType)
}
}
return &RecordBuilder{

View File

@ -448,19 +448,19 @@ func AddFieldDataToPayload(eventWriter *insertEventWriter, dataType schemapb.Dat
}
}
case schemapb.DataType_BinaryVector:
if err = eventWriter.AddBinaryVectorToPayload(singleData.(*BinaryVectorFieldData).Data, singleData.(*BinaryVectorFieldData).Dim); err != nil {
if err = eventWriter.AddBinaryVectorToPayload(singleData.(*BinaryVectorFieldData).Data, singleData.(*BinaryVectorFieldData).Dim, singleData.(*BinaryVectorFieldData).ValidData); err != nil {
return err
}
case schemapb.DataType_FloatVector:
if err = eventWriter.AddFloatVectorToPayload(singleData.(*FloatVectorFieldData).Data, singleData.(*FloatVectorFieldData).Dim); err != nil {
if err = eventWriter.AddFloatVectorToPayload(singleData.(*FloatVectorFieldData).Data, singleData.(*FloatVectorFieldData).Dim, singleData.(*FloatVectorFieldData).ValidData); err != nil {
return err
}
case schemapb.DataType_Float16Vector:
if err = eventWriter.AddFloat16VectorToPayload(singleData.(*Float16VectorFieldData).Data, singleData.(*Float16VectorFieldData).Dim); err != nil {
if err = eventWriter.AddFloat16VectorToPayload(singleData.(*Float16VectorFieldData).Data, singleData.(*Float16VectorFieldData).Dim, singleData.(*Float16VectorFieldData).ValidData); err != nil {
return err
}
case schemapb.DataType_BFloat16Vector:
if err = eventWriter.AddBFloat16VectorToPayload(singleData.(*BFloat16VectorFieldData).Data, singleData.(*BFloat16VectorFieldData).Dim); err != nil {
if err = eventWriter.AddBFloat16VectorToPayload(singleData.(*BFloat16VectorFieldData).Data, singleData.(*BFloat16VectorFieldData).Dim, singleData.(*BFloat16VectorFieldData).ValidData); err != nil {
return err
}
case schemapb.DataType_SparseFloatVector:
@ -468,7 +468,7 @@ func AddFieldDataToPayload(eventWriter *insertEventWriter, dataType schemapb.Dat
return err
}
case schemapb.DataType_Int8Vector:
if err = eventWriter.AddInt8VectorToPayload(singleData.(*Int8VectorFieldData).Data, singleData.(*Int8VectorFieldData).Dim); err != nil {
if err = eventWriter.AddInt8VectorToPayload(singleData.(*Int8VectorFieldData).Data, singleData.(*Int8VectorFieldData).Dim, singleData.(*Int8VectorFieldData).ValidData); err != nil {
return err
}
case schemapb.DataType_ArrayOfVector:
@ -747,6 +747,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
return length, err
}
binaryVectorFieldData.Dim = dim
if validData != nil && len(validData) > 0 {
startLogical := len(binaryVectorFieldData.ValidData)
if binaryVectorFieldData.ValidData == nil {
binaryVectorFieldData.ValidData = make([]bool, 0, rowNum)
}
binaryVectorFieldData.ValidData = append(binaryVectorFieldData.ValidData, validData...)
binaryVectorFieldData.Nullable = true
binaryVectorFieldData.L2PMapping.Build(validData, startLogical, len(validData))
}
insertData.Data[fieldID] = binaryVectorFieldData
return length, nil
@ -763,6 +772,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
return length, err
}
float16VectorFieldData.Dim = dim
if validData != nil && len(validData) > 0 {
startLogical := len(float16VectorFieldData.ValidData)
if float16VectorFieldData.ValidData == nil {
float16VectorFieldData.ValidData = make([]bool, 0, rowNum)
}
float16VectorFieldData.ValidData = append(float16VectorFieldData.ValidData, validData...)
float16VectorFieldData.Nullable = true
float16VectorFieldData.L2PMapping.Build(validData, startLogical, len(validData))
}
insertData.Data[fieldID] = float16VectorFieldData
return length, nil
@ -779,6 +797,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
return length, err
}
bfloat16VectorFieldData.Dim = dim
if validData != nil && len(validData) > 0 {
startLogical := len(bfloat16VectorFieldData.ValidData)
if bfloat16VectorFieldData.ValidData == nil {
bfloat16VectorFieldData.ValidData = make([]bool, 0, rowNum)
}
bfloat16VectorFieldData.ValidData = append(bfloat16VectorFieldData.ValidData, validData...)
bfloat16VectorFieldData.Nullable = true
bfloat16VectorFieldData.L2PMapping.Build(validData, startLogical, len(validData))
}
insertData.Data[fieldID] = bfloat16VectorFieldData
return length, nil
@ -795,6 +822,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
return 0, err
}
floatVectorFieldData.Dim = dim
if validData != nil && len(validData) > 0 {
startLogical := len(floatVectorFieldData.ValidData)
if floatVectorFieldData.ValidData == nil {
floatVectorFieldData.ValidData = make([]bool, 0, rowNum)
}
floatVectorFieldData.ValidData = append(floatVectorFieldData.ValidData, validData...)
floatVectorFieldData.Nullable = true
floatVectorFieldData.L2PMapping.Build(validData, startLogical, len(validData))
}
insertData.Data[fieldID] = floatVectorFieldData
return length, nil
@ -805,6 +841,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
}
vec := fieldData.(*SparseFloatVectorFieldData)
vec.AppendAllRows(singleData)
if validData != nil && len(validData) > 0 {
startLogical := len(vec.ValidData)
if vec.ValidData == nil {
vec.ValidData = make([]bool, 0, rowNum)
}
vec.ValidData = append(vec.ValidData, validData...)
vec.Nullable = true
vec.L2PMapping.Build(validData, startLogical, len(validData))
}
insertData.Data[fieldID] = vec
return singleData.RowNum(), nil
@ -821,6 +866,15 @@ func AddInsertData(dataType schemapb.DataType, data interface{}, insertData *Ins
return 0, err
}
int8VectorFieldData.Dim = dim
if validData != nil && len(validData) > 0 {
startLogical := len(int8VectorFieldData.ValidData)
if int8VectorFieldData.ValidData == nil {
int8VectorFieldData.ValidData = make([]bool, 0, rowNum)
}
int8VectorFieldData.ValidData = append(int8VectorFieldData.ValidData, validData...)
int8VectorFieldData.Nullable = true
int8VectorFieldData.L2PMapping.Build(validData, startLogical, len(validData))
}
insertData.Data[fieldID] = int8VectorFieldData
return length, nil

View File

@ -35,30 +35,36 @@ import (
)
const (
CollectionID = 1
PartitionID = 1
SegmentID = 1
RowIDField = 0
TimestampField = 1
BoolField = 100
Int8Field = 101
Int16Field = 102
Int32Field = 103
Int64Field = 104
FloatField = 105
DoubleField = 106
StringField = 107
BinaryVectorField = 108
FloatVectorField = 109
ArrayField = 110
JSONField = 111
Float16VectorField = 112
BFloat16VectorField = 113
SparseFloatVectorField = 114
Int8VectorField = 115
StructField = 116
StructSubInt32Field = 117
StructSubFloatVectorField = 118
CollectionID = 1
PartitionID = 1
SegmentID = 1
RowIDField = 0
TimestampField = 1
BoolField = 100
Int8Field = 101
Int16Field = 102
Int32Field = 103
Int64Field = 104
FloatField = 105
DoubleField = 106
StringField = 107
BinaryVectorField = 108
FloatVectorField = 109
ArrayField = 110
JSONField = 111
Float16VectorField = 112
BFloat16VectorField = 113
SparseFloatVectorField = 114
Int8VectorField = 115
StructField = 116
StructSubInt32Field = 117
StructSubFloatVectorField = 118
NullableFloatVectorField = 119
NullableBinaryVectorField = 120
NullableFloat16VectorField = 121
NullableBFloat16VectorField = 122
NullableInt8VectorField = 123
NullableSparseFloatVectorField = 124
)
func assertTestData(t *testing.T, i int, value *Value) {
@ -284,26 +290,35 @@ func generateTestDataWithSeed(seed, num int) ([]*Blob, error) {
19: &JSONFieldData{Data: field19},
101: &Int32FieldData{Data: field101},
102: &FloatVectorFieldData{
Data: field102,
Dim: 8,
Data: field102,
ValidData: nil,
Dim: 8,
Nullable: false,
},
103: &BinaryVectorFieldData{
Data: field103,
Dim: 8,
Data: field103,
ValidData: nil,
Dim: 8,
Nullable: false,
},
104: &Float16VectorFieldData{
Data: field104,
Dim: 8,
Data: field104,
ValidData: nil,
Dim: 8,
Nullable: false,
},
105: &BFloat16VectorFieldData{
Data: field105,
Dim: 8,
Data: field105,
ValidData: nil,
Dim: 8,
Nullable: false,
},
106: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 28433,
Contents: field106,
},
Nullable: false,
},
}}
@ -616,6 +631,79 @@ func genTestCollectionMeta() *etcdpb.CollectionMeta {
},
},
},
{
FieldID: NullableFloatVectorField,
Name: "field_nullable_float_vector",
Description: "nullable_float_vector",
DataType: schemapb.DataType_FloatVector,
Nullable: true,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "4",
},
},
},
{
FieldID: NullableBinaryVectorField,
Name: "field_nullable_binary_vector",
Description: "nullable_binary_vector",
DataType: schemapb.DataType_BinaryVector,
Nullable: true,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "8",
},
},
},
{
FieldID: NullableFloat16VectorField,
Name: "field_nullable_float16_vector",
Description: "nullable_float16_vector",
DataType: schemapb.DataType_Float16Vector,
Nullable: true,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "4",
},
},
},
{
FieldID: NullableBFloat16VectorField,
Name: "field_nullable_bfloat16_vector",
Description: "nullable_bfloat16_vector",
DataType: schemapb.DataType_BFloat16Vector,
Nullable: true,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "4",
},
},
},
{
FieldID: NullableInt8VectorField,
Name: "field_nullable_int8_vector",
Description: "nullable_int8_vector",
DataType: schemapb.DataType_Int8Vector,
Nullable: true,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "4",
},
},
},
{
FieldID: NullableSparseFloatVectorField,
Name: "field_nullable_sparse_float_vector",
Description: "nullable_sparse_float_vector",
DataType: schemapb.DataType_SparseFloatVector,
Nullable: true,
TypeParams: []*commonpb.KeyValuePair{},
},
},
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
@ -649,63 +737,6 @@ func genTestCollectionMeta() *etcdpb.CollectionMeta {
}
}
func TestInsertCodecFailed(t *testing.T) {
t.Run("vector field not support null", func(t *testing.T) {
tests := []struct {
description string
dataType schemapb.DataType
}{
{"nullable FloatVector field", schemapb.DataType_FloatVector},
{"nullable Float16Vector field", schemapb.DataType_Float16Vector},
{"nullable BinaryVector field", schemapb.DataType_BinaryVector},
{"nullable BFloat16Vector field", schemapb.DataType_BFloat16Vector},
{"nullable SparseFloatVector field", schemapb.DataType_SparseFloatVector},
{"nullable Int8Vector field", schemapb.DataType_Int8Vector},
}
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
schema := &etcdpb.CollectionMeta{
ID: CollectionID,
CreateTime: 1,
SegmentIDs: []int64{SegmentID},
PartitionTags: []string{"partition_0", "partition_1"},
Schema: &schemapb.CollectionSchema{
Name: "schema",
Description: "schema",
Fields: []*schemapb.FieldSchema{
{
FieldID: RowIDField,
Name: "row_id",
Description: "row_id",
DataType: schemapb.DataType_Int64,
},
{
FieldID: TimestampField,
Name: "Timestamp",
Description: "Timestamp",
DataType: schemapb.DataType_Int64,
},
{
DataType: test.dataType,
},
},
},
}
insertCodec := NewInsertCodecWithSchema(schema)
insertDataEmpty := &InsertData{
Data: map[int64]FieldData{
RowIDField: &Int64FieldData{[]int64{}, nil, false},
TimestampField: &Int64FieldData{[]int64{}, nil, false},
},
}
_, err := insertCodec.Serialize(PartitionID, SegmentID, insertDataEmpty)
assert.Error(t, err)
})
}
})
}
func TestInsertCodec(t *testing.T) {
schema := genTestCollectionMeta()
insertCodec := NewInsertCodecWithSchema(schema)
@ -742,12 +773,16 @@ func TestInsertCodec(t *testing.T) {
Data: []string{"3", "4"},
},
BinaryVectorField: &BinaryVectorFieldData{
Data: []byte{0, 255},
Dim: 8,
Data: []byte{0, 255},
ValidData: nil,
Dim: 8,
Nullable: false,
},
FloatVectorField: &FloatVectorFieldData{
Data: []float32{4, 5, 6, 7, 4, 5, 6, 7},
Dim: 4,
Data: []float32{4, 5, 6, 7, 4, 5, 6, 7},
ValidData: nil,
Dim: 4,
Nullable: false,
},
ArrayField: &ArrayFieldData{
ElementType: schemapb.DataType_Int32,
@ -772,13 +807,17 @@ func TestInsertCodec(t *testing.T) {
},
Float16VectorField: &Float16VectorFieldData{
// length = 2 * Dim * numRows(2) = 16
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
Dim: 4,
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
ValidData: nil,
Dim: 4,
Nullable: false,
},
BFloat16VectorField: &BFloat16VectorFieldData{
// length = 2 * Dim * numRows(2) = 16
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
Dim: 4,
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
ValidData: nil,
Dim: 4,
Nullable: false,
},
SparseFloatVectorField: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
@ -789,10 +828,13 @@ func TestInsertCodec(t *testing.T) {
typeutil.CreateSparseFloatRow([]uint32{100, 200, 599}, []float32{3.1, 3.2, 3.3}),
},
},
Nullable: false,
},
Int8VectorField: &Int8VectorFieldData{
Data: []int8{-4, -5, -6, -7, -4, -5, -6, -7},
Dim: 4,
Data: []int8{-4, -5, -6, -7, -4, -5, -6, -7},
ValidData: nil,
Dim: 4,
Nullable: false,
},
StructSubInt32Field: &ArrayFieldData{
ElementType: schemapb.DataType_Int32,
@ -827,6 +869,36 @@ func TestInsertCodec(t *testing.T) {
},
},
},
NullableFloatVectorField: &FloatVectorFieldData{
Data: []float32{4.0, 5.0, 6.0, 7.0},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
NullableBinaryVectorField: &BinaryVectorFieldData{
Data: []byte{255},
ValidData: []bool{true, false},
Dim: 8,
Nullable: true,
},
NullableFloat16VectorField: &Float16VectorFieldData{
Data: []byte{255, 0, 255, 0, 255, 0, 255, 0},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
NullableBFloat16VectorField: &BFloat16VectorFieldData{
Data: []byte{255, 0, 255, 0, 255, 0, 255, 0},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
NullableInt8VectorField: &Int8VectorFieldData{
Data: []int8{-4, -5, -6, -7},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
},
}
@ -863,22 +935,30 @@ func TestInsertCodec(t *testing.T) {
Data: []string{"1", "2"},
},
BinaryVectorField: &BinaryVectorFieldData{
Data: []byte{0, 255},
Dim: 8,
Data: []byte{0, 255},
ValidData: nil,
Dim: 8,
Nullable: false,
},
FloatVectorField: &FloatVectorFieldData{
Data: []float32{0, 1, 2, 3, 0, 1, 2, 3},
Dim: 4,
Data: []float32{0, 1, 2, 3, 0, 1, 2, 3},
ValidData: nil,
Dim: 4,
Nullable: false,
},
Float16VectorField: &Float16VectorFieldData{
// length = 2 * Dim * numRows(2) = 16
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
Dim: 4,
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
ValidData: nil,
Dim: 4,
Nullable: false,
},
BFloat16VectorField: &BFloat16VectorFieldData{
// length = 2 * Dim * numRows(2) = 16
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
Dim: 4,
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255},
ValidData: nil,
Dim: 4,
Nullable: false,
},
SparseFloatVectorField: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
@ -889,10 +969,13 @@ func TestInsertCodec(t *testing.T) {
typeutil.CreateSparseFloatRow([]uint32{105, 207, 299}, []float32{3.1, 3.2, 3.3}),
},
},
Nullable: false,
},
Int8VectorField: &Int8VectorFieldData{
Data: []int8{0, 1, 2, 3, 0, 1, 2, 3},
Dim: 4,
Data: []int8{0, 1, 2, 3, 0, 1, 2, 3},
ValidData: nil,
Dim: 4,
Nullable: false,
},
StructSubInt32Field: &ArrayFieldData{
ElementType: schemapb.DataType_Int32,
@ -927,6 +1010,46 @@ func TestInsertCodec(t *testing.T) {
},
},
},
NullableFloatVectorField: &FloatVectorFieldData{
Data: []float32{0.0, 1.0, 2.0, 3.0},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
NullableBinaryVectorField: &BinaryVectorFieldData{
Data: []byte{0},
ValidData: []bool{true, false},
Dim: 8,
Nullable: true,
},
NullableFloat16VectorField: &Float16VectorFieldData{
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
NullableBFloat16VectorField: &BFloat16VectorFieldData{
Data: []byte{0, 255, 0, 255, 0, 255, 0, 255},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
NullableInt8VectorField: &Int8VectorFieldData{
Data: []int8{0, 1, 2, 3},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
NullableSparseFloatVectorField: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 300,
Contents: [][]byte{
typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
},
},
ValidData: []bool{true, false},
Nullable: true,
},
ArrayField: &ArrayFieldData{
ElementType: schemapb.DataType_Int32,
Data: []*schemapb.ScalarField{
@ -953,27 +1076,84 @@ func TestInsertCodec(t *testing.T) {
insertDataEmpty := &InsertData{
Data: map[int64]FieldData{
RowIDField: &Int64FieldData{[]int64{}, nil, false},
TimestampField: &Int64FieldData{[]int64{}, nil, false},
BoolField: &BoolFieldData{[]bool{}, nil, false},
Int8Field: &Int8FieldData{[]int8{}, nil, false},
Int16Field: &Int16FieldData{[]int16{}, nil, false},
Int32Field: &Int32FieldData{[]int32{}, nil, false},
Int64Field: &Int64FieldData{[]int64{}, nil, false},
FloatField: &FloatFieldData{[]float32{}, nil, false},
DoubleField: &DoubleFieldData{[]float64{}, nil, false},
StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false},
BinaryVectorField: &BinaryVectorFieldData{[]byte{}, 8},
FloatVectorField: &FloatVectorFieldData{[]float32{}, 4},
Float16VectorField: &Float16VectorFieldData{[]byte{}, 4},
BFloat16VectorField: &BFloat16VectorFieldData{[]byte{}, 4},
RowIDField: &Int64FieldData{[]int64{}, nil, false},
TimestampField: &Int64FieldData{[]int64{}, nil, false},
BoolField: &BoolFieldData{[]bool{}, nil, false},
Int8Field: &Int8FieldData{[]int8{}, nil, false},
Int16Field: &Int16FieldData{[]int16{}, nil, false},
Int32Field: &Int32FieldData{[]int32{}, nil, false},
Int64Field: &Int64FieldData{[]int64{}, nil, false},
FloatField: &FloatFieldData{[]float32{}, nil, false},
DoubleField: &DoubleFieldData{[]float64{}, nil, false},
StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false},
BinaryVectorField: &BinaryVectorFieldData{
Data: []byte{},
ValidData: nil,
Dim: 8,
Nullable: false,
},
FloatVectorField: &FloatVectorFieldData{
Data: []float32{},
ValidData: nil,
Dim: 4,
Nullable: false,
},
Float16VectorField: &Float16VectorFieldData{
Data: []byte{},
ValidData: nil,
Dim: 4,
Nullable: false,
},
BFloat16VectorField: &BFloat16VectorFieldData{
Data: []byte{},
ValidData: nil,
Dim: 4,
Nullable: false,
},
SparseFloatVectorField: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 0,
Contents: [][]byte{},
},
ValidData: nil,
Nullable: false,
},
Int8VectorField: &Int8VectorFieldData{
Data: []int8{},
ValidData: nil,
Dim: 4,
Nullable: false,
},
NullableFloatVectorField: &FloatVectorFieldData{
Data: []float32{},
ValidData: []bool{},
Dim: 4,
Nullable: true,
},
NullableBinaryVectorField: &BinaryVectorFieldData{
Data: []byte{},
ValidData: []bool{},
Dim: 8,
Nullable: true,
},
NullableFloat16VectorField: &Float16VectorFieldData{
Data: []byte{},
ValidData: []bool{},
Dim: 4,
Nullable: true,
},
NullableBFloat16VectorField: &BFloat16VectorFieldData{
Data: []byte{},
ValidData: []bool{},
Dim: 4,
Nullable: true,
},
NullableInt8VectorField: &Int8VectorFieldData{
Data: []int8{},
ValidData: []bool{},
Dim: 4,
Nullable: true,
},
Int8VectorField: &Int8VectorFieldData{[]int8{}, 4},
StructSubInt32Field: &ArrayFieldData{schemapb.DataType_Int32, []*schemapb.ScalarField{}, nil, false},
ArrayField: &ArrayFieldData{schemapb.DataType_Int32, []*schemapb.ScalarField{}, nil, false},
JSONField: &JSONFieldData{[][]byte{}, nil, false},
@ -1321,24 +1501,74 @@ func TestMemorySize(t *testing.T) {
Data: []string{"3"},
},
BinaryVectorField: &BinaryVectorFieldData{
Data: []byte{0},
Dim: 8,
Data: []byte{0},
ValidData: nil,
Dim: 8,
Nullable: false,
},
FloatVectorField: &FloatVectorFieldData{
Data: []float32{4, 5, 6, 7},
Dim: 4,
Data: []float32{4, 5, 6, 7},
ValidData: nil,
Dim: 4,
Nullable: false,
},
Float16VectorField: &Float16VectorFieldData{
Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7},
Dim: 4,
Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7},
ValidData: nil,
Dim: 4,
Nullable: false,
},
BFloat16VectorField: &BFloat16VectorFieldData{
Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7},
Dim: 4,
Data: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7},
ValidData: nil,
Dim: 4,
Nullable: false,
},
Int8VectorField: &Int8VectorFieldData{
Data: []int8{4, 5, 6, 7},
Dim: 4,
Data: []int8{4, 5, 6, 7},
ValidData: nil,
Dim: 4,
Nullable: false,
},
NullableFloatVectorField: &FloatVectorFieldData{
Data: []float32{4.0, 5.0, 6.0, 7.0},
ValidData: []bool{true},
Dim: 4,
Nullable: true,
},
NullableBinaryVectorField: &BinaryVectorFieldData{
Data: []byte{255},
ValidData: []bool{true},
Dim: 8,
Nullable: true,
},
NullableFloat16VectorField: &Float16VectorFieldData{
Data: []byte{0xff, 0x0, 0xff, 0x0, 0xff, 0x0, 0xff, 0x0},
ValidData: []bool{true},
Dim: 4,
Nullable: true,
},
NullableBFloat16VectorField: &BFloat16VectorFieldData{
Data: []byte{0xff, 0x0, 0xff, 0x0, 0xff, 0x0, 0xff, 0x0},
ValidData: []bool{true},
Dim: 4,
Nullable: true,
},
NullableInt8VectorField: &Int8VectorFieldData{
Data: []int8{4, 5, 6, 7},
ValidData: []bool{true},
Dim: 4,
Nullable: true,
},
NullableSparseFloatVectorField: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 300,
Contents: [][]byte{
typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
},
},
ValidData: []bool{true},
Nullable: true,
},
ArrayField: &ArrayFieldData{
ElementType: schemapb.DataType_Int32,
@ -1389,15 +1619,21 @@ func TestMemorySize(t *testing.T) {
assert.Equal(t, insertData1.Data[FloatField].GetMemorySize(), 5)
assert.Equal(t, insertData1.Data[DoubleField].GetMemorySize(), 9)
assert.Equal(t, insertData1.Data[StringField].GetMemorySize(), 18)
assert.Equal(t, insertData1.Data[BinaryVectorField].GetMemorySize(), 5)
assert.Equal(t, insertData1.Data[FloatVectorField].GetMemorySize(), 20)
assert.Equal(t, insertData1.Data[Float16VectorField].GetMemorySize(), 12)
assert.Equal(t, insertData1.Data[BFloat16VectorField].GetMemorySize(), 12)
assert.Equal(t, insertData1.Data[Int8VectorField].GetMemorySize(), 8)
assert.Equal(t, insertData1.Data[ArrayField].GetMemorySize(), 3*4+1)
assert.Equal(t, insertData1.Data[JSONField].GetMemorySize(), len([]byte(`{"batch":1}`))+16+1)
assert.Equal(t, insertData1.Data[StructSubInt32Field].GetMemorySize(), 4*4+1)
assert.Equal(t, insertData1.Data[StructSubFloatVectorField].GetMemorySize(), 4*4+4)
assert.Equal(t, insertData1.Data[BinaryVectorField].GetMemorySize(), 14)
assert.Equal(t, insertData1.Data[FloatVectorField].GetMemorySize(), 29)
assert.Equal(t, insertData1.Data[Float16VectorField].GetMemorySize(), 21)
assert.Equal(t, insertData1.Data[BFloat16VectorField].GetMemorySize(), 21)
assert.Equal(t, insertData1.Data[Int8VectorField].GetMemorySize(), 17)
assert.Equal(t, insertData1.Data[NullableFloatVectorField].GetMemorySize(), 30)
assert.Equal(t, insertData1.Data[NullableBinaryVectorField].GetMemorySize(), 15)
assert.Equal(t, insertData1.Data[NullableFloat16VectorField].GetMemorySize(), 22)
assert.Equal(t, insertData1.Data[NullableBFloat16VectorField].GetMemorySize(), 22)
assert.Equal(t, insertData1.Data[NullableInt8VectorField].GetMemorySize(), 18)
assert.Equal(t, insertData1.Data[NullableSparseFloatVectorField].GetMemorySize(), 39)
assert.Equal(t, insertData1.Data[ArrayField].GetMemorySize(), 13)
assert.Equal(t, insertData1.Data[JSONField].GetMemorySize(), 28)
assert.Equal(t, insertData1.Data[StructSubInt32Field].GetMemorySize(), 17)
assert.Equal(t, insertData1.Data[StructSubFloatVectorField].GetMemorySize(), 20)
insertData2 := &InsertData{
Data: map[int64]FieldData{
@ -1432,24 +1668,84 @@ func TestMemorySize(t *testing.T) {
Data: []string{"1", "23"},
},
BinaryVectorField: &BinaryVectorFieldData{
Data: []byte{0, 255},
Dim: 8,
Data: []byte{0, 255},
ValidData: nil,
Dim: 8,
Nullable: false,
},
FloatVectorField: &FloatVectorFieldData{
Data: []float32{0, 1, 2, 3, 0, 1, 2, 3},
Dim: 4,
Data: []float32{0, 1, 2, 3, 0, 1, 2, 3},
ValidData: nil,
Dim: 4,
Nullable: false,
},
Float16VectorField: &Float16VectorFieldData{
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7},
Dim: 4,
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7},
ValidData: nil,
Dim: 4,
Nullable: false,
},
BFloat16VectorField: &BFloat16VectorFieldData{
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7},
Dim: 4,
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7},
ValidData: nil,
Dim: 4,
Nullable: false,
},
Int8VectorField: &Int8VectorFieldData{
Data: []int8{0, 1, 2, 3, 0, 1, 2, 3},
Dim: 4,
Data: []int8{0, 1, 2, 3, 0, 1, 2, 3},
ValidData: nil,
Dim: 4,
Nullable: false,
},
SparseFloatVectorField: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 300,
Contents: [][]byte{
typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{1.1, 1.2, 1.3}),
typeutil.CreateSparseFloatRow([]uint32{10, 20, 30}, []float32{2.1, 2.2, 2.3}),
},
},
Nullable: false,
},
NullableFloatVectorField: &FloatVectorFieldData{
Data: []float32{0.0, 1.0, 2.0, 3.0},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
NullableBinaryVectorField: &BinaryVectorFieldData{
Data: []byte{0},
ValidData: []bool{true, false},
Dim: 8,
Nullable: true,
},
NullableFloat16VectorField: &Float16VectorFieldData{
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
NullableBFloat16VectorField: &BFloat16VectorFieldData{
Data: []byte{0, 1, 2, 3, 4, 5, 6, 7},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
NullableInt8VectorField: &Int8VectorFieldData{
Data: []int8{0, 1, 2, 3},
ValidData: []bool{true, false},
Dim: 4,
Nullable: true,
},
NullableSparseFloatVectorField: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 300,
Contents: [][]byte{
typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
},
},
ValidData: []bool{true, false},
Nullable: true,
},
},
}
@ -1464,29 +1760,107 @@ func TestMemorySize(t *testing.T) {
assert.Equal(t, insertData2.Data[FloatField].GetMemorySize(), 9)
assert.Equal(t, insertData2.Data[DoubleField].GetMemorySize(), 17)
assert.Equal(t, insertData2.Data[StringField].GetMemorySize(), 36)
assert.Equal(t, insertData2.Data[BinaryVectorField].GetMemorySize(), 6)
assert.Equal(t, insertData2.Data[FloatVectorField].GetMemorySize(), 36)
assert.Equal(t, insertData2.Data[Float16VectorField].GetMemorySize(), 20)
assert.Equal(t, insertData2.Data[BFloat16VectorField].GetMemorySize(), 20)
assert.Equal(t, insertData2.Data[Int8VectorField].GetMemorySize(), 12)
assert.Equal(t, insertData2.Data[BinaryVectorField].GetMemorySize(), 15)
assert.Equal(t, insertData2.Data[FloatVectorField].GetMemorySize(), 45)
assert.Equal(t, insertData2.Data[Float16VectorField].GetMemorySize(), 29)
assert.Equal(t, insertData2.Data[BFloat16VectorField].GetMemorySize(), 29)
assert.Equal(t, insertData2.Data[Int8VectorField].GetMemorySize(), 21)
assert.Equal(t, insertData2.Data[SparseFloatVectorField].GetMemorySize(), 64)
assert.Equal(t, insertData2.Data[NullableBinaryVectorField].GetMemorySize(), 16)
assert.Equal(t, insertData2.Data[NullableFloatVectorField].GetMemorySize(), 31)
assert.Equal(t, insertData2.Data[NullableFloat16VectorField].GetMemorySize(), 23)
assert.Equal(t, insertData2.Data[NullableBFloat16VectorField].GetMemorySize(), 23)
assert.Equal(t, insertData2.Data[NullableInt8VectorField].GetMemorySize(), 19)
assert.Equal(t, insertData2.Data[NullableSparseFloatVectorField].GetMemorySize(), 40)
insertDataEmpty := &InsertData{
Data: map[int64]FieldData{
RowIDField: &Int64FieldData{[]int64{}, nil, false},
TimestampField: &Int64FieldData{[]int64{}, nil, false},
BoolField: &BoolFieldData{[]bool{}, nil, false},
Int8Field: &Int8FieldData{[]int8{}, nil, false},
Int16Field: &Int16FieldData{[]int16{}, nil, false},
Int32Field: &Int32FieldData{[]int32{}, nil, false},
Int64Field: &Int64FieldData{[]int64{}, nil, false},
FloatField: &FloatFieldData{[]float32{}, nil, false},
DoubleField: &DoubleFieldData{[]float64{}, nil, false},
StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false},
BinaryVectorField: &BinaryVectorFieldData{[]byte{}, 8},
FloatVectorField: &FloatVectorFieldData{[]float32{}, 4},
Float16VectorField: &Float16VectorFieldData{[]byte{}, 4},
BFloat16VectorField: &BFloat16VectorFieldData{[]byte{}, 4},
Int8VectorField: &Int8VectorFieldData{[]int8{}, 4},
RowIDField: &Int64FieldData{[]int64{}, nil, false},
TimestampField: &Int64FieldData{[]int64{}, nil, false},
BoolField: &BoolFieldData{[]bool{}, nil, false},
Int8Field: &Int8FieldData{[]int8{}, nil, false},
Int16Field: &Int16FieldData{[]int16{}, nil, false},
Int32Field: &Int32FieldData{[]int32{}, nil, false},
Int64Field: &Int64FieldData{[]int64{}, nil, false},
FloatField: &FloatFieldData{[]float32{}, nil, false},
DoubleField: &DoubleFieldData{[]float64{}, nil, false},
StringField: &StringFieldData{[]string{}, schemapb.DataType_VarChar, nil, false},
BinaryVectorField: &BinaryVectorFieldData{
Data: []byte{},
ValidData: nil,
Dim: 8,
Nullable: false,
},
FloatVectorField: &FloatVectorFieldData{
Data: []float32{},
ValidData: nil,
Dim: 4,
Nullable: false,
},
Float16VectorField: &Float16VectorFieldData{
Data: []byte{},
ValidData: nil,
Dim: 4,
Nullable: false,
},
BFloat16VectorField: &BFloat16VectorFieldData{
Data: []byte{},
ValidData: nil,
Dim: 4,
Nullable: false,
},
Int8VectorField: &Int8VectorFieldData{
Data: []int8{},
ValidData: nil,
Dim: 4,
Nullable: false,
},
SparseFloatVectorField: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 0,
Contents: [][]byte{},
},
ValidData: nil,
Nullable: false,
},
NullableFloatVectorField: &FloatVectorFieldData{
Data: []float32{},
ValidData: []bool{},
Dim: 4,
Nullable: true,
},
NullableBinaryVectorField: &BinaryVectorFieldData{
Data: []byte{},
ValidData: []bool{},
Dim: 8,
Nullable: true,
},
NullableFloat16VectorField: &Float16VectorFieldData{
Data: []byte{},
ValidData: []bool{},
Dim: 4,
Nullable: true,
},
NullableBFloat16VectorField: &BFloat16VectorFieldData{
Data: []byte{},
ValidData: []bool{},
Dim: 4,
Nullable: true,
},
NullableInt8VectorField: &Int8VectorFieldData{
Data: []int8{},
ValidData: []bool{},
Dim: 4,
Nullable: true,
},
NullableSparseFloatVectorField: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 0,
Contents: [][]byte{},
},
ValidData: []bool{},
Nullable: true,
},
StructSubFloatVectorField: &VectorArrayFieldData{
Dim: 2,
ElementType: schemapb.DataType_FloatVector,
@ -1505,11 +1879,18 @@ func TestMemorySize(t *testing.T) {
assert.Equal(t, insertDataEmpty.Data[FloatField].GetMemorySize(), 1)
assert.Equal(t, insertDataEmpty.Data[DoubleField].GetMemorySize(), 1)
assert.Equal(t, insertDataEmpty.Data[StringField].GetMemorySize(), 1)
assert.Equal(t, insertDataEmpty.Data[BinaryVectorField].GetMemorySize(), 4)
assert.Equal(t, insertDataEmpty.Data[FloatVectorField].GetMemorySize(), 4)
assert.Equal(t, insertDataEmpty.Data[Float16VectorField].GetMemorySize(), 4)
assert.Equal(t, insertDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 4)
assert.Equal(t, insertDataEmpty.Data[Int8VectorField].GetMemorySize(), 4)
assert.Equal(t, insertDataEmpty.Data[BinaryVectorField].GetMemorySize(), 13)
assert.Equal(t, insertDataEmpty.Data[FloatVectorField].GetMemorySize(), 13)
assert.Equal(t, insertDataEmpty.Data[Float16VectorField].GetMemorySize(), 13)
assert.Equal(t, insertDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 13)
assert.Equal(t, insertDataEmpty.Data[Int8VectorField].GetMemorySize(), 13)
assert.Equal(t, insertDataEmpty.Data[SparseFloatVectorField].GetMemorySize(), 9)
assert.Equal(t, insertDataEmpty.Data[NullableFloatVectorField].GetMemorySize(), 13)
assert.Equal(t, insertDataEmpty.Data[NullableBinaryVectorField].GetMemorySize(), 13)
assert.Equal(t, insertDataEmpty.Data[NullableFloat16VectorField].GetMemorySize(), 13)
assert.Equal(t, insertDataEmpty.Data[NullableBFloat16VectorField].GetMemorySize(), 13)
assert.Equal(t, insertDataEmpty.Data[NullableInt8VectorField].GetMemorySize(), 13)
assert.Equal(t, insertDataEmpty.Data[NullableSparseFloatVectorField].GetMemorySize(), 9)
assert.Equal(t, insertDataEmpty.Data[StructSubFloatVectorField].GetMemorySize(), 0)
}
@ -1589,22 +1970,49 @@ func TestAddFieldDataToPayload(t *testing.T) {
assert.Error(t, err)
err = AddFieldDataToPayload(e, schemapb.DataType_JSON, &JSONFieldData{[][]byte{[]byte(`"batch":2}`)}, nil, false})
assert.Error(t, err)
err = AddFieldDataToPayload(e, schemapb.DataType_BinaryVector, &BinaryVectorFieldData{[]byte{}, 8})
err = AddFieldDataToPayload(e, schemapb.DataType_BinaryVector, &BinaryVectorFieldData{
Data: []byte{},
ValidData: nil,
Dim: 8,
Nullable: false,
})
assert.Error(t, err)
err = AddFieldDataToPayload(e, schemapb.DataType_FloatVector, &FloatVectorFieldData{[]float32{}, 4})
err = AddFieldDataToPayload(e, schemapb.DataType_FloatVector, &FloatVectorFieldData{
Data: []float32{},
ValidData: nil,
Dim: 4,
Nullable: false,
})
assert.Error(t, err)
err = AddFieldDataToPayload(e, schemapb.DataType_Float16Vector, &Float16VectorFieldData{[]byte{}, 4})
err = AddFieldDataToPayload(e, schemapb.DataType_Float16Vector, &Float16VectorFieldData{
Data: []byte{},
ValidData: nil,
Dim: 4,
Nullable: false,
})
assert.Error(t, err)
err = AddFieldDataToPayload(e, schemapb.DataType_BFloat16Vector, &BFloat16VectorFieldData{[]byte{}, 8})
err = AddFieldDataToPayload(e, schemapb.DataType_BFloat16Vector, &BFloat16VectorFieldData{
Data: []byte{},
ValidData: nil,
Dim: 8,
Nullable: false,
})
assert.Error(t, err)
err = AddFieldDataToPayload(e, schemapb.DataType_SparseFloatVector, &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 0,
Contents: [][]byte{},
},
ValidData: nil,
Nullable: false,
})
assert.Error(t, err)
err = AddFieldDataToPayload(e, schemapb.DataType_Int8Vector, &Int8VectorFieldData{[]int8{}, 4})
err = AddFieldDataToPayload(e, schemapb.DataType_Int8Vector, &Int8VectorFieldData{
Data: []int8{},
ValidData: nil,
Dim: 4,
Nullable: false,
})
assert.Error(t, err)
err = AddFieldDataToPayload(e, schemapb.DataType_ArrayOfVector, &VectorArrayFieldData{
Dim: 2,

View File

@ -195,70 +195,83 @@ func NewFieldData(dataType schemapb.DataType, fieldSchema *schemapb.FieldSchema,
typeParams := fieldSchema.GetTypeParams()
switch dataType {
case schemapb.DataType_Float16Vector:
if fieldSchema.GetNullable() {
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
}
dim, err := GetDimFromParams(typeParams)
if err != nil {
return nil, err
}
return &Float16VectorFieldData{
Data: make([]byte, 0, cap),
Dim: dim,
}, nil
data := &Float16VectorFieldData{
Data: make([]byte, 0, cap),
Dim: dim,
Nullable: fieldSchema.GetNullable(),
}
if fieldSchema.GetNullable() {
data.ValidData = make([]bool, 0, cap)
}
return data, nil
case schemapb.DataType_BFloat16Vector:
if fieldSchema.GetNullable() {
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
}
dim, err := GetDimFromParams(typeParams)
if err != nil {
return nil, err
}
return &BFloat16VectorFieldData{
Data: make([]byte, 0, cap),
Dim: dim,
}, nil
data := &BFloat16VectorFieldData{
Data: make([]byte, 0, cap),
Dim: dim,
Nullable: fieldSchema.GetNullable(),
}
if fieldSchema.GetNullable() {
data.ValidData = make([]bool, 0, cap)
}
return data, nil
case schemapb.DataType_FloatVector:
if fieldSchema.GetNullable() {
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
}
dim, err := GetDimFromParams(typeParams)
if err != nil {
return nil, err
}
return &FloatVectorFieldData{
Data: make([]float32, 0, cap),
Dim: dim,
}, nil
data := &FloatVectorFieldData{
Data: make([]float32, 0, cap),
Dim: dim,
Nullable: fieldSchema.GetNullable(),
}
if fieldSchema.GetNullable() {
data.ValidData = make([]bool, 0, cap)
}
return data, nil
case schemapb.DataType_BinaryVector:
if fieldSchema.GetNullable() {
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
}
dim, err := GetDimFromParams(typeParams)
if err != nil {
return nil, err
}
return &BinaryVectorFieldData{
Data: make([]byte, 0, cap),
Dim: dim,
}, nil
data := &BinaryVectorFieldData{
Data: make([]byte, 0, cap),
Dim: dim,
Nullable: fieldSchema.GetNullable(),
}
if fieldSchema.GetNullable() {
data.ValidData = make([]bool, 0, cap)
}
return data, nil
case schemapb.DataType_SparseFloatVector:
if fieldSchema.GetNullable() {
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
data := &SparseFloatVectorFieldData{
Nullable: fieldSchema.GetNullable(),
}
return &SparseFloatVectorFieldData{}, nil
if fieldSchema.GetNullable() {
data.ValidData = make([]bool, 0, cap)
}
return data, nil
case schemapb.DataType_Int8Vector:
if fieldSchema.GetNullable() {
return nil, merr.WrapErrParameterInvalidMsg("vector not support null")
}
dim, err := GetDimFromParams(typeParams)
if err != nil {
return nil, err
}
return &Int8VectorFieldData{
Data: make([]int8, 0, cap),
Dim: dim,
}, nil
data := &Int8VectorFieldData{
Data: make([]int8, 0, cap),
Dim: dim,
Nullable: fieldSchema.GetNullable(),
}
if fieldSchema.GetNullable() {
data.ValidData = make([]bool, 0, cap)
}
return data, nil
case schemapb.DataType_Bool:
data := &BoolFieldData{
Data: make([]bool, 0, cap),
@ -453,30 +466,97 @@ type GeometryFieldData struct {
ValidData []bool
Nullable bool
}
// LogicalToPhysicalMapping maps logical offset to physical offset for nullable vector
type LogicalToPhysicalMapping struct {
validCount int
l2pMap map[int]int
}
func (m *LogicalToPhysicalMapping) GetPhysicalOffset(logicalOffset int) int {
if m.l2pMap == nil {
return logicalOffset
}
if physicalOffset, ok := m.l2pMap[logicalOffset]; ok {
return physicalOffset
}
return -1
}
func (m *LogicalToPhysicalMapping) GetMemorySize() int {
size := 8 // validCount int
size += len(m.l2pMap) * 16 // map[int]int, roughly 16 bytes per entry
return size
}
func (m *LogicalToPhysicalMapping) GetValidCount() int {
return m.validCount
}
func (m *LogicalToPhysicalMapping) Build(validData []bool, startLogical, totalCount int) {
if totalCount == 0 {
return
}
if len(validData) < totalCount {
return
}
if m.l2pMap == nil {
m.l2pMap = make(map[int]int)
}
physicalIdx := m.validCount
for i := 0; i < totalCount; i++ {
if validData[i] {
m.l2pMap[startLogical+i] = physicalIdx
physicalIdx++
}
}
m.validCount = physicalIdx
}
type BinaryVectorFieldData struct {
Data []byte
Dim int
Data []byte
ValidData []bool
Dim int
Nullable bool
L2PMapping LogicalToPhysicalMapping
}
type FloatVectorFieldData struct {
Data []float32
Dim int
Data []float32
ValidData []bool
Dim int
Nullable bool
L2PMapping LogicalToPhysicalMapping
}
type Float16VectorFieldData struct {
Data []byte
Dim int
Data []byte
ValidData []bool
Dim int
Nullable bool
L2PMapping LogicalToPhysicalMapping
}
type BFloat16VectorFieldData struct {
Data []byte
Dim int
Data []byte
ValidData []bool
Dim int
Nullable bool
L2PMapping LogicalToPhysicalMapping
}
type SparseFloatVectorFieldData struct {
schemapb.SparseFloatArray
ValidData []bool
Nullable bool
L2PMapping LogicalToPhysicalMapping
}
type Int8VectorFieldData struct {
Data []int8
Dim int
Data []int8
ValidData []bool
Dim int
Nullable bool
L2PMapping LogicalToPhysicalMapping
}
type VectorArrayFieldData struct {
@ -493,29 +573,71 @@ func (dst *SparseFloatVectorFieldData) AppendAllRows(src *SparseFloatVectorField
dst.Dim = src.Dim
}
dst.Contents = append(dst.Contents, src.Contents...)
if src.Nullable {
if dst.ValidData == nil {
dst.ValidData = make([]bool, 0, len(src.ValidData))
}
dst.L2PMapping.Build(src.ValidData, len(dst.ValidData), len(src.ValidData))
dst.ValidData = append(dst.ValidData, src.ValidData...)
dst.Nullable = true
}
}
// RowNum implements FieldData.RowNum
func (data *BoolFieldData) RowNum() int { return len(data.Data) }
func (data *Int8FieldData) RowNum() int { return len(data.Data) }
func (data *Int16FieldData) RowNum() int { return len(data.Data) }
func (data *Int32FieldData) RowNum() int { return len(data.Data) }
func (data *Int64FieldData) RowNum() int { return len(data.Data) }
func (data *FloatFieldData) RowNum() int { return len(data.Data) }
func (data *DoubleFieldData) RowNum() int { return len(data.Data) }
func (data *TimestamptzFieldData) RowNum() int { return len(data.Data) }
func (data *StringFieldData) RowNum() int { return len(data.Data) }
func (data *ArrayFieldData) RowNum() int { return len(data.Data) }
func (data *JSONFieldData) RowNum() int { return len(data.Data) }
func (data *GeometryFieldData) RowNum() int { return len(data.Data) }
func (data *BinaryVectorFieldData) RowNum() int { return len(data.Data) * 8 / data.Dim }
func (data *FloatVectorFieldData) RowNum() int { return len(data.Data) / data.Dim }
func (data *Float16VectorFieldData) RowNum() int { return len(data.Data) / 2 / data.Dim }
func (data *BFloat16VectorFieldData) RowNum() int {
func (data *BoolFieldData) RowNum() int { return len(data.Data) }
func (data *Int8FieldData) RowNum() int { return len(data.Data) }
func (data *Int16FieldData) RowNum() int { return len(data.Data) }
func (data *Int32FieldData) RowNum() int { return len(data.Data) }
func (data *Int64FieldData) RowNum() int { return len(data.Data) }
func (data *FloatFieldData) RowNum() int { return len(data.Data) }
func (data *DoubleFieldData) RowNum() int { return len(data.Data) }
func (data *TimestamptzFieldData) RowNum() int { return len(data.Data) }
func (data *StringFieldData) RowNum() int { return len(data.Data) }
func (data *ArrayFieldData) RowNum() int { return len(data.Data) }
func (data *JSONFieldData) RowNum() int { return len(data.Data) }
func (data *GeometryFieldData) RowNum() int { return len(data.Data) }
func (data *BinaryVectorFieldData) RowNum() int {
if data.Nullable {
return len(data.ValidData)
}
return len(data.Data) * 8 / data.Dim
}
func (data *FloatVectorFieldData) RowNum() int {
if data.Nullable {
return len(data.ValidData)
}
return len(data.Data) / data.Dim
}
func (data *Float16VectorFieldData) RowNum() int {
if data.Nullable {
return len(data.ValidData)
}
return len(data.Data) / 2 / data.Dim
}
func (data *SparseFloatVectorFieldData) RowNum() int { return len(data.Contents) }
func (data *Int8VectorFieldData) RowNum() int { return len(data.Data) / data.Dim }
func (data *BFloat16VectorFieldData) RowNum() int {
if data.Nullable {
return len(data.ValidData)
}
return len(data.Data) / 2 / data.Dim
}
func (data *SparseFloatVectorFieldData) RowNum() int {
if data.Nullable {
return len(data.ValidData)
}
return len(data.Contents)
}
func (data *Int8VectorFieldData) RowNum() int {
if data.Nullable {
return len(data.ValidData)
}
return len(data.Data) / data.Dim
}
func (data *VectorArrayFieldData) RowNum() int {
return len(data.Data)
}
@ -606,27 +728,51 @@ func (data *GeometryFieldData) GetRow(i int) any {
}
func (data *BinaryVectorFieldData) GetRow(i int) any {
return data.Data[i*data.Dim/8 : (i+1)*data.Dim/8]
if data.GetNullable() && !data.ValidData[i] {
return nil
}
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
return data.Data[physicalIdx*data.Dim/8 : (physicalIdx+1)*data.Dim/8]
}
func (data *SparseFloatVectorFieldData) GetRow(i int) interface{} {
return data.Contents[i]
if data.GetNullable() && !data.ValidData[i] {
return nil
}
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
return data.Contents[physicalIdx]
}
func (data *FloatVectorFieldData) GetRow(i int) interface{} {
return data.Data[i*data.Dim : (i+1)*data.Dim]
if data.GetNullable() && !data.ValidData[i] {
return nil
}
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
return data.Data[physicalIdx*data.Dim : (physicalIdx+1)*data.Dim]
}
func (data *Float16VectorFieldData) GetRow(i int) interface{} {
return data.Data[i*data.Dim*2 : (i+1)*data.Dim*2]
if data.GetNullable() && !data.ValidData[i] {
return nil
}
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
return data.Data[physicalIdx*data.Dim*2 : (physicalIdx+1)*data.Dim*2]
}
func (data *BFloat16VectorFieldData) GetRow(i int) interface{} {
return data.Data[i*data.Dim*2 : (i+1)*data.Dim*2]
if data.GetNullable() && !data.ValidData[i] {
return nil
}
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
return data.Data[physicalIdx*data.Dim*2 : (physicalIdx+1)*data.Dim*2]
}
func (data *Int8VectorFieldData) GetRow(i int) interface{} {
return data.Data[i*data.Dim : (i+1)*data.Dim]
if data.GetNullable() && !data.ValidData[i] {
return nil
}
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
return data.Data[physicalIdx*data.Dim : (physicalIdx+1)*data.Dim]
}
func (data *VectorArrayFieldData) GetRow(i int) interface{} {
@ -862,42 +1008,83 @@ func (data *GeometryFieldData) AppendRow(row interface{}) error {
}
func (data *BinaryVectorFieldData) AppendRow(row interface{}) error {
if data.GetNullable() && row == nil {
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, false)
return nil
}
v, ok := row.([]byte)
if !ok || len(v) != data.Dim/8 {
return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type")
}
data.Data = append(data.Data, v...)
if data.GetNullable() {
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, true)
}
return nil
}
func (data *FloatVectorFieldData) AppendRow(row interface{}) error {
if data.GetNullable() && row == nil {
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, false)
return nil
}
v, ok := row.([]float32)
if !ok || len(v) != data.Dim {
return merr.WrapErrParameterInvalid("[]float32", row, "Wrong row type")
}
data.Data = append(data.Data, v...)
if data.GetNullable() {
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, true)
}
return nil
}
func (data *Float16VectorFieldData) AppendRow(row interface{}) error {
if data.GetNullable() && row == nil {
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, false)
return nil
}
v, ok := row.([]byte)
if !ok || len(v) != data.Dim*2 {
return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type")
}
data.Data = append(data.Data, v...)
if data.GetNullable() {
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, true)
}
return nil
}
func (data *BFloat16VectorFieldData) AppendRow(row interface{}) error {
if data.GetNullable() && row == nil {
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, false)
return nil
}
v, ok := row.([]byte)
if !ok || len(v) != data.Dim*2 {
return merr.WrapErrParameterInvalid("[]byte", row, "Wrong row type")
}
data.Data = append(data.Data, v...)
if data.GetNullable() {
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, true)
}
return nil
}
func (data *SparseFloatVectorFieldData) AppendRow(row interface{}) error {
if data.GetNullable() && row == nil {
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, false)
return nil
}
v, ok := row.([]byte)
if !ok {
return merr.WrapErrParameterInvalid("SparseFloatVectorRowData", row, "Wrong row type")
@ -910,15 +1097,28 @@ func (data *SparseFloatVectorFieldData) AppendRow(row interface{}) error {
data.Dim = rowDim
}
data.Contents = append(data.Contents, v)
if data.GetNullable() {
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, true)
}
return nil
}
func (data *Int8VectorFieldData) AppendRow(row interface{}) error {
if data.GetNullable() && row == nil {
data.L2PMapping.Build([]bool{false}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, false)
return nil
}
v, ok := row.([]int8)
if !ok || len(v) != data.Dim {
return merr.WrapErrParameterInvalid("[]int8", row, "Wrong row type")
}
data.Data = append(data.Data, v...)
if data.GetNullable() {
data.L2PMapping.Build([]bool{true}, len(data.ValidData), 1)
data.ValidData = append(data.ValidData, true)
}
return nil
}
@ -1431,15 +1631,15 @@ func (data *GeometryFieldData) AppendValidDataRows(rows interface{}) error {
// AppendValidDataRows appends FLATTEN vectors to field data.
func (data *BinaryVectorFieldData) AppendValidDataRows(rows interface{}) error {
if rows != nil {
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
if len(v) != 0 {
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
}
if rows == nil {
return nil
}
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
data.L2PMapping.Build(v, len(data.ValidData), len(v))
data.ValidData = append(data.ValidData, v...)
return nil
}
@ -1458,69 +1658,69 @@ func (data *VectorArrayFieldData) AppendValidDataRows(rows interface{}) error {
// AppendValidDataRows appends FLATTEN vectors to field data.
func (data *FloatVectorFieldData) AppendValidDataRows(rows interface{}) error {
if rows != nil {
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
if len(v) != 0 {
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
}
if rows == nil {
return nil
}
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
data.L2PMapping.Build(v, len(data.ValidData), len(v))
data.ValidData = append(data.ValidData, v...)
return nil
}
// AppendValidDataRows appends FLATTEN vectors to field data.
func (data *Float16VectorFieldData) AppendValidDataRows(rows interface{}) error {
if rows != nil {
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
if len(v) != 0 {
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
}
if rows == nil {
return nil
}
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
data.L2PMapping.Build(v, len(data.ValidData), len(v))
data.ValidData = append(data.ValidData, v...)
return nil
}
// AppendValidDataRows appends FLATTEN vectors to field data.
func (data *BFloat16VectorFieldData) AppendValidDataRows(rows interface{}) error {
if rows != nil {
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
if len(v) != 0 {
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
}
if rows == nil {
return nil
}
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
data.L2PMapping.Build(v, len(data.ValidData), len(v))
data.ValidData = append(data.ValidData, v...)
return nil
}
func (data *SparseFloatVectorFieldData) AppendValidDataRows(rows interface{}) error {
if rows != nil {
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
if len(v) != 0 {
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
}
if rows == nil {
return nil
}
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
data.L2PMapping.Build(v, len(data.ValidData), len(v))
data.ValidData = append(data.ValidData, v...)
return nil
}
func (data *Int8VectorFieldData) AppendValidDataRows(rows interface{}) error {
if rows != nil {
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
if len(v) != 0 {
return merr.WrapErrParameterInvalidMsg("not support Nullable in vector")
}
if rows == nil {
return nil
}
v, ok := rows.([]bool)
if !ok {
return merr.WrapErrParameterInvalid("[]bool", rows, "Wrong rows type")
}
data.L2PMapping.Build(v, len(data.ValidData), len(v))
data.ValidData = append(data.ValidData, v...)
return nil
}
@ -1557,17 +1757,32 @@ func (data *TimestamptzFieldData) GetMemorySize() int {
return binary.Size(data.Data) + binary.Size(data.ValidData) + binary.Size(data.Nullable)
}
func (data *BinaryVectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 }
func (data *FloatVectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 }
func (data *Float16VectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 }
func (data *BFloat16VectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 }
func (data *BinaryVectorFieldData) GetMemorySize() int {
// Data + ValidData + Dim(4) + Nullable(1) + L2PMapping
return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize()
}
func (data *FloatVectorFieldData) GetMemorySize() int {
return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize()
}
func (data *Float16VectorFieldData) GetMemorySize() int {
return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize()
}
func (data *BFloat16VectorFieldData) GetMemorySize() int {
return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize()
}
func (data *SparseFloatVectorFieldData) GetMemorySize() int {
// TODO(SPARSE): should this be the memory size of serialzied size?
return proto.Size(&data.SparseFloatArray)
// SparseFloatArray + ValidData + Nullable(1) + L2PMapping
return proto.Size(&data.SparseFloatArray) + binary.Size(data.ValidData) + 1 + data.L2PMapping.GetMemorySize()
}
func (data *Int8VectorFieldData) GetMemorySize() int { return binary.Size(data.Data) + 4 }
func (data *Int8VectorFieldData) GetMemorySize() int {
return binary.Size(data.Data) + binary.Size(data.ValidData) + 4 + 1 + data.L2PMapping.GetMemorySize()
}
func GetVectorSize(vector *schemapb.VectorField, vectorType schemapb.DataType) int {
size := 0
@ -1698,22 +1913,51 @@ func (data *GeometryFieldData) GetMemorySize() int {
return size + binary.Size(data.ValidData) + binary.Size(data.Nullable)
}
func (data *BoolFieldData) GetRowSize(i int) int { return 1 }
func (data *Int8FieldData) GetRowSize(i int) int { return 1 }
func (data *Int16FieldData) GetRowSize(i int) int { return 2 }
func (data *Int32FieldData) GetRowSize(i int) int { return 4 }
func (data *Int64FieldData) GetRowSize(i int) int { return 8 }
func (data *FloatFieldData) GetRowSize(i int) int { return 4 }
func (data *DoubleFieldData) GetRowSize(i int) int { return 8 }
func (data *TimestamptzFieldData) GetRowSize(i int) int { return 8 }
func (data *BinaryVectorFieldData) GetRowSize(i int) int { return data.Dim / 8 }
func (data *FloatVectorFieldData) GetRowSize(i int) int { return data.Dim * 4 }
func (data *Float16VectorFieldData) GetRowSize(i int) int { return data.Dim * 2 }
func (data *BFloat16VectorFieldData) GetRowSize(i int) int { return data.Dim * 2 }
func (data *Int8VectorFieldData) GetRowSize(i int) int { return data.Dim }
func (data *StringFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
func (data *JSONFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
func (data *GeometryFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
func (data *BoolFieldData) GetRowSize(i int) int { return 1 }
func (data *Int8FieldData) GetRowSize(i int) int { return 1 }
func (data *Int16FieldData) GetRowSize(i int) int { return 2 }
func (data *Int32FieldData) GetRowSize(i int) int { return 4 }
func (data *Int64FieldData) GetRowSize(i int) int { return 8 }
func (data *FloatFieldData) GetRowSize(i int) int { return 4 }
func (data *DoubleFieldData) GetRowSize(i int) int { return 8 }
func (data *TimestamptzFieldData) GetRowSize(i int) int { return 8 }
func (data *BinaryVectorFieldData) GetRowSize(i int) int {
if data.GetNullable() && !data.ValidData[i] {
return 0
}
return data.Dim / 8
}
func (data *FloatVectorFieldData) GetRowSize(i int) int {
if data.GetNullable() && !data.ValidData[i] {
return 0
}
return data.Dim * 4
}
func (data *Float16VectorFieldData) GetRowSize(i int) int {
if data.GetNullable() && !data.ValidData[i] {
return 0
}
return data.Dim * 2
}
func (data *BFloat16VectorFieldData) GetRowSize(i int) int {
if data.GetNullable() && !data.ValidData[i] {
return 0
}
return data.Dim * 2
}
func (data *Int8VectorFieldData) GetRowSize(i int) int {
if data.GetNullable() && !data.ValidData[i] {
return 0
}
return data.Dim
}
func (data *StringFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
func (data *JSONFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
func (data *GeometryFieldData) GetRowSize(i int) int { return len(data.Data[i]) + 16 }
func (data *ArrayFieldData) GetRowSize(i int) int {
switch data.ElementType {
case schemapb.DataType_Bool:
@ -1737,7 +1981,11 @@ func (data *ArrayFieldData) GetRowSize(i int) int {
}
func (data *SparseFloatVectorFieldData) GetRowSize(i int) int {
return len(data.Contents[i])
if data.GetNullable() && !data.ValidData[i] {
return 0
}
physicalIdx := data.L2PMapping.GetPhysicalOffset(i)
return len(data.Contents[physicalIdx])
}
func (data *VectorArrayFieldData) GetRowSize(i int) int {
@ -1777,27 +2025,27 @@ func (data *TimestamptzFieldData) GetNullable() bool {
}
func (data *BFloat16VectorFieldData) GetNullable() bool {
return false
return data.Nullable
}
func (data *BinaryVectorFieldData) GetNullable() bool {
return false
return data.Nullable
}
func (data *FloatVectorFieldData) GetNullable() bool {
return false
return data.Nullable
}
func (data *SparseFloatVectorFieldData) GetNullable() bool {
return false
return data.Nullable
}
func (data *Float16VectorFieldData) GetNullable() bool {
return false
return data.Nullable
}
func (data *Int8VectorFieldData) GetNullable() bool {
return false
return data.Nullable
}
func (data *StringFieldData) GetNullable() bool {
@ -1819,3 +2067,23 @@ func (data *VectorArrayFieldData) GetNullable() bool {
func (data *GeometryFieldData) GetNullable() bool {
return data.Nullable
}
func (data *BoolFieldData) GetValidData() []bool { return data.ValidData }
func (data *Int8FieldData) GetValidData() []bool { return data.ValidData }
func (data *Int16FieldData) GetValidData() []bool { return data.ValidData }
func (data *Int32FieldData) GetValidData() []bool { return data.ValidData }
func (data *Int64FieldData) GetValidData() []bool { return data.ValidData }
func (data *FloatFieldData) GetValidData() []bool { return data.ValidData }
func (data *DoubleFieldData) GetValidData() []bool { return data.ValidData }
func (data *TimestamptzFieldData) GetValidData() []bool { return data.ValidData }
func (data *StringFieldData) GetValidData() []bool { return data.ValidData }
func (data *ArrayFieldData) GetValidData() []bool { return data.ValidData }
func (data *JSONFieldData) GetValidData() []bool { return data.ValidData }
func (data *GeometryFieldData) GetValidData() []bool { return data.ValidData }
func (data *BinaryVectorFieldData) GetValidData() []bool { return data.ValidData }
func (data *FloatVectorFieldData) GetValidData() []bool { return data.ValidData }
func (data *Float16VectorFieldData) GetValidData() []bool { return data.ValidData }
func (data *BFloat16VectorFieldData) GetValidData() []bool { return data.ValidData }
func (data *SparseFloatVectorFieldData) GetValidData() []bool { return data.ValidData }
func (data *Int8VectorFieldData) GetValidData() []bool { return data.ValidData }
func (data *VectorArrayFieldData) GetValidData() []bool { return nil }

View File

@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -45,17 +46,31 @@ func (s *InsertDataSuite) TestInsertData() {
tests := []struct {
description string
dataType schemapb.DataType
typeParams []*commonpb.KeyValuePair
nullable bool
}{
{"nullable bool field", schemapb.DataType_Bool},
{"nullable int8 field", schemapb.DataType_Int8},
{"nullable int16 field", schemapb.DataType_Int16},
{"nullable int32 field", schemapb.DataType_Int32},
{"nullable int64 field", schemapb.DataType_Int64},
{"nullable float field", schemapb.DataType_Float},
{"nullable double field", schemapb.DataType_Double},
{"nullable json field", schemapb.DataType_JSON},
{"nullable array field", schemapb.DataType_Array},
{"nullable string/varchar field", schemapb.DataType_String},
{"nullable bool field", schemapb.DataType_Bool, nil, true},
{"nullable int8 field", schemapb.DataType_Int8, nil, true},
{"nullable int16 field", schemapb.DataType_Int16, nil, true},
{"nullable int32 field", schemapb.DataType_Int32, nil, true},
{"nullable int64 field", schemapb.DataType_Int64, nil, true},
{"nullable float field", schemapb.DataType_Float, nil, true},
{"nullable double field", schemapb.DataType_Double, nil, true},
{"nullable json field", schemapb.DataType_JSON, nil, true},
{"nullable array field", schemapb.DataType_Array, nil, true},
{"nullable string/varchar field", schemapb.DataType_String, nil, true},
{"nullable binary vector field", schemapb.DataType_BinaryVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "8"}}, true},
{"nullable float vector field", schemapb.DataType_FloatVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true},
{"nullable float16 vector field", schemapb.DataType_Float16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true},
{"nullable bfloat16 vector field", schemapb.DataType_BFloat16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true},
{"nullable sparse float vector field", schemapb.DataType_SparseFloatVector, nil, true},
{"nullable int8 vector field", schemapb.DataType_Int8Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, true},
{"non-nullable binary vector field", schemapb.DataType_BinaryVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "8"}}, false},
{"non-nullable float vector field", schemapb.DataType_FloatVector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false},
{"non-nullable float16 vector field", schemapb.DataType_Float16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false},
{"non-nullable bfloat16 vector field", schemapb.DataType_BFloat16Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false},
{"non-nullable sparse float vector field", schemapb.DataType_SparseFloatVector, nil, false},
{"non-nullable int8 vector field", schemapb.DataType_Int8Vector, []*commonpb.KeyValuePair{{Key: "dim", Value: "4"}}, false},
}
for _, test := range tests {
@ -63,8 +78,9 @@ func (s *InsertDataSuite) TestInsertData() {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
DataType: test.dataType,
Nullable: true,
DataType: test.dataType,
Nullable: test.nullable,
TypeParams: test.typeParams,
},
},
}
@ -115,15 +131,15 @@ func (s *InsertDataSuite) TestInsertData() {
s.Run("init by New", func() {
s.True(s.iDataEmpty.IsEmpty())
s.Equal(0, s.iDataEmpty.GetRowNum())
s.Equal(33, s.iDataEmpty.GetMemorySize())
s.Equal(161, s.iDataEmpty.GetMemorySize())
s.False(s.iDataOneRow.IsEmpty())
s.Equal(1, s.iDataOneRow.GetRowNum())
s.Equal(240, s.iDataOneRow.GetMemorySize())
s.Equal(535, s.iDataOneRow.GetMemorySize())
s.False(s.iDataTwoRows.IsEmpty())
s.Equal(2, s.iDataTwoRows.GetRowNum())
s.Equal(433, s.iDataTwoRows.GetMemorySize())
s.Equal(734, s.iDataTwoRows.GetMemorySize())
for _, field := range s.iDataTwoRows.Data {
s.Equal(2, field.RowNum())
@ -147,12 +163,13 @@ func (s *InsertDataSuite) TestMemorySize() {
s.Equal(s.iDataEmpty.Data[DoubleField].GetMemorySize(), 1)
s.Equal(s.iDataEmpty.Data[StringField].GetMemorySize(), 1)
s.Equal(s.iDataEmpty.Data[ArrayField].GetMemorySize(), 1)
s.Equal(s.iDataEmpty.Data[BinaryVectorField].GetMemorySize(), 4)
s.Equal(s.iDataEmpty.Data[FloatVectorField].GetMemorySize(), 4)
s.Equal(s.iDataEmpty.Data[Float16VectorField].GetMemorySize(), 4)
s.Equal(s.iDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 4)
s.Equal(s.iDataEmpty.Data[SparseFloatVectorField].GetMemorySize(), 0)
s.Equal(s.iDataEmpty.Data[Int8VectorField].GetMemorySize(), 4)
// +9 bytes: Nullable(1) + L2PMapping.GetMemorySize()(8)
s.Equal(s.iDataEmpty.Data[BinaryVectorField].GetMemorySize(), 4+9)
s.Equal(s.iDataEmpty.Data[FloatVectorField].GetMemorySize(), 4+9)
s.Equal(s.iDataEmpty.Data[Float16VectorField].GetMemorySize(), 4+9)
s.Equal(s.iDataEmpty.Data[BFloat16VectorField].GetMemorySize(), 4+9)
s.Equal(s.iDataEmpty.Data[SparseFloatVectorField].GetMemorySize(), 0+9)
s.Equal(s.iDataEmpty.Data[Int8VectorField].GetMemorySize(), 4+9)
s.Equal(s.iDataEmpty.Data[StructSubInt32Field].GetMemorySize(), 1)
s.Equal(s.iDataEmpty.Data[StructSubFloatVectorField].GetMemorySize(), 0)
@ -168,12 +185,13 @@ func (s *InsertDataSuite) TestMemorySize() {
s.Equal(s.iDataOneRow.Data[StringField].GetMemorySize(), 20)
s.Equal(s.iDataOneRow.Data[JSONField].GetMemorySize(), len([]byte(`{"batch":1}`))+16+1)
s.Equal(s.iDataOneRow.Data[ArrayField].GetMemorySize(), 3*4+1)
s.Equal(s.iDataOneRow.Data[BinaryVectorField].GetMemorySize(), 5)
s.Equal(s.iDataOneRow.Data[FloatVectorField].GetMemorySize(), 20)
s.Equal(s.iDataOneRow.Data[Float16VectorField].GetMemorySize(), 12)
s.Equal(s.iDataOneRow.Data[BFloat16VectorField].GetMemorySize(), 12)
s.Equal(s.iDataOneRow.Data[SparseFloatVectorField].GetMemorySize(), 28)
s.Equal(s.iDataOneRow.Data[Int8VectorField].GetMemorySize(), 8)
// +9 bytes: Nullable(1) + L2PMapping.GetMemorySize()(8)
s.Equal(s.iDataOneRow.Data[BinaryVectorField].GetMemorySize(), 5+9)
s.Equal(s.iDataOneRow.Data[FloatVectorField].GetMemorySize(), 20+9)
s.Equal(s.iDataOneRow.Data[Float16VectorField].GetMemorySize(), 12+9)
s.Equal(s.iDataOneRow.Data[BFloat16VectorField].GetMemorySize(), 12+9)
s.Equal(s.iDataOneRow.Data[SparseFloatVectorField].GetMemorySize(), 28+9)
s.Equal(s.iDataOneRow.Data[Int8VectorField].GetMemorySize(), 8+9)
s.Equal(s.iDataOneRow.Data[StructSubInt32Field].GetMemorySize(), 3*4+1)
s.Equal(s.iDataOneRow.Data[StructSubFloatVectorField].GetMemorySize(), 3*4*2+4)
@ -188,12 +206,13 @@ func (s *InsertDataSuite) TestMemorySize() {
s.Equal(s.iDataTwoRows.Data[DoubleField].GetMemorySize(), 17)
s.Equal(s.iDataTwoRows.Data[StringField].GetMemorySize(), 39)
s.Equal(s.iDataTwoRows.Data[ArrayField].GetMemorySize(), 25)
s.Equal(s.iDataTwoRows.Data[BinaryVectorField].GetMemorySize(), 6)
s.Equal(s.iDataTwoRows.Data[FloatVectorField].GetMemorySize(), 36)
s.Equal(s.iDataTwoRows.Data[Float16VectorField].GetMemorySize(), 20)
s.Equal(s.iDataTwoRows.Data[BFloat16VectorField].GetMemorySize(), 20)
s.Equal(s.iDataTwoRows.Data[SparseFloatVectorField].GetMemorySize(), 54)
s.Equal(s.iDataTwoRows.Data[Int8VectorField].GetMemorySize(), 12)
// +9 bytes: Nullable(1) + L2PMapping.GetMemorySize()(8)
s.Equal(s.iDataTwoRows.Data[BinaryVectorField].GetMemorySize(), 6+9)
s.Equal(s.iDataTwoRows.Data[FloatVectorField].GetMemorySize(), 36+9)
s.Equal(s.iDataTwoRows.Data[Float16VectorField].GetMemorySize(), 20+9)
s.Equal(s.iDataTwoRows.Data[BFloat16VectorField].GetMemorySize(), 20+9)
s.Equal(s.iDataTwoRows.Data[SparseFloatVectorField].GetMemorySize(), 54+9)
s.Equal(s.iDataTwoRows.Data[Int8VectorField].GetMemorySize(), 12+9)
s.Equal(s.iDataTwoRows.Data[StructSubInt32Field].GetMemorySize(), 3*4+2*4+1)
s.Equal(s.iDataTwoRows.Data[StructSubFloatVectorField].GetMemorySize(), 3*4*2+4+2*4*2+4)
}
@ -252,25 +271,31 @@ func (s *InsertDataSuite) SetupTest() {
s.Require().NoError(err)
s.True(s.iDataEmpty.IsEmpty())
s.Equal(0, s.iDataEmpty.GetRowNum())
s.Equal(33, s.iDataEmpty.GetMemorySize())
s.Equal(161, s.iDataEmpty.GetMemorySize())
row1 := map[FieldID]interface{}{
RowIDField: int64(3),
TimestampField: int64(3),
BoolField: true,
Int8Field: int8(3),
Int16Field: int16(3),
Int32Field: int32(3),
Int64Field: int64(3),
FloatField: float32(3),
DoubleField: float64(3),
StringField: "str",
BinaryVectorField: []byte{0},
FloatVectorField: []float32{4, 5, 6, 7},
Float16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255},
BFloat16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255},
SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
Int8VectorField: []int8{-4, -5, 6, 7},
RowIDField: int64(3),
TimestampField: int64(3),
BoolField: true,
Int8Field: int8(3),
Int16Field: int16(3),
Int32Field: int32(3),
Int64Field: int64(3),
FloatField: float32(3),
DoubleField: float64(3),
StringField: "str",
BinaryVectorField: []byte{0},
FloatVectorField: []float32{4, 5, 6, 7},
Float16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255},
BFloat16VectorField: []byte{0, 0, 0, 0, 255, 255, 255, 255},
SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
Int8VectorField: []int8{-4, -5, 6, 7},
NullableFloatVectorField: []float32{1.0, 2.0, 3.0, 4.0},
NullableBinaryVectorField: []byte{1},
NullableFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
NullableBFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
NullableInt8VectorField: []int8{1, 2, 3, 4},
NullableSparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{4, 5, 6}),
ArrayField: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}},
@ -300,22 +325,28 @@ func (s *InsertDataSuite) SetupTest() {
}
row2 := map[FieldID]interface{}{
RowIDField: int64(1),
TimestampField: int64(1),
BoolField: false,
Int8Field: int8(1),
Int16Field: int16(1),
Int32Field: int32(1),
Int64Field: int64(1),
FloatField: float32(1),
DoubleField: float64(1),
StringField: string("str"),
BinaryVectorField: []byte{0},
FloatVectorField: []float32{4, 5, 6, 7},
Float16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
BFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{2, 3, 4}, []float32{4, 5, 6}),
Int8VectorField: []int8{-128, -5, 6, 127},
RowIDField: int64(1),
TimestampField: int64(1),
BoolField: false,
Int8Field: int8(1),
Int16Field: int16(1),
Int32Field: int32(1),
Int64Field: int64(1),
FloatField: float32(1),
DoubleField: float64(1),
StringField: string("str"),
BinaryVectorField: []byte{0},
FloatVectorField: []float32{4, 5, 6, 7},
Float16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
BFloat16VectorField: []byte{1, 2, 3, 4, 5, 6, 7, 8},
SparseFloatVectorField: typeutil.CreateSparseFloatRow([]uint32{2, 3, 4}, []float32{4, 5, 6}),
Int8VectorField: []int8{-128, -5, 6, 127},
NullableFloatVectorField: nil,
NullableBinaryVectorField: nil,
NullableFloat16VectorField: nil,
NullableBFloat16VectorField: nil,
NullableInt8VectorField: nil,
NullableSparseFloatVectorField: nil,
ArrayField: &schemapb.ScalarField{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}},

View File

@ -40,12 +40,12 @@ type PayloadWriterInterface interface {
AddOneArrayToPayload(*schemapb.ScalarField, bool) error
AddOneJSONToPayload([]byte, bool) error
AddOneGeometryToPayload(msg []byte, isValid bool) error
AddBinaryVectorToPayload([]byte, int) error
AddFloatVectorToPayload([]float32, int) error
AddFloat16VectorToPayload([]byte, int) error
AddBFloat16VectorToPayload([]byte, int) error
AddBinaryVectorToPayload(data []byte, dim int, validData []bool) error
AddFloatVectorToPayload(data []float32, dim int, validData []bool) error
AddFloat16VectorToPayload(data []byte, dim int, validData []bool) error
AddBFloat16VectorToPayload(data []byte, dim int, validData []bool) error
AddSparseFloatVectorToPayload(*SparseFloatVectorFieldData) error
AddInt8VectorToPayload([]int8, int) error
AddInt8VectorToPayload(data []int8, dim int, validData []bool) error
AddVectorArrayFieldDataToPayload(*VectorArrayFieldData) error
FinishPayloadWriter() error
GetPayloadBufferFromWriter() ([]byte, error)
@ -72,12 +72,12 @@ type PayloadReaderInterface interface {
GetVectorArrayFromPayload() ([]*schemapb.VectorField, error)
GetJSONFromPayload() ([][]byte, []bool, error)
GetGeometryFromPayload() ([][]byte, []bool, error)
GetBinaryVectorFromPayload() ([]byte, int, error)
GetFloat16VectorFromPayload() ([]byte, int, error)
GetBFloat16VectorFromPayload() ([]byte, int, error)
GetFloatVectorFromPayload() ([]float32, int, error)
GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, error)
GetInt8VectorFromPayload() ([]int8, int, error)
GetBinaryVectorFromPayload() ([]byte, int, []bool, int, error)
GetFloat16VectorFromPayload() ([]byte, int, []bool, int, error)
GetBFloat16VectorFromPayload() ([]byte, int, []bool, int, error)
GetFloatVectorFromPayload() ([]float32, int, []bool, int, error)
GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, []bool, error)
GetInt8VectorFromPayload() ([]int8, int, []bool, int, error)
GetPayloadLengthFromReader() (int, error)
GetByteArrayDataSet() (*DataSet[parquet.ByteArray, *file.ByteArrayColumnChunkReader], error)

View File

@ -149,23 +149,23 @@ func (r *PayloadReader) GetDataFromPayload() (interface{}, []bool, int, error) {
val, validData, err := r.GetTimestamptzFromPayload()
return val, validData, 0, err
case schemapb.DataType_BinaryVector:
val, dim, err := r.GetBinaryVectorFromPayload()
return val, nil, dim, err
val, dim, validData, _, err := r.GetBinaryVectorFromPayload()
return val, validData, dim, err
case schemapb.DataType_FloatVector:
val, dim, err := r.GetFloatVectorFromPayload()
return val, nil, dim, err
val, dim, validData, _, err := r.GetFloatVectorFromPayload()
return val, validData, dim, err
case schemapb.DataType_Float16Vector:
val, dim, err := r.GetFloat16VectorFromPayload()
return val, nil, dim, err
val, dim, validData, _, err := r.GetFloat16VectorFromPayload()
return val, validData, dim, err
case schemapb.DataType_BFloat16Vector:
val, dim, err := r.GetBFloat16VectorFromPayload()
return val, nil, dim, err
val, dim, validData, _, err := r.GetBFloat16VectorFromPayload()
return val, validData, dim, err
case schemapb.DataType_SparseFloatVector:
val, dim, err := r.GetSparseFloatVectorFromPayload()
return val, nil, dim, err
val, dim, validData, err := r.GetSparseFloatVectorFromPayload()
return val, validData, dim, err
case schemapb.DataType_Int8Vector:
val, dim, err := r.GetInt8VectorFromPayload()
return val, nil, dim, err
val, dim, validData, _, err := r.GetInt8VectorFromPayload()
return val, validData, dim, err
case schemapb.DataType_String, schemapb.DataType_VarChar:
val, validData, err := r.GetStringFromPayload()
return val, validData, 0, err
@ -681,96 +681,434 @@ func readByteAndConvert[T any](r *PayloadReader, convert func(parquet.ByteArray)
return ret, nil
}
// GetBinaryVectorFromPayload returns vector, dimension, error
func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) {
// GetBinaryVectorFromPayload returns vector, dimension, validData, numRows, error
func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, []bool, int, error) {
if r.colType != schemapb.DataType_BinaryVector {
return nil, -1, fmt.Errorf("failed to get binary vector from datatype %v", r.colType.String())
return nil, -1, nil, 0, fmt.Errorf("failed to get binary vector from datatype %v", r.colType.String())
}
if r.nullable {
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
if err != nil {
return nil, -1, nil, 0, err
}
arrowSchema, err := fileReader.Schema()
if err != nil {
return nil, -1, nil, 0, err
}
if arrowSchema.NumFields() != 1 {
return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields())
}
field := arrowSchema.Field(0)
var dim int
if field.Type.ID() == arrow.BINARY {
if !field.HasMetadata() {
return nil, -1, nil, 0, fmt.Errorf("nullable binary vector field is missing metadata")
}
metadata := field.Metadata
dimStr, ok := metadata.GetValue("dim")
if !ok {
return nil, -1, nil, 0, fmt.Errorf("nullable binary vector metadata missing required 'dim' field")
}
var err error
dim, err = strconv.Atoi(dimStr)
if err != nil {
return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err)
}
dim = dim / 8
} else {
col, err := r.reader.RowGroup(0).Column(0)
if err != nil {
return nil, -1, nil, 0, err
}
dim = col.Descriptor().TypeLength()
}
table, err := fileReader.ReadTable(context.Background())
if err != nil {
return nil, -1, nil, 0, err
}
defer table.Release()
if table.NumCols() != 1 {
return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols())
}
column := table.Column(0)
validCount := 0
for _, chunk := range column.Data().Chunks() {
for i := 0; i < chunk.Len(); i++ {
if chunk.IsValid(i) {
validCount++
}
}
}
ret := make([]byte, validCount*dim)
validData := make([]bool, r.numRows)
offset := 0
dataIdx := 0
for _, chunk := range column.Data().Chunks() {
binaryArray, ok := chunk.(*array.Binary)
if !ok {
return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk)
}
for i := 0; i < binaryArray.Len(); i++ {
if binaryArray.IsValid(i) {
validData[offset+i] = true
bytes := binaryArray.Value(i)
copy(ret[dataIdx*dim:(dataIdx+1)*dim], bytes)
dataIdx++
}
}
offset += binaryArray.Len()
}
return ret, dim * 8, validData, int(r.numRows), nil
}
col, err := r.reader.RowGroup(0).Column(0)
if err != nil {
return nil, -1, err
return nil, -1, nil, 0, err
}
dim := col.Descriptor().TypeLength()
values := make([]parquet.FixedLenByteArray, r.numRows)
valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
if err != nil {
return nil, -1, err
return nil, -1, nil, 0, err
}
if valuesRead != r.numRows {
return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
}
ret := make([]byte, int64(dim)*r.numRows)
for i := 0; i < int(r.numRows); i++ {
copy(ret[i*dim:(i+1)*dim], values[i])
}
return ret, dim * 8, nil
return ret, dim * 8, nil, int(r.numRows), nil
}
// GetFloat16VectorFromPayload returns vector, dimension, error
func (r *PayloadReader) GetFloat16VectorFromPayload() ([]byte, int, error) {
// GetFloat16VectorFromPayload returns vector, dimension, validData, numRows, error
func (r *PayloadReader) GetFloat16VectorFromPayload() ([]byte, int, []bool, int, error) {
if r.colType != schemapb.DataType_Float16Vector {
return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String())
return nil, -1, nil, 0, fmt.Errorf("failed to get float16 vector from datatype %v", r.colType.String())
}
if r.nullable {
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
if err != nil {
return nil, -1, nil, 0, err
}
arrowSchema, err := fileReader.Schema()
if err != nil {
return nil, -1, nil, 0, err
}
if arrowSchema.NumFields() != 1 {
return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields())
}
field := arrowSchema.Field(0)
var dim int
if field.Type.ID() == arrow.BINARY {
if !field.HasMetadata() {
return nil, -1, nil, 0, fmt.Errorf("nullable float16 vector field is missing metadata")
}
metadata := field.Metadata
dimStr, ok := metadata.GetValue("dim")
if !ok {
return nil, -1, nil, 0, fmt.Errorf("nullable float16 vector metadata missing required 'dim' field")
}
var err error
dim, err = strconv.Atoi(dimStr)
if err != nil {
return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err)
}
} else {
col, err := r.reader.RowGroup(0).Column(0)
if err != nil {
return nil, -1, nil, 0, err
}
dim = col.Descriptor().TypeLength() / 2
}
table, err := fileReader.ReadTable(context.Background())
if err != nil {
return nil, -1, nil, 0, err
}
defer table.Release()
if table.NumCols() != 1 {
return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols())
}
column := table.Column(0)
validCount := 0
for _, chunk := range column.Data().Chunks() {
for i := 0; i < chunk.Len(); i++ {
if chunk.IsValid(i) {
validCount++
}
}
}
ret := make([]byte, validCount*dim*2)
validData := make([]bool, r.numRows)
offset := 0
dataIdx := 0
for _, chunk := range column.Data().Chunks() {
binaryArray, ok := chunk.(*array.Binary)
if !ok {
return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk)
}
for i := 0; i < binaryArray.Len(); i++ {
if binaryArray.IsValid(i) {
validData[offset+i] = true
bytes := binaryArray.Value(i)
copy(ret[dataIdx*dim*2:(dataIdx+1)*dim*2], bytes)
dataIdx++
}
}
offset += binaryArray.Len()
}
return ret, dim, validData, int(r.numRows), nil
}
col, err := r.reader.RowGroup(0).Column(0)
if err != nil {
return nil, -1, err
return nil, -1, nil, 0, err
}
dim := col.Descriptor().TypeLength() / 2
values := make([]parquet.FixedLenByteArray, r.numRows)
valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
if err != nil {
return nil, -1, err
return nil, -1, nil, 0, err
}
if valuesRead != r.numRows {
return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
}
ret := make([]byte, int64(dim*2)*r.numRows)
for i := 0; i < int(r.numRows); i++ {
copy(ret[i*dim*2:(i+1)*dim*2], values[i])
}
return ret, dim, nil
return ret, dim, nil, int(r.numRows), nil
}
// GetBFloat16VectorFromPayload returns vector, dimension, error
func (r *PayloadReader) GetBFloat16VectorFromPayload() ([]byte, int, error) {
// GetBFloat16VectorFromPayload returns vector, dimension, validData, numRows, error
func (r *PayloadReader) GetBFloat16VectorFromPayload() ([]byte, int, []bool, int, error) {
if r.colType != schemapb.DataType_BFloat16Vector {
return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String())
return nil, -1, nil, 0, fmt.Errorf("failed to get bfloat16 vector from datatype %v", r.colType.String())
}
if r.nullable {
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
if err != nil {
return nil, -1, nil, 0, err
}
arrowSchema, err := fileReader.Schema()
if err != nil {
return nil, -1, nil, 0, err
}
if arrowSchema.NumFields() != 1 {
return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields())
}
field := arrowSchema.Field(0)
var dim int
if field.Type.ID() == arrow.BINARY {
if !field.HasMetadata() {
return nil, -1, nil, 0, fmt.Errorf("nullable bfloat16 vector field is missing metadata")
}
metadata := field.Metadata
dimStr, ok := metadata.GetValue("dim")
if !ok {
return nil, -1, nil, 0, fmt.Errorf("nullable bfloat16 vector metadata missing required 'dim' field")
}
var err error
dim, err = strconv.Atoi(dimStr)
if err != nil {
return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err)
}
} else {
col, err := r.reader.RowGroup(0).Column(0)
if err != nil {
return nil, -1, nil, 0, err
}
dim = col.Descriptor().TypeLength() / 2
}
table, err := fileReader.ReadTable(context.Background())
if err != nil {
return nil, -1, nil, 0, err
}
defer table.Release()
if table.NumCols() != 1 {
return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols())
}
column := table.Column(0)
validCount := 0
for _, chunk := range column.Data().Chunks() {
for i := 0; i < chunk.Len(); i++ {
if chunk.IsValid(i) {
validCount++
}
}
}
ret := make([]byte, validCount*dim*2)
validData := make([]bool, r.numRows)
offset := 0
dataIdx := 0
for _, chunk := range column.Data().Chunks() {
binaryArray, ok := chunk.(*array.Binary)
if !ok {
return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk)
}
for i := 0; i < binaryArray.Len(); i++ {
if binaryArray.IsValid(i) {
validData[offset+i] = true
bytes := binaryArray.Value(i)
copy(ret[dataIdx*dim*2:(dataIdx+1)*dim*2], bytes)
dataIdx++
}
}
offset += binaryArray.Len()
}
return ret, dim, validData, int(r.numRows), nil
}
col, err := r.reader.RowGroup(0).Column(0)
if err != nil {
return nil, -1, err
return nil, -1, nil, 0, err
}
dim := col.Descriptor().TypeLength() / 2
values := make([]parquet.FixedLenByteArray, r.numRows)
valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
if err != nil {
return nil, -1, err
return nil, -1, nil, 0, err
}
if valuesRead != r.numRows {
return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
}
ret := make([]byte, int64(dim*2)*r.numRows)
for i := 0; i < int(r.numRows); i++ {
copy(ret[i*dim*2:(i+1)*dim*2], values[i])
}
return ret, dim, nil
return ret, dim, nil, int(r.numRows), nil
}
// GetFloatVectorFromPayload returns vector, dimension, error
func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
// GetFloatVectorFromPayload returns vector, dimension, validData, numRows, error
func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, []bool, int, error) {
if r.colType != schemapb.DataType_FloatVector {
return nil, -1, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String())
return nil, -1, nil, 0, fmt.Errorf("failed to get float vector from datatype %v", r.colType.String())
}
if r.nullable {
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
if err != nil {
return nil, -1, nil, 0, err
}
arrowSchema, err := fileReader.Schema()
if err != nil {
return nil, -1, nil, 0, err
}
if arrowSchema.NumFields() != 1 {
return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields())
}
field := arrowSchema.Field(0)
var dim int
if field.Type.ID() == arrow.BINARY {
if !field.HasMetadata() {
return nil, -1, nil, 0, fmt.Errorf("nullable float vector field is missing metadata")
}
metadata := field.Metadata
dimStr, ok := metadata.GetValue("dim")
if !ok {
return nil, -1, nil, 0, fmt.Errorf("nullable float vector metadata missing required 'dim' field")
}
var err error
dim, err = strconv.Atoi(dimStr)
if err != nil {
return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err)
}
} else {
col, err := r.reader.RowGroup(0).Column(0)
if err != nil {
return nil, -1, nil, 0, err
}
dim = col.Descriptor().TypeLength() / 4
}
table, err := fileReader.ReadTable(context.Background())
if err != nil {
return nil, -1, nil, 0, err
}
defer table.Release()
if table.NumCols() != 1 {
return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols())
}
column := table.Column(0)
validCount := 0
for _, chunk := range column.Data().Chunks() {
for i := 0; i < chunk.Len(); i++ {
if chunk.IsValid(i) {
validCount++
}
}
}
ret := make([]float32, validCount*dim)
validData := make([]bool, r.numRows)
offset := 0
dataIdx := 0
for _, chunk := range column.Data().Chunks() {
binaryArray, ok := chunk.(*array.Binary)
if !ok {
return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk)
}
for i := 0; i < binaryArray.Len(); i++ {
if binaryArray.IsValid(i) {
validData[offset+i] = true
bytes := binaryArray.Value(i)
copy(arrow.Float32Traits.CastToBytes(ret[dataIdx*dim:(dataIdx+1)*dim]), bytes)
dataIdx++
}
}
offset += binaryArray.Len()
}
return ret, dim, validData, int(r.numRows), nil
}
col, err := r.reader.RowGroup(0).Column(0)
if err != nil {
return nil, -1, err
return nil, -1, nil, 0, err
}
dim := col.Descriptor().TypeLength() / 4
@ -778,38 +1116,89 @@ func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) {
values := make([]parquet.FixedLenByteArray, r.numRows)
valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
if err != nil {
return nil, -1, err
return nil, -1, nil, 0, err
}
if valuesRead != r.numRows {
return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
}
ret := make([]float32, int64(dim)*r.numRows)
for i := 0; i < int(r.numRows); i++ {
copy(arrow.Float32Traits.CastToBytes(ret[i*dim:(i+1)*dim]), values[i])
}
return ret, dim, nil
return ret, dim, nil, int(r.numRows), nil
}
func (r *PayloadReader) GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, error) {
// GetSparseFloatVectorFromPayload returns fieldData, dimension, validData, error
func (r *PayloadReader) GetSparseFloatVectorFromPayload() (*SparseFloatVectorFieldData, int, []bool, error) {
if !typeutil.IsSparseFloatVectorType(r.colType) {
return nil, -1, fmt.Errorf("failed to get sparse float vector from datatype %v", r.colType.String())
return nil, -1, nil, fmt.Errorf("failed to get sparse float vector from datatype %v", r.colType.String())
}
if r.nullable {
fieldData := &SparseFloatVectorFieldData{}
validData := make([]bool, r.numRows)
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
if err != nil {
return nil, -1, nil, err
}
table, err := fileReader.ReadTable(context.Background())
if err != nil {
return nil, -1, nil, err
}
defer table.Release()
if table.NumCols() != 1 {
return nil, -1, nil, fmt.Errorf("expected 1 column, got %d", table.NumCols())
}
column := table.Column(0)
offset := 0
for _, chunk := range column.Data().Chunks() {
binaryArray, ok := chunk.(*array.Binary)
if !ok {
return nil, -1, nil, fmt.Errorf("expected Binary array, got %T", chunk)
}
for i := 0; i < binaryArray.Len(); i++ {
validData[offset+i] = binaryArray.IsValid(i)
if validData[offset+i] {
value := binaryArray.Value(i)
if len(value)%8 != 0 {
return nil, -1, nil, errors.New("invalid bytesData length")
}
fieldData.Contents = append(fieldData.Contents, value)
rowDim := typeutil.SparseFloatRowDim(value)
if rowDim > fieldData.Dim {
fieldData.Dim = rowDim
}
} else {
fieldData.Contents = append(fieldData.Contents, nil)
}
}
offset += binaryArray.Len()
}
return fieldData, int(fieldData.Dim), validData, nil
}
values := make([]parquet.ByteArray, r.numRows)
valuesRead, err := ReadDataFromAllRowGroups[parquet.ByteArray, *file.ByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
if err != nil {
return nil, -1, err
return nil, -1, nil, err
}
if valuesRead != r.numRows {
return nil, -1, fmt.Errorf("expect %d binary, but got = %d", r.numRows, valuesRead)
return nil, -1, nil, fmt.Errorf("expect %d binary, but got = %d", r.numRows, valuesRead)
}
fieldData := &SparseFloatVectorFieldData{}
for _, value := range values {
if len(value)%8 != 0 {
return nil, -1, errors.New("invalid bytesData length")
return nil, -1, nil, errors.New("invalid bytesData length")
}
fieldData.Contents = append(fieldData.Contents, value)
@ -819,17 +1208,101 @@ func (r *PayloadReader) GetSparseFloatVectorFromPayload() (*SparseFloatVectorFie
}
}
return fieldData, int(fieldData.Dim), nil
return fieldData, int(fieldData.Dim), nil, nil
}
// GetInt8VectorFromPayload returns vector, dimension, error
func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, error) {
// GetInt8VectorFromPayload returns vector, dimension, validData, numRows, error
func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, []bool, int, error) {
if r.colType != schemapb.DataType_Int8Vector {
return nil, -1, fmt.Errorf("failed to get int8 vector from datatype %v", r.colType.String())
return nil, -1, nil, 0, fmt.Errorf("failed to get int8 vector from datatype %v", r.colType.String())
}
if r.nullable {
fileReader, err := pqarrow.NewFileReader(r.reader, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator)
if err != nil {
return nil, -1, nil, 0, err
}
arrowSchema, err := fileReader.Schema()
if err != nil {
return nil, -1, nil, 0, err
}
if arrowSchema.NumFields() != 1 {
return nil, -1, nil, 0, fmt.Errorf("expected 1 field, got %d", arrowSchema.NumFields())
}
field := arrowSchema.Field(0)
var dim int
if field.Type.ID() == arrow.BINARY {
if !field.HasMetadata() {
return nil, -1, nil, 0, fmt.Errorf("nullable int8 vector field is missing metadata")
}
metadata := field.Metadata
dimStr, ok := metadata.GetValue("dim")
if !ok {
return nil, -1, nil, 0, fmt.Errorf("nullable int8 vector metadata missing required 'dim' field")
}
var err error
dim, err = strconv.Atoi(dimStr)
if err != nil {
return nil, -1, nil, 0, fmt.Errorf("invalid dim value in metadata: %v", err)
}
} else {
col, err := r.reader.RowGroup(0).Column(0)
if err != nil {
return nil, -1, nil, 0, err
}
dim = col.Descriptor().TypeLength()
}
table, err := fileReader.ReadTable(context.Background())
if err != nil {
return nil, -1, nil, 0, err
}
defer table.Release()
if table.NumCols() != 1 {
return nil, -1, nil, 0, fmt.Errorf("expected 1 column, got %d", table.NumCols())
}
column := table.Column(0)
validCount := 0
for _, chunk := range column.Data().Chunks() {
for i := 0; i < chunk.Len(); i++ {
if chunk.IsValid(i) {
validCount++
}
}
}
ret := make([]int8, validCount*dim)
validData := make([]bool, r.numRows)
offset := 0
dataIdx := 0
for _, chunk := range column.Data().Chunks() {
binaryArray, ok := chunk.(*array.Binary)
if !ok {
return nil, -1, nil, 0, fmt.Errorf("expected Binary array for nullable vector, got %T", chunk)
}
for i := 0; i < binaryArray.Len(); i++ {
if binaryArray.IsValid(i) {
validData[offset+i] = true
bytes := binaryArray.Value(i)
int8Vals := arrow.Int8Traits.CastFromBytes(bytes)
copy(ret[dataIdx*dim:(dataIdx+1)*dim], int8Vals)
dataIdx++
}
}
offset += binaryArray.Len()
}
return ret, dim, validData, int(r.numRows), nil
}
col, err := r.reader.RowGroup(0).Column(0)
if err != nil {
return nil, -1, err
return nil, -1, nil, 0, err
}
dim := col.Descriptor().TypeLength()
@ -837,11 +1310,11 @@ func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, error) {
values := make([]parquet.FixedLenByteArray, r.numRows)
valuesRead, err := ReadDataFromAllRowGroups[parquet.FixedLenByteArray, *file.FixedLenByteArrayColumnChunkReader](r.reader, values, 0, r.numRows)
if err != nil {
return nil, -1, err
return nil, -1, nil, 0, err
}
if valuesRead != r.numRows {
return nil, -1, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
return nil, -1, nil, 0, fmt.Errorf("expect %d rows, but got valuesRead = %d", r.numRows, valuesRead)
}
ret := make([]int8, int64(dim)*r.numRows)
@ -849,7 +1322,7 @@ func (r *PayloadReader) GetInt8VectorFromPayload() ([]int8, int, error) {
int8Vals := arrow.Int8Traits.CastFromBytes(values[i])
copy(ret[i*dim:(i+1)*dim], int8Vals)
}
return ret, dim, nil
return ret, dim, nil, int(r.numRows), nil
}
func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) {

View File

@ -484,7 +484,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
in2[i] = 1
}
err = w.AddBinaryVectorToPayload(in, 8)
err = w.AddBinaryVectorToPayload(in, 8, nil)
assert.NoError(t, err)
err = w.AddDataToPayloadForUT(in2, nil)
assert.NoError(t, err)
@ -505,7 +505,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, length, 24)
binVecs, dim, err := r.GetBinaryVectorFromPayload()
binVecs, dim, _, _, err := r.GetBinaryVectorFromPayload()
assert.NoError(t, err)
assert.Equal(t, 8, dim)
assert.Equal(t, 24, len(binVecs))
@ -524,7 +524,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1)
err = w.AddFloatVectorToPayload([]float32{1.0, 2.0}, 1, nil)
assert.NoError(t, err)
err = w.AddDataToPayloadForUT([]float32{3.0, 4.0}, nil)
assert.NoError(t, err)
@ -545,7 +545,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, length, 4)
floatVecs, dim, err := r.GetFloatVectorFromPayload()
floatVecs, dim, _, _, err := r.GetFloatVectorFromPayload()
assert.NoError(t, err)
assert.Equal(t, 1, dim)
assert.Equal(t, 4, len(floatVecs))
@ -566,7 +566,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddFloat16VectorToPayload([]byte{1, 2}, 1)
err = w.AddFloat16VectorToPayload([]byte{1, 2}, 1, nil)
assert.NoError(t, err)
err = w.AddDataToPayloadForUT([]byte{3, 4}, nil)
assert.NoError(t, err)
@ -587,7 +587,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, length, 2)
float16Vecs, dim, err := r.GetFloat16VectorFromPayload()
float16Vecs, dim, _, _, err := r.GetFloat16VectorFromPayload()
assert.NoError(t, err)
assert.Equal(t, 1, dim)
assert.Equal(t, 4, len(float16Vecs))
@ -608,7 +608,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddBFloat16VectorToPayload([]byte{1, 2}, 1)
err = w.AddBFloat16VectorToPayload([]byte{1, 2}, 1, nil)
assert.NoError(t, err)
err = w.AddDataToPayloadForUT([]byte{3, 4}, nil)
assert.NoError(t, err)
@ -629,7 +629,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, length, 2)
bfloat16Vecs, dim, err := r.GetBFloat16VectorFromPayload()
bfloat16Vecs, dim, _, _, err := r.GetBFloat16VectorFromPayload()
assert.NoError(t, err)
assert.Equal(t, 1, dim)
assert.Equal(t, 4, len(bfloat16Vecs))
@ -689,7 +689,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, length, 6)
floatVecs, dim, err := r.GetSparseFloatVectorFromPayload()
floatVecs, dim, _, err := r.GetSparseFloatVectorFromPayload()
assert.NoError(t, err)
assert.Equal(t, 600, dim)
assert.Equal(t, 6, len(floatVecs.Contents))
@ -743,7 +743,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, length, 3)
floatVecs, dim, err := r.GetSparseFloatVectorFromPayload()
floatVecs, dim, _, err := r.GetSparseFloatVectorFromPayload()
assert.NoError(t, err)
assert.Equal(t, actualDim, dim)
assert.Equal(t, 3, len(floatVecs.Contents))
@ -951,16 +951,16 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
err = w.FinishPayloadWriter()
assert.NoError(t, err)
err = w.AddBinaryVectorToPayload([]byte{}, 8)
err = w.AddBinaryVectorToPayload([]byte{}, 8, nil)
assert.Error(t, err)
err = w.AddBinaryVectorToPayload([]byte{1}, 0)
err = w.AddBinaryVectorToPayload([]byte{1}, 0, nil)
assert.Error(t, err)
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
assert.Error(t, err)
err = w.FinishPayloadWriter()
assert.Error(t, err)
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
assert.Error(t, err)
})
t.Run("TestAddFloatVectorAfterFinish", func(t *testing.T) {
@ -972,16 +972,16 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
err = w.FinishPayloadWriter()
assert.NoError(t, err)
err = w.AddFloatVectorToPayload([]float32{}, 8)
err = w.AddFloatVectorToPayload([]float32{}, 8, nil)
assert.Error(t, err)
err = w.AddFloatVectorToPayload([]float32{1.0}, 0)
err = w.AddFloatVectorToPayload([]float32{1.0}, 0, nil)
assert.Error(t, err)
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8)
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8, nil)
assert.Error(t, err)
err = w.FinishPayloadWriter()
assert.Error(t, err)
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8)
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8, nil)
assert.Error(t, err)
})
t.Run("TestAddFloat16VectorAfterFinish", func(t *testing.T) {
@ -990,22 +990,22 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
require.NotNil(t, w)
defer w.Close()
err = w.AddFloat16VectorToPayload([]byte{}, 8)
err = w.AddFloat16VectorToPayload([]byte{}, 8, nil)
assert.Error(t, err)
err = w.FinishPayloadWriter()
assert.NoError(t, err)
err = w.AddFloat16VectorToPayload([]byte{}, 8)
err = w.AddFloat16VectorToPayload([]byte{}, 8, nil)
assert.Error(t, err)
err = w.AddFloat16VectorToPayload([]byte{1}, 0)
err = w.AddFloat16VectorToPayload([]byte{1}, 0, nil)
assert.Error(t, err)
err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
assert.Error(t, err)
err = w.FinishPayloadWriter()
assert.Error(t, err)
err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
err = w.AddFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
assert.Error(t, err)
})
t.Run("TestAddBFloat16VectorAfterFinish", func(t *testing.T) {
@ -1014,22 +1014,22 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
require.NotNil(t, w)
defer w.Close()
err = w.AddBFloat16VectorToPayload([]byte{}, 8)
err = w.AddBFloat16VectorToPayload([]byte{}, 8, nil)
assert.Error(t, err)
err = w.FinishPayloadWriter()
assert.NoError(t, err)
err = w.AddBFloat16VectorToPayload([]byte{}, 8)
err = w.AddBFloat16VectorToPayload([]byte{}, 8, nil)
assert.Error(t, err)
err = w.AddBFloat16VectorToPayload([]byte{1}, 0)
err = w.AddBFloat16VectorToPayload([]byte{1}, 0, nil)
assert.Error(t, err)
err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
assert.Error(t, err)
err = w.FinishPayloadWriter()
assert.Error(t, err)
err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
err = w.AddBFloat16VectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
assert.Error(t, err)
})
t.Run("TestAddSparseFloatVectorAfterFinish", func(t *testing.T) {
@ -1481,11 +1481,11 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
r, err := NewPayloadReader(schemapb.DataType_BinaryVector, buffer, false)
assert.NoError(t, err)
_, _, err = r.GetBinaryVectorFromPayload()
_, _, _, _, err = r.GetBinaryVectorFromPayload()
assert.Error(t, err)
r.colType = 999
_, _, err = r.GetBinaryVectorFromPayload()
_, _, _, _, err = r.GetBinaryVectorFromPayload()
assert.Error(t, err)
})
t.Run("TestGetBinaryVectorError2", func(t *testing.T) {
@ -1493,7 +1493,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8)
err = w.AddBinaryVectorToPayload([]byte{1, 0, 0, 0, 0, 0, 0, 0}, 8, nil)
assert.NoError(t, err)
err = w.FinishPayloadWriter()
@ -1506,7 +1506,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
assert.NoError(t, err)
r.numRows = 99
_, _, err = r.GetBinaryVectorFromPayload()
_, _, _, _, err = r.GetBinaryVectorFromPayload()
assert.Error(t, err)
})
t.Run("TestGetFloatVectorError", func(t *testing.T) {
@ -1526,11 +1526,11 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
r, err := NewPayloadReader(schemapb.DataType_FloatVector, buffer, false)
assert.NoError(t, err)
_, _, err = r.GetFloatVectorFromPayload()
_, _, _, _, err = r.GetFloatVectorFromPayload()
assert.Error(t, err)
r.colType = 999
_, _, err = r.GetFloatVectorFromPayload()
_, _, _, _, err = r.GetFloatVectorFromPayload()
assert.Error(t, err)
})
t.Run("TestGetFloatVectorError2", func(t *testing.T) {
@ -1538,7 +1538,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8)
err = w.AddFloatVectorToPayload([]float32{1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, 8, nil)
assert.NoError(t, err)
err = w.FinishPayloadWriter()
@ -1551,7 +1551,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
assert.NoError(t, err)
r.numRows = 99
_, _, err = r.GetFloatVectorFromPayload()
_, _, _, _, err = r.GetFloatVectorFromPayload()
assert.Error(t, err)
})
@ -1599,7 +1599,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
w, err := NewPayloadWriter(schemapb.DataType_FloatVector)
assert.NoError(t, err)
err = w.AddFloatVectorToPayload(vec, 128)
err = w.AddFloatVectorToPayload(vec, 128, nil)
assert.NoError(t, err)
err = w.FinishPayloadWriter()
@ -2234,19 +2234,548 @@ func TestPayload_NullableReaderAndWriter(t *testing.T) {
w.ReleasePayloadWriter()
})
t.Run("TestBinaryVector", func(t *testing.T) {
_, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithNullable(true), WithDim(8))
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
t.Run("TestFloatVector", func(t *testing.T) {
dim := 128
numRows := 100
type testCase struct {
name string
validDataSetup func([]bool) int
}
testCases := []testCase{
{
name: "half null",
validDataSetup: func(validData []bool) int {
validCount := 0
for i := 0; i < numRows; i++ {
if i%2 == 0 {
validData[i] = true
validCount++
}
}
return validCount
},
},
{
name: "all valid",
validDataSetup: func(validData []bool) int {
for i := 0; i < numRows; i++ {
validData[i] = true
}
return numRows
},
},
{
name: "all null",
validDataSetup: func(validData []bool) int {
return 0
},
},
}
for _, tc := range testCases {
w, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithDim(dim), WithNullable(true))
require.NoError(t, err)
validData := make([]bool, numRows)
validCount := tc.validDataSetup(validData)
data := make([]float32, validCount*dim)
dataIdx := 0
for i := 0; i < numRows; i++ {
if validData[i] {
for j := 0; j < dim; j++ {
data[dataIdx*dim+j] = float32(i*100 + j)
}
dataIdx++
}
}
err = w.AddFloatVectorToPayload(data, dim, validData)
require.NoError(t, err)
err = w.FinishPayloadWriter()
require.NoError(t, err)
buffer, err := w.GetPayloadBufferFromWriter()
require.NoError(t, err)
r, err := NewPayloadReader(schemapb.DataType_FloatVector, buffer, true)
require.NoError(t, err)
readData, readDim, readValid, readNumRows, err := r.GetFloatVectorFromPayload()
require.NoError(t, err)
require.Equal(t, dim, readDim)
require.Equal(t, numRows, readNumRows)
require.Equal(t, numRows, len(readValid))
dataIdx = 0
for i := 0; i < numRows; i++ {
require.Equal(t, validData[i], readValid[i])
if validData[i] {
pos := dataIdx
for j := 0; j < dim; j++ {
require.Equal(t, data[dataIdx*dim+j], readData[pos*dim+j])
}
dataIdx++
}
}
}
})
t.Run("TestFloatVector", func(t *testing.T) {
_, err := NewPayloadWriter(schemapb.DataType_FloatVector, WithNullable(true), WithDim(1))
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
t.Run("TestBinaryVector", func(t *testing.T) {
dim := 128
numRows := 100
type testCase struct {
name string
validDataSetup func([]bool) int
}
testCases := []testCase{
{
name: "partial null",
validDataSetup: func(validData []bool) int {
validCount := 0
for i := 0; i < numRows; i++ {
if i%3 == 0 {
validData[i] = true
validCount++
}
}
return validCount
},
},
{
name: "all valid",
validDataSetup: func(validData []bool) int {
for i := 0; i < numRows; i++ {
validData[i] = true
}
return numRows
},
},
{
name: "all null",
validDataSetup: func(validData []bool) int {
return 0
},
},
}
for _, tc := range testCases {
w, err := NewPayloadWriter(schemapb.DataType_BinaryVector, WithDim(dim), WithNullable(true))
require.NoError(t, err)
validData := make([]bool, numRows)
validCount := tc.validDataSetup(validData)
data := make([]byte, validCount*dim/8)
dataIdx := 0
for i := 0; i < numRows; i++ {
if validData[i] {
for j := 0; j < dim/8; j++ {
data[dataIdx*dim/8+j] = byte(i + j)
}
dataIdx++
}
}
err = w.AddBinaryVectorToPayload(data, dim, validData)
require.NoError(t, err)
err = w.FinishPayloadWriter()
require.NoError(t, err)
buffer, err := w.GetPayloadBufferFromWriter()
require.NoError(t, err)
r, err := NewPayloadReader(schemapb.DataType_BinaryVector, buffer, true)
require.NoError(t, err)
readData, readDim, readValid, readNumRows, err := r.GetBinaryVectorFromPayload()
require.NoError(t, err)
require.Equal(t, dim, readDim)
require.Equal(t, numRows, readNumRows)
require.Equal(t, numRows, len(readValid))
dataIdx = 0
for i := 0; i < numRows; i++ {
require.Equal(t, validData[i], readValid[i])
if validData[i] {
pos := dataIdx
for j := 0; j < dim/8; j++ {
require.Equal(t, data[dataIdx*dim/8+j], readData[pos*dim/8+j])
}
dataIdx++
}
}
}
})
t.Run("TestFloat16Vector", func(t *testing.T) {
_, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithNullable(true), WithDim(1))
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
dim := 128
numRows := 100
type testCase struct {
name string
validDataSetup func([]bool) int
}
testCases := []testCase{
{
name: "partial null",
validDataSetup: func(validData []bool) int {
validCount := 0
for i := 0; i < numRows; i++ {
if i%2 == 1 {
validData[i] = true
validCount++
}
}
return validCount
},
},
{
name: "all valid",
validDataSetup: func(validData []bool) int {
for i := 0; i < numRows; i++ {
validData[i] = true
}
return numRows
},
},
{
name: "all null",
validDataSetup: func(validData []bool) int {
return 0
},
},
}
for _, tc := range testCases {
w, err := NewPayloadWriter(schemapb.DataType_Float16Vector, WithDim(dim), WithNullable(true))
require.NoError(t, err)
validData := make([]bool, numRows)
validCount := tc.validDataSetup(validData)
data := make([]byte, validCount*dim*2)
dataIdx := 0
for i := 0; i < numRows; i++ {
if validData[i] {
for j := 0; j < dim*2; j++ {
data[dataIdx*dim*2+j] = byte((i*10 + j) % 256)
}
dataIdx++
}
}
err = w.AddFloat16VectorToPayload(data, dim, validData)
require.NoError(t, err)
err = w.FinishPayloadWriter()
require.NoError(t, err)
buffer, err := w.GetPayloadBufferFromWriter()
require.NoError(t, err)
r, err := NewPayloadReader(schemapb.DataType_Float16Vector, buffer, true)
require.NoError(t, err)
readData, readDim, readValid, readNumRows, err := r.GetFloat16VectorFromPayload()
require.NoError(t, err)
require.Equal(t, dim, readDim)
require.Equal(t, numRows, readNumRows)
require.Equal(t, numRows, len(readValid))
dataIdx = 0
for i := 0; i < numRows; i++ {
require.Equal(t, validData[i], readValid[i])
if validData[i] {
pos := dataIdx
for j := 0; j < dim*2; j++ {
require.Equal(t, data[dataIdx*dim*2+j], readData[pos*dim*2+j])
}
dataIdx++
}
}
}
})
t.Run("TestBFloat16Vector", func(t *testing.T) {
dim := 128
numRows := 100
type testCase struct {
name string
validDataSetup func([]bool) int
}
testCases := []testCase{
{
name: "partial null",
validDataSetup: func(validData []bool) int {
validCount := 0
for i := 0; i < numRows; i++ {
if (i+1)%3 != 0 {
validData[i] = true
validCount++
}
}
return validCount
},
},
{
name: "all valid",
validDataSetup: func(validData []bool) int {
for i := 0; i < numRows; i++ {
validData[i] = true
}
return numRows
},
},
{
name: "all null",
validDataSetup: func(validData []bool) int {
return 0
},
},
}
for _, tc := range testCases {
w, err := NewPayloadWriter(schemapb.DataType_BFloat16Vector, WithDim(dim), WithNullable(true))
require.NoError(t, err)
validData := make([]bool, numRows)
validCount := tc.validDataSetup(validData)
data := make([]byte, validCount*dim*2)
dataIdx := 0
for i := 0; i < numRows; i++ {
if validData[i] {
for j := 0; j < dim*2; j++ {
data[dataIdx*dim*2+j] = byte((i*20 + j) % 256)
}
dataIdx++
}
}
err = w.AddBFloat16VectorToPayload(data, dim, validData)
require.NoError(t, err)
err = w.FinishPayloadWriter()
require.NoError(t, err)
buffer, err := w.GetPayloadBufferFromWriter()
require.NoError(t, err)
r, err := NewPayloadReader(schemapb.DataType_BFloat16Vector, buffer, true)
require.NoError(t, err)
readData, readDim, readValid, readNumRows, err := r.GetBFloat16VectorFromPayload()
require.NoError(t, err)
require.Equal(t, dim, readDim)
require.Equal(t, numRows, readNumRows)
require.Equal(t, numRows, len(readValid))
dataIdx = 0
for i := 0; i < numRows; i++ {
require.Equal(t, validData[i], readValid[i])
if validData[i] {
pos := dataIdx
for j := 0; j < dim*2; j++ {
require.Equal(t, data[dataIdx*dim*2+j], readData[pos*dim*2+j])
}
dataIdx++
}
}
}
})
t.Run("TestInt8Vector", func(t *testing.T) {
dim := 128
numRows := 100
type testCase struct {
name string
validDataSetup func([]bool) int
}
testCases := []testCase{
{
name: "partial null",
validDataSetup: func(validData []bool) int {
validCount := 0
for i := 0; i < numRows; i++ {
if i < numRows/2 {
validData[i] = true
validCount++
}
}
return validCount
},
},
{
name: "all valid",
validDataSetup: func(validData []bool) int {
for i := 0; i < numRows; i++ {
validData[i] = true
}
return numRows
},
},
{
name: "all null",
validDataSetup: func(validData []bool) int {
return 0
},
},
}
for _, tc := range testCases {
w, err := NewPayloadWriter(schemapb.DataType_Int8Vector, WithDim(dim), WithNullable(true))
require.NoError(t, err)
validData := make([]bool, numRows)
validCount := tc.validDataSetup(validData)
data := make([]int8, validCount*dim)
dataIdx := 0
for i := 0; i < numRows; i++ {
if validData[i] {
for j := 0; j < dim; j++ {
data[dataIdx*dim+j] = int8((i*10 + j) % 128)
}
dataIdx++
}
}
err = w.AddInt8VectorToPayload(data, dim, validData)
require.NoError(t, err)
err = w.FinishPayloadWriter()
require.NoError(t, err)
buffer, err := w.GetPayloadBufferFromWriter()
require.NoError(t, err)
r, err := NewPayloadReader(schemapb.DataType_Int8Vector, buffer, true)
require.NoError(t, err)
readData, readDim, readValid, readNumRows, err := r.GetInt8VectorFromPayload()
require.NoError(t, err)
require.Equal(t, dim, readDim)
require.Equal(t, numRows, readNumRows)
require.Equal(t, numRows, len(readValid))
dataIdx = 0
for i := 0; i < numRows; i++ {
require.Equal(t, validData[i], readValid[i])
if validData[i] {
pos := dataIdx
for j := 0; j < dim; j++ {
require.Equal(t, data[dataIdx*dim+j], readData[pos*dim+j])
}
dataIdx++
}
}
}
})
t.Run("TestSparseFloatVector", func(t *testing.T) {
numRows := 100
type testCase struct {
name string
validDataSetup func([]bool) int
}
testCases := []testCase{
{
name: "half null",
validDataSetup: func(validData []bool) int {
validCount := 0
for i := 0; i < numRows; i++ {
if i%2 == 0 {
validData[i] = true
validCount++
}
}
return validCount
},
},
{
name: "all valid",
validDataSetup: func(validData []bool) int {
for i := 0; i < numRows; i++ {
validData[i] = true
}
return numRows
},
},
{
name: "all null",
validDataSetup: func(validData []bool) int {
return 0
},
},
}
for _, tc := range testCases {
w, err := NewPayloadWriter(schemapb.DataType_SparseFloatVector, WithNullable(true))
require.NoError(t, err)
validData := make([]bool, numRows)
tc.validDataSetup(validData)
data := &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 100,
},
ValidData: validData,
}
for i := 0; i < numRows; i++ {
if validData[i] {
sparseVec := make([]byte, 16)
for j := 0; j < 16; j++ {
sparseVec[j] = byte((i*10 + j) % 256)
}
data.SparseFloatArray.Contents = append(data.SparseFloatArray.Contents, sparseVec)
}
}
err = w.AddSparseFloatVectorToPayload(data)
require.NoError(t, err)
err = w.FinishPayloadWriter()
require.NoError(t, err)
buffer, err := w.GetPayloadBufferFromWriter()
require.NoError(t, err)
r, err := NewPayloadReader(schemapb.DataType_SparseFloatVector, buffer, true)
require.NoError(t, err)
readData, _, readValid, err := r.GetSparseFloatVectorFromPayload()
require.NoError(t, err)
require.Equal(t, numRows, len(readValid))
require.Equal(t, numRows, len(readData.Contents))
for i := 0; i < numRows; i++ {
require.Equal(t, validData[i], readValid[i])
if validData[i] {
require.NotNil(t, readData.Contents[i])
require.Equal(t, 16, len(readData.Contents[i]))
for j := 0; j < 16; j++ {
require.Equal(t, byte((i*10+j)%256), readData.Contents[i][j])
}
} else {
require.Nil(t, readData.Contents[i])
}
}
}
})
t.Run("TestAddBool with wrong valids", func(t *testing.T) {

View File

@ -103,9 +103,6 @@ func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions
if w.dim.IsNull() {
return nil, merr.WrapErrParameterInvalidMsg("incorrect input numbers")
}
if w.nullable {
return nil, merr.WrapErrParameterInvalidMsg("vector type does not support nullable")
}
} else {
w.dim = NewNullableInt(1)
}
@ -125,8 +122,13 @@ func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions
w.arrowType = arrow.ListOf(elemType)
w.builder = array.NewListBuilder(memory.DefaultAllocator, elemType)
} else {
w.arrowType = MilvusDataTypeToArrowType(colType, *w.dim.Value)
w.builder = array.NewBuilder(memory.DefaultAllocator, w.arrowType)
if w.nullable && typeutil.IsVectorType(colType) && !typeutil.IsSparseFloatVectorType(colType) {
w.arrowType = &arrow.BinaryType{}
w.builder = array.NewBinaryBuilder(memory.DefaultAllocator, arrow.BinaryTypes.Binary)
} else {
w.arrowType = MilvusDataTypeToArrowType(colType, *w.dim.Value)
w.builder = array.NewBuilder(memory.DefaultAllocator, w.arrowType)
}
}
return w, nil
}
@ -262,25 +264,25 @@ func (w *NativePayloadWriter) AddDataToPayloadForUT(data interface{}, validData
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddBinaryVectorToPayload(val, w.dim.GetValue())
return w.AddBinaryVectorToPayload(val, w.dim.GetValue(), validData)
case schemapb.DataType_FloatVector:
val, ok := data.([]float32)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddFloatVectorToPayload(val, w.dim.GetValue())
return w.AddFloatVectorToPayload(val, w.dim.GetValue(), validData)
case schemapb.DataType_Float16Vector:
val, ok := data.([]byte)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddFloat16VectorToPayload(val, w.dim.GetValue())
return w.AddFloat16VectorToPayload(val, w.dim.GetValue(), validData)
case schemapb.DataType_BFloat16Vector:
val, ok := data.([]byte)
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddBFloat16VectorToPayload(val, w.dim.GetValue())
return w.AddBFloat16VectorToPayload(val, w.dim.GetValue(), validData)
case schemapb.DataType_SparseFloatVector:
val, ok := data.(*SparseFloatVectorFieldData)
if !ok {
@ -292,7 +294,7 @@ func (w *NativePayloadWriter) AddDataToPayloadForUT(data interface{}, validData
if !ok {
return merr.WrapErrParameterInvalidMsg("incorrect data type")
}
return w.AddInt8VectorToPayload(val, w.dim.GetValue())
return w.AddInt8VectorToPayload(val, w.dim.GetValue(), validData)
case schemapb.DataType_ArrayOfVector:
val, ok := data.(*VectorArrayFieldData)
if !ok {
@ -660,106 +662,262 @@ func (w *NativePayloadWriter) AddOneGeometryToPayload(data []byte, isValid bool)
return nil
}
func (w *NativePayloadWriter) AddBinaryVectorToPayload(data []byte, dim int) error {
func (w *NativePayloadWriter) AddBinaryVectorToPayload(data []byte, dim int, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished binary vector payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into binary vector payload")
}
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast BinaryVectorBuilder")
}
byteLength := dim / 8
length := len(data) / byteLength
builder.Reserve(length)
for i := 0; i < length; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
var numRows int
if w.nullable && len(validData) > 0 {
numRows = len(validData)
validCount := 0
for _, valid := range validData {
if valid {
validCount++
}
}
expectedDataLen := validCount * byteLength
if len(data) != expectedDataLen {
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
if len(data) == 0 {
return errors.New("can't add empty msgs into binary vector payload")
}
numRows = len(data) / byteLength
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
if w.nullable {
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast to BinaryBuilder for nullable BinaryVector")
}
builder.Reserve(numRows)
dataIdx := 0
for i := 0; i < numRows; i++ {
if len(validData) > 0 && !validData[i] {
builder.AppendNull()
} else {
builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength])
dataIdx++
}
}
} else {
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable BinaryVector")
}
builder.Reserve(numRows)
for i := 0; i < numRows; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
}
}
return nil
}
func (w *NativePayloadWriter) AddFloatVectorToPayload(data []float32, dim int) error {
func (w *NativePayloadWriter) AddFloatVectorToPayload(data []float32, dim int, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished float vector payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into float vector payload")
}
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast FloatVectorBuilder")
var numRows int
if w.nullable && len(validData) > 0 {
numRows = len(validData)
validCount := 0
for _, valid := range validData {
if valid {
validCount++
}
}
expectedDataLen := validCount * dim
if len(data) != expectedDataLen {
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * dim(%d) = %d", len(data), validCount, dim, expectedDataLen)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
if len(data) == 0 {
return errors.New("can't add empty msgs into float vector payload")
}
numRows = len(data) / dim
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
byteLength := dim * 4
length := len(data) / dim
builder.Reserve(length)
bytesData := make([]byte, byteLength)
for i := 0; i < length; i++ {
vec := data[i*dim : (i+1)*dim]
for j := range vec {
bytes := math.Float32bits(vec[j])
common.Endian.PutUint32(bytesData[j*4:], bytes)
if w.nullable {
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast to BinaryBuilder for nullable FloatVector")
}
builder.Reserve(numRows)
bytesData := make([]byte, byteLength)
dataIdx := 0
for i := 0; i < numRows; i++ {
if len(validData) > 0 && !validData[i] {
builder.AppendNull()
} else {
vec := data[dataIdx*dim : (dataIdx+1)*dim]
for j := range vec {
bytes := math.Float32bits(vec[j])
common.Endian.PutUint32(bytesData[j*4:], bytes)
}
builder.Append(bytesData)
dataIdx++
}
}
} else {
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable FloatVector")
}
builder.Reserve(numRows)
bytesData := make([]byte, byteLength)
for i := 0; i < numRows; i++ {
vec := data[i*dim : (i+1)*dim]
for j := range vec {
bytes := math.Float32bits(vec[j])
common.Endian.PutUint32(bytesData[j*4:], bytes)
}
builder.Append(bytesData)
}
builder.Append(bytesData)
}
return nil
}
func (w *NativePayloadWriter) AddFloat16VectorToPayload(data []byte, dim int) error {
func (w *NativePayloadWriter) AddFloat16VectorToPayload(data []byte, dim int, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished float16 payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into float16 payload")
}
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast Float16Builder")
}
byteLength := dim * 2
length := len(data) / byteLength
var numRows int
if w.nullable && len(validData) > 0 {
numRows = len(validData)
validCount := 0
for _, valid := range validData {
if valid {
validCount++
}
}
expectedDataLen := validCount * byteLength
if len(data) != expectedDataLen {
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
if len(data) == 0 {
return errors.New("can't add empty msgs into float16 payload")
}
numRows = len(data) / byteLength
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
builder.Reserve(length)
for i := 0; i < length; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
if w.nullable {
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast to BinaryBuilder for nullable Float16Vector")
}
builder.Reserve(numRows)
dataIdx := 0
for i := 0; i < numRows; i++ {
if len(validData) > 0 && !validData[i] {
builder.AppendNull()
} else {
builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength])
dataIdx++
}
}
} else {
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable Float16Vector")
}
builder.Reserve(numRows)
for i := 0; i < numRows; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
}
}
return nil
}
func (w *NativePayloadWriter) AddBFloat16VectorToPayload(data []byte, dim int) error {
func (w *NativePayloadWriter) AddBFloat16VectorToPayload(data []byte, dim int, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished BFloat16 payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into BFloat16 payload")
}
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast BFloat16Builder")
}
byteLength := dim * 2
length := len(data) / byteLength
var numRows int
if w.nullable && len(validData) > 0 {
numRows = len(validData)
validCount := 0
for _, valid := range validData {
if valid {
validCount++
}
}
expectedDataLen := validCount * byteLength
if len(data) != expectedDataLen {
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * byteLength(%d) = %d", len(data), validCount, byteLength, expectedDataLen)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
if len(data) == 0 {
return errors.New("can't add empty msgs into BFloat16 payload")
}
numRows = len(data) / byteLength
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
builder.Reserve(length)
for i := 0; i < length; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
if w.nullable {
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast to BinaryBuilder for nullable BFloat16Vector")
}
builder.Reserve(numRows)
dataIdx := 0
for i := 0; i < numRows; i++ {
if len(validData) > 0 && !validData[i] {
builder.AppendNull()
} else {
builder.Append(data[dataIdx*byteLength : (dataIdx+1)*byteLength])
dataIdx++
}
}
} else {
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable BFloat16Vector")
}
builder.Reserve(numRows)
for i := 0; i < numRows; i++ {
builder.Append(data[i*byteLength : (i+1)*byteLength])
}
}
return nil
@ -769,41 +927,107 @@ func (w *NativePayloadWriter) AddSparseFloatVectorToPayload(data *SparseFloatVec
if w.finished {
return errors.New("can't append data to finished sparse float vector payload")
}
var numRows int
if w.nullable && len(data.ValidData) > 0 {
numRows = len(data.ValidData)
validCount := 0
for _, valid := range data.ValidData {
if valid {
validCount++
}
}
if len(data.SparseFloatArray.Contents) != validCount {
msg := fmt.Sprintf("when nullable, Contents length(%d) must equal to valid count(%d)", len(data.SparseFloatArray.Contents), validCount)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
numRows = len(data.SparseFloatArray.Contents)
if !w.nullable && len(data.ValidData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(data.ValidData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast SparseFloatVectorBuilder")
}
length := len(data.SparseFloatArray.Contents)
builder.Reserve(length)
for i := 0; i < length; i++ {
builder.Append(data.SparseFloatArray.Contents[i])
builder.Reserve(numRows)
dataIdx := 0
for i := 0; i < numRows; i++ {
if w.nullable && len(data.ValidData) > 0 && !data.ValidData[i] {
builder.AppendNull()
} else {
builder.Append(data.SparseFloatArray.Contents[dataIdx])
dataIdx++
}
}
return nil
}
func (w *NativePayloadWriter) AddInt8VectorToPayload(data []int8, dim int) error {
func (w *NativePayloadWriter) AddInt8VectorToPayload(data []int8, dim int, validData []bool) error {
if w.finished {
return errors.New("can't append data to finished int8 vector payload")
}
if len(data) == 0 {
return errors.New("can't add empty msgs into int8 vector payload")
var numRows int
if w.nullable && len(validData) > 0 {
numRows = len(validData)
validCount := 0
for _, valid := range validData {
if valid {
validCount++
}
}
expectedDataLen := validCount * dim
if len(data) != expectedDataLen {
msg := fmt.Sprintf("when nullable, data length(%d) must equal to valid count(%d) * dim(%d) = %d", len(data), validCount, dim, expectedDataLen)
return merr.WrapErrParameterInvalidMsg(msg)
}
} else {
if len(data) == 0 {
return errors.New("can't add empty msgs into int8 vector payload")
}
numRows = len(data) / dim
if !w.nullable && len(validData) != 0 {
msg := fmt.Sprintf("length of validData(%d) must be 0 when not nullable", len(validData))
return merr.WrapErrParameterInvalidMsg(msg)
}
}
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast Int8VectorBuilder")
}
if w.nullable {
builder, ok := w.builder.(*array.BinaryBuilder)
if !ok {
return errors.New("failed to cast to BinaryBuilder for nullable Int8Vector")
}
byteLength := dim
length := len(data) / byteLength
builder.Reserve(numRows)
dataIdx := 0
for i := 0; i < numRows; i++ {
if len(validData) > 0 && !validData[i] {
builder.AppendNull()
} else {
vec := data[dataIdx*dim : (dataIdx+1)*dim]
vecBytes := arrow.Int8Traits.CastToBytes(vec)
builder.Append(vecBytes)
dataIdx++
}
}
} else {
builder, ok := w.builder.(*array.FixedSizeBinaryBuilder)
if !ok {
return errors.New("failed to cast to FixedSizeBinaryBuilder for non-nullable Int8Vector")
}
builder.Reserve(length)
for i := 0; i < length; i++ {
vec := data[i*dim : (i+1)*dim]
vecBytes := arrow.Int8Traits.CastToBytes(vec)
builder.Append(vecBytes)
builder.Reserve(numRows)
for i := 0; i < numRows; i++ {
vec := data[i*dim : (i+1)*dim]
vecBytes := arrow.Int8Traits.CastToBytes(vec)
builder.Append(vecBytes)
}
}
return nil
@ -827,6 +1051,11 @@ func (w *NativePayloadWriter) FinishPayloadWriter() error {
[]string{"elementType", "dim"},
[]string{fmt.Sprintf("%d", int32(*w.elementType)), fmt.Sprintf("%d", w.dim.GetValue())},
)
} else if w.nullable && typeutil.IsVectorType(w.dataType) && !typeutil.IsSparseFloatVectorType(w.dataType) {
metadata = arrow.NewMetadata(
[]string{"dim"},
[]string{fmt.Sprintf("%d", w.dim.GetValue())},
)
}
field := arrow.Field{
@ -849,7 +1078,8 @@ func (w *NativePayloadWriter) FinishPayloadWriter() error {
defer table.Release()
arrowWriterProps := pqarrow.DefaultWriterProps()
if w.dataType == schemapb.DataType_ArrayOfVector {
if w.dataType == schemapb.DataType_ArrayOfVector ||
(w.nullable && typeutil.IsVectorType(w.dataType) && !typeutil.IsSparseFloatVectorType(w.dataType)) {
// Store metadata in the Arrow writer properties
arrowWriterProps = pqarrow.NewArrowWriterProperties(
pqarrow.WithStoreSchema(),

View File

@ -260,14 +260,14 @@ func TestPayloadWriter_Failed(t *testing.T) {
err = w.FinishPayloadWriter()
require.NoError(t, err)
err = w.AddBinaryVectorToPayload(data, 8)
err = w.AddBinaryVectorToPayload(data, 8, nil)
require.Error(t, err)
w, err = NewPayloadWriter(schemapb.DataType_Int64)
require.Nil(t, err)
require.NotNil(t, w)
err = w.AddBinaryVectorToPayload(data, 8)
err = w.AddBinaryVectorToPayload(data, 8, nil)
require.Error(t, err)
})

View File

@ -303,7 +303,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
fmt.Printf("\t\t%d : %s\n", i, val[i])
}
case schemapb.DataType_BinaryVector:
val, dim, err := reader.GetBinaryVectorFromPayload()
val, dim, _, _, err := reader.GetBinaryVectorFromPayload()
if err != nil {
return err
}
@ -318,7 +318,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
fmt.Println()
}
case schemapb.DataType_Float16Vector:
val, dim, err := reader.GetFloat16VectorFromPayload()
val, dim, _, _, err := reader.GetFloat16VectorFromPayload()
if err != nil {
return err
}
@ -333,7 +333,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
fmt.Println()
}
case schemapb.DataType_BFloat16Vector:
val, dim, err := reader.GetBFloat16VectorFromPayload()
val, dim, _, _, err := reader.GetBFloat16VectorFromPayload()
if err != nil {
return err
}
@ -349,7 +349,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
}
case schemapb.DataType_FloatVector:
val, dim, err := reader.GetFloatVectorFromPayload()
val, dim, _, _, err := reader.GetFloatVectorFromPayload()
if err != nil {
return err
}
@ -362,6 +362,20 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
}
fmt.Println()
}
case schemapb.DataType_Int8Vector:
val, dim, _, _, err := reader.GetInt8VectorFromPayload()
if err != nil {
return err
}
length := len(val) / dim
for i := 0; i < length; i++ {
fmt.Printf("\t\t%d :", i)
for j := 0; j < dim; j++ {
idx := i*dim + j
fmt.Printf(" %d", val[idx])
}
fmt.Println()
}
case schemapb.DataType_JSON:
rows, err := reader.GetPayloadLengthFromReader()
@ -397,7 +411,7 @@ func printPayloadValues(colType schemapb.DataType, reader PayloadReaderInterface
fmt.Printf("\t\t%d : %v\n", i, v)
}
case schemapb.DataType_SparseFloatVector:
sparseData, _, err := reader.GetSparseFloatVectorFromPayload()
sparseData, _, _, err := reader.GetSparseFloatVectorFromPayload()
if err != nil {
return err
}

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/etcdpb"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
"github.com/milvus-io/milvus/pkg/v2/util/uniquegenerator"
)
@ -215,6 +216,23 @@ func TestPrintBinlogFiles(t *testing.T) {
Description: "description_15",
DataType: schemapb.DataType_Geometry,
},
{
FieldID: 114,
Name: "field_int8_vector",
IsPrimaryKey: false,
Description: "description_16",
DataType: schemapb.DataType_Int8Vector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "4"},
},
},
{
FieldID: 115,
Name: "field_sparse_float_vector",
IsPrimaryKey: false,
Description: "description_17",
DataType: schemapb.DataType_SparseFloatVector,
},
},
},
}
@ -280,6 +298,19 @@ func TestPrintBinlogFiles(t *testing.T) {
{0x01, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A, 0x0D, 0x1B, 0x4F, 0x4F, 0x9A, 0x3D, 0x40, 0x03, 0xA6, 0xB4, 0xA6, 0xA4, 0xD2, 0xC5, 0xC0, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A},
},
},
114: &Int8VectorFieldData{
Data: []int8{1, 2, 3, 4, 5, 6, 7, 8},
Dim: 4,
},
115: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 100,
Contents: [][]byte{
typeutil.CreateSparseFloatRow([]uint32{0, 1, 2}, []float32{1.1, 1.2, 1.3}),
typeutil.CreateSparseFloatRow([]uint32{10, 20, 30}, []float32{2.1, 2.2, 2.3}),
},
},
},
},
}
@ -344,6 +375,19 @@ func TestPrintBinlogFiles(t *testing.T) {
{0x01, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A, 0x0D, 0x1B, 0x4F, 0x4F, 0x9A, 0x3D, 0x40, 0x03, 0xA6, 0xB4, 0xA6, 0xA4, 0xD2, 0xC5, 0xC0, 0xD2, 0x4A, 0x4D, 0x6A, 0x8B, 0x3C, 0x5C, 0x0A},
},
},
114: &Int8VectorFieldData{
Data: []int8{11, 12, 13, 14, 15, 16, 17, 18},
Dim: 4,
},
115: &SparseFloatVectorFieldData{
SparseFloatArray: schemapb.SparseFloatArray{
Dim: 100,
Contents: [][]byte{
typeutil.CreateSparseFloatRow([]uint32{5, 6, 7}, []float32{3.1, 3.2, 3.3}),
typeutil.CreateSparseFloatRow([]uint32{15, 25, 35}, []float32{4.1, 4.2, 4.3}),
},
},
},
},
}
firstBlobs, err := insertCodec.Serialize(1, 1, insertDataFirst)

View File

@ -38,8 +38,28 @@ func ConvertToArrowSchema(schema *schemapb.CollectionSchema, useFieldID bool) (*
}
arrowType := serdeMap[field.DataType].arrowType(dim, elementType)
if field.GetNullable() {
switch field.DataType {
case schemapb.DataType_BinaryVector, schemapb.DataType_FloatVector,
schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_Int8Vector:
arrowType = arrow.BinaryTypes.Binary
}
}
arrowField := ConvertToArrowField(field, arrowType, useFieldID)
if field.GetNullable() {
switch field.DataType {
case schemapb.DataType_BinaryVector, schemapb.DataType_FloatVector,
schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_Int8Vector:
arrowField.Metadata = arrow.NewMetadata(
[]string{packed.ArrowFieldIdMetadataKey, "dim"},
[]string{strconv.Itoa(int(field.GetFieldID())), strconv.Itoa(dim)},
)
}
}
// Add extra metadata for ArrayOfVector
if field.DataType == schemapb.DataType_ArrayOfVector {
arrowField.Metadata = arrow.NewMetadata(

View File

@ -43,12 +43,18 @@ func TestConvertArrowSchema(t *testing.T) {
{FieldID: 14, Name: "field13", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 15, Name: "field14", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 16, Name: "field15", DataType: schemapb.DataType_Int8Vector, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 17, Name: "field16", DataType: schemapb.DataType_BinaryVector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 18, Name: "field17", DataType: schemapb.DataType_FloatVector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 19, Name: "field18", DataType: schemapb.DataType_Float16Vector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 20, Name: "field19", DataType: schemapb.DataType_BFloat16Vector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 21, Name: "field20", DataType: schemapb.DataType_Int8Vector, Nullable: true, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
{FieldID: 22, Name: "field21", DataType: schemapb.DataType_SparseFloatVector, Nullable: true},
}
StructArrayFieldSchemas := []*schemapb.StructArrayFieldSchema{
{FieldID: 17, Name: "struct_field0", Fields: []*schemapb.FieldSchema{
{FieldID: 18, Name: "field16", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
{FieldID: 19, Name: "field17", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float},
{FieldID: 23, Name: "struct_field0", Fields: []*schemapb.FieldSchema{
{FieldID: 24, Name: "field22", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64},
{FieldID: 25, Name: "field23", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float},
}},
}
@ -59,6 +65,14 @@ func TestConvertArrowSchema(t *testing.T) {
arrowSchema, err := ConvertToArrowSchema(schema, false)
assert.NoError(t, err)
assert.Equal(t, len(fieldSchemas)+len(StructArrayFieldSchemas[0].Fields), len(arrowSchema.Fields()))
for i, field := range arrowSchema.Fields() {
if i >= 16 && i <= 20 {
dimVal, ok := field.Metadata.GetValue("dim")
assert.True(t, ok, "nullable vector field should have dim metadata")
assert.Equal(t, "128", dimVal)
}
}
}
func TestConvertArrowSchemaWithoutDim(t *testing.T) {

View File

@ -553,8 +553,12 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry {
b.AppendNull()
return true
}
if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok {
if v, ok := v.([]byte); ok {
if v, ok := v.([]byte); ok {
if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok {
builder.Append(v)
return true
}
if builder, ok := b.(*array.BinaryBuilder); ok {
builder.Append(v)
return true
}
@ -607,14 +611,21 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry {
b.AppendNull()
return true
}
var bytesData []byte
if vv, ok := v.([]byte); ok {
bytesData = vv
} else if vv, ok := v.([]int8); ok {
bytesData = arrow.Int8Traits.CastToBytes(vv)
} else {
return false
}
if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok {
if vv, ok := v.([]byte); ok {
builder.Append(vv)
return true
} else if vv, ok := v.([]int8); ok {
builder.Append(arrow.Int8Traits.CastToBytes(vv))
return true
}
builder.Append(bytesData)
return true
}
if builder, ok := b.(*array.BinaryBuilder); ok {
builder.Append(bytesData)
return true
}
return false
},
@ -643,15 +654,19 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry {
b.AppendNull()
return true
}
if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok {
if vv, ok := v.([]float32); ok {
dim := len(vv)
byteLength := dim * 4
bytesData := make([]byte, byteLength)
for i, vec := range vv {
bytes := math.Float32bits(vec)
common.Endian.PutUint32(bytesData[i*4:], bytes)
}
if vv, ok := v.([]float32); ok {
dim := len(vv)
byteLength := dim * 4
bytesData := make([]byte, byteLength)
for i, vec := range vv {
bytes := math.Float32bits(vec)
common.Endian.PutUint32(bytesData[i*4:], bytes)
}
if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok {
builder.Append(bytesData)
return true
}
if builder, ok := b.(*array.BinaryBuilder); ok {
builder.Append(bytesData)
return true
}
@ -987,7 +1002,16 @@ func newSingleFieldRecordWriter(field *schemapb.FieldSchema, writer io.Writer, o
[]string{fmt.Sprintf("%d", int32(elementType)), fmt.Sprintf("%d", dim)},
)
}
arrowType = serdeMap[field.DataType].arrowType(int(dim), elementType)
if field.GetNullable() && typeutil.IsVectorType(field.DataType) && !typeutil.IsSparseFloatVectorType(field.DataType) {
arrowType = arrow.BinaryTypes.Binary
fieldMetadata = arrow.NewMetadata(
[]string{"dim"},
[]string{fmt.Sprintf("%d", dim)},
)
} else {
arrowType = serdeMap[field.DataType].arrowType(int(dim), elementType)
}
w := &singleFieldRecordWriter{
fieldId: field.FieldID,
@ -1199,10 +1223,40 @@ func BuildRecord(b *array.RecordBuilder, data *InsertData, schema *schemapb.Coll
elementType = field.GetElementType()
}
for j := 0; j < fieldData.RowNum(); j++ {
ok = typeEntry.serialize(fBuilder, fieldData.GetRow(j), elementType)
if !ok {
return merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", field.DataType.String()))
if field.GetNullable() && typeutil.IsVectorType(field.DataType) {
var validData []bool
switch fd := fieldData.(type) {
case *FloatVectorFieldData:
validData = fd.ValidData
case *BinaryVectorFieldData:
validData = fd.ValidData
case *Float16VectorFieldData:
validData = fd.ValidData
case *BFloat16VectorFieldData:
validData = fd.ValidData
case *SparseFloatVectorFieldData:
validData = fd.ValidData
case *Int8VectorFieldData:
validData = fd.ValidData
}
// Use len(validData) as logical row count, GetRow takes logical index
for j := 0; j < len(validData); j++ {
if !validData[j] {
fBuilder.(*array.BinaryBuilder).AppendNull()
} else {
rowData := fieldData.GetRow(j)
ok = typeEntry.serialize(fBuilder, rowData, elementType)
if !ok {
return merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", field.DataType.String()))
}
}
}
} else {
for j := 0; j < fieldData.RowNum(); j++ {
ok = typeEntry.serialize(fBuilder, fieldData.GetRow(j), elementType)
if !ok {
return merr.WrapErrServiceInternal(fmt.Sprintf("serialize error on type %s", field.DataType.String()))
}
}
}
return nil

Some files were not shown because too many files have changed in this diff Show More