feat: impl StructArray -- support more types of vector in STRUCT (#44736)

ref: https://github.com/milvus-io/milvus/issues/42148

---------

Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
Signed-off-by: SpadeA-Tang <tangchenjie1210@gmail.com>
This commit is contained in:
Spade A 2025-10-15 10:25:59 +08:00 committed by GitHub
parent 5ad8a29c0b
commit c4f3f0ce4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
63 changed files with 1928 additions and 684 deletions

View File

@ -438,11 +438,11 @@ class VectorArrayChunk : public Chunk {
offsets_lens_ = reinterpret_cast<uint32_t*>(data); offsets_lens_ = reinterpret_cast<uint32_t*>(data);
auto offset = 0; auto offset = 0;
lims_.reserve(row_nums_ + 1); offsets_.reserve(row_nums_ + 1);
lims_.push_back(offset); offsets_.push_back(offset);
for (int64_t i = 0; i < row_nums_; i++) { for (int64_t i = 0; i < row_nums_; i++) {
offset += offsets_lens_[i * 2 + 1]; offset += offsets_lens_[i * 2 + 1];
lims_.push_back(offset); offsets_.push_back(offset);
} }
} }
@ -507,17 +507,15 @@ class VectorArrayChunk : public Chunk {
} }
const size_t* const size_t*
Lims() const { Offsets() const {
return lims_.data(); return offsets_.data();
} }
private: private:
int64_t dim_; int64_t dim_;
uint32_t* offsets_lens_; uint32_t* offsets_lens_;
milvus::DataType element_type_; milvus::DataType element_type_;
// The name 'Lims' is consistent with knowhere::DataSet::SetLims which describes the number of vectors std::vector<size_t> offsets_;
// in each vector array (embedding list). This is needed as vectors are flattened in the chunk.
std::vector<size_t> lims_;
}; };
class SparseFloatVectorChunk : public Chunk { class SparseFloatVectorChunk : public Chunk {

View File

@ -339,34 +339,9 @@ VectorArrayChunkWriter::write(const arrow::ArrayVector& array_vec) {
target_ = std::make_shared<MemChunkTarget>(total_size); target_ = std::make_shared<MemChunkTarget>(total_size);
} }
switch (element_type_) { // Seirialization, the format is: [offsets_lens][all_vector_data_concatenated]
case milvus::DataType::VECTOR_FLOAT:
writeFloatVectorArray(array_vec);
break;
case milvus::DataType::VECTOR_BINARY:
ThrowInfo(NotImplemented,
"BinaryVector in VectorArray not implemented yet");
case milvus::DataType::VECTOR_FLOAT16:
ThrowInfo(NotImplemented,
"Float16Vector in VectorArray not implemented yet");
case milvus::DataType::VECTOR_BFLOAT16:
ThrowInfo(NotImplemented,
"BFloat16Vector in VectorArray not implemented yet");
case milvus::DataType::VECTOR_INT8:
ThrowInfo(NotImplemented,
"Int8Vector in VectorArray not implemented yet");
default:
ThrowInfo(NotImplemented,
"Unsupported element type in VectorArray: {}",
static_cast<int>(element_type_));
}
}
void
VectorArrayChunkWriter::writeFloatVectorArray(
const arrow::ArrayVector& array_vec) {
std::vector<uint32_t> offsets_lens; std::vector<uint32_t> offsets_lens;
std::vector<const float*> float_data_ptrs; std::vector<const uint8_t*> vector_data_ptrs;
std::vector<size_t> data_sizes; std::vector<size_t> data_sizes;
uint32_t current_offset = uint32_t current_offset =
@ -375,25 +350,27 @@ VectorArrayChunkWriter::writeFloatVectorArray(
for (const auto& array_data : array_vec) { for (const auto& array_data : array_vec) {
auto list_array = auto list_array =
std::static_pointer_cast<arrow::ListArray>(array_data); std::static_pointer_cast<arrow::ListArray>(array_data);
auto float_values = auto binary_values =
std::static_pointer_cast<arrow::FloatArray>(list_array->values()); std::static_pointer_cast<arrow::FixedSizeBinaryArray>(
const float* raw_floats = float_values->raw_values(); list_array->values());
const int32_t* list_offsets = list_array->raw_value_offsets(); const int32_t* list_offsets = list_array->raw_value_offsets();
int byte_width = binary_values->byte_width();
// Generate offsets and lengths for each row // Generate offsets and lengths for each row
// Each list contains multiple float vectors which are flattened, so the float count // Each list contains multiple vectors, each stored as a fixed-size binary chunk
// in each list is vector count * dim.
for (int64_t i = 0; i < list_array->length(); i++) { for (int64_t i = 0; i < list_array->length(); i++) {
auto start_idx = list_offsets[i]; auto start_idx = list_offsets[i];
auto end_idx = list_offsets[i + 1]; auto end_idx = list_offsets[i + 1];
auto vector_count = (end_idx - start_idx) / dim_; auto vector_count = end_idx - start_idx;
auto byte_size = (end_idx - start_idx) * sizeof(float); auto byte_size = vector_count * byte_width;
offsets_lens.push_back(current_offset); offsets_lens.push_back(current_offset);
offsets_lens.push_back(static_cast<uint32_t>(vector_count)); offsets_lens.push_back(static_cast<uint32_t>(vector_count));
float_data_ptrs.push_back(raw_floats + start_idx); for (int j = start_idx; j < end_idx; j++) {
data_sizes.push_back(byte_size); vector_data_ptrs.push_back(binary_values->GetValue(j));
data_sizes.push_back(byte_width);
}
current_offset += byte_size; current_offset += byte_size;
} }
@ -409,8 +386,8 @@ VectorArrayChunkWriter::writeFloatVectorArray(
} }
target_->write(&offsets_lens.back(), sizeof(uint32_t)); // final offset target_->write(&offsets_lens.back(), sizeof(uint32_t)); // final offset
for (size_t i = 0; i < float_data_ptrs.size(); i++) { for (size_t i = 0; i < vector_data_ptrs.size(); i++) {
target_->write(float_data_ptrs[i], data_sizes[i]); target_->write(vector_data_ptrs[i], data_sizes[i]);
} }
} }
@ -427,19 +404,18 @@ VectorArrayChunkWriter::calculateTotalSize(
std::static_pointer_cast<arrow::ListArray>(array_data); std::static_pointer_cast<arrow::ListArray>(array_data);
switch (element_type_) { switch (element_type_) {
case milvus::DataType::VECTOR_FLOAT: { case milvus::DataType::VECTOR_FLOAT:
auto float_values = std::static_pointer_cast<arrow::FloatArray>(
list_array->values());
total_size += float_values->length() * sizeof(float);
break;
}
case milvus::DataType::VECTOR_BINARY: case milvus::DataType::VECTOR_BINARY:
case milvus::DataType::VECTOR_FLOAT16: case milvus::DataType::VECTOR_FLOAT16:
case milvus::DataType::VECTOR_BFLOAT16: case milvus::DataType::VECTOR_BFLOAT16:
case milvus::DataType::VECTOR_INT8: case milvus::DataType::VECTOR_INT8: {
ThrowInfo(NotImplemented, auto binary_values =
"Element type {} in VectorArray not implemented yet", std::static_pointer_cast<arrow::FixedSizeBinaryArray>(
static_cast<int>(element_type_)); list_array->values());
total_size +=
binary_values->length() * binary_values->byte_width();
break;
}
default: default:
ThrowInfo(DataTypeInvalid, ThrowInfo(DataTypeInvalid,
"Invalid element type {} for VectorArray", "Invalid element type {} for VectorArray",

View File

@ -273,9 +273,6 @@ class VectorArrayChunkWriter : public ChunkWriterBase {
finish() override; finish() override;
private: private:
void
writeFloatVectorArray(const arrow::ArrayVector& array_vec);
size_t size_t
calculateTotalSize(const arrow::ArrayVector& array_vec); calculateTotalSize(const arrow::ArrayVector& array_vec);

View File

@ -25,6 +25,7 @@
#include "common/Exception.h" #include "common/Exception.h"
#include "common/FieldDataInterface.h" #include "common/FieldDataInterface.h"
#include "common/Json.h" #include "common/Json.h"
#include "index/Utils.h"
#include "simdjson/padded_string.h" #include "simdjson/padded_string.h"
namespace milvus { namespace milvus {
@ -348,46 +349,47 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
std::vector<VectorArray> values(element_count); std::vector<VectorArray> values(element_count);
switch (element_type) { switch (element_type) {
case DataType::VECTOR_FLOAT: { case DataType::VECTOR_FLOAT:
auto float_array = case DataType::VECTOR_BINARY:
std::dynamic_pointer_cast<arrow::FloatArray>( case DataType::VECTOR_FLOAT16:
case DataType::VECTOR_BFLOAT16:
case DataType::VECTOR_INT8: {
// All vector types use FixedSizeBinaryArray and have the same serialization logic
auto binary_array =
std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(
values_array); values_array);
AssertInfo( AssertInfo(binary_array != nullptr,
float_array != nullptr, "Expected FixedSizeBinaryArray for VectorArray "
"Expected FloatArray for VECTOR_FLOAT element type"); "element");
// Calculate bytes per vector using the unified function
auto bytes_per_vec =
milvus::vector_bytes_per_element(element_type, dim);
for (size_t index = 0; index < element_count; ++index) { for (size_t index = 0; index < element_count; ++index) {
int64_t start_offset = list_array->value_offset(index); int64_t start_offset = list_array->value_offset(index);
int64_t end_offset = int64_t end_offset =
list_array->value_offset(index + 1); list_array->value_offset(index + 1);
int64_t num_floats = end_offset - start_offset; int64_t num_vectors = end_offset - start_offset;
AssertInfo(num_floats % dim == 0,
"Invalid data: number of floats ({}) not "
"divisible by "
"dimension ({})",
num_floats,
dim);
int num_vectors = num_floats / dim; auto data_size = num_vectors * bytes_per_vec;
const float* data_ptr = auto data_ptr = std::make_unique<uint8_t[]>(data_size);
float_array->raw_values() + start_offset;
values[index] = for (int64_t i = 0; i < num_vectors; i++) {
VectorArray(static_cast<const void*>(data_ptr), const uint8_t* binary_data =
binary_array->GetValue(start_offset + i);
uint8_t* dest = data_ptr.get() + i * bytes_per_vec;
std::memcpy(dest, binary_data, bytes_per_vec);
}
values[index] = VectorArray(
static_cast<const void*>(data_ptr.get()),
num_vectors, num_vectors,
dim, dim,
element_type); element_type);
} }
break; break;
} }
case DataType::VECTOR_BINARY:
case DataType::VECTOR_FLOAT16:
case DataType::VECTOR_BFLOAT16:
case DataType::VECTOR_INT8:
ThrowInfo(
NotImplemented,
"Element type {} in VectorArray not implemented yet",
GetDataTypeName(element_type));
break;
default: default:
ThrowInfo(DataTypeInvalid, ThrowInfo(DataTypeInvalid,
"Unsupported element type {} in VectorArray", "Unsupported element type {} in VectorArray",

View File

@ -94,8 +94,8 @@ Schema::ConvertToArrowSchema() const {
std::shared_ptr<arrow::DataType> arrow_data_type = nullptr; std::shared_ptr<arrow::DataType> arrow_data_type = nullptr;
auto data_type = meta.get_data_type(); auto data_type = meta.get_data_type();
if (data_type == DataType::VECTOR_ARRAY) { if (data_type == DataType::VECTOR_ARRAY) {
arrow_data_type = arrow_data_type = GetArrowDataTypeForVectorArray(
GetArrowDataTypeForVectorArray(meta.get_element_type()); meta.get_element_type(), meta.get_dim());
} else { } else {
arrow_data_type = GetArrowDataType(data_type, dim); arrow_data_type = GetArrowDataType(data_type, dim);
} }

View File

@ -205,10 +205,23 @@ GetArrowDataType(DataType data_type, int dim = 1) {
} }
inline std::shared_ptr<arrow::DataType> inline std::shared_ptr<arrow::DataType>
GetArrowDataTypeForVectorArray(DataType elem_type) { GetArrowDataTypeForVectorArray(DataType elem_type, int dim) {
if (dim <= 0) {
ThrowInfo(DataTypeInvalid, "dim must be provided for VectorArray");
}
// VectorArray stores vectors as FixedSizeBinaryArray
// We must have dim to create the correct fixed_size_binary type
switch (elem_type) { switch (elem_type) {
case DataType::VECTOR_FLOAT: case DataType::VECTOR_FLOAT:
return arrow::list(arrow::float32()); return arrow::list(arrow::fixed_size_binary(dim * sizeof(float)));
case DataType::VECTOR_BINARY:
return arrow::list(arrow::fixed_size_binary((dim + 7) / 8));
case DataType::VECTOR_FLOAT16:
return arrow::list(arrow::fixed_size_binary(dim * 2));
case DataType::VECTOR_BFLOAT16:
return arrow::list(arrow::fixed_size_binary(dim * 2));
case DataType::VECTOR_INT8:
return arrow::list(arrow::fixed_size_binary(dim));
default: { default: {
ThrowInfo(DataTypeInvalid, ThrowInfo(DataTypeInvalid,
fmt::format("failed to get arrow type for vector array, " fmt::format("failed to get arrow type for vector array, "
@ -534,7 +547,10 @@ IsFloatVectorMetricType(const MetricType& metric_type) {
metric_type == knowhere::metric::IP || metric_type == knowhere::metric::IP ||
metric_type == knowhere::metric::COSINE || metric_type == knowhere::metric::COSINE ||
metric_type == knowhere::metric::BM25 || metric_type == knowhere::metric::BM25 ||
metric_type == knowhere::metric::MAX_SIM; metric_type == knowhere::metric::MAX_SIM ||
metric_type == knowhere::metric::MAX_SIM_COSINE ||
metric_type == knowhere::metric::MAX_SIM_IP ||
metric_type == knowhere::metric::MAX_SIM_L2;
} }
inline bool inline bool
@ -543,14 +559,20 @@ IsBinaryVectorMetricType(const MetricType& metric_type) {
metric_type == knowhere::metric::JACCARD || metric_type == knowhere::metric::JACCARD ||
metric_type == knowhere::metric::SUPERSTRUCTURE || metric_type == knowhere::metric::SUPERSTRUCTURE ||
metric_type == knowhere::metric::SUBSTRUCTURE || metric_type == knowhere::metric::SUBSTRUCTURE ||
metric_type == knowhere::metric::MHJACCARD; metric_type == knowhere::metric::MHJACCARD ||
metric_type == knowhere::metric::MAX_SIM_HAMMING ||
metric_type == knowhere::metric::MAX_SIM_JACCARD;
} }
inline bool inline bool
IsIntVectorMetricType(const MetricType& metric_type) { IsIntVectorMetricType(const MetricType& metric_type) {
return metric_type == knowhere::metric::L2 || return metric_type == knowhere::metric::L2 ||
metric_type == knowhere::metric::IP || metric_type == knowhere::metric::IP ||
metric_type == knowhere::metric::COSINE; metric_type == knowhere::metric::COSINE ||
metric_type == knowhere::metric::MAX_SIM ||
metric_type == knowhere::metric::MAX_SIM_COSINE ||
metric_type == knowhere::metric::MAX_SIM_IP ||
metric_type == knowhere::metric::MAX_SIM_L2;
} }
// Plus 1 because we can't use greater(>) symbol // Plus 1 because we can't use greater(>) symbol
@ -746,6 +768,30 @@ FromValCase(milvus::proto::plan::GenericValue::ValCase val_case) {
return DataType::NONE; return DataType::NONE;
} }
} }
// Calculate bytes per vector element for different vector types
// Used by VectorArray and storage utilities
inline size_t
vector_bytes_per_element(const DataType data_type, int64_t dim) {
switch (data_type) {
case DataType::VECTOR_BINARY:
// Binary vector stores bits, so dim represents bit count
// Need (dim + 7) / 8 bytes to store dim bits
return (dim + 7) / 8;
case DataType::VECTOR_FLOAT:
return dim * sizeof(float);
case DataType::VECTOR_FLOAT16:
return dim * sizeof(float16);
case DataType::VECTOR_BFLOAT16:
return dim * sizeof(bfloat16);
case DataType::VECTOR_INT8:
return dim * sizeof(int8);
default:
ThrowInfo(UnexpectedError,
fmt::format("invalid data type: {}", data_type));
}
}
} // namespace milvus } // namespace milvus
template <> template <>
struct fmt::formatter<milvus::DataType> : formatter<string_view> { struct fmt::formatter<milvus::DataType> : formatter<string_view> {

View File

@ -41,16 +41,8 @@ class VectorArray : public milvus::VectorTrait {
assert(num_vectors > 0); assert(num_vectors > 0);
assert(dim > 0); assert(dim > 0);
switch (element_type) { size_ =
case DataType::VECTOR_FLOAT: num_vectors * milvus::vector_bytes_per_element(element_type, dim);
size_ = num_vectors * dim * sizeof(float);
break;
default:
ThrowInfo(NotImplemented,
"Direct VectorArray construction only supports "
"VECTOR_FLOAT, got {}",
GetDataTypeName(element_type));
}
data_ = std::make_unique<char[]>(size_); data_ = std::make_unique<char[]>(size_);
std::memcpy(data_.get(), data, size_); std::memcpy(data_.get(), data, size_);
@ -73,8 +65,49 @@ class VectorArray : public milvus::VectorTrait {
data_ = std::unique_ptr<char[]>(reinterpret_cast<char*>(data)); data_ = std::unique_ptr<char[]>(reinterpret_cast<char*>(data));
break; break;
} }
case VectorFieldProto::kBinaryVector: {
element_type_ = DataType::VECTOR_BINARY;
int bytes_per_vector = (dim_ + 7) / 8;
length_ =
vector_field.binary_vector().size() / bytes_per_vector;
size_ = vector_field.binary_vector().size();
data_ = std::make_unique<char[]>(size_);
std::memcpy(
data_.get(), vector_field.binary_vector().data(), size_);
break;
}
case VectorFieldProto::kFloat16Vector: {
element_type_ = DataType::VECTOR_FLOAT16;
int bytes_per_element = 2; // 2 bytes per float16
length_ = vector_field.float16_vector().size() /
(dim_ * bytes_per_element);
size_ = vector_field.float16_vector().size();
data_ = std::make_unique<char[]>(size_);
std::memcpy(
data_.get(), vector_field.float16_vector().data(), size_);
break;
}
case VectorFieldProto::kBfloat16Vector: {
element_type_ = DataType::VECTOR_BFLOAT16;
int bytes_per_element = 2; // 2 bytes per bfloat16
length_ = vector_field.bfloat16_vector().size() /
(dim_ * bytes_per_element);
size_ = vector_field.bfloat16_vector().size();
data_ = std::make_unique<char[]>(size_);
std::memcpy(
data_.get(), vector_field.bfloat16_vector().data(), size_);
break;
}
case VectorFieldProto::kInt8Vector: {
element_type_ = DataType::VECTOR_INT8;
length_ = vector_field.int8_vector().size() / dim_;
size_ = vector_field.int8_vector().size();
data_ = std::make_unique<char[]>(size_);
std::memcpy(
data_.get(), vector_field.int8_vector().data(), size_);
break;
}
default: { default: {
// TODO(SpadeA): add other vector types
ThrowInfo(NotImplemented, ThrowInfo(NotImplemented,
"Not implemented vector type: {}", "Not implemented vector type: {}",
static_cast<int>(vector_field.data_case())); static_cast<int>(vector_field.data_case()));
@ -160,8 +193,24 @@ class VectorArray : public milvus::VectorTrait {
return reinterpret_cast<VectorElement*>(data_.get()) + return reinterpret_cast<VectorElement*>(data_.get()) +
index * dim_; index * dim_;
} }
case DataType::VECTOR_BINARY: {
// Binary vectors are packed bits
int bytes_per_vector = (dim_ + 7) / 8;
return reinterpret_cast<VectorElement*>(
data_.get() + index * bytes_per_vector);
}
case DataType::VECTOR_FLOAT16:
case DataType::VECTOR_BFLOAT16: {
// Float16/BFloat16 are 2 bytes per element
return reinterpret_cast<VectorElement*>(data_.get() +
index * dim_ * 2);
}
case DataType::VECTOR_INT8: {
// Int8 is 1 byte per element
return reinterpret_cast<VectorElement*>(data_.get() +
index * dim_);
}
default: { default: {
// TODO(SpadeA): add other vector types
ThrowInfo(NotImplemented, ThrowInfo(NotImplemented,
"Not implemented vector type: {}", "Not implemented vector type: {}",
static_cast<int>(element_type_)); static_cast<int>(element_type_));
@ -180,8 +229,23 @@ class VectorArray : public milvus::VectorTrait {
data, data + length_ * dim_); data, data + length_ * dim_);
break; break;
} }
case DataType::VECTOR_BINARY: {
vector_field.set_binary_vector(data_.get(), size_);
break;
}
case DataType::VECTOR_FLOAT16: {
vector_field.set_float16_vector(data_.get(), size_);
break;
}
case DataType::VECTOR_BFLOAT16: {
vector_field.set_bfloat16_vector(data_.get(), size_);
break;
}
case DataType::VECTOR_INT8: {
vector_field.set_int8_vector(data_.get(), size_);
break;
}
default: { default: {
// TODO(SpadeA): add other vector types
ThrowInfo(NotImplemented, ThrowInfo(NotImplemented,
"Not implemented vector type: {}", "Not implemented vector type: {}",
static_cast<int>(element_type_)); static_cast<int>(element_type_));
@ -300,8 +364,23 @@ class VectorArrayView {
"VectorElement must be float for VECTOR_FLOAT"); "VectorElement must be float for VECTOR_FLOAT");
return reinterpret_cast<VectorElement*>(data_) + index * dim_; return reinterpret_cast<VectorElement*>(data_) + index * dim_;
} }
case DataType::VECTOR_BINARY: {
// Binary vectors are packed bits
int bytes_per_vector = (dim_ + 7) / 8;
return reinterpret_cast<VectorElement*>(
data_ + index * bytes_per_vector);
}
case DataType::VECTOR_FLOAT16:
case DataType::VECTOR_BFLOAT16: {
// Float16/BFloat16 are 2 bytes per element
return reinterpret_cast<VectorElement*>(data_ +
index * dim_ * 2);
}
case DataType::VECTOR_INT8: {
// Int8 is 1 byte per element
return reinterpret_cast<VectorElement*>(data_ + index * dim_);
}
default: { default: {
// TODO(SpadeA): add other vector types.
ThrowInfo(NotImplemented, ThrowInfo(NotImplemented,
"Not implemented vector type: {}", "Not implemented vector type: {}",
static_cast<int>(element_type_)); static_cast<int>(element_type_));
@ -320,8 +399,23 @@ class VectorArrayView {
data, data + length_ * dim_); data, data + length_ * dim_);
break; break;
} }
case DataType::VECTOR_BINARY: {
vector_array.set_binary_vector(data_, size_);
break;
}
case DataType::VECTOR_FLOAT16: {
vector_array.set_float16_vector(data_, size_);
break;
}
case DataType::VECTOR_BFLOAT16: {
vector_array.set_bfloat16_vector(data_, size_);
break;
}
case DataType::VECTOR_INT8: {
vector_array.set_int8_vector(data_, size_);
break;
}
default: { default: {
// TODO(SpadeA): add other vector types
ThrowInfo(NotImplemented, ThrowInfo(NotImplemented,
"Not implemented vector type: {}", "Not implemented vector type: {}",
static_cast<int>(element_type_)); static_cast<int>(element_type_));

View File

@ -40,12 +40,57 @@ class VectorArrayChunkTest : public ::testing::Test {
return result; return result;
} }
std::vector<uint16_t>
generateFloat16Vector(int64_t seed, int64_t N, int64_t dim) {
std::vector<uint16_t> result(dim * N);
std::default_random_engine gen(seed);
std::uniform_int_distribution<uint16_t> dist(0, 65535);
for (int64_t i = 0; i < dim * N; ++i) {
result[i] = dist(gen);
}
return result;
}
std::vector<uint16_t>
generateBFloat16Vector(int64_t seed, int64_t N, int64_t dim) {
// Same as Float16 for testing purposes
return generateFloat16Vector(seed, N, dim);
}
std::vector<int8_t>
generateInt8Vector(int64_t seed, int64_t N, int64_t dim) {
std::vector<int8_t> result(dim * N);
std::default_random_engine gen(seed);
std::uniform_int_distribution<int> dist(-128, 127);
for (int64_t i = 0; i < dim * N; ++i) {
result[i] = static_cast<int8_t>(dist(gen));
}
return result;
}
std::vector<uint8_t>
generateBinaryVector(int64_t seed, int64_t N, int64_t dim) {
std::vector<uint8_t> result((dim * N + 7) / 8);
std::default_random_engine gen(seed);
std::uniform_int_distribution<int> dist(0, 255);
for (size_t i = 0; i < result.size(); ++i) {
result[i] = static_cast<uint8_t>(dist(gen));
}
return result;
}
std::shared_ptr<arrow::ListArray> std::shared_ptr<arrow::ListArray>
createFloatVectorListArray(const std::vector<float>& data, createFloatVectorListArray(const std::vector<float>& data,
const std::vector<int32_t>& offsets) { const std::vector<int32_t>& offsets,
auto float_builder = std::make_shared<arrow::FloatBuilder>(); int64_t dim) {
int byte_width = dim * sizeof(float);
auto value_builder = std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(byte_width));
auto list_builder = std::make_shared<arrow::ListBuilder>( auto list_builder = std::make_shared<arrow::ListBuilder>(
arrow::default_memory_pool(), float_builder); arrow::default_memory_pool(), value_builder);
arrow::Status ast; arrow::Status ast;
for (size_t i = 0; i < offsets.size() - 1; ++i) { for (size_t i = 0; i < offsets.size() - 1; ++i) {
@ -54,8 +99,112 @@ class VectorArrayChunkTest : public ::testing::Test {
int32_t start = offsets[i]; int32_t start = offsets[i];
int32_t end = offsets[i + 1]; int32_t end = offsets[i + 1];
for (int32_t j = start; j < end; ++j) { // Each vector is dim floats
float_builder->Append(data[j]); for (int32_t j = start; j < end; j += dim) {
// Convert float vector to binary
const uint8_t* binary_data =
reinterpret_cast<const uint8_t*>(&data[j]);
ast = value_builder->Append(binary_data);
assert(ast.ok());
}
}
std::shared_ptr<arrow::Array> array;
ast = list_builder->Finish(&array);
assert(ast.ok());
return std::static_pointer_cast<arrow::ListArray>(array);
}
std::shared_ptr<arrow::ListArray>
createFloat16VectorListArray(const std::vector<uint16_t>& data,
const std::vector<int32_t>& offsets,
int64_t dim) {
int byte_width = dim * 2; // 2 bytes per float16
auto value_builder = std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(byte_width));
auto list_builder = std::make_shared<arrow::ListBuilder>(
arrow::default_memory_pool(), value_builder);
arrow::Status ast;
for (size_t i = 0; i < offsets.size() - 1; ++i) {
ast = list_builder->Append();
assert(ast.ok());
int32_t start = offsets[i];
int32_t end = offsets[i + 1];
for (int32_t j = start; j < end; j += dim) {
const uint8_t* binary_data =
reinterpret_cast<const uint8_t*>(&data[j]);
ast = value_builder->Append(binary_data);
assert(ast.ok());
}
}
std::shared_ptr<arrow::Array> array;
ast = list_builder->Finish(&array);
assert(ast.ok());
return std::static_pointer_cast<arrow::ListArray>(array);
}
std::shared_ptr<arrow::ListArray>
createBFloat16VectorListArray(const std::vector<uint16_t>& data,
const std::vector<int32_t>& offsets,
int64_t dim) {
// Same as Float16 but for bfloat16
return createFloat16VectorListArray(data, offsets, dim);
}
std::shared_ptr<arrow::ListArray>
createInt8VectorListArray(const std::vector<int8_t>& data,
const std::vector<int32_t>& offsets,
int64_t dim) {
int byte_width = dim; // 1 byte per int8
auto value_builder = std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(byte_width));
auto list_builder = std::make_shared<arrow::ListBuilder>(
arrow::default_memory_pool(), value_builder);
arrow::Status ast;
for (size_t i = 0; i < offsets.size() - 1; ++i) {
ast = list_builder->Append();
assert(ast.ok());
int32_t start = offsets[i];
int32_t end = offsets[i + 1];
for (int32_t j = start; j < end; j += dim) {
const uint8_t* binary_data =
reinterpret_cast<const uint8_t*>(&data[j]);
ast = value_builder->Append(binary_data);
assert(ast.ok());
}
}
std::shared_ptr<arrow::Array> array;
ast = list_builder->Finish(&array);
assert(ast.ok());
return std::static_pointer_cast<arrow::ListArray>(array);
}
std::shared_ptr<arrow::ListArray>
createBinaryVectorListArray(const std::vector<uint8_t>& data,
const std::vector<int32_t>& offsets,
int64_t dim) {
int byte_width = (dim + 7) / 8; // bits packed into bytes
auto value_builder = std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(byte_width));
auto list_builder = std::make_shared<arrow::ListBuilder>(
arrow::default_memory_pool(), value_builder);
arrow::Status ast;
for (size_t i = 0; i < offsets.size() - 1; ++i) {
ast = list_builder->Append();
assert(ast.ok());
int32_t start = offsets[i];
int32_t end = offsets[i + 1];
for (int32_t j = start; j < end; j += byte_width) {
ast = value_builder->Append(&data[j]);
assert(ast.ok());
} }
} }
@ -66,52 +215,6 @@ class VectorArrayChunkTest : public ::testing::Test {
} }
}; };
TEST_F(VectorArrayChunkTest, TestWriteFloatVectorArray) {
// Test parameters
const int64_t dim = 128;
const int num_rows = 100;
const int vectors_per_row = 5; // Each row contains 5 vectors
// Generate test data
std::vector<float> all_data;
std::vector<int32_t> offsets = {0};
for (int row = 0; row < num_rows; ++row) {
auto row_data = generateFloatVector(row, vectors_per_row, dim);
all_data.insert(all_data.end(), row_data.begin(), row_data.end());
offsets.push_back(offsets.back() + vectors_per_row * dim);
}
// Create Arrow ListArray
auto list_array = createFloatVectorListArray(all_data, offsets);
arrow::ArrayVector array_vec = {list_array};
// Test VectorArrayChunkWriter
VectorArrayChunkWriter writer(dim, DataType::VECTOR_FLOAT);
writer.write(array_vec);
auto chunk = writer.finish();
auto vector_array_chunk = static_cast<VectorArrayChunk*>(chunk.get());
// Verify results
EXPECT_EQ(vector_array_chunk->RowNums(), num_rows);
// Verify data integrity using View method
for (int row = 0; row < num_rows; ++row) {
auto view = vector_array_chunk->View(row);
// Verify by converting back to VectorFieldProto and checking
auto proto = view.output_data();
EXPECT_EQ(proto.dim(), dim);
EXPECT_EQ(proto.float_vector().data_size(), vectors_per_row * dim);
const float* expected = all_data.data() + row * vectors_per_row * dim;
for (int i = 0; i < vectors_per_row * dim; ++i) {
EXPECT_FLOAT_EQ(proto.float_vector().data(i), expected[i]);
}
}
}
TEST_F(VectorArrayChunkTest, TestWriteMultipleBatches) { TEST_F(VectorArrayChunkTest, TestWriteMultipleBatches) {
const int64_t dim = 64; const int64_t dim = 64;
const int batch_size = 50; const int batch_size = 50;
@ -137,7 +240,7 @@ TEST_F(VectorArrayChunkTest, TestWriteMultipleBatches) {
all_batch_data.push_back(batch_data); all_batch_data.push_back(batch_data);
array_vec.push_back( array_vec.push_back(
createFloatVectorListArray(batch_data, batch_offsets)); createFloatVectorListArray(batch_data, batch_offsets, dim));
} }
// Write using VectorArrayChunkWriter // Write using VectorArrayChunkWriter
@ -191,7 +294,7 @@ TEST_F(VectorArrayChunkTest, TestWriteWithMmap) {
offsets.push_back(offsets.back() + vectors_per_row * dim); offsets.push_back(offsets.back() + vectors_per_row * dim);
} }
auto list_array = createFloatVectorListArray(all_data, offsets); auto list_array = createFloatVectorListArray(all_data, offsets, dim);
arrow::ArrayVector array_vec = {list_array}; arrow::ArrayVector array_vec = {list_array};
// Write with mmap // Write with mmap
@ -235,3 +338,189 @@ TEST_F(VectorArrayChunkTest, TestEmptyVectorArray) {
EXPECT_EQ(vector_array_chunk->RowNums(), 0); EXPECT_EQ(vector_array_chunk->RowNums(), 0);
} }
struct VectorArrayTestParam {
DataType data_type;
int64_t dim;
int num_rows;
int vectors_per_row;
std::string test_name;
};
class VectorArrayChunkParameterizedTest
: public VectorArrayChunkTest,
public ::testing::WithParamInterface<VectorArrayTestParam> {};
template <typename T>
std::shared_ptr<arrow::ListArray>
createVectorListArray(const std::vector<T>& data,
const std::vector<int32_t>& offsets,
int64_t dim,
DataType dtype) {
int byte_width;
switch (dtype) {
case DataType::VECTOR_FLOAT:
byte_width = dim * sizeof(float);
break;
case DataType::VECTOR_FLOAT16:
case DataType::VECTOR_BFLOAT16:
byte_width = dim * 2;
break;
case DataType::VECTOR_INT8:
byte_width = dim;
break;
case DataType::VECTOR_BINARY:
byte_width = (dim + 7) / 8;
break;
default:
throw std::invalid_argument("Unsupported data type");
}
auto value_builder = std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(byte_width));
auto list_builder = std::make_shared<arrow::ListBuilder>(
arrow::default_memory_pool(), value_builder);
arrow::Status ast;
int element_size = (dtype == DataType::VECTOR_BINARY) ? byte_width : dim;
for (size_t i = 0; i < offsets.size() - 1; ++i) {
ast = list_builder->Append();
assert(ast.ok());
int32_t start = offsets[i];
int32_t end = offsets[i + 1];
for (int32_t j = start; j < end; j += element_size) {
const uint8_t* binary_data =
reinterpret_cast<const uint8_t*>(&data[j]);
ast = value_builder->Append(binary_data);
assert(ast.ok());
}
}
std::shared_ptr<arrow::Array> array;
ast = list_builder->Finish(&array);
assert(ast.ok());
return std::static_pointer_cast<arrow::ListArray>(array);
}
TEST_P(VectorArrayChunkParameterizedTest, TestWriteVectorArray) {
auto param = GetParam();
// Generate test data based on type
std::vector<uint8_t> all_data;
std::vector<int32_t> offsets = {0};
for (int row = 0; row < param.num_rows; ++row) {
std::vector<uint8_t> row_data;
switch (param.data_type) {
case DataType::VECTOR_FLOAT: {
auto float_data =
generateFloatVector(row, param.vectors_per_row, param.dim);
row_data.resize(float_data.size() * sizeof(float));
memcpy(row_data.data(), float_data.data(), row_data.size());
break;
}
case DataType::VECTOR_FLOAT16:
case DataType::VECTOR_BFLOAT16: {
auto uint16_data = generateFloat16Vector(
row, param.vectors_per_row, param.dim);
row_data.resize(uint16_data.size() * sizeof(uint16_t));
memcpy(row_data.data(), uint16_data.data(), row_data.size());
break;
}
case DataType::VECTOR_INT8: {
auto int8_data =
generateInt8Vector(row, param.vectors_per_row, param.dim);
row_data.resize(int8_data.size());
memcpy(row_data.data(), int8_data.data(), row_data.size());
break;
}
case DataType::VECTOR_BINARY: {
row_data =
generateBinaryVector(row, param.vectors_per_row, param.dim);
break;
}
default:
FAIL() << "Unsupported data type";
}
all_data.insert(all_data.end(), row_data.begin(), row_data.end());
// Calculate offset based on data type
int offset_increment;
if (param.data_type == DataType::VECTOR_BINARY) {
offset_increment = param.vectors_per_row * ((param.dim + 7) / 8);
} else if (param.data_type == DataType::VECTOR_FLOAT) {
offset_increment =
param.vectors_per_row * param.dim * sizeof(float);
} else if (param.data_type == DataType::VECTOR_FLOAT16 ||
param.data_type == DataType::VECTOR_BFLOAT16) {
offset_increment = param.vectors_per_row * param.dim * 2;
} else {
offset_increment = param.vectors_per_row * param.dim;
}
offsets.push_back(offsets.back() + offset_increment);
}
// Create Arrow ListArray
auto list_array =
createVectorListArray(all_data, offsets, param.dim, param.data_type);
arrow::ArrayVector array_vec = {list_array};
// Test VectorArrayChunkWriter
VectorArrayChunkWriter writer(param.dim, param.data_type);
writer.write(array_vec);
auto chunk = writer.finish();
auto vector_array_chunk = static_cast<VectorArrayChunk*>(chunk.get());
// Verify results
EXPECT_EQ(vector_array_chunk->RowNums(), param.num_rows);
// Basic verification - ensure View doesn't crash and returns valid data
for (int row = 0; row < param.num_rows; ++row) {
auto view = vector_array_chunk->View(row);
auto proto = view.output_data();
EXPECT_EQ(proto.dim(), param.dim);
// Verify the correct field is populated based on data type
switch (param.data_type) {
case DataType::VECTOR_FLOAT:
EXPECT_GT(proto.float_vector().data_size(), 0);
break;
case DataType::VECTOR_FLOAT16:
EXPECT_GT(proto.float16_vector().size(), 0);
break;
case DataType::VECTOR_BFLOAT16:
EXPECT_GT(proto.bfloat16_vector().size(), 0);
break;
case DataType::VECTOR_INT8:
EXPECT_GT(proto.int8_vector().size(), 0);
break;
case DataType::VECTOR_BINARY:
EXPECT_GT(proto.binary_vector().size(), 0);
break;
default:
break;
}
}
}
INSTANTIATE_TEST_SUITE_P(
VectorTypes,
VectorArrayChunkParameterizedTest,
::testing::Values(
VectorArrayTestParam{
DataType::VECTOR_FLOAT, 128, 100, 5, "FloatVector"},
VectorArrayTestParam{
DataType::VECTOR_FLOAT16, 128, 100, 3, "Float16Vector"},
VectorArrayTestParam{
DataType::VECTOR_BFLOAT16, 64, 50, 2, "BFloat16Vector"},
VectorArrayTestParam{DataType::VECTOR_INT8, 256, 80, 4, "Int8Vector"},
VectorArrayTestParam{
DataType::VECTOR_BINARY, 512, 60, 3, "BinaryVector"}),
[](const testing::TestParamInfo<VectorArrayTestParam>& info) {
return info.param.test_name;
});

View File

@ -159,8 +159,12 @@ class TestVectorArrayStorageV2 : public testing::Test {
// Create appropriate value builder based on element type // Create appropriate value builder based on element type
std::shared_ptr<arrow::ArrayBuilder> value_builder; std::shared_ptr<arrow::ArrayBuilder> value_builder;
int byte_width = 0;
if (element_type == DataType::VECTOR_FLOAT) { if (element_type == DataType::VECTOR_FLOAT) {
value_builder = std::make_shared<arrow::FloatBuilder>(); byte_width = DIM * sizeof(float);
value_builder =
std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(byte_width));
} else { } else {
FAIL() << "Unsupported element type for VECTOR_ARRAY " FAIL() << "Unsupported element type for VECTOR_ARRAY "
"in test"; "in test";
@ -176,11 +180,13 @@ class TestVectorArrayStorageV2 : public testing::Test {
// Generate 3 vectors for this row // Generate 3 vectors for this row
auto data = generate_float_vector(3, DIM); auto data = generate_float_vector(3, DIM);
auto float_builder = auto binary_builder = std::static_pointer_cast<
std::static_pointer_cast<arrow::FloatBuilder>( arrow::FixedSizeBinaryBuilder>(value_builder);
value_builder); // Append each vector as a fixed-size binary value
for (const auto& value : data) { for (int vec_idx = 0; vec_idx < 3; vec_idx++) {
status = float_builder->Append(value); status = binary_builder->Append(
reinterpret_cast<const uint8_t*>(
data.data() + vec_idx * DIM));
EXPECT_TRUE(status.ok()); EXPECT_TRUE(status.ok());
} }
} }
@ -288,7 +294,7 @@ TEST_F(TestVectorArrayStorageV2, BuildEmbListHNSWIndex) {
milvus::index::CreateIndexInfo create_index_info; milvus::index::CreateIndexInfo create_index_info;
create_index_info.field_type = DataType::VECTOR_ARRAY; create_index_info.field_type = DataType::VECTOR_ARRAY;
create_index_info.metric_type = knowhere::metric::MAX_SIM; create_index_info.metric_type = knowhere::metric::MAX_SIM;
create_index_info.index_type = knowhere::IndexEnum::INDEX_EMB_LIST_HNSW; create_index_info.index_type = knowhere::IndexEnum::INDEX_HNSW;
create_index_info.index_engine_version = create_index_info.index_engine_version =
knowhere::Version::GetCurrentVersion().VersionNumber(); knowhere::Version::GetCurrentVersion().VersionNumber();
@ -299,8 +305,7 @@ TEST_F(TestVectorArrayStorageV2, BuildEmbListHNSWIndex) {
// Build index with storage v2 configuration // Build index with storage v2 configuration
Config config; Config config;
config[milvus::index::INDEX_TYPE] = config[milvus::index::INDEX_TYPE] = knowhere::IndexEnum::INDEX_HNSW;
knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type; config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type;
config[knowhere::indexparam::M] = "16"; config[knowhere::indexparam::M] = "16";
config[knowhere::indexparam::EF] = "10"; config[knowhere::indexparam::EF] = "10";
@ -330,11 +335,12 @@ TEST_F(TestVectorArrayStorageV2, BuildEmbListHNSWIndex) {
std::vector<float> query_vec = generate_float_vector(vec_num, DIM); std::vector<float> query_vec = generate_float_vector(vec_num, DIM);
auto query_dataset = auto query_dataset =
knowhere::GenDataSet(vec_num, DIM, query_vec.data()); knowhere::GenDataSet(vec_num, DIM, query_vec.data());
std::vector<size_t> query_vec_lims; std::vector<size_t> query_vec_offsets;
query_vec_lims.push_back(0); query_vec_offsets.push_back(0);
query_vec_lims.push_back(3); query_vec_offsets.push_back(3);
query_vec_lims.push_back(10); query_vec_offsets.push_back(10);
query_dataset->SetLims(query_vec_lims.data()); query_dataset->Set(knowhere::meta::EMB_LIST_OFFSET,
const_cast<const size_t*>(query_vec_offsets.data()));
auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}}; auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}};
milvus::SearchInfo searchInfo; milvus::SearchInfo searchInfo;

View File

@ -32,19 +32,35 @@ namespace milvus {
using elem_type = std::conditional_t< \ using elem_type = std::conditional_t< \
std::is_same_v<TraitType, milvus::EmbListFloatVector>, \ std::is_same_v<TraitType, milvus::EmbListFloatVector>, \
milvus::EmbListFloatVector::embedded_type, \ milvus::EmbListFloatVector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::EmbListBinaryVector>, \
milvus::EmbListBinaryVector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::EmbListFloat16Vector>, \
milvus::EmbListFloat16Vector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::EmbListBFloat16Vector>, \
milvus::EmbListBFloat16Vector::embedded_type, \
std::conditional_t< \
std::is_same_v<TraitType, milvus::EmbListInt8Vector>, \
milvus::EmbListInt8Vector::embedded_type, \
std::conditional_t< \ std::conditional_t< \
std::is_same_v<TraitType, milvus::FloatVector>, \ std::is_same_v<TraitType, milvus::FloatVector>, \
milvus::FloatVector::embedded_type, \ milvus::FloatVector::embedded_type, \
std::conditional_t< \ std::conditional_t< \
std::is_same_v<TraitType, milvus::Float16Vector>, \ std::is_same_v<TraitType, \
milvus::Float16Vector>, \
milvus::Float16Vector::embedded_type, \ milvus::Float16Vector::embedded_type, \
std::conditional_t< \ std::conditional_t< \
std::is_same_v<TraitType, milvus::BFloat16Vector>, \ std::is_same_v<TraitType, \
milvus::BFloat16Vector>, \
milvus::BFloat16Vector::embedded_type, \ milvus::BFloat16Vector::embedded_type, \
std::conditional_t< \ std::conditional_t< \
std::is_same_v<TraitType, milvus::Int8Vector>, \ std::is_same_v<TraitType, \
milvus::Int8Vector>, \
milvus::Int8Vector::embedded_type, \ milvus::Int8Vector::embedded_type, \
milvus::BinaryVector::embedded_type>>>>>; milvus::BinaryVector:: \
embedded_type>>>>>>>>>;
#define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \ #define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \
auto schema_data_type = \ auto schema_data_type = \
@ -164,6 +180,82 @@ class EmbListFloatVector : public VectorTrait {
} }
}; };
class EmbListBinaryVector : public VectorTrait {
public:
using embedded_type = uint8_t;
static constexpr int32_t dim_factor = 8;
static constexpr auto data_type = DataType::VECTOR_ARRAY;
static constexpr auto c_data_type = CDataType::VectorArray;
static constexpr auto schema_data_type =
proto::schema::DataType::ArrayOfVector;
static constexpr auto vector_type =
proto::plan::VectorType::EmbListBinaryVector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::EmbListBinaryVector;
static constexpr bool
is_embedding_list() {
return true;
}
};
class EmbListFloat16Vector : public VectorTrait {
public:
using embedded_type = float16;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_ARRAY;
static constexpr auto c_data_type = CDataType::VectorArray;
static constexpr auto schema_data_type =
proto::schema::DataType::ArrayOfVector;
static constexpr auto vector_type =
proto::plan::VectorType::EmbListFloat16Vector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::EmbListFloat16Vector;
static constexpr bool
is_embedding_list() {
return true;
}
};
class EmbListBFloat16Vector : public VectorTrait {
public:
using embedded_type = bfloat16;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_ARRAY;
static constexpr auto c_data_type = CDataType::VectorArray;
static constexpr auto schema_data_type =
proto::schema::DataType::ArrayOfVector;
static constexpr auto vector_type =
proto::plan::VectorType::EmbListBFloat16Vector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::EmbListBFloat16Vector;
static constexpr bool
is_embedding_list() {
return true;
}
};
class EmbListInt8Vector : public VectorTrait {
public:
using embedded_type = int8;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_ARRAY;
static constexpr auto c_data_type = CDataType::VectorArray;
static constexpr auto schema_data_type =
proto::schema::DataType::ArrayOfVector;
static constexpr auto vector_type =
proto::plan::VectorType::EmbListInt8Vector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::EmbListInt8Vector;
static constexpr bool
is_embedding_list() {
return true;
}
};
struct FundamentalTag {}; struct FundamentalTag {};
struct StringTag {}; struct StringTag {};

View File

@ -76,7 +76,7 @@ PhyVectorSearchNode::GetOutput() {
auto& ph = placeholder_group_->at(0); auto& ph = placeholder_group_->at(0);
auto src_data = ph.get_blob(); auto src_data = ph.get_blob();
auto src_lims = ph.get_lims(); auto src_offsets = ph.get_offsets();
auto num_queries = ph.num_of_queries_; auto num_queries = ph.num_of_queries_;
milvus::SearchResult search_result; milvus::SearchResult search_result;
@ -94,7 +94,7 @@ PhyVectorSearchNode::GetOutput() {
auto op_context = query_context_->get_op_context(); auto op_context = query_context_->get_op_context();
segment_->vector_search(search_info_, segment_->vector_search(search_info_,
src_data, src_data,
src_lims, src_offsets,
num_queries, num_queries,
query_timestamp_, query_timestamp_,
final_view, final_view,

View File

@ -234,9 +234,42 @@ IndexFactory::VecIndexLoadResource(
num_rows, num_rows,
dim, dim,
config); config);
has_raw_data = break;
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData( case milvus::DataType::VECTOR_FLOAT16:
index_type, index_version, config); resource = knowhere::IndexStaticFaced<knowhere::fp16>::
EstimateLoadResource(index_type,
index_version,
index_size_in_bytes,
num_rows,
dim,
config);
break;
case milvus::DataType::VECTOR_BFLOAT16:
resource = knowhere::IndexStaticFaced<knowhere::bf16>::
EstimateLoadResource(index_type,
index_version,
index_size_in_bytes,
num_rows,
dim,
config);
break;
case milvus::DataType::VECTOR_BINARY:
resource = knowhere::IndexStaticFaced<knowhere::bin1>::
EstimateLoadResource(index_type,
index_version,
index_size_in_bytes,
num_rows,
dim,
config);
break;
case milvus::DataType::VECTOR_INT8:
resource = knowhere::IndexStaticFaced<knowhere::int8>::
EstimateLoadResource(index_type,
index_version,
index_size_in_bytes,
num_rows,
dim,
config);
break; break;
default: default:
@ -247,6 +280,9 @@ IndexFactory::VecIndexLoadResource(
element_type); element_type);
return LoadResourceRequest{0, 0, 0, 0, true}; return LoadResourceRequest{0, 0, 0, 0, true};
} }
// For VectorArray, has_raw_data is always false as get_vector of index does not provide offsets which
// is required for reconstructing the raw data
has_raw_data = false;
break; break;
} }
default: default:
@ -641,6 +677,42 @@ IndexFactory::CreateVectorIndex(
version, version,
use_knowhere_build_pool, use_knowhere_build_pool,
file_manager_context); file_manager_context);
case DataType::VECTOR_FLOAT16: {
return std::make_unique<VectorMemIndex<float16>>(
element_type,
index_type,
metric_type,
version,
use_knowhere_build_pool,
file_manager_context);
}
case DataType::VECTOR_BFLOAT16: {
return std::make_unique<VectorMemIndex<bfloat16>>(
element_type,
index_type,
metric_type,
version,
use_knowhere_build_pool,
file_manager_context);
}
case DataType::VECTOR_BINARY: {
return std::make_unique<VectorMemIndex<bin1>>(
element_type,
index_type,
metric_type,
version,
use_knowhere_build_pool,
file_manager_context);
}
case DataType::VECTOR_INT8: {
return std::make_unique<VectorMemIndex<int8>>(
element_type,
index_type,
metric_type,
version,
use_knowhere_build_pool,
file_manager_context);
}
default: default:
ThrowInfo(NotImplemented, ThrowInfo(NotImplemented,
fmt::format("not implemented data type to " fmt::format("not implemented data type to "

View File

@ -234,21 +234,4 @@ void inline SetBitsetGrowing(void* bitset,
} }
} }
inline size_t
vector_element_size(const DataType data_type) {
switch (data_type) {
case DataType::VECTOR_FLOAT:
return sizeof(float);
case DataType::VECTOR_FLOAT16:
return sizeof(float16);
case DataType::VECTOR_BFLOAT16:
return sizeof(bfloat16);
case DataType::VECTOR_INT8:
return sizeof(int8);
default:
ThrowInfo(UnexpectedError,
fmt::format("invalid data type: {}", data_type));
}
}
} // namespace milvus::index } // namespace milvus::index

View File

@ -313,11 +313,6 @@ VectorMemIndex<T>::BuildWithDataset(const DatasetPtr& dataset,
SetDim(index_.Dim()); SetDim(index_.Dim());
} }
bool
is_embedding_list_index(const IndexType& index_type) {
return index_type == knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
}
template <typename T> template <typename T>
void void
VectorMemIndex<T>::Build(const Config& config) { VectorMemIndex<T>::Build(const Config& config) {
@ -344,30 +339,19 @@ VectorMemIndex<T>::Build(const Config& config) {
total_size += data->Size(); total_size += data->Size();
total_num_rows += data->get_num_rows(); total_num_rows += data->get_num_rows();
// todo(SapdeA): now, vector arrays (embedding list) are serialized
// to parquet by using binary format which does not provide dim
// information so we use this temporary solution.
if (is_embedding_list_index(index_type_)) {
AssertInfo(elem_type_ != DataType::NONE,
"embedding list index must have elem_type");
dim = config[DIM_KEY].get<int64_t>();
} else {
AssertInfo(dim == 0 || dim == data->get_dim(), AssertInfo(dim == 0 || dim == data->get_dim(),
"inconsistent dim value between field datas!"); "inconsistent dim value between field datas!");
dim = data->get_dim(); dim = data->get_dim();
} }
}
auto buf = std::shared_ptr<uint8_t[]>(new uint8_t[total_size]); auto buf = std::shared_ptr<uint8_t[]>(new uint8_t[total_size]);
size_t lim_offset = 0; size_t lim_offset = 0;
std::vector<size_t> lims; std::vector<size_t> offsets;
lims.reserve(total_num_rows + 1);
lims.push_back(lim_offset);
int64_t offset = 0; int64_t offset = 0;
if (!is_embedding_list_index(index_type_)) { // For embedding list index, elem_type_ is not NONE
if (elem_type_ == DataType::NONE) {
// TODO: avoid copying // TODO: avoid copying
for (auto data : field_datas) { for (auto data : field_datas) {
std::memcpy(buf.get() + offset, data->Data(), data->Size()); std::memcpy(buf.get() + offset, data->Data(), data->Size());
@ -375,7 +359,9 @@ VectorMemIndex<T>::Build(const Config& config) {
data.reset(); data.reset();
} }
} else { } else {
auto elem_size = vector_element_size(elem_type_); offsets.reserve(total_num_rows + 1);
offsets.push_back(lim_offset);
auto bytes_per_vec = vector_bytes_per_element(elem_type_, dim);
for (auto data : field_datas) { for (auto data : field_datas) {
auto vec_array_data = auto vec_array_data =
dynamic_cast<FieldData<VectorArray>*>(data.get()); dynamic_cast<FieldData<VectorArray>*>(data.get());
@ -385,16 +371,16 @@ VectorMemIndex<T>::Build(const Config& config) {
auto rows = vec_array_data->get_num_rows(); auto rows = vec_array_data->get_num_rows();
for (auto i = 0; i < rows; ++i) { for (auto i = 0; i < rows; ++i) {
auto size = vec_array_data->DataSize(i); auto size = vec_array_data->DataSize(i);
assert(size % (dim * elem_size) == 0); assert(size % bytes_per_vec == 0);
assert(dim * elem_size != 0); assert(bytes_per_vec != 0);
auto vec_array = vec_array_data->value_at(i); auto vec_array = vec_array_data->value_at(i);
std::memcpy(buf.get() + offset, vec_array->data(), size); std::memcpy(buf.get() + offset, vec_array->data(), size);
offset += size; offset += size;
lim_offset += size / (dim * elem_size); lim_offset += size / bytes_per_vec;
lims.push_back(lim_offset); offsets.push_back(lim_offset);
} }
assert(data->Size() == offset); assert(data->Size() == offset);
@ -411,8 +397,9 @@ VectorMemIndex<T>::Build(const Config& config) {
if (!scalar_info.empty()) { if (!scalar_info.empty()) {
dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info)); dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info));
} }
if (!lims.empty()) { if (!offsets.empty()) {
dataset->SetLims(lims.data()); dataset->Set(knowhere::meta::EMB_LIST_OFFSET,
const_cast<const size_t*>(offsets.data()));
} }
BuildWithDataset(dataset, build_config); BuildWithDataset(dataset, build_config);
} else { } else {

View File

@ -282,12 +282,12 @@ class ChunkedColumnBase : public ChunkedColumnInterface {
"VectorArrayViews only supported for ChunkedVectorArrayColumn"); "VectorArrayViews only supported for ChunkedVectorArrayColumn");
} }
virtual PinWrapper<const size_t*> PinWrapper<const size_t*>
VectorArrayLims(milvus::OpContext* op_ctx, VectorArrayOffsets(milvus::OpContext* op_ctx,
int64_t chunk_id) const override { int64_t chunk_id) const override {
ThrowInfo( ThrowInfo(
ErrorCode::Unsupported, ErrorCode::Unsupported,
"VectorArrayLims only supported for ChunkedVectorArrayColumn"); "VectorArrayOffsets only supported for ChunkedVectorArrayColumn");
} }
PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>> PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>>
@ -691,13 +691,13 @@ class ChunkedVectorArrayColumn : public ChunkedColumnBase {
} }
PinWrapper<const size_t*> PinWrapper<const size_t*>
VectorArrayLims(milvus::OpContext* op_ctx, VectorArrayOffsets(milvus::OpContext* op_ctx,
int64_t chunk_id) const override { int64_t chunk_id) const override {
auto ca = SemiInlineGet( auto ca = SemiInlineGet(
slot_->PinCells(op_ctx, {static_cast<cid_t>(chunk_id)})); slot_->PinCells(op_ctx, {static_cast<cid_t>(chunk_id)}));
auto chunk = ca->get_cell_of(chunk_id); auto chunk = ca->get_cell_of(chunk_id);
return PinWrapper<const size_t*>( return PinWrapper<const size_t*>(
ca, static_cast<VectorArrayChunk*>(chunk)->Lims()); ca, static_cast<VectorArrayChunk*>(chunk)->Offsets());
} }
}; };

View File

@ -328,17 +328,18 @@ class ProxyChunkColumn : public ChunkedColumnInterface {
} }
PinWrapper<const size_t*> PinWrapper<const size_t*>
VectorArrayLims(milvus::OpContext* op_ctx, VectorArrayOffsets(milvus::OpContext* op_ctx,
int64_t chunk_id) const override { int64_t chunk_id) const override {
if (!IsChunkedVectorArrayColumnDataType(data_type_)) { if (!IsChunkedVectorArrayColumnDataType(data_type_)) {
ThrowInfo(ErrorCode::Unsupported, ThrowInfo(ErrorCode::Unsupported,
"VectorArrayLims only supported for " "VectorArrayOffsets only supported for "
"ChunkedVectorArrayColumn"); "ChunkedVectorArrayColumn");
} }
auto chunk_wrapper = group_->GetGroupChunk(op_ctx, chunk_id); auto chunk_wrapper = group_->GetGroupChunk(op_ctx, chunk_id);
auto chunk = chunk_wrapper.get()->GetChunk(field_id_); auto chunk = chunk_wrapper.get()->GetChunk(field_id_);
return PinWrapper<const size_t*>( return PinWrapper<const size_t*>(
chunk_wrapper, static_cast<VectorArrayChunk*>(chunk.get())->Lims()); chunk_wrapper,
static_cast<VectorArrayChunk*>(chunk.get())->Offsets());
} }
PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>> PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>>

View File

@ -92,7 +92,7 @@ class ChunkedColumnInterface {
std::optional<std::pair<int64_t, int64_t>> offset_len) const = 0; std::optional<std::pair<int64_t, int64_t>> offset_len) const = 0;
virtual PinWrapper<const size_t*> virtual PinWrapper<const size_t*>
VectorArrayLims(milvus::OpContext* op_ctx, int64_t chunk_id) const = 0; VectorArrayOffsets(milvus::OpContext* op_ctx, int64_t chunk_id) const = 0;
virtual PinWrapper< virtual PinWrapper<
std::pair<std::vector<std::string_view>, FixedVector<bool>>> std::pair<std::vector<std::string_view>, FixedVector<bool>>>

View File

@ -34,8 +34,23 @@ bool
check_data_type(const FieldMeta& field_meta, check_data_type(const FieldMeta& field_meta,
const milvus::proto::common::PlaceholderType type) { const milvus::proto::common::PlaceholderType type) {
if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
if (field_meta.get_element_type() == DataType::VECTOR_FLOAT) {
return type == return type ==
milvus::proto::common::PlaceholderType::EmbListFloatVector; milvus::proto::common::PlaceholderType::EmbListFloatVector;
} else if (field_meta.get_element_type() == DataType::VECTOR_FLOAT16) {
return type ==
milvus::proto::common::PlaceholderType::EmbListFloat16Vector;
} else if (field_meta.get_element_type() == DataType::VECTOR_BFLOAT16) {
return type == milvus::proto::common::PlaceholderType::
EmbListBFloat16Vector;
} else if (field_meta.get_element_type() == DataType::VECTOR_BINARY) {
return type ==
milvus::proto::common::PlaceholderType::EmbListBinaryVector;
} else if (field_meta.get_element_type() == DataType::VECTOR_INT8) {
return type ==
milvus::proto::common::PlaceholderType::EmbListInt8Vector;
}
return false;
} }
return static_cast<int>(field_meta.get_data_type()) == return static_cast<int>(field_meta.get_data_type()) ==
static_cast<int>(type); static_cast<int>(type);
@ -96,25 +111,25 @@ ParsePlaceholderGroup(const Plan* plan,
// If the vector is embedding list, line contains multiple vectors. // If the vector is embedding list, line contains multiple vectors.
// And we should record the offsets so that we can identify each // And we should record the offsets so that we can identify each
// embedding list in a flattened vectors. // embedding list in a flattened vectors.
auto& lims = element.lims_; auto& offsets = element.offsets_;
lims.reserve(element.num_of_queries_ + 1); offsets.reserve(element.num_of_queries_ + 1);
size_t offset = 0; size_t offset = 0;
lims.push_back(offset); offsets.push_back(offset);
auto elem_size = milvus::index::vector_element_size( auto bytes_per_vec = milvus::vector_bytes_per_element(
field_meta.get_element_type()); field_meta.get_element_type(), dim);
for (auto& line : info.values()) { for (auto& line : info.values()) {
target.insert(target.end(), line.begin(), line.end()); target.insert(target.end(), line.begin(), line.end());
AssertInfo( AssertInfo(
line.size() % (dim * elem_size) == 0, line.size() % bytes_per_vec == 0,
"line.size() % (dim * elem_size) == 0 assert failed, " "line.size() % bytes_per_vec == 0 assert failed, "
"line.size() = {}, dim = {}, elem_size = {}", "line.size() = {}, dim = {}, bytes_per_vec = {}",
line.size(), line.size(),
dim, dim,
elem_size); bytes_per_vec);
offset += line.size() / (dim * elem_size); offset += line.size() / bytes_per_vec;
lims.push_back(offset); offsets.push_back(offset);
} }
} }
} }

View File

@ -70,8 +70,8 @@ struct Plan {
struct Placeholder { struct Placeholder {
std::string tag_; std::string tag_;
// note: for embedding list search, num_of_queries_ stands for the number of vectors. // note: for embedding list search, num_of_queries_ stands for the number of vectors.
// lims_ records the offsets of embedding list in the flattened vector and // offsets_ records the offsets of embedding list in the flattened vector and
// hence lims_.size() - 1 is the number of queries in embedding list search. // hence offsets_.size() - 1 is the number of queries in embedding list search.
int64_t num_of_queries_; int64_t num_of_queries_;
// TODO(SPARSE): add a dim_ field here, use the dim passed in search request // TODO(SPARSE): add a dim_ field here, use the dim passed in search request
// instead of the dim in schema, since the dim of sparse float column is // instead of the dim in schema, since the dim of sparse float column is
@ -84,7 +84,7 @@ struct Placeholder {
std::unique_ptr<knowhere::sparse::SparseRow<SparseValueType>[]> std::unique_ptr<knowhere::sparse::SparseRow<SparseValueType>[]>
sparse_matrix_; sparse_matrix_;
// offsets for embedding list // offsets for embedding list
aligned_vector<size_t> lims_; aligned_vector<size_t> offsets_;
const void* const void*
get_blob() const { get_blob() const {
@ -103,13 +103,13 @@ struct Placeholder {
} }
const size_t* const size_t*
get_lims() const { get_offsets() const {
return lims_.data(); return offsets_.data();
} }
size_t* size_t*
get_lims() { get_offsets() {
return lims_.data(); return offsets_.data();
} }
}; };

View File

@ -89,21 +89,25 @@ PrepareBFDataSet(const dataset::SearchDataset& query_ds,
DataType data_type) { DataType data_type) {
auto base_dataset = auto base_dataset =
knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data); knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data);
if (raw_ds.raw_data_lims != nullptr) { if (raw_ds.raw_data_offsets != nullptr) {
// knowhere::DataSet count vectors in a flattened manner where as the num_raw_data here is the number // knowhere::DataSet count vectors in a flattened manner where as the num_raw_data here is the number
// of embedding lists where each embedding list contains multiple vectors. So we should use the last element // of embedding lists where each embedding list contains multiple vectors. So we should use the last element
// in lims which equals to the total number of vectors. // in offsets which equals to the total number of vectors.
base_dataset->SetLims(raw_ds.raw_data_lims); base_dataset->Set(knowhere::meta::EMB_LIST_OFFSET,
// the length of lims equals to the number of embedding lists + 1 raw_ds.raw_data_offsets);
base_dataset->SetRows(raw_ds.raw_data_lims[raw_ds.num_raw_data]);
// the length of offsets equals to the number of embedding lists + 1
base_dataset->SetRows(raw_ds.raw_data_offsets[raw_ds.num_raw_data]);
} }
auto query_dataset = knowhere::GenDataSet( auto query_dataset = knowhere::GenDataSet(
query_ds.num_queries, query_ds.dim, query_ds.query_data); query_ds.num_queries, query_ds.dim, query_ds.query_data);
if (query_ds.query_lims != nullptr) { if (query_ds.query_offsets != nullptr) {
// ditto // ditto
query_dataset->SetLims(query_ds.query_lims); query_dataset->Set(knowhere::meta::EMB_LIST_OFFSET,
query_dataset->SetRows(query_ds.query_lims[query_ds.num_queries]); query_ds.query_offsets);
query_dataset->SetRows(query_ds.query_offsets[query_ds.num_queries]);
} }
if (data_type == DataType::VECTOR_SPARSE_U32_F32) { if (data_type == DataType::VECTOR_SPARSE_U32_F32) {

View File

@ -73,7 +73,7 @@ void
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
const SearchInfo& info, const SearchInfo& info,
const void* query_data, const void* query_data,
const size_t* query_lims, const size_t* query_offsets,
int64_t num_queries, int64_t num_queries,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,
@ -141,7 +141,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
round_decimal, round_decimal,
dim, dim,
query_data, query_data,
query_lims}; query_offsets};
int32_t current_chunk_id = 0; int32_t current_chunk_id = 0;
// get K1 and B from index for bm25 brute force // get K1 and B from index for bm25 brute force
@ -222,8 +222,8 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
if (data_type == DataType::VECTOR_ARRAY) { if (data_type == DataType::VECTOR_ARRAY) {
AssertInfo( AssertInfo(
query_lims != nullptr, query_offsets != nullptr,
"query_lims is nullptr, but data_type is vector array"); "query_offsets is nullptr, but data_type is vector array");
} }
if (milvus::exec::UseVectorIterator(info)) { if (milvus::exec::UseVectorIterator(info)) {

View File

@ -21,7 +21,7 @@ void
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
const SearchInfo& info, const SearchInfo& info,
const void* query_data, const void* query_data,
const size_t* query_lims, const size_t* query_offsets,
int64_t num_queries, int64_t num_queries,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,

View File

@ -31,7 +31,7 @@ SearchOnSealedIndex(const Schema& schema,
const segcore::SealedIndexingRecord& record, const segcore::SealedIndexingRecord& record,
const SearchInfo& search_info, const SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims, const size_t* query_offsets,
int64_t num_queries, int64_t num_queries,
const BitsetView& bitset, const BitsetView& bitset,
milvus::OpContext* op_context, milvus::OpContext* op_context,
@ -55,15 +55,15 @@ SearchOnSealedIndex(const Schema& schema,
search_info.metric_type_); search_info.metric_type_);
knowhere::DataSetPtr dataset; knowhere::DataSetPtr dataset;
if (query_lims == nullptr) { if (query_offsets == nullptr) {
dataset = knowhere::GenDataSet(num_queries, dim, query_data); dataset = knowhere::GenDataSet(num_queries, dim, query_data);
} else { } else {
// Rather than non-embedding list search where num_queries equals to the number of vectors, // Rather than non-embedding list search where num_queries equals to the number of vectors,
// in embedding list search, multiple vectors form an embedding list and the last element of query_lims // in embedding list search, multiple vectors form an embedding list and the last element of query_offsets
// stands for the total number of vectors. // stands for the total number of vectors.
auto num_vectors = query_lims[num_queries]; auto num_vectors = query_offsets[num_queries];
dataset = knowhere::GenDataSet(num_vectors, dim, query_data); dataset = knowhere::GenDataSet(num_vectors, dim, query_data);
dataset->SetLims(query_lims); dataset->Set(knowhere::meta::EMB_LIST_OFFSET, query_offsets);
} }
dataset->SetIsSparse(is_sparse); dataset->SetIsSparse(is_sparse);
@ -107,7 +107,7 @@ SearchOnSealedColumn(const Schema& schema,
const SearchInfo& search_info, const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info, const std::map<std::string, std::string>& index_info,
const void* query_data, const void* query_data,
const size_t* query_lims, const size_t* query_offsets,
int64_t num_queries, int64_t num_queries,
int64_t row_count, int64_t row_count,
const BitsetView& bitview, const BitsetView& bitview,
@ -128,7 +128,7 @@ SearchOnSealedColumn(const Schema& schema,
search_info.round_decimal_, search_info.round_decimal_,
dim, dim,
query_data, query_data,
query_lims}; query_offsets};
CheckBruteForceSearchParam(field, search_info); CheckBruteForceSearchParam(field, search_info);
@ -158,13 +158,14 @@ SearchOnSealedColumn(const Schema& schema,
auto raw_dataset = auto raw_dataset =
query::dataset::RawDataset{offset, dim, chunk_size, vec_data}; query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
PinWrapper<const size_t*> lims_pw; PinWrapper<const size_t*> offsets_pw;
if (data_type == DataType::VECTOR_ARRAY) { if (data_type == DataType::VECTOR_ARRAY) {
AssertInfo(query_lims != nullptr, AssertInfo(
"query_lims is nullptr, but data_type is vector array"); query_offsets != nullptr,
"query_offsets is nullptr, but data_type is vector array");
lims_pw = column->VectorArrayLims(op_context, i); offsets_pw = column->VectorArrayOffsets(op_context, i);
raw_dataset.raw_data_lims = lims_pw.get(); raw_dataset.raw_data_offsets = offsets_pw.get();
} }
if (milvus::exec::UseVectorIterator(search_info)) { if (milvus::exec::UseVectorIterator(search_info)) {

View File

@ -23,7 +23,7 @@ SearchOnSealedIndex(const Schema& schema,
const segcore::SealedIndexingRecord& record, const segcore::SealedIndexingRecord& record,
const SearchInfo& search_info, const SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims, const size_t* query_offsets,
int64_t num_queries, int64_t num_queries,
const BitsetView& view, const BitsetView& view,
milvus::OpContext* op_context, milvus::OpContext* op_context,
@ -35,7 +35,7 @@ SearchOnSealedColumn(const Schema& schema,
const SearchInfo& search_info, const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info, const std::map<std::string, std::string>& index_info,
const void* query_data, const void* query_data,
const size_t* query_lims, const size_t* query_offsets,
int64_t num_queries, int64_t num_queries,
int64_t row_count, int64_t row_count,
const BitsetView& bitset, const BitsetView& bitset,

View File

@ -24,7 +24,7 @@ struct RawDataset {
int64_t dim; int64_t dim;
int64_t num_raw_data; int64_t num_raw_data;
const void* raw_data; const void* raw_data;
const size_t* raw_data_lims = nullptr; const size_t* raw_data_offsets = nullptr;
}; };
struct SearchDataset { struct SearchDataset {
knowhere::MetricType metric_type; knowhere::MetricType metric_type;
@ -34,7 +34,7 @@ struct SearchDataset {
int64_t dim; int64_t dim;
const void* query_data; const void* query_data;
// used for embedding list query // used for embedding list query
const size_t* query_lims = nullptr; const size_t* query_offsets = nullptr;
}; };
} // namespace dataset } // namespace dataset

View File

@ -756,7 +756,7 @@ ChunkedSegmentSealedImpl::mask_with_delete(BitsetTypeView& bitset,
void void
ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims, const size_t* query_offsets,
int64_t query_count, int64_t query_count,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,
@ -783,7 +783,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
vector_indexings_, vector_indexings_,
binlog_search_info, binlog_search_info,
query_data, query_data,
query_lims, query_offsets,
query_count, query_count,
bitset, bitset,
op_context, op_context,
@ -798,7 +798,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
vector_indexings_, vector_indexings_,
search_info, search_info,
query_data, query_data,
query_lims, query_offsets,
query_count, query_count,
bitset, bitset,
op_context, op_context,
@ -826,7 +826,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
search_info, search_info,
index_info, index_info,
query_data, query_data,
query_lims, query_offsets,
query_count, query_count,
row_count, row_count,
bitset, bitset,

View File

@ -459,7 +459,7 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
void void
vector_search(SearchInfo& search_info, vector_search(SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims, const size_t* query_offsets,
int64_t query_count, int64_t query_count,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,

View File

@ -696,7 +696,7 @@ SegmentGrowingImpl::search_batch_pks(
void void
SegmentGrowingImpl::vector_search(SearchInfo& search_info, SegmentGrowingImpl::vector_search(SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims, const size_t* query_offsets,
int64_t query_count, int64_t query_count,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,
@ -705,7 +705,7 @@ SegmentGrowingImpl::vector_search(SearchInfo& search_info,
query::SearchOnGrowing(*this, query::SearchOnGrowing(*this,
search_info, search_info,
query_data, query_data,
query_lims, query_offsets,
query_count, query_count,
timestamp, timestamp,
bitset, bitset,

View File

@ -359,7 +359,7 @@ class SegmentGrowingImpl : public SegmentGrowing {
void void
vector_search(SearchInfo& search_info, vector_search(SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims, const size_t* query_offsets,
int64_t query_count, int64_t query_count,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,

View File

@ -581,7 +581,7 @@ TEST(GrowingTest, SearchVectorArray) {
config.set_enable_interim_segment_index(true); config.set_enable_interim_segment_index(true);
std::map<std::string, std::string> index_params = { std::map<std::string, std::string> index_params = {
{"index_type", knowhere::IndexEnum::INDEX_EMB_LIST_HNSW}, {"index_type", knowhere::IndexEnum::INDEX_HNSW},
{"metric_type", metric_type}, {"metric_type", metric_type},
{"nlist", "128"}}; {"nlist", "128"}};
std::map<std::string, std::string> type_params = { std::map<std::string, std::string> type_params = {
@ -612,11 +612,11 @@ TEST(GrowingTest, SearchVectorArray) {
int vec_num = 10; // Total number of query vectors int vec_num = 10; // Total number of query vectors
std::vector<float> query_vec = generate_float_vector(vec_num, dim); std::vector<float> query_vec = generate_float_vector(vec_num, dim);
// Create query dataset with lims for VectorArray // Create query dataset with offsets for VectorArray
std::vector<size_t> query_vec_lims; std::vector<size_t> query_vec_offsets;
query_vec_lims.push_back(0); // First query has 3 vectors query_vec_offsets.push_back(0); // First query has 3 vectors
query_vec_lims.push_back(3); query_vec_offsets.push_back(3);
query_vec_lims.push_back(10); // Second query has 7 vectors query_vec_offsets.push_back(10); // Second query has 7 vectors
// Create search plan // Create search plan
const char* raw_plan = R"(vector_anns: < const char* raw_plan = R"(vector_anns: <
@ -636,7 +636,7 @@ TEST(GrowingTest, SearchVectorArray) {
// Use CreatePlaceholderGroupFromBlob for VectorArray // Use CreatePlaceholderGroupFromBlob for VectorArray
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>( auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
vec_num, dim, query_vec.data(), query_vec_lims); vec_num, dim, query_vec.data(), query_vec_offsets);
auto ph_group = auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());

View File

@ -373,14 +373,14 @@ class SegmentInternalInterface : public SegmentInterface {
const std::string& nested_path) const override; const std::string& nested_path) const override;
public: public:
// `query_lims` is not null only for vector array (embedding list) search // `query_offsets` is not null only for vector array (embedding list) search
// where it denotes the number of vectors in each embedding list. The length // where it denotes the number of vectors in each embedding list. The length
// of `query_lims` is the number of queries in the search plus one (the first // of `query_offsets` is the number of queries in the search plus one (the first
// element in query_lims is 0). // element in query_offsets is 0).
virtual void virtual void
vector_search(SearchInfo& search_info, vector_search(SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims, const size_t* query_offsets,
int64_t query_count, int64_t query_count,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,

View File

@ -35,7 +35,7 @@ struct LoadIndexInfo {
int64_t segment_id; int64_t segment_id;
int64_t field_id; int64_t field_id;
DataType field_type; DataType field_type;
// The element type of the field. It's DataType::NONE if field_type is array/vector_array. // The element type of the field. It's not DataType::NONE if field_type is array/vector_array.
DataType element_type; DataType element_type;
bool enable_mmap; bool enable_mmap;
std::string mmap_dir_path; std::string mmap_dir_path;

View File

@ -611,22 +611,40 @@ CreateVectorDataArrayFrom(const void* data_raw,
case DataType::VECTOR_ARRAY: { case DataType::VECTOR_ARRAY: {
auto data = reinterpret_cast<const VectorFieldProto*>(data_raw); auto data = reinterpret_cast<const VectorFieldProto*>(data_raw);
auto vector_type = field_meta.get_element_type(); auto vector_type = field_meta.get_element_type();
switch (vector_type) {
case DataType::VECTOR_FLOAT: {
auto obj = vector_array->mutable_vector_array(); auto obj = vector_array->mutable_vector_array();
obj->set_dim(dim);
// Set element type based on vector type
switch (vector_type) {
case DataType::VECTOR_FLOAT:
obj->set_element_type( obj->set_element_type(
milvus::proto::schema::DataType::FloatVector); milvus::proto::schema::DataType::FloatVector);
obj->set_dim(dim);
for (auto i = 0; i < count; i++) {
*(obj->mutable_data()->Add()) = data[i];
}
break; break;
} case DataType::VECTOR_FLOAT16:
default: { obj->set_element_type(
milvus::proto::schema::DataType::Float16Vector);
break;
case DataType::VECTOR_BFLOAT16:
obj->set_element_type(
milvus::proto::schema::DataType::BFloat16Vector);
break;
case DataType::VECTOR_BINARY:
obj->set_element_type(
milvus::proto::schema::DataType::BinaryVector);
break;
case DataType::VECTOR_INT8:
obj->set_element_type(
milvus::proto::schema::DataType::Int8Vector);
break;
default:
ThrowInfo(NotImplemented, ThrowInfo(NotImplemented,
fmt::format("not implemented vector type {}", fmt::format("not implemented vector type {}",
vector_type)); vector_type));
} }
// Add all vector data
for (auto i = 0; i < count; i++) {
*(obj->mutable_data()->Add()) = data[i];
} }
break; break;
} }

View File

@ -241,15 +241,28 @@ AddPayloadToArrowBuilder(std::shared_ptr<arrow::ArrayBuilder> builder,
if (length > 0) { if (length > 0) {
auto element_type = vector_arrays[0].get_element_type(); auto element_type = vector_arrays[0].get_element_type();
// Validate element type
switch (element_type) { switch (element_type) {
case DataType::VECTOR_FLOAT: { case DataType::VECTOR_FLOAT:
auto value_builder = static_cast<arrow::FloatBuilder*>( case DataType::VECTOR_BINARY:
case DataType::VECTOR_FLOAT16:
case DataType::VECTOR_BFLOAT16:
case DataType::VECTOR_INT8:
break;
default:
ThrowInfo(DataTypeInvalid,
"Unsupported element type in VectorArray: {}",
element_type);
}
// All supported vector types use FixedSizeBinaryBuilder
auto value_builder =
static_cast<arrow::FixedSizeBinaryBuilder*>(
list_builder->value_builder()); list_builder->value_builder());
AssertInfo(value_builder != nullptr, AssertInfo(value_builder != nullptr,
"value_builder must be FloatBuilder for " "value_builder must be FixedSizeBinaryBuilder for "
"FloatVector"); "VectorArray");
arrow::Status ast;
for (int i = 0; i < length; ++i) { for (int i = 0; i < length; ++i) {
auto status = list_builder->Append(); auto status = list_builder->Append();
AssertInfo(status.ok(), AssertInfo(status.ok(),
@ -257,53 +270,20 @@ AddPayloadToArrowBuilder(std::shared_ptr<arrow::ArrayBuilder> builder,
status.ToString()); status.ToString());
const auto& array = vector_arrays[i]; const auto& array = vector_arrays[i];
AssertInfo( AssertInfo(array.get_element_type() == element_type,
array.get_element_type() ==
DataType::VECTOR_FLOAT,
"Inconsistent element types in VectorArray"); "Inconsistent element types in VectorArray");
int num_vectors = array.length(); int num_vectors = array.length();
int dim = array.dim(); auto ast = value_builder->AppendValues(
reinterpret_cast<const uint8_t*>(array.data()),
for (int j = 0; j < num_vectors; ++j) { num_vectors);
auto vec_data = array.get_data<float>(j);
ast =
value_builder->AppendValues(vec_data, dim);
AssertInfo(ast.ok(), AssertInfo(ast.ok(),
"Failed to append list: {}", "Failed to batch append vectors: {}",
ast.ToString()); ast.ToString());
} }
} }
break; break;
} }
case DataType::VECTOR_BINARY:
ThrowInfo(
NotImplemented,
"BinaryVector in VectorArray not implemented yet");
break;
case DataType::VECTOR_FLOAT16:
ThrowInfo(
NotImplemented,
"Float16Vector in VectorArray not implemented yet");
break;
case DataType::VECTOR_BFLOAT16:
ThrowInfo(NotImplemented,
"BFloat16Vector in VectorArray not "
"implemented yet");
break;
case DataType::VECTOR_INT8:
ThrowInfo(
NotImplemented,
"Int8Vector in VectorArray not implemented yet");
break;
default:
ThrowInfo(DataTypeInvalid,
"Unsupported element type in VectorArray: {}",
element_type);
}
}
break;
}
default: { default: {
ThrowInfo(DataTypeInvalid, "unsupported data type {}", data_type); ThrowInfo(DataTypeInvalid, "unsupported data type {}", data_type);
} }
@ -428,7 +408,38 @@ CreateArrowBuilder(DataType data_type, DataType element_type, int dim) {
std::shared_ptr<arrow::ArrayBuilder> value_builder; std::shared_ptr<arrow::ArrayBuilder> value_builder;
switch (element_type) { switch (element_type) {
case DataType::VECTOR_FLOAT: { case DataType::VECTOR_FLOAT: {
value_builder = std::make_shared<arrow::FloatBuilder>(); int byte_width = dim * sizeof(float);
value_builder =
std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(byte_width));
break;
}
case DataType::VECTOR_BINARY: {
int byte_width = (dim + 7) / 8;
value_builder =
std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(byte_width));
break;
}
case DataType::VECTOR_FLOAT16: {
int byte_width = dim * 2;
value_builder =
std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(byte_width));
break;
}
case DataType::VECTOR_BFLOAT16: {
int byte_width = dim * 2;
value_builder =
std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(byte_width));
break;
}
case DataType::VECTOR_INT8: {
int byte_width = dim;
value_builder =
std::make_shared<arrow::FixedSizeBinaryBuilder>(
arrow::fixed_size_binary(byte_width));
break; break;
} }
default: { default: {
@ -610,7 +621,7 @@ CreateArrowSchema(DataType data_type, int dim, DataType element_type) {
"This overload is only for VECTOR_ARRAY type"); "This overload is only for VECTOR_ARRAY type");
AssertInfo(dim > 0, "invalid dim value"); AssertInfo(dim > 0, "invalid dim value");
auto value_type = GetArrowDataTypeForVectorArray(element_type); auto value_type = GetArrowDataTypeForVectorArray(element_type, dim);
auto metadata = arrow::KeyValueMetadata::Make( auto metadata = arrow::KeyValueMetadata::Make(
{ELEMENT_TYPE_KEY_FOR_ARROW, DIM_KEY}, {ELEMENT_TYPE_KEY_FOR_ARROW, DIM_KEY},
{std::to_string(static_cast<int>(element_type)), std::to_string(dim)}); {std::to_string(static_cast<int>(element_type)), std::to_string(dim)});

View File

@ -14,7 +14,7 @@
# Update KNOWHERE_VERSION for the first occurrence # Update KNOWHERE_VERSION for the first occurrence
milvus_add_pkg_config("knowhere") milvus_add_pkg_config("knowhere")
set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES "") set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES "")
set( KNOWHERE_VERSION v2.6.3 ) set( KNOWHERE_VERSION c4d5dd8 )
set( GIT_REPOSITORY "https://github.com/zilliztech/knowhere.git") set( GIT_REPOSITORY "https://github.com/zilliztech/knowhere.git")
message(STATUS "Knowhere repo: ${GIT_REPOSITORY}") message(STATUS "Knowhere repo: ${GIT_REPOSITORY}")

View File

@ -17,6 +17,7 @@
#include "common/Types.h" #include "common/Types.h"
#include "index/IndexFactory.h" #include "index/IndexFactory.h"
#include "knowhere/version.h" #include "knowhere/version.h"
#include "knowhere/comp/index_param.h"
#include "storage/RemoteChunkManagerSingleton.h" #include "storage/RemoteChunkManagerSingleton.h"
#include "storage/Util.h" #include "storage/Util.h"
#include "common/VectorArray.h" #include "common/VectorArray.h"
@ -2304,13 +2305,81 @@ TEST(Sealed, SearchSortedPk) {
EXPECT_EQ(100, offsets2[0].get()); EXPECT_EQ(100, offsets2[0].get());
} }
TEST(Sealed, QueryVectorArrayAllFields) { using VectorArrayTestParam =
std::tuple<DataType, std::string, int, std::string>;
class SealedVectorArrayTest
: public ::testing::TestWithParam<VectorArrayTestParam> {
protected:
DataType element_type;
std::string metric_type;
int dim;
std::string test_name;
void
SetUp() override {
auto param = GetParam();
element_type = std::get<0>(param);
metric_type = std::get<1>(param);
dim = std::get<2>(param);
test_name = std::get<3>(param);
// Ensure dim is valid for binary vectors
if (element_type == DataType::VECTOR_BINARY) {
ASSERT_EQ(dim % 8, 0) << "Binary vector dim must be multiple of 8";
}
}
void
VerifyVectorResults(const VectorFieldProto& result_vec,
const VectorFieldProto& expected_vec,
DataType element_type) {
switch (element_type) {
case DataType::VECTOR_FLOAT: {
auto result_data = result_vec.float_vector().data();
auto expected_data = expected_vec.float_vector().data();
EXPECT_EQ(result_data.size(), expected_data.size());
for (int64_t i = 0; i < result_data.size(); ++i) {
EXPECT_NEAR(result_data[i], expected_data[i], 1e-6f);
}
break;
}
case DataType::VECTOR_BINARY: {
auto result_data = result_vec.binary_vector();
auto expected_data = expected_vec.binary_vector();
EXPECT_EQ(result_data, expected_data);
break;
}
case DataType::VECTOR_FLOAT16: {
auto result_data = result_vec.float16_vector();
auto expected_data = expected_vec.float16_vector();
EXPECT_EQ(result_data, expected_data);
break;
}
case DataType::VECTOR_BFLOAT16: {
auto result_data = result_vec.bfloat16_vector();
auto expected_data = expected_vec.bfloat16_vector();
EXPECT_EQ(result_data, expected_data);
break;
}
case DataType::VECTOR_INT8: {
auto result_data = result_vec.int8_vector();
auto expected_data = expected_vec.int8_vector();
EXPECT_EQ(result_data, expected_data);
break;
}
default:
break;
}
}
};
TEST_P(SealedVectorArrayTest, QueryVectorArrayAllFields) {
auto schema = std::make_shared<Schema>(); auto schema = std::make_shared<Schema>();
auto metric_type = knowhere::metric::MAX_SIM;
int64_t dim = 4;
auto int64_field = schema->AddDebugField("int64", DataType::INT64); auto int64_field = schema->AddDebugField("int64", DataType::INT64);
auto array_vec = schema->AddDebugVectorArrayField( auto array_vec = schema->AddDebugVectorArrayField(
"array_vec", DataType::VECTOR_FLOAT, dim, metric_type); "array_vec", element_type, dim, metric_type);
schema->set_primary_field_id(int64_field); schema->set_primary_field_id(int64_field);
std::map<FieldId, FieldIndexMeta> filedMap{}; std::map<FieldId, FieldIndexMeta> filedMap{};
@ -2329,49 +2398,36 @@ TEST(Sealed, QueryVectorArrayAllFields) {
auto ids_ds = GenRandomIds(dataset_size); auto ids_ds = GenRandomIds(dataset_size);
auto int64_result = segment->bulk_subscript( auto int64_result = segment->bulk_subscript(
nullptr, int64_field, ids_ds->GetIds(), dataset_size); nullptr, int64_field, ids_ds->GetIds(), dataset_size);
auto array_float_vector_result = segment->bulk_subscript( auto array_vector_result = segment->bulk_subscript(
nullptr, array_vec, ids_ds->GetIds(), dataset_size); nullptr, array_vec, ids_ds->GetIds(), dataset_size);
EXPECT_EQ(int64_result->scalars().long_data().data_size(), dataset_size); EXPECT_EQ(int64_result->scalars().long_data().data_size(), dataset_size);
EXPECT_EQ(array_float_vector_result->vectors().vector_array().data_size(), EXPECT_EQ(array_vector_result->vectors().vector_array().data_size(),
dataset_size); dataset_size);
auto verify_float_vectors = [](auto arr1, auto arr2) {
static constexpr float EPSILON = 1e-6;
EXPECT_EQ(arr1.size(), arr2.size());
for (int64_t i = 0; i < arr1.size(); ++i) {
EXPECT_NEAR(arr1[i], arr2[i], EPSILON);
}
};
for (int64_t i = 0; i < dataset_size; ++i) { for (int64_t i = 0; i < dataset_size; ++i) {
auto arrow_array = array_float_vector_result->vectors() auto result_vec =
.vector_array() array_vector_result->vectors().vector_array().data()[i];
.data()[i] auto expected_vec = array_vec_values[ids_ds->GetIds()[i]];
.float_vector() VerifyVectorResults(result_vec, expected_vec, element_type);
.data();
auto expected_array =
array_vec_values[ids_ds->GetIds()[i]].float_vector().data();
verify_float_vectors(arrow_array, expected_array);
} }
EXPECT_EQ(int64_result->valid_data_size(), 0); EXPECT_EQ(int64_result->valid_data_size(), 0);
EXPECT_EQ(array_float_vector_result->valid_data_size(), 0); EXPECT_EQ(array_vector_result->valid_data_size(), 0);
} }
TEST(Sealed, SearchVectorArray) { TEST_P(SealedVectorArrayTest, SearchVectorArray) {
int64_t collection_id = 1; int64_t collection_id = 1;
int64_t partition_id = 2; int64_t partition_id = 2;
int64_t segment_id = 3; int64_t segment_id = 3;
int64_t index_build_id = 4000; int64_t index_build_id = 4000;
int64_t index_version = 4000; int64_t index_version = 4000;
int64_t index_id = 5000; int64_t index_id = 5000;
int64_t dim = 4;
auto schema = std::make_shared<Schema>(); auto schema = std::make_shared<Schema>();
auto metric_type = knowhere::metric::MAX_SIM;
auto int64_field = schema->AddDebugField("int64", DataType::INT64); auto int64_field = schema->AddDebugField("int64", DataType::INT64);
auto array_vec = schema->AddDebugVectorArrayField( auto array_vec = schema->AddDebugVectorArrayField(
"array_vec", DataType::VECTOR_FLOAT, dim, metric_type); "array_vec", element_type, dim, metric_type);
schema->set_primary_field_id(int64_field); schema->set_primary_field_id(int64_field);
auto field_meta = milvus::segcore::gen_field_meta(collection_id, auto field_meta = milvus::segcore::gen_field_meta(collection_id,
@ -2379,7 +2435,7 @@ TEST(Sealed, SearchVectorArray) {
segment_id, segment_id,
array_vec.get(), array_vec.get(),
DataType::VECTOR_ARRAY, DataType::VECTOR_ARRAY,
DataType::VECTOR_FLOAT, element_type,
false); false);
auto index_meta = gen_index_meta( auto index_meta = gen_index_meta(
segment_id, array_vec.get(), index_build_id, index_version); segment_id, array_vec.get(), index_build_id, index_version);
@ -2403,7 +2459,7 @@ TEST(Sealed, SearchVectorArray) {
vector_arrays.push_back(milvus::VectorArray(v)); vector_arrays.push_back(milvus::VectorArray(v));
} }
auto field_data = storage::CreateFieldData( auto field_data = storage::CreateFieldData(
DataType::VECTOR_ARRAY, DataType::VECTOR_FLOAT, false, dim); DataType::VECTOR_ARRAY, element_type, false, dim);
field_data->FillFieldData(vector_arrays.data(), vector_arrays.size()); field_data->FillFieldData(vector_arrays.data(), vector_arrays.size());
// create sealed segment // create sealed segment
@ -2445,8 +2501,8 @@ TEST(Sealed, SearchVectorArray) {
// create index // create index
milvus::index::CreateIndexInfo create_index_info; milvus::index::CreateIndexInfo create_index_info;
create_index_info.field_type = DataType::VECTOR_ARRAY; create_index_info.field_type = DataType::VECTOR_ARRAY;
create_index_info.metric_type = knowhere::metric::MAX_SIM; create_index_info.metric_type = metric_type;
create_index_info.index_type = knowhere::IndexEnum::INDEX_EMB_LIST_HNSW; create_index_info.index_type = knowhere::IndexEnum::INDEX_HNSW;
create_index_info.index_engine_version = create_index_info.index_engine_version =
knowhere::Version::GetCurrentVersion().VersionNumber(); knowhere::Version::GetCurrentVersion().VersionNumber();
@ -2457,8 +2513,7 @@ TEST(Sealed, SearchVectorArray) {
// build index // build index
Config config; Config config;
config[milvus::index::INDEX_TYPE] = config[milvus::index::INDEX_TYPE] = knowhere::IndexEnum::INDEX_HNSW;
knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
config[INSERT_FILES_KEY] = std::vector<std::string>{log_path}; config[INSERT_FILES_KEY] = std::vector<std::string>{log_path};
config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type; config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type;
config[knowhere::indexparam::M] = "16"; config[knowhere::indexparam::M] = "16";
@ -2473,18 +2528,37 @@ TEST(Sealed, SearchVectorArray) {
// search // search
auto vec_num = 10; auto vec_num = 10;
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
auto query_dataset = knowhere::GenDataSet(vec_num, dim, query_vec.data()); // Generate query vectors based on element type
std::vector<size_t> query_vec_lims; std::vector<uint8_t> query_vec_bin;
query_vec_lims.push_back(0); std::vector<float> query_vec_f32;
query_vec_lims.push_back(3); knowhere::DataSetPtr query_dataset;
query_vec_lims.push_back(10); if (element_type == DataType::VECTOR_BINARY) {
query_dataset->SetLims(query_vec_lims.data()); auto byte_dim = (dim + 7) / 8;
auto total_bytes = vec_num * byte_dim;
query_vec_bin.resize(total_bytes);
for (size_t i = 0; i < total_bytes; ++i) {
query_vec_bin[i] = rand() % 256;
}
query_dataset =
knowhere::GenDataSet(vec_num, dim, query_vec_bin.data());
} else {
// For float-like types (FLOAT, FLOAT16, BFLOAT16, INT8)
query_vec_f32 = generate_float_vector(vec_num, dim);
query_dataset =
knowhere::GenDataSet(vec_num, dim, query_vec_f32.data());
}
std::vector<size_t> query_vec_offsets;
query_vec_offsets.push_back(0);
query_vec_offsets.push_back(3);
query_vec_offsets.push_back(10);
query_dataset->Set(knowhere::meta::EMB_LIST_OFFSET,
const_cast<const size_t*>(query_vec_offsets.data()));
auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}}; auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}};
milvus::SearchInfo searchInfo; milvus::SearchInfo searchInfo;
searchInfo.topk_ = 5; searchInfo.topk_ = 5;
searchInfo.metric_type_ = knowhere::metric::MAX_SIM; searchInfo.metric_type_ = metric_type;
searchInfo.search_params_ = search_conf; searchInfo.search_params_ = search_conf;
SearchResult result; SearchResult result;
vec_index->Query(query_dataset, searchInfo, nullptr, nullptr, result); vec_index->Query(query_dataset, searchInfo, nullptr, nullptr, result);
@ -2498,21 +2572,62 @@ TEST(Sealed, SearchVectorArray) {
// brute force search // brute force search
{ {
const char* raw_plan = R"(vector_anns: < std::string raw_plan = fmt::format(R"(vector_anns: <
field_id: 101 field_id: 101
query_info: < query_info: <
topk: 5 topk: 5
round_decimal: 3 round_decimal: 3
metric_type: "MAX_SIM" metric_type: "{}"
search_params: "{\"nprobe\": 10}" search_params: "{{\"nprobe\": 10}}"
> >
placeholder_tag: "$0" placeholder_tag: "$0"
>)"; >)",
auto plan_str = translate_text_plan_to_binary_plan(raw_plan); metric_type);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
auto plan = auto plan =
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size()); CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
vec_num, dim, query_vec.data(), query_vec_lims); // Create placeholder based on element type
milvus::proto::common::PlaceholderGroup ph_group_raw;
if (element_type == DataType::VECTOR_BINARY) {
auto byte_dim = (dim + 7) / 8;
auto total_bytes = vec_num * byte_dim;
std::vector<uint8_t> query_vec(total_bytes);
for (size_t i = 0; i < total_bytes; ++i) {
query_vec[i] = rand() % 256;
}
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListBinaryVector>(
vec_num, dim, query_vec.data(), query_vec_offsets);
} else if (element_type == DataType::VECTOR_FLOAT16) {
std::vector<float> float_vec = generate_float_vector(vec_num, dim);
std::vector<float16> query_vec(vec_num * dim);
for (size_t i = 0; i < vec_num * dim; ++i) {
query_vec[i] = float16(float_vec[i]);
}
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloat16Vector>(
vec_num, dim, query_vec.data(), query_vec_offsets);
} else if (element_type == DataType::VECTOR_BFLOAT16) {
std::vector<float> float_vec = generate_float_vector(vec_num, dim);
std::vector<bfloat16> query_vec(vec_num * dim);
for (size_t i = 0; i < vec_num * dim; ++i) {
query_vec[i] = bfloat16(float_vec[i]);
}
ph_group_raw =
CreatePlaceholderGroupFromBlob<EmbListBFloat16Vector>(
vec_num, dim, query_vec.data(), query_vec_offsets);
} else if (element_type == DataType::VECTOR_INT8) {
std::vector<int8_t> query_vec(vec_num * dim);
for (size_t i = 0; i < vec_num * dim; ++i) {
query_vec[i] = static_cast<int8_t>(rand() % 256 - 128);
}
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListInt8Vector>(
vec_num, dim, query_vec.data(), query_vec_offsets);
} else {
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
vec_num, dim, query_vec.data(), query_vec_offsets);
}
auto ph_group = auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp timestamp = 1000000; Timestamp timestamp = 1000000;
@ -2528,30 +2643,71 @@ TEST(Sealed, SearchVectorArray) {
LoadIndexInfo load_info; LoadIndexInfo load_info;
load_info.field_id = array_vec.get(); load_info.field_id = array_vec.get();
load_info.field_type = DataType::VECTOR_ARRAY; load_info.field_type = DataType::VECTOR_ARRAY;
load_info.element_type = DataType::VECTOR_FLOAT; load_info.element_type = element_type;
load_info.index_params = GenIndexParams(emb_list_hnsw_index.get()); load_info.index_params = GenIndexParams(emb_list_hnsw_index.get());
load_info.cache_index = load_info.cache_index =
CreateTestCacheIndex("test", std::move(emb_list_hnsw_index)); CreateTestCacheIndex("test", std::move(emb_list_hnsw_index));
load_info.index_params["metric_type"] = knowhere::metric::MAX_SIM; load_info.index_params["metric_type"] = metric_type;
sealed_segment->DropFieldData(array_vec); sealed_segment->DropFieldData(array_vec);
sealed_segment->LoadIndex(load_info); sealed_segment->LoadIndex(load_info);
const char* raw_plan = R"(vector_anns: < std::string raw_plan = fmt::format(R"(vector_anns: <
field_id: 101 field_id: 101
query_info: < query_info: <
topk: 5 topk: 5
round_decimal: 3 round_decimal: 3
metric_type: "MAX_SIM" metric_type: "{}"
search_params: "{\"nprobe\": 10}" search_params: "{{\"nprobe\": 10}}"
> >
placeholder_tag: "$0" placeholder_tag: "$0"
>)"; >)",
auto plan_str = translate_text_plan_to_binary_plan(raw_plan); metric_type);
auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
auto plan = auto plan =
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size()); CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
vec_num, dim, query_vec.data(), query_vec_lims); // Create placeholder based on element type
milvus::proto::common::PlaceholderGroup ph_group_raw;
if (element_type == DataType::VECTOR_BINARY) {
auto byte_dim = (dim + 7) / 8;
auto total_bytes = vec_num * byte_dim;
std::vector<uint8_t> query_vec(total_bytes);
for (size_t i = 0; i < total_bytes; ++i) {
query_vec[i] = rand() % 256;
}
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListBinaryVector>(
vec_num, dim, query_vec.data(), query_vec_offsets);
} else if (element_type == DataType::VECTOR_FLOAT16) {
std::vector<float> float_vec = generate_float_vector(vec_num, dim);
std::vector<float16> query_vec(vec_num * dim);
for (size_t i = 0; i < vec_num * dim; ++i) {
query_vec[i] = float16(float_vec[i]);
}
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloat16Vector>(
vec_num, dim, query_vec.data(), query_vec_offsets);
} else if (element_type == DataType::VECTOR_BFLOAT16) {
std::vector<float> float_vec = generate_float_vector(vec_num, dim);
std::vector<bfloat16> query_vec(vec_num * dim);
for (size_t i = 0; i < vec_num * dim; ++i) {
query_vec[i] = bfloat16(float_vec[i]);
}
ph_group_raw =
CreatePlaceholderGroupFromBlob<EmbListBFloat16Vector>(
vec_num, dim, query_vec.data(), query_vec_offsets);
} else if (element_type == DataType::VECTOR_INT8) {
std::vector<int8_t> query_vec(vec_num * dim);
for (size_t i = 0; i < vec_num * dim; ++i) {
query_vec[i] = static_cast<int8_t>(rand() % 256 - 128);
}
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListInt8Vector>(
vec_num, dim, query_vec.data(), query_vec_offsets);
} else {
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
vec_num, dim, query_vec.data(), query_vec_offsets);
}
auto ph_group = auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp timestamp = 1000000; Timestamp timestamp = 1000000;
@ -2562,3 +2718,27 @@ TEST(Sealed, SearchVectorArray) {
std::cout << sr_parsed.dump(1) << std::endl; std::cout << sr_parsed.dump(1) << std::endl;
} }
} }
INSTANTIATE_TEST_SUITE_P(
VectorArrayTypes,
SealedVectorArrayTest,
::testing::Values(
std::make_tuple(DataType::VECTOR_FLOAT, "MAX_SIM", 4, "float_max_sim"),
std::make_tuple(DataType::VECTOR_FLOAT, "MAX_SIM_L2", 4, "float_l2"),
std::make_tuple(
DataType::VECTOR_FLOAT16, "MAX_SIM", 4, "float16_max_sim"),
std::make_tuple(
DataType::VECTOR_FLOAT16, "MAX_SIM_L2", 4, "float16_l2"),
std::make_tuple(
DataType::VECTOR_BFLOAT16, "MAX_SIM", 4, "bfloat16_max_sim"),
std::make_tuple(
DataType::VECTOR_BFLOAT16, "MAX_SIM_L2", 4, "bfloat16_l2"),
std::make_tuple(DataType::VECTOR_INT8, "MAX_SIM", 4, "int8_max_sim"),
std::make_tuple(DataType::VECTOR_INT8, "MAX_SIM_L2", 4, "int8_l2"),
std::make_tuple(
DataType::VECTOR_BINARY, "MAX_SIM_HAMMING", 32, "binary_hamming"),
std::make_tuple(
DataType::VECTOR_BINARY, "MAX_SIM_JACCARD", 32, "binary_jaccard")),
[](const ::testing::TestParamInfo<VectorArrayTestParam>& info) {
return std::get<3>(info.param);
});

View File

@ -249,14 +249,22 @@ func CreateSearchPlanArgs(schema *typeutil.SchemaHelper, exprStr string, vectorF
switch elementType { switch elementType {
case schemapb.DataType_FloatVector: case schemapb.DataType_FloatVector:
vectorType = planpb.VectorType_EmbListFloatVector vectorType = planpb.VectorType_EmbListFloatVector
case schemapb.DataType_BinaryVector:
vectorType = planpb.VectorType_EmbListBinaryVector
case schemapb.DataType_Float16Vector:
vectorType = planpb.VectorType_EmbListFloat16Vector
case schemapb.DataType_BFloat16Vector:
vectorType = planpb.VectorType_EmbListBFloat16Vector
case schemapb.DataType_Int8Vector:
vectorType = planpb.VectorType_EmbListInt8Vector
default: default:
log.Error("Invalid elementType", zap.Any("elementType", elementType)) log.Error("Invalid elementType for ArrayOfVector", zap.Any("elementType", elementType))
return nil, err return nil, fmt.Errorf("unsupported element type for ArrayOfVector: %v", elementType)
} }
default: default:
log.Error("Invalid dataType", zap.Any("dataType", dataType)) log.Error("Invalid dataType", zap.Any("dataType", dataType))
return nil, err return nil, fmt.Errorf("unsupported vector data type: %v", dataType)
} }
scorers, options, err := CreateSearchScorers(schema, functionScorer, exprTemplateValues) scorers, options, err := CreateSearchScorers(schema, functionScorer, exprTemplateValues)

View File

@ -29,7 +29,7 @@ func Test_CheckVecIndexWithDataTypeExist(t *testing.T) {
want bool want bool
}{ }{
{"HNSW", schemapb.DataType_FloatVector, true}, {"HNSW", schemapb.DataType_FloatVector, true},
{"HNSW", schemapb.DataType_BinaryVector, false}, {"HNSW", schemapb.DataType_BinaryVector, true},
{"HNSW", schemapb.DataType_Float16Vector, true}, {"HNSW", schemapb.DataType_Float16Vector, true},
{"SPARSE_WAND", schemapb.DataType_SparseFloatVector, true}, {"SPARSE_WAND", schemapb.DataType_SparseFloatVector, true},

View File

@ -555,7 +555,7 @@ func constructTestCreateIndexRequest(dbName, collectionName string, dataType sch
}, },
{ {
Key: common.IndexTypeKey, Key: common.IndexTypeKey,
Value: "EMB_LIST_HNSW", Value: "HNSW",
}, },
{ {
Key: "nlist", Key: "nlist",
@ -1744,7 +1744,6 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
}) })
fmt.Println("create index for binVec field")
fieldName := ConcatStructFieldName(structField, subFieldFVec) fieldName := ConcatStructFieldName(structField, subFieldFVec)
wg.Add(1) wg.Add(1)
@ -1757,8 +1756,6 @@ func TestProxy(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
}) })
fmt.Println("create index for embedding list field")
wg.Add(1) wg.Add(1)
t.Run("alter index for embedding list field", func(t *testing.T) { t.Run("alter index for embedding list field", func(t *testing.T) {
defer wg.Done() defer wg.Done()
@ -1778,7 +1775,6 @@ func TestProxy(t *testing.T) {
err = merr.CheckRPCCall(resp, err) err = merr.CheckRPCCall(resp, err)
assert.NoError(t, err) assert.NoError(t, err)
}) })
fmt.Println("alter index for embedding list field")
wg.Add(1) wg.Add(1)
t.Run("describe index for embedding list field", func(t *testing.T) { t.Run("describe index for embedding list field", func(t *testing.T) {
@ -1796,7 +1792,6 @@ func TestProxy(t *testing.T) {
enableMmap, _ := common.IsMmapDataEnabled(resp.IndexDescriptions[0].GetParams()...) enableMmap, _ := common.IsMmapDataEnabled(resp.IndexDescriptions[0].GetParams()...)
assert.True(t, enableMmap, "params: %+v", resp.IndexDescriptions[0]) assert.True(t, enableMmap, "params: %+v", resp.IndexDescriptions[0])
}) })
fmt.Println("describe index for embedding list field")
wg.Add(1) wg.Add(1)
t.Run("describe index with indexName for embedding list field", func(t *testing.T) { t.Run("describe index with indexName for embedding list field", func(t *testing.T) {
@ -1812,7 +1807,6 @@ func TestProxy(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
}) })
fmt.Println("describe index with indexName for embedding list field")
wg.Add(1) wg.Add(1)
t.Run("get index statistics for embedding list field", func(t *testing.T) { t.Run("get index statistics for embedding list field", func(t *testing.T) {

View File

@ -71,8 +71,6 @@ const (
RoundDecimalKey = "round_decimal" RoundDecimalKey = "round_decimal"
OffsetKey = "offset" OffsetKey = "offset"
LimitKey = "limit" LimitKey = "limit"
// offsets for embedding list search
LimsKey = "lims"
// key for timestamptz translation // key for timestamptz translation
TimezoneKey = "timezone" TimezoneKey = "timezone"
TimefieldsKey = "time_fields" TimefieldsKey = "time_fields"

View File

@ -1206,7 +1206,7 @@ func Test_checkEmbeddingListIndex(t *testing.T) {
ExtraParams: []*commonpb.KeyValuePair{ ExtraParams: []*commonpb.KeyValuePair{
{ {
Key: common.IndexTypeKey, Key: common.IndexTypeKey,
Value: "EMB_LIST_HNSW", Value: "HNSW",
}, },
{ {
Key: common.MetricTypeKey, Key: common.MetricTypeKey,
@ -1237,7 +1237,7 @@ func Test_checkEmbeddingListIndex(t *testing.T) {
ExtraParams: []*commonpb.KeyValuePair{ ExtraParams: []*commonpb.KeyValuePair{
{ {
Key: common.IndexTypeKey, Key: common.IndexTypeKey,
Value: "EMB_LIST_HNSW", Value: "HNSW",
}, },
{ {
Key: common.MetricTypeKey, Key: common.MetricTypeKey,
@ -1290,37 +1290,6 @@ func Test_checkEmbeddingListIndex(t *testing.T) {
err := cit.parseIndexParams(context.TODO()) err := cit.parseIndexParams(context.TODO())
assert.True(t, strings.Contains(err.Error(), "float vector index does not support metric type: MAX_SIM")) assert.True(t, strings.Contains(err.Error(), "float vector index does not support metric type: MAX_SIM"))
}) })
t.Run("data type wrong", func(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
req: &milvuspb.CreateIndexRequest{
ExtraParams: []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "EMB_LIST_HNSW",
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
},
IndexName: "",
},
fieldSchema: &schemapb.FieldSchema{
FieldID: 101,
Name: "FieldFloat",
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
},
}
err := cit.parseIndexParams(context.TODO())
assert.True(t, strings.Contains(err.Error(), "data type FloatVector can't build with this index EMB_LIST_HNSW"))
})
} }
func Test_ngram_parseIndexParams(t *testing.T) { func Test_ngram_parseIndexParams(t *testing.T) {

View File

@ -1087,12 +1087,61 @@ func (v *validateUtil) checkArrayOfVectorFieldData(field *schemapb.FieldData, fi
return merr.WrapErrParameterInvalid("need float vector array", "got nil", msg) return merr.WrapErrParameterInvalid("need float vector array", "got nil", msg)
} }
if v.checkNAN { if v.checkNAN {
return typeutil.VerifyFloats32(floatVector.GetData()) if err := typeutil.VerifyFloats32(floatVector.GetData()); err != nil {
return err
}
}
}
return nil
case schemapb.DataType_BinaryVector:
for _, vector := range data.GetData() {
binaryVector := vector.GetBinaryVector()
if binaryVector == nil {
msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need binary vector array", "got nil", msg)
}
}
return nil
case schemapb.DataType_Float16Vector:
for _, vector := range data.GetData() {
float16Vector := vector.GetFloat16Vector()
if float16Vector == nil {
msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need float16 vector array", "got nil", msg)
}
if v.checkNAN {
if err := typeutil.VerifyFloats16(float16Vector); err != nil {
return err
}
}
}
return nil
case schemapb.DataType_BFloat16Vector:
for _, vector := range data.GetData() {
bfloat16Vector := vector.GetBfloat16Vector()
if bfloat16Vector == nil {
msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need bfloat16 vector array", "got nil", msg)
}
if v.checkNAN {
if err := typeutil.VerifyBFloats16(bfloat16Vector); err != nil {
return err
}
}
}
return nil
case schemapb.DataType_Int8Vector:
for _, vector := range data.GetData() {
int8Vector := vector.GetInt8Vector()
if int8Vector == nil {
msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName())
return merr.WrapErrParameterInvalid("need int8 vector array", "got nil", msg)
} }
} }
return nil return nil
default: default:
panic("not implemented") msg := fmt.Sprintf("unsupported element type for ArrayOfVector: %v", fieldSchema.GetElementType())
return merr.WrapErrParameterInvalid("supported vector type", fieldSchema.GetElementType().String(), msg)
} }
} }

View File

@ -251,16 +251,18 @@ func appendValueAt(builder array.Builder, a arrow.Array, idx int, defaultValue *
b.Append(true) b.Append(true)
valuesArray := la.ListValues() valuesArray := la.ListValues()
valueBuilder := b.ValueBuilder()
var totalSize uint64 = 0 var totalSize uint64 = 0
valueBuilder := b.ValueBuilder()
switch vb := valueBuilder.(type) { switch vb := valueBuilder.(type) {
case *array.Float32Builder: case *array.FixedSizeBinaryBuilder:
if floatArray, ok := valuesArray.(*array.Float32); ok { fixedArray, ok := valuesArray.(*array.FixedSizeBinary)
for i := start; i < end; i++ { if !ok {
vb.Append(floatArray.Value(int(i))) return 0, fmt.Errorf("invalid value type %T, expect %T", valuesArray.DataType(), vb.Type())
totalSize += 4
} }
for i := start; i < end; i++ {
val := fixedArray.Value(int(i))
vb.Append(val)
totalSize += uint64(len(val))
} }
default: default:
return 0, fmt.Errorf("unsupported value builder type in ListBuilder: %T", valueBuilder) return 0, fmt.Errorf("unsupported value builder type in ListBuilder: %T", valueBuilder)

View File

@ -376,7 +376,12 @@ func NewFieldData(dataType schemapb.DataType, fieldSchema *schemapb.FieldSchema,
} }
return data, nil return data, nil
case schemapb.DataType_ArrayOfVector: case schemapb.DataType_ArrayOfVector:
dim, err := GetDimFromParams(typeParams)
if err != nil {
return nil, err
}
data := &VectorArrayFieldData{ data := &VectorArrayFieldData{
Dim: int64(dim),
Data: make([]*schemapb.VectorField, 0, cap), Data: make([]*schemapb.VectorField, 0, cap),
ElementType: fieldSchema.GetElementType(), ElementType: fieldSchema.GetElementType(),
} }

View File

@ -628,46 +628,17 @@ func readVectorArrayFromListArray(r *PayloadReader) ([]*schemapb.VectorField, er
return nil, fmt.Errorf("expected ListArray, got %T", chunk) return nil, fmt.Errorf("expected ListArray, got %T", chunk)
} }
valuesArray := listArray.ListValues()
switch elementType {
case schemapb.DataType_FloatVector:
floatArray, ok := valuesArray.(*array.Float32)
if !ok {
return nil, fmt.Errorf("expected Float32 array for FloatVector, got %T", valuesArray)
}
// Process each row which contains multiple vectors
for i := 0; i < listArray.Len(); i++ { for i := 0; i < listArray.Len(); i++ {
if listArray.IsNull(i) { value, ok := deserializeArrayOfVector(listArray, i, elementType, dim, true)
if !ok {
return nil, fmt.Errorf("failed to deserialize VectorArray at row %d", len(result))
}
vectorField, _ := value.(*schemapb.VectorField)
if vectorField == nil {
return nil, fmt.Errorf("null value in VectorArray") return nil, fmt.Errorf("null value in VectorArray")
} }
start, end := listArray.ValueOffsets(i)
vectorData := make([]float32, end-start)
copy(vectorData, floatArray.Float32Values()[start:end])
vectorField := &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: vectorData,
},
},
}
result = append(result, vectorField) result = append(result, vectorField)
} }
case schemapb.DataType_BinaryVector:
return nil, fmt.Errorf("BinaryVector in VectorArray not implemented yet")
case schemapb.DataType_Float16Vector:
return nil, fmt.Errorf("Float16Vector in VectorArray not implemented yet")
case schemapb.DataType_BFloat16Vector:
return nil, fmt.Errorf("BFloat16Vector in VectorArray not implemented yet")
case schemapb.DataType_Int8Vector:
return nil, fmt.Errorf("Int8Vector in VectorArray not implemented yet")
default:
return nil, fmt.Errorf("unsupported element type in VectorArray: %s", elementType.String())
}
} }
return result, nil return result, nil

View File

@ -1719,6 +1719,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
// Create VectorArrayFieldData with 3 rows // Create VectorArrayFieldData with 3 rows
vectorArrayData := &VectorArrayFieldData{ vectorArrayData := &VectorArrayFieldData{
Dim: int64(dim),
Data: []*schemapb.VectorField{ Data: []*schemapb.VectorField{
{ {
Dim: int64(dim), Dim: int64(dim),

View File

@ -18,6 +18,7 @@ package storage
import ( import (
"bytes" "bytes"
"encoding/binary"
"fmt" "fmt"
"math" "math"
"sync" "sync"
@ -114,7 +115,10 @@ func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions
if w.elementType == nil { if w.elementType == nil {
return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires elementType, use WithElementType option") return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires elementType, use WithElementType option")
} }
elemType, err := VectorArrayToArrowType(*w.elementType) if w.dim == nil {
return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires dim to be specified")
}
elemType, err := VectorArrayToArrowType(*w.elementType, *w.dim.Value)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -962,13 +966,13 @@ func (w *NativePayloadWriter) AddVectorArrayFieldDataToPayload(data *VectorArray
case schemapb.DataType_FloatVector: case schemapb.DataType_FloatVector:
return w.addFloatVectorArrayToPayload(builder, data) return w.addFloatVectorArrayToPayload(builder, data)
case schemapb.DataType_BinaryVector: case schemapb.DataType_BinaryVector:
return merr.WrapErrParameterInvalidMsg("BinaryVector in VectorArray not implemented yet") return w.addBinaryVectorArrayToPayload(builder, data)
case schemapb.DataType_Float16Vector: case schemapb.DataType_Float16Vector:
return merr.WrapErrParameterInvalidMsg("Float16Vector in VectorArray not implemented yet") return w.addFloat16VectorArrayToPayload(builder, data)
case schemapb.DataType_BFloat16Vector: case schemapb.DataType_BFloat16Vector:
return merr.WrapErrParameterInvalidMsg("BFloat16Vector in VectorArray not implemented yet") return w.addBFloat16VectorArrayToPayload(builder, data)
case schemapb.DataType_Int8Vector: case schemapb.DataType_Int8Vector:
return merr.WrapErrParameterInvalidMsg("Int8Vector in VectorArray not implemented yet") return w.addInt8VectorArrayToPayload(builder, data)
default: default:
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", data.ElementType.String())) return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", data.ElementType.String()))
} }
@ -976,21 +980,159 @@ func (w *NativePayloadWriter) AddVectorArrayFieldDataToPayload(data *VectorArray
// addFloatVectorArrayToPayload handles FloatVector elements in VectorArray // addFloatVectorArrayToPayload handles FloatVector elements in VectorArray
func (w *NativePayloadWriter) addFloatVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error { func (w *NativePayloadWriter) addFloatVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
valueBuilder := builder.ValueBuilder().(*array.Float32Builder) if data.Dim <= 0 {
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
}
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
// Each element in data.Data represents one row of VectorArray
for _, vectorField := range data.Data { for _, vectorField := range data.Data {
if vectorField.GetFloatVector() == nil { if vectorField.GetFloatVector() == nil {
return merr.WrapErrParameterInvalidMsg("expected FloatVector but got different type") return merr.WrapErrParameterInvalidMsg("expected FloatVector but got different type")
} }
// Start a new list for this row
builder.Append(true) builder.Append(true)
floatData := vectorField.GetFloatVector().GetData() floatData := vectorField.GetFloatVector().GetData()
if len(floatData) == 0 {
return merr.WrapErrParameterInvalidMsg("empty vector data not allowed") numVectors := len(floatData) / int(data.Dim)
for i := 0; i < numVectors; i++ {
start := i * int(data.Dim)
end := start + int(data.Dim)
vectorSlice := floatData[start:end]
bytes := make([]byte, data.Dim*4)
for j, f := range vectorSlice {
binary.LittleEndian.PutUint32(bytes[j*4:], math.Float32bits(f))
} }
valueBuilder.AppendValues(floatData, nil) valueBuilder.Append(bytes)
}
}
return nil
}
// addBinaryVectorArrayToPayload handles BinaryVector elements in VectorArray
func (w *NativePayloadWriter) addBinaryVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
if data.Dim <= 0 {
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
}
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
// Each element in data.Data represents one row of VectorArray
for _, vectorField := range data.Data {
if vectorField.GetBinaryVector() == nil {
return merr.WrapErrParameterInvalidMsg("expected BinaryVector but got different type")
}
// Start a new list for this row
builder.Append(true)
binaryData := vectorField.GetBinaryVector()
byteWidth := (data.Dim + 7) / 8
numVectors := len(binaryData) / int(byteWidth)
for i := 0; i < numVectors; i++ {
start := i * int(byteWidth)
end := start + int(byteWidth)
valueBuilder.Append(binaryData[start:end])
}
}
return nil
}
// addFloat16VectorArrayToPayload handles Float16Vector elements in VectorArray
func (w *NativePayloadWriter) addFloat16VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
if data.Dim <= 0 {
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
}
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
// Each element in data.Data represents one row of VectorArray
for _, vectorField := range data.Data {
if vectorField.GetFloat16Vector() == nil {
return merr.WrapErrParameterInvalidMsg("expected Float16Vector but got different type")
}
// Start a new list for this row
builder.Append(true)
float16Data := vectorField.GetFloat16Vector()
byteWidth := data.Dim * 2
numVectors := len(float16Data) / int(byteWidth)
for i := 0; i < numVectors; i++ {
start := i * int(byteWidth)
end := start + int(byteWidth)
valueBuilder.Append(float16Data[start:end])
}
}
return nil
}
// addBFloat16VectorArrayToPayload handles BFloat16Vector elements in VectorArray
func (w *NativePayloadWriter) addBFloat16VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
if data.Dim <= 0 {
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
}
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
// Each element in data.Data represents one row of VectorArray
for _, vectorField := range data.Data {
if vectorField.GetBfloat16Vector() == nil {
return merr.WrapErrParameterInvalidMsg("expected BFloat16Vector but got different type")
}
// Start a new list for this row
builder.Append(true)
bfloat16Data := vectorField.GetBfloat16Vector()
byteWidth := data.Dim * 2
numVectors := len(bfloat16Data) / int(byteWidth)
for i := 0; i < numVectors; i++ {
start := i * int(byteWidth)
end := start + int(byteWidth)
valueBuilder.Append(bfloat16Data[start:end])
}
}
return nil
}
// addInt8VectorArrayToPayload handles Int8Vector elements in VectorArray
func (w *NativePayloadWriter) addInt8VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
if data.Dim <= 0 {
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
}
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
// Each element in data.Data represents one row of VectorArray
for _, vectorField := range data.Data {
if vectorField.GetInt8Vector() == nil {
return merr.WrapErrParameterInvalidMsg("expected Int8Vector but got different type")
}
// Start a new list for this row
builder.Append(true)
int8Data := vectorField.GetInt8Vector()
numVectors := len(int8Data) / int(data.Dim)
for i := 0; i < numVectors; i++ {
start := i * int(data.Dim)
end := start + int(data.Dim)
valueBuilder.Append(int8Data[start:end])
}
} }
return nil return nil

View File

@ -309,6 +309,7 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) {
vectorArrayData := &VectorArrayFieldData{ vectorArrayData := &VectorArrayFieldData{
Data: make([]*schemapb.VectorField, numRows), Data: make([]*schemapb.VectorField, numRows),
ElementType: schemapb.DataType_FloatVector, ElementType: schemapb.DataType_FloatVector,
Dim: int64(dim),
} }
for i := 0; i < numRows; i++ { for i := 0; i < numRows; i++ {
@ -408,6 +409,7 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) {
batchData := &VectorArrayFieldData{ batchData := &VectorArrayFieldData{
Data: make([]*schemapb.VectorField, batchSize), Data: make([]*schemapb.VectorField, batchSize),
ElementType: schemapb.DataType_FloatVector, ElementType: schemapb.DataType_FloatVector,
Dim: int64(dim),
} }
for i := 0; i < batchSize; i++ { for i := 0; i < batchSize; i++ {
@ -454,6 +456,7 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) {
vectorArrayData := &VectorArrayFieldData{ vectorArrayData := &VectorArrayFieldData{
Data: make([]*schemapb.VectorField, numRows), Data: make([]*schemapb.VectorField, numRows),
ElementType: schemapb.DataType_FloatVector, ElementType: schemapb.DataType_FloatVector,
Dim: int64(dim),
} }
for i := 0; i < numRows; i++ { for i := 0; i < numRows; i++ {
@ -485,6 +488,159 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, numRows, length) require.Equal(t, numRows, length)
}) })
t.Run("Test ArrayOfFloat16Vector - Basic", func(t *testing.T) {
dim := 64
numRows := 50
vectorsPerRow := 4
// Create test data
vectorArrayData := &VectorArrayFieldData{
Data: make([]*schemapb.VectorField, numRows),
ElementType: schemapb.DataType_Float16Vector,
Dim: int64(dim),
}
for i := 0; i < numRows; i++ {
// Float16 vectors are stored as bytes (2 bytes per element)
byteData := make([]byte, vectorsPerRow*dim*2)
for j := 0; j < len(byteData); j++ {
byteData[j] = byte((i*1000 + j) % 256)
}
vectorArrayData.Data[i] = &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_Float16Vector{
Float16Vector: byteData,
},
}
}
w, err := NewPayloadWriter(
schemapb.DataType_ArrayOfVector,
WithDim(dim),
WithElementType(schemapb.DataType_Float16Vector),
)
require.NoError(t, err)
require.NotNil(t, w)
err = w.AddVectorArrayFieldDataToPayload(vectorArrayData)
require.NoError(t, err)
err = w.FinishPayloadWriter()
require.NoError(t, err)
// Verify results
buffer, err := w.GetPayloadBufferFromWriter()
require.NoError(t, err)
require.NotEmpty(t, buffer)
length, err := w.GetPayloadLengthFromWriter()
require.NoError(t, err)
require.Equal(t, numRows, length)
})
t.Run("Test ArrayOfBinaryVector - Basic", func(t *testing.T) {
dim := 128 // Must be multiple of 8
numRows := 50
vectorsPerRow := 3
// Create test data
vectorArrayData := &VectorArrayFieldData{
Data: make([]*schemapb.VectorField, numRows),
ElementType: schemapb.DataType_BinaryVector,
Dim: int64(dim),
}
for i := 0; i < numRows; i++ {
// Binary vectors use 1 bit per dimension, so dim/8 bytes per vector
byteData := make([]byte, vectorsPerRow*dim/8)
for j := 0; j < len(byteData); j++ {
byteData[j] = byte((i + j) % 256)
}
vectorArrayData.Data[i] = &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: byteData,
},
}
}
w, err := NewPayloadWriter(
schemapb.DataType_ArrayOfVector,
WithDim(dim),
WithElementType(schemapb.DataType_BinaryVector),
)
require.NoError(t, err)
require.NotNil(t, w)
err = w.AddVectorArrayFieldDataToPayload(vectorArrayData)
require.NoError(t, err)
err = w.FinishPayloadWriter()
require.NoError(t, err)
// Verify results
buffer, err := w.GetPayloadBufferFromWriter()
require.NoError(t, err)
require.NotEmpty(t, buffer)
length, err := w.GetPayloadLengthFromWriter()
require.NoError(t, err)
require.Equal(t, numRows, length)
})
t.Run("Test ArrayOfBFloat16Vector - Basic", func(t *testing.T) {
dim := 64
numRows := 50
vectorsPerRow := 4
// Create test data
vectorArrayData := &VectorArrayFieldData{
Data: make([]*schemapb.VectorField, numRows),
ElementType: schemapb.DataType_BFloat16Vector,
Dim: int64(dim),
}
for i := 0; i < numRows; i++ {
// BFloat16 vectors are stored as bytes (2 bytes per element)
byteData := make([]byte, vectorsPerRow*dim*2)
for j := 0; j < len(byteData); j++ {
byteData[j] = byte((i*100 + j) % 256)
}
vectorArrayData.Data[i] = &schemapb.VectorField{
Dim: int64(dim),
Data: &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: byteData,
},
}
}
w, err := NewPayloadWriter(
schemapb.DataType_ArrayOfVector,
WithDim(dim),
WithElementType(schemapb.DataType_BFloat16Vector),
)
require.NoError(t, err)
require.NotNil(t, w)
err = w.AddVectorArrayFieldDataToPayload(vectorArrayData)
require.NoError(t, err)
err = w.FinishPayloadWriter()
require.NoError(t, err)
// Verify results
buffer, err := w.GetPayloadBufferFromWriter()
require.NoError(t, err)
require.NotEmpty(t, buffer)
length, err := w.GetPayloadLengthFromWriter()
require.NoError(t, err)
require.Equal(t, numRows, length)
})
} }
func TestParquetEncoding(t *testing.T) { func TestParquetEncoding(t *testing.T) {

View File

@ -17,6 +17,7 @@
package storage package storage
import ( import (
"encoding/binary"
"fmt" "fmt"
"io" "io"
"math" "math"
@ -448,8 +449,8 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry {
// ArrayOfVector now implements the standard interface with elementType parameter // ArrayOfVector now implements the standard interface with elementType parameter
m[schemapb.DataType_ArrayOfVector] = serdeEntry{ m[schemapb.DataType_ArrayOfVector] = serdeEntry{
arrowType: func(_ int, elementType schemapb.DataType) arrow.DataType { arrowType: func(dim int, elementType schemapb.DataType) arrow.DataType {
return getArrayOfVectorArrowType(elementType) return getArrayOfVectorArrowType(elementType, dim)
}, },
deserialize: func(a arrow.Array, i int, elementType schemapb.DataType, dim int, shouldCopy bool) (any, bool) { deserialize: func(a arrow.Array, i int, elementType schemapb.DataType, dim int, shouldCopy bool) (any, bool) {
return deserializeArrayOfVector(a, i, elementType, int64(dim), shouldCopy) return deserializeArrayOfVector(a, i, elementType, int64(dim), shouldCopy)
@ -471,24 +472,64 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry {
} }
builder.Append(true) builder.Append(true)
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
dim := vf.GetDim()
appendVectorChunks := func(data []byte, bytesPerVector int) bool {
numVectors := len(data) / bytesPerVector
for i := 0; i < numVectors; i++ {
start := i * bytesPerVector
end := start + bytesPerVector
valueBuilder.Append(data[start:end])
}
return true
}
switch elementType { switch elementType {
case schemapb.DataType_FloatVector: case schemapb.DataType_FloatVector:
if vf.GetFloatVector() == nil { if vf.GetFloatVector() == nil {
return false return false
} }
valueBuilder := builder.ValueBuilder().(*array.Float32Builder) floatData := vf.GetFloatVector().GetData()
valueBuilder.AppendValues(vf.GetFloatVector().GetData(), nil) numVectors := len(floatData) / int(dim)
// Convert float data to binary
for i := 0; i < numVectors; i++ {
start := i * int(dim)
end := start + int(dim)
vectorSlice := floatData[start:end]
bytes := make([]byte, dim*4)
for j, f := range vectorSlice {
binary.LittleEndian.PutUint32(bytes[j*4:], math.Float32bits(f))
}
valueBuilder.Append(bytes)
}
return true return true
case schemapb.DataType_BinaryVector: case schemapb.DataType_BinaryVector:
panic("BinaryVector in VectorArray not implemented yet") if vf.GetBinaryVector() == nil {
return false
}
return appendVectorChunks(vf.GetBinaryVector(), int((dim+7)/8))
case schemapb.DataType_Float16Vector: case schemapb.DataType_Float16Vector:
panic("Float16Vector in VectorArray not implemented yet") if vf.GetFloat16Vector() == nil {
return false
}
return appendVectorChunks(vf.GetFloat16Vector(), int(dim)*2)
case schemapb.DataType_BFloat16Vector: case schemapb.DataType_BFloat16Vector:
panic("BFloat16Vector in VectorArray not implemented yet") if vf.GetBfloat16Vector() == nil {
return false
}
return appendVectorChunks(vf.GetBfloat16Vector(), int(dim)*2)
case schemapb.DataType_Int8Vector: case schemapb.DataType_Int8Vector:
panic("Int8Vector in VectorArray not implemented yet") if vf.GetInt8Vector() == nil {
return false
}
return appendVectorChunks(vf.GetInt8Vector(), int(dim))
case schemapb.DataType_SparseFloatVector: case schemapb.DataType_SparseFloatVector:
panic("SparseFloatVector in VectorArray not implemented yet") panic("SparseFloatVector in VectorArray not implemented yet")
default: default:
@ -806,18 +847,18 @@ func (sfw *singleFieldRecordWriter) Close() error {
} }
// getArrayOfVectorArrowType returns the appropriate Arrow type for ArrayOfVector based on element type // getArrayOfVectorArrowType returns the appropriate Arrow type for ArrayOfVector based on element type
func getArrayOfVectorArrowType(elementType schemapb.DataType) arrow.DataType { func getArrayOfVectorArrowType(elementType schemapb.DataType, dim int) arrow.DataType {
switch elementType { switch elementType {
case schemapb.DataType_FloatVector: case schemapb.DataType_FloatVector:
return arrow.ListOf(arrow.PrimitiveTypes.Float32) return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim * 4})
case schemapb.DataType_BinaryVector: case schemapb.DataType_BinaryVector:
return arrow.ListOf(arrow.PrimitiveTypes.Uint8) return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: (dim + 7) / 8})
case schemapb.DataType_Float16Vector: case schemapb.DataType_Float16Vector:
return arrow.ListOf(arrow.PrimitiveTypes.Uint8) return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim * 2})
case schemapb.DataType_BFloat16Vector: case schemapb.DataType_BFloat16Vector:
return arrow.ListOf(arrow.PrimitiveTypes.Uint8) return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim * 2})
case schemapb.DataType_Int8Vector: case schemapb.DataType_Int8Vector:
return arrow.ListOf(arrow.PrimitiveTypes.Int8) return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim})
case schemapb.DataType_SparseFloatVector: case schemapb.DataType_SparseFloatVector:
return arrow.ListOf(arrow.BinaryTypes.Binary) return arrow.ListOf(arrow.BinaryTypes.Binary)
default: default:
@ -842,54 +883,78 @@ func deserializeArrayOfVector(a arrow.Array, i int, elementType schemapb.DataTyp
return nil, false return nil, false
} }
// Validate dimension for vector types that have fixed dimensions valuesArray := arr.ListValues()
if dim > 0 && totalElements%dim != 0 { binaryArray, ok := valuesArray.(*array.FixedSizeBinary)
// Dimension mismatch - data corruption or schema inconsistency if !ok {
return nil, false // empty array
return nil, true
} }
valuesArray := arr.ListValues() numVectors := int(totalElements)
// Helper function to extract byte vectors from FixedSizeBinary array
extractByteVectors := func(bytesPerVector int64) []byte {
totalBytes := numVectors * int(bytesPerVector)
data := make([]byte, totalBytes)
for j := 0; j < numVectors; j++ {
vectorIndex := int(start) + j
vectorData := binaryArray.Value(vectorIndex)
copy(data[j*int(bytesPerVector):], vectorData)
}
return data
}
switch elementType { switch elementType {
case schemapb.DataType_FloatVector: case schemapb.DataType_FloatVector:
floatArray, ok := valuesArray.(*array.Float32) totalFloats := numVectors * int(dim)
if !ok { floatData := make([]float32, totalFloats)
return nil, false for j := 0; j < numVectors; j++ {
vectorIndex := int(start) + j
binaryData := binaryArray.Value(vectorIndex)
vectorFloats := arrow.Float32Traits.CastFromBytes(binaryData)
copy(floatData[j*int(dim):], vectorFloats)
} }
// Handle data copying based on shouldCopy parameter return &schemapb.VectorField{
var floatData []float32
if shouldCopy {
// Explicitly requested copy
floatData = make([]float32, totalElements)
for j := start; j < end; j++ {
floatData[j-start] = floatArray.Value(int(j))
}
} else {
// Try to avoid copying - use a slice of the underlying data
// This creates a slice that references the same underlying array
allData := floatArray.Float32Values()
floatData = allData[start:end]
}
vectorField := &schemapb.VectorField{
Dim: dim, Dim: dim,
Data: &schemapb.VectorField_FloatVector{ Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{ FloatVector: &schemapb.FloatArray{
Data: floatData, Data: floatData,
}, },
}, },
} }, true
return vectorField, true
case schemapb.DataType_BinaryVector: case schemapb.DataType_BinaryVector:
panic("BinaryVector in VectorArray deserialization not implemented yet") return &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: extractByteVectors((dim + 7) / 8),
},
}, true
case schemapb.DataType_Float16Vector: case schemapb.DataType_Float16Vector:
panic("Float16Vector in VectorArray deserialization not implemented yet") return &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_Float16Vector{
Float16Vector: extractByteVectors(dim * 2),
},
}, true
case schemapb.DataType_BFloat16Vector: case schemapb.DataType_BFloat16Vector:
panic("BFloat16Vector in VectorArray deserialization not implemented yet") return &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_Bfloat16Vector{
Bfloat16Vector: extractByteVectors(dim * 2),
},
}, true
case schemapb.DataType_Int8Vector: case schemapb.DataType_Int8Vector:
panic("Int8Vector in VectorArray deserialization not implemented yet") return &schemapb.VectorField{
Dim: dim,
Data: &schemapb.VectorField_Int8Vector{
Int8Vector: extractByteVectors(dim),
},
}, true
case schemapb.DataType_SparseFloatVector: case schemapb.DataType_SparseFloatVector:
panic("SparseFloatVector in VectorArray deserialization not implemented yet") panic("SparseFloatVector in VectorArray deserialization not implemented yet")
default: default:

View File

@ -268,41 +268,48 @@ func TestCalculateArraySize(t *testing.T) {
} }
func TestArrayOfVectorArrowType(t *testing.T) { func TestArrayOfVectorArrowType(t *testing.T) {
dim := 128 // Test dimension
tests := []struct { tests := []struct {
name string name string
elementType schemapb.DataType elementType schemapb.DataType
dim int
expectedChild arrow.DataType expectedChild arrow.DataType
}{ }{
{ {
name: "FloatVector", name: "FloatVector",
elementType: schemapb.DataType_FloatVector, elementType: schemapb.DataType_FloatVector,
expectedChild: arrow.PrimitiveTypes.Float32, dim: dim,
expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim * 4},
}, },
{ {
name: "BinaryVector", name: "BinaryVector",
elementType: schemapb.DataType_BinaryVector, elementType: schemapb.DataType_BinaryVector,
expectedChild: arrow.PrimitiveTypes.Uint8, dim: dim,
expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: (dim + 7) / 8},
}, },
{ {
name: "Float16Vector", name: "Float16Vector",
elementType: schemapb.DataType_Float16Vector, elementType: schemapb.DataType_Float16Vector,
expectedChild: arrow.PrimitiveTypes.Uint8, dim: dim,
expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2},
}, },
{ {
name: "BFloat16Vector", name: "BFloat16Vector",
elementType: schemapb.DataType_BFloat16Vector, elementType: schemapb.DataType_BFloat16Vector,
expectedChild: arrow.PrimitiveTypes.Uint8, dim: dim,
expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2},
}, },
{ {
name: "Int8Vector", name: "Int8Vector",
elementType: schemapb.DataType_Int8Vector, elementType: schemapb.DataType_Int8Vector,
expectedChild: arrow.PrimitiveTypes.Int8, dim: dim,
expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
arrowType := getArrayOfVectorArrowType(tt.elementType) arrowType := getArrayOfVectorArrowType(tt.elementType, tt.dim)
assert.NotNil(t, arrowType) assert.NotNil(t, arrowType)
listType, ok := arrowType.(*arrow.ListType) listType, ok := arrowType.(*arrow.ListType)

View File

@ -1665,20 +1665,21 @@ func SortFieldBinlogs(fieldBinlogs map[int64]*datapb.FieldBinlog) []*datapb.Fiel
} }
// VectorArrayToArrowType converts VectorArray element type to the corresponding Arrow type // VectorArrayToArrowType converts VectorArray element type to the corresponding Arrow type
// Note: This returns the element type (e.g., float32), not a list type // Note: This returns the element type (e.g., FixedSizeBinary), not a list type
// The caller is responsible for wrapping it in a list if needed // The caller is responsible for wrapping it in a list if needed
func VectorArrayToArrowType(elementType schemapb.DataType) (arrow.DataType, error) { func VectorArrayToArrowType(elementType schemapb.DataType, dim int) (arrow.DataType, error) {
switch elementType { switch elementType {
case schemapb.DataType_FloatVector: case schemapb.DataType_FloatVector:
return arrow.PrimitiveTypes.Float32, nil // Each vector is stored as a fixed-size binary chunk
return &arrow.FixedSizeBinaryType{ByteWidth: dim * 4}, nil
case schemapb.DataType_BinaryVector: case schemapb.DataType_BinaryVector:
return nil, merr.WrapErrParameterInvalidMsg("BinaryVector in VectorArray not implemented yet") return &arrow.FixedSizeBinaryType{ByteWidth: (dim + 7) / 8}, nil
case schemapb.DataType_Float16Vector: case schemapb.DataType_Float16Vector:
return nil, merr.WrapErrParameterInvalidMsg("Float16Vector in VectorArray not implemented yet") return &arrow.FixedSizeBinaryType{ByteWidth: dim * 2}, nil
case schemapb.DataType_BFloat16Vector: case schemapb.DataType_BFloat16Vector:
return nil, merr.WrapErrParameterInvalidMsg("BFloat16Vector in VectorArray not implemented yet") return &arrow.FixedSizeBinaryType{ByteWidth: dim * 2}, nil
case schemapb.DataType_Int8Vector: case schemapb.DataType_Int8Vector:
return nil, merr.WrapErrParameterInvalidMsg("Int8Vector in VectorArray not implemented yet") return &arrow.FixedSizeBinaryType{ByteWidth: dim}, nil
default: default:
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", elementType.String())) return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", elementType.String()))
} }

View File

@ -1845,34 +1845,49 @@ func ReadVectorArrayData(pcr *FieldReader, count int64) (any, error) {
if chunk.NullN() > 0 { if chunk.NullN() > 0 {
return nil, WrapNullRowErr(pcr.field) return nil, WrapNullRowErr(pcr.field)
} }
// ArrayOfVector is stored as list of fixed size binary
listReader, ok := chunk.(*array.List) listReader, ok := chunk.(*array.List)
if !ok { if !ok {
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name()) return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
} }
listFloat32Reader, ok := listReader.ListValues().(*array.Float32)
fixedBinaryReader, ok := listReader.ListValues().(*array.FixedSizeBinary)
if !ok { if !ok {
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name()) return nil, WrapTypeErr(pcr.field, listReader.ListValues().DataType().Name())
} }
// Check that each vector has the correct byte size (dim * 4 bytes for float32)
expectedByteSize := int(dim) * 4
actualByteSize := fixedBinaryReader.DataType().(*arrow.FixedSizeBinaryType).ByteWidth
if actualByteSize != expectedByteSize {
return nil, merr.WrapErrImportFailed(fmt.Sprintf("vector byte size mismatch: expected %d, got %d for field '%s'",
expectedByteSize, actualByteSize, pcr.field.GetName()))
}
offsets := listReader.Offsets() offsets := listReader.Offsets()
for i := 1; i < len(offsets); i++ { for i := 1; i < len(offsets); i++ {
start, end := offsets[i-1], offsets[i] start, end := offsets[i-1], offsets[i]
floatCount := end - start vectorCount := end - start
if floatCount%int32(dim) != 0 {
return nil, merr.WrapErrImportFailed(fmt.Sprintf("vectors in VectorArray should be aligned with dim: %d", dim))
}
arrLength := floatCount / int32(dim) if err = common.CheckArrayCapacity(int(vectorCount), maxCapacity, pcr.field); err != nil {
if err = common.CheckArrayCapacity(int(arrLength), maxCapacity, pcr.field); err != nil {
return nil, err return nil, err
} }
arrData := make([]float32, floatCount) // Convert binary data to float32 array using arrow's built-in conversion
copy(arrData, listFloat32Reader.Float32Values()[start:end]) totalFloats := vectorCount * int32(dim)
floatData := make([]float32, totalFloats)
for j := int32(0); j < vectorCount; j++ {
vectorIndex := start + j
binaryData := fixedBinaryReader.Value(int(vectorIndex))
vectorFloats := arrow.Float32Traits.CastFromBytes(binaryData)
copy(floatData[j*int32(dim):(j+1)*int32(dim)], vectorFloats)
}
data = append(data, &schemapb.VectorField{ data = append(data, &schemapb.VectorField{
Dim: dim, Dim: dim,
Data: &schemapb.VectorField_FloatVector{ Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{ FloatVector: &schemapb.FloatArray{
Data: arrData, Data: floatData,
}, },
}, },
}) })

View File

@ -195,6 +195,8 @@ func isArrowDataTypeConvertible(src arrow.DataType, dst arrow.DataType, field *s
return valid return valid
} }
return false return false
case arrow.FIXED_SIZE_BINARY:
return dstType == arrow.FIXED_SIZE_BINARY
default: default:
return false return false
} }
@ -293,9 +295,11 @@ func convertToArrowDataType(field *schemapb.FieldSchema, isArray bool) (arrow.Da
Metadata: arrow.Metadata{}, Metadata: arrow.Metadata{},
}), nil }), nil
case schemapb.DataType_ArrayOfVector: case schemapb.DataType_ArrayOfVector:
// VectorArrayToArrowType now returns the element type (e.g., float32) dim, err := typeutil.GetDim(field)
// We wrap it in a single list to get list<float32> (flattened) if err != nil {
elemType, err := storage.VectorArrayToArrowType(field.GetElementType()) return nil, err
}
elemType, err := storage.VectorArrayToArrowType(field.GetElementType(), int(dim))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -52,11 +52,12 @@ const (
) )
var ( var (
FloatVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE} // const // all consts
SparseFloatVectorMetrics = []string{metric.IP, metric.BM25} // const FloatVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE}
BinaryVectorMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE, metric.MHJACCARD} // const SparseFloatVectorMetrics = []string{metric.IP, metric.BM25}
IntVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE} // const BinaryVectorMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE, metric.MHJACCARD}
EmbListMetrics = []string{metric.MaxSim} // const IntVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE}
EmbListMetrics = []string{metric.MaxSim, metric.MaxSimCosine, metric.MaxSimL2, metric.MaxSimIP, metric.MaxSimHamming, metric.MaxSimJaccard}
) )
// BinIDMapMetrics is a set of all metric types supported for binary vector. // BinIDMapMetrics is a set of all metric types supported for binary vector.

View File

@ -1,6 +1,7 @@
package testutil package testutil
import ( import (
"encoding/binary"
"fmt" "fmt"
"math" "math"
"math/rand" "math/rand"
@ -734,31 +735,114 @@ func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.Inser
columns = append(columns, builder.NewListArray()) columns = append(columns, builder.NewListArray())
} }
case schemapb.DataType_ArrayOfVector: case schemapb.DataType_ArrayOfVector:
data := insertData.Data[fieldID].(*storage.VectorArrayFieldData).Data vectorArrayData := insertData.Data[fieldID].(*storage.VectorArrayFieldData)
rows := len(data) dim, err := typeutil.GetDim(field)
if err != nil {
return nil, err
}
elemType, err := storage.VectorArrayToArrowType(elementType, int(dim))
if err != nil {
return nil, err
}
// Create ListBuilder with "item" field name to match convertToArrowDataType
// Always represented as a list of fixed-size binary values
listBuilder := array.NewListBuilderWithField(mem, arrow.Field{
Name: "item",
Type: elemType,
Nullable: true,
Metadata: arrow.Metadata{},
})
fixedSizeBuilder, ok := listBuilder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
if !ok {
return nil, fmt.Errorf("unexpected list value builder for VectorArray field %s: %T", field.GetName(), listBuilder.ValueBuilder())
}
vectorArrayData.Dim = dim
bytesPerVector := fixedSizeBuilder.Type().(*arrow.FixedSizeBinaryType).ByteWidth
appendBinarySlice := func(data []byte, stride int) error {
if stride == 0 {
return fmt.Errorf("zero stride for VectorArray field %s", field.GetName())
}
if len(data)%stride != 0 {
return fmt.Errorf("vector array data length %d is not divisible by stride %d for field %s", len(data), stride, field.GetName())
}
for offset := 0; offset < len(data); offset += stride {
fixedSizeBuilder.Append(data[offset : offset+stride])
}
return nil
}
for _, vectorField := range vectorArrayData.Data {
if vectorField == nil {
listBuilder.Append(false)
continue
}
listBuilder.Append(true)
switch elementType { switch elementType {
case schemapb.DataType_FloatVector: case schemapb.DataType_FloatVector:
// ArrayOfVector is flattened in Arrow - just a list of floats floatArray := vectorField.GetFloatVector()
// where total floats = dim * num_vectors if floatArray == nil {
builder := array.NewListBuilder(mem, &arrow.Float32Type{}) return nil, fmt.Errorf("expected FloatVector data for field %s", field.GetName())
valueBuilder := builder.ValueBuilder().(*array.Float32Builder) }
data := floatArray.GetData()
for i := 0; i < rows; i++ { if len(data) == 0 {
vectorArray := data[i].GetFloatVector()
if vectorArray == nil || len(vectorArray.GetData()) == 0 {
builder.AppendNull()
continue continue
} }
builder.Append(true) if len(data)%int(dim) != 0 {
// Append all flattened vector data return nil, fmt.Errorf("float vector data length %d is not divisible by dim %d for field %s", len(data), dim, field.GetName())
valueBuilder.AppendValues(vectorArray.GetData(), nil) }
for offset := 0; offset < len(data); offset += int(dim) {
vectorBytes := make([]byte, bytesPerVector)
for j := 0; j < int(dim); j++ {
binary.LittleEndian.PutUint32(vectorBytes[j*4:], math.Float32bits(data[offset+j]))
}
fixedSizeBuilder.Append(vectorBytes)
}
case schemapb.DataType_BinaryVector:
binaryData := vectorField.GetBinaryVector()
if len(binaryData) == 0 {
continue
}
bytesPer := int((dim + 7) / 8)
if err := appendBinarySlice(binaryData, bytesPer); err != nil {
return nil, err
}
case schemapb.DataType_Float16Vector:
float16Data := vectorField.GetFloat16Vector()
if len(float16Data) == 0 {
continue
}
if err := appendBinarySlice(float16Data, int(dim)*2); err != nil {
return nil, err
}
case schemapb.DataType_BFloat16Vector:
bfloat16Data := vectorField.GetBfloat16Vector()
if len(bfloat16Data) == 0 {
continue
}
if err := appendBinarySlice(bfloat16Data, int(dim)*2); err != nil {
return nil, err
}
case schemapb.DataType_Int8Vector:
int8Data := vectorField.GetInt8Vector()
if len(int8Data) == 0 {
continue
}
if err := appendBinarySlice(int8Data, int(dim)); err != nil {
return nil, err
} }
columns = append(columns, builder.NewListArray())
default: default:
return nil, fmt.Errorf("unsupported element type in VectorArray: %s", elementType.String()) return nil, fmt.Errorf("unsupported element type in VectorArray: %s", elementType.String())
} }
} }
columns = append(columns, listBuilder.NewListArray())
}
} }
return columns, nil return columns, nil
} }

View File

@ -45,6 +45,7 @@ enum VectorType {
EmbListFloat16Vector = 7; EmbListFloat16Vector = 7;
EmbListBFloat16Vector = 8; EmbListBFloat16Vector = 8;
EmbListInt8Vector = 9; EmbListInt8Vector = 9;
EmbListBinaryVector = 10;
}; };
message GenericValue { message GenericValue {

View File

@ -184,6 +184,7 @@ const (
VectorType_EmbListFloat16Vector VectorType = 7 VectorType_EmbListFloat16Vector VectorType = 7
VectorType_EmbListBFloat16Vector VectorType = 8 VectorType_EmbListBFloat16Vector VectorType = 8
VectorType_EmbListInt8Vector VectorType = 9 VectorType_EmbListInt8Vector VectorType = 9
VectorType_EmbListBinaryVector VectorType = 10
) )
// Enum value maps for VectorType. // Enum value maps for VectorType.
@ -199,6 +200,7 @@ var (
7: "EmbListFloat16Vector", 7: "EmbListFloat16Vector",
8: "EmbListBFloat16Vector", 8: "EmbListBFloat16Vector",
9: "EmbListInt8Vector", 9: "EmbListInt8Vector",
10: "EmbListBinaryVector",
} }
VectorType_value = map[string]int32{ VectorType_value = map[string]int32{
"BinaryVector": 0, "BinaryVector": 0,
@ -211,6 +213,7 @@ var (
"EmbListFloat16Vector": 7, "EmbListFloat16Vector": 7,
"EmbListBFloat16Vector": 8, "EmbListBFloat16Vector": 8,
"EmbListInt8Vector": 9, "EmbListInt8Vector": 9,
"EmbListBinaryVector": 10,
} }
) )
@ -3808,7 +3811,7 @@ var file_plan_proto_rawDesc = []byte{
0x0a, 0x03, 0x53, 0x75, 0x62, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x75, 0x6c, 0x10, 0x03, 0x0a, 0x03, 0x53, 0x75, 0x62, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x75, 0x6c, 0x10, 0x03,
0x12, 0x07, 0x0a, 0x03, 0x44, 0x69, 0x76, 0x10, 0x04, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x6f, 0x64, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x69, 0x76, 0x10, 0x04, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x6f, 0x64,
0x10, 0x05, 0x12, 0x0f, 0x0a, 0x0b, 0x41, 0x72, 0x72, 0x61, 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x10, 0x05, 0x12, 0x0f, 0x0a, 0x0b, 0x41, 0x72, 0x72, 0x61, 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74,
0x68, 0x10, 0x06, 0x2a, 0xe1, 0x01, 0x0a, 0x0a, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x54, 0x79, 0x68, 0x10, 0x06, 0x2a, 0xfa, 0x01, 0x0a, 0x0a, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x54, 0x79,
0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x63, 0x74, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x63, 0x74,
0x6f, 0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63, 0x6f, 0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63,
0x74, 0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x74, 0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36,
@ -3822,22 +3825,23 @@ var file_plan_proto_rawDesc = []byte{
0x74, 0x6f, 0x72, 0x10, 0x07, 0x12, 0x19, 0x0a, 0x15, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x74, 0x6f, 0x72, 0x10, 0x07, 0x12, 0x19, 0x0a, 0x15, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74,
0x42, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x08, 0x42, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x08,
0x12, 0x15, 0x0a, 0x11, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x49, 0x6e, 0x74, 0x38, 0x56, 0x12, 0x15, 0x0a, 0x11, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x49, 0x6e, 0x74, 0x38, 0x56,
0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x09, 0x2a, 0x3e, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x09, 0x12, 0x17, 0x0a, 0x13, 0x45, 0x6d, 0x62, 0x4c, 0x69,
0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x12, 0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x73, 0x74, 0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x0a,
0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x57, 0x65, 0x69, 0x67, 0x68, 0x74, 0x10, 0x00, 0x12, 0x2a, 0x3e, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65,
0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x12, 0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65,
0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x10, 0x01, 0x2a, 0x3d, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x57, 0x65, 0x69, 0x67, 0x68, 0x74, 0x10, 0x00, 0x12, 0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63,
0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x14, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x10, 0x01,
0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x79, 0x10, 0x2a, 0x3d, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65,
0x00, 0x12, 0x13, 0x0a, 0x0f, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x12, 0x18, 0x0a, 0x14, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65,
0x65, 0x53, 0x75, 0x6d, 0x10, 0x01, 0x2a, 0x34, 0x0a, 0x09, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x79, 0x10, 0x00, 0x12, 0x13, 0x0a, 0x0f, 0x46, 0x75,
0x6f, 0x64, 0x65, 0x12, 0x15, 0x0a, 0x11, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65, 0x53, 0x75, 0x6d, 0x10, 0x01, 0x2a,
0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x79, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x6f, 0x34, 0x0a, 0x09, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x15, 0x0a, 0x11,
0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x53, 0x75, 0x6d, 0x10, 0x01, 0x42, 0x31, 0x5a, 0x2f, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c,
0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x79, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65,
0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x53, 0x75, 0x6d, 0x10, 0x01, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e,
0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (

View File

@ -44,5 +44,11 @@ const (
EMPTY MetricType = "" EMPTY MetricType = ""
// The same with MaxSimCosine
MaxSim MetricType = "MAX_SIM" MaxSim MetricType = "MAX_SIM"
MaxSimCosine MetricType = "MAX_SIM_COSINE"
MaxSimL2 MetricType = "MAX_SIM_L2"
MaxSimIP MetricType = "MAX_SIM_IP"
MaxSimHamming MetricType = "MAX_SIM_HAMMING"
MaxSimJaccard MetricType = "MAX_SIM_JACCARD"
) )

View File

@ -176,7 +176,7 @@ func (s *ArrayStructDataNodeSuite) loadCollection(collectionName string) {
CollectionName: collectionName, CollectionName: collectionName,
FieldName: subFieldName, FieldName: subFieldName,
IndexName: "array_of_vector_index", IndexName: "array_of_vector_index",
ExtraParams: integration.ConstructIndexParam(s.dim, integration.IndexEmbListHNSW, metric.MaxSim), ExtraParams: integration.ConstructIndexParam(s.dim, integration.IndexHNSW, metric.MaxSim),
}) })
s.NoError(err) s.NoError(err)
s.Require().Equal(createIndexResult.GetErrorCode(), commonpb.ErrorCode_Success) s.Require().Equal(createIndexResult.GetErrorCode(), commonpb.ErrorCode_Success)
@ -318,7 +318,7 @@ func (s *ArrayStructDataNodeSuite) query(collectionName string) {
roundDecimal := -1 roundDecimal := -1
subFieldName := proxy.ConcatStructFieldName(integration.StructArrayField, integration.StructSubFloatVecField) subFieldName := proxy.ConcatStructFieldName(integration.StructArrayField, integration.StructSubFloatVecField)
params := integration.GetSearchParams(integration.IndexEmbListHNSW, metric.MaxSim) params := integration.GetSearchParams(integration.IndexHNSW, metric.MaxSim)
searchReq := integration.ConstructEmbeddingListSearchRequest("", collectionName, expr, searchReq := integration.ConstructEmbeddingListSearchRequest("", collectionName, expr,
subFieldName, schemapb.DataType_FloatVector, []string{integration.StructArrayField}, metric.MaxSim, params, nq, s.dim, topk, roundDecimal) subFieldName, schemapb.DataType_FloatVector, []string{integration.StructArrayField}, metric.MaxSim, params, nq, s.dim, topk, roundDecimal)

View File

@ -246,7 +246,7 @@ func (s *TestArrayStructSuite) run() {
func (s *TestArrayStructSuite) TestGetVector_ArrayStruct_FloatVector() { func (s *TestArrayStructSuite) TestGetVector_ArrayStruct_FloatVector() {
s.nq = 10 s.nq = 10
s.topK = 10 s.topK = 10
s.indexType = integration.IndexEmbListHNSW s.indexType = integration.IndexHNSW
s.metricType = metric.MaxSim s.metricType = metric.MaxSim
s.vecType = schemapb.DataType_FloatVector s.vecType = schemapb.DataType_FloatVector
s.run() s.run()

View File

@ -102,7 +102,7 @@ func (s *BulkInsertSuite) PrepareSourceCollection(dim int, dmlGroup *DMLGroup) *
CollectionName: collectionName, CollectionName: collectionName,
FieldName: name, FieldName: name,
IndexName: "array_of_vector_index", IndexName: "array_of_vector_index",
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexEmbListHNSW, metric.MaxSim), ExtraParams: integration.ConstructIndexParam(dim, integration.IndexHNSW, metric.MaxSim),
}) })
s.NoError(err) s.NoError(err)
s.Require().Equal(createIndexResult.GetErrorCode(), commonpb.ErrorCode_Success) s.Require().Equal(createIndexResult.GetErrorCode(), commonpb.ErrorCode_Success)

View File

@ -303,7 +303,7 @@ func (s *BulkInsertSuite) TestImportWithVectorArray() {
for _, fileType := range fileTypeArr { for _, fileType := range fileTypeArr {
s.fileType = fileType s.fileType = fileType
s.vecType = schemapb.DataType_FloatVector s.vecType = schemapb.DataType_FloatVector
s.indexType = integration.IndexEmbListHNSW s.indexType = integration.IndexHNSW
s.metricType = metric.MaxSim s.metricType = metric.MaxSim
s.runForStructArray() s.runForStructArray()
} }

View File

@ -43,7 +43,6 @@ const (
IndexDISKANN = "DISKANN" IndexDISKANN = "DISKANN"
IndexSparseInvertedIndex = "SPARSE_INVERTED_INDEX" IndexSparseInvertedIndex = "SPARSE_INVERTED_INDEX"
IndexSparseWand = "SPARSE_WAND" IndexSparseWand = "SPARSE_WAND"
IndexEmbListHNSW = "EMB_LIST_HNSW"
) )
func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) { func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {
@ -169,15 +168,6 @@ func ConstructIndexParam(dim int, indexType string, metricType string) []*common
Key: "efConstruction", Key: "efConstruction",
Value: "200", Value: "200",
}) })
case IndexEmbListHNSW:
params = append(params, &commonpb.KeyValuePair{
Key: "M",
Value: "16",
})
params = append(params, &commonpb.KeyValuePair{
Key: "efConstruction",
Value: "200",
})
case IndexSparseInvertedIndex: case IndexSparseInvertedIndex:
case IndexSparseWand: case IndexSparseWand:
case IndexDISKANN: case IndexDISKANN:
@ -195,7 +185,6 @@ func GetSearchParams(indexType string, metricType string) map[string]any {
case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8, IndexFaissIvfPQ, IndexScaNN: case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8, IndexFaissIvfPQ, IndexScaNN:
params["nprobe"] = 8 params["nprobe"] = 8
case IndexHNSW: case IndexHNSW:
case IndexEmbListHNSW:
params["ef"] = 200 params["ef"] = 200
case IndexDISKANN: case IndexDISKANN:
params["search_list"] = 20 params["search_list"] = 20