diff --git a/internal/core/src/common/Chunk.h b/internal/core/src/common/Chunk.h index 387edf59bd..9d3512f24c 100644 --- a/internal/core/src/common/Chunk.h +++ b/internal/core/src/common/Chunk.h @@ -438,11 +438,11 @@ class VectorArrayChunk : public Chunk { offsets_lens_ = reinterpret_cast(data); auto offset = 0; - lims_.reserve(row_nums_ + 1); - lims_.push_back(offset); + offsets_.reserve(row_nums_ + 1); + offsets_.push_back(offset); for (int64_t i = 0; i < row_nums_; i++) { offset += offsets_lens_[i * 2 + 1]; - lims_.push_back(offset); + offsets_.push_back(offset); } } @@ -507,17 +507,15 @@ class VectorArrayChunk : public Chunk { } const size_t* - Lims() const { - return lims_.data(); + Offsets() const { + return offsets_.data(); } private: int64_t dim_; uint32_t* offsets_lens_; milvus::DataType element_type_; - // The name 'Lims' is consistent with knowhere::DataSet::SetLims which describes the number of vectors - // in each vector array (embedding list). This is needed as vectors are flattened in the chunk. - std::vector lims_; + std::vector offsets_; }; class SparseFloatVectorChunk : public Chunk { diff --git a/internal/core/src/common/ChunkWriter.cpp b/internal/core/src/common/ChunkWriter.cpp index dcad7c1888..0808279f42 100644 --- a/internal/core/src/common/ChunkWriter.cpp +++ b/internal/core/src/common/ChunkWriter.cpp @@ -339,34 +339,9 @@ VectorArrayChunkWriter::write(const arrow::ArrayVector& array_vec) { target_ = std::make_shared(total_size); } - switch (element_type_) { - case milvus::DataType::VECTOR_FLOAT: - writeFloatVectorArray(array_vec); - break; - case milvus::DataType::VECTOR_BINARY: - ThrowInfo(NotImplemented, - "BinaryVector in VectorArray not implemented yet"); - case milvus::DataType::VECTOR_FLOAT16: - ThrowInfo(NotImplemented, - "Float16Vector in VectorArray not implemented yet"); - case milvus::DataType::VECTOR_BFLOAT16: - ThrowInfo(NotImplemented, - "BFloat16Vector in VectorArray not implemented yet"); - case milvus::DataType::VECTOR_INT8: - ThrowInfo(NotImplemented, - "Int8Vector in VectorArray not implemented yet"); - default: - ThrowInfo(NotImplemented, - "Unsupported element type in VectorArray: {}", - static_cast(element_type_)); - } -} - -void -VectorArrayChunkWriter::writeFloatVectorArray( - const arrow::ArrayVector& array_vec) { + // Seirialization, the format is: [offsets_lens][all_vector_data_concatenated] std::vector offsets_lens; - std::vector float_data_ptrs; + std::vector vector_data_ptrs; std::vector data_sizes; uint32_t current_offset = @@ -375,25 +350,27 @@ VectorArrayChunkWriter::writeFloatVectorArray( for (const auto& array_data : array_vec) { auto list_array = std::static_pointer_cast(array_data); - auto float_values = - std::static_pointer_cast(list_array->values()); - const float* raw_floats = float_values->raw_values(); + auto binary_values = + std::static_pointer_cast( + list_array->values()); const int32_t* list_offsets = list_array->raw_value_offsets(); + int byte_width = binary_values->byte_width(); // Generate offsets and lengths for each row - // Each list contains multiple float vectors which are flattened, so the float count - // in each list is vector count * dim. + // Each list contains multiple vectors, each stored as a fixed-size binary chunk for (int64_t i = 0; i < list_array->length(); i++) { auto start_idx = list_offsets[i]; auto end_idx = list_offsets[i + 1]; - auto vector_count = (end_idx - start_idx) / dim_; - auto byte_size = (end_idx - start_idx) * sizeof(float); + auto vector_count = end_idx - start_idx; + auto byte_size = vector_count * byte_width; offsets_lens.push_back(current_offset); offsets_lens.push_back(static_cast(vector_count)); - float_data_ptrs.push_back(raw_floats + start_idx); - data_sizes.push_back(byte_size); + for (int j = start_idx; j < end_idx; j++) { + vector_data_ptrs.push_back(binary_values->GetValue(j)); + data_sizes.push_back(byte_width); + } current_offset += byte_size; } @@ -409,8 +386,8 @@ VectorArrayChunkWriter::writeFloatVectorArray( } target_->write(&offsets_lens.back(), sizeof(uint32_t)); // final offset - for (size_t i = 0; i < float_data_ptrs.size(); i++) { - target_->write(float_data_ptrs[i], data_sizes[i]); + for (size_t i = 0; i < vector_data_ptrs.size(); i++) { + target_->write(vector_data_ptrs[i], data_sizes[i]); } } @@ -427,19 +404,18 @@ VectorArrayChunkWriter::calculateTotalSize( std::static_pointer_cast(array_data); switch (element_type_) { - case milvus::DataType::VECTOR_FLOAT: { - auto float_values = std::static_pointer_cast( - list_array->values()); - total_size += float_values->length() * sizeof(float); - break; - } + case milvus::DataType::VECTOR_FLOAT: case milvus::DataType::VECTOR_BINARY: case milvus::DataType::VECTOR_FLOAT16: case milvus::DataType::VECTOR_BFLOAT16: - case milvus::DataType::VECTOR_INT8: - ThrowInfo(NotImplemented, - "Element type {} in VectorArray not implemented yet", - static_cast(element_type_)); + case milvus::DataType::VECTOR_INT8: { + auto binary_values = + std::static_pointer_cast( + list_array->values()); + total_size += + binary_values->length() * binary_values->byte_width(); + break; + } default: ThrowInfo(DataTypeInvalid, "Invalid element type {} for VectorArray", diff --git a/internal/core/src/common/ChunkWriter.h b/internal/core/src/common/ChunkWriter.h index 767fd60548..18ef8e0241 100644 --- a/internal/core/src/common/ChunkWriter.h +++ b/internal/core/src/common/ChunkWriter.h @@ -273,9 +273,6 @@ class VectorArrayChunkWriter : public ChunkWriterBase { finish() override; private: - void - writeFloatVectorArray(const arrow::ArrayVector& array_vec); - size_t calculateTotalSize(const arrow::ArrayVector& array_vec); diff --git a/internal/core/src/common/FieldData.cpp b/internal/core/src/common/FieldData.cpp index a85f219c7b..4ed29bb422 100644 --- a/internal/core/src/common/FieldData.cpp +++ b/internal/core/src/common/FieldData.cpp @@ -25,6 +25,7 @@ #include "common/Exception.h" #include "common/FieldDataInterface.h" #include "common/Json.h" +#include "index/Utils.h" #include "simdjson/padded_string.h" namespace milvus { @@ -348,46 +349,47 @@ FieldDataImpl::FillFieldData( std::vector values(element_count); switch (element_type) { - case DataType::VECTOR_FLOAT: { - auto float_array = - std::dynamic_pointer_cast( + case DataType::VECTOR_FLOAT: + case DataType::VECTOR_BINARY: + case DataType::VECTOR_FLOAT16: + case DataType::VECTOR_BFLOAT16: + case DataType::VECTOR_INT8: { + // All vector types use FixedSizeBinaryArray and have the same serialization logic + auto binary_array = + std::dynamic_pointer_cast( values_array); - AssertInfo( - float_array != nullptr, - "Expected FloatArray for VECTOR_FLOAT element type"); + AssertInfo(binary_array != nullptr, + "Expected FixedSizeBinaryArray for VectorArray " + "element"); + + // Calculate bytes per vector using the unified function + auto bytes_per_vec = + milvus::vector_bytes_per_element(element_type, dim); for (size_t index = 0; index < element_count; ++index) { int64_t start_offset = list_array->value_offset(index); int64_t end_offset = list_array->value_offset(index + 1); - int64_t num_floats = end_offset - start_offset; - AssertInfo(num_floats % dim == 0, - "Invalid data: number of floats ({}) not " - "divisible by " - "dimension ({})", - num_floats, - dim); + int64_t num_vectors = end_offset - start_offset; - int num_vectors = num_floats / dim; - const float* data_ptr = - float_array->raw_values() + start_offset; - values[index] = - VectorArray(static_cast(data_ptr), - num_vectors, - dim, - element_type); + auto data_size = num_vectors * bytes_per_vec; + auto data_ptr = std::make_unique(data_size); + + for (int64_t i = 0; i < num_vectors; i++) { + const uint8_t* binary_data = + binary_array->GetValue(start_offset + i); + uint8_t* dest = data_ptr.get() + i * bytes_per_vec; + std::memcpy(dest, binary_data, bytes_per_vec); + } + + values[index] = VectorArray( + static_cast(data_ptr.get()), + num_vectors, + dim, + element_type); } break; } - case DataType::VECTOR_BINARY: - case DataType::VECTOR_FLOAT16: - case DataType::VECTOR_BFLOAT16: - case DataType::VECTOR_INT8: - ThrowInfo( - NotImplemented, - "Element type {} in VectorArray not implemented yet", - GetDataTypeName(element_type)); - break; default: ThrowInfo(DataTypeInvalid, "Unsupported element type {} in VectorArray", diff --git a/internal/core/src/common/Schema.cpp b/internal/core/src/common/Schema.cpp index 5e0810e173..c221efb629 100644 --- a/internal/core/src/common/Schema.cpp +++ b/internal/core/src/common/Schema.cpp @@ -94,8 +94,8 @@ Schema::ConvertToArrowSchema() const { std::shared_ptr arrow_data_type = nullptr; auto data_type = meta.get_data_type(); if (data_type == DataType::VECTOR_ARRAY) { - arrow_data_type = - GetArrowDataTypeForVectorArray(meta.get_element_type()); + arrow_data_type = GetArrowDataTypeForVectorArray( + meta.get_element_type(), meta.get_dim()); } else { arrow_data_type = GetArrowDataType(data_type, dim); } diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index 434575a20f..7a32e6875c 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -205,10 +205,23 @@ GetArrowDataType(DataType data_type, int dim = 1) { } inline std::shared_ptr -GetArrowDataTypeForVectorArray(DataType elem_type) { +GetArrowDataTypeForVectorArray(DataType elem_type, int dim) { + if (dim <= 0) { + ThrowInfo(DataTypeInvalid, "dim must be provided for VectorArray"); + } + // VectorArray stores vectors as FixedSizeBinaryArray + // We must have dim to create the correct fixed_size_binary type switch (elem_type) { case DataType::VECTOR_FLOAT: - return arrow::list(arrow::float32()); + return arrow::list(arrow::fixed_size_binary(dim * sizeof(float))); + case DataType::VECTOR_BINARY: + return arrow::list(arrow::fixed_size_binary((dim + 7) / 8)); + case DataType::VECTOR_FLOAT16: + return arrow::list(arrow::fixed_size_binary(dim * 2)); + case DataType::VECTOR_BFLOAT16: + return arrow::list(arrow::fixed_size_binary(dim * 2)); + case DataType::VECTOR_INT8: + return arrow::list(arrow::fixed_size_binary(dim)); default: { ThrowInfo(DataTypeInvalid, fmt::format("failed to get arrow type for vector array, " @@ -534,7 +547,10 @@ IsFloatVectorMetricType(const MetricType& metric_type) { metric_type == knowhere::metric::IP || metric_type == knowhere::metric::COSINE || metric_type == knowhere::metric::BM25 || - metric_type == knowhere::metric::MAX_SIM; + metric_type == knowhere::metric::MAX_SIM || + metric_type == knowhere::metric::MAX_SIM_COSINE || + metric_type == knowhere::metric::MAX_SIM_IP || + metric_type == knowhere::metric::MAX_SIM_L2; } inline bool @@ -543,14 +559,20 @@ IsBinaryVectorMetricType(const MetricType& metric_type) { metric_type == knowhere::metric::JACCARD || metric_type == knowhere::metric::SUPERSTRUCTURE || metric_type == knowhere::metric::SUBSTRUCTURE || - metric_type == knowhere::metric::MHJACCARD; + metric_type == knowhere::metric::MHJACCARD || + metric_type == knowhere::metric::MAX_SIM_HAMMING || + metric_type == knowhere::metric::MAX_SIM_JACCARD; } inline bool IsIntVectorMetricType(const MetricType& metric_type) { return metric_type == knowhere::metric::L2 || metric_type == knowhere::metric::IP || - metric_type == knowhere::metric::COSINE; + metric_type == knowhere::metric::COSINE || + metric_type == knowhere::metric::MAX_SIM || + metric_type == knowhere::metric::MAX_SIM_COSINE || + metric_type == knowhere::metric::MAX_SIM_IP || + metric_type == knowhere::metric::MAX_SIM_L2; } // Plus 1 because we can't use greater(>) symbol @@ -746,6 +768,30 @@ FromValCase(milvus::proto::plan::GenericValue::ValCase val_case) { return DataType::NONE; } } + +// Calculate bytes per vector element for different vector types +// Used by VectorArray and storage utilities +inline size_t +vector_bytes_per_element(const DataType data_type, int64_t dim) { + switch (data_type) { + case DataType::VECTOR_BINARY: + // Binary vector stores bits, so dim represents bit count + // Need (dim + 7) / 8 bytes to store dim bits + return (dim + 7) / 8; + case DataType::VECTOR_FLOAT: + return dim * sizeof(float); + case DataType::VECTOR_FLOAT16: + return dim * sizeof(float16); + case DataType::VECTOR_BFLOAT16: + return dim * sizeof(bfloat16); + case DataType::VECTOR_INT8: + return dim * sizeof(int8); + default: + ThrowInfo(UnexpectedError, + fmt::format("invalid data type: {}", data_type)); + } +} + } // namespace milvus template <> struct fmt::formatter : formatter { diff --git a/internal/core/src/common/VectorArray.h b/internal/core/src/common/VectorArray.h index 936ca7a417..aa7f57c64b 100644 --- a/internal/core/src/common/VectorArray.h +++ b/internal/core/src/common/VectorArray.h @@ -41,16 +41,8 @@ class VectorArray : public milvus::VectorTrait { assert(num_vectors > 0); assert(dim > 0); - switch (element_type) { - case DataType::VECTOR_FLOAT: - size_ = num_vectors * dim * sizeof(float); - break; - default: - ThrowInfo(NotImplemented, - "Direct VectorArray construction only supports " - "VECTOR_FLOAT, got {}", - GetDataTypeName(element_type)); - } + size_ = + num_vectors * milvus::vector_bytes_per_element(element_type, dim); data_ = std::make_unique(size_); std::memcpy(data_.get(), data, size_); @@ -73,8 +65,49 @@ class VectorArray : public milvus::VectorTrait { data_ = std::unique_ptr(reinterpret_cast(data)); break; } + case VectorFieldProto::kBinaryVector: { + element_type_ = DataType::VECTOR_BINARY; + int bytes_per_vector = (dim_ + 7) / 8; + length_ = + vector_field.binary_vector().size() / bytes_per_vector; + size_ = vector_field.binary_vector().size(); + data_ = std::make_unique(size_); + std::memcpy( + data_.get(), vector_field.binary_vector().data(), size_); + break; + } + case VectorFieldProto::kFloat16Vector: { + element_type_ = DataType::VECTOR_FLOAT16; + int bytes_per_element = 2; // 2 bytes per float16 + length_ = vector_field.float16_vector().size() / + (dim_ * bytes_per_element); + size_ = vector_field.float16_vector().size(); + data_ = std::make_unique(size_); + std::memcpy( + data_.get(), vector_field.float16_vector().data(), size_); + break; + } + case VectorFieldProto::kBfloat16Vector: { + element_type_ = DataType::VECTOR_BFLOAT16; + int bytes_per_element = 2; // 2 bytes per bfloat16 + length_ = vector_field.bfloat16_vector().size() / + (dim_ * bytes_per_element); + size_ = vector_field.bfloat16_vector().size(); + data_ = std::make_unique(size_); + std::memcpy( + data_.get(), vector_field.bfloat16_vector().data(), size_); + break; + } + case VectorFieldProto::kInt8Vector: { + element_type_ = DataType::VECTOR_INT8; + length_ = vector_field.int8_vector().size() / dim_; + size_ = vector_field.int8_vector().size(); + data_ = std::make_unique(size_); + std::memcpy( + data_.get(), vector_field.int8_vector().data(), size_); + break; + } default: { - // TODO(SpadeA): add other vector types ThrowInfo(NotImplemented, "Not implemented vector type: {}", static_cast(vector_field.data_case())); @@ -160,8 +193,24 @@ class VectorArray : public milvus::VectorTrait { return reinterpret_cast(data_.get()) + index * dim_; } + case DataType::VECTOR_BINARY: { + // Binary vectors are packed bits + int bytes_per_vector = (dim_ + 7) / 8; + return reinterpret_cast( + data_.get() + index * bytes_per_vector); + } + case DataType::VECTOR_FLOAT16: + case DataType::VECTOR_BFLOAT16: { + // Float16/BFloat16 are 2 bytes per element + return reinterpret_cast(data_.get() + + index * dim_ * 2); + } + case DataType::VECTOR_INT8: { + // Int8 is 1 byte per element + return reinterpret_cast(data_.get() + + index * dim_); + } default: { - // TODO(SpadeA): add other vector types ThrowInfo(NotImplemented, "Not implemented vector type: {}", static_cast(element_type_)); @@ -180,8 +229,23 @@ class VectorArray : public milvus::VectorTrait { data, data + length_ * dim_); break; } + case DataType::VECTOR_BINARY: { + vector_field.set_binary_vector(data_.get(), size_); + break; + } + case DataType::VECTOR_FLOAT16: { + vector_field.set_float16_vector(data_.get(), size_); + break; + } + case DataType::VECTOR_BFLOAT16: { + vector_field.set_bfloat16_vector(data_.get(), size_); + break; + } + case DataType::VECTOR_INT8: { + vector_field.set_int8_vector(data_.get(), size_); + break; + } default: { - // TODO(SpadeA): add other vector types ThrowInfo(NotImplemented, "Not implemented vector type: {}", static_cast(element_type_)); @@ -300,8 +364,23 @@ class VectorArrayView { "VectorElement must be float for VECTOR_FLOAT"); return reinterpret_cast(data_) + index * dim_; } + case DataType::VECTOR_BINARY: { + // Binary vectors are packed bits + int bytes_per_vector = (dim_ + 7) / 8; + return reinterpret_cast( + data_ + index * bytes_per_vector); + } + case DataType::VECTOR_FLOAT16: + case DataType::VECTOR_BFLOAT16: { + // Float16/BFloat16 are 2 bytes per element + return reinterpret_cast(data_ + + index * dim_ * 2); + } + case DataType::VECTOR_INT8: { + // Int8 is 1 byte per element + return reinterpret_cast(data_ + index * dim_); + } default: { - // TODO(SpadeA): add other vector types. ThrowInfo(NotImplemented, "Not implemented vector type: {}", static_cast(element_type_)); @@ -320,8 +399,23 @@ class VectorArrayView { data, data + length_ * dim_); break; } + case DataType::VECTOR_BINARY: { + vector_array.set_binary_vector(data_, size_); + break; + } + case DataType::VECTOR_FLOAT16: { + vector_array.set_float16_vector(data_, size_); + break; + } + case DataType::VECTOR_BFLOAT16: { + vector_array.set_bfloat16_vector(data_, size_); + break; + } + case DataType::VECTOR_INT8: { + vector_array.set_int8_vector(data_, size_); + break; + } default: { - // TODO(SpadeA): add other vector types ThrowInfo(NotImplemented, "Not implemented vector type: {}", static_cast(element_type_)); diff --git a/internal/core/src/common/VectorArrayChunkTest.cpp b/internal/core/src/common/VectorArrayChunkTest.cpp index dde4d22c99..dd07782e04 100644 --- a/internal/core/src/common/VectorArrayChunkTest.cpp +++ b/internal/core/src/common/VectorArrayChunkTest.cpp @@ -40,12 +40,57 @@ class VectorArrayChunkTest : public ::testing::Test { return result; } + std::vector + generateFloat16Vector(int64_t seed, int64_t N, int64_t dim) { + std::vector result(dim * N); + std::default_random_engine gen(seed); + std::uniform_int_distribution dist(0, 65535); + + for (int64_t i = 0; i < dim * N; ++i) { + result[i] = dist(gen); + } + return result; + } + + std::vector + generateBFloat16Vector(int64_t seed, int64_t N, int64_t dim) { + // Same as Float16 for testing purposes + return generateFloat16Vector(seed, N, dim); + } + + std::vector + generateInt8Vector(int64_t seed, int64_t N, int64_t dim) { + std::vector result(dim * N); + std::default_random_engine gen(seed); + std::uniform_int_distribution dist(-128, 127); + + for (int64_t i = 0; i < dim * N; ++i) { + result[i] = static_cast(dist(gen)); + } + return result; + } + + std::vector + generateBinaryVector(int64_t seed, int64_t N, int64_t dim) { + std::vector result((dim * N + 7) / 8); + std::default_random_engine gen(seed); + std::uniform_int_distribution dist(0, 255); + + for (size_t i = 0; i < result.size(); ++i) { + result[i] = static_cast(dist(gen)); + } + return result; + } + std::shared_ptr createFloatVectorListArray(const std::vector& data, - const std::vector& offsets) { - auto float_builder = std::make_shared(); + const std::vector& offsets, + int64_t dim) { + int byte_width = dim * sizeof(float); + auto value_builder = std::make_shared( + arrow::fixed_size_binary(byte_width)); auto list_builder = std::make_shared( - arrow::default_memory_pool(), float_builder); + arrow::default_memory_pool(), value_builder); arrow::Status ast; for (size_t i = 0; i < offsets.size() - 1; ++i) { @@ -54,8 +99,112 @@ class VectorArrayChunkTest : public ::testing::Test { int32_t start = offsets[i]; int32_t end = offsets[i + 1]; - for (int32_t j = start; j < end; ++j) { - float_builder->Append(data[j]); + // Each vector is dim floats + for (int32_t j = start; j < end; j += dim) { + // Convert float vector to binary + const uint8_t* binary_data = + reinterpret_cast(&data[j]); + ast = value_builder->Append(binary_data); + assert(ast.ok()); + } + } + + std::shared_ptr array; + ast = list_builder->Finish(&array); + assert(ast.ok()); + return std::static_pointer_cast(array); + } + + std::shared_ptr + createFloat16VectorListArray(const std::vector& data, + const std::vector& offsets, + int64_t dim) { + int byte_width = dim * 2; // 2 bytes per float16 + auto value_builder = std::make_shared( + arrow::fixed_size_binary(byte_width)); + auto list_builder = std::make_shared( + arrow::default_memory_pool(), value_builder); + + arrow::Status ast; + for (size_t i = 0; i < offsets.size() - 1; ++i) { + ast = list_builder->Append(); + assert(ast.ok()); + int32_t start = offsets[i]; + int32_t end = offsets[i + 1]; + + for (int32_t j = start; j < end; j += dim) { + const uint8_t* binary_data = + reinterpret_cast(&data[j]); + ast = value_builder->Append(binary_data); + assert(ast.ok()); + } + } + + std::shared_ptr array; + ast = list_builder->Finish(&array); + assert(ast.ok()); + return std::static_pointer_cast(array); + } + + std::shared_ptr + createBFloat16VectorListArray(const std::vector& data, + const std::vector& offsets, + int64_t dim) { + // Same as Float16 but for bfloat16 + return createFloat16VectorListArray(data, offsets, dim); + } + + std::shared_ptr + createInt8VectorListArray(const std::vector& data, + const std::vector& offsets, + int64_t dim) { + int byte_width = dim; // 1 byte per int8 + auto value_builder = std::make_shared( + arrow::fixed_size_binary(byte_width)); + auto list_builder = std::make_shared( + arrow::default_memory_pool(), value_builder); + + arrow::Status ast; + for (size_t i = 0; i < offsets.size() - 1; ++i) { + ast = list_builder->Append(); + assert(ast.ok()); + int32_t start = offsets[i]; + int32_t end = offsets[i + 1]; + + for (int32_t j = start; j < end; j += dim) { + const uint8_t* binary_data = + reinterpret_cast(&data[j]); + ast = value_builder->Append(binary_data); + assert(ast.ok()); + } + } + + std::shared_ptr array; + ast = list_builder->Finish(&array); + assert(ast.ok()); + return std::static_pointer_cast(array); + } + + std::shared_ptr + createBinaryVectorListArray(const std::vector& data, + const std::vector& offsets, + int64_t dim) { + int byte_width = (dim + 7) / 8; // bits packed into bytes + auto value_builder = std::make_shared( + arrow::fixed_size_binary(byte_width)); + auto list_builder = std::make_shared( + arrow::default_memory_pool(), value_builder); + + arrow::Status ast; + for (size_t i = 0; i < offsets.size() - 1; ++i) { + ast = list_builder->Append(); + assert(ast.ok()); + int32_t start = offsets[i]; + int32_t end = offsets[i + 1]; + + for (int32_t j = start; j < end; j += byte_width) { + ast = value_builder->Append(&data[j]); + assert(ast.ok()); } } @@ -66,52 +215,6 @@ class VectorArrayChunkTest : public ::testing::Test { } }; -TEST_F(VectorArrayChunkTest, TestWriteFloatVectorArray) { - // Test parameters - const int64_t dim = 128; - const int num_rows = 100; - const int vectors_per_row = 5; // Each row contains 5 vectors - - // Generate test data - std::vector all_data; - std::vector offsets = {0}; - - for (int row = 0; row < num_rows; ++row) { - auto row_data = generateFloatVector(row, vectors_per_row, dim); - all_data.insert(all_data.end(), row_data.begin(), row_data.end()); - offsets.push_back(offsets.back() + vectors_per_row * dim); - } - - // Create Arrow ListArray - auto list_array = createFloatVectorListArray(all_data, offsets); - arrow::ArrayVector array_vec = {list_array}; - - // Test VectorArrayChunkWriter - VectorArrayChunkWriter writer(dim, DataType::VECTOR_FLOAT); - writer.write(array_vec); - - auto chunk = writer.finish(); - auto vector_array_chunk = static_cast(chunk.get()); - - // Verify results - EXPECT_EQ(vector_array_chunk->RowNums(), num_rows); - - // Verify data integrity using View method - for (int row = 0; row < num_rows; ++row) { - auto view = vector_array_chunk->View(row); - - // Verify by converting back to VectorFieldProto and checking - auto proto = view.output_data(); - EXPECT_EQ(proto.dim(), dim); - EXPECT_EQ(proto.float_vector().data_size(), vectors_per_row * dim); - - const float* expected = all_data.data() + row * vectors_per_row * dim; - for (int i = 0; i < vectors_per_row * dim; ++i) { - EXPECT_FLOAT_EQ(proto.float_vector().data(i), expected[i]); - } - } -} - TEST_F(VectorArrayChunkTest, TestWriteMultipleBatches) { const int64_t dim = 64; const int batch_size = 50; @@ -137,7 +240,7 @@ TEST_F(VectorArrayChunkTest, TestWriteMultipleBatches) { all_batch_data.push_back(batch_data); array_vec.push_back( - createFloatVectorListArray(batch_data, batch_offsets)); + createFloatVectorListArray(batch_data, batch_offsets, dim)); } // Write using VectorArrayChunkWriter @@ -191,7 +294,7 @@ TEST_F(VectorArrayChunkTest, TestWriteWithMmap) { offsets.push_back(offsets.back() + vectors_per_row * dim); } - auto list_array = createFloatVectorListArray(all_data, offsets); + auto list_array = createFloatVectorListArray(all_data, offsets, dim); arrow::ArrayVector array_vec = {list_array}; // Write with mmap @@ -235,3 +338,189 @@ TEST_F(VectorArrayChunkTest, TestEmptyVectorArray) { EXPECT_EQ(vector_array_chunk->RowNums(), 0); } + +struct VectorArrayTestParam { + DataType data_type; + int64_t dim; + int num_rows; + int vectors_per_row; + std::string test_name; +}; + +class VectorArrayChunkParameterizedTest + : public VectorArrayChunkTest, + public ::testing::WithParamInterface {}; + +template +std::shared_ptr +createVectorListArray(const std::vector& data, + const std::vector& offsets, + int64_t dim, + DataType dtype) { + int byte_width; + switch (dtype) { + case DataType::VECTOR_FLOAT: + byte_width = dim * sizeof(float); + break; + case DataType::VECTOR_FLOAT16: + case DataType::VECTOR_BFLOAT16: + byte_width = dim * 2; + break; + case DataType::VECTOR_INT8: + byte_width = dim; + break; + case DataType::VECTOR_BINARY: + byte_width = (dim + 7) / 8; + break; + default: + throw std::invalid_argument("Unsupported data type"); + } + + auto value_builder = std::make_shared( + arrow::fixed_size_binary(byte_width)); + auto list_builder = std::make_shared( + arrow::default_memory_pool(), value_builder); + + arrow::Status ast; + int element_size = (dtype == DataType::VECTOR_BINARY) ? byte_width : dim; + + for (size_t i = 0; i < offsets.size() - 1; ++i) { + ast = list_builder->Append(); + assert(ast.ok()); + int32_t start = offsets[i]; + int32_t end = offsets[i + 1]; + + for (int32_t j = start; j < end; j += element_size) { + const uint8_t* binary_data = + reinterpret_cast(&data[j]); + ast = value_builder->Append(binary_data); + assert(ast.ok()); + } + } + + std::shared_ptr array; + ast = list_builder->Finish(&array); + assert(ast.ok()); + return std::static_pointer_cast(array); +} + +TEST_P(VectorArrayChunkParameterizedTest, TestWriteVectorArray) { + auto param = GetParam(); + + // Generate test data based on type + std::vector all_data; + std::vector offsets = {0}; + + for (int row = 0; row < param.num_rows; ++row) { + std::vector row_data; + + switch (param.data_type) { + case DataType::VECTOR_FLOAT: { + auto float_data = + generateFloatVector(row, param.vectors_per_row, param.dim); + row_data.resize(float_data.size() * sizeof(float)); + memcpy(row_data.data(), float_data.data(), row_data.size()); + break; + } + case DataType::VECTOR_FLOAT16: + case DataType::VECTOR_BFLOAT16: { + auto uint16_data = generateFloat16Vector( + row, param.vectors_per_row, param.dim); + row_data.resize(uint16_data.size() * sizeof(uint16_t)); + memcpy(row_data.data(), uint16_data.data(), row_data.size()); + break; + } + case DataType::VECTOR_INT8: { + auto int8_data = + generateInt8Vector(row, param.vectors_per_row, param.dim); + row_data.resize(int8_data.size()); + memcpy(row_data.data(), int8_data.data(), row_data.size()); + break; + } + case DataType::VECTOR_BINARY: { + row_data = + generateBinaryVector(row, param.vectors_per_row, param.dim); + break; + } + default: + FAIL() << "Unsupported data type"; + } + + all_data.insert(all_data.end(), row_data.begin(), row_data.end()); + + // Calculate offset based on data type + int offset_increment; + if (param.data_type == DataType::VECTOR_BINARY) { + offset_increment = param.vectors_per_row * ((param.dim + 7) / 8); + } else if (param.data_type == DataType::VECTOR_FLOAT) { + offset_increment = + param.vectors_per_row * param.dim * sizeof(float); + } else if (param.data_type == DataType::VECTOR_FLOAT16 || + param.data_type == DataType::VECTOR_BFLOAT16) { + offset_increment = param.vectors_per_row * param.dim * 2; + } else { + offset_increment = param.vectors_per_row * param.dim; + } + offsets.push_back(offsets.back() + offset_increment); + } + + // Create Arrow ListArray + auto list_array = + createVectorListArray(all_data, offsets, param.dim, param.data_type); + arrow::ArrayVector array_vec = {list_array}; + + // Test VectorArrayChunkWriter + VectorArrayChunkWriter writer(param.dim, param.data_type); + writer.write(array_vec); + + auto chunk = writer.finish(); + auto vector_array_chunk = static_cast(chunk.get()); + + // Verify results + EXPECT_EQ(vector_array_chunk->RowNums(), param.num_rows); + + // Basic verification - ensure View doesn't crash and returns valid data + for (int row = 0; row < param.num_rows; ++row) { + auto view = vector_array_chunk->View(row); + auto proto = view.output_data(); + EXPECT_EQ(proto.dim(), param.dim); + + // Verify the correct field is populated based on data type + switch (param.data_type) { + case DataType::VECTOR_FLOAT: + EXPECT_GT(proto.float_vector().data_size(), 0); + break; + case DataType::VECTOR_FLOAT16: + EXPECT_GT(proto.float16_vector().size(), 0); + break; + case DataType::VECTOR_BFLOAT16: + EXPECT_GT(proto.bfloat16_vector().size(), 0); + break; + case DataType::VECTOR_INT8: + EXPECT_GT(proto.int8_vector().size(), 0); + break; + case DataType::VECTOR_BINARY: + EXPECT_GT(proto.binary_vector().size(), 0); + break; + default: + break; + } + } +} + +INSTANTIATE_TEST_SUITE_P( + VectorTypes, + VectorArrayChunkParameterizedTest, + ::testing::Values( + VectorArrayTestParam{ + DataType::VECTOR_FLOAT, 128, 100, 5, "FloatVector"}, + VectorArrayTestParam{ + DataType::VECTOR_FLOAT16, 128, 100, 3, "Float16Vector"}, + VectorArrayTestParam{ + DataType::VECTOR_BFLOAT16, 64, 50, 2, "BFloat16Vector"}, + VectorArrayTestParam{DataType::VECTOR_INT8, 256, 80, 4, "Int8Vector"}, + VectorArrayTestParam{ + DataType::VECTOR_BINARY, 512, 60, 3, "BinaryVector"}), + [](const testing::TestParamInfo& info) { + return info.param.test_name; + }); diff --git a/internal/core/src/common/VectorArrayStorageV2Test.cpp b/internal/core/src/common/VectorArrayStorageV2Test.cpp index 46a1485af2..c035ccfb69 100644 --- a/internal/core/src/common/VectorArrayStorageV2Test.cpp +++ b/internal/core/src/common/VectorArrayStorageV2Test.cpp @@ -159,8 +159,12 @@ class TestVectorArrayStorageV2 : public testing::Test { // Create appropriate value builder based on element type std::shared_ptr value_builder; + int byte_width = 0; if (element_type == DataType::VECTOR_FLOAT) { - value_builder = std::make_shared(); + byte_width = DIM * sizeof(float); + value_builder = + std::make_shared( + arrow::fixed_size_binary(byte_width)); } else { FAIL() << "Unsupported element type for VECTOR_ARRAY " "in test"; @@ -176,11 +180,13 @@ class TestVectorArrayStorageV2 : public testing::Test { // Generate 3 vectors for this row auto data = generate_float_vector(3, DIM); - auto float_builder = - std::static_pointer_cast( - value_builder); - for (const auto& value : data) { - status = float_builder->Append(value); + auto binary_builder = std::static_pointer_cast< + arrow::FixedSizeBinaryBuilder>(value_builder); + // Append each vector as a fixed-size binary value + for (int vec_idx = 0; vec_idx < 3; vec_idx++) { + status = binary_builder->Append( + reinterpret_cast( + data.data() + vec_idx * DIM)); EXPECT_TRUE(status.ok()); } } @@ -288,7 +294,7 @@ TEST_F(TestVectorArrayStorageV2, BuildEmbListHNSWIndex) { milvus::index::CreateIndexInfo create_index_info; create_index_info.field_type = DataType::VECTOR_ARRAY; create_index_info.metric_type = knowhere::metric::MAX_SIM; - create_index_info.index_type = knowhere::IndexEnum::INDEX_EMB_LIST_HNSW; + create_index_info.index_type = knowhere::IndexEnum::INDEX_HNSW; create_index_info.index_engine_version = knowhere::Version::GetCurrentVersion().VersionNumber(); @@ -299,8 +305,7 @@ TEST_F(TestVectorArrayStorageV2, BuildEmbListHNSWIndex) { // Build index with storage v2 configuration Config config; - config[milvus::index::INDEX_TYPE] = - knowhere::IndexEnum::INDEX_EMB_LIST_HNSW; + config[milvus::index::INDEX_TYPE] = knowhere::IndexEnum::INDEX_HNSW; config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type; config[knowhere::indexparam::M] = "16"; config[knowhere::indexparam::EF] = "10"; @@ -330,11 +335,12 @@ TEST_F(TestVectorArrayStorageV2, BuildEmbListHNSWIndex) { std::vector query_vec = generate_float_vector(vec_num, DIM); auto query_dataset = knowhere::GenDataSet(vec_num, DIM, query_vec.data()); - std::vector query_vec_lims; - query_vec_lims.push_back(0); - query_vec_lims.push_back(3); - query_vec_lims.push_back(10); - query_dataset->SetLims(query_vec_lims.data()); + std::vector query_vec_offsets; + query_vec_offsets.push_back(0); + query_vec_offsets.push_back(3); + query_vec_offsets.push_back(10); + query_dataset->Set(knowhere::meta::EMB_LIST_OFFSET, + const_cast(query_vec_offsets.data())); auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}}; milvus::SearchInfo searchInfo; diff --git a/internal/core/src/common/VectorTrait.h b/internal/core/src/common/VectorTrait.h index 11e9c8b968..4e34532e89 100644 --- a/internal/core/src/common/VectorTrait.h +++ b/internal/core/src/common/VectorTrait.h @@ -28,23 +28,39 @@ namespace milvus { -#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \ - using elem_type = std::conditional_t< \ - std::is_same_v, \ - milvus::EmbListFloatVector::embedded_type, \ - std::conditional_t< \ - std::is_same_v, \ - milvus::FloatVector::embedded_type, \ - std::conditional_t< \ - std::is_same_v, \ - milvus::Float16Vector::embedded_type, \ - std::conditional_t< \ - std::is_same_v, \ - milvus::BFloat16Vector::embedded_type, \ - std::conditional_t< \ - std::is_same_v, \ - milvus::Int8Vector::embedded_type, \ - milvus::BinaryVector::embedded_type>>>>>; +#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \ + using elem_type = std::conditional_t< \ + std::is_same_v, \ + milvus::EmbListFloatVector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::EmbListBinaryVector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::EmbListFloat16Vector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::EmbListBFloat16Vector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::EmbListInt8Vector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::FloatVector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::Float16Vector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::BFloat16Vector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::Int8Vector::embedded_type, \ + milvus::BinaryVector:: \ + embedded_type>>>>>>>>>; #define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \ auto schema_data_type = \ @@ -164,6 +180,82 @@ class EmbListFloatVector : public VectorTrait { } }; +class EmbListBinaryVector : public VectorTrait { + public: + using embedded_type = uint8_t; + static constexpr int32_t dim_factor = 8; + static constexpr auto data_type = DataType::VECTOR_ARRAY; + static constexpr auto c_data_type = CDataType::VectorArray; + static constexpr auto schema_data_type = + proto::schema::DataType::ArrayOfVector; + static constexpr auto vector_type = + proto::plan::VectorType::EmbListBinaryVector; + static constexpr auto placeholder_type = + proto::common::PlaceholderType::EmbListBinaryVector; + + static constexpr bool + is_embedding_list() { + return true; + } +}; + +class EmbListFloat16Vector : public VectorTrait { + public: + using embedded_type = float16; + static constexpr int32_t dim_factor = 1; + static constexpr auto data_type = DataType::VECTOR_ARRAY; + static constexpr auto c_data_type = CDataType::VectorArray; + static constexpr auto schema_data_type = + proto::schema::DataType::ArrayOfVector; + static constexpr auto vector_type = + proto::plan::VectorType::EmbListFloat16Vector; + static constexpr auto placeholder_type = + proto::common::PlaceholderType::EmbListFloat16Vector; + + static constexpr bool + is_embedding_list() { + return true; + } +}; + +class EmbListBFloat16Vector : public VectorTrait { + public: + using embedded_type = bfloat16; + static constexpr int32_t dim_factor = 1; + static constexpr auto data_type = DataType::VECTOR_ARRAY; + static constexpr auto c_data_type = CDataType::VectorArray; + static constexpr auto schema_data_type = + proto::schema::DataType::ArrayOfVector; + static constexpr auto vector_type = + proto::plan::VectorType::EmbListBFloat16Vector; + static constexpr auto placeholder_type = + proto::common::PlaceholderType::EmbListBFloat16Vector; + + static constexpr bool + is_embedding_list() { + return true; + } +}; + +class EmbListInt8Vector : public VectorTrait { + public: + using embedded_type = int8; + static constexpr int32_t dim_factor = 1; + static constexpr auto data_type = DataType::VECTOR_ARRAY; + static constexpr auto c_data_type = CDataType::VectorArray; + static constexpr auto schema_data_type = + proto::schema::DataType::ArrayOfVector; + static constexpr auto vector_type = + proto::plan::VectorType::EmbListInt8Vector; + static constexpr auto placeholder_type = + proto::common::PlaceholderType::EmbListInt8Vector; + + static constexpr bool + is_embedding_list() { + return true; + } +}; + struct FundamentalTag {}; struct StringTag {}; diff --git a/internal/core/src/exec/operator/VectorSearchNode.cpp b/internal/core/src/exec/operator/VectorSearchNode.cpp index 3f6199980f..355cae39c0 100644 --- a/internal/core/src/exec/operator/VectorSearchNode.cpp +++ b/internal/core/src/exec/operator/VectorSearchNode.cpp @@ -76,7 +76,7 @@ PhyVectorSearchNode::GetOutput() { auto& ph = placeholder_group_->at(0); auto src_data = ph.get_blob(); - auto src_lims = ph.get_lims(); + auto src_offsets = ph.get_offsets(); auto num_queries = ph.num_of_queries_; milvus::SearchResult search_result; @@ -94,7 +94,7 @@ PhyVectorSearchNode::GetOutput() { auto op_context = query_context_->get_op_context(); segment_->vector_search(search_info_, src_data, - src_lims, + src_offsets, num_queries, query_timestamp_, final_view, diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index 725d1ba090..173fa94a24 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -234,9 +234,42 @@ IndexFactory::VecIndexLoadResource( num_rows, dim, config); - has_raw_data = - knowhere::IndexStaticFaced::HasRawData( - index_type, index_version, config); + break; + case milvus::DataType::VECTOR_FLOAT16: + resource = knowhere::IndexStaticFaced:: + EstimateLoadResource(index_type, + index_version, + index_size_in_bytes, + num_rows, + dim, + config); + break; + case milvus::DataType::VECTOR_BFLOAT16: + resource = knowhere::IndexStaticFaced:: + EstimateLoadResource(index_type, + index_version, + index_size_in_bytes, + num_rows, + dim, + config); + break; + case milvus::DataType::VECTOR_BINARY: + resource = knowhere::IndexStaticFaced:: + EstimateLoadResource(index_type, + index_version, + index_size_in_bytes, + num_rows, + dim, + config); + break; + case milvus::DataType::VECTOR_INT8: + resource = knowhere::IndexStaticFaced:: + EstimateLoadResource(index_type, + index_version, + index_size_in_bytes, + num_rows, + dim, + config); break; default: @@ -247,6 +280,9 @@ IndexFactory::VecIndexLoadResource( element_type); return LoadResourceRequest{0, 0, 0, 0, true}; } + // For VectorArray, has_raw_data is always false as get_vector of index does not provide offsets which + // is required for reconstructing the raw data + has_raw_data = false; break; } default: @@ -641,6 +677,42 @@ IndexFactory::CreateVectorIndex( version, use_knowhere_build_pool, file_manager_context); + case DataType::VECTOR_FLOAT16: { + return std::make_unique>( + element_type, + index_type, + metric_type, + version, + use_knowhere_build_pool, + file_manager_context); + } + case DataType::VECTOR_BFLOAT16: { + return std::make_unique>( + element_type, + index_type, + metric_type, + version, + use_knowhere_build_pool, + file_manager_context); + } + case DataType::VECTOR_BINARY: { + return std::make_unique>( + element_type, + index_type, + metric_type, + version, + use_knowhere_build_pool, + file_manager_context); + } + case DataType::VECTOR_INT8: { + return std::make_unique>( + element_type, + index_type, + metric_type, + version, + use_knowhere_build_pool, + file_manager_context); + } default: ThrowInfo(NotImplemented, fmt::format("not implemented data type to " diff --git a/internal/core/src/index/Utils.h b/internal/core/src/index/Utils.h index f6d95455f2..cef0ab0f66 100644 --- a/internal/core/src/index/Utils.h +++ b/internal/core/src/index/Utils.h @@ -234,21 +234,4 @@ void inline SetBitsetGrowing(void* bitset, } } -inline size_t -vector_element_size(const DataType data_type) { - switch (data_type) { - case DataType::VECTOR_FLOAT: - return sizeof(float); - case DataType::VECTOR_FLOAT16: - return sizeof(float16); - case DataType::VECTOR_BFLOAT16: - return sizeof(bfloat16); - case DataType::VECTOR_INT8: - return sizeof(int8); - default: - ThrowInfo(UnexpectedError, - fmt::format("invalid data type: {}", data_type)); - } -} - } // namespace milvus::index diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index a2f399f871..842c9b69a1 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -313,11 +313,6 @@ VectorMemIndex::BuildWithDataset(const DatasetPtr& dataset, SetDim(index_.Dim()); } -bool -is_embedding_list_index(const IndexType& index_type) { - return index_type == knowhere::IndexEnum::INDEX_EMB_LIST_HNSW; -} - template void VectorMemIndex::Build(const Config& config) { @@ -344,30 +339,19 @@ VectorMemIndex::Build(const Config& config) { total_size += data->Size(); total_num_rows += data->get_num_rows(); - // todo(SapdeA): now, vector arrays (embedding list) are serialized - // to parquet by using binary format which does not provide dim - // information so we use this temporary solution. - if (is_embedding_list_index(index_type_)) { - AssertInfo(elem_type_ != DataType::NONE, - "embedding list index must have elem_type"); - dim = config[DIM_KEY].get(); - } else { - AssertInfo(dim == 0 || dim == data->get_dim(), - "inconsistent dim value between field datas!"); - - dim = data->get_dim(); - } + AssertInfo(dim == 0 || dim == data->get_dim(), + "inconsistent dim value between field datas!"); + dim = data->get_dim(); } auto buf = std::shared_ptr(new uint8_t[total_size]); size_t lim_offset = 0; - std::vector lims; - lims.reserve(total_num_rows + 1); - lims.push_back(lim_offset); + std::vector offsets; int64_t offset = 0; - if (!is_embedding_list_index(index_type_)) { + // For embedding list index, elem_type_ is not NONE + if (elem_type_ == DataType::NONE) { // TODO: avoid copying for (auto data : field_datas) { std::memcpy(buf.get() + offset, data->Data(), data->Size()); @@ -375,7 +359,9 @@ VectorMemIndex::Build(const Config& config) { data.reset(); } } else { - auto elem_size = vector_element_size(elem_type_); + offsets.reserve(total_num_rows + 1); + offsets.push_back(lim_offset); + auto bytes_per_vec = vector_bytes_per_element(elem_type_, dim); for (auto data : field_datas) { auto vec_array_data = dynamic_cast*>(data.get()); @@ -385,16 +371,16 @@ VectorMemIndex::Build(const Config& config) { auto rows = vec_array_data->get_num_rows(); for (auto i = 0; i < rows; ++i) { auto size = vec_array_data->DataSize(i); - assert(size % (dim * elem_size) == 0); - assert(dim * elem_size != 0); + assert(size % bytes_per_vec == 0); + assert(bytes_per_vec != 0); auto vec_array = vec_array_data->value_at(i); std::memcpy(buf.get() + offset, vec_array->data(), size); offset += size; - lim_offset += size / (dim * elem_size); - lims.push_back(lim_offset); + lim_offset += size / bytes_per_vec; + offsets.push_back(lim_offset); } assert(data->Size() == offset); @@ -411,8 +397,9 @@ VectorMemIndex::Build(const Config& config) { if (!scalar_info.empty()) { dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info)); } - if (!lims.empty()) { - dataset->SetLims(lims.data()); + if (!offsets.empty()) { + dataset->Set(knowhere::meta::EMB_LIST_OFFSET, + const_cast(offsets.data())); } BuildWithDataset(dataset, build_config); } else { diff --git a/internal/core/src/mmap/ChunkedColumn.h b/internal/core/src/mmap/ChunkedColumn.h index 7c2a0ea266..b4fee58d5c 100644 --- a/internal/core/src/mmap/ChunkedColumn.h +++ b/internal/core/src/mmap/ChunkedColumn.h @@ -282,12 +282,12 @@ class ChunkedColumnBase : public ChunkedColumnInterface { "VectorArrayViews only supported for ChunkedVectorArrayColumn"); } - virtual PinWrapper - VectorArrayLims(milvus::OpContext* op_ctx, - int64_t chunk_id) const override { + PinWrapper + VectorArrayOffsets(milvus::OpContext* op_ctx, + int64_t chunk_id) const override { ThrowInfo( ErrorCode::Unsupported, - "VectorArrayLims only supported for ChunkedVectorArrayColumn"); + "VectorArrayOffsets only supported for ChunkedVectorArrayColumn"); } PinWrapper, FixedVector>> @@ -691,13 +691,13 @@ class ChunkedVectorArrayColumn : public ChunkedColumnBase { } PinWrapper - VectorArrayLims(milvus::OpContext* op_ctx, - int64_t chunk_id) const override { + VectorArrayOffsets(milvus::OpContext* op_ctx, + int64_t chunk_id) const override { auto ca = SemiInlineGet( slot_->PinCells(op_ctx, {static_cast(chunk_id)})); auto chunk = ca->get_cell_of(chunk_id); return PinWrapper( - ca, static_cast(chunk)->Lims()); + ca, static_cast(chunk)->Offsets()); } }; diff --git a/internal/core/src/mmap/ChunkedColumnGroup.h b/internal/core/src/mmap/ChunkedColumnGroup.h index 9707bab22c..a9d882228e 100644 --- a/internal/core/src/mmap/ChunkedColumnGroup.h +++ b/internal/core/src/mmap/ChunkedColumnGroup.h @@ -328,17 +328,18 @@ class ProxyChunkColumn : public ChunkedColumnInterface { } PinWrapper - VectorArrayLims(milvus::OpContext* op_ctx, - int64_t chunk_id) const override { + VectorArrayOffsets(milvus::OpContext* op_ctx, + int64_t chunk_id) const override { if (!IsChunkedVectorArrayColumnDataType(data_type_)) { ThrowInfo(ErrorCode::Unsupported, - "VectorArrayLims only supported for " + "VectorArrayOffsets only supported for " "ChunkedVectorArrayColumn"); } auto chunk_wrapper = group_->GetGroupChunk(op_ctx, chunk_id); auto chunk = chunk_wrapper.get()->GetChunk(field_id_); return PinWrapper( - chunk_wrapper, static_cast(chunk.get())->Lims()); + chunk_wrapper, + static_cast(chunk.get())->Offsets()); } PinWrapper, FixedVector>> diff --git a/internal/core/src/mmap/ChunkedColumnInterface.h b/internal/core/src/mmap/ChunkedColumnInterface.h index 481fb5c9f4..631ad50df8 100644 --- a/internal/core/src/mmap/ChunkedColumnInterface.h +++ b/internal/core/src/mmap/ChunkedColumnInterface.h @@ -92,7 +92,7 @@ class ChunkedColumnInterface { std::optional> offset_len) const = 0; virtual PinWrapper - VectorArrayLims(milvus::OpContext* op_ctx, int64_t chunk_id) const = 0; + VectorArrayOffsets(milvus::OpContext* op_ctx, int64_t chunk_id) const = 0; virtual PinWrapper< std::pair, FixedVector>> diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index 66670888ca..00f300a061 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -34,8 +34,23 @@ bool check_data_type(const FieldMeta& field_meta, const milvus::proto::common::PlaceholderType type) { if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { - return type == - milvus::proto::common::PlaceholderType::EmbListFloatVector; + if (field_meta.get_element_type() == DataType::VECTOR_FLOAT) { + return type == + milvus::proto::common::PlaceholderType::EmbListFloatVector; + } else if (field_meta.get_element_type() == DataType::VECTOR_FLOAT16) { + return type == + milvus::proto::common::PlaceholderType::EmbListFloat16Vector; + } else if (field_meta.get_element_type() == DataType::VECTOR_BFLOAT16) { + return type == milvus::proto::common::PlaceholderType:: + EmbListBFloat16Vector; + } else if (field_meta.get_element_type() == DataType::VECTOR_BINARY) { + return type == + milvus::proto::common::PlaceholderType::EmbListBinaryVector; + } else if (field_meta.get_element_type() == DataType::VECTOR_INT8) { + return type == + milvus::proto::common::PlaceholderType::EmbListInt8Vector; + } + return false; } return static_cast(field_meta.get_data_type()) == static_cast(type); @@ -96,25 +111,25 @@ ParsePlaceholderGroup(const Plan* plan, // If the vector is embedding list, line contains multiple vectors. // And we should record the offsets so that we can identify each // embedding list in a flattened vectors. - auto& lims = element.lims_; - lims.reserve(element.num_of_queries_ + 1); + auto& offsets = element.offsets_; + offsets.reserve(element.num_of_queries_ + 1); size_t offset = 0; - lims.push_back(offset); + offsets.push_back(offset); - auto elem_size = milvus::index::vector_element_size( - field_meta.get_element_type()); + auto bytes_per_vec = milvus::vector_bytes_per_element( + field_meta.get_element_type(), dim); for (auto& line : info.values()) { target.insert(target.end(), line.begin(), line.end()); AssertInfo( - line.size() % (dim * elem_size) == 0, - "line.size() % (dim * elem_size) == 0 assert failed, " - "line.size() = {}, dim = {}, elem_size = {}", + line.size() % bytes_per_vec == 0, + "line.size() % bytes_per_vec == 0 assert failed, " + "line.size() = {}, dim = {}, bytes_per_vec = {}", line.size(), dim, - elem_size); + bytes_per_vec); - offset += line.size() / (dim * elem_size); - lims.push_back(offset); + offset += line.size() / bytes_per_vec; + offsets.push_back(offset); } } } diff --git a/internal/core/src/query/PlanImpl.h b/internal/core/src/query/PlanImpl.h index b9b0989e36..1ade71c335 100644 --- a/internal/core/src/query/PlanImpl.h +++ b/internal/core/src/query/PlanImpl.h @@ -70,8 +70,8 @@ struct Plan { struct Placeholder { std::string tag_; // note: for embedding list search, num_of_queries_ stands for the number of vectors. - // lims_ records the offsets of embedding list in the flattened vector and - // hence lims_.size() - 1 is the number of queries in embedding list search. + // offsets_ records the offsets of embedding list in the flattened vector and + // hence offsets_.size() - 1 is the number of queries in embedding list search. int64_t num_of_queries_; // TODO(SPARSE): add a dim_ field here, use the dim passed in search request // instead of the dim in schema, since the dim of sparse float column is @@ -84,7 +84,7 @@ struct Placeholder { std::unique_ptr[]> sparse_matrix_; // offsets for embedding list - aligned_vector lims_; + aligned_vector offsets_; const void* get_blob() const { @@ -103,13 +103,13 @@ struct Placeholder { } const size_t* - get_lims() const { - return lims_.data(); + get_offsets() const { + return offsets_.data(); } size_t* - get_lims() { - return lims_.data(); + get_offsets() { + return offsets_.data(); } }; diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 099801457b..ae1a332d40 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -89,21 +89,25 @@ PrepareBFDataSet(const dataset::SearchDataset& query_ds, DataType data_type) { auto base_dataset = knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data); - if (raw_ds.raw_data_lims != nullptr) { + if (raw_ds.raw_data_offsets != nullptr) { // knowhere::DataSet count vectors in a flattened manner where as the num_raw_data here is the number // of embedding lists where each embedding list contains multiple vectors. So we should use the last element - // in lims which equals to the total number of vectors. - base_dataset->SetLims(raw_ds.raw_data_lims); - // the length of lims equals to the number of embedding lists + 1 - base_dataset->SetRows(raw_ds.raw_data_lims[raw_ds.num_raw_data]); + // in offsets which equals to the total number of vectors. + base_dataset->Set(knowhere::meta::EMB_LIST_OFFSET, + raw_ds.raw_data_offsets); + + // the length of offsets equals to the number of embedding lists + 1 + base_dataset->SetRows(raw_ds.raw_data_offsets[raw_ds.num_raw_data]); } auto query_dataset = knowhere::GenDataSet( query_ds.num_queries, query_ds.dim, query_ds.query_data); - if (query_ds.query_lims != nullptr) { + if (query_ds.query_offsets != nullptr) { // ditto - query_dataset->SetLims(query_ds.query_lims); - query_dataset->SetRows(query_ds.query_lims[query_ds.num_queries]); + query_dataset->Set(knowhere::meta::EMB_LIST_OFFSET, + query_ds.query_offsets); + + query_dataset->SetRows(query_ds.query_offsets[query_ds.num_queries]); } if (data_type == DataType::VECTOR_SPARSE_U32_F32) { diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index 8499e6aa65..f3db538b81 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -73,7 +73,7 @@ void SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, const SearchInfo& info, const void* query_data, - const size_t* query_lims, + const size_t* query_offsets, int64_t num_queries, Timestamp timestamp, const BitsetView& bitset, @@ -141,7 +141,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, round_decimal, dim, query_data, - query_lims}; + query_offsets}; int32_t current_chunk_id = 0; // get K1 and B from index for bm25 brute force @@ -222,8 +222,8 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, if (data_type == DataType::VECTOR_ARRAY) { AssertInfo( - query_lims != nullptr, - "query_lims is nullptr, but data_type is vector array"); + query_offsets != nullptr, + "query_offsets is nullptr, but data_type is vector array"); } if (milvus::exec::UseVectorIterator(info)) { diff --git a/internal/core/src/query/SearchOnGrowing.h b/internal/core/src/query/SearchOnGrowing.h index 99fc3cdad3..69d06c7615 100644 --- a/internal/core/src/query/SearchOnGrowing.h +++ b/internal/core/src/query/SearchOnGrowing.h @@ -21,7 +21,7 @@ void SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, const SearchInfo& info, const void* query_data, - const size_t* query_lims, + const size_t* query_offsets, int64_t num_queries, Timestamp timestamp, const BitsetView& bitset, diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index 32df192b30..effe7ecbb8 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -31,7 +31,7 @@ SearchOnSealedIndex(const Schema& schema, const segcore::SealedIndexingRecord& record, const SearchInfo& search_info, const void* query_data, - const size_t* query_lims, + const size_t* query_offsets, int64_t num_queries, const BitsetView& bitset, milvus::OpContext* op_context, @@ -55,15 +55,15 @@ SearchOnSealedIndex(const Schema& schema, search_info.metric_type_); knowhere::DataSetPtr dataset; - if (query_lims == nullptr) { + if (query_offsets == nullptr) { dataset = knowhere::GenDataSet(num_queries, dim, query_data); } else { // Rather than non-embedding list search where num_queries equals to the number of vectors, - // in embedding list search, multiple vectors form an embedding list and the last element of query_lims + // in embedding list search, multiple vectors form an embedding list and the last element of query_offsets // stands for the total number of vectors. - auto num_vectors = query_lims[num_queries]; + auto num_vectors = query_offsets[num_queries]; dataset = knowhere::GenDataSet(num_vectors, dim, query_data); - dataset->SetLims(query_lims); + dataset->Set(knowhere::meta::EMB_LIST_OFFSET, query_offsets); } dataset->SetIsSparse(is_sparse); @@ -107,7 +107,7 @@ SearchOnSealedColumn(const Schema& schema, const SearchInfo& search_info, const std::map& index_info, const void* query_data, - const size_t* query_lims, + const size_t* query_offsets, int64_t num_queries, int64_t row_count, const BitsetView& bitview, @@ -128,7 +128,7 @@ SearchOnSealedColumn(const Schema& schema, search_info.round_decimal_, dim, query_data, - query_lims}; + query_offsets}; CheckBruteForceSearchParam(field, search_info); @@ -158,13 +158,14 @@ SearchOnSealedColumn(const Schema& schema, auto raw_dataset = query::dataset::RawDataset{offset, dim, chunk_size, vec_data}; - PinWrapper lims_pw; + PinWrapper offsets_pw; if (data_type == DataType::VECTOR_ARRAY) { - AssertInfo(query_lims != nullptr, - "query_lims is nullptr, but data_type is vector array"); + AssertInfo( + query_offsets != nullptr, + "query_offsets is nullptr, but data_type is vector array"); - lims_pw = column->VectorArrayLims(op_context, i); - raw_dataset.raw_data_lims = lims_pw.get(); + offsets_pw = column->VectorArrayOffsets(op_context, i); + raw_dataset.raw_data_offsets = offsets_pw.get(); } if (milvus::exec::UseVectorIterator(search_info)) { diff --git a/internal/core/src/query/SearchOnSealed.h b/internal/core/src/query/SearchOnSealed.h index 5c8ab0312b..08b531b3e7 100644 --- a/internal/core/src/query/SearchOnSealed.h +++ b/internal/core/src/query/SearchOnSealed.h @@ -23,7 +23,7 @@ SearchOnSealedIndex(const Schema& schema, const segcore::SealedIndexingRecord& record, const SearchInfo& search_info, const void* query_data, - const size_t* query_lims, + const size_t* query_offsets, int64_t num_queries, const BitsetView& view, milvus::OpContext* op_context, @@ -35,7 +35,7 @@ SearchOnSealedColumn(const Schema& schema, const SearchInfo& search_info, const std::map& index_info, const void* query_data, - const size_t* query_lims, + const size_t* query_offsets, int64_t num_queries, int64_t row_count, const BitsetView& bitset, diff --git a/internal/core/src/query/helper.h b/internal/core/src/query/helper.h index 034d4854fb..ef7b470b22 100644 --- a/internal/core/src/query/helper.h +++ b/internal/core/src/query/helper.h @@ -24,7 +24,7 @@ struct RawDataset { int64_t dim; int64_t num_raw_data; const void* raw_data; - const size_t* raw_data_lims = nullptr; + const size_t* raw_data_offsets = nullptr; }; struct SearchDataset { knowhere::MetricType metric_type; @@ -34,7 +34,7 @@ struct SearchDataset { int64_t dim; const void* query_data; // used for embedding list query - const size_t* query_lims = nullptr; + const size_t* query_offsets = nullptr; }; } // namespace dataset diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp index 80510298da..8b6b8a0919 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp @@ -756,7 +756,7 @@ ChunkedSegmentSealedImpl::mask_with_delete(BitsetTypeView& bitset, void ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, const void* query_data, - const size_t* query_lims, + const size_t* query_offsets, int64_t query_count, Timestamp timestamp, const BitsetView& bitset, @@ -783,7 +783,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, vector_indexings_, binlog_search_info, query_data, - query_lims, + query_offsets, query_count, bitset, op_context, @@ -798,7 +798,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, vector_indexings_, search_info, query_data, - query_lims, + query_offsets, query_count, bitset, op_context, @@ -826,7 +826,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, search_info, index_info, query_data, - query_lims, + query_offsets, query_count, row_count, bitset, diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.h b/internal/core/src/segcore/ChunkedSegmentSealedImpl.h index ed73f18431..21bc60e4ff 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.h +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.h @@ -459,7 +459,7 @@ class ChunkedSegmentSealedImpl : public SegmentSealed { void vector_search(SearchInfo& search_info, const void* query_data, - const size_t* query_lims, + const size_t* query_offsets, int64_t query_count, Timestamp timestamp, const BitsetView& bitset, diff --git a/internal/core/src/segcore/SegmentGrowingImpl.cpp b/internal/core/src/segcore/SegmentGrowingImpl.cpp index 972f157adc..6fabd7582b 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.cpp +++ b/internal/core/src/segcore/SegmentGrowingImpl.cpp @@ -696,7 +696,7 @@ SegmentGrowingImpl::search_batch_pks( void SegmentGrowingImpl::vector_search(SearchInfo& search_info, const void* query_data, - const size_t* query_lims, + const size_t* query_offsets, int64_t query_count, Timestamp timestamp, const BitsetView& bitset, @@ -705,7 +705,7 @@ SegmentGrowingImpl::vector_search(SearchInfo& search_info, query::SearchOnGrowing(*this, search_info, query_data, - query_lims, + query_offsets, query_count, timestamp, bitset, diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index e3b598af80..d3a018c671 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -359,7 +359,7 @@ class SegmentGrowingImpl : public SegmentGrowing { void vector_search(SearchInfo& search_info, const void* query_data, - const size_t* query_lims, + const size_t* query_offsets, int64_t query_count, Timestamp timestamp, const BitsetView& bitset, diff --git a/internal/core/src/segcore/SegmentGrowingTest.cpp b/internal/core/src/segcore/SegmentGrowingTest.cpp index 27517adb3f..903a1bb5e9 100644 --- a/internal/core/src/segcore/SegmentGrowingTest.cpp +++ b/internal/core/src/segcore/SegmentGrowingTest.cpp @@ -581,7 +581,7 @@ TEST(GrowingTest, SearchVectorArray) { config.set_enable_interim_segment_index(true); std::map index_params = { - {"index_type", knowhere::IndexEnum::INDEX_EMB_LIST_HNSW}, + {"index_type", knowhere::IndexEnum::INDEX_HNSW}, {"metric_type", metric_type}, {"nlist", "128"}}; std::map type_params = { @@ -612,11 +612,11 @@ TEST(GrowingTest, SearchVectorArray) { int vec_num = 10; // Total number of query vectors std::vector query_vec = generate_float_vector(vec_num, dim); - // Create query dataset with lims for VectorArray - std::vector query_vec_lims; - query_vec_lims.push_back(0); // First query has 3 vectors - query_vec_lims.push_back(3); - query_vec_lims.push_back(10); // Second query has 7 vectors + // Create query dataset with offsets for VectorArray + std::vector query_vec_offsets; + query_vec_offsets.push_back(0); // First query has 3 vectors + query_vec_offsets.push_back(3); + query_vec_offsets.push_back(10); // Second query has 7 vectors // Create search plan const char* raw_plan = R"(vector_anns: < @@ -636,7 +636,7 @@ TEST(GrowingTest, SearchVectorArray) { // Use CreatePlaceholderGroupFromBlob for VectorArray auto ph_group_raw = CreatePlaceholderGroupFromBlob( - vec_num, dim, query_vec.data(), query_vec_lims); + vec_num, dim, query_vec.data(), query_vec_offsets); auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); diff --git a/internal/core/src/segcore/SegmentInterface.h b/internal/core/src/segcore/SegmentInterface.h index 81cb525375..aee52ef354 100644 --- a/internal/core/src/segcore/SegmentInterface.h +++ b/internal/core/src/segcore/SegmentInterface.h @@ -373,14 +373,14 @@ class SegmentInternalInterface : public SegmentInterface { const std::string& nested_path) const override; public: - // `query_lims` is not null only for vector array (embedding list) search + // `query_offsets` is not null only for vector array (embedding list) search // where it denotes the number of vectors in each embedding list. The length - // of `query_lims` is the number of queries in the search plus one (the first - // element in query_lims is 0). + // of `query_offsets` is the number of queries in the search plus one (the first + // element in query_offsets is 0). virtual void vector_search(SearchInfo& search_info, const void* query_data, - const size_t* query_lims, + const size_t* query_offsets, int64_t query_count, Timestamp timestamp, const BitsetView& bitset, diff --git a/internal/core/src/segcore/Types.h b/internal/core/src/segcore/Types.h index 85a808726d..a4e36650bc 100644 --- a/internal/core/src/segcore/Types.h +++ b/internal/core/src/segcore/Types.h @@ -35,7 +35,7 @@ struct LoadIndexInfo { int64_t segment_id; int64_t field_id; DataType field_type; - // The element type of the field. It's DataType::NONE if field_type is array/vector_array. + // The element type of the field. It's not DataType::NONE if field_type is array/vector_array. DataType element_type; bool enable_mmap; std::string mmap_dir_path; diff --git a/internal/core/src/segcore/Utils.cpp b/internal/core/src/segcore/Utils.cpp index 90625a144e..f4095f15cc 100644 --- a/internal/core/src/segcore/Utils.cpp +++ b/internal/core/src/segcore/Utils.cpp @@ -611,22 +611,40 @@ CreateVectorDataArrayFrom(const void* data_raw, case DataType::VECTOR_ARRAY: { auto data = reinterpret_cast(data_raw); auto vector_type = field_meta.get_element_type(); + auto obj = vector_array->mutable_vector_array(); + obj->set_dim(dim); + + // Set element type based on vector type switch (vector_type) { - case DataType::VECTOR_FLOAT: { - auto obj = vector_array->mutable_vector_array(); + case DataType::VECTOR_FLOAT: obj->set_element_type( milvus::proto::schema::DataType::FloatVector); - obj->set_dim(dim); - for (auto i = 0; i < count; i++) { - *(obj->mutable_data()->Add()) = data[i]; - } break; - } - default: { + case DataType::VECTOR_FLOAT16: + obj->set_element_type( + milvus::proto::schema::DataType::Float16Vector); + break; + case DataType::VECTOR_BFLOAT16: + obj->set_element_type( + milvus::proto::schema::DataType::BFloat16Vector); + break; + case DataType::VECTOR_BINARY: + obj->set_element_type( + milvus::proto::schema::DataType::BinaryVector); + break; + case DataType::VECTOR_INT8: + obj->set_element_type( + milvus::proto::schema::DataType::Int8Vector); + break; + default: ThrowInfo(NotImplemented, fmt::format("not implemented vector type {}", vector_type)); - } + } + + // Add all vector data + for (auto i = 0; i < count; i++) { + *(obj->mutable_data()->Add()) = data[i]; } break; } diff --git a/internal/core/src/storage/Util.cpp b/internal/core/src/storage/Util.cpp index 57795aced5..8027993381 100644 --- a/internal/core/src/storage/Util.cpp +++ b/internal/core/src/storage/Util.cpp @@ -241,66 +241,46 @@ AddPayloadToArrowBuilder(std::shared_ptr builder, if (length > 0) { auto element_type = vector_arrays[0].get_element_type(); + // Validate element type switch (element_type) { - case DataType::VECTOR_FLOAT: { - auto value_builder = static_cast( - list_builder->value_builder()); - AssertInfo(value_builder != nullptr, - "value_builder must be FloatBuilder for " - "FloatVector"); - - arrow::Status ast; - for (int i = 0; i < length; ++i) { - auto status = list_builder->Append(); - AssertInfo(status.ok(), - "Failed to append list: {}", - status.ToString()); - - const auto& array = vector_arrays[i]; - AssertInfo( - array.get_element_type() == - DataType::VECTOR_FLOAT, - "Inconsistent element types in VectorArray"); - - int num_vectors = array.length(); - int dim = array.dim(); - - for (int j = 0; j < num_vectors; ++j) { - auto vec_data = array.get_data(j); - ast = - value_builder->AppendValues(vec_data, dim); - AssertInfo(ast.ok(), - "Failed to append list: {}", - ast.ToString()); - } - } - break; - } + case DataType::VECTOR_FLOAT: case DataType::VECTOR_BINARY: - ThrowInfo( - NotImplemented, - "BinaryVector in VectorArray not implemented yet"); - break; case DataType::VECTOR_FLOAT16: - ThrowInfo( - NotImplemented, - "Float16Vector in VectorArray not implemented yet"); - break; case DataType::VECTOR_BFLOAT16: - ThrowInfo(NotImplemented, - "BFloat16Vector in VectorArray not " - "implemented yet"); - break; case DataType::VECTOR_INT8: - ThrowInfo( - NotImplemented, - "Int8Vector in VectorArray not implemented yet"); break; default: ThrowInfo(DataTypeInvalid, "Unsupported element type in VectorArray: {}", element_type); } + + // All supported vector types use FixedSizeBinaryBuilder + auto value_builder = + static_cast( + list_builder->value_builder()); + AssertInfo(value_builder != nullptr, + "value_builder must be FixedSizeBinaryBuilder for " + "VectorArray"); + + for (int i = 0; i < length; ++i) { + auto status = list_builder->Append(); + AssertInfo(status.ok(), + "Failed to append list: {}", + status.ToString()); + + const auto& array = vector_arrays[i]; + AssertInfo(array.get_element_type() == element_type, + "Inconsistent element types in VectorArray"); + + int num_vectors = array.length(); + auto ast = value_builder->AppendValues( + reinterpret_cast(array.data()), + num_vectors); + AssertInfo(ast.ok(), + "Failed to batch append vectors: {}", + ast.ToString()); + } } break; } @@ -428,7 +408,38 @@ CreateArrowBuilder(DataType data_type, DataType element_type, int dim) { std::shared_ptr value_builder; switch (element_type) { case DataType::VECTOR_FLOAT: { - value_builder = std::make_shared(); + int byte_width = dim * sizeof(float); + value_builder = + std::make_shared( + arrow::fixed_size_binary(byte_width)); + break; + } + case DataType::VECTOR_BINARY: { + int byte_width = (dim + 7) / 8; + value_builder = + std::make_shared( + arrow::fixed_size_binary(byte_width)); + break; + } + case DataType::VECTOR_FLOAT16: { + int byte_width = dim * 2; + value_builder = + std::make_shared( + arrow::fixed_size_binary(byte_width)); + break; + } + case DataType::VECTOR_BFLOAT16: { + int byte_width = dim * 2; + value_builder = + std::make_shared( + arrow::fixed_size_binary(byte_width)); + break; + } + case DataType::VECTOR_INT8: { + int byte_width = dim; + value_builder = + std::make_shared( + arrow::fixed_size_binary(byte_width)); break; } default: { @@ -610,7 +621,7 @@ CreateArrowSchema(DataType data_type, int dim, DataType element_type) { "This overload is only for VECTOR_ARRAY type"); AssertInfo(dim > 0, "invalid dim value"); - auto value_type = GetArrowDataTypeForVectorArray(element_type); + auto value_type = GetArrowDataTypeForVectorArray(element_type, dim); auto metadata = arrow::KeyValueMetadata::Make( {ELEMENT_TYPE_KEY_FOR_ARROW, DIM_KEY}, {std::to_string(static_cast(element_type)), std::to_string(dim)}); diff --git a/internal/core/thirdparty/knowhere/CMakeLists.txt b/internal/core/thirdparty/knowhere/CMakeLists.txt index bd92bca5cb..9d6ab2f4b6 100644 --- a/internal/core/thirdparty/knowhere/CMakeLists.txt +++ b/internal/core/thirdparty/knowhere/CMakeLists.txt @@ -14,7 +14,7 @@ # Update KNOWHERE_VERSION for the first occurrence milvus_add_pkg_config("knowhere") set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES "") -set( KNOWHERE_VERSION v2.6.3 ) +set( KNOWHERE_VERSION c4d5dd8 ) set( GIT_REPOSITORY "https://github.com/zilliztech/knowhere.git") message(STATUS "Knowhere repo: ${GIT_REPOSITORY}") diff --git a/internal/core/unittest/test_sealed.cpp b/internal/core/unittest/test_sealed.cpp index 74d55cfb60..e54a3d780e 100644 --- a/internal/core/unittest/test_sealed.cpp +++ b/internal/core/unittest/test_sealed.cpp @@ -17,6 +17,7 @@ #include "common/Types.h" #include "index/IndexFactory.h" #include "knowhere/version.h" +#include "knowhere/comp/index_param.h" #include "storage/RemoteChunkManagerSingleton.h" #include "storage/Util.h" #include "common/VectorArray.h" @@ -2304,13 +2305,81 @@ TEST(Sealed, SearchSortedPk) { EXPECT_EQ(100, offsets2[0].get()); } -TEST(Sealed, QueryVectorArrayAllFields) { +using VectorArrayTestParam = + std::tuple; + +class SealedVectorArrayTest + : public ::testing::TestWithParam { + protected: + DataType element_type; + std::string metric_type; + int dim; + std::string test_name; + + void + SetUp() override { + auto param = GetParam(); + element_type = std::get<0>(param); + metric_type = std::get<1>(param); + dim = std::get<2>(param); + test_name = std::get<3>(param); + + // Ensure dim is valid for binary vectors + if (element_type == DataType::VECTOR_BINARY) { + ASSERT_EQ(dim % 8, 0) << "Binary vector dim must be multiple of 8"; + } + } + + void + VerifyVectorResults(const VectorFieldProto& result_vec, + const VectorFieldProto& expected_vec, + DataType element_type) { + switch (element_type) { + case DataType::VECTOR_FLOAT: { + auto result_data = result_vec.float_vector().data(); + auto expected_data = expected_vec.float_vector().data(); + EXPECT_EQ(result_data.size(), expected_data.size()); + for (int64_t i = 0; i < result_data.size(); ++i) { + EXPECT_NEAR(result_data[i], expected_data[i], 1e-6f); + } + break; + } + case DataType::VECTOR_BINARY: { + auto result_data = result_vec.binary_vector(); + auto expected_data = expected_vec.binary_vector(); + EXPECT_EQ(result_data, expected_data); + break; + } + case DataType::VECTOR_FLOAT16: { + auto result_data = result_vec.float16_vector(); + auto expected_data = expected_vec.float16_vector(); + EXPECT_EQ(result_data, expected_data); + break; + } + case DataType::VECTOR_BFLOAT16: { + auto result_data = result_vec.bfloat16_vector(); + auto expected_data = expected_vec.bfloat16_vector(); + EXPECT_EQ(result_data, expected_data); + break; + } + case DataType::VECTOR_INT8: { + auto result_data = result_vec.int8_vector(); + auto expected_data = expected_vec.int8_vector(); + EXPECT_EQ(result_data, expected_data); + break; + } + default: + break; + } + } +}; + +TEST_P(SealedVectorArrayTest, QueryVectorArrayAllFields) { auto schema = std::make_shared(); - auto metric_type = knowhere::metric::MAX_SIM; - int64_t dim = 4; + auto int64_field = schema->AddDebugField("int64", DataType::INT64); auto array_vec = schema->AddDebugVectorArrayField( - "array_vec", DataType::VECTOR_FLOAT, dim, metric_type); + "array_vec", element_type, dim, metric_type); schema->set_primary_field_id(int64_field); std::map filedMap{}; @@ -2329,49 +2398,36 @@ TEST(Sealed, QueryVectorArrayAllFields) { auto ids_ds = GenRandomIds(dataset_size); auto int64_result = segment->bulk_subscript( nullptr, int64_field, ids_ds->GetIds(), dataset_size); - auto array_float_vector_result = segment->bulk_subscript( + auto array_vector_result = segment->bulk_subscript( nullptr, array_vec, ids_ds->GetIds(), dataset_size); EXPECT_EQ(int64_result->scalars().long_data().data_size(), dataset_size); - EXPECT_EQ(array_float_vector_result->vectors().vector_array().data_size(), + EXPECT_EQ(array_vector_result->vectors().vector_array().data_size(), dataset_size); - auto verify_float_vectors = [](auto arr1, auto arr2) { - static constexpr float EPSILON = 1e-6; - EXPECT_EQ(arr1.size(), arr2.size()); - for (int64_t i = 0; i < arr1.size(); ++i) { - EXPECT_NEAR(arr1[i], arr2[i], EPSILON); - } - }; for (int64_t i = 0; i < dataset_size; ++i) { - auto arrow_array = array_float_vector_result->vectors() - .vector_array() - .data()[i] - .float_vector() - .data(); - auto expected_array = - array_vec_values[ids_ds->GetIds()[i]].float_vector().data(); - verify_float_vectors(arrow_array, expected_array); + auto result_vec = + array_vector_result->vectors().vector_array().data()[i]; + auto expected_vec = array_vec_values[ids_ds->GetIds()[i]]; + VerifyVectorResults(result_vec, expected_vec, element_type); } EXPECT_EQ(int64_result->valid_data_size(), 0); - EXPECT_EQ(array_float_vector_result->valid_data_size(), 0); + EXPECT_EQ(array_vector_result->valid_data_size(), 0); } -TEST(Sealed, SearchVectorArray) { +TEST_P(SealedVectorArrayTest, SearchVectorArray) { int64_t collection_id = 1; int64_t partition_id = 2; int64_t segment_id = 3; int64_t index_build_id = 4000; int64_t index_version = 4000; int64_t index_id = 5000; - int64_t dim = 4; auto schema = std::make_shared(); - auto metric_type = knowhere::metric::MAX_SIM; auto int64_field = schema->AddDebugField("int64", DataType::INT64); auto array_vec = schema->AddDebugVectorArrayField( - "array_vec", DataType::VECTOR_FLOAT, dim, metric_type); + "array_vec", element_type, dim, metric_type); schema->set_primary_field_id(int64_field); auto field_meta = milvus::segcore::gen_field_meta(collection_id, @@ -2379,7 +2435,7 @@ TEST(Sealed, SearchVectorArray) { segment_id, array_vec.get(), DataType::VECTOR_ARRAY, - DataType::VECTOR_FLOAT, + element_type, false); auto index_meta = gen_index_meta( segment_id, array_vec.get(), index_build_id, index_version); @@ -2403,7 +2459,7 @@ TEST(Sealed, SearchVectorArray) { vector_arrays.push_back(milvus::VectorArray(v)); } auto field_data = storage::CreateFieldData( - DataType::VECTOR_ARRAY, DataType::VECTOR_FLOAT, false, dim); + DataType::VECTOR_ARRAY, element_type, false, dim); field_data->FillFieldData(vector_arrays.data(), vector_arrays.size()); // create sealed segment @@ -2445,8 +2501,8 @@ TEST(Sealed, SearchVectorArray) { // create index milvus::index::CreateIndexInfo create_index_info; create_index_info.field_type = DataType::VECTOR_ARRAY; - create_index_info.metric_type = knowhere::metric::MAX_SIM; - create_index_info.index_type = knowhere::IndexEnum::INDEX_EMB_LIST_HNSW; + create_index_info.metric_type = metric_type; + create_index_info.index_type = knowhere::IndexEnum::INDEX_HNSW; create_index_info.index_engine_version = knowhere::Version::GetCurrentVersion().VersionNumber(); @@ -2457,8 +2513,7 @@ TEST(Sealed, SearchVectorArray) { // build index Config config; - config[milvus::index::INDEX_TYPE] = - knowhere::IndexEnum::INDEX_EMB_LIST_HNSW; + config[milvus::index::INDEX_TYPE] = knowhere::IndexEnum::INDEX_HNSW; config[INSERT_FILES_KEY] = std::vector{log_path}; config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type; config[knowhere::indexparam::M] = "16"; @@ -2473,18 +2528,37 @@ TEST(Sealed, SearchVectorArray) { // search auto vec_num = 10; - std::vector query_vec = generate_float_vector(vec_num, dim); - auto query_dataset = knowhere::GenDataSet(vec_num, dim, query_vec.data()); - std::vector query_vec_lims; - query_vec_lims.push_back(0); - query_vec_lims.push_back(3); - query_vec_lims.push_back(10); - query_dataset->SetLims(query_vec_lims.data()); + + // Generate query vectors based on element type + std::vector query_vec_bin; + std::vector query_vec_f32; + knowhere::DataSetPtr query_dataset; + if (element_type == DataType::VECTOR_BINARY) { + auto byte_dim = (dim + 7) / 8; + auto total_bytes = vec_num * byte_dim; + query_vec_bin.resize(total_bytes); + for (size_t i = 0; i < total_bytes; ++i) { + query_vec_bin[i] = rand() % 256; + } + query_dataset = + knowhere::GenDataSet(vec_num, dim, query_vec_bin.data()); + } else { + // For float-like types (FLOAT, FLOAT16, BFLOAT16, INT8) + query_vec_f32 = generate_float_vector(vec_num, dim); + query_dataset = + knowhere::GenDataSet(vec_num, dim, query_vec_f32.data()); + } + std::vector query_vec_offsets; + query_vec_offsets.push_back(0); + query_vec_offsets.push_back(3); + query_vec_offsets.push_back(10); + query_dataset->Set(knowhere::meta::EMB_LIST_OFFSET, + const_cast(query_vec_offsets.data())); auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}}; milvus::SearchInfo searchInfo; searchInfo.topk_ = 5; - searchInfo.metric_type_ = knowhere::metric::MAX_SIM; + searchInfo.metric_type_ = metric_type; searchInfo.search_params_ = search_conf; SearchResult result; vec_index->Query(query_dataset, searchInfo, nullptr, nullptr, result); @@ -2498,21 +2572,62 @@ TEST(Sealed, SearchVectorArray) { // brute force search { - const char* raw_plan = R"(vector_anns: < + std::string raw_plan = fmt::format(R"(vector_anns: < field_id: 101 query_info: < topk: 5 round_decimal: 3 - metric_type: "MAX_SIM" - search_params: "{\"nprobe\": 10}" + metric_type: "{}" + search_params: "{{\"nprobe\": 10}}" > placeholder_tag: "$0" - >)"; - auto plan_str = translate_text_plan_to_binary_plan(raw_plan); + >)", + metric_type); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size()); - auto ph_group_raw = CreatePlaceholderGroupFromBlob( - vec_num, dim, query_vec.data(), query_vec_lims); + + // Create placeholder based on element type + milvus::proto::common::PlaceholderGroup ph_group_raw; + if (element_type == DataType::VECTOR_BINARY) { + auto byte_dim = (dim + 7) / 8; + auto total_bytes = vec_num * byte_dim; + std::vector query_vec(total_bytes); + for (size_t i = 0; i < total_bytes; ++i) { + query_vec[i] = rand() % 256; + } + ph_group_raw = CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_offsets); + } else if (element_type == DataType::VECTOR_FLOAT16) { + std::vector float_vec = generate_float_vector(vec_num, dim); + std::vector query_vec(vec_num * dim); + for (size_t i = 0; i < vec_num * dim; ++i) { + query_vec[i] = float16(float_vec[i]); + } + ph_group_raw = CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_offsets); + } else if (element_type == DataType::VECTOR_BFLOAT16) { + std::vector float_vec = generate_float_vector(vec_num, dim); + std::vector query_vec(vec_num * dim); + for (size_t i = 0; i < vec_num * dim; ++i) { + query_vec[i] = bfloat16(float_vec[i]); + } + ph_group_raw = + CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_offsets); + } else if (element_type == DataType::VECTOR_INT8) { + std::vector query_vec(vec_num * dim); + for (size_t i = 0; i < vec_num * dim; ++i) { + query_vec[i] = static_cast(rand() % 256 - 128); + } + ph_group_raw = CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_offsets); + } else { + std::vector query_vec = generate_float_vector(vec_num, dim); + ph_group_raw = CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_offsets); + } + auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); Timestamp timestamp = 1000000; @@ -2528,30 +2643,71 @@ TEST(Sealed, SearchVectorArray) { LoadIndexInfo load_info; load_info.field_id = array_vec.get(); load_info.field_type = DataType::VECTOR_ARRAY; - load_info.element_type = DataType::VECTOR_FLOAT; + load_info.element_type = element_type; load_info.index_params = GenIndexParams(emb_list_hnsw_index.get()); load_info.cache_index = CreateTestCacheIndex("test", std::move(emb_list_hnsw_index)); - load_info.index_params["metric_type"] = knowhere::metric::MAX_SIM; + load_info.index_params["metric_type"] = metric_type; sealed_segment->DropFieldData(array_vec); sealed_segment->LoadIndex(load_info); - const char* raw_plan = R"(vector_anns: < + std::string raw_plan = fmt::format(R"(vector_anns: < field_id: 101 query_info: < topk: 5 round_decimal: 3 - metric_type: "MAX_SIM" - search_params: "{\"nprobe\": 10}" + metric_type: "{}" + search_params: "{{\"nprobe\": 10}}" > placeholder_tag: "$0" - >)"; - auto plan_str = translate_text_plan_to_binary_plan(raw_plan); + >)", + metric_type); + auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); auto plan = CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size()); - auto ph_group_raw = CreatePlaceholderGroupFromBlob( - vec_num, dim, query_vec.data(), query_vec_lims); + + // Create placeholder based on element type + milvus::proto::common::PlaceholderGroup ph_group_raw; + if (element_type == DataType::VECTOR_BINARY) { + auto byte_dim = (dim + 7) / 8; + auto total_bytes = vec_num * byte_dim; + std::vector query_vec(total_bytes); + for (size_t i = 0; i < total_bytes; ++i) { + query_vec[i] = rand() % 256; + } + ph_group_raw = CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_offsets); + } else if (element_type == DataType::VECTOR_FLOAT16) { + std::vector float_vec = generate_float_vector(vec_num, dim); + std::vector query_vec(vec_num * dim); + for (size_t i = 0; i < vec_num * dim; ++i) { + query_vec[i] = float16(float_vec[i]); + } + ph_group_raw = CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_offsets); + } else if (element_type == DataType::VECTOR_BFLOAT16) { + std::vector float_vec = generate_float_vector(vec_num, dim); + std::vector query_vec(vec_num * dim); + for (size_t i = 0; i < vec_num * dim; ++i) { + query_vec[i] = bfloat16(float_vec[i]); + } + ph_group_raw = + CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_offsets); + } else if (element_type == DataType::VECTOR_INT8) { + std::vector query_vec(vec_num * dim); + for (size_t i = 0; i < vec_num * dim; ++i) { + query_vec[i] = static_cast(rand() % 256 - 128); + } + ph_group_raw = CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_offsets); + } else { + std::vector query_vec = generate_float_vector(vec_num, dim); + ph_group_raw = CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_offsets); + } + auto ph_group = ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); Timestamp timestamp = 1000000; @@ -2562,3 +2718,27 @@ TEST(Sealed, SearchVectorArray) { std::cout << sr_parsed.dump(1) << std::endl; } } + +INSTANTIATE_TEST_SUITE_P( + VectorArrayTypes, + SealedVectorArrayTest, + ::testing::Values( + std::make_tuple(DataType::VECTOR_FLOAT, "MAX_SIM", 4, "float_max_sim"), + std::make_tuple(DataType::VECTOR_FLOAT, "MAX_SIM_L2", 4, "float_l2"), + std::make_tuple( + DataType::VECTOR_FLOAT16, "MAX_SIM", 4, "float16_max_sim"), + std::make_tuple( + DataType::VECTOR_FLOAT16, "MAX_SIM_L2", 4, "float16_l2"), + std::make_tuple( + DataType::VECTOR_BFLOAT16, "MAX_SIM", 4, "bfloat16_max_sim"), + std::make_tuple( + DataType::VECTOR_BFLOAT16, "MAX_SIM_L2", 4, "bfloat16_l2"), + std::make_tuple(DataType::VECTOR_INT8, "MAX_SIM", 4, "int8_max_sim"), + std::make_tuple(DataType::VECTOR_INT8, "MAX_SIM_L2", 4, "int8_l2"), + std::make_tuple( + DataType::VECTOR_BINARY, "MAX_SIM_HAMMING", 32, "binary_hamming"), + std::make_tuple( + DataType::VECTOR_BINARY, "MAX_SIM_JACCARD", 32, "binary_jaccard")), + [](const ::testing::TestParamInfo& info) { + return std::get<3>(info.param); + }); diff --git a/internal/parser/planparserv2/plan_parser_v2.go b/internal/parser/planparserv2/plan_parser_v2.go index d6d13eb968..cf897b908c 100644 --- a/internal/parser/planparserv2/plan_parser_v2.go +++ b/internal/parser/planparserv2/plan_parser_v2.go @@ -249,14 +249,22 @@ func CreateSearchPlanArgs(schema *typeutil.SchemaHelper, exprStr string, vectorF switch elementType { case schemapb.DataType_FloatVector: vectorType = planpb.VectorType_EmbListFloatVector + case schemapb.DataType_BinaryVector: + vectorType = planpb.VectorType_EmbListBinaryVector + case schemapb.DataType_Float16Vector: + vectorType = planpb.VectorType_EmbListFloat16Vector + case schemapb.DataType_BFloat16Vector: + vectorType = planpb.VectorType_EmbListBFloat16Vector + case schemapb.DataType_Int8Vector: + vectorType = planpb.VectorType_EmbListInt8Vector default: - log.Error("Invalid elementType", zap.Any("elementType", elementType)) - return nil, err + log.Error("Invalid elementType for ArrayOfVector", zap.Any("elementType", elementType)) + return nil, fmt.Errorf("unsupported element type for ArrayOfVector: %v", elementType) } default: log.Error("Invalid dataType", zap.Any("dataType", dataType)) - return nil, err + return nil, fmt.Errorf("unsupported vector data type: %v", dataType) } scorers, options, err := CreateSearchScorers(schema, functionScorer, exprTemplateValues) diff --git a/internal/proxy/cgo_util_test.go b/internal/proxy/cgo_util_test.go index ae68f76430..d698c7aa38 100644 --- a/internal/proxy/cgo_util_test.go +++ b/internal/proxy/cgo_util_test.go @@ -29,7 +29,7 @@ func Test_CheckVecIndexWithDataTypeExist(t *testing.T) { want bool }{ {"HNSW", schemapb.DataType_FloatVector, true}, - {"HNSW", schemapb.DataType_BinaryVector, false}, + {"HNSW", schemapb.DataType_BinaryVector, true}, {"HNSW", schemapb.DataType_Float16Vector, true}, {"SPARSE_WAND", schemapb.DataType_SparseFloatVector, true}, diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 2232dcdbc7..badd2a2182 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -555,7 +555,7 @@ func constructTestCreateIndexRequest(dbName, collectionName string, dataType sch }, { Key: common.IndexTypeKey, - Value: "EMB_LIST_HNSW", + Value: "HNSW", }, { Key: "nlist", @@ -1744,7 +1744,6 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) }) - fmt.Println("create index for binVec field") fieldName := ConcatStructFieldName(structField, subFieldFVec) wg.Add(1) @@ -1757,8 +1756,6 @@ func TestProxy(t *testing.T) { assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) }) - fmt.Println("create index for embedding list field") - wg.Add(1) t.Run("alter index for embedding list field", func(t *testing.T) { defer wg.Done() @@ -1778,7 +1775,6 @@ func TestProxy(t *testing.T) { err = merr.CheckRPCCall(resp, err) assert.NoError(t, err) }) - fmt.Println("alter index for embedding list field") wg.Add(1) t.Run("describe index for embedding list field", func(t *testing.T) { @@ -1796,7 +1792,6 @@ func TestProxy(t *testing.T) { enableMmap, _ := common.IsMmapDataEnabled(resp.IndexDescriptions[0].GetParams()...) assert.True(t, enableMmap, "params: %+v", resp.IndexDescriptions[0]) }) - fmt.Println("describe index for embedding list field") wg.Add(1) t.Run("describe index with indexName for embedding list field", func(t *testing.T) { @@ -1812,7 +1807,6 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) - fmt.Println("describe index with indexName for embedding list field") wg.Add(1) t.Run("get index statistics for embedding list field", func(t *testing.T) { diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 76737ed058..9c51b576ac 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -71,8 +71,6 @@ const ( RoundDecimalKey = "round_decimal" OffsetKey = "offset" LimitKey = "limit" - // offsets for embedding list search - LimsKey = "lims" // key for timestamptz translation TimezoneKey = "timezone" TimefieldsKey = "time_fields" diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 89e3c6afe6..a1c5585472 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -1206,7 +1206,7 @@ func Test_checkEmbeddingListIndex(t *testing.T) { ExtraParams: []*commonpb.KeyValuePair{ { Key: common.IndexTypeKey, - Value: "EMB_LIST_HNSW", + Value: "HNSW", }, { Key: common.MetricTypeKey, @@ -1237,7 +1237,7 @@ func Test_checkEmbeddingListIndex(t *testing.T) { ExtraParams: []*commonpb.KeyValuePair{ { Key: common.IndexTypeKey, - Value: "EMB_LIST_HNSW", + Value: "HNSW", }, { Key: common.MetricTypeKey, @@ -1290,37 +1290,6 @@ func Test_checkEmbeddingListIndex(t *testing.T) { err := cit.parseIndexParams(context.TODO()) assert.True(t, strings.Contains(err.Error(), "float vector index does not support metric type: MAX_SIM")) }) - - t.Run("data type wrong", func(t *testing.T) { - cit := &createIndexTask{ - Condition: nil, - req: &milvuspb.CreateIndexRequest{ - ExtraParams: []*commonpb.KeyValuePair{ - { - Key: common.IndexTypeKey, - Value: "EMB_LIST_HNSW", - }, - { - Key: common.MetricTypeKey, - Value: metric.L2, - }, - }, - IndexName: "", - }, - fieldSchema: &schemapb.FieldSchema{ - FieldID: 101, - Name: "FieldFloat", - IsPrimaryKey: false, - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: common.DimKey, Value: "128"}, - }, - }, - } - - err := cit.parseIndexParams(context.TODO()) - assert.True(t, strings.Contains(err.Error(), "data type FloatVector can't build with this index EMB_LIST_HNSW")) - }) } func Test_ngram_parseIndexParams(t *testing.T) { diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index 108be82177..b79383fb2e 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -1087,12 +1087,61 @@ func (v *validateUtil) checkArrayOfVectorFieldData(field *schemapb.FieldData, fi return merr.WrapErrParameterInvalid("need float vector array", "got nil", msg) } if v.checkNAN { - return typeutil.VerifyFloats32(floatVector.GetData()) + if err := typeutil.VerifyFloats32(floatVector.GetData()); err != nil { + return err + } + } + } + return nil + case schemapb.DataType_BinaryVector: + for _, vector := range data.GetData() { + binaryVector := vector.GetBinaryVector() + if binaryVector == nil { + msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need binary vector array", "got nil", msg) + } + } + return nil + case schemapb.DataType_Float16Vector: + for _, vector := range data.GetData() { + float16Vector := vector.GetFloat16Vector() + if float16Vector == nil { + msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need float16 vector array", "got nil", msg) + } + if v.checkNAN { + if err := typeutil.VerifyFloats16(float16Vector); err != nil { + return err + } + } + } + return nil + case schemapb.DataType_BFloat16Vector: + for _, vector := range data.GetData() { + bfloat16Vector := vector.GetBfloat16Vector() + if bfloat16Vector == nil { + msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need bfloat16 vector array", "got nil", msg) + } + if v.checkNAN { + if err := typeutil.VerifyBFloats16(bfloat16Vector); err != nil { + return err + } + } + } + return nil + case schemapb.DataType_Int8Vector: + for _, vector := range data.GetData() { + int8Vector := vector.GetInt8Vector() + if int8Vector == nil { + msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName()) + return merr.WrapErrParameterInvalid("need int8 vector array", "got nil", msg) } } return nil default: - panic("not implemented") + msg := fmt.Sprintf("unsupported element type for ArrayOfVector: %v", fieldSchema.GetElementType()) + return merr.WrapErrParameterInvalid("supported vector type", fieldSchema.GetElementType().String(), msg) } } diff --git a/internal/storage/arrow_util.go b/internal/storage/arrow_util.go index 9b0cbad83d..3f0c30184c 100644 --- a/internal/storage/arrow_util.go +++ b/internal/storage/arrow_util.go @@ -251,16 +251,18 @@ func appendValueAt(builder array.Builder, a arrow.Array, idx int, defaultValue * b.Append(true) valuesArray := la.ListValues() - valueBuilder := b.ValueBuilder() - var totalSize uint64 = 0 + valueBuilder := b.ValueBuilder() switch vb := valueBuilder.(type) { - case *array.Float32Builder: - if floatArray, ok := valuesArray.(*array.Float32); ok { - for i := start; i < end; i++ { - vb.Append(floatArray.Value(int(i))) - totalSize += 4 - } + case *array.FixedSizeBinaryBuilder: + fixedArray, ok := valuesArray.(*array.FixedSizeBinary) + if !ok { + return 0, fmt.Errorf("invalid value type %T, expect %T", valuesArray.DataType(), vb.Type()) + } + for i := start; i < end; i++ { + val := fixedArray.Value(int(i)) + vb.Append(val) + totalSize += uint64(len(val)) } default: return 0, fmt.Errorf("unsupported value builder type in ListBuilder: %T", valueBuilder) diff --git a/internal/storage/insert_data.go b/internal/storage/insert_data.go index 8e00d490a6..3b281dd5cd 100644 --- a/internal/storage/insert_data.go +++ b/internal/storage/insert_data.go @@ -376,7 +376,12 @@ func NewFieldData(dataType schemapb.DataType, fieldSchema *schemapb.FieldSchema, } return data, nil case schemapb.DataType_ArrayOfVector: + dim, err := GetDimFromParams(typeParams) + if err != nil { + return nil, err + } data := &VectorArrayFieldData{ + Dim: int64(dim), Data: make([]*schemapb.VectorField, 0, cap), ElementType: fieldSchema.GetElementType(), } diff --git a/internal/storage/payload_reader.go b/internal/storage/payload_reader.go index 03f31d7961..90ba26ddf7 100644 --- a/internal/storage/payload_reader.go +++ b/internal/storage/payload_reader.go @@ -628,45 +628,16 @@ func readVectorArrayFromListArray(r *PayloadReader) ([]*schemapb.VectorField, er return nil, fmt.Errorf("expected ListArray, got %T", chunk) } - valuesArray := listArray.ListValues() - switch elementType { - case schemapb.DataType_FloatVector: - floatArray, ok := valuesArray.(*array.Float32) + for i := 0; i < listArray.Len(); i++ { + value, ok := deserializeArrayOfVector(listArray, i, elementType, dim, true) if !ok { - return nil, fmt.Errorf("expected Float32 array for FloatVector, got %T", valuesArray) + return nil, fmt.Errorf("failed to deserialize VectorArray at row %d", len(result)) } - - // Process each row which contains multiple vectors - for i := 0; i < listArray.Len(); i++ { - if listArray.IsNull(i) { - return nil, fmt.Errorf("null value in VectorArray") - } - - start, end := listArray.ValueOffsets(i) - vectorData := make([]float32, end-start) - copy(vectorData, floatArray.Float32Values()[start:end]) - - vectorField := &schemapb.VectorField{ - Dim: dim, - Data: &schemapb.VectorField_FloatVector{ - FloatVector: &schemapb.FloatArray{ - Data: vectorData, - }, - }, - } - result = append(result, vectorField) + vectorField, _ := value.(*schemapb.VectorField) + if vectorField == nil { + return nil, fmt.Errorf("null value in VectorArray") } - - case schemapb.DataType_BinaryVector: - return nil, fmt.Errorf("BinaryVector in VectorArray not implemented yet") - case schemapb.DataType_Float16Vector: - return nil, fmt.Errorf("Float16Vector in VectorArray not implemented yet") - case schemapb.DataType_BFloat16Vector: - return nil, fmt.Errorf("BFloat16Vector in VectorArray not implemented yet") - case schemapb.DataType_Int8Vector: - return nil, fmt.Errorf("Int8Vector in VectorArray not implemented yet") - default: - return nil, fmt.Errorf("unsupported element type in VectorArray: %s", elementType.String()) + result = append(result, vectorField) } } diff --git a/internal/storage/payload_test.go b/internal/storage/payload_test.go index 25c94cacb3..5d848a0bb9 100644 --- a/internal/storage/payload_test.go +++ b/internal/storage/payload_test.go @@ -1719,6 +1719,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) { // Create VectorArrayFieldData with 3 rows vectorArrayData := &VectorArrayFieldData{ + Dim: int64(dim), Data: []*schemapb.VectorField{ { Dim: int64(dim), diff --git a/internal/storage/payload_writer.go b/internal/storage/payload_writer.go index e93cf8c68b..0ac894ea68 100644 --- a/internal/storage/payload_writer.go +++ b/internal/storage/payload_writer.go @@ -18,6 +18,7 @@ package storage import ( "bytes" + "encoding/binary" "fmt" "math" "sync" @@ -114,7 +115,10 @@ func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions if w.elementType == nil { return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires elementType, use WithElementType option") } - elemType, err := VectorArrayToArrowType(*w.elementType) + if w.dim == nil { + return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires dim to be specified") + } + elemType, err := VectorArrayToArrowType(*w.elementType, *w.dim.Value) if err != nil { return nil, err } @@ -962,13 +966,13 @@ func (w *NativePayloadWriter) AddVectorArrayFieldDataToPayload(data *VectorArray case schemapb.DataType_FloatVector: return w.addFloatVectorArrayToPayload(builder, data) case schemapb.DataType_BinaryVector: - return merr.WrapErrParameterInvalidMsg("BinaryVector in VectorArray not implemented yet") + return w.addBinaryVectorArrayToPayload(builder, data) case schemapb.DataType_Float16Vector: - return merr.WrapErrParameterInvalidMsg("Float16Vector in VectorArray not implemented yet") + return w.addFloat16VectorArrayToPayload(builder, data) case schemapb.DataType_BFloat16Vector: - return merr.WrapErrParameterInvalidMsg("BFloat16Vector in VectorArray not implemented yet") + return w.addBFloat16VectorArrayToPayload(builder, data) case schemapb.DataType_Int8Vector: - return merr.WrapErrParameterInvalidMsg("Int8Vector in VectorArray not implemented yet") + return w.addInt8VectorArrayToPayload(builder, data) default: return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", data.ElementType.String())) } @@ -976,21 +980,159 @@ func (w *NativePayloadWriter) AddVectorArrayFieldDataToPayload(data *VectorArray // addFloatVectorArrayToPayload handles FloatVector elements in VectorArray func (w *NativePayloadWriter) addFloatVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error { - valueBuilder := builder.ValueBuilder().(*array.Float32Builder) + if data.Dim <= 0 { + return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0") + } + valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder) + + // Each element in data.Data represents one row of VectorArray for _, vectorField := range data.Data { if vectorField.GetFloatVector() == nil { return merr.WrapErrParameterInvalidMsg("expected FloatVector but got different type") } + // Start a new list for this row builder.Append(true) floatData := vectorField.GetFloatVector().GetData() - if len(floatData) == 0 { - return merr.WrapErrParameterInvalidMsg("empty vector data not allowed") - } - valueBuilder.AppendValues(floatData, nil) + numVectors := len(floatData) / int(data.Dim) + for i := 0; i < numVectors; i++ { + start := i * int(data.Dim) + end := start + int(data.Dim) + vectorSlice := floatData[start:end] + + bytes := make([]byte, data.Dim*4) + for j, f := range vectorSlice { + binary.LittleEndian.PutUint32(bytes[j*4:], math.Float32bits(f)) + } + + valueBuilder.Append(bytes) + } + } + + return nil +} + +// addBinaryVectorArrayToPayload handles BinaryVector elements in VectorArray +func (w *NativePayloadWriter) addBinaryVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error { + if data.Dim <= 0 { + return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0") + } + + valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder) + + // Each element in data.Data represents one row of VectorArray + for _, vectorField := range data.Data { + if vectorField.GetBinaryVector() == nil { + return merr.WrapErrParameterInvalidMsg("expected BinaryVector but got different type") + } + + // Start a new list for this row + builder.Append(true) + + binaryData := vectorField.GetBinaryVector() + byteWidth := (data.Dim + 7) / 8 + numVectors := len(binaryData) / int(byteWidth) + + for i := 0; i < numVectors; i++ { + start := i * int(byteWidth) + end := start + int(byteWidth) + valueBuilder.Append(binaryData[start:end]) + } + } + + return nil +} + +// addFloat16VectorArrayToPayload handles Float16Vector elements in VectorArray +func (w *NativePayloadWriter) addFloat16VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error { + if data.Dim <= 0 { + return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0") + } + + valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder) + + // Each element in data.Data represents one row of VectorArray + for _, vectorField := range data.Data { + if vectorField.GetFloat16Vector() == nil { + return merr.WrapErrParameterInvalidMsg("expected Float16Vector but got different type") + } + + // Start a new list for this row + builder.Append(true) + + float16Data := vectorField.GetFloat16Vector() + byteWidth := data.Dim * 2 + numVectors := len(float16Data) / int(byteWidth) + + for i := 0; i < numVectors; i++ { + start := i * int(byteWidth) + end := start + int(byteWidth) + valueBuilder.Append(float16Data[start:end]) + } + } + + return nil +} + +// addBFloat16VectorArrayToPayload handles BFloat16Vector elements in VectorArray +func (w *NativePayloadWriter) addBFloat16VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error { + if data.Dim <= 0 { + return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0") + } + + valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder) + + // Each element in data.Data represents one row of VectorArray + for _, vectorField := range data.Data { + if vectorField.GetBfloat16Vector() == nil { + return merr.WrapErrParameterInvalidMsg("expected BFloat16Vector but got different type") + } + + // Start a new list for this row + builder.Append(true) + + bfloat16Data := vectorField.GetBfloat16Vector() + byteWidth := data.Dim * 2 + numVectors := len(bfloat16Data) / int(byteWidth) + + for i := 0; i < numVectors; i++ { + start := i * int(byteWidth) + end := start + int(byteWidth) + valueBuilder.Append(bfloat16Data[start:end]) + } + } + + return nil +} + +// addInt8VectorArrayToPayload handles Int8Vector elements in VectorArray +func (w *NativePayloadWriter) addInt8VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error { + if data.Dim <= 0 { + return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0") + } + + valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder) + + // Each element in data.Data represents one row of VectorArray + for _, vectorField := range data.Data { + if vectorField.GetInt8Vector() == nil { + return merr.WrapErrParameterInvalidMsg("expected Int8Vector but got different type") + } + + // Start a new list for this row + builder.Append(true) + + int8Data := vectorField.GetInt8Vector() + numVectors := len(int8Data) / int(data.Dim) + + for i := 0; i < numVectors; i++ { + start := i * int(data.Dim) + end := start + int(data.Dim) + valueBuilder.Append(int8Data[start:end]) + } } return nil diff --git a/internal/storage/payload_writer_test.go b/internal/storage/payload_writer_test.go index fb7561b736..8c3c798eb1 100644 --- a/internal/storage/payload_writer_test.go +++ b/internal/storage/payload_writer_test.go @@ -309,6 +309,7 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) { vectorArrayData := &VectorArrayFieldData{ Data: make([]*schemapb.VectorField, numRows), ElementType: schemapb.DataType_FloatVector, + Dim: int64(dim), } for i := 0; i < numRows; i++ { @@ -408,6 +409,7 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) { batchData := &VectorArrayFieldData{ Data: make([]*schemapb.VectorField, batchSize), ElementType: schemapb.DataType_FloatVector, + Dim: int64(dim), } for i := 0; i < batchSize; i++ { @@ -454,6 +456,7 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) { vectorArrayData := &VectorArrayFieldData{ Data: make([]*schemapb.VectorField, numRows), ElementType: schemapb.DataType_FloatVector, + Dim: int64(dim), } for i := 0; i < numRows; i++ { @@ -485,6 +488,159 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) { require.NoError(t, err) require.Equal(t, numRows, length) }) + + t.Run("Test ArrayOfFloat16Vector - Basic", func(t *testing.T) { + dim := 64 + numRows := 50 + vectorsPerRow := 4 + + // Create test data + vectorArrayData := &VectorArrayFieldData{ + Data: make([]*schemapb.VectorField, numRows), + ElementType: schemapb.DataType_Float16Vector, + Dim: int64(dim), + } + + for i := 0; i < numRows; i++ { + // Float16 vectors are stored as bytes (2 bytes per element) + byteData := make([]byte, vectorsPerRow*dim*2) + for j := 0; j < len(byteData); j++ { + byteData[j] = byte((i*1000 + j) % 256) + } + + vectorArrayData.Data[i] = &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: byteData, + }, + } + } + + w, err := NewPayloadWriter( + schemapb.DataType_ArrayOfVector, + WithDim(dim), + WithElementType(schemapb.DataType_Float16Vector), + ) + require.NoError(t, err) + require.NotNil(t, w) + + err = w.AddVectorArrayFieldDataToPayload(vectorArrayData) + require.NoError(t, err) + + err = w.FinishPayloadWriter() + require.NoError(t, err) + + // Verify results + buffer, err := w.GetPayloadBufferFromWriter() + require.NoError(t, err) + require.NotEmpty(t, buffer) + + length, err := w.GetPayloadLengthFromWriter() + require.NoError(t, err) + require.Equal(t, numRows, length) + }) + + t.Run("Test ArrayOfBinaryVector - Basic", func(t *testing.T) { + dim := 128 // Must be multiple of 8 + numRows := 50 + vectorsPerRow := 3 + + // Create test data + vectorArrayData := &VectorArrayFieldData{ + Data: make([]*schemapb.VectorField, numRows), + ElementType: schemapb.DataType_BinaryVector, + Dim: int64(dim), + } + + for i := 0; i < numRows; i++ { + // Binary vectors use 1 bit per dimension, so dim/8 bytes per vector + byteData := make([]byte, vectorsPerRow*dim/8) + for j := 0; j < len(byteData); j++ { + byteData[j] = byte((i + j) % 256) + } + + vectorArrayData.Data[i] = &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: byteData, + }, + } + } + + w, err := NewPayloadWriter( + schemapb.DataType_ArrayOfVector, + WithDim(dim), + WithElementType(schemapb.DataType_BinaryVector), + ) + require.NoError(t, err) + require.NotNil(t, w) + + err = w.AddVectorArrayFieldDataToPayload(vectorArrayData) + require.NoError(t, err) + + err = w.FinishPayloadWriter() + require.NoError(t, err) + + // Verify results + buffer, err := w.GetPayloadBufferFromWriter() + require.NoError(t, err) + require.NotEmpty(t, buffer) + + length, err := w.GetPayloadLengthFromWriter() + require.NoError(t, err) + require.Equal(t, numRows, length) + }) + + t.Run("Test ArrayOfBFloat16Vector - Basic", func(t *testing.T) { + dim := 64 + numRows := 50 + vectorsPerRow := 4 + + // Create test data + vectorArrayData := &VectorArrayFieldData{ + Data: make([]*schemapb.VectorField, numRows), + ElementType: schemapb.DataType_BFloat16Vector, + Dim: int64(dim), + } + + for i := 0; i < numRows; i++ { + // BFloat16 vectors are stored as bytes (2 bytes per element) + byteData := make([]byte, vectorsPerRow*dim*2) + for j := 0; j < len(byteData); j++ { + byteData[j] = byte((i*100 + j) % 256) + } + + vectorArrayData.Data[i] = &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_Bfloat16Vector{ + Bfloat16Vector: byteData, + }, + } + } + + w, err := NewPayloadWriter( + schemapb.DataType_ArrayOfVector, + WithDim(dim), + WithElementType(schemapb.DataType_BFloat16Vector), + ) + require.NoError(t, err) + require.NotNil(t, w) + + err = w.AddVectorArrayFieldDataToPayload(vectorArrayData) + require.NoError(t, err) + + err = w.FinishPayloadWriter() + require.NoError(t, err) + + // Verify results + buffer, err := w.GetPayloadBufferFromWriter() + require.NoError(t, err) + require.NotEmpty(t, buffer) + + length, err := w.GetPayloadLengthFromWriter() + require.NoError(t, err) + require.Equal(t, numRows, length) + }) } func TestParquetEncoding(t *testing.T) { diff --git a/internal/storage/serde.go b/internal/storage/serde.go index 79fd88a800..445a9567ab 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -17,6 +17,7 @@ package storage import ( + "encoding/binary" "fmt" "io" "math" @@ -448,8 +449,8 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { // ArrayOfVector now implements the standard interface with elementType parameter m[schemapb.DataType_ArrayOfVector] = serdeEntry{ - arrowType: func(_ int, elementType schemapb.DataType) arrow.DataType { - return getArrayOfVectorArrowType(elementType) + arrowType: func(dim int, elementType schemapb.DataType) arrow.DataType { + return getArrayOfVectorArrowType(elementType, dim) }, deserialize: func(a arrow.Array, i int, elementType schemapb.DataType, dim int, shouldCopy bool) (any, bool) { return deserializeArrayOfVector(a, i, elementType, int64(dim), shouldCopy) @@ -471,24 +472,64 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { } builder.Append(true) + valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder) + dim := vf.GetDim() + + appendVectorChunks := func(data []byte, bytesPerVector int) bool { + numVectors := len(data) / bytesPerVector + for i := 0; i < numVectors; i++ { + start := i * bytesPerVector + end := start + bytesPerVector + valueBuilder.Append(data[start:end]) + } + return true + } switch elementType { case schemapb.DataType_FloatVector: if vf.GetFloatVector() == nil { return false } - valueBuilder := builder.ValueBuilder().(*array.Float32Builder) - valueBuilder.AppendValues(vf.GetFloatVector().GetData(), nil) + floatData := vf.GetFloatVector().GetData() + numVectors := len(floatData) / int(dim) + // Convert float data to binary + for i := 0; i < numVectors; i++ { + start := i * int(dim) + end := start + int(dim) + vectorSlice := floatData[start:end] + + bytes := make([]byte, dim*4) + for j, f := range vectorSlice { + binary.LittleEndian.PutUint32(bytes[j*4:], math.Float32bits(f)) + } + valueBuilder.Append(bytes) + } return true case schemapb.DataType_BinaryVector: - panic("BinaryVector in VectorArray not implemented yet") + if vf.GetBinaryVector() == nil { + return false + } + return appendVectorChunks(vf.GetBinaryVector(), int((dim+7)/8)) + case schemapb.DataType_Float16Vector: - panic("Float16Vector in VectorArray not implemented yet") + if vf.GetFloat16Vector() == nil { + return false + } + return appendVectorChunks(vf.GetFloat16Vector(), int(dim)*2) + case schemapb.DataType_BFloat16Vector: - panic("BFloat16Vector in VectorArray not implemented yet") + if vf.GetBfloat16Vector() == nil { + return false + } + return appendVectorChunks(vf.GetBfloat16Vector(), int(dim)*2) + case schemapb.DataType_Int8Vector: - panic("Int8Vector in VectorArray not implemented yet") + if vf.GetInt8Vector() == nil { + return false + } + return appendVectorChunks(vf.GetInt8Vector(), int(dim)) + case schemapb.DataType_SparseFloatVector: panic("SparseFloatVector in VectorArray not implemented yet") default: @@ -806,18 +847,18 @@ func (sfw *singleFieldRecordWriter) Close() error { } // getArrayOfVectorArrowType returns the appropriate Arrow type for ArrayOfVector based on element type -func getArrayOfVectorArrowType(elementType schemapb.DataType) arrow.DataType { +func getArrayOfVectorArrowType(elementType schemapb.DataType, dim int) arrow.DataType { switch elementType { case schemapb.DataType_FloatVector: - return arrow.ListOf(arrow.PrimitiveTypes.Float32) + return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim * 4}) case schemapb.DataType_BinaryVector: - return arrow.ListOf(arrow.PrimitiveTypes.Uint8) + return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: (dim + 7) / 8}) case schemapb.DataType_Float16Vector: - return arrow.ListOf(arrow.PrimitiveTypes.Uint8) + return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim * 2}) case schemapb.DataType_BFloat16Vector: - return arrow.ListOf(arrow.PrimitiveTypes.Uint8) + return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim * 2}) case schemapb.DataType_Int8Vector: - return arrow.ListOf(arrow.PrimitiveTypes.Int8) + return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim}) case schemapb.DataType_SparseFloatVector: return arrow.ListOf(arrow.BinaryTypes.Binary) default: @@ -842,54 +883,78 @@ func deserializeArrayOfVector(a arrow.Array, i int, elementType schemapb.DataTyp return nil, false } - // Validate dimension for vector types that have fixed dimensions - if dim > 0 && totalElements%dim != 0 { - // Dimension mismatch - data corruption or schema inconsistency - return nil, false + valuesArray := arr.ListValues() + binaryArray, ok := valuesArray.(*array.FixedSizeBinary) + if !ok { + // empty array + return nil, true } - valuesArray := arr.ListValues() + numVectors := int(totalElements) + + // Helper function to extract byte vectors from FixedSizeBinary array + extractByteVectors := func(bytesPerVector int64) []byte { + totalBytes := numVectors * int(bytesPerVector) + data := make([]byte, totalBytes) + for j := 0; j < numVectors; j++ { + vectorIndex := int(start) + j + vectorData := binaryArray.Value(vectorIndex) + copy(data[j*int(bytesPerVector):], vectorData) + } + return data + } switch elementType { case schemapb.DataType_FloatVector: - floatArray, ok := valuesArray.(*array.Float32) - if !ok { - return nil, false + totalFloats := numVectors * int(dim) + floatData := make([]float32, totalFloats) + for j := 0; j < numVectors; j++ { + vectorIndex := int(start) + j + binaryData := binaryArray.Value(vectorIndex) + vectorFloats := arrow.Float32Traits.CastFromBytes(binaryData) + copy(floatData[j*int(dim):], vectorFloats) } - // Handle data copying based on shouldCopy parameter - var floatData []float32 - if shouldCopy { - // Explicitly requested copy - floatData = make([]float32, totalElements) - for j := start; j < end; j++ { - floatData[j-start] = floatArray.Value(int(j)) - } - } else { - // Try to avoid copying - use a slice of the underlying data - // This creates a slice that references the same underlying array - allData := floatArray.Float32Values() - floatData = allData[start:end] - } - - vectorField := &schemapb.VectorField{ + return &schemapb.VectorField{ Dim: dim, Data: &schemapb.VectorField_FloatVector{ FloatVector: &schemapb.FloatArray{ Data: floatData, }, }, - } - return vectorField, true + }, true case schemapb.DataType_BinaryVector: - panic("BinaryVector in VectorArray deserialization not implemented yet") + return &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_BinaryVector{ + BinaryVector: extractByteVectors((dim + 7) / 8), + }, + }, true + case schemapb.DataType_Float16Vector: - panic("Float16Vector in VectorArray deserialization not implemented yet") + return &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_Float16Vector{ + Float16Vector: extractByteVectors(dim * 2), + }, + }, true + case schemapb.DataType_BFloat16Vector: - panic("BFloat16Vector in VectorArray deserialization not implemented yet") + return &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_Bfloat16Vector{ + Bfloat16Vector: extractByteVectors(dim * 2), + }, + }, true + case schemapb.DataType_Int8Vector: - panic("Int8Vector in VectorArray deserialization not implemented yet") + return &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_Int8Vector{ + Int8Vector: extractByteVectors(dim), + }, + }, true case schemapb.DataType_SparseFloatVector: panic("SparseFloatVector in VectorArray deserialization not implemented yet") default: diff --git a/internal/storage/serde_test.go b/internal/storage/serde_test.go index 98f4dcf658..a2678c53c8 100644 --- a/internal/storage/serde_test.go +++ b/internal/storage/serde_test.go @@ -268,41 +268,48 @@ func TestCalculateArraySize(t *testing.T) { } func TestArrayOfVectorArrowType(t *testing.T) { + dim := 128 // Test dimension tests := []struct { name string elementType schemapb.DataType + dim int expectedChild arrow.DataType }{ { name: "FloatVector", elementType: schemapb.DataType_FloatVector, - expectedChild: arrow.PrimitiveTypes.Float32, + dim: dim, + expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim * 4}, }, { name: "BinaryVector", elementType: schemapb.DataType_BinaryVector, - expectedChild: arrow.PrimitiveTypes.Uint8, + dim: dim, + expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: (dim + 7) / 8}, }, { name: "Float16Vector", elementType: schemapb.DataType_Float16Vector, - expectedChild: arrow.PrimitiveTypes.Uint8, + dim: dim, + expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2}, }, { name: "BFloat16Vector", elementType: schemapb.DataType_BFloat16Vector, - expectedChild: arrow.PrimitiveTypes.Uint8, + dim: dim, + expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2}, }, { name: "Int8Vector", elementType: schemapb.DataType_Int8Vector, - expectedChild: arrow.PrimitiveTypes.Int8, + dim: dim, + expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - arrowType := getArrayOfVectorArrowType(tt.elementType) + arrowType := getArrayOfVectorArrowType(tt.elementType, tt.dim) assert.NotNil(t, arrowType) listType, ok := arrowType.(*arrow.ListType) diff --git a/internal/storage/utils.go b/internal/storage/utils.go index 75f9ce79b2..382cb6f154 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -1665,20 +1665,21 @@ func SortFieldBinlogs(fieldBinlogs map[int64]*datapb.FieldBinlog) []*datapb.Fiel } // VectorArrayToArrowType converts VectorArray element type to the corresponding Arrow type -// Note: This returns the element type (e.g., float32), not a list type +// Note: This returns the element type (e.g., FixedSizeBinary), not a list type // The caller is responsible for wrapping it in a list if needed -func VectorArrayToArrowType(elementType schemapb.DataType) (arrow.DataType, error) { +func VectorArrayToArrowType(elementType schemapb.DataType, dim int) (arrow.DataType, error) { switch elementType { case schemapb.DataType_FloatVector: - return arrow.PrimitiveTypes.Float32, nil + // Each vector is stored as a fixed-size binary chunk + return &arrow.FixedSizeBinaryType{ByteWidth: dim * 4}, nil case schemapb.DataType_BinaryVector: - return nil, merr.WrapErrParameterInvalidMsg("BinaryVector in VectorArray not implemented yet") + return &arrow.FixedSizeBinaryType{ByteWidth: (dim + 7) / 8}, nil case schemapb.DataType_Float16Vector: - return nil, merr.WrapErrParameterInvalidMsg("Float16Vector in VectorArray not implemented yet") + return &arrow.FixedSizeBinaryType{ByteWidth: dim * 2}, nil case schemapb.DataType_BFloat16Vector: - return nil, merr.WrapErrParameterInvalidMsg("BFloat16Vector in VectorArray not implemented yet") + return &arrow.FixedSizeBinaryType{ByteWidth: dim * 2}, nil case schemapb.DataType_Int8Vector: - return nil, merr.WrapErrParameterInvalidMsg("Int8Vector in VectorArray not implemented yet") + return &arrow.FixedSizeBinaryType{ByteWidth: dim}, nil default: return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", elementType.String())) } diff --git a/internal/util/importutilv2/parquet/field_reader.go b/internal/util/importutilv2/parquet/field_reader.go index a204794502..3bc875e191 100644 --- a/internal/util/importutilv2/parquet/field_reader.go +++ b/internal/util/importutilv2/parquet/field_reader.go @@ -1845,34 +1845,49 @@ func ReadVectorArrayData(pcr *FieldReader, count int64) (any, error) { if chunk.NullN() > 0 { return nil, WrapNullRowErr(pcr.field) } + // ArrayOfVector is stored as list of fixed size binary listReader, ok := chunk.(*array.List) if !ok { return nil, WrapTypeErr(pcr.field, chunk.DataType().Name()) } - listFloat32Reader, ok := listReader.ListValues().(*array.Float32) + + fixedBinaryReader, ok := listReader.ListValues().(*array.FixedSizeBinary) if !ok { - return nil, WrapTypeErr(pcr.field, chunk.DataType().Name()) + return nil, WrapTypeErr(pcr.field, listReader.ListValues().DataType().Name()) } + + // Check that each vector has the correct byte size (dim * 4 bytes for float32) + expectedByteSize := int(dim) * 4 + actualByteSize := fixedBinaryReader.DataType().(*arrow.FixedSizeBinaryType).ByteWidth + if actualByteSize != expectedByteSize { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("vector byte size mismatch: expected %d, got %d for field '%s'", + expectedByteSize, actualByteSize, pcr.field.GetName())) + } + offsets := listReader.Offsets() for i := 1; i < len(offsets); i++ { start, end := offsets[i-1], offsets[i] - floatCount := end - start - if floatCount%int32(dim) != 0 { - return nil, merr.WrapErrImportFailed(fmt.Sprintf("vectors in VectorArray should be aligned with dim: %d", dim)) - } + vectorCount := end - start - arrLength := floatCount / int32(dim) - if err = common.CheckArrayCapacity(int(arrLength), maxCapacity, pcr.field); err != nil { + if err = common.CheckArrayCapacity(int(vectorCount), maxCapacity, pcr.field); err != nil { return nil, err } - arrData := make([]float32, floatCount) - copy(arrData, listFloat32Reader.Float32Values()[start:end]) + // Convert binary data to float32 array using arrow's built-in conversion + totalFloats := vectorCount * int32(dim) + floatData := make([]float32, totalFloats) + for j := int32(0); j < vectorCount; j++ { + vectorIndex := start + j + binaryData := fixedBinaryReader.Value(int(vectorIndex)) + vectorFloats := arrow.Float32Traits.CastFromBytes(binaryData) + copy(floatData[j*int32(dim):(j+1)*int32(dim)], vectorFloats) + } + data = append(data, &schemapb.VectorField{ Dim: dim, Data: &schemapb.VectorField_FloatVector{ FloatVector: &schemapb.FloatArray{ - Data: arrData, + Data: floatData, }, }, }) diff --git a/internal/util/importutilv2/parquet/util.go b/internal/util/importutilv2/parquet/util.go index f6079aa893..3ba18992b7 100644 --- a/internal/util/importutilv2/parquet/util.go +++ b/internal/util/importutilv2/parquet/util.go @@ -195,6 +195,8 @@ func isArrowDataTypeConvertible(src arrow.DataType, dst arrow.DataType, field *s return valid } return false + case arrow.FIXED_SIZE_BINARY: + return dstType == arrow.FIXED_SIZE_BINARY default: return false } @@ -293,9 +295,11 @@ func convertToArrowDataType(field *schemapb.FieldSchema, isArray bool) (arrow.Da Metadata: arrow.Metadata{}, }), nil case schemapb.DataType_ArrayOfVector: - // VectorArrayToArrowType now returns the element type (e.g., float32) - // We wrap it in a single list to get list (flattened) - elemType, err := storage.VectorArrayToArrowType(field.GetElementType()) + dim, err := typeutil.GetDim(field) + if err != nil { + return nil, err + } + elemType, err := storage.VectorArrayToArrowType(field.GetElementType(), int(dim)) if err != nil { return nil, err } diff --git a/internal/util/indexparamcheck/constraints.go b/internal/util/indexparamcheck/constraints.go index c55693b0d2..a004b33bc6 100644 --- a/internal/util/indexparamcheck/constraints.go +++ b/internal/util/indexparamcheck/constraints.go @@ -52,11 +52,12 @@ const ( ) var ( - FloatVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE} // const - SparseFloatVectorMetrics = []string{metric.IP, metric.BM25} // const - BinaryVectorMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE, metric.MHJACCARD} // const - IntVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE} // const - EmbListMetrics = []string{metric.MaxSim} // const + // all consts + FloatVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE} + SparseFloatVectorMetrics = []string{metric.IP, metric.BM25} + BinaryVectorMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE, metric.MHJACCARD} + IntVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE} + EmbListMetrics = []string{metric.MaxSim, metric.MaxSimCosine, metric.MaxSimL2, metric.MaxSimIP, metric.MaxSimHamming, metric.MaxSimJaccard} ) // BinIDMapMetrics is a set of all metric types supported for binary vector. diff --git a/internal/util/testutil/test_util.go b/internal/util/testutil/test_util.go index 20439e36ae..897522bfbc 100644 --- a/internal/util/testutil/test_util.go +++ b/internal/util/testutil/test_util.go @@ -1,6 +1,7 @@ package testutil import ( + "encoding/binary" "fmt" "math" "math/rand" @@ -734,30 +735,113 @@ func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.Inser columns = append(columns, builder.NewListArray()) } case schemapb.DataType_ArrayOfVector: - data := insertData.Data[fieldID].(*storage.VectorArrayFieldData).Data - rows := len(data) + vectorArrayData := insertData.Data[fieldID].(*storage.VectorArrayFieldData) + dim, err := typeutil.GetDim(field) + if err != nil { + return nil, err + } + elemType, err := storage.VectorArrayToArrowType(elementType, int(dim)) + if err != nil { + return nil, err + } - switch elementType { - case schemapb.DataType_FloatVector: - // ArrayOfVector is flattened in Arrow - just a list of floats - // where total floats = dim * num_vectors - builder := array.NewListBuilder(mem, &arrow.Float32Type{}) - valueBuilder := builder.ValueBuilder().(*array.Float32Builder) + // Create ListBuilder with "item" field name to match convertToArrowDataType + // Always represented as a list of fixed-size binary values + listBuilder := array.NewListBuilderWithField(mem, arrow.Field{ + Name: "item", + Type: elemType, + Nullable: true, + Metadata: arrow.Metadata{}, + }) + fixedSizeBuilder, ok := listBuilder.ValueBuilder().(*array.FixedSizeBinaryBuilder) + if !ok { + return nil, fmt.Errorf("unexpected list value builder for VectorArray field %s: %T", field.GetName(), listBuilder.ValueBuilder()) + } - for i := 0; i < rows; i++ { - vectorArray := data[i].GetFloatVector() - if vectorArray == nil || len(vectorArray.GetData()) == 0 { - builder.AppendNull() + vectorArrayData.Dim = dim + + bytesPerVector := fixedSizeBuilder.Type().(*arrow.FixedSizeBinaryType).ByteWidth + + appendBinarySlice := func(data []byte, stride int) error { + if stride == 0 { + return fmt.Errorf("zero stride for VectorArray field %s", field.GetName()) + } + if len(data)%stride != 0 { + return fmt.Errorf("vector array data length %d is not divisible by stride %d for field %s", len(data), stride, field.GetName()) + } + for offset := 0; offset < len(data); offset += stride { + fixedSizeBuilder.Append(data[offset : offset+stride]) + } + return nil + } + + for _, vectorField := range vectorArrayData.Data { + if vectorField == nil { + listBuilder.Append(false) + continue + } + + listBuilder.Append(true) + + switch elementType { + case schemapb.DataType_FloatVector: + floatArray := vectorField.GetFloatVector() + if floatArray == nil { + return nil, fmt.Errorf("expected FloatVector data for field %s", field.GetName()) + } + data := floatArray.GetData() + if len(data) == 0 { continue } - builder.Append(true) - // Append all flattened vector data - valueBuilder.AppendValues(vectorArray.GetData(), nil) + if len(data)%int(dim) != 0 { + return nil, fmt.Errorf("float vector data length %d is not divisible by dim %d for field %s", len(data), dim, field.GetName()) + } + for offset := 0; offset < len(data); offset += int(dim) { + vectorBytes := make([]byte, bytesPerVector) + for j := 0; j < int(dim); j++ { + binary.LittleEndian.PutUint32(vectorBytes[j*4:], math.Float32bits(data[offset+j])) + } + fixedSizeBuilder.Append(vectorBytes) + } + case schemapb.DataType_BinaryVector: + binaryData := vectorField.GetBinaryVector() + if len(binaryData) == 0 { + continue + } + bytesPer := int((dim + 7) / 8) + if err := appendBinarySlice(binaryData, bytesPer); err != nil { + return nil, err + } + case schemapb.DataType_Float16Vector: + float16Data := vectorField.GetFloat16Vector() + if len(float16Data) == 0 { + continue + } + if err := appendBinarySlice(float16Data, int(dim)*2); err != nil { + return nil, err + } + case schemapb.DataType_BFloat16Vector: + bfloat16Data := vectorField.GetBfloat16Vector() + if len(bfloat16Data) == 0 { + continue + } + if err := appendBinarySlice(bfloat16Data, int(dim)*2); err != nil { + return nil, err + } + case schemapb.DataType_Int8Vector: + int8Data := vectorField.GetInt8Vector() + if len(int8Data) == 0 { + continue + } + if err := appendBinarySlice(int8Data, int(dim)); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unsupported element type in VectorArray: %s", elementType.String()) } - columns = append(columns, builder.NewListArray()) - default: - return nil, fmt.Errorf("unsupported element type in VectorArray: %s", elementType.String()) } + + columns = append(columns, listBuilder.NewListArray()) } } return columns, nil diff --git a/pkg/proto/plan.proto b/pkg/proto/plan.proto index 0df3b2274c..c12ac79abd 100644 --- a/pkg/proto/plan.proto +++ b/pkg/proto/plan.proto @@ -45,6 +45,7 @@ enum VectorType { EmbListFloat16Vector = 7; EmbListBFloat16Vector = 8; EmbListInt8Vector = 9; + EmbListBinaryVector = 10; }; message GenericValue { diff --git a/pkg/proto/planpb/plan.pb.go b/pkg/proto/planpb/plan.pb.go index c4257066ba..33569908c0 100644 --- a/pkg/proto/planpb/plan.pb.go +++ b/pkg/proto/planpb/plan.pb.go @@ -184,21 +184,23 @@ const ( VectorType_EmbListFloat16Vector VectorType = 7 VectorType_EmbListBFloat16Vector VectorType = 8 VectorType_EmbListInt8Vector VectorType = 9 + VectorType_EmbListBinaryVector VectorType = 10 ) // Enum value maps for VectorType. var ( VectorType_name = map[int32]string{ - 0: "BinaryVector", - 1: "FloatVector", - 2: "Float16Vector", - 3: "BFloat16Vector", - 4: "SparseFloatVector", - 5: "Int8Vector", - 6: "EmbListFloatVector", - 7: "EmbListFloat16Vector", - 8: "EmbListBFloat16Vector", - 9: "EmbListInt8Vector", + 0: "BinaryVector", + 1: "FloatVector", + 2: "Float16Vector", + 3: "BFloat16Vector", + 4: "SparseFloatVector", + 5: "Int8Vector", + 6: "EmbListFloatVector", + 7: "EmbListFloat16Vector", + 8: "EmbListBFloat16Vector", + 9: "EmbListInt8Vector", + 10: "EmbListBinaryVector", } VectorType_value = map[string]int32{ "BinaryVector": 0, @@ -211,6 +213,7 @@ var ( "EmbListFloat16Vector": 7, "EmbListBFloat16Vector": 8, "EmbListInt8Vector": 9, + "EmbListBinaryVector": 10, } ) @@ -3808,7 +3811,7 @@ var file_plan_proto_rawDesc = []byte{ 0x0a, 0x03, 0x53, 0x75, 0x62, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x75, 0x6c, 0x10, 0x03, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x69, 0x76, 0x10, 0x04, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x6f, 0x64, 0x10, 0x05, 0x12, 0x0f, 0x0a, 0x0b, 0x41, 0x72, 0x72, 0x61, 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74, - 0x68, 0x10, 0x06, 0x2a, 0xe1, 0x01, 0x0a, 0x0a, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x54, 0x79, + 0x68, 0x10, 0x06, 0x2a, 0xfa, 0x01, 0x0a, 0x0a, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, @@ -3822,22 +3825,23 @@ var file_plan_proto_rawDesc = []byte{ 0x74, 0x6f, 0x72, 0x10, 0x07, 0x12, 0x19, 0x0a, 0x15, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x42, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x08, 0x12, 0x15, 0x0a, 0x11, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x49, 0x6e, 0x74, 0x38, 0x56, - 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x09, 0x2a, 0x3e, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x12, 0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x57, 0x65, 0x69, 0x67, 0x68, 0x74, 0x10, 0x00, 0x12, - 0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x52, - 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x10, 0x01, 0x2a, 0x3d, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x14, 0x46, 0x75, 0x6e, 0x63, 0x74, - 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x79, 0x10, - 0x00, 0x12, 0x13, 0x0a, 0x0f, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, - 0x65, 0x53, 0x75, 0x6d, 0x10, 0x01, 0x2a, 0x34, 0x0a, 0x09, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, - 0x6f, 0x64, 0x65, 0x12, 0x15, 0x0a, 0x11, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, - 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x79, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x6f, - 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x53, 0x75, 0x6d, 0x10, 0x01, 0x42, 0x31, 0x5a, 0x2f, - 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, - 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, - 0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x09, 0x12, 0x17, 0x0a, 0x13, 0x45, 0x6d, 0x62, 0x4c, 0x69, + 0x73, 0x74, 0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x0a, + 0x2a, 0x3e, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, + 0x12, 0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, + 0x57, 0x65, 0x69, 0x67, 0x68, 0x74, 0x10, 0x00, 0x12, 0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x10, 0x01, + 0x2a, 0x3d, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65, + 0x12, 0x18, 0x0a, 0x14, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65, + 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x79, 0x10, 0x00, 0x12, 0x13, 0x0a, 0x0f, 0x46, 0x75, + 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65, 0x53, 0x75, 0x6d, 0x10, 0x01, 0x2a, + 0x34, 0x0a, 0x09, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x15, 0x0a, 0x11, + 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, + 0x79, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, + 0x53, 0x75, 0x6d, 0x10, 0x01, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, + 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, + 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/pkg/util/metric/metric_type.go b/pkg/util/metric/metric_type.go index 691feb2e77..9dc68776a9 100644 --- a/pkg/util/metric/metric_type.go +++ b/pkg/util/metric/metric_type.go @@ -44,5 +44,11 @@ const ( EMPTY MetricType = "" - MaxSim MetricType = "MAX_SIM" + // The same with MaxSimCosine + MaxSim MetricType = "MAX_SIM" + MaxSimCosine MetricType = "MAX_SIM_COSINE" + MaxSimL2 MetricType = "MAX_SIM_L2" + MaxSimIP MetricType = "MAX_SIM_IP" + MaxSimHamming MetricType = "MAX_SIM_HAMMING" + MaxSimJaccard MetricType = "MAX_SIM_JACCARD" ) diff --git a/tests/integration/datanode/struct_array_test.go b/tests/integration/datanode/struct_array_test.go index 15aeaffecb..06cb69fb0d 100644 --- a/tests/integration/datanode/struct_array_test.go +++ b/tests/integration/datanode/struct_array_test.go @@ -176,7 +176,7 @@ func (s *ArrayStructDataNodeSuite) loadCollection(collectionName string) { CollectionName: collectionName, FieldName: subFieldName, IndexName: "array_of_vector_index", - ExtraParams: integration.ConstructIndexParam(s.dim, integration.IndexEmbListHNSW, metric.MaxSim), + ExtraParams: integration.ConstructIndexParam(s.dim, integration.IndexHNSW, metric.MaxSim), }) s.NoError(err) s.Require().Equal(createIndexResult.GetErrorCode(), commonpb.ErrorCode_Success) @@ -318,7 +318,7 @@ func (s *ArrayStructDataNodeSuite) query(collectionName string) { roundDecimal := -1 subFieldName := proxy.ConcatStructFieldName(integration.StructArrayField, integration.StructSubFloatVecField) - params := integration.GetSearchParams(integration.IndexEmbListHNSW, metric.MaxSim) + params := integration.GetSearchParams(integration.IndexHNSW, metric.MaxSim) searchReq := integration.ConstructEmbeddingListSearchRequest("", collectionName, expr, subFieldName, schemapb.DataType_FloatVector, []string{integration.StructArrayField}, metric.MaxSim, params, nq, s.dim, topk, roundDecimal) diff --git a/tests/integration/getvector/array_struct_test.go b/tests/integration/getvector/array_struct_test.go index f917dc6323..67b3d41bb1 100644 --- a/tests/integration/getvector/array_struct_test.go +++ b/tests/integration/getvector/array_struct_test.go @@ -246,7 +246,7 @@ func (s *TestArrayStructSuite) run() { func (s *TestArrayStructSuite) TestGetVector_ArrayStruct_FloatVector() { s.nq = 10 s.topK = 10 - s.indexType = integration.IndexEmbListHNSW + s.indexType = integration.IndexHNSW s.metricType = metric.MaxSim s.vecType = schemapb.DataType_FloatVector s.run() diff --git a/tests/integration/import/binlog_test.go b/tests/integration/import/binlog_test.go index 38c680e7a8..9128b0322a 100644 --- a/tests/integration/import/binlog_test.go +++ b/tests/integration/import/binlog_test.go @@ -102,7 +102,7 @@ func (s *BulkInsertSuite) PrepareSourceCollection(dim int, dmlGroup *DMLGroup) * CollectionName: collectionName, FieldName: name, IndexName: "array_of_vector_index", - ExtraParams: integration.ConstructIndexParam(dim, integration.IndexEmbListHNSW, metric.MaxSim), + ExtraParams: integration.ConstructIndexParam(dim, integration.IndexHNSW, metric.MaxSim), }) s.NoError(err) s.Require().Equal(createIndexResult.GetErrorCode(), commonpb.ErrorCode_Success) diff --git a/tests/integration/import/vector_array_test.go b/tests/integration/import/vector_array_test.go index 64225b6fb6..7a4aee2cff 100644 --- a/tests/integration/import/vector_array_test.go +++ b/tests/integration/import/vector_array_test.go @@ -303,7 +303,7 @@ func (s *BulkInsertSuite) TestImportWithVectorArray() { for _, fileType := range fileTypeArr { s.fileType = fileType s.vecType = schemapb.DataType_FloatVector - s.indexType = integration.IndexEmbListHNSW + s.indexType = integration.IndexHNSW s.metricType = metric.MaxSim s.runForStructArray() } diff --git a/tests/integration/util_index.go b/tests/integration/util_index.go index 9eb32171cd..4e29ab04fa 100644 --- a/tests/integration/util_index.go +++ b/tests/integration/util_index.go @@ -43,7 +43,6 @@ const ( IndexDISKANN = "DISKANN" IndexSparseInvertedIndex = "SPARSE_INVERTED_INDEX" IndexSparseWand = "SPARSE_WAND" - IndexEmbListHNSW = "EMB_LIST_HNSW" ) func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) { @@ -169,15 +168,6 @@ func ConstructIndexParam(dim int, indexType string, metricType string) []*common Key: "efConstruction", Value: "200", }) - case IndexEmbListHNSW: - params = append(params, &commonpb.KeyValuePair{ - Key: "M", - Value: "16", - }) - params = append(params, &commonpb.KeyValuePair{ - Key: "efConstruction", - Value: "200", - }) case IndexSparseInvertedIndex: case IndexSparseWand: case IndexDISKANN: @@ -195,7 +185,6 @@ func GetSearchParams(indexType string, metricType string) map[string]any { case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8, IndexFaissIvfPQ, IndexScaNN: params["nprobe"] = 8 case IndexHNSW: - case IndexEmbListHNSW: params["ef"] = 200 case IndexDISKANN: params["search_list"] = 20