feat: impl StructArray -- support create index for vector array (embedding list) and search on it (#43726)

Ref https://github.com/milvus-io/milvus/issues/42148

This PR supports create index for vector array (now, only for
`DataType.FLOAT_VECTOR`) and search on it.
The index type supported in this PR is `EMB_LIST_HNSW` and the metric
type is `MAX_SIM` only.

The way to use it:
```python
milvus_client = MilvusClient("xxx:19530")
schema = milvus_client.create_schema(enable_dynamic_field=True, auto_id=True)
...
struct_schema = milvus_client.create_struct_array_field_schema("struct_array_field")
...
struct_schema.add_field("struct_float_vec", DataType.ARRAY_OF_VECTOR, element_type=DataType.FLOAT_VECTOR, dim=128, max_capacity=1000)
...
schema.add_struct_array_field(struct_schema)
index_params = milvus_client.prepare_index_params()
index_params.add_index(field_name="struct_float_vec", index_type="EMB_LIST_HNSW", metric_type="MAX_SIM", index_params={"nlist": 128})
...
milvus_client.create_index(COLLECTION_NAME, schema=schema, index_params=index_params)
```

Note: This PR uses `Lims` to convey offsets of the vector array to
knowhere where vectors of multiple vector arrays are concatenated and we
need offsets to specify which vectors belong to which vector array.

---------

Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
Signed-off-by: SpadeA-Tang <tangchenjie1210@gmail.com>
This commit is contained in:
Spade A 2025-08-20 10:27:46 +08:00 committed by GitHub
parent cfdb17a088
commit d6a428e880
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
138 changed files with 3562 additions and 1819 deletions

View File

@ -395,6 +395,14 @@ class VectorArrayChunk : public Chunk {
dim_(dim), dim_(dim),
element_type_(element_type) { element_type_(element_type) {
offsets_lens_ = reinterpret_cast<uint32_t*>(data); offsets_lens_ = reinterpret_cast<uint32_t*>(data);
auto offset = 0;
lims_.reserve(row_nums_ + 1);
lims_.push_back(offset);
for (int64_t i = 0; i < row_nums_; i++) {
offset += offsets_lens_[i * 2 + 1];
lims_.push_back(offset);
}
} }
VectorArrayView VectorArrayView
@ -424,10 +432,23 @@ class VectorArrayChunk : public Chunk {
"VectorArrayChunk::ValueAt is not supported"); "VectorArrayChunk::ValueAt is not supported");
} }
const char*
Data() const override {
return data_ + offsets_lens_[0];
}
const size_t*
Lims() const {
return lims_.data();
}
private: private:
int64_t dim_; int64_t dim_;
uint32_t* offsets_lens_; uint32_t* offsets_lens_;
milvus::DataType element_type_; milvus::DataType element_type_;
// The name 'Lims' is consistent with knowhere::DataSet::SetLims which describes the number of vectors
// in each vector array (embedding list). This is needed as vectors are flattened in the chunk.
std::vector<size_t> lims_;
}; };
class SparseFloatVectorChunk : public Chunk { class SparseFloatVectorChunk : public Chunk {

View File

@ -92,6 +92,15 @@ class FieldData<VectorArray> : public FieldDataVectorArrayImpl {
ThrowInfo(Unsupported, ThrowInfo(Unsupported,
"Call get_dim on FieldData<VectorArray> is not supported"); "Call get_dim on FieldData<VectorArray> is not supported");
} }
const VectorArray*
value_at(ssize_t offset) const {
AssertInfo(offset < get_num_rows(),
"field data subscript out of range");
AssertInfo(offset < length(),
"subscript position don't has valid value");
return &data_[offset];
}
}; };
template <> template <>

View File

@ -47,6 +47,7 @@ constexpr bool IsVariableType =
IsSparse<T> || std::is_same_v<T, VectorArray> || IsSparse<T> || std::is_same_v<T, VectorArray> ||
std::is_same_v<T, VectorArrayView>; std::is_same_v<T, VectorArrayView>;
// todo(SpadeA): support vector array
template <typename T> template <typename T>
constexpr bool IsVariableTypeSupportInChunk = constexpr bool IsVariableTypeSupportInChunk =
std::is_same_v<T, std::string> || std::is_same_v<T, Array> || std::is_same_v<T, std::string> || std::is_same_v<T, Array> ||

View File

@ -493,7 +493,8 @@ IsFloatVectorMetricType(const MetricType& metric_type) {
return metric_type == knowhere::metric::L2 || return metric_type == knowhere::metric::L2 ||
metric_type == knowhere::metric::IP || metric_type == knowhere::metric::IP ||
metric_type == knowhere::metric::COSINE || metric_type == knowhere::metric::COSINE ||
metric_type == knowhere::metric::BM25; metric_type == knowhere::metric::BM25 ||
metric_type == knowhere::metric::MAX_SIM;
} }
inline bool inline bool

View File

@ -28,20 +28,23 @@
namespace milvus { namespace milvus {
#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \ #define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \
using elem_type = std::conditional_t< \ using elem_type = std::conditional_t< \
std::is_same_v<TraitType, milvus::FloatVector>, \ std::is_same_v<TraitType, milvus::EmbListFloatVector>, \
milvus::FloatVector::embedded_type, \ milvus::EmbListFloatVector::embedded_type, \
std::conditional_t< \ std::conditional_t< \
std::is_same_v<TraitType, milvus::Float16Vector>, \ std::is_same_v<TraitType, milvus::FloatVector>, \
milvus::Float16Vector::embedded_type, \ milvus::FloatVector::embedded_type, \
std::conditional_t< \ std::conditional_t< \
std::is_same_v<TraitType, milvus::BFloat16Vector>, \ std::is_same_v<TraitType, milvus::Float16Vector>, \
milvus::BFloat16Vector::embedded_type, \ milvus::Float16Vector::embedded_type, \
std::conditional_t< \ std::conditional_t< \
std::is_same_v<TraitType, milvus::Int8Vector>, \ std::is_same_v<TraitType, milvus::BFloat16Vector>, \
milvus::Int8Vector::embedded_type, \ milvus::BFloat16Vector::embedded_type, \
milvus::BinaryVector::embedded_type>>>>; std::conditional_t< \
std::is_same_v<TraitType, milvus::Int8Vector>, \
milvus::Int8Vector::embedded_type, \
milvus::BinaryVector::embedded_type>>>>>;
#define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \ #define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \
auto schema_data_type = \ auto schema_data_type = \
@ -55,7 +58,13 @@ namespace milvus {
? milvus::Int8Vector::schema_data_type \ ? milvus::Int8Vector::schema_data_type \
: milvus::BinaryVector::schema_data_type; : milvus::BinaryVector::schema_data_type;
class VectorTrait {}; class VectorTrait {
public:
static constexpr bool
is_embedding_list() {
return false;
}
};
class FloatVector : public VectorTrait { class FloatVector : public VectorTrait {
public: public:
@ -136,6 +145,25 @@ class Int8Vector : public VectorTrait {
proto::common::PlaceholderType::Int8Vector; proto::common::PlaceholderType::Int8Vector;
}; };
class EmbListFloatVector : public VectorTrait {
public:
using embedded_type = float;
static constexpr int32_t dim_factor = 1;
static constexpr auto data_type = DataType::VECTOR_ARRAY;
static constexpr auto c_data_type = CDataType::VectorArray;
static constexpr auto schema_data_type =
proto::schema::DataType::ArrayOfVector;
static constexpr auto vector_type =
proto::plan::VectorType::EmbListFloatVector;
static constexpr auto placeholder_type =
proto::common::PlaceholderType::EmbListFloatVector;
static constexpr bool
is_embedding_list() {
return true;
}
};
struct FundamentalTag {}; struct FundamentalTag {};
struct StringTag {}; struct StringTag {};

View File

@ -55,6 +55,7 @@ enum CDataType {
BFloat16Vector = 103, BFloat16Vector = 103,
SparseFloatVector = 104, SparseFloatVector = 104,
Int8Vector = 105, Int8Vector = 105,
VectorArray = 106,
}; };
typedef enum CDataType CDataType; typedef enum CDataType CDataType;

View File

@ -69,6 +69,7 @@ PhyVectorSearchNode::GetOutput() {
auto& ph = placeholder_group_->at(0); auto& ph = placeholder_group_->at(0);
auto src_data = ph.get_blob(); auto src_data = ph.get_blob();
auto src_lims = ph.get_lims();
auto num_queries = ph.num_of_queries_; auto num_queries = ph.num_of_queries_;
milvus::SearchResult search_result; milvus::SearchResult search_result;
@ -85,6 +86,7 @@ PhyVectorSearchNode::GetOutput() {
col_input->size()); col_input->size());
segment_->vector_search(search_info_, segment_->vector_search(search_info_,
src_data, src_data,
src_lims,
num_queries, num_queries,
query_timestamp_, query_timestamp_,
final_view, final_view,

View File

@ -98,13 +98,18 @@ IndexFactory::CreatePrimitiveScalarIndex<std::string>(
LoadResourceRequest LoadResourceRequest
IndexFactory::IndexLoadResource( IndexFactory::IndexLoadResource(
DataType field_type, DataType field_type,
DataType element_type,
IndexVersion index_version, IndexVersion index_version,
float index_size, float index_size,
const std::map<std::string, std::string>& index_params, const std::map<std::string, std::string>& index_params,
bool mmap_enable) { bool mmap_enable) {
if (milvus::IsVectorDataType(field_type)) { if (milvus::IsVectorDataType(field_type)) {
return VecIndexLoadResource( return VecIndexLoadResource(field_type,
field_type, index_version, index_size, index_params, mmap_enable); element_type,
index_version,
index_size,
index_params,
mmap_enable);
} else { } else {
return ScalarIndexLoadResource( return ScalarIndexLoadResource(
field_type, index_version, index_size, index_params, mmap_enable); field_type, index_version, index_size, index_params, mmap_enable);
@ -114,6 +119,7 @@ IndexFactory::IndexLoadResource(
LoadResourceRequest LoadResourceRequest
IndexFactory::VecIndexLoadResource( IndexFactory::VecIndexLoadResource(
DataType field_type, DataType field_type,
DataType element_type,
IndexVersion index_version, IndexVersion index_version,
float index_size, float index_size,
const std::map<std::string, std::string>& index_params, const std::map<std::string, std::string>& index_params,
@ -198,6 +204,29 @@ IndexFactory::VecIndexLoadResource(
knowhere::IndexStaticFaced<knowhere::int8>::HasRawData( knowhere::IndexStaticFaced<knowhere::int8>::HasRawData(
index_type, index_version, config); index_type, index_version, config);
break; break;
case milvus::DataType::VECTOR_ARRAY: {
switch (element_type) {
case milvus::DataType::VECTOR_FLOAT:
resource = knowhere::IndexStaticFaced<
knowhere::fp32>::EstimateLoadResource(index_type,
index_version,
index_size_gb,
config);
has_raw_data =
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(
index_type, index_version, config);
break;
default:
LOG_ERROR(
"invalid data type to estimate index load resource: "
"field_type {}, element_type {}",
field_type,
element_type);
return LoadResourceRequest{0, 0, 0, 0, true};
}
break;
}
default: default:
LOG_ERROR("invalid data type to estimate index load resource: {}", LOG_ERROR("invalid data type to estimate index load resource: {}",
field_type); field_type);
@ -491,8 +520,14 @@ IndexFactory::CreateVectorIndex(
return std::make_unique<VectorDiskAnnIndex<float>>( return std::make_unique<VectorDiskAnnIndex<float>>(
index_type, metric_type, version, file_manager_context); index_type, metric_type, version, file_manager_context);
} }
case DataType::VECTOR_ARRAY: {
ThrowInfo(Unsupported,
"VECTOR_ARRAY for DiskAnnIndex is not supported");
}
case DataType::VECTOR_INT8: { case DataType::VECTOR_INT8: {
// TODO caiyd, not support yet // TODO caiyd, not support yet
ThrowInfo(Unsupported,
"VECTOR_INT8 for DiskAnnIndex is not supported");
} }
default: default:
ThrowInfo( ThrowInfo(
@ -505,6 +540,7 @@ IndexFactory::CreateVectorIndex(
case DataType::VECTOR_FLOAT: case DataType::VECTOR_FLOAT:
case DataType::VECTOR_SPARSE_FLOAT: { case DataType::VECTOR_SPARSE_FLOAT: {
return std::make_unique<VectorMemIndex<float>>( return std::make_unique<VectorMemIndex<float>>(
DataType::NONE,
index_type, index_type,
metric_type, metric_type,
version, version,
@ -513,6 +549,7 @@ IndexFactory::CreateVectorIndex(
} }
case DataType::VECTOR_BINARY: { case DataType::VECTOR_BINARY: {
return std::make_unique<VectorMemIndex<bin1>>( return std::make_unique<VectorMemIndex<bin1>>(
DataType::NONE,
index_type, index_type,
metric_type, metric_type,
version, version,
@ -521,6 +558,7 @@ IndexFactory::CreateVectorIndex(
} }
case DataType::VECTOR_FLOAT16: { case DataType::VECTOR_FLOAT16: {
return std::make_unique<VectorMemIndex<float16>>( return std::make_unique<VectorMemIndex<float16>>(
DataType::NONE,
index_type, index_type,
metric_type, metric_type,
version, version,
@ -529,6 +567,7 @@ IndexFactory::CreateVectorIndex(
} }
case DataType::VECTOR_BFLOAT16: { case DataType::VECTOR_BFLOAT16: {
return std::make_unique<VectorMemIndex<bfloat16>>( return std::make_unique<VectorMemIndex<bfloat16>>(
DataType::NONE,
index_type, index_type,
metric_type, metric_type,
version, version,
@ -537,12 +576,33 @@ IndexFactory::CreateVectorIndex(
} }
case DataType::VECTOR_INT8: { case DataType::VECTOR_INT8: {
return std::make_unique<VectorMemIndex<int8>>( return std::make_unique<VectorMemIndex<int8>>(
DataType::NONE,
index_type, index_type,
metric_type, metric_type,
version, version,
use_knowhere_build_pool, use_knowhere_build_pool,
file_manager_context); file_manager_context);
} }
case DataType::VECTOR_ARRAY: {
auto element_type =
static_cast<DataType>(file_manager_context.fieldDataMeta
.field_schema.element_type());
switch (element_type) {
case DataType::VECTOR_FLOAT:
return std::make_unique<VectorMemIndex<float>>(
element_type,
index_type,
metric_type,
version,
use_knowhere_build_pool,
file_manager_context);
default:
ThrowInfo(NotImplemented,
fmt::format("not implemented data type to "
"build mem index: {}",
data_type));
}
}
default: default:
ThrowInfo( ThrowInfo(
DataTypeInvalid, DataTypeInvalid,

View File

@ -56,6 +56,7 @@ class IndexFactory {
LoadResourceRequest LoadResourceRequest
IndexLoadResource(DataType field_type, IndexLoadResource(DataType field_type,
DataType element_type,
IndexVersion index_version, IndexVersion index_version,
float index_size, float index_size,
const std::map<std::string, std::string>& index_params, const std::map<std::string, std::string>& index_params,
@ -63,6 +64,7 @@ class IndexFactory {
LoadResourceRequest LoadResourceRequest
VecIndexLoadResource(DataType field_type, VecIndexLoadResource(DataType field_type,
DataType element_type,
IndexVersion index_version, IndexVersion index_version,
float index_size, float index_size,
const std::map<std::string, std::string>& index_params, const std::map<std::string, std::string>& index_params,

View File

@ -229,4 +229,21 @@ void inline SetBitsetGrowing(void* bitset,
} }
} }
inline size_t
vector_element_size(const DataType data_type) {
switch (data_type) {
case DataType::VECTOR_FLOAT:
return sizeof(float);
case DataType::VECTOR_FLOAT16:
return sizeof(float16);
case DataType::VECTOR_BFLOAT16:
return sizeof(bfloat16);
case DataType::VECTOR_INT8:
return sizeof(int8);
default:
ThrowInfo(UnexpectedError,
fmt::format("invalid data type: {}", data_type));
}
}
} // namespace milvus::index } // namespace milvus::index

View File

@ -245,7 +245,7 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
SearchResult& search_result) const { SearchResult& search_result) const {
AssertInfo(GetMetricType() == search_info.metric_type_, AssertInfo(GetMetricType() == search_info.metric_type_,
"Metric type of field index isn't the same with search info"); "Metric type of field index isn't the same with search info");
auto num_queries = dataset->GetRows(); auto num_rows = dataset->GetRows();
auto topk = search_info.topk_; auto topk = search_info.topk_;
knowhere::Json search_config = PrepareSearchParams(search_info); knowhere::Json search_config = PrepareSearchParams(search_info);
@ -277,7 +277,7 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
res.what())); res.what()));
} }
return ReGenRangeSearchResult( return ReGenRangeSearchResult(
res.value(), topk, num_queries, GetMetricType()); res.value(), topk, num_rows, GetMetricType());
} else { } else {
auto res = index_.Search(dataset, search_config, bitset); auto res = index_.Search(dataset, search_config, bitset);
if (!res.has_value()) { if (!res.has_value()) {
@ -291,6 +291,8 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
}(); }();
auto ids = final->GetIds(); auto ids = final->GetIds();
// In embedding list query, final->GetRows() can be different from dataset->GetRows().
auto num_queries = final->GetRows();
float* distances = const_cast<float*>(final->GetDistance()); float* distances = const_cast<float*>(final->GetDistance());
final->SetIsOwner(true); final->SetIsOwner(true);

View File

@ -59,12 +59,14 @@ namespace milvus::index {
template <typename T> template <typename T>
VectorMemIndex<T>::VectorMemIndex( VectorMemIndex<T>::VectorMemIndex(
DataType elem_type,
const IndexType& index_type, const IndexType& index_type,
const MetricType& metric_type, const MetricType& metric_type,
const IndexVersion& version, const IndexVersion& version,
bool use_knowhere_build_pool, bool use_knowhere_build_pool,
const storage::FileManagerContext& file_manager_context) const storage::FileManagerContext& file_manager_context)
: VectorIndex(index_type, metric_type), : VectorIndex(index_type, metric_type),
elem_type_(elem_type),
use_knowhere_build_pool_(use_knowhere_build_pool) { use_knowhere_build_pool_(use_knowhere_build_pool) {
CheckMetricTypeSupport<T>(metric_type); CheckMetricTypeSupport<T>(metric_type);
AssertInfo(!is_unsupported(index_type, metric_type), AssertInfo(!is_unsupported(index_type, metric_type),
@ -89,12 +91,14 @@ VectorMemIndex<T>::VectorMemIndex(
} }
template <typename T> template <typename T>
VectorMemIndex<T>::VectorMemIndex(const IndexType& index_type, VectorMemIndex<T>::VectorMemIndex(DataType elem_type,
const IndexType& index_type,
const MetricType& metric_type, const MetricType& metric_type,
const IndexVersion& version, const IndexVersion& version,
const knowhere::ViewDataOp view_data, const knowhere::ViewDataOp view_data,
bool use_knowhere_build_pool) bool use_knowhere_build_pool)
: VectorIndex(index_type, metric_type), : VectorIndex(index_type, metric_type),
elem_type_(elem_type),
use_knowhere_build_pool_(use_knowhere_build_pool) { use_knowhere_build_pool_(use_knowhere_build_pool) {
CheckMetricTypeSupport<T>(metric_type); CheckMetricTypeSupport<T>(metric_type);
AssertInfo(!is_unsupported(index_type, metric_type), AssertInfo(!is_unsupported(index_type, metric_type),
@ -304,6 +308,11 @@ VectorMemIndex<T>::BuildWithDataset(const DatasetPtr& dataset,
SetDim(index_.Dim()); SetDim(index_.Dim());
} }
bool
is_embedding_list_index(const IndexType& index_type) {
return index_type == knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
}
template <typename T> template <typename T>
void void
VectorMemIndex<T>::Build(const Config& config) { VectorMemIndex<T>::Build(const Config& config) {
@ -331,23 +340,74 @@ VectorMemIndex<T>::Build(const Config& config) {
total_num_rows += data->get_num_rows(); total_num_rows += data->get_num_rows();
AssertInfo(dim == 0 || dim == data->get_dim(), AssertInfo(dim == 0 || dim == data->get_dim(),
"inconsistent dim value between field datas!"); "inconsistent dim value between field datas!");
dim = data->get_dim();
// todo(SapdeA): now, vector arrays (embedding list) are serialized
// to parquet by using binary format which does not provide dim
// information so we use this temporary solution.
if (is_embedding_list_index(index_type_)) {
AssertInfo(elem_type_ != DataType::NONE,
"embedding list index must have elem_type");
dim = config[DIM_KEY].get<int64_t>();
} else {
dim = data->get_dim();
}
} }
auto buf = std::shared_ptr<uint8_t[]>(new uint8_t[total_size]); auto buf = std::shared_ptr<uint8_t[]>(new uint8_t[total_size]);
size_t lim_offset = 0;
std::vector<size_t> lims;
lims.reserve(total_num_rows + 1);
lims.push_back(lim_offset);
int64_t offset = 0; int64_t offset = 0;
// TODO: avoid copying if (!is_embedding_list_index(index_type_)) {
for (auto data : field_datas) { // TODO: avoid copying
std::memcpy(buf.get() + offset, data->Data(), data->Size()); for (auto data : field_datas) {
offset += data->Size(); std::memcpy(buf.get() + offset, data->Data(), data->Size());
data.reset(); offset += data->Size();
data.reset();
}
} else {
auto elem_size = vector_element_size(elem_type_);
for (auto data : field_datas) {
auto vec_array_data =
dynamic_cast<FieldData<VectorArray>*>(data.get());
AssertInfo(vec_array_data != nullptr,
"failed to cast field data to vector array");
auto rows = vec_array_data->get_num_rows();
for (auto i = 0; i < rows; ++i) {
auto size = vec_array_data->DataSize(i);
assert(size % (dim * elem_size) == 0);
assert(dim * elem_size != 0);
auto vec_array = vec_array_data->value_at(i);
std::memcpy(buf.get() + offset, vec_array->data(), size);
offset += size;
lim_offset += size / (dim * elem_size);
lims.push_back(lim_offset);
}
assert(data->Size() == offset);
data.reset();
}
total_num_rows = lim_offset;
} }
field_datas.clear(); field_datas.clear();
auto dataset = GenDataset(total_num_rows, dim, buf.get()); auto dataset = GenDataset(total_num_rows, dim, buf.get());
if (!scalar_info.empty()) { if (!scalar_info.empty()) {
dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info)); dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info));
} }
if (!lims.empty()) {
dataset->SetLims(lims.data());
}
BuildWithDataset(dataset, build_config); BuildWithDataset(dataset, build_config);
} else { } else {
// sparse // sparse
@ -409,7 +469,7 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
// AssertInfo(GetMetricType() == search_info.metric_type_, // AssertInfo(GetMetricType() == search_info.metric_type_,
// "Metric type of field index isn't the same with search info"); // "Metric type of field index isn't the same with search info");
auto num_queries = dataset->GetRows(); auto num_vectors = dataset->GetRows();
knowhere::Json search_conf = PrepareSearchParams(search_info); knowhere::Json search_conf = PrepareSearchParams(search_info);
auto topk = search_info.topk_; auto topk = search_info.topk_;
// TODO :: check dim of search data // TODO :: check dim of search data
@ -427,7 +487,7 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
res.what()); res.what());
} }
auto result = ReGenRangeSearchResult( auto result = ReGenRangeSearchResult(
res.value(), topk, num_queries, GetMetricType()); res.value(), topk, num_vectors, GetMetricType());
milvus::tracer::AddEvent("finish_ReGenRangeSearchResult"); milvus::tracer::AddEvent("finish_ReGenRangeSearchResult");
return result; return result;
} else { } else {
@ -448,6 +508,8 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
}(); }();
auto ids = final->GetIds(); auto ids = final->GetIds();
// In embedding list query, final->GetRows() can be different from dataset->GetRows().
auto num_queries = final->GetRows();
float* distances = const_cast<float*>(final->GetDistance()); float* distances = const_cast<float*>(final->GetDistance());
final->SetIsOwner(true); final->SetIsOwner(true);
auto round_decimal = search_info.round_decimal_; auto round_decimal = search_info.round_decimal_;

View File

@ -35,6 +35,7 @@ template <typename T>
class VectorMemIndex : public VectorIndex { class VectorMemIndex : public VectorIndex {
public: public:
explicit VectorMemIndex( explicit VectorMemIndex(
DataType elem_type /* used for embedding list only */,
const IndexType& index_type, const IndexType& index_type,
const MetricType& metric_type, const MetricType& metric_type,
const IndexVersion& version, const IndexVersion& version,
@ -43,7 +44,8 @@ class VectorMemIndex : public VectorIndex {
storage::FileManagerContext()); storage::FileManagerContext());
// knowhere data view index special constucter for intermin index, no need to hold file_manager_ to upload or download files // knowhere data view index special constucter for intermin index, no need to hold file_manager_ to upload or download files
VectorMemIndex(const IndexType& index_type, VectorMemIndex(DataType elem_type /* used for embedding list only */,
const IndexType& index_type,
const MetricType& metric_type, const MetricType& metric_type,
const IndexVersion& version, const IndexVersion& version,
const knowhere::ViewDataOp view_data, const knowhere::ViewDataOp view_data,
@ -108,6 +110,8 @@ class VectorMemIndex : public VectorIndex {
Config config_; Config config_;
knowhere::Index<knowhere::IndexNode> index_; knowhere::Index<knowhere::IndexNode> index_;
std::shared_ptr<storage::MemFileManagerImpl> file_manager_; std::shared_ptr<storage::MemFileManagerImpl> file_manager_;
// used for embedding list only
DataType elem_type_;
CreateIndexInfo create_index_info_; CreateIndexInfo create_index_info_;
bool use_knowhere_build_pool_; bool use_knowhere_build_pool_;

View File

@ -70,11 +70,8 @@ class IndexFactory {
case DataType::VECTOR_BINARY: case DataType::VECTOR_BINARY:
case DataType::VECTOR_SPARSE_FLOAT: case DataType::VECTOR_SPARSE_FLOAT:
case DataType::VECTOR_INT8: case DataType::VECTOR_INT8:
return std::make_unique<VecIndexCreator>(type, config, context);
case DataType::VECTOR_ARRAY: case DataType::VECTOR_ARRAY:
ThrowInfo(DataTypeInvalid, return std::make_unique<VecIndexCreator>(type, config, context);
fmt::format("VECTOR_ARRAY is not implemented"));
default: default:
ThrowInfo(DataTypeInvalid, ThrowInfo(DataTypeInvalid,

View File

@ -34,6 +34,13 @@ VecIndexCreator::VecIndexCreator(
Config& config, Config& config,
const storage::FileManagerContext& file_manager_context) const storage::FileManagerContext& file_manager_context)
: config_(config), data_type_(data_type) { : config_(config), data_type_(data_type) {
if (data_type == DataType::VECTOR_ARRAY) {
// TODO(SpadeA): record dim in config as there's the dim cannot be inferred in
// parquet due to the serialize method of vector array.
// This should be a temp solution.
config_[DIM_KEY] = file_manager_context.indexMeta.dim;
}
index::CreateIndexInfo index_info; index::CreateIndexInfo index_info;
index_info.field_type = data_type_; index_info.field_type = data_type_;
index_info.index_type = index::GetIndexTypeFromConfig(config_); index_info.index_type = index::GetIndexTypeFromConfig(config_);

View File

@ -273,6 +273,13 @@ class ChunkedColumnBase : public ChunkedColumnInterface {
"VectorArrayViews only supported for ChunkedVectorArrayColumn"); "VectorArrayViews only supported for ChunkedVectorArrayColumn");
} }
virtual PinWrapper<const size_t*>
VectorArrayLims(int64_t chunk_id) const override {
ThrowInfo(
ErrorCode::Unsupported,
"VectorArrayLims only supported for ChunkedVectorArrayColumn");
}
PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>> PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>>
StringViewsByOffsets(int64_t chunk_id, StringViewsByOffsets(int64_t chunk_id,
const FixedVector<int32_t>& offsets) const override { const FixedVector<int32_t>& offsets) const override {
@ -621,6 +628,15 @@ class ChunkedVectorArrayColumn : public ChunkedColumnBase {
return PinWrapper<std::vector<VectorArrayView>>( return PinWrapper<std::vector<VectorArrayView>>(
ca, static_cast<VectorArrayChunk*>(chunk)->Views()); ca, static_cast<VectorArrayChunk*>(chunk)->Views());
} }
PinWrapper<const size_t*>
VectorArrayLims(int64_t chunk_id) const override {
auto ca =
SemiInlineGet(slot_->PinCells({static_cast<cid_t>(chunk_id)}));
auto chunk = ca->get_cell_of(chunk_id);
return PinWrapper<const size_t*>(
ca, static_cast<VectorArrayChunk*>(chunk)->Lims());
}
}; };
inline std::shared_ptr<ChunkedColumnInterface> inline std::shared_ptr<ChunkedColumnInterface>

View File

@ -319,6 +319,19 @@ class ProxyChunkColumn : public ChunkedColumnInterface {
static_cast<VectorArrayChunk*>(chunk.get())->Views()); static_cast<VectorArrayChunk*>(chunk.get())->Views());
} }
PinWrapper<const size_t*>
VectorArrayLims(int64_t chunk_id) const override {
if (!IsChunkedVectorArrayColumnDataType(data_type_)) {
ThrowInfo(ErrorCode::Unsupported,
"VectorArrayLims only supported for "
"ChunkedVectorArrayColumn");
}
auto chunk_wrapper = group_->GetGroupChunk(chunk_id);
auto chunk = chunk_wrapper.get()->GetChunk(field_id_);
return PinWrapper<const size_t*>(
chunk_wrapper, static_cast<VectorArrayChunk*>(chunk.get())->Lims());
}
PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>> PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>>
StringViewsByOffsets(int64_t chunk_id, StringViewsByOffsets(int64_t chunk_id,
const FixedVector<int32_t>& offsets) const override { const FixedVector<int32_t>& offsets) const override {

View File

@ -84,6 +84,9 @@ class ChunkedColumnInterface {
virtual PinWrapper<std::vector<VectorArrayView>> virtual PinWrapper<std::vector<VectorArrayView>>
VectorArrayViews(int64_t chunk_id) const = 0; VectorArrayViews(int64_t chunk_id) const = 0;
virtual PinWrapper<const size_t*>
VectorArrayLims(int64_t chunk_id) const = 0;
virtual PinWrapper< virtual PinWrapper<
std::pair<std::vector<std::string_view>, FixedVector<bool>>> std::pair<std::vector<std::string_view>, FixedVector<bool>>>
StringViewsByOffsets(int64_t chunk_id, StringViewsByOffsets(int64_t chunk_id,

View File

@ -242,4 +242,9 @@ ExecPlanNodeVisitor::visit(Int8VectorANNS& node) {
VectorVisitorImpl<Int8Vector>(node); VectorVisitorImpl<Int8Vector>(node);
} }
void
ExecPlanNodeVisitor::visit(EmbListFloatVectorANNS& node) {
VectorVisitorImpl<EmbListFloatVector>(node);
}
} // namespace milvus::query } // namespace milvus::query

View File

@ -43,6 +43,9 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
void void
visit(RetrievePlanNode& node) override; visit(RetrievePlanNode& node) override;
void
visit(EmbListFloatVectorANNS& node) override;
public: public:
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment, ExecPlanNodeVisitor(const segcore::SegmentInterface& segment,
Timestamp timestamp, Timestamp timestamp,

View File

@ -30,6 +30,17 @@ ParsePlaceholderGroup(const Plan* plan,
placeholder_group_blob.size()); placeholder_group_blob.size());
} }
bool
check_data_type(const FieldMeta& field_meta,
const milvus::proto::common::PlaceholderType type) {
if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
return type ==
milvus::proto::common::PlaceholderType::EmbListFloatVector;
}
return static_cast<int>(field_meta.get_data_type()) ==
static_cast<int>(type);
}
std::unique_ptr<PlaceholderGroup> std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan, ParsePlaceholderGroup(const Plan* plan,
const uint8_t* blob, const uint8_t* blob,
@ -44,8 +55,7 @@ ParsePlaceholderGroup(const Plan* plan,
Assert(plan->tag2field_.count(element.tag_)); Assert(plan->tag2field_.count(element.tag_));
auto field_id = plan->tag2field_.at(element.tag_); auto field_id = plan->tag2field_.at(element.tag_);
auto& field_meta = plan->schema_->operator[](field_id); auto& field_meta = plan->schema_->operator[](field_id);
AssertInfo(static_cast<int>(field_meta.get_data_type()) == AssertInfo(check_data_type(field_meta, info.type()),
static_cast<int>(info.type()),
"vector type must be the same, field {} - type {}, search " "vector type must be the same, field {} - type {}, search "
"info type {}", "info type {}",
field_meta.get_name().get(), field_meta.get_name().get(),
@ -59,23 +69,47 @@ ParsePlaceholderGroup(const Plan* plan,
SparseBytesToRows(info.values(), /*validate=*/true); SparseBytesToRows(info.values(), /*validate=*/true);
} else { } else {
auto line_size = info.values().Get(0).size(); auto line_size = info.values().Get(0).size();
if (field_meta.get_sizeof() != line_size) {
ThrowInfo(
DimNotMatch,
fmt::format("vector dimension mismatch, expected vector "
"size(byte) {}, actual {}.",
field_meta.get_sizeof(),
line_size));
}
auto& target = element.blob_; auto& target = element.blob_;
target.reserve(line_size * element.num_of_queries_);
for (auto& line : info.values()) { if (field_meta.get_data_type() != DataType::VECTOR_ARRAY) {
AssertInfo(line_size == line.size(), if (field_meta.get_sizeof() != line_size) {
"vector dimension mismatch, expected vector " ThrowInfo(DimNotMatch,
"size(byte) {}, actual {}.", fmt::format(
line_size, "vector dimension mismatch, expected vector "
line.size()); "size(byte) {}, actual {}.",
target.insert(target.end(), line.begin(), line.end()); field_meta.get_sizeof(),
line_size));
}
target.reserve(line_size * element.num_of_queries_);
for (auto& line : info.values()) {
AssertInfo(line_size == line.size(),
"vector dimension mismatch, expected vector "
"size(byte) {}, actual {}.",
line_size,
line.size());
target.insert(target.end(), line.begin(), line.end());
}
} else {
target.reserve(line_size * element.num_of_queries_);
auto dim = field_meta.get_dim();
// If the vector is embedding list, line contains multiple vectors.
// And we should record the offsets so that we can identify each
// embedding list in a flattened vectors.
auto& lims = element.lims_;
lims.reserve(element.num_of_queries_ + 1);
size_t offset = 0;
lims.push_back(offset);
auto elem_size = milvus::index::vector_element_size(
field_meta.get_element_type());
for (auto& line : info.values()) {
target.insert(target.end(), line.begin(), line.end());
Assert(line.size() % (dim * elem_size) == 0);
offset += line.size() / (dim * elem_size);
lims.push_back(offset);
}
} }
} }
result->emplace_back(std::move(element)); result->emplace_back(std::move(element));

View File

@ -68,6 +68,9 @@ struct Plan {
struct Placeholder { struct Placeholder {
std::string tag_; std::string tag_;
// note: for embedding list search, num_of_queries_ stands for the number of vectors.
// lims_ records the offsets of embedding list in the flattened vector and
// hence lims_.size() - 1 is the number of queries in embedding list search.
int64_t num_of_queries_; int64_t num_of_queries_;
// TODO(SPARSE): add a dim_ field here, use the dim passed in search request // TODO(SPARSE): add a dim_ field here, use the dim passed in search request
// instead of the dim in schema, since the dim of sparse float column is // instead of the dim in schema, since the dim of sparse float column is
@ -78,6 +81,8 @@ struct Placeholder {
// dense vector search and sparse_matrix_ is for sparse vector search. // dense vector search and sparse_matrix_ is for sparse vector search.
aligned_vector<char> blob_; aligned_vector<char> blob_;
std::unique_ptr<knowhere::sparse::SparseRow<float>[]> sparse_matrix_; std::unique_ptr<knowhere::sparse::SparseRow<float>[]> sparse_matrix_;
// offsets for embedding list
aligned_vector<size_t> lims_;
const void* const void*
get_blob() const { get_blob() const {
@ -94,6 +99,16 @@ struct Placeholder {
} }
return blob_.data(); return blob_.data();
} }
const size_t*
get_lims() const {
return lims_.data();
}
size_t*
get_lims() {
return lims_.data();
}
}; };
struct RetrievePlan { struct RetrievePlan {

View File

@ -50,4 +50,9 @@ RetrievePlanNode::accept(PlanNodeVisitor& visitor) {
visitor.visit(*this); visitor.visit(*this);
} }
void
EmbListFloatVectorANNS::accept(PlanNodeVisitor& visitor) {
visitor.visit(*this);
}
} // namespace milvus::query } // namespace milvus::query

View File

@ -77,6 +77,12 @@ struct Int8VectorANNS : VectorPlanNode {
accept(PlanNodeVisitor&) override; accept(PlanNodeVisitor&) override;
}; };
struct EmbListFloatVectorANNS : VectorPlanNode {
public:
void
accept(PlanNodeVisitor&) override;
};
struct RetrievePlanNode : PlanNode { struct RetrievePlanNode : PlanNode {
public: public:
void void

View File

@ -39,5 +39,8 @@ class PlanNodeVisitor {
virtual void virtual void
visit(RetrievePlanNode&) = 0; visit(RetrievePlanNode&) = 0;
virtual void
visit(EmbListFloatVectorANNS&) = 0;
}; };
} // namespace milvus::query } // namespace milvus::query

View File

@ -127,6 +127,9 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
} else if (anns_proto.vector_type() == } else if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::Int8Vector) { milvus::proto::plan::VectorType::Int8Vector) {
return std::make_unique<Int8VectorANNS>(); return std::make_unique<Int8VectorANNS>();
} else if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::EmbListFloatVector) {
return std::make_unique<EmbListFloatVectorANNS>();
} else { } else {
return std::make_unique<FloatVectorANNS>(); return std::make_unique<FloatVectorANNS>();
} }

View File

@ -89,8 +89,23 @@ PrepareBFDataSet(const dataset::SearchDataset& query_ds,
DataType data_type) { DataType data_type) {
auto base_dataset = auto base_dataset =
knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data); knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data);
if (raw_ds.raw_data_lims != nullptr) {
// knowhere::DataSet count vectors in a flattened manner where as the num_raw_data here is the number
// of embedding lists where each embedding list contains multiple vectors. So we should use the last element
// in lims which equals to the total number of vectors.
base_dataset->SetLims(raw_ds.raw_data_lims);
// the length of lims equals to the number of embedding lists + 1
base_dataset->SetRows(raw_ds.raw_data_lims[raw_ds.num_raw_data]);
}
auto query_dataset = knowhere::GenDataSet( auto query_dataset = knowhere::GenDataSet(
query_ds.num_queries, query_ds.dim, query_ds.query_data); query_ds.num_queries, query_ds.dim, query_ds.query_data);
if (query_ds.query_lims != nullptr) {
// ditto
query_dataset->SetLims(query_ds.query_lims);
query_dataset->SetRows(query_ds.query_lims[query_ds.num_queries]);
}
if (data_type == DataType::VECTOR_SPARSE_FLOAT) { if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
base_dataset->SetIsSparse(true); base_dataset->SetIsSparse(true);
query_dataset->SetIsSparse(true); query_dataset->SetIsSparse(true);
@ -105,7 +120,8 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
const SearchInfo& search_info, const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info, const std::map<std::string, std::string>& index_info,
const BitsetView& bitset, const BitsetView& bitset,
DataType data_type) { DataType data_type,
DataType element_type) {
SubSearchResult sub_result(query_ds.num_queries, SubSearchResult sub_result(query_ds.num_queries,
query_ds.topk, query_ds.topk,
query_ds.metric_type, query_ds.metric_type,
@ -122,7 +138,18 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
sub_result.mutable_seg_offsets().resize(nq * topk); sub_result.mutable_seg_offsets().resize(nq * topk);
sub_result.mutable_distances().resize(nq * topk); sub_result.mutable_distances().resize(nq * topk);
// For vector array (embedding list), element type is used to determine how to operate search.
if (data_type == DataType::VECTOR_ARRAY) {
AssertInfo(element_type != DataType::NONE,
"Element type is not specified for vector array");
data_type = element_type;
}
if (search_cfg.contains(RADIUS)) { if (search_cfg.contains(RADIUS)) {
AssertInfo(data_type != DataType::VECTOR_ARRAY,
"Vector array(embedding list) is not supported for range "
"search");
if (search_cfg.contains(RANGE_FILTER)) { if (search_cfg.contains(RANGE_FILTER)) {
CheckRangeSearchParam(search_cfg[RADIUS], CheckRangeSearchParam(search_cfg[RADIUS],
search_cfg[RANGE_FILTER], search_cfg[RANGE_FILTER],
@ -238,7 +265,10 @@ DispatchBruteForceIteratorByDataType(const knowhere::DataSetPtr& base_dataset,
const knowhere::DataSetPtr& query_dataset, const knowhere::DataSetPtr& query_dataset,
const knowhere::Json& config, const knowhere::Json& config,
const BitsetView& bitset, const BitsetView& bitset,
const milvus::DataType& data_type) { milvus::DataType data_type) {
AssertInfo(data_type != DataType::VECTOR_ARRAY,
"VECTOR_ARRAY is not supported for brute force iterator");
switch (data_type) { switch (data_type) {
case DataType::VECTOR_FLOAT: case DataType::VECTOR_FLOAT:
return knowhere::BruteForce::AnnIterator<float>( return knowhere::BruteForce::AnnIterator<float>(

View File

@ -29,7 +29,8 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
const SearchInfo& search_info, const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info, const std::map<std::string, std::string>& index_info,
const BitsetView& bitset, const BitsetView& bitset,
DataType data_type); DataType data_type,
DataType element_type);
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>> knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
GetBruteForceSearchIterators( GetBruteForceSearchIterators(

View File

@ -71,6 +71,7 @@ void
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
const SearchInfo& info, const SearchInfo& info,
const void* query_data, const void* query_data,
const size_t* query_lims,
int64_t num_queries, int64_t num_queries,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,
@ -87,6 +88,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
CheckBruteForceSearchParam(field, info); CheckBruteForceSearchParam(field, info);
auto data_type = field.get_data_type(); auto data_type = field.get_data_type();
auto element_type = field.get_element_type();
AssertInfo(IsVectorDataType(data_type), AssertInfo(IsVectorDataType(data_type),
"[SearchOnGrowing]Data type isn't vector type"); "[SearchOnGrowing]Data type isn't vector type");
@ -96,6 +98,11 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
// step 2: small indexing search // step 2: small indexing search
if (segment.get_indexing_record().SyncDataWithIndex(field.get_id())) { if (segment.get_indexing_record().SyncDataWithIndex(field.get_id())) {
AssertInfo(
data_type != DataType::VECTOR_ARRAY,
"vector array(embedding list) is not supported for growing segment "
"indexing search");
FloatSegmentIndexSearch( FloatSegmentIndexSearch(
segment, info, query_data, num_queries, bitset, search_result); segment, info, query_data, num_queries, bitset, search_result);
} else { } else {
@ -103,6 +110,10 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
segment.get_chunk_mutex()); segment.get_chunk_mutex());
// check SyncDataWithIndex() again, in case the vector chunks has been removed. // check SyncDataWithIndex() again, in case the vector chunks has been removed.
if (segment.get_indexing_record().SyncDataWithIndex(field.get_id())) { if (segment.get_indexing_record().SyncDataWithIndex(field.get_id())) {
AssertInfo(data_type != DataType::VECTOR_ARRAY,
"vector array(embedding list) is not supported for "
"growing segment indexing search");
return FloatSegmentIndexSearch( return FloatSegmentIndexSearch(
segment, info, query_data, num_queries, bitset, search_result); segment, info, query_data, num_queries, bitset, search_result);
} }
@ -111,8 +122,13 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
auto dim = field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT auto dim = field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT
? 0 ? 0
: field.get_dim(); : field.get_dim();
dataset::SearchDataset search_dataset{ dataset::SearchDataset search_dataset{metric_type,
metric_type, num_queries, topk, round_decimal, dim, query_data}; num_queries,
topk,
round_decimal,
dim,
query_data,
query_lims};
int32_t current_chunk_id = 0; int32_t current_chunk_id = 0;
// get K1 and B from index for bm25 brute force // get K1 and B from index for bm25 brute force
@ -127,6 +143,10 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
auto vec_ptr = record.get_data_base(vecfield_id); auto vec_ptr = record.get_data_base(vecfield_id);
if (info.iterator_v2_info_.has_value()) { if (info.iterator_v2_info_.has_value()) {
AssertInfo(data_type != DataType::VECTOR_ARRAY,
"vector array(embedding list) is not supported for "
"vector iterator");
CachedSearchIterator cached_iter(search_dataset, CachedSearchIterator cached_iter(search_dataset,
vec_ptr, vec_ptr,
active_count, active_count,
@ -150,9 +170,54 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
std::min(active_count, (chunk_id + 1) * vec_size_per_chunk); std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
auto size_per_chunk = element_end - element_begin; auto size_per_chunk = element_end - element_begin;
auto sub_data = query::dataset::RawDataset{ query::dataset::RawDataset sub_data;
element_begin, dim, size_per_chunk, chunk_data}; std::unique_ptr<uint8_t[]> buf = nullptr;
std::vector<size_t> offsets;
if (data_type != DataType::VECTOR_ARRAY) {
sub_data = query::dataset::RawDataset{
element_begin, dim, size_per_chunk, chunk_data};
} else {
// TODO(SpadeA): For VectorArray(Embedding List), data is
// discreted stored in FixedVector which means we will copy the
// data to a contiguous memory buffer. This is inefficient and
// will be optimized in the future.
auto vec_ptr = reinterpret_cast<const VectorArray*>(chunk_data);
auto size = 0;
for (int i = 0; i < size_per_chunk; ++i) {
size += vec_ptr[i].byte_size();
}
buf = std::make_unique<uint8_t[]>(size);
offsets.reserve(size_per_chunk + 1);
offsets.push_back(0);
auto offset = 0;
auto ptr = buf.get();
for (int i = 0; i < size_per_chunk; ++i) {
memcpy(ptr, vec_ptr[i].data(), vec_ptr[i].byte_size());
ptr += vec_ptr[i].byte_size();
offset += vec_ptr[i].length();
offsets.push_back(offset);
}
sub_data = query::dataset::RawDataset{element_begin,
dim,
size_per_chunk,
buf.get(),
offsets.data()};
}
if (data_type == DataType::VECTOR_ARRAY) {
AssertInfo(
query_lims != nullptr,
"query_lims is nullptr, but data_type is vector array");
}
if (milvus::exec::UseVectorIterator(info)) { if (milvus::exec::UseVectorIterator(info)) {
AssertInfo(data_type != DataType::VECTOR_ARRAY,
"vector array(embedding list) is not supported for "
"vector iterator");
auto sub_qr = auto sub_qr =
PackBruteForceSearchIteratorsIntoSubResult(search_dataset, PackBruteForceSearchIteratorsIntoSubResult(search_dataset,
sub_data, sub_data,
@ -167,7 +232,8 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
info, info,
index_info, index_info,
bitset, bitset,
data_type); data_type,
element_type);
final_qr.merge(sub_qr); final_qr.merge(sub_qr);
} }
} }

View File

@ -20,6 +20,7 @@ void
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
const SearchInfo& info, const SearchInfo& info,
const void* query_data, const void* query_data,
const size_t* query_lims,
int64_t num_queries, int64_t num_queries,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,

View File

@ -31,6 +31,7 @@ SearchOnSealedIndex(const Schema& schema,
const segcore::SealedIndexingRecord& record, const segcore::SealedIndexingRecord& record,
const SearchInfo& search_info, const SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims,
int64_t num_queries, int64_t num_queries,
const BitsetView& bitset, const BitsetView& bitset,
SearchResult& search_result) { SearchResult& search_result) {
@ -52,7 +53,18 @@ SearchOnSealedIndex(const Schema& schema,
field_indexing->metric_type_, field_indexing->metric_type_,
search_info.metric_type_); search_info.metric_type_);
auto dataset = knowhere::GenDataSet(num_queries, dim, query_data); knowhere::DataSetPtr dataset;
if (query_lims == nullptr) {
dataset = knowhere::GenDataSet(num_queries, dim, query_data);
} else {
// Rather than non-embedding list search where num_queries equals to the number of vectors,
// in embedding list search, multiple vectors form an embedding list and the last element of query_lims
// stands for the total number of vectors.
auto num_vectors = query_lims[num_queries];
dataset = knowhere::GenDataSet(num_vectors, dim, query_data);
dataset->SetLims(query_lims);
}
dataset->SetIsSparse(is_sparse); dataset->SetIsSparse(is_sparse);
auto accessor = SemiInlineGet(field_indexing->indexing_->PinCells({0})); auto accessor = SemiInlineGet(field_indexing->indexing_->PinCells({0}));
auto vec_index = auto vec_index =
@ -92,6 +104,7 @@ SearchOnSealedColumn(const Schema& schema,
const SearchInfo& search_info, const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info, const std::map<std::string, std::string>& index_info,
const void* query_data, const void* query_data,
const size_t* query_lims,
int64_t num_queries, int64_t num_queries,
int64_t row_count, int64_t row_count,
const BitsetView& bitview, const BitsetView& bitview,
@ -99,22 +112,26 @@ SearchOnSealedColumn(const Schema& schema,
auto field_id = search_info.field_id_; auto field_id = search_info.field_id_;
auto& field = schema[field_id]; auto& field = schema[field_id];
auto data_type = field.get_data_type();
auto element_type = field.get_element_type();
// TODO(SPARSE): see todo in PlanImpl.h::PlaceHolder. // TODO(SPARSE): see todo in PlanImpl.h::PlaceHolder.
auto dim = field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT auto dim = data_type == DataType::VECTOR_SPARSE_FLOAT ? 0 : field.get_dim();
? 0
: field.get_dim();
query::dataset::SearchDataset query_dataset{search_info.metric_type_, query::dataset::SearchDataset query_dataset{search_info.metric_type_,
num_queries, num_queries,
search_info.topk_, search_info.topk_,
search_info.round_decimal_, search_info.round_decimal_,
dim, dim,
query_data}; query_data,
query_lims};
auto data_type = field.get_data_type();
CheckBruteForceSearchParam(field, search_info); CheckBruteForceSearchParam(field, search_info);
if (search_info.iterator_v2_info_.has_value()) { if (search_info.iterator_v2_info_.has_value()) {
AssertInfo(data_type != DataType::VECTOR_ARRAY,
"vector array(embedding list) is not supported for "
"vector iterator");
CachedSearchIterator cached_iter( CachedSearchIterator cached_iter(
column, query_dataset, search_info, index_info, bitview, data_type); column, query_dataset, search_info, index_info, bitview, data_type);
cached_iter.NextBatch(search_info, result); cached_iter.NextBatch(search_info, result);
@ -135,7 +152,20 @@ SearchOnSealedColumn(const Schema& schema,
auto chunk_size = column->chunk_row_nums(i); auto chunk_size = column->chunk_row_nums(i);
auto raw_dataset = auto raw_dataset =
query::dataset::RawDataset{offset, dim, chunk_size, vec_data}; query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
PinWrapper<const size_t*> lims_pw;
if (data_type == DataType::VECTOR_ARRAY) {
AssertInfo(query_lims != nullptr,
"query_lims is nullptr, but data_type is vector array");
lims_pw = column->VectorArrayLims(i);
raw_dataset.raw_data_lims = lims_pw.get();
}
if (milvus::exec::UseVectorIterator(search_info)) { if (milvus::exec::UseVectorIterator(search_info)) {
AssertInfo(data_type != DataType::VECTOR_ARRAY,
"vector array(embedding list) is not supported for "
"vector iterator");
auto sub_qr = auto sub_qr =
PackBruteForceSearchIteratorsIntoSubResult(query_dataset, PackBruteForceSearchIteratorsIntoSubResult(query_dataset,
raw_dataset, raw_dataset,
@ -150,7 +180,8 @@ SearchOnSealedColumn(const Schema& schema,
search_info, search_info,
index_info, index_info,
bitview, bitview,
data_type); data_type,
element_type);
final_qr.merge(sub_qr); final_qr.merge(sub_qr);
} }
offset += chunk_size; offset += chunk_size;

View File

@ -23,6 +23,7 @@ SearchOnSealedIndex(const Schema& schema,
const segcore::SealedIndexingRecord& record, const segcore::SealedIndexingRecord& record,
const SearchInfo& search_info, const SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims,
int64_t num_queries, int64_t num_queries,
const BitsetView& view, const BitsetView& view,
SearchResult& search_result); SearchResult& search_result);
@ -33,6 +34,7 @@ SearchOnSealedColumn(const Schema& schema,
const SearchInfo& search_info, const SearchInfo& search_info,
const std::map<std::string, std::string>& index_info, const std::map<std::string, std::string>& index_info,
const void* query_data, const void* query_data,
const size_t* query_lims,
int64_t num_queries, int64_t num_queries,
int64_t row_count, int64_t row_count,
const BitsetView& bitset, const BitsetView& bitset,

View File

@ -24,6 +24,7 @@ struct RawDataset {
int64_t dim; int64_t dim;
int64_t num_raw_data; int64_t num_raw_data;
const void* raw_data; const void* raw_data;
const size_t* raw_data_lims = nullptr;
}; };
struct SearchDataset { struct SearchDataset {
knowhere::MetricType metric_type; knowhere::MetricType metric_type;
@ -32,6 +33,8 @@ struct SearchDataset {
int64_t round_decimal; int64_t round_decimal;
int64_t dim; int64_t dim;
const void* query_data; const void* query_data;
// used for embedding list query
const size_t* query_lims = nullptr;
}; };
} // namespace dataset } // namespace dataset

View File

@ -95,10 +95,6 @@ ChunkedSegmentSealedImpl::LoadIndex(const LoadIndexInfo& info) {
auto field_id = FieldId(info.field_id); auto field_id = FieldId(info.field_id);
auto& field_meta = schema_->operator[](field_id); auto& field_meta = schema_->operator[](field_id);
if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
ThrowInfo(DataTypeInvalid, "VECTOR_ARRAY is not implemented");
}
if (field_meta.is_vector()) { if (field_meta.is_vector()) {
LoadVecIndex(info); LoadVecIndex(info);
} else { } else {
@ -127,6 +123,7 @@ ChunkedSegmentSealedImpl::LoadVecIndex(const LoadIndexInfo& info) {
LoadResourceRequest request = LoadResourceRequest request =
milvus::index::IndexFactory::GetInstance().VecIndexLoadResource( milvus::index::IndexFactory::GetInstance().VecIndexLoadResource(
field_meta.get_data_type(), field_meta.get_data_type(),
info.element_type,
info.index_engine_version, info.index_engine_version,
info.index_size, info.index_size,
info.index_params, info.index_params,
@ -498,10 +495,6 @@ int64_t
ChunkedSegmentSealedImpl::num_chunk_index(FieldId field_id) const { ChunkedSegmentSealedImpl::num_chunk_index(FieldId field_id) const {
auto& field_meta = schema_->operator[](field_id); auto& field_meta = schema_->operator[](field_id);
if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
ThrowInfo(DataTypeInvalid, "VECTOR_ARRAY is not implemented");
}
if (field_meta.is_vector()) { if (field_meta.is_vector()) {
return int64_t(vector_indexings_.is_ready(field_id)); return int64_t(vector_indexings_.is_ready(field_id));
} }
@ -720,6 +713,7 @@ ChunkedSegmentSealedImpl::mask_with_delete(BitsetTypeView& bitset,
void void
ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims,
int64_t query_count, int64_t query_count,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,
@ -745,6 +739,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
vector_indexings_, vector_indexings_,
binlog_search_info, binlog_search_info,
query_data, query_data,
query_lims,
query_count, query_count,
bitset, bitset,
output); output);
@ -758,6 +753,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
vector_indexings_, vector_indexings_,
search_info, search_info,
query_data, query_data,
query_lims,
query_count, query_count,
bitset, bitset,
output); output);
@ -782,6 +778,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
search_info, search_info,
index_info, index_info,
query_data, query_data,
query_lims,
query_count, query_count,
row_count, row_count,
bitset, bitset,

View File

@ -395,6 +395,7 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
void void
vector_search(SearchInfo& search_info, vector_search(SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims,
int64_t query_count, int64_t query_count,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,

View File

@ -47,6 +47,7 @@ VectorFieldIndexing::recreate_index(DataType data_type,
const VectorBase* field_raw_data) { const VectorBase* field_raw_data) {
if (IsSparseFloatVectorDataType(data_type)) { if (IsSparseFloatVectorDataType(data_type)) {
index_ = std::make_unique<index::VectorMemIndex<float>>( index_ = std::make_unique<index::VectorMemIndex<float>>(
DataType::NONE,
config_->GetIndexType(), config_->GetIndexType(),
config_->GetMetricType(), config_->GetMetricType(),
knowhere::Version::GetCurrentVersion().VersionNumber()); knowhere::Version::GetCurrentVersion().VersionNumber());
@ -62,6 +63,7 @@ VectorFieldIndexing::recreate_index(DataType data_type,
return (const void*)field_raw_data_ptr->get_element(id); return (const void*)field_raw_data_ptr->get_element(id);
}; };
index_ = std::make_unique<index::VectorMemIndex<float>>( index_ = std::make_unique<index::VectorMemIndex<float>>(
DataType::NONE,
config_->GetIndexType(), config_->GetIndexType(),
config_->GetMetricType(), config_->GetMetricType(),
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),
@ -78,6 +80,7 @@ VectorFieldIndexing::recreate_index(DataType data_type,
return (const void*)field_raw_data_ptr->get_element(id); return (const void*)field_raw_data_ptr->get_element(id);
}; };
index_ = std::make_unique<index::VectorMemIndex<float16>>( index_ = std::make_unique<index::VectorMemIndex<float16>>(
DataType::NONE,
config_->GetIndexType(), config_->GetIndexType(),
config_->GetMetricType(), config_->GetMetricType(),
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),
@ -94,6 +97,7 @@ VectorFieldIndexing::recreate_index(DataType data_type,
return (const void*)field_raw_data_ptr->get_element(id); return (const void*)field_raw_data_ptr->get_element(id);
}; };
index_ = std::make_unique<index::VectorMemIndex<bfloat16>>( index_ = std::make_unique<index::VectorMemIndex<bfloat16>>(
DataType::NONE,
config_->GetIndexType(), config_->GetIndexType(),
config_->GetMetricType(), config_->GetMetricType(),
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),

View File

@ -292,8 +292,9 @@ class IndexingRecord {
index_meta_->HasFiled(field_id)) { index_meta_->HasFiled(field_id)) {
auto vec_field_meta = auto vec_field_meta =
index_meta_->GetFieldIndexMeta(field_id); index_meta_->GetFieldIndexMeta(field_id);
//Disable growing index for flat //Disable growing index for flat and embedding list
if (!vec_field_meta.IsFlatIndex()) { if (!vec_field_meta.IsFlatIndex() &&
field_meta.get_data_type() != DataType::VECTOR_ARRAY) {
auto field_raw_data = auto field_raw_data =
insert_record->get_data_base(field_id); insert_record->get_data_base(field_id);
field_indexings_.try_emplace( field_indexings_.try_emplace(

View File

@ -695,12 +695,19 @@ SegmentGrowingImpl::search_batch_pks(
void void
SegmentGrowingImpl::vector_search(SearchInfo& search_info, SegmentGrowingImpl::vector_search(SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims,
int64_t query_count, int64_t query_count,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,
SearchResult& output) const { SearchResult& output) const {
query::SearchOnGrowing( query::SearchOnGrowing(*this,
*this, search_info, query_data, query_count, timestamp, bitset, output); search_info,
query_data,
query_lims,
query_count,
timestamp,
bitset,
output);
} }
std::unique_ptr<DataArray> std::unique_ptr<DataArray>

View File

@ -326,6 +326,7 @@ class SegmentGrowingImpl : public SegmentGrowing {
void void
vector_search(SearchInfo& search_info, vector_search(SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims,
int64_t query_count, int64_t query_count,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,

View File

@ -394,9 +394,14 @@ class SegmentInternalInterface : public SegmentInterface {
const std::string& nested_path) const override; const std::string& nested_path) const override;
public: public:
// `query_lims` is not null only for vector array (embedding list) search
// where it denotes the number of vectors in each embedding list. The length
// of `query_lims` is the number of queries in the search plus one (the first
// element in query_lims is 0).
virtual void virtual void
vector_search(SearchInfo& search_info, vector_search(SearchInfo& search_info,
const void* query_data, const void* query_data,
const size_t* query_lims,
int64_t query_count, int64_t query_count,
Timestamp timestamp, Timestamp timestamp,
const BitsetView& bitset, const BitsetView& bitset,

View File

@ -35,6 +35,8 @@ struct LoadIndexInfo {
int64_t segment_id; int64_t segment_id;
int64_t field_id; int64_t field_id;
DataType field_type; DataType field_type;
// The element type of the field. It's DataType::NONE if field_type is array/vector_array.
DataType element_type;
bool enable_mmap; bool enable_mmap;
std::string mmap_dir_path; std::string mmap_dir_path;
int64_t index_id; int64_t index_id;

View File

@ -668,7 +668,11 @@ MergeDataArray(std::vector<MergeBase>& merge_bases,
auto obj = vector_array->mutable_int8_vector(); auto obj = vector_array->mutable_int8_vector();
obj->assign(data, dim * sizeof(int8)); obj->assign(data, dim * sizeof(int8));
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { } else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
ThrowInfo(DataTypeInvalid, "VECTOR_ARRAY is not implemented"); auto data = src_field_data->vectors().vector_array();
auto obj = vector_array->mutable_vector_array();
obj->set_element_type(
proto::schema::DataType(field_meta.get_element_type()));
obj->CopyFrom(data);
} else { } else {
ThrowInfo(DataTypeInvalid, ThrowInfo(DataTypeInvalid,
fmt::format("unsupported datatype {}", data_type)); fmt::format("unsupported datatype {}", data_type));

View File

@ -15,7 +15,11 @@
#include "knowhere/comp/knowhere_check.h" #include "knowhere/comp/knowhere_check.h"
bool bool
CheckVecIndexWithDataType(const char* index_type, enum CDataType data_type) { CheckVecIndexWithDataType(const char* index_type,
enum CDataType data_type,
bool is_emb_list_data) {
return knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck( return knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(
std::string(index_type), knowhere::VecType(data_type)); std::string(index_type),
knowhere::VecType(data_type),
is_emb_list_data);
} }

View File

@ -16,7 +16,9 @@ extern "C" {
#endif #endif
#include "common/type_c.h" #include "common/type_c.h"
bool bool
CheckVecIndexWithDataType(const char* index_type, enum CDataType data_type); CheckVecIndexWithDataType(const char* index_type,
enum CDataType data_type,
bool is_emb_list_data);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -98,40 +98,6 @@ AppendIndexParam(CLoadIndexInfo c_load_index_info,
} }
} }
CStatus
AppendFieldInfo(CLoadIndexInfo c_load_index_info,
int64_t collection_id,
int64_t partition_id,
int64_t segment_id,
int64_t field_id,
enum CDataType field_type,
bool enable_mmap,
const char* mmap_dir_path) {
SCOPE_CGO_CALL_METRIC();
try {
auto load_index_info =
(milvus::segcore::LoadIndexInfo*)c_load_index_info;
load_index_info->collection_id = collection_id;
load_index_info->partition_id = partition_id;
load_index_info->segment_id = segment_id;
load_index_info->field_id = field_id;
load_index_info->field_type = milvus::DataType(field_type);
load_index_info->enable_mmap = enable_mmap;
load_index_info->mmap_dir_path = std::string(mmap_dir_path);
auto status = CStatus();
status.error_code = milvus::Success;
status.error_msg = "";
return status;
} catch (std::exception& e) {
auto status = CStatus();
status.error_code = milvus::UnexpectedError;
status.error_msg = strdup(e.what());
return status;
}
}
CStatus CStatus
appendVecIndex(CLoadIndexInfo c_load_index_info, CBinarySet c_binary_set) { appendVecIndex(CLoadIndexInfo c_load_index_info, CBinarySet c_binary_set) {
SCOPE_CGO_CALL_METRIC(); SCOPE_CGO_CALL_METRIC();
@ -252,6 +218,7 @@ EstimateLoadIndexResource(CLoadIndexInfo c_load_index_info) {
auto load_index_info = auto load_index_info =
(milvus::segcore::LoadIndexInfo*)c_load_index_info; (milvus::segcore::LoadIndexInfo*)c_load_index_info;
auto field_type = load_index_info->field_type; auto field_type = load_index_info->field_type;
auto element_type = load_index_info->element_type;
auto& index_params = load_index_info->index_params; auto& index_params = load_index_info->index_params;
bool find_index_type = bool find_index_type =
index_params.count("index_type") > 0 ? true : false; index_params.count("index_type") > 0 ? true : false;
@ -261,6 +228,7 @@ EstimateLoadIndexResource(CLoadIndexInfo c_load_index_info) {
LoadResourceRequest request = LoadResourceRequest request =
milvus::index::IndexFactory::GetInstance().IndexLoadResource( milvus::index::IndexFactory::GetInstance().IndexLoadResource(
field_type, field_type,
element_type,
load_index_info->index_engine_version, load_index_info->index_engine_version,
load_index_info->index_size, load_index_info->index_size,
index_params, index_params,
@ -581,6 +549,8 @@ FinishLoadIndexInfo(CLoadIndexInfo c_load_index_info,
load_index_info->field_id = info_proto->field().fieldid(); load_index_info->field_id = info_proto->field().fieldid();
load_index_info->field_type = load_index_info->field_type =
static_cast<milvus::DataType>(info_proto->field().data_type()); static_cast<milvus::DataType>(info_proto->field().data_type());
load_index_info->element_type = static_cast<milvus::DataType>(
info_proto->field().element_type());
load_index_info->enable_mmap = info_proto->enable_mmap(); load_index_info->enable_mmap = info_proto->enable_mmap();
load_index_info->mmap_dir_path = info_proto->mmap_dir_path(); load_index_info->mmap_dir_path = info_proto->mmap_dir_path();
load_index_info->index_id = info_proto->indexid(); load_index_info->index_id = info_proto->indexid();

View File

@ -38,16 +38,6 @@ AppendIndexParam(CLoadIndexInfo c_load_index_info,
const char* index_key, const char* index_key,
const char* index_value); const char* index_value);
CStatus
AppendFieldInfo(CLoadIndexInfo c_load_index_info,
int64_t collection_id,
int64_t partition_id,
int64_t segment_id,
int64_t field_id,
enum CDataType field_type,
bool enable_mmap,
const char* mmap_dir_path);
LoadResourceRequest LoadResourceRequest
EstimateLoadIndexResource(CLoadIndexInfo c_load_index_info); EstimateLoadIndexResource(CLoadIndexInfo c_load_index_info);

View File

@ -440,7 +440,10 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
->set_element_type( ->set_element_type(
proto::schema::DataType(field_meta.get_element_type())); proto::schema::DataType(field_meta.get_element_type()));
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { } else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
ThrowInfo(NotImplemented, "VECTOR_ARRAY is not implemented"); field_data->mutable_vectors()
->mutable_vector_array()
->set_element_type(
proto::schema::DataType(field_meta.get_element_type()));
} }
search_result_data->mutable_fields_data()->AddAllocated( search_result_data->mutable_fields_data()->AddAllocated(
field_data.release()); field_data.release());

View File

@ -155,7 +155,10 @@ StreamReducerHelper::AssembleMergedResult() {
->set_element_type( ->set_element_type(
proto::schema::DataType(field_meta.get_element_type())); proto::schema::DataType(field_meta.get_element_type()));
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { } else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
ThrowInfo(NotImplemented, "VECTOR_ARRAY is not implemented"); field_data->mutable_vectors()
->mutable_vector_array()
->set_element_type(
proto::schema::DataType(field_meta.get_element_type()));
} }
new_merged_result->output_fields_data_[field_id] = new_merged_result->output_fields_data_[field_id] =
@ -674,7 +677,10 @@ StreamReducerHelper::GetSearchResultDataSlice(int slice_index) {
->set_element_type( ->set_element_type(
proto::schema::DataType(field_meta.get_element_type())); proto::schema::DataType(field_meta.get_element_type()));
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { } else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
ThrowInfo(NotImplemented, "VECTOR_ARRAY is not implemented"); field_data->mutable_vectors()
->mutable_vector_array()
->set_element_type(
proto::schema::DataType(field_meta.get_element_type()));
} }
search_result_data->mutable_fields_data()->AddAllocated( search_result_data->mutable_fields_data()->AddAllocated(
field_data.release()); field_data.release());

View File

@ -81,6 +81,7 @@ InterimSealedIndexTranslator::get_cells(
if (vec_data_type_ == DataType::VECTOR_FLOAT) { if (vec_data_type_ == DataType::VECTOR_FLOAT) {
vec_index = std::make_unique<index::VectorMemIndex<float>>( vec_index = std::make_unique<index::VectorMemIndex<float>>(
DataType::NONE,
index_type_, index_type_,
metric_type_, metric_type_,
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),
@ -88,6 +89,7 @@ InterimSealedIndexTranslator::get_cells(
false); false);
} else if (vec_data_type_ == DataType::VECTOR_FLOAT16) { } else if (vec_data_type_ == DataType::VECTOR_FLOAT16) {
vec_index = std::make_unique<index::VectorMemIndex<knowhere::fp16>>( vec_index = std::make_unique<index::VectorMemIndex<knowhere::fp16>>(
DataType::NONE,
index_type_, index_type_,
metric_type_, metric_type_,
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),
@ -95,6 +97,7 @@ InterimSealedIndexTranslator::get_cells(
false); false);
} else if (vec_data_type_ == DataType::VECTOR_BFLOAT16) { } else if (vec_data_type_ == DataType::VECTOR_BFLOAT16) {
vec_index = std::make_unique<index::VectorMemIndex<knowhere::bf16>>( vec_index = std::make_unique<index::VectorMemIndex<knowhere::bf16>>(
DataType::NONE,
index_type_, index_type_,
metric_type_, metric_type_,
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),
@ -103,6 +106,7 @@ InterimSealedIndexTranslator::get_cells(
} }
} else { } else {
vec_index = std::make_unique<index::VectorMemIndex<float>>( vec_index = std::make_unique<index::VectorMemIndex<float>>(
DataType::NONE,
index_type_, index_type_,
metric_type_, metric_type_,
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),

View File

@ -22,6 +22,7 @@ SealedIndexTranslator::SealedIndexTranslator(
index_load_info_({load_index_info->enable_mmap, index_load_info_({load_index_info->enable_mmap,
load_index_info->mmap_dir_path, load_index_info->mmap_dir_path,
load_index_info->field_type, load_index_info->field_type,
load_index_info->element_type,
load_index_info->index_params, load_index_info->index_params,
load_index_info->index_size, load_index_info->index_size,
load_index_info->index_engine_version, load_index_info->index_engine_version,
@ -54,6 +55,7 @@ SealedIndexTranslator::estimated_byte_size_of_cell(
LoadResourceRequest request = LoadResourceRequest request =
milvus::index::IndexFactory::GetInstance().IndexLoadResource( milvus::index::IndexFactory::GetInstance().IndexLoadResource(
index_load_info_.field_type, index_load_info_.field_type,
index_load_info_.element_type,
index_load_info_.index_engine_version, index_load_info_.index_engine_version,
index_load_info_.index_size, index_load_info_.index_size,
index_load_info_.index_params, index_load_info_.index_params,

View File

@ -45,6 +45,7 @@ class SealedIndexTranslator
bool enable_mmap; bool enable_mmap;
std::string mmap_dir_path; std::string mmap_dir_path;
DataType field_type; DataType field_type;
DataType element_type;
std::map<std::string, std::string> index_params; std::map<std::string, std::string> index_params;
int64_t index_size; int64_t index_size;
int64_t index_engine_version; int64_t index_engine_version;

View File

@ -13,6 +13,7 @@ V1SealedIndexTranslator::V1SealedIndexTranslator(
load_index_info->enable_mmap, load_index_info->enable_mmap,
load_index_info->mmap_dir_path, load_index_info->mmap_dir_path,
load_index_info->field_type, load_index_info->field_type,
load_index_info->element_type,
load_index_info->index_params, load_index_info->index_params,
load_index_info->index_files, load_index_info->index_files,
load_index_info->index_size, load_index_info->index_size,

View File

@ -44,6 +44,7 @@ class V1SealedIndexTranslator : public Translator<milvus::index::IndexBase> {
bool enable_mmap; bool enable_mmap;
std::string mmap_dir_path; std::string mmap_dir_path;
DataType field_type; DataType field_type;
DataType element_type;
std::map<std::string, std::string> index_params; std::map<std::string, std::string> index_params;
std::vector<std::string> index_files; std::vector<std::string> index_files;
int64_t index_size; int64_t index_size;

View File

@ -24,6 +24,7 @@
CStatus CStatus
ValidateIndexParams(const char* index_type, ValidateIndexParams(const char* index_type,
enum CDataType data_type, enum CDataType data_type,
enum CDataType element_type,
const uint8_t* serialized_index_params, const uint8_t* serialized_index_params,
const uint64_t length) { const uint64_t length) {
try { try {
@ -44,45 +45,64 @@ ValidateIndexParams(const char* index_type,
knowhere::Status status; knowhere::Status status;
std::string error_msg; std::string error_msg;
if (dataType == milvus::DataType::VECTOR_BINARY) { auto check_leaf_type = [&index_type, &json, &error_msg, &status](
status = knowhere::IndexStaticFaced<knowhere::bin1>::ConfigCheck( milvus::DataType dataType) {
index_type, if (dataType == milvus::DataType::VECTOR_BINARY) {
knowhere::Version::GetCurrentVersion().VersionNumber(), status =
json, knowhere::IndexStaticFaced<knowhere::bin1>::ConfigCheck(
error_msg); index_type,
} else if (dataType == milvus::DataType::VECTOR_FLOAT) { knowhere::Version::GetCurrentVersion().VersionNumber(),
status = knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck( json,
index_type, error_msg);
knowhere::Version::GetCurrentVersion().VersionNumber(), } else if (dataType == milvus::DataType::VECTOR_FLOAT) {
json, status =
error_msg); knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
} else if (dataType == milvus::DataType::VECTOR_BFLOAT16) { index_type,
status = knowhere::IndexStaticFaced<knowhere::bf16>::ConfigCheck( knowhere::Version::GetCurrentVersion().VersionNumber(),
index_type, json,
knowhere::Version::GetCurrentVersion().VersionNumber(), error_msg);
json, } else if (dataType == milvus::DataType::VECTOR_BFLOAT16) {
error_msg); status =
} else if (dataType == milvus::DataType::VECTOR_FLOAT16) { knowhere::IndexStaticFaced<knowhere::bf16>::ConfigCheck(
status = knowhere::IndexStaticFaced<knowhere::fp16>::ConfigCheck( index_type,
index_type, knowhere::Version::GetCurrentVersion().VersionNumber(),
knowhere::Version::GetCurrentVersion().VersionNumber(), json,
json, error_msg);
error_msg); } else if (dataType == milvus::DataType::VECTOR_FLOAT16) {
} else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) { status =
status = knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck( knowhere::IndexStaticFaced<knowhere::fp16>::ConfigCheck(
index_type, index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),
json, json,
error_msg); error_msg);
} else if (dataType == milvus::DataType::VECTOR_INT8) { } else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) {
status = knowhere::IndexStaticFaced<knowhere::int8>::ConfigCheck( status =
index_type, knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
knowhere::Version::GetCurrentVersion().VersionNumber(), index_type,
json, knowhere::Version::GetCurrentVersion().VersionNumber(),
error_msg); json,
error_msg);
} else if (dataType == milvus::DataType::VECTOR_INT8) {
status =
knowhere::IndexStaticFaced<knowhere::int8>::ConfigCheck(
index_type,
knowhere::Version::GetCurrentVersion().VersionNumber(),
json,
error_msg);
} else {
status = knowhere::Status::invalid_args;
}
};
if (dataType == milvus::DataType::VECTOR_ARRAY) {
milvus::DataType elementType(
static_cast<milvus::DataType>(element_type));
check_leaf_type(elementType);
} else { } else {
status = knowhere::Status::invalid_args; check_leaf_type(dataType);
} }
CStatus cStatus; CStatus cStatus;
if (status == knowhere::Status::success) { if (status == knowhere::Status::success) {
cStatus.error_code = milvus::Success; cStatus.error_code = milvus::Success;

View File

@ -20,6 +20,7 @@ extern "C" {
CStatus CStatus
ValidateIndexParams(const char* index_type, ValidateIndexParams(const char* index_type,
enum CDataType data_type, enum CDataType data_type,
enum CDataType element_type,
const uint8_t* index_params, const uint8_t* index_params,
const uint64_t length); const uint64_t length);

View File

@ -141,7 +141,8 @@ class TestFloatSearchBruteForce : public ::testing::Test {
search_info, search_info,
index_info, index_info,
bitset_view, bitset_view,
DataType::VECTOR_FLOAT); DataType::VECTOR_FLOAT,
DataType::NONE);
for (int i = 0; i < nq; i++) { for (int i = 0; i < nq; i++) {
auto ref = Ref(base.data(), auto ref = Ref(base.data(),
query.data() + i * dim, query.data() + i * dim,

View File

@ -113,7 +113,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
search_info, search_info,
index_info, index_info,
bitset_view, bitset_view,
DataType::VECTOR_SPARSE_FLOAT)); DataType::VECTOR_SPARSE_FLOAT,
DataType::NONE));
return; return;
} }
auto result = BruteForceSearch(query_dataset, auto result = BruteForceSearch(query_dataset,
@ -121,7 +122,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
search_info, search_info,
index_info, index_info,
bitset_view, bitset_view,
DataType::VECTOR_SPARSE_FLOAT); DataType::VECTOR_SPARSE_FLOAT,
DataType::NONE);
for (int i = 0; i < nq; i++) { for (int i = 0; i < nq; i++) {
auto ref = SearchRef(base.get(), *(query.get() + i), nb, topk); auto ref = SearchRef(base.get(), *(query.get() + i), nb, topk);
auto ans = result.get_seg_offsets() + i * topk; auto ans = result.get_seg_offsets() + i * topk;
@ -135,7 +137,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
search_info, search_info,
index_info, index_info,
bitset_view, bitset_view,
DataType::VECTOR_SPARSE_FLOAT); DataType::VECTOR_SPARSE_FLOAT,
DataType::NONE);
for (int i = 0; i < nq; i++) { for (int i = 0; i < nq; i++) {
auto ref = RangeSearchRef( auto ref = RangeSearchRef(
base.get(), *(query.get() + i), nb, 0.1, 0.5, topk); base.get(), *(query.get() + i), nb, 0.1, 0.5, topk);

View File

@ -1936,6 +1936,7 @@ TEST(CApiTest, LoadIndexSearch) {
auto& index_params = load_index_info.index_params; auto& index_params = load_index_info.index_params;
index_params["index_type"] = knowhere::IndexEnum::INDEX_FAISS_IVFSQ8; index_params["index_type"] = knowhere::IndexEnum::INDEX_FAISS_IVFSQ8;
auto index = std::make_unique<VectorMemIndex<float>>( auto index = std::make_unique<VectorMemIndex<float>>(
DataType::NONE,
index_params["index_type"], index_params["index_type"],
knowhere::metric::L2, knowhere::metric::L2,
knowhere::Version::GetCurrentVersion().VersionNumber()); knowhere::Version::GetCurrentVersion().VersionNumber());

View File

@ -128,6 +128,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
search_info, search_info,
index_info, index_info,
query_data, query_data,
nullptr,
1, 1,
total_row_count, total_row_count,
bv, bv,
@ -153,6 +154,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
search_info, search_info,
index_info, index_info,
query_data, query_data,
nullptr,
1, 1,
total_row_count, total_row_count,
bv, bv,

View File

@ -16674,6 +16674,7 @@ TEST(JsonIndexTest, TestJsonNotEqualExpr) {
file_manager_ctx.fieldDataMeta.field_schema.set_data_type( file_manager_ctx.fieldDataMeta.field_schema.set_data_type(
milvus::proto::schema::JSON); milvus::proto::schema::JSON);
file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get()); file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get());
file_manager_ctx.fieldDataMeta.field_id = json_fid.get();
auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex( auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex(
index::CreateIndexInfo{ index::CreateIndexInfo{
@ -16784,6 +16785,7 @@ TEST_P(JsonIndexExistsTest, TestExistsExpr) {
milvus::proto::schema::JSON); milvus::proto::schema::JSON);
file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get()); file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get());
file_manager_ctx.fieldDataMeta.field_schema.set_nullable(true); file_manager_ctx.fieldDataMeta.field_schema.set_nullable(true);
file_manager_ctx.fieldDataMeta.field_id = json_fid.get();
auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex( auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex(
index::CreateIndexInfo{ index::CreateIndexInfo{
.index_type = index::INVERTED_INDEX_TYPE, .index_type = index::INVERTED_INDEX_TYPE,
@ -16971,6 +16973,7 @@ TEST_P(JsonIndexBinaryExprTest, TestBinaryRangeExpr) {
file_manager_ctx.fieldDataMeta.field_schema.set_data_type( file_manager_ctx.fieldDataMeta.field_schema.set_data_type(
milvus::proto::schema::JSON); milvus::proto::schema::JSON);
file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get()); file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get());
file_manager_ctx.fieldDataMeta.field_id = json_fid.get();
auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex( auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex(
index::CreateIndexInfo{ index::CreateIndexInfo{

View File

@ -541,3 +541,87 @@ TEST(GrowingTest, LoadVectorArrayData) {
verify_float_vectors(arrow_array, expected_array); verify_float_vectors(arrow_array, expected_array);
} }
} }
TEST(GrowingTest, SearchVectorArray) {
using namespace milvus::query;
auto schema = std::make_shared<Schema>();
auto metric_type = knowhere::metric::MAX_SIM;
// Add fields
auto int64_field = schema->AddDebugField("int64", DataType::INT64);
auto array_vec = schema->AddDebugVectorArrayField(
"array_vec", DataType::VECTOR_FLOAT, 128, metric_type);
schema->set_primary_field_id(int64_field);
// Configure segment
auto config = SegcoreConfig::default_config();
config.set_chunk_rows(1024);
config.set_enable_interim_segment_index(true);
std::map<std::string, std::string> index_params = {
{"index_type", knowhere::IndexEnum::INDEX_EMB_LIST_HNSW},
{"metric_type", metric_type},
{"nlist", "128"}};
std::map<std::string, std::string> type_params = {{"dim", "128"}};
FieldIndexMeta fieldIndexMeta(
array_vec, std::move(index_params), std::move(type_params));
std::map<FieldId, FieldIndexMeta> fieldMap = {{array_vec, fieldIndexMeta}};
IndexMetaPtr metaPtr =
std::make_shared<CollectionIndexMeta>(100000, std::move(fieldMap));
auto segment = CreateGrowingSegment(schema, metaPtr, 1, config);
auto segmentImplPtr = dynamic_cast<SegmentGrowingImpl*>(segment.get());
// Insert data
int64_t N = 100;
uint64_t seed = 42;
int emb_list_len = 5; // Each row contains 5 vectors
auto dataset = DataGen(schema, N, seed, 0, 1, emb_list_len);
auto offset = 0;
segment->Insert(offset,
N,
dataset.row_ids_.data(),
dataset.timestamps_.data(),
dataset.raw_);
// Prepare search query
int vec_num = 10; // Total number of query vectors
int dim = 128;
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
// Create query dataset with lims for VectorArray
std::vector<size_t> query_vec_lims;
query_vec_lims.push_back(0); // First query has 3 vectors
query_vec_lims.push_back(3);
query_vec_lims.push_back(10); // Second query has 7 vectors
// Create search plan
const char* raw_plan = R"(vector_anns: <
field_id: 101
query_info: <
topk: 5
round_decimal: 3
metric_type: "MAX_SIM"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
// Use CreatePlaceholderGroupFromBlob for VectorArray
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
vec_num, dim, query_vec.data(), query_vec_lims);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
// Execute search
Timestamp timestamp = 10000000;
auto sr = segment->Search(plan.get(), ph_group.get(), timestamp);
auto sr_parsed = SearchResultToJson(*sr);
std::cout << sr_parsed.dump(1) << std::endl;
}

View File

@ -360,6 +360,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) {
if (data_type == DataType::VECTOR_FLOAT) { if (data_type == DataType::VECTOR_FLOAT) {
auto index = std::make_unique<milvus::index::VectorMemIndex<float>>( auto index = std::make_unique<milvus::index::VectorMemIndex<float>>(
DataType::NONE,
index_type, index_type,
metric_type, metric_type,
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),
@ -375,6 +376,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) {
EXPECT_EQ(index->Count(), (add_cont + 1) * N); EXPECT_EQ(index->Count(), (add_cont + 1) * N);
} else if (data_type == DataType::VECTOR_FLOAT16) { } else if (data_type == DataType::VECTOR_FLOAT16) {
auto index = std::make_unique<milvus::index::VectorMemIndex<float16>>( auto index = std::make_unique<milvus::index::VectorMemIndex<float16>>(
DataType::NONE,
index_type, index_type,
metric_type, metric_type,
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),
@ -391,6 +393,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) {
EXPECT_EQ(index->Count(), (add_cont + 1) * N); EXPECT_EQ(index->Count(), (add_cont + 1) * N);
} else if (data_type == DataType::VECTOR_BFLOAT16) { } else if (data_type == DataType::VECTOR_BFLOAT16) {
auto index = std::make_unique<milvus::index::VectorMemIndex<bfloat16>>( auto index = std::make_unique<milvus::index::VectorMemIndex<bfloat16>>(
DataType::NONE,
index_type, index_type,
metric_type, metric_type,
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),
@ -407,6 +410,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) {
EXPECT_EQ(index->Count(), (add_cont + 1) * N); EXPECT_EQ(index->Count(), (add_cont + 1) * N);
} else if (is_sparse) { } else if (is_sparse) {
auto index = std::make_unique<milvus::index::VectorMemIndex<float>>( auto index = std::make_unique<milvus::index::VectorMemIndex<float>>(
DataType::NONE,
index_type, index_type,
metric_type, metric_type,
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),

View File

@ -184,7 +184,8 @@ TEST(Indexing, BinaryBruteForce) {
search_info, search_info,
index_info, index_info,
nullptr, nullptr,
DataType::VECTOR_BINARY); DataType::VECTOR_BINARY,
DataType::NONE);
SearchResult sr; SearchResult sr;
sr.total_nq_ = num_queries; sr.total_nq_ = num_queries;

View File

@ -564,6 +564,7 @@ class JsonFlatIndexExprTest : public ::testing::Test {
file_manager_ctx.fieldDataMeta.field_schema.set_fieldid( file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(
json_fid_.get()); json_fid_.get());
file_manager_ctx.fieldDataMeta.field_schema.set_nullable(true); file_manager_ctx.fieldDataMeta.field_schema.set_nullable(true);
file_manager_ctx.fieldDataMeta.field_id = json_fid_.get();
auto index = index::IndexFactory::GetInstance().CreateJsonIndex( auto index = index::IndexFactory::GetInstance().CreateJsonIndex(
index::CreateIndexInfo{ index::CreateIndexInfo{
.index_type = index::INVERTED_INDEX_TYPE, .index_type = index::INVERTED_INDEX_TYPE,

View File

@ -45,13 +45,13 @@ test_ngram_with_data(const boost::container::vector<std::string>& data,
auto schema = std::make_shared<Schema>(); auto schema = std::make_shared<Schema>();
auto field_id = schema->AddDebugField("ngram", DataType::VARCHAR); auto field_id = schema->AddDebugField("ngram", DataType::VARCHAR);
auto field_meta = gen_field_meta(collection_id, auto field_meta = milvus::segcore::gen_field_meta(collection_id,
partition_id, partition_id,
segment_id, segment_id,
field_id.get(), field_id.get(),
DataType::VARCHAR, DataType::VARCHAR,
DataType::NONE, DataType::NONE,
false); false);
auto index_meta = gen_index_meta( auto index_meta = gen_index_meta(
segment_id, field_id.get(), index_build_id, index_version); segment_id, field_id.get(), index_build_id, index_version);

View File

@ -19,6 +19,7 @@
#include "knowhere/version.h" #include "knowhere/version.h"
#include "storage/RemoteChunkManagerSingleton.h" #include "storage/RemoteChunkManagerSingleton.h"
#include "storage/Util.h" #include "storage/Util.h"
#include "common/VectorArray.h"
#include "test_utils/cachinglayer_test_utils.h" #include "test_utils/cachinglayer_test_utils.h"
#include "test_utils/DataGen.h" #include "test_utils/DataGen.h"
@ -2333,3 +2334,257 @@ TEST(Sealed, QueryVectorArrayAllFields) {
EXPECT_EQ(int64_result->valid_data_size(), 0); EXPECT_EQ(int64_result->valid_data_size(), 0);
EXPECT_EQ(array_float_vector_result->valid_data_size(), 0); EXPECT_EQ(array_float_vector_result->valid_data_size(), 0);
} }
TEST(Sealed, SearchVectorArray) {
int64_t collection_id = 1;
int64_t partition_id = 2;
int64_t segment_id = 3;
int64_t index_build_id = 4000;
int64_t index_version = 4000;
int64_t index_id = 5000;
auto schema = std::make_shared<Schema>();
auto metric_type = knowhere::metric::L2;
auto int64_field = schema->AddDebugField("int64", DataType::INT64);
auto array_vec = schema->AddDebugVectorArrayField(
"array_vec", DataType::VECTOR_FLOAT, 128, metric_type);
schema->set_primary_field_id(int64_field);
auto field_meta = milvus::segcore::gen_field_meta(collection_id,
partition_id,
segment_id,
array_vec.get(),
DataType::VECTOR_ARRAY,
DataType::VECTOR_FLOAT,
false);
auto index_meta = gen_index_meta(
segment_id, array_vec.get(), index_build_id, index_version);
std::map<FieldId, FieldIndexMeta> filedMap{};
IndexMetaPtr metaPtr =
std::make_shared<CollectionIndexMeta>(100000, std::move(filedMap));
int64_t dataset_size = 1000;
int64_t dim = 128;
auto emb_list_len = 10;
auto dataset = DataGen(schema, dataset_size, 42, 0, 1, emb_list_len);
// create field data
std::string root_path = "/tmp/test-vector-array/";
auto storage_config = gen_local_storage_config(root_path);
auto cm = CreateChunkManager(storage_config);
auto vec_array_col = dataset.get_col<VectorFieldProto>(array_vec);
std::vector<milvus::VectorArray> vector_arrays;
for (auto& v : vec_array_col) {
vector_arrays.push_back(milvus::VectorArray(v));
}
auto field_data = storage::CreateFieldData(DataType::VECTOR_ARRAY, false);
field_data->FillFieldData(vector_arrays.data(), vector_arrays.size());
// create sealed segment
auto segment = CreateSealedSegment(schema);
auto field_data_info = PrepareSingleFieldInsertBinlog(collection_id,
partition_id,
segment_id,
array_vec.get(),
{field_data},
cm);
segment->LoadFieldData(field_data_info);
// serialize bin logs
auto payload_reader =
std::make_shared<milvus::storage::PayloadReader>(field_data);
storage::InsertData insert_data(payload_reader);
insert_data.SetFieldDataMeta(field_meta);
insert_data.SetTimestamps(0, 100);
auto serialized_bytes = insert_data.Serialize(storage::Remote);
auto get_binlog_path = [=](int64_t log_id) {
return fmt::format("{}/{}/{}/{}/{}",
collection_id,
partition_id,
segment_id,
array_vec.get(),
log_id);
};
auto log_path = get_binlog_path(0);
auto cm_w = ChunkManagerWrapper(cm);
cm_w.Write(log_path, serialized_bytes.data(), serialized_bytes.size());
storage::FileManagerContext ctx(field_meta, index_meta, cm);
std::vector<std::string> index_files;
// create index
milvus::index::CreateIndexInfo create_index_info;
create_index_info.field_type = DataType::VECTOR_ARRAY;
create_index_info.metric_type = knowhere::metric::MAX_SIM;
create_index_info.index_type = knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
create_index_info.index_engine_version =
knowhere::Version::GetCurrentVersion().VersionNumber();
auto emb_list_hnsw_index =
milvus::index::IndexFactory::GetInstance().CreateIndex(
create_index_info,
storage::FileManagerContext(field_meta, index_meta, cm));
// build index
Config config;
config[milvus::index::INDEX_TYPE] =
knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
config[INSERT_FILES_KEY] = std::vector<std::string>{log_path};
config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type;
config[knowhere::indexparam::M] = "16";
config[knowhere::indexparam::EF] = "10";
config[DIM_KEY] = dim;
emb_list_hnsw_index->Build(config);
auto vec_index =
dynamic_cast<milvus::index::VectorIndex*>(emb_list_hnsw_index.get());
EXPECT_EQ(vec_index->Count(), dataset_size * emb_list_len);
EXPECT_EQ(vec_index->GetDim(), dim);
// search
auto vec_num = 10;
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
auto query_dataset = knowhere::GenDataSet(vec_num, dim, query_vec.data());
std::vector<size_t> query_vec_lims;
query_vec_lims.push_back(0);
query_vec_lims.push_back(3);
query_vec_lims.push_back(10);
query_dataset->SetLims(query_vec_lims.data());
auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}};
milvus::SearchInfo searchInfo;
searchInfo.topk_ = 5;
searchInfo.metric_type_ = knowhere::metric::L2;
searchInfo.search_params_ = search_conf;
SearchResult result;
vec_index->Query(query_dataset, searchInfo, nullptr, result);
auto ref_result = SearchResultToJson(result);
std::cout << ref_result.dump(1) << std::endl;
EXPECT_EQ(result.total_nq_, 2);
EXPECT_EQ(result.distances_.size(), 2 * searchInfo.topk_);
// create sealed segment
auto sealed_segment = CreateSealedWithFieldDataLoaded(schema, dataset);
// brute force search
{
const char* raw_plan = R"(vector_anns: <
field_id: 101
query_info: <
topk: 5
round_decimal: 3
metric_type: "MAX_SIM"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
vec_num, dim, query_vec.data(), query_vec_lims);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp timestamp = 1000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
auto sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp);
auto sr_parsed = SearchResultToJson(*sr);
std::cout << sr_parsed.dump(1) << std::endl;
}
// // brute force search with iterative filter
// {
// auto [min, max] =
// std::minmax_element(int_values.begin(), int_values.end());
// auto min_val = *min;
// auto max_val = *max;
// auto raw_plan = fmt::format(R"(vector_anns: <
// field_id: 101
// predicates: <
// binary_range_expr: <
// column_info: <
// field_id: 100
// data_type: Int64
// >
// lower_inclusive: true
// upper_inclusive: true
// lower_value: <
// int64_val: {}
// >
// upper_value: <
// int64_val: {}
// >
// >
// >
// query_info: <
// topk: 5
// round_decimal: 3
// metric_type: "MAX_SIM"
// hints: "iterative_filter"
// search_params: "{{\"nprobe\": 10}}"
// >
// placeholder_tag: "$0"
// >)",
// min_val,
// max_val);
// auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
// auto plan =
// CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
// auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
// vec_num, dim, query_vec.data(), query_vec_lims);
// auto ph_group =
// ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
// Timestamp timestamp = 1000000;
// std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
// auto sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp);
// auto sr_parsed = SearchResultToJson(*sr);
// std::cout << sr_parsed.dump(1) << std::endl;
// }
// search with index
{
LoadIndexInfo load_info;
load_info.field_id = array_vec.get();
load_info.field_type = DataType::VECTOR_ARRAY;
load_info.element_type = DataType::VECTOR_FLOAT;
load_info.index_params = GenIndexParams(emb_list_hnsw_index.get());
load_info.cache_index =
CreateTestCacheIndex("test", std::move(emb_list_hnsw_index));
load_info.index_params["metric_type"] = knowhere::metric::MAX_SIM;
sealed_segment->DropFieldData(array_vec);
sealed_segment->LoadIndex(load_info);
const char* raw_plan = R"(vector_anns: <
field_id: 101
query_info: <
topk: 5
round_decimal: 3
metric_type: "MAX_SIM"
search_params: "{\"nprobe\": 10}"
>
placeholder_tag: "$0"
>)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan =
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
vec_num, dim, query_vec.data(), query_vec_lims);
auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
Timestamp timestamp = 1000000;
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
auto sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp);
auto sr_parsed = SearchResultToJson(*sr);
std::cout << sr_parsed.dump(1) << std::endl;
}
}

View File

@ -1562,7 +1562,8 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
search_info, search_info,
index_info, index_info,
nullptr, nullptr,
DataType::VECTOR_FLOAT); DataType::VECTOR_FLOAT,
DataType::NONE);
auto sr = segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); auto sr = segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP);
segment->FillPrimaryKeys(plan.get(), *sr); segment->FillPrimaryKeys(plan.get(), *sr);

View File

@ -1004,7 +1004,10 @@ CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) {
template <class TraitType = milvus::FloatVector> template <class TraitType = milvus::FloatVector>
inline auto inline auto
CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const void* src) { CreatePlaceholderGroupFromBlob(int64_t num_queries,
int dim,
const void* src,
std::vector<size_t> offsets = {}) {
if (std::is_same_v<TraitType, milvus::BinaryVector>) { if (std::is_same_v<TraitType, milvus::BinaryVector>) {
assert(dim % 8 == 0); assert(dim % 8 == 0);
} }
@ -1017,12 +1020,27 @@ CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const void* src) {
value->set_type(TraitType::placeholder_type); value->set_type(TraitType::placeholder_type);
int64_t src_index = 0; int64_t src_index = 0;
for (int i = 0; i < num_queries; ++i) { if (offsets.empty()) {
std::vector<elem_type> vec; for (int i = 0; i < num_queries; ++i) {
for (int d = 0; d < dim / TraitType::dim_factor; ++d) { std::vector<elem_type> vec;
vec.push_back(((elem_type*)src)[src_index++]); for (int d = 0; d < dim / TraitType::dim_factor; ++d) {
vec.push_back(((elem_type*)src)[src_index++]);
}
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
}
} else {
assert(offsets.back() == num_queries);
for (int i = 0; i < offsets.size() - 1; i++) {
auto start = offsets[i];
auto end = offsets[i + 1];
std::vector<elem_type> vec;
for (int j = start; j < end; j++) {
for (int d = 0; d < dim / TraitType::dim_factor; ++d) {
vec.push_back(((elem_type*)src)[src_index++]);
}
}
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
} }
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
} }
return raw_group; return raw_group;
} }
@ -1362,6 +1380,7 @@ GenVecIndexing(int64_t N,
milvus::storage::FileManagerContext file_manager_context( milvus::storage::FileManagerContext file_manager_context(
field_data_meta, index_meta, chunk_manager); field_data_meta, index_meta, chunk_manager);
auto indexing = std::make_unique<index::VectorMemIndex<float>>( auto indexing = std::make_unique<index::VectorMemIndex<float>>(
DataType::NONE,
index_type, index_type,
knowhere::metric::L2, knowhere::metric::L2,
knowhere::Version::GetCurrentVersion().VersionNumber(), knowhere::Version::GetCurrentVersion().VersionNumber(),
@ -1631,4 +1650,29 @@ GenChunkedSegmentTestSchema(bool pk_is_string) {
return schema; return schema;
} }
inline std::vector<float>
generate_float_vector(int64_t N, int64_t dim) {
auto seed = 42;
auto offset = 0;
std::vector<float> final(dim * N);
for (int n = 0; n < N; ++n) {
std::vector<float> data(dim);
float sum = 0;
std::default_random_engine er2(seed + n);
std::normal_distribution<> distr2(0, 1);
for (auto& x : data) {
x = distr2(er2) + offset++;
sum += x * x;
}
sum = sqrt(sum);
for (auto& x : data) {
x /= sum;
}
std::copy(data.begin(), data.end(), final.begin() + dim * n);
}
return final;
};
} // namespace milvus::segcore } // namespace milvus::segcore

View File

@ -31,6 +31,7 @@
#include "segcore/segment_c.h" #include "segcore/segment_c.h"
#include "futures/Future.h" #include "futures/Future.h"
#include "futures/future_c.h" #include "futures/future_c.h"
#include "segcore/load_index_c.h"
#include "DataGen.h" #include "DataGen.h"
#include "PbHelper.h" #include "PbHelper.h"
#include "indexbuilder_test_utils.h" #include "indexbuilder_test_utils.h"
@ -38,6 +39,39 @@
using namespace milvus; using namespace milvus;
using namespace milvus::segcore; using namespace milvus::segcore;
// Test utility function for AppendFieldInfo
inline CStatus
AppendFieldInfo(CLoadIndexInfo c_load_index_info,
int64_t collection_id,
int64_t partition_id,
int64_t segment_id,
int64_t field_id,
enum CDataType field_type,
bool enable_mmap,
const char* mmap_dir_path) {
try {
auto load_index_info =
(milvus::segcore::LoadIndexInfo*)c_load_index_info;
load_index_info->collection_id = collection_id;
load_index_info->partition_id = partition_id;
load_index_info->segment_id = segment_id;
load_index_info->field_id = field_id;
load_index_info->field_type = milvus::DataType(field_type);
load_index_info->enable_mmap = enable_mmap;
load_index_info->mmap_dir_path = std::string(mmap_dir_path);
auto status = CStatus();
status.error_code = milvus::Success;
status.error_msg = "";
return status;
} catch (std::exception& e) {
auto status = CStatus();
status.error_code = milvus::UnexpectedError;
status.error_msg = strdup(e.what());
return status;
}
}
namespace { namespace {
std::string std::string

View File

@ -109,7 +109,7 @@ TEST(VectorArray, TestConstructVectorArray) {
field_float_vector_array.mutable_float_vector()->mutable_data()->Add( field_float_vector_array.mutable_float_vector()->mutable_data()->Add(
data.begin(), data.end()); data.begin(), data.end());
auto float_vector_array = VectorArray(field_float_vector_array); auto float_vector_array = milvus::VectorArray(field_float_vector_array);
ASSERT_EQ(float_vector_array.length(), N); ASSERT_EQ(float_vector_array.length(), N);
ASSERT_EQ(float_vector_array.dim(), dim); ASSERT_EQ(float_vector_array.dim(), dim);
ASSERT_EQ(float_vector_array.get_element_type(), DataType::VECTOR_FLOAT); ASSERT_EQ(float_vector_array.get_element_type(), DataType::VECTOR_FLOAT);
@ -117,16 +117,16 @@ TEST(VectorArray, TestConstructVectorArray) {
ASSERT_TRUE(float_vector_array.is_same_array(field_float_vector_array)); ASSERT_TRUE(float_vector_array.is_same_array(field_float_vector_array));
auto float_vector_array_tmp = VectorArray(float_vector_array); auto float_vector_array_tmp = milvus::VectorArray(float_vector_array);
ASSERT_TRUE(float_vector_array_tmp.is_same_array(field_float_vector_array)); ASSERT_TRUE(float_vector_array_tmp.is_same_array(field_float_vector_array));
auto float_vector_array_view = auto float_vector_array_view =
VectorArrayView(const_cast<char*>(float_vector_array.data()), milvus::VectorArrayView(const_cast<char*>(float_vector_array.data()),
float_vector_array.length(), float_vector_array.length(),
float_vector_array.dim(), float_vector_array.dim(),
float_vector_array.byte_size(), float_vector_array.byte_size(),
float_vector_array.get_element_type()); float_vector_array.get_element_type());
ASSERT_TRUE( ASSERT_TRUE(
float_vector_array_view.is_same_array(field_float_vector_array)); float_vector_array_view.is_same_array(field_float_vector_array));

View File

@ -70,13 +70,29 @@ func (s *Server) getSchema(ctx context.Context, collID int64) (*schemapb.Collect
return resp.GetSchema(), nil return resp.GetSchema(), nil
} }
func isJsonField(schema *schemapb.CollectionSchema, fieldID int64) (bool, error) { func FieldExists(schema *schemapb.CollectionSchema, fieldID int64) bool {
for _, f := range schema.Fields { for _, f := range schema.Fields {
if f.FieldID == fieldID { if f.FieldID == fieldID {
return typeutil.IsJSONType(f.DataType), nil return true
} }
} }
return false, merr.WrapErrFieldNotFound(fieldID) for _, structField := range schema.StructArrayFields {
for _, f := range structField.Fields {
if f.FieldID == fieldID {
return true
}
}
}
return false
}
func isJsonField(schema *schemapb.CollectionSchema, fieldID int64) bool {
for _, f := range schema.Fields {
if f.FieldID == fieldID {
return typeutil.IsJSONType(f.DataType)
}
}
return false
} }
func getIndexParam(indexParams []*commonpb.KeyValuePair, key string) (string, error) { func getIndexParam(indexParams []*commonpb.KeyValuePair, key string) (string, error) {
@ -154,11 +170,12 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques
if err != nil { if err != nil {
return merr.Status(err), nil return merr.Status(err), nil
} }
isJson, err := isJsonField(schema, req.GetFieldID())
if err != nil { if !FieldExists(schema, req.GetFieldID()) {
return merr.Status(err), nil return merr.Status(merr.WrapErrFieldNotFound(req.GetFieldID())), nil
} }
isJson := isJsonField(schema, req.GetFieldID())
if isJson { if isJson {
// check json_path and json_cast_type exist // check json_path and json_cast_type exist
jsonPath, err := getIndexParam(req.GetIndexParams(), common.JSONPathKey) jsonPath, err := getIndexParam(req.GetIndexParams(), common.JSONPathKey)

View File

@ -250,7 +250,8 @@ func (it *indexBuildTask) prepareJobRequest(ctx context.Context, segment *Segmen
schema := collectionInfo.Schema schema := collectionInfo.Schema
var field *schemapb.FieldSchema var field *schemapb.FieldSchema
for _, f := range schema.Fields { allFields := typeutil.GetAllFieldSchemas(schema)
for _, f := range allFields {
if f.FieldID == fieldID { if f.FieldID == fieldID {
field = f field = f
break break
@ -263,7 +264,11 @@ func (it *indexBuildTask) prepareJobRequest(ctx context.Context, segment *Segmen
// Extract dim only for vector types to avoid unnecessary warnings // Extract dim only for vector types to avoid unnecessary warnings
dim := -1 dim := -1
if typeutil.IsFixDimVectorType(field.GetDataType()) { dataType := field.GetDataType()
if typeutil.IsVectorArrayType(dataType) {
dataType = field.GetElementType()
}
if typeutil.IsFixDimVectorType(dataType) {
if dimVal, err := storage.GetDimFromParams(field.GetTypeParams()); err != nil { if dimVal, err := storage.GetDimFromParams(field.GetTypeParams()); err != nil {
log.Warn("failed to get dim from field type params", log.Warn("failed to get dim from field type params",
zap.String("field type", field.GetDataType().String()), zap.Error(err)) zap.String("field type", field.GetDataType().String()), zap.Error(err))

View File

@ -180,6 +180,7 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField
// plan ok with schema, check ann field // plan ok with schema, check ann field
fieldID := vectorField.FieldID fieldID := vectorField.FieldID
dataType := vectorField.DataType dataType := vectorField.DataType
elementType := vectorField.ElementType
var vectorType planpb.VectorType var vectorType planpb.VectorType
if !typeutil.IsVectorType(dataType) { if !typeutil.IsVectorType(dataType) {
@ -198,6 +199,15 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField
vectorType = planpb.VectorType_SparseFloatVector vectorType = planpb.VectorType_SparseFloatVector
case schemapb.DataType_Int8Vector: case schemapb.DataType_Int8Vector:
vectorType = planpb.VectorType_Int8Vector vectorType = planpb.VectorType_Int8Vector
case schemapb.DataType_ArrayOfVector:
switch elementType {
case schemapb.DataType_FloatVector:
vectorType = planpb.VectorType_EmbListFloatVector
default:
log.Error("Invalid elementType", zap.Any("elementType", elementType))
return nil, err
}
default: default:
log.Error("Invalid dataType", zap.Any("dataType", dataType)) log.Error("Invalid dataType", zap.Any("dataType", dataType))
return nil, err return nil, err

View File

@ -62,13 +62,18 @@ func GetDynamicPool() *conc.Pool[any] {
return dp.Load() return dp.Load()
} }
func CheckVecIndexWithDataTypeExist(name string, dType schemapb.DataType) bool { func CheckVecIndexWithDataTypeExist(name string, dataType schemapb.DataType, elementType schemapb.DataType) bool {
isEmbeddingList := dataType == schemapb.DataType_ArrayOfVector
if isEmbeddingList {
dataType = elementType
}
var result bool var result bool
GetDynamicPool().Submit(func() (any, error) { GetDynamicPool().Submit(func() (any, error) {
cIndexName := C.CString(name) cIndexName := C.CString(name)
cType := uint32(dType) cType := uint32(dataType)
defer C.free(unsafe.Pointer(cIndexName)) defer C.free(unsafe.Pointer(cIndexName))
result = bool(C.CheckVecIndexWithDataType(cIndexName, cType)) result = bool(C.CheckVecIndexWithDataType(cIndexName, cType, C.bool(isEmbeddingList)))
return nil, nil return nil, nil
}).Await() }).Await()

View File

@ -50,7 +50,7 @@ func Test_CheckVecIndexWithDataTypeExist(t *testing.T) {
} }
for _, test := range cases { for _, test := range cases {
if got := CheckVecIndexWithDataTypeExist(test.indexType, test.dataType); got != test.want { if got := CheckVecIndexWithDataTypeExist(test.indexType, test.dataType, schemapb.DataType_None); got != test.want {
t.Errorf("CheckVecIndexWithDataTypeExist(%v, %v) = %v", test.indexType, test.dataType, test.want) t.Errorf("CheckVecIndexWithDataTypeExist(%v, %v) = %v", test.indexType, test.dataType, test.want)
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -549,24 +549,24 @@ func (op *lambdaOperator) run(ctx context.Context, span trace.Span, inputs ...an
type filterFieldOperator struct { type filterFieldOperator struct {
outputFieldNames []string outputFieldNames []string
schema *schemaInfo fieldSchemas []*schemapb.FieldSchema
} }
func newFilterFieldOperator(t *searchTask, _ map[string]any) (operator, error) { func newFilterFieldOperator(t *searchTask, _ map[string]any) (operator, error) {
return &filterFieldOperator{ return &filterFieldOperator{
outputFieldNames: t.translatedOutputFields, outputFieldNames: t.translatedOutputFields,
schema: t.schema, fieldSchemas: typeutil.GetAllFieldSchemas(t.schema.CollectionSchema),
}, nil }, nil
} }
func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
result := inputs[0].(*milvuspb.SearchResults) result := inputs[0].(*milvuspb.SearchResults)
for _, retField := range result.Results.FieldsData { for _, retField := range result.Results.FieldsData {
for _, schemaField := range op.schema.Fields { for _, fieldSchema := range op.fieldSchemas {
if retField != nil && retField.FieldId == schemaField.FieldID { if retField != nil && retField.FieldId == fieldSchema.FieldID {
retField.FieldName = schemaField.Name retField.FieldName = fieldSchema.Name
retField.Type = schemaField.DataType retField.Type = fieldSchema.DataType
retField.IsDynamic = schemaField.IsDynamic retField.IsDynamic = fieldSchema.IsDynamic
} }
} }
} }

View File

@ -560,6 +560,79 @@ func (s *SearchPipelineSuite) TestHybridSearchPipe() {
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
} }
func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() {
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
{FieldID: 102, Name: "floatField", DataType: schemapb.DataType_Float},
},
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
Name: "structArray",
Fields: []*schemapb.FieldSchema{
{FieldID: 104, Name: "structArrayField", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32},
{FieldID: 105, Name: "structVectorField", DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector},
},
},
},
}
task := &searchTask{
schema: &schemaInfo{
CollectionSchema: schema,
},
translatedOutputFields: []string{"intField", "floatField", "structArrayField", "structVectorField"},
}
op, err := newFilterFieldOperator(task, nil)
s.NoError(err)
// Create mock search results with fields including struct array fields
searchResults := &milvuspb.SearchResults{
Results: &schemapb.SearchResultData{
FieldsData: []*schemapb.FieldData{
{FieldId: 101}, // intField
{FieldId: 102}, // floatField
{FieldId: 104}, // structArrayField
{FieldId: 105}, // structVectorField
},
},
}
results, err := op.run(context.Background(), s.span, searchResults)
s.NoError(err)
s.NotNil(results)
resultData := results[0].(*milvuspb.SearchResults)
s.NotNil(resultData.Results.FieldsData)
s.Len(resultData.Results.FieldsData, 4)
// Verify all fields including struct array fields got their names and types set
for _, field := range resultData.Results.FieldsData {
switch field.FieldId {
case 101:
s.Equal("intField", field.FieldName)
s.Equal(schemapb.DataType_Int64, field.Type)
s.False(field.IsDynamic)
case 102:
s.Equal("floatField", field.FieldName)
s.Equal(schemapb.DataType_Float, field.Type)
s.False(field.IsDynamic)
case 104:
// Struct array field should be handled by GetAllFieldSchemas
s.Equal("structArrayField", field.FieldName)
s.Equal(schemapb.DataType_Array, field.Type)
s.False(field.IsDynamic)
case 105:
// Struct array vector field should be handled by GetAllFieldSchemas
s.Equal("structVectorField", field.FieldName)
s.Equal(schemapb.DataType_ArrayOfVector, field.Type)
s.False(field.IsDynamic)
}
}
}
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() { func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
task := getHybridSearchTask("test_collection", [][]string{ task := getHybridSearchTask("test_collection", [][]string{
{"1", "2"}, {"1", "2"},

View File

@ -281,6 +281,28 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
return nil, fmt.Errorf("parse iterator v2 info failed: %w", err) return nil, fmt.Errorf("parse iterator v2 info failed: %w", err)
} }
// 7. check search for embedding list
annsFieldName, _ := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, searchParamsPair)
if annsFieldName != "" {
annField := typeutil.GetFieldByName(schema, annsFieldName)
if annField != nil && annField.GetDataType() == schemapb.DataType_ArrayOfVector {
if strings.Contains(searchParamStr, radiusKey) {
return nil, merr.WrapErrParameterInvalid("", "",
"range search is not supported for vector array (embedding list) fields, fieldName: %s", annsFieldName)
}
if groupByFieldId > 0 {
return nil, merr.WrapErrParameterInvalid("", "",
"group by search is not supported for vector array (embedding list) fields, fieldName: %s", annsFieldName)
}
if isIterator {
return nil, merr.WrapErrParameterInvalid("", "",
"search iterator is not supported for vector array (embedding list) fields, fieldName: %s", annsFieldName)
}
}
}
return &SearchInfo{ return &SearchInfo{
planInfo: &planpb.QueryInfo{ planInfo: &planpb.QueryInfo{
Topk: queryTopK, Topk: queryTopK,

View File

@ -68,6 +68,8 @@ const (
RoundDecimalKey = "round_decimal" RoundDecimalKey = "round_decimal"
OffsetKey = "offset" OffsetKey = "offset"
LimitKey = "limit" LimitKey = "limit"
// offsets for embedding list search
LimsKey = "lims"
SearchIterV2Key = "search_iter_v2" SearchIterV2Key = "search_iter_v2"
SearchIterBatchSizeKey = "search_iter_batch_size" SearchIterBatchSizeKey = "search_iter_batch_size"
@ -2047,7 +2049,8 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) {
loadFieldsSet := typeutil.NewSet(loadFields...) loadFieldsSet := typeutil.NewSet(loadFields...)
unindexedVecFields := make([]string, 0) unindexedVecFields := make([]string, 0)
for _, field := range collSchema.GetFields() { allFields := typeutil.GetAllFieldSchemas(collSchema.CollectionSchema)
for _, field := range allFields {
if typeutil.IsVectorType(field.GetDataType()) && loadFieldsSet.Contain(field.GetFieldID()) { if typeutil.IsVectorType(field.GetDataType()) && loadFieldsSet.Contain(field.GetFieldID()) {
if _, ok := fieldIndexIDs[field.GetFieldID()]; !ok { if _, ok := fieldIndexIDs[field.GetFieldID()]; !ok {
unindexedVecFields = append(unindexedVecFields, field.GetName()) unindexedVecFields = append(unindexedVecFields, field.GetName())
@ -2055,8 +2058,6 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) {
} }
} }
// todo(SpadeA): check vector field in StructArrayField when index is implemented
if len(unindexedVecFields) != 0 { if len(unindexedVecFields) != 0 {
errMsg := fmt.Sprintf("there is no vector index on field: %v, please create index firstly", unindexedVecFields) errMsg := fmt.Sprintf("there is no vector index on field: %v, please create index firstly", unindexedVecFields)
log.Debug(errMsg) log.Debug(errMsg)
@ -2305,7 +2306,8 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error {
loadFieldsSet := typeutil.NewSet(loadFields...) loadFieldsSet := typeutil.NewSet(loadFields...)
unindexedVecFields := make([]string, 0) unindexedVecFields := make([]string, 0)
for _, field := range collSchema.GetFields() { allFields := typeutil.GetAllFieldSchemas(collSchema.CollectionSchema)
for _, field := range allFields {
if typeutil.IsVectorType(field.GetDataType()) && loadFieldsSet.Contain(field.GetFieldID()) { if typeutil.IsVectorType(field.GetDataType()) && loadFieldsSet.Contain(field.GetFieldID()) {
if _, ok := fieldIndexIDs[field.GetFieldID()]; !ok { if _, ok := fieldIndexIDs[field.GetFieldID()]; !ok {
unindexedVecFields = append(unindexedVecFields, field.GetName()) unindexedVecFields = append(unindexedVecFields, field.GetName())
@ -2313,8 +2315,6 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error {
} }
} }
// todo(SpadeA): check vector field in StructArrayField when index is implemented
if len(unindexedVecFields) != 0 { if len(unindexedVecFields) != 0 {
errMsg := fmt.Sprintf("there is no vector index on field: %v, please create index firstly", unindexedVecFields) errMsg := fmt.Sprintf("there is no vector index on field: %v, please create index firstly", unindexedVecFields)
log.Ctx(ctx).Debug(errMsg) log.Ctx(ctx).Debug(errMsg)

View File

@ -202,10 +202,12 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
specifyIndexType, exist := indexParamsMap[common.IndexTypeKey] specifyIndexType, exist := indexParamsMap[common.IndexTypeKey]
if exist && specifyIndexType != "" { if exist && specifyIndexType != "" {
// todo(SpadeA): mmap check for struct array index
if err := indexparamcheck.ValidateMmapIndexParams(specifyIndexType, indexParamsMap); err != nil { if err := indexparamcheck.ValidateMmapIndexParams(specifyIndexType, indexParamsMap); err != nil {
log.Ctx(ctx).Warn("Invalid mmap type params", zap.String(common.IndexTypeKey, specifyIndexType), zap.Error(err)) log.Ctx(ctx).Warn("Invalid mmap type params", zap.String(common.IndexTypeKey, specifyIndexType), zap.Error(err))
return merr.WrapErrParameterInvalidMsg("invalid mmap type params: %s", err.Error()) return merr.WrapErrParameterInvalidMsg("invalid mmap type params: %s", err.Error())
} }
// todo(SpadeA): check for struct array index
checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(specifyIndexType) checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(specifyIndexType)
// not enable hybrid index for user, used in milvus internally // not enable hybrid index for user, used in milvus internally
if err != nil || indexparamcheck.IsHYBRIDChecker(checker) { if err != nil || indexparamcheck.IsHYBRIDChecker(checker) {
@ -327,16 +329,20 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
} }
var config map[string]string var config map[string]string
if typeutil.IsDenseFloatVectorType(cit.fieldSchema.DataType) { if typeutil.IsDenseFloatVectorType(cit.fieldSchema.DataType) ||
(typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsDenseFloatVectorType(cit.fieldSchema.ElementType)) {
// override float vector index params by autoindex // override float vector index params by autoindex
config = Params.AutoIndexConfig.IndexParams.GetAsJSONMap() config = Params.AutoIndexConfig.IndexParams.GetAsJSONMap()
} else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) { } else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) ||
(typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsSparseFloatVectorType(cit.fieldSchema.ElementType)) {
// override sparse float vector index params by autoindex // override sparse float vector index params by autoindex
config = Params.AutoIndexConfig.SparseIndexParams.GetAsJSONMap() config = Params.AutoIndexConfig.SparseIndexParams.GetAsJSONMap()
} else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) { } else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) ||
(typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsBinaryVectorType(cit.fieldSchema.ElementType)) {
// override binary vector index params by autoindex // override binary vector index params by autoindex
config = Params.AutoIndexConfig.BinaryIndexParams.GetAsJSONMap() config = Params.AutoIndexConfig.BinaryIndexParams.GetAsJSONMap()
} else if typeutil.IsIntVectorType(cit.fieldSchema.DataType) { } else if typeutil.IsIntVectorType(cit.fieldSchema.DataType) ||
(typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsIntVectorType(cit.fieldSchema.ElementType)) {
// override int vector index params by autoindex // override int vector index params by autoindex
config = Params.AutoIndexConfig.IndexParams.GetAsJSONMap() config = Params.AutoIndexConfig.IndexParams.GetAsJSONMap()
} }
@ -397,6 +403,12 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
if !funcutil.SliceContain(indexparamcheck.IntVectorMetrics, metricType) { if !funcutil.SliceContain(indexparamcheck.IntVectorMetrics, metricType) {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "int vector index does not support metric type: "+metricType) return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "int vector index does not support metric type: "+metricType)
} }
} else if typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) {
// TODO(SpadeA): adjust it when more metric types are supported. Especially, when different metric types
// are supported for different element types.
if !funcutil.SliceContain(indexparamcheck.EmbListMetrics, metricType) {
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector index does not support metric type: "+metricType)
}
} }
} }
@ -500,7 +512,7 @@ func checkTrain(ctx context.Context, field *schemapb.FieldSchema, indexParams ma
} }
if typeutil.IsVectorType(field.DataType) && indexType != indexparamcheck.AutoIndex { if typeutil.IsVectorType(field.DataType) && indexType != indexparamcheck.AutoIndex {
exist := CheckVecIndexWithDataTypeExist(indexType, field.DataType) exist := CheckVecIndexWithDataTypeExist(indexType, field.DataType, field.ElementType)
if !exist { if !exist {
return fmt.Errorf("data type %s can't build with this index %s", schemapb.DataType_name[int32(field.GetDataType())], indexType) return fmt.Errorf("data type %s can't build with this index %s", schemapb.DataType_name[int32(field.GetDataType())], indexType)
} }
@ -519,7 +531,7 @@ func checkTrain(ctx context.Context, field *schemapb.FieldSchema, indexParams ma
return err return err
} }
if err := checker.CheckTrain(field.DataType, indexParams); err != nil { if err := checker.CheckTrain(field.DataType, field.ElementType, indexParams); err != nil {
log.Ctx(ctx).Info("create index with invalid parameters", zap.Error(err)) log.Ctx(ctx).Info("create index with invalid parameters", zap.Error(err))
return err return err
} }

View File

@ -21,6 +21,7 @@ import (
"os" "os"
"sort" "sort"
"strconv" "strconv"
"strings"
"testing" "testing"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
@ -1164,6 +1165,131 @@ func Test_parseIndexParams(t *testing.T) {
}) })
} }
func Test_checkEmbeddingListIndex(t *testing.T) {
t.Run("check embedding list index", func(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
req: &milvuspb.CreateIndexRequest{
ExtraParams: []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "EMB_LIST_HNSW",
},
{
Key: common.MetricTypeKey,
Value: metric.MaxSim,
},
},
IndexName: "",
},
fieldSchema: &schemapb.FieldSchema{
FieldID: 101,
Name: "EmbListFloat",
IsPrimaryKey: false,
DataType: schemapb.DataType_ArrayOfVector,
ElementType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
},
}
err := cit.parseIndexParams(context.TODO())
assert.NoError(t, err)
})
t.Run("metrics wrong for embedding list index", func(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
req: &milvuspb.CreateIndexRequest{
ExtraParams: []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "EMB_LIST_HNSW",
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
},
IndexName: "",
},
fieldSchema: &schemapb.FieldSchema{
FieldID: 101,
Name: "EmbListFloat",
IsPrimaryKey: false,
DataType: schemapb.DataType_ArrayOfVector,
ElementType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
},
}
err := cit.parseIndexParams(context.TODO())
assert.True(t, strings.Contains(err.Error(), "array of vector index does not support metric type: L2"))
})
t.Run("metric type wrong", func(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
req: &milvuspb.CreateIndexRequest{
ExtraParams: []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "HNSW",
},
{
Key: common.MetricTypeKey,
Value: metric.MaxSim,
},
},
IndexName: "",
},
fieldSchema: &schemapb.FieldSchema{
FieldID: 101,
Name: "FieldFloat",
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
},
}
err := cit.parseIndexParams(context.TODO())
assert.True(t, strings.Contains(err.Error(), "float vector index does not support metric type: MAX_SIM"))
})
t.Run("data type wrong", func(t *testing.T) {
cit := &createIndexTask{
Condition: nil,
req: &milvuspb.CreateIndexRequest{
ExtraParams: []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "EMB_LIST_HNSW",
},
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
},
IndexName: "",
},
fieldSchema: &schemapb.FieldSchema{
FieldID: 101,
Name: "FieldFloat",
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
},
}
err := cit.parseIndexParams(context.TODO())
assert.True(t, strings.Contains(err.Error(), "data type FloatVector can't build with this index EMB_LIST_HNSW"))
})
}
func Test_ngram_parseIndexParams(t *testing.T) { func Test_ngram_parseIndexParams(t *testing.T) {
t.Run("valid ngram index params", func(t *testing.T) { t.Run("valid ngram index params", func(t *testing.T) {
cit := &createIndexTask{ cit := &createIndexTask{

View File

@ -217,11 +217,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return err return err
} }
allFields := make([]*schemapb.FieldSchema, 0, len(it.schema.Fields)+5) allFields := typeutil.GetAllFieldSchemas(it.schema)
allFields = append(allFields, it.schema.Fields...)
for _, structField := range it.schema.GetStructArrayFields() {
allFields = append(allFields, structField.GetFields()...)
}
// check primaryFieldData whether autoID is true or not // check primaryFieldData whether autoID is true or not
// set rowIDs as primary data if autoID == true // set rowIDs as primary data if autoID == true

View File

@ -650,7 +650,9 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
return err return err
} }
t.result.OutputFields = t.userOutputFields t.result.OutputFields = t.userOutputFields
reconstructStructFieldData(t.result, t.schema.CollectionSchema) if !t.reQuery {
reconstructStructFieldDataForQuery(t.result, t.schema.CollectionSchema)
}
primaryFieldSchema, err := t.schema.GetPkField() primaryFieldSchema, err := t.schema.GetPkField()
if err != nil { if err != nil {

View File

@ -1292,470 +1292,3 @@ func TestQueryTask_CanSkipAllocTimestamp(t *testing.T) {
assert.True(t, skip) assert.True(t, skip)
}) })
} }
func Test_reconstructStructFieldData(t *testing.T) {
t.Run("count(*) query - should return early", func(t *testing.T) {
results := &milvuspb.QueryResults{
OutputFields: []string{"count(*)"},
FieldsData: []*schemapb.FieldData{
{
FieldName: "count(*)",
FieldId: 0,
Type: schemapb.DataType_Int64,
},
},
}
schema := &schemapb.CollectionSchema{
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
FieldID: 102,
Name: "test_struct",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1021,
Name: "sub_field",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
},
},
},
}
originalFieldsData := make([]*schemapb.FieldData, len(results.FieldsData))
copy(originalFieldsData, results.FieldsData)
originalOutputFields := make([]string, len(results.OutputFields))
copy(originalOutputFields, results.OutputFields)
reconstructStructFieldData(results, schema)
// Should not modify anything for count(*) query
assert.Equal(t, originalFieldsData, results.FieldsData)
assert.Equal(t, originalOutputFields, results.OutputFields)
})
t.Run("no struct array fields - should return early", func(t *testing.T) {
results := &milvuspb.QueryResults{
OutputFields: []string{"field1", "field2"},
FieldsData: []*schemapb.FieldData{
{
FieldName: "field1",
FieldId: 100,
Type: schemapb.DataType_Int64,
},
{
FieldName: "field2",
FieldId: 101,
Type: schemapb.DataType_VarChar,
},
},
}
schema := &schemapb.CollectionSchema{
StructArrayFields: []*schemapb.StructArrayFieldSchema{},
}
originalFieldsData := make([]*schemapb.FieldData, len(results.FieldsData))
copy(originalFieldsData, results.FieldsData)
originalOutputFields := make([]string, len(results.OutputFields))
copy(originalOutputFields, results.OutputFields)
reconstructStructFieldData(results, schema)
// Should not modify anything when no struct array fields
assert.Equal(t, originalFieldsData, results.FieldsData)
assert.Equal(t, originalOutputFields, results.OutputFields)
})
t.Run("reconstruct single struct field", func(t *testing.T) {
// Create mock data
subField1Data := &schemapb.FieldData{
FieldName: "sub_int_array",
FieldId: 1021,
Type: schemapb.DataType_Array,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
ElementType: schemapb.DataType_Int32,
Data: []*schemapb.ScalarField{
{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}},
},
},
},
},
},
},
},
}
subField2Data := &schemapb.FieldData{
FieldName: "sub_text_array",
FieldId: 1022,
Type: schemapb.DataType_Array,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
ElementType: schemapb.DataType_VarChar,
Data: []*schemapb.ScalarField{
{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{Data: []string{"hello", "world"}},
},
},
},
},
},
},
},
}
results := &milvuspb.QueryResults{
OutputFields: []string{"sub_int_array", "sub_text_array"},
FieldsData: []*schemapb.FieldData{subField1Data, subField2Data},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "pk",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
},
},
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
FieldID: 102,
Name: "test_struct",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1021,
Name: "sub_int_array",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
{
FieldID: 1022,
Name: "sub_text_array",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_VarChar,
},
},
},
},
}
reconstructStructFieldData(results, schema)
// Check result
assert.Len(t, results.FieldsData, 1, "Should only have one reconstructed struct field")
assert.Len(t, results.OutputFields, 1, "Output fields should only have one")
structField := results.FieldsData[0]
assert.Equal(t, "test_struct", structField.FieldName)
assert.Equal(t, int64(102), structField.FieldId)
assert.Equal(t, schemapb.DataType_ArrayOfStruct, structField.Type)
assert.Equal(t, "test_struct", results.OutputFields[0])
// Check fields inside struct
structArrays := structField.GetStructArrays()
assert.NotNil(t, structArrays)
assert.Len(t, structArrays.Fields, 2, "Struct should contain 2 sub fields")
// Check sub fields
var foundIntField, foundTextField bool
for _, field := range structArrays.Fields {
switch field.FieldId {
case 1021:
assert.Equal(t, "sub_int_array", field.FieldName)
assert.Equal(t, schemapb.DataType_Array, field.Type)
foundIntField = true
case 1022:
assert.Equal(t, "sub_text_array", field.FieldName)
assert.Equal(t, schemapb.DataType_Array, field.Type)
foundTextField = true
}
}
assert.True(t, foundIntField, "Should find int array field")
assert.True(t, foundTextField, "Should find text array field")
})
t.Run("mixed regular and struct fields", func(t *testing.T) {
// Create regular field data
regularField := &schemapb.FieldData{
FieldName: "regular_field",
FieldId: 100,
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
},
},
},
}
// Create struct sub field data
subFieldData := &schemapb.FieldData{
FieldName: "sub_field",
FieldId: 1021,
Type: schemapb.DataType_Array,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
ElementType: schemapb.DataType_Int32,
Data: []*schemapb.ScalarField{
{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{Data: []int32{10, 20}},
},
},
},
},
},
},
},
}
results := &milvuspb.QueryResults{
OutputFields: []string{"regular_field", "sub_field"},
FieldsData: []*schemapb.FieldData{regularField, subFieldData},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "regular_field",
DataType: schemapb.DataType_Int64,
},
},
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
FieldID: 102,
Name: "test_struct",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1021,
Name: "sub_field",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
},
},
},
}
reconstructStructFieldData(results, schema)
// Check result: should have 2 fields (1 regular + 1 reconstructed struct)
assert.Len(t, results.FieldsData, 2)
assert.Len(t, results.OutputFields, 2)
// Check regular and struct fields both exist
var foundRegularField, foundStructField bool
for i, field := range results.FieldsData {
switch field.FieldId {
case 100:
assert.Equal(t, "regular_field", field.FieldName)
assert.Equal(t, schemapb.DataType_Int64, field.Type)
assert.Equal(t, "regular_field", results.OutputFields[i])
foundRegularField = true
case 102:
assert.Equal(t, "test_struct", field.FieldName)
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
assert.Equal(t, "test_struct", results.OutputFields[i])
foundStructField = true
}
}
assert.True(t, foundRegularField, "Should find regular field")
assert.True(t, foundStructField, "Should find reconstructed struct field")
})
t.Run("multiple struct fields", func(t *testing.T) {
// Create sub field for first struct
struct1SubField := &schemapb.FieldData{
FieldName: "struct1_sub",
FieldId: 1021,
Type: schemapb.DataType_Array,
}
// Create sub fields for second struct
struct2SubField1 := &schemapb.FieldData{
FieldName: "struct2_sub1",
FieldId: 1031,
Type: schemapb.DataType_Array,
}
struct2SubField2 := &schemapb.FieldData{
FieldName: "struct2_sub2",
FieldId: 1032,
Type: schemapb.DataType_Array,
}
results := &milvuspb.QueryResults{
OutputFields: []string{"struct1_sub", "struct2_sub1", "struct2_sub2"},
FieldsData: []*schemapb.FieldData{struct1SubField, struct2SubField1, struct2SubField2},
}
schema := &schemapb.CollectionSchema{
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
FieldID: 102,
Name: "struct1",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1021,
Name: "struct1_sub",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
},
},
{
FieldID: 103,
Name: "struct2",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1031,
Name: "struct2_sub1",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_VarChar,
},
{
FieldID: 1032,
Name: "struct2_sub2",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Float,
},
},
},
},
}
reconstructStructFieldData(results, schema)
// Check result: should have 2 reconstructed struct fields
assert.Len(t, results.FieldsData, 2)
assert.Len(t, results.OutputFields, 2)
// Check both struct fields are reconstructed correctly
foundStruct1, foundStruct2 := false, false
for i, field := range results.FieldsData {
switch field.FieldId {
case 102:
assert.Equal(t, "struct1", field.FieldName)
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
assert.Equal(t, "struct1", results.OutputFields[i])
structArrays := field.GetStructArrays()
assert.NotNil(t, structArrays)
assert.Len(t, structArrays.Fields, 1)
assert.Equal(t, int64(1021), structArrays.Fields[0].FieldId)
foundStruct1 = true
case 103:
assert.Equal(t, "struct2", field.FieldName)
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
assert.Equal(t, "struct2", results.OutputFields[i])
structArrays := field.GetStructArrays()
assert.NotNil(t, structArrays)
assert.Len(t, structArrays.Fields, 2)
// Check struct2 contains two sub fields
subFieldIds := make([]int64, 0, 2)
for _, subField := range structArrays.Fields {
subFieldIds = append(subFieldIds, subField.FieldId)
}
assert.Contains(t, subFieldIds, int64(1031))
assert.Contains(t, subFieldIds, int64(1032))
foundStruct2 = true
}
}
assert.True(t, foundStruct1, "Should find struct1")
assert.True(t, foundStruct2, "Should find struct2")
})
t.Run("empty fields data", func(t *testing.T) {
results := &milvuspb.QueryResults{
OutputFields: []string{},
FieldsData: []*schemapb.FieldData{},
}
schema := &schemapb.CollectionSchema{
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
FieldID: 102,
Name: "test_struct",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1021,
Name: "sub_field",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
},
},
},
}
reconstructStructFieldData(results, schema)
// Empty data should remain unchanged
assert.Len(t, results.FieldsData, 0)
assert.Len(t, results.OutputFields, 0)
})
t.Run("no matching sub fields", func(t *testing.T) {
// Field data does not match any struct definition
regularField := &schemapb.FieldData{
FieldName: "regular_field",
FieldId: 200, // Not in any struct
Type: schemapb.DataType_Int64,
}
results := &milvuspb.QueryResults{
OutputFields: []string{"regular_field"},
FieldsData: []*schemapb.FieldData{regularField},
}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 200,
Name: "regular_field",
DataType: schemapb.DataType_Int64,
},
},
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
FieldID: 102,
Name: "test_struct",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1021,
Name: "sub_field",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
},
},
},
}
reconstructStructFieldData(results, schema)
// Should only keep the regular field, no struct field
assert.Len(t, results.FieldsData, 1)
assert.Len(t, results.OutputFields, 1)
assert.Equal(t, int64(200), results.FieldsData[0].FieldId)
assert.Equal(t, "regular_field", results.OutputFields[0])
})
}

View File

@ -544,7 +544,8 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
} }
} }
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool { allFields := typeutil.GetAllFieldSchemas(t.schema.CollectionSchema)
vectorOutputFields := lo.Filter(allFields, func(field *schemapb.FieldSchema, _ int) bool {
return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType()) return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType())
}) })
t.needRequery = len(vectorOutputFields) > 0 t.needRequery = len(vectorOutputFields) > 0
@ -765,6 +766,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
} }
t.fillResult() t.fillResult()
t.result.Results.OutputFields = t.userOutputFields t.result.Results.OutputFields = t.userOutputFields
reconstructStructFieldDataForSearch(t.result, t.schema.CollectionSchema)
t.result.CollectionName = t.request.GetCollectionName() t.result.CollectionName = t.request.GetCollectionName()
primaryFieldSchema, _ := t.schema.GetPkField() primaryFieldSchema, _ := t.schema.GetPkField()

View File

@ -3548,6 +3548,278 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
assert.ErrorContains(t, err, "failed to parse input last bound") assert.ErrorContains(t, err, "failed to parse input last bound")
}) })
}) })
t.Run("check vector array unsupported features", func(t *testing.T) {
// Helper function to create a schema with vector array field
createSchemaWithVectorArray := func(annsFieldName string) *schemapb.CollectionSchema {
return &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "id",
DataType: schemapb.DataType_Int64,
IsPrimaryKey: true,
},
{
FieldID: 101,
Name: "normal_vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
},
{
FieldID: 102,
Name: "group_field",
DataType: schemapb.DataType_VarChar,
},
},
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
FieldID: 103,
Name: "struct_array_field",
Fields: []*schemapb.FieldSchema{
{
FieldID: 104,
Name: annsFieldName,
DataType: schemapb.DataType_ArrayOfVector,
ElementType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
},
},
},
},
}
}
// Helper function to create search params with anns field
createSearchParams := func(annsFieldName string) []*commonpb.KeyValuePair {
return []*commonpb.KeyValuePair{
{
Key: AnnsFieldKey,
Value: annsFieldName,
},
{
Key: TopKKey,
Value: "10",
},
{
Key: common.MetricTypeKey,
Value: metric.MaxSim,
},
{
Key: ParamsKey,
Value: `{"nprobe": 10}`,
},
{
Key: RoundDecimalKey,
Value: "-1",
},
{
Key: IgnoreGrowingKey,
Value: "false",
},
}
}
t.Run("vector array with range search", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
params := createSearchParams("embeddings_list")
// Add radius parameter for range search
resetSearchParamsValue(params, ParamsKey, `{"nprobe": 10, "radius": 0.2}`)
searchInfo, err := parseSearchInfo(params, schema, nil)
assert.Error(t, err)
assert.Nil(t, searchInfo)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
fmt.Println(err.Error())
assert.Contains(t, err.Error(), "range search is not supported for vector array (embedding list) fields")
})
t.Run("vector array with group by", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
params := createSearchParams("embeddings_list")
// Add group by parameter
params = append(params, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: "group_field",
})
searchInfo, err := parseSearchInfo(params, schema, nil)
assert.Error(t, err)
assert.Nil(t, searchInfo)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
assert.Contains(t, err.Error(), "group by search is not supported for vector array (embedding list) fields")
assert.Contains(t, err.Error(), "embeddings_list")
})
t.Run("vector array with iterator", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
params := createSearchParams("embeddings_list")
// Add iterator parameter
params = append(params, &commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
})
searchInfo, err := parseSearchInfo(params, schema, nil)
assert.Error(t, err)
assert.Nil(t, searchInfo)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
assert.Contains(t, err.Error(), "search iterator is not supported for vector array (embedding list) fields")
assert.Contains(t, err.Error(), "embeddings_list")
})
t.Run("vector array with iterator v2", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
params := createSearchParams("embeddings_list")
// Add iterator v2 parameters
params = append(params,
&commonpb.KeyValuePair{
Key: SearchIterV2Key,
Value: "True",
},
&commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
},
&commonpb.KeyValuePair{
Key: SearchIterBatchSizeKey,
Value: "10",
},
)
searchInfo, err := parseSearchInfo(params, schema, nil)
assert.Error(t, err)
assert.Nil(t, searchInfo)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
assert.Contains(t, err.Error(), "search iterator is not supported for vector array (embedding list) fields")
assert.Contains(t, err.Error(), "embeddings_list")
})
t.Run("normal search on vector array should succeed", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
params := createSearchParams("embeddings_list")
searchInfo, err := parseSearchInfo(params, schema, nil)
assert.NoError(t, err)
assert.NotNil(t, searchInfo)
assert.NotNil(t, searchInfo.planInfo)
})
t.Run("normal vector field with range search should succeed", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
params := createSearchParams("normal_vector")
// Add radius parameter for range search
resetSearchParamsValue(params, ParamsKey, `{"nprobe": 10, "radius": 0.2}`)
searchInfo, err := parseSearchInfo(params, schema, nil)
assert.NoError(t, err)
assert.NotNil(t, searchInfo)
assert.NotNil(t, searchInfo.planInfo)
})
t.Run("normal vector field with group by should succeed", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
params := createSearchParams("normal_vector")
// Add group by parameter
params = append(params, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: "group_field",
})
searchInfo, err := parseSearchInfo(params, schema, nil)
assert.NoError(t, err)
assert.NotNil(t, searchInfo)
assert.NotNil(t, searchInfo.planInfo)
assert.Equal(t, int64(102), searchInfo.planInfo.GroupByFieldId)
})
t.Run("normal vector field with iterator should succeed", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
params := createSearchParams("normal_vector")
// Add iterator parameter
params = append(params, &commonpb.KeyValuePair{
Key: IteratorField,
Value: "True",
})
searchInfo, err := parseSearchInfo(params, schema, nil)
assert.NoError(t, err)
assert.NotNil(t, searchInfo)
assert.NotNil(t, searchInfo.planInfo)
})
t.Run("vector array with range search", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
params := createSearchParams("embeddings_list")
resetSearchParamsValue(params, ParamsKey, `{"nprobe": 10, "radius": 0.2}`)
searchInfo, err := parseSearchInfo(params, schema, nil)
assert.Error(t, err)
assert.Nil(t, searchInfo)
// Should fail on range search first
assert.Contains(t, err.Error(), "range search is not supported for vector array (embedding list) fields")
})
t.Run("no anns field specified", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
params := getValidSearchParams()
// Don't specify anns field
searchInfo, err := parseSearchInfo(params, schema, nil)
assert.NoError(t, err)
assert.NotNil(t, searchInfo)
// Should not trigger vector array validation without anns field
})
t.Run("non-existent anns field", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
params := createSearchParams("non_existent_field")
searchInfo, err := parseSearchInfo(params, schema, nil)
assert.NoError(t, err)
assert.NotNil(t, searchInfo)
// Should not trigger vector array validation for non-existent field
})
t.Run("hybrid search with outer group by on vector array", func(t *testing.T) {
schema := createSchemaWithVectorArray("embeddings_list")
// Create rank params with group by
rankParams := getValidSearchParams()
rankParams = append(rankParams,
&commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: "group_field",
},
&commonpb.KeyValuePair{
Key: LimitKey,
Value: "100",
},
)
parsedRankParams, err := parseRankParams(rankParams, schema)
assert.NoError(t, err)
searchParams := createSearchParams("embeddings_list")
// Parse search info with rank params
searchInfo, err := parseSearchInfo(searchParams, schema, parsedRankParams)
assert.Error(t, err)
assert.Nil(t, searchInfo)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
assert.Contains(t, err.Error(), "group by search is not supported for vector array (embedding list) fields")
})
})
} }
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData { func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
@ -4301,3 +4573,110 @@ func genTestSearchResultData(nq int64, topk int64, dType schemapb.DataType, fiel
} }
return result return result
} }
func TestSearchTask_InitSearchRequestWithStructArrayFields(t *testing.T) {
ctx := context.Background()
schema := &schemapb.CollectionSchema{
Name: "test_collection",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "regular_vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}},
{FieldID: 102, Name: "regular_scalar", DataType: schemapb.DataType_Int32},
},
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
Name: "structArray",
Fields: []*schemapb.FieldSchema{
{FieldID: 104, Name: "struct_vec_array", DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "64"}}},
{FieldID: 105, Name: "struct_scalar_array", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32},
},
},
},
}
schemaInfo := newSchemaInfo(schema)
tests := []struct {
name string
outputFields []string
expectedRequery bool
description string
}{
{
name: "regular_vector_field",
outputFields: []string{"pk", "regular_vec"},
expectedRequery: true,
description: "Should require requery when regular vector field in output",
},
{
name: "struct_array_vector_field",
outputFields: []string{"pk", "struct_vec_array"},
expectedRequery: true,
description: "Should require requery when struct array vector field in output (tests GetAllFieldSchemas)",
},
{
name: "both_vector_fields",
outputFields: []string{"pk", "regular_vec", "struct_vec_array"},
expectedRequery: true,
description: "Should require requery when both regular and struct array vector fields in output",
},
{
name: "struct_scalar_array_only",
outputFields: []string{"pk", "struct_scalar_array"},
expectedRequery: false,
description: "Should not require requery when only struct scalar array field in output",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
task := &searchTask{
ctx: ctx,
collectionName: "test_collection",
SearchRequest: &internalpb.SearchRequest{
CollectionID: 1,
PartitionIDs: []int64{1},
Dsl: "",
PlaceholderGroup: nil,
DslType: commonpb.DslType_BoolExprV1,
OutputFieldsId: []int64{},
},
request: &milvuspb.SearchRequest{
CollectionName: "test_collection",
OutputFields: tt.outputFields,
Dsl: "",
SearchParams: []*commonpb.KeyValuePair{
{Key: AnnsFieldKey, Value: "regular_vec"},
{Key: TopKKey, Value: "10"},
{Key: common.MetricTypeKey, Value: metric.L2},
{Key: ParamsKey, Value: `{"nprobe": 10}`},
},
PlaceholderGroup: nil,
ConsistencyLevel: commonpb.ConsistencyLevel_Session,
},
schema: schemaInfo,
translatedOutputFields: tt.outputFields,
tr: timerecord.NewTimeRecorder("test"),
queryInfos: []*planpb.QueryInfo{{}},
}
// Set translated output field IDs based on the schema
outputFieldIDs := []int64{}
allFields := typeutil.GetAllFieldSchemas(schema)
for _, fieldName := range tt.outputFields {
for _, field := range allFields {
if field.Name == fieldName {
outputFieldIDs = append(outputFieldIDs, field.FieldID)
break
}
}
}
task.SearchRequest.OutputFieldsId = outputFieldIDs
err := task.initSearchRequest(ctx)
assert.NoError(t, err)
assert.Equal(t, tt.expectedRequery, task.needRequery, tt.description)
})
}
}

View File

@ -585,11 +585,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
return err return err
} }
allFields := make([]*schemapb.FieldSchema, 0, len(it.schema.Fields)+5) allFields := typeutil.GetAllFieldSchemas(it.schema.CollectionSchema)
allFields = append(allFields, it.schema.Fields...)
for _, structField := range it.schema.GetStructArrayFields() {
allFields = append(allFields, structField.GetFields()...)
}
// use the passed pk as new pk when autoID == false // use the passed pk as new pk when autoID == false
// automatic generate pk as new pk wehen autoID == true // automatic generate pk as new pk wehen autoID == true

View File

@ -619,9 +619,6 @@ func ValidateFieldsInStruct(field *schemapb.FieldSchema, schema *schemapb.Collec
if field.GetNullable() { if field.GetNullable() {
return fmt.Errorf("nullable is not supported for fields in struct array now, fieldName = %s", field.Name) return fmt.Errorf("nullable is not supported for fields in struct array now, fieldName = %s", field.Name)
} }
// todo(SpadeA): add more check when index is enabled
return nil return nil
} }
@ -2547,3 +2544,88 @@ func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 {
return 0 return 0
} }
// reconstructStructFieldDataCommon reconstructs struct fields from flattened sub-fields
// It works with both QueryResults and SearchResults by operating on the common data structures
func reconstructStructFieldDataCommon(
fieldsData []*schemapb.FieldData,
outputFields []string,
schema *schemapb.CollectionSchema,
) ([]*schemapb.FieldData, []string) {
if len(outputFields) == 1 && outputFields[0] == "count(*)" {
return fieldsData, outputFields
}
if len(schema.StructArrayFields) == 0 {
return fieldsData, outputFields
}
regularFieldIDs := make(map[int64]interface{})
subFieldToStructMap := make(map[int64]int64)
groupedStructFields := make(map[int64][]*schemapb.FieldData)
structFieldNames := make(map[int64]string)
reconstructedOutputFields := make([]string, 0, len(fieldsData))
// record all regular field IDs
for _, field := range schema.Fields {
regularFieldIDs[field.GetFieldID()] = nil
}
// build the mapping from sub-field ID to struct field ID
for _, structField := range schema.StructArrayFields {
for _, subField := range structField.GetFields() {
subFieldToStructMap[subField.GetFieldID()] = structField.GetFieldID()
}
structFieldNames[structField.GetFieldID()] = structField.GetName()
}
newFieldsData := make([]*schemapb.FieldData, 0, len(fieldsData))
for _, field := range fieldsData {
fieldID := field.GetFieldId()
if _, ok := regularFieldIDs[fieldID]; ok {
newFieldsData = append(newFieldsData, field)
reconstructedOutputFields = append(reconstructedOutputFields, field.GetFieldName())
} else {
structFieldID := subFieldToStructMap[fieldID]
groupedStructFields[structFieldID] = append(groupedStructFields[structFieldID], field)
}
}
for structFieldID, fields := range groupedStructFields {
fieldData := &schemapb.FieldData{
FieldName: structFieldNames[structFieldID],
FieldId: structFieldID,
Type: schemapb.DataType_ArrayOfStruct,
Field: &schemapb.FieldData_StructArrays{StructArrays: &schemapb.StructArrayField{Fields: fields}},
}
newFieldsData = append(newFieldsData, fieldData)
reconstructedOutputFields = append(reconstructedOutputFields, structFieldNames[structFieldID])
}
return newFieldsData, reconstructedOutputFields
}
// Wrapper for QueryResults
func reconstructStructFieldDataForQuery(results *milvuspb.QueryResults, schema *schemapb.CollectionSchema) {
fieldsData, outputFields := reconstructStructFieldDataCommon(
results.FieldsData,
results.OutputFields,
schema,
)
results.FieldsData = fieldsData
results.OutputFields = outputFields
}
// New wrapper for SearchResults
func reconstructStructFieldDataForSearch(results *milvuspb.SearchResults, schema *schemapb.CollectionSchema) {
if results.Results == nil {
return
}
fieldsData, outputFields := reconstructStructFieldDataCommon(
results.Results.FieldsData,
results.Results.OutputFields,
schema,
)
results.Results.FieldsData = fieldsData
results.Results.OutputFields = outputFields
}

View File

@ -3825,5 +3825,644 @@ func TestCheckAndFlattenStructFieldData(t *testing.T) {
} }
func TestValidateFieldsInStruct(t *testing.T) { func TestValidateFieldsInStruct(t *testing.T) {
// todo(SpadeA): add test cases schema := &schemapb.CollectionSchema{
Name: "test_collection",
}
t.Run("valid array field", func(t *testing.T) {
field := &schemapb.FieldSchema{
Name: "valid_array",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
}
err := ValidateFieldsInStruct(field, schema)
assert.NoError(t, err)
})
t.Run("valid array of vector field", func(t *testing.T) {
field := &schemapb.FieldSchema{
Name: "valid_array_vector",
DataType: schemapb.DataType_ArrayOfVector,
ElementType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
}
err := ValidateFieldsInStruct(field, schema)
assert.NoError(t, err)
})
t.Run("invalid field name", func(t *testing.T) {
testCases := []struct {
name string
expected string
}{
{"", "field name should not be empty"},
{"123abc", "The first character of a field name must be an underscore or letter"},
{"abc-def", "Field name can only contain numbers, letters, and underscores"},
}
for _, tc := range testCases {
field := &schemapb.FieldSchema{
Name: tc.name,
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
}
err := ValidateFieldsInStruct(field, schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expected)
}
})
t.Run("invalid data type", func(t *testing.T) {
field := &schemapb.FieldSchema{
Name: "invalid_type",
DataType: schemapb.DataType_Int32, // Not array or array of vector
}
err := ValidateFieldsInStruct(field, schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "Fields in StructArrayField can only be array or array of struct")
})
t.Run("nested array not supported", func(t *testing.T) {
testCases := []struct {
elementType schemapb.DataType
}{
{schemapb.DataType_ArrayOfStruct},
{schemapb.DataType_ArrayOfVector},
{schemapb.DataType_Array},
}
for _, tc := range testCases {
field := &schemapb.FieldSchema{
Name: "nested_array",
DataType: schemapb.DataType_Array,
ElementType: tc.elementType,
}
err := ValidateFieldsInStruct(field, schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "Nested array is not supported")
}
})
t.Run("array field with vector element type", func(t *testing.T) {
field := &schemapb.FieldSchema{
Name: "array_with_vector",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_FloatVector,
}
err := ValidateFieldsInStruct(field, schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "element type of array field array_with_vector is a vector type")
})
t.Run("array of vector field with non-vector element type", func(t *testing.T) {
field := &schemapb.FieldSchema{
Name: "array_vector_with_scalar",
DataType: schemapb.DataType_ArrayOfVector,
ElementType: schemapb.DataType_Int32,
}
err := ValidateFieldsInStruct(field, schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "element type of array field array_vector_with_scalar is not a vector type")
})
t.Run("array of vector missing dimension", func(t *testing.T) {
field := &schemapb.FieldSchema{
Name: "array_vector_no_dim",
DataType: schemapb.DataType_ArrayOfVector,
ElementType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{}, // No dimension specified
}
err := ValidateFieldsInStruct(field, schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "dimension is not defined in field")
})
t.Run("array of vector with invalid dimension", func(t *testing.T) {
field := &schemapb.FieldSchema{
Name: "array_vector_invalid_dim",
DataType: schemapb.DataType_ArrayOfVector,
ElementType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "not_a_number"},
},
}
err := ValidateFieldsInStruct(field, schema)
assert.Error(t, err)
})
t.Run("varchar array without max_length", func(t *testing.T) {
field := &schemapb.FieldSchema{
Name: "varchar_array_no_max_length",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{}, // No max_length specified
}
err := ValidateFieldsInStruct(field, schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "type param(max_length) should be specified")
})
t.Run("varchar array with valid max_length", func(t *testing.T) {
field := &schemapb.FieldSchema{
Name: "varchar_array_valid",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.MaxLengthKey, Value: "100"},
},
}
err := ValidateFieldsInStruct(field, schema)
assert.NoError(t, err)
})
t.Run("varchar array with invalid max_length", func(t *testing.T) {
// Test with max_length exceeding limit
field := &schemapb.FieldSchema{
Name: "varchar_array_invalid_length",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.MaxLengthKey, Value: "99999999"}, // Exceeds limit
},
}
err := ValidateFieldsInStruct(field, schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "the maximum length specified for the field")
})
t.Run("nullable field not supported", func(t *testing.T) {
field := &schemapb.FieldSchema{
Name: "nullable_field",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
Nullable: true,
}
err := ValidateFieldsInStruct(field, schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "nullable is not supported for fields in struct array now")
})
t.Run("sparse float vector in array of vector", func(t *testing.T) {
// Note: ArrayOfVector with sparse vector element type still requires dimension
// because validateDimension checks the field's DataType (ArrayOfVector), not ElementType
field := &schemapb.FieldSchema{
Name: "sparse_vector_array",
DataType: schemapb.DataType_ArrayOfVector,
ElementType: schemapb.DataType_SparseFloatVector,
TypeParams: []*commonpb.KeyValuePair{},
}
err := ValidateFieldsInStruct(field, schema)
assert.Error(t, err)
assert.Contains(t, err.Error(), "dimension is not defined")
})
t.Run("array with various scalar element types", func(t *testing.T) {
validScalarTypes := []schemapb.DataType{
schemapb.DataType_Bool,
schemapb.DataType_Int8,
schemapb.DataType_Int16,
schemapb.DataType_Int32,
schemapb.DataType_Int64,
schemapb.DataType_Float,
schemapb.DataType_Double,
schemapb.DataType_String,
}
for _, dt := range validScalarTypes {
field := &schemapb.FieldSchema{
Name: "array_" + dt.String(),
DataType: schemapb.DataType_Array,
ElementType: dt,
}
err := ValidateFieldsInStruct(field, schema)
assert.NoError(t, err)
}
})
t.Run("array of vector with various vector types", func(t *testing.T) {
validVectorTypes := []schemapb.DataType{
schemapb.DataType_FloatVector,
schemapb.DataType_BinaryVector,
schemapb.DataType_Float16Vector,
schemapb.DataType_BFloat16Vector,
// Note: SparseFloatVector is excluded because validateDimension checks
// the field's DataType (ArrayOfVector), not ElementType, so it still requires dimension
}
for _, vt := range validVectorTypes {
field := &schemapb.FieldSchema{
Name: "vector_array_" + vt.String(),
DataType: schemapb.DataType_ArrayOfVector,
ElementType: vt,
TypeParams: []*commonpb.KeyValuePair{
{Key: common.DimKey, Value: "128"},
},
}
err := ValidateFieldsInStruct(field, schema)
assert.NoError(t, err)
}
})
}
func Test_reconstructStructFieldDataCommon(t *testing.T) {
t.Run("count(*) query - should return early", func(t *testing.T) {
fieldsData := []*schemapb.FieldData{
{
FieldName: "count(*)",
FieldId: 0,
Type: schemapb.DataType_Int64,
},
}
outputFields := []string{"count(*)"}
schema := &schemapb.CollectionSchema{
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
FieldID: 102,
Name: "test_struct",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1021,
Name: "sub_field",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
},
},
},
}
originalFieldsData := make([]*schemapb.FieldData, len(fieldsData))
copy(originalFieldsData, fieldsData)
originalOutputFields := make([]string, len(outputFields))
copy(originalOutputFields, outputFields)
resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema)
// Should not modify anything for count(*) query
assert.Equal(t, originalFieldsData, resultFieldsData)
assert.Equal(t, originalOutputFields, resultOutputFields)
})
t.Run("no struct array fields - should return early", func(t *testing.T) {
fieldsData := []*schemapb.FieldData{
{
FieldName: "field1",
FieldId: 100,
Type: schemapb.DataType_Int64,
},
{
FieldName: "field2",
FieldId: 101,
Type: schemapb.DataType_VarChar,
},
}
outputFields := []string{"field1", "field2"}
schema := &schemapb.CollectionSchema{
StructArrayFields: []*schemapb.StructArrayFieldSchema{},
}
originalFieldsData := make([]*schemapb.FieldData, len(fieldsData))
copy(originalFieldsData, fieldsData)
originalOutputFields := make([]string, len(outputFields))
copy(originalOutputFields, outputFields)
resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema)
// Should not modify anything when no struct array fields
assert.Equal(t, originalFieldsData, resultFieldsData)
assert.Equal(t, originalOutputFields, resultOutputFields)
})
t.Run("reconstruct single struct field", func(t *testing.T) {
// Create mock data
subField1Data := &schemapb.FieldData{
FieldName: "sub_int_array",
FieldId: 1021,
Type: schemapb.DataType_Array,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
ElementType: schemapb.DataType_Int32,
Data: []*schemapb.ScalarField{
{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}},
},
},
},
},
},
},
},
}
subField2Data := &schemapb.FieldData{
FieldName: "sub_text_array",
FieldId: 1022,
Type: schemapb.DataType_Array,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
ElementType: schemapb.DataType_VarChar,
Data: []*schemapb.ScalarField{
{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{Data: []string{"hello", "world"}},
},
},
},
},
},
},
},
}
fieldsData := []*schemapb.FieldData{subField1Data, subField2Data}
outputFields := []string{"sub_int_array", "sub_text_array"}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "pk",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
},
},
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
FieldID: 102,
Name: "test_struct",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1021,
Name: "sub_int_array",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
{
FieldID: 1022,
Name: "sub_text_array",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_VarChar,
},
},
},
},
}
resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema)
// Check result
assert.Len(t, resultFieldsData, 1, "Should only have one reconstructed struct field")
assert.Len(t, resultOutputFields, 1, "Output fields should only have one")
structField := resultFieldsData[0]
assert.Equal(t, "test_struct", structField.FieldName)
assert.Equal(t, int64(102), structField.FieldId)
assert.Equal(t, schemapb.DataType_ArrayOfStruct, structField.Type)
assert.Equal(t, "test_struct", resultOutputFields[0])
// Check fields inside struct
structArrays := structField.GetStructArrays()
assert.NotNil(t, structArrays)
assert.Len(t, structArrays.Fields, 2, "Struct should contain 2 sub fields")
// Check sub fields
var foundIntField, foundTextField bool
for _, field := range structArrays.Fields {
switch field.FieldId {
case 1021:
assert.Equal(t, "sub_int_array", field.FieldName)
assert.Equal(t, schemapb.DataType_Array, field.Type)
foundIntField = true
case 1022:
assert.Equal(t, "sub_text_array", field.FieldName)
assert.Equal(t, schemapb.DataType_Array, field.Type)
foundTextField = true
}
}
assert.True(t, foundIntField, "Should find int array field")
assert.True(t, foundTextField, "Should find text array field")
})
t.Run("mixed regular and struct fields", func(t *testing.T) {
// Create regular field data
regularField := &schemapb.FieldData{
FieldName: "regular_field",
FieldId: 100,
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
},
},
},
}
// Create struct sub field data
subFieldData := &schemapb.FieldData{
FieldName: "sub_field",
FieldId: 1021,
Type: schemapb.DataType_Array,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
ElementType: schemapb.DataType_Int32,
Data: []*schemapb.ScalarField{
{
Data: &schemapb.ScalarField_IntData{
IntData: &schemapb.IntArray{Data: []int32{10, 20}},
},
},
},
},
},
},
},
}
fieldsData := []*schemapb.FieldData{regularField, subFieldData}
outputFields := []string{"regular_field", "sub_field"}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "regular_field",
DataType: schemapb.DataType_Int64,
},
},
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
FieldID: 102,
Name: "test_struct",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1021,
Name: "sub_field",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
},
},
},
}
resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema)
// Check result: should have 2 fields (1 regular + 1 reconstructed struct)
assert.Len(t, resultFieldsData, 2)
assert.Len(t, resultOutputFields, 2)
// Check regular and struct fields both exist
var foundRegularField, foundStructField bool
for i, field := range resultFieldsData {
switch field.FieldId {
case 100:
assert.Equal(t, "regular_field", field.FieldName)
assert.Equal(t, schemapb.DataType_Int64, field.Type)
assert.Equal(t, "regular_field", resultOutputFields[i])
foundRegularField = true
case 102:
assert.Equal(t, "test_struct", field.FieldName)
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
assert.Equal(t, "test_struct", resultOutputFields[i])
foundStructField = true
}
}
assert.True(t, foundRegularField, "Should find regular field")
assert.True(t, foundStructField, "Should find reconstructed struct field")
})
t.Run("multiple struct fields", func(t *testing.T) {
// Create sub field for first struct
struct1SubField := &schemapb.FieldData{
FieldName: "struct1_sub",
FieldId: 1021,
Type: schemapb.DataType_Array,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
ElementType: schemapb.DataType_Int32,
Data: []*schemapb.ScalarField{},
},
},
},
},
}
// Create sub fields for second struct
struct2SubField1 := &schemapb.FieldData{
FieldName: "struct2_sub1",
FieldId: 1031,
Type: schemapb.DataType_Array,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
ElementType: schemapb.DataType_Int32,
Data: []*schemapb.ScalarField{},
},
},
},
},
}
struct2SubField2 := &schemapb.FieldData{
FieldName: "struct2_sub2",
FieldId: 1032,
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{Data: []string{"test"}},
},
},
},
}
fieldsData := []*schemapb.FieldData{struct1SubField, struct2SubField1, struct2SubField2}
outputFields := []string{"struct1_sub", "struct2_sub1", "struct2_sub2"}
schema := &schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "pk",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
},
},
StructArrayFields: []*schemapb.StructArrayFieldSchema{
{
FieldID: 102,
Name: "struct1",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1021,
Name: "struct1_sub",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
},
},
{
FieldID: 103,
Name: "struct2",
Fields: []*schemapb.FieldSchema{
{
FieldID: 1031,
Name: "struct2_sub1",
DataType: schemapb.DataType_Array,
ElementType: schemapb.DataType_Int32,
},
{
FieldID: 1032,
Name: "struct2_sub2",
DataType: schemapb.DataType_VarChar,
},
},
},
},
}
resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema)
// Check result: should have 2 struct fields
assert.Len(t, resultFieldsData, 2)
assert.Len(t, resultOutputFields, 2)
// Check both struct fields
var foundStruct1, foundStruct2 bool
for _, field := range resultFieldsData {
switch field.FieldId {
case 102:
assert.Equal(t, "struct1", field.FieldName)
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
foundStruct1 = true
structArrays := field.GetStructArrays()
assert.NotNil(t, structArrays)
assert.Len(t, structArrays.Fields, 1)
case 103:
assert.Equal(t, "struct2", field.FieldName)
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
foundStruct2 = true
structArrays := field.GetStructArrays()
assert.NotNil(t, structArrays)
assert.Len(t, structArrays.Fields, 2)
}
}
assert.True(t, foundStruct1, "Should find struct1")
assert.True(t, foundStruct2, "Should find struct2")
})
} }

View File

@ -33,7 +33,6 @@ import (
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/cgopb" "github.com/milvus-io/milvus/pkg/v2/proto/cgopb"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -115,25 +114,6 @@ func (li *LoadIndexInfo) appendIndexFile(ctx context.Context, filePath string) e
return HandleCStatus(ctx, &status, "AppendIndexIFile failed") return HandleCStatus(ctx, &status, "AppendIndexIFile failed")
} }
// appendFieldInfo appends fieldID & fieldType to index
func (li *LoadIndexInfo) appendFieldInfo(ctx context.Context, collectionID int64, partitionID int64, segmentID int64, fieldID int64, fieldType schemapb.DataType, enableMmap bool, mmapDirPath string) error {
var status C.CStatus
GetDynamicPool().Submit(func() (any, error) {
cColID := C.int64_t(collectionID)
cParID := C.int64_t(partitionID)
cSegID := C.int64_t(segmentID)
cFieldID := C.int64_t(fieldID)
cintDType := uint32(fieldType)
cEnableMmap := C.bool(enableMmap)
cMmapDirPath := C.CString(mmapDirPath)
defer C.free(unsafe.Pointer(cMmapDirPath))
status = C.AppendFieldInfo(li.cLoadIndexInfo, cColID, cParID, cSegID, cFieldID, cintDType, cEnableMmap, cMmapDirPath)
return nil, nil
}).Await()
return HandleCStatus(ctx, &status, "AppendFieldInfo failed")
}
func (li *LoadIndexInfo) appendStorageInfo(uri string, version int64) { func (li *LoadIndexInfo) appendStorageInfo(uri string, version int64) {
GetDynamicPool().Submit(func() (any, error) { GetDynamicPool().Submit(func() (any, error) {
cURI := C.CString(uri) cURI := C.CString(uri)

View File

@ -476,9 +476,16 @@ func checkFieldSchema(fieldSchemas []*schemapb.FieldSchema) error {
func checkStructArrayFieldSchema(schemas []*schemapb.StructArrayFieldSchema) error { func checkStructArrayFieldSchema(schemas []*schemapb.StructArrayFieldSchema) error {
for _, schema := range schemas { for _, schema := range schemas {
// todo(SpadeA): check struct array field schema if len(schema.GetFields()) == 0 {
return merr.WrapErrParameterInvalidMsg("empty fields in StructArrayField is not allowed")
}
for _, field := range schema.GetFields() { for _, field := range schema.GetFields() {
if field.GetDataType() != schemapb.DataType_Array && field.GetDataType() != schemapb.DataType_ArrayOfVector {
msg := fmt.Sprintf("Fields in StructArrayField can only be array or array of vector, but field %s is %s", field.Name, field.DataType.String())
return merr.WrapErrParameterInvalidMsg(msg)
}
if field.IsPartitionKey || field.IsPrimaryKey { if field.IsPartitionKey || field.IsPrimaryKey {
msg := fmt.Sprintf("partition key or primary key can not be in struct array field. data type:%s, element type:%s, name:%s", msg := fmt.Sprintf("partition key or primary key can not be in struct array field. data type:%s, element type:%s, name:%s",
field.DataType.String(), field.ElementType.String(), field.Name) field.DataType.String(), field.ElementType.String(), field.Name)

View File

@ -19,6 +19,7 @@ package storage
import ( import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
) )
// DataSorter sorts insert data // DataSorter sorts insert data
@ -52,11 +53,7 @@ func (ds *DataSorter) Len() int {
// Swap swaps each field's i-th and j-th element // Swap swaps each field's i-th and j-th element
func (ds *DataSorter) Swap(i, j int) { func (ds *DataSorter) Swap(i, j int) {
if ds.AllFields == nil { if ds.AllFields == nil {
allFields := ds.InsertCodec.Schema.Schema.Fields ds.AllFields = typeutil.GetAllFieldSchemas(ds.InsertCodec.Schema.Schema)
for _, field := range ds.InsertCodec.Schema.Schema.StructArrayFields {
allFields = append(allFields, field.Fields...)
}
ds.AllFields = allFields
} }
for _, field := range ds.AllFields { for _, field := range ds.AllFields {
singleData, has := ds.InsertData.Data[field.FieldID] singleData, has := ds.InsertData.Data[field.FieldID]

View File

@ -282,9 +282,9 @@ func ValueDeserializerWithSchema(r Record, v []*Value, schema *schemapb.Collecti
return valueDeserializer(r, v, allFields, shouldCopy) return valueDeserializer(r, v, allFields, shouldCopy)
} }
func valueDeserializer(r Record, v []*Value, fieldSchema []*schemapb.FieldSchema, shouldCopy bool) error { func valueDeserializer(r Record, v []*Value, fields []*schemapb.FieldSchema, shouldCopy bool) error {
pkField := func() *schemapb.FieldSchema { pkField := func() *schemapb.FieldSchema {
for _, field := range fieldSchema { for _, field := range fields {
if field.GetIsPrimaryKey() { if field.GetIsPrimaryKey() {
return field return field
} }
@ -299,12 +299,12 @@ func valueDeserializer(r Record, v []*Value, fieldSchema []*schemapb.FieldSchema
value := v[i] value := v[i]
if value == nil { if value == nil {
value = &Value{} value = &Value{}
value.Value = make(map[FieldID]interface{}, len(fieldSchema)) value.Value = make(map[FieldID]interface{}, len(fields))
v[i] = value v[i] = value
} }
m := value.Value.(map[FieldID]interface{}) m := value.Value.(map[FieldID]interface{})
for _, f := range fieldSchema { for _, f := range fields {
j := f.FieldID j := f.FieldID
dt := f.DataType dt := f.DataType
if r.Column(j).IsNull(i) { if r.Column(j).IsNull(i) {

View File

@ -1532,7 +1532,9 @@ func GetDefaultValue(fieldSchema *schemapb.FieldSchema) interface{} {
func fillMissingFields(schema *schemapb.CollectionSchema, insertData *InsertData) error { func fillMissingFields(schema *schemapb.CollectionSchema, insertData *InsertData) error {
batchRows := int64(insertData.GetRowNum()) batchRows := int64(insertData.GetRowNum())
for _, field := range schema.Fields { allFields := typeutil.GetAllFieldSchemas(schema)
for _, field := range allFields {
// Skip function output fields and system fields // Skip function output fields and system fields
if field.GetIsFunctionOutput() || field.GetFieldID() < 100 { if field.GetIsFunctionOutput() || field.GetFieldID() < 100 {
continue continue

View File

@ -9,7 +9,7 @@ type AUTOINDEXChecker struct {
baseChecker baseChecker
} }
func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error {
return nil return nil
} }

View File

@ -24,7 +24,7 @@ import (
type baseChecker struct{} type baseChecker struct{}
func (c baseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { func (c baseChecker) CheckTrain(_ schemapb.DataType, _ schemapb.DataType, _ map[string]string) error {
return nil return nil
} }
@ -35,7 +35,7 @@ func (c baseChecker) CheckValidDataType(indexType IndexType, field *schemapb.Fie
func (c baseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, m map[string]string) {} func (c baseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, m map[string]string) {}
func (c baseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { func (c baseChecker) StaticCheck(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error {
return errors.New("unsupported index type") return errors.New("unsupported index type")
} }

View File

@ -47,9 +47,9 @@ func Test_baseChecker_CheckTrain(t *testing.T) {
test.params[common.IndexTypeKey] = "HNSW" test.params[common.IndexTypeKey] = "HNSW"
var err error var err error
if test.params[common.IsSparseKey] == "True" { if test.params[common.IsSparseKey] == "True" {
err = c.CheckTrain(schemapb.DataType_SparseFloatVector, test.params) err = c.CheckTrain(schemapb.DataType_SparseFloatVector, schemapb.DataType_None, test.params)
} else { } else {
err = c.CheckTrain(schemapb.DataType_FloatVector, test.params) err = c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params)
} }
if test.errIsNil { if test.errIsNil {
assert.NoError(t, err) assert.NoError(t, err)
@ -132,5 +132,5 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) {
func Test_baseChecker_StaticCheck(t *testing.T) { func Test_baseChecker_StaticCheck(t *testing.T) {
// TODO // TODO
assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, nil)) assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, schemapb.DataType_None, nil))
} }

View File

@ -68,7 +68,7 @@ func Test_binFlatChecker_CheckTrain(t *testing.T) {
c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT") c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT")
for _, test := range cases { for _, test := range cases {
test.params[common.IndexTypeKey] = "BINFLAT" test.params[common.IndexTypeKey] = "BINFLAT"
err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params) err := c.CheckTrain(schemapb.DataType_BinaryVector, schemapb.DataType_None, test.params)
if test.errIsNil { if test.errIsNil {
assert.NoError(t, err) assert.NoError(t, err)
} else { } else {

View File

@ -119,7 +119,7 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT") c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT")
for _, test := range cases { for _, test := range cases {
test.params[common.IndexTypeKey] = "BIN_IVF_FLAT" test.params[common.IndexTypeKey] = "BIN_IVF_FLAT"
err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params) err := c.CheckTrain(schemapb.DataType_BinaryVector, schemapb.DataType_None, test.params)
if test.errIsNil { if test.errIsNil {
assert.NoError(t, err) assert.NoError(t, err)
} else { } else {

Some files were not shown because too many files have changed in this diff Show More