diff --git a/internal/core/src/common/Chunk.h b/internal/core/src/common/Chunk.h index ca90ebc2f1..45a9ca47b2 100644 --- a/internal/core/src/common/Chunk.h +++ b/internal/core/src/common/Chunk.h @@ -395,6 +395,14 @@ class VectorArrayChunk : public Chunk { dim_(dim), element_type_(element_type) { offsets_lens_ = reinterpret_cast(data); + + auto offset = 0; + lims_.reserve(row_nums_ + 1); + lims_.push_back(offset); + for (int64_t i = 0; i < row_nums_; i++) { + offset += offsets_lens_[i * 2 + 1]; + lims_.push_back(offset); + } } VectorArrayView @@ -424,10 +432,23 @@ class VectorArrayChunk : public Chunk { "VectorArrayChunk::ValueAt is not supported"); } + const char* + Data() const override { + return data_ + offsets_lens_[0]; + } + + const size_t* + Lims() const { + return lims_.data(); + } + private: int64_t dim_; uint32_t* offsets_lens_; milvus::DataType element_type_; + // The name 'Lims' is consistent with knowhere::DataSet::SetLims which describes the number of vectors + // in each vector array (embedding list). This is needed as vectors are flattened in the chunk. + std::vector lims_; }; class SparseFloatVectorChunk : public Chunk { diff --git a/internal/core/src/common/FieldData.h b/internal/core/src/common/FieldData.h index 0d5a6c6830..36b45c965c 100644 --- a/internal/core/src/common/FieldData.h +++ b/internal/core/src/common/FieldData.h @@ -92,6 +92,15 @@ class FieldData : public FieldDataVectorArrayImpl { ThrowInfo(Unsupported, "Call get_dim on FieldData is not supported"); } + + const VectorArray* + value_at(ssize_t offset) const { + AssertInfo(offset < get_num_rows(), + "field data subscript out of range"); + AssertInfo(offset < length(), + "subscript position don't has valid value"); + return &data_[offset]; + } }; template <> diff --git a/internal/core/src/common/TypeTraits.h b/internal/core/src/common/TypeTraits.h index f3875e52d6..c99ca0ebcd 100644 --- a/internal/core/src/common/TypeTraits.h +++ b/internal/core/src/common/TypeTraits.h @@ -47,6 +47,7 @@ constexpr bool IsVariableType = IsSparse || std::is_same_v || std::is_same_v; +// todo(SpadeA): support vector array template constexpr bool IsVariableTypeSupportInChunk = std::is_same_v || std::is_same_v || diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index d1b1125cb2..443687765f 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -493,7 +493,8 @@ IsFloatVectorMetricType(const MetricType& metric_type) { return metric_type == knowhere::metric::L2 || metric_type == knowhere::metric::IP || metric_type == knowhere::metric::COSINE || - metric_type == knowhere::metric::BM25; + metric_type == knowhere::metric::BM25 || + metric_type == knowhere::metric::MAX_SIM; } inline bool diff --git a/internal/core/src/common/VectorTrait.h b/internal/core/src/common/VectorTrait.h index a0c4b9bdba..c853577c46 100644 --- a/internal/core/src/common/VectorTrait.h +++ b/internal/core/src/common/VectorTrait.h @@ -28,20 +28,23 @@ namespace milvus { -#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \ - using elem_type = std::conditional_t< \ - std::is_same_v, \ - milvus::FloatVector::embedded_type, \ - std::conditional_t< \ - std::is_same_v, \ - milvus::Float16Vector::embedded_type, \ - std::conditional_t< \ - std::is_same_v, \ - milvus::BFloat16Vector::embedded_type, \ - std::conditional_t< \ - std::is_same_v, \ - milvus::Int8Vector::embedded_type, \ - milvus::BinaryVector::embedded_type>>>>; +#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \ + using elem_type = std::conditional_t< \ + std::is_same_v, \ + milvus::EmbListFloatVector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::FloatVector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::Float16Vector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::BFloat16Vector::embedded_type, \ + std::conditional_t< \ + std::is_same_v, \ + milvus::Int8Vector::embedded_type, \ + milvus::BinaryVector::embedded_type>>>>>; #define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \ auto schema_data_type = \ @@ -55,7 +58,13 @@ namespace milvus { ? milvus::Int8Vector::schema_data_type \ : milvus::BinaryVector::schema_data_type; -class VectorTrait {}; +class VectorTrait { + public: + static constexpr bool + is_embedding_list() { + return false; + } +}; class FloatVector : public VectorTrait { public: @@ -136,6 +145,25 @@ class Int8Vector : public VectorTrait { proto::common::PlaceholderType::Int8Vector; }; +class EmbListFloatVector : public VectorTrait { + public: + using embedded_type = float; + static constexpr int32_t dim_factor = 1; + static constexpr auto data_type = DataType::VECTOR_ARRAY; + static constexpr auto c_data_type = CDataType::VectorArray; + static constexpr auto schema_data_type = + proto::schema::DataType::ArrayOfVector; + static constexpr auto vector_type = + proto::plan::VectorType::EmbListFloatVector; + static constexpr auto placeholder_type = + proto::common::PlaceholderType::EmbListFloatVector; + + static constexpr bool + is_embedding_list() { + return true; + } +}; + struct FundamentalTag {}; struct StringTag {}; diff --git a/internal/core/src/common/type_c.h b/internal/core/src/common/type_c.h index cd3b2ea21e..73e6f8c46c 100644 --- a/internal/core/src/common/type_c.h +++ b/internal/core/src/common/type_c.h @@ -55,6 +55,7 @@ enum CDataType { BFloat16Vector = 103, SparseFloatVector = 104, Int8Vector = 105, + VectorArray = 106, }; typedef enum CDataType CDataType; diff --git a/internal/core/src/exec/operator/VectorSearchNode.cpp b/internal/core/src/exec/operator/VectorSearchNode.cpp index be7ec74404..8a3656f8a3 100644 --- a/internal/core/src/exec/operator/VectorSearchNode.cpp +++ b/internal/core/src/exec/operator/VectorSearchNode.cpp @@ -69,6 +69,7 @@ PhyVectorSearchNode::GetOutput() { auto& ph = placeholder_group_->at(0); auto src_data = ph.get_blob(); + auto src_lims = ph.get_lims(); auto num_queries = ph.num_of_queries_; milvus::SearchResult search_result; @@ -85,6 +86,7 @@ PhyVectorSearchNode::GetOutput() { col_input->size()); segment_->vector_search(search_info_, src_data, + src_lims, num_queries, query_timestamp_, final_view, diff --git a/internal/core/src/index/IndexFactory.cpp b/internal/core/src/index/IndexFactory.cpp index 47a60cbb60..dd9ddd2536 100644 --- a/internal/core/src/index/IndexFactory.cpp +++ b/internal/core/src/index/IndexFactory.cpp @@ -98,13 +98,18 @@ IndexFactory::CreatePrimitiveScalarIndex( LoadResourceRequest IndexFactory::IndexLoadResource( DataType field_type, + DataType element_type, IndexVersion index_version, float index_size, const std::map& index_params, bool mmap_enable) { if (milvus::IsVectorDataType(field_type)) { - return VecIndexLoadResource( - field_type, index_version, index_size, index_params, mmap_enable); + return VecIndexLoadResource(field_type, + element_type, + index_version, + index_size, + index_params, + mmap_enable); } else { return ScalarIndexLoadResource( field_type, index_version, index_size, index_params, mmap_enable); @@ -114,6 +119,7 @@ IndexFactory::IndexLoadResource( LoadResourceRequest IndexFactory::VecIndexLoadResource( DataType field_type, + DataType element_type, IndexVersion index_version, float index_size, const std::map& index_params, @@ -198,6 +204,29 @@ IndexFactory::VecIndexLoadResource( knowhere::IndexStaticFaced::HasRawData( index_type, index_version, config); break; + case milvus::DataType::VECTOR_ARRAY: { + switch (element_type) { + case milvus::DataType::VECTOR_FLOAT: + resource = knowhere::IndexStaticFaced< + knowhere::fp32>::EstimateLoadResource(index_type, + index_version, + index_size_gb, + config); + has_raw_data = + knowhere::IndexStaticFaced::HasRawData( + index_type, index_version, config); + break; + + default: + LOG_ERROR( + "invalid data type to estimate index load resource: " + "field_type {}, element_type {}", + field_type, + element_type); + return LoadResourceRequest{0, 0, 0, 0, true}; + } + break; + } default: LOG_ERROR("invalid data type to estimate index load resource: {}", field_type); @@ -491,8 +520,14 @@ IndexFactory::CreateVectorIndex( return std::make_unique>( index_type, metric_type, version, file_manager_context); } + case DataType::VECTOR_ARRAY: { + ThrowInfo(Unsupported, + "VECTOR_ARRAY for DiskAnnIndex is not supported"); + } case DataType::VECTOR_INT8: { // TODO caiyd, not support yet + ThrowInfo(Unsupported, + "VECTOR_INT8 for DiskAnnIndex is not supported"); } default: ThrowInfo( @@ -505,6 +540,7 @@ IndexFactory::CreateVectorIndex( case DataType::VECTOR_FLOAT: case DataType::VECTOR_SPARSE_FLOAT: { return std::make_unique>( + DataType::NONE, index_type, metric_type, version, @@ -513,6 +549,7 @@ IndexFactory::CreateVectorIndex( } case DataType::VECTOR_BINARY: { return std::make_unique>( + DataType::NONE, index_type, metric_type, version, @@ -521,6 +558,7 @@ IndexFactory::CreateVectorIndex( } case DataType::VECTOR_FLOAT16: { return std::make_unique>( + DataType::NONE, index_type, metric_type, version, @@ -529,6 +567,7 @@ IndexFactory::CreateVectorIndex( } case DataType::VECTOR_BFLOAT16: { return std::make_unique>( + DataType::NONE, index_type, metric_type, version, @@ -537,12 +576,33 @@ IndexFactory::CreateVectorIndex( } case DataType::VECTOR_INT8: { return std::make_unique>( + DataType::NONE, index_type, metric_type, version, use_knowhere_build_pool, file_manager_context); } + case DataType::VECTOR_ARRAY: { + auto element_type = + static_cast(file_manager_context.fieldDataMeta + .field_schema.element_type()); + switch (element_type) { + case DataType::VECTOR_FLOAT: + return std::make_unique>( + element_type, + index_type, + metric_type, + version, + use_knowhere_build_pool, + file_manager_context); + default: + ThrowInfo(NotImplemented, + fmt::format("not implemented data type to " + "build mem index: {}", + data_type)); + } + } default: ThrowInfo( DataTypeInvalid, diff --git a/internal/core/src/index/IndexFactory.h b/internal/core/src/index/IndexFactory.h index b682423b51..f0cf7e6e95 100644 --- a/internal/core/src/index/IndexFactory.h +++ b/internal/core/src/index/IndexFactory.h @@ -56,6 +56,7 @@ class IndexFactory { LoadResourceRequest IndexLoadResource(DataType field_type, + DataType element_type, IndexVersion index_version, float index_size, const std::map& index_params, @@ -63,6 +64,7 @@ class IndexFactory { LoadResourceRequest VecIndexLoadResource(DataType field_type, + DataType element_type, IndexVersion index_version, float index_size, const std::map& index_params, diff --git a/internal/core/src/index/Utils.h b/internal/core/src/index/Utils.h index bcb374bf73..5c3be6006f 100644 --- a/internal/core/src/index/Utils.h +++ b/internal/core/src/index/Utils.h @@ -229,4 +229,21 @@ void inline SetBitsetGrowing(void* bitset, } } +inline size_t +vector_element_size(const DataType data_type) { + switch (data_type) { + case DataType::VECTOR_FLOAT: + return sizeof(float); + case DataType::VECTOR_FLOAT16: + return sizeof(float16); + case DataType::VECTOR_BFLOAT16: + return sizeof(bfloat16); + case DataType::VECTOR_INT8: + return sizeof(int8); + default: + ThrowInfo(UnexpectedError, + fmt::format("invalid data type: {}", data_type)); + } +} + } // namespace milvus::index diff --git a/internal/core/src/index/VectorDiskIndex.cpp b/internal/core/src/index/VectorDiskIndex.cpp index fad6409dac..ec06f6c750 100644 --- a/internal/core/src/index/VectorDiskIndex.cpp +++ b/internal/core/src/index/VectorDiskIndex.cpp @@ -245,7 +245,7 @@ VectorDiskAnnIndex::Query(const DatasetPtr dataset, SearchResult& search_result) const { AssertInfo(GetMetricType() == search_info.metric_type_, "Metric type of field index isn't the same with search info"); - auto num_queries = dataset->GetRows(); + auto num_rows = dataset->GetRows(); auto topk = search_info.topk_; knowhere::Json search_config = PrepareSearchParams(search_info); @@ -277,7 +277,7 @@ VectorDiskAnnIndex::Query(const DatasetPtr dataset, res.what())); } return ReGenRangeSearchResult( - res.value(), topk, num_queries, GetMetricType()); + res.value(), topk, num_rows, GetMetricType()); } else { auto res = index_.Search(dataset, search_config, bitset); if (!res.has_value()) { @@ -291,6 +291,8 @@ VectorDiskAnnIndex::Query(const DatasetPtr dataset, }(); auto ids = final->GetIds(); + // In embedding list query, final->GetRows() can be different from dataset->GetRows(). + auto num_queries = final->GetRows(); float* distances = const_cast(final->GetDistance()); final->SetIsOwner(true); diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index 31be8b3f0c..8de6e5c2cf 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -59,12 +59,14 @@ namespace milvus::index { template VectorMemIndex::VectorMemIndex( + DataType elem_type, const IndexType& index_type, const MetricType& metric_type, const IndexVersion& version, bool use_knowhere_build_pool, const storage::FileManagerContext& file_manager_context) : VectorIndex(index_type, metric_type), + elem_type_(elem_type), use_knowhere_build_pool_(use_knowhere_build_pool) { CheckMetricTypeSupport(metric_type); AssertInfo(!is_unsupported(index_type, metric_type), @@ -89,12 +91,14 @@ VectorMemIndex::VectorMemIndex( } template -VectorMemIndex::VectorMemIndex(const IndexType& index_type, +VectorMemIndex::VectorMemIndex(DataType elem_type, + const IndexType& index_type, const MetricType& metric_type, const IndexVersion& version, const knowhere::ViewDataOp view_data, bool use_knowhere_build_pool) : VectorIndex(index_type, metric_type), + elem_type_(elem_type), use_knowhere_build_pool_(use_knowhere_build_pool) { CheckMetricTypeSupport(metric_type); AssertInfo(!is_unsupported(index_type, metric_type), @@ -304,6 +308,11 @@ VectorMemIndex::BuildWithDataset(const DatasetPtr& dataset, SetDim(index_.Dim()); } +bool +is_embedding_list_index(const IndexType& index_type) { + return index_type == knowhere::IndexEnum::INDEX_EMB_LIST_HNSW; +} + template void VectorMemIndex::Build(const Config& config) { @@ -331,23 +340,74 @@ VectorMemIndex::Build(const Config& config) { total_num_rows += data->get_num_rows(); AssertInfo(dim == 0 || dim == data->get_dim(), "inconsistent dim value between field datas!"); - dim = data->get_dim(); + + // todo(SapdeA): now, vector arrays (embedding list) are serialized + // to parquet by using binary format which does not provide dim + // information so we use this temporary solution. + if (is_embedding_list_index(index_type_)) { + AssertInfo(elem_type_ != DataType::NONE, + "embedding list index must have elem_type"); + dim = config[DIM_KEY].get(); + } else { + dim = data->get_dim(); + } } auto buf = std::shared_ptr(new uint8_t[total_size]); + + size_t lim_offset = 0; + std::vector lims; + lims.reserve(total_num_rows + 1); + lims.push_back(lim_offset); + int64_t offset = 0; - // TODO: avoid copying - for (auto data : field_datas) { - std::memcpy(buf.get() + offset, data->Data(), data->Size()); - offset += data->Size(); - data.reset(); + if (!is_embedding_list_index(index_type_)) { + // TODO: avoid copying + for (auto data : field_datas) { + std::memcpy(buf.get() + offset, data->Data(), data->Size()); + offset += data->Size(); + data.reset(); + } + } else { + auto elem_size = vector_element_size(elem_type_); + for (auto data : field_datas) { + auto vec_array_data = + dynamic_cast*>(data.get()); + AssertInfo(vec_array_data != nullptr, + "failed to cast field data to vector array"); + + auto rows = vec_array_data->get_num_rows(); + for (auto i = 0; i < rows; ++i) { + auto size = vec_array_data->DataSize(i); + assert(size % (dim * elem_size) == 0); + assert(dim * elem_size != 0); + + auto vec_array = vec_array_data->value_at(i); + + std::memcpy(buf.get() + offset, vec_array->data(), size); + offset += size; + + lim_offset += size / (dim * elem_size); + lims.push_back(lim_offset); + } + + assert(data->Size() == offset); + + data.reset(); + } + + total_num_rows = lim_offset; } + field_datas.clear(); auto dataset = GenDataset(total_num_rows, dim, buf.get()); if (!scalar_info.empty()) { dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info)); } + if (!lims.empty()) { + dataset->SetLims(lims.data()); + } BuildWithDataset(dataset, build_config); } else { // sparse @@ -409,7 +469,7 @@ VectorMemIndex::Query(const DatasetPtr dataset, // AssertInfo(GetMetricType() == search_info.metric_type_, // "Metric type of field index isn't the same with search info"); - auto num_queries = dataset->GetRows(); + auto num_vectors = dataset->GetRows(); knowhere::Json search_conf = PrepareSearchParams(search_info); auto topk = search_info.topk_; // TODO :: check dim of search data @@ -427,7 +487,7 @@ VectorMemIndex::Query(const DatasetPtr dataset, res.what()); } auto result = ReGenRangeSearchResult( - res.value(), topk, num_queries, GetMetricType()); + res.value(), topk, num_vectors, GetMetricType()); milvus::tracer::AddEvent("finish_ReGenRangeSearchResult"); return result; } else { @@ -448,6 +508,8 @@ VectorMemIndex::Query(const DatasetPtr dataset, }(); auto ids = final->GetIds(); + // In embedding list query, final->GetRows() can be different from dataset->GetRows(). + auto num_queries = final->GetRows(); float* distances = const_cast(final->GetDistance()); final->SetIsOwner(true); auto round_decimal = search_info.round_decimal_; diff --git a/internal/core/src/index/VectorMemIndex.h b/internal/core/src/index/VectorMemIndex.h index 0d2b98a8b0..50a0e37fc7 100644 --- a/internal/core/src/index/VectorMemIndex.h +++ b/internal/core/src/index/VectorMemIndex.h @@ -35,6 +35,7 @@ template class VectorMemIndex : public VectorIndex { public: explicit VectorMemIndex( + DataType elem_type /* used for embedding list only */, const IndexType& index_type, const MetricType& metric_type, const IndexVersion& version, @@ -43,7 +44,8 @@ class VectorMemIndex : public VectorIndex { storage::FileManagerContext()); // knowhere data view index special constucter for intermin index, no need to hold file_manager_ to upload or download files - VectorMemIndex(const IndexType& index_type, + VectorMemIndex(DataType elem_type /* used for embedding list only */, + const IndexType& index_type, const MetricType& metric_type, const IndexVersion& version, const knowhere::ViewDataOp view_data, @@ -108,6 +110,8 @@ class VectorMemIndex : public VectorIndex { Config config_; knowhere::Index index_; std::shared_ptr file_manager_; + // used for embedding list only + DataType elem_type_; CreateIndexInfo create_index_info_; bool use_knowhere_build_pool_; diff --git a/internal/core/src/indexbuilder/IndexFactory.h b/internal/core/src/indexbuilder/IndexFactory.h index fa75990dd7..d5019f9ab3 100644 --- a/internal/core/src/indexbuilder/IndexFactory.h +++ b/internal/core/src/indexbuilder/IndexFactory.h @@ -70,11 +70,8 @@ class IndexFactory { case DataType::VECTOR_BINARY: case DataType::VECTOR_SPARSE_FLOAT: case DataType::VECTOR_INT8: - return std::make_unique(type, config, context); - case DataType::VECTOR_ARRAY: - ThrowInfo(DataTypeInvalid, - fmt::format("VECTOR_ARRAY is not implemented")); + return std::make_unique(type, config, context); default: ThrowInfo(DataTypeInvalid, diff --git a/internal/core/src/indexbuilder/VecIndexCreator.cpp b/internal/core/src/indexbuilder/VecIndexCreator.cpp index 68afdff4ed..124dc4cd39 100644 --- a/internal/core/src/indexbuilder/VecIndexCreator.cpp +++ b/internal/core/src/indexbuilder/VecIndexCreator.cpp @@ -34,6 +34,13 @@ VecIndexCreator::VecIndexCreator( Config& config, const storage::FileManagerContext& file_manager_context) : config_(config), data_type_(data_type) { + if (data_type == DataType::VECTOR_ARRAY) { + // TODO(SpadeA): record dim in config as there's the dim cannot be inferred in + // parquet due to the serialize method of vector array. + // This should be a temp solution. + config_[DIM_KEY] = file_manager_context.indexMeta.dim; + } + index::CreateIndexInfo index_info; index_info.field_type = data_type_; index_info.index_type = index::GetIndexTypeFromConfig(config_); diff --git a/internal/core/src/mmap/ChunkedColumn.h b/internal/core/src/mmap/ChunkedColumn.h index e90e33b75c..19bb6d990b 100644 --- a/internal/core/src/mmap/ChunkedColumn.h +++ b/internal/core/src/mmap/ChunkedColumn.h @@ -273,6 +273,13 @@ class ChunkedColumnBase : public ChunkedColumnInterface { "VectorArrayViews only supported for ChunkedVectorArrayColumn"); } + virtual PinWrapper + VectorArrayLims(int64_t chunk_id) const override { + ThrowInfo( + ErrorCode::Unsupported, + "VectorArrayLims only supported for ChunkedVectorArrayColumn"); + } + PinWrapper, FixedVector>> StringViewsByOffsets(int64_t chunk_id, const FixedVector& offsets) const override { @@ -621,6 +628,15 @@ class ChunkedVectorArrayColumn : public ChunkedColumnBase { return PinWrapper>( ca, static_cast(chunk)->Views()); } + + PinWrapper + VectorArrayLims(int64_t chunk_id) const override { + auto ca = + SemiInlineGet(slot_->PinCells({static_cast(chunk_id)})); + auto chunk = ca->get_cell_of(chunk_id); + return PinWrapper( + ca, static_cast(chunk)->Lims()); + } }; inline std::shared_ptr diff --git a/internal/core/src/mmap/ChunkedColumnGroup.h b/internal/core/src/mmap/ChunkedColumnGroup.h index 4b6416180c..70267cb4c0 100644 --- a/internal/core/src/mmap/ChunkedColumnGroup.h +++ b/internal/core/src/mmap/ChunkedColumnGroup.h @@ -319,6 +319,19 @@ class ProxyChunkColumn : public ChunkedColumnInterface { static_cast(chunk.get())->Views()); } + PinWrapper + VectorArrayLims(int64_t chunk_id) const override { + if (!IsChunkedVectorArrayColumnDataType(data_type_)) { + ThrowInfo(ErrorCode::Unsupported, + "VectorArrayLims only supported for " + "ChunkedVectorArrayColumn"); + } + auto chunk_wrapper = group_->GetGroupChunk(chunk_id); + auto chunk = chunk_wrapper.get()->GetChunk(field_id_); + return PinWrapper( + chunk_wrapper, static_cast(chunk.get())->Lims()); + } + PinWrapper, FixedVector>> StringViewsByOffsets(int64_t chunk_id, const FixedVector& offsets) const override { diff --git a/internal/core/src/mmap/ChunkedColumnInterface.h b/internal/core/src/mmap/ChunkedColumnInterface.h index 69698ec30e..715c2f36e1 100644 --- a/internal/core/src/mmap/ChunkedColumnInterface.h +++ b/internal/core/src/mmap/ChunkedColumnInterface.h @@ -84,6 +84,9 @@ class ChunkedColumnInterface { virtual PinWrapper> VectorArrayViews(int64_t chunk_id) const = 0; + virtual PinWrapper + VectorArrayLims(int64_t chunk_id) const = 0; + virtual PinWrapper< std::pair, FixedVector>> StringViewsByOffsets(int64_t chunk_id, diff --git a/internal/core/src/query/ExecPlanNodeVisitor.cpp b/internal/core/src/query/ExecPlanNodeVisitor.cpp index e38cb1cf5c..40bddc1130 100644 --- a/internal/core/src/query/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/ExecPlanNodeVisitor.cpp @@ -242,4 +242,9 @@ ExecPlanNodeVisitor::visit(Int8VectorANNS& node) { VectorVisitorImpl(node); } +void +ExecPlanNodeVisitor::visit(EmbListFloatVectorANNS& node) { + VectorVisitorImpl(node); +} + } // namespace milvus::query diff --git a/internal/core/src/query/ExecPlanNodeVisitor.h b/internal/core/src/query/ExecPlanNodeVisitor.h index 803a282881..6dcbae97ae 100644 --- a/internal/core/src/query/ExecPlanNodeVisitor.h +++ b/internal/core/src/query/ExecPlanNodeVisitor.h @@ -43,6 +43,9 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor { void visit(RetrievePlanNode& node) override; + void + visit(EmbListFloatVectorANNS& node) override; + public: ExecPlanNodeVisitor(const segcore::SegmentInterface& segment, Timestamp timestamp, diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index 9df4cda216..695a4f5ab1 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -30,6 +30,17 @@ ParsePlaceholderGroup(const Plan* plan, placeholder_group_blob.size()); } +bool +check_data_type(const FieldMeta& field_meta, + const milvus::proto::common::PlaceholderType type) { + if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { + return type == + milvus::proto::common::PlaceholderType::EmbListFloatVector; + } + return static_cast(field_meta.get_data_type()) == + static_cast(type); +} + std::unique_ptr ParsePlaceholderGroup(const Plan* plan, const uint8_t* blob, @@ -44,8 +55,7 @@ ParsePlaceholderGroup(const Plan* plan, Assert(plan->tag2field_.count(element.tag_)); auto field_id = plan->tag2field_.at(element.tag_); auto& field_meta = plan->schema_->operator[](field_id); - AssertInfo(static_cast(field_meta.get_data_type()) == - static_cast(info.type()), + AssertInfo(check_data_type(field_meta, info.type()), "vector type must be the same, field {} - type {}, search " "info type {}", field_meta.get_name().get(), @@ -59,23 +69,47 @@ ParsePlaceholderGroup(const Plan* plan, SparseBytesToRows(info.values(), /*validate=*/true); } else { auto line_size = info.values().Get(0).size(); - if (field_meta.get_sizeof() != line_size) { - ThrowInfo( - DimNotMatch, - fmt::format("vector dimension mismatch, expected vector " - "size(byte) {}, actual {}.", - field_meta.get_sizeof(), - line_size)); - } auto& target = element.blob_; - target.reserve(line_size * element.num_of_queries_); - for (auto& line : info.values()) { - AssertInfo(line_size == line.size(), - "vector dimension mismatch, expected vector " - "size(byte) {}, actual {}.", - line_size, - line.size()); - target.insert(target.end(), line.begin(), line.end()); + + if (field_meta.get_data_type() != DataType::VECTOR_ARRAY) { + if (field_meta.get_sizeof() != line_size) { + ThrowInfo(DimNotMatch, + fmt::format( + "vector dimension mismatch, expected vector " + "size(byte) {}, actual {}.", + field_meta.get_sizeof(), + line_size)); + } + target.reserve(line_size * element.num_of_queries_); + for (auto& line : info.values()) { + AssertInfo(line_size == line.size(), + "vector dimension mismatch, expected vector " + "size(byte) {}, actual {}.", + line_size, + line.size()); + target.insert(target.end(), line.begin(), line.end()); + } + } else { + target.reserve(line_size * element.num_of_queries_); + auto dim = field_meta.get_dim(); + + // If the vector is embedding list, line contains multiple vectors. + // And we should record the offsets so that we can identify each + // embedding list in a flattened vectors. + auto& lims = element.lims_; + lims.reserve(element.num_of_queries_ + 1); + size_t offset = 0; + lims.push_back(offset); + + auto elem_size = milvus::index::vector_element_size( + field_meta.get_element_type()); + for (auto& line : info.values()) { + target.insert(target.end(), line.begin(), line.end()); + + Assert(line.size() % (dim * elem_size) == 0); + offset += line.size() / (dim * elem_size); + lims.push_back(offset); + } } } result->emplace_back(std::move(element)); diff --git a/internal/core/src/query/PlanImpl.h b/internal/core/src/query/PlanImpl.h index 085f4a21b0..19f98fffaa 100644 --- a/internal/core/src/query/PlanImpl.h +++ b/internal/core/src/query/PlanImpl.h @@ -68,6 +68,9 @@ struct Plan { struct Placeholder { std::string tag_; + // note: for embedding list search, num_of_queries_ stands for the number of vectors. + // lims_ records the offsets of embedding list in the flattened vector and + // hence lims_.size() - 1 is the number of queries in embedding list search. int64_t num_of_queries_; // TODO(SPARSE): add a dim_ field here, use the dim passed in search request // instead of the dim in schema, since the dim of sparse float column is @@ -78,6 +81,8 @@ struct Placeholder { // dense vector search and sparse_matrix_ is for sparse vector search. aligned_vector blob_; std::unique_ptr[]> sparse_matrix_; + // offsets for embedding list + aligned_vector lims_; const void* get_blob() const { @@ -94,6 +99,16 @@ struct Placeholder { } return blob_.data(); } + + const size_t* + get_lims() const { + return lims_.data(); + } + + size_t* + get_lims() { + return lims_.data(); + } }; struct RetrievePlan { diff --git a/internal/core/src/query/PlanNode.cpp b/internal/core/src/query/PlanNode.cpp index 65214706bd..ea07e804e0 100644 --- a/internal/core/src/query/PlanNode.cpp +++ b/internal/core/src/query/PlanNode.cpp @@ -50,4 +50,9 @@ RetrievePlanNode::accept(PlanNodeVisitor& visitor) { visitor.visit(*this); } +void +EmbListFloatVectorANNS::accept(PlanNodeVisitor& visitor) { + visitor.visit(*this); +} + } // namespace milvus::query diff --git a/internal/core/src/query/PlanNode.h b/internal/core/src/query/PlanNode.h index af5b113231..c6b9e53b52 100644 --- a/internal/core/src/query/PlanNode.h +++ b/internal/core/src/query/PlanNode.h @@ -77,6 +77,12 @@ struct Int8VectorANNS : VectorPlanNode { accept(PlanNodeVisitor&) override; }; +struct EmbListFloatVectorANNS : VectorPlanNode { + public: + void + accept(PlanNodeVisitor&) override; +}; + struct RetrievePlanNode : PlanNode { public: void diff --git a/internal/core/src/query/PlanNodeVisitor.h b/internal/core/src/query/PlanNodeVisitor.h index 9f4620ef47..f912165e4b 100644 --- a/internal/core/src/query/PlanNodeVisitor.h +++ b/internal/core/src/query/PlanNodeVisitor.h @@ -39,5 +39,8 @@ class PlanNodeVisitor { virtual void visit(RetrievePlanNode&) = 0; + + virtual void + visit(EmbListFloatVectorANNS&) = 0; }; } // namespace milvus::query diff --git a/internal/core/src/query/PlanProto.cpp b/internal/core/src/query/PlanProto.cpp index ae08d72366..9dcc11fb22 100644 --- a/internal/core/src/query/PlanProto.cpp +++ b/internal/core/src/query/PlanProto.cpp @@ -127,6 +127,9 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { } else if (anns_proto.vector_type() == milvus::proto::plan::VectorType::Int8Vector) { return std::make_unique(); + } else if (anns_proto.vector_type() == + milvus::proto::plan::VectorType::EmbListFloatVector) { + return std::make_unique(); } else { return std::make_unique(); } diff --git a/internal/core/src/query/SearchBruteForce.cpp b/internal/core/src/query/SearchBruteForce.cpp index 8bfe880595..18435755dd 100644 --- a/internal/core/src/query/SearchBruteForce.cpp +++ b/internal/core/src/query/SearchBruteForce.cpp @@ -89,8 +89,23 @@ PrepareBFDataSet(const dataset::SearchDataset& query_ds, DataType data_type) { auto base_dataset = knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data); + if (raw_ds.raw_data_lims != nullptr) { + // knowhere::DataSet count vectors in a flattened manner where as the num_raw_data here is the number + // of embedding lists where each embedding list contains multiple vectors. So we should use the last element + // in lims which equals to the total number of vectors. + base_dataset->SetLims(raw_ds.raw_data_lims); + // the length of lims equals to the number of embedding lists + 1 + base_dataset->SetRows(raw_ds.raw_data_lims[raw_ds.num_raw_data]); + } + auto query_dataset = knowhere::GenDataSet( query_ds.num_queries, query_ds.dim, query_ds.query_data); + if (query_ds.query_lims != nullptr) { + // ditto + query_dataset->SetLims(query_ds.query_lims); + query_dataset->SetRows(query_ds.query_lims[query_ds.num_queries]); + } + if (data_type == DataType::VECTOR_SPARSE_FLOAT) { base_dataset->SetIsSparse(true); query_dataset->SetIsSparse(true); @@ -105,7 +120,8 @@ BruteForceSearch(const dataset::SearchDataset& query_ds, const SearchInfo& search_info, const std::map& index_info, const BitsetView& bitset, - DataType data_type) { + DataType data_type, + DataType element_type) { SubSearchResult sub_result(query_ds.num_queries, query_ds.topk, query_ds.metric_type, @@ -122,7 +138,18 @@ BruteForceSearch(const dataset::SearchDataset& query_ds, sub_result.mutable_seg_offsets().resize(nq * topk); sub_result.mutable_distances().resize(nq * topk); + // For vector array (embedding list), element type is used to determine how to operate search. + if (data_type == DataType::VECTOR_ARRAY) { + AssertInfo(element_type != DataType::NONE, + "Element type is not specified for vector array"); + data_type = element_type; + } + if (search_cfg.contains(RADIUS)) { + AssertInfo(data_type != DataType::VECTOR_ARRAY, + "Vector array(embedding list) is not supported for range " + "search"); + if (search_cfg.contains(RANGE_FILTER)) { CheckRangeSearchParam(search_cfg[RADIUS], search_cfg[RANGE_FILTER], @@ -238,7 +265,10 @@ DispatchBruteForceIteratorByDataType(const knowhere::DataSetPtr& base_dataset, const knowhere::DataSetPtr& query_dataset, const knowhere::Json& config, const BitsetView& bitset, - const milvus::DataType& data_type) { + milvus::DataType data_type) { + AssertInfo(data_type != DataType::VECTOR_ARRAY, + "VECTOR_ARRAY is not supported for brute force iterator"); + switch (data_type) { case DataType::VECTOR_FLOAT: return knowhere::BruteForce::AnnIterator( diff --git a/internal/core/src/query/SearchBruteForce.h b/internal/core/src/query/SearchBruteForce.h index 51e348bfa2..05389db4e5 100644 --- a/internal/core/src/query/SearchBruteForce.h +++ b/internal/core/src/query/SearchBruteForce.h @@ -29,7 +29,8 @@ BruteForceSearch(const dataset::SearchDataset& query_ds, const SearchInfo& search_info, const std::map& index_info, const BitsetView& bitset, - DataType data_type); + DataType data_type, + DataType element_type); knowhere::expected> GetBruteForceSearchIterators( diff --git a/internal/core/src/query/SearchOnGrowing.cpp b/internal/core/src/query/SearchOnGrowing.cpp index 982d69e846..4e367db994 100644 --- a/internal/core/src/query/SearchOnGrowing.cpp +++ b/internal/core/src/query/SearchOnGrowing.cpp @@ -71,6 +71,7 @@ void SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, const SearchInfo& info, const void* query_data, + const size_t* query_lims, int64_t num_queries, Timestamp timestamp, const BitsetView& bitset, @@ -87,6 +88,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, CheckBruteForceSearchParam(field, info); auto data_type = field.get_data_type(); + auto element_type = field.get_element_type(); AssertInfo(IsVectorDataType(data_type), "[SearchOnGrowing]Data type isn't vector type"); @@ -96,6 +98,11 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, // step 2: small indexing search if (segment.get_indexing_record().SyncDataWithIndex(field.get_id())) { + AssertInfo( + data_type != DataType::VECTOR_ARRAY, + "vector array(embedding list) is not supported for growing segment " + "indexing search"); + FloatSegmentIndexSearch( segment, info, query_data, num_queries, bitset, search_result); } else { @@ -103,6 +110,10 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, segment.get_chunk_mutex()); // check SyncDataWithIndex() again, in case the vector chunks has been removed. if (segment.get_indexing_record().SyncDataWithIndex(field.get_id())) { + AssertInfo(data_type != DataType::VECTOR_ARRAY, + "vector array(embedding list) is not supported for " + "growing segment indexing search"); + return FloatSegmentIndexSearch( segment, info, query_data, num_queries, bitset, search_result); } @@ -111,8 +122,13 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, auto dim = field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT ? 0 : field.get_dim(); - dataset::SearchDataset search_dataset{ - metric_type, num_queries, topk, round_decimal, dim, query_data}; + dataset::SearchDataset search_dataset{metric_type, + num_queries, + topk, + round_decimal, + dim, + query_data, + query_lims}; int32_t current_chunk_id = 0; // get K1 and B from index for bm25 brute force @@ -127,6 +143,10 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, auto vec_ptr = record.get_data_base(vecfield_id); if (info.iterator_v2_info_.has_value()) { + AssertInfo(data_type != DataType::VECTOR_ARRAY, + "vector array(embedding list) is not supported for " + "vector iterator"); + CachedSearchIterator cached_iter(search_dataset, vec_ptr, active_count, @@ -150,9 +170,54 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, std::min(active_count, (chunk_id + 1) * vec_size_per_chunk); auto size_per_chunk = element_end - element_begin; - auto sub_data = query::dataset::RawDataset{ - element_begin, dim, size_per_chunk, chunk_data}; + query::dataset::RawDataset sub_data; + std::unique_ptr buf = nullptr; + std::vector offsets; + if (data_type != DataType::VECTOR_ARRAY) { + sub_data = query::dataset::RawDataset{ + element_begin, dim, size_per_chunk, chunk_data}; + } else { + // TODO(SpadeA): For VectorArray(Embedding List), data is + // discreted stored in FixedVector which means we will copy the + // data to a contiguous memory buffer. This is inefficient and + // will be optimized in the future. + auto vec_ptr = reinterpret_cast(chunk_data); + auto size = 0; + for (int i = 0; i < size_per_chunk; ++i) { + size += vec_ptr[i].byte_size(); + } + + buf = std::make_unique(size); + offsets.reserve(size_per_chunk + 1); + offsets.push_back(0); + + auto offset = 0; + auto ptr = buf.get(); + for (int i = 0; i < size_per_chunk; ++i) { + memcpy(ptr, vec_ptr[i].data(), vec_ptr[i].byte_size()); + ptr += vec_ptr[i].byte_size(); + + offset += vec_ptr[i].length(); + offsets.push_back(offset); + } + sub_data = query::dataset::RawDataset{element_begin, + dim, + size_per_chunk, + buf.get(), + offsets.data()}; + } + + if (data_type == DataType::VECTOR_ARRAY) { + AssertInfo( + query_lims != nullptr, + "query_lims is nullptr, but data_type is vector array"); + } + if (milvus::exec::UseVectorIterator(info)) { + AssertInfo(data_type != DataType::VECTOR_ARRAY, + "vector array(embedding list) is not supported for " + "vector iterator"); + auto sub_qr = PackBruteForceSearchIteratorsIntoSubResult(search_dataset, sub_data, @@ -167,7 +232,8 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, info, index_info, bitset, - data_type); + data_type, + element_type); final_qr.merge(sub_qr); } } diff --git a/internal/core/src/query/SearchOnGrowing.h b/internal/core/src/query/SearchOnGrowing.h index 0b6aeb1add..c487c9532e 100644 --- a/internal/core/src/query/SearchOnGrowing.h +++ b/internal/core/src/query/SearchOnGrowing.h @@ -20,6 +20,7 @@ void SearchOnGrowing(const segcore::SegmentGrowingImpl& segment, const SearchInfo& info, const void* query_data, + const size_t* query_lims, int64_t num_queries, Timestamp timestamp, const BitsetView& bitset, diff --git a/internal/core/src/query/SearchOnSealed.cpp b/internal/core/src/query/SearchOnSealed.cpp index b40a5b7d95..ed76aa40ac 100644 --- a/internal/core/src/query/SearchOnSealed.cpp +++ b/internal/core/src/query/SearchOnSealed.cpp @@ -31,6 +31,7 @@ SearchOnSealedIndex(const Schema& schema, const segcore::SealedIndexingRecord& record, const SearchInfo& search_info, const void* query_data, + const size_t* query_lims, int64_t num_queries, const BitsetView& bitset, SearchResult& search_result) { @@ -52,7 +53,18 @@ SearchOnSealedIndex(const Schema& schema, field_indexing->metric_type_, search_info.metric_type_); - auto dataset = knowhere::GenDataSet(num_queries, dim, query_data); + knowhere::DataSetPtr dataset; + if (query_lims == nullptr) { + dataset = knowhere::GenDataSet(num_queries, dim, query_data); + } else { + // Rather than non-embedding list search where num_queries equals to the number of vectors, + // in embedding list search, multiple vectors form an embedding list and the last element of query_lims + // stands for the total number of vectors. + auto num_vectors = query_lims[num_queries]; + dataset = knowhere::GenDataSet(num_vectors, dim, query_data); + dataset->SetLims(query_lims); + } + dataset->SetIsSparse(is_sparse); auto accessor = SemiInlineGet(field_indexing->indexing_->PinCells({0})); auto vec_index = @@ -92,6 +104,7 @@ SearchOnSealedColumn(const Schema& schema, const SearchInfo& search_info, const std::map& index_info, const void* query_data, + const size_t* query_lims, int64_t num_queries, int64_t row_count, const BitsetView& bitview, @@ -99,22 +112,26 @@ SearchOnSealedColumn(const Schema& schema, auto field_id = search_info.field_id_; auto& field = schema[field_id]; + auto data_type = field.get_data_type(); + auto element_type = field.get_element_type(); // TODO(SPARSE): see todo in PlanImpl.h::PlaceHolder. - auto dim = field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT - ? 0 - : field.get_dim(); + auto dim = data_type == DataType::VECTOR_SPARSE_FLOAT ? 0 : field.get_dim(); query::dataset::SearchDataset query_dataset{search_info.metric_type_, num_queries, search_info.topk_, search_info.round_decimal_, dim, - query_data}; + query_data, + query_lims}; - auto data_type = field.get_data_type(); CheckBruteForceSearchParam(field, search_info); if (search_info.iterator_v2_info_.has_value()) { + AssertInfo(data_type != DataType::VECTOR_ARRAY, + "vector array(embedding list) is not supported for " + "vector iterator"); + CachedSearchIterator cached_iter( column, query_dataset, search_info, index_info, bitview, data_type); cached_iter.NextBatch(search_info, result); @@ -135,7 +152,20 @@ SearchOnSealedColumn(const Schema& schema, auto chunk_size = column->chunk_row_nums(i); auto raw_dataset = query::dataset::RawDataset{offset, dim, chunk_size, vec_data}; + + PinWrapper lims_pw; + if (data_type == DataType::VECTOR_ARRAY) { + AssertInfo(query_lims != nullptr, + "query_lims is nullptr, but data_type is vector array"); + + lims_pw = column->VectorArrayLims(i); + raw_dataset.raw_data_lims = lims_pw.get(); + } + if (milvus::exec::UseVectorIterator(search_info)) { + AssertInfo(data_type != DataType::VECTOR_ARRAY, + "vector array(embedding list) is not supported for " + "vector iterator"); auto sub_qr = PackBruteForceSearchIteratorsIntoSubResult(query_dataset, raw_dataset, @@ -150,7 +180,8 @@ SearchOnSealedColumn(const Schema& schema, search_info, index_info, bitview, - data_type); + data_type, + element_type); final_qr.merge(sub_qr); } offset += chunk_size; diff --git a/internal/core/src/query/SearchOnSealed.h b/internal/core/src/query/SearchOnSealed.h index 9340c5a664..20314c69da 100644 --- a/internal/core/src/query/SearchOnSealed.h +++ b/internal/core/src/query/SearchOnSealed.h @@ -23,6 +23,7 @@ SearchOnSealedIndex(const Schema& schema, const segcore::SealedIndexingRecord& record, const SearchInfo& search_info, const void* query_data, + const size_t* query_lims, int64_t num_queries, const BitsetView& view, SearchResult& search_result); @@ -33,6 +34,7 @@ SearchOnSealedColumn(const Schema& schema, const SearchInfo& search_info, const std::map& index_info, const void* query_data, + const size_t* query_lims, int64_t num_queries, int64_t row_count, const BitsetView& bitset, diff --git a/internal/core/src/query/helper.h b/internal/core/src/query/helper.h index 56bf1d1261..034d4854fb 100644 --- a/internal/core/src/query/helper.h +++ b/internal/core/src/query/helper.h @@ -24,6 +24,7 @@ struct RawDataset { int64_t dim; int64_t num_raw_data; const void* raw_data; + const size_t* raw_data_lims = nullptr; }; struct SearchDataset { knowhere::MetricType metric_type; @@ -32,6 +33,8 @@ struct SearchDataset { int64_t round_decimal; int64_t dim; const void* query_data; + // used for embedding list query + const size_t* query_lims = nullptr; }; } // namespace dataset diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp index 878ad0d882..c088ef665d 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.cpp @@ -95,10 +95,6 @@ ChunkedSegmentSealedImpl::LoadIndex(const LoadIndexInfo& info) { auto field_id = FieldId(info.field_id); auto& field_meta = schema_->operator[](field_id); - if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { - ThrowInfo(DataTypeInvalid, "VECTOR_ARRAY is not implemented"); - } - if (field_meta.is_vector()) { LoadVecIndex(info); } else { @@ -127,6 +123,7 @@ ChunkedSegmentSealedImpl::LoadVecIndex(const LoadIndexInfo& info) { LoadResourceRequest request = milvus::index::IndexFactory::GetInstance().VecIndexLoadResource( field_meta.get_data_type(), + info.element_type, info.index_engine_version, info.index_size, info.index_params, @@ -498,10 +495,6 @@ int64_t ChunkedSegmentSealedImpl::num_chunk_index(FieldId field_id) const { auto& field_meta = schema_->operator[](field_id); - if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { - ThrowInfo(DataTypeInvalid, "VECTOR_ARRAY is not implemented"); - } - if (field_meta.is_vector()) { return int64_t(vector_indexings_.is_ready(field_id)); } @@ -720,6 +713,7 @@ ChunkedSegmentSealedImpl::mask_with_delete(BitsetTypeView& bitset, void ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, const void* query_data, + const size_t* query_lims, int64_t query_count, Timestamp timestamp, const BitsetView& bitset, @@ -745,6 +739,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, vector_indexings_, binlog_search_info, query_data, + query_lims, query_count, bitset, output); @@ -758,6 +753,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, vector_indexings_, search_info, query_data, + query_lims, query_count, bitset, output); @@ -782,6 +778,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info, search_info, index_info, query_data, + query_lims, query_count, row_count, bitset, diff --git a/internal/core/src/segcore/ChunkedSegmentSealedImpl.h b/internal/core/src/segcore/ChunkedSegmentSealedImpl.h index b4a63a2bcc..3e94aebe0b 100644 --- a/internal/core/src/segcore/ChunkedSegmentSealedImpl.h +++ b/internal/core/src/segcore/ChunkedSegmentSealedImpl.h @@ -395,6 +395,7 @@ class ChunkedSegmentSealedImpl : public SegmentSealed { void vector_search(SearchInfo& search_info, const void* query_data, + const size_t* query_lims, int64_t query_count, Timestamp timestamp, const BitsetView& bitset, diff --git a/internal/core/src/segcore/FieldIndexing.cpp b/internal/core/src/segcore/FieldIndexing.cpp index 88d6c1098e..27b719cb6b 100644 --- a/internal/core/src/segcore/FieldIndexing.cpp +++ b/internal/core/src/segcore/FieldIndexing.cpp @@ -47,6 +47,7 @@ VectorFieldIndexing::recreate_index(DataType data_type, const VectorBase* field_raw_data) { if (IsSparseFloatVectorDataType(data_type)) { index_ = std::make_unique>( + DataType::NONE, config_->GetIndexType(), config_->GetMetricType(), knowhere::Version::GetCurrentVersion().VersionNumber()); @@ -62,6 +63,7 @@ VectorFieldIndexing::recreate_index(DataType data_type, return (const void*)field_raw_data_ptr->get_element(id); }; index_ = std::make_unique>( + DataType::NONE, config_->GetIndexType(), config_->GetMetricType(), knowhere::Version::GetCurrentVersion().VersionNumber(), @@ -78,6 +80,7 @@ VectorFieldIndexing::recreate_index(DataType data_type, return (const void*)field_raw_data_ptr->get_element(id); }; index_ = std::make_unique>( + DataType::NONE, config_->GetIndexType(), config_->GetMetricType(), knowhere::Version::GetCurrentVersion().VersionNumber(), @@ -94,6 +97,7 @@ VectorFieldIndexing::recreate_index(DataType data_type, return (const void*)field_raw_data_ptr->get_element(id); }; index_ = std::make_unique>( + DataType::NONE, config_->GetIndexType(), config_->GetMetricType(), knowhere::Version::GetCurrentVersion().VersionNumber(), diff --git a/internal/core/src/segcore/FieldIndexing.h b/internal/core/src/segcore/FieldIndexing.h index 4df1231aeb..68d21bd764 100644 --- a/internal/core/src/segcore/FieldIndexing.h +++ b/internal/core/src/segcore/FieldIndexing.h @@ -292,8 +292,9 @@ class IndexingRecord { index_meta_->HasFiled(field_id)) { auto vec_field_meta = index_meta_->GetFieldIndexMeta(field_id); - //Disable growing index for flat - if (!vec_field_meta.IsFlatIndex()) { + //Disable growing index for flat and embedding list + if (!vec_field_meta.IsFlatIndex() && + field_meta.get_data_type() != DataType::VECTOR_ARRAY) { auto field_raw_data = insert_record->get_data_base(field_id); field_indexings_.try_emplace( diff --git a/internal/core/src/segcore/SegmentGrowingImpl.cpp b/internal/core/src/segcore/SegmentGrowingImpl.cpp index d9d6d08f15..2c4c8c887a 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.cpp +++ b/internal/core/src/segcore/SegmentGrowingImpl.cpp @@ -695,12 +695,19 @@ SegmentGrowingImpl::search_batch_pks( void SegmentGrowingImpl::vector_search(SearchInfo& search_info, const void* query_data, + const size_t* query_lims, int64_t query_count, Timestamp timestamp, const BitsetView& bitset, SearchResult& output) const { - query::SearchOnGrowing( - *this, search_info, query_data, query_count, timestamp, bitset, output); + query::SearchOnGrowing(*this, + search_info, + query_data, + query_lims, + query_count, + timestamp, + bitset, + output); } std::unique_ptr diff --git a/internal/core/src/segcore/SegmentGrowingImpl.h b/internal/core/src/segcore/SegmentGrowingImpl.h index db7de56c38..07adab0865 100644 --- a/internal/core/src/segcore/SegmentGrowingImpl.h +++ b/internal/core/src/segcore/SegmentGrowingImpl.h @@ -326,6 +326,7 @@ class SegmentGrowingImpl : public SegmentGrowing { void vector_search(SearchInfo& search_info, const void* query_data, + const size_t* query_lims, int64_t query_count, Timestamp timestamp, const BitsetView& bitset, diff --git a/internal/core/src/segcore/SegmentInterface.h b/internal/core/src/segcore/SegmentInterface.h index 0354130cd1..e26e03e0fb 100644 --- a/internal/core/src/segcore/SegmentInterface.h +++ b/internal/core/src/segcore/SegmentInterface.h @@ -394,9 +394,14 @@ class SegmentInternalInterface : public SegmentInterface { const std::string& nested_path) const override; public: + // `query_lims` is not null only for vector array (embedding list) search + // where it denotes the number of vectors in each embedding list. The length + // of `query_lims` is the number of queries in the search plus one (the first + // element in query_lims is 0). virtual void vector_search(SearchInfo& search_info, const void* query_data, + const size_t* query_lims, int64_t query_count, Timestamp timestamp, const BitsetView& bitset, diff --git a/internal/core/src/segcore/Types.h b/internal/core/src/segcore/Types.h index 032d5aa20a..0dcb227029 100644 --- a/internal/core/src/segcore/Types.h +++ b/internal/core/src/segcore/Types.h @@ -35,6 +35,8 @@ struct LoadIndexInfo { int64_t segment_id; int64_t field_id; DataType field_type; + // The element type of the field. It's DataType::NONE if field_type is array/vector_array. + DataType element_type; bool enable_mmap; std::string mmap_dir_path; int64_t index_id; diff --git a/internal/core/src/segcore/Utils.cpp b/internal/core/src/segcore/Utils.cpp index 8f1e44ba90..06dc899c65 100644 --- a/internal/core/src/segcore/Utils.cpp +++ b/internal/core/src/segcore/Utils.cpp @@ -668,7 +668,11 @@ MergeDataArray(std::vector& merge_bases, auto obj = vector_array->mutable_int8_vector(); obj->assign(data, dim * sizeof(int8)); } else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { - ThrowInfo(DataTypeInvalid, "VECTOR_ARRAY is not implemented"); + auto data = src_field_data->vectors().vector_array(); + auto obj = vector_array->mutable_vector_array(); + obj->set_element_type( + proto::schema::DataType(field_meta.get_element_type())); + obj->CopyFrom(data); } else { ThrowInfo(DataTypeInvalid, fmt::format("unsupported datatype {}", data_type)); diff --git a/internal/core/src/segcore/check_vec_index_c.cpp b/internal/core/src/segcore/check_vec_index_c.cpp index 5008a348fb..d917f00c73 100644 --- a/internal/core/src/segcore/check_vec_index_c.cpp +++ b/internal/core/src/segcore/check_vec_index_c.cpp @@ -15,7 +15,11 @@ #include "knowhere/comp/knowhere_check.h" bool -CheckVecIndexWithDataType(const char* index_type, enum CDataType data_type) { +CheckVecIndexWithDataType(const char* index_type, + enum CDataType data_type, + bool is_emb_list_data) { return knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck( - std::string(index_type), knowhere::VecType(data_type)); + std::string(index_type), + knowhere::VecType(data_type), + is_emb_list_data); } diff --git a/internal/core/src/segcore/check_vec_index_c.h b/internal/core/src/segcore/check_vec_index_c.h index 11496b582e..13e9d7f0bf 100644 --- a/internal/core/src/segcore/check_vec_index_c.h +++ b/internal/core/src/segcore/check_vec_index_c.h @@ -16,7 +16,9 @@ extern "C" { #endif #include "common/type_c.h" bool -CheckVecIndexWithDataType(const char* index_type, enum CDataType data_type); +CheckVecIndexWithDataType(const char* index_type, + enum CDataType data_type, + bool is_emb_list_data); #ifdef __cplusplus } diff --git a/internal/core/src/segcore/load_index_c.cpp b/internal/core/src/segcore/load_index_c.cpp index 55ac8891e1..9aa7453354 100644 --- a/internal/core/src/segcore/load_index_c.cpp +++ b/internal/core/src/segcore/load_index_c.cpp @@ -98,40 +98,6 @@ AppendIndexParam(CLoadIndexInfo c_load_index_info, } } -CStatus -AppendFieldInfo(CLoadIndexInfo c_load_index_info, - int64_t collection_id, - int64_t partition_id, - int64_t segment_id, - int64_t field_id, - enum CDataType field_type, - bool enable_mmap, - const char* mmap_dir_path) { - SCOPE_CGO_CALL_METRIC(); - - try { - auto load_index_info = - (milvus::segcore::LoadIndexInfo*)c_load_index_info; - load_index_info->collection_id = collection_id; - load_index_info->partition_id = partition_id; - load_index_info->segment_id = segment_id; - load_index_info->field_id = field_id; - load_index_info->field_type = milvus::DataType(field_type); - load_index_info->enable_mmap = enable_mmap; - load_index_info->mmap_dir_path = std::string(mmap_dir_path); - - auto status = CStatus(); - status.error_code = milvus::Success; - status.error_msg = ""; - return status; - } catch (std::exception& e) { - auto status = CStatus(); - status.error_code = milvus::UnexpectedError; - status.error_msg = strdup(e.what()); - return status; - } -} - CStatus appendVecIndex(CLoadIndexInfo c_load_index_info, CBinarySet c_binary_set) { SCOPE_CGO_CALL_METRIC(); @@ -252,6 +218,7 @@ EstimateLoadIndexResource(CLoadIndexInfo c_load_index_info) { auto load_index_info = (milvus::segcore::LoadIndexInfo*)c_load_index_info; auto field_type = load_index_info->field_type; + auto element_type = load_index_info->element_type; auto& index_params = load_index_info->index_params; bool find_index_type = index_params.count("index_type") > 0 ? true : false; @@ -261,6 +228,7 @@ EstimateLoadIndexResource(CLoadIndexInfo c_load_index_info) { LoadResourceRequest request = milvus::index::IndexFactory::GetInstance().IndexLoadResource( field_type, + element_type, load_index_info->index_engine_version, load_index_info->index_size, index_params, @@ -581,6 +549,8 @@ FinishLoadIndexInfo(CLoadIndexInfo c_load_index_info, load_index_info->field_id = info_proto->field().fieldid(); load_index_info->field_type = static_cast(info_proto->field().data_type()); + load_index_info->element_type = static_cast( + info_proto->field().element_type()); load_index_info->enable_mmap = info_proto->enable_mmap(); load_index_info->mmap_dir_path = info_proto->mmap_dir_path(); load_index_info->index_id = info_proto->indexid(); diff --git a/internal/core/src/segcore/load_index_c.h b/internal/core/src/segcore/load_index_c.h index 1b92ea778e..fb659238b5 100644 --- a/internal/core/src/segcore/load_index_c.h +++ b/internal/core/src/segcore/load_index_c.h @@ -38,16 +38,6 @@ AppendIndexParam(CLoadIndexInfo c_load_index_info, const char* index_key, const char* index_value); -CStatus -AppendFieldInfo(CLoadIndexInfo c_load_index_info, - int64_t collection_id, - int64_t partition_id, - int64_t segment_id, - int64_t field_id, - enum CDataType field_type, - bool enable_mmap, - const char* mmap_dir_path); - LoadResourceRequest EstimateLoadIndexResource(CLoadIndexInfo c_load_index_info); diff --git a/internal/core/src/segcore/reduce/Reduce.cpp b/internal/core/src/segcore/reduce/Reduce.cpp index 524923ff2b..2c0af354cd 100644 --- a/internal/core/src/segcore/reduce/Reduce.cpp +++ b/internal/core/src/segcore/reduce/Reduce.cpp @@ -440,7 +440,10 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) { ->set_element_type( proto::schema::DataType(field_meta.get_element_type())); } else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { - ThrowInfo(NotImplemented, "VECTOR_ARRAY is not implemented"); + field_data->mutable_vectors() + ->mutable_vector_array() + ->set_element_type( + proto::schema::DataType(field_meta.get_element_type())); } search_result_data->mutable_fields_data()->AddAllocated( field_data.release()); diff --git a/internal/core/src/segcore/reduce/StreamReduce.cpp b/internal/core/src/segcore/reduce/StreamReduce.cpp index a0728b2f7b..38b519da1a 100644 --- a/internal/core/src/segcore/reduce/StreamReduce.cpp +++ b/internal/core/src/segcore/reduce/StreamReduce.cpp @@ -155,7 +155,10 @@ StreamReducerHelper::AssembleMergedResult() { ->set_element_type( proto::schema::DataType(field_meta.get_element_type())); } else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { - ThrowInfo(NotImplemented, "VECTOR_ARRAY is not implemented"); + field_data->mutable_vectors() + ->mutable_vector_array() + ->set_element_type( + proto::schema::DataType(field_meta.get_element_type())); } new_merged_result->output_fields_data_[field_id] = @@ -674,7 +677,10 @@ StreamReducerHelper::GetSearchResultDataSlice(int slice_index) { ->set_element_type( proto::schema::DataType(field_meta.get_element_type())); } else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) { - ThrowInfo(NotImplemented, "VECTOR_ARRAY is not implemented"); + field_data->mutable_vectors() + ->mutable_vector_array() + ->set_element_type( + proto::schema::DataType(field_meta.get_element_type())); } search_result_data->mutable_fields_data()->AddAllocated( field_data.release()); diff --git a/internal/core/src/segcore/storagev1translator/InterimSealedIndexTranslator.cpp b/internal/core/src/segcore/storagev1translator/InterimSealedIndexTranslator.cpp index 0fe6d02715..33b8ef3d61 100644 --- a/internal/core/src/segcore/storagev1translator/InterimSealedIndexTranslator.cpp +++ b/internal/core/src/segcore/storagev1translator/InterimSealedIndexTranslator.cpp @@ -81,6 +81,7 @@ InterimSealedIndexTranslator::get_cells( if (vec_data_type_ == DataType::VECTOR_FLOAT) { vec_index = std::make_unique>( + DataType::NONE, index_type_, metric_type_, knowhere::Version::GetCurrentVersion().VersionNumber(), @@ -88,6 +89,7 @@ InterimSealedIndexTranslator::get_cells( false); } else if (vec_data_type_ == DataType::VECTOR_FLOAT16) { vec_index = std::make_unique>( + DataType::NONE, index_type_, metric_type_, knowhere::Version::GetCurrentVersion().VersionNumber(), @@ -95,6 +97,7 @@ InterimSealedIndexTranslator::get_cells( false); } else if (vec_data_type_ == DataType::VECTOR_BFLOAT16) { vec_index = std::make_unique>( + DataType::NONE, index_type_, metric_type_, knowhere::Version::GetCurrentVersion().VersionNumber(), @@ -103,6 +106,7 @@ InterimSealedIndexTranslator::get_cells( } } else { vec_index = std::make_unique>( + DataType::NONE, index_type_, metric_type_, knowhere::Version::GetCurrentVersion().VersionNumber(), diff --git a/internal/core/src/segcore/storagev1translator/SealedIndexTranslator.cpp b/internal/core/src/segcore/storagev1translator/SealedIndexTranslator.cpp index 0ffcbafd3a..a7cff71bed 100644 --- a/internal/core/src/segcore/storagev1translator/SealedIndexTranslator.cpp +++ b/internal/core/src/segcore/storagev1translator/SealedIndexTranslator.cpp @@ -22,6 +22,7 @@ SealedIndexTranslator::SealedIndexTranslator( index_load_info_({load_index_info->enable_mmap, load_index_info->mmap_dir_path, load_index_info->field_type, + load_index_info->element_type, load_index_info->index_params, load_index_info->index_size, load_index_info->index_engine_version, @@ -54,6 +55,7 @@ SealedIndexTranslator::estimated_byte_size_of_cell( LoadResourceRequest request = milvus::index::IndexFactory::GetInstance().IndexLoadResource( index_load_info_.field_type, + index_load_info_.element_type, index_load_info_.index_engine_version, index_load_info_.index_size, index_load_info_.index_params, diff --git a/internal/core/src/segcore/storagev1translator/SealedIndexTranslator.h b/internal/core/src/segcore/storagev1translator/SealedIndexTranslator.h index b880970745..375b047746 100644 --- a/internal/core/src/segcore/storagev1translator/SealedIndexTranslator.h +++ b/internal/core/src/segcore/storagev1translator/SealedIndexTranslator.h @@ -45,6 +45,7 @@ class SealedIndexTranslator bool enable_mmap; std::string mmap_dir_path; DataType field_type; + DataType element_type; std::map index_params; int64_t index_size; int64_t index_engine_version; diff --git a/internal/core/src/segcore/storagev1translator/V1SealedIndexTranslator.cpp b/internal/core/src/segcore/storagev1translator/V1SealedIndexTranslator.cpp index 3a1148871c..273348d379 100644 --- a/internal/core/src/segcore/storagev1translator/V1SealedIndexTranslator.cpp +++ b/internal/core/src/segcore/storagev1translator/V1SealedIndexTranslator.cpp @@ -13,6 +13,7 @@ V1SealedIndexTranslator::V1SealedIndexTranslator( load_index_info->enable_mmap, load_index_info->mmap_dir_path, load_index_info->field_type, + load_index_info->element_type, load_index_info->index_params, load_index_info->index_files, load_index_info->index_size, diff --git a/internal/core/src/segcore/storagev1translator/V1SealedIndexTranslator.h b/internal/core/src/segcore/storagev1translator/V1SealedIndexTranslator.h index 7a2cb6e516..971e39082e 100644 --- a/internal/core/src/segcore/storagev1translator/V1SealedIndexTranslator.h +++ b/internal/core/src/segcore/storagev1translator/V1SealedIndexTranslator.h @@ -44,6 +44,7 @@ class V1SealedIndexTranslator : public Translator { bool enable_mmap; std::string mmap_dir_path; DataType field_type; + DataType element_type; std::map index_params; std::vector index_files; int64_t index_size; diff --git a/internal/core/src/segcore/vector_index_c.cpp b/internal/core/src/segcore/vector_index_c.cpp index 43a55bc04c..9018ddc700 100644 --- a/internal/core/src/segcore/vector_index_c.cpp +++ b/internal/core/src/segcore/vector_index_c.cpp @@ -24,6 +24,7 @@ CStatus ValidateIndexParams(const char* index_type, enum CDataType data_type, + enum CDataType element_type, const uint8_t* serialized_index_params, const uint64_t length) { try { @@ -44,45 +45,64 @@ ValidateIndexParams(const char* index_type, knowhere::Status status; std::string error_msg; - if (dataType == milvus::DataType::VECTOR_BINARY) { - status = knowhere::IndexStaticFaced::ConfigCheck( - index_type, - knowhere::Version::GetCurrentVersion().VersionNumber(), - json, - error_msg); - } else if (dataType == milvus::DataType::VECTOR_FLOAT) { - status = knowhere::IndexStaticFaced::ConfigCheck( - index_type, - knowhere::Version::GetCurrentVersion().VersionNumber(), - json, - error_msg); - } else if (dataType == milvus::DataType::VECTOR_BFLOAT16) { - status = knowhere::IndexStaticFaced::ConfigCheck( - index_type, - knowhere::Version::GetCurrentVersion().VersionNumber(), - json, - error_msg); - } else if (dataType == milvus::DataType::VECTOR_FLOAT16) { - status = knowhere::IndexStaticFaced::ConfigCheck( - index_type, - knowhere::Version::GetCurrentVersion().VersionNumber(), - json, - error_msg); - } else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) { - status = knowhere::IndexStaticFaced::ConfigCheck( - index_type, - knowhere::Version::GetCurrentVersion().VersionNumber(), - json, - error_msg); - } else if (dataType == milvus::DataType::VECTOR_INT8) { - status = knowhere::IndexStaticFaced::ConfigCheck( - index_type, - knowhere::Version::GetCurrentVersion().VersionNumber(), - json, - error_msg); + auto check_leaf_type = [&index_type, &json, &error_msg, &status]( + milvus::DataType dataType) { + if (dataType == milvus::DataType::VECTOR_BINARY) { + status = + knowhere::IndexStaticFaced::ConfigCheck( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + error_msg); + } else if (dataType == milvus::DataType::VECTOR_FLOAT) { + status = + knowhere::IndexStaticFaced::ConfigCheck( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + error_msg); + } else if (dataType == milvus::DataType::VECTOR_BFLOAT16) { + status = + knowhere::IndexStaticFaced::ConfigCheck( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + error_msg); + } else if (dataType == milvus::DataType::VECTOR_FLOAT16) { + status = + knowhere::IndexStaticFaced::ConfigCheck( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + error_msg); + } else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) { + status = + knowhere::IndexStaticFaced::ConfigCheck( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + error_msg); + } else if (dataType == milvus::DataType::VECTOR_INT8) { + status = + knowhere::IndexStaticFaced::ConfigCheck( + index_type, + knowhere::Version::GetCurrentVersion().VersionNumber(), + json, + error_msg); + } else { + status = knowhere::Status::invalid_args; + } + }; + + if (dataType == milvus::DataType::VECTOR_ARRAY) { + milvus::DataType elementType( + static_cast(element_type)); + + check_leaf_type(elementType); } else { - status = knowhere::Status::invalid_args; + check_leaf_type(dataType); } + CStatus cStatus; if (status == knowhere::Status::success) { cStatus.error_code = milvus::Success; diff --git a/internal/core/src/segcore/vector_index_c.h b/internal/core/src/segcore/vector_index_c.h index 7e9b8f5239..45ec40080a 100644 --- a/internal/core/src/segcore/vector_index_c.h +++ b/internal/core/src/segcore/vector_index_c.h @@ -20,6 +20,7 @@ extern "C" { CStatus ValidateIndexParams(const char* index_type, enum CDataType data_type, + enum CDataType element_type, const uint8_t* index_params, const uint64_t length); diff --git a/internal/core/unittest/test_bf.cpp b/internal/core/unittest/test_bf.cpp index 4428a20eef..6588824c0b 100644 --- a/internal/core/unittest/test_bf.cpp +++ b/internal/core/unittest/test_bf.cpp @@ -141,7 +141,8 @@ class TestFloatSearchBruteForce : public ::testing::Test { search_info, index_info, bitset_view, - DataType::VECTOR_FLOAT); + DataType::VECTOR_FLOAT, + DataType::NONE); for (int i = 0; i < nq; i++) { auto ref = Ref(base.data(), query.data() + i * dim, diff --git a/internal/core/unittest/test_bf_sparse.cpp b/internal/core/unittest/test_bf_sparse.cpp index 4f82e42ecb..a707454840 100644 --- a/internal/core/unittest/test_bf_sparse.cpp +++ b/internal/core/unittest/test_bf_sparse.cpp @@ -113,7 +113,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { search_info, index_info, bitset_view, - DataType::VECTOR_SPARSE_FLOAT)); + DataType::VECTOR_SPARSE_FLOAT, + DataType::NONE)); return; } auto result = BruteForceSearch(query_dataset, @@ -121,7 +122,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { search_info, index_info, bitset_view, - DataType::VECTOR_SPARSE_FLOAT); + DataType::VECTOR_SPARSE_FLOAT, + DataType::NONE); for (int i = 0; i < nq; i++) { auto ref = SearchRef(base.get(), *(query.get() + i), nb, topk); auto ans = result.get_seg_offsets() + i * topk; @@ -135,7 +137,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test { search_info, index_info, bitset_view, - DataType::VECTOR_SPARSE_FLOAT); + DataType::VECTOR_SPARSE_FLOAT, + DataType::NONE); for (int i = 0; i < nq; i++) { auto ref = RangeSearchRef( base.get(), *(query.get() + i), nb, 0.1, 0.5, topk); diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index 3966944ec8..51233b5a4a 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -1936,6 +1936,7 @@ TEST(CApiTest, LoadIndexSearch) { auto& index_params = load_index_info.index_params; index_params["index_type"] = knowhere::IndexEnum::INDEX_FAISS_IVFSQ8; auto index = std::make_unique>( + DataType::NONE, index_params["index_type"], knowhere::metric::L2, knowhere::Version::GetCurrentVersion().VersionNumber()); diff --git a/internal/core/unittest/test_chunked_segment.cpp b/internal/core/unittest/test_chunked_segment.cpp index 7d0c4842a6..c8887929e6 100644 --- a/internal/core/unittest/test_chunked_segment.cpp +++ b/internal/core/unittest/test_chunked_segment.cpp @@ -128,6 +128,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) { search_info, index_info, query_data, + nullptr, 1, total_row_count, bv, @@ -153,6 +154,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) { search_info, index_info, query_data, + nullptr, 1, total_row_count, bv, diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index 29efa3850b..c6204c119b 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -16674,6 +16674,7 @@ TEST(JsonIndexTest, TestJsonNotEqualExpr) { file_manager_ctx.fieldDataMeta.field_schema.set_data_type( milvus::proto::schema::JSON); file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get()); + file_manager_ctx.fieldDataMeta.field_id = json_fid.get(); auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex( index::CreateIndexInfo{ @@ -16784,6 +16785,7 @@ TEST_P(JsonIndexExistsTest, TestExistsExpr) { milvus::proto::schema::JSON); file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get()); file_manager_ctx.fieldDataMeta.field_schema.set_nullable(true); + file_manager_ctx.fieldDataMeta.field_id = json_fid.get(); auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex( index::CreateIndexInfo{ .index_type = index::INVERTED_INDEX_TYPE, @@ -16971,6 +16973,7 @@ TEST_P(JsonIndexBinaryExprTest, TestBinaryRangeExpr) { file_manager_ctx.fieldDataMeta.field_schema.set_data_type( milvus::proto::schema::JSON); file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get()); + file_manager_ctx.fieldDataMeta.field_id = json_fid.get(); auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex( index::CreateIndexInfo{ diff --git a/internal/core/unittest/test_growing.cpp b/internal/core/unittest/test_growing.cpp index c1016b9cd3..c94156483b 100644 --- a/internal/core/unittest/test_growing.cpp +++ b/internal/core/unittest/test_growing.cpp @@ -540,4 +540,88 @@ TEST(GrowingTest, LoadVectorArrayData) { array_vec_values[ids_ds->GetIds()[i]].float_vector().data(); verify_float_vectors(arrow_array, expected_array); } -} \ No newline at end of file +} + +TEST(GrowingTest, SearchVectorArray) { + using namespace milvus::query; + + auto schema = std::make_shared(); + auto metric_type = knowhere::metric::MAX_SIM; + + // Add fields + auto int64_field = schema->AddDebugField("int64", DataType::INT64); + auto array_vec = schema->AddDebugVectorArrayField( + "array_vec", DataType::VECTOR_FLOAT, 128, metric_type); + schema->set_primary_field_id(int64_field); + + // Configure segment + auto config = SegcoreConfig::default_config(); + config.set_chunk_rows(1024); + config.set_enable_interim_segment_index(true); + + std::map index_params = { + {"index_type", knowhere::IndexEnum::INDEX_EMB_LIST_HNSW}, + {"metric_type", metric_type}, + {"nlist", "128"}}; + std::map type_params = {{"dim", "128"}}; + FieldIndexMeta fieldIndexMeta( + array_vec, std::move(index_params), std::move(type_params)); + std::map fieldMap = {{array_vec, fieldIndexMeta}}; + + IndexMetaPtr metaPtr = + std::make_shared(100000, std::move(fieldMap)); + auto segment = CreateGrowingSegment(schema, metaPtr, 1, config); + auto segmentImplPtr = dynamic_cast(segment.get()); + + // Insert data + int64_t N = 100; + uint64_t seed = 42; + int emb_list_len = 5; // Each row contains 5 vectors + auto dataset = DataGen(schema, N, seed, 0, 1, emb_list_len); + + auto offset = 0; + segment->Insert(offset, + N, + dataset.row_ids_.data(), + dataset.timestamps_.data(), + dataset.raw_); + + // Prepare search query + int vec_num = 10; // Total number of query vectors + int dim = 128; + std::vector query_vec = generate_float_vector(vec_num, dim); + + // Create query dataset with lims for VectorArray + std::vector query_vec_lims; + query_vec_lims.push_back(0); // First query has 3 vectors + query_vec_lims.push_back(3); + query_vec_lims.push_back(10); // Second query has 7 vectors + + // Create search plan + const char* raw_plan = R"(vector_anns: < + field_id: 101 + query_info: < + topk: 5 + round_decimal: 3 + metric_type: "MAX_SIM" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + + auto plan_str = translate_text_plan_to_binary_plan(raw_plan); + auto plan = + CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size()); + + // Use CreatePlaceholderGroupFromBlob for VectorArray + auto ph_group_raw = CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_lims); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + + // Execute search + Timestamp timestamp = 10000000; + auto sr = segment->Search(plan.get(), ph_group.get(), timestamp); + auto sr_parsed = SearchResultToJson(*sr); + std::cout << sr_parsed.dump(1) << std::endl; +} diff --git a/internal/core/unittest/test_growing_index.cpp b/internal/core/unittest/test_growing_index.cpp index be0df50910..38939221bc 100644 --- a/internal/core/unittest/test_growing_index.cpp +++ b/internal/core/unittest/test_growing_index.cpp @@ -360,6 +360,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) { if (data_type == DataType::VECTOR_FLOAT) { auto index = std::make_unique>( + DataType::NONE, index_type, metric_type, knowhere::Version::GetCurrentVersion().VersionNumber(), @@ -375,6 +376,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) { EXPECT_EQ(index->Count(), (add_cont + 1) * N); } else if (data_type == DataType::VECTOR_FLOAT16) { auto index = std::make_unique>( + DataType::NONE, index_type, metric_type, knowhere::Version::GetCurrentVersion().VersionNumber(), @@ -391,6 +393,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) { EXPECT_EQ(index->Count(), (add_cont + 1) * N); } else if (data_type == DataType::VECTOR_BFLOAT16) { auto index = std::make_unique>( + DataType::NONE, index_type, metric_type, knowhere::Version::GetCurrentVersion().VersionNumber(), @@ -407,6 +410,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) { EXPECT_EQ(index->Count(), (add_cont + 1) * N); } else if (is_sparse) { auto index = std::make_unique>( + DataType::NONE, index_type, metric_type, knowhere::Version::GetCurrentVersion().VersionNumber(), diff --git a/internal/core/unittest/test_indexing.cpp b/internal/core/unittest/test_indexing.cpp index 3168bd48c7..e07708d209 100644 --- a/internal/core/unittest/test_indexing.cpp +++ b/internal/core/unittest/test_indexing.cpp @@ -184,7 +184,8 @@ TEST(Indexing, BinaryBruteForce) { search_info, index_info, nullptr, - DataType::VECTOR_BINARY); + DataType::VECTOR_BINARY, + DataType::NONE); SearchResult sr; sr.total_nq_ = num_queries; diff --git a/internal/core/unittest/test_json_flat_index.cpp b/internal/core/unittest/test_json_flat_index.cpp index b81a23639d..60346a1922 100644 --- a/internal/core/unittest/test_json_flat_index.cpp +++ b/internal/core/unittest/test_json_flat_index.cpp @@ -564,6 +564,7 @@ class JsonFlatIndexExprTest : public ::testing::Test { file_manager_ctx.fieldDataMeta.field_schema.set_fieldid( json_fid_.get()); file_manager_ctx.fieldDataMeta.field_schema.set_nullable(true); + file_manager_ctx.fieldDataMeta.field_id = json_fid_.get(); auto index = index::IndexFactory::GetInstance().CreateJsonIndex( index::CreateIndexInfo{ .index_type = index::INVERTED_INDEX_TYPE, diff --git a/internal/core/unittest/test_ngram_query.cpp b/internal/core/unittest/test_ngram_query.cpp index a476835278..6d833dcece 100644 --- a/internal/core/unittest/test_ngram_query.cpp +++ b/internal/core/unittest/test_ngram_query.cpp @@ -45,13 +45,13 @@ test_ngram_with_data(const boost::container::vector& data, auto schema = std::make_shared(); auto field_id = schema->AddDebugField("ngram", DataType::VARCHAR); - auto field_meta = gen_field_meta(collection_id, - partition_id, - segment_id, - field_id.get(), - DataType::VARCHAR, - DataType::NONE, - false); + auto field_meta = milvus::segcore::gen_field_meta(collection_id, + partition_id, + segment_id, + field_id.get(), + DataType::VARCHAR, + DataType::NONE, + false); auto index_meta = gen_index_meta( segment_id, field_id.get(), index_build_id, index_version); diff --git a/internal/core/unittest/test_sealed.cpp b/internal/core/unittest/test_sealed.cpp index 2e5d337a13..9ff7d5d13d 100644 --- a/internal/core/unittest/test_sealed.cpp +++ b/internal/core/unittest/test_sealed.cpp @@ -19,6 +19,7 @@ #include "knowhere/version.h" #include "storage/RemoteChunkManagerSingleton.h" #include "storage/Util.h" +#include "common/VectorArray.h" #include "test_utils/cachinglayer_test_utils.h" #include "test_utils/DataGen.h" @@ -2333,3 +2334,257 @@ TEST(Sealed, QueryVectorArrayAllFields) { EXPECT_EQ(int64_result->valid_data_size(), 0); EXPECT_EQ(array_float_vector_result->valid_data_size(), 0); } + +TEST(Sealed, SearchVectorArray) { + int64_t collection_id = 1; + int64_t partition_id = 2; + int64_t segment_id = 3; + int64_t index_build_id = 4000; + int64_t index_version = 4000; + int64_t index_id = 5000; + + auto schema = std::make_shared(); + auto metric_type = knowhere::metric::L2; + auto int64_field = schema->AddDebugField("int64", DataType::INT64); + auto array_vec = schema->AddDebugVectorArrayField( + "array_vec", DataType::VECTOR_FLOAT, 128, metric_type); + schema->set_primary_field_id(int64_field); + + auto field_meta = milvus::segcore::gen_field_meta(collection_id, + partition_id, + segment_id, + array_vec.get(), + DataType::VECTOR_ARRAY, + DataType::VECTOR_FLOAT, + false); + auto index_meta = gen_index_meta( + segment_id, array_vec.get(), index_build_id, index_version); + + std::map filedMap{}; + IndexMetaPtr metaPtr = + std::make_shared(100000, std::move(filedMap)); + + int64_t dataset_size = 1000; + int64_t dim = 128; + auto emb_list_len = 10; + auto dataset = DataGen(schema, dataset_size, 42, 0, 1, emb_list_len); + + // create field data + std::string root_path = "/tmp/test-vector-array/"; + auto storage_config = gen_local_storage_config(root_path); + auto cm = CreateChunkManager(storage_config); + auto vec_array_col = dataset.get_col(array_vec); + std::vector vector_arrays; + for (auto& v : vec_array_col) { + vector_arrays.push_back(milvus::VectorArray(v)); + } + auto field_data = storage::CreateFieldData(DataType::VECTOR_ARRAY, false); + field_data->FillFieldData(vector_arrays.data(), vector_arrays.size()); + + // create sealed segment + auto segment = CreateSealedSegment(schema); + auto field_data_info = PrepareSingleFieldInsertBinlog(collection_id, + partition_id, + segment_id, + array_vec.get(), + {field_data}, + cm); + segment->LoadFieldData(field_data_info); + + // serialize bin logs + auto payload_reader = + std::make_shared(field_data); + storage::InsertData insert_data(payload_reader); + insert_data.SetFieldDataMeta(field_meta); + insert_data.SetTimestamps(0, 100); + + auto serialized_bytes = insert_data.Serialize(storage::Remote); + + auto get_binlog_path = [=](int64_t log_id) { + return fmt::format("{}/{}/{}/{}/{}", + collection_id, + partition_id, + segment_id, + array_vec.get(), + log_id); + }; + + auto log_path = get_binlog_path(0); + + auto cm_w = ChunkManagerWrapper(cm); + cm_w.Write(log_path, serialized_bytes.data(), serialized_bytes.size()); + + storage::FileManagerContext ctx(field_meta, index_meta, cm); + std::vector index_files; + + // create index + milvus::index::CreateIndexInfo create_index_info; + create_index_info.field_type = DataType::VECTOR_ARRAY; + create_index_info.metric_type = knowhere::metric::MAX_SIM; + create_index_info.index_type = knowhere::IndexEnum::INDEX_EMB_LIST_HNSW; + create_index_info.index_engine_version = + knowhere::Version::GetCurrentVersion().VersionNumber(); + + auto emb_list_hnsw_index = + milvus::index::IndexFactory::GetInstance().CreateIndex( + create_index_info, + storage::FileManagerContext(field_meta, index_meta, cm)); + + // build index + Config config; + config[milvus::index::INDEX_TYPE] = + knowhere::IndexEnum::INDEX_EMB_LIST_HNSW; + config[INSERT_FILES_KEY] = std::vector{log_path}; + config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type; + config[knowhere::indexparam::M] = "16"; + config[knowhere::indexparam::EF] = "10"; + config[DIM_KEY] = dim; + emb_list_hnsw_index->Build(config); + + auto vec_index = + dynamic_cast(emb_list_hnsw_index.get()); + EXPECT_EQ(vec_index->Count(), dataset_size * emb_list_len); + EXPECT_EQ(vec_index->GetDim(), dim); + + // search + auto vec_num = 10; + std::vector query_vec = generate_float_vector(vec_num, dim); + auto query_dataset = knowhere::GenDataSet(vec_num, dim, query_vec.data()); + std::vector query_vec_lims; + query_vec_lims.push_back(0); + query_vec_lims.push_back(3); + query_vec_lims.push_back(10); + query_dataset->SetLims(query_vec_lims.data()); + + auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}}; + milvus::SearchInfo searchInfo; + searchInfo.topk_ = 5; + searchInfo.metric_type_ = knowhere::metric::L2; + searchInfo.search_params_ = search_conf; + SearchResult result; + vec_index->Query(query_dataset, searchInfo, nullptr, result); + auto ref_result = SearchResultToJson(result); + std::cout << ref_result.dump(1) << std::endl; + EXPECT_EQ(result.total_nq_, 2); + EXPECT_EQ(result.distances_.size(), 2 * searchInfo.topk_); + + // create sealed segment + auto sealed_segment = CreateSealedWithFieldDataLoaded(schema, dataset); + + // brute force search + { + const char* raw_plan = R"(vector_anns: < + field_id: 101 + query_info: < + topk: 5 + round_decimal: 3 + metric_type: "MAX_SIM" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto plan_str = translate_text_plan_to_binary_plan(raw_plan); + auto plan = + CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size()); + auto ph_group_raw = CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_lims); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + Timestamp timestamp = 1000000; + std::vector ph_group_arr = {ph_group.get()}; + + auto sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp); + auto sr_parsed = SearchResultToJson(*sr); + std::cout << sr_parsed.dump(1) << std::endl; + } + + // // brute force search with iterative filter + // { + // auto [min, max] = + // std::minmax_element(int_values.begin(), int_values.end()); + // auto min_val = *min; + // auto max_val = *max; + + // auto raw_plan = fmt::format(R"(vector_anns: < + // field_id: 101 + // predicates: < + // binary_range_expr: < + // column_info: < + // field_id: 100 + // data_type: Int64 + // > + // lower_inclusive: true + // upper_inclusive: true + // lower_value: < + // int64_val: {} + // > + // upper_value: < + // int64_val: {} + // > + // > + // > + // query_info: < + // topk: 5 + // round_decimal: 3 + // metric_type: "MAX_SIM" + // hints: "iterative_filter" + // search_params: "{{\"nprobe\": 10}}" + // > + // placeholder_tag: "$0" + // >)", + // min_val, + // max_val); + // auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str()); + // auto plan = + // CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size()); + // auto ph_group_raw = CreatePlaceholderGroupFromBlob( + // vec_num, dim, query_vec.data(), query_vec_lims); + // auto ph_group = + // ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + // Timestamp timestamp = 1000000; + // std::vector ph_group_arr = {ph_group.get()}; + + // auto sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp); + // auto sr_parsed = SearchResultToJson(*sr); + // std::cout << sr_parsed.dump(1) << std::endl; + // } + + // search with index + { + LoadIndexInfo load_info; + load_info.field_id = array_vec.get(); + load_info.field_type = DataType::VECTOR_ARRAY; + load_info.element_type = DataType::VECTOR_FLOAT; + load_info.index_params = GenIndexParams(emb_list_hnsw_index.get()); + load_info.cache_index = + CreateTestCacheIndex("test", std::move(emb_list_hnsw_index)); + load_info.index_params["metric_type"] = knowhere::metric::MAX_SIM; + + sealed_segment->DropFieldData(array_vec); + sealed_segment->LoadIndex(load_info); + + const char* raw_plan = R"(vector_anns: < + field_id: 101 + query_info: < + topk: 5 + round_decimal: 3 + metric_type: "MAX_SIM" + search_params: "{\"nprobe\": 10}" + > + placeholder_tag: "$0" + >)"; + auto plan_str = translate_text_plan_to_binary_plan(raw_plan); + auto plan = + CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size()); + auto ph_group_raw = CreatePlaceholderGroupFromBlob( + vec_num, dim, query_vec.data(), query_vec_lims); + auto ph_group = + ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); + Timestamp timestamp = 1000000; + std::vector ph_group_arr = {ph_group.get()}; + + auto sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp); + auto sr_parsed = SearchResultToJson(*sr); + std::cout << sr_parsed.dump(1) << std::endl; + } +} diff --git a/internal/core/unittest/test_string_expr.cpp b/internal/core/unittest/test_string_expr.cpp index 86400b83fb..c3cee52b90 100644 --- a/internal/core/unittest/test_string_expr.cpp +++ b/internal/core/unittest/test_string_expr.cpp @@ -1562,7 +1562,8 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) { search_info, index_info, nullptr, - DataType::VECTOR_FLOAT); + DataType::VECTOR_FLOAT, + DataType::NONE); auto sr = segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP); segment->FillPrimaryKeys(plan.get(), *sr); diff --git a/internal/core/unittest/test_utils/DataGen.h b/internal/core/unittest/test_utils/DataGen.h index b4e0aaa1b5..372b53d33c 100644 --- a/internal/core/unittest/test_utils/DataGen.h +++ b/internal/core/unittest/test_utils/DataGen.h @@ -1004,7 +1004,10 @@ CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) { template inline auto -CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const void* src) { +CreatePlaceholderGroupFromBlob(int64_t num_queries, + int dim, + const void* src, + std::vector offsets = {}) { if (std::is_same_v) { assert(dim % 8 == 0); } @@ -1017,12 +1020,27 @@ CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const void* src) { value->set_type(TraitType::placeholder_type); int64_t src_index = 0; - for (int i = 0; i < num_queries; ++i) { - std::vector vec; - for (int d = 0; d < dim / TraitType::dim_factor; ++d) { - vec.push_back(((elem_type*)src)[src_index++]); + if (offsets.empty()) { + for (int i = 0; i < num_queries; ++i) { + std::vector vec; + for (int d = 0; d < dim / TraitType::dim_factor; ++d) { + vec.push_back(((elem_type*)src)[src_index++]); + } + value->add_values(vec.data(), vec.size() * sizeof(elem_type)); + } + } else { + assert(offsets.back() == num_queries); + for (int i = 0; i < offsets.size() - 1; i++) { + auto start = offsets[i]; + auto end = offsets[i + 1]; + std::vector vec; + for (int j = start; j < end; j++) { + for (int d = 0; d < dim / TraitType::dim_factor; ++d) { + vec.push_back(((elem_type*)src)[src_index++]); + } + } + value->add_values(vec.data(), vec.size() * sizeof(elem_type)); } - value->add_values(vec.data(), vec.size() * sizeof(elem_type)); } return raw_group; } @@ -1362,6 +1380,7 @@ GenVecIndexing(int64_t N, milvus::storage::FileManagerContext file_manager_context( field_data_meta, index_meta, chunk_manager); auto indexing = std::make_unique>( + DataType::NONE, index_type, knowhere::metric::L2, knowhere::Version::GetCurrentVersion().VersionNumber(), @@ -1631,4 +1650,29 @@ GenChunkedSegmentTestSchema(bool pk_is_string) { return schema; } +inline std::vector +generate_float_vector(int64_t N, int64_t dim) { + auto seed = 42; + auto offset = 0; + std::vector final(dim * N); + for (int n = 0; n < N; ++n) { + std::vector data(dim); + float sum = 0; + + std::default_random_engine er2(seed + n); + std::normal_distribution<> distr2(0, 1); + for (auto& x : data) { + x = distr2(er2) + offset++; + sum += x * x; + } + sum = sqrt(sum); + for (auto& x : data) { + x /= sum; + } + + std::copy(data.begin(), data.end(), final.begin() + dim * n); + } + return final; +}; + } // namespace milvus::segcore diff --git a/internal/core/unittest/test_utils/c_api_test_utils.h b/internal/core/unittest/test_utils/c_api_test_utils.h index 979178c835..3ca4c9d7ce 100644 --- a/internal/core/unittest/test_utils/c_api_test_utils.h +++ b/internal/core/unittest/test_utils/c_api_test_utils.h @@ -31,6 +31,7 @@ #include "segcore/segment_c.h" #include "futures/Future.h" #include "futures/future_c.h" +#include "segcore/load_index_c.h" #include "DataGen.h" #include "PbHelper.h" #include "indexbuilder_test_utils.h" @@ -38,6 +39,39 @@ using namespace milvus; using namespace milvus::segcore; +// Test utility function for AppendFieldInfo +inline CStatus +AppendFieldInfo(CLoadIndexInfo c_load_index_info, + int64_t collection_id, + int64_t partition_id, + int64_t segment_id, + int64_t field_id, + enum CDataType field_type, + bool enable_mmap, + const char* mmap_dir_path) { + try { + auto load_index_info = + (milvus::segcore::LoadIndexInfo*)c_load_index_info; + load_index_info->collection_id = collection_id; + load_index_info->partition_id = partition_id; + load_index_info->segment_id = segment_id; + load_index_info->field_id = field_id; + load_index_info->field_type = milvus::DataType(field_type); + load_index_info->enable_mmap = enable_mmap; + load_index_info->mmap_dir_path = std::string(mmap_dir_path); + + auto status = CStatus(); + status.error_code = milvus::Success; + status.error_msg = ""; + return status; + } catch (std::exception& e) { + auto status = CStatus(); + status.error_code = milvus::UnexpectedError; + status.error_msg = strdup(e.what()); + return status; + } +} + namespace { std::string diff --git a/internal/core/unittest/test_vector_array.cpp b/internal/core/unittest/test_vector_array.cpp index 3f5d568d34..df5f48bdbe 100644 --- a/internal/core/unittest/test_vector_array.cpp +++ b/internal/core/unittest/test_vector_array.cpp @@ -109,7 +109,7 @@ TEST(VectorArray, TestConstructVectorArray) { field_float_vector_array.mutable_float_vector()->mutable_data()->Add( data.begin(), data.end()); - auto float_vector_array = VectorArray(field_float_vector_array); + auto float_vector_array = milvus::VectorArray(field_float_vector_array); ASSERT_EQ(float_vector_array.length(), N); ASSERT_EQ(float_vector_array.dim(), dim); ASSERT_EQ(float_vector_array.get_element_type(), DataType::VECTOR_FLOAT); @@ -117,16 +117,16 @@ TEST(VectorArray, TestConstructVectorArray) { ASSERT_TRUE(float_vector_array.is_same_array(field_float_vector_array)); - auto float_vector_array_tmp = VectorArray(float_vector_array); + auto float_vector_array_tmp = milvus::VectorArray(float_vector_array); ASSERT_TRUE(float_vector_array_tmp.is_same_array(field_float_vector_array)); auto float_vector_array_view = - VectorArrayView(const_cast(float_vector_array.data()), - float_vector_array.length(), - float_vector_array.dim(), - float_vector_array.byte_size(), - float_vector_array.get_element_type()); + milvus::VectorArrayView(const_cast(float_vector_array.data()), + float_vector_array.length(), + float_vector_array.dim(), + float_vector_array.byte_size(), + float_vector_array.get_element_type()); ASSERT_TRUE( float_vector_array_view.is_same_array(field_float_vector_array)); diff --git a/internal/datacoord/index_service.go b/internal/datacoord/index_service.go index 651ac477ef..40af386962 100644 --- a/internal/datacoord/index_service.go +++ b/internal/datacoord/index_service.go @@ -70,13 +70,29 @@ func (s *Server) getSchema(ctx context.Context, collID int64) (*schemapb.Collect return resp.GetSchema(), nil } -func isJsonField(schema *schemapb.CollectionSchema, fieldID int64) (bool, error) { +func FieldExists(schema *schemapb.CollectionSchema, fieldID int64) bool { for _, f := range schema.Fields { if f.FieldID == fieldID { - return typeutil.IsJSONType(f.DataType), nil + return true } } - return false, merr.WrapErrFieldNotFound(fieldID) + for _, structField := range schema.StructArrayFields { + for _, f := range structField.Fields { + if f.FieldID == fieldID { + return true + } + } + } + return false +} + +func isJsonField(schema *schemapb.CollectionSchema, fieldID int64) bool { + for _, f := range schema.Fields { + if f.FieldID == fieldID { + return typeutil.IsJSONType(f.DataType) + } + } + return false } func getIndexParam(indexParams []*commonpb.KeyValuePair, key string) (string, error) { @@ -154,11 +170,12 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques if err != nil { return merr.Status(err), nil } - isJson, err := isJsonField(schema, req.GetFieldID()) - if err != nil { - return merr.Status(err), nil + + if !FieldExists(schema, req.GetFieldID()) { + return merr.Status(merr.WrapErrFieldNotFound(req.GetFieldID())), nil } + isJson := isJsonField(schema, req.GetFieldID()) if isJson { // check json_path and json_cast_type exist jsonPath, err := getIndexParam(req.GetIndexParams(), common.JSONPathKey) diff --git a/internal/datacoord/task_index.go b/internal/datacoord/task_index.go index 1e58145344..bc949f8ea3 100644 --- a/internal/datacoord/task_index.go +++ b/internal/datacoord/task_index.go @@ -250,7 +250,8 @@ func (it *indexBuildTask) prepareJobRequest(ctx context.Context, segment *Segmen schema := collectionInfo.Schema var field *schemapb.FieldSchema - for _, f := range schema.Fields { + allFields := typeutil.GetAllFieldSchemas(schema) + for _, f := range allFields { if f.FieldID == fieldID { field = f break @@ -263,7 +264,11 @@ func (it *indexBuildTask) prepareJobRequest(ctx context.Context, segment *Segmen // Extract dim only for vector types to avoid unnecessary warnings dim := -1 - if typeutil.IsFixDimVectorType(field.GetDataType()) { + dataType := field.GetDataType() + if typeutil.IsVectorArrayType(dataType) { + dataType = field.GetElementType() + } + if typeutil.IsFixDimVectorType(dataType) { if dimVal, err := storage.GetDimFromParams(field.GetTypeParams()); err != nil { log.Warn("failed to get dim from field type params", zap.String("field type", field.GetDataType().String()), zap.Error(err)) diff --git a/internal/parser/planparserv2/plan_parser_v2.go b/internal/parser/planparserv2/plan_parser_v2.go index a04a1fee88..b6ae5d45af 100644 --- a/internal/parser/planparserv2/plan_parser_v2.go +++ b/internal/parser/planparserv2/plan_parser_v2.go @@ -180,6 +180,7 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField // plan ok with schema, check ann field fieldID := vectorField.FieldID dataType := vectorField.DataType + elementType := vectorField.ElementType var vectorType planpb.VectorType if !typeutil.IsVectorType(dataType) { @@ -198,6 +199,15 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField vectorType = planpb.VectorType_SparseFloatVector case schemapb.DataType_Int8Vector: vectorType = planpb.VectorType_Int8Vector + case schemapb.DataType_ArrayOfVector: + switch elementType { + case schemapb.DataType_FloatVector: + vectorType = planpb.VectorType_EmbListFloatVector + default: + log.Error("Invalid elementType", zap.Any("elementType", elementType)) + return nil, err + } + default: log.Error("Invalid dataType", zap.Any("dataType", dataType)) return nil, err diff --git a/internal/proxy/cgo_util.go b/internal/proxy/cgo_util.go index 6ae963d89f..910044dd61 100644 --- a/internal/proxy/cgo_util.go +++ b/internal/proxy/cgo_util.go @@ -62,13 +62,18 @@ func GetDynamicPool() *conc.Pool[any] { return dp.Load() } -func CheckVecIndexWithDataTypeExist(name string, dType schemapb.DataType) bool { +func CheckVecIndexWithDataTypeExist(name string, dataType schemapb.DataType, elementType schemapb.DataType) bool { + isEmbeddingList := dataType == schemapb.DataType_ArrayOfVector + if isEmbeddingList { + dataType = elementType + } + var result bool GetDynamicPool().Submit(func() (any, error) { cIndexName := C.CString(name) - cType := uint32(dType) + cType := uint32(dataType) defer C.free(unsafe.Pointer(cIndexName)) - result = bool(C.CheckVecIndexWithDataType(cIndexName, cType)) + result = bool(C.CheckVecIndexWithDataType(cIndexName, cType, C.bool(isEmbeddingList))) return nil, nil }).Await() diff --git a/internal/proxy/cgo_util_test.go b/internal/proxy/cgo_util_test.go index c3277bdce9..ae68f76430 100644 --- a/internal/proxy/cgo_util_test.go +++ b/internal/proxy/cgo_util_test.go @@ -50,7 +50,7 @@ func Test_CheckVecIndexWithDataTypeExist(t *testing.T) { } for _, test := range cases { - if got := CheckVecIndexWithDataTypeExist(test.indexType, test.dataType); got != test.want { + if got := CheckVecIndexWithDataTypeExist(test.indexType, test.dataType, schemapb.DataType_None); got != test.want { t.Errorf("CheckVecIndexWithDataTypeExist(%v, %v) = %v", test.indexType, test.dataType, test.want) } } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 9fd3beac0c..8c99015087 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -77,10 +77,41 @@ import ( ) const ( + // Test framework constants attempts = 1000000 sleepDuration = time.Millisecond * 200 ) +const ( + // Collection and partition naming + prefix = "test_proxy_" + partitionPrefix = "test_proxy_partition_" + + // Collection configuration + shardsNum = common.DefaultShardsNum + dim = 128 + rowNum = 100 + nlist = 10 + nq = 10 +) + +const ( + // Field names + int64Field = "int64" + floatVecField = "fVec" + binaryVecField = "bVec" + structField = "structField" + subFieldI32 = "structI32" + subFieldFVec = "structFVec" +) + +const ( + // Index names + testFloatIndexName = "float_index" + testBinaryIndexName = "binary_index" + testStructFVecIndexName = "structFVecIndex" +) + var Registry *prometheus.Registry func init() { @@ -294,6 +325,621 @@ func (s *proxyTestServer) gracefulStop() { } } +func checkFlushState(ctx context.Context, proxy *Proxy, segmentIDs []int64) bool { + resp, err := proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ + SegmentIDs: segmentIDs, + }) + if err != nil { + return false + } + return resp.GetFlushed() +} + +func checkCollectionLoaded(ctx context.Context, proxy *Proxy, dbName, collectionName string) bool { + resp, err := proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{ + Base: nil, + DbName: dbName, + TimeStamp: 0, + Type: milvuspb.ShowType_InMemory, + CollectionNames: []string{collectionName}, + }) + if err != nil { + return false + } + if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return false + } + + for idx, name := range resp.CollectionNames { + if name == collectionName && resp.InMemoryPercentages[idx] == 100 { + return true + } + } + + return false +} + +func constructTestCollectionSchema(collectionName, int64Field, floatVecField, binaryVecField, structField string, dim int) *schemapb.CollectionSchema { + pk := &schemapb.FieldSchema{ + FieldID: 100, + Name: int64Field, + IsPrimaryKey: true, + Description: "", + DataType: schemapb.DataType_Int64, + TypeParams: nil, + IndexParams: nil, + AutoID: true, + } + fVec := &schemapb.FieldSchema{ + FieldID: 101, + Name: floatVecField, + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + }, + IndexParams: nil, + AutoID: false, + } + bVec := &schemapb.FieldSchema{ + FieldID: 102, + Name: binaryVecField, + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + }, + IndexParams: nil, + AutoID: false, + } + // struct schema fields + sId := &schemapb.FieldSchema{ + FieldID: 104, + Name: subFieldI32, + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + IndexParams: nil, + AutoID: false, + } + sFVec := &schemapb.FieldSchema{ + FieldID: 105, + Name: subFieldFVec, + IsPrimaryKey: false, + Description: "", + DataType: schemapb.DataType_ArrayOfVector, + ElementType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + { + Key: common.MaxCapacityKey, + Value: "100", + }, + }, + IndexParams: nil, + AutoID: false, + } + structF := &schemapb.StructArrayFieldSchema{ + FieldID: 103, + Name: structField, + Fields: []*schemapb.FieldSchema{sId, sFVec}, + } + return &schemapb.CollectionSchema{ + Name: collectionName, + Description: "", + AutoID: false, + Fields: []*schemapb.FieldSchema{ + pk, + fVec, + bVec, + }, + StructArrayFields: []*schemapb.StructArrayFieldSchema{structF}, + } +} + +func constructTestCreateCollectionRequest(dbName, collectionName string, schema *schemapb.CollectionSchema, shardsNum int32) *milvuspb.CreateCollectionRequest { + bs, err := proto.Marshal(schema) + if err != nil { + panic(err) + } + return &milvuspb.CreateCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Schema: bs, + ShardsNum: shardsNum, + } +} + +func constructTestCollectionInsertRequest(dbName, collectionName, floatVecField, binaryVecField, structField string, schema *schemapb.CollectionSchema, rowNum, dim int) *milvuspb.InsertRequest { + fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) + structColumn := newStructArrayFieldData(schema.StructArrayFields[0], structField, rowNum, dim) + hashKeys := testutils.GenerateHashKeys(rowNum) + return &milvuspb.InsertRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionName: "", + FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn, structColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + } +} + +func constructTestCreateIndexRequest(dbName, collectionName string, dataType schemapb.DataType, fieldName string, dim, nlist int) *milvuspb.CreateIndexRequest { + req := &milvuspb.CreateIndexRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + } + switch dataType { + case schemapb.DataType_FloatVector: + { + req.FieldName = fieldName + req.IndexName = testFloatIndexName + req.ExtraParams = []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + { + Key: common.MetricTypeKey, + Value: metric.L2, + }, + { + Key: common.IndexTypeKey, + Value: "IVF_FLAT", + }, + { + Key: "nlist", + Value: strconv.Itoa(nlist), + }, + } + } + case schemapb.DataType_BinaryVector: + { + req.FieldName = fieldName + req.IndexName = testBinaryIndexName + req.ExtraParams = []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + { + Key: common.MetricTypeKey, + Value: metric.JACCARD, + }, + { + Key: common.IndexTypeKey, + Value: "BIN_IVF_FLAT", + }, + { + Key: "nlist", + Value: strconv.Itoa(nlist), + }, + } + } + case schemapb.DataType_ArrayOfVector: + { + req.FieldName = fieldName + req.IndexName = testStructFVecIndexName + req.ExtraParams = []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + { + Key: common.MetricTypeKey, + Value: metric.MaxSim, + }, + { + Key: common.IndexTypeKey, + Value: "EMB_LIST_HNSW", + }, + { + Key: "nlist", + Value: strconv.Itoa(nlist), + }, + } + } + } + + return req +} + +func constructTestVectorsPlaceholderGroup(nq int, dim int, isEmbedingList bool) *commonpb.PlaceholderGroup { + values := make([][]byte, 0, nq) + for i := 0; i < nq; i++ { + bs := make([]byte, 0, dim*4) + count := dim + if isEmbedingList { + count = (rand.Intn(5) + 2) * dim + } + + for j := 0; j < count; j++ { + var buffer bytes.Buffer + f := rand.Float32() + err := binary.Write(&buffer, common.Endian, f) + if err != nil { + panic(err) + } + bs = append(bs, buffer.Bytes()...) + } + values = append(values, bs) + } + + vectorType := commonpb.PlaceholderType_FloatVector + if isEmbedingList { + vectorType = commonpb.PlaceholderType_EmbListFloatVector + } + + return &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{ + { + Tag: "$0", + Type: vectorType, + Values: values, + }, + }, + } +} + +func constructTestSearchRequest(dbName, collectionName, floatVecField, expr string, nq, nprobe, topk, roundDecimal, dim int) *milvuspb.SearchRequest { + plg := constructTestVectorsPlaceholderGroup(nq, dim, false) + plgBs, err := proto.Marshal(plg) + if err != nil { + panic(err) + } + + params := make(map[string]string) + params["nprobe"] = strconv.Itoa(nprobe) + b, err := json.Marshal(params) + if err != nil { + panic(err) + } + searchParams := []*commonpb.KeyValuePair{ + {Key: MetricTypeKey, Value: metric.L2}, + {Key: ParamsKey, Value: string(b)}, + {Key: AnnsFieldKey, Value: floatVecField}, + {Key: TopKKey, Value: strconv.Itoa(topk)}, + {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, + } + + return &milvuspb.SearchRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionNames: nil, + Dsl: expr, + PlaceholderGroup: plgBs, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: nil, + SearchParams: searchParams, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + SearchByPrimaryKeys: false, + } +} + +func constructTestSubSearchRequest(floatVecField, expr string, nq, nprobe, topk, roundDecimal, dim int) *milvuspb.SubSearchRequest { + plg := constructTestVectorsPlaceholderGroup(nq, dim, false) + plgBs, err := proto.Marshal(plg) + if err != nil { + panic(err) + } + + params := make(map[string]string) + params["nprobe"] = strconv.Itoa(nprobe) + b, err := json.Marshal(params) + if err != nil { + panic(err) + } + searchParams := []*commonpb.KeyValuePair{ + {Key: MetricTypeKey, Value: metric.L2}, + {Key: ParamsKey, Value: string(b)}, + {Key: AnnsFieldKey, Value: floatVecField}, + {Key: TopKKey, Value: strconv.Itoa(topk)}, + {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, + } + + return &milvuspb.SubSearchRequest{ + Dsl: expr, + PlaceholderGroup: plgBs, + DslType: commonpb.DslType_BoolExprV1, + SearchParams: searchParams, + } +} + +func constructTestAdvancedSearchRequest(dbName, collectionName, floatVecField, expr string, nq, nprobe, topk, roundDecimal, dim int) *milvuspb.SearchRequest { + params := make(map[string]float64) + params[RRFParamsKey] = 60 + b, err := json.Marshal(params) + if err != nil { + panic(err) + } + rankParams := []*commonpb.KeyValuePair{ + {Key: RankTypeKey, Value: "rrf"}, + {Key: ParamsKey, Value: string(b)}, + {Key: LimitKey, Value: strconv.Itoa(topk)}, + {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, + } + + req1 := constructTestSubSearchRequest(floatVecField, expr, nq, nprobe, topk, roundDecimal, dim) + req2 := constructTestSubSearchRequest(floatVecField, expr, nq, nprobe, topk, roundDecimal, dim) + ret := &milvuspb.SearchRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionNames: nil, + OutputFields: nil, + SearchParams: rankParams, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + } + ret.SubReqs = append(ret.SubReqs, req1) + ret.SubReqs = append(ret.SubReqs, req2) + return ret +} + +func constructTestEmbeddingListSearchRequest(dbName, collectionName, structFVec, expr string, nq, nprobe, topk, roundDecimal, dim int) *milvuspb.SearchRequest { + plg := constructTestVectorsPlaceholderGroup(nq, dim, true) + plgBs, err := proto.Marshal(plg) + if err != nil { + panic(err) + } + params := make(map[string]string) + params["nprobe"] = strconv.Itoa(nprobe) + b, err := json.Marshal(params) + if err != nil { + panic(err) + } + searchParams := []*commonpb.KeyValuePair{ + {Key: MetricTypeKey, Value: metric.MaxSim}, + {Key: ParamsKey, Value: string(b)}, + {Key: AnnsFieldKey, Value: structFVec}, + {Key: TopKKey, Value: strconv.Itoa(topk)}, + {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, + } + + return &milvuspb.SearchRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionNames: nil, + Dsl: expr, + PlaceholderGroup: plgBs, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: nil, + SearchParams: searchParams, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + SearchByPrimaryKeys: false, + } +} + +// Helper functions for TestProxy +func constructPrimaryKeysPlaceholderGroup(int64Field string, insertedIDs []int64) *commonpb.PlaceholderGroup { + expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIDs[0]) + exprBytes := []byte(expr) + + return &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{ + { + Tag: "$0", + Type: commonpb.PlaceholderType_None, + Values: [][]byte{exprBytes}, + }, + }, + } +} + +func constructSearchByPksRequest(t *testing.T, dbName, collectionName, floatVecField, int64Field string, insertedIDs []int64, nprobe, topk, roundDecimal int) *milvuspb.SearchRequest { + plg := constructPrimaryKeysPlaceholderGroup(int64Field, insertedIDs) + plgBs, err := proto.Marshal(plg) + assert.NoError(t, err) + + params := make(map[string]string) + params["nprobe"] = strconv.Itoa(nprobe) + b, err := json.Marshal(params) + assert.NoError(t, err) + searchParams := []*commonpb.KeyValuePair{ + {Key: MetricTypeKey, Value: metric.L2}, + {Key: ParamsKey, Value: string(b)}, + {Key: AnnsFieldKey, Value: floatVecField}, + {Key: TopKKey, Value: strconv.Itoa(topk)}, + {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, + } + + return &milvuspb.SearchRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionNames: nil, + Dsl: "", + PlaceholderGroup: plgBs, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: nil, + SearchParams: searchParams, + TravelTimestamp: 0, + GuaranteeTimestamp: 0, + SearchByPrimaryKeys: true, + } +} + +func constructPartitionInsertRequest(dbName, collectionName, partitionName, floatVecField, binaryVecField, structField string, schema *schemapb.CollectionSchema, rowNum, dim int) *milvuspb.InsertRequest { + fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) + structColumn := newStructArrayFieldData(schema.StructArrayFields[0], structField, rowNum, dim) + hashKeys := testutils.GenerateHashKeys(rowNum) + return &milvuspb.InsertRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionName: partitionName, + FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn, structColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + } +} + +func constructCollectionUpsertRequestNoPK(dbName, collectionName, floatVecField, binaryVecField, structField string, schema *schemapb.CollectionSchema, rowNum, dim int) *milvuspb.UpsertRequest { + fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) + structColumn := newStructArrayFieldData(schema.StructArrayFields[0], structField, rowNum, dim) + hashKeys := testutils.GenerateHashKeys(rowNum) + return &milvuspb.UpsertRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn, structColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + } +} + +func constructCollectionUpsertRequestWithPK(dbName, collectionName, floatVecField, binaryVecField, structField string, schema *schemapb.CollectionSchema, rowNum, dim int) *milvuspb.UpsertRequest { + pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum) + fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) + structColumn := newStructArrayFieldData(schema.StructArrayFields[0], structField, rowNum, dim) + hashKeys := testutils.GenerateHashKeys(rowNum) + return &milvuspb.UpsertRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn, bVecColumn, structColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + } +} + +func constructCreateCredentialRequest(username, password string) *milvuspb.CreateCredentialRequest { + return &milvuspb.CreateCredentialRequest{ + Base: nil, + Username: username, + Password: password, + } +} + +func constructUpdateCredentialRequest(username, oldPassword, newPassword string) *milvuspb.UpdateCredentialRequest { + return &milvuspb.UpdateCredentialRequest{ + Base: nil, + Username: username, + OldPassword: oldPassword, + NewPassword: newPassword, + } +} + +func constructGetCredentialRequest(username string) *rootcoordpb.GetCredentialRequest { + return &rootcoordpb.GetCredentialRequest{ + Base: nil, + Username: username, + } +} + +func constructListCredUsersRequest() *milvuspb.ListCredUsersRequest { + return &milvuspb.ListCredUsersRequest{ + Base: nil, + } +} + +func constructDelCredRequest(username string) *milvuspb.DeleteCredentialRequest { + return &milvuspb.DeleteCredentialRequest{ + Base: nil, + Username: username, + } +} + +func constructPartitionReqUpsertRequestValid(dbName, collectionName, partitionName, floatVecField, binaryVecField, structField string, schema *schemapb.CollectionSchema, rowNum, dim int) *milvuspb.UpsertRequest { + pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum) + fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) + structColumn := newStructArrayFieldData(schema.StructArrayFields[0], structField, rowNum, dim) + hashKeys := testutils.GenerateHashKeys(rowNum) + return &milvuspb.UpsertRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionName: partitionName, + FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn, bVecColumn, structColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + } +} + +func constructPartitionReqUpsertRequestInvalid(dbName, collectionName, floatVecField, binaryVecField, structField string, schema *schemapb.CollectionSchema, rowNum, dim int) *milvuspb.UpsertRequest { + pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum) + fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) + structColumn := newStructArrayFieldData(schema.StructArrayFields[0], structField, rowNum, dim) + hashKeys := testutils.GenerateHashKeys(rowNum) + return &milvuspb.UpsertRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionName: "%$@", + FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn, bVecColumn, structColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + } +} + +func constructCollectionUpsertRequestValid(dbName, collectionName, floatVecField, binaryVecField, structField string, schema *schemapb.CollectionSchema, rowNum, dim int) *milvuspb.UpsertRequest { + pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum) + fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) + bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) + structColumn := newStructArrayFieldData(schema.StructArrayFields[0], structField, rowNum, dim) + hashKeys := testutils.GenerateHashKeys(rowNum) + return &milvuspb.UpsertRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn, bVecColumn, structColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + } +} + +func checkPartitionInMemory(t *testing.T, ctx context.Context, proxy *Proxy, dbName, collectionName, partitionName string, collectionID int64) bool { + resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + CollectionID: collectionID, + PartitionNames: []string{partitionName}, + Type: milvuspb.ShowType_InMemory, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + + for idx, name := range resp.PartitionNames { + if name == partitionName && resp.InMemoryPercentages[idx] == 100 { + return true + } + } + + return false +} + func TestProxy(t *testing.T) { var err error var wg sync.WaitGroup @@ -304,6 +950,11 @@ func TestProxy(t *testing.T) { streamingutil.SetStreamingServiceEnabled() defer streamingutil.UnsetStreamingServiceEnabled() + // params.Save(params.EtcdCfg.RequestTimeout.Key, "300000") + // params.Save(params.CommonCfg.SessionTTL.Key, "300") + // params.Save(params.CommonCfg.SessionRetryTimes.Key, "500") + // params.Save(params.CommonCfg.GracefulStopTimeout.Key, "3600") + params.RootCoordGrpcServerCfg.IP = "localhost" params.QueryCoordGrpcServerCfg.IP = "localhost" params.DataCoordGrpcServerCfg.IP = "localhost" @@ -434,261 +1085,17 @@ func TestProxy(t *testing.T) { assert.Equal(t, "", resp.Value) }) - prefix := "test_proxy_" - partitionPrefix := "test_proxy_partition_" dbName := GetCurDBNameFromContextOrDefault(ctx) collectionName := prefix + funcutil.GenRandomStr() otherCollectionName := collectionName + "_other_" + funcutil.GenRandomStr() partitionName := partitionPrefix + funcutil.GenRandomStr() otherPartitionName := partitionPrefix + "_other_" + funcutil.GenRandomStr() - shardsNum := common.DefaultShardsNum - int64Field := "int64" - floatVecField := "fVec" - binaryVecField := "bVec" - dim := 128 - rowNum := 500 - floatIndexName := "float_index" - binaryIndexName := "binary_index" - structId := "structI32" - structFVec := "structFVec" - structField := "structField" - nlist := 10 - nq := 10 + var segmentIDs []int64 // an int64 field (pk) & a float vector field - constructCollectionSchema := func() *schemapb.CollectionSchema { - pk := &schemapb.FieldSchema{ - FieldID: 100, - Name: int64Field, - IsPrimaryKey: true, - Description: "", - DataType: schemapb.DataType_Int64, - TypeParams: nil, - IndexParams: nil, - AutoID: true, - } - fVec := &schemapb.FieldSchema{ - FieldID: 101, - Name: floatVecField, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: strconv.Itoa(dim), - }, - }, - IndexParams: nil, - AutoID: false, - } - bVec := &schemapb.FieldSchema{ - FieldID: 102, - Name: binaryVecField, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_BinaryVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: strconv.Itoa(dim), - }, - }, - IndexParams: nil, - AutoID: false, - } - // struct schema fields - sId := &schemapb.FieldSchema{ - FieldID: 104, - Name: structId, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_Array, - ElementType: schemapb.DataType_Int32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.MaxCapacityKey, - Value: "100", - }, - }, - IndexParams: nil, - AutoID: false, - } - sFVec := &schemapb.FieldSchema{ - FieldID: 105, - Name: structFVec, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_ArrayOfVector, - ElementType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: strconv.Itoa(dim), - }, - { - Key: common.MaxCapacityKey, - Value: "100", - }, - }, - IndexParams: nil, - AutoID: false, - } - structF := &schemapb.StructArrayFieldSchema{ - FieldID: 103, - Name: structField, - Fields: []*schemapb.FieldSchema{sId, sFVec}, - } - return &schemapb.CollectionSchema{ - Name: collectionName, - Description: "", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - pk, - fVec, - bVec, - }, - StructArrayFields: []*schemapb.StructArrayFieldSchema{structF}, - } - } - schema := constructCollectionSchema() - - constructCreateCollectionRequest := func() *milvuspb.CreateCollectionRequest { - bs, err := proto.Marshal(schema) - assert.NoError(t, err) - return &milvuspb.CreateCollectionRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - Schema: bs, - ShardsNum: shardsNum, - } - } - createCollectionReq := constructCreateCollectionRequest() - - constructCollectionInsertRequest := func() *milvuspb.InsertRequest { - fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) - bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) - structColumn := newStructArrayFieldData(schema.StructArrayFields[0], structField, rowNum, dim) - hashKeys := testutils.GenerateHashKeys(rowNum) - return &milvuspb.InsertRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - PartitionName: "", - FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn, structColumn}, - HashKeys: hashKeys, - NumRows: uint32(rowNum), - } - } - - constructPartitionInsertRequest := func() *milvuspb.InsertRequest { - fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) - bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) - structColumn := newStructArrayFieldData(schema.StructArrayFields[0], structField, rowNum, dim) - hashKeys := testutils.GenerateHashKeys(rowNum) - return &milvuspb.InsertRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - PartitionName: partitionName, - FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn, structColumn}, - HashKeys: hashKeys, - NumRows: uint32(rowNum), - } - } - - constructCollectionUpsertRequestNoPK := func() *milvuspb.UpsertRequest { - fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) - bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) - structColumn := newStructArrayFieldData(schema.StructArrayFields[0], structField, rowNum, dim) - hashKeys := testutils.GenerateHashKeys(rowNum) - return &milvuspb.UpsertRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - PartitionName: partitionName, - FieldsData: []*schemapb.FieldData{fVecColumn, bVecColumn, structColumn}, - HashKeys: hashKeys, - NumRows: uint32(rowNum), - } - } - - constructCollectionUpsertRequestWithPK := func() *milvuspb.UpsertRequest { - pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum) - fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) - bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) - structColumn := newStructArrayFieldData(schema.StructArrayFields[0], structField, rowNum, dim) - hashKeys := testutils.GenerateHashKeys(rowNum) - return &milvuspb.UpsertRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - PartitionName: partitionName, - FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn, bVecColumn, structColumn}, - HashKeys: hashKeys, - NumRows: uint32(rowNum), - } - } - - constructCreateIndexRequest := func(dataType schemapb.DataType, fieldName string) *milvuspb.CreateIndexRequest { - req := &milvuspb.CreateIndexRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - } - switch dataType { - case schemapb.DataType_FloatVector: - { - req.FieldName = fieldName - req.IndexName = floatIndexName - req.ExtraParams = []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: strconv.Itoa(dim), - }, - { - Key: common.MetricTypeKey, - Value: metric.L2, - }, - { - Key: common.IndexTypeKey, - Value: "IVF_FLAT", - }, - { - Key: "nlist", - Value: strconv.Itoa(nlist), - }, - } - } - case schemapb.DataType_BinaryVector: - { - req.FieldName = fieldName - req.IndexName = binaryIndexName - req.ExtraParams = []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: strconv.Itoa(dim), - }, - { - Key: common.MetricTypeKey, - Value: metric.JACCARD, - }, - { - Key: common.IndexTypeKey, - Value: "BIN_IVF_FLAT", - }, - { - Key: "nlist", - Value: strconv.Itoa(nlist), - }, - } - } - } - - return req - } + schema := constructTestCollectionSchema(collectionName, int64Field, floatVecField, binaryVecField, structField, dim) + createCollectionReq := constructTestCreateCollectionRequest(dbName, collectionName, schema, shardsNum) wg.Add(1) t.Run("create collection", func(t *testing.T) { @@ -698,16 +1105,12 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) - reqInvalidField := constructCreateCollectionRequest() - schema := constructCollectionSchema() - schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + invalidSchema := constructTestCollectionSchema(collectionName, int64Field, floatVecField, binaryVecField, structField, dim) + invalidSchema.Fields = append(invalidSchema.Fields, &schemapb.FieldSchema{ Name: "StringField", DataType: schemapb.DataType_String, }) - bs, err := proto.Marshal(schema) - assert.NoError(t, err) - reqInvalidField.CollectionName = "invalid_field_coll" - reqInvalidField.Schema = bs + reqInvalidField := constructTestCreateCollectionRequest(dbName, "invalid_field_coll", invalidSchema, shardsNum) resp, err = proxy.CreateCollection(ctx, reqInvalidField) assert.NoError(t, err) @@ -1097,7 +1500,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("insert", func(t *testing.T) { defer wg.Done() - req := constructCollectionInsertRequest() + req := constructTestCollectionInsertRequest(dbName, collectionName, floatVecField, binaryVecField, structField, schema, rowNum, dim) resp, err := proxy.Insert(ctx, req) assert.NoError(t, err) @@ -1134,19 +1537,9 @@ func TestProxy(t *testing.T) { time.Sleep(5 * time.Second) log.Info("flush collection", zap.Int64s("segments to be flushed", segmentIDs)) - f := func() bool { - resp, err := proxy.GetFlushState(ctx, &milvuspb.GetFlushStateRequest{ - SegmentIDs: segmentIDs, - }) - if err != nil { - return false - } - return resp.GetFlushed() - } - // waiting for flush operation to be done counter := 0 - for !f() { + for !checkFlushState(ctx, proxy, segmentIDs) { if counter > 100 { flushed = false break @@ -1159,6 +1552,7 @@ func TestProxy(t *testing.T) { if !flushed { log.Warn("flush operation was not sure to be done") } + wg.Add(1) t.Run("get statistics after flush", func(t *testing.T) { defer wg.Done() @@ -1184,10 +1578,11 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) }) + wg.Add(1) t.Run("create index for floatVec field", func(t *testing.T) { defer wg.Done() - req := constructCreateIndexRequest(schemapb.DataType_FloatVector, floatVecField) + req := constructTestCreateIndexRequest(dbName, collectionName, schemapb.DataType_FloatVector, floatVecField, dim, nlist) resp, err := proxy.CreateIndex(ctx, req) assert.NoError(t, err) @@ -1200,7 +1595,7 @@ func TestProxy(t *testing.T) { req := &milvuspb.AlterIndexRequest{ DbName: dbName, CollectionName: collectionName, - IndexName: floatIndexName, + IndexName: testFloatIndexName, ExtraParams: []*commonpb.KeyValuePair{ { Key: common.MmapEnabledKey, @@ -1226,7 +1621,7 @@ func TestProxy(t *testing.T) { }) err = merr.CheckRPCCall(resp, err) assert.NoError(t, err) - assert.Equal(t, floatIndexName, resp.IndexDescriptions[0].IndexName) + assert.Equal(t, testFloatIndexName, resp.IndexDescriptions[0].IndexName) enableMmap, _ := common.IsMmapDataEnabled(resp.IndexDescriptions[0].GetParams()...) assert.True(t, enableMmap, "params: %+v", resp.IndexDescriptions[0]) @@ -1234,7 +1629,7 @@ func TestProxy(t *testing.T) { req := &milvuspb.AlterIndexRequest{ DbName: dbName, CollectionName: collectionName, - IndexName: floatIndexName, + IndexName: testFloatIndexName, ExtraParams: []*commonpb.KeyValuePair{ { Key: common.MmapEnabledKey, @@ -1255,7 +1650,7 @@ func TestProxy(t *testing.T) { DbName: dbName, CollectionName: collectionName, FieldName: floatVecField, - IndexName: floatIndexName, + IndexName: testFloatIndexName, }) err = merr.CheckRPCCall(resp, err) assert.NoError(t, err) @@ -1273,7 +1668,7 @@ func TestProxy(t *testing.T) { }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) - assert.Equal(t, floatIndexName, resp.IndexDescriptions[0].IndexName) + assert.Equal(t, testFloatIndexName, resp.IndexDescriptions[0].IndexName) }) wg.Add(1) @@ -1284,7 +1679,7 @@ func TestProxy(t *testing.T) { DbName: dbName, CollectionName: collectionName, FieldName: floatVecField, - IndexName: floatIndexName, + IndexName: testFloatIndexName, }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) @@ -1298,7 +1693,7 @@ func TestProxy(t *testing.T) { DbName: dbName, CollectionName: collectionName, FieldName: floatVecField, - IndexName: floatIndexName, + IndexName: testFloatIndexName, }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) @@ -1329,13 +1724,117 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("create index for binVec field", func(t *testing.T) { defer wg.Done() - req := constructCreateIndexRequest(schemapb.DataType_BinaryVector, binaryVecField) + req := constructTestCreateIndexRequest(dbName, collectionName, schemapb.DataType_BinaryVector, binaryVecField, dim, nlist) resp, err := proxy.CreateIndex(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) }) + wg.Add(1) + t.Run("create index for embedding list field", func(t *testing.T) { + defer wg.Done() + req := constructTestCreateIndexRequest(dbName, collectionName, schemapb.DataType_ArrayOfVector, subFieldFVec, dim, nlist) + + resp, err := proxy.CreateIndex(ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) + }) + + wg.Add(1) + t.Run("alter index for embedding list field", func(t *testing.T) { + defer wg.Done() + req := &milvuspb.AlterIndexRequest{ + DbName: dbName, + CollectionName: collectionName, + IndexName: testStructFVecIndexName, + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.MmapEnabledKey, + Value: "true", + }, + }, + } + + resp, err := proxy.AlterIndex(ctx, req) + err = merr.CheckRPCCall(resp, err) + assert.NoError(t, err) + }) + + wg.Add(1) + t.Run("describe index for embedding list field", func(t *testing.T) { + defer wg.Done() + resp, err := proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + FieldName: subFieldFVec, + IndexName: testStructFVecIndexName, + }) + err = merr.CheckRPCCall(resp, err) + assert.NoError(t, err) + assert.Equal(t, testStructFVecIndexName, resp.IndexDescriptions[0].IndexName) + enableMmap, _ := common.IsMmapDataEnabled(resp.IndexDescriptions[0].GetParams()...) + assert.True(t, enableMmap, "params: %+v", resp.IndexDescriptions[0]) + }) + + wg.Add(1) + t.Run("describe index with indexName for embedding list field", func(t *testing.T) { + defer wg.Done() + resp, err := proxy.DescribeIndex(ctx, &milvuspb.DescribeIndexRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + FieldName: subFieldFVec, + IndexName: testStructFVecIndexName, + }) + err = merr.CheckRPCCall(resp, err) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + wg.Add(1) + t.Run("get index statistics for embedding list field", func(t *testing.T) { + defer wg.Done() + resp, err := proxy.GetIndexStatistics(ctx, &milvuspb.GetIndexStatisticsRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + IndexName: testStructFVecIndexName, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + assert.Equal(t, testStructFVecIndexName, resp.IndexDescriptions[0].IndexName) + }) + + wg.Add(1) + t.Run("get index build progress for embedding list field", func(t *testing.T) { + defer wg.Done() + resp, err := proxy.GetIndexBuildProgress(ctx, &milvuspb.GetIndexBuildProgressRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + FieldName: subFieldFVec, + IndexName: testStructFVecIndexName, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + + wg.Add(1) + t.Run("get index state for embedding list field", func(t *testing.T) { + defer wg.Done() + resp, err := proxy.GetIndexState(ctx, &milvuspb.GetIndexStateRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + FieldName: subFieldFVec, + IndexName: testStructFVecIndexName, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + }) + loaded := true wg.Add(1) t.Run("load collection", func(t *testing.T) { @@ -1367,29 +1866,9 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode) - f := func() bool { - resp, err := proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{ - Base: nil, - DbName: dbName, - TimeStamp: 0, - Type: milvuspb.ShowType_InMemory, - CollectionNames: []string{collectionName}, - }) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) - - for idx, name := range resp.CollectionNames { - if name == collectionName && resp.InMemoryPercentages[idx] == 100 { - return true - } - } - - return false - } - // waiting for collection to be loaded counter := 0 - for !f() { + for !checkCollectionLoaded(ctx, proxy, dbName, collectionName) { if counter > 100 { loaded = false break @@ -1515,93 +1994,11 @@ func TestProxy(t *testing.T) { topk := 10 roundDecimal := 6 expr := fmt.Sprintf("%s > 0", int64Field) - constructVectorsPlaceholderGroup := func(nq int) *commonpb.PlaceholderGroup { - values := make([][]byte, 0, nq) - for i := 0; i < nq; i++ { - bs := make([]byte, 0, dim*4) - for j := 0; j < dim; j++ { - var buffer bytes.Buffer - f := rand.Float32() - err := binary.Write(&buffer, common.Endian, f) - assert.NoError(t, err) - bs = append(bs, buffer.Bytes()...) - } - values = append(values, bs) - } - - return &commonpb.PlaceholderGroup{ - Placeholders: []*commonpb.PlaceholderValue{ - { - Tag: "$0", - Type: commonpb.PlaceholderType_FloatVector, - Values: values, - }, - }, - } - } - - constructSearchRequest := func(nq int) *milvuspb.SearchRequest { - plg := constructVectorsPlaceholderGroup(nq) - plgBs, err := proto.Marshal(plg) - assert.NoError(t, err) - - params := make(map[string]string) - params["nprobe"] = strconv.Itoa(nprobe) - b, err := json.Marshal(params) - assert.NoError(t, err) - searchParams := []*commonpb.KeyValuePair{ - {Key: MetricTypeKey, Value: metric.L2}, - {Key: ParamsKey, Value: string(b)}, - {Key: AnnsFieldKey, Value: floatVecField}, - {Key: TopKKey, Value: strconv.Itoa(topk)}, - {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, - } - - return &milvuspb.SearchRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - PartitionNames: nil, - Dsl: expr, - PlaceholderGroup: plgBs, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: nil, - SearchParams: searchParams, - TravelTimestamp: 0, - GuaranteeTimestamp: 0, - SearchByPrimaryKeys: false, - } - } - - constructSubSearchRequest := func(nq int) *milvuspb.SubSearchRequest { - plg := constructVectorsPlaceholderGroup(nq) - plgBs, err := proto.Marshal(plg) - assert.NoError(t, err) - - params := make(map[string]string) - params["nprobe"] = strconv.Itoa(nprobe) - b, err := json.Marshal(params) - assert.NoError(t, err) - searchParams := []*commonpb.KeyValuePair{ - {Key: MetricTypeKey, Value: metric.L2}, - {Key: ParamsKey, Value: string(b)}, - {Key: AnnsFieldKey, Value: floatVecField}, - {Key: TopKKey, Value: strconv.Itoa(topk)}, - {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, - } - - return &milvuspb.SubSearchRequest{ - Dsl: expr, - PlaceholderGroup: plgBs, - DslType: commonpb.DslType_BoolExprV1, - SearchParams: searchParams, - } - } wg.Add(1) t.Run("search", func(t *testing.T) { defer wg.Done() - req := constructSearchRequest(nq) + req := constructTestSearchRequest(dbName, collectionName, floatVecField, expr, nq, nprobe, topk, roundDecimal, dim) resp, err := proxy.Search(ctx, req) assert.NoError(t, err) @@ -1616,208 +2013,42 @@ func TestProxy(t *testing.T) { } }) - constructAdvancedSearchRequest := func() *milvuspb.SearchRequest { - params := make(map[string]float64) - params[RRFParamsKey] = 60 - b, err := json.Marshal(params) - assert.NoError(t, err) - rankParams := []*commonpb.KeyValuePair{ - {Key: RankTypeKey, Value: "rrf"}, - {Key: ParamsKey, Value: string(b)}, - {Key: LimitKey, Value: strconv.Itoa(topk)}, - {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, - } - - req1 := constructSubSearchRequest(nq) - req2 := constructSubSearchRequest(nq) - ret := &milvuspb.SearchRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - PartitionNames: nil, - OutputFields: nil, - SearchParams: rankParams, - TravelTimestamp: 0, - GuaranteeTimestamp: 0, - } - ret.SubReqs = append(ret.SubReqs, req1) - ret.SubReqs = append(ret.SubReqs, req2) - return ret - } - wg.Add(1) t.Run("advanced search", func(t *testing.T) { defer wg.Done() - req := constructAdvancedSearchRequest() + req := constructTestAdvancedSearchRequest(dbName, collectionName, floatVecField, expr, nq, nprobe, topk, roundDecimal, dim) resp, err := proxy.Search(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) }) - nq = 10 - constructPrimaryKeysPlaceholderGroup := func() *commonpb.PlaceholderGroup { - expr := fmt.Sprintf("%v in [%v]", int64Field, insertedIDs[0]) - exprBytes := []byte(expr) + wg.Add(1) + t.Run("embedding list search", func(t *testing.T) { + defer wg.Done() + req := constructTestEmbeddingListSearchRequest(dbName, collectionName, subFieldFVec, expr, nq, nprobe, topk, roundDecimal, dim) - return &commonpb.PlaceholderGroup{ - Placeholders: []*commonpb.PlaceholderValue{ - { - Tag: "$0", - Type: commonpb.PlaceholderType_None, - Values: [][]byte{exprBytes}, - }, - }, - } - } - - constructSearchByPksRequest := func() *milvuspb.SearchRequest { - plg := constructPrimaryKeysPlaceholderGroup() - plgBs, err := proto.Marshal(plg) + resp, err := proxy.Search(ctx, req) assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) - params := make(map[string]string) - params["nprobe"] = strconv.Itoa(nprobe) - b, err := json.Marshal(params) - assert.NoError(t, err) - searchParams := []*commonpb.KeyValuePair{ - {Key: MetricTypeKey, Value: metric.L2}, - {Key: ParamsKey, Value: string(b)}, - {Key: AnnsFieldKey, Value: floatVecField}, - {Key: TopKKey, Value: strconv.Itoa(topk)}, - {Key: RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, + { + Params.Save(Params.ProxyCfg.MustUsePartitionKey.Key, "true") + resp, err := proxy.Search(ctx, req) + assert.NoError(t, err) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + Params.Reset(Params.ProxyCfg.MustUsePartitionKey.Key) } - - return &milvuspb.SearchRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - PartitionNames: nil, - Dsl: "", - PlaceholderGroup: plgBs, - DslType: commonpb.DslType_BoolExprV1, - OutputFields: nil, - SearchParams: searchParams, - TravelTimestamp: 0, - GuaranteeTimestamp: 0, - SearchByPrimaryKeys: true, - } - } + }) wg.Add(1) t.Run("search by primary keys", func(t *testing.T) { defer wg.Done() - req := constructSearchByPksRequest() + req := constructSearchByPksRequest(t, dbName, collectionName, floatVecField, int64Field, insertedIDs, nprobe, topk, roundDecimal) resp, err := proxy.Search(ctx, req) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) }) - // nprobe := 10 - // topk := 10 - // roundDecimal := 6 - // expr := fmt.Sprintf("%s > 0", int64Field) - // constructPlaceholderGroup := func() *milvuspb.PlaceholderGroup { - // values := make([][]byte, 0, nq) - // for i := 0; i < nq; i++ { - // bs := make([]byte, 0, dim*4) - // for j := 0; j < dim; j++ { - // var buffer bytes.Buffer - // f := rand.Float32() - // err := binary.Write(&buffer, common.Endian, f) - // assert.NoError(t, err) - // bs = append(bs, buffer.Bytes()...) - // } - // values = append(values, bs) - // } - // - // return &milvuspb.PlaceholderGroup{ - // Placeholders: []*milvuspb.PlaceholderValue{ - // { - // Tag: "$0", - // Type: milvuspb.PlaceholderType_FloatVector, - // Values: values, - // }, - // }, - // } - // } - // - // constructSearchRequest := func() *milvuspb.SearchRequest { - // params := make(map[string]string) - // params["nprobe"] = strconv.Itoa(nprobe) - // b, err := json.Marshal(params) - // assert.NoError(t, err) - // plg := constructPlaceholderGroup() - // plgBs, err := proto.Marshal(plg) - // assert.NoError(t, err) - // - // return &milvuspb.SearchRequest{ - // Base: nil, - // DbName: dbName, - // CollectionName: collectionName, - // PartitionNames: nil, - // Dsl: expr, - // PlaceholderGroup: plgBs, - // DslType: commonpb.DslType_BoolExprV1, - // OutputFields: nil, - // SearchParams: []*commonpb.KeyValuePair{ - // { - // Key: MetricTypeKey, - // Value: distance.L2, - // }, - // { - // Key: SearchParamsKey, - // Value: string(b), - // }, - // { - // Key: AnnsFieldKey, - // Value: floatVecField, - // }, - // { - // Key: TopKKey, - // Value: strconv.Itoa(topk), - // }, - // { - // Key: RoundDecimalKey, - // Value: strconv.Itoa(roundDecimal), - // }, - // }, - // TravelTimestamp: 0, - // GuaranteeTimestamp: 0, - // } - // } - - // TODO(Goose): reopen after joint-tests - // if loaded { - // wg.Add(1) - // t.Run("search", func(t *testing.T) { - // defer wg.Done() - // req := constructSearchRequest() - // - // resp, err := proxy.Search(ctx, req) - // assert.NoError(t, err) - // assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) - // }) - // - // wg.Add(1) - // t.Run("query", func(t *testing.T) { - // defer wg.Done() - // //resp, err := proxy.Query(ctx, &milvuspb.QueryRequest{ - // _, err := proxy.Query(ctx, &milvuspb.QueryRequest{ - // Base: nil, - // DbName: dbName, - // CollectionName: collectionName, - // Expr: expr, - // OutputFields: nil, - // PartitionNames: nil, - // TravelTimestamp: 0, - // GuaranteeTimestamp: 0, - // }) - // assert.NoError(t, err) - // // FIXME(dragondriver) - // // assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) - // // TODO(dragondriver): compare query result - // }) - wg.Add(1) t.Run("calculate distance", func(t *testing.T) { defer wg.Done() @@ -2066,30 +2297,9 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode) - f := func() bool { - resp, err := proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - CollectionID: collectionID, - PartitionNames: []string{partitionName}, - Type: milvuspb.ShowType_InMemory, - }) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) - - for idx, name := range resp.PartitionNames { - if name == partitionName && resp.InMemoryPercentages[idx] == 100 { - return true - } - } - - return false - } - // waiting for collection to be loaded counter := 0 - for !f() { + for !checkPartitionInMemory(t, ctx, proxy, dbName, collectionName, partitionName, collectionID) { if counter > 100 { pLoaded = false break @@ -2100,6 +2310,7 @@ func TestProxy(t *testing.T) { } }) assert.True(t, pLoaded) + wg.Add(1) t.Run("show in-memory partitions", func(t *testing.T) { defer wg.Done() @@ -2169,7 +2380,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("insert partition", func(t *testing.T) { defer wg.Done() - req := constructPartitionInsertRequest() + req := constructPartitionInsertRequest(dbName, collectionName, partitionName, floatVecField, binaryVecField, structField, schema, rowNum, dim) resp, err := proxy.Insert(ctx, req) assert.NoError(t, err) @@ -2250,7 +2461,7 @@ func TestProxy(t *testing.T) { t.Run("upsert when autoID == true", func(t *testing.T) { defer wg.Done() // autoID==true but not pass pk in upsert, failed - req := constructCollectionUpsertRequestNoPK() + req := constructCollectionUpsertRequestNoPK(dbName, collectionName, floatVecField, binaryVecField, structField, schema, rowNum, dim) resp, err := proxy.Upsert(ctx, req) assert.NoError(t, err) @@ -2260,7 +2471,7 @@ func TestProxy(t *testing.T) { assert.Equal(t, int64(0), resp.UpsertCnt) // autoID==true and pass pk in upsert, succeed - req = constructCollectionUpsertRequestWithPK() + req = constructCollectionUpsertRequestWithPK(dbName, collectionName, floatVecField, binaryVecField, structField, schema, rowNum, dim) resp, err = proxy.Upsert(ctx, req) assert.NoError(t, err) @@ -2406,7 +2617,7 @@ func TestProxy(t *testing.T) { DbName: dbName, CollectionName: collectionName, FieldName: floatVecField, - IndexName: floatIndexName, + IndexName: testFloatIndexName, }) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) @@ -2503,14 +2714,7 @@ func TestProxy(t *testing.T) { defer wg.Done() // 1. create credential - constructCreateCredentialRequest := func() *milvuspb.CreateCredentialRequest { - return &milvuspb.CreateCredentialRequest{ - Base: nil, - Username: username, - Password: crypto.Base64Encode(password), - } - } - createCredentialReq := constructCreateCredentialRequest() + createCredentialReq := constructCreateCredentialRequest(username, crypto.Base64Encode(password)) // success resp, err := proxy.CreateCredential(ctx, createCredentialReq) assert.NoError(t, err) @@ -2552,16 +2756,8 @@ func TestProxy(t *testing.T) { // 2. update credential newPassword := "new_password" - constructUpdateCredentialRequest := func() *milvuspb.UpdateCredentialRequest { - return &milvuspb.UpdateCredentialRequest{ - Base: nil, - Username: username, - OldPassword: crypto.Base64Encode(password), - NewPassword: crypto.Base64Encode(newPassword), - } - } // cannot update non-existing user's password - updateCredentialReq := constructUpdateCredentialRequest() + updateCredentialReq := constructUpdateCredentialRequest(username, crypto.Base64Encode(password), crypto.Base64Encode(newPassword)) updateCredentialReq.Username = "test_username_" + funcutil.RandomString(15) updateResp, err := proxy.UpdateCredential(ctx, updateCredentialReq) assert.NoError(t, err) @@ -2617,13 +2813,7 @@ func TestProxy(t *testing.T) { // 3. get credential newPassword := "new_password" - constructGetCredentialRequest := func() *rootcoordpb.GetCredentialRequest { - return &rootcoordpb.GetCredentialRequest{ - Base: nil, - Username: username, - } - } - getCredentialReq := constructGetCredentialRequest() + getCredentialReq := constructGetCredentialRequest(username) getResp, err := rootCoordClient.GetCredential(ctx, getCredentialReq) assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, getResp.GetStatus().GetErrorCode()) @@ -2640,11 +2830,6 @@ func TestProxy(t *testing.T) { defer wg.Done() // 4. list credential usernames - constructListCredUsersRequest := func() *milvuspb.ListCredUsersRequest { - return &milvuspb.ListCredUsersRequest{ - Base: nil, - } - } listCredUsersReq := constructListCredUsersRequest() listUsersResp, err := proxy.ListCredUsers(ctx, listCredUsersReq) assert.NoError(t, err) @@ -2656,13 +2841,7 @@ func TestProxy(t *testing.T) { defer wg.Done() // 5. delete credential - constructDelCredRequest := func() *milvuspb.DeleteCredentialRequest { - return &milvuspb.DeleteCredentialRequest{ - Base: nil, - Username: username, - } - } - delCredReq := constructDelCredRequest() + delCredReq := constructDelCredRequest(username) deleteResp, err := proxy.DeleteCredential(ctx, delCredReq) assert.NoError(t, err) @@ -3676,120 +3855,8 @@ func TestProxy(t *testing.T) { testProxyRoleTimeout(shortCtx, t, proxy) testProxyPrivilegeTimeout(shortCtx, t, proxy) - constructCollectionSchema = func() *schemapb.CollectionSchema { - pk := &schemapb.FieldSchema{ - FieldID: 100, - Name: int64Field, - IsPrimaryKey: true, - Description: "", - DataType: schemapb.DataType_Int64, - TypeParams: nil, - IndexParams: nil, - AutoID: false, - } - fVec := &schemapb.FieldSchema{ - FieldID: 101, - Name: floatVecField, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: strconv.Itoa(dim), - }, - }, - IndexParams: nil, - AutoID: false, - } - bVec := &schemapb.FieldSchema{ - FieldID: 102, - Name: binaryVecField, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_BinaryVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: strconv.Itoa(dim), - }, - }, - IndexParams: nil, - AutoID: false, - } - return &schemapb.CollectionSchema{ - Name: collectionName, - Description: "", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - pk, - fVec, - bVec, - }, - } - } - schema = constructCollectionSchema() - - constructCreateCollectionRequest = func() *milvuspb.CreateCollectionRequest { - bs, err := proto.Marshal(schema) - assert.NoError(t, err) - return &milvuspb.CreateCollectionRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - Schema: bs, - ShardsNum: shardsNum, - } - } - createCollectionReq = constructCreateCollectionRequest() - - constructPartitionReqUpsertRequestValid := func() *milvuspb.UpsertRequest { - pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum) - fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) - bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) - hashKeys := testutils.GenerateHashKeys(rowNum) - return &milvuspb.UpsertRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - PartitionName: partitionName, - FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn, bVecColumn}, - HashKeys: hashKeys, - NumRows: uint32(rowNum), - } - } - - constructPartitionReqUpsertRequestInvalid := func() *milvuspb.UpsertRequest { - pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum) - fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) - bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) - hashKeys := testutils.GenerateHashKeys(rowNum) - return &milvuspb.UpsertRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - PartitionName: "%$@", - FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn, bVecColumn}, - HashKeys: hashKeys, - NumRows: uint32(rowNum), - } - } - - constructCollectionUpsertRequestValid := func() *milvuspb.UpsertRequest { - pkFieldData := newScalarFieldData(schema.Fields[0], int64Field, rowNum) - fVecColumn := newFloatVectorFieldData(floatVecField, rowNum, dim) - bVecColumn := newBinaryVectorFieldData(binaryVecField, rowNum, dim) - hashKeys := testutils.GenerateHashKeys(rowNum) - return &milvuspb.UpsertRequest{ - Base: nil, - DbName: dbName, - CollectionName: collectionName, - PartitionName: partitionName, - FieldsData: []*schemapb.FieldData{pkFieldData, fVecColumn, bVecColumn}, - HashKeys: hashKeys, - NumRows: uint32(rowNum), - } - } + schema = constructTestCollectionSchema(collectionName, int64Field, floatVecField, binaryVecField, structField, dim) + createCollectionReq = constructTestCreateCollectionRequest(dbName, collectionName, schema, shardsNum) wg.Add(1) t.Run("create collection upsert valid", func(t *testing.T) { @@ -3799,8 +3866,8 @@ func TestProxy(t *testing.T) { assert.NoError(t, err) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) - reqInvalidField := constructCreateCollectionRequest() - schema := constructCollectionSchema() + reqInvalidField := constructTestCreateCollectionRequest(dbName, collectionName, schema, shardsNum) + schema := constructTestCollectionSchema(collectionName, int64Field, floatVecField, binaryVecField, structField, dim) schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ Name: "StringField", DataType: schemapb.DataType_String, @@ -3841,7 +3908,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("upsert partition", func(t *testing.T) { defer wg.Done() - req := constructPartitionReqUpsertRequestValid() + req := constructPartitionReqUpsertRequestValid(dbName, collectionName, partitionName, floatVecField, binaryVecField, structField, schema, rowNum, dim) resp, err := proxy.Upsert(ctx, req) assert.NoError(t, err) @@ -3854,7 +3921,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("upsert when occurs unexpected error like illegal partition name", func(t *testing.T) { defer wg.Done() - req := constructPartitionReqUpsertRequestInvalid() + req := constructPartitionReqUpsertRequestInvalid(dbName, collectionName, floatVecField, binaryVecField, structField, schema, rowNum, dim) resp, err := proxy.Upsert(ctx, req) assert.NoError(t, err) @@ -3867,7 +3934,7 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("upsert when autoID == false", func(t *testing.T) { defer wg.Done() - req := constructCollectionUpsertRequestValid() + req := constructCollectionUpsertRequestValid(dbName, collectionName, floatVecField, binaryVecField, structField, schema, rowNum, dim) resp, err := proxy.Upsert(ctx, req) assert.NoError(t, err) @@ -3877,195 +3944,6 @@ func TestProxy(t *testing.T) { assert.Equal(t, int64(rowNum), resp.UpsertCnt) }) - wg.Add(1) - // todo: when struct array field is done, this will be merged with above logic - t.Run("test struct array field", func(t *testing.T) { - defer wg.Done() - - structCollectionName := prefix + "struct_" + funcutil.GenRandomStr() - structId := "structI32Array" - structVec := "structFloatVecArray" - structField := "struct" - - constructStructCollectionSchema := func() *schemapb.CollectionSchema { - pk := &schemapb.FieldSchema{ - FieldID: 100, - Name: int64Field, - IsPrimaryKey: true, - Description: "", - DataType: schemapb.DataType_Int64, - TypeParams: nil, - IndexParams: nil, - AutoID: true, - } - fVec := &schemapb.FieldSchema{ - FieldID: 101, - Name: floatVecField, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: strconv.Itoa(dim), - }, - }, - IndexParams: nil, - AutoID: false, - } - // struct schema fields - sId := &schemapb.FieldSchema{ - FieldID: 103, - Name: structId, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_Array, - ElementType: schemapb.DataType_Int32, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.MaxCapacityKey, - Value: "100", - }, - }, - IndexParams: nil, - AutoID: false, - } - sVec := &schemapb.FieldSchema{ - FieldID: 104, - Name: structVec, - IsPrimaryKey: false, - Description: "", - DataType: schemapb.DataType_ArrayOfVector, - ElementType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: common.DimKey, - Value: strconv.Itoa(dim), - }, - { - Key: common.MaxCapacityKey, - Value: "100", - }, - }, - IndexParams: nil, - AutoID: false, - } - structF := &schemapb.StructArrayFieldSchema{ - FieldID: 105, - Name: structField, - Fields: []*schemapb.FieldSchema{sId, sVec}, - } - return &schemapb.CollectionSchema{ - Name: structCollectionName, - Description: "", - AutoID: false, - Fields: []*schemapb.FieldSchema{ - pk, - fVec, - }, - StructArrayFields: []*schemapb.StructArrayFieldSchema{structF}, - } - } - - structSchema := constructStructCollectionSchema() - constructStructCreateCollectionRequest := func() *milvuspb.CreateCollectionRequest { - bs, err := proto.Marshal(structSchema) - assert.NoError(t, err) - return &milvuspb.CreateCollectionRequest{ - Base: nil, - DbName: dbName, - CollectionName: structCollectionName, - Schema: bs, - ShardsNum: shardsNum, - } - } - - createStructCollectionReq := constructStructCreateCollectionRequest() - resp, err := proxy.CreateCollection(ctx, createStructCollectionReq) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) - - reqInvalidField := constructStructCreateCollectionRequest() - invalidSchema := constructStructCollectionSchema() - invalidSchema.Fields = append(invalidSchema.Fields, &schemapb.FieldSchema{ - Name: "StringField", - DataType: schemapb.DataType_String, - }) - bs, err := proto.Marshal(invalidSchema) - assert.NoError(t, err) - reqInvalidField.CollectionName = "invalid_field_coll" - reqInvalidField.Schema = bs - - resp, err = proxy.CreateCollection(ctx, reqInvalidField) - assert.NoError(t, err) - assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode) - - hasResp, err := proxy.HasCollection(ctx, &milvuspb.HasCollectionRequest{ - Base: nil, - DbName: dbName, - CollectionName: structCollectionName, - TimeStamp: 0, - }) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, hasResp.GetStatus().GetErrorCode()) - assert.True(t, hasResp.Value) - - collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, structCollectionName) - assert.NoError(t, err) - - descResp, err := proxy.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{ - Base: nil, - DbName: dbName, - CollectionName: structCollectionName, - CollectionID: collectionID, - TimeStamp: 0, - }) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, descResp.GetStatus().GetErrorCode()) - assert.Equal(t, collectionID, descResp.CollectionID) - - fieldsMap := make(map[string]*schemapb.FieldSchema) - for _, field := range descResp.Schema.Fields { - fieldsMap[field.Name] = field - } - for _, structField := range descResp.Schema.StructArrayFields { - for _, field := range structField.Fields { - fieldsMap[field.Name] = field - } - } - assert.Equal(t, len(fieldsMap), len(structSchema.Fields)+2) - for _, field := range structSchema.Fields { - fSchema, ok := fieldsMap[field.Name] - assert.True(t, ok) - assert.True(t, proto.Equal(field, fSchema)) - } - for _, structField := range structSchema.StructArrayFields { - for _, field := range structField.Fields { - fSchema, ok := fieldsMap[field.Name] - assert.True(t, ok) - assert.True(t, proto.Equal(field, fSchema)) - } - } - - statsResp, err := proxy.GetCollectionStatistics(ctx, &milvuspb.GetCollectionStatisticsRequest{ - Base: nil, - DbName: dbName, - CollectionName: structCollectionName, - }) - assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, statsResp.GetStatus().GetErrorCode()) - - showResp, err := proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{ - Base: nil, - DbName: dbName, - TimeStamp: 0, - Type: milvuspb.ShowType_All, - }) - assert.NoError(t, err) - assert.True(t, merr.Ok(showResp.GetStatus())) - assert.Contains(t, showResp.CollectionNames, structCollectionName) - }) - testServer.gracefulStop() wg.Wait() log.Info("case done") diff --git a/internal/proxy/search_pipeline.go b/internal/proxy/search_pipeline.go index 729738ad56..b199ed7674 100644 --- a/internal/proxy/search_pipeline.go +++ b/internal/proxy/search_pipeline.go @@ -549,24 +549,24 @@ func (op *lambdaOperator) run(ctx context.Context, span trace.Span, inputs ...an type filterFieldOperator struct { outputFieldNames []string - schema *schemaInfo + fieldSchemas []*schemapb.FieldSchema } func newFilterFieldOperator(t *searchTask, _ map[string]any) (operator, error) { return &filterFieldOperator{ outputFieldNames: t.translatedOutputFields, - schema: t.schema, + fieldSchemas: typeutil.GetAllFieldSchemas(t.schema.CollectionSchema), }, nil } func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) { result := inputs[0].(*milvuspb.SearchResults) for _, retField := range result.Results.FieldsData { - for _, schemaField := range op.schema.Fields { - if retField != nil && retField.FieldId == schemaField.FieldID { - retField.FieldName = schemaField.Name - retField.Type = schemaField.DataType - retField.IsDynamic = schemaField.IsDynamic + for _, fieldSchema := range op.fieldSchemas { + if retField != nil && retField.FieldId == fieldSchema.FieldID { + retField.FieldName = fieldSchema.Name + retField.Type = fieldSchema.DataType + retField.IsDynamic = fieldSchema.IsDynamic } } } diff --git a/internal/proxy/search_pipeline_test.go b/internal/proxy/search_pipeline_test.go index 38831a3c30..e884de0d06 100644 --- a/internal/proxy/search_pipeline_test.go +++ b/internal/proxy/search_pipeline_test.go @@ -560,6 +560,79 @@ func (s *SearchPipelineSuite) TestHybridSearchPipe() { s.Len(results.Results.Scores, 20) // 2 queries * 10 topk } +func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64}, + {FieldID: 102, Name: "floatField", DataType: schemapb.DataType_Float}, + }, + StructArrayFields: []*schemapb.StructArrayFieldSchema{ + { + Name: "structArray", + Fields: []*schemapb.FieldSchema{ + {FieldID: 104, Name: "structArrayField", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}, + {FieldID: 105, Name: "structVectorField", DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector}, + }, + }, + }, + } + + task := &searchTask{ + schema: &schemaInfo{ + CollectionSchema: schema, + }, + translatedOutputFields: []string{"intField", "floatField", "structArrayField", "structVectorField"}, + } + + op, err := newFilterFieldOperator(task, nil) + s.NoError(err) + + // Create mock search results with fields including struct array fields + searchResults := &milvuspb.SearchResults{ + Results: &schemapb.SearchResultData{ + FieldsData: []*schemapb.FieldData{ + {FieldId: 101}, // intField + {FieldId: 102}, // floatField + {FieldId: 104}, // structArrayField + {FieldId: 105}, // structVectorField + }, + }, + } + + results, err := op.run(context.Background(), s.span, searchResults) + s.NoError(err) + s.NotNil(results) + + resultData := results[0].(*milvuspb.SearchResults) + s.NotNil(resultData.Results.FieldsData) + s.Len(resultData.Results.FieldsData, 4) + + // Verify all fields including struct array fields got their names and types set + for _, field := range resultData.Results.FieldsData { + switch field.FieldId { + case 101: + s.Equal("intField", field.FieldName) + s.Equal(schemapb.DataType_Int64, field.Type) + s.False(field.IsDynamic) + case 102: + s.Equal("floatField", field.FieldName) + s.Equal(schemapb.DataType_Float, field.Type) + s.False(field.IsDynamic) + case 104: + // Struct array field should be handled by GetAllFieldSchemas + s.Equal("structArrayField", field.FieldName) + s.Equal(schemapb.DataType_Array, field.Type) + s.False(field.IsDynamic) + case 105: + // Struct array vector field should be handled by GetAllFieldSchemas + s.Equal("structVectorField", field.FieldName) + s.Equal(schemapb.DataType_ArrayOfVector, field.Type) + s.False(field.IsDynamic) + } + } +} + func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() { task := getHybridSearchTask("test_collection", [][]string{ {"1", "2"}, diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 27afde7413..b6d07c5a94 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -281,6 +281,28 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb return nil, fmt.Errorf("parse iterator v2 info failed: %w", err) } + // 7. check search for embedding list + annsFieldName, _ := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, searchParamsPair) + if annsFieldName != "" { + annField := typeutil.GetFieldByName(schema, annsFieldName) + if annField != nil && annField.GetDataType() == schemapb.DataType_ArrayOfVector { + if strings.Contains(searchParamStr, radiusKey) { + return nil, merr.WrapErrParameterInvalid("", "", + "range search is not supported for vector array (embedding list) fields, fieldName: %s", annsFieldName) + } + + if groupByFieldId > 0 { + return nil, merr.WrapErrParameterInvalid("", "", + "group by search is not supported for vector array (embedding list) fields, fieldName: %s", annsFieldName) + } + + if isIterator { + return nil, merr.WrapErrParameterInvalid("", "", + "search iterator is not supported for vector array (embedding list) fields, fieldName: %s", annsFieldName) + } + } + } + return &SearchInfo{ planInfo: &planpb.QueryInfo{ Topk: queryTopK, diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 8a1d543e67..1b92871a7c 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -68,6 +68,8 @@ const ( RoundDecimalKey = "round_decimal" OffsetKey = "offset" LimitKey = "limit" + // offsets for embedding list search + LimsKey = "lims" SearchIterV2Key = "search_iter_v2" SearchIterBatchSizeKey = "search_iter_batch_size" @@ -2047,7 +2049,8 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) { loadFieldsSet := typeutil.NewSet(loadFields...) unindexedVecFields := make([]string, 0) - for _, field := range collSchema.GetFields() { + allFields := typeutil.GetAllFieldSchemas(collSchema.CollectionSchema) + for _, field := range allFields { if typeutil.IsVectorType(field.GetDataType()) && loadFieldsSet.Contain(field.GetFieldID()) { if _, ok := fieldIndexIDs[field.GetFieldID()]; !ok { unindexedVecFields = append(unindexedVecFields, field.GetName()) @@ -2055,8 +2058,6 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) { } } - // todo(SpadeA): check vector field in StructArrayField when index is implemented - if len(unindexedVecFields) != 0 { errMsg := fmt.Sprintf("there is no vector index on field: %v, please create index firstly", unindexedVecFields) log.Debug(errMsg) @@ -2305,7 +2306,8 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error { loadFieldsSet := typeutil.NewSet(loadFields...) unindexedVecFields := make([]string, 0) - for _, field := range collSchema.GetFields() { + allFields := typeutil.GetAllFieldSchemas(collSchema.CollectionSchema) + for _, field := range allFields { if typeutil.IsVectorType(field.GetDataType()) && loadFieldsSet.Contain(field.GetFieldID()) { if _, ok := fieldIndexIDs[field.GetFieldID()]; !ok { unindexedVecFields = append(unindexedVecFields, field.GetName()) @@ -2313,8 +2315,6 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error { } } - // todo(SpadeA): check vector field in StructArrayField when index is implemented - if len(unindexedVecFields) != 0 { errMsg := fmt.Sprintf("there is no vector index on field: %v, please create index firstly", unindexedVecFields) log.Ctx(ctx).Debug(errMsg) diff --git a/internal/proxy/task_index.go b/internal/proxy/task_index.go index 12dea6e8aa..90512249a8 100644 --- a/internal/proxy/task_index.go +++ b/internal/proxy/task_index.go @@ -202,10 +202,12 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error { specifyIndexType, exist := indexParamsMap[common.IndexTypeKey] if exist && specifyIndexType != "" { + // todo(SpadeA): mmap check for struct array index if err := indexparamcheck.ValidateMmapIndexParams(specifyIndexType, indexParamsMap); err != nil { log.Ctx(ctx).Warn("Invalid mmap type params", zap.String(common.IndexTypeKey, specifyIndexType), zap.Error(err)) return merr.WrapErrParameterInvalidMsg("invalid mmap type params: %s", err.Error()) } + // todo(SpadeA): check for struct array index checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(specifyIndexType) // not enable hybrid index for user, used in milvus internally if err != nil || indexparamcheck.IsHYBRIDChecker(checker) { @@ -327,16 +329,20 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error { } var config map[string]string - if typeutil.IsDenseFloatVectorType(cit.fieldSchema.DataType) { + if typeutil.IsDenseFloatVectorType(cit.fieldSchema.DataType) || + (typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsDenseFloatVectorType(cit.fieldSchema.ElementType)) { // override float vector index params by autoindex config = Params.AutoIndexConfig.IndexParams.GetAsJSONMap() - } else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) { + } else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) || + (typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsSparseFloatVectorType(cit.fieldSchema.ElementType)) { // override sparse float vector index params by autoindex config = Params.AutoIndexConfig.SparseIndexParams.GetAsJSONMap() - } else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) { + } else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) || + (typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsBinaryVectorType(cit.fieldSchema.ElementType)) { // override binary vector index params by autoindex config = Params.AutoIndexConfig.BinaryIndexParams.GetAsJSONMap() - } else if typeutil.IsIntVectorType(cit.fieldSchema.DataType) { + } else if typeutil.IsIntVectorType(cit.fieldSchema.DataType) || + (typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsIntVectorType(cit.fieldSchema.ElementType)) { // override int vector index params by autoindex config = Params.AutoIndexConfig.IndexParams.GetAsJSONMap() } @@ -397,6 +403,12 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error { if !funcutil.SliceContain(indexparamcheck.IntVectorMetrics, metricType) { return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "int vector index does not support metric type: "+metricType) } + } else if typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) { + // TODO(SpadeA): adjust it when more metric types are supported. Especially, when different metric types + // are supported for different element types. + if !funcutil.SliceContain(indexparamcheck.EmbListMetrics, metricType) { + return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector index does not support metric type: "+metricType) + } } } @@ -500,7 +512,7 @@ func checkTrain(ctx context.Context, field *schemapb.FieldSchema, indexParams ma } if typeutil.IsVectorType(field.DataType) && indexType != indexparamcheck.AutoIndex { - exist := CheckVecIndexWithDataTypeExist(indexType, field.DataType) + exist := CheckVecIndexWithDataTypeExist(indexType, field.DataType, field.ElementType) if !exist { return fmt.Errorf("data type %s can't build with this index %s", schemapb.DataType_name[int32(field.GetDataType())], indexType) } @@ -519,7 +531,7 @@ func checkTrain(ctx context.Context, field *schemapb.FieldSchema, indexParams ma return err } - if err := checker.CheckTrain(field.DataType, indexParams); err != nil { + if err := checker.CheckTrain(field.DataType, field.ElementType, indexParams); err != nil { log.Ctx(ctx).Info("create index with invalid parameters", zap.Error(err)) return err } diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 155158fafe..da723b72bc 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -21,6 +21,7 @@ import ( "os" "sort" "strconv" + "strings" "testing" "github.com/cockroachdb/errors" @@ -1164,6 +1165,131 @@ func Test_parseIndexParams(t *testing.T) { }) } +func Test_checkEmbeddingListIndex(t *testing.T) { + t.Run("check embedding list index", func(t *testing.T) { + cit := &createIndexTask{ + Condition: nil, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "EMB_LIST_HNSW", + }, + { + Key: common.MetricTypeKey, + Value: metric.MaxSim, + }, + }, + IndexName: "", + }, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "EmbListFloat", + IsPrimaryKey: false, + DataType: schemapb.DataType_ArrayOfVector, + ElementType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "128"}, + }, + }, + } + err := cit.parseIndexParams(context.TODO()) + assert.NoError(t, err) + }) + + t.Run("metrics wrong for embedding list index", func(t *testing.T) { + cit := &createIndexTask{ + Condition: nil, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "EMB_LIST_HNSW", + }, + { + Key: common.MetricTypeKey, + Value: metric.L2, + }, + }, + IndexName: "", + }, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "EmbListFloat", + IsPrimaryKey: false, + DataType: schemapb.DataType_ArrayOfVector, + ElementType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "128"}, + }, + }, + } + err := cit.parseIndexParams(context.TODO()) + assert.True(t, strings.Contains(err.Error(), "array of vector index does not support metric type: L2")) + }) + + t.Run("metric type wrong", func(t *testing.T) { + cit := &createIndexTask{ + Condition: nil, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "HNSW", + }, + { + Key: common.MetricTypeKey, + Value: metric.MaxSim, + }, + }, + IndexName: "", + }, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "FieldFloat", + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "128"}, + }, + }, + } + err := cit.parseIndexParams(context.TODO()) + assert.True(t, strings.Contains(err.Error(), "float vector index does not support metric type: MAX_SIM")) + }) + + t.Run("data type wrong", func(t *testing.T) { + cit := &createIndexTask{ + Condition: nil, + req: &milvuspb.CreateIndexRequest{ + ExtraParams: []*commonpb.KeyValuePair{ + { + Key: common.IndexTypeKey, + Value: "EMB_LIST_HNSW", + }, + { + Key: common.MetricTypeKey, + Value: metric.L2, + }, + }, + IndexName: "", + }, + fieldSchema: &schemapb.FieldSchema{ + FieldID: 101, + Name: "FieldFloat", + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "128"}, + }, + }, + } + + err := cit.parseIndexParams(context.TODO()) + assert.True(t, strings.Contains(err.Error(), "data type FloatVector can't build with this index EMB_LIST_HNSW")) + }) +} + func Test_ngram_parseIndexParams(t *testing.T) { t.Run("valid ngram index params", func(t *testing.T) { cit := &createIndexTask{ diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 910724f4b1..13491e09f6 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -217,11 +217,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error { return err } - allFields := make([]*schemapb.FieldSchema, 0, len(it.schema.Fields)+5) - allFields = append(allFields, it.schema.Fields...) - for _, structField := range it.schema.GetStructArrayFields() { - allFields = append(allFields, structField.GetFields()...) - } + allFields := typeutil.GetAllFieldSchemas(it.schema) // check primaryFieldData whether autoID is true or not // set rowIDs as primary data if autoID == true diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index de02b70195..2d39ea18d1 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -650,7 +650,9 @@ func (t *queryTask) PostExecute(ctx context.Context) error { return err } t.result.OutputFields = t.userOutputFields - reconstructStructFieldData(t.result, t.schema.CollectionSchema) + if !t.reQuery { + reconstructStructFieldDataForQuery(t.result, t.schema.CollectionSchema) + } primaryFieldSchema, err := t.schema.GetPkField() if err != nil { diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index d6db8352b4..c20fdb313e 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -1292,470 +1292,3 @@ func TestQueryTask_CanSkipAllocTimestamp(t *testing.T) { assert.True(t, skip) }) } - -func Test_reconstructStructFieldData(t *testing.T) { - t.Run("count(*) query - should return early", func(t *testing.T) { - results := &milvuspb.QueryResults{ - OutputFields: []string{"count(*)"}, - FieldsData: []*schemapb.FieldData{ - { - FieldName: "count(*)", - FieldId: 0, - Type: schemapb.DataType_Int64, - }, - }, - } - - schema := &schemapb.CollectionSchema{ - StructArrayFields: []*schemapb.StructArrayFieldSchema{ - { - FieldID: 102, - Name: "test_struct", - Fields: []*schemapb.FieldSchema{ - { - FieldID: 1021, - Name: "sub_field", - DataType: schemapb.DataType_Array, - ElementType: schemapb.DataType_Int32, - }, - }, - }, - }, - } - - originalFieldsData := make([]*schemapb.FieldData, len(results.FieldsData)) - copy(originalFieldsData, results.FieldsData) - originalOutputFields := make([]string, len(results.OutputFields)) - copy(originalOutputFields, results.OutputFields) - - reconstructStructFieldData(results, schema) - - // Should not modify anything for count(*) query - assert.Equal(t, originalFieldsData, results.FieldsData) - assert.Equal(t, originalOutputFields, results.OutputFields) - }) - - t.Run("no struct array fields - should return early", func(t *testing.T) { - results := &milvuspb.QueryResults{ - OutputFields: []string{"field1", "field2"}, - FieldsData: []*schemapb.FieldData{ - { - FieldName: "field1", - FieldId: 100, - Type: schemapb.DataType_Int64, - }, - { - FieldName: "field2", - FieldId: 101, - Type: schemapb.DataType_VarChar, - }, - }, - } - - schema := &schemapb.CollectionSchema{ - StructArrayFields: []*schemapb.StructArrayFieldSchema{}, - } - - originalFieldsData := make([]*schemapb.FieldData, len(results.FieldsData)) - copy(originalFieldsData, results.FieldsData) - originalOutputFields := make([]string, len(results.OutputFields)) - copy(originalOutputFields, results.OutputFields) - - reconstructStructFieldData(results, schema) - - // Should not modify anything when no struct array fields - assert.Equal(t, originalFieldsData, results.FieldsData) - assert.Equal(t, originalOutputFields, results.OutputFields) - }) - - t.Run("reconstruct single struct field", func(t *testing.T) { - // Create mock data - subField1Data := &schemapb.FieldData{ - FieldName: "sub_int_array", - FieldId: 1021, - Type: schemapb.DataType_Array, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_ArrayData{ - ArrayData: &schemapb.ArrayArray{ - ElementType: schemapb.DataType_Int32, - Data: []*schemapb.ScalarField{ - { - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}}, - }, - }, - }, - }, - }, - }, - }, - } - - subField2Data := &schemapb.FieldData{ - FieldName: "sub_text_array", - FieldId: 1022, - Type: schemapb.DataType_Array, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_ArrayData{ - ArrayData: &schemapb.ArrayArray{ - ElementType: schemapb.DataType_VarChar, - Data: []*schemapb.ScalarField{ - { - Data: &schemapb.ScalarField_StringData{ - StringData: &schemapb.StringArray{Data: []string{"hello", "world"}}, - }, - }, - }, - }, - }, - }, - }, - } - - results := &milvuspb.QueryResults{ - OutputFields: []string{"sub_int_array", "sub_text_array"}, - FieldsData: []*schemapb.FieldData{subField1Data, subField2Data}, - } - - schema := &schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ - { - FieldID: 100, - Name: "pk", - IsPrimaryKey: true, - DataType: schemapb.DataType_Int64, - }, - }, - StructArrayFields: []*schemapb.StructArrayFieldSchema{ - { - FieldID: 102, - Name: "test_struct", - Fields: []*schemapb.FieldSchema{ - { - FieldID: 1021, - Name: "sub_int_array", - DataType: schemapb.DataType_Array, - ElementType: schemapb.DataType_Int32, - }, - { - FieldID: 1022, - Name: "sub_text_array", - DataType: schemapb.DataType_Array, - ElementType: schemapb.DataType_VarChar, - }, - }, - }, - }, - } - - reconstructStructFieldData(results, schema) - - // Check result - assert.Len(t, results.FieldsData, 1, "Should only have one reconstructed struct field") - assert.Len(t, results.OutputFields, 1, "Output fields should only have one") - - structField := results.FieldsData[0] - assert.Equal(t, "test_struct", structField.FieldName) - assert.Equal(t, int64(102), structField.FieldId) - assert.Equal(t, schemapb.DataType_ArrayOfStruct, structField.Type) - assert.Equal(t, "test_struct", results.OutputFields[0]) - - // Check fields inside struct - structArrays := structField.GetStructArrays() - assert.NotNil(t, structArrays) - assert.Len(t, structArrays.Fields, 2, "Struct should contain 2 sub fields") - - // Check sub fields - var foundIntField, foundTextField bool - for _, field := range structArrays.Fields { - switch field.FieldId { - case 1021: - assert.Equal(t, "sub_int_array", field.FieldName) - assert.Equal(t, schemapb.DataType_Array, field.Type) - foundIntField = true - case 1022: - assert.Equal(t, "sub_text_array", field.FieldName) - assert.Equal(t, schemapb.DataType_Array, field.Type) - foundTextField = true - } - } - assert.True(t, foundIntField, "Should find int array field") - assert.True(t, foundTextField, "Should find text array field") - }) - - t.Run("mixed regular and struct fields", func(t *testing.T) { - // Create regular field data - regularField := &schemapb.FieldData{ - FieldName: "regular_field", - FieldId: 100, - Type: schemapb.DataType_Int64, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}, - }, - }, - }, - } - - // Create struct sub field data - subFieldData := &schemapb.FieldData{ - FieldName: "sub_field", - FieldId: 1021, - Type: schemapb.DataType_Array, - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_ArrayData{ - ArrayData: &schemapb.ArrayArray{ - ElementType: schemapb.DataType_Int32, - Data: []*schemapb.ScalarField{ - { - Data: &schemapb.ScalarField_IntData{ - IntData: &schemapb.IntArray{Data: []int32{10, 20}}, - }, - }, - }, - }, - }, - }, - }, - } - - results := &milvuspb.QueryResults{ - OutputFields: []string{"regular_field", "sub_field"}, - FieldsData: []*schemapb.FieldData{regularField, subFieldData}, - } - - schema := &schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ - { - FieldID: 100, - Name: "regular_field", - DataType: schemapb.DataType_Int64, - }, - }, - StructArrayFields: []*schemapb.StructArrayFieldSchema{ - { - FieldID: 102, - Name: "test_struct", - Fields: []*schemapb.FieldSchema{ - { - FieldID: 1021, - Name: "sub_field", - DataType: schemapb.DataType_Array, - ElementType: schemapb.DataType_Int32, - }, - }, - }, - }, - } - - reconstructStructFieldData(results, schema) - - // Check result: should have 2 fields (1 regular + 1 reconstructed struct) - assert.Len(t, results.FieldsData, 2) - assert.Len(t, results.OutputFields, 2) - - // Check regular and struct fields both exist - var foundRegularField, foundStructField bool - for i, field := range results.FieldsData { - switch field.FieldId { - case 100: - assert.Equal(t, "regular_field", field.FieldName) - assert.Equal(t, schemapb.DataType_Int64, field.Type) - assert.Equal(t, "regular_field", results.OutputFields[i]) - foundRegularField = true - case 102: - assert.Equal(t, "test_struct", field.FieldName) - assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type) - assert.Equal(t, "test_struct", results.OutputFields[i]) - foundStructField = true - } - } - assert.True(t, foundRegularField, "Should find regular field") - assert.True(t, foundStructField, "Should find reconstructed struct field") - }) - - t.Run("multiple struct fields", func(t *testing.T) { - // Create sub field for first struct - struct1SubField := &schemapb.FieldData{ - FieldName: "struct1_sub", - FieldId: 1021, - Type: schemapb.DataType_Array, - } - - // Create sub fields for second struct - struct2SubField1 := &schemapb.FieldData{ - FieldName: "struct2_sub1", - FieldId: 1031, - Type: schemapb.DataType_Array, - } - - struct2SubField2 := &schemapb.FieldData{ - FieldName: "struct2_sub2", - FieldId: 1032, - Type: schemapb.DataType_Array, - } - - results := &milvuspb.QueryResults{ - OutputFields: []string{"struct1_sub", "struct2_sub1", "struct2_sub2"}, - FieldsData: []*schemapb.FieldData{struct1SubField, struct2SubField1, struct2SubField2}, - } - - schema := &schemapb.CollectionSchema{ - StructArrayFields: []*schemapb.StructArrayFieldSchema{ - { - FieldID: 102, - Name: "struct1", - Fields: []*schemapb.FieldSchema{ - { - FieldID: 1021, - Name: "struct1_sub", - DataType: schemapb.DataType_Array, - ElementType: schemapb.DataType_Int32, - }, - }, - }, - { - FieldID: 103, - Name: "struct2", - Fields: []*schemapb.FieldSchema{ - { - FieldID: 1031, - Name: "struct2_sub1", - DataType: schemapb.DataType_Array, - ElementType: schemapb.DataType_VarChar, - }, - { - FieldID: 1032, - Name: "struct2_sub2", - DataType: schemapb.DataType_Array, - ElementType: schemapb.DataType_Float, - }, - }, - }, - }, - } - - reconstructStructFieldData(results, schema) - - // Check result: should have 2 reconstructed struct fields - assert.Len(t, results.FieldsData, 2) - assert.Len(t, results.OutputFields, 2) - - // Check both struct fields are reconstructed correctly - foundStruct1, foundStruct2 := false, false - for i, field := range results.FieldsData { - switch field.FieldId { - case 102: - assert.Equal(t, "struct1", field.FieldName) - assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type) - assert.Equal(t, "struct1", results.OutputFields[i]) - - structArrays := field.GetStructArrays() - assert.NotNil(t, structArrays) - assert.Len(t, structArrays.Fields, 1) - assert.Equal(t, int64(1021), structArrays.Fields[0].FieldId) - foundStruct1 = true - - case 103: - assert.Equal(t, "struct2", field.FieldName) - assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type) - assert.Equal(t, "struct2", results.OutputFields[i]) - - structArrays := field.GetStructArrays() - assert.NotNil(t, structArrays) - assert.Len(t, structArrays.Fields, 2) - - // Check struct2 contains two sub fields - subFieldIds := make([]int64, 0, 2) - for _, subField := range structArrays.Fields { - subFieldIds = append(subFieldIds, subField.FieldId) - } - assert.Contains(t, subFieldIds, int64(1031)) - assert.Contains(t, subFieldIds, int64(1032)) - foundStruct2 = true - } - } - assert.True(t, foundStruct1, "Should find struct1") - assert.True(t, foundStruct2, "Should find struct2") - }) - - t.Run("empty fields data", func(t *testing.T) { - results := &milvuspb.QueryResults{ - OutputFields: []string{}, - FieldsData: []*schemapb.FieldData{}, - } - - schema := &schemapb.CollectionSchema{ - StructArrayFields: []*schemapb.StructArrayFieldSchema{ - { - FieldID: 102, - Name: "test_struct", - Fields: []*schemapb.FieldSchema{ - { - FieldID: 1021, - Name: "sub_field", - DataType: schemapb.DataType_Array, - ElementType: schemapb.DataType_Int32, - }, - }, - }, - }, - } - - reconstructStructFieldData(results, schema) - - // Empty data should remain unchanged - assert.Len(t, results.FieldsData, 0) - assert.Len(t, results.OutputFields, 0) - }) - - t.Run("no matching sub fields", func(t *testing.T) { - // Field data does not match any struct definition - regularField := &schemapb.FieldData{ - FieldName: "regular_field", - FieldId: 200, // Not in any struct - Type: schemapb.DataType_Int64, - } - - results := &milvuspb.QueryResults{ - OutputFields: []string{"regular_field"}, - FieldsData: []*schemapb.FieldData{regularField}, - } - - schema := &schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ - { - FieldID: 200, - Name: "regular_field", - DataType: schemapb.DataType_Int64, - }, - }, - StructArrayFields: []*schemapb.StructArrayFieldSchema{ - { - FieldID: 102, - Name: "test_struct", - Fields: []*schemapb.FieldSchema{ - { - FieldID: 1021, - Name: "sub_field", - DataType: schemapb.DataType_Array, - ElementType: schemapb.DataType_Int32, - }, - }, - }, - }, - } - - reconstructStructFieldData(results, schema) - - // Should only keep the regular field, no struct field - assert.Len(t, results.FieldsData, 1) - assert.Len(t, results.OutputFields, 1) - assert.Equal(t, int64(200), results.FieldsData[0].FieldId) - assert.Equal(t, "regular_field", results.OutputFields[0]) - }) -} diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index cb570d5dec..7c6e894a0e 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -544,7 +544,8 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error { } } - vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool { + allFields := typeutil.GetAllFieldSchemas(t.schema.CollectionSchema) + vectorOutputFields := lo.Filter(allFields, func(field *schemapb.FieldSchema, _ int) bool { return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType()) }) t.needRequery = len(vectorOutputFields) > 0 @@ -765,6 +766,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error { } t.fillResult() t.result.Results.OutputFields = t.userOutputFields + reconstructStructFieldDataForSearch(t.result, t.schema.CollectionSchema) t.result.CollectionName = t.request.GetCollectionName() primaryFieldSchema, _ := t.schema.GetPkField() diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index b168c025ed..99ac3997dc 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -3548,6 +3548,278 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { assert.ErrorContains(t, err, "failed to parse input last bound") }) }) + + t.Run("check vector array unsupported features", func(t *testing.T) { + // Helper function to create a schema with vector array field + createSchemaWithVectorArray := func(annsFieldName string) *schemapb.CollectionSchema { + return &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "id", + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + }, + { + FieldID: 101, + Name: "normal_vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "128"}, + }, + }, + { + FieldID: 102, + Name: "group_field", + DataType: schemapb.DataType_VarChar, + }, + }, + StructArrayFields: []*schemapb.StructArrayFieldSchema{ + { + FieldID: 103, + Name: "struct_array_field", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 104, + Name: annsFieldName, + DataType: schemapb.DataType_ArrayOfVector, + ElementType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "128"}, + }, + }, + }, + }, + }, + } + } + + // Helper function to create search params with anns field + createSearchParams := func(annsFieldName string) []*commonpb.KeyValuePair { + return []*commonpb.KeyValuePair{ + { + Key: AnnsFieldKey, + Value: annsFieldName, + }, + { + Key: TopKKey, + Value: "10", + }, + { + Key: common.MetricTypeKey, + Value: metric.MaxSim, + }, + { + Key: ParamsKey, + Value: `{"nprobe": 10}`, + }, + { + Key: RoundDecimalKey, + Value: "-1", + }, + { + Key: IgnoreGrowingKey, + Value: "false", + }, + } + } + + t.Run("vector array with range search", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + params := createSearchParams("embeddings_list") + + // Add radius parameter for range search + resetSearchParamsValue(params, ParamsKey, `{"nprobe": 10, "radius": 0.2}`) + + searchInfo, err := parseSearchInfo(params, schema, nil) + assert.Error(t, err) + assert.Nil(t, searchInfo) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + fmt.Println(err.Error()) + assert.Contains(t, err.Error(), "range search is not supported for vector array (embedding list) fields") + }) + + t.Run("vector array with group by", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + params := createSearchParams("embeddings_list") + + // Add group by parameter + params = append(params, &commonpb.KeyValuePair{ + Key: GroupByFieldKey, + Value: "group_field", + }) + + searchInfo, err := parseSearchInfo(params, schema, nil) + assert.Error(t, err) + assert.Nil(t, searchInfo) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "group by search is not supported for vector array (embedding list) fields") + assert.Contains(t, err.Error(), "embeddings_list") + }) + + t.Run("vector array with iterator", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + params := createSearchParams("embeddings_list") + + // Add iterator parameter + params = append(params, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }) + + searchInfo, err := parseSearchInfo(params, schema, nil) + assert.Error(t, err) + assert.Nil(t, searchInfo) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "search iterator is not supported for vector array (embedding list) fields") + assert.Contains(t, err.Error(), "embeddings_list") + }) + + t.Run("vector array with iterator v2", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + params := createSearchParams("embeddings_list") + + // Add iterator v2 parameters + params = append(params, + &commonpb.KeyValuePair{ + Key: SearchIterV2Key, + Value: "True", + }, + &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }, + &commonpb.KeyValuePair{ + Key: SearchIterBatchSizeKey, + Value: "10", + }, + ) + + searchInfo, err := parseSearchInfo(params, schema, nil) + assert.Error(t, err) + assert.Nil(t, searchInfo) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "search iterator is not supported for vector array (embedding list) fields") + assert.Contains(t, err.Error(), "embeddings_list") + }) + + t.Run("normal search on vector array should succeed", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + params := createSearchParams("embeddings_list") + + searchInfo, err := parseSearchInfo(params, schema, nil) + assert.NoError(t, err) + assert.NotNil(t, searchInfo) + assert.NotNil(t, searchInfo.planInfo) + }) + + t.Run("normal vector field with range search should succeed", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + params := createSearchParams("normal_vector") + + // Add radius parameter for range search + resetSearchParamsValue(params, ParamsKey, `{"nprobe": 10, "radius": 0.2}`) + + searchInfo, err := parseSearchInfo(params, schema, nil) + assert.NoError(t, err) + assert.NotNil(t, searchInfo) + assert.NotNil(t, searchInfo.planInfo) + }) + + t.Run("normal vector field with group by should succeed", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + params := createSearchParams("normal_vector") + + // Add group by parameter + params = append(params, &commonpb.KeyValuePair{ + Key: GroupByFieldKey, + Value: "group_field", + }) + + searchInfo, err := parseSearchInfo(params, schema, nil) + assert.NoError(t, err) + assert.NotNil(t, searchInfo) + assert.NotNil(t, searchInfo.planInfo) + assert.Equal(t, int64(102), searchInfo.planInfo.GroupByFieldId) + }) + + t.Run("normal vector field with iterator should succeed", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + params := createSearchParams("normal_vector") + + // Add iterator parameter + params = append(params, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }) + + searchInfo, err := parseSearchInfo(params, schema, nil) + assert.NoError(t, err) + assert.NotNil(t, searchInfo) + assert.NotNil(t, searchInfo.planInfo) + }) + + t.Run("vector array with range search", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + params := createSearchParams("embeddings_list") + resetSearchParamsValue(params, ParamsKey, `{"nprobe": 10, "radius": 0.2}`) + + searchInfo, err := parseSearchInfo(params, schema, nil) + assert.Error(t, err) + assert.Nil(t, searchInfo) + // Should fail on range search first + assert.Contains(t, err.Error(), "range search is not supported for vector array (embedding list) fields") + }) + + t.Run("no anns field specified", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + params := getValidSearchParams() + // Don't specify anns field + + searchInfo, err := parseSearchInfo(params, schema, nil) + assert.NoError(t, err) + assert.NotNil(t, searchInfo) + // Should not trigger vector array validation without anns field + }) + + t.Run("non-existent anns field", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + params := createSearchParams("non_existent_field") + + searchInfo, err := parseSearchInfo(params, schema, nil) + assert.NoError(t, err) + assert.NotNil(t, searchInfo) + // Should not trigger vector array validation for non-existent field + }) + + t.Run("hybrid search with outer group by on vector array", func(t *testing.T) { + schema := createSchemaWithVectorArray("embeddings_list") + + // Create rank params with group by + rankParams := getValidSearchParams() + rankParams = append(rankParams, + &commonpb.KeyValuePair{ + Key: GroupByFieldKey, + Value: "group_field", + }, + &commonpb.KeyValuePair{ + Key: LimitKey, + Value: "100", + }, + ) + + parsedRankParams, err := parseRankParams(rankParams, schema) + assert.NoError(t, err) + + searchParams := createSearchParams("embeddings_list") + // Parse search info with rank params + searchInfo, err := parseSearchInfo(searchParams, schema, parsedRankParams) + assert.Error(t, err) + assert.Nil(t, searchInfo) + assert.ErrorIs(t, err, merr.ErrParameterInvalid) + assert.Contains(t, err.Error(), "group by search is not supported for vector array (embedding list) fields") + }) + }) } func getSearchResultData(nq, topk int64) *schemapb.SearchResultData { @@ -4301,3 +4573,110 @@ func genTestSearchResultData(nq int64, topk int64, dType schemapb.DataType, fiel } return result } + +func TestSearchTask_InitSearchRequestWithStructArrayFields(t *testing.T) { + ctx := context.Background() + + schema := &schemapb.CollectionSchema{ + Name: "test_collection", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "regular_vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}}, + {FieldID: 102, Name: "regular_scalar", DataType: schemapb.DataType_Int32}, + }, + StructArrayFields: []*schemapb.StructArrayFieldSchema{ + { + Name: "structArray", + Fields: []*schemapb.FieldSchema{ + {FieldID: 104, Name: "struct_vec_array", DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "64"}}}, + {FieldID: 105, Name: "struct_scalar_array", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32}, + }, + }, + }, + } + + schemaInfo := newSchemaInfo(schema) + + tests := []struct { + name string + outputFields []string + expectedRequery bool + description string + }{ + { + name: "regular_vector_field", + outputFields: []string{"pk", "regular_vec"}, + expectedRequery: true, + description: "Should require requery when regular vector field in output", + }, + { + name: "struct_array_vector_field", + outputFields: []string{"pk", "struct_vec_array"}, + expectedRequery: true, + description: "Should require requery when struct array vector field in output (tests GetAllFieldSchemas)", + }, + { + name: "both_vector_fields", + outputFields: []string{"pk", "regular_vec", "struct_vec_array"}, + expectedRequery: true, + description: "Should require requery when both regular and struct array vector fields in output", + }, + { + name: "struct_scalar_array_only", + outputFields: []string{"pk", "struct_scalar_array"}, + expectedRequery: false, + description: "Should not require requery when only struct scalar array field in output", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + task := &searchTask{ + ctx: ctx, + collectionName: "test_collection", + SearchRequest: &internalpb.SearchRequest{ + CollectionID: 1, + PartitionIDs: []int64{1}, + Dsl: "", + PlaceholderGroup: nil, + DslType: commonpb.DslType_BoolExprV1, + OutputFieldsId: []int64{}, + }, + request: &milvuspb.SearchRequest{ + CollectionName: "test_collection", + OutputFields: tt.outputFields, + Dsl: "", + SearchParams: []*commonpb.KeyValuePair{ + {Key: AnnsFieldKey, Value: "regular_vec"}, + {Key: TopKKey, Value: "10"}, + {Key: common.MetricTypeKey, Value: metric.L2}, + {Key: ParamsKey, Value: `{"nprobe": 10}`}, + }, + PlaceholderGroup: nil, + ConsistencyLevel: commonpb.ConsistencyLevel_Session, + }, + schema: schemaInfo, + translatedOutputFields: tt.outputFields, + tr: timerecord.NewTimeRecorder("test"), + queryInfos: []*planpb.QueryInfo{{}}, + } + + // Set translated output field IDs based on the schema + outputFieldIDs := []int64{} + allFields := typeutil.GetAllFieldSchemas(schema) + for _, fieldName := range tt.outputFields { + for _, field := range allFields { + if field.Name == fieldName { + outputFieldIDs = append(outputFieldIDs, field.FieldID) + break + } + } + } + task.SearchRequest.OutputFieldsId = outputFieldIDs + + err := task.initSearchRequest(ctx) + assert.NoError(t, err) + assert.Equal(t, tt.expectedRequery, task.needRequery, tt.description) + }) + } +} diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index bde01bad29..ea6705f69d 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -585,11 +585,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { return err } - allFields := make([]*schemapb.FieldSchema, 0, len(it.schema.Fields)+5) - allFields = append(allFields, it.schema.Fields...) - for _, structField := range it.schema.GetStructArrayFields() { - allFields = append(allFields, structField.GetFields()...) - } + allFields := typeutil.GetAllFieldSchemas(it.schema.CollectionSchema) // use the passed pk as new pk when autoID == false // automatic generate pk as new pk wehen autoID == true diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 48949a818c..1984a57c09 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -619,9 +619,6 @@ func ValidateFieldsInStruct(field *schemapb.FieldSchema, schema *schemapb.Collec if field.GetNullable() { return fmt.Errorf("nullable is not supported for fields in struct array now, fieldName = %s", field.Name) } - - // todo(SpadeA): add more check when index is enabled - return nil } @@ -2547,3 +2544,88 @@ func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 { return 0 } + +// reconstructStructFieldDataCommon reconstructs struct fields from flattened sub-fields +// It works with both QueryResults and SearchResults by operating on the common data structures +func reconstructStructFieldDataCommon( + fieldsData []*schemapb.FieldData, + outputFields []string, + schema *schemapb.CollectionSchema, +) ([]*schemapb.FieldData, []string) { + if len(outputFields) == 1 && outputFields[0] == "count(*)" { + return fieldsData, outputFields + } + + if len(schema.StructArrayFields) == 0 { + return fieldsData, outputFields + } + + regularFieldIDs := make(map[int64]interface{}) + subFieldToStructMap := make(map[int64]int64) + groupedStructFields := make(map[int64][]*schemapb.FieldData) + structFieldNames := make(map[int64]string) + reconstructedOutputFields := make([]string, 0, len(fieldsData)) + + // record all regular field IDs + for _, field := range schema.Fields { + regularFieldIDs[field.GetFieldID()] = nil + } + + // build the mapping from sub-field ID to struct field ID + for _, structField := range schema.StructArrayFields { + for _, subField := range structField.GetFields() { + subFieldToStructMap[subField.GetFieldID()] = structField.GetFieldID() + } + structFieldNames[structField.GetFieldID()] = structField.GetName() + } + + newFieldsData := make([]*schemapb.FieldData, 0, len(fieldsData)) + for _, field := range fieldsData { + fieldID := field.GetFieldId() + if _, ok := regularFieldIDs[fieldID]; ok { + newFieldsData = append(newFieldsData, field) + reconstructedOutputFields = append(reconstructedOutputFields, field.GetFieldName()) + } else { + structFieldID := subFieldToStructMap[fieldID] + groupedStructFields[structFieldID] = append(groupedStructFields[structFieldID], field) + } + } + + for structFieldID, fields := range groupedStructFields { + fieldData := &schemapb.FieldData{ + FieldName: structFieldNames[structFieldID], + FieldId: structFieldID, + Type: schemapb.DataType_ArrayOfStruct, + Field: &schemapb.FieldData_StructArrays{StructArrays: &schemapb.StructArrayField{Fields: fields}}, + } + newFieldsData = append(newFieldsData, fieldData) + reconstructedOutputFields = append(reconstructedOutputFields, structFieldNames[structFieldID]) + } + + return newFieldsData, reconstructedOutputFields +} + +// Wrapper for QueryResults +func reconstructStructFieldDataForQuery(results *milvuspb.QueryResults, schema *schemapb.CollectionSchema) { + fieldsData, outputFields := reconstructStructFieldDataCommon( + results.FieldsData, + results.OutputFields, + schema, + ) + results.FieldsData = fieldsData + results.OutputFields = outputFields +} + +// New wrapper for SearchResults +func reconstructStructFieldDataForSearch(results *milvuspb.SearchResults, schema *schemapb.CollectionSchema) { + if results.Results == nil { + return + } + fieldsData, outputFields := reconstructStructFieldDataCommon( + results.Results.FieldsData, + results.Results.OutputFields, + schema, + ) + results.Results.FieldsData = fieldsData + results.Results.OutputFields = outputFields +} diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index b33e0ad80a..57a8accade 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -3825,5 +3825,644 @@ func TestCheckAndFlattenStructFieldData(t *testing.T) { } func TestValidateFieldsInStruct(t *testing.T) { - // todo(SpadeA): add test cases + schema := &schemapb.CollectionSchema{ + Name: "test_collection", + } + + t.Run("valid array field", func(t *testing.T) { + field := &schemapb.FieldSchema{ + Name: "valid_array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + } + err := ValidateFieldsInStruct(field, schema) + assert.NoError(t, err) + }) + + t.Run("valid array of vector field", func(t *testing.T) { + field := &schemapb.FieldSchema{ + Name: "valid_array_vector", + DataType: schemapb.DataType_ArrayOfVector, + ElementType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "128"}, + }, + } + err := ValidateFieldsInStruct(field, schema) + assert.NoError(t, err) + }) + + t.Run("invalid field name", func(t *testing.T) { + testCases := []struct { + name string + expected string + }{ + {"", "field name should not be empty"}, + {"123abc", "The first character of a field name must be an underscore or letter"}, + {"abc-def", "Field name can only contain numbers, letters, and underscores"}, + } + + for _, tc := range testCases { + field := &schemapb.FieldSchema{ + Name: tc.name, + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + } + err := ValidateFieldsInStruct(field, schema) + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expected) + } + }) + + t.Run("invalid data type", func(t *testing.T) { + field := &schemapb.FieldSchema{ + Name: "invalid_type", + DataType: schemapb.DataType_Int32, // Not array or array of vector + } + err := ValidateFieldsInStruct(field, schema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Fields in StructArrayField can only be array or array of struct") + }) + + t.Run("nested array not supported", func(t *testing.T) { + testCases := []struct { + elementType schemapb.DataType + }{ + {schemapb.DataType_ArrayOfStruct}, + {schemapb.DataType_ArrayOfVector}, + {schemapb.DataType_Array}, + } + + for _, tc := range testCases { + field := &schemapb.FieldSchema{ + Name: "nested_array", + DataType: schemapb.DataType_Array, + ElementType: tc.elementType, + } + err := ValidateFieldsInStruct(field, schema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Nested array is not supported") + } + }) + + t.Run("array field with vector element type", func(t *testing.T) { + field := &schemapb.FieldSchema{ + Name: "array_with_vector", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_FloatVector, + } + err := ValidateFieldsInStruct(field, schema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "element type of array field array_with_vector is a vector type") + }) + + t.Run("array of vector field with non-vector element type", func(t *testing.T) { + field := &schemapb.FieldSchema{ + Name: "array_vector_with_scalar", + DataType: schemapb.DataType_ArrayOfVector, + ElementType: schemapb.DataType_Int32, + } + err := ValidateFieldsInStruct(field, schema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "element type of array field array_vector_with_scalar is not a vector type") + }) + + t.Run("array of vector missing dimension", func(t *testing.T) { + field := &schemapb.FieldSchema{ + Name: "array_vector_no_dim", + DataType: schemapb.DataType_ArrayOfVector, + ElementType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{}, // No dimension specified + } + err := ValidateFieldsInStruct(field, schema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "dimension is not defined in field") + }) + + t.Run("array of vector with invalid dimension", func(t *testing.T) { + field := &schemapb.FieldSchema{ + Name: "array_vector_invalid_dim", + DataType: schemapb.DataType_ArrayOfVector, + ElementType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "not_a_number"}, + }, + } + err := ValidateFieldsInStruct(field, schema) + assert.Error(t, err) + }) + + t.Run("varchar array without max_length", func(t *testing.T) { + field := &schemapb.FieldSchema{ + Name: "varchar_array_no_max_length", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{}, // No max_length specified + } + err := ValidateFieldsInStruct(field, schema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "type param(max_length) should be specified") + }) + + t.Run("varchar array with valid max_length", func(t *testing.T) { + field := &schemapb.FieldSchema{ + Name: "varchar_array_valid", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.MaxLengthKey, Value: "100"}, + }, + } + err := ValidateFieldsInStruct(field, schema) + assert.NoError(t, err) + }) + + t.Run("varchar array with invalid max_length", func(t *testing.T) { + // Test with max_length exceeding limit + field := &schemapb.FieldSchema{ + Name: "varchar_array_invalid_length", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.MaxLengthKey, Value: "99999999"}, // Exceeds limit + }, + } + err := ValidateFieldsInStruct(field, schema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "the maximum length specified for the field") + }) + + t.Run("nullable field not supported", func(t *testing.T) { + field := &schemapb.FieldSchema{ + Name: "nullable_field", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + Nullable: true, + } + err := ValidateFieldsInStruct(field, schema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "nullable is not supported for fields in struct array now") + }) + + t.Run("sparse float vector in array of vector", func(t *testing.T) { + // Note: ArrayOfVector with sparse vector element type still requires dimension + // because validateDimension checks the field's DataType (ArrayOfVector), not ElementType + field := &schemapb.FieldSchema{ + Name: "sparse_vector_array", + DataType: schemapb.DataType_ArrayOfVector, + ElementType: schemapb.DataType_SparseFloatVector, + TypeParams: []*commonpb.KeyValuePair{}, + } + err := ValidateFieldsInStruct(field, schema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "dimension is not defined") + }) + + t.Run("array with various scalar element types", func(t *testing.T) { + validScalarTypes := []schemapb.DataType{ + schemapb.DataType_Bool, + schemapb.DataType_Int8, + schemapb.DataType_Int16, + schemapb.DataType_Int32, + schemapb.DataType_Int64, + schemapb.DataType_Float, + schemapb.DataType_Double, + schemapb.DataType_String, + } + + for _, dt := range validScalarTypes { + field := &schemapb.FieldSchema{ + Name: "array_" + dt.String(), + DataType: schemapb.DataType_Array, + ElementType: dt, + } + err := ValidateFieldsInStruct(field, schema) + assert.NoError(t, err) + } + }) + + t.Run("array of vector with various vector types", func(t *testing.T) { + validVectorTypes := []schemapb.DataType{ + schemapb.DataType_FloatVector, + schemapb.DataType_BinaryVector, + schemapb.DataType_Float16Vector, + schemapb.DataType_BFloat16Vector, + // Note: SparseFloatVector is excluded because validateDimension checks + // the field's DataType (ArrayOfVector), not ElementType, so it still requires dimension + } + + for _, vt := range validVectorTypes { + field := &schemapb.FieldSchema{ + Name: "vector_array_" + vt.String(), + DataType: schemapb.DataType_ArrayOfVector, + ElementType: vt, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "128"}, + }, + } + err := ValidateFieldsInStruct(field, schema) + assert.NoError(t, err) + } + }) +} + +func Test_reconstructStructFieldDataCommon(t *testing.T) { + t.Run("count(*) query - should return early", func(t *testing.T) { + fieldsData := []*schemapb.FieldData{ + { + FieldName: "count(*)", + FieldId: 0, + Type: schemapb.DataType_Int64, + }, + } + outputFields := []string{"count(*)"} + + schema := &schemapb.CollectionSchema{ + StructArrayFields: []*schemapb.StructArrayFieldSchema{ + { + FieldID: 102, + Name: "test_struct", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1021, + Name: "sub_field", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + }, + }, + }, + }, + } + + originalFieldsData := make([]*schemapb.FieldData, len(fieldsData)) + copy(originalFieldsData, fieldsData) + originalOutputFields := make([]string, len(outputFields)) + copy(originalOutputFields, outputFields) + + resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema) + + // Should not modify anything for count(*) query + assert.Equal(t, originalFieldsData, resultFieldsData) + assert.Equal(t, originalOutputFields, resultOutputFields) + }) + + t.Run("no struct array fields - should return early", func(t *testing.T) { + fieldsData := []*schemapb.FieldData{ + { + FieldName: "field1", + FieldId: 100, + Type: schemapb.DataType_Int64, + }, + { + FieldName: "field2", + FieldId: 101, + Type: schemapb.DataType_VarChar, + }, + } + outputFields := []string{"field1", "field2"} + + schema := &schemapb.CollectionSchema{ + StructArrayFields: []*schemapb.StructArrayFieldSchema{}, + } + + originalFieldsData := make([]*schemapb.FieldData, len(fieldsData)) + copy(originalFieldsData, fieldsData) + originalOutputFields := make([]string, len(outputFields)) + copy(originalOutputFields, outputFields) + + resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema) + + // Should not modify anything when no struct array fields + assert.Equal(t, originalFieldsData, resultFieldsData) + assert.Equal(t, originalOutputFields, resultOutputFields) + }) + + t.Run("reconstruct single struct field", func(t *testing.T) { + // Create mock data + subField1Data := &schemapb.FieldData{ + FieldName: "sub_int_array", + FieldId: 1021, + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + ElementType: schemapb.DataType_Int32, + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}}, + }, + }, + }, + }, + }, + }, + }, + } + + subField2Data := &schemapb.FieldData{ + FieldName: "sub_text_array", + FieldId: 1022, + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + ElementType: schemapb.DataType_VarChar, + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{Data: []string{"hello", "world"}}, + }, + }, + }, + }, + }, + }, + }, + } + + fieldsData := []*schemapb.FieldData{subField1Data, subField2Data} + outputFields := []string{"sub_int_array", "sub_text_array"} + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + }, + StructArrayFields: []*schemapb.StructArrayFieldSchema{ + { + FieldID: 102, + Name: "test_struct", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1021, + Name: "sub_int_array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + }, + { + FieldID: 1022, + Name: "sub_text_array", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_VarChar, + }, + }, + }, + }, + } + + resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema) + + // Check result + assert.Len(t, resultFieldsData, 1, "Should only have one reconstructed struct field") + assert.Len(t, resultOutputFields, 1, "Output fields should only have one") + + structField := resultFieldsData[0] + assert.Equal(t, "test_struct", structField.FieldName) + assert.Equal(t, int64(102), structField.FieldId) + assert.Equal(t, schemapb.DataType_ArrayOfStruct, structField.Type) + assert.Equal(t, "test_struct", resultOutputFields[0]) + + // Check fields inside struct + structArrays := structField.GetStructArrays() + assert.NotNil(t, structArrays) + assert.Len(t, structArrays.Fields, 2, "Struct should contain 2 sub fields") + + // Check sub fields + var foundIntField, foundTextField bool + for _, field := range structArrays.Fields { + switch field.FieldId { + case 1021: + assert.Equal(t, "sub_int_array", field.FieldName) + assert.Equal(t, schemapb.DataType_Array, field.Type) + foundIntField = true + case 1022: + assert.Equal(t, "sub_text_array", field.FieldName) + assert.Equal(t, schemapb.DataType_Array, field.Type) + foundTextField = true + } + } + assert.True(t, foundIntField, "Should find int array field") + assert.True(t, foundTextField, "Should find text array field") + }) + + t.Run("mixed regular and struct fields", func(t *testing.T) { + // Create regular field data + regularField := &schemapb.FieldData{ + FieldName: "regular_field", + FieldId: 100, + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}, + }, + }, + }, + } + + // Create struct sub field data + subFieldData := &schemapb.FieldData{ + FieldName: "sub_field", + FieldId: 1021, + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + ElementType: schemapb.DataType_Int32, + Data: []*schemapb.ScalarField{ + { + Data: &schemapb.ScalarField_IntData{ + IntData: &schemapb.IntArray{Data: []int32{10, 20}}, + }, + }, + }, + }, + }, + }, + }, + } + + fieldsData := []*schemapb.FieldData{regularField, subFieldData} + outputFields := []string{"regular_field", "sub_field"} + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "regular_field", + DataType: schemapb.DataType_Int64, + }, + }, + StructArrayFields: []*schemapb.StructArrayFieldSchema{ + { + FieldID: 102, + Name: "test_struct", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1021, + Name: "sub_field", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + }, + }, + }, + }, + } + + resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema) + + // Check result: should have 2 fields (1 regular + 1 reconstructed struct) + assert.Len(t, resultFieldsData, 2) + assert.Len(t, resultOutputFields, 2) + + // Check regular and struct fields both exist + var foundRegularField, foundStructField bool + for i, field := range resultFieldsData { + switch field.FieldId { + case 100: + assert.Equal(t, "regular_field", field.FieldName) + assert.Equal(t, schemapb.DataType_Int64, field.Type) + assert.Equal(t, "regular_field", resultOutputFields[i]) + foundRegularField = true + case 102: + assert.Equal(t, "test_struct", field.FieldName) + assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type) + assert.Equal(t, "test_struct", resultOutputFields[i]) + foundStructField = true + } + } + assert.True(t, foundRegularField, "Should find regular field") + assert.True(t, foundStructField, "Should find reconstructed struct field") + }) + + t.Run("multiple struct fields", func(t *testing.T) { + // Create sub field for first struct + struct1SubField := &schemapb.FieldData{ + FieldName: "struct1_sub", + FieldId: 1021, + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + ElementType: schemapb.DataType_Int32, + Data: []*schemapb.ScalarField{}, + }, + }, + }, + }, + } + + // Create sub fields for second struct + struct2SubField1 := &schemapb.FieldData{ + FieldName: "struct2_sub1", + FieldId: 1031, + Type: schemapb.DataType_Array, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_ArrayData{ + ArrayData: &schemapb.ArrayArray{ + ElementType: schemapb.DataType_Int32, + Data: []*schemapb.ScalarField{}, + }, + }, + }, + }, + } + + struct2SubField2 := &schemapb.FieldData{ + FieldName: "struct2_sub2", + FieldId: 1032, + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{Data: []string{"test"}}, + }, + }, + }, + } + + fieldsData := []*schemapb.FieldData{struct1SubField, struct2SubField1, struct2SubField2} + outputFields := []string{"struct1_sub", "struct2_sub1", "struct2_sub2"} + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "pk", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + }, + StructArrayFields: []*schemapb.StructArrayFieldSchema{ + { + FieldID: 102, + Name: "struct1", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1021, + Name: "struct1_sub", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + }, + }, + }, + { + FieldID: 103, + Name: "struct2", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1031, + Name: "struct2_sub1", + DataType: schemapb.DataType_Array, + ElementType: schemapb.DataType_Int32, + }, + { + FieldID: 1032, + Name: "struct2_sub2", + DataType: schemapb.DataType_VarChar, + }, + }, + }, + }, + } + + resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema) + + // Check result: should have 2 struct fields + assert.Len(t, resultFieldsData, 2) + assert.Len(t, resultOutputFields, 2) + + // Check both struct fields + var foundStruct1, foundStruct2 bool + for _, field := range resultFieldsData { + switch field.FieldId { + case 102: + assert.Equal(t, "struct1", field.FieldName) + assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type) + foundStruct1 = true + structArrays := field.GetStructArrays() + assert.NotNil(t, structArrays) + assert.Len(t, structArrays.Fields, 1) + case 103: + assert.Equal(t, "struct2", field.FieldName) + assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type) + foundStruct2 = true + structArrays := field.GetStructArrays() + assert.NotNil(t, structArrays) + assert.Len(t, structArrays.Fields, 2) + } + } + assert.True(t, foundStruct1, "Should find struct1") + assert.True(t, foundStruct2, "Should find struct2") + }) } diff --git a/internal/querynodev2/segments/load_index_info.go b/internal/querynodev2/segments/load_index_info.go index 9068ce5a47..a055c22fd5 100644 --- a/internal/querynodev2/segments/load_index_info.go +++ b/internal/querynodev2/segments/load_index_info.go @@ -33,7 +33,6 @@ import ( "google.golang.org/protobuf/proto" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/proto/cgopb" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" @@ -115,25 +114,6 @@ func (li *LoadIndexInfo) appendIndexFile(ctx context.Context, filePath string) e return HandleCStatus(ctx, &status, "AppendIndexIFile failed") } -// appendFieldInfo appends fieldID & fieldType to index -func (li *LoadIndexInfo) appendFieldInfo(ctx context.Context, collectionID int64, partitionID int64, segmentID int64, fieldID int64, fieldType schemapb.DataType, enableMmap bool, mmapDirPath string) error { - var status C.CStatus - GetDynamicPool().Submit(func() (any, error) { - cColID := C.int64_t(collectionID) - cParID := C.int64_t(partitionID) - cSegID := C.int64_t(segmentID) - cFieldID := C.int64_t(fieldID) - cintDType := uint32(fieldType) - cEnableMmap := C.bool(enableMmap) - cMmapDirPath := C.CString(mmapDirPath) - defer C.free(unsafe.Pointer(cMmapDirPath)) - status = C.AppendFieldInfo(li.cLoadIndexInfo, cColID, cParID, cSegID, cFieldID, cintDType, cEnableMmap, cMmapDirPath) - return nil, nil - }).Await() - - return HandleCStatus(ctx, &status, "AppendFieldInfo failed") -} - func (li *LoadIndexInfo) appendStorageInfo(uri string, version int64) { GetDynamicPool().Submit(func() (any, error) { cURI := C.CString(uri) diff --git a/internal/rootcoord/util.go b/internal/rootcoord/util.go index 9c587e26ac..e7d94c35a8 100644 --- a/internal/rootcoord/util.go +++ b/internal/rootcoord/util.go @@ -476,9 +476,16 @@ func checkFieldSchema(fieldSchemas []*schemapb.FieldSchema) error { func checkStructArrayFieldSchema(schemas []*schemapb.StructArrayFieldSchema) error { for _, schema := range schemas { - // todo(SpadeA): check struct array field schema + if len(schema.GetFields()) == 0 { + return merr.WrapErrParameterInvalidMsg("empty fields in StructArrayField is not allowed") + } for _, field := range schema.GetFields() { + if field.GetDataType() != schemapb.DataType_Array && field.GetDataType() != schemapb.DataType_ArrayOfVector { + msg := fmt.Sprintf("Fields in StructArrayField can only be array or array of vector, but field %s is %s", field.Name, field.DataType.String()) + return merr.WrapErrParameterInvalidMsg(msg) + } + if field.IsPartitionKey || field.IsPrimaryKey { msg := fmt.Sprintf("partition key or primary key can not be in struct array field. data type:%s, element type:%s, name:%s", field.DataType.String(), field.ElementType.String(), field.Name) diff --git a/internal/storage/data_sorter.go b/internal/storage/data_sorter.go index 8b97a6f23c..0c57f5d662 100644 --- a/internal/storage/data_sorter.go +++ b/internal/storage/data_sorter.go @@ -19,6 +19,7 @@ package storage import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/v2/common" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) // DataSorter sorts insert data @@ -52,11 +53,7 @@ func (ds *DataSorter) Len() int { // Swap swaps each field's i-th and j-th element func (ds *DataSorter) Swap(i, j int) { if ds.AllFields == nil { - allFields := ds.InsertCodec.Schema.Schema.Fields - for _, field := range ds.InsertCodec.Schema.Schema.StructArrayFields { - allFields = append(allFields, field.Fields...) - } - ds.AllFields = allFields + ds.AllFields = typeutil.GetAllFieldSchemas(ds.InsertCodec.Schema.Schema) } for _, field := range ds.AllFields { singleData, has := ds.InsertData.Data[field.FieldID] diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index 5d569c4767..7d016b18ba 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -282,9 +282,9 @@ func ValueDeserializerWithSchema(r Record, v []*Value, schema *schemapb.Collecti return valueDeserializer(r, v, allFields, shouldCopy) } -func valueDeserializer(r Record, v []*Value, fieldSchema []*schemapb.FieldSchema, shouldCopy bool) error { +func valueDeserializer(r Record, v []*Value, fields []*schemapb.FieldSchema, shouldCopy bool) error { pkField := func() *schemapb.FieldSchema { - for _, field := range fieldSchema { + for _, field := range fields { if field.GetIsPrimaryKey() { return field } @@ -299,12 +299,12 @@ func valueDeserializer(r Record, v []*Value, fieldSchema []*schemapb.FieldSchema value := v[i] if value == nil { value = &Value{} - value.Value = make(map[FieldID]interface{}, len(fieldSchema)) + value.Value = make(map[FieldID]interface{}, len(fields)) v[i] = value } m := value.Value.(map[FieldID]interface{}) - for _, f := range fieldSchema { + for _, f := range fields { j := f.FieldID dt := f.DataType if r.Column(j).IsNull(i) { diff --git a/internal/storage/utils.go b/internal/storage/utils.go index 7f03c8e001..0cbe762651 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -1532,7 +1532,9 @@ func GetDefaultValue(fieldSchema *schemapb.FieldSchema) interface{} { func fillMissingFields(schema *schemapb.CollectionSchema, insertData *InsertData) error { batchRows := int64(insertData.GetRowNum()) - for _, field := range schema.Fields { + allFields := typeutil.GetAllFieldSchemas(schema) + + for _, field := range allFields { // Skip function output fields and system fields if field.GetIsFunctionOutput() || field.GetFieldID() < 100 { continue diff --git a/internal/util/indexparamcheck/auto_index_checker.go b/internal/util/indexparamcheck/auto_index_checker.go index f56a2887b1..ee7156c102 100644 --- a/internal/util/indexparamcheck/auto_index_checker.go +++ b/internal/util/indexparamcheck/auto_index_checker.go @@ -9,7 +9,7 @@ type AUTOINDEXChecker struct { baseChecker } -func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { +func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error { return nil } diff --git a/internal/util/indexparamcheck/base_checker.go b/internal/util/indexparamcheck/base_checker.go index 5041b5f232..2b423a1f73 100644 --- a/internal/util/indexparamcheck/base_checker.go +++ b/internal/util/indexparamcheck/base_checker.go @@ -24,7 +24,7 @@ import ( type baseChecker struct{} -func (c baseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { +func (c baseChecker) CheckTrain(_ schemapb.DataType, _ schemapb.DataType, _ map[string]string) error { return nil } @@ -35,7 +35,7 @@ func (c baseChecker) CheckValidDataType(indexType IndexType, field *schemapb.Fie func (c baseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, m map[string]string) {} -func (c baseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { +func (c baseChecker) StaticCheck(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error { return errors.New("unsupported index type") } diff --git a/internal/util/indexparamcheck/base_checker_test.go b/internal/util/indexparamcheck/base_checker_test.go index 9cdc206436..132c6e83ea 100644 --- a/internal/util/indexparamcheck/base_checker_test.go +++ b/internal/util/indexparamcheck/base_checker_test.go @@ -47,9 +47,9 @@ func Test_baseChecker_CheckTrain(t *testing.T) { test.params[common.IndexTypeKey] = "HNSW" var err error if test.params[common.IsSparseKey] == "True" { - err = c.CheckTrain(schemapb.DataType_SparseFloatVector, test.params) + err = c.CheckTrain(schemapb.DataType_SparseFloatVector, schemapb.DataType_None, test.params) } else { - err = c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err = c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) } if test.errIsNil { assert.NoError(t, err) @@ -132,5 +132,5 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) { func Test_baseChecker_StaticCheck(t *testing.T) { // TODO - assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, nil)) + assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, schemapb.DataType_None, nil)) } diff --git a/internal/util/indexparamcheck/bin_flat_checker_test.go b/internal/util/indexparamcheck/bin_flat_checker_test.go index 93896e285f..d84f76ffaa 100644 --- a/internal/util/indexparamcheck/bin_flat_checker_test.go +++ b/internal/util/indexparamcheck/bin_flat_checker_test.go @@ -68,7 +68,7 @@ func Test_binFlatChecker_CheckTrain(t *testing.T) { c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT") for _, test := range cases { test.params[common.IndexTypeKey] = "BINFLAT" - err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params) + err := c.CheckTrain(schemapb.DataType_BinaryVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/bin_ivf_flat_checker_test.go b/internal/util/indexparamcheck/bin_ivf_flat_checker_test.go index d3cff273e6..114704b6dd 100644 --- a/internal/util/indexparamcheck/bin_ivf_flat_checker_test.go +++ b/internal/util/indexparamcheck/bin_ivf_flat_checker_test.go @@ -119,7 +119,7 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) { c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT") for _, test := range cases { test.params[common.IndexTypeKey] = "BIN_IVF_FLAT" - err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params) + err := c.CheckTrain(schemapb.DataType_BinaryVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/bitmap_index_checker.go b/internal/util/indexparamcheck/bitmap_index_checker.go index 6996b70052..60bd08c07b 100644 --- a/internal/util/indexparamcheck/bitmap_index_checker.go +++ b/internal/util/indexparamcheck/bitmap_index_checker.go @@ -11,8 +11,8 @@ type BITMAPChecker struct { scalarIndexChecker } -func (c *BITMAPChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(dataType, params) +func (c *BITMAPChecker) CheckTrain(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(dataType, elementType, params) } func (c *BITMAPChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { diff --git a/internal/util/indexparamcheck/cagra_checker_test.go b/internal/util/indexparamcheck/cagra_checker_test.go index a258cf1935..353222a6b8 100644 --- a/internal/util/indexparamcheck/cagra_checker_test.go +++ b/internal/util/indexparamcheck/cagra_checker_test.go @@ -109,7 +109,7 @@ func Test_cagraChecker_CheckTrain(t *testing.T) { return } for _, test := range cases { - err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err := c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/constraints.go b/internal/util/indexparamcheck/constraints.go index c6b2a883d1..d9696cb2fc 100644 --- a/internal/util/indexparamcheck/constraints.go +++ b/internal/util/indexparamcheck/constraints.go @@ -56,6 +56,7 @@ var ( SparseFloatVectorMetrics = []string{metric.IP, metric.BM25} // const BinaryVectorMetrics = []string{metric.HAMMING, metric.JACCARD, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE, metric.MHJACCARD} // const IntVectorMetrics = []string{metric.L2, metric.IP, metric.COSINE} // const + EmbListMetrics = []string{metric.MaxSim} // const ) // BinIDMapMetrics is a set of all metric types supported for binary vector. diff --git a/internal/util/indexparamcheck/diskann_checker_test.go b/internal/util/indexparamcheck/diskann_checker_test.go index 752e82580c..0937aa29b8 100644 --- a/internal/util/indexparamcheck/diskann_checker_test.go +++ b/internal/util/indexparamcheck/diskann_checker_test.go @@ -76,7 +76,7 @@ func Test_diskannChecker_CheckTrain(t *testing.T) { c, _ := GetIndexCheckerMgrInstance().GetChecker("DISKANN") for _, test := range cases { test.params[common.IndexTypeKey] = "DISKANN" - err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err := c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/flat_checker_test.go b/internal/util/indexparamcheck/flat_checker_test.go index aee817dd43..355c037053 100644 --- a/internal/util/indexparamcheck/flat_checker_test.go +++ b/internal/util/indexparamcheck/flat_checker_test.go @@ -57,7 +57,7 @@ func Test_flatChecker_CheckTrain(t *testing.T) { c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT") for _, test := range cases { test.params[common.IndexTypeKey] = "FLAT" - err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err := c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { @@ -95,7 +95,7 @@ func Test_flatChecker_StaticCheck(t *testing.T) { c, _ := GetIndexCheckerMgrInstance().GetChecker("FLAT") for _, test := range cases { test.params[common.IndexTypeKey] = "FLAT" - err := c.StaticCheck(schemapb.DataType_FloatVector, test.params) + err := c.StaticCheck(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/hnsw_checker_test.go b/internal/util/indexparamcheck/hnsw_checker_test.go index 8dfad02025..4ee8d1112a 100644 --- a/internal/util/indexparamcheck/hnsw_checker_test.go +++ b/internal/util/indexparamcheck/hnsw_checker_test.go @@ -98,9 +98,9 @@ func Test_hnswChecker_CheckTrain(t *testing.T) { test.params[common.IndexTypeKey] = "HNSW" var err error if CheckStrByValues(test.params, common.MetricTypeKey, BinaryVectorMetrics) { - err = c.CheckTrain(schemapb.DataType_BinaryVector, test.params) + err = c.CheckTrain(schemapb.DataType_BinaryVector, schemapb.DataType_None, test.params) } else { - err = c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err = c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) } if test.errIsNil { assert.NoError(t, err) diff --git a/internal/util/indexparamcheck/hybrid_checker_test.go b/internal/util/indexparamcheck/hybrid_checker_test.go index 4d09f7aca1..5f00c2285d 100644 --- a/internal/util/indexparamcheck/hybrid_checker_test.go +++ b/internal/util/indexparamcheck/hybrid_checker_test.go @@ -11,7 +11,7 @@ import ( func Test_HybridIndexChecker(t *testing.T) { c := newHYBRIDChecker() - assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, map[string]string{"bitmap_cardinality_limit": "100"})) + assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, schemapb.DataType_None, map[string]string{"bitmap_cardinality_limit": "100"})) assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Bool})) assert.NoError(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Int8})) @@ -31,7 +31,7 @@ func Test_HybridIndexChecker(t *testing.T) { assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Double})) assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Float})) assert.Error(t, c.CheckValidDataType(IndexHybrid, &schemapb.FieldSchema{DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Double})) - assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{})) - assert.Error(t, c.CheckTrain(schemapb.DataType_Float, map[string]string{"bitmap_cardinality_limit": "0"})) - assert.Error(t, c.CheckTrain(schemapb.DataType_Double, map[string]string{"bitmap_cardinality_limit": "2000"})) + assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, schemapb.DataType_None, map[string]string{})) + assert.Error(t, c.CheckTrain(schemapb.DataType_Float, schemapb.DataType_None, map[string]string{"bitmap_cardinality_limit": "0"})) + assert.Error(t, c.CheckTrain(schemapb.DataType_Double, schemapb.DataType_None, map[string]string{"bitmap_cardinality_limit": "2000"})) } diff --git a/internal/util/indexparamcheck/hybrid_index_checker.go b/internal/util/indexparamcheck/hybrid_index_checker.go index 7e7258ac72..6e7e3e8801 100644 --- a/internal/util/indexparamcheck/hybrid_index_checker.go +++ b/internal/util/indexparamcheck/hybrid_index_checker.go @@ -14,12 +14,12 @@ type HYBRIDChecker struct { scalarIndexChecker } -func (c *HYBRIDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { +func (c *HYBRIDChecker) CheckTrain(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error { if !CheckIntByRange(params, common.BitmapCardinalityLimitKey, 1, MaxBitmapCardinalityLimit) { return fmt.Errorf("failed to check bitmap cardinality limit, should be larger than 0 and smaller than %d", MaxBitmapCardinalityLimit) } - return c.scalarIndexChecker.CheckTrain(dataType, params) + return c.scalarIndexChecker.CheckTrain(dataType, elementType, params) } func (c *HYBRIDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { diff --git a/internal/util/indexparamcheck/index_checker.go b/internal/util/indexparamcheck/index_checker.go index 610ddffc2c..ee6bafc398 100644 --- a/internal/util/indexparamcheck/index_checker.go +++ b/internal/util/indexparamcheck/index_checker.go @@ -21,8 +21,8 @@ import ( ) type IndexChecker interface { - CheckTrain(schemapb.DataType, map[string]string) error + CheckTrain(schemapb.DataType, schemapb.DataType, map[string]string) error CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error SetDefaultMetricTypeIfNotExist(schemapb.DataType, map[string]string) - StaticCheck(schemapb.DataType, map[string]string) error + StaticCheck(schemapb.DataType, schemapb.DataType, map[string]string) error } diff --git a/internal/util/indexparamcheck/inverted_checker.go b/internal/util/indexparamcheck/inverted_checker.go index 5dade936b1..4d414ec54d 100644 --- a/internal/util/indexparamcheck/inverted_checker.go +++ b/internal/util/indexparamcheck/inverted_checker.go @@ -20,7 +20,7 @@ var validJSONCastTypes = []string{"BOOL", "DOUBLE", "VARCHAR", "ARRAY_BOOL", "AR var validJSONCastFunctions = []string{"STRING_TO_DOUBLE"} -func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { +func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error { // check json index params isJSONIndex := typeutil.IsJSONType(dataType) if isJSONIndex { @@ -44,7 +44,7 @@ func (c *INVERTEDChecker) CheckTrain(dataType schemapb.DataType, params map[stri } } } - return c.scalarIndexChecker.CheckTrain(dataType, params) + return c.scalarIndexChecker.CheckTrain(dataType, elementType, params) } func (c *INVERTEDChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { diff --git a/internal/util/indexparamcheck/inverted_checker_test.go b/internal/util/indexparamcheck/inverted_checker_test.go index a5726a3a1d..9513c780a4 100644 --- a/internal/util/indexparamcheck/inverted_checker_test.go +++ b/internal/util/indexparamcheck/inverted_checker_test.go @@ -11,7 +11,7 @@ import ( func Test_INVERTEDIndexChecker(t *testing.T) { c := newINVERTEDChecker() - assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, map[string]string{})) + assert.NoError(t, c.CheckTrain(schemapb.DataType_Bool, schemapb.DataType_None, map[string]string{})) assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) assert.NoError(t, c.CheckValidDataType(IndexINVERTED, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) @@ -26,7 +26,7 @@ func Test_INVERTEDIndexChecker(t *testing.T) { func Test_CheckTrain(t *testing.T) { c := newINVERTEDChecker() - assert.NoError(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{"json_cast_type": "BOOL", "json_path": "json['a']"})) - assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{"json_cast_type": "array", "json_path": "json['a']"})) - assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, map[string]string{"json_cast_type": "abc", "json_path": "json['a']"})) + assert.NoError(t, c.CheckTrain(schemapb.DataType_JSON, schemapb.DataType_None, map[string]string{"json_cast_type": "BOOL", "json_path": "json['a']"})) + assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, schemapb.DataType_None, map[string]string{"json_cast_type": "array", "json_path": "json['a']"})) + assert.Error(t, c.CheckTrain(schemapb.DataType_JSON, schemapb.DataType_None, map[string]string{"json_cast_type": "abc", "json_path": "json['a']"})) } diff --git a/internal/util/indexparamcheck/ivf_base_checker_test.go b/internal/util/indexparamcheck/ivf_base_checker_test.go index fa76f1b6bf..ce27c38b98 100644 --- a/internal/util/indexparamcheck/ivf_base_checker_test.go +++ b/internal/util/indexparamcheck/ivf_base_checker_test.go @@ -74,7 +74,7 @@ func Test_ivfBaseChecker_CheckTrain(t *testing.T) { c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_FLAT") for _, test := range cases { test.params[common.IndexTypeKey] = "IVF_FLAT" - err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err := c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/ivf_pq_checker_test.go b/internal/util/indexparamcheck/ivf_pq_checker_test.go index ccf88bec82..a27042e6a7 100644 --- a/internal/util/indexparamcheck/ivf_pq_checker_test.go +++ b/internal/util/indexparamcheck/ivf_pq_checker_test.go @@ -145,7 +145,7 @@ func Test_ivfPQChecker_CheckTrain(t *testing.T) { c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_PQ") for _, test := range cases { test.params[common.IndexTypeKey] = "IVF_PQ" - err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err := c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/ivf_sq_checker_test.go b/internal/util/indexparamcheck/ivf_sq_checker_test.go index b57781e3d3..475b29540c 100644 --- a/internal/util/indexparamcheck/ivf_sq_checker_test.go +++ b/internal/util/indexparamcheck/ivf_sq_checker_test.go @@ -93,7 +93,7 @@ func Test_ivfSQChecker_CheckTrain(t *testing.T) { c, _ := GetIndexCheckerMgrInstance().GetChecker("IVF_SQ") for _, test := range cases { test.params[common.IndexTypeKey] = "IVF_SQ" - err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err := c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/ngram_index_checker.go b/internal/util/indexparamcheck/ngram_index_checker.go index 416183dd81..65eaad885c 100644 --- a/internal/util/indexparamcheck/ngram_index_checker.go +++ b/internal/util/indexparamcheck/ngram_index_checker.go @@ -22,7 +22,7 @@ func newNgramIndexChecker() *NgramIndexChecker { return &NgramIndexChecker{} } -func (c *NgramIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { +func (c *NgramIndexChecker) CheckTrain(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error { if dataType != schemapb.DataType_VarChar && dataType != schemapb.DataType_JSON { return merr.WrapErrParameterInvalidMsg("Ngram index can only be created on VARCHAR or JSON field") } @@ -47,7 +47,7 @@ func (c *NgramIndexChecker) CheckTrain(dataType schemapb.DataType, params map[st return merr.WrapErrParameterInvalidMsg("invalid min_gram or max_gram value for Ngram index, min_gram: %d, max_gram: %d", minGram, maxGram) } - return c.scalarIndexChecker.CheckTrain(dataType, params) + return c.scalarIndexChecker.CheckTrain(dataType, elementType, params) } func (c *NgramIndexChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { diff --git a/internal/util/indexparamcheck/raft_brute_force_checker_test.go b/internal/util/indexparamcheck/raft_brute_force_checker_test.go index 1a0b7fb4eb..056aa238d7 100644 --- a/internal/util/indexparamcheck/raft_brute_force_checker_test.go +++ b/internal/util/indexparamcheck/raft_brute_force_checker_test.go @@ -62,7 +62,7 @@ func Test_raftbfChecker_CheckTrain(t *testing.T) { } for _, test := range cases { test.params[common.IndexTypeKey] = "GPU_BRUTE_FORCE" - err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err := c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/raft_ivf_flat_checker_test.go b/internal/util/indexparamcheck/raft_ivf_flat_checker_test.go index 863d31cc3b..fdfaa6d9d1 100644 --- a/internal/util/indexparamcheck/raft_ivf_flat_checker_test.go +++ b/internal/util/indexparamcheck/raft_ivf_flat_checker_test.go @@ -93,7 +93,7 @@ func Test_raftIvfFlatChecker_CheckTrain(t *testing.T) { } for _, test := range cases { test.params[common.IndexTypeKey] = "GPU_IVF_FLAT" - err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err := c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/raft_ivf_pq_checker_test.go b/internal/util/indexparamcheck/raft_ivf_pq_checker_test.go index 04d87dd2ad..e30b381e26 100644 --- a/internal/util/indexparamcheck/raft_ivf_pq_checker_test.go +++ b/internal/util/indexparamcheck/raft_ivf_pq_checker_test.go @@ -153,7 +153,7 @@ func Test_raftIVFPQChecker_CheckTrain(t *testing.T) { } for _, test := range cases { test.params[common.IndexTypeKey] = "GPU_IVF_PQ" - err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err := c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/scalar_index_checker.go b/internal/util/indexparamcheck/scalar_index_checker.go index a1272ae388..5019d7fe61 100644 --- a/internal/util/indexparamcheck/scalar_index_checker.go +++ b/internal/util/indexparamcheck/scalar_index_checker.go @@ -6,6 +6,6 @@ type scalarIndexChecker struct { baseChecker } -func (c scalarIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { +func (c scalarIndexChecker) CheckTrain(dataType schemapb.DataType, _ schemapb.DataType, params map[string]string) error { return nil } diff --git a/internal/util/indexparamcheck/scalar_index_checker_test.go b/internal/util/indexparamcheck/scalar_index_checker_test.go index faf8ea2419..4bd913cf3f 100644 --- a/internal/util/indexparamcheck/scalar_index_checker_test.go +++ b/internal/util/indexparamcheck/scalar_index_checker_test.go @@ -10,5 +10,5 @@ import ( func TestCheckIndexValid(t *testing.T) { scalarIndexChecker := &scalarIndexChecker{} - assert.NoError(t, scalarIndexChecker.CheckTrain(schemapb.DataType_Bool, map[string]string{})) + assert.NoError(t, scalarIndexChecker.CheckTrain(schemapb.DataType_Bool, schemapb.DataType_None, map[string]string{})) } diff --git a/internal/util/indexparamcheck/scann_checker_test.go b/internal/util/indexparamcheck/scann_checker_test.go index 00f4d0a26f..ed7b6a40ee 100644 --- a/internal/util/indexparamcheck/scann_checker_test.go +++ b/internal/util/indexparamcheck/scann_checker_test.go @@ -91,7 +91,7 @@ func Test_scaNNChecker_CheckTrain(t *testing.T) { c, _ := GetIndexCheckerMgrInstance().GetChecker("SCANN") for _, test := range cases { test.params[common.IndexTypeKey] = "SCANN" - err := c.CheckTrain(schemapb.DataType_FloatVector, test.params) + err := c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params) if test.errIsNil { assert.NoError(t, err) } else { diff --git a/internal/util/indexparamcheck/sparse_float_vector_base_checker_test.go b/internal/util/indexparamcheck/sparse_float_vector_base_checker_test.go index 079a0b5d55..90c7ad7671 100644 --- a/internal/util/indexparamcheck/sparse_float_vector_base_checker_test.go +++ b/internal/util/indexparamcheck/sparse_float_vector_base_checker_test.go @@ -23,12 +23,12 @@ func Test_sparseFloatVectorBaseChecker_StaticCheck(t *testing.T) { c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX") t.Run("valid metric", func(t *testing.T) { - err := c.StaticCheck(schemapb.DataType_SparseFloatVector, validParams) + err := c.StaticCheck(schemapb.DataType_SparseFloatVector, schemapb.DataType_None, validParams) assert.NoError(t, err) }) t.Run("invalid metric", func(t *testing.T) { - err := c.StaticCheck(schemapb.DataType_SparseFloatVector, invalidParams) + err := c.StaticCheck(schemapb.DataType_SparseFloatVector, schemapb.DataType_None, invalidParams) assert.Error(t, err) }) } @@ -63,22 +63,22 @@ func Test_sparseFloatVectorBaseChecker_CheckTrain(t *testing.T) { c, _ := GetIndexCheckerMgrInstance().GetChecker("SPARSE_INVERTED_INDEX") t.Run("valid params", func(t *testing.T) { - err := c.CheckTrain(schemapb.DataType_SparseFloatVector, validParams) + err := c.CheckTrain(schemapb.DataType_SparseFloatVector, schemapb.DataType_None, validParams) assert.NoError(t, err) }) t.Run("invalid drop ratio", func(t *testing.T) { - err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidDropRatio) + err := c.CheckTrain(schemapb.DataType_SparseFloatVector, schemapb.DataType_None, invalidDropRatio) assert.Error(t, err) }) t.Run("invalid BM25K1", func(t *testing.T) { - err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidBM25K1) + err := c.CheckTrain(schemapb.DataType_SparseFloatVector, schemapb.DataType_None, invalidBM25K1) assert.Error(t, err) }) t.Run("invalid BM25B", func(t *testing.T) { - err := c.CheckTrain(schemapb.DataType_SparseFloatVector, invalidBM25B) + err := c.CheckTrain(schemapb.DataType_SparseFloatVector, schemapb.DataType_None, invalidBM25B) assert.Error(t, err) }) } diff --git a/internal/util/indexparamcheck/stl_sort_checker.go b/internal/util/indexparamcheck/stl_sort_checker.go index 58abf8f8ba..224617f0f5 100644 --- a/internal/util/indexparamcheck/stl_sort_checker.go +++ b/internal/util/indexparamcheck/stl_sort_checker.go @@ -12,8 +12,8 @@ type STLSORTChecker struct { scalarIndexChecker } -func (c *STLSORTChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(dataType, params) +func (c *STLSORTChecker) CheckTrain(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(dataType, elementType, params) } func (c *STLSORTChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { diff --git a/internal/util/indexparamcheck/stl_sort_checker_test.go b/internal/util/indexparamcheck/stl_sort_checker_test.go index 7bb21f95d7..1fc527bf6c 100644 --- a/internal/util/indexparamcheck/stl_sort_checker_test.go +++ b/internal/util/indexparamcheck/stl_sort_checker_test.go @@ -11,7 +11,7 @@ import ( func Test_STLSORTIndexChecker(t *testing.T) { c := newSTLSORTChecker() - assert.NoError(t, c.CheckTrain(schemapb.DataType_Int64, map[string]string{})) + assert.NoError(t, c.CheckTrain(schemapb.DataType_Int64, schemapb.DataType_None, map[string]string{})) assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Int64})) assert.NoError(t, c.CheckValidDataType(IndexSTLSORT, &schemapb.FieldSchema{DataType: schemapb.DataType_Float})) diff --git a/internal/util/indexparamcheck/trie_checker.go b/internal/util/indexparamcheck/trie_checker.go index ad9745071d..83fbe30903 100644 --- a/internal/util/indexparamcheck/trie_checker.go +++ b/internal/util/indexparamcheck/trie_checker.go @@ -12,8 +12,8 @@ type TRIEChecker struct { scalarIndexChecker } -func (c *TRIEChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { - return c.scalarIndexChecker.CheckTrain(dataType, params) +func (c *TRIEChecker) CheckTrain(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error { + return c.scalarIndexChecker.CheckTrain(dataType, elementType, params) } func (c *TRIEChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { diff --git a/internal/util/indexparamcheck/trie_checker_test.go b/internal/util/indexparamcheck/trie_checker_test.go index fb81c90b2c..4680ca84d3 100644 --- a/internal/util/indexparamcheck/trie_checker_test.go +++ b/internal/util/indexparamcheck/trie_checker_test.go @@ -11,7 +11,7 @@ import ( func Test_TrieIndexChecker(t *testing.T) { c := newTRIEChecker() - assert.NoError(t, c.CheckTrain(schemapb.DataType_VarChar, map[string]string{})) + assert.NoError(t, c.CheckTrain(schemapb.DataType_VarChar, schemapb.DataType_None, map[string]string{})) assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar})) assert.NoError(t, c.CheckValidDataType(IndexTRIE, &schemapb.FieldSchema{DataType: schemapb.DataType_String})) diff --git a/internal/util/indexparamcheck/utils.go b/internal/util/indexparamcheck/utils.go index 2717436e51..db9735b16d 100644 --- a/internal/util/indexparamcheck/utils.go +++ b/internal/util/indexparamcheck/utils.go @@ -84,7 +84,7 @@ func CheckAutoIndexHelper(key string, m map[string]string, dtype schemapb.DataTy panic(fmt.Sprintf("%s invalid, unsupported index type: %s", key, indexType)) } - if err := checker.StaticCheck(dtype, m); err != nil { + if err := checker.StaticCheck(dtype, schemapb.DataType_None, m); err != nil { panic(fmt.Sprintf("%s invalid, parameters invalid, error: %s", key, err.Error())) } } diff --git a/internal/util/indexparamcheck/vector_index_checker.go b/internal/util/indexparamcheck/vector_index_checker.go index 16dc33c9b0..1568c9cb5c 100644 --- a/internal/util/indexparamcheck/vector_index_checker.go +++ b/internal/util/indexparamcheck/vector_index_checker.go @@ -40,7 +40,7 @@ func HandleCStatus(status *C.CStatus) error { return fmt.Errorf("%s", errorMsg) } -func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error { +func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error { if typeutil.IsDenseFloatVectorType(dataType) { if !CheckStrByValues(params, Metric, FloatVectorMetrics) { return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], FloatVectorMetrics) @@ -57,6 +57,10 @@ func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, params map[stri if !CheckStrByValues(params, Metric, IntVectorMetrics) { return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], IntVectorMetrics) } + } else if typeutil.IsArrayOfVectorType(dataType) { + if !CheckStrByValues(params, Metric, EmbListMetrics) { + return fmt.Errorf("metric type %s not found or not supported, supported: %v", params[Metric], EmbListMetrics) + } } indexType, exist := params[common.IndexTypeKey] @@ -86,31 +90,32 @@ func (c vecIndexChecker) StaticCheck(dataType schemapb.DataType, params map[stri cIndexType := C.CString(indexType) cDataType := uint32(dataType) - status = C.ValidateIndexParams(cIndexType, cDataType, (*C.uint8_t)(unsafe.Pointer(&indexParamsBlob[0])), (C.uint64_t)(len(indexParamsBlob))) + cElementType := uint32(elementType) + status = C.ValidateIndexParams(cIndexType, cDataType, cElementType, (*C.uint8_t)(unsafe.Pointer(&indexParamsBlob[0])), (C.uint64_t)(len(indexParamsBlob))) C.free(unsafe.Pointer(cIndexType)) return HandleCStatus(&status) } -func (c vecIndexChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error { - if err := c.StaticCheck(dataType, params); err != nil { +func (c vecIndexChecker) CheckTrain(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error { + if err := c.StaticCheck(dataType, elementType, params); err != nil { return err } - if typeutil.IsFixDimVectorType(dataType) { + if typeutil.IsFixDimVectorType(dataType) || (typeutil.IsArrayOfVectorType(dataType) && typeutil.IsFixDimVectorType(elementType)) { if !CheckIntByRange(params, DIM, 1, math.MaxInt) { return errors.New("failed to check vector dimension, should be larger than 0 and smaller than math.MaxInt") } } - return c.baseChecker.CheckTrain(dataType, params) + return c.baseChecker.CheckTrain(dataType, elementType, params) } func (c vecIndexChecker) CheckValidDataType(indexType IndexType, field *schemapb.FieldSchema) error { if !typeutil.IsVectorType(field.GetDataType()) { return fmt.Errorf("index %s only supports vector data type", indexType) } - if !vecindexmgr.GetVecIndexMgrInstance().IsDataTypeSupport(indexType, field.GetDataType()) { + if !vecindexmgr.GetVecIndexMgrInstance().IsDataTypeSupport(indexType, field.GetDataType(), field.GetElementType()) { return fmt.Errorf("index %s do not support data type: %s", indexType, schemapb.DataType_name[int32(field.GetDataType())]) } return nil diff --git a/internal/util/indexparamcheck/vector_index_checker_test.go b/internal/util/indexparamcheck/vector_index_checker_test.go index f562730ee0..d2e45944d5 100644 --- a/internal/util/indexparamcheck/vector_index_checker_test.go +++ b/internal/util/indexparamcheck/vector_index_checker_test.go @@ -45,7 +45,7 @@ func TestVecIndexChecker_StaticCheck(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := checker.StaticCheck(tt.dataType, tt.params) + err := checker.StaticCheck(tt.dataType, schemapb.DataType_None, tt.params) if tt.wantErr { assert.Error(t, err) } else { diff --git a/internal/util/vecindexmgr/vector_index_mgr.go b/internal/util/vecindexmgr/vector_index_mgr.go index 504b5c441b..e46324f7b9 100644 --- a/internal/util/vecindexmgr/vector_index_mgr.go +++ b/internal/util/vecindexmgr/vector_index_mgr.go @@ -42,6 +42,8 @@ const ( SparseFloat32Flag uint64 = 1 << 4 Int8Flag uint64 = 1 << 5 + EmbeddingListFlag uint64 = 1 << 15 + // NOTrainFlag This flag indicates that there is no need to create any index structure NOTrainFlag uint64 = 1 << 16 // KNNFlag This flag indicates that the index defaults to KNN search, meaning the recall rate is 100% @@ -63,13 +65,13 @@ type VecIndexMgr interface { GetFeature(indexType IndexType) (uint64, bool) - IsBinaryVectorSupport(indexType IndexType) bool - IsFloat32VectorSupport(indexType IndexType) bool - IsFloat16VectorSupport(indexType IndexType) bool - IsBFloat16VectorSupport(indexType IndexType) bool - IsSparseFloat32VectorSupport(indexType IndexType) bool - IsInt8VectorSupport(indexType IndexType) bool - IsDataTypeSupport(indexType IndexType, dataType schemapb.DataType) bool + IsBinaryVectorSupport(indexType IndexType, isEmbeddingList bool) bool + IsFloat32VectorSupport(indexType IndexType, isEmbeddingList bool) bool + IsFloat16VectorSupport(indexType IndexType, isEmbeddingList bool) bool + IsBFloat16VectorSupport(indexType IndexType, isEmbeddingList bool) bool + IsSparseFloat32VectorSupport(indexType IndexType, isEmbeddingList bool) bool + IsInt8VectorSupport(indexType IndexType, isEmbeddingList bool) bool + IsDataTypeSupport(indexType IndexType, dataType schemapb.DataType, elementType schemapb.DataType) bool IsFlatVecIndex(indexType IndexType) bool IsNoTrainIndex(indexType IndexType) bool @@ -126,67 +128,67 @@ func (mgr *vecIndexMgrImpl) init() { log.Info("init vector indexes with features : " + featureLog.String()) } -func (mgr *vecIndexMgrImpl) IsBinaryVectorSupport(indexType IndexType) bool { +func (mgr *vecIndexMgrImpl) isVectorTypeSupported(indexType IndexType, vectorFlag uint64, isEmbeddingList bool) bool { feature, ok := mgr.GetFeature(indexType) if !ok { return false } - return (feature & BinaryFlag) == BinaryFlag -} -func (mgr *vecIndexMgrImpl) IsFloat32VectorSupport(indexType IndexType) bool { - feature, ok := mgr.GetFeature(indexType) - if !ok { + // check if the vector type is supported + if (feature & vectorFlag) != vectorFlag { return false } - return (feature & Float32Flag) == Float32Flag -} -func (mgr *vecIndexMgrImpl) IsFloat16VectorSupport(indexType IndexType) bool { - feature, ok := mgr.GetFeature(indexType) - if !ok { + // if it is embedding list, also check EmbeddingListFlag + if isEmbeddingList && (feature&EmbeddingListFlag) != EmbeddingListFlag { return false } - return (feature & Float16Flag) == Float16Flag + + return true } -func (mgr *vecIndexMgrImpl) IsBFloat16VectorSupport(indexType IndexType) bool { - feature, ok := mgr.GetFeature(indexType) - if !ok { - return false +func (mgr *vecIndexMgrImpl) IsBinaryVectorSupport(indexType IndexType, isEmbeddingList bool) bool { + return mgr.isVectorTypeSupported(indexType, BinaryFlag, isEmbeddingList) +} + +func (mgr *vecIndexMgrImpl) IsFloat32VectorSupport(indexType IndexType, isEmbeddingList bool) bool { + return mgr.isVectorTypeSupported(indexType, Float32Flag, isEmbeddingList) +} + +func (mgr *vecIndexMgrImpl) IsFloat16VectorSupport(indexType IndexType, isEmbeddingList bool) bool { + return mgr.isVectorTypeSupported(indexType, Float16Flag, isEmbeddingList) +} + +func (mgr *vecIndexMgrImpl) IsBFloat16VectorSupport(indexType IndexType, isEmbeddingList bool) bool { + return mgr.isVectorTypeSupported(indexType, BFloat16Flag, isEmbeddingList) +} + +func (mgr *vecIndexMgrImpl) IsSparseFloat32VectorSupport(indexType IndexType, isEmbeddingList bool) bool { + return mgr.isVectorTypeSupported(indexType, SparseFloat32Flag, isEmbeddingList) +} + +func (mgr *vecIndexMgrImpl) IsInt8VectorSupport(indexType IndexType, isEmbeddingList bool) bool { + return mgr.isVectorTypeSupported(indexType, Int8Flag, isEmbeddingList) +} + +func (mgr *vecIndexMgrImpl) IsDataTypeSupport(indexType IndexType, dataType schemapb.DataType, elementType schemapb.DataType) bool { + isEmbeddingList := dataType == schemapb.DataType_ArrayOfVector + if isEmbeddingList { + dataType = elementType } - return (feature & BFloat16Flag) == BFloat16Flag -} -func (mgr *vecIndexMgrImpl) IsSparseFloat32VectorSupport(indexType IndexType) bool { - feature, ok := mgr.GetFeature(indexType) - if !ok { - return false - } - return (feature & SparseFloat32Flag) == SparseFloat32Flag -} - -func (mgr *vecIndexMgrImpl) IsInt8VectorSupport(indexType IndexType) bool { - feature, ok := mgr.GetFeature(indexType) - if !ok { - return false - } - return (feature & Int8Flag) == Int8Flag -} - -func (mgr *vecIndexMgrImpl) IsDataTypeSupport(indexType IndexType, dataType schemapb.DataType) bool { if dataType == schemapb.DataType_BinaryVector { - return mgr.IsBinaryVectorSupport(indexType) + return mgr.IsBinaryVectorSupport(indexType, isEmbeddingList) } else if dataType == schemapb.DataType_FloatVector { - return mgr.IsFloat32VectorSupport(indexType) + return mgr.IsFloat32VectorSupport(indexType, isEmbeddingList) } else if dataType == schemapb.DataType_BFloat16Vector { - return mgr.IsBFloat16VectorSupport(indexType) + return mgr.IsBFloat16VectorSupport(indexType, isEmbeddingList) } else if dataType == schemapb.DataType_Float16Vector { - return mgr.IsFloat16VectorSupport(indexType) + return mgr.IsFloat16VectorSupport(indexType, isEmbeddingList) } else if dataType == schemapb.DataType_SparseFloatVector { - return mgr.IsSparseFloat32VectorSupport(indexType) + return mgr.IsSparseFloat32VectorSupport(indexType, isEmbeddingList) } else if dataType == schemapb.DataType_Int8Vector { - return mgr.IsInt8VectorSupport(indexType) + return mgr.IsInt8VectorSupport(indexType, isEmbeddingList) } return false } diff --git a/internal/util/vecindexmgr/vector_index_mgr_test.go b/internal/util/vecindexmgr/vector_index_mgr_test.go index eca7f52dbe..6abcefd6a2 100644 --- a/internal/util/vecindexmgr/vector_index_mgr_test.go +++ b/internal/util/vecindexmgr/vector_index_mgr_test.go @@ -98,7 +98,7 @@ func Test_VecIndex_DataType_Support(t *testing.T) { for _, tt := range tests { t.Run(string(tt.indexType), func(t *testing.T) { for i, dataType := range tt.dataTypes { - got := mgr.IsDataTypeSupport(tt.indexType, dataType) + got := mgr.IsDataTypeSupport(tt.indexType, dataType, schemapb.DataType_None) if got != tt.wants[i] { t.Errorf("IsDataTypeSupport(%v, %v) = %v, want %v", tt.indexType, dataType, got, tt.wants[i]) } diff --git a/pkg/proto/plan.proto b/pkg/proto/plan.proto index 8bbecf1568..74dfd3723e 100644 --- a/pkg/proto/plan.proto +++ b/pkg/proto/plan.proto @@ -40,6 +40,10 @@ enum VectorType { BFloat16Vector = 3; SparseFloatVector = 4; Int8Vector = 5; + EmbListFloatVector = 6; + EmbListFloat16Vector = 7; + EmbListBFloat16Vector = 8; + EmbListInt8Vector = 9; }; message GenericValue { diff --git a/pkg/proto/planpb/plan.pb.go b/pkg/proto/planpb/plan.pb.go index c238917c50..25eb4a3f6b 100644 --- a/pkg/proto/planpb/plan.pb.go +++ b/pkg/proto/planpb/plan.pb.go @@ -173,12 +173,16 @@ func (ArithOpType) EnumDescriptor() ([]byte, []int) { type VectorType int32 const ( - VectorType_BinaryVector VectorType = 0 - VectorType_FloatVector VectorType = 1 - VectorType_Float16Vector VectorType = 2 - VectorType_BFloat16Vector VectorType = 3 - VectorType_SparseFloatVector VectorType = 4 - VectorType_Int8Vector VectorType = 5 + VectorType_BinaryVector VectorType = 0 + VectorType_FloatVector VectorType = 1 + VectorType_Float16Vector VectorType = 2 + VectorType_BFloat16Vector VectorType = 3 + VectorType_SparseFloatVector VectorType = 4 + VectorType_Int8Vector VectorType = 5 + VectorType_EmbListFloatVector VectorType = 6 + VectorType_EmbListFloat16Vector VectorType = 7 + VectorType_EmbListBFloat16Vector VectorType = 8 + VectorType_EmbListInt8Vector VectorType = 9 ) // Enum value maps for VectorType. @@ -190,14 +194,22 @@ var ( 3: "BFloat16Vector", 4: "SparseFloatVector", 5: "Int8Vector", + 6: "EmbListFloatVector", + 7: "EmbListFloat16Vector", + 8: "EmbListBFloat16Vector", + 9: "EmbListInt8Vector", } VectorType_value = map[string]int32{ - "BinaryVector": 0, - "FloatVector": 1, - "Float16Vector": 2, - "BFloat16Vector": 3, - "SparseFloatVector": 4, - "Int8Vector": 5, + "BinaryVector": 0, + "FloatVector": 1, + "Float16Vector": 2, + "BFloat16Vector": 3, + "SparseFloatVector": 4, + "Int8Vector": 5, + "EmbListFloatVector": 6, + "EmbListFloat16Vector": 7, + "EmbListBFloat16Vector": 8, + "EmbListInt8Vector": 9, } ) @@ -3058,19 +3070,25 @@ var file_plan_proto_rawDesc = []byte{ 0x64, 0x64, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x53, 0x75, 0x62, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x75, 0x6c, 0x10, 0x03, 0x12, 0x07, 0x0a, 0x03, 0x44, 0x69, 0x76, 0x10, 0x04, 0x12, 0x07, 0x0a, 0x03, 0x4d, 0x6f, 0x64, 0x10, 0x05, 0x12, 0x0f, 0x0a, 0x0b, 0x41, 0x72, 0x72, 0x61, - 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x10, 0x06, 0x2a, 0x7d, 0x0a, 0x0a, 0x56, 0x65, 0x63, - 0x74, 0x6f, 0x72, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, 0x72, - 0x79, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, 0x6f, - 0x61, 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, 0x6c, - 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x02, 0x12, 0x12, 0x0a, - 0x0e, 0x42, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, - 0x03, 0x12, 0x15, 0x0a, 0x11, 0x53, 0x70, 0x61, 0x72, 0x73, 0x65, 0x46, 0x6c, 0x6f, 0x61, 0x74, - 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x04, 0x12, 0x0e, 0x0a, 0x0a, 0x49, 0x6e, 0x74, 0x38, - 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x05, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, - 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2d, 0x69, 0x6f, - 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x76, 0x32, 0x2f, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0x79, 0x4c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x10, 0x06, 0x2a, 0xe1, 0x01, 0x0a, 0x0a, 0x56, 0x65, + 0x63, 0x74, 0x6f, 0x72, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10, 0x0a, 0x0c, 0x42, 0x69, 0x6e, 0x61, + 0x72, 0x79, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x46, 0x6c, + 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x01, 0x12, 0x11, 0x0a, 0x0d, 0x46, + 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x02, 0x12, 0x12, + 0x0a, 0x0e, 0x42, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, + 0x10, 0x03, 0x12, 0x15, 0x0a, 0x11, 0x53, 0x70, 0x61, 0x72, 0x73, 0x65, 0x46, 0x6c, 0x6f, 0x61, + 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x04, 0x12, 0x0e, 0x0a, 0x0a, 0x49, 0x6e, 0x74, + 0x38, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x05, 0x12, 0x16, 0x0a, 0x12, 0x45, 0x6d, 0x62, + 0x4c, 0x69, 0x73, 0x74, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, + 0x06, 0x12, 0x18, 0x0a, 0x14, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x46, 0x6c, 0x6f, 0x61, + 0x74, 0x31, 0x36, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x07, 0x12, 0x19, 0x0a, 0x15, 0x45, + 0x6d, 0x62, 0x4c, 0x69, 0x73, 0x74, 0x42, 0x46, 0x6c, 0x6f, 0x61, 0x74, 0x31, 0x36, 0x56, 0x65, + 0x63, 0x74, 0x6f, 0x72, 0x10, 0x08, 0x12, 0x15, 0x0a, 0x11, 0x45, 0x6d, 0x62, 0x4c, 0x69, 0x73, + 0x74, 0x49, 0x6e, 0x74, 0x38, 0x56, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x10, 0x09, 0x42, 0x31, 0x5a, + 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x6c, 0x76, + 0x75, 0x73, 0x2d, 0x69, 0x6f, 0x2f, 0x6d, 0x69, 0x6c, 0x76, 0x75, 0x73, 0x2f, 0x70, 0x6b, 0x67, + 0x2f, 0x76, 0x32, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x70, 0x6c, 0x61, 0x6e, 0x70, 0x62, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/pkg/util/metric/metric_type.go b/pkg/util/metric/metric_type.go index 3b764d24c2..691feb2e77 100644 --- a/pkg/util/metric/metric_type.go +++ b/pkg/util/metric/metric_type.go @@ -43,4 +43,6 @@ const ( BM25 MetricType = "BM25" EMPTY MetricType = "" + + MaxSim MetricType = "MAX_SIM" ) diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index 49db69b33e..230de52f7d 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -175,7 +175,7 @@ func estimateSizeBy(schema *schemapb.CollectionSchema, policy getVariableFieldLe return res, nil } -func CalcColumnSize(column *schemapb.FieldData) int { +func CalcScalarSize(column *schemapb.FieldData) int { res := 0 switch column.GetType() { case schemapb.DataType_Bool: @@ -198,7 +198,7 @@ func CalcColumnSize(column *schemapb.FieldData) int { } case schemapb.DataType_Array: for _, array := range column.GetScalars().GetArrayData().GetData() { - res += CalcColumnSize(&schemapb.FieldData{ + res += CalcScalarSize(&schemapb.FieldData{ Field: &schemapb.FieldData_Scalars{Scalars: array}, Type: column.GetScalars().GetArrayData().GetElementType(), }) @@ -228,6 +228,8 @@ func calcVectorSize(column *schemapb.VectorField, vectorType schemapb.DataType) panic("unimplemented") case schemapb.DataType_Int8Vector: res += len(column.GetInt8Vector()) + case schemapb.DataType_ArrayOfVector: + panic("unreachable") default: panic("Unknown data type:" + vectorType.String()) } @@ -256,7 +258,7 @@ func EstimateEntitySize(fieldsData []*schemapb.FieldData, rowOffset int) (int, e return 0, errors.New("offset out range of field datas") } array := fs.GetScalars().GetArrayData().GetData()[rowOffset] - res += CalcColumnSize(&schemapb.FieldData{ + res += CalcScalarSize(&schemapb.FieldData{ Field: &schemapb.FieldData_Scalars{Scalars: array}, Type: fs.GetScalars().GetArrayData().GetElementType(), }) @@ -312,11 +314,7 @@ func CreateSchemaHelper(schema *schemapb.CollectionSchema) (*SchemaHelper, error return nil, errors.New("schema is nil") } - allFields := make([]*schemapb.FieldSchema, 0, len(schema.Fields)+5) - allFields = append(allFields, schema.Fields...) - for _, structField := range schema.GetStructArrayFields() { - allFields = append(allFields, structField.GetFields()...) - } + allFields := GetAllFieldSchemas(schema) schemaHelper := SchemaHelper{ schema: schema, @@ -525,6 +523,10 @@ func IsDenseFloatVectorType(dataType schemapb.DataType) bool { } } +func IsArrayOfVectorType(dataType schemapb.DataType) bool { + return dataType == schemapb.DataType_ArrayOfVector +} + // return VectorTypeSize for each dim (byte) func VectorTypeSize(dataType schemapb.DataType) float64 { switch dataType { diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index c339b1136a..5a77e1826c 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -1677,7 +1677,7 @@ func TestCalcColumnSize(t *testing.T) { for _, field := range schema.GetFields() { values := fieldValues[field.GetFieldID()] fieldData := genFieldData(field.GetName(), field.GetFieldID(), field.GetDataType(), values, 0) - size := CalcColumnSize(fieldData) + size := CalcScalarSize(fieldData) expected := 0 switch field.GetDataType() { case schemapb.DataType_VarChar: diff --git a/tests/integration/datanode/struct_array_test.go b/tests/integration/datanode/struct_array_test.go index 8e139143ae..ea74a37ca8 100644 --- a/tests/integration/datanode/struct_array_test.go +++ b/tests/integration/datanode/struct_array_test.go @@ -304,7 +304,7 @@ func (s *ArrayStructDataNodeSuite) query(collectionName string) { params := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.IP) params["radius"] = radius searchReq := integration.ConstructSearchRequest("", collectionName, expr, - integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, s.dim, topk, roundDecimal) + integration.StructSubFloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, s.dim, topk, roundDecimal) searchResult, _ := c.MilvusClient.Search(context.TODO(), searchReq) diff --git a/tests/integration/util_query.go b/tests/integration/util_query.go index 7d6d1bcf47..af8fe44926 100644 --- a/tests/integration/util_query.go +++ b/tests/integration/util_query.go @@ -173,12 +173,39 @@ func ConstructSearchRequest( metricType string, params map[string]any, nq, dim int, topk, roundDecimal int, +) *milvuspb.SearchRequest { + return constructSearchRequest(dbName, collectionName, expr, vecField, false, vectorType, outputFields, metricType, params, nq, dim, topk, roundDecimal) +} + +func ConstructEmbeddingListSearchRequest( + dbName, collectionName string, + expr string, + vecField string, + vectorType schemapb.DataType, + outputFields []string, + metricType string, + params map[string]any, + nq, dim int, topk, roundDecimal int, +) *milvuspb.SearchRequest { + return constructSearchRequest(dbName, collectionName, expr, vecField, true, vectorType, outputFields, metricType, params, nq, dim, topk, roundDecimal) +} + +func constructSearchRequest( + dbName, collectionName string, + expr string, + vecField string, + isEmbeddingList bool, + vectorType schemapb.DataType, + outputFields []string, + metricType string, + params map[string]any, + nq, dim int, topk, roundDecimal int, ) *milvuspb.SearchRequest { b, err := json.Marshal(params) if err != nil { panic(err) } - plg := constructPlaceholderGroup(nq, dim, vectorType) + plg := constructPlaceholderGroup(nq, dim, vectorType, isEmbeddingList) plgBs, err := proto.Marshal(plg) if err != nil { panic(err) @@ -237,7 +264,7 @@ func ConstructSearchRequestWithConsistencyLevel( if err != nil { panic(err) } - plg := constructPlaceholderGroup(nq, dim, vectorType) + plg := constructPlaceholderGroup(nq, dim, vectorType, false) plgBs, err := proto.Marshal(plg) if err != nil { panic(err) @@ -281,12 +308,16 @@ func ConstructSearchRequestWithConsistencyLevel( } } -func constructPlaceholderGroup(nq, dim int, vectorType schemapb.DataType) *commonpb.PlaceholderGroup { +func constructPlaceholderGroup(nq, dim int, vectorType schemapb.DataType, isEmbeddingList bool) *commonpb.PlaceholderGroup { values := make([][]byte, 0, nq) var placeholderType commonpb.PlaceholderType switch vectorType { case schemapb.DataType_FloatVector: - placeholderType = commonpb.PlaceholderType_FloatVector + if !isEmbeddingList { + placeholderType = commonpb.PlaceholderType_FloatVector + } else { + placeholderType = commonpb.PlaceholderType_EmbListFloatVector + } for i := 0; i < nq; i++ { bs := make([]byte, 0, dim*4) for j := 0; j < dim; j++ {