mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
feat: impl StructArray -- support more types of vector in STRUCT (#44736)
ref: https://github.com/milvus-io/milvus/issues/42148 --------- Signed-off-by: SpadeA <tangchenjie1210@gmail.com> Signed-off-by: SpadeA-Tang <tangchenjie1210@gmail.com>
This commit is contained in:
parent
5ad8a29c0b
commit
c4f3f0ce4c
@ -438,11 +438,11 @@ class VectorArrayChunk : public Chunk {
|
|||||||
offsets_lens_ = reinterpret_cast<uint32_t*>(data);
|
offsets_lens_ = reinterpret_cast<uint32_t*>(data);
|
||||||
|
|
||||||
auto offset = 0;
|
auto offset = 0;
|
||||||
lims_.reserve(row_nums_ + 1);
|
offsets_.reserve(row_nums_ + 1);
|
||||||
lims_.push_back(offset);
|
offsets_.push_back(offset);
|
||||||
for (int64_t i = 0; i < row_nums_; i++) {
|
for (int64_t i = 0; i < row_nums_; i++) {
|
||||||
offset += offsets_lens_[i * 2 + 1];
|
offset += offsets_lens_[i * 2 + 1];
|
||||||
lims_.push_back(offset);
|
offsets_.push_back(offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -507,17 +507,15 @@ class VectorArrayChunk : public Chunk {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const size_t*
|
const size_t*
|
||||||
Lims() const {
|
Offsets() const {
|
||||||
return lims_.data();
|
return offsets_.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int64_t dim_;
|
int64_t dim_;
|
||||||
uint32_t* offsets_lens_;
|
uint32_t* offsets_lens_;
|
||||||
milvus::DataType element_type_;
|
milvus::DataType element_type_;
|
||||||
// The name 'Lims' is consistent with knowhere::DataSet::SetLims which describes the number of vectors
|
std::vector<size_t> offsets_;
|
||||||
// in each vector array (embedding list). This is needed as vectors are flattened in the chunk.
|
|
||||||
std::vector<size_t> lims_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class SparseFloatVectorChunk : public Chunk {
|
class SparseFloatVectorChunk : public Chunk {
|
||||||
|
|||||||
@ -339,34 +339,9 @@ VectorArrayChunkWriter::write(const arrow::ArrayVector& array_vec) {
|
|||||||
target_ = std::make_shared<MemChunkTarget>(total_size);
|
target_ = std::make_shared<MemChunkTarget>(total_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (element_type_) {
|
// Seirialization, the format is: [offsets_lens][all_vector_data_concatenated]
|
||||||
case milvus::DataType::VECTOR_FLOAT:
|
|
||||||
writeFloatVectorArray(array_vec);
|
|
||||||
break;
|
|
||||||
case milvus::DataType::VECTOR_BINARY:
|
|
||||||
ThrowInfo(NotImplemented,
|
|
||||||
"BinaryVector in VectorArray not implemented yet");
|
|
||||||
case milvus::DataType::VECTOR_FLOAT16:
|
|
||||||
ThrowInfo(NotImplemented,
|
|
||||||
"Float16Vector in VectorArray not implemented yet");
|
|
||||||
case milvus::DataType::VECTOR_BFLOAT16:
|
|
||||||
ThrowInfo(NotImplemented,
|
|
||||||
"BFloat16Vector in VectorArray not implemented yet");
|
|
||||||
case milvus::DataType::VECTOR_INT8:
|
|
||||||
ThrowInfo(NotImplemented,
|
|
||||||
"Int8Vector in VectorArray not implemented yet");
|
|
||||||
default:
|
|
||||||
ThrowInfo(NotImplemented,
|
|
||||||
"Unsupported element type in VectorArray: {}",
|
|
||||||
static_cast<int>(element_type_));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void
|
|
||||||
VectorArrayChunkWriter::writeFloatVectorArray(
|
|
||||||
const arrow::ArrayVector& array_vec) {
|
|
||||||
std::vector<uint32_t> offsets_lens;
|
std::vector<uint32_t> offsets_lens;
|
||||||
std::vector<const float*> float_data_ptrs;
|
std::vector<const uint8_t*> vector_data_ptrs;
|
||||||
std::vector<size_t> data_sizes;
|
std::vector<size_t> data_sizes;
|
||||||
|
|
||||||
uint32_t current_offset =
|
uint32_t current_offset =
|
||||||
@ -375,25 +350,27 @@ VectorArrayChunkWriter::writeFloatVectorArray(
|
|||||||
for (const auto& array_data : array_vec) {
|
for (const auto& array_data : array_vec) {
|
||||||
auto list_array =
|
auto list_array =
|
||||||
std::static_pointer_cast<arrow::ListArray>(array_data);
|
std::static_pointer_cast<arrow::ListArray>(array_data);
|
||||||
auto float_values =
|
auto binary_values =
|
||||||
std::static_pointer_cast<arrow::FloatArray>(list_array->values());
|
std::static_pointer_cast<arrow::FixedSizeBinaryArray>(
|
||||||
const float* raw_floats = float_values->raw_values();
|
list_array->values());
|
||||||
const int32_t* list_offsets = list_array->raw_value_offsets();
|
const int32_t* list_offsets = list_array->raw_value_offsets();
|
||||||
|
int byte_width = binary_values->byte_width();
|
||||||
|
|
||||||
// Generate offsets and lengths for each row
|
// Generate offsets and lengths for each row
|
||||||
// Each list contains multiple float vectors which are flattened, so the float count
|
// Each list contains multiple vectors, each stored as a fixed-size binary chunk
|
||||||
// in each list is vector count * dim.
|
|
||||||
for (int64_t i = 0; i < list_array->length(); i++) {
|
for (int64_t i = 0; i < list_array->length(); i++) {
|
||||||
auto start_idx = list_offsets[i];
|
auto start_idx = list_offsets[i];
|
||||||
auto end_idx = list_offsets[i + 1];
|
auto end_idx = list_offsets[i + 1];
|
||||||
auto vector_count = (end_idx - start_idx) / dim_;
|
auto vector_count = end_idx - start_idx;
|
||||||
auto byte_size = (end_idx - start_idx) * sizeof(float);
|
auto byte_size = vector_count * byte_width;
|
||||||
|
|
||||||
offsets_lens.push_back(current_offset);
|
offsets_lens.push_back(current_offset);
|
||||||
offsets_lens.push_back(static_cast<uint32_t>(vector_count));
|
offsets_lens.push_back(static_cast<uint32_t>(vector_count));
|
||||||
|
|
||||||
float_data_ptrs.push_back(raw_floats + start_idx);
|
for (int j = start_idx; j < end_idx; j++) {
|
||||||
data_sizes.push_back(byte_size);
|
vector_data_ptrs.push_back(binary_values->GetValue(j));
|
||||||
|
data_sizes.push_back(byte_width);
|
||||||
|
}
|
||||||
|
|
||||||
current_offset += byte_size;
|
current_offset += byte_size;
|
||||||
}
|
}
|
||||||
@ -409,8 +386,8 @@ VectorArrayChunkWriter::writeFloatVectorArray(
|
|||||||
}
|
}
|
||||||
target_->write(&offsets_lens.back(), sizeof(uint32_t)); // final offset
|
target_->write(&offsets_lens.back(), sizeof(uint32_t)); // final offset
|
||||||
|
|
||||||
for (size_t i = 0; i < float_data_ptrs.size(); i++) {
|
for (size_t i = 0; i < vector_data_ptrs.size(); i++) {
|
||||||
target_->write(float_data_ptrs[i], data_sizes[i]);
|
target_->write(vector_data_ptrs[i], data_sizes[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -427,19 +404,18 @@ VectorArrayChunkWriter::calculateTotalSize(
|
|||||||
std::static_pointer_cast<arrow::ListArray>(array_data);
|
std::static_pointer_cast<arrow::ListArray>(array_data);
|
||||||
|
|
||||||
switch (element_type_) {
|
switch (element_type_) {
|
||||||
case milvus::DataType::VECTOR_FLOAT: {
|
case milvus::DataType::VECTOR_FLOAT:
|
||||||
auto float_values = std::static_pointer_cast<arrow::FloatArray>(
|
|
||||||
list_array->values());
|
|
||||||
total_size += float_values->length() * sizeof(float);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case milvus::DataType::VECTOR_BINARY:
|
case milvus::DataType::VECTOR_BINARY:
|
||||||
case milvus::DataType::VECTOR_FLOAT16:
|
case milvus::DataType::VECTOR_FLOAT16:
|
||||||
case milvus::DataType::VECTOR_BFLOAT16:
|
case milvus::DataType::VECTOR_BFLOAT16:
|
||||||
case milvus::DataType::VECTOR_INT8:
|
case milvus::DataType::VECTOR_INT8: {
|
||||||
ThrowInfo(NotImplemented,
|
auto binary_values =
|
||||||
"Element type {} in VectorArray not implemented yet",
|
std::static_pointer_cast<arrow::FixedSizeBinaryArray>(
|
||||||
static_cast<int>(element_type_));
|
list_array->values());
|
||||||
|
total_size +=
|
||||||
|
binary_values->length() * binary_values->byte_width();
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
ThrowInfo(DataTypeInvalid,
|
ThrowInfo(DataTypeInvalid,
|
||||||
"Invalid element type {} for VectorArray",
|
"Invalid element type {} for VectorArray",
|
||||||
|
|||||||
@ -273,9 +273,6 @@ class VectorArrayChunkWriter : public ChunkWriterBase {
|
|||||||
finish() override;
|
finish() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void
|
|
||||||
writeFloatVectorArray(const arrow::ArrayVector& array_vec);
|
|
||||||
|
|
||||||
size_t
|
size_t
|
||||||
calculateTotalSize(const arrow::ArrayVector& array_vec);
|
calculateTotalSize(const arrow::ArrayVector& array_vec);
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@
|
|||||||
#include "common/Exception.h"
|
#include "common/Exception.h"
|
||||||
#include "common/FieldDataInterface.h"
|
#include "common/FieldDataInterface.h"
|
||||||
#include "common/Json.h"
|
#include "common/Json.h"
|
||||||
|
#include "index/Utils.h"
|
||||||
#include "simdjson/padded_string.h"
|
#include "simdjson/padded_string.h"
|
||||||
|
|
||||||
namespace milvus {
|
namespace milvus {
|
||||||
@ -348,46 +349,47 @@ FieldDataImpl<Type, is_type_entire_row>::FillFieldData(
|
|||||||
std::vector<VectorArray> values(element_count);
|
std::vector<VectorArray> values(element_count);
|
||||||
|
|
||||||
switch (element_type) {
|
switch (element_type) {
|
||||||
case DataType::VECTOR_FLOAT: {
|
case DataType::VECTOR_FLOAT:
|
||||||
auto float_array =
|
case DataType::VECTOR_BINARY:
|
||||||
std::dynamic_pointer_cast<arrow::FloatArray>(
|
case DataType::VECTOR_FLOAT16:
|
||||||
|
case DataType::VECTOR_BFLOAT16:
|
||||||
|
case DataType::VECTOR_INT8: {
|
||||||
|
// All vector types use FixedSizeBinaryArray and have the same serialization logic
|
||||||
|
auto binary_array =
|
||||||
|
std::dynamic_pointer_cast<arrow::FixedSizeBinaryArray>(
|
||||||
values_array);
|
values_array);
|
||||||
AssertInfo(
|
AssertInfo(binary_array != nullptr,
|
||||||
float_array != nullptr,
|
"Expected FixedSizeBinaryArray for VectorArray "
|
||||||
"Expected FloatArray for VECTOR_FLOAT element type");
|
"element");
|
||||||
|
|
||||||
|
// Calculate bytes per vector using the unified function
|
||||||
|
auto bytes_per_vec =
|
||||||
|
milvus::vector_bytes_per_element(element_type, dim);
|
||||||
|
|
||||||
for (size_t index = 0; index < element_count; ++index) {
|
for (size_t index = 0; index < element_count; ++index) {
|
||||||
int64_t start_offset = list_array->value_offset(index);
|
int64_t start_offset = list_array->value_offset(index);
|
||||||
int64_t end_offset =
|
int64_t end_offset =
|
||||||
list_array->value_offset(index + 1);
|
list_array->value_offset(index + 1);
|
||||||
int64_t num_floats = end_offset - start_offset;
|
int64_t num_vectors = end_offset - start_offset;
|
||||||
AssertInfo(num_floats % dim == 0,
|
|
||||||
"Invalid data: number of floats ({}) not "
|
|
||||||
"divisible by "
|
|
||||||
"dimension ({})",
|
|
||||||
num_floats,
|
|
||||||
dim);
|
|
||||||
|
|
||||||
int num_vectors = num_floats / dim;
|
auto data_size = num_vectors * bytes_per_vec;
|
||||||
const float* data_ptr =
|
auto data_ptr = std::make_unique<uint8_t[]>(data_size);
|
||||||
float_array->raw_values() + start_offset;
|
|
||||||
values[index] =
|
for (int64_t i = 0; i < num_vectors; i++) {
|
||||||
VectorArray(static_cast<const void*>(data_ptr),
|
const uint8_t* binary_data =
|
||||||
|
binary_array->GetValue(start_offset + i);
|
||||||
|
uint8_t* dest = data_ptr.get() + i * bytes_per_vec;
|
||||||
|
std::memcpy(dest, binary_data, bytes_per_vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
values[index] = VectorArray(
|
||||||
|
static_cast<const void*>(data_ptr.get()),
|
||||||
num_vectors,
|
num_vectors,
|
||||||
dim,
|
dim,
|
||||||
element_type);
|
element_type);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DataType::VECTOR_BINARY:
|
|
||||||
case DataType::VECTOR_FLOAT16:
|
|
||||||
case DataType::VECTOR_BFLOAT16:
|
|
||||||
case DataType::VECTOR_INT8:
|
|
||||||
ThrowInfo(
|
|
||||||
NotImplemented,
|
|
||||||
"Element type {} in VectorArray not implemented yet",
|
|
||||||
GetDataTypeName(element_type));
|
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
ThrowInfo(DataTypeInvalid,
|
ThrowInfo(DataTypeInvalid,
|
||||||
"Unsupported element type {} in VectorArray",
|
"Unsupported element type {} in VectorArray",
|
||||||
|
|||||||
@ -94,8 +94,8 @@ Schema::ConvertToArrowSchema() const {
|
|||||||
std::shared_ptr<arrow::DataType> arrow_data_type = nullptr;
|
std::shared_ptr<arrow::DataType> arrow_data_type = nullptr;
|
||||||
auto data_type = meta.get_data_type();
|
auto data_type = meta.get_data_type();
|
||||||
if (data_type == DataType::VECTOR_ARRAY) {
|
if (data_type == DataType::VECTOR_ARRAY) {
|
||||||
arrow_data_type =
|
arrow_data_type = GetArrowDataTypeForVectorArray(
|
||||||
GetArrowDataTypeForVectorArray(meta.get_element_type());
|
meta.get_element_type(), meta.get_dim());
|
||||||
} else {
|
} else {
|
||||||
arrow_data_type = GetArrowDataType(data_type, dim);
|
arrow_data_type = GetArrowDataType(data_type, dim);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -205,10 +205,23 @@ GetArrowDataType(DataType data_type, int dim = 1) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline std::shared_ptr<arrow::DataType>
|
inline std::shared_ptr<arrow::DataType>
|
||||||
GetArrowDataTypeForVectorArray(DataType elem_type) {
|
GetArrowDataTypeForVectorArray(DataType elem_type, int dim) {
|
||||||
|
if (dim <= 0) {
|
||||||
|
ThrowInfo(DataTypeInvalid, "dim must be provided for VectorArray");
|
||||||
|
}
|
||||||
|
// VectorArray stores vectors as FixedSizeBinaryArray
|
||||||
|
// We must have dim to create the correct fixed_size_binary type
|
||||||
switch (elem_type) {
|
switch (elem_type) {
|
||||||
case DataType::VECTOR_FLOAT:
|
case DataType::VECTOR_FLOAT:
|
||||||
return arrow::list(arrow::float32());
|
return arrow::list(arrow::fixed_size_binary(dim * sizeof(float)));
|
||||||
|
case DataType::VECTOR_BINARY:
|
||||||
|
return arrow::list(arrow::fixed_size_binary((dim + 7) / 8));
|
||||||
|
case DataType::VECTOR_FLOAT16:
|
||||||
|
return arrow::list(arrow::fixed_size_binary(dim * 2));
|
||||||
|
case DataType::VECTOR_BFLOAT16:
|
||||||
|
return arrow::list(arrow::fixed_size_binary(dim * 2));
|
||||||
|
case DataType::VECTOR_INT8:
|
||||||
|
return arrow::list(arrow::fixed_size_binary(dim));
|
||||||
default: {
|
default: {
|
||||||
ThrowInfo(DataTypeInvalid,
|
ThrowInfo(DataTypeInvalid,
|
||||||
fmt::format("failed to get arrow type for vector array, "
|
fmt::format("failed to get arrow type for vector array, "
|
||||||
@ -534,7 +547,10 @@ IsFloatVectorMetricType(const MetricType& metric_type) {
|
|||||||
metric_type == knowhere::metric::IP ||
|
metric_type == knowhere::metric::IP ||
|
||||||
metric_type == knowhere::metric::COSINE ||
|
metric_type == knowhere::metric::COSINE ||
|
||||||
metric_type == knowhere::metric::BM25 ||
|
metric_type == knowhere::metric::BM25 ||
|
||||||
metric_type == knowhere::metric::MAX_SIM;
|
metric_type == knowhere::metric::MAX_SIM ||
|
||||||
|
metric_type == knowhere::metric::MAX_SIM_COSINE ||
|
||||||
|
metric_type == knowhere::metric::MAX_SIM_IP ||
|
||||||
|
metric_type == knowhere::metric::MAX_SIM_L2;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool
|
inline bool
|
||||||
@ -543,14 +559,20 @@ IsBinaryVectorMetricType(const MetricType& metric_type) {
|
|||||||
metric_type == knowhere::metric::JACCARD ||
|
metric_type == knowhere::metric::JACCARD ||
|
||||||
metric_type == knowhere::metric::SUPERSTRUCTURE ||
|
metric_type == knowhere::metric::SUPERSTRUCTURE ||
|
||||||
metric_type == knowhere::metric::SUBSTRUCTURE ||
|
metric_type == knowhere::metric::SUBSTRUCTURE ||
|
||||||
metric_type == knowhere::metric::MHJACCARD;
|
metric_type == knowhere::metric::MHJACCARD ||
|
||||||
|
metric_type == knowhere::metric::MAX_SIM_HAMMING ||
|
||||||
|
metric_type == knowhere::metric::MAX_SIM_JACCARD;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool
|
inline bool
|
||||||
IsIntVectorMetricType(const MetricType& metric_type) {
|
IsIntVectorMetricType(const MetricType& metric_type) {
|
||||||
return metric_type == knowhere::metric::L2 ||
|
return metric_type == knowhere::metric::L2 ||
|
||||||
metric_type == knowhere::metric::IP ||
|
metric_type == knowhere::metric::IP ||
|
||||||
metric_type == knowhere::metric::COSINE;
|
metric_type == knowhere::metric::COSINE ||
|
||||||
|
metric_type == knowhere::metric::MAX_SIM ||
|
||||||
|
metric_type == knowhere::metric::MAX_SIM_COSINE ||
|
||||||
|
metric_type == knowhere::metric::MAX_SIM_IP ||
|
||||||
|
metric_type == knowhere::metric::MAX_SIM_L2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Plus 1 because we can't use greater(>) symbol
|
// Plus 1 because we can't use greater(>) symbol
|
||||||
@ -746,6 +768,30 @@ FromValCase(milvus::proto::plan::GenericValue::ValCase val_case) {
|
|||||||
return DataType::NONE;
|
return DataType::NONE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calculate bytes per vector element for different vector types
|
||||||
|
// Used by VectorArray and storage utilities
|
||||||
|
inline size_t
|
||||||
|
vector_bytes_per_element(const DataType data_type, int64_t dim) {
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::VECTOR_BINARY:
|
||||||
|
// Binary vector stores bits, so dim represents bit count
|
||||||
|
// Need (dim + 7) / 8 bytes to store dim bits
|
||||||
|
return (dim + 7) / 8;
|
||||||
|
case DataType::VECTOR_FLOAT:
|
||||||
|
return dim * sizeof(float);
|
||||||
|
case DataType::VECTOR_FLOAT16:
|
||||||
|
return dim * sizeof(float16);
|
||||||
|
case DataType::VECTOR_BFLOAT16:
|
||||||
|
return dim * sizeof(bfloat16);
|
||||||
|
case DataType::VECTOR_INT8:
|
||||||
|
return dim * sizeof(int8);
|
||||||
|
default:
|
||||||
|
ThrowInfo(UnexpectedError,
|
||||||
|
fmt::format("invalid data type: {}", data_type));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace milvus
|
} // namespace milvus
|
||||||
template <>
|
template <>
|
||||||
struct fmt::formatter<milvus::DataType> : formatter<string_view> {
|
struct fmt::formatter<milvus::DataType> : formatter<string_view> {
|
||||||
|
|||||||
@ -41,16 +41,8 @@ class VectorArray : public milvus::VectorTrait {
|
|||||||
assert(num_vectors > 0);
|
assert(num_vectors > 0);
|
||||||
assert(dim > 0);
|
assert(dim > 0);
|
||||||
|
|
||||||
switch (element_type) {
|
size_ =
|
||||||
case DataType::VECTOR_FLOAT:
|
num_vectors * milvus::vector_bytes_per_element(element_type, dim);
|
||||||
size_ = num_vectors * dim * sizeof(float);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
ThrowInfo(NotImplemented,
|
|
||||||
"Direct VectorArray construction only supports "
|
|
||||||
"VECTOR_FLOAT, got {}",
|
|
||||||
GetDataTypeName(element_type));
|
|
||||||
}
|
|
||||||
|
|
||||||
data_ = std::make_unique<char[]>(size_);
|
data_ = std::make_unique<char[]>(size_);
|
||||||
std::memcpy(data_.get(), data, size_);
|
std::memcpy(data_.get(), data, size_);
|
||||||
@ -73,8 +65,49 @@ class VectorArray : public milvus::VectorTrait {
|
|||||||
data_ = std::unique_ptr<char[]>(reinterpret_cast<char*>(data));
|
data_ = std::unique_ptr<char[]>(reinterpret_cast<char*>(data));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case VectorFieldProto::kBinaryVector: {
|
||||||
|
element_type_ = DataType::VECTOR_BINARY;
|
||||||
|
int bytes_per_vector = (dim_ + 7) / 8;
|
||||||
|
length_ =
|
||||||
|
vector_field.binary_vector().size() / bytes_per_vector;
|
||||||
|
size_ = vector_field.binary_vector().size();
|
||||||
|
data_ = std::make_unique<char[]>(size_);
|
||||||
|
std::memcpy(
|
||||||
|
data_.get(), vector_field.binary_vector().data(), size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case VectorFieldProto::kFloat16Vector: {
|
||||||
|
element_type_ = DataType::VECTOR_FLOAT16;
|
||||||
|
int bytes_per_element = 2; // 2 bytes per float16
|
||||||
|
length_ = vector_field.float16_vector().size() /
|
||||||
|
(dim_ * bytes_per_element);
|
||||||
|
size_ = vector_field.float16_vector().size();
|
||||||
|
data_ = std::make_unique<char[]>(size_);
|
||||||
|
std::memcpy(
|
||||||
|
data_.get(), vector_field.float16_vector().data(), size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case VectorFieldProto::kBfloat16Vector: {
|
||||||
|
element_type_ = DataType::VECTOR_BFLOAT16;
|
||||||
|
int bytes_per_element = 2; // 2 bytes per bfloat16
|
||||||
|
length_ = vector_field.bfloat16_vector().size() /
|
||||||
|
(dim_ * bytes_per_element);
|
||||||
|
size_ = vector_field.bfloat16_vector().size();
|
||||||
|
data_ = std::make_unique<char[]>(size_);
|
||||||
|
std::memcpy(
|
||||||
|
data_.get(), vector_field.bfloat16_vector().data(), size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case VectorFieldProto::kInt8Vector: {
|
||||||
|
element_type_ = DataType::VECTOR_INT8;
|
||||||
|
length_ = vector_field.int8_vector().size() / dim_;
|
||||||
|
size_ = vector_field.int8_vector().size();
|
||||||
|
data_ = std::make_unique<char[]>(size_);
|
||||||
|
std::memcpy(
|
||||||
|
data_.get(), vector_field.int8_vector().data(), size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
// TODO(SpadeA): add other vector types
|
|
||||||
ThrowInfo(NotImplemented,
|
ThrowInfo(NotImplemented,
|
||||||
"Not implemented vector type: {}",
|
"Not implemented vector type: {}",
|
||||||
static_cast<int>(vector_field.data_case()));
|
static_cast<int>(vector_field.data_case()));
|
||||||
@ -160,8 +193,24 @@ class VectorArray : public milvus::VectorTrait {
|
|||||||
return reinterpret_cast<VectorElement*>(data_.get()) +
|
return reinterpret_cast<VectorElement*>(data_.get()) +
|
||||||
index * dim_;
|
index * dim_;
|
||||||
}
|
}
|
||||||
|
case DataType::VECTOR_BINARY: {
|
||||||
|
// Binary vectors are packed bits
|
||||||
|
int bytes_per_vector = (dim_ + 7) / 8;
|
||||||
|
return reinterpret_cast<VectorElement*>(
|
||||||
|
data_.get() + index * bytes_per_vector);
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_FLOAT16:
|
||||||
|
case DataType::VECTOR_BFLOAT16: {
|
||||||
|
// Float16/BFloat16 are 2 bytes per element
|
||||||
|
return reinterpret_cast<VectorElement*>(data_.get() +
|
||||||
|
index * dim_ * 2);
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_INT8: {
|
||||||
|
// Int8 is 1 byte per element
|
||||||
|
return reinterpret_cast<VectorElement*>(data_.get() +
|
||||||
|
index * dim_);
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
// TODO(SpadeA): add other vector types
|
|
||||||
ThrowInfo(NotImplemented,
|
ThrowInfo(NotImplemented,
|
||||||
"Not implemented vector type: {}",
|
"Not implemented vector type: {}",
|
||||||
static_cast<int>(element_type_));
|
static_cast<int>(element_type_));
|
||||||
@ -180,8 +229,23 @@ class VectorArray : public milvus::VectorTrait {
|
|||||||
data, data + length_ * dim_);
|
data, data + length_ * dim_);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case DataType::VECTOR_BINARY: {
|
||||||
|
vector_field.set_binary_vector(data_.get(), size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_FLOAT16: {
|
||||||
|
vector_field.set_float16_vector(data_.get(), size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_BFLOAT16: {
|
||||||
|
vector_field.set_bfloat16_vector(data_.get(), size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_INT8: {
|
||||||
|
vector_field.set_int8_vector(data_.get(), size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
// TODO(SpadeA): add other vector types
|
|
||||||
ThrowInfo(NotImplemented,
|
ThrowInfo(NotImplemented,
|
||||||
"Not implemented vector type: {}",
|
"Not implemented vector type: {}",
|
||||||
static_cast<int>(element_type_));
|
static_cast<int>(element_type_));
|
||||||
@ -300,8 +364,23 @@ class VectorArrayView {
|
|||||||
"VectorElement must be float for VECTOR_FLOAT");
|
"VectorElement must be float for VECTOR_FLOAT");
|
||||||
return reinterpret_cast<VectorElement*>(data_) + index * dim_;
|
return reinterpret_cast<VectorElement*>(data_) + index * dim_;
|
||||||
}
|
}
|
||||||
|
case DataType::VECTOR_BINARY: {
|
||||||
|
// Binary vectors are packed bits
|
||||||
|
int bytes_per_vector = (dim_ + 7) / 8;
|
||||||
|
return reinterpret_cast<VectorElement*>(
|
||||||
|
data_ + index * bytes_per_vector);
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_FLOAT16:
|
||||||
|
case DataType::VECTOR_BFLOAT16: {
|
||||||
|
// Float16/BFloat16 are 2 bytes per element
|
||||||
|
return reinterpret_cast<VectorElement*>(data_ +
|
||||||
|
index * dim_ * 2);
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_INT8: {
|
||||||
|
// Int8 is 1 byte per element
|
||||||
|
return reinterpret_cast<VectorElement*>(data_ + index * dim_);
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
// TODO(SpadeA): add other vector types.
|
|
||||||
ThrowInfo(NotImplemented,
|
ThrowInfo(NotImplemented,
|
||||||
"Not implemented vector type: {}",
|
"Not implemented vector type: {}",
|
||||||
static_cast<int>(element_type_));
|
static_cast<int>(element_type_));
|
||||||
@ -320,8 +399,23 @@ class VectorArrayView {
|
|||||||
data, data + length_ * dim_);
|
data, data + length_ * dim_);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case DataType::VECTOR_BINARY: {
|
||||||
|
vector_array.set_binary_vector(data_, size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_FLOAT16: {
|
||||||
|
vector_array.set_float16_vector(data_, size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_BFLOAT16: {
|
||||||
|
vector_array.set_bfloat16_vector(data_, size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_INT8: {
|
||||||
|
vector_array.set_int8_vector(data_, size_);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
// TODO(SpadeA): add other vector types
|
|
||||||
ThrowInfo(NotImplemented,
|
ThrowInfo(NotImplemented,
|
||||||
"Not implemented vector type: {}",
|
"Not implemented vector type: {}",
|
||||||
static_cast<int>(element_type_));
|
static_cast<int>(element_type_));
|
||||||
|
|||||||
@ -40,12 +40,57 @@ class VectorArrayChunkTest : public ::testing::Test {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<uint16_t>
|
||||||
|
generateFloat16Vector(int64_t seed, int64_t N, int64_t dim) {
|
||||||
|
std::vector<uint16_t> result(dim * N);
|
||||||
|
std::default_random_engine gen(seed);
|
||||||
|
std::uniform_int_distribution<uint16_t> dist(0, 65535);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < dim * N; ++i) {
|
||||||
|
result[i] = dist(gen);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint16_t>
|
||||||
|
generateBFloat16Vector(int64_t seed, int64_t N, int64_t dim) {
|
||||||
|
// Same as Float16 for testing purposes
|
||||||
|
return generateFloat16Vector(seed, N, dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int8_t>
|
||||||
|
generateInt8Vector(int64_t seed, int64_t N, int64_t dim) {
|
||||||
|
std::vector<int8_t> result(dim * N);
|
||||||
|
std::default_random_engine gen(seed);
|
||||||
|
std::uniform_int_distribution<int> dist(-128, 127);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < dim * N; ++i) {
|
||||||
|
result[i] = static_cast<int8_t>(dist(gen));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint8_t>
|
||||||
|
generateBinaryVector(int64_t seed, int64_t N, int64_t dim) {
|
||||||
|
std::vector<uint8_t> result((dim * N + 7) / 8);
|
||||||
|
std::default_random_engine gen(seed);
|
||||||
|
std::uniform_int_distribution<int> dist(0, 255);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < result.size(); ++i) {
|
||||||
|
result[i] = static_cast<uint8_t>(dist(gen));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<arrow::ListArray>
|
std::shared_ptr<arrow::ListArray>
|
||||||
createFloatVectorListArray(const std::vector<float>& data,
|
createFloatVectorListArray(const std::vector<float>& data,
|
||||||
const std::vector<int32_t>& offsets) {
|
const std::vector<int32_t>& offsets,
|
||||||
auto float_builder = std::make_shared<arrow::FloatBuilder>();
|
int64_t dim) {
|
||||||
|
int byte_width = dim * sizeof(float);
|
||||||
|
auto value_builder = std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||||
|
arrow::fixed_size_binary(byte_width));
|
||||||
auto list_builder = std::make_shared<arrow::ListBuilder>(
|
auto list_builder = std::make_shared<arrow::ListBuilder>(
|
||||||
arrow::default_memory_pool(), float_builder);
|
arrow::default_memory_pool(), value_builder);
|
||||||
|
|
||||||
arrow::Status ast;
|
arrow::Status ast;
|
||||||
for (size_t i = 0; i < offsets.size() - 1; ++i) {
|
for (size_t i = 0; i < offsets.size() - 1; ++i) {
|
||||||
@ -54,8 +99,112 @@ class VectorArrayChunkTest : public ::testing::Test {
|
|||||||
int32_t start = offsets[i];
|
int32_t start = offsets[i];
|
||||||
int32_t end = offsets[i + 1];
|
int32_t end = offsets[i + 1];
|
||||||
|
|
||||||
for (int32_t j = start; j < end; ++j) {
|
// Each vector is dim floats
|
||||||
float_builder->Append(data[j]);
|
for (int32_t j = start; j < end; j += dim) {
|
||||||
|
// Convert float vector to binary
|
||||||
|
const uint8_t* binary_data =
|
||||||
|
reinterpret_cast<const uint8_t*>(&data[j]);
|
||||||
|
ast = value_builder->Append(binary_data);
|
||||||
|
assert(ast.ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<arrow::Array> array;
|
||||||
|
ast = list_builder->Finish(&array);
|
||||||
|
assert(ast.ok());
|
||||||
|
return std::static_pointer_cast<arrow::ListArray>(array);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<arrow::ListArray>
|
||||||
|
createFloat16VectorListArray(const std::vector<uint16_t>& data,
|
||||||
|
const std::vector<int32_t>& offsets,
|
||||||
|
int64_t dim) {
|
||||||
|
int byte_width = dim * 2; // 2 bytes per float16
|
||||||
|
auto value_builder = std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||||
|
arrow::fixed_size_binary(byte_width));
|
||||||
|
auto list_builder = std::make_shared<arrow::ListBuilder>(
|
||||||
|
arrow::default_memory_pool(), value_builder);
|
||||||
|
|
||||||
|
arrow::Status ast;
|
||||||
|
for (size_t i = 0; i < offsets.size() - 1; ++i) {
|
||||||
|
ast = list_builder->Append();
|
||||||
|
assert(ast.ok());
|
||||||
|
int32_t start = offsets[i];
|
||||||
|
int32_t end = offsets[i + 1];
|
||||||
|
|
||||||
|
for (int32_t j = start; j < end; j += dim) {
|
||||||
|
const uint8_t* binary_data =
|
||||||
|
reinterpret_cast<const uint8_t*>(&data[j]);
|
||||||
|
ast = value_builder->Append(binary_data);
|
||||||
|
assert(ast.ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<arrow::Array> array;
|
||||||
|
ast = list_builder->Finish(&array);
|
||||||
|
assert(ast.ok());
|
||||||
|
return std::static_pointer_cast<arrow::ListArray>(array);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<arrow::ListArray>
|
||||||
|
createBFloat16VectorListArray(const std::vector<uint16_t>& data,
|
||||||
|
const std::vector<int32_t>& offsets,
|
||||||
|
int64_t dim) {
|
||||||
|
// Same as Float16 but for bfloat16
|
||||||
|
return createFloat16VectorListArray(data, offsets, dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<arrow::ListArray>
|
||||||
|
createInt8VectorListArray(const std::vector<int8_t>& data,
|
||||||
|
const std::vector<int32_t>& offsets,
|
||||||
|
int64_t dim) {
|
||||||
|
int byte_width = dim; // 1 byte per int8
|
||||||
|
auto value_builder = std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||||
|
arrow::fixed_size_binary(byte_width));
|
||||||
|
auto list_builder = std::make_shared<arrow::ListBuilder>(
|
||||||
|
arrow::default_memory_pool(), value_builder);
|
||||||
|
|
||||||
|
arrow::Status ast;
|
||||||
|
for (size_t i = 0; i < offsets.size() - 1; ++i) {
|
||||||
|
ast = list_builder->Append();
|
||||||
|
assert(ast.ok());
|
||||||
|
int32_t start = offsets[i];
|
||||||
|
int32_t end = offsets[i + 1];
|
||||||
|
|
||||||
|
for (int32_t j = start; j < end; j += dim) {
|
||||||
|
const uint8_t* binary_data =
|
||||||
|
reinterpret_cast<const uint8_t*>(&data[j]);
|
||||||
|
ast = value_builder->Append(binary_data);
|
||||||
|
assert(ast.ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<arrow::Array> array;
|
||||||
|
ast = list_builder->Finish(&array);
|
||||||
|
assert(ast.ok());
|
||||||
|
return std::static_pointer_cast<arrow::ListArray>(array);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<arrow::ListArray>
|
||||||
|
createBinaryVectorListArray(const std::vector<uint8_t>& data,
|
||||||
|
const std::vector<int32_t>& offsets,
|
||||||
|
int64_t dim) {
|
||||||
|
int byte_width = (dim + 7) / 8; // bits packed into bytes
|
||||||
|
auto value_builder = std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||||
|
arrow::fixed_size_binary(byte_width));
|
||||||
|
auto list_builder = std::make_shared<arrow::ListBuilder>(
|
||||||
|
arrow::default_memory_pool(), value_builder);
|
||||||
|
|
||||||
|
arrow::Status ast;
|
||||||
|
for (size_t i = 0; i < offsets.size() - 1; ++i) {
|
||||||
|
ast = list_builder->Append();
|
||||||
|
assert(ast.ok());
|
||||||
|
int32_t start = offsets[i];
|
||||||
|
int32_t end = offsets[i + 1];
|
||||||
|
|
||||||
|
for (int32_t j = start; j < end; j += byte_width) {
|
||||||
|
ast = value_builder->Append(&data[j]);
|
||||||
|
assert(ast.ok());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,52 +215,6 @@ class VectorArrayChunkTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(VectorArrayChunkTest, TestWriteFloatVectorArray) {
|
|
||||||
// Test parameters
|
|
||||||
const int64_t dim = 128;
|
|
||||||
const int num_rows = 100;
|
|
||||||
const int vectors_per_row = 5; // Each row contains 5 vectors
|
|
||||||
|
|
||||||
// Generate test data
|
|
||||||
std::vector<float> all_data;
|
|
||||||
std::vector<int32_t> offsets = {0};
|
|
||||||
|
|
||||||
for (int row = 0; row < num_rows; ++row) {
|
|
||||||
auto row_data = generateFloatVector(row, vectors_per_row, dim);
|
|
||||||
all_data.insert(all_data.end(), row_data.begin(), row_data.end());
|
|
||||||
offsets.push_back(offsets.back() + vectors_per_row * dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create Arrow ListArray
|
|
||||||
auto list_array = createFloatVectorListArray(all_data, offsets);
|
|
||||||
arrow::ArrayVector array_vec = {list_array};
|
|
||||||
|
|
||||||
// Test VectorArrayChunkWriter
|
|
||||||
VectorArrayChunkWriter writer(dim, DataType::VECTOR_FLOAT);
|
|
||||||
writer.write(array_vec);
|
|
||||||
|
|
||||||
auto chunk = writer.finish();
|
|
||||||
auto vector_array_chunk = static_cast<VectorArrayChunk*>(chunk.get());
|
|
||||||
|
|
||||||
// Verify results
|
|
||||||
EXPECT_EQ(vector_array_chunk->RowNums(), num_rows);
|
|
||||||
|
|
||||||
// Verify data integrity using View method
|
|
||||||
for (int row = 0; row < num_rows; ++row) {
|
|
||||||
auto view = vector_array_chunk->View(row);
|
|
||||||
|
|
||||||
// Verify by converting back to VectorFieldProto and checking
|
|
||||||
auto proto = view.output_data();
|
|
||||||
EXPECT_EQ(proto.dim(), dim);
|
|
||||||
EXPECT_EQ(proto.float_vector().data_size(), vectors_per_row * dim);
|
|
||||||
|
|
||||||
const float* expected = all_data.data() + row * vectors_per_row * dim;
|
|
||||||
for (int i = 0; i < vectors_per_row * dim; ++i) {
|
|
||||||
EXPECT_FLOAT_EQ(proto.float_vector().data(i), expected[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(VectorArrayChunkTest, TestWriteMultipleBatches) {
|
TEST_F(VectorArrayChunkTest, TestWriteMultipleBatches) {
|
||||||
const int64_t dim = 64;
|
const int64_t dim = 64;
|
||||||
const int batch_size = 50;
|
const int batch_size = 50;
|
||||||
@ -137,7 +240,7 @@ TEST_F(VectorArrayChunkTest, TestWriteMultipleBatches) {
|
|||||||
|
|
||||||
all_batch_data.push_back(batch_data);
|
all_batch_data.push_back(batch_data);
|
||||||
array_vec.push_back(
|
array_vec.push_back(
|
||||||
createFloatVectorListArray(batch_data, batch_offsets));
|
createFloatVectorListArray(batch_data, batch_offsets, dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write using VectorArrayChunkWriter
|
// Write using VectorArrayChunkWriter
|
||||||
@ -191,7 +294,7 @@ TEST_F(VectorArrayChunkTest, TestWriteWithMmap) {
|
|||||||
offsets.push_back(offsets.back() + vectors_per_row * dim);
|
offsets.push_back(offsets.back() + vectors_per_row * dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto list_array = createFloatVectorListArray(all_data, offsets);
|
auto list_array = createFloatVectorListArray(all_data, offsets, dim);
|
||||||
arrow::ArrayVector array_vec = {list_array};
|
arrow::ArrayVector array_vec = {list_array};
|
||||||
|
|
||||||
// Write with mmap
|
// Write with mmap
|
||||||
@ -235,3 +338,189 @@ TEST_F(VectorArrayChunkTest, TestEmptyVectorArray) {
|
|||||||
|
|
||||||
EXPECT_EQ(vector_array_chunk->RowNums(), 0);
|
EXPECT_EQ(vector_array_chunk->RowNums(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct VectorArrayTestParam {
|
||||||
|
DataType data_type;
|
||||||
|
int64_t dim;
|
||||||
|
int num_rows;
|
||||||
|
int vectors_per_row;
|
||||||
|
std::string test_name;
|
||||||
|
};
|
||||||
|
|
||||||
|
class VectorArrayChunkParameterizedTest
|
||||||
|
: public VectorArrayChunkTest,
|
||||||
|
public ::testing::WithParamInterface<VectorArrayTestParam> {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::shared_ptr<arrow::ListArray>
|
||||||
|
createVectorListArray(const std::vector<T>& data,
|
||||||
|
const std::vector<int32_t>& offsets,
|
||||||
|
int64_t dim,
|
||||||
|
DataType dtype) {
|
||||||
|
int byte_width;
|
||||||
|
switch (dtype) {
|
||||||
|
case DataType::VECTOR_FLOAT:
|
||||||
|
byte_width = dim * sizeof(float);
|
||||||
|
break;
|
||||||
|
case DataType::VECTOR_FLOAT16:
|
||||||
|
case DataType::VECTOR_BFLOAT16:
|
||||||
|
byte_width = dim * 2;
|
||||||
|
break;
|
||||||
|
case DataType::VECTOR_INT8:
|
||||||
|
byte_width = dim;
|
||||||
|
break;
|
||||||
|
case DataType::VECTOR_BINARY:
|
||||||
|
byte_width = (dim + 7) / 8;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Unsupported data type");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto value_builder = std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||||
|
arrow::fixed_size_binary(byte_width));
|
||||||
|
auto list_builder = std::make_shared<arrow::ListBuilder>(
|
||||||
|
arrow::default_memory_pool(), value_builder);
|
||||||
|
|
||||||
|
arrow::Status ast;
|
||||||
|
int element_size = (dtype == DataType::VECTOR_BINARY) ? byte_width : dim;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < offsets.size() - 1; ++i) {
|
||||||
|
ast = list_builder->Append();
|
||||||
|
assert(ast.ok());
|
||||||
|
int32_t start = offsets[i];
|
||||||
|
int32_t end = offsets[i + 1];
|
||||||
|
|
||||||
|
for (int32_t j = start; j < end; j += element_size) {
|
||||||
|
const uint8_t* binary_data =
|
||||||
|
reinterpret_cast<const uint8_t*>(&data[j]);
|
||||||
|
ast = value_builder->Append(binary_data);
|
||||||
|
assert(ast.ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<arrow::Array> array;
|
||||||
|
ast = list_builder->Finish(&array);
|
||||||
|
assert(ast.ok());
|
||||||
|
return std::static_pointer_cast<arrow::ListArray>(array);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(VectorArrayChunkParameterizedTest, TestWriteVectorArray) {
|
||||||
|
auto param = GetParam();
|
||||||
|
|
||||||
|
// Generate test data based on type
|
||||||
|
std::vector<uint8_t> all_data;
|
||||||
|
std::vector<int32_t> offsets = {0};
|
||||||
|
|
||||||
|
for (int row = 0; row < param.num_rows; ++row) {
|
||||||
|
std::vector<uint8_t> row_data;
|
||||||
|
|
||||||
|
switch (param.data_type) {
|
||||||
|
case DataType::VECTOR_FLOAT: {
|
||||||
|
auto float_data =
|
||||||
|
generateFloatVector(row, param.vectors_per_row, param.dim);
|
||||||
|
row_data.resize(float_data.size() * sizeof(float));
|
||||||
|
memcpy(row_data.data(), float_data.data(), row_data.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_FLOAT16:
|
||||||
|
case DataType::VECTOR_BFLOAT16: {
|
||||||
|
auto uint16_data = generateFloat16Vector(
|
||||||
|
row, param.vectors_per_row, param.dim);
|
||||||
|
row_data.resize(uint16_data.size() * sizeof(uint16_t));
|
||||||
|
memcpy(row_data.data(), uint16_data.data(), row_data.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_INT8: {
|
||||||
|
auto int8_data =
|
||||||
|
generateInt8Vector(row, param.vectors_per_row, param.dim);
|
||||||
|
row_data.resize(int8_data.size());
|
||||||
|
memcpy(row_data.data(), int8_data.data(), row_data.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_BINARY: {
|
||||||
|
row_data =
|
||||||
|
generateBinaryVector(row, param.vectors_per_row, param.dim);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
FAIL() << "Unsupported data type";
|
||||||
|
}
|
||||||
|
|
||||||
|
all_data.insert(all_data.end(), row_data.begin(), row_data.end());
|
||||||
|
|
||||||
|
// Calculate offset based on data type
|
||||||
|
int offset_increment;
|
||||||
|
if (param.data_type == DataType::VECTOR_BINARY) {
|
||||||
|
offset_increment = param.vectors_per_row * ((param.dim + 7) / 8);
|
||||||
|
} else if (param.data_type == DataType::VECTOR_FLOAT) {
|
||||||
|
offset_increment =
|
||||||
|
param.vectors_per_row * param.dim * sizeof(float);
|
||||||
|
} else if (param.data_type == DataType::VECTOR_FLOAT16 ||
|
||||||
|
param.data_type == DataType::VECTOR_BFLOAT16) {
|
||||||
|
offset_increment = param.vectors_per_row * param.dim * 2;
|
||||||
|
} else {
|
||||||
|
offset_increment = param.vectors_per_row * param.dim;
|
||||||
|
}
|
||||||
|
offsets.push_back(offsets.back() + offset_increment);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Arrow ListArray
|
||||||
|
auto list_array =
|
||||||
|
createVectorListArray(all_data, offsets, param.dim, param.data_type);
|
||||||
|
arrow::ArrayVector array_vec = {list_array};
|
||||||
|
|
||||||
|
// Test VectorArrayChunkWriter
|
||||||
|
VectorArrayChunkWriter writer(param.dim, param.data_type);
|
||||||
|
writer.write(array_vec);
|
||||||
|
|
||||||
|
auto chunk = writer.finish();
|
||||||
|
auto vector_array_chunk = static_cast<VectorArrayChunk*>(chunk.get());
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
EXPECT_EQ(vector_array_chunk->RowNums(), param.num_rows);
|
||||||
|
|
||||||
|
// Basic verification - ensure View doesn't crash and returns valid data
|
||||||
|
for (int row = 0; row < param.num_rows; ++row) {
|
||||||
|
auto view = vector_array_chunk->View(row);
|
||||||
|
auto proto = view.output_data();
|
||||||
|
EXPECT_EQ(proto.dim(), param.dim);
|
||||||
|
|
||||||
|
// Verify the correct field is populated based on data type
|
||||||
|
switch (param.data_type) {
|
||||||
|
case DataType::VECTOR_FLOAT:
|
||||||
|
EXPECT_GT(proto.float_vector().data_size(), 0);
|
||||||
|
break;
|
||||||
|
case DataType::VECTOR_FLOAT16:
|
||||||
|
EXPECT_GT(proto.float16_vector().size(), 0);
|
||||||
|
break;
|
||||||
|
case DataType::VECTOR_BFLOAT16:
|
||||||
|
EXPECT_GT(proto.bfloat16_vector().size(), 0);
|
||||||
|
break;
|
||||||
|
case DataType::VECTOR_INT8:
|
||||||
|
EXPECT_GT(proto.int8_vector().size(), 0);
|
||||||
|
break;
|
||||||
|
case DataType::VECTOR_BINARY:
|
||||||
|
EXPECT_GT(proto.binary_vector().size(), 0);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
VectorTypes,
|
||||||
|
VectorArrayChunkParameterizedTest,
|
||||||
|
::testing::Values(
|
||||||
|
VectorArrayTestParam{
|
||||||
|
DataType::VECTOR_FLOAT, 128, 100, 5, "FloatVector"},
|
||||||
|
VectorArrayTestParam{
|
||||||
|
DataType::VECTOR_FLOAT16, 128, 100, 3, "Float16Vector"},
|
||||||
|
VectorArrayTestParam{
|
||||||
|
DataType::VECTOR_BFLOAT16, 64, 50, 2, "BFloat16Vector"},
|
||||||
|
VectorArrayTestParam{DataType::VECTOR_INT8, 256, 80, 4, "Int8Vector"},
|
||||||
|
VectorArrayTestParam{
|
||||||
|
DataType::VECTOR_BINARY, 512, 60, 3, "BinaryVector"}),
|
||||||
|
[](const testing::TestParamInfo<VectorArrayTestParam>& info) {
|
||||||
|
return info.param.test_name;
|
||||||
|
});
|
||||||
|
|||||||
@ -159,8 +159,12 @@ class TestVectorArrayStorageV2 : public testing::Test {
|
|||||||
|
|
||||||
// Create appropriate value builder based on element type
|
// Create appropriate value builder based on element type
|
||||||
std::shared_ptr<arrow::ArrayBuilder> value_builder;
|
std::shared_ptr<arrow::ArrayBuilder> value_builder;
|
||||||
|
int byte_width = 0;
|
||||||
if (element_type == DataType::VECTOR_FLOAT) {
|
if (element_type == DataType::VECTOR_FLOAT) {
|
||||||
value_builder = std::make_shared<arrow::FloatBuilder>();
|
byte_width = DIM * sizeof(float);
|
||||||
|
value_builder =
|
||||||
|
std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||||
|
arrow::fixed_size_binary(byte_width));
|
||||||
} else {
|
} else {
|
||||||
FAIL() << "Unsupported element type for VECTOR_ARRAY "
|
FAIL() << "Unsupported element type for VECTOR_ARRAY "
|
||||||
"in test";
|
"in test";
|
||||||
@ -176,11 +180,13 @@ class TestVectorArrayStorageV2 : public testing::Test {
|
|||||||
|
|
||||||
// Generate 3 vectors for this row
|
// Generate 3 vectors for this row
|
||||||
auto data = generate_float_vector(3, DIM);
|
auto data = generate_float_vector(3, DIM);
|
||||||
auto float_builder =
|
auto binary_builder = std::static_pointer_cast<
|
||||||
std::static_pointer_cast<arrow::FloatBuilder>(
|
arrow::FixedSizeBinaryBuilder>(value_builder);
|
||||||
value_builder);
|
// Append each vector as a fixed-size binary value
|
||||||
for (const auto& value : data) {
|
for (int vec_idx = 0; vec_idx < 3; vec_idx++) {
|
||||||
status = float_builder->Append(value);
|
status = binary_builder->Append(
|
||||||
|
reinterpret_cast<const uint8_t*>(
|
||||||
|
data.data() + vec_idx * DIM));
|
||||||
EXPECT_TRUE(status.ok());
|
EXPECT_TRUE(status.ok());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -288,7 +294,7 @@ TEST_F(TestVectorArrayStorageV2, BuildEmbListHNSWIndex) {
|
|||||||
milvus::index::CreateIndexInfo create_index_info;
|
milvus::index::CreateIndexInfo create_index_info;
|
||||||
create_index_info.field_type = DataType::VECTOR_ARRAY;
|
create_index_info.field_type = DataType::VECTOR_ARRAY;
|
||||||
create_index_info.metric_type = knowhere::metric::MAX_SIM;
|
create_index_info.metric_type = knowhere::metric::MAX_SIM;
|
||||||
create_index_info.index_type = knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
|
create_index_info.index_type = knowhere::IndexEnum::INDEX_HNSW;
|
||||||
create_index_info.index_engine_version =
|
create_index_info.index_engine_version =
|
||||||
knowhere::Version::GetCurrentVersion().VersionNumber();
|
knowhere::Version::GetCurrentVersion().VersionNumber();
|
||||||
|
|
||||||
@ -299,8 +305,7 @@ TEST_F(TestVectorArrayStorageV2, BuildEmbListHNSWIndex) {
|
|||||||
|
|
||||||
// Build index with storage v2 configuration
|
// Build index with storage v2 configuration
|
||||||
Config config;
|
Config config;
|
||||||
config[milvus::index::INDEX_TYPE] =
|
config[milvus::index::INDEX_TYPE] = knowhere::IndexEnum::INDEX_HNSW;
|
||||||
knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
|
|
||||||
config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type;
|
config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type;
|
||||||
config[knowhere::indexparam::M] = "16";
|
config[knowhere::indexparam::M] = "16";
|
||||||
config[knowhere::indexparam::EF] = "10";
|
config[knowhere::indexparam::EF] = "10";
|
||||||
@ -330,11 +335,12 @@ TEST_F(TestVectorArrayStorageV2, BuildEmbListHNSWIndex) {
|
|||||||
std::vector<float> query_vec = generate_float_vector(vec_num, DIM);
|
std::vector<float> query_vec = generate_float_vector(vec_num, DIM);
|
||||||
auto query_dataset =
|
auto query_dataset =
|
||||||
knowhere::GenDataSet(vec_num, DIM, query_vec.data());
|
knowhere::GenDataSet(vec_num, DIM, query_vec.data());
|
||||||
std::vector<size_t> query_vec_lims;
|
std::vector<size_t> query_vec_offsets;
|
||||||
query_vec_lims.push_back(0);
|
query_vec_offsets.push_back(0);
|
||||||
query_vec_lims.push_back(3);
|
query_vec_offsets.push_back(3);
|
||||||
query_vec_lims.push_back(10);
|
query_vec_offsets.push_back(10);
|
||||||
query_dataset->SetLims(query_vec_lims.data());
|
query_dataset->Set(knowhere::meta::EMB_LIST_OFFSET,
|
||||||
|
const_cast<const size_t*>(query_vec_offsets.data()));
|
||||||
|
|
||||||
auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}};
|
auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}};
|
||||||
milvus::SearchInfo searchInfo;
|
milvus::SearchInfo searchInfo;
|
||||||
|
|||||||
@ -32,19 +32,35 @@ namespace milvus {
|
|||||||
using elem_type = std::conditional_t< \
|
using elem_type = std::conditional_t< \
|
||||||
std::is_same_v<TraitType, milvus::EmbListFloatVector>, \
|
std::is_same_v<TraitType, milvus::EmbListFloatVector>, \
|
||||||
milvus::EmbListFloatVector::embedded_type, \
|
milvus::EmbListFloatVector::embedded_type, \
|
||||||
|
std::conditional_t< \
|
||||||
|
std::is_same_v<TraitType, milvus::EmbListBinaryVector>, \
|
||||||
|
milvus::EmbListBinaryVector::embedded_type, \
|
||||||
|
std::conditional_t< \
|
||||||
|
std::is_same_v<TraitType, milvus::EmbListFloat16Vector>, \
|
||||||
|
milvus::EmbListFloat16Vector::embedded_type, \
|
||||||
|
std::conditional_t< \
|
||||||
|
std::is_same_v<TraitType, milvus::EmbListBFloat16Vector>, \
|
||||||
|
milvus::EmbListBFloat16Vector::embedded_type, \
|
||||||
|
std::conditional_t< \
|
||||||
|
std::is_same_v<TraitType, milvus::EmbListInt8Vector>, \
|
||||||
|
milvus::EmbListInt8Vector::embedded_type, \
|
||||||
std::conditional_t< \
|
std::conditional_t< \
|
||||||
std::is_same_v<TraitType, milvus::FloatVector>, \
|
std::is_same_v<TraitType, milvus::FloatVector>, \
|
||||||
milvus::FloatVector::embedded_type, \
|
milvus::FloatVector::embedded_type, \
|
||||||
std::conditional_t< \
|
std::conditional_t< \
|
||||||
std::is_same_v<TraitType, milvus::Float16Vector>, \
|
std::is_same_v<TraitType, \
|
||||||
|
milvus::Float16Vector>, \
|
||||||
milvus::Float16Vector::embedded_type, \
|
milvus::Float16Vector::embedded_type, \
|
||||||
std::conditional_t< \
|
std::conditional_t< \
|
||||||
std::is_same_v<TraitType, milvus::BFloat16Vector>, \
|
std::is_same_v<TraitType, \
|
||||||
|
milvus::BFloat16Vector>, \
|
||||||
milvus::BFloat16Vector::embedded_type, \
|
milvus::BFloat16Vector::embedded_type, \
|
||||||
std::conditional_t< \
|
std::conditional_t< \
|
||||||
std::is_same_v<TraitType, milvus::Int8Vector>, \
|
std::is_same_v<TraitType, \
|
||||||
|
milvus::Int8Vector>, \
|
||||||
milvus::Int8Vector::embedded_type, \
|
milvus::Int8Vector::embedded_type, \
|
||||||
milvus::BinaryVector::embedded_type>>>>>;
|
milvus::BinaryVector:: \
|
||||||
|
embedded_type>>>>>>>>>;
|
||||||
|
|
||||||
#define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \
|
#define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \
|
||||||
auto schema_data_type = \
|
auto schema_data_type = \
|
||||||
@ -164,6 +180,82 @@ class EmbListFloatVector : public VectorTrait {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class EmbListBinaryVector : public VectorTrait {
|
||||||
|
public:
|
||||||
|
using embedded_type = uint8_t;
|
||||||
|
static constexpr int32_t dim_factor = 8;
|
||||||
|
static constexpr auto data_type = DataType::VECTOR_ARRAY;
|
||||||
|
static constexpr auto c_data_type = CDataType::VectorArray;
|
||||||
|
static constexpr auto schema_data_type =
|
||||||
|
proto::schema::DataType::ArrayOfVector;
|
||||||
|
static constexpr auto vector_type =
|
||||||
|
proto::plan::VectorType::EmbListBinaryVector;
|
||||||
|
static constexpr auto placeholder_type =
|
||||||
|
proto::common::PlaceholderType::EmbListBinaryVector;
|
||||||
|
|
||||||
|
static constexpr bool
|
||||||
|
is_embedding_list() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class EmbListFloat16Vector : public VectorTrait {
|
||||||
|
public:
|
||||||
|
using embedded_type = float16;
|
||||||
|
static constexpr int32_t dim_factor = 1;
|
||||||
|
static constexpr auto data_type = DataType::VECTOR_ARRAY;
|
||||||
|
static constexpr auto c_data_type = CDataType::VectorArray;
|
||||||
|
static constexpr auto schema_data_type =
|
||||||
|
proto::schema::DataType::ArrayOfVector;
|
||||||
|
static constexpr auto vector_type =
|
||||||
|
proto::plan::VectorType::EmbListFloat16Vector;
|
||||||
|
static constexpr auto placeholder_type =
|
||||||
|
proto::common::PlaceholderType::EmbListFloat16Vector;
|
||||||
|
|
||||||
|
static constexpr bool
|
||||||
|
is_embedding_list() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class EmbListBFloat16Vector : public VectorTrait {
|
||||||
|
public:
|
||||||
|
using embedded_type = bfloat16;
|
||||||
|
static constexpr int32_t dim_factor = 1;
|
||||||
|
static constexpr auto data_type = DataType::VECTOR_ARRAY;
|
||||||
|
static constexpr auto c_data_type = CDataType::VectorArray;
|
||||||
|
static constexpr auto schema_data_type =
|
||||||
|
proto::schema::DataType::ArrayOfVector;
|
||||||
|
static constexpr auto vector_type =
|
||||||
|
proto::plan::VectorType::EmbListBFloat16Vector;
|
||||||
|
static constexpr auto placeholder_type =
|
||||||
|
proto::common::PlaceholderType::EmbListBFloat16Vector;
|
||||||
|
|
||||||
|
static constexpr bool
|
||||||
|
is_embedding_list() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class EmbListInt8Vector : public VectorTrait {
|
||||||
|
public:
|
||||||
|
using embedded_type = int8;
|
||||||
|
static constexpr int32_t dim_factor = 1;
|
||||||
|
static constexpr auto data_type = DataType::VECTOR_ARRAY;
|
||||||
|
static constexpr auto c_data_type = CDataType::VectorArray;
|
||||||
|
static constexpr auto schema_data_type =
|
||||||
|
proto::schema::DataType::ArrayOfVector;
|
||||||
|
static constexpr auto vector_type =
|
||||||
|
proto::plan::VectorType::EmbListInt8Vector;
|
||||||
|
static constexpr auto placeholder_type =
|
||||||
|
proto::common::PlaceholderType::EmbListInt8Vector;
|
||||||
|
|
||||||
|
static constexpr bool
|
||||||
|
is_embedding_list() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct FundamentalTag {};
|
struct FundamentalTag {};
|
||||||
struct StringTag {};
|
struct StringTag {};
|
||||||
|
|
||||||
|
|||||||
@ -76,7 +76,7 @@ PhyVectorSearchNode::GetOutput() {
|
|||||||
|
|
||||||
auto& ph = placeholder_group_->at(0);
|
auto& ph = placeholder_group_->at(0);
|
||||||
auto src_data = ph.get_blob();
|
auto src_data = ph.get_blob();
|
||||||
auto src_lims = ph.get_lims();
|
auto src_offsets = ph.get_offsets();
|
||||||
auto num_queries = ph.num_of_queries_;
|
auto num_queries = ph.num_of_queries_;
|
||||||
milvus::SearchResult search_result;
|
milvus::SearchResult search_result;
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ PhyVectorSearchNode::GetOutput() {
|
|||||||
auto op_context = query_context_->get_op_context();
|
auto op_context = query_context_->get_op_context();
|
||||||
segment_->vector_search(search_info_,
|
segment_->vector_search(search_info_,
|
||||||
src_data,
|
src_data,
|
||||||
src_lims,
|
src_offsets,
|
||||||
num_queries,
|
num_queries,
|
||||||
query_timestamp_,
|
query_timestamp_,
|
||||||
final_view,
|
final_view,
|
||||||
|
|||||||
@ -234,9 +234,42 @@ IndexFactory::VecIndexLoadResource(
|
|||||||
num_rows,
|
num_rows,
|
||||||
dim,
|
dim,
|
||||||
config);
|
config);
|
||||||
has_raw_data =
|
break;
|
||||||
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(
|
case milvus::DataType::VECTOR_FLOAT16:
|
||||||
index_type, index_version, config);
|
resource = knowhere::IndexStaticFaced<knowhere::fp16>::
|
||||||
|
EstimateLoadResource(index_type,
|
||||||
|
index_version,
|
||||||
|
index_size_in_bytes,
|
||||||
|
num_rows,
|
||||||
|
dim,
|
||||||
|
config);
|
||||||
|
break;
|
||||||
|
case milvus::DataType::VECTOR_BFLOAT16:
|
||||||
|
resource = knowhere::IndexStaticFaced<knowhere::bf16>::
|
||||||
|
EstimateLoadResource(index_type,
|
||||||
|
index_version,
|
||||||
|
index_size_in_bytes,
|
||||||
|
num_rows,
|
||||||
|
dim,
|
||||||
|
config);
|
||||||
|
break;
|
||||||
|
case milvus::DataType::VECTOR_BINARY:
|
||||||
|
resource = knowhere::IndexStaticFaced<knowhere::bin1>::
|
||||||
|
EstimateLoadResource(index_type,
|
||||||
|
index_version,
|
||||||
|
index_size_in_bytes,
|
||||||
|
num_rows,
|
||||||
|
dim,
|
||||||
|
config);
|
||||||
|
break;
|
||||||
|
case milvus::DataType::VECTOR_INT8:
|
||||||
|
resource = knowhere::IndexStaticFaced<knowhere::int8>::
|
||||||
|
EstimateLoadResource(index_type,
|
||||||
|
index_version,
|
||||||
|
index_size_in_bytes,
|
||||||
|
num_rows,
|
||||||
|
dim,
|
||||||
|
config);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@ -247,6 +280,9 @@ IndexFactory::VecIndexLoadResource(
|
|||||||
element_type);
|
element_type);
|
||||||
return LoadResourceRequest{0, 0, 0, 0, true};
|
return LoadResourceRequest{0, 0, 0, 0, true};
|
||||||
}
|
}
|
||||||
|
// For VectorArray, has_raw_data is always false as get_vector of index does not provide offsets which
|
||||||
|
// is required for reconstructing the raw data
|
||||||
|
has_raw_data = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -641,6 +677,42 @@ IndexFactory::CreateVectorIndex(
|
|||||||
version,
|
version,
|
||||||
use_knowhere_build_pool,
|
use_knowhere_build_pool,
|
||||||
file_manager_context);
|
file_manager_context);
|
||||||
|
case DataType::VECTOR_FLOAT16: {
|
||||||
|
return std::make_unique<VectorMemIndex<float16>>(
|
||||||
|
element_type,
|
||||||
|
index_type,
|
||||||
|
metric_type,
|
||||||
|
version,
|
||||||
|
use_knowhere_build_pool,
|
||||||
|
file_manager_context);
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_BFLOAT16: {
|
||||||
|
return std::make_unique<VectorMemIndex<bfloat16>>(
|
||||||
|
element_type,
|
||||||
|
index_type,
|
||||||
|
metric_type,
|
||||||
|
version,
|
||||||
|
use_knowhere_build_pool,
|
||||||
|
file_manager_context);
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_BINARY: {
|
||||||
|
return std::make_unique<VectorMemIndex<bin1>>(
|
||||||
|
element_type,
|
||||||
|
index_type,
|
||||||
|
metric_type,
|
||||||
|
version,
|
||||||
|
use_knowhere_build_pool,
|
||||||
|
file_manager_context);
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_INT8: {
|
||||||
|
return std::make_unique<VectorMemIndex<int8>>(
|
||||||
|
element_type,
|
||||||
|
index_type,
|
||||||
|
metric_type,
|
||||||
|
version,
|
||||||
|
use_knowhere_build_pool,
|
||||||
|
file_manager_context);
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
ThrowInfo(NotImplemented,
|
ThrowInfo(NotImplemented,
|
||||||
fmt::format("not implemented data type to "
|
fmt::format("not implemented data type to "
|
||||||
|
|||||||
@ -234,21 +234,4 @@ void inline SetBitsetGrowing(void* bitset,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline size_t
|
|
||||||
vector_element_size(const DataType data_type) {
|
|
||||||
switch (data_type) {
|
|
||||||
case DataType::VECTOR_FLOAT:
|
|
||||||
return sizeof(float);
|
|
||||||
case DataType::VECTOR_FLOAT16:
|
|
||||||
return sizeof(float16);
|
|
||||||
case DataType::VECTOR_BFLOAT16:
|
|
||||||
return sizeof(bfloat16);
|
|
||||||
case DataType::VECTOR_INT8:
|
|
||||||
return sizeof(int8);
|
|
||||||
default:
|
|
||||||
ThrowInfo(UnexpectedError,
|
|
||||||
fmt::format("invalid data type: {}", data_type));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace milvus::index
|
} // namespace milvus::index
|
||||||
|
|||||||
@ -313,11 +313,6 @@ VectorMemIndex<T>::BuildWithDataset(const DatasetPtr& dataset,
|
|||||||
SetDim(index_.Dim());
|
SetDim(index_.Dim());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool
|
|
||||||
is_embedding_list_index(const IndexType& index_type) {
|
|
||||||
return index_type == knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void
|
void
|
||||||
VectorMemIndex<T>::Build(const Config& config) {
|
VectorMemIndex<T>::Build(const Config& config) {
|
||||||
@ -344,30 +339,19 @@ VectorMemIndex<T>::Build(const Config& config) {
|
|||||||
total_size += data->Size();
|
total_size += data->Size();
|
||||||
total_num_rows += data->get_num_rows();
|
total_num_rows += data->get_num_rows();
|
||||||
|
|
||||||
// todo(SapdeA): now, vector arrays (embedding list) are serialized
|
|
||||||
// to parquet by using binary format which does not provide dim
|
|
||||||
// information so we use this temporary solution.
|
|
||||||
if (is_embedding_list_index(index_type_)) {
|
|
||||||
AssertInfo(elem_type_ != DataType::NONE,
|
|
||||||
"embedding list index must have elem_type");
|
|
||||||
dim = config[DIM_KEY].get<int64_t>();
|
|
||||||
} else {
|
|
||||||
AssertInfo(dim == 0 || dim == data->get_dim(),
|
AssertInfo(dim == 0 || dim == data->get_dim(),
|
||||||
"inconsistent dim value between field datas!");
|
"inconsistent dim value between field datas!");
|
||||||
|
|
||||||
dim = data->get_dim();
|
dim = data->get_dim();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
auto buf = std::shared_ptr<uint8_t[]>(new uint8_t[total_size]);
|
auto buf = std::shared_ptr<uint8_t[]>(new uint8_t[total_size]);
|
||||||
|
|
||||||
size_t lim_offset = 0;
|
size_t lim_offset = 0;
|
||||||
std::vector<size_t> lims;
|
std::vector<size_t> offsets;
|
||||||
lims.reserve(total_num_rows + 1);
|
|
||||||
lims.push_back(lim_offset);
|
|
||||||
|
|
||||||
int64_t offset = 0;
|
int64_t offset = 0;
|
||||||
if (!is_embedding_list_index(index_type_)) {
|
// For embedding list index, elem_type_ is not NONE
|
||||||
|
if (elem_type_ == DataType::NONE) {
|
||||||
// TODO: avoid copying
|
// TODO: avoid copying
|
||||||
for (auto data : field_datas) {
|
for (auto data : field_datas) {
|
||||||
std::memcpy(buf.get() + offset, data->Data(), data->Size());
|
std::memcpy(buf.get() + offset, data->Data(), data->Size());
|
||||||
@ -375,7 +359,9 @@ VectorMemIndex<T>::Build(const Config& config) {
|
|||||||
data.reset();
|
data.reset();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto elem_size = vector_element_size(elem_type_);
|
offsets.reserve(total_num_rows + 1);
|
||||||
|
offsets.push_back(lim_offset);
|
||||||
|
auto bytes_per_vec = vector_bytes_per_element(elem_type_, dim);
|
||||||
for (auto data : field_datas) {
|
for (auto data : field_datas) {
|
||||||
auto vec_array_data =
|
auto vec_array_data =
|
||||||
dynamic_cast<FieldData<VectorArray>*>(data.get());
|
dynamic_cast<FieldData<VectorArray>*>(data.get());
|
||||||
@ -385,16 +371,16 @@ VectorMemIndex<T>::Build(const Config& config) {
|
|||||||
auto rows = vec_array_data->get_num_rows();
|
auto rows = vec_array_data->get_num_rows();
|
||||||
for (auto i = 0; i < rows; ++i) {
|
for (auto i = 0; i < rows; ++i) {
|
||||||
auto size = vec_array_data->DataSize(i);
|
auto size = vec_array_data->DataSize(i);
|
||||||
assert(size % (dim * elem_size) == 0);
|
assert(size % bytes_per_vec == 0);
|
||||||
assert(dim * elem_size != 0);
|
assert(bytes_per_vec != 0);
|
||||||
|
|
||||||
auto vec_array = vec_array_data->value_at(i);
|
auto vec_array = vec_array_data->value_at(i);
|
||||||
|
|
||||||
std::memcpy(buf.get() + offset, vec_array->data(), size);
|
std::memcpy(buf.get() + offset, vec_array->data(), size);
|
||||||
offset += size;
|
offset += size;
|
||||||
|
|
||||||
lim_offset += size / (dim * elem_size);
|
lim_offset += size / bytes_per_vec;
|
||||||
lims.push_back(lim_offset);
|
offsets.push_back(lim_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(data->Size() == offset);
|
assert(data->Size() == offset);
|
||||||
@ -411,8 +397,9 @@ VectorMemIndex<T>::Build(const Config& config) {
|
|||||||
if (!scalar_info.empty()) {
|
if (!scalar_info.empty()) {
|
||||||
dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info));
|
dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info));
|
||||||
}
|
}
|
||||||
if (!lims.empty()) {
|
if (!offsets.empty()) {
|
||||||
dataset->SetLims(lims.data());
|
dataset->Set(knowhere::meta::EMB_LIST_OFFSET,
|
||||||
|
const_cast<const size_t*>(offsets.data()));
|
||||||
}
|
}
|
||||||
BuildWithDataset(dataset, build_config);
|
BuildWithDataset(dataset, build_config);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -282,12 +282,12 @@ class ChunkedColumnBase : public ChunkedColumnInterface {
|
|||||||
"VectorArrayViews only supported for ChunkedVectorArrayColumn");
|
"VectorArrayViews only supported for ChunkedVectorArrayColumn");
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual PinWrapper<const size_t*>
|
PinWrapper<const size_t*>
|
||||||
VectorArrayLims(milvus::OpContext* op_ctx,
|
VectorArrayOffsets(milvus::OpContext* op_ctx,
|
||||||
int64_t chunk_id) const override {
|
int64_t chunk_id) const override {
|
||||||
ThrowInfo(
|
ThrowInfo(
|
||||||
ErrorCode::Unsupported,
|
ErrorCode::Unsupported,
|
||||||
"VectorArrayLims only supported for ChunkedVectorArrayColumn");
|
"VectorArrayOffsets only supported for ChunkedVectorArrayColumn");
|
||||||
}
|
}
|
||||||
|
|
||||||
PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>>
|
PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>>
|
||||||
@ -691,13 +691,13 @@ class ChunkedVectorArrayColumn : public ChunkedColumnBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PinWrapper<const size_t*>
|
PinWrapper<const size_t*>
|
||||||
VectorArrayLims(milvus::OpContext* op_ctx,
|
VectorArrayOffsets(milvus::OpContext* op_ctx,
|
||||||
int64_t chunk_id) const override {
|
int64_t chunk_id) const override {
|
||||||
auto ca = SemiInlineGet(
|
auto ca = SemiInlineGet(
|
||||||
slot_->PinCells(op_ctx, {static_cast<cid_t>(chunk_id)}));
|
slot_->PinCells(op_ctx, {static_cast<cid_t>(chunk_id)}));
|
||||||
auto chunk = ca->get_cell_of(chunk_id);
|
auto chunk = ca->get_cell_of(chunk_id);
|
||||||
return PinWrapper<const size_t*>(
|
return PinWrapper<const size_t*>(
|
||||||
ca, static_cast<VectorArrayChunk*>(chunk)->Lims());
|
ca, static_cast<VectorArrayChunk*>(chunk)->Offsets());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -328,17 +328,18 @@ class ProxyChunkColumn : public ChunkedColumnInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
PinWrapper<const size_t*>
|
PinWrapper<const size_t*>
|
||||||
VectorArrayLims(milvus::OpContext* op_ctx,
|
VectorArrayOffsets(milvus::OpContext* op_ctx,
|
||||||
int64_t chunk_id) const override {
|
int64_t chunk_id) const override {
|
||||||
if (!IsChunkedVectorArrayColumnDataType(data_type_)) {
|
if (!IsChunkedVectorArrayColumnDataType(data_type_)) {
|
||||||
ThrowInfo(ErrorCode::Unsupported,
|
ThrowInfo(ErrorCode::Unsupported,
|
||||||
"VectorArrayLims only supported for "
|
"VectorArrayOffsets only supported for "
|
||||||
"ChunkedVectorArrayColumn");
|
"ChunkedVectorArrayColumn");
|
||||||
}
|
}
|
||||||
auto chunk_wrapper = group_->GetGroupChunk(op_ctx, chunk_id);
|
auto chunk_wrapper = group_->GetGroupChunk(op_ctx, chunk_id);
|
||||||
auto chunk = chunk_wrapper.get()->GetChunk(field_id_);
|
auto chunk = chunk_wrapper.get()->GetChunk(field_id_);
|
||||||
return PinWrapper<const size_t*>(
|
return PinWrapper<const size_t*>(
|
||||||
chunk_wrapper, static_cast<VectorArrayChunk*>(chunk.get())->Lims());
|
chunk_wrapper,
|
||||||
|
static_cast<VectorArrayChunk*>(chunk.get())->Offsets());
|
||||||
}
|
}
|
||||||
|
|
||||||
PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>>
|
PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>>
|
||||||
|
|||||||
@ -92,7 +92,7 @@ class ChunkedColumnInterface {
|
|||||||
std::optional<std::pair<int64_t, int64_t>> offset_len) const = 0;
|
std::optional<std::pair<int64_t, int64_t>> offset_len) const = 0;
|
||||||
|
|
||||||
virtual PinWrapper<const size_t*>
|
virtual PinWrapper<const size_t*>
|
||||||
VectorArrayLims(milvus::OpContext* op_ctx, int64_t chunk_id) const = 0;
|
VectorArrayOffsets(milvus::OpContext* op_ctx, int64_t chunk_id) const = 0;
|
||||||
|
|
||||||
virtual PinWrapper<
|
virtual PinWrapper<
|
||||||
std::pair<std::vector<std::string_view>, FixedVector<bool>>>
|
std::pair<std::vector<std::string_view>, FixedVector<bool>>>
|
||||||
|
|||||||
@ -34,8 +34,23 @@ bool
|
|||||||
check_data_type(const FieldMeta& field_meta,
|
check_data_type(const FieldMeta& field_meta,
|
||||||
const milvus::proto::common::PlaceholderType type) {
|
const milvus::proto::common::PlaceholderType type) {
|
||||||
if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
||||||
|
if (field_meta.get_element_type() == DataType::VECTOR_FLOAT) {
|
||||||
return type ==
|
return type ==
|
||||||
milvus::proto::common::PlaceholderType::EmbListFloatVector;
|
milvus::proto::common::PlaceholderType::EmbListFloatVector;
|
||||||
|
} else if (field_meta.get_element_type() == DataType::VECTOR_FLOAT16) {
|
||||||
|
return type ==
|
||||||
|
milvus::proto::common::PlaceholderType::EmbListFloat16Vector;
|
||||||
|
} else if (field_meta.get_element_type() == DataType::VECTOR_BFLOAT16) {
|
||||||
|
return type == milvus::proto::common::PlaceholderType::
|
||||||
|
EmbListBFloat16Vector;
|
||||||
|
} else if (field_meta.get_element_type() == DataType::VECTOR_BINARY) {
|
||||||
|
return type ==
|
||||||
|
milvus::proto::common::PlaceholderType::EmbListBinaryVector;
|
||||||
|
} else if (field_meta.get_element_type() == DataType::VECTOR_INT8) {
|
||||||
|
return type ==
|
||||||
|
milvus::proto::common::PlaceholderType::EmbListInt8Vector;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
return static_cast<int>(field_meta.get_data_type()) ==
|
return static_cast<int>(field_meta.get_data_type()) ==
|
||||||
static_cast<int>(type);
|
static_cast<int>(type);
|
||||||
@ -96,25 +111,25 @@ ParsePlaceholderGroup(const Plan* plan,
|
|||||||
// If the vector is embedding list, line contains multiple vectors.
|
// If the vector is embedding list, line contains multiple vectors.
|
||||||
// And we should record the offsets so that we can identify each
|
// And we should record the offsets so that we can identify each
|
||||||
// embedding list in a flattened vectors.
|
// embedding list in a flattened vectors.
|
||||||
auto& lims = element.lims_;
|
auto& offsets = element.offsets_;
|
||||||
lims.reserve(element.num_of_queries_ + 1);
|
offsets.reserve(element.num_of_queries_ + 1);
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
lims.push_back(offset);
|
offsets.push_back(offset);
|
||||||
|
|
||||||
auto elem_size = milvus::index::vector_element_size(
|
auto bytes_per_vec = milvus::vector_bytes_per_element(
|
||||||
field_meta.get_element_type());
|
field_meta.get_element_type(), dim);
|
||||||
for (auto& line : info.values()) {
|
for (auto& line : info.values()) {
|
||||||
target.insert(target.end(), line.begin(), line.end());
|
target.insert(target.end(), line.begin(), line.end());
|
||||||
AssertInfo(
|
AssertInfo(
|
||||||
line.size() % (dim * elem_size) == 0,
|
line.size() % bytes_per_vec == 0,
|
||||||
"line.size() % (dim * elem_size) == 0 assert failed, "
|
"line.size() % bytes_per_vec == 0 assert failed, "
|
||||||
"line.size() = {}, dim = {}, elem_size = {}",
|
"line.size() = {}, dim = {}, bytes_per_vec = {}",
|
||||||
line.size(),
|
line.size(),
|
||||||
dim,
|
dim,
|
||||||
elem_size);
|
bytes_per_vec);
|
||||||
|
|
||||||
offset += line.size() / (dim * elem_size);
|
offset += line.size() / bytes_per_vec;
|
||||||
lims.push_back(offset);
|
offsets.push_back(offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -70,8 +70,8 @@ struct Plan {
|
|||||||
struct Placeholder {
|
struct Placeholder {
|
||||||
std::string tag_;
|
std::string tag_;
|
||||||
// note: for embedding list search, num_of_queries_ stands for the number of vectors.
|
// note: for embedding list search, num_of_queries_ stands for the number of vectors.
|
||||||
// lims_ records the offsets of embedding list in the flattened vector and
|
// offsets_ records the offsets of embedding list in the flattened vector and
|
||||||
// hence lims_.size() - 1 is the number of queries in embedding list search.
|
// hence offsets_.size() - 1 is the number of queries in embedding list search.
|
||||||
int64_t num_of_queries_;
|
int64_t num_of_queries_;
|
||||||
// TODO(SPARSE): add a dim_ field here, use the dim passed in search request
|
// TODO(SPARSE): add a dim_ field here, use the dim passed in search request
|
||||||
// instead of the dim in schema, since the dim of sparse float column is
|
// instead of the dim in schema, since the dim of sparse float column is
|
||||||
@ -84,7 +84,7 @@ struct Placeholder {
|
|||||||
std::unique_ptr<knowhere::sparse::SparseRow<SparseValueType>[]>
|
std::unique_ptr<knowhere::sparse::SparseRow<SparseValueType>[]>
|
||||||
sparse_matrix_;
|
sparse_matrix_;
|
||||||
// offsets for embedding list
|
// offsets for embedding list
|
||||||
aligned_vector<size_t> lims_;
|
aligned_vector<size_t> offsets_;
|
||||||
|
|
||||||
const void*
|
const void*
|
||||||
get_blob() const {
|
get_blob() const {
|
||||||
@ -103,13 +103,13 @@ struct Placeholder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const size_t*
|
const size_t*
|
||||||
get_lims() const {
|
get_offsets() const {
|
||||||
return lims_.data();
|
return offsets_.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t*
|
size_t*
|
||||||
get_lims() {
|
get_offsets() {
|
||||||
return lims_.data();
|
return offsets_.data();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -89,21 +89,25 @@ PrepareBFDataSet(const dataset::SearchDataset& query_ds,
|
|||||||
DataType data_type) {
|
DataType data_type) {
|
||||||
auto base_dataset =
|
auto base_dataset =
|
||||||
knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data);
|
knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data);
|
||||||
if (raw_ds.raw_data_lims != nullptr) {
|
if (raw_ds.raw_data_offsets != nullptr) {
|
||||||
// knowhere::DataSet count vectors in a flattened manner where as the num_raw_data here is the number
|
// knowhere::DataSet count vectors in a flattened manner where as the num_raw_data here is the number
|
||||||
// of embedding lists where each embedding list contains multiple vectors. So we should use the last element
|
// of embedding lists where each embedding list contains multiple vectors. So we should use the last element
|
||||||
// in lims which equals to the total number of vectors.
|
// in offsets which equals to the total number of vectors.
|
||||||
base_dataset->SetLims(raw_ds.raw_data_lims);
|
base_dataset->Set(knowhere::meta::EMB_LIST_OFFSET,
|
||||||
// the length of lims equals to the number of embedding lists + 1
|
raw_ds.raw_data_offsets);
|
||||||
base_dataset->SetRows(raw_ds.raw_data_lims[raw_ds.num_raw_data]);
|
|
||||||
|
// the length of offsets equals to the number of embedding lists + 1
|
||||||
|
base_dataset->SetRows(raw_ds.raw_data_offsets[raw_ds.num_raw_data]);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto query_dataset = knowhere::GenDataSet(
|
auto query_dataset = knowhere::GenDataSet(
|
||||||
query_ds.num_queries, query_ds.dim, query_ds.query_data);
|
query_ds.num_queries, query_ds.dim, query_ds.query_data);
|
||||||
if (query_ds.query_lims != nullptr) {
|
if (query_ds.query_offsets != nullptr) {
|
||||||
// ditto
|
// ditto
|
||||||
query_dataset->SetLims(query_ds.query_lims);
|
query_dataset->Set(knowhere::meta::EMB_LIST_OFFSET,
|
||||||
query_dataset->SetRows(query_ds.query_lims[query_ds.num_queries]);
|
query_ds.query_offsets);
|
||||||
|
|
||||||
|
query_dataset->SetRows(query_ds.query_offsets[query_ds.num_queries]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
|
if (data_type == DataType::VECTOR_SPARSE_U32_F32) {
|
||||||
|
|||||||
@ -73,7 +73,7 @@ void
|
|||||||
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||||
const SearchInfo& info,
|
const SearchInfo& info,
|
||||||
const void* query_data,
|
const void* query_data,
|
||||||
const size_t* query_lims,
|
const size_t* query_offsets,
|
||||||
int64_t num_queries,
|
int64_t num_queries,
|
||||||
Timestamp timestamp,
|
Timestamp timestamp,
|
||||||
const BitsetView& bitset,
|
const BitsetView& bitset,
|
||||||
@ -141,7 +141,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
|||||||
round_decimal,
|
round_decimal,
|
||||||
dim,
|
dim,
|
||||||
query_data,
|
query_data,
|
||||||
query_lims};
|
query_offsets};
|
||||||
int32_t current_chunk_id = 0;
|
int32_t current_chunk_id = 0;
|
||||||
|
|
||||||
// get K1 and B from index for bm25 brute force
|
// get K1 and B from index for bm25 brute force
|
||||||
@ -222,8 +222,8 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
|||||||
|
|
||||||
if (data_type == DataType::VECTOR_ARRAY) {
|
if (data_type == DataType::VECTOR_ARRAY) {
|
||||||
AssertInfo(
|
AssertInfo(
|
||||||
query_lims != nullptr,
|
query_offsets != nullptr,
|
||||||
"query_lims is nullptr, but data_type is vector array");
|
"query_offsets is nullptr, but data_type is vector array");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (milvus::exec::UseVectorIterator(info)) {
|
if (milvus::exec::UseVectorIterator(info)) {
|
||||||
|
|||||||
@ -21,7 +21,7 @@ void
|
|||||||
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||||
const SearchInfo& info,
|
const SearchInfo& info,
|
||||||
const void* query_data,
|
const void* query_data,
|
||||||
const size_t* query_lims,
|
const size_t* query_offsets,
|
||||||
int64_t num_queries,
|
int64_t num_queries,
|
||||||
Timestamp timestamp,
|
Timestamp timestamp,
|
||||||
const BitsetView& bitset,
|
const BitsetView& bitset,
|
||||||
|
|||||||
@ -31,7 +31,7 @@ SearchOnSealedIndex(const Schema& schema,
|
|||||||
const segcore::SealedIndexingRecord& record,
|
const segcore::SealedIndexingRecord& record,
|
||||||
const SearchInfo& search_info,
|
const SearchInfo& search_info,
|
||||||
const void* query_data,
|
const void* query_data,
|
||||||
const size_t* query_lims,
|
const size_t* query_offsets,
|
||||||
int64_t num_queries,
|
int64_t num_queries,
|
||||||
const BitsetView& bitset,
|
const BitsetView& bitset,
|
||||||
milvus::OpContext* op_context,
|
milvus::OpContext* op_context,
|
||||||
@ -55,15 +55,15 @@ SearchOnSealedIndex(const Schema& schema,
|
|||||||
search_info.metric_type_);
|
search_info.metric_type_);
|
||||||
|
|
||||||
knowhere::DataSetPtr dataset;
|
knowhere::DataSetPtr dataset;
|
||||||
if (query_lims == nullptr) {
|
if (query_offsets == nullptr) {
|
||||||
dataset = knowhere::GenDataSet(num_queries, dim, query_data);
|
dataset = knowhere::GenDataSet(num_queries, dim, query_data);
|
||||||
} else {
|
} else {
|
||||||
// Rather than non-embedding list search where num_queries equals to the number of vectors,
|
// Rather than non-embedding list search where num_queries equals to the number of vectors,
|
||||||
// in embedding list search, multiple vectors form an embedding list and the last element of query_lims
|
// in embedding list search, multiple vectors form an embedding list and the last element of query_offsets
|
||||||
// stands for the total number of vectors.
|
// stands for the total number of vectors.
|
||||||
auto num_vectors = query_lims[num_queries];
|
auto num_vectors = query_offsets[num_queries];
|
||||||
dataset = knowhere::GenDataSet(num_vectors, dim, query_data);
|
dataset = knowhere::GenDataSet(num_vectors, dim, query_data);
|
||||||
dataset->SetLims(query_lims);
|
dataset->Set(knowhere::meta::EMB_LIST_OFFSET, query_offsets);
|
||||||
}
|
}
|
||||||
|
|
||||||
dataset->SetIsSparse(is_sparse);
|
dataset->SetIsSparse(is_sparse);
|
||||||
@ -107,7 +107,7 @@ SearchOnSealedColumn(const Schema& schema,
|
|||||||
const SearchInfo& search_info,
|
const SearchInfo& search_info,
|
||||||
const std::map<std::string, std::string>& index_info,
|
const std::map<std::string, std::string>& index_info,
|
||||||
const void* query_data,
|
const void* query_data,
|
||||||
const size_t* query_lims,
|
const size_t* query_offsets,
|
||||||
int64_t num_queries,
|
int64_t num_queries,
|
||||||
int64_t row_count,
|
int64_t row_count,
|
||||||
const BitsetView& bitview,
|
const BitsetView& bitview,
|
||||||
@ -128,7 +128,7 @@ SearchOnSealedColumn(const Schema& schema,
|
|||||||
search_info.round_decimal_,
|
search_info.round_decimal_,
|
||||||
dim,
|
dim,
|
||||||
query_data,
|
query_data,
|
||||||
query_lims};
|
query_offsets};
|
||||||
|
|
||||||
CheckBruteForceSearchParam(field, search_info);
|
CheckBruteForceSearchParam(field, search_info);
|
||||||
|
|
||||||
@ -158,13 +158,14 @@ SearchOnSealedColumn(const Schema& schema,
|
|||||||
auto raw_dataset =
|
auto raw_dataset =
|
||||||
query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
|
query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
|
||||||
|
|
||||||
PinWrapper<const size_t*> lims_pw;
|
PinWrapper<const size_t*> offsets_pw;
|
||||||
if (data_type == DataType::VECTOR_ARRAY) {
|
if (data_type == DataType::VECTOR_ARRAY) {
|
||||||
AssertInfo(query_lims != nullptr,
|
AssertInfo(
|
||||||
"query_lims is nullptr, but data_type is vector array");
|
query_offsets != nullptr,
|
||||||
|
"query_offsets is nullptr, but data_type is vector array");
|
||||||
|
|
||||||
lims_pw = column->VectorArrayLims(op_context, i);
|
offsets_pw = column->VectorArrayOffsets(op_context, i);
|
||||||
raw_dataset.raw_data_lims = lims_pw.get();
|
raw_dataset.raw_data_offsets = offsets_pw.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (milvus::exec::UseVectorIterator(search_info)) {
|
if (milvus::exec::UseVectorIterator(search_info)) {
|
||||||
|
|||||||
@ -23,7 +23,7 @@ SearchOnSealedIndex(const Schema& schema,
|
|||||||
const segcore::SealedIndexingRecord& record,
|
const segcore::SealedIndexingRecord& record,
|
||||||
const SearchInfo& search_info,
|
const SearchInfo& search_info,
|
||||||
const void* query_data,
|
const void* query_data,
|
||||||
const size_t* query_lims,
|
const size_t* query_offsets,
|
||||||
int64_t num_queries,
|
int64_t num_queries,
|
||||||
const BitsetView& view,
|
const BitsetView& view,
|
||||||
milvus::OpContext* op_context,
|
milvus::OpContext* op_context,
|
||||||
@ -35,7 +35,7 @@ SearchOnSealedColumn(const Schema& schema,
|
|||||||
const SearchInfo& search_info,
|
const SearchInfo& search_info,
|
||||||
const std::map<std::string, std::string>& index_info,
|
const std::map<std::string, std::string>& index_info,
|
||||||
const void* query_data,
|
const void* query_data,
|
||||||
const size_t* query_lims,
|
const size_t* query_offsets,
|
||||||
int64_t num_queries,
|
int64_t num_queries,
|
||||||
int64_t row_count,
|
int64_t row_count,
|
||||||
const BitsetView& bitset,
|
const BitsetView& bitset,
|
||||||
|
|||||||
@ -24,7 +24,7 @@ struct RawDataset {
|
|||||||
int64_t dim;
|
int64_t dim;
|
||||||
int64_t num_raw_data;
|
int64_t num_raw_data;
|
||||||
const void* raw_data;
|
const void* raw_data;
|
||||||
const size_t* raw_data_lims = nullptr;
|
const size_t* raw_data_offsets = nullptr;
|
||||||
};
|
};
|
||||||
struct SearchDataset {
|
struct SearchDataset {
|
||||||
knowhere::MetricType metric_type;
|
knowhere::MetricType metric_type;
|
||||||
@ -34,7 +34,7 @@ struct SearchDataset {
|
|||||||
int64_t dim;
|
int64_t dim;
|
||||||
const void* query_data;
|
const void* query_data;
|
||||||
// used for embedding list query
|
// used for embedding list query
|
||||||
const size_t* query_lims = nullptr;
|
const size_t* query_offsets = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|||||||
@ -756,7 +756,7 @@ ChunkedSegmentSealedImpl::mask_with_delete(BitsetTypeView& bitset,
|
|||||||
void
|
void
|
||||||
ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
|
ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
|
||||||
const void* query_data,
|
const void* query_data,
|
||||||
const size_t* query_lims,
|
const size_t* query_offsets,
|
||||||
int64_t query_count,
|
int64_t query_count,
|
||||||
Timestamp timestamp,
|
Timestamp timestamp,
|
||||||
const BitsetView& bitset,
|
const BitsetView& bitset,
|
||||||
@ -783,7 +783,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
|
|||||||
vector_indexings_,
|
vector_indexings_,
|
||||||
binlog_search_info,
|
binlog_search_info,
|
||||||
query_data,
|
query_data,
|
||||||
query_lims,
|
query_offsets,
|
||||||
query_count,
|
query_count,
|
||||||
bitset,
|
bitset,
|
||||||
op_context,
|
op_context,
|
||||||
@ -798,7 +798,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
|
|||||||
vector_indexings_,
|
vector_indexings_,
|
||||||
search_info,
|
search_info,
|
||||||
query_data,
|
query_data,
|
||||||
query_lims,
|
query_offsets,
|
||||||
query_count,
|
query_count,
|
||||||
bitset,
|
bitset,
|
||||||
op_context,
|
op_context,
|
||||||
@ -826,7 +826,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
|
|||||||
search_info,
|
search_info,
|
||||||
index_info,
|
index_info,
|
||||||
query_data,
|
query_data,
|
||||||
query_lims,
|
query_offsets,
|
||||||
query_count,
|
query_count,
|
||||||
row_count,
|
row_count,
|
||||||
bitset,
|
bitset,
|
||||||
|
|||||||
@ -459,7 +459,7 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
|
|||||||
void
|
void
|
||||||
vector_search(SearchInfo& search_info,
|
vector_search(SearchInfo& search_info,
|
||||||
const void* query_data,
|
const void* query_data,
|
||||||
const size_t* query_lims,
|
const size_t* query_offsets,
|
||||||
int64_t query_count,
|
int64_t query_count,
|
||||||
Timestamp timestamp,
|
Timestamp timestamp,
|
||||||
const BitsetView& bitset,
|
const BitsetView& bitset,
|
||||||
|
|||||||
@ -696,7 +696,7 @@ SegmentGrowingImpl::search_batch_pks(
|
|||||||
void
|
void
|
||||||
SegmentGrowingImpl::vector_search(SearchInfo& search_info,
|
SegmentGrowingImpl::vector_search(SearchInfo& search_info,
|
||||||
const void* query_data,
|
const void* query_data,
|
||||||
const size_t* query_lims,
|
const size_t* query_offsets,
|
||||||
int64_t query_count,
|
int64_t query_count,
|
||||||
Timestamp timestamp,
|
Timestamp timestamp,
|
||||||
const BitsetView& bitset,
|
const BitsetView& bitset,
|
||||||
@ -705,7 +705,7 @@ SegmentGrowingImpl::vector_search(SearchInfo& search_info,
|
|||||||
query::SearchOnGrowing(*this,
|
query::SearchOnGrowing(*this,
|
||||||
search_info,
|
search_info,
|
||||||
query_data,
|
query_data,
|
||||||
query_lims,
|
query_offsets,
|
||||||
query_count,
|
query_count,
|
||||||
timestamp,
|
timestamp,
|
||||||
bitset,
|
bitset,
|
||||||
|
|||||||
@ -359,7 +359,7 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
|||||||
void
|
void
|
||||||
vector_search(SearchInfo& search_info,
|
vector_search(SearchInfo& search_info,
|
||||||
const void* query_data,
|
const void* query_data,
|
||||||
const size_t* query_lims,
|
const size_t* query_offsets,
|
||||||
int64_t query_count,
|
int64_t query_count,
|
||||||
Timestamp timestamp,
|
Timestamp timestamp,
|
||||||
const BitsetView& bitset,
|
const BitsetView& bitset,
|
||||||
|
|||||||
@ -581,7 +581,7 @@ TEST(GrowingTest, SearchVectorArray) {
|
|||||||
config.set_enable_interim_segment_index(true);
|
config.set_enable_interim_segment_index(true);
|
||||||
|
|
||||||
std::map<std::string, std::string> index_params = {
|
std::map<std::string, std::string> index_params = {
|
||||||
{"index_type", knowhere::IndexEnum::INDEX_EMB_LIST_HNSW},
|
{"index_type", knowhere::IndexEnum::INDEX_HNSW},
|
||||||
{"metric_type", metric_type},
|
{"metric_type", metric_type},
|
||||||
{"nlist", "128"}};
|
{"nlist", "128"}};
|
||||||
std::map<std::string, std::string> type_params = {
|
std::map<std::string, std::string> type_params = {
|
||||||
@ -612,11 +612,11 @@ TEST(GrowingTest, SearchVectorArray) {
|
|||||||
int vec_num = 10; // Total number of query vectors
|
int vec_num = 10; // Total number of query vectors
|
||||||
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
|
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
|
||||||
|
|
||||||
// Create query dataset with lims for VectorArray
|
// Create query dataset with offsets for VectorArray
|
||||||
std::vector<size_t> query_vec_lims;
|
std::vector<size_t> query_vec_offsets;
|
||||||
query_vec_lims.push_back(0); // First query has 3 vectors
|
query_vec_offsets.push_back(0); // First query has 3 vectors
|
||||||
query_vec_lims.push_back(3);
|
query_vec_offsets.push_back(3);
|
||||||
query_vec_lims.push_back(10); // Second query has 7 vectors
|
query_vec_offsets.push_back(10); // Second query has 7 vectors
|
||||||
|
|
||||||
// Create search plan
|
// Create search plan
|
||||||
const char* raw_plan = R"(vector_anns: <
|
const char* raw_plan = R"(vector_anns: <
|
||||||
@ -636,7 +636,7 @@ TEST(GrowingTest, SearchVectorArray) {
|
|||||||
|
|
||||||
// Use CreatePlaceholderGroupFromBlob for VectorArray
|
// Use CreatePlaceholderGroupFromBlob for VectorArray
|
||||||
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
|
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
|
||||||
vec_num, dim, query_vec.data(), query_vec_lims);
|
vec_num, dim, query_vec.data(), query_vec_offsets);
|
||||||
auto ph_group =
|
auto ph_group =
|
||||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||||
|
|
||||||
|
|||||||
@ -373,14 +373,14 @@ class SegmentInternalInterface : public SegmentInterface {
|
|||||||
const std::string& nested_path) const override;
|
const std::string& nested_path) const override;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// `query_lims` is not null only for vector array (embedding list) search
|
// `query_offsets` is not null only for vector array (embedding list) search
|
||||||
// where it denotes the number of vectors in each embedding list. The length
|
// where it denotes the number of vectors in each embedding list. The length
|
||||||
// of `query_lims` is the number of queries in the search plus one (the first
|
// of `query_offsets` is the number of queries in the search plus one (the first
|
||||||
// element in query_lims is 0).
|
// element in query_offsets is 0).
|
||||||
virtual void
|
virtual void
|
||||||
vector_search(SearchInfo& search_info,
|
vector_search(SearchInfo& search_info,
|
||||||
const void* query_data,
|
const void* query_data,
|
||||||
const size_t* query_lims,
|
const size_t* query_offsets,
|
||||||
int64_t query_count,
|
int64_t query_count,
|
||||||
Timestamp timestamp,
|
Timestamp timestamp,
|
||||||
const BitsetView& bitset,
|
const BitsetView& bitset,
|
||||||
|
|||||||
@ -35,7 +35,7 @@ struct LoadIndexInfo {
|
|||||||
int64_t segment_id;
|
int64_t segment_id;
|
||||||
int64_t field_id;
|
int64_t field_id;
|
||||||
DataType field_type;
|
DataType field_type;
|
||||||
// The element type of the field. It's DataType::NONE if field_type is array/vector_array.
|
// The element type of the field. It's not DataType::NONE if field_type is array/vector_array.
|
||||||
DataType element_type;
|
DataType element_type;
|
||||||
bool enable_mmap;
|
bool enable_mmap;
|
||||||
std::string mmap_dir_path;
|
std::string mmap_dir_path;
|
||||||
|
|||||||
@ -611,22 +611,40 @@ CreateVectorDataArrayFrom(const void* data_raw,
|
|||||||
case DataType::VECTOR_ARRAY: {
|
case DataType::VECTOR_ARRAY: {
|
||||||
auto data = reinterpret_cast<const VectorFieldProto*>(data_raw);
|
auto data = reinterpret_cast<const VectorFieldProto*>(data_raw);
|
||||||
auto vector_type = field_meta.get_element_type();
|
auto vector_type = field_meta.get_element_type();
|
||||||
switch (vector_type) {
|
|
||||||
case DataType::VECTOR_FLOAT: {
|
|
||||||
auto obj = vector_array->mutable_vector_array();
|
auto obj = vector_array->mutable_vector_array();
|
||||||
|
obj->set_dim(dim);
|
||||||
|
|
||||||
|
// Set element type based on vector type
|
||||||
|
switch (vector_type) {
|
||||||
|
case DataType::VECTOR_FLOAT:
|
||||||
obj->set_element_type(
|
obj->set_element_type(
|
||||||
milvus::proto::schema::DataType::FloatVector);
|
milvus::proto::schema::DataType::FloatVector);
|
||||||
obj->set_dim(dim);
|
|
||||||
for (auto i = 0; i < count; i++) {
|
|
||||||
*(obj->mutable_data()->Add()) = data[i];
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
case DataType::VECTOR_FLOAT16:
|
||||||
default: {
|
obj->set_element_type(
|
||||||
|
milvus::proto::schema::DataType::Float16Vector);
|
||||||
|
break;
|
||||||
|
case DataType::VECTOR_BFLOAT16:
|
||||||
|
obj->set_element_type(
|
||||||
|
milvus::proto::schema::DataType::BFloat16Vector);
|
||||||
|
break;
|
||||||
|
case DataType::VECTOR_BINARY:
|
||||||
|
obj->set_element_type(
|
||||||
|
milvus::proto::schema::DataType::BinaryVector);
|
||||||
|
break;
|
||||||
|
case DataType::VECTOR_INT8:
|
||||||
|
obj->set_element_type(
|
||||||
|
milvus::proto::schema::DataType::Int8Vector);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
ThrowInfo(NotImplemented,
|
ThrowInfo(NotImplemented,
|
||||||
fmt::format("not implemented vector type {}",
|
fmt::format("not implemented vector type {}",
|
||||||
vector_type));
|
vector_type));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add all vector data
|
||||||
|
for (auto i = 0; i < count; i++) {
|
||||||
|
*(obj->mutable_data()->Add()) = data[i];
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -241,15 +241,28 @@ AddPayloadToArrowBuilder(std::shared_ptr<arrow::ArrayBuilder> builder,
|
|||||||
if (length > 0) {
|
if (length > 0) {
|
||||||
auto element_type = vector_arrays[0].get_element_type();
|
auto element_type = vector_arrays[0].get_element_type();
|
||||||
|
|
||||||
|
// Validate element type
|
||||||
switch (element_type) {
|
switch (element_type) {
|
||||||
case DataType::VECTOR_FLOAT: {
|
case DataType::VECTOR_FLOAT:
|
||||||
auto value_builder = static_cast<arrow::FloatBuilder*>(
|
case DataType::VECTOR_BINARY:
|
||||||
|
case DataType::VECTOR_FLOAT16:
|
||||||
|
case DataType::VECTOR_BFLOAT16:
|
||||||
|
case DataType::VECTOR_INT8:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
ThrowInfo(DataTypeInvalid,
|
||||||
|
"Unsupported element type in VectorArray: {}",
|
||||||
|
element_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
// All supported vector types use FixedSizeBinaryBuilder
|
||||||
|
auto value_builder =
|
||||||
|
static_cast<arrow::FixedSizeBinaryBuilder*>(
|
||||||
list_builder->value_builder());
|
list_builder->value_builder());
|
||||||
AssertInfo(value_builder != nullptr,
|
AssertInfo(value_builder != nullptr,
|
||||||
"value_builder must be FloatBuilder for "
|
"value_builder must be FixedSizeBinaryBuilder for "
|
||||||
"FloatVector");
|
"VectorArray");
|
||||||
|
|
||||||
arrow::Status ast;
|
|
||||||
for (int i = 0; i < length; ++i) {
|
for (int i = 0; i < length; ++i) {
|
||||||
auto status = list_builder->Append();
|
auto status = list_builder->Append();
|
||||||
AssertInfo(status.ok(),
|
AssertInfo(status.ok(),
|
||||||
@ -257,53 +270,20 @@ AddPayloadToArrowBuilder(std::shared_ptr<arrow::ArrayBuilder> builder,
|
|||||||
status.ToString());
|
status.ToString());
|
||||||
|
|
||||||
const auto& array = vector_arrays[i];
|
const auto& array = vector_arrays[i];
|
||||||
AssertInfo(
|
AssertInfo(array.get_element_type() == element_type,
|
||||||
array.get_element_type() ==
|
|
||||||
DataType::VECTOR_FLOAT,
|
|
||||||
"Inconsistent element types in VectorArray");
|
"Inconsistent element types in VectorArray");
|
||||||
|
|
||||||
int num_vectors = array.length();
|
int num_vectors = array.length();
|
||||||
int dim = array.dim();
|
auto ast = value_builder->AppendValues(
|
||||||
|
reinterpret_cast<const uint8_t*>(array.data()),
|
||||||
for (int j = 0; j < num_vectors; ++j) {
|
num_vectors);
|
||||||
auto vec_data = array.get_data<float>(j);
|
|
||||||
ast =
|
|
||||||
value_builder->AppendValues(vec_data, dim);
|
|
||||||
AssertInfo(ast.ok(),
|
AssertInfo(ast.ok(),
|
||||||
"Failed to append list: {}",
|
"Failed to batch append vectors: {}",
|
||||||
ast.ToString());
|
ast.ToString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case DataType::VECTOR_BINARY:
|
|
||||||
ThrowInfo(
|
|
||||||
NotImplemented,
|
|
||||||
"BinaryVector in VectorArray not implemented yet");
|
|
||||||
break;
|
|
||||||
case DataType::VECTOR_FLOAT16:
|
|
||||||
ThrowInfo(
|
|
||||||
NotImplemented,
|
|
||||||
"Float16Vector in VectorArray not implemented yet");
|
|
||||||
break;
|
|
||||||
case DataType::VECTOR_BFLOAT16:
|
|
||||||
ThrowInfo(NotImplemented,
|
|
||||||
"BFloat16Vector in VectorArray not "
|
|
||||||
"implemented yet");
|
|
||||||
break;
|
|
||||||
case DataType::VECTOR_INT8:
|
|
||||||
ThrowInfo(
|
|
||||||
NotImplemented,
|
|
||||||
"Int8Vector in VectorArray not implemented yet");
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
ThrowInfo(DataTypeInvalid,
|
|
||||||
"Unsupported element type in VectorArray: {}",
|
|
||||||
element_type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default: {
|
default: {
|
||||||
ThrowInfo(DataTypeInvalid, "unsupported data type {}", data_type);
|
ThrowInfo(DataTypeInvalid, "unsupported data type {}", data_type);
|
||||||
}
|
}
|
||||||
@ -428,7 +408,38 @@ CreateArrowBuilder(DataType data_type, DataType element_type, int dim) {
|
|||||||
std::shared_ptr<arrow::ArrayBuilder> value_builder;
|
std::shared_ptr<arrow::ArrayBuilder> value_builder;
|
||||||
switch (element_type) {
|
switch (element_type) {
|
||||||
case DataType::VECTOR_FLOAT: {
|
case DataType::VECTOR_FLOAT: {
|
||||||
value_builder = std::make_shared<arrow::FloatBuilder>();
|
int byte_width = dim * sizeof(float);
|
||||||
|
value_builder =
|
||||||
|
std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||||
|
arrow::fixed_size_binary(byte_width));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_BINARY: {
|
||||||
|
int byte_width = (dim + 7) / 8;
|
||||||
|
value_builder =
|
||||||
|
std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||||
|
arrow::fixed_size_binary(byte_width));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_FLOAT16: {
|
||||||
|
int byte_width = dim * 2;
|
||||||
|
value_builder =
|
||||||
|
std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||||
|
arrow::fixed_size_binary(byte_width));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_BFLOAT16: {
|
||||||
|
int byte_width = dim * 2;
|
||||||
|
value_builder =
|
||||||
|
std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||||
|
arrow::fixed_size_binary(byte_width));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_INT8: {
|
||||||
|
int byte_width = dim;
|
||||||
|
value_builder =
|
||||||
|
std::make_shared<arrow::FixedSizeBinaryBuilder>(
|
||||||
|
arrow::fixed_size_binary(byte_width));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
@ -610,7 +621,7 @@ CreateArrowSchema(DataType data_type, int dim, DataType element_type) {
|
|||||||
"This overload is only for VECTOR_ARRAY type");
|
"This overload is only for VECTOR_ARRAY type");
|
||||||
AssertInfo(dim > 0, "invalid dim value");
|
AssertInfo(dim > 0, "invalid dim value");
|
||||||
|
|
||||||
auto value_type = GetArrowDataTypeForVectorArray(element_type);
|
auto value_type = GetArrowDataTypeForVectorArray(element_type, dim);
|
||||||
auto metadata = arrow::KeyValueMetadata::Make(
|
auto metadata = arrow::KeyValueMetadata::Make(
|
||||||
{ELEMENT_TYPE_KEY_FOR_ARROW, DIM_KEY},
|
{ELEMENT_TYPE_KEY_FOR_ARROW, DIM_KEY},
|
||||||
{std::to_string(static_cast<int>(element_type)), std::to_string(dim)});
|
{std::to_string(static_cast<int>(element_type)), std::to_string(dim)});
|
||||||
|
|||||||
@ -14,7 +14,7 @@
|
|||||||
# Update KNOWHERE_VERSION for the first occurrence
|
# Update KNOWHERE_VERSION for the first occurrence
|
||||||
milvus_add_pkg_config("knowhere")
|
milvus_add_pkg_config("knowhere")
|
||||||
set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES "")
|
set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES "")
|
||||||
set( KNOWHERE_VERSION v2.6.3 )
|
set( KNOWHERE_VERSION c4d5dd8 )
|
||||||
set( GIT_REPOSITORY "https://github.com/zilliztech/knowhere.git")
|
set( GIT_REPOSITORY "https://github.com/zilliztech/knowhere.git")
|
||||||
|
|
||||||
message(STATUS "Knowhere repo: ${GIT_REPOSITORY}")
|
message(STATUS "Knowhere repo: ${GIT_REPOSITORY}")
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
#include "common/Types.h"
|
#include "common/Types.h"
|
||||||
#include "index/IndexFactory.h"
|
#include "index/IndexFactory.h"
|
||||||
#include "knowhere/version.h"
|
#include "knowhere/version.h"
|
||||||
|
#include "knowhere/comp/index_param.h"
|
||||||
#include "storage/RemoteChunkManagerSingleton.h"
|
#include "storage/RemoteChunkManagerSingleton.h"
|
||||||
#include "storage/Util.h"
|
#include "storage/Util.h"
|
||||||
#include "common/VectorArray.h"
|
#include "common/VectorArray.h"
|
||||||
@ -2304,13 +2305,81 @@ TEST(Sealed, SearchSortedPk) {
|
|||||||
EXPECT_EQ(100, offsets2[0].get());
|
EXPECT_EQ(100, offsets2[0].get());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Sealed, QueryVectorArrayAllFields) {
|
using VectorArrayTestParam =
|
||||||
|
std::tuple<DataType, std::string, int, std::string>;
|
||||||
|
|
||||||
|
class SealedVectorArrayTest
|
||||||
|
: public ::testing::TestWithParam<VectorArrayTestParam> {
|
||||||
|
protected:
|
||||||
|
DataType element_type;
|
||||||
|
std::string metric_type;
|
||||||
|
int dim;
|
||||||
|
std::string test_name;
|
||||||
|
|
||||||
|
void
|
||||||
|
SetUp() override {
|
||||||
|
auto param = GetParam();
|
||||||
|
element_type = std::get<0>(param);
|
||||||
|
metric_type = std::get<1>(param);
|
||||||
|
dim = std::get<2>(param);
|
||||||
|
test_name = std::get<3>(param);
|
||||||
|
|
||||||
|
// Ensure dim is valid for binary vectors
|
||||||
|
if (element_type == DataType::VECTOR_BINARY) {
|
||||||
|
ASSERT_EQ(dim % 8, 0) << "Binary vector dim must be multiple of 8";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
VerifyVectorResults(const VectorFieldProto& result_vec,
|
||||||
|
const VectorFieldProto& expected_vec,
|
||||||
|
DataType element_type) {
|
||||||
|
switch (element_type) {
|
||||||
|
case DataType::VECTOR_FLOAT: {
|
||||||
|
auto result_data = result_vec.float_vector().data();
|
||||||
|
auto expected_data = expected_vec.float_vector().data();
|
||||||
|
EXPECT_EQ(result_data.size(), expected_data.size());
|
||||||
|
for (int64_t i = 0; i < result_data.size(); ++i) {
|
||||||
|
EXPECT_NEAR(result_data[i], expected_data[i], 1e-6f);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_BINARY: {
|
||||||
|
auto result_data = result_vec.binary_vector();
|
||||||
|
auto expected_data = expected_vec.binary_vector();
|
||||||
|
EXPECT_EQ(result_data, expected_data);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_FLOAT16: {
|
||||||
|
auto result_data = result_vec.float16_vector();
|
||||||
|
auto expected_data = expected_vec.float16_vector();
|
||||||
|
EXPECT_EQ(result_data, expected_data);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_BFLOAT16: {
|
||||||
|
auto result_data = result_vec.bfloat16_vector();
|
||||||
|
auto expected_data = expected_vec.bfloat16_vector();
|
||||||
|
EXPECT_EQ(result_data, expected_data);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DataType::VECTOR_INT8: {
|
||||||
|
auto result_data = result_vec.int8_vector();
|
||||||
|
auto expected_data = expected_vec.int8_vector();
|
||||||
|
EXPECT_EQ(result_data, expected_data);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(SealedVectorArrayTest, QueryVectorArrayAllFields) {
|
||||||
auto schema = std::make_shared<Schema>();
|
auto schema = std::make_shared<Schema>();
|
||||||
auto metric_type = knowhere::metric::MAX_SIM;
|
|
||||||
int64_t dim = 4;
|
|
||||||
auto int64_field = schema->AddDebugField("int64", DataType::INT64);
|
auto int64_field = schema->AddDebugField("int64", DataType::INT64);
|
||||||
auto array_vec = schema->AddDebugVectorArrayField(
|
auto array_vec = schema->AddDebugVectorArrayField(
|
||||||
"array_vec", DataType::VECTOR_FLOAT, dim, metric_type);
|
"array_vec", element_type, dim, metric_type);
|
||||||
schema->set_primary_field_id(int64_field);
|
schema->set_primary_field_id(int64_field);
|
||||||
|
|
||||||
std::map<FieldId, FieldIndexMeta> filedMap{};
|
std::map<FieldId, FieldIndexMeta> filedMap{};
|
||||||
@ -2329,49 +2398,36 @@ TEST(Sealed, QueryVectorArrayAllFields) {
|
|||||||
auto ids_ds = GenRandomIds(dataset_size);
|
auto ids_ds = GenRandomIds(dataset_size);
|
||||||
auto int64_result = segment->bulk_subscript(
|
auto int64_result = segment->bulk_subscript(
|
||||||
nullptr, int64_field, ids_ds->GetIds(), dataset_size);
|
nullptr, int64_field, ids_ds->GetIds(), dataset_size);
|
||||||
auto array_float_vector_result = segment->bulk_subscript(
|
auto array_vector_result = segment->bulk_subscript(
|
||||||
nullptr, array_vec, ids_ds->GetIds(), dataset_size);
|
nullptr, array_vec, ids_ds->GetIds(), dataset_size);
|
||||||
|
|
||||||
EXPECT_EQ(int64_result->scalars().long_data().data_size(), dataset_size);
|
EXPECT_EQ(int64_result->scalars().long_data().data_size(), dataset_size);
|
||||||
EXPECT_EQ(array_float_vector_result->vectors().vector_array().data_size(),
|
EXPECT_EQ(array_vector_result->vectors().vector_array().data_size(),
|
||||||
dataset_size);
|
dataset_size);
|
||||||
|
|
||||||
auto verify_float_vectors = [](auto arr1, auto arr2) {
|
|
||||||
static constexpr float EPSILON = 1e-6;
|
|
||||||
EXPECT_EQ(arr1.size(), arr2.size());
|
|
||||||
for (int64_t i = 0; i < arr1.size(); ++i) {
|
|
||||||
EXPECT_NEAR(arr1[i], arr2[i], EPSILON);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
for (int64_t i = 0; i < dataset_size; ++i) {
|
for (int64_t i = 0; i < dataset_size; ++i) {
|
||||||
auto arrow_array = array_float_vector_result->vectors()
|
auto result_vec =
|
||||||
.vector_array()
|
array_vector_result->vectors().vector_array().data()[i];
|
||||||
.data()[i]
|
auto expected_vec = array_vec_values[ids_ds->GetIds()[i]];
|
||||||
.float_vector()
|
VerifyVectorResults(result_vec, expected_vec, element_type);
|
||||||
.data();
|
|
||||||
auto expected_array =
|
|
||||||
array_vec_values[ids_ds->GetIds()[i]].float_vector().data();
|
|
||||||
verify_float_vectors(arrow_array, expected_array);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_EQ(int64_result->valid_data_size(), 0);
|
EXPECT_EQ(int64_result->valid_data_size(), 0);
|
||||||
EXPECT_EQ(array_float_vector_result->valid_data_size(), 0);
|
EXPECT_EQ(array_vector_result->valid_data_size(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Sealed, SearchVectorArray) {
|
TEST_P(SealedVectorArrayTest, SearchVectorArray) {
|
||||||
int64_t collection_id = 1;
|
int64_t collection_id = 1;
|
||||||
int64_t partition_id = 2;
|
int64_t partition_id = 2;
|
||||||
int64_t segment_id = 3;
|
int64_t segment_id = 3;
|
||||||
int64_t index_build_id = 4000;
|
int64_t index_build_id = 4000;
|
||||||
int64_t index_version = 4000;
|
int64_t index_version = 4000;
|
||||||
int64_t index_id = 5000;
|
int64_t index_id = 5000;
|
||||||
int64_t dim = 4;
|
|
||||||
|
|
||||||
auto schema = std::make_shared<Schema>();
|
auto schema = std::make_shared<Schema>();
|
||||||
auto metric_type = knowhere::metric::MAX_SIM;
|
|
||||||
auto int64_field = schema->AddDebugField("int64", DataType::INT64);
|
auto int64_field = schema->AddDebugField("int64", DataType::INT64);
|
||||||
auto array_vec = schema->AddDebugVectorArrayField(
|
auto array_vec = schema->AddDebugVectorArrayField(
|
||||||
"array_vec", DataType::VECTOR_FLOAT, dim, metric_type);
|
"array_vec", element_type, dim, metric_type);
|
||||||
schema->set_primary_field_id(int64_field);
|
schema->set_primary_field_id(int64_field);
|
||||||
|
|
||||||
auto field_meta = milvus::segcore::gen_field_meta(collection_id,
|
auto field_meta = milvus::segcore::gen_field_meta(collection_id,
|
||||||
@ -2379,7 +2435,7 @@ TEST(Sealed, SearchVectorArray) {
|
|||||||
segment_id,
|
segment_id,
|
||||||
array_vec.get(),
|
array_vec.get(),
|
||||||
DataType::VECTOR_ARRAY,
|
DataType::VECTOR_ARRAY,
|
||||||
DataType::VECTOR_FLOAT,
|
element_type,
|
||||||
false);
|
false);
|
||||||
auto index_meta = gen_index_meta(
|
auto index_meta = gen_index_meta(
|
||||||
segment_id, array_vec.get(), index_build_id, index_version);
|
segment_id, array_vec.get(), index_build_id, index_version);
|
||||||
@ -2403,7 +2459,7 @@ TEST(Sealed, SearchVectorArray) {
|
|||||||
vector_arrays.push_back(milvus::VectorArray(v));
|
vector_arrays.push_back(milvus::VectorArray(v));
|
||||||
}
|
}
|
||||||
auto field_data = storage::CreateFieldData(
|
auto field_data = storage::CreateFieldData(
|
||||||
DataType::VECTOR_ARRAY, DataType::VECTOR_FLOAT, false, dim);
|
DataType::VECTOR_ARRAY, element_type, false, dim);
|
||||||
field_data->FillFieldData(vector_arrays.data(), vector_arrays.size());
|
field_data->FillFieldData(vector_arrays.data(), vector_arrays.size());
|
||||||
|
|
||||||
// create sealed segment
|
// create sealed segment
|
||||||
@ -2445,8 +2501,8 @@ TEST(Sealed, SearchVectorArray) {
|
|||||||
// create index
|
// create index
|
||||||
milvus::index::CreateIndexInfo create_index_info;
|
milvus::index::CreateIndexInfo create_index_info;
|
||||||
create_index_info.field_type = DataType::VECTOR_ARRAY;
|
create_index_info.field_type = DataType::VECTOR_ARRAY;
|
||||||
create_index_info.metric_type = knowhere::metric::MAX_SIM;
|
create_index_info.metric_type = metric_type;
|
||||||
create_index_info.index_type = knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
|
create_index_info.index_type = knowhere::IndexEnum::INDEX_HNSW;
|
||||||
create_index_info.index_engine_version =
|
create_index_info.index_engine_version =
|
||||||
knowhere::Version::GetCurrentVersion().VersionNumber();
|
knowhere::Version::GetCurrentVersion().VersionNumber();
|
||||||
|
|
||||||
@ -2457,8 +2513,7 @@ TEST(Sealed, SearchVectorArray) {
|
|||||||
|
|
||||||
// build index
|
// build index
|
||||||
Config config;
|
Config config;
|
||||||
config[milvus::index::INDEX_TYPE] =
|
config[milvus::index::INDEX_TYPE] = knowhere::IndexEnum::INDEX_HNSW;
|
||||||
knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
|
|
||||||
config[INSERT_FILES_KEY] = std::vector<std::string>{log_path};
|
config[INSERT_FILES_KEY] = std::vector<std::string>{log_path};
|
||||||
config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type;
|
config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type;
|
||||||
config[knowhere::indexparam::M] = "16";
|
config[knowhere::indexparam::M] = "16";
|
||||||
@ -2473,18 +2528,37 @@ TEST(Sealed, SearchVectorArray) {
|
|||||||
|
|
||||||
// search
|
// search
|
||||||
auto vec_num = 10;
|
auto vec_num = 10;
|
||||||
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
|
|
||||||
auto query_dataset = knowhere::GenDataSet(vec_num, dim, query_vec.data());
|
// Generate query vectors based on element type
|
||||||
std::vector<size_t> query_vec_lims;
|
std::vector<uint8_t> query_vec_bin;
|
||||||
query_vec_lims.push_back(0);
|
std::vector<float> query_vec_f32;
|
||||||
query_vec_lims.push_back(3);
|
knowhere::DataSetPtr query_dataset;
|
||||||
query_vec_lims.push_back(10);
|
if (element_type == DataType::VECTOR_BINARY) {
|
||||||
query_dataset->SetLims(query_vec_lims.data());
|
auto byte_dim = (dim + 7) / 8;
|
||||||
|
auto total_bytes = vec_num * byte_dim;
|
||||||
|
query_vec_bin.resize(total_bytes);
|
||||||
|
for (size_t i = 0; i < total_bytes; ++i) {
|
||||||
|
query_vec_bin[i] = rand() % 256;
|
||||||
|
}
|
||||||
|
query_dataset =
|
||||||
|
knowhere::GenDataSet(vec_num, dim, query_vec_bin.data());
|
||||||
|
} else {
|
||||||
|
// For float-like types (FLOAT, FLOAT16, BFLOAT16, INT8)
|
||||||
|
query_vec_f32 = generate_float_vector(vec_num, dim);
|
||||||
|
query_dataset =
|
||||||
|
knowhere::GenDataSet(vec_num, dim, query_vec_f32.data());
|
||||||
|
}
|
||||||
|
std::vector<size_t> query_vec_offsets;
|
||||||
|
query_vec_offsets.push_back(0);
|
||||||
|
query_vec_offsets.push_back(3);
|
||||||
|
query_vec_offsets.push_back(10);
|
||||||
|
query_dataset->Set(knowhere::meta::EMB_LIST_OFFSET,
|
||||||
|
const_cast<const size_t*>(query_vec_offsets.data()));
|
||||||
|
|
||||||
auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}};
|
auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}};
|
||||||
milvus::SearchInfo searchInfo;
|
milvus::SearchInfo searchInfo;
|
||||||
searchInfo.topk_ = 5;
|
searchInfo.topk_ = 5;
|
||||||
searchInfo.metric_type_ = knowhere::metric::MAX_SIM;
|
searchInfo.metric_type_ = metric_type;
|
||||||
searchInfo.search_params_ = search_conf;
|
searchInfo.search_params_ = search_conf;
|
||||||
SearchResult result;
|
SearchResult result;
|
||||||
vec_index->Query(query_dataset, searchInfo, nullptr, nullptr, result);
|
vec_index->Query(query_dataset, searchInfo, nullptr, nullptr, result);
|
||||||
@ -2498,21 +2572,62 @@ TEST(Sealed, SearchVectorArray) {
|
|||||||
|
|
||||||
// brute force search
|
// brute force search
|
||||||
{
|
{
|
||||||
const char* raw_plan = R"(vector_anns: <
|
std::string raw_plan = fmt::format(R"(vector_anns: <
|
||||||
field_id: 101
|
field_id: 101
|
||||||
query_info: <
|
query_info: <
|
||||||
topk: 5
|
topk: 5
|
||||||
round_decimal: 3
|
round_decimal: 3
|
||||||
metric_type: "MAX_SIM"
|
metric_type: "{}"
|
||||||
search_params: "{\"nprobe\": 10}"
|
search_params: "{{\"nprobe\": 10}}"
|
||||||
>
|
>
|
||||||
placeholder_tag: "$0"
|
placeholder_tag: "$0"
|
||||||
>)";
|
>)",
|
||||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
metric_type);
|
||||||
|
auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
|
||||||
auto plan =
|
auto plan =
|
||||||
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
|
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
|
||||||
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
|
|
||||||
vec_num, dim, query_vec.data(), query_vec_lims);
|
// Create placeholder based on element type
|
||||||
|
milvus::proto::common::PlaceholderGroup ph_group_raw;
|
||||||
|
if (element_type == DataType::VECTOR_BINARY) {
|
||||||
|
auto byte_dim = (dim + 7) / 8;
|
||||||
|
auto total_bytes = vec_num * byte_dim;
|
||||||
|
std::vector<uint8_t> query_vec(total_bytes);
|
||||||
|
for (size_t i = 0; i < total_bytes; ++i) {
|
||||||
|
query_vec[i] = rand() % 256;
|
||||||
|
}
|
||||||
|
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListBinaryVector>(
|
||||||
|
vec_num, dim, query_vec.data(), query_vec_offsets);
|
||||||
|
} else if (element_type == DataType::VECTOR_FLOAT16) {
|
||||||
|
std::vector<float> float_vec = generate_float_vector(vec_num, dim);
|
||||||
|
std::vector<float16> query_vec(vec_num * dim);
|
||||||
|
for (size_t i = 0; i < vec_num * dim; ++i) {
|
||||||
|
query_vec[i] = float16(float_vec[i]);
|
||||||
|
}
|
||||||
|
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloat16Vector>(
|
||||||
|
vec_num, dim, query_vec.data(), query_vec_offsets);
|
||||||
|
} else if (element_type == DataType::VECTOR_BFLOAT16) {
|
||||||
|
std::vector<float> float_vec = generate_float_vector(vec_num, dim);
|
||||||
|
std::vector<bfloat16> query_vec(vec_num * dim);
|
||||||
|
for (size_t i = 0; i < vec_num * dim; ++i) {
|
||||||
|
query_vec[i] = bfloat16(float_vec[i]);
|
||||||
|
}
|
||||||
|
ph_group_raw =
|
||||||
|
CreatePlaceholderGroupFromBlob<EmbListBFloat16Vector>(
|
||||||
|
vec_num, dim, query_vec.data(), query_vec_offsets);
|
||||||
|
} else if (element_type == DataType::VECTOR_INT8) {
|
||||||
|
std::vector<int8_t> query_vec(vec_num * dim);
|
||||||
|
for (size_t i = 0; i < vec_num * dim; ++i) {
|
||||||
|
query_vec[i] = static_cast<int8_t>(rand() % 256 - 128);
|
||||||
|
}
|
||||||
|
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListInt8Vector>(
|
||||||
|
vec_num, dim, query_vec.data(), query_vec_offsets);
|
||||||
|
} else {
|
||||||
|
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
|
||||||
|
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
|
||||||
|
vec_num, dim, query_vec.data(), query_vec_offsets);
|
||||||
|
}
|
||||||
|
|
||||||
auto ph_group =
|
auto ph_group =
|
||||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||||
Timestamp timestamp = 1000000;
|
Timestamp timestamp = 1000000;
|
||||||
@ -2528,30 +2643,71 @@ TEST(Sealed, SearchVectorArray) {
|
|||||||
LoadIndexInfo load_info;
|
LoadIndexInfo load_info;
|
||||||
load_info.field_id = array_vec.get();
|
load_info.field_id = array_vec.get();
|
||||||
load_info.field_type = DataType::VECTOR_ARRAY;
|
load_info.field_type = DataType::VECTOR_ARRAY;
|
||||||
load_info.element_type = DataType::VECTOR_FLOAT;
|
load_info.element_type = element_type;
|
||||||
load_info.index_params = GenIndexParams(emb_list_hnsw_index.get());
|
load_info.index_params = GenIndexParams(emb_list_hnsw_index.get());
|
||||||
load_info.cache_index =
|
load_info.cache_index =
|
||||||
CreateTestCacheIndex("test", std::move(emb_list_hnsw_index));
|
CreateTestCacheIndex("test", std::move(emb_list_hnsw_index));
|
||||||
load_info.index_params["metric_type"] = knowhere::metric::MAX_SIM;
|
load_info.index_params["metric_type"] = metric_type;
|
||||||
|
|
||||||
sealed_segment->DropFieldData(array_vec);
|
sealed_segment->DropFieldData(array_vec);
|
||||||
sealed_segment->LoadIndex(load_info);
|
sealed_segment->LoadIndex(load_info);
|
||||||
|
|
||||||
const char* raw_plan = R"(vector_anns: <
|
std::string raw_plan = fmt::format(R"(vector_anns: <
|
||||||
field_id: 101
|
field_id: 101
|
||||||
query_info: <
|
query_info: <
|
||||||
topk: 5
|
topk: 5
|
||||||
round_decimal: 3
|
round_decimal: 3
|
||||||
metric_type: "MAX_SIM"
|
metric_type: "{}"
|
||||||
search_params: "{\"nprobe\": 10}"
|
search_params: "{{\"nprobe\": 10}}"
|
||||||
>
|
>
|
||||||
placeholder_tag: "$0"
|
placeholder_tag: "$0"
|
||||||
>)";
|
>)",
|
||||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
metric_type);
|
||||||
|
auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
|
||||||
auto plan =
|
auto plan =
|
||||||
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
|
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
|
||||||
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
|
|
||||||
vec_num, dim, query_vec.data(), query_vec_lims);
|
// Create placeholder based on element type
|
||||||
|
milvus::proto::common::PlaceholderGroup ph_group_raw;
|
||||||
|
if (element_type == DataType::VECTOR_BINARY) {
|
||||||
|
auto byte_dim = (dim + 7) / 8;
|
||||||
|
auto total_bytes = vec_num * byte_dim;
|
||||||
|
std::vector<uint8_t> query_vec(total_bytes);
|
||||||
|
for (size_t i = 0; i < total_bytes; ++i) {
|
||||||
|
query_vec[i] = rand() % 256;
|
||||||
|
}
|
||||||
|
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListBinaryVector>(
|
||||||
|
vec_num, dim, query_vec.data(), query_vec_offsets);
|
||||||
|
} else if (element_type == DataType::VECTOR_FLOAT16) {
|
||||||
|
std::vector<float> float_vec = generate_float_vector(vec_num, dim);
|
||||||
|
std::vector<float16> query_vec(vec_num * dim);
|
||||||
|
for (size_t i = 0; i < vec_num * dim; ++i) {
|
||||||
|
query_vec[i] = float16(float_vec[i]);
|
||||||
|
}
|
||||||
|
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloat16Vector>(
|
||||||
|
vec_num, dim, query_vec.data(), query_vec_offsets);
|
||||||
|
} else if (element_type == DataType::VECTOR_BFLOAT16) {
|
||||||
|
std::vector<float> float_vec = generate_float_vector(vec_num, dim);
|
||||||
|
std::vector<bfloat16> query_vec(vec_num * dim);
|
||||||
|
for (size_t i = 0; i < vec_num * dim; ++i) {
|
||||||
|
query_vec[i] = bfloat16(float_vec[i]);
|
||||||
|
}
|
||||||
|
ph_group_raw =
|
||||||
|
CreatePlaceholderGroupFromBlob<EmbListBFloat16Vector>(
|
||||||
|
vec_num, dim, query_vec.data(), query_vec_offsets);
|
||||||
|
} else if (element_type == DataType::VECTOR_INT8) {
|
||||||
|
std::vector<int8_t> query_vec(vec_num * dim);
|
||||||
|
for (size_t i = 0; i < vec_num * dim; ++i) {
|
||||||
|
query_vec[i] = static_cast<int8_t>(rand() % 256 - 128);
|
||||||
|
}
|
||||||
|
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListInt8Vector>(
|
||||||
|
vec_num, dim, query_vec.data(), query_vec_offsets);
|
||||||
|
} else {
|
||||||
|
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
|
||||||
|
ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
|
||||||
|
vec_num, dim, query_vec.data(), query_vec_offsets);
|
||||||
|
}
|
||||||
|
|
||||||
auto ph_group =
|
auto ph_group =
|
||||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||||
Timestamp timestamp = 1000000;
|
Timestamp timestamp = 1000000;
|
||||||
@ -2562,3 +2718,27 @@ TEST(Sealed, SearchVectorArray) {
|
|||||||
std::cout << sr_parsed.dump(1) << std::endl;
|
std::cout << sr_parsed.dump(1) << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
VectorArrayTypes,
|
||||||
|
SealedVectorArrayTest,
|
||||||
|
::testing::Values(
|
||||||
|
std::make_tuple(DataType::VECTOR_FLOAT, "MAX_SIM", 4, "float_max_sim"),
|
||||||
|
std::make_tuple(DataType::VECTOR_FLOAT, "MAX_SIM_L2", 4, "float_l2"),
|
||||||
|
std::make_tuple(
|
||||||
|
DataType::VECTOR_FLOAT16, "MAX_SIM", 4, "float16_max_sim"),
|
||||||
|
std::make_tuple(
|
||||||
|
DataType::VECTOR_FLOAT16, "MAX_SIM_L2", 4, "float16_l2"),
|
||||||
|
std::make_tuple(
|
||||||
|
DataType::VECTOR_BFLOAT16, "MAX_SIM", 4, "bfloat16_max_sim"),
|
||||||
|
std::make_tuple(
|
||||||
|
DataType::VECTOR_BFLOAT16, "MAX_SIM_L2", 4, "bfloat16_l2"),
|
||||||
|
std::make_tuple(DataType::VECTOR_INT8, "MAX_SIM", 4, "int8_max_sim"),
|
||||||
|
std::make_tuple(DataType::VECTOR_INT8, "MAX_SIM_L2", 4, "int8_l2"),
|
||||||
|
std::make_tuple(
|
||||||
|
DataType::VECTOR_BINARY, "MAX_SIM_HAMMING", 32, "binary_hamming"),
|
||||||
|
std::make_tuple(
|
||||||
|
DataType::VECTOR_BINARY, "MAX_SIM_JACCARD", 32, "binary_jaccard")),
|
||||||
|
[](const ::testing::TestParamInfo<VectorArrayTestParam>& info) {
|
||||||
|
return std::get<3>(info.param);
|
||||||
|
});
|
||||||
|
|||||||
@ -249,14 +249,22 @@ func CreateSearchPlanArgs(schema *typeutil.SchemaHelper, exprStr string, vectorF
|
|||||||
switch elementType {
|
switch elementType {
|
||||||
case schemapb.DataType_FloatVector:
|
case schemapb.DataType_FloatVector:
|
||||||
vectorType = planpb.VectorType_EmbListFloatVector
|
vectorType = planpb.VectorType_EmbListFloatVector
|
||||||
|
case schemapb.DataType_BinaryVector:
|
||||||
|
vectorType = planpb.VectorType_EmbListBinaryVector
|
||||||
|
case schemapb.DataType_Float16Vector:
|
||||||
|
vectorType = planpb.VectorType_EmbListFloat16Vector
|
||||||
|
case schemapb.DataType_BFloat16Vector:
|
||||||
|
vectorType = planpb.VectorType_EmbListBFloat16Vector
|
||||||
|
case schemapb.DataType_Int8Vector:
|
||||||
|
vectorType = planpb.VectorType_EmbListInt8Vector
|
||||||
default:
|
default:
|
||||||
log.Error("Invalid elementType", zap.Any("elementType", elementType))
|
log.Error("Invalid elementType for ArrayOfVector", zap.Any("elementType", elementType))
|
||||||
return nil, err
|
return nil, fmt.Errorf("unsupported element type for ArrayOfVector: %v", elementType)
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
log.Error("Invalid dataType", zap.Any("dataType", dataType))
|
log.Error("Invalid dataType", zap.Any("dataType", dataType))
|
||||||
return nil, err
|
return nil, fmt.Errorf("unsupported vector data type: %v", dataType)
|
||||||
}
|
}
|
||||||
|
|
||||||
scorers, options, err := CreateSearchScorers(schema, functionScorer, exprTemplateValues)
|
scorers, options, err := CreateSearchScorers(schema, functionScorer, exprTemplateValues)
|
||||||
|
|||||||
@ -29,7 +29,7 @@ func Test_CheckVecIndexWithDataTypeExist(t *testing.T) {
|
|||||||
want bool
|
want bool
|
||||||
}{
|
}{
|
||||||
{"HNSW", schemapb.DataType_FloatVector, true},
|
{"HNSW", schemapb.DataType_FloatVector, true},
|
||||||
{"HNSW", schemapb.DataType_BinaryVector, false},
|
{"HNSW", schemapb.DataType_BinaryVector, true},
|
||||||
{"HNSW", schemapb.DataType_Float16Vector, true},
|
{"HNSW", schemapb.DataType_Float16Vector, true},
|
||||||
|
|
||||||
{"SPARSE_WAND", schemapb.DataType_SparseFloatVector, true},
|
{"SPARSE_WAND", schemapb.DataType_SparseFloatVector, true},
|
||||||
|
|||||||
@ -555,7 +555,7 @@ func constructTestCreateIndexRequest(dbName, collectionName string, dataType sch
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
Key: common.IndexTypeKey,
|
Key: common.IndexTypeKey,
|
||||||
Value: "EMB_LIST_HNSW",
|
Value: "HNSW",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Key: "nlist",
|
Key: "nlist",
|
||||||
@ -1744,7 +1744,6 @@ func TestProxy(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||||
})
|
})
|
||||||
fmt.Println("create index for binVec field")
|
|
||||||
|
|
||||||
fieldName := ConcatStructFieldName(structField, subFieldFVec)
|
fieldName := ConcatStructFieldName(structField, subFieldFVec)
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
@ -1757,8 +1756,6 @@ func TestProxy(t *testing.T) {
|
|||||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
fmt.Println("create index for embedding list field")
|
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
t.Run("alter index for embedding list field", func(t *testing.T) {
|
t.Run("alter index for embedding list field", func(t *testing.T) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@ -1778,7 +1775,6 @@ func TestProxy(t *testing.T) {
|
|||||||
err = merr.CheckRPCCall(resp, err)
|
err = merr.CheckRPCCall(resp, err)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
fmt.Println("alter index for embedding list field")
|
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
t.Run("describe index for embedding list field", func(t *testing.T) {
|
t.Run("describe index for embedding list field", func(t *testing.T) {
|
||||||
@ -1796,7 +1792,6 @@ func TestProxy(t *testing.T) {
|
|||||||
enableMmap, _ := common.IsMmapDataEnabled(resp.IndexDescriptions[0].GetParams()...)
|
enableMmap, _ := common.IsMmapDataEnabled(resp.IndexDescriptions[0].GetParams()...)
|
||||||
assert.True(t, enableMmap, "params: %+v", resp.IndexDescriptions[0])
|
assert.True(t, enableMmap, "params: %+v", resp.IndexDescriptions[0])
|
||||||
})
|
})
|
||||||
fmt.Println("describe index for embedding list field")
|
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
t.Run("describe index with indexName for embedding list field", func(t *testing.T) {
|
t.Run("describe index with indexName for embedding list field", func(t *testing.T) {
|
||||||
@ -1812,7 +1807,6 @@ func TestProxy(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode())
|
||||||
})
|
})
|
||||||
fmt.Println("describe index with indexName for embedding list field")
|
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
t.Run("get index statistics for embedding list field", func(t *testing.T) {
|
t.Run("get index statistics for embedding list field", func(t *testing.T) {
|
||||||
|
|||||||
@ -71,8 +71,6 @@ const (
|
|||||||
RoundDecimalKey = "round_decimal"
|
RoundDecimalKey = "round_decimal"
|
||||||
OffsetKey = "offset"
|
OffsetKey = "offset"
|
||||||
LimitKey = "limit"
|
LimitKey = "limit"
|
||||||
// offsets for embedding list search
|
|
||||||
LimsKey = "lims"
|
|
||||||
// key for timestamptz translation
|
// key for timestamptz translation
|
||||||
TimezoneKey = "timezone"
|
TimezoneKey = "timezone"
|
||||||
TimefieldsKey = "time_fields"
|
TimefieldsKey = "time_fields"
|
||||||
|
|||||||
@ -1206,7 +1206,7 @@ func Test_checkEmbeddingListIndex(t *testing.T) {
|
|||||||
ExtraParams: []*commonpb.KeyValuePair{
|
ExtraParams: []*commonpb.KeyValuePair{
|
||||||
{
|
{
|
||||||
Key: common.IndexTypeKey,
|
Key: common.IndexTypeKey,
|
||||||
Value: "EMB_LIST_HNSW",
|
Value: "HNSW",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Key: common.MetricTypeKey,
|
Key: common.MetricTypeKey,
|
||||||
@ -1237,7 +1237,7 @@ func Test_checkEmbeddingListIndex(t *testing.T) {
|
|||||||
ExtraParams: []*commonpb.KeyValuePair{
|
ExtraParams: []*commonpb.KeyValuePair{
|
||||||
{
|
{
|
||||||
Key: common.IndexTypeKey,
|
Key: common.IndexTypeKey,
|
||||||
Value: "EMB_LIST_HNSW",
|
Value: "HNSW",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Key: common.MetricTypeKey,
|
Key: common.MetricTypeKey,
|
||||||
@ -1290,37 +1290,6 @@ func Test_checkEmbeddingListIndex(t *testing.T) {
|
|||||||
err := cit.parseIndexParams(context.TODO())
|
err := cit.parseIndexParams(context.TODO())
|
||||||
assert.True(t, strings.Contains(err.Error(), "float vector index does not support metric type: MAX_SIM"))
|
assert.True(t, strings.Contains(err.Error(), "float vector index does not support metric type: MAX_SIM"))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("data type wrong", func(t *testing.T) {
|
|
||||||
cit := &createIndexTask{
|
|
||||||
Condition: nil,
|
|
||||||
req: &milvuspb.CreateIndexRequest{
|
|
||||||
ExtraParams: []*commonpb.KeyValuePair{
|
|
||||||
{
|
|
||||||
Key: common.IndexTypeKey,
|
|
||||||
Value: "EMB_LIST_HNSW",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: common.MetricTypeKey,
|
|
||||||
Value: metric.L2,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IndexName: "",
|
|
||||||
},
|
|
||||||
fieldSchema: &schemapb.FieldSchema{
|
|
||||||
FieldID: 101,
|
|
||||||
Name: "FieldFloat",
|
|
||||||
IsPrimaryKey: false,
|
|
||||||
DataType: schemapb.DataType_FloatVector,
|
|
||||||
TypeParams: []*commonpb.KeyValuePair{
|
|
||||||
{Key: common.DimKey, Value: "128"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := cit.parseIndexParams(context.TODO())
|
|
||||||
assert.True(t, strings.Contains(err.Error(), "data type FloatVector can't build with this index EMB_LIST_HNSW"))
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_ngram_parseIndexParams(t *testing.T) {
|
func Test_ngram_parseIndexParams(t *testing.T) {
|
||||||
|
|||||||
@ -1087,12 +1087,61 @@ func (v *validateUtil) checkArrayOfVectorFieldData(field *schemapb.FieldData, fi
|
|||||||
return merr.WrapErrParameterInvalid("need float vector array", "got nil", msg)
|
return merr.WrapErrParameterInvalid("need float vector array", "got nil", msg)
|
||||||
}
|
}
|
||||||
if v.checkNAN {
|
if v.checkNAN {
|
||||||
return typeutil.VerifyFloats32(floatVector.GetData())
|
if err := typeutil.VerifyFloats32(floatVector.GetData()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case schemapb.DataType_BinaryVector:
|
||||||
|
for _, vector := range data.GetData() {
|
||||||
|
binaryVector := vector.GetBinaryVector()
|
||||||
|
if binaryVector == nil {
|
||||||
|
msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||||
|
return merr.WrapErrParameterInvalid("need binary vector array", "got nil", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case schemapb.DataType_Float16Vector:
|
||||||
|
for _, vector := range data.GetData() {
|
||||||
|
float16Vector := vector.GetFloat16Vector()
|
||||||
|
if float16Vector == nil {
|
||||||
|
msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||||
|
return merr.WrapErrParameterInvalid("need float16 vector array", "got nil", msg)
|
||||||
|
}
|
||||||
|
if v.checkNAN {
|
||||||
|
if err := typeutil.VerifyFloats16(float16Vector); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case schemapb.DataType_BFloat16Vector:
|
||||||
|
for _, vector := range data.GetData() {
|
||||||
|
bfloat16Vector := vector.GetBfloat16Vector()
|
||||||
|
if bfloat16Vector == nil {
|
||||||
|
msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||||
|
return merr.WrapErrParameterInvalid("need bfloat16 vector array", "got nil", msg)
|
||||||
|
}
|
||||||
|
if v.checkNAN {
|
||||||
|
if err := typeutil.VerifyBFloats16(bfloat16Vector); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case schemapb.DataType_Int8Vector:
|
||||||
|
for _, vector := range data.GetData() {
|
||||||
|
int8Vector := vector.GetInt8Vector()
|
||||||
|
if int8Vector == nil {
|
||||||
|
msg := fmt.Sprintf("array of vector field '%v' is illegal, array type mismatch", field.GetFieldName())
|
||||||
|
return merr.WrapErrParameterInvalid("need int8 vector array", "got nil", msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
panic("not implemented")
|
msg := fmt.Sprintf("unsupported element type for ArrayOfVector: %v", fieldSchema.GetElementType())
|
||||||
|
return merr.WrapErrParameterInvalid("supported vector type", fieldSchema.GetElementType().String(), msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -251,16 +251,18 @@ func appendValueAt(builder array.Builder, a arrow.Array, idx int, defaultValue *
|
|||||||
b.Append(true)
|
b.Append(true)
|
||||||
|
|
||||||
valuesArray := la.ListValues()
|
valuesArray := la.ListValues()
|
||||||
valueBuilder := b.ValueBuilder()
|
|
||||||
|
|
||||||
var totalSize uint64 = 0
|
var totalSize uint64 = 0
|
||||||
|
valueBuilder := b.ValueBuilder()
|
||||||
switch vb := valueBuilder.(type) {
|
switch vb := valueBuilder.(type) {
|
||||||
case *array.Float32Builder:
|
case *array.FixedSizeBinaryBuilder:
|
||||||
if floatArray, ok := valuesArray.(*array.Float32); ok {
|
fixedArray, ok := valuesArray.(*array.FixedSizeBinary)
|
||||||
for i := start; i < end; i++ {
|
if !ok {
|
||||||
vb.Append(floatArray.Value(int(i)))
|
return 0, fmt.Errorf("invalid value type %T, expect %T", valuesArray.DataType(), vb.Type())
|
||||||
totalSize += 4
|
|
||||||
}
|
}
|
||||||
|
for i := start; i < end; i++ {
|
||||||
|
val := fixedArray.Value(int(i))
|
||||||
|
vb.Append(val)
|
||||||
|
totalSize += uint64(len(val))
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("unsupported value builder type in ListBuilder: %T", valueBuilder)
|
return 0, fmt.Errorf("unsupported value builder type in ListBuilder: %T", valueBuilder)
|
||||||
|
|||||||
@ -376,7 +376,12 @@ func NewFieldData(dataType schemapb.DataType, fieldSchema *schemapb.FieldSchema,
|
|||||||
}
|
}
|
||||||
return data, nil
|
return data, nil
|
||||||
case schemapb.DataType_ArrayOfVector:
|
case schemapb.DataType_ArrayOfVector:
|
||||||
|
dim, err := GetDimFromParams(typeParams)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
data := &VectorArrayFieldData{
|
data := &VectorArrayFieldData{
|
||||||
|
Dim: int64(dim),
|
||||||
Data: make([]*schemapb.VectorField, 0, cap),
|
Data: make([]*schemapb.VectorField, 0, cap),
|
||||||
ElementType: fieldSchema.GetElementType(),
|
ElementType: fieldSchema.GetElementType(),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -628,46 +628,17 @@ func readVectorArrayFromListArray(r *PayloadReader) ([]*schemapb.VectorField, er
|
|||||||
return nil, fmt.Errorf("expected ListArray, got %T", chunk)
|
return nil, fmt.Errorf("expected ListArray, got %T", chunk)
|
||||||
}
|
}
|
||||||
|
|
||||||
valuesArray := listArray.ListValues()
|
|
||||||
switch elementType {
|
|
||||||
case schemapb.DataType_FloatVector:
|
|
||||||
floatArray, ok := valuesArray.(*array.Float32)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("expected Float32 array for FloatVector, got %T", valuesArray)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process each row which contains multiple vectors
|
|
||||||
for i := 0; i < listArray.Len(); i++ {
|
for i := 0; i < listArray.Len(); i++ {
|
||||||
if listArray.IsNull(i) {
|
value, ok := deserializeArrayOfVector(listArray, i, elementType, dim, true)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to deserialize VectorArray at row %d", len(result))
|
||||||
|
}
|
||||||
|
vectorField, _ := value.(*schemapb.VectorField)
|
||||||
|
if vectorField == nil {
|
||||||
return nil, fmt.Errorf("null value in VectorArray")
|
return nil, fmt.Errorf("null value in VectorArray")
|
||||||
}
|
}
|
||||||
|
|
||||||
start, end := listArray.ValueOffsets(i)
|
|
||||||
vectorData := make([]float32, end-start)
|
|
||||||
copy(vectorData, floatArray.Float32Values()[start:end])
|
|
||||||
|
|
||||||
vectorField := &schemapb.VectorField{
|
|
||||||
Dim: dim,
|
|
||||||
Data: &schemapb.VectorField_FloatVector{
|
|
||||||
FloatVector: &schemapb.FloatArray{
|
|
||||||
Data: vectorData,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
result = append(result, vectorField)
|
result = append(result, vectorField)
|
||||||
}
|
}
|
||||||
|
|
||||||
case schemapb.DataType_BinaryVector:
|
|
||||||
return nil, fmt.Errorf("BinaryVector in VectorArray not implemented yet")
|
|
||||||
case schemapb.DataType_Float16Vector:
|
|
||||||
return nil, fmt.Errorf("Float16Vector in VectorArray not implemented yet")
|
|
||||||
case schemapb.DataType_BFloat16Vector:
|
|
||||||
return nil, fmt.Errorf("BFloat16Vector in VectorArray not implemented yet")
|
|
||||||
case schemapb.DataType_Int8Vector:
|
|
||||||
return nil, fmt.Errorf("Int8Vector in VectorArray not implemented yet")
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported element type in VectorArray: %s", elementType.String())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
|
|||||||
@ -1719,6 +1719,7 @@ func TestPayload_ReaderAndWriter(t *testing.T) {
|
|||||||
|
|
||||||
// Create VectorArrayFieldData with 3 rows
|
// Create VectorArrayFieldData with 3 rows
|
||||||
vectorArrayData := &VectorArrayFieldData{
|
vectorArrayData := &VectorArrayFieldData{
|
||||||
|
Dim: int64(dim),
|
||||||
Data: []*schemapb.VectorField{
|
Data: []*schemapb.VectorField{
|
||||||
{
|
{
|
||||||
Dim: int64(dim),
|
Dim: int64(dim),
|
||||||
|
|||||||
@ -18,6 +18,7 @@ package storage
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"sync"
|
"sync"
|
||||||
@ -114,7 +115,10 @@ func NewPayloadWriter(colType schemapb.DataType, options ...PayloadWriterOptions
|
|||||||
if w.elementType == nil {
|
if w.elementType == nil {
|
||||||
return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires elementType, use WithElementType option")
|
return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires elementType, use WithElementType option")
|
||||||
}
|
}
|
||||||
elemType, err := VectorArrayToArrowType(*w.elementType)
|
if w.dim == nil {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("ArrayOfVector requires dim to be specified")
|
||||||
|
}
|
||||||
|
elemType, err := VectorArrayToArrowType(*w.elementType, *w.dim.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -962,13 +966,13 @@ func (w *NativePayloadWriter) AddVectorArrayFieldDataToPayload(data *VectorArray
|
|||||||
case schemapb.DataType_FloatVector:
|
case schemapb.DataType_FloatVector:
|
||||||
return w.addFloatVectorArrayToPayload(builder, data)
|
return w.addFloatVectorArrayToPayload(builder, data)
|
||||||
case schemapb.DataType_BinaryVector:
|
case schemapb.DataType_BinaryVector:
|
||||||
return merr.WrapErrParameterInvalidMsg("BinaryVector in VectorArray not implemented yet")
|
return w.addBinaryVectorArrayToPayload(builder, data)
|
||||||
case schemapb.DataType_Float16Vector:
|
case schemapb.DataType_Float16Vector:
|
||||||
return merr.WrapErrParameterInvalidMsg("Float16Vector in VectorArray not implemented yet")
|
return w.addFloat16VectorArrayToPayload(builder, data)
|
||||||
case schemapb.DataType_BFloat16Vector:
|
case schemapb.DataType_BFloat16Vector:
|
||||||
return merr.WrapErrParameterInvalidMsg("BFloat16Vector in VectorArray not implemented yet")
|
return w.addBFloat16VectorArrayToPayload(builder, data)
|
||||||
case schemapb.DataType_Int8Vector:
|
case schemapb.DataType_Int8Vector:
|
||||||
return merr.WrapErrParameterInvalidMsg("Int8Vector in VectorArray not implemented yet")
|
return w.addInt8VectorArrayToPayload(builder, data)
|
||||||
default:
|
default:
|
||||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", data.ElementType.String()))
|
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", data.ElementType.String()))
|
||||||
}
|
}
|
||||||
@ -976,21 +980,159 @@ func (w *NativePayloadWriter) AddVectorArrayFieldDataToPayload(data *VectorArray
|
|||||||
|
|
||||||
// addFloatVectorArrayToPayload handles FloatVector elements in VectorArray
|
// addFloatVectorArrayToPayload handles FloatVector elements in VectorArray
|
||||||
func (w *NativePayloadWriter) addFloatVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
|
func (w *NativePayloadWriter) addFloatVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
|
||||||
valueBuilder := builder.ValueBuilder().(*array.Float32Builder)
|
if data.Dim <= 0 {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
||||||
|
|
||||||
|
// Each element in data.Data represents one row of VectorArray
|
||||||
for _, vectorField := range data.Data {
|
for _, vectorField := range data.Data {
|
||||||
if vectorField.GetFloatVector() == nil {
|
if vectorField.GetFloatVector() == nil {
|
||||||
return merr.WrapErrParameterInvalidMsg("expected FloatVector but got different type")
|
return merr.WrapErrParameterInvalidMsg("expected FloatVector but got different type")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start a new list for this row
|
||||||
builder.Append(true)
|
builder.Append(true)
|
||||||
|
|
||||||
floatData := vectorField.GetFloatVector().GetData()
|
floatData := vectorField.GetFloatVector().GetData()
|
||||||
if len(floatData) == 0 {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("empty vector data not allowed")
|
numVectors := len(floatData) / int(data.Dim)
|
||||||
|
for i := 0; i < numVectors; i++ {
|
||||||
|
start := i * int(data.Dim)
|
||||||
|
end := start + int(data.Dim)
|
||||||
|
vectorSlice := floatData[start:end]
|
||||||
|
|
||||||
|
bytes := make([]byte, data.Dim*4)
|
||||||
|
for j, f := range vectorSlice {
|
||||||
|
binary.LittleEndian.PutUint32(bytes[j*4:], math.Float32bits(f))
|
||||||
}
|
}
|
||||||
|
|
||||||
valueBuilder.AppendValues(floatData, nil)
|
valueBuilder.Append(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addBinaryVectorArrayToPayload handles BinaryVector elements in VectorArray
|
||||||
|
func (w *NativePayloadWriter) addBinaryVectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
|
||||||
|
if data.Dim <= 0 {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
||||||
|
|
||||||
|
// Each element in data.Data represents one row of VectorArray
|
||||||
|
for _, vectorField := range data.Data {
|
||||||
|
if vectorField.GetBinaryVector() == nil {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("expected BinaryVector but got different type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a new list for this row
|
||||||
|
builder.Append(true)
|
||||||
|
|
||||||
|
binaryData := vectorField.GetBinaryVector()
|
||||||
|
byteWidth := (data.Dim + 7) / 8
|
||||||
|
numVectors := len(binaryData) / int(byteWidth)
|
||||||
|
|
||||||
|
for i := 0; i < numVectors; i++ {
|
||||||
|
start := i * int(byteWidth)
|
||||||
|
end := start + int(byteWidth)
|
||||||
|
valueBuilder.Append(binaryData[start:end])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addFloat16VectorArrayToPayload handles Float16Vector elements in VectorArray
|
||||||
|
func (w *NativePayloadWriter) addFloat16VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
|
||||||
|
if data.Dim <= 0 {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
||||||
|
|
||||||
|
// Each element in data.Data represents one row of VectorArray
|
||||||
|
for _, vectorField := range data.Data {
|
||||||
|
if vectorField.GetFloat16Vector() == nil {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("expected Float16Vector but got different type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a new list for this row
|
||||||
|
builder.Append(true)
|
||||||
|
|
||||||
|
float16Data := vectorField.GetFloat16Vector()
|
||||||
|
byteWidth := data.Dim * 2
|
||||||
|
numVectors := len(float16Data) / int(byteWidth)
|
||||||
|
|
||||||
|
for i := 0; i < numVectors; i++ {
|
||||||
|
start := i * int(byteWidth)
|
||||||
|
end := start + int(byteWidth)
|
||||||
|
valueBuilder.Append(float16Data[start:end])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addBFloat16VectorArrayToPayload handles BFloat16Vector elements in VectorArray
|
||||||
|
func (w *NativePayloadWriter) addBFloat16VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
|
||||||
|
if data.Dim <= 0 {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
||||||
|
|
||||||
|
// Each element in data.Data represents one row of VectorArray
|
||||||
|
for _, vectorField := range data.Data {
|
||||||
|
if vectorField.GetBfloat16Vector() == nil {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("expected BFloat16Vector but got different type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a new list for this row
|
||||||
|
builder.Append(true)
|
||||||
|
|
||||||
|
bfloat16Data := vectorField.GetBfloat16Vector()
|
||||||
|
byteWidth := data.Dim * 2
|
||||||
|
numVectors := len(bfloat16Data) / int(byteWidth)
|
||||||
|
|
||||||
|
for i := 0; i < numVectors; i++ {
|
||||||
|
start := i * int(byteWidth)
|
||||||
|
end := start + int(byteWidth)
|
||||||
|
valueBuilder.Append(bfloat16Data[start:end])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addInt8VectorArrayToPayload handles Int8Vector elements in VectorArray
|
||||||
|
func (w *NativePayloadWriter) addInt8VectorArrayToPayload(builder *array.ListBuilder, data *VectorArrayFieldData) error {
|
||||||
|
if data.Dim <= 0 {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("vector dimension must be greater than 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
||||||
|
|
||||||
|
// Each element in data.Data represents one row of VectorArray
|
||||||
|
for _, vectorField := range data.Data {
|
||||||
|
if vectorField.GetInt8Vector() == nil {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("expected Int8Vector but got different type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a new list for this row
|
||||||
|
builder.Append(true)
|
||||||
|
|
||||||
|
int8Data := vectorField.GetInt8Vector()
|
||||||
|
numVectors := len(int8Data) / int(data.Dim)
|
||||||
|
|
||||||
|
for i := 0; i < numVectors; i++ {
|
||||||
|
start := i * int(data.Dim)
|
||||||
|
end := start + int(data.Dim)
|
||||||
|
valueBuilder.Append(int8Data[start:end])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -309,6 +309,7 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) {
|
|||||||
vectorArrayData := &VectorArrayFieldData{
|
vectorArrayData := &VectorArrayFieldData{
|
||||||
Data: make([]*schemapb.VectorField, numRows),
|
Data: make([]*schemapb.VectorField, numRows),
|
||||||
ElementType: schemapb.DataType_FloatVector,
|
ElementType: schemapb.DataType_FloatVector,
|
||||||
|
Dim: int64(dim),
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < numRows; i++ {
|
for i := 0; i < numRows; i++ {
|
||||||
@ -408,6 +409,7 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) {
|
|||||||
batchData := &VectorArrayFieldData{
|
batchData := &VectorArrayFieldData{
|
||||||
Data: make([]*schemapb.VectorField, batchSize),
|
Data: make([]*schemapb.VectorField, batchSize),
|
||||||
ElementType: schemapb.DataType_FloatVector,
|
ElementType: schemapb.DataType_FloatVector,
|
||||||
|
Dim: int64(dim),
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < batchSize; i++ {
|
for i := 0; i < batchSize; i++ {
|
||||||
@ -454,6 +456,7 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) {
|
|||||||
vectorArrayData := &VectorArrayFieldData{
|
vectorArrayData := &VectorArrayFieldData{
|
||||||
Data: make([]*schemapb.VectorField, numRows),
|
Data: make([]*schemapb.VectorField, numRows),
|
||||||
ElementType: schemapb.DataType_FloatVector,
|
ElementType: schemapb.DataType_FloatVector,
|
||||||
|
Dim: int64(dim),
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < numRows; i++ {
|
for i := 0; i < numRows; i++ {
|
||||||
@ -485,6 +488,159 @@ func TestPayloadWriter_ArrayOfVector(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, numRows, length)
|
require.Equal(t, numRows, length)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Test ArrayOfFloat16Vector - Basic", func(t *testing.T) {
|
||||||
|
dim := 64
|
||||||
|
numRows := 50
|
||||||
|
vectorsPerRow := 4
|
||||||
|
|
||||||
|
// Create test data
|
||||||
|
vectorArrayData := &VectorArrayFieldData{
|
||||||
|
Data: make([]*schemapb.VectorField, numRows),
|
||||||
|
ElementType: schemapb.DataType_Float16Vector,
|
||||||
|
Dim: int64(dim),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numRows; i++ {
|
||||||
|
// Float16 vectors are stored as bytes (2 bytes per element)
|
||||||
|
byteData := make([]byte, vectorsPerRow*dim*2)
|
||||||
|
for j := 0; j < len(byteData); j++ {
|
||||||
|
byteData[j] = byte((i*1000 + j) % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
vectorArrayData.Data[i] = &schemapb.VectorField{
|
||||||
|
Dim: int64(dim),
|
||||||
|
Data: &schemapb.VectorField_Float16Vector{
|
||||||
|
Float16Vector: byteData,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w, err := NewPayloadWriter(
|
||||||
|
schemapb.DataType_ArrayOfVector,
|
||||||
|
WithDim(dim),
|
||||||
|
WithElementType(schemapb.DataType_Float16Vector),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, w)
|
||||||
|
|
||||||
|
err = w.AddVectorArrayFieldDataToPayload(vectorArrayData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = w.FinishPayloadWriter()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
buffer, err := w.GetPayloadBufferFromWriter()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, buffer)
|
||||||
|
|
||||||
|
length, err := w.GetPayloadLengthFromWriter()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, numRows, length)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Test ArrayOfBinaryVector - Basic", func(t *testing.T) {
|
||||||
|
dim := 128 // Must be multiple of 8
|
||||||
|
numRows := 50
|
||||||
|
vectorsPerRow := 3
|
||||||
|
|
||||||
|
// Create test data
|
||||||
|
vectorArrayData := &VectorArrayFieldData{
|
||||||
|
Data: make([]*schemapb.VectorField, numRows),
|
||||||
|
ElementType: schemapb.DataType_BinaryVector,
|
||||||
|
Dim: int64(dim),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numRows; i++ {
|
||||||
|
// Binary vectors use 1 bit per dimension, so dim/8 bytes per vector
|
||||||
|
byteData := make([]byte, vectorsPerRow*dim/8)
|
||||||
|
for j := 0; j < len(byteData); j++ {
|
||||||
|
byteData[j] = byte((i + j) % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
vectorArrayData.Data[i] = &schemapb.VectorField{
|
||||||
|
Dim: int64(dim),
|
||||||
|
Data: &schemapb.VectorField_BinaryVector{
|
||||||
|
BinaryVector: byteData,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w, err := NewPayloadWriter(
|
||||||
|
schemapb.DataType_ArrayOfVector,
|
||||||
|
WithDim(dim),
|
||||||
|
WithElementType(schemapb.DataType_BinaryVector),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, w)
|
||||||
|
|
||||||
|
err = w.AddVectorArrayFieldDataToPayload(vectorArrayData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = w.FinishPayloadWriter()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
buffer, err := w.GetPayloadBufferFromWriter()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, buffer)
|
||||||
|
|
||||||
|
length, err := w.GetPayloadLengthFromWriter()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, numRows, length)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Test ArrayOfBFloat16Vector - Basic", func(t *testing.T) {
|
||||||
|
dim := 64
|
||||||
|
numRows := 50
|
||||||
|
vectorsPerRow := 4
|
||||||
|
|
||||||
|
// Create test data
|
||||||
|
vectorArrayData := &VectorArrayFieldData{
|
||||||
|
Data: make([]*schemapb.VectorField, numRows),
|
||||||
|
ElementType: schemapb.DataType_BFloat16Vector,
|
||||||
|
Dim: int64(dim),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < numRows; i++ {
|
||||||
|
// BFloat16 vectors are stored as bytes (2 bytes per element)
|
||||||
|
byteData := make([]byte, vectorsPerRow*dim*2)
|
||||||
|
for j := 0; j < len(byteData); j++ {
|
||||||
|
byteData[j] = byte((i*100 + j) % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
vectorArrayData.Data[i] = &schemapb.VectorField{
|
||||||
|
Dim: int64(dim),
|
||||||
|
Data: &schemapb.VectorField_Bfloat16Vector{
|
||||||
|
Bfloat16Vector: byteData,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w, err := NewPayloadWriter(
|
||||||
|
schemapb.DataType_ArrayOfVector,
|
||||||
|
WithDim(dim),
|
||||||
|
WithElementType(schemapb.DataType_BFloat16Vector),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, w)
|
||||||
|
|
||||||
|
err = w.AddVectorArrayFieldDataToPayload(vectorArrayData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = w.FinishPayloadWriter()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify results
|
||||||
|
buffer, err := w.GetPayloadBufferFromWriter()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, buffer)
|
||||||
|
|
||||||
|
length, err := w.GetPayloadLengthFromWriter()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, numRows, length)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParquetEncoding(t *testing.T) {
|
func TestParquetEncoding(t *testing.T) {
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
package storage
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
@ -448,8 +449,8 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry {
|
|||||||
|
|
||||||
// ArrayOfVector now implements the standard interface with elementType parameter
|
// ArrayOfVector now implements the standard interface with elementType parameter
|
||||||
m[schemapb.DataType_ArrayOfVector] = serdeEntry{
|
m[schemapb.DataType_ArrayOfVector] = serdeEntry{
|
||||||
arrowType: func(_ int, elementType schemapb.DataType) arrow.DataType {
|
arrowType: func(dim int, elementType schemapb.DataType) arrow.DataType {
|
||||||
return getArrayOfVectorArrowType(elementType)
|
return getArrayOfVectorArrowType(elementType, dim)
|
||||||
},
|
},
|
||||||
deserialize: func(a arrow.Array, i int, elementType schemapb.DataType, dim int, shouldCopy bool) (any, bool) {
|
deserialize: func(a arrow.Array, i int, elementType schemapb.DataType, dim int, shouldCopy bool) (any, bool) {
|
||||||
return deserializeArrayOfVector(a, i, elementType, int64(dim), shouldCopy)
|
return deserializeArrayOfVector(a, i, elementType, int64(dim), shouldCopy)
|
||||||
@ -471,24 +472,64 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
builder.Append(true)
|
builder.Append(true)
|
||||||
|
valueBuilder := builder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
||||||
|
dim := vf.GetDim()
|
||||||
|
|
||||||
|
appendVectorChunks := func(data []byte, bytesPerVector int) bool {
|
||||||
|
numVectors := len(data) / bytesPerVector
|
||||||
|
for i := 0; i < numVectors; i++ {
|
||||||
|
start := i * bytesPerVector
|
||||||
|
end := start + bytesPerVector
|
||||||
|
valueBuilder.Append(data[start:end])
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
switch elementType {
|
switch elementType {
|
||||||
case schemapb.DataType_FloatVector:
|
case schemapb.DataType_FloatVector:
|
||||||
if vf.GetFloatVector() == nil {
|
if vf.GetFloatVector() == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
valueBuilder := builder.ValueBuilder().(*array.Float32Builder)
|
floatData := vf.GetFloatVector().GetData()
|
||||||
valueBuilder.AppendValues(vf.GetFloatVector().GetData(), nil)
|
numVectors := len(floatData) / int(dim)
|
||||||
|
// Convert float data to binary
|
||||||
|
for i := 0; i < numVectors; i++ {
|
||||||
|
start := i * int(dim)
|
||||||
|
end := start + int(dim)
|
||||||
|
vectorSlice := floatData[start:end]
|
||||||
|
|
||||||
|
bytes := make([]byte, dim*4)
|
||||||
|
for j, f := range vectorSlice {
|
||||||
|
binary.LittleEndian.PutUint32(bytes[j*4:], math.Float32bits(f))
|
||||||
|
}
|
||||||
|
valueBuilder.Append(bytes)
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
|
|
||||||
case schemapb.DataType_BinaryVector:
|
case schemapb.DataType_BinaryVector:
|
||||||
panic("BinaryVector in VectorArray not implemented yet")
|
if vf.GetBinaryVector() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return appendVectorChunks(vf.GetBinaryVector(), int((dim+7)/8))
|
||||||
|
|
||||||
case schemapb.DataType_Float16Vector:
|
case schemapb.DataType_Float16Vector:
|
||||||
panic("Float16Vector in VectorArray not implemented yet")
|
if vf.GetFloat16Vector() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return appendVectorChunks(vf.GetFloat16Vector(), int(dim)*2)
|
||||||
|
|
||||||
case schemapb.DataType_BFloat16Vector:
|
case schemapb.DataType_BFloat16Vector:
|
||||||
panic("BFloat16Vector in VectorArray not implemented yet")
|
if vf.GetBfloat16Vector() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return appendVectorChunks(vf.GetBfloat16Vector(), int(dim)*2)
|
||||||
|
|
||||||
case schemapb.DataType_Int8Vector:
|
case schemapb.DataType_Int8Vector:
|
||||||
panic("Int8Vector in VectorArray not implemented yet")
|
if vf.GetInt8Vector() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return appendVectorChunks(vf.GetInt8Vector(), int(dim))
|
||||||
|
|
||||||
case schemapb.DataType_SparseFloatVector:
|
case schemapb.DataType_SparseFloatVector:
|
||||||
panic("SparseFloatVector in VectorArray not implemented yet")
|
panic("SparseFloatVector in VectorArray not implemented yet")
|
||||||
default:
|
default:
|
||||||
@ -806,18 +847,18 @@ func (sfw *singleFieldRecordWriter) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getArrayOfVectorArrowType returns the appropriate Arrow type for ArrayOfVector based on element type
|
// getArrayOfVectorArrowType returns the appropriate Arrow type for ArrayOfVector based on element type
|
||||||
func getArrayOfVectorArrowType(elementType schemapb.DataType) arrow.DataType {
|
func getArrayOfVectorArrowType(elementType schemapb.DataType, dim int) arrow.DataType {
|
||||||
switch elementType {
|
switch elementType {
|
||||||
case schemapb.DataType_FloatVector:
|
case schemapb.DataType_FloatVector:
|
||||||
return arrow.ListOf(arrow.PrimitiveTypes.Float32)
|
return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim * 4})
|
||||||
case schemapb.DataType_BinaryVector:
|
case schemapb.DataType_BinaryVector:
|
||||||
return arrow.ListOf(arrow.PrimitiveTypes.Uint8)
|
return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: (dim + 7) / 8})
|
||||||
case schemapb.DataType_Float16Vector:
|
case schemapb.DataType_Float16Vector:
|
||||||
return arrow.ListOf(arrow.PrimitiveTypes.Uint8)
|
return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim * 2})
|
||||||
case schemapb.DataType_BFloat16Vector:
|
case schemapb.DataType_BFloat16Vector:
|
||||||
return arrow.ListOf(arrow.PrimitiveTypes.Uint8)
|
return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim * 2})
|
||||||
case schemapb.DataType_Int8Vector:
|
case schemapb.DataType_Int8Vector:
|
||||||
return arrow.ListOf(arrow.PrimitiveTypes.Int8)
|
return arrow.ListOf(&arrow.FixedSizeBinaryType{ByteWidth: dim})
|
||||||
case schemapb.DataType_SparseFloatVector:
|
case schemapb.DataType_SparseFloatVector:
|
||||||
return arrow.ListOf(arrow.BinaryTypes.Binary)
|
return arrow.ListOf(arrow.BinaryTypes.Binary)
|
||||||
default:
|
default:
|
||||||
@ -842,54 +883,78 @@ func deserializeArrayOfVector(a arrow.Array, i int, elementType schemapb.DataTyp
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate dimension for vector types that have fixed dimensions
|
valuesArray := arr.ListValues()
|
||||||
if dim > 0 && totalElements%dim != 0 {
|
binaryArray, ok := valuesArray.(*array.FixedSizeBinary)
|
||||||
// Dimension mismatch - data corruption or schema inconsistency
|
if !ok {
|
||||||
return nil, false
|
// empty array
|
||||||
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
valuesArray := arr.ListValues()
|
numVectors := int(totalElements)
|
||||||
|
|
||||||
|
// Helper function to extract byte vectors from FixedSizeBinary array
|
||||||
|
extractByteVectors := func(bytesPerVector int64) []byte {
|
||||||
|
totalBytes := numVectors * int(bytesPerVector)
|
||||||
|
data := make([]byte, totalBytes)
|
||||||
|
for j := 0; j < numVectors; j++ {
|
||||||
|
vectorIndex := int(start) + j
|
||||||
|
vectorData := binaryArray.Value(vectorIndex)
|
||||||
|
copy(data[j*int(bytesPerVector):], vectorData)
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
switch elementType {
|
switch elementType {
|
||||||
case schemapb.DataType_FloatVector:
|
case schemapb.DataType_FloatVector:
|
||||||
floatArray, ok := valuesArray.(*array.Float32)
|
totalFloats := numVectors * int(dim)
|
||||||
if !ok {
|
floatData := make([]float32, totalFloats)
|
||||||
return nil, false
|
for j := 0; j < numVectors; j++ {
|
||||||
|
vectorIndex := int(start) + j
|
||||||
|
binaryData := binaryArray.Value(vectorIndex)
|
||||||
|
vectorFloats := arrow.Float32Traits.CastFromBytes(binaryData)
|
||||||
|
copy(floatData[j*int(dim):], vectorFloats)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle data copying based on shouldCopy parameter
|
return &schemapb.VectorField{
|
||||||
var floatData []float32
|
|
||||||
if shouldCopy {
|
|
||||||
// Explicitly requested copy
|
|
||||||
floatData = make([]float32, totalElements)
|
|
||||||
for j := start; j < end; j++ {
|
|
||||||
floatData[j-start] = floatArray.Value(int(j))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Try to avoid copying - use a slice of the underlying data
|
|
||||||
// This creates a slice that references the same underlying array
|
|
||||||
allData := floatArray.Float32Values()
|
|
||||||
floatData = allData[start:end]
|
|
||||||
}
|
|
||||||
|
|
||||||
vectorField := &schemapb.VectorField{
|
|
||||||
Dim: dim,
|
Dim: dim,
|
||||||
Data: &schemapb.VectorField_FloatVector{
|
Data: &schemapb.VectorField_FloatVector{
|
||||||
FloatVector: &schemapb.FloatArray{
|
FloatVector: &schemapb.FloatArray{
|
||||||
Data: floatData,
|
Data: floatData,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}, true
|
||||||
return vectorField, true
|
|
||||||
|
|
||||||
case schemapb.DataType_BinaryVector:
|
case schemapb.DataType_BinaryVector:
|
||||||
panic("BinaryVector in VectorArray deserialization not implemented yet")
|
return &schemapb.VectorField{
|
||||||
|
Dim: dim,
|
||||||
|
Data: &schemapb.VectorField_BinaryVector{
|
||||||
|
BinaryVector: extractByteVectors((dim + 7) / 8),
|
||||||
|
},
|
||||||
|
}, true
|
||||||
|
|
||||||
case schemapb.DataType_Float16Vector:
|
case schemapb.DataType_Float16Vector:
|
||||||
panic("Float16Vector in VectorArray deserialization not implemented yet")
|
return &schemapb.VectorField{
|
||||||
|
Dim: dim,
|
||||||
|
Data: &schemapb.VectorField_Float16Vector{
|
||||||
|
Float16Vector: extractByteVectors(dim * 2),
|
||||||
|
},
|
||||||
|
}, true
|
||||||
|
|
||||||
case schemapb.DataType_BFloat16Vector:
|
case schemapb.DataType_BFloat16Vector:
|
||||||
panic("BFloat16Vector in VectorArray deserialization not implemented yet")
|
return &schemapb.VectorField{
|
||||||
|
Dim: dim,
|
||||||
|
Data: &schemapb.VectorField_Bfloat16Vector{
|
||||||
|
Bfloat16Vector: extractByteVectors(dim * 2),
|
||||||
|
},
|
||||||
|
}, true
|
||||||
|
|
||||||
case schemapb.DataType_Int8Vector:
|
case schemapb.DataType_Int8Vector:
|
||||||
panic("Int8Vector in VectorArray deserialization not implemented yet")
|
return &schemapb.VectorField{
|
||||||
|
Dim: dim,
|
||||||
|
Data: &schemapb.VectorField_Int8Vector{
|
||||||
|
Int8Vector: extractByteVectors(dim),
|
||||||
|
},
|
||||||
|
}, true
|
||||||
case schemapb.DataType_SparseFloatVector:
|
case schemapb.DataType_SparseFloatVector:
|
||||||
panic("SparseFloatVector in VectorArray deserialization not implemented yet")
|
panic("SparseFloatVector in VectorArray deserialization not implemented yet")
|
||||||
default:
|
default:
|
||||||
|
|||||||
@ -268,41 +268,48 @@ func TestCalculateArraySize(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestArrayOfVectorArrowType(t *testing.T) {
|
func TestArrayOfVectorArrowType(t *testing.T) {
|
||||||
|
dim := 128 // Test dimension
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
elementType schemapb.DataType
|
elementType schemapb.DataType
|
||||||
|
dim int
|
||||||
expectedChild arrow.DataType
|
expectedChild arrow.DataType
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "FloatVector",
|
name: "FloatVector",
|
||||||
elementType: schemapb.DataType_FloatVector,
|
elementType: schemapb.DataType_FloatVector,
|
||||||
expectedChild: arrow.PrimitiveTypes.Float32,
|
dim: dim,
|
||||||
|
expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim * 4},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "BinaryVector",
|
name: "BinaryVector",
|
||||||
elementType: schemapb.DataType_BinaryVector,
|
elementType: schemapb.DataType_BinaryVector,
|
||||||
expectedChild: arrow.PrimitiveTypes.Uint8,
|
dim: dim,
|
||||||
|
expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: (dim + 7) / 8},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Float16Vector",
|
name: "Float16Vector",
|
||||||
elementType: schemapb.DataType_Float16Vector,
|
elementType: schemapb.DataType_Float16Vector,
|
||||||
expectedChild: arrow.PrimitiveTypes.Uint8,
|
dim: dim,
|
||||||
|
expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "BFloat16Vector",
|
name: "BFloat16Vector",
|
||||||
elementType: schemapb.DataType_BFloat16Vector,
|
elementType: schemapb.DataType_BFloat16Vector,
|
||||||
expectedChild: arrow.PrimitiveTypes.Uint8,
|
dim: dim,
|
||||||
|
expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim * 2},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Int8Vector",
|
name: "Int8Vector",
|
||||||
elementType: schemapb.DataType_Int8Vector,
|
elementType: schemapb.DataType_Int8Vector,
|
||||||
expectedChild: arrow.PrimitiveTypes.Int8,
|
dim: dim,
|
||||||
|
expectedChild: &arrow.FixedSizeBinaryType{ByteWidth: dim},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
arrowType := getArrayOfVectorArrowType(tt.elementType)
|
arrowType := getArrayOfVectorArrowType(tt.elementType, tt.dim)
|
||||||
assert.NotNil(t, arrowType)
|
assert.NotNil(t, arrowType)
|
||||||
|
|
||||||
listType, ok := arrowType.(*arrow.ListType)
|
listType, ok := arrowType.(*arrow.ListType)
|
||||||
|
|||||||
@ -1665,20 +1665,21 @@ func SortFieldBinlogs(fieldBinlogs map[int64]*datapb.FieldBinlog) []*datapb.Fiel
|
|||||||
}
|
}
|
||||||
|
|
||||||
// VectorArrayToArrowType converts VectorArray element type to the corresponding Arrow type
|
// VectorArrayToArrowType converts VectorArray element type to the corresponding Arrow type
|
||||||
// Note: This returns the element type (e.g., float32), not a list type
|
// Note: This returns the element type (e.g., FixedSizeBinary), not a list type
|
||||||
// The caller is responsible for wrapping it in a list if needed
|
// The caller is responsible for wrapping it in a list if needed
|
||||||
func VectorArrayToArrowType(elementType schemapb.DataType) (arrow.DataType, error) {
|
func VectorArrayToArrowType(elementType schemapb.DataType, dim int) (arrow.DataType, error) {
|
||||||
switch elementType {
|
switch elementType {
|
||||||
case schemapb.DataType_FloatVector:
|
case schemapb.DataType_FloatVector:
|
||||||
return arrow.PrimitiveTypes.Float32, nil
|
// Each vector is stored as a fixed-size binary chunk
|
||||||
|
return &arrow.FixedSizeBinaryType{ByteWidth: dim * 4}, nil
|
||||||
case schemapb.DataType_BinaryVector:
|
case schemapb.DataType_BinaryVector:
|
||||||
return nil, merr.WrapErrParameterInvalidMsg("BinaryVector in VectorArray not implemented yet")
|
return &arrow.FixedSizeBinaryType{ByteWidth: (dim + 7) / 8}, nil
|
||||||
case schemapb.DataType_Float16Vector:
|
case schemapb.DataType_Float16Vector:
|
||||||
return nil, merr.WrapErrParameterInvalidMsg("Float16Vector in VectorArray not implemented yet")
|
return &arrow.FixedSizeBinaryType{ByteWidth: dim * 2}, nil
|
||||||
case schemapb.DataType_BFloat16Vector:
|
case schemapb.DataType_BFloat16Vector:
|
||||||
return nil, merr.WrapErrParameterInvalidMsg("BFloat16Vector in VectorArray not implemented yet")
|
return &arrow.FixedSizeBinaryType{ByteWidth: dim * 2}, nil
|
||||||
case schemapb.DataType_Int8Vector:
|
case schemapb.DataType_Int8Vector:
|
||||||
return nil, merr.WrapErrParameterInvalidMsg("Int8Vector in VectorArray not implemented yet")
|
return &arrow.FixedSizeBinaryType{ByteWidth: dim}, nil
|
||||||
default:
|
default:
|
||||||
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", elementType.String()))
|
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("unsupported element type in VectorArray: %s", elementType.String()))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1845,34 +1845,49 @@ func ReadVectorArrayData(pcr *FieldReader, count int64) (any, error) {
|
|||||||
if chunk.NullN() > 0 {
|
if chunk.NullN() > 0 {
|
||||||
return nil, WrapNullRowErr(pcr.field)
|
return nil, WrapNullRowErr(pcr.field)
|
||||||
}
|
}
|
||||||
|
// ArrayOfVector is stored as list of fixed size binary
|
||||||
listReader, ok := chunk.(*array.List)
|
listReader, ok := chunk.(*array.List)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
||||||
}
|
}
|
||||||
listFloat32Reader, ok := listReader.ListValues().(*array.Float32)
|
|
||||||
|
fixedBinaryReader, ok := listReader.ListValues().(*array.FixedSizeBinary)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, WrapTypeErr(pcr.field, chunk.DataType().Name())
|
return nil, WrapTypeErr(pcr.field, listReader.ListValues().DataType().Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that each vector has the correct byte size (dim * 4 bytes for float32)
|
||||||
|
expectedByteSize := int(dim) * 4
|
||||||
|
actualByteSize := fixedBinaryReader.DataType().(*arrow.FixedSizeBinaryType).ByteWidth
|
||||||
|
if actualByteSize != expectedByteSize {
|
||||||
|
return nil, merr.WrapErrImportFailed(fmt.Sprintf("vector byte size mismatch: expected %d, got %d for field '%s'",
|
||||||
|
expectedByteSize, actualByteSize, pcr.field.GetName()))
|
||||||
|
}
|
||||||
|
|
||||||
offsets := listReader.Offsets()
|
offsets := listReader.Offsets()
|
||||||
for i := 1; i < len(offsets); i++ {
|
for i := 1; i < len(offsets); i++ {
|
||||||
start, end := offsets[i-1], offsets[i]
|
start, end := offsets[i-1], offsets[i]
|
||||||
floatCount := end - start
|
vectorCount := end - start
|
||||||
if floatCount%int32(dim) != 0 {
|
|
||||||
return nil, merr.WrapErrImportFailed(fmt.Sprintf("vectors in VectorArray should be aligned with dim: %d", dim))
|
|
||||||
}
|
|
||||||
|
|
||||||
arrLength := floatCount / int32(dim)
|
if err = common.CheckArrayCapacity(int(vectorCount), maxCapacity, pcr.field); err != nil {
|
||||||
if err = common.CheckArrayCapacity(int(arrLength), maxCapacity, pcr.field); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
arrData := make([]float32, floatCount)
|
// Convert binary data to float32 array using arrow's built-in conversion
|
||||||
copy(arrData, listFloat32Reader.Float32Values()[start:end])
|
totalFloats := vectorCount * int32(dim)
|
||||||
|
floatData := make([]float32, totalFloats)
|
||||||
|
for j := int32(0); j < vectorCount; j++ {
|
||||||
|
vectorIndex := start + j
|
||||||
|
binaryData := fixedBinaryReader.Value(int(vectorIndex))
|
||||||
|
vectorFloats := arrow.Float32Traits.CastFromBytes(binaryData)
|
||||||
|
copy(floatData[j*int32(dim):(j+1)*int32(dim)], vectorFloats)
|
||||||
|
}
|
||||||
|
|
||||||
data = append(data, &schemapb.VectorField{
|
data = append(data, &schemapb.VectorField{
|
||||||
Dim: dim,
|
Dim: dim,
|
||||||
Data: &schemapb.VectorField_FloatVector{
|
Data: &schemapb.VectorField_FloatVector{
|
||||||
FloatVector: &schemapb.FloatArray{
|
FloatVector: &schemapb.FloatArray{
|
||||||
Data: arrData,
|
Data: floatData,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@ -195,6 +195,8 @@ func isArrowDataTypeConvertible(src arrow.DataType, dst arrow.DataType, field *s
|
|||||||
return valid
|
return valid
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
|
case arrow.FIXED_SIZE_BINARY:
|
||||||
|
return dstType == arrow.FIXED_SIZE_BINARY
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -293,9 +295,11 @@ func convertToArrowDataType(field *schemapb.FieldSchema, isArray bool) (arrow.Da
|
|||||||
Metadata: arrow.Metadata{},
|
Metadata: arrow.Metadata{},
|
||||||
}), nil
|
}), nil
|
||||||
case schemapb.DataType_ArrayOfVector:
|
case schemapb.DataType_ArrayOfVector:
|
||||||
// VectorArrayToArrowType now returns the element type (e.g., float32)
|
dim, err := typeutil.GetDim(field)
|
||||||
// We wrap it in a single list to get list<float32> (flattened)
|
if err != nil {
|
||||||
elemType, err := storage.VectorArrayToArrowType(field.GetElementType())
|
return nil, err
|
||||||
|
}
|
||||||
|
elemType, err := storage.VectorArrayToArrowType(field.GetElementType(), int(dim))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@ -52,11 +52,12 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
FloatVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE} // const
|
// all consts
|
||||||
SparseFloatVectorMetrics = []string{metric.IP, metric.BM25} // const
|
FloatVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE}
|
||||||
BinaryVectorMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE, metric.MHJACCARD} // const
|
SparseFloatVectorMetrics = []string{metric.IP, metric.BM25}
|
||||||
IntVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE} // const
|
BinaryVectorMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE, metric.MHJACCARD}
|
||||||
EmbListMetrics = []string{metric.MaxSim} // const
|
IntVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE}
|
||||||
|
EmbListMetrics = []string{metric.MaxSim, metric.MaxSimCosine, metric.MaxSimL2, metric.MaxSimIP, metric.MaxSimHamming, metric.MaxSimJaccard}
|
||||||
)
|
)
|
||||||
|
|
||||||
// BinIDMapMetrics is a set of all metric types supported for binary vector.
|
// BinIDMapMetrics is a set of all metric types supported for binary vector.
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package testutil
|
package testutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@ -734,31 +735,114 @@ func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.Inser
|
|||||||
columns = append(columns, builder.NewListArray())
|
columns = append(columns, builder.NewListArray())
|
||||||
}
|
}
|
||||||
case schemapb.DataType_ArrayOfVector:
|
case schemapb.DataType_ArrayOfVector:
|
||||||
data := insertData.Data[fieldID].(*storage.VectorArrayFieldData).Data
|
vectorArrayData := insertData.Data[fieldID].(*storage.VectorArrayFieldData)
|
||||||
rows := len(data)
|
dim, err := typeutil.GetDim(field)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
elemType, err := storage.VectorArrayToArrowType(elementType, int(dim))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create ListBuilder with "item" field name to match convertToArrowDataType
|
||||||
|
// Always represented as a list of fixed-size binary values
|
||||||
|
listBuilder := array.NewListBuilderWithField(mem, arrow.Field{
|
||||||
|
Name: "item",
|
||||||
|
Type: elemType,
|
||||||
|
Nullable: true,
|
||||||
|
Metadata: arrow.Metadata{},
|
||||||
|
})
|
||||||
|
fixedSizeBuilder, ok := listBuilder.ValueBuilder().(*array.FixedSizeBinaryBuilder)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected list value builder for VectorArray field %s: %T", field.GetName(), listBuilder.ValueBuilder())
|
||||||
|
}
|
||||||
|
|
||||||
|
vectorArrayData.Dim = dim
|
||||||
|
|
||||||
|
bytesPerVector := fixedSizeBuilder.Type().(*arrow.FixedSizeBinaryType).ByteWidth
|
||||||
|
|
||||||
|
appendBinarySlice := func(data []byte, stride int) error {
|
||||||
|
if stride == 0 {
|
||||||
|
return fmt.Errorf("zero stride for VectorArray field %s", field.GetName())
|
||||||
|
}
|
||||||
|
if len(data)%stride != 0 {
|
||||||
|
return fmt.Errorf("vector array data length %d is not divisible by stride %d for field %s", len(data), stride, field.GetName())
|
||||||
|
}
|
||||||
|
for offset := 0; offset < len(data); offset += stride {
|
||||||
|
fixedSizeBuilder.Append(data[offset : offset+stride])
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, vectorField := range vectorArrayData.Data {
|
||||||
|
if vectorField == nil {
|
||||||
|
listBuilder.Append(false)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
listBuilder.Append(true)
|
||||||
|
|
||||||
switch elementType {
|
switch elementType {
|
||||||
case schemapb.DataType_FloatVector:
|
case schemapb.DataType_FloatVector:
|
||||||
// ArrayOfVector is flattened in Arrow - just a list of floats
|
floatArray := vectorField.GetFloatVector()
|
||||||
// where total floats = dim * num_vectors
|
if floatArray == nil {
|
||||||
builder := array.NewListBuilder(mem, &arrow.Float32Type{})
|
return nil, fmt.Errorf("expected FloatVector data for field %s", field.GetName())
|
||||||
valueBuilder := builder.ValueBuilder().(*array.Float32Builder)
|
}
|
||||||
|
data := floatArray.GetData()
|
||||||
for i := 0; i < rows; i++ {
|
if len(data) == 0 {
|
||||||
vectorArray := data[i].GetFloatVector()
|
|
||||||
if vectorArray == nil || len(vectorArray.GetData()) == 0 {
|
|
||||||
builder.AppendNull()
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
builder.Append(true)
|
if len(data)%int(dim) != 0 {
|
||||||
// Append all flattened vector data
|
return nil, fmt.Errorf("float vector data length %d is not divisible by dim %d for field %s", len(data), dim, field.GetName())
|
||||||
valueBuilder.AppendValues(vectorArray.GetData(), nil)
|
}
|
||||||
|
for offset := 0; offset < len(data); offset += int(dim) {
|
||||||
|
vectorBytes := make([]byte, bytesPerVector)
|
||||||
|
for j := 0; j < int(dim); j++ {
|
||||||
|
binary.LittleEndian.PutUint32(vectorBytes[j*4:], math.Float32bits(data[offset+j]))
|
||||||
|
}
|
||||||
|
fixedSizeBuilder.Append(vectorBytes)
|
||||||
|
}
|
||||||
|
case schemapb.DataType_BinaryVector:
|
||||||
|
binaryData := vectorField.GetBinaryVector()
|
||||||
|
if len(binaryData) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
bytesPer := int((dim + 7) / 8)
|
||||||
|
if err := appendBinarySlice(binaryData, bytesPer); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case schemapb.DataType_Float16Vector:
|
||||||
|
float16Data := vectorField.GetFloat16Vector()
|
||||||
|
if len(float16Data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := appendBinarySlice(float16Data, int(dim)*2); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case schemapb.DataType_BFloat16Vector:
|
||||||
|
bfloat16Data := vectorField.GetBfloat16Vector()
|
||||||
|
if len(bfloat16Data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := appendBinarySlice(bfloat16Data, int(dim)*2); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case schemapb.DataType_Int8Vector:
|
||||||
|
int8Data := vectorField.GetInt8Vector()
|
||||||
|
if len(int8Data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := appendBinarySlice(int8Data, int(dim)); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
columns = append(columns, builder.NewListArray())
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported element type in VectorArray: %s", elementType.String())
|
return nil, fmt.Errorf("unsupported element type in VectorArray: %s", elementType.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
columns = append(columns, listBuilder.NewListArray())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return columns, nil
|
return columns, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -45,6 +45,7 @@ enum VectorType {
|
|||||||
EmbListFloat16Vector = 7;
|
EmbListFloat16Vector = 7;
|
||||||
EmbListBFloat16Vector = 8;
|
EmbListBFloat16Vector = 8;
|
||||||
EmbListInt8Vector = 9;
|
EmbListInt8Vector = 9;
|
||||||
|
EmbListBinaryVector = 10;
|
||||||
};
|
};
|
||||||
|
|
||||||
message GenericValue {
|
message GenericValue {
|
||||||
|
|||||||
@ -184,6 +184,7 @@ const (
|
|||||||
VectorType_EmbListFloat16Vector VectorType = 7
|
VectorType_EmbListFloat16Vector VectorType = 7
|
||||||
VectorType_EmbListBFloat16Vector VectorType = 8
|
VectorType_EmbListBFloat16Vector VectorType = 8
|
||||||
VectorType_EmbListInt8Vector VectorType = 9
|
VectorType_EmbListInt8Vector VectorType = 9
|
||||||
|
VectorType_EmbListBinaryVector VectorType = 10
|
||||||
)
|
)
|
||||||
|
|
||||||
// Enum value maps for VectorType.
|
// Enum value maps for VectorType.
|
||||||
@ -199,6 +200,7 @@ var (
|
|||||||
7: "EmbListFloat16Vector",
|
7: "EmbListFloat16Vector",
|
||||||
8: "EmbListBFloat16Vector",
|
8: "EmbListBFloat16Vector",
|
||||||
9: "EmbListInt8Vector",
|
9: "EmbListInt8Vector",
|
||||||
|
10: "EmbListBinaryVector",
|
||||||
}
|
}
|
||||||
VectorType_value = map[string]int32{
|
VectorType_value = map[string]int32{
|
||||||
"BinaryVector": 0,
|
"BinaryVector": 0,
|
||||||
@ -211,6 +213,7 @@ var (
|
|||||||
"EmbListFloat16Vector": 7,
|
"EmbListFloat16Vector": 7,
|
||||||
"EmbListBFloat16Vector": 8,
|
"EmbListBFloat16Vector": 8,
|
||||||
"EmbListInt8Vector": 9,
|
"EmbListInt8Vector": 9,
|
||||||
|
"EmbListBinaryVector": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3808,7 +3811,7 @@ var file_plan_proto_rawDesc = []byte{
|
|||||||
0x0a, 0x03, 0x53, 0x75, 0x62, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x75, 0x6c, 0x10, 0x03,
|
0x0a, 0x03, 0x53, 0x75, 0x62, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x75, 0x6c, 0x10, 0x03,
|
||||||
0x12, 0x07, 0x0a, 0x03, 0x44, 0x69, 0x76, 0x10, 0x04, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x6f, 0x64,
|
0x12, 0x07, 0x0a, 0x03, 0x44, 0x69, 0x76, 0x10, 0x04, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x6f, 0x64,
|
||||||
0x10, 0x05, 0x12, 0x0f, 0x0a, 0x0b, 0x41, 0x72, 0x72, 0x61, 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74,
|
0x10, 0x05, 0x12, 0x0f, 0x0a, 0x0b, 0x41, 0x72, 0x72, 0x61, 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74,
|
||||||
0x68, 0x10, 0x06, 0x2a, 0xe1, 0x01, 0x0a, 0x0a, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x54, 0x79,
|
0x68, 0x10, 0x06, 0x2a, 0xfa, 0x01, 0x0a, 0x0a, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x54, 0x79,
|
||||||
0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x63, 0x74,
|
0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x63, 0x74,
|
||||||
0x6f, 0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63,
|
0x6f, 0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63,
|
||||||
0x74, 0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36,
|
0x74, 0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36,
|
||||||
@ -3822,22 +3825,23 @@ var file_plan_proto_rawDesc = []byte{
|
|||||||
0x74, 0x6f, 0x72, 0x10, 0x07, 0x12, 0x19, 0x0a, 0x15, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74,
|
0x74, 0x6f, 0x72, 0x10, 0x07, 0x12, 0x19, 0x0a, 0x15, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74,
|
||||||
0x42, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x08,
|
0x42, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x08,
|
||||||
0x12, 0x15, 0x0a, 0x11, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x49, 0x6e, 0x74, 0x38, 0x56,
|
0x12, 0x15, 0x0a, 0x11, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x49, 0x6e, 0x74, 0x38, 0x56,
|
||||||
0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x09, 0x2a, 0x3e, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74,
|
0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x09, 0x12, 0x17, 0x0a, 0x13, 0x45, 0x6d, 0x62, 0x4c, 0x69,
|
||||||
0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x12, 0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63, 0x74,
|
0x73, 0x74, 0x42, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x0a,
|
||||||
0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x57, 0x65, 0x69, 0x67, 0x68, 0x74, 0x10, 0x00, 0x12,
|
0x2a, 0x3e, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65,
|
||||||
0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x52,
|
0x12, 0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65,
|
||||||
0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x10, 0x01, 0x2a, 0x3d, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74,
|
0x57, 0x65, 0x69, 0x67, 0x68, 0x74, 0x10, 0x00, 0x12, 0x16, 0x0a, 0x12, 0x46, 0x75, 0x6e, 0x63,
|
||||||
0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x14, 0x46, 0x75, 0x6e, 0x63, 0x74,
|
0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x61, 0x6e, 0x64, 0x6f, 0x6d, 0x10, 0x01,
|
||||||
0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x79, 0x10,
|
0x2a, 0x3d, 0x0a, 0x0c, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65,
|
||||||
0x00, 0x12, 0x13, 0x0a, 0x0f, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64,
|
0x12, 0x18, 0x0a, 0x14, 0x46, 0x75, 0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65,
|
||||||
0x65, 0x53, 0x75, 0x6d, 0x10, 0x01, 0x2a, 0x34, 0x0a, 0x09, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d,
|
0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x79, 0x10, 0x00, 0x12, 0x13, 0x0a, 0x0f, 0x46, 0x75,
|
||||||
0x6f, 0x64, 0x65, 0x12, 0x15, 0x0a, 0x11, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65,
|
0x6e, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x4d, 0x6f, 0x64, 0x65, 0x53, 0x75, 0x6d, 0x10, 0x01, 0x2a,
|
||||||
0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x79, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x6f,
|
0x34, 0x0a, 0x09, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x15, 0x0a, 0x11,
|
||||||
0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x53, 0x75, 0x6d, 0x10, 0x01, 0x42, 0x31, 0x5a, 0x2f,
|
0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c,
|
||||||
0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75,
|
0x79, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x6f, 0x6f, 0x73, 0x74, 0x4d, 0x6f, 0x64, 0x65,
|
||||||
0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f,
|
0x53, 0x75, 0x6d, 0x10, 0x01, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e,
|
||||||
0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62,
|
0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69,
|
||||||
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74,
|
||||||
|
0x6f, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@ -44,5 +44,11 @@ const (
|
|||||||
|
|
||||||
EMPTY MetricType = ""
|
EMPTY MetricType = ""
|
||||||
|
|
||||||
|
// The same with MaxSimCosine
|
||||||
MaxSim MetricType = "MAX_SIM"
|
MaxSim MetricType = "MAX_SIM"
|
||||||
|
MaxSimCosine MetricType = "MAX_SIM_COSINE"
|
||||||
|
MaxSimL2 MetricType = "MAX_SIM_L2"
|
||||||
|
MaxSimIP MetricType = "MAX_SIM_IP"
|
||||||
|
MaxSimHamming MetricType = "MAX_SIM_HAMMING"
|
||||||
|
MaxSimJaccard MetricType = "MAX_SIM_JACCARD"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -176,7 +176,7 @@ func (s *ArrayStructDataNodeSuite) loadCollection(collectionName string) {
|
|||||||
CollectionName: collectionName,
|
CollectionName: collectionName,
|
||||||
FieldName: subFieldName,
|
FieldName: subFieldName,
|
||||||
IndexName: "array_of_vector_index",
|
IndexName: "array_of_vector_index",
|
||||||
ExtraParams: integration.ConstructIndexParam(s.dim, integration.IndexEmbListHNSW, metric.MaxSim),
|
ExtraParams: integration.ConstructIndexParam(s.dim, integration.IndexHNSW, metric.MaxSim),
|
||||||
})
|
})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Require().Equal(createIndexResult.GetErrorCode(), commonpb.ErrorCode_Success)
|
s.Require().Equal(createIndexResult.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||||
@ -318,7 +318,7 @@ func (s *ArrayStructDataNodeSuite) query(collectionName string) {
|
|||||||
roundDecimal := -1
|
roundDecimal := -1
|
||||||
|
|
||||||
subFieldName := proxy.ConcatStructFieldName(integration.StructArrayField, integration.StructSubFloatVecField)
|
subFieldName := proxy.ConcatStructFieldName(integration.StructArrayField, integration.StructSubFloatVecField)
|
||||||
params := integration.GetSearchParams(integration.IndexEmbListHNSW, metric.MaxSim)
|
params := integration.GetSearchParams(integration.IndexHNSW, metric.MaxSim)
|
||||||
searchReq := integration.ConstructEmbeddingListSearchRequest("", collectionName, expr,
|
searchReq := integration.ConstructEmbeddingListSearchRequest("", collectionName, expr,
|
||||||
subFieldName, schemapb.DataType_FloatVector, []string{integration.StructArrayField}, metric.MaxSim, params, nq, s.dim, topk, roundDecimal)
|
subFieldName, schemapb.DataType_FloatVector, []string{integration.StructArrayField}, metric.MaxSim, params, nq, s.dim, topk, roundDecimal)
|
||||||
|
|
||||||
|
|||||||
@ -246,7 +246,7 @@ func (s *TestArrayStructSuite) run() {
|
|||||||
func (s *TestArrayStructSuite) TestGetVector_ArrayStruct_FloatVector() {
|
func (s *TestArrayStructSuite) TestGetVector_ArrayStruct_FloatVector() {
|
||||||
s.nq = 10
|
s.nq = 10
|
||||||
s.topK = 10
|
s.topK = 10
|
||||||
s.indexType = integration.IndexEmbListHNSW
|
s.indexType = integration.IndexHNSW
|
||||||
s.metricType = metric.MaxSim
|
s.metricType = metric.MaxSim
|
||||||
s.vecType = schemapb.DataType_FloatVector
|
s.vecType = schemapb.DataType_FloatVector
|
||||||
s.run()
|
s.run()
|
||||||
|
|||||||
@ -102,7 +102,7 @@ func (s *BulkInsertSuite) PrepareSourceCollection(dim int, dmlGroup *DMLGroup) *
|
|||||||
CollectionName: collectionName,
|
CollectionName: collectionName,
|
||||||
FieldName: name,
|
FieldName: name,
|
||||||
IndexName: "array_of_vector_index",
|
IndexName: "array_of_vector_index",
|
||||||
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexEmbListHNSW, metric.MaxSim),
|
ExtraParams: integration.ConstructIndexParam(dim, integration.IndexHNSW, metric.MaxSim),
|
||||||
})
|
})
|
||||||
s.NoError(err)
|
s.NoError(err)
|
||||||
s.Require().Equal(createIndexResult.GetErrorCode(), commonpb.ErrorCode_Success)
|
s.Require().Equal(createIndexResult.GetErrorCode(), commonpb.ErrorCode_Success)
|
||||||
|
|||||||
@ -303,7 +303,7 @@ func (s *BulkInsertSuite) TestImportWithVectorArray() {
|
|||||||
for _, fileType := range fileTypeArr {
|
for _, fileType := range fileTypeArr {
|
||||||
s.fileType = fileType
|
s.fileType = fileType
|
||||||
s.vecType = schemapb.DataType_FloatVector
|
s.vecType = schemapb.DataType_FloatVector
|
||||||
s.indexType = integration.IndexEmbListHNSW
|
s.indexType = integration.IndexHNSW
|
||||||
s.metricType = metric.MaxSim
|
s.metricType = metric.MaxSim
|
||||||
s.runForStructArray()
|
s.runForStructArray()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,7 +43,6 @@ const (
|
|||||||
IndexDISKANN = "DISKANN"
|
IndexDISKANN = "DISKANN"
|
||||||
IndexSparseInvertedIndex = "SPARSE_INVERTED_INDEX"
|
IndexSparseInvertedIndex = "SPARSE_INVERTED_INDEX"
|
||||||
IndexSparseWand = "SPARSE_WAND"
|
IndexSparseWand = "SPARSE_WAND"
|
||||||
IndexEmbListHNSW = "EMB_LIST_HNSW"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {
|
func (s *MiniClusterSuite) WaitForIndexBuiltWithDB(ctx context.Context, dbName, collection, field string) {
|
||||||
@ -169,15 +168,6 @@ func ConstructIndexParam(dim int, indexType string, metricType string) []*common
|
|||||||
Key: "efConstruction",
|
Key: "efConstruction",
|
||||||
Value: "200",
|
Value: "200",
|
||||||
})
|
})
|
||||||
case IndexEmbListHNSW:
|
|
||||||
params = append(params, &commonpb.KeyValuePair{
|
|
||||||
Key: "M",
|
|
||||||
Value: "16",
|
|
||||||
})
|
|
||||||
params = append(params, &commonpb.KeyValuePair{
|
|
||||||
Key: "efConstruction",
|
|
||||||
Value: "200",
|
|
||||||
})
|
|
||||||
case IndexSparseInvertedIndex:
|
case IndexSparseInvertedIndex:
|
||||||
case IndexSparseWand:
|
case IndexSparseWand:
|
||||||
case IndexDISKANN:
|
case IndexDISKANN:
|
||||||
@ -195,7 +185,6 @@ func GetSearchParams(indexType string, metricType string) map[string]any {
|
|||||||
case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8, IndexFaissIvfPQ, IndexScaNN:
|
case IndexFaissIvfFlat, IndexFaissBinIvfFlat, IndexFaissIvfSQ8, IndexFaissIvfPQ, IndexScaNN:
|
||||||
params["nprobe"] = 8
|
params["nprobe"] = 8
|
||||||
case IndexHNSW:
|
case IndexHNSW:
|
||||||
case IndexEmbListHNSW:
|
|
||||||
params["ef"] = 200
|
params["ef"] = 200
|
||||||
case IndexDISKANN:
|
case IndexDISKANN:
|
||||||
params["search_list"] = 20
|
params["search_list"] = 20
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user