mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
feat: impl StructArray -- support create index for vector array (embedding list) and search on it (#43726)
Ref https://github.com/milvus-io/milvus/issues/42148 This PR supports create index for vector array (now, only for `DataType.FLOAT_VECTOR`) and search on it. The index type supported in this PR is `EMB_LIST_HNSW` and the metric type is `MAX_SIM` only. The way to use it: ```python milvus_client = MilvusClient("xxx:19530") schema = milvus_client.create_schema(enable_dynamic_field=True, auto_id=True) ... struct_schema = milvus_client.create_struct_array_field_schema("struct_array_field") ... struct_schema.add_field("struct_float_vec", DataType.ARRAY_OF_VECTOR, element_type=DataType.FLOAT_VECTOR, dim=128, max_capacity=1000) ... schema.add_struct_array_field(struct_schema) index_params = milvus_client.prepare_index_params() index_params.add_index(field_name="struct_float_vec", index_type="EMB_LIST_HNSW", metric_type="MAX_SIM", index_params={"nlist": 128}) ... milvus_client.create_index(COLLECTION_NAME, schema=schema, index_params=index_params) ``` Note: This PR uses `Lims` to convey offsets of the vector array to knowhere where vectors of multiple vector arrays are concatenated and we need offsets to specify which vectors belong to which vector array. --------- Signed-off-by: SpadeA <tangchenjie1210@gmail.com> Signed-off-by: SpadeA-Tang <tangchenjie1210@gmail.com>
This commit is contained in:
parent
cfdb17a088
commit
d6a428e880
@ -395,6 +395,14 @@ class VectorArrayChunk : public Chunk {
|
||||
dim_(dim),
|
||||
element_type_(element_type) {
|
||||
offsets_lens_ = reinterpret_cast<uint32_t*>(data);
|
||||
|
||||
auto offset = 0;
|
||||
lims_.reserve(row_nums_ + 1);
|
||||
lims_.push_back(offset);
|
||||
for (int64_t i = 0; i < row_nums_; i++) {
|
||||
offset += offsets_lens_[i * 2 + 1];
|
||||
lims_.push_back(offset);
|
||||
}
|
||||
}
|
||||
|
||||
VectorArrayView
|
||||
@ -424,10 +432,23 @@ class VectorArrayChunk : public Chunk {
|
||||
"VectorArrayChunk::ValueAt is not supported");
|
||||
}
|
||||
|
||||
const char*
|
||||
Data() const override {
|
||||
return data_ + offsets_lens_[0];
|
||||
}
|
||||
|
||||
const size_t*
|
||||
Lims() const {
|
||||
return lims_.data();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t dim_;
|
||||
uint32_t* offsets_lens_;
|
||||
milvus::DataType element_type_;
|
||||
// The name 'Lims' is consistent with knowhere::DataSet::SetLims which describes the number of vectors
|
||||
// in each vector array (embedding list). This is needed as vectors are flattened in the chunk.
|
||||
std::vector<size_t> lims_;
|
||||
};
|
||||
|
||||
class SparseFloatVectorChunk : public Chunk {
|
||||
|
||||
@ -92,6 +92,15 @@ class FieldData<VectorArray> : public FieldDataVectorArrayImpl {
|
||||
ThrowInfo(Unsupported,
|
||||
"Call get_dim on FieldData<VectorArray> is not supported");
|
||||
}
|
||||
|
||||
const VectorArray*
|
||||
value_at(ssize_t offset) const {
|
||||
AssertInfo(offset < get_num_rows(),
|
||||
"field data subscript out of range");
|
||||
AssertInfo(offset < length(),
|
||||
"subscript position don't has valid value");
|
||||
return &data_[offset];
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
@ -47,6 +47,7 @@ constexpr bool IsVariableType =
|
||||
IsSparse<T> || std::is_same_v<T, VectorArray> ||
|
||||
std::is_same_v<T, VectorArrayView>;
|
||||
|
||||
// todo(SpadeA): support vector array
|
||||
template <typename T>
|
||||
constexpr bool IsVariableTypeSupportInChunk =
|
||||
std::is_same_v<T, std::string> || std::is_same_v<T, Array> ||
|
||||
|
||||
@ -493,7 +493,8 @@ IsFloatVectorMetricType(const MetricType& metric_type) {
|
||||
return metric_type == knowhere::metric::L2 ||
|
||||
metric_type == knowhere::metric::IP ||
|
||||
metric_type == knowhere::metric::COSINE ||
|
||||
metric_type == knowhere::metric::BM25;
|
||||
metric_type == knowhere::metric::BM25 ||
|
||||
metric_type == knowhere::metric::MAX_SIM;
|
||||
}
|
||||
|
||||
inline bool
|
||||
|
||||
@ -28,20 +28,23 @@
|
||||
|
||||
namespace milvus {
|
||||
|
||||
#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \
|
||||
using elem_type = std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::FloatVector>, \
|
||||
milvus::FloatVector::embedded_type, \
|
||||
std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::Float16Vector>, \
|
||||
milvus::Float16Vector::embedded_type, \
|
||||
std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::BFloat16Vector>, \
|
||||
milvus::BFloat16Vector::embedded_type, \
|
||||
std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::Int8Vector>, \
|
||||
milvus::Int8Vector::embedded_type, \
|
||||
milvus::BinaryVector::embedded_type>>>>;
|
||||
#define GET_ELEM_TYPE_FOR_VECTOR_TRAIT \
|
||||
using elem_type = std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::EmbListFloatVector>, \
|
||||
milvus::EmbListFloatVector::embedded_type, \
|
||||
std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::FloatVector>, \
|
||||
milvus::FloatVector::embedded_type, \
|
||||
std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::Float16Vector>, \
|
||||
milvus::Float16Vector::embedded_type, \
|
||||
std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::BFloat16Vector>, \
|
||||
milvus::BFloat16Vector::embedded_type, \
|
||||
std::conditional_t< \
|
||||
std::is_same_v<TraitType, milvus::Int8Vector>, \
|
||||
milvus::Int8Vector::embedded_type, \
|
||||
milvus::BinaryVector::embedded_type>>>>>;
|
||||
|
||||
#define GET_SCHEMA_DATA_TYPE_FOR_VECTOR_TRAIT \
|
||||
auto schema_data_type = \
|
||||
@ -55,7 +58,13 @@ namespace milvus {
|
||||
? milvus::Int8Vector::schema_data_type \
|
||||
: milvus::BinaryVector::schema_data_type;
|
||||
|
||||
class VectorTrait {};
|
||||
class VectorTrait {
|
||||
public:
|
||||
static constexpr bool
|
||||
is_embedding_list() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
class FloatVector : public VectorTrait {
|
||||
public:
|
||||
@ -136,6 +145,25 @@ class Int8Vector : public VectorTrait {
|
||||
proto::common::PlaceholderType::Int8Vector;
|
||||
};
|
||||
|
||||
class EmbListFloatVector : public VectorTrait {
|
||||
public:
|
||||
using embedded_type = float;
|
||||
static constexpr int32_t dim_factor = 1;
|
||||
static constexpr auto data_type = DataType::VECTOR_ARRAY;
|
||||
static constexpr auto c_data_type = CDataType::VectorArray;
|
||||
static constexpr auto schema_data_type =
|
||||
proto::schema::DataType::ArrayOfVector;
|
||||
static constexpr auto vector_type =
|
||||
proto::plan::VectorType::EmbListFloatVector;
|
||||
static constexpr auto placeholder_type =
|
||||
proto::common::PlaceholderType::EmbListFloatVector;
|
||||
|
||||
static constexpr bool
|
||||
is_embedding_list() {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct FundamentalTag {};
|
||||
struct StringTag {};
|
||||
|
||||
|
||||
@ -55,6 +55,7 @@ enum CDataType {
|
||||
BFloat16Vector = 103,
|
||||
SparseFloatVector = 104,
|
||||
Int8Vector = 105,
|
||||
VectorArray = 106,
|
||||
};
|
||||
typedef enum CDataType CDataType;
|
||||
|
||||
|
||||
@ -69,6 +69,7 @@ PhyVectorSearchNode::GetOutput() {
|
||||
|
||||
auto& ph = placeholder_group_->at(0);
|
||||
auto src_data = ph.get_blob();
|
||||
auto src_lims = ph.get_lims();
|
||||
auto num_queries = ph.num_of_queries_;
|
||||
milvus::SearchResult search_result;
|
||||
|
||||
@ -85,6 +86,7 @@ PhyVectorSearchNode::GetOutput() {
|
||||
col_input->size());
|
||||
segment_->vector_search(search_info_,
|
||||
src_data,
|
||||
src_lims,
|
||||
num_queries,
|
||||
query_timestamp_,
|
||||
final_view,
|
||||
|
||||
@ -98,13 +98,18 @@ IndexFactory::CreatePrimitiveScalarIndex<std::string>(
|
||||
LoadResourceRequest
|
||||
IndexFactory::IndexLoadResource(
|
||||
DataType field_type,
|
||||
DataType element_type,
|
||||
IndexVersion index_version,
|
||||
float index_size,
|
||||
const std::map<std::string, std::string>& index_params,
|
||||
bool mmap_enable) {
|
||||
if (milvus::IsVectorDataType(field_type)) {
|
||||
return VecIndexLoadResource(
|
||||
field_type, index_version, index_size, index_params, mmap_enable);
|
||||
return VecIndexLoadResource(field_type,
|
||||
element_type,
|
||||
index_version,
|
||||
index_size,
|
||||
index_params,
|
||||
mmap_enable);
|
||||
} else {
|
||||
return ScalarIndexLoadResource(
|
||||
field_type, index_version, index_size, index_params, mmap_enable);
|
||||
@ -114,6 +119,7 @@ IndexFactory::IndexLoadResource(
|
||||
LoadResourceRequest
|
||||
IndexFactory::VecIndexLoadResource(
|
||||
DataType field_type,
|
||||
DataType element_type,
|
||||
IndexVersion index_version,
|
||||
float index_size,
|
||||
const std::map<std::string, std::string>& index_params,
|
||||
@ -198,6 +204,29 @@ IndexFactory::VecIndexLoadResource(
|
||||
knowhere::IndexStaticFaced<knowhere::int8>::HasRawData(
|
||||
index_type, index_version, config);
|
||||
break;
|
||||
case milvus::DataType::VECTOR_ARRAY: {
|
||||
switch (element_type) {
|
||||
case milvus::DataType::VECTOR_FLOAT:
|
||||
resource = knowhere::IndexStaticFaced<
|
||||
knowhere::fp32>::EstimateLoadResource(index_type,
|
||||
index_version,
|
||||
index_size_gb,
|
||||
config);
|
||||
has_raw_data =
|
||||
knowhere::IndexStaticFaced<knowhere::fp32>::HasRawData(
|
||||
index_type, index_version, config);
|
||||
break;
|
||||
|
||||
default:
|
||||
LOG_ERROR(
|
||||
"invalid data type to estimate index load resource: "
|
||||
"field_type {}, element_type {}",
|
||||
field_type,
|
||||
element_type);
|
||||
return LoadResourceRequest{0, 0, 0, 0, true};
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
LOG_ERROR("invalid data type to estimate index load resource: {}",
|
||||
field_type);
|
||||
@ -491,8 +520,14 @@ IndexFactory::CreateVectorIndex(
|
||||
return std::make_unique<VectorDiskAnnIndex<float>>(
|
||||
index_type, metric_type, version, file_manager_context);
|
||||
}
|
||||
case DataType::VECTOR_ARRAY: {
|
||||
ThrowInfo(Unsupported,
|
||||
"VECTOR_ARRAY for DiskAnnIndex is not supported");
|
||||
}
|
||||
case DataType::VECTOR_INT8: {
|
||||
// TODO caiyd, not support yet
|
||||
ThrowInfo(Unsupported,
|
||||
"VECTOR_INT8 for DiskAnnIndex is not supported");
|
||||
}
|
||||
default:
|
||||
ThrowInfo(
|
||||
@ -505,6 +540,7 @@ IndexFactory::CreateVectorIndex(
|
||||
case DataType::VECTOR_FLOAT:
|
||||
case DataType::VECTOR_SPARSE_FLOAT: {
|
||||
return std::make_unique<VectorMemIndex<float>>(
|
||||
DataType::NONE,
|
||||
index_type,
|
||||
metric_type,
|
||||
version,
|
||||
@ -513,6 +549,7 @@ IndexFactory::CreateVectorIndex(
|
||||
}
|
||||
case DataType::VECTOR_BINARY: {
|
||||
return std::make_unique<VectorMemIndex<bin1>>(
|
||||
DataType::NONE,
|
||||
index_type,
|
||||
metric_type,
|
||||
version,
|
||||
@ -521,6 +558,7 @@ IndexFactory::CreateVectorIndex(
|
||||
}
|
||||
case DataType::VECTOR_FLOAT16: {
|
||||
return std::make_unique<VectorMemIndex<float16>>(
|
||||
DataType::NONE,
|
||||
index_type,
|
||||
metric_type,
|
||||
version,
|
||||
@ -529,6 +567,7 @@ IndexFactory::CreateVectorIndex(
|
||||
}
|
||||
case DataType::VECTOR_BFLOAT16: {
|
||||
return std::make_unique<VectorMemIndex<bfloat16>>(
|
||||
DataType::NONE,
|
||||
index_type,
|
||||
metric_type,
|
||||
version,
|
||||
@ -537,12 +576,33 @@ IndexFactory::CreateVectorIndex(
|
||||
}
|
||||
case DataType::VECTOR_INT8: {
|
||||
return std::make_unique<VectorMemIndex<int8>>(
|
||||
DataType::NONE,
|
||||
index_type,
|
||||
metric_type,
|
||||
version,
|
||||
use_knowhere_build_pool,
|
||||
file_manager_context);
|
||||
}
|
||||
case DataType::VECTOR_ARRAY: {
|
||||
auto element_type =
|
||||
static_cast<DataType>(file_manager_context.fieldDataMeta
|
||||
.field_schema.element_type());
|
||||
switch (element_type) {
|
||||
case DataType::VECTOR_FLOAT:
|
||||
return std::make_unique<VectorMemIndex<float>>(
|
||||
element_type,
|
||||
index_type,
|
||||
metric_type,
|
||||
version,
|
||||
use_knowhere_build_pool,
|
||||
file_manager_context);
|
||||
default:
|
||||
ThrowInfo(NotImplemented,
|
||||
fmt::format("not implemented data type to "
|
||||
"build mem index: {}",
|
||||
data_type));
|
||||
}
|
||||
}
|
||||
default:
|
||||
ThrowInfo(
|
||||
DataTypeInvalid,
|
||||
|
||||
@ -56,6 +56,7 @@ class IndexFactory {
|
||||
|
||||
LoadResourceRequest
|
||||
IndexLoadResource(DataType field_type,
|
||||
DataType element_type,
|
||||
IndexVersion index_version,
|
||||
float index_size,
|
||||
const std::map<std::string, std::string>& index_params,
|
||||
@ -63,6 +64,7 @@ class IndexFactory {
|
||||
|
||||
LoadResourceRequest
|
||||
VecIndexLoadResource(DataType field_type,
|
||||
DataType element_type,
|
||||
IndexVersion index_version,
|
||||
float index_size,
|
||||
const std::map<std::string, std::string>& index_params,
|
||||
|
||||
@ -229,4 +229,21 @@ void inline SetBitsetGrowing(void* bitset,
|
||||
}
|
||||
}
|
||||
|
||||
inline size_t
|
||||
vector_element_size(const DataType data_type) {
|
||||
switch (data_type) {
|
||||
case DataType::VECTOR_FLOAT:
|
||||
return sizeof(float);
|
||||
case DataType::VECTOR_FLOAT16:
|
||||
return sizeof(float16);
|
||||
case DataType::VECTOR_BFLOAT16:
|
||||
return sizeof(bfloat16);
|
||||
case DataType::VECTOR_INT8:
|
||||
return sizeof(int8);
|
||||
default:
|
||||
ThrowInfo(UnexpectedError,
|
||||
fmt::format("invalid data type: {}", data_type));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace milvus::index
|
||||
|
||||
@ -245,7 +245,7 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
|
||||
SearchResult& search_result) const {
|
||||
AssertInfo(GetMetricType() == search_info.metric_type_,
|
||||
"Metric type of field index isn't the same with search info");
|
||||
auto num_queries = dataset->GetRows();
|
||||
auto num_rows = dataset->GetRows();
|
||||
auto topk = search_info.topk_;
|
||||
|
||||
knowhere::Json search_config = PrepareSearchParams(search_info);
|
||||
@ -277,7 +277,7 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
|
||||
res.what()));
|
||||
}
|
||||
return ReGenRangeSearchResult(
|
||||
res.value(), topk, num_queries, GetMetricType());
|
||||
res.value(), topk, num_rows, GetMetricType());
|
||||
} else {
|
||||
auto res = index_.Search(dataset, search_config, bitset);
|
||||
if (!res.has_value()) {
|
||||
@ -291,6 +291,8 @@ VectorDiskAnnIndex<T>::Query(const DatasetPtr dataset,
|
||||
}();
|
||||
|
||||
auto ids = final->GetIds();
|
||||
// In embedding list query, final->GetRows() can be different from dataset->GetRows().
|
||||
auto num_queries = final->GetRows();
|
||||
float* distances = const_cast<float*>(final->GetDistance());
|
||||
final->SetIsOwner(true);
|
||||
|
||||
|
||||
@ -59,12 +59,14 @@ namespace milvus::index {
|
||||
|
||||
template <typename T>
|
||||
VectorMemIndex<T>::VectorMemIndex(
|
||||
DataType elem_type,
|
||||
const IndexType& index_type,
|
||||
const MetricType& metric_type,
|
||||
const IndexVersion& version,
|
||||
bool use_knowhere_build_pool,
|
||||
const storage::FileManagerContext& file_manager_context)
|
||||
: VectorIndex(index_type, metric_type),
|
||||
elem_type_(elem_type),
|
||||
use_knowhere_build_pool_(use_knowhere_build_pool) {
|
||||
CheckMetricTypeSupport<T>(metric_type);
|
||||
AssertInfo(!is_unsupported(index_type, metric_type),
|
||||
@ -89,12 +91,14 @@ VectorMemIndex<T>::VectorMemIndex(
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VectorMemIndex<T>::VectorMemIndex(const IndexType& index_type,
|
||||
VectorMemIndex<T>::VectorMemIndex(DataType elem_type,
|
||||
const IndexType& index_type,
|
||||
const MetricType& metric_type,
|
||||
const IndexVersion& version,
|
||||
const knowhere::ViewDataOp view_data,
|
||||
bool use_knowhere_build_pool)
|
||||
: VectorIndex(index_type, metric_type),
|
||||
elem_type_(elem_type),
|
||||
use_knowhere_build_pool_(use_knowhere_build_pool) {
|
||||
CheckMetricTypeSupport<T>(metric_type);
|
||||
AssertInfo(!is_unsupported(index_type, metric_type),
|
||||
@ -304,6 +308,11 @@ VectorMemIndex<T>::BuildWithDataset(const DatasetPtr& dataset,
|
||||
SetDim(index_.Dim());
|
||||
}
|
||||
|
||||
bool
|
||||
is_embedding_list_index(const IndexType& index_type) {
|
||||
return index_type == knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void
|
||||
VectorMemIndex<T>::Build(const Config& config) {
|
||||
@ -331,23 +340,74 @@ VectorMemIndex<T>::Build(const Config& config) {
|
||||
total_num_rows += data->get_num_rows();
|
||||
AssertInfo(dim == 0 || dim == data->get_dim(),
|
||||
"inconsistent dim value between field datas!");
|
||||
dim = data->get_dim();
|
||||
|
||||
// todo(SapdeA): now, vector arrays (embedding list) are serialized
|
||||
// to parquet by using binary format which does not provide dim
|
||||
// information so we use this temporary solution.
|
||||
if (is_embedding_list_index(index_type_)) {
|
||||
AssertInfo(elem_type_ != DataType::NONE,
|
||||
"embedding list index must have elem_type");
|
||||
dim = config[DIM_KEY].get<int64_t>();
|
||||
} else {
|
||||
dim = data->get_dim();
|
||||
}
|
||||
}
|
||||
|
||||
auto buf = std::shared_ptr<uint8_t[]>(new uint8_t[total_size]);
|
||||
|
||||
size_t lim_offset = 0;
|
||||
std::vector<size_t> lims;
|
||||
lims.reserve(total_num_rows + 1);
|
||||
lims.push_back(lim_offset);
|
||||
|
||||
int64_t offset = 0;
|
||||
// TODO: avoid copying
|
||||
for (auto data : field_datas) {
|
||||
std::memcpy(buf.get() + offset, data->Data(), data->Size());
|
||||
offset += data->Size();
|
||||
data.reset();
|
||||
if (!is_embedding_list_index(index_type_)) {
|
||||
// TODO: avoid copying
|
||||
for (auto data : field_datas) {
|
||||
std::memcpy(buf.get() + offset, data->Data(), data->Size());
|
||||
offset += data->Size();
|
||||
data.reset();
|
||||
}
|
||||
} else {
|
||||
auto elem_size = vector_element_size(elem_type_);
|
||||
for (auto data : field_datas) {
|
||||
auto vec_array_data =
|
||||
dynamic_cast<FieldData<VectorArray>*>(data.get());
|
||||
AssertInfo(vec_array_data != nullptr,
|
||||
"failed to cast field data to vector array");
|
||||
|
||||
auto rows = vec_array_data->get_num_rows();
|
||||
for (auto i = 0; i < rows; ++i) {
|
||||
auto size = vec_array_data->DataSize(i);
|
||||
assert(size % (dim * elem_size) == 0);
|
||||
assert(dim * elem_size != 0);
|
||||
|
||||
auto vec_array = vec_array_data->value_at(i);
|
||||
|
||||
std::memcpy(buf.get() + offset, vec_array->data(), size);
|
||||
offset += size;
|
||||
|
||||
lim_offset += size / (dim * elem_size);
|
||||
lims.push_back(lim_offset);
|
||||
}
|
||||
|
||||
assert(data->Size() == offset);
|
||||
|
||||
data.reset();
|
||||
}
|
||||
|
||||
total_num_rows = lim_offset;
|
||||
}
|
||||
|
||||
field_datas.clear();
|
||||
|
||||
auto dataset = GenDataset(total_num_rows, dim, buf.get());
|
||||
if (!scalar_info.empty()) {
|
||||
dataset->Set(knowhere::meta::SCALAR_INFO, std::move(scalar_info));
|
||||
}
|
||||
if (!lims.empty()) {
|
||||
dataset->SetLims(lims.data());
|
||||
}
|
||||
BuildWithDataset(dataset, build_config);
|
||||
} else {
|
||||
// sparse
|
||||
@ -409,7 +469,7 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
|
||||
// AssertInfo(GetMetricType() == search_info.metric_type_,
|
||||
// "Metric type of field index isn't the same with search info");
|
||||
|
||||
auto num_queries = dataset->GetRows();
|
||||
auto num_vectors = dataset->GetRows();
|
||||
knowhere::Json search_conf = PrepareSearchParams(search_info);
|
||||
auto topk = search_info.topk_;
|
||||
// TODO :: check dim of search data
|
||||
@ -427,7 +487,7 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
|
||||
res.what());
|
||||
}
|
||||
auto result = ReGenRangeSearchResult(
|
||||
res.value(), topk, num_queries, GetMetricType());
|
||||
res.value(), topk, num_vectors, GetMetricType());
|
||||
milvus::tracer::AddEvent("finish_ReGenRangeSearchResult");
|
||||
return result;
|
||||
} else {
|
||||
@ -448,6 +508,8 @@ VectorMemIndex<T>::Query(const DatasetPtr dataset,
|
||||
}();
|
||||
|
||||
auto ids = final->GetIds();
|
||||
// In embedding list query, final->GetRows() can be different from dataset->GetRows().
|
||||
auto num_queries = final->GetRows();
|
||||
float* distances = const_cast<float*>(final->GetDistance());
|
||||
final->SetIsOwner(true);
|
||||
auto round_decimal = search_info.round_decimal_;
|
||||
|
||||
@ -35,6 +35,7 @@ template <typename T>
|
||||
class VectorMemIndex : public VectorIndex {
|
||||
public:
|
||||
explicit VectorMemIndex(
|
||||
DataType elem_type /* used for embedding list only */,
|
||||
const IndexType& index_type,
|
||||
const MetricType& metric_type,
|
||||
const IndexVersion& version,
|
||||
@ -43,7 +44,8 @@ class VectorMemIndex : public VectorIndex {
|
||||
storage::FileManagerContext());
|
||||
|
||||
// knowhere data view index special constucter for intermin index, no need to hold file_manager_ to upload or download files
|
||||
VectorMemIndex(const IndexType& index_type,
|
||||
VectorMemIndex(DataType elem_type /* used for embedding list only */,
|
||||
const IndexType& index_type,
|
||||
const MetricType& metric_type,
|
||||
const IndexVersion& version,
|
||||
const knowhere::ViewDataOp view_data,
|
||||
@ -108,6 +110,8 @@ class VectorMemIndex : public VectorIndex {
|
||||
Config config_;
|
||||
knowhere::Index<knowhere::IndexNode> index_;
|
||||
std::shared_ptr<storage::MemFileManagerImpl> file_manager_;
|
||||
// used for embedding list only
|
||||
DataType elem_type_;
|
||||
|
||||
CreateIndexInfo create_index_info_;
|
||||
bool use_knowhere_build_pool_;
|
||||
|
||||
@ -70,11 +70,8 @@ class IndexFactory {
|
||||
case DataType::VECTOR_BINARY:
|
||||
case DataType::VECTOR_SPARSE_FLOAT:
|
||||
case DataType::VECTOR_INT8:
|
||||
return std::make_unique<VecIndexCreator>(type, config, context);
|
||||
|
||||
case DataType::VECTOR_ARRAY:
|
||||
ThrowInfo(DataTypeInvalid,
|
||||
fmt::format("VECTOR_ARRAY is not implemented"));
|
||||
return std::make_unique<VecIndexCreator>(type, config, context);
|
||||
|
||||
default:
|
||||
ThrowInfo(DataTypeInvalid,
|
||||
|
||||
@ -34,6 +34,13 @@ VecIndexCreator::VecIndexCreator(
|
||||
Config& config,
|
||||
const storage::FileManagerContext& file_manager_context)
|
||||
: config_(config), data_type_(data_type) {
|
||||
if (data_type == DataType::VECTOR_ARRAY) {
|
||||
// TODO(SpadeA): record dim in config as there's the dim cannot be inferred in
|
||||
// parquet due to the serialize method of vector array.
|
||||
// This should be a temp solution.
|
||||
config_[DIM_KEY] = file_manager_context.indexMeta.dim;
|
||||
}
|
||||
|
||||
index::CreateIndexInfo index_info;
|
||||
index_info.field_type = data_type_;
|
||||
index_info.index_type = index::GetIndexTypeFromConfig(config_);
|
||||
|
||||
@ -273,6 +273,13 @@ class ChunkedColumnBase : public ChunkedColumnInterface {
|
||||
"VectorArrayViews only supported for ChunkedVectorArrayColumn");
|
||||
}
|
||||
|
||||
virtual PinWrapper<const size_t*>
|
||||
VectorArrayLims(int64_t chunk_id) const override {
|
||||
ThrowInfo(
|
||||
ErrorCode::Unsupported,
|
||||
"VectorArrayLims only supported for ChunkedVectorArrayColumn");
|
||||
}
|
||||
|
||||
PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>>
|
||||
StringViewsByOffsets(int64_t chunk_id,
|
||||
const FixedVector<int32_t>& offsets) const override {
|
||||
@ -621,6 +628,15 @@ class ChunkedVectorArrayColumn : public ChunkedColumnBase {
|
||||
return PinWrapper<std::vector<VectorArrayView>>(
|
||||
ca, static_cast<VectorArrayChunk*>(chunk)->Views());
|
||||
}
|
||||
|
||||
PinWrapper<const size_t*>
|
||||
VectorArrayLims(int64_t chunk_id) const override {
|
||||
auto ca =
|
||||
SemiInlineGet(slot_->PinCells({static_cast<cid_t>(chunk_id)}));
|
||||
auto chunk = ca->get_cell_of(chunk_id);
|
||||
return PinWrapper<const size_t*>(
|
||||
ca, static_cast<VectorArrayChunk*>(chunk)->Lims());
|
||||
}
|
||||
};
|
||||
|
||||
inline std::shared_ptr<ChunkedColumnInterface>
|
||||
|
||||
@ -319,6 +319,19 @@ class ProxyChunkColumn : public ChunkedColumnInterface {
|
||||
static_cast<VectorArrayChunk*>(chunk.get())->Views());
|
||||
}
|
||||
|
||||
PinWrapper<const size_t*>
|
||||
VectorArrayLims(int64_t chunk_id) const override {
|
||||
if (!IsChunkedVectorArrayColumnDataType(data_type_)) {
|
||||
ThrowInfo(ErrorCode::Unsupported,
|
||||
"VectorArrayLims only supported for "
|
||||
"ChunkedVectorArrayColumn");
|
||||
}
|
||||
auto chunk_wrapper = group_->GetGroupChunk(chunk_id);
|
||||
auto chunk = chunk_wrapper.get()->GetChunk(field_id_);
|
||||
return PinWrapper<const size_t*>(
|
||||
chunk_wrapper, static_cast<VectorArrayChunk*>(chunk.get())->Lims());
|
||||
}
|
||||
|
||||
PinWrapper<std::pair<std::vector<std::string_view>, FixedVector<bool>>>
|
||||
StringViewsByOffsets(int64_t chunk_id,
|
||||
const FixedVector<int32_t>& offsets) const override {
|
||||
|
||||
@ -84,6 +84,9 @@ class ChunkedColumnInterface {
|
||||
virtual PinWrapper<std::vector<VectorArrayView>>
|
||||
VectorArrayViews(int64_t chunk_id) const = 0;
|
||||
|
||||
virtual PinWrapper<const size_t*>
|
||||
VectorArrayLims(int64_t chunk_id) const = 0;
|
||||
|
||||
virtual PinWrapper<
|
||||
std::pair<std::vector<std::string_view>, FixedVector<bool>>>
|
||||
StringViewsByOffsets(int64_t chunk_id,
|
||||
|
||||
@ -242,4 +242,9 @@ ExecPlanNodeVisitor::visit(Int8VectorANNS& node) {
|
||||
VectorVisitorImpl<Int8Vector>(node);
|
||||
}
|
||||
|
||||
void
|
||||
ExecPlanNodeVisitor::visit(EmbListFloatVectorANNS& node) {
|
||||
VectorVisitorImpl<EmbListFloatVector>(node);
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
||||
@ -43,6 +43,9 @@ class ExecPlanNodeVisitor : public PlanNodeVisitor {
|
||||
void
|
||||
visit(RetrievePlanNode& node) override;
|
||||
|
||||
void
|
||||
visit(EmbListFloatVectorANNS& node) override;
|
||||
|
||||
public:
|
||||
ExecPlanNodeVisitor(const segcore::SegmentInterface& segment,
|
||||
Timestamp timestamp,
|
||||
|
||||
@ -30,6 +30,17 @@ ParsePlaceholderGroup(const Plan* plan,
|
||||
placeholder_group_blob.size());
|
||||
}
|
||||
|
||||
bool
|
||||
check_data_type(const FieldMeta& field_meta,
|
||||
const milvus::proto::common::PlaceholderType type) {
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
||||
return type ==
|
||||
milvus::proto::common::PlaceholderType::EmbListFloatVector;
|
||||
}
|
||||
return static_cast<int>(field_meta.get_data_type()) ==
|
||||
static_cast<int>(type);
|
||||
}
|
||||
|
||||
std::unique_ptr<PlaceholderGroup>
|
||||
ParsePlaceholderGroup(const Plan* plan,
|
||||
const uint8_t* blob,
|
||||
@ -44,8 +55,7 @@ ParsePlaceholderGroup(const Plan* plan,
|
||||
Assert(plan->tag2field_.count(element.tag_));
|
||||
auto field_id = plan->tag2field_.at(element.tag_);
|
||||
auto& field_meta = plan->schema_->operator[](field_id);
|
||||
AssertInfo(static_cast<int>(field_meta.get_data_type()) ==
|
||||
static_cast<int>(info.type()),
|
||||
AssertInfo(check_data_type(field_meta, info.type()),
|
||||
"vector type must be the same, field {} - type {}, search "
|
||||
"info type {}",
|
||||
field_meta.get_name().get(),
|
||||
@ -59,23 +69,47 @@ ParsePlaceholderGroup(const Plan* plan,
|
||||
SparseBytesToRows(info.values(), /*validate=*/true);
|
||||
} else {
|
||||
auto line_size = info.values().Get(0).size();
|
||||
if (field_meta.get_sizeof() != line_size) {
|
||||
ThrowInfo(
|
||||
DimNotMatch,
|
||||
fmt::format("vector dimension mismatch, expected vector "
|
||||
"size(byte) {}, actual {}.",
|
||||
field_meta.get_sizeof(),
|
||||
line_size));
|
||||
}
|
||||
auto& target = element.blob_;
|
||||
target.reserve(line_size * element.num_of_queries_);
|
||||
for (auto& line : info.values()) {
|
||||
AssertInfo(line_size == line.size(),
|
||||
"vector dimension mismatch, expected vector "
|
||||
"size(byte) {}, actual {}.",
|
||||
line_size,
|
||||
line.size());
|
||||
target.insert(target.end(), line.begin(), line.end());
|
||||
|
||||
if (field_meta.get_data_type() != DataType::VECTOR_ARRAY) {
|
||||
if (field_meta.get_sizeof() != line_size) {
|
||||
ThrowInfo(DimNotMatch,
|
||||
fmt::format(
|
||||
"vector dimension mismatch, expected vector "
|
||||
"size(byte) {}, actual {}.",
|
||||
field_meta.get_sizeof(),
|
||||
line_size));
|
||||
}
|
||||
target.reserve(line_size * element.num_of_queries_);
|
||||
for (auto& line : info.values()) {
|
||||
AssertInfo(line_size == line.size(),
|
||||
"vector dimension mismatch, expected vector "
|
||||
"size(byte) {}, actual {}.",
|
||||
line_size,
|
||||
line.size());
|
||||
target.insert(target.end(), line.begin(), line.end());
|
||||
}
|
||||
} else {
|
||||
target.reserve(line_size * element.num_of_queries_);
|
||||
auto dim = field_meta.get_dim();
|
||||
|
||||
// If the vector is embedding list, line contains multiple vectors.
|
||||
// And we should record the offsets so that we can identify each
|
||||
// embedding list in a flattened vectors.
|
||||
auto& lims = element.lims_;
|
||||
lims.reserve(element.num_of_queries_ + 1);
|
||||
size_t offset = 0;
|
||||
lims.push_back(offset);
|
||||
|
||||
auto elem_size = milvus::index::vector_element_size(
|
||||
field_meta.get_element_type());
|
||||
for (auto& line : info.values()) {
|
||||
target.insert(target.end(), line.begin(), line.end());
|
||||
|
||||
Assert(line.size() % (dim * elem_size) == 0);
|
||||
offset += line.size() / (dim * elem_size);
|
||||
lims.push_back(offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
result->emplace_back(std::move(element));
|
||||
|
||||
@ -68,6 +68,9 @@ struct Plan {
|
||||
|
||||
struct Placeholder {
|
||||
std::string tag_;
|
||||
// note: for embedding list search, num_of_queries_ stands for the number of vectors.
|
||||
// lims_ records the offsets of embedding list in the flattened vector and
|
||||
// hence lims_.size() - 1 is the number of queries in embedding list search.
|
||||
int64_t num_of_queries_;
|
||||
// TODO(SPARSE): add a dim_ field here, use the dim passed in search request
|
||||
// instead of the dim in schema, since the dim of sparse float column is
|
||||
@ -78,6 +81,8 @@ struct Placeholder {
|
||||
// dense vector search and sparse_matrix_ is for sparse vector search.
|
||||
aligned_vector<char> blob_;
|
||||
std::unique_ptr<knowhere::sparse::SparseRow<float>[]> sparse_matrix_;
|
||||
// offsets for embedding list
|
||||
aligned_vector<size_t> lims_;
|
||||
|
||||
const void*
|
||||
get_blob() const {
|
||||
@ -94,6 +99,16 @@ struct Placeholder {
|
||||
}
|
||||
return blob_.data();
|
||||
}
|
||||
|
||||
const size_t*
|
||||
get_lims() const {
|
||||
return lims_.data();
|
||||
}
|
||||
|
||||
size_t*
|
||||
get_lims() {
|
||||
return lims_.data();
|
||||
}
|
||||
};
|
||||
|
||||
struct RetrievePlan {
|
||||
|
||||
@ -50,4 +50,9 @@ RetrievePlanNode::accept(PlanNodeVisitor& visitor) {
|
||||
visitor.visit(*this);
|
||||
}
|
||||
|
||||
void
|
||||
EmbListFloatVectorANNS::accept(PlanNodeVisitor& visitor) {
|
||||
visitor.visit(*this);
|
||||
}
|
||||
|
||||
} // namespace milvus::query
|
||||
|
||||
@ -77,6 +77,12 @@ struct Int8VectorANNS : VectorPlanNode {
|
||||
accept(PlanNodeVisitor&) override;
|
||||
};
|
||||
|
||||
struct EmbListFloatVectorANNS : VectorPlanNode {
|
||||
public:
|
||||
void
|
||||
accept(PlanNodeVisitor&) override;
|
||||
};
|
||||
|
||||
struct RetrievePlanNode : PlanNode {
|
||||
public:
|
||||
void
|
||||
|
||||
@ -39,5 +39,8 @@ class PlanNodeVisitor {
|
||||
|
||||
virtual void
|
||||
visit(RetrievePlanNode&) = 0;
|
||||
|
||||
virtual void
|
||||
visit(EmbListFloatVectorANNS&) = 0;
|
||||
};
|
||||
} // namespace milvus::query
|
||||
|
||||
@ -127,6 +127,9 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
|
||||
} else if (anns_proto.vector_type() ==
|
||||
milvus::proto::plan::VectorType::Int8Vector) {
|
||||
return std::make_unique<Int8VectorANNS>();
|
||||
} else if (anns_proto.vector_type() ==
|
||||
milvus::proto::plan::VectorType::EmbListFloatVector) {
|
||||
return std::make_unique<EmbListFloatVectorANNS>();
|
||||
} else {
|
||||
return std::make_unique<FloatVectorANNS>();
|
||||
}
|
||||
|
||||
@ -89,8 +89,23 @@ PrepareBFDataSet(const dataset::SearchDataset& query_ds,
|
||||
DataType data_type) {
|
||||
auto base_dataset =
|
||||
knowhere::GenDataSet(raw_ds.num_raw_data, raw_ds.dim, raw_ds.raw_data);
|
||||
if (raw_ds.raw_data_lims != nullptr) {
|
||||
// knowhere::DataSet count vectors in a flattened manner where as the num_raw_data here is the number
|
||||
// of embedding lists where each embedding list contains multiple vectors. So we should use the last element
|
||||
// in lims which equals to the total number of vectors.
|
||||
base_dataset->SetLims(raw_ds.raw_data_lims);
|
||||
// the length of lims equals to the number of embedding lists + 1
|
||||
base_dataset->SetRows(raw_ds.raw_data_lims[raw_ds.num_raw_data]);
|
||||
}
|
||||
|
||||
auto query_dataset = knowhere::GenDataSet(
|
||||
query_ds.num_queries, query_ds.dim, query_ds.query_data);
|
||||
if (query_ds.query_lims != nullptr) {
|
||||
// ditto
|
||||
query_dataset->SetLims(query_ds.query_lims);
|
||||
query_dataset->SetRows(query_ds.query_lims[query_ds.num_queries]);
|
||||
}
|
||||
|
||||
if (data_type == DataType::VECTOR_SPARSE_FLOAT) {
|
||||
base_dataset->SetIsSparse(true);
|
||||
query_dataset->SetIsSparse(true);
|
||||
@ -105,7 +120,8 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type) {
|
||||
DataType data_type,
|
||||
DataType element_type) {
|
||||
SubSearchResult sub_result(query_ds.num_queries,
|
||||
query_ds.topk,
|
||||
query_ds.metric_type,
|
||||
@ -122,7 +138,18 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
||||
sub_result.mutable_seg_offsets().resize(nq * topk);
|
||||
sub_result.mutable_distances().resize(nq * topk);
|
||||
|
||||
// For vector array (embedding list), element type is used to determine how to operate search.
|
||||
if (data_type == DataType::VECTOR_ARRAY) {
|
||||
AssertInfo(element_type != DataType::NONE,
|
||||
"Element type is not specified for vector array");
|
||||
data_type = element_type;
|
||||
}
|
||||
|
||||
if (search_cfg.contains(RADIUS)) {
|
||||
AssertInfo(data_type != DataType::VECTOR_ARRAY,
|
||||
"Vector array(embedding list) is not supported for range "
|
||||
"search");
|
||||
|
||||
if (search_cfg.contains(RANGE_FILTER)) {
|
||||
CheckRangeSearchParam(search_cfg[RADIUS],
|
||||
search_cfg[RANGE_FILTER],
|
||||
@ -238,7 +265,10 @@ DispatchBruteForceIteratorByDataType(const knowhere::DataSetPtr& base_dataset,
|
||||
const knowhere::DataSetPtr& query_dataset,
|
||||
const knowhere::Json& config,
|
||||
const BitsetView& bitset,
|
||||
const milvus::DataType& data_type) {
|
||||
milvus::DataType data_type) {
|
||||
AssertInfo(data_type != DataType::VECTOR_ARRAY,
|
||||
"VECTOR_ARRAY is not supported for brute force iterator");
|
||||
|
||||
switch (data_type) {
|
||||
case DataType::VECTOR_FLOAT:
|
||||
return knowhere::BruteForce::AnnIterator<float>(
|
||||
|
||||
@ -29,7 +29,8 @@ BruteForceSearch(const dataset::SearchDataset& query_ds,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const BitsetView& bitset,
|
||||
DataType data_type);
|
||||
DataType data_type,
|
||||
DataType element_type);
|
||||
|
||||
knowhere::expected<std::vector<knowhere::IndexNode::IteratorPtr>>
|
||||
GetBruteForceSearchIterators(
|
||||
|
||||
@ -71,6 +71,7 @@ void
|
||||
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
const SearchInfo& info,
|
||||
const void* query_data,
|
||||
const size_t* query_lims,
|
||||
int64_t num_queries,
|
||||
Timestamp timestamp,
|
||||
const BitsetView& bitset,
|
||||
@ -87,6 +88,7 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
CheckBruteForceSearchParam(field, info);
|
||||
|
||||
auto data_type = field.get_data_type();
|
||||
auto element_type = field.get_element_type();
|
||||
AssertInfo(IsVectorDataType(data_type),
|
||||
"[SearchOnGrowing]Data type isn't vector type");
|
||||
|
||||
@ -96,6 +98,11 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
|
||||
// step 2: small indexing search
|
||||
if (segment.get_indexing_record().SyncDataWithIndex(field.get_id())) {
|
||||
AssertInfo(
|
||||
data_type != DataType::VECTOR_ARRAY,
|
||||
"vector array(embedding list) is not supported for growing segment "
|
||||
"indexing search");
|
||||
|
||||
FloatSegmentIndexSearch(
|
||||
segment, info, query_data, num_queries, bitset, search_result);
|
||||
} else {
|
||||
@ -103,6 +110,10 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
segment.get_chunk_mutex());
|
||||
// check SyncDataWithIndex() again, in case the vector chunks has been removed.
|
||||
if (segment.get_indexing_record().SyncDataWithIndex(field.get_id())) {
|
||||
AssertInfo(data_type != DataType::VECTOR_ARRAY,
|
||||
"vector array(embedding list) is not supported for "
|
||||
"growing segment indexing search");
|
||||
|
||||
return FloatSegmentIndexSearch(
|
||||
segment, info, query_data, num_queries, bitset, search_result);
|
||||
}
|
||||
@ -111,8 +122,13 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
auto dim = field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT
|
||||
? 0
|
||||
: field.get_dim();
|
||||
dataset::SearchDataset search_dataset{
|
||||
metric_type, num_queries, topk, round_decimal, dim, query_data};
|
||||
dataset::SearchDataset search_dataset{metric_type,
|
||||
num_queries,
|
||||
topk,
|
||||
round_decimal,
|
||||
dim,
|
||||
query_data,
|
||||
query_lims};
|
||||
int32_t current_chunk_id = 0;
|
||||
|
||||
// get K1 and B from index for bm25 brute force
|
||||
@ -127,6 +143,10 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
auto vec_ptr = record.get_data_base(vecfield_id);
|
||||
|
||||
if (info.iterator_v2_info_.has_value()) {
|
||||
AssertInfo(data_type != DataType::VECTOR_ARRAY,
|
||||
"vector array(embedding list) is not supported for "
|
||||
"vector iterator");
|
||||
|
||||
CachedSearchIterator cached_iter(search_dataset,
|
||||
vec_ptr,
|
||||
active_count,
|
||||
@ -150,9 +170,54 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
std::min(active_count, (chunk_id + 1) * vec_size_per_chunk);
|
||||
auto size_per_chunk = element_end - element_begin;
|
||||
|
||||
auto sub_data = query::dataset::RawDataset{
|
||||
element_begin, dim, size_per_chunk, chunk_data};
|
||||
query::dataset::RawDataset sub_data;
|
||||
std::unique_ptr<uint8_t[]> buf = nullptr;
|
||||
std::vector<size_t> offsets;
|
||||
if (data_type != DataType::VECTOR_ARRAY) {
|
||||
sub_data = query::dataset::RawDataset{
|
||||
element_begin, dim, size_per_chunk, chunk_data};
|
||||
} else {
|
||||
// TODO(SpadeA): For VectorArray(Embedding List), data is
|
||||
// discreted stored in FixedVector which means we will copy the
|
||||
// data to a contiguous memory buffer. This is inefficient and
|
||||
// will be optimized in the future.
|
||||
auto vec_ptr = reinterpret_cast<const VectorArray*>(chunk_data);
|
||||
auto size = 0;
|
||||
for (int i = 0; i < size_per_chunk; ++i) {
|
||||
size += vec_ptr[i].byte_size();
|
||||
}
|
||||
|
||||
buf = std::make_unique<uint8_t[]>(size);
|
||||
offsets.reserve(size_per_chunk + 1);
|
||||
offsets.push_back(0);
|
||||
|
||||
auto offset = 0;
|
||||
auto ptr = buf.get();
|
||||
for (int i = 0; i < size_per_chunk; ++i) {
|
||||
memcpy(ptr, vec_ptr[i].data(), vec_ptr[i].byte_size());
|
||||
ptr += vec_ptr[i].byte_size();
|
||||
|
||||
offset += vec_ptr[i].length();
|
||||
offsets.push_back(offset);
|
||||
}
|
||||
sub_data = query::dataset::RawDataset{element_begin,
|
||||
dim,
|
||||
size_per_chunk,
|
||||
buf.get(),
|
||||
offsets.data()};
|
||||
}
|
||||
|
||||
if (data_type == DataType::VECTOR_ARRAY) {
|
||||
AssertInfo(
|
||||
query_lims != nullptr,
|
||||
"query_lims is nullptr, but data_type is vector array");
|
||||
}
|
||||
|
||||
if (milvus::exec::UseVectorIterator(info)) {
|
||||
AssertInfo(data_type != DataType::VECTOR_ARRAY,
|
||||
"vector array(embedding list) is not supported for "
|
||||
"vector iterator");
|
||||
|
||||
auto sub_qr =
|
||||
PackBruteForceSearchIteratorsIntoSubResult(search_dataset,
|
||||
sub_data,
|
||||
@ -167,7 +232,8 @@ SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
info,
|
||||
index_info,
|
||||
bitset,
|
||||
data_type);
|
||||
data_type,
|
||||
element_type);
|
||||
final_qr.merge(sub_qr);
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ void
|
||||
SearchOnGrowing(const segcore::SegmentGrowingImpl& segment,
|
||||
const SearchInfo& info,
|
||||
const void* query_data,
|
||||
const size_t* query_lims,
|
||||
int64_t num_queries,
|
||||
Timestamp timestamp,
|
||||
const BitsetView& bitset,
|
||||
|
||||
@ -31,6 +31,7 @@ SearchOnSealedIndex(const Schema& schema,
|
||||
const segcore::SealedIndexingRecord& record,
|
||||
const SearchInfo& search_info,
|
||||
const void* query_data,
|
||||
const size_t* query_lims,
|
||||
int64_t num_queries,
|
||||
const BitsetView& bitset,
|
||||
SearchResult& search_result) {
|
||||
@ -52,7 +53,18 @@ SearchOnSealedIndex(const Schema& schema,
|
||||
field_indexing->metric_type_,
|
||||
search_info.metric_type_);
|
||||
|
||||
auto dataset = knowhere::GenDataSet(num_queries, dim, query_data);
|
||||
knowhere::DataSetPtr dataset;
|
||||
if (query_lims == nullptr) {
|
||||
dataset = knowhere::GenDataSet(num_queries, dim, query_data);
|
||||
} else {
|
||||
// Rather than non-embedding list search where num_queries equals to the number of vectors,
|
||||
// in embedding list search, multiple vectors form an embedding list and the last element of query_lims
|
||||
// stands for the total number of vectors.
|
||||
auto num_vectors = query_lims[num_queries];
|
||||
dataset = knowhere::GenDataSet(num_vectors, dim, query_data);
|
||||
dataset->SetLims(query_lims);
|
||||
}
|
||||
|
||||
dataset->SetIsSparse(is_sparse);
|
||||
auto accessor = SemiInlineGet(field_indexing->indexing_->PinCells({0}));
|
||||
auto vec_index =
|
||||
@ -92,6 +104,7 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const void* query_data,
|
||||
const size_t* query_lims,
|
||||
int64_t num_queries,
|
||||
int64_t row_count,
|
||||
const BitsetView& bitview,
|
||||
@ -99,22 +112,26 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
auto field_id = search_info.field_id_;
|
||||
auto& field = schema[field_id];
|
||||
|
||||
auto data_type = field.get_data_type();
|
||||
auto element_type = field.get_element_type();
|
||||
// TODO(SPARSE): see todo in PlanImpl.h::PlaceHolder.
|
||||
auto dim = field.get_data_type() == DataType::VECTOR_SPARSE_FLOAT
|
||||
? 0
|
||||
: field.get_dim();
|
||||
auto dim = data_type == DataType::VECTOR_SPARSE_FLOAT ? 0 : field.get_dim();
|
||||
|
||||
query::dataset::SearchDataset query_dataset{search_info.metric_type_,
|
||||
num_queries,
|
||||
search_info.topk_,
|
||||
search_info.round_decimal_,
|
||||
dim,
|
||||
query_data};
|
||||
query_data,
|
||||
query_lims};
|
||||
|
||||
auto data_type = field.get_data_type();
|
||||
CheckBruteForceSearchParam(field, search_info);
|
||||
|
||||
if (search_info.iterator_v2_info_.has_value()) {
|
||||
AssertInfo(data_type != DataType::VECTOR_ARRAY,
|
||||
"vector array(embedding list) is not supported for "
|
||||
"vector iterator");
|
||||
|
||||
CachedSearchIterator cached_iter(
|
||||
column, query_dataset, search_info, index_info, bitview, data_type);
|
||||
cached_iter.NextBatch(search_info, result);
|
||||
@ -135,7 +152,20 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
auto chunk_size = column->chunk_row_nums(i);
|
||||
auto raw_dataset =
|
||||
query::dataset::RawDataset{offset, dim, chunk_size, vec_data};
|
||||
|
||||
PinWrapper<const size_t*> lims_pw;
|
||||
if (data_type == DataType::VECTOR_ARRAY) {
|
||||
AssertInfo(query_lims != nullptr,
|
||||
"query_lims is nullptr, but data_type is vector array");
|
||||
|
||||
lims_pw = column->VectorArrayLims(i);
|
||||
raw_dataset.raw_data_lims = lims_pw.get();
|
||||
}
|
||||
|
||||
if (milvus::exec::UseVectorIterator(search_info)) {
|
||||
AssertInfo(data_type != DataType::VECTOR_ARRAY,
|
||||
"vector array(embedding list) is not supported for "
|
||||
"vector iterator");
|
||||
auto sub_qr =
|
||||
PackBruteForceSearchIteratorsIntoSubResult(query_dataset,
|
||||
raw_dataset,
|
||||
@ -150,7 +180,8 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
search_info,
|
||||
index_info,
|
||||
bitview,
|
||||
data_type);
|
||||
data_type,
|
||||
element_type);
|
||||
final_qr.merge(sub_qr);
|
||||
}
|
||||
offset += chunk_size;
|
||||
|
||||
@ -23,6 +23,7 @@ SearchOnSealedIndex(const Schema& schema,
|
||||
const segcore::SealedIndexingRecord& record,
|
||||
const SearchInfo& search_info,
|
||||
const void* query_data,
|
||||
const size_t* query_lims,
|
||||
int64_t num_queries,
|
||||
const BitsetView& view,
|
||||
SearchResult& search_result);
|
||||
@ -33,6 +34,7 @@ SearchOnSealedColumn(const Schema& schema,
|
||||
const SearchInfo& search_info,
|
||||
const std::map<std::string, std::string>& index_info,
|
||||
const void* query_data,
|
||||
const size_t* query_lims,
|
||||
int64_t num_queries,
|
||||
int64_t row_count,
|
||||
const BitsetView& bitset,
|
||||
|
||||
@ -24,6 +24,7 @@ struct RawDataset {
|
||||
int64_t dim;
|
||||
int64_t num_raw_data;
|
||||
const void* raw_data;
|
||||
const size_t* raw_data_lims = nullptr;
|
||||
};
|
||||
struct SearchDataset {
|
||||
knowhere::MetricType metric_type;
|
||||
@ -32,6 +33,8 @@ struct SearchDataset {
|
||||
int64_t round_decimal;
|
||||
int64_t dim;
|
||||
const void* query_data;
|
||||
// used for embedding list query
|
||||
const size_t* query_lims = nullptr;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
||||
@ -95,10 +95,6 @@ ChunkedSegmentSealedImpl::LoadIndex(const LoadIndexInfo& info) {
|
||||
auto field_id = FieldId(info.field_id);
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
||||
ThrowInfo(DataTypeInvalid, "VECTOR_ARRAY is not implemented");
|
||||
}
|
||||
|
||||
if (field_meta.is_vector()) {
|
||||
LoadVecIndex(info);
|
||||
} else {
|
||||
@ -127,6 +123,7 @@ ChunkedSegmentSealedImpl::LoadVecIndex(const LoadIndexInfo& info) {
|
||||
LoadResourceRequest request =
|
||||
milvus::index::IndexFactory::GetInstance().VecIndexLoadResource(
|
||||
field_meta.get_data_type(),
|
||||
info.element_type,
|
||||
info.index_engine_version,
|
||||
info.index_size,
|
||||
info.index_params,
|
||||
@ -498,10 +495,6 @@ int64_t
|
||||
ChunkedSegmentSealedImpl::num_chunk_index(FieldId field_id) const {
|
||||
auto& field_meta = schema_->operator[](field_id);
|
||||
|
||||
if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
||||
ThrowInfo(DataTypeInvalid, "VECTOR_ARRAY is not implemented");
|
||||
}
|
||||
|
||||
if (field_meta.is_vector()) {
|
||||
return int64_t(vector_indexings_.is_ready(field_id));
|
||||
}
|
||||
@ -720,6 +713,7 @@ ChunkedSegmentSealedImpl::mask_with_delete(BitsetTypeView& bitset,
|
||||
void
|
||||
ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
|
||||
const void* query_data,
|
||||
const size_t* query_lims,
|
||||
int64_t query_count,
|
||||
Timestamp timestamp,
|
||||
const BitsetView& bitset,
|
||||
@ -745,6 +739,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
|
||||
vector_indexings_,
|
||||
binlog_search_info,
|
||||
query_data,
|
||||
query_lims,
|
||||
query_count,
|
||||
bitset,
|
||||
output);
|
||||
@ -758,6 +753,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
|
||||
vector_indexings_,
|
||||
search_info,
|
||||
query_data,
|
||||
query_lims,
|
||||
query_count,
|
||||
bitset,
|
||||
output);
|
||||
@ -782,6 +778,7 @@ ChunkedSegmentSealedImpl::vector_search(SearchInfo& search_info,
|
||||
search_info,
|
||||
index_info,
|
||||
query_data,
|
||||
query_lims,
|
||||
query_count,
|
||||
row_count,
|
||||
bitset,
|
||||
|
||||
@ -395,6 +395,7 @@ class ChunkedSegmentSealedImpl : public SegmentSealed {
|
||||
void
|
||||
vector_search(SearchInfo& search_info,
|
||||
const void* query_data,
|
||||
const size_t* query_lims,
|
||||
int64_t query_count,
|
||||
Timestamp timestamp,
|
||||
const BitsetView& bitset,
|
||||
|
||||
@ -47,6 +47,7 @@ VectorFieldIndexing::recreate_index(DataType data_type,
|
||||
const VectorBase* field_raw_data) {
|
||||
if (IsSparseFloatVectorDataType(data_type)) {
|
||||
index_ = std::make_unique<index::VectorMemIndex<float>>(
|
||||
DataType::NONE,
|
||||
config_->GetIndexType(),
|
||||
config_->GetMetricType(),
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber());
|
||||
@ -62,6 +63,7 @@ VectorFieldIndexing::recreate_index(DataType data_type,
|
||||
return (const void*)field_raw_data_ptr->get_element(id);
|
||||
};
|
||||
index_ = std::make_unique<index::VectorMemIndex<float>>(
|
||||
DataType::NONE,
|
||||
config_->GetIndexType(),
|
||||
config_->GetMetricType(),
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
@ -78,6 +80,7 @@ VectorFieldIndexing::recreate_index(DataType data_type,
|
||||
return (const void*)field_raw_data_ptr->get_element(id);
|
||||
};
|
||||
index_ = std::make_unique<index::VectorMemIndex<float16>>(
|
||||
DataType::NONE,
|
||||
config_->GetIndexType(),
|
||||
config_->GetMetricType(),
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
@ -94,6 +97,7 @@ VectorFieldIndexing::recreate_index(DataType data_type,
|
||||
return (const void*)field_raw_data_ptr->get_element(id);
|
||||
};
|
||||
index_ = std::make_unique<index::VectorMemIndex<bfloat16>>(
|
||||
DataType::NONE,
|
||||
config_->GetIndexType(),
|
||||
config_->GetMetricType(),
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
|
||||
@ -292,8 +292,9 @@ class IndexingRecord {
|
||||
index_meta_->HasFiled(field_id)) {
|
||||
auto vec_field_meta =
|
||||
index_meta_->GetFieldIndexMeta(field_id);
|
||||
//Disable growing index for flat
|
||||
if (!vec_field_meta.IsFlatIndex()) {
|
||||
//Disable growing index for flat and embedding list
|
||||
if (!vec_field_meta.IsFlatIndex() &&
|
||||
field_meta.get_data_type() != DataType::VECTOR_ARRAY) {
|
||||
auto field_raw_data =
|
||||
insert_record->get_data_base(field_id);
|
||||
field_indexings_.try_emplace(
|
||||
|
||||
@ -695,12 +695,19 @@ SegmentGrowingImpl::search_batch_pks(
|
||||
void
|
||||
SegmentGrowingImpl::vector_search(SearchInfo& search_info,
|
||||
const void* query_data,
|
||||
const size_t* query_lims,
|
||||
int64_t query_count,
|
||||
Timestamp timestamp,
|
||||
const BitsetView& bitset,
|
||||
SearchResult& output) const {
|
||||
query::SearchOnGrowing(
|
||||
*this, search_info, query_data, query_count, timestamp, bitset, output);
|
||||
query::SearchOnGrowing(*this,
|
||||
search_info,
|
||||
query_data,
|
||||
query_lims,
|
||||
query_count,
|
||||
timestamp,
|
||||
bitset,
|
||||
output);
|
||||
}
|
||||
|
||||
std::unique_ptr<DataArray>
|
||||
|
||||
@ -326,6 +326,7 @@ class SegmentGrowingImpl : public SegmentGrowing {
|
||||
void
|
||||
vector_search(SearchInfo& search_info,
|
||||
const void* query_data,
|
||||
const size_t* query_lims,
|
||||
int64_t query_count,
|
||||
Timestamp timestamp,
|
||||
const BitsetView& bitset,
|
||||
|
||||
@ -394,9 +394,14 @@ class SegmentInternalInterface : public SegmentInterface {
|
||||
const std::string& nested_path) const override;
|
||||
|
||||
public:
|
||||
// `query_lims` is not null only for vector array (embedding list) search
|
||||
// where it denotes the number of vectors in each embedding list. The length
|
||||
// of `query_lims` is the number of queries in the search plus one (the first
|
||||
// element in query_lims is 0).
|
||||
virtual void
|
||||
vector_search(SearchInfo& search_info,
|
||||
const void* query_data,
|
||||
const size_t* query_lims,
|
||||
int64_t query_count,
|
||||
Timestamp timestamp,
|
||||
const BitsetView& bitset,
|
||||
|
||||
@ -35,6 +35,8 @@ struct LoadIndexInfo {
|
||||
int64_t segment_id;
|
||||
int64_t field_id;
|
||||
DataType field_type;
|
||||
// The element type of the field. It's DataType::NONE if field_type is array/vector_array.
|
||||
DataType element_type;
|
||||
bool enable_mmap;
|
||||
std::string mmap_dir_path;
|
||||
int64_t index_id;
|
||||
|
||||
@ -668,7 +668,11 @@ MergeDataArray(std::vector<MergeBase>& merge_bases,
|
||||
auto obj = vector_array->mutable_int8_vector();
|
||||
obj->assign(data, dim * sizeof(int8));
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
||||
ThrowInfo(DataTypeInvalid, "VECTOR_ARRAY is not implemented");
|
||||
auto data = src_field_data->vectors().vector_array();
|
||||
auto obj = vector_array->mutable_vector_array();
|
||||
obj->set_element_type(
|
||||
proto::schema::DataType(field_meta.get_element_type()));
|
||||
obj->CopyFrom(data);
|
||||
} else {
|
||||
ThrowInfo(DataTypeInvalid,
|
||||
fmt::format("unsupported datatype {}", data_type));
|
||||
|
||||
@ -15,7 +15,11 @@
|
||||
#include "knowhere/comp/knowhere_check.h"
|
||||
|
||||
bool
|
||||
CheckVecIndexWithDataType(const char* index_type, enum CDataType data_type) {
|
||||
CheckVecIndexWithDataType(const char* index_type,
|
||||
enum CDataType data_type,
|
||||
bool is_emb_list_data) {
|
||||
return knowhere::KnowhereCheck::IndexTypeAndDataTypeCheck(
|
||||
std::string(index_type), knowhere::VecType(data_type));
|
||||
std::string(index_type),
|
||||
knowhere::VecType(data_type),
|
||||
is_emb_list_data);
|
||||
}
|
||||
|
||||
@ -16,7 +16,9 @@ extern "C" {
|
||||
#endif
|
||||
#include "common/type_c.h"
|
||||
bool
|
||||
CheckVecIndexWithDataType(const char* index_type, enum CDataType data_type);
|
||||
CheckVecIndexWithDataType(const char* index_type,
|
||||
enum CDataType data_type,
|
||||
bool is_emb_list_data);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@ -98,40 +98,6 @@ AppendIndexParam(CLoadIndexInfo c_load_index_info,
|
||||
}
|
||||
}
|
||||
|
||||
CStatus
|
||||
AppendFieldInfo(CLoadIndexInfo c_load_index_info,
|
||||
int64_t collection_id,
|
||||
int64_t partition_id,
|
||||
int64_t segment_id,
|
||||
int64_t field_id,
|
||||
enum CDataType field_type,
|
||||
bool enable_mmap,
|
||||
const char* mmap_dir_path) {
|
||||
SCOPE_CGO_CALL_METRIC();
|
||||
|
||||
try {
|
||||
auto load_index_info =
|
||||
(milvus::segcore::LoadIndexInfo*)c_load_index_info;
|
||||
load_index_info->collection_id = collection_id;
|
||||
load_index_info->partition_id = partition_id;
|
||||
load_index_info->segment_id = segment_id;
|
||||
load_index_info->field_id = field_id;
|
||||
load_index_info->field_type = milvus::DataType(field_type);
|
||||
load_index_info->enable_mmap = enable_mmap;
|
||||
load_index_info->mmap_dir_path = std::string(mmap_dir_path);
|
||||
|
||||
auto status = CStatus();
|
||||
status.error_code = milvus::Success;
|
||||
status.error_msg = "";
|
||||
return status;
|
||||
} catch (std::exception& e) {
|
||||
auto status = CStatus();
|
||||
status.error_code = milvus::UnexpectedError;
|
||||
status.error_msg = strdup(e.what());
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
CStatus
|
||||
appendVecIndex(CLoadIndexInfo c_load_index_info, CBinarySet c_binary_set) {
|
||||
SCOPE_CGO_CALL_METRIC();
|
||||
@ -252,6 +218,7 @@ EstimateLoadIndexResource(CLoadIndexInfo c_load_index_info) {
|
||||
auto load_index_info =
|
||||
(milvus::segcore::LoadIndexInfo*)c_load_index_info;
|
||||
auto field_type = load_index_info->field_type;
|
||||
auto element_type = load_index_info->element_type;
|
||||
auto& index_params = load_index_info->index_params;
|
||||
bool find_index_type =
|
||||
index_params.count("index_type") > 0 ? true : false;
|
||||
@ -261,6 +228,7 @@ EstimateLoadIndexResource(CLoadIndexInfo c_load_index_info) {
|
||||
LoadResourceRequest request =
|
||||
milvus::index::IndexFactory::GetInstance().IndexLoadResource(
|
||||
field_type,
|
||||
element_type,
|
||||
load_index_info->index_engine_version,
|
||||
load_index_info->index_size,
|
||||
index_params,
|
||||
@ -581,6 +549,8 @@ FinishLoadIndexInfo(CLoadIndexInfo c_load_index_info,
|
||||
load_index_info->field_id = info_proto->field().fieldid();
|
||||
load_index_info->field_type =
|
||||
static_cast<milvus::DataType>(info_proto->field().data_type());
|
||||
load_index_info->element_type = static_cast<milvus::DataType>(
|
||||
info_proto->field().element_type());
|
||||
load_index_info->enable_mmap = info_proto->enable_mmap();
|
||||
load_index_info->mmap_dir_path = info_proto->mmap_dir_path();
|
||||
load_index_info->index_id = info_proto->indexid();
|
||||
|
||||
@ -38,16 +38,6 @@ AppendIndexParam(CLoadIndexInfo c_load_index_info,
|
||||
const char* index_key,
|
||||
const char* index_value);
|
||||
|
||||
CStatus
|
||||
AppendFieldInfo(CLoadIndexInfo c_load_index_info,
|
||||
int64_t collection_id,
|
||||
int64_t partition_id,
|
||||
int64_t segment_id,
|
||||
int64_t field_id,
|
||||
enum CDataType field_type,
|
||||
bool enable_mmap,
|
||||
const char* mmap_dir_path);
|
||||
|
||||
LoadResourceRequest
|
||||
EstimateLoadIndexResource(CLoadIndexInfo c_load_index_info);
|
||||
|
||||
|
||||
@ -440,7 +440,10 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
|
||||
->set_element_type(
|
||||
proto::schema::DataType(field_meta.get_element_type()));
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
||||
ThrowInfo(NotImplemented, "VECTOR_ARRAY is not implemented");
|
||||
field_data->mutable_vectors()
|
||||
->mutable_vector_array()
|
||||
->set_element_type(
|
||||
proto::schema::DataType(field_meta.get_element_type()));
|
||||
}
|
||||
search_result_data->mutable_fields_data()->AddAllocated(
|
||||
field_data.release());
|
||||
|
||||
@ -155,7 +155,10 @@ StreamReducerHelper::AssembleMergedResult() {
|
||||
->set_element_type(
|
||||
proto::schema::DataType(field_meta.get_element_type()));
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
||||
ThrowInfo(NotImplemented, "VECTOR_ARRAY is not implemented");
|
||||
field_data->mutable_vectors()
|
||||
->mutable_vector_array()
|
||||
->set_element_type(
|
||||
proto::schema::DataType(field_meta.get_element_type()));
|
||||
}
|
||||
|
||||
new_merged_result->output_fields_data_[field_id] =
|
||||
@ -674,7 +677,10 @@ StreamReducerHelper::GetSearchResultDataSlice(int slice_index) {
|
||||
->set_element_type(
|
||||
proto::schema::DataType(field_meta.get_element_type()));
|
||||
} else if (field_meta.get_data_type() == DataType::VECTOR_ARRAY) {
|
||||
ThrowInfo(NotImplemented, "VECTOR_ARRAY is not implemented");
|
||||
field_data->mutable_vectors()
|
||||
->mutable_vector_array()
|
||||
->set_element_type(
|
||||
proto::schema::DataType(field_meta.get_element_type()));
|
||||
}
|
||||
search_result_data->mutable_fields_data()->AddAllocated(
|
||||
field_data.release());
|
||||
|
||||
@ -81,6 +81,7 @@ InterimSealedIndexTranslator::get_cells(
|
||||
|
||||
if (vec_data_type_ == DataType::VECTOR_FLOAT) {
|
||||
vec_index = std::make_unique<index::VectorMemIndex<float>>(
|
||||
DataType::NONE,
|
||||
index_type_,
|
||||
metric_type_,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
@ -88,6 +89,7 @@ InterimSealedIndexTranslator::get_cells(
|
||||
false);
|
||||
} else if (vec_data_type_ == DataType::VECTOR_FLOAT16) {
|
||||
vec_index = std::make_unique<index::VectorMemIndex<knowhere::fp16>>(
|
||||
DataType::NONE,
|
||||
index_type_,
|
||||
metric_type_,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
@ -95,6 +97,7 @@ InterimSealedIndexTranslator::get_cells(
|
||||
false);
|
||||
} else if (vec_data_type_ == DataType::VECTOR_BFLOAT16) {
|
||||
vec_index = std::make_unique<index::VectorMemIndex<knowhere::bf16>>(
|
||||
DataType::NONE,
|
||||
index_type_,
|
||||
metric_type_,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
@ -103,6 +106,7 @@ InterimSealedIndexTranslator::get_cells(
|
||||
}
|
||||
} else {
|
||||
vec_index = std::make_unique<index::VectorMemIndex<float>>(
|
||||
DataType::NONE,
|
||||
index_type_,
|
||||
metric_type_,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
|
||||
@ -22,6 +22,7 @@ SealedIndexTranslator::SealedIndexTranslator(
|
||||
index_load_info_({load_index_info->enable_mmap,
|
||||
load_index_info->mmap_dir_path,
|
||||
load_index_info->field_type,
|
||||
load_index_info->element_type,
|
||||
load_index_info->index_params,
|
||||
load_index_info->index_size,
|
||||
load_index_info->index_engine_version,
|
||||
@ -54,6 +55,7 @@ SealedIndexTranslator::estimated_byte_size_of_cell(
|
||||
LoadResourceRequest request =
|
||||
milvus::index::IndexFactory::GetInstance().IndexLoadResource(
|
||||
index_load_info_.field_type,
|
||||
index_load_info_.element_type,
|
||||
index_load_info_.index_engine_version,
|
||||
index_load_info_.index_size,
|
||||
index_load_info_.index_params,
|
||||
|
||||
@ -45,6 +45,7 @@ class SealedIndexTranslator
|
||||
bool enable_mmap;
|
||||
std::string mmap_dir_path;
|
||||
DataType field_type;
|
||||
DataType element_type;
|
||||
std::map<std::string, std::string> index_params;
|
||||
int64_t index_size;
|
||||
int64_t index_engine_version;
|
||||
|
||||
@ -13,6 +13,7 @@ V1SealedIndexTranslator::V1SealedIndexTranslator(
|
||||
load_index_info->enable_mmap,
|
||||
load_index_info->mmap_dir_path,
|
||||
load_index_info->field_type,
|
||||
load_index_info->element_type,
|
||||
load_index_info->index_params,
|
||||
load_index_info->index_files,
|
||||
load_index_info->index_size,
|
||||
|
||||
@ -44,6 +44,7 @@ class V1SealedIndexTranslator : public Translator<milvus::index::IndexBase> {
|
||||
bool enable_mmap;
|
||||
std::string mmap_dir_path;
|
||||
DataType field_type;
|
||||
DataType element_type;
|
||||
std::map<std::string, std::string> index_params;
|
||||
std::vector<std::string> index_files;
|
||||
int64_t index_size;
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
CStatus
|
||||
ValidateIndexParams(const char* index_type,
|
||||
enum CDataType data_type,
|
||||
enum CDataType element_type,
|
||||
const uint8_t* serialized_index_params,
|
||||
const uint64_t length) {
|
||||
try {
|
||||
@ -44,45 +45,64 @@ ValidateIndexParams(const char* index_type,
|
||||
|
||||
knowhere::Status status;
|
||||
std::string error_msg;
|
||||
if (dataType == milvus::DataType::VECTOR_BINARY) {
|
||||
status = knowhere::IndexStaticFaced<knowhere::bin1>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_FLOAT) {
|
||||
status = knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_BFLOAT16) {
|
||||
status = knowhere::IndexStaticFaced<knowhere::bf16>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_FLOAT16) {
|
||||
status = knowhere::IndexStaticFaced<knowhere::fp16>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) {
|
||||
status = knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_INT8) {
|
||||
status = knowhere::IndexStaticFaced<knowhere::int8>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
auto check_leaf_type = [&index_type, &json, &error_msg, &status](
|
||||
milvus::DataType dataType) {
|
||||
if (dataType == milvus::DataType::VECTOR_BINARY) {
|
||||
status =
|
||||
knowhere::IndexStaticFaced<knowhere::bin1>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_FLOAT) {
|
||||
status =
|
||||
knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_BFLOAT16) {
|
||||
status =
|
||||
knowhere::IndexStaticFaced<knowhere::bf16>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_FLOAT16) {
|
||||
status =
|
||||
knowhere::IndexStaticFaced<knowhere::fp16>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_SPARSE_FLOAT) {
|
||||
status =
|
||||
knowhere::IndexStaticFaced<knowhere::fp32>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else if (dataType == milvus::DataType::VECTOR_INT8) {
|
||||
status =
|
||||
knowhere::IndexStaticFaced<knowhere::int8>::ConfigCheck(
|
||||
index_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
json,
|
||||
error_msg);
|
||||
} else {
|
||||
status = knowhere::Status::invalid_args;
|
||||
}
|
||||
};
|
||||
|
||||
if (dataType == milvus::DataType::VECTOR_ARRAY) {
|
||||
milvus::DataType elementType(
|
||||
static_cast<milvus::DataType>(element_type));
|
||||
|
||||
check_leaf_type(elementType);
|
||||
} else {
|
||||
status = knowhere::Status::invalid_args;
|
||||
check_leaf_type(dataType);
|
||||
}
|
||||
|
||||
CStatus cStatus;
|
||||
if (status == knowhere::Status::success) {
|
||||
cStatus.error_code = milvus::Success;
|
||||
|
||||
@ -20,6 +20,7 @@ extern "C" {
|
||||
CStatus
|
||||
ValidateIndexParams(const char* index_type,
|
||||
enum CDataType data_type,
|
||||
enum CDataType element_type,
|
||||
const uint8_t* index_params,
|
||||
const uint64_t length);
|
||||
|
||||
|
||||
@ -141,7 +141,8 @@ class TestFloatSearchBruteForce : public ::testing::Test {
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_FLOAT);
|
||||
DataType::VECTOR_FLOAT,
|
||||
DataType::NONE);
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto ref = Ref(base.data(),
|
||||
query.data() + i * dim,
|
||||
|
||||
@ -113,7 +113,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT));
|
||||
DataType::VECTOR_SPARSE_FLOAT,
|
||||
DataType::NONE));
|
||||
return;
|
||||
}
|
||||
auto result = BruteForceSearch(query_dataset,
|
||||
@ -121,7 +122,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT);
|
||||
DataType::VECTOR_SPARSE_FLOAT,
|
||||
DataType::NONE);
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto ref = SearchRef(base.get(), *(query.get() + i), nb, topk);
|
||||
auto ans = result.get_seg_offsets() + i * topk;
|
||||
@ -135,7 +137,8 @@ class TestSparseFloatSearchBruteForce : public ::testing::Test {
|
||||
search_info,
|
||||
index_info,
|
||||
bitset_view,
|
||||
DataType::VECTOR_SPARSE_FLOAT);
|
||||
DataType::VECTOR_SPARSE_FLOAT,
|
||||
DataType::NONE);
|
||||
for (int i = 0; i < nq; i++) {
|
||||
auto ref = RangeSearchRef(
|
||||
base.get(), *(query.get() + i), nb, 0.1, 0.5, topk);
|
||||
|
||||
@ -1936,6 +1936,7 @@ TEST(CApiTest, LoadIndexSearch) {
|
||||
auto& index_params = load_index_info.index_params;
|
||||
index_params["index_type"] = knowhere::IndexEnum::INDEX_FAISS_IVFSQ8;
|
||||
auto index = std::make_unique<VectorMemIndex<float>>(
|
||||
DataType::NONE,
|
||||
index_params["index_type"],
|
||||
knowhere::metric::L2,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber());
|
||||
|
||||
@ -128,6 +128,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
|
||||
search_info,
|
||||
index_info,
|
||||
query_data,
|
||||
nullptr,
|
||||
1,
|
||||
total_row_count,
|
||||
bv,
|
||||
@ -153,6 +154,7 @@ TEST(test_chunk_segment, TestSearchOnSealed) {
|
||||
search_info,
|
||||
index_info,
|
||||
query_data,
|
||||
nullptr,
|
||||
1,
|
||||
total_row_count,
|
||||
bv,
|
||||
|
||||
@ -16674,6 +16674,7 @@ TEST(JsonIndexTest, TestJsonNotEqualExpr) {
|
||||
file_manager_ctx.fieldDataMeta.field_schema.set_data_type(
|
||||
milvus::proto::schema::JSON);
|
||||
file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get());
|
||||
file_manager_ctx.fieldDataMeta.field_id = json_fid.get();
|
||||
|
||||
auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex(
|
||||
index::CreateIndexInfo{
|
||||
@ -16784,6 +16785,7 @@ TEST_P(JsonIndexExistsTest, TestExistsExpr) {
|
||||
milvus::proto::schema::JSON);
|
||||
file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get());
|
||||
file_manager_ctx.fieldDataMeta.field_schema.set_nullable(true);
|
||||
file_manager_ctx.fieldDataMeta.field_id = json_fid.get();
|
||||
auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex(
|
||||
index::CreateIndexInfo{
|
||||
.index_type = index::INVERTED_INDEX_TYPE,
|
||||
@ -16971,6 +16973,7 @@ TEST_P(JsonIndexBinaryExprTest, TestBinaryRangeExpr) {
|
||||
file_manager_ctx.fieldDataMeta.field_schema.set_data_type(
|
||||
milvus::proto::schema::JSON);
|
||||
file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(json_fid.get());
|
||||
file_manager_ctx.fieldDataMeta.field_id = json_fid.get();
|
||||
|
||||
auto inv_index = index::IndexFactory::GetInstance().CreateJsonIndex(
|
||||
index::CreateIndexInfo{
|
||||
|
||||
@ -541,3 +541,87 @@ TEST(GrowingTest, LoadVectorArrayData) {
|
||||
verify_float_vectors(arrow_array, expected_array);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GrowingTest, SearchVectorArray) {
|
||||
using namespace milvus::query;
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto metric_type = knowhere::metric::MAX_SIM;
|
||||
|
||||
// Add fields
|
||||
auto int64_field = schema->AddDebugField("int64", DataType::INT64);
|
||||
auto array_vec = schema->AddDebugVectorArrayField(
|
||||
"array_vec", DataType::VECTOR_FLOAT, 128, metric_type);
|
||||
schema->set_primary_field_id(int64_field);
|
||||
|
||||
// Configure segment
|
||||
auto config = SegcoreConfig::default_config();
|
||||
config.set_chunk_rows(1024);
|
||||
config.set_enable_interim_segment_index(true);
|
||||
|
||||
std::map<std::string, std::string> index_params = {
|
||||
{"index_type", knowhere::IndexEnum::INDEX_EMB_LIST_HNSW},
|
||||
{"metric_type", metric_type},
|
||||
{"nlist", "128"}};
|
||||
std::map<std::string, std::string> type_params = {{"dim", "128"}};
|
||||
FieldIndexMeta fieldIndexMeta(
|
||||
array_vec, std::move(index_params), std::move(type_params));
|
||||
std::map<FieldId, FieldIndexMeta> fieldMap = {{array_vec, fieldIndexMeta}};
|
||||
|
||||
IndexMetaPtr metaPtr =
|
||||
std::make_shared<CollectionIndexMeta>(100000, std::move(fieldMap));
|
||||
auto segment = CreateGrowingSegment(schema, metaPtr, 1, config);
|
||||
auto segmentImplPtr = dynamic_cast<SegmentGrowingImpl*>(segment.get());
|
||||
|
||||
// Insert data
|
||||
int64_t N = 100;
|
||||
uint64_t seed = 42;
|
||||
int emb_list_len = 5; // Each row contains 5 vectors
|
||||
auto dataset = DataGen(schema, N, seed, 0, 1, emb_list_len);
|
||||
|
||||
auto offset = 0;
|
||||
segment->Insert(offset,
|
||||
N,
|
||||
dataset.row_ids_.data(),
|
||||
dataset.timestamps_.data(),
|
||||
dataset.raw_);
|
||||
|
||||
// Prepare search query
|
||||
int vec_num = 10; // Total number of query vectors
|
||||
int dim = 128;
|
||||
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
|
||||
|
||||
// Create query dataset with lims for VectorArray
|
||||
std::vector<size_t> query_vec_lims;
|
||||
query_vec_lims.push_back(0); // First query has 3 vectors
|
||||
query_vec_lims.push_back(3);
|
||||
query_vec_lims.push_back(10); // Second query has 7 vectors
|
||||
|
||||
// Create search plan
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 101
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "MAX_SIM"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
|
||||
|
||||
// Use CreatePlaceholderGroupFromBlob for VectorArray
|
||||
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
|
||||
vec_num, dim, query_vec.data(), query_vec_lims);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
|
||||
// Execute search
|
||||
Timestamp timestamp = 10000000;
|
||||
auto sr = segment->Search(plan.get(), ph_group.get(), timestamp);
|
||||
auto sr_parsed = SearchResultToJson(*sr);
|
||||
std::cout << sr_parsed.dump(1) << std::endl;
|
||||
}
|
||||
|
||||
@ -360,6 +360,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) {
|
||||
|
||||
if (data_type == DataType::VECTOR_FLOAT) {
|
||||
auto index = std::make_unique<milvus::index::VectorMemIndex<float>>(
|
||||
DataType::NONE,
|
||||
index_type,
|
||||
metric_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
@ -375,6 +376,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) {
|
||||
EXPECT_EQ(index->Count(), (add_cont + 1) * N);
|
||||
} else if (data_type == DataType::VECTOR_FLOAT16) {
|
||||
auto index = std::make_unique<milvus::index::VectorMemIndex<float16>>(
|
||||
DataType::NONE,
|
||||
index_type,
|
||||
metric_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
@ -391,6 +393,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) {
|
||||
EXPECT_EQ(index->Count(), (add_cont + 1) * N);
|
||||
} else if (data_type == DataType::VECTOR_BFLOAT16) {
|
||||
auto index = std::make_unique<milvus::index::VectorMemIndex<bfloat16>>(
|
||||
DataType::NONE,
|
||||
index_type,
|
||||
metric_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
@ -407,6 +410,7 @@ TEST_P(GrowingIndexTest, AddWithoutBuildPool) {
|
||||
EXPECT_EQ(index->Count(), (add_cont + 1) * N);
|
||||
} else if (is_sparse) {
|
||||
auto index = std::make_unique<milvus::index::VectorMemIndex<float>>(
|
||||
DataType::NONE,
|
||||
index_type,
|
||||
metric_type,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
|
||||
@ -184,7 +184,8 @@ TEST(Indexing, BinaryBruteForce) {
|
||||
search_info,
|
||||
index_info,
|
||||
nullptr,
|
||||
DataType::VECTOR_BINARY);
|
||||
DataType::VECTOR_BINARY,
|
||||
DataType::NONE);
|
||||
|
||||
SearchResult sr;
|
||||
sr.total_nq_ = num_queries;
|
||||
|
||||
@ -564,6 +564,7 @@ class JsonFlatIndexExprTest : public ::testing::Test {
|
||||
file_manager_ctx.fieldDataMeta.field_schema.set_fieldid(
|
||||
json_fid_.get());
|
||||
file_manager_ctx.fieldDataMeta.field_schema.set_nullable(true);
|
||||
file_manager_ctx.fieldDataMeta.field_id = json_fid_.get();
|
||||
auto index = index::IndexFactory::GetInstance().CreateJsonIndex(
|
||||
index::CreateIndexInfo{
|
||||
.index_type = index::INVERTED_INDEX_TYPE,
|
||||
|
||||
@ -45,13 +45,13 @@ test_ngram_with_data(const boost::container::vector<std::string>& data,
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto field_id = schema->AddDebugField("ngram", DataType::VARCHAR);
|
||||
|
||||
auto field_meta = gen_field_meta(collection_id,
|
||||
partition_id,
|
||||
segment_id,
|
||||
field_id.get(),
|
||||
DataType::VARCHAR,
|
||||
DataType::NONE,
|
||||
false);
|
||||
auto field_meta = milvus::segcore::gen_field_meta(collection_id,
|
||||
partition_id,
|
||||
segment_id,
|
||||
field_id.get(),
|
||||
DataType::VARCHAR,
|
||||
DataType::NONE,
|
||||
false);
|
||||
auto index_meta = gen_index_meta(
|
||||
segment_id, field_id.get(), index_build_id, index_version);
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "knowhere/version.h"
|
||||
#include "storage/RemoteChunkManagerSingleton.h"
|
||||
#include "storage/Util.h"
|
||||
#include "common/VectorArray.h"
|
||||
|
||||
#include "test_utils/cachinglayer_test_utils.h"
|
||||
#include "test_utils/DataGen.h"
|
||||
@ -2333,3 +2334,257 @@ TEST(Sealed, QueryVectorArrayAllFields) {
|
||||
EXPECT_EQ(int64_result->valid_data_size(), 0);
|
||||
EXPECT_EQ(array_float_vector_result->valid_data_size(), 0);
|
||||
}
|
||||
|
||||
TEST(Sealed, SearchVectorArray) {
|
||||
int64_t collection_id = 1;
|
||||
int64_t partition_id = 2;
|
||||
int64_t segment_id = 3;
|
||||
int64_t index_build_id = 4000;
|
||||
int64_t index_version = 4000;
|
||||
int64_t index_id = 5000;
|
||||
|
||||
auto schema = std::make_shared<Schema>();
|
||||
auto metric_type = knowhere::metric::L2;
|
||||
auto int64_field = schema->AddDebugField("int64", DataType::INT64);
|
||||
auto array_vec = schema->AddDebugVectorArrayField(
|
||||
"array_vec", DataType::VECTOR_FLOAT, 128, metric_type);
|
||||
schema->set_primary_field_id(int64_field);
|
||||
|
||||
auto field_meta = milvus::segcore::gen_field_meta(collection_id,
|
||||
partition_id,
|
||||
segment_id,
|
||||
array_vec.get(),
|
||||
DataType::VECTOR_ARRAY,
|
||||
DataType::VECTOR_FLOAT,
|
||||
false);
|
||||
auto index_meta = gen_index_meta(
|
||||
segment_id, array_vec.get(), index_build_id, index_version);
|
||||
|
||||
std::map<FieldId, FieldIndexMeta> filedMap{};
|
||||
IndexMetaPtr metaPtr =
|
||||
std::make_shared<CollectionIndexMeta>(100000, std::move(filedMap));
|
||||
|
||||
int64_t dataset_size = 1000;
|
||||
int64_t dim = 128;
|
||||
auto emb_list_len = 10;
|
||||
auto dataset = DataGen(schema, dataset_size, 42, 0, 1, emb_list_len);
|
||||
|
||||
// create field data
|
||||
std::string root_path = "/tmp/test-vector-array/";
|
||||
auto storage_config = gen_local_storage_config(root_path);
|
||||
auto cm = CreateChunkManager(storage_config);
|
||||
auto vec_array_col = dataset.get_col<VectorFieldProto>(array_vec);
|
||||
std::vector<milvus::VectorArray> vector_arrays;
|
||||
for (auto& v : vec_array_col) {
|
||||
vector_arrays.push_back(milvus::VectorArray(v));
|
||||
}
|
||||
auto field_data = storage::CreateFieldData(DataType::VECTOR_ARRAY, false);
|
||||
field_data->FillFieldData(vector_arrays.data(), vector_arrays.size());
|
||||
|
||||
// create sealed segment
|
||||
auto segment = CreateSealedSegment(schema);
|
||||
auto field_data_info = PrepareSingleFieldInsertBinlog(collection_id,
|
||||
partition_id,
|
||||
segment_id,
|
||||
array_vec.get(),
|
||||
{field_data},
|
||||
cm);
|
||||
segment->LoadFieldData(field_data_info);
|
||||
|
||||
// serialize bin logs
|
||||
auto payload_reader =
|
||||
std::make_shared<milvus::storage::PayloadReader>(field_data);
|
||||
storage::InsertData insert_data(payload_reader);
|
||||
insert_data.SetFieldDataMeta(field_meta);
|
||||
insert_data.SetTimestamps(0, 100);
|
||||
|
||||
auto serialized_bytes = insert_data.Serialize(storage::Remote);
|
||||
|
||||
auto get_binlog_path = [=](int64_t log_id) {
|
||||
return fmt::format("{}/{}/{}/{}/{}",
|
||||
collection_id,
|
||||
partition_id,
|
||||
segment_id,
|
||||
array_vec.get(),
|
||||
log_id);
|
||||
};
|
||||
|
||||
auto log_path = get_binlog_path(0);
|
||||
|
||||
auto cm_w = ChunkManagerWrapper(cm);
|
||||
cm_w.Write(log_path, serialized_bytes.data(), serialized_bytes.size());
|
||||
|
||||
storage::FileManagerContext ctx(field_meta, index_meta, cm);
|
||||
std::vector<std::string> index_files;
|
||||
|
||||
// create index
|
||||
milvus::index::CreateIndexInfo create_index_info;
|
||||
create_index_info.field_type = DataType::VECTOR_ARRAY;
|
||||
create_index_info.metric_type = knowhere::metric::MAX_SIM;
|
||||
create_index_info.index_type = knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
|
||||
create_index_info.index_engine_version =
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber();
|
||||
|
||||
auto emb_list_hnsw_index =
|
||||
milvus::index::IndexFactory::GetInstance().CreateIndex(
|
||||
create_index_info,
|
||||
storage::FileManagerContext(field_meta, index_meta, cm));
|
||||
|
||||
// build index
|
||||
Config config;
|
||||
config[milvus::index::INDEX_TYPE] =
|
||||
knowhere::IndexEnum::INDEX_EMB_LIST_HNSW;
|
||||
config[INSERT_FILES_KEY] = std::vector<std::string>{log_path};
|
||||
config[knowhere::meta::METRIC_TYPE] = create_index_info.metric_type;
|
||||
config[knowhere::indexparam::M] = "16";
|
||||
config[knowhere::indexparam::EF] = "10";
|
||||
config[DIM_KEY] = dim;
|
||||
emb_list_hnsw_index->Build(config);
|
||||
|
||||
auto vec_index =
|
||||
dynamic_cast<milvus::index::VectorIndex*>(emb_list_hnsw_index.get());
|
||||
EXPECT_EQ(vec_index->Count(), dataset_size * emb_list_len);
|
||||
EXPECT_EQ(vec_index->GetDim(), dim);
|
||||
|
||||
// search
|
||||
auto vec_num = 10;
|
||||
std::vector<float> query_vec = generate_float_vector(vec_num, dim);
|
||||
auto query_dataset = knowhere::GenDataSet(vec_num, dim, query_vec.data());
|
||||
std::vector<size_t> query_vec_lims;
|
||||
query_vec_lims.push_back(0);
|
||||
query_vec_lims.push_back(3);
|
||||
query_vec_lims.push_back(10);
|
||||
query_dataset->SetLims(query_vec_lims.data());
|
||||
|
||||
auto search_conf = knowhere::Json{{knowhere::indexparam::NPROBE, 10}};
|
||||
milvus::SearchInfo searchInfo;
|
||||
searchInfo.topk_ = 5;
|
||||
searchInfo.metric_type_ = knowhere::metric::L2;
|
||||
searchInfo.search_params_ = search_conf;
|
||||
SearchResult result;
|
||||
vec_index->Query(query_dataset, searchInfo, nullptr, result);
|
||||
auto ref_result = SearchResultToJson(result);
|
||||
std::cout << ref_result.dump(1) << std::endl;
|
||||
EXPECT_EQ(result.total_nq_, 2);
|
||||
EXPECT_EQ(result.distances_.size(), 2 * searchInfo.topk_);
|
||||
|
||||
// create sealed segment
|
||||
auto sealed_segment = CreateSealedWithFieldDataLoaded(schema, dataset);
|
||||
|
||||
// brute force search
|
||||
{
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 101
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "MAX_SIM"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
|
||||
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
|
||||
vec_num, dim, query_vec.data(), query_vec_lims);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
Timestamp timestamp = 1000000;
|
||||
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
|
||||
|
||||
auto sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp);
|
||||
auto sr_parsed = SearchResultToJson(*sr);
|
||||
std::cout << sr_parsed.dump(1) << std::endl;
|
||||
}
|
||||
|
||||
// // brute force search with iterative filter
|
||||
// {
|
||||
// auto [min, max] =
|
||||
// std::minmax_element(int_values.begin(), int_values.end());
|
||||
// auto min_val = *min;
|
||||
// auto max_val = *max;
|
||||
|
||||
// auto raw_plan = fmt::format(R"(vector_anns: <
|
||||
// field_id: 101
|
||||
// predicates: <
|
||||
// binary_range_expr: <
|
||||
// column_info: <
|
||||
// field_id: 100
|
||||
// data_type: Int64
|
||||
// >
|
||||
// lower_inclusive: true
|
||||
// upper_inclusive: true
|
||||
// lower_value: <
|
||||
// int64_val: {}
|
||||
// >
|
||||
// upper_value: <
|
||||
// int64_val: {}
|
||||
// >
|
||||
// >
|
||||
// >
|
||||
// query_info: <
|
||||
// topk: 5
|
||||
// round_decimal: 3
|
||||
// metric_type: "MAX_SIM"
|
||||
// hints: "iterative_filter"
|
||||
// search_params: "{{\"nprobe\": 10}}"
|
||||
// >
|
||||
// placeholder_tag: "$0"
|
||||
// >)",
|
||||
// min_val,
|
||||
// max_val);
|
||||
// auto plan_str = translate_text_plan_to_binary_plan(raw_plan.c_str());
|
||||
// auto plan =
|
||||
// CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
|
||||
// auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
|
||||
// vec_num, dim, query_vec.data(), query_vec_lims);
|
||||
// auto ph_group =
|
||||
// ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
// Timestamp timestamp = 1000000;
|
||||
// std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
|
||||
|
||||
// auto sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp);
|
||||
// auto sr_parsed = SearchResultToJson(*sr);
|
||||
// std::cout << sr_parsed.dump(1) << std::endl;
|
||||
// }
|
||||
|
||||
// search with index
|
||||
{
|
||||
LoadIndexInfo load_info;
|
||||
load_info.field_id = array_vec.get();
|
||||
load_info.field_type = DataType::VECTOR_ARRAY;
|
||||
load_info.element_type = DataType::VECTOR_FLOAT;
|
||||
load_info.index_params = GenIndexParams(emb_list_hnsw_index.get());
|
||||
load_info.cache_index =
|
||||
CreateTestCacheIndex("test", std::move(emb_list_hnsw_index));
|
||||
load_info.index_params["metric_type"] = knowhere::metric::MAX_SIM;
|
||||
|
||||
sealed_segment->DropFieldData(array_vec);
|
||||
sealed_segment->LoadIndex(load_info);
|
||||
|
||||
const char* raw_plan = R"(vector_anns: <
|
||||
field_id: 101
|
||||
query_info: <
|
||||
topk: 5
|
||||
round_decimal: 3
|
||||
metric_type: "MAX_SIM"
|
||||
search_params: "{\"nprobe\": 10}"
|
||||
>
|
||||
placeholder_tag: "$0"
|
||||
>)";
|
||||
auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
|
||||
auto plan =
|
||||
CreateSearchPlanByExpr(schema, plan_str.data(), plan_str.size());
|
||||
auto ph_group_raw = CreatePlaceholderGroupFromBlob<EmbListFloatVector>(
|
||||
vec_num, dim, query_vec.data(), query_vec_lims);
|
||||
auto ph_group =
|
||||
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
|
||||
Timestamp timestamp = 1000000;
|
||||
std::vector<const PlaceholderGroup*> ph_group_arr = {ph_group.get()};
|
||||
|
||||
auto sr = sealed_segment->Search(plan.get(), ph_group.get(), timestamp);
|
||||
auto sr_parsed = SearchResultToJson(*sr);
|
||||
std::cout << sr_parsed.dump(1) << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1562,7 +1562,8 @@ TEST(AlwaysTrueStringPlan, SearchWithOutputFields) {
|
||||
search_info,
|
||||
index_info,
|
||||
nullptr,
|
||||
DataType::VECTOR_FLOAT);
|
||||
DataType::VECTOR_FLOAT,
|
||||
DataType::NONE);
|
||||
|
||||
auto sr = segment->Search(plan.get(), ph_group.get(), MAX_TIMESTAMP);
|
||||
segment->FillPrimaryKeys(plan.get(), *sr);
|
||||
|
||||
@ -1004,7 +1004,10 @@ CreatePlaceholderGroup(int64_t num_queries, int dim, int64_t seed = 42) {
|
||||
|
||||
template <class TraitType = milvus::FloatVector>
|
||||
inline auto
|
||||
CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const void* src) {
|
||||
CreatePlaceholderGroupFromBlob(int64_t num_queries,
|
||||
int dim,
|
||||
const void* src,
|
||||
std::vector<size_t> offsets = {}) {
|
||||
if (std::is_same_v<TraitType, milvus::BinaryVector>) {
|
||||
assert(dim % 8 == 0);
|
||||
}
|
||||
@ -1017,12 +1020,27 @@ CreatePlaceholderGroupFromBlob(int64_t num_queries, int dim, const void* src) {
|
||||
value->set_type(TraitType::placeholder_type);
|
||||
int64_t src_index = 0;
|
||||
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<elem_type> vec;
|
||||
for (int d = 0; d < dim / TraitType::dim_factor; ++d) {
|
||||
vec.push_back(((elem_type*)src)[src_index++]);
|
||||
if (offsets.empty()) {
|
||||
for (int i = 0; i < num_queries; ++i) {
|
||||
std::vector<elem_type> vec;
|
||||
for (int d = 0; d < dim / TraitType::dim_factor; ++d) {
|
||||
vec.push_back(((elem_type*)src)[src_index++]);
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
|
||||
}
|
||||
} else {
|
||||
assert(offsets.back() == num_queries);
|
||||
for (int i = 0; i < offsets.size() - 1; i++) {
|
||||
auto start = offsets[i];
|
||||
auto end = offsets[i + 1];
|
||||
std::vector<elem_type> vec;
|
||||
for (int j = start; j < end; j++) {
|
||||
for (int d = 0; d < dim / TraitType::dim_factor; ++d) {
|
||||
vec.push_back(((elem_type*)src)[src_index++]);
|
||||
}
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
|
||||
}
|
||||
value->add_values(vec.data(), vec.size() * sizeof(elem_type));
|
||||
}
|
||||
return raw_group;
|
||||
}
|
||||
@ -1362,6 +1380,7 @@ GenVecIndexing(int64_t N,
|
||||
milvus::storage::FileManagerContext file_manager_context(
|
||||
field_data_meta, index_meta, chunk_manager);
|
||||
auto indexing = std::make_unique<index::VectorMemIndex<float>>(
|
||||
DataType::NONE,
|
||||
index_type,
|
||||
knowhere::metric::L2,
|
||||
knowhere::Version::GetCurrentVersion().VersionNumber(),
|
||||
@ -1631,4 +1650,29 @@ GenChunkedSegmentTestSchema(bool pk_is_string) {
|
||||
return schema;
|
||||
}
|
||||
|
||||
inline std::vector<float>
|
||||
generate_float_vector(int64_t N, int64_t dim) {
|
||||
auto seed = 42;
|
||||
auto offset = 0;
|
||||
std::vector<float> final(dim * N);
|
||||
for (int n = 0; n < N; ++n) {
|
||||
std::vector<float> data(dim);
|
||||
float sum = 0;
|
||||
|
||||
std::default_random_engine er2(seed + n);
|
||||
std::normal_distribution<> distr2(0, 1);
|
||||
for (auto& x : data) {
|
||||
x = distr2(er2) + offset++;
|
||||
sum += x * x;
|
||||
}
|
||||
sum = sqrt(sum);
|
||||
for (auto& x : data) {
|
||||
x /= sum;
|
||||
}
|
||||
|
||||
std::copy(data.begin(), data.end(), final.begin() + dim * n);
|
||||
}
|
||||
return final;
|
||||
};
|
||||
|
||||
} // namespace milvus::segcore
|
||||
|
||||
@ -31,6 +31,7 @@
|
||||
#include "segcore/segment_c.h"
|
||||
#include "futures/Future.h"
|
||||
#include "futures/future_c.h"
|
||||
#include "segcore/load_index_c.h"
|
||||
#include "DataGen.h"
|
||||
#include "PbHelper.h"
|
||||
#include "indexbuilder_test_utils.h"
|
||||
@ -38,6 +39,39 @@
|
||||
using namespace milvus;
|
||||
using namespace milvus::segcore;
|
||||
|
||||
// Test utility function for AppendFieldInfo
|
||||
inline CStatus
|
||||
AppendFieldInfo(CLoadIndexInfo c_load_index_info,
|
||||
int64_t collection_id,
|
||||
int64_t partition_id,
|
||||
int64_t segment_id,
|
||||
int64_t field_id,
|
||||
enum CDataType field_type,
|
||||
bool enable_mmap,
|
||||
const char* mmap_dir_path) {
|
||||
try {
|
||||
auto load_index_info =
|
||||
(milvus::segcore::LoadIndexInfo*)c_load_index_info;
|
||||
load_index_info->collection_id = collection_id;
|
||||
load_index_info->partition_id = partition_id;
|
||||
load_index_info->segment_id = segment_id;
|
||||
load_index_info->field_id = field_id;
|
||||
load_index_info->field_type = milvus::DataType(field_type);
|
||||
load_index_info->enable_mmap = enable_mmap;
|
||||
load_index_info->mmap_dir_path = std::string(mmap_dir_path);
|
||||
|
||||
auto status = CStatus();
|
||||
status.error_code = milvus::Success;
|
||||
status.error_msg = "";
|
||||
return status;
|
||||
} catch (std::exception& e) {
|
||||
auto status = CStatus();
|
||||
status.error_code = milvus::UnexpectedError;
|
||||
status.error_msg = strdup(e.what());
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::string
|
||||
|
||||
@ -109,7 +109,7 @@ TEST(VectorArray, TestConstructVectorArray) {
|
||||
field_float_vector_array.mutable_float_vector()->mutable_data()->Add(
|
||||
data.begin(), data.end());
|
||||
|
||||
auto float_vector_array = VectorArray(field_float_vector_array);
|
||||
auto float_vector_array = milvus::VectorArray(field_float_vector_array);
|
||||
ASSERT_EQ(float_vector_array.length(), N);
|
||||
ASSERT_EQ(float_vector_array.dim(), dim);
|
||||
ASSERT_EQ(float_vector_array.get_element_type(), DataType::VECTOR_FLOAT);
|
||||
@ -117,16 +117,16 @@ TEST(VectorArray, TestConstructVectorArray) {
|
||||
|
||||
ASSERT_TRUE(float_vector_array.is_same_array(field_float_vector_array));
|
||||
|
||||
auto float_vector_array_tmp = VectorArray(float_vector_array);
|
||||
auto float_vector_array_tmp = milvus::VectorArray(float_vector_array);
|
||||
|
||||
ASSERT_TRUE(float_vector_array_tmp.is_same_array(field_float_vector_array));
|
||||
|
||||
auto float_vector_array_view =
|
||||
VectorArrayView(const_cast<char*>(float_vector_array.data()),
|
||||
float_vector_array.length(),
|
||||
float_vector_array.dim(),
|
||||
float_vector_array.byte_size(),
|
||||
float_vector_array.get_element_type());
|
||||
milvus::VectorArrayView(const_cast<char*>(float_vector_array.data()),
|
||||
float_vector_array.length(),
|
||||
float_vector_array.dim(),
|
||||
float_vector_array.byte_size(),
|
||||
float_vector_array.get_element_type());
|
||||
|
||||
ASSERT_TRUE(
|
||||
float_vector_array_view.is_same_array(field_float_vector_array));
|
||||
|
||||
@ -70,13 +70,29 @@ func (s *Server) getSchema(ctx context.Context, collID int64) (*schemapb.Collect
|
||||
return resp.GetSchema(), nil
|
||||
}
|
||||
|
||||
func isJsonField(schema *schemapb.CollectionSchema, fieldID int64) (bool, error) {
|
||||
func FieldExists(schema *schemapb.CollectionSchema, fieldID int64) bool {
|
||||
for _, f := range schema.Fields {
|
||||
if f.FieldID == fieldID {
|
||||
return typeutil.IsJSONType(f.DataType), nil
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false, merr.WrapErrFieldNotFound(fieldID)
|
||||
for _, structField := range schema.StructArrayFields {
|
||||
for _, f := range structField.Fields {
|
||||
if f.FieldID == fieldID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isJsonField(schema *schemapb.CollectionSchema, fieldID int64) bool {
|
||||
for _, f := range schema.Fields {
|
||||
if f.FieldID == fieldID {
|
||||
return typeutil.IsJSONType(f.DataType)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getIndexParam(indexParams []*commonpb.KeyValuePair, key string) (string, error) {
|
||||
@ -154,11 +170,12 @@ func (s *Server) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques
|
||||
if err != nil {
|
||||
return merr.Status(err), nil
|
||||
}
|
||||
isJson, err := isJsonField(schema, req.GetFieldID())
|
||||
if err != nil {
|
||||
return merr.Status(err), nil
|
||||
|
||||
if !FieldExists(schema, req.GetFieldID()) {
|
||||
return merr.Status(merr.WrapErrFieldNotFound(req.GetFieldID())), nil
|
||||
}
|
||||
|
||||
isJson := isJsonField(schema, req.GetFieldID())
|
||||
if isJson {
|
||||
// check json_path and json_cast_type exist
|
||||
jsonPath, err := getIndexParam(req.GetIndexParams(), common.JSONPathKey)
|
||||
|
||||
@ -250,7 +250,8 @@ func (it *indexBuildTask) prepareJobRequest(ctx context.Context, segment *Segmen
|
||||
schema := collectionInfo.Schema
|
||||
var field *schemapb.FieldSchema
|
||||
|
||||
for _, f := range schema.Fields {
|
||||
allFields := typeutil.GetAllFieldSchemas(schema)
|
||||
for _, f := range allFields {
|
||||
if f.FieldID == fieldID {
|
||||
field = f
|
||||
break
|
||||
@ -263,7 +264,11 @@ func (it *indexBuildTask) prepareJobRequest(ctx context.Context, segment *Segmen
|
||||
|
||||
// Extract dim only for vector types to avoid unnecessary warnings
|
||||
dim := -1
|
||||
if typeutil.IsFixDimVectorType(field.GetDataType()) {
|
||||
dataType := field.GetDataType()
|
||||
if typeutil.IsVectorArrayType(dataType) {
|
||||
dataType = field.GetElementType()
|
||||
}
|
||||
if typeutil.IsFixDimVectorType(dataType) {
|
||||
if dimVal, err := storage.GetDimFromParams(field.GetTypeParams()); err != nil {
|
||||
log.Warn("failed to get dim from field type params",
|
||||
zap.String("field type", field.GetDataType().String()), zap.Error(err))
|
||||
|
||||
@ -180,6 +180,7 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField
|
||||
// plan ok with schema, check ann field
|
||||
fieldID := vectorField.FieldID
|
||||
dataType := vectorField.DataType
|
||||
elementType := vectorField.ElementType
|
||||
|
||||
var vectorType planpb.VectorType
|
||||
if !typeutil.IsVectorType(dataType) {
|
||||
@ -198,6 +199,15 @@ func CreateSearchPlan(schema *typeutil.SchemaHelper, exprStr string, vectorField
|
||||
vectorType = planpb.VectorType_SparseFloatVector
|
||||
case schemapb.DataType_Int8Vector:
|
||||
vectorType = planpb.VectorType_Int8Vector
|
||||
case schemapb.DataType_ArrayOfVector:
|
||||
switch elementType {
|
||||
case schemapb.DataType_FloatVector:
|
||||
vectorType = planpb.VectorType_EmbListFloatVector
|
||||
default:
|
||||
log.Error("Invalid elementType", zap.Any("elementType", elementType))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
log.Error("Invalid dataType", zap.Any("dataType", dataType))
|
||||
return nil, err
|
||||
|
||||
@ -62,13 +62,18 @@ func GetDynamicPool() *conc.Pool[any] {
|
||||
return dp.Load()
|
||||
}
|
||||
|
||||
func CheckVecIndexWithDataTypeExist(name string, dType schemapb.DataType) bool {
|
||||
func CheckVecIndexWithDataTypeExist(name string, dataType schemapb.DataType, elementType schemapb.DataType) bool {
|
||||
isEmbeddingList := dataType == schemapb.DataType_ArrayOfVector
|
||||
if isEmbeddingList {
|
||||
dataType = elementType
|
||||
}
|
||||
|
||||
var result bool
|
||||
GetDynamicPool().Submit(func() (any, error) {
|
||||
cIndexName := C.CString(name)
|
||||
cType := uint32(dType)
|
||||
cType := uint32(dataType)
|
||||
defer C.free(unsafe.Pointer(cIndexName))
|
||||
result = bool(C.CheckVecIndexWithDataType(cIndexName, cType))
|
||||
result = bool(C.CheckVecIndexWithDataType(cIndexName, cType, C.bool(isEmbeddingList)))
|
||||
return nil, nil
|
||||
}).Await()
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ func Test_CheckVecIndexWithDataTypeExist(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range cases {
|
||||
if got := CheckVecIndexWithDataTypeExist(test.indexType, test.dataType); got != test.want {
|
||||
if got := CheckVecIndexWithDataTypeExist(test.indexType, test.dataType, schemapb.DataType_None); got != test.want {
|
||||
t.Errorf("CheckVecIndexWithDataTypeExist(%v, %v) = %v", test.indexType, test.dataType, test.want)
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -549,24 +549,24 @@ func (op *lambdaOperator) run(ctx context.Context, span trace.Span, inputs ...an
|
||||
|
||||
type filterFieldOperator struct {
|
||||
outputFieldNames []string
|
||||
schema *schemaInfo
|
||||
fieldSchemas []*schemapb.FieldSchema
|
||||
}
|
||||
|
||||
func newFilterFieldOperator(t *searchTask, _ map[string]any) (operator, error) {
|
||||
return &filterFieldOperator{
|
||||
outputFieldNames: t.translatedOutputFields,
|
||||
schema: t.schema,
|
||||
fieldSchemas: typeutil.GetAllFieldSchemas(t.schema.CollectionSchema),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
||||
result := inputs[0].(*milvuspb.SearchResults)
|
||||
for _, retField := range result.Results.FieldsData {
|
||||
for _, schemaField := range op.schema.Fields {
|
||||
if retField != nil && retField.FieldId == schemaField.FieldID {
|
||||
retField.FieldName = schemaField.Name
|
||||
retField.Type = schemaField.DataType
|
||||
retField.IsDynamic = schemaField.IsDynamic
|
||||
for _, fieldSchema := range op.fieldSchemas {
|
||||
if retField != nil && retField.FieldId == fieldSchema.FieldID {
|
||||
retField.FieldName = fieldSchema.Name
|
||||
retField.Type = fieldSchema.DataType
|
||||
retField.IsDynamic = fieldSchema.IsDynamic
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -560,6 +560,79 @@ func (s *SearchPipelineSuite) TestHybridSearchPipe() {
|
||||
s.Len(results.Results.Scores, 20) // 2 queries * 10 topk
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestFilterFieldOperatorWithStructArrayFields() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "int64", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "intField", DataType: schemapb.DataType_Int64},
|
||||
{FieldID: 102, Name: "floatField", DataType: schemapb.DataType_Float},
|
||||
},
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
Name: "structArray",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 104, Name: "structArrayField", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32},
|
||||
{FieldID: 105, Name: "structVectorField", DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
task := &searchTask{
|
||||
schema: &schemaInfo{
|
||||
CollectionSchema: schema,
|
||||
},
|
||||
translatedOutputFields: []string{"intField", "floatField", "structArrayField", "structVectorField"},
|
||||
}
|
||||
|
||||
op, err := newFilterFieldOperator(task, nil)
|
||||
s.NoError(err)
|
||||
|
||||
// Create mock search results with fields including struct array fields
|
||||
searchResults := &milvuspb.SearchResults{
|
||||
Results: &schemapb.SearchResultData{
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{FieldId: 101}, // intField
|
||||
{FieldId: 102}, // floatField
|
||||
{FieldId: 104}, // structArrayField
|
||||
{FieldId: 105}, // structVectorField
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
results, err := op.run(context.Background(), s.span, searchResults)
|
||||
s.NoError(err)
|
||||
s.NotNil(results)
|
||||
|
||||
resultData := results[0].(*milvuspb.SearchResults)
|
||||
s.NotNil(resultData.Results.FieldsData)
|
||||
s.Len(resultData.Results.FieldsData, 4)
|
||||
|
||||
// Verify all fields including struct array fields got their names and types set
|
||||
for _, field := range resultData.Results.FieldsData {
|
||||
switch field.FieldId {
|
||||
case 101:
|
||||
s.Equal("intField", field.FieldName)
|
||||
s.Equal(schemapb.DataType_Int64, field.Type)
|
||||
s.False(field.IsDynamic)
|
||||
case 102:
|
||||
s.Equal("floatField", field.FieldName)
|
||||
s.Equal(schemapb.DataType_Float, field.Type)
|
||||
s.False(field.IsDynamic)
|
||||
case 104:
|
||||
// Struct array field should be handled by GetAllFieldSchemas
|
||||
s.Equal("structArrayField", field.FieldName)
|
||||
s.Equal(schemapb.DataType_Array, field.Type)
|
||||
s.False(field.IsDynamic)
|
||||
case 105:
|
||||
// Struct array vector field should be handled by GetAllFieldSchemas
|
||||
s.Equal("structVectorField", field.FieldName)
|
||||
s.Equal(schemapb.DataType_ArrayOfVector, field.Type)
|
||||
s.False(field.IsDynamic)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SearchPipelineSuite) TestHybridSearchWithRequeryPipe() {
|
||||
task := getHybridSearchTask("test_collection", [][]string{
|
||||
{"1", "2"},
|
||||
|
||||
@ -281,6 +281,28 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
||||
return nil, fmt.Errorf("parse iterator v2 info failed: %w", err)
|
||||
}
|
||||
|
||||
// 7. check search for embedding list
|
||||
annsFieldName, _ := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, searchParamsPair)
|
||||
if annsFieldName != "" {
|
||||
annField := typeutil.GetFieldByName(schema, annsFieldName)
|
||||
if annField != nil && annField.GetDataType() == schemapb.DataType_ArrayOfVector {
|
||||
if strings.Contains(searchParamStr, radiusKey) {
|
||||
return nil, merr.WrapErrParameterInvalid("", "",
|
||||
"range search is not supported for vector array (embedding list) fields, fieldName: %s", annsFieldName)
|
||||
}
|
||||
|
||||
if groupByFieldId > 0 {
|
||||
return nil, merr.WrapErrParameterInvalid("", "",
|
||||
"group by search is not supported for vector array (embedding list) fields, fieldName: %s", annsFieldName)
|
||||
}
|
||||
|
||||
if isIterator {
|
||||
return nil, merr.WrapErrParameterInvalid("", "",
|
||||
"search iterator is not supported for vector array (embedding list) fields, fieldName: %s", annsFieldName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &SearchInfo{
|
||||
planInfo: &planpb.QueryInfo{
|
||||
Topk: queryTopK,
|
||||
|
||||
@ -68,6 +68,8 @@ const (
|
||||
RoundDecimalKey = "round_decimal"
|
||||
OffsetKey = "offset"
|
||||
LimitKey = "limit"
|
||||
// offsets for embedding list search
|
||||
LimsKey = "lims"
|
||||
|
||||
SearchIterV2Key = "search_iter_v2"
|
||||
SearchIterBatchSizeKey = "search_iter_batch_size"
|
||||
@ -2047,7 +2049,8 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) {
|
||||
|
||||
loadFieldsSet := typeutil.NewSet(loadFields...)
|
||||
unindexedVecFields := make([]string, 0)
|
||||
for _, field := range collSchema.GetFields() {
|
||||
allFields := typeutil.GetAllFieldSchemas(collSchema.CollectionSchema)
|
||||
for _, field := range allFields {
|
||||
if typeutil.IsVectorType(field.GetDataType()) && loadFieldsSet.Contain(field.GetFieldID()) {
|
||||
if _, ok := fieldIndexIDs[field.GetFieldID()]; !ok {
|
||||
unindexedVecFields = append(unindexedVecFields, field.GetName())
|
||||
@ -2055,8 +2058,6 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// todo(SpadeA): check vector field in StructArrayField when index is implemented
|
||||
|
||||
if len(unindexedVecFields) != 0 {
|
||||
errMsg := fmt.Sprintf("there is no vector index on field: %v, please create index firstly", unindexedVecFields)
|
||||
log.Debug(errMsg)
|
||||
@ -2305,7 +2306,8 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error {
|
||||
|
||||
loadFieldsSet := typeutil.NewSet(loadFields...)
|
||||
unindexedVecFields := make([]string, 0)
|
||||
for _, field := range collSchema.GetFields() {
|
||||
allFields := typeutil.GetAllFieldSchemas(collSchema.CollectionSchema)
|
||||
for _, field := range allFields {
|
||||
if typeutil.IsVectorType(field.GetDataType()) && loadFieldsSet.Contain(field.GetFieldID()) {
|
||||
if _, ok := fieldIndexIDs[field.GetFieldID()]; !ok {
|
||||
unindexedVecFields = append(unindexedVecFields, field.GetName())
|
||||
@ -2313,8 +2315,6 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// todo(SpadeA): check vector field in StructArrayField when index is implemented
|
||||
|
||||
if len(unindexedVecFields) != 0 {
|
||||
errMsg := fmt.Sprintf("there is no vector index on field: %v, please create index firstly", unindexedVecFields)
|
||||
log.Ctx(ctx).Debug(errMsg)
|
||||
|
||||
@ -202,10 +202,12 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
|
||||
|
||||
specifyIndexType, exist := indexParamsMap[common.IndexTypeKey]
|
||||
if exist && specifyIndexType != "" {
|
||||
// todo(SpadeA): mmap check for struct array index
|
||||
if err := indexparamcheck.ValidateMmapIndexParams(specifyIndexType, indexParamsMap); err != nil {
|
||||
log.Ctx(ctx).Warn("Invalid mmap type params", zap.String(common.IndexTypeKey, specifyIndexType), zap.Error(err))
|
||||
return merr.WrapErrParameterInvalidMsg("invalid mmap type params: %s", err.Error())
|
||||
}
|
||||
// todo(SpadeA): check for struct array index
|
||||
checker, err := indexparamcheck.GetIndexCheckerMgrInstance().GetChecker(specifyIndexType)
|
||||
// not enable hybrid index for user, used in milvus internally
|
||||
if err != nil || indexparamcheck.IsHYBRIDChecker(checker) {
|
||||
@ -327,16 +329,20 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
|
||||
}
|
||||
|
||||
var config map[string]string
|
||||
if typeutil.IsDenseFloatVectorType(cit.fieldSchema.DataType) {
|
||||
if typeutil.IsDenseFloatVectorType(cit.fieldSchema.DataType) ||
|
||||
(typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsDenseFloatVectorType(cit.fieldSchema.ElementType)) {
|
||||
// override float vector index params by autoindex
|
||||
config = Params.AutoIndexConfig.IndexParams.GetAsJSONMap()
|
||||
} else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) {
|
||||
} else if typeutil.IsSparseFloatVectorType(cit.fieldSchema.DataType) ||
|
||||
(typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsSparseFloatVectorType(cit.fieldSchema.ElementType)) {
|
||||
// override sparse float vector index params by autoindex
|
||||
config = Params.AutoIndexConfig.SparseIndexParams.GetAsJSONMap()
|
||||
} else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) {
|
||||
} else if typeutil.IsBinaryVectorType(cit.fieldSchema.DataType) ||
|
||||
(typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsBinaryVectorType(cit.fieldSchema.ElementType)) {
|
||||
// override binary vector index params by autoindex
|
||||
config = Params.AutoIndexConfig.BinaryIndexParams.GetAsJSONMap()
|
||||
} else if typeutil.IsIntVectorType(cit.fieldSchema.DataType) {
|
||||
} else if typeutil.IsIntVectorType(cit.fieldSchema.DataType) ||
|
||||
(typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) && typeutil.IsIntVectorType(cit.fieldSchema.ElementType)) {
|
||||
// override int vector index params by autoindex
|
||||
config = Params.AutoIndexConfig.IndexParams.GetAsJSONMap()
|
||||
}
|
||||
@ -397,6 +403,12 @@ func (cit *createIndexTask) parseIndexParams(ctx context.Context) error {
|
||||
if !funcutil.SliceContain(indexparamcheck.IntVectorMetrics, metricType) {
|
||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "int vector index does not support metric type: "+metricType)
|
||||
}
|
||||
} else if typeutil.IsArrayOfVectorType(cit.fieldSchema.DataType) {
|
||||
// TODO(SpadeA): adjust it when more metric types are supported. Especially, when different metric types
|
||||
// are supported for different element types.
|
||||
if !funcutil.SliceContain(indexparamcheck.EmbListMetrics, metricType) {
|
||||
return merr.WrapErrParameterInvalid("valid index params", "invalid index params", "array of vector index does not support metric type: "+metricType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -500,7 +512,7 @@ func checkTrain(ctx context.Context, field *schemapb.FieldSchema, indexParams ma
|
||||
}
|
||||
|
||||
if typeutil.IsVectorType(field.DataType) && indexType != indexparamcheck.AutoIndex {
|
||||
exist := CheckVecIndexWithDataTypeExist(indexType, field.DataType)
|
||||
exist := CheckVecIndexWithDataTypeExist(indexType, field.DataType, field.ElementType)
|
||||
if !exist {
|
||||
return fmt.Errorf("data type %s can't build with this index %s", schemapb.DataType_name[int32(field.GetDataType())], indexType)
|
||||
}
|
||||
@ -519,7 +531,7 @@ func checkTrain(ctx context.Context, field *schemapb.FieldSchema, indexParams ma
|
||||
return err
|
||||
}
|
||||
|
||||
if err := checker.CheckTrain(field.DataType, indexParams); err != nil {
|
||||
if err := checker.CheckTrain(field.DataType, field.ElementType, indexParams); err != nil {
|
||||
log.Ctx(ctx).Info("create index with invalid parameters", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
@ -1164,6 +1165,131 @@ func Test_parseIndexParams(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_checkEmbeddingListIndex(t *testing.T) {
|
||||
t.Run("check embedding list index", func(t *testing.T) {
|
||||
cit := &createIndexTask{
|
||||
Condition: nil,
|
||||
req: &milvuspb.CreateIndexRequest{
|
||||
ExtraParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.IndexTypeKey,
|
||||
Value: "EMB_LIST_HNSW",
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metric.MaxSim,
|
||||
},
|
||||
},
|
||||
IndexName: "",
|
||||
},
|
||||
fieldSchema: &schemapb.FieldSchema{
|
||||
FieldID: 101,
|
||||
Name: "EmbListFloat",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "128"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := cit.parseIndexParams(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("metrics wrong for embedding list index", func(t *testing.T) {
|
||||
cit := &createIndexTask{
|
||||
Condition: nil,
|
||||
req: &milvuspb.CreateIndexRequest{
|
||||
ExtraParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.IndexTypeKey,
|
||||
Value: "EMB_LIST_HNSW",
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metric.L2,
|
||||
},
|
||||
},
|
||||
IndexName: "",
|
||||
},
|
||||
fieldSchema: &schemapb.FieldSchema{
|
||||
FieldID: 101,
|
||||
Name: "EmbListFloat",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "128"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := cit.parseIndexParams(context.TODO())
|
||||
assert.True(t, strings.Contains(err.Error(), "array of vector index does not support metric type: L2"))
|
||||
})
|
||||
|
||||
t.Run("metric type wrong", func(t *testing.T) {
|
||||
cit := &createIndexTask{
|
||||
Condition: nil,
|
||||
req: &milvuspb.CreateIndexRequest{
|
||||
ExtraParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.IndexTypeKey,
|
||||
Value: "HNSW",
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metric.MaxSim,
|
||||
},
|
||||
},
|
||||
IndexName: "",
|
||||
},
|
||||
fieldSchema: &schemapb.FieldSchema{
|
||||
FieldID: 101,
|
||||
Name: "FieldFloat",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "128"},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := cit.parseIndexParams(context.TODO())
|
||||
assert.True(t, strings.Contains(err.Error(), "float vector index does not support metric type: MAX_SIM"))
|
||||
})
|
||||
|
||||
t.Run("data type wrong", func(t *testing.T) {
|
||||
cit := &createIndexTask{
|
||||
Condition: nil,
|
||||
req: &milvuspb.CreateIndexRequest{
|
||||
ExtraParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.IndexTypeKey,
|
||||
Value: "EMB_LIST_HNSW",
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metric.L2,
|
||||
},
|
||||
},
|
||||
IndexName: "",
|
||||
},
|
||||
fieldSchema: &schemapb.FieldSchema{
|
||||
FieldID: 101,
|
||||
Name: "FieldFloat",
|
||||
IsPrimaryKey: false,
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "128"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := cit.parseIndexParams(context.TODO())
|
||||
assert.True(t, strings.Contains(err.Error(), "data type FloatVector can't build with this index EMB_LIST_HNSW"))
|
||||
})
|
||||
}
|
||||
|
||||
func Test_ngram_parseIndexParams(t *testing.T) {
|
||||
t.Run("valid ngram index params", func(t *testing.T) {
|
||||
cit := &createIndexTask{
|
||||
|
||||
@ -217,11 +217,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
allFields := make([]*schemapb.FieldSchema, 0, len(it.schema.Fields)+5)
|
||||
allFields = append(allFields, it.schema.Fields...)
|
||||
for _, structField := range it.schema.GetStructArrayFields() {
|
||||
allFields = append(allFields, structField.GetFields()...)
|
||||
}
|
||||
allFields := typeutil.GetAllFieldSchemas(it.schema)
|
||||
|
||||
// check primaryFieldData whether autoID is true or not
|
||||
// set rowIDs as primary data if autoID == true
|
||||
|
||||
@ -650,7 +650,9 @@ func (t *queryTask) PostExecute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
t.result.OutputFields = t.userOutputFields
|
||||
reconstructStructFieldData(t.result, t.schema.CollectionSchema)
|
||||
if !t.reQuery {
|
||||
reconstructStructFieldDataForQuery(t.result, t.schema.CollectionSchema)
|
||||
}
|
||||
|
||||
primaryFieldSchema, err := t.schema.GetPkField()
|
||||
if err != nil {
|
||||
|
||||
@ -1292,470 +1292,3 @@ func TestQueryTask_CanSkipAllocTimestamp(t *testing.T) {
|
||||
assert.True(t, skip)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_reconstructStructFieldData(t *testing.T) {
|
||||
t.Run("count(*) query - should return early", func(t *testing.T) {
|
||||
results := &milvuspb.QueryResults{
|
||||
OutputFields: []string{"count(*)"},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "count(*)",
|
||||
FieldId: 0,
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "test_struct",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1021,
|
||||
Name: "sub_field",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
originalFieldsData := make([]*schemapb.FieldData, len(results.FieldsData))
|
||||
copy(originalFieldsData, results.FieldsData)
|
||||
originalOutputFields := make([]string, len(results.OutputFields))
|
||||
copy(originalOutputFields, results.OutputFields)
|
||||
|
||||
reconstructStructFieldData(results, schema)
|
||||
|
||||
// Should not modify anything for count(*) query
|
||||
assert.Equal(t, originalFieldsData, results.FieldsData)
|
||||
assert.Equal(t, originalOutputFields, results.OutputFields)
|
||||
})
|
||||
|
||||
t.Run("no struct array fields - should return early", func(t *testing.T) {
|
||||
results := &milvuspb.QueryResults{
|
||||
OutputFields: []string{"field1", "field2"},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "field1",
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldName: "field2",
|
||||
FieldId: 101,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{},
|
||||
}
|
||||
|
||||
originalFieldsData := make([]*schemapb.FieldData, len(results.FieldsData))
|
||||
copy(originalFieldsData, results.FieldsData)
|
||||
originalOutputFields := make([]string, len(results.OutputFields))
|
||||
copy(originalOutputFields, results.OutputFields)
|
||||
|
||||
reconstructStructFieldData(results, schema)
|
||||
|
||||
// Should not modify anything when no struct array fields
|
||||
assert.Equal(t, originalFieldsData, results.FieldsData)
|
||||
assert.Equal(t, originalOutputFields, results.OutputFields)
|
||||
})
|
||||
|
||||
t.Run("reconstruct single struct field", func(t *testing.T) {
|
||||
// Create mock data
|
||||
subField1Data := &schemapb.FieldData{
|
||||
FieldName: "sub_int_array",
|
||||
FieldId: 1021,
|
||||
Type: schemapb.DataType_Array,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
Data: []*schemapb.ScalarField{
|
||||
{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
subField2Data := &schemapb.FieldData{
|
||||
FieldName: "sub_text_array",
|
||||
FieldId: 1022,
|
||||
Type: schemapb.DataType_Array,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
ElementType: schemapb.DataType_VarChar,
|
||||
Data: []*schemapb.ScalarField{
|
||||
{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{Data: []string{"hello", "world"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
results := &milvuspb.QueryResults{
|
||||
OutputFields: []string{"sub_int_array", "sub_text_array"},
|
||||
FieldsData: []*schemapb.FieldData{subField1Data, subField2Data},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "pk",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "test_struct",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1021,
|
||||
Name: "sub_int_array",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
{
|
||||
FieldID: 1022,
|
||||
Name: "sub_text_array",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_VarChar,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
reconstructStructFieldData(results, schema)
|
||||
|
||||
// Check result
|
||||
assert.Len(t, results.FieldsData, 1, "Should only have one reconstructed struct field")
|
||||
assert.Len(t, results.OutputFields, 1, "Output fields should only have one")
|
||||
|
||||
structField := results.FieldsData[0]
|
||||
assert.Equal(t, "test_struct", structField.FieldName)
|
||||
assert.Equal(t, int64(102), structField.FieldId)
|
||||
assert.Equal(t, schemapb.DataType_ArrayOfStruct, structField.Type)
|
||||
assert.Equal(t, "test_struct", results.OutputFields[0])
|
||||
|
||||
// Check fields inside struct
|
||||
structArrays := structField.GetStructArrays()
|
||||
assert.NotNil(t, structArrays)
|
||||
assert.Len(t, structArrays.Fields, 2, "Struct should contain 2 sub fields")
|
||||
|
||||
// Check sub fields
|
||||
var foundIntField, foundTextField bool
|
||||
for _, field := range structArrays.Fields {
|
||||
switch field.FieldId {
|
||||
case 1021:
|
||||
assert.Equal(t, "sub_int_array", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_Array, field.Type)
|
||||
foundIntField = true
|
||||
case 1022:
|
||||
assert.Equal(t, "sub_text_array", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_Array, field.Type)
|
||||
foundTextField = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundIntField, "Should find int array field")
|
||||
assert.True(t, foundTextField, "Should find text array field")
|
||||
})
|
||||
|
||||
t.Run("mixed regular and struct fields", func(t *testing.T) {
|
||||
// Create regular field data
|
||||
regularField := &schemapb.FieldData{
|
||||
FieldName: "regular_field",
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create struct sub field data
|
||||
subFieldData := &schemapb.FieldData{
|
||||
FieldName: "sub_field",
|
||||
FieldId: 1021,
|
||||
Type: schemapb.DataType_Array,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
Data: []*schemapb.ScalarField{
|
||||
{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{Data: []int32{10, 20}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
results := &milvuspb.QueryResults{
|
||||
OutputFields: []string{"regular_field", "sub_field"},
|
||||
FieldsData: []*schemapb.FieldData{regularField, subFieldData},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "regular_field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "test_struct",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1021,
|
||||
Name: "sub_field",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
reconstructStructFieldData(results, schema)
|
||||
|
||||
// Check result: should have 2 fields (1 regular + 1 reconstructed struct)
|
||||
assert.Len(t, results.FieldsData, 2)
|
||||
assert.Len(t, results.OutputFields, 2)
|
||||
|
||||
// Check regular and struct fields both exist
|
||||
var foundRegularField, foundStructField bool
|
||||
for i, field := range results.FieldsData {
|
||||
switch field.FieldId {
|
||||
case 100:
|
||||
assert.Equal(t, "regular_field", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_Int64, field.Type)
|
||||
assert.Equal(t, "regular_field", results.OutputFields[i])
|
||||
foundRegularField = true
|
||||
case 102:
|
||||
assert.Equal(t, "test_struct", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
|
||||
assert.Equal(t, "test_struct", results.OutputFields[i])
|
||||
foundStructField = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundRegularField, "Should find regular field")
|
||||
assert.True(t, foundStructField, "Should find reconstructed struct field")
|
||||
})
|
||||
|
||||
t.Run("multiple struct fields", func(t *testing.T) {
|
||||
// Create sub field for first struct
|
||||
struct1SubField := &schemapb.FieldData{
|
||||
FieldName: "struct1_sub",
|
||||
FieldId: 1021,
|
||||
Type: schemapb.DataType_Array,
|
||||
}
|
||||
|
||||
// Create sub fields for second struct
|
||||
struct2SubField1 := &schemapb.FieldData{
|
||||
FieldName: "struct2_sub1",
|
||||
FieldId: 1031,
|
||||
Type: schemapb.DataType_Array,
|
||||
}
|
||||
|
||||
struct2SubField2 := &schemapb.FieldData{
|
||||
FieldName: "struct2_sub2",
|
||||
FieldId: 1032,
|
||||
Type: schemapb.DataType_Array,
|
||||
}
|
||||
|
||||
results := &milvuspb.QueryResults{
|
||||
OutputFields: []string{"struct1_sub", "struct2_sub1", "struct2_sub2"},
|
||||
FieldsData: []*schemapb.FieldData{struct1SubField, struct2SubField1, struct2SubField2},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "struct1",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1021,
|
||||
Name: "struct1_sub",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "struct2",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1031,
|
||||
Name: "struct2_sub1",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_VarChar,
|
||||
},
|
||||
{
|
||||
FieldID: 1032,
|
||||
Name: "struct2_sub2",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Float,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
reconstructStructFieldData(results, schema)
|
||||
|
||||
// Check result: should have 2 reconstructed struct fields
|
||||
assert.Len(t, results.FieldsData, 2)
|
||||
assert.Len(t, results.OutputFields, 2)
|
||||
|
||||
// Check both struct fields are reconstructed correctly
|
||||
foundStruct1, foundStruct2 := false, false
|
||||
for i, field := range results.FieldsData {
|
||||
switch field.FieldId {
|
||||
case 102:
|
||||
assert.Equal(t, "struct1", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
|
||||
assert.Equal(t, "struct1", results.OutputFields[i])
|
||||
|
||||
structArrays := field.GetStructArrays()
|
||||
assert.NotNil(t, structArrays)
|
||||
assert.Len(t, structArrays.Fields, 1)
|
||||
assert.Equal(t, int64(1021), structArrays.Fields[0].FieldId)
|
||||
foundStruct1 = true
|
||||
|
||||
case 103:
|
||||
assert.Equal(t, "struct2", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
|
||||
assert.Equal(t, "struct2", results.OutputFields[i])
|
||||
|
||||
structArrays := field.GetStructArrays()
|
||||
assert.NotNil(t, structArrays)
|
||||
assert.Len(t, structArrays.Fields, 2)
|
||||
|
||||
// Check struct2 contains two sub fields
|
||||
subFieldIds := make([]int64, 0, 2)
|
||||
for _, subField := range structArrays.Fields {
|
||||
subFieldIds = append(subFieldIds, subField.FieldId)
|
||||
}
|
||||
assert.Contains(t, subFieldIds, int64(1031))
|
||||
assert.Contains(t, subFieldIds, int64(1032))
|
||||
foundStruct2 = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStruct1, "Should find struct1")
|
||||
assert.True(t, foundStruct2, "Should find struct2")
|
||||
})
|
||||
|
||||
t.Run("empty fields data", func(t *testing.T) {
|
||||
results := &milvuspb.QueryResults{
|
||||
OutputFields: []string{},
|
||||
FieldsData: []*schemapb.FieldData{},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "test_struct",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1021,
|
||||
Name: "sub_field",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
reconstructStructFieldData(results, schema)
|
||||
|
||||
// Empty data should remain unchanged
|
||||
assert.Len(t, results.FieldsData, 0)
|
||||
assert.Len(t, results.OutputFields, 0)
|
||||
})
|
||||
|
||||
t.Run("no matching sub fields", func(t *testing.T) {
|
||||
// Field data does not match any struct definition
|
||||
regularField := &schemapb.FieldData{
|
||||
FieldName: "regular_field",
|
||||
FieldId: 200, // Not in any struct
|
||||
Type: schemapb.DataType_Int64,
|
||||
}
|
||||
|
||||
results := &milvuspb.QueryResults{
|
||||
OutputFields: []string{"regular_field"},
|
||||
FieldsData: []*schemapb.FieldData{regularField},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 200,
|
||||
Name: "regular_field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "test_struct",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1021,
|
||||
Name: "sub_field",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
reconstructStructFieldData(results, schema)
|
||||
|
||||
// Should only keep the regular field, no struct field
|
||||
assert.Len(t, results.FieldsData, 1)
|
||||
assert.Len(t, results.OutputFields, 1)
|
||||
assert.Equal(t, int64(200), results.FieldsData[0].FieldId)
|
||||
assert.Equal(t, "regular_field", results.OutputFields[0])
|
||||
})
|
||||
}
|
||||
|
||||
@ -544,7 +544,8 @@ func (t *searchTask) initSearchRequest(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool {
|
||||
allFields := typeutil.GetAllFieldSchemas(t.schema.CollectionSchema)
|
||||
vectorOutputFields := lo.Filter(allFields, func(field *schemapb.FieldSchema, _ int) bool {
|
||||
return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType())
|
||||
})
|
||||
t.needRequery = len(vectorOutputFields) > 0
|
||||
@ -765,6 +766,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
|
||||
}
|
||||
t.fillResult()
|
||||
t.result.Results.OutputFields = t.userOutputFields
|
||||
reconstructStructFieldDataForSearch(t.result, t.schema.CollectionSchema)
|
||||
t.result.CollectionName = t.request.GetCollectionName()
|
||||
|
||||
primaryFieldSchema, _ := t.schema.GetPkField()
|
||||
|
||||
@ -3548,6 +3548,278 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
|
||||
assert.ErrorContains(t, err, "failed to parse input last bound")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("check vector array unsupported features", func(t *testing.T) {
|
||||
// Helper function to create a schema with vector array field
|
||||
createSchemaWithVectorArray := func(annsFieldName string) *schemapb.CollectionSchema {
|
||||
return &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "id",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
IsPrimaryKey: true,
|
||||
},
|
||||
{
|
||||
FieldID: 101,
|
||||
Name: "normal_vector",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "128"},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "group_field",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
},
|
||||
},
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "struct_array_field",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 104,
|
||||
Name: annsFieldName,
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "128"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create search params with anns field
|
||||
createSearchParams := func(annsFieldName string) []*commonpb.KeyValuePair {
|
||||
return []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: annsFieldName,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: "10",
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metric.MaxSim,
|
||||
},
|
||||
{
|
||||
Key: ParamsKey,
|
||||
Value: `{"nprobe": 10}`,
|
||||
},
|
||||
{
|
||||
Key: RoundDecimalKey,
|
||||
Value: "-1",
|
||||
},
|
||||
{
|
||||
Key: IgnoreGrowingKey,
|
||||
Value: "false",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("vector array with range search", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
params := createSearchParams("embeddings_list")
|
||||
|
||||
// Add radius parameter for range search
|
||||
resetSearchParamsValue(params, ParamsKey, `{"nprobe": 10, "radius": 0.2}`)
|
||||
|
||||
searchInfo, err := parseSearchInfo(params, schema, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, searchInfo)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
fmt.Println(err.Error())
|
||||
assert.Contains(t, err.Error(), "range search is not supported for vector array (embedding list) fields")
|
||||
})
|
||||
|
||||
t.Run("vector array with group by", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
params := createSearchParams("embeddings_list")
|
||||
|
||||
// Add group by parameter
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: GroupByFieldKey,
|
||||
Value: "group_field",
|
||||
})
|
||||
|
||||
searchInfo, err := parseSearchInfo(params, schema, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, searchInfo)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
assert.Contains(t, err.Error(), "group by search is not supported for vector array (embedding list) fields")
|
||||
assert.Contains(t, err.Error(), "embeddings_list")
|
||||
})
|
||||
|
||||
t.Run("vector array with iterator", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
params := createSearchParams("embeddings_list")
|
||||
|
||||
// Add iterator parameter
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: IteratorField,
|
||||
Value: "True",
|
||||
})
|
||||
|
||||
searchInfo, err := parseSearchInfo(params, schema, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, searchInfo)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
assert.Contains(t, err.Error(), "search iterator is not supported for vector array (embedding list) fields")
|
||||
assert.Contains(t, err.Error(), "embeddings_list")
|
||||
})
|
||||
|
||||
t.Run("vector array with iterator v2", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
params := createSearchParams("embeddings_list")
|
||||
|
||||
// Add iterator v2 parameters
|
||||
params = append(params,
|
||||
&commonpb.KeyValuePair{
|
||||
Key: SearchIterV2Key,
|
||||
Value: "True",
|
||||
},
|
||||
&commonpb.KeyValuePair{
|
||||
Key: IteratorField,
|
||||
Value: "True",
|
||||
},
|
||||
&commonpb.KeyValuePair{
|
||||
Key: SearchIterBatchSizeKey,
|
||||
Value: "10",
|
||||
},
|
||||
)
|
||||
|
||||
searchInfo, err := parseSearchInfo(params, schema, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, searchInfo)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
assert.Contains(t, err.Error(), "search iterator is not supported for vector array (embedding list) fields")
|
||||
assert.Contains(t, err.Error(), "embeddings_list")
|
||||
})
|
||||
|
||||
t.Run("normal search on vector array should succeed", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
params := createSearchParams("embeddings_list")
|
||||
|
||||
searchInfo, err := parseSearchInfo(params, schema, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, searchInfo)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
})
|
||||
|
||||
t.Run("normal vector field with range search should succeed", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
params := createSearchParams("normal_vector")
|
||||
|
||||
// Add radius parameter for range search
|
||||
resetSearchParamsValue(params, ParamsKey, `{"nprobe": 10, "radius": 0.2}`)
|
||||
|
||||
searchInfo, err := parseSearchInfo(params, schema, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, searchInfo)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
})
|
||||
|
||||
t.Run("normal vector field with group by should succeed", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
params := createSearchParams("normal_vector")
|
||||
|
||||
// Add group by parameter
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: GroupByFieldKey,
|
||||
Value: "group_field",
|
||||
})
|
||||
|
||||
searchInfo, err := parseSearchInfo(params, schema, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, searchInfo)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
assert.Equal(t, int64(102), searchInfo.planInfo.GroupByFieldId)
|
||||
})
|
||||
|
||||
t.Run("normal vector field with iterator should succeed", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
params := createSearchParams("normal_vector")
|
||||
|
||||
// Add iterator parameter
|
||||
params = append(params, &commonpb.KeyValuePair{
|
||||
Key: IteratorField,
|
||||
Value: "True",
|
||||
})
|
||||
|
||||
searchInfo, err := parseSearchInfo(params, schema, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, searchInfo)
|
||||
assert.NotNil(t, searchInfo.planInfo)
|
||||
})
|
||||
|
||||
t.Run("vector array with range search", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
params := createSearchParams("embeddings_list")
|
||||
resetSearchParamsValue(params, ParamsKey, `{"nprobe": 10, "radius": 0.2}`)
|
||||
|
||||
searchInfo, err := parseSearchInfo(params, schema, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, searchInfo)
|
||||
// Should fail on range search first
|
||||
assert.Contains(t, err.Error(), "range search is not supported for vector array (embedding list) fields")
|
||||
})
|
||||
|
||||
t.Run("no anns field specified", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
params := getValidSearchParams()
|
||||
// Don't specify anns field
|
||||
|
||||
searchInfo, err := parseSearchInfo(params, schema, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, searchInfo)
|
||||
// Should not trigger vector array validation without anns field
|
||||
})
|
||||
|
||||
t.Run("non-existent anns field", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
params := createSearchParams("non_existent_field")
|
||||
|
||||
searchInfo, err := parseSearchInfo(params, schema, nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, searchInfo)
|
||||
// Should not trigger vector array validation for non-existent field
|
||||
})
|
||||
|
||||
t.Run("hybrid search with outer group by on vector array", func(t *testing.T) {
|
||||
schema := createSchemaWithVectorArray("embeddings_list")
|
||||
|
||||
// Create rank params with group by
|
||||
rankParams := getValidSearchParams()
|
||||
rankParams = append(rankParams,
|
||||
&commonpb.KeyValuePair{
|
||||
Key: GroupByFieldKey,
|
||||
Value: "group_field",
|
||||
},
|
||||
&commonpb.KeyValuePair{
|
||||
Key: LimitKey,
|
||||
Value: "100",
|
||||
},
|
||||
)
|
||||
|
||||
parsedRankParams, err := parseRankParams(rankParams, schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
searchParams := createSearchParams("embeddings_list")
|
||||
// Parse search info with rank params
|
||||
searchInfo, err := parseSearchInfo(searchParams, schema, parsedRankParams)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, searchInfo)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
assert.Contains(t, err.Error(), "group by search is not supported for vector array (embedding list) fields")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func getSearchResultData(nq, topk int64) *schemapb.SearchResultData {
|
||||
@ -4301,3 +4573,110 @@ func genTestSearchResultData(nq int64, topk int64, dType schemapb.DataType, fiel
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func TestSearchTask_InitSearchRequestWithStructArrayFields(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test_collection",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "regular_vec", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}},
|
||||
{FieldID: 102, Name: "regular_scalar", DataType: schemapb.DataType_Int32},
|
||||
},
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
Name: "structArray",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 104, Name: "struct_vec_array", DataType: schemapb.DataType_ArrayOfVector, ElementType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "64"}}},
|
||||
{FieldID: 105, Name: "struct_scalar_array", DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int32},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schemaInfo := newSchemaInfo(schema)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
outputFields []string
|
||||
expectedRequery bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "regular_vector_field",
|
||||
outputFields: []string{"pk", "regular_vec"},
|
||||
expectedRequery: true,
|
||||
description: "Should require requery when regular vector field in output",
|
||||
},
|
||||
{
|
||||
name: "struct_array_vector_field",
|
||||
outputFields: []string{"pk", "struct_vec_array"},
|
||||
expectedRequery: true,
|
||||
description: "Should require requery when struct array vector field in output (tests GetAllFieldSchemas)",
|
||||
},
|
||||
{
|
||||
name: "both_vector_fields",
|
||||
outputFields: []string{"pk", "regular_vec", "struct_vec_array"},
|
||||
expectedRequery: true,
|
||||
description: "Should require requery when both regular and struct array vector fields in output",
|
||||
},
|
||||
{
|
||||
name: "struct_scalar_array_only",
|
||||
outputFields: []string{"pk", "struct_scalar_array"},
|
||||
expectedRequery: false,
|
||||
description: "Should not require requery when only struct scalar array field in output",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
task := &searchTask{
|
||||
ctx: ctx,
|
||||
collectionName: "test_collection",
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
CollectionID: 1,
|
||||
PartitionIDs: []int64{1},
|
||||
Dsl: "",
|
||||
PlaceholderGroup: nil,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFieldsId: []int64{},
|
||||
},
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: "test_collection",
|
||||
OutputFields: tt.outputFields,
|
||||
Dsl: "",
|
||||
SearchParams: []*commonpb.KeyValuePair{
|
||||
{Key: AnnsFieldKey, Value: "regular_vec"},
|
||||
{Key: TopKKey, Value: "10"},
|
||||
{Key: common.MetricTypeKey, Value: metric.L2},
|
||||
{Key: ParamsKey, Value: `{"nprobe": 10}`},
|
||||
},
|
||||
PlaceholderGroup: nil,
|
||||
ConsistencyLevel: commonpb.ConsistencyLevel_Session,
|
||||
},
|
||||
schema: schemaInfo,
|
||||
translatedOutputFields: tt.outputFields,
|
||||
tr: timerecord.NewTimeRecorder("test"),
|
||||
queryInfos: []*planpb.QueryInfo{{}},
|
||||
}
|
||||
|
||||
// Set translated output field IDs based on the schema
|
||||
outputFieldIDs := []int64{}
|
||||
allFields := typeutil.GetAllFieldSchemas(schema)
|
||||
for _, fieldName := range tt.outputFields {
|
||||
for _, field := range allFields {
|
||||
if field.Name == fieldName {
|
||||
outputFieldIDs = append(outputFieldIDs, field.FieldID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
task.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||
|
||||
err := task.initSearchRequest(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedRequery, task.needRequery, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -585,11 +585,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
allFields := make([]*schemapb.FieldSchema, 0, len(it.schema.Fields)+5)
|
||||
allFields = append(allFields, it.schema.Fields...)
|
||||
for _, structField := range it.schema.GetStructArrayFields() {
|
||||
allFields = append(allFields, structField.GetFields()...)
|
||||
}
|
||||
allFields := typeutil.GetAllFieldSchemas(it.schema.CollectionSchema)
|
||||
|
||||
// use the passed pk as new pk when autoID == false
|
||||
// automatic generate pk as new pk wehen autoID == true
|
||||
|
||||
@ -619,9 +619,6 @@ func ValidateFieldsInStruct(field *schemapb.FieldSchema, schema *schemapb.Collec
|
||||
if field.GetNullable() {
|
||||
return fmt.Errorf("nullable is not supported for fields in struct array now, fieldName = %s", field.Name)
|
||||
}
|
||||
|
||||
// todo(SpadeA): add more check when index is enabled
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -2547,3 +2544,88 @@ func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 {
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// reconstructStructFieldDataCommon reconstructs struct fields from flattened sub-fields
|
||||
// It works with both QueryResults and SearchResults by operating on the common data structures
|
||||
func reconstructStructFieldDataCommon(
|
||||
fieldsData []*schemapb.FieldData,
|
||||
outputFields []string,
|
||||
schema *schemapb.CollectionSchema,
|
||||
) ([]*schemapb.FieldData, []string) {
|
||||
if len(outputFields) == 1 && outputFields[0] == "count(*)" {
|
||||
return fieldsData, outputFields
|
||||
}
|
||||
|
||||
if len(schema.StructArrayFields) == 0 {
|
||||
return fieldsData, outputFields
|
||||
}
|
||||
|
||||
regularFieldIDs := make(map[int64]interface{})
|
||||
subFieldToStructMap := make(map[int64]int64)
|
||||
groupedStructFields := make(map[int64][]*schemapb.FieldData)
|
||||
structFieldNames := make(map[int64]string)
|
||||
reconstructedOutputFields := make([]string, 0, len(fieldsData))
|
||||
|
||||
// record all regular field IDs
|
||||
for _, field := range schema.Fields {
|
||||
regularFieldIDs[field.GetFieldID()] = nil
|
||||
}
|
||||
|
||||
// build the mapping from sub-field ID to struct field ID
|
||||
for _, structField := range schema.StructArrayFields {
|
||||
for _, subField := range structField.GetFields() {
|
||||
subFieldToStructMap[subField.GetFieldID()] = structField.GetFieldID()
|
||||
}
|
||||
structFieldNames[structField.GetFieldID()] = structField.GetName()
|
||||
}
|
||||
|
||||
newFieldsData := make([]*schemapb.FieldData, 0, len(fieldsData))
|
||||
for _, field := range fieldsData {
|
||||
fieldID := field.GetFieldId()
|
||||
if _, ok := regularFieldIDs[fieldID]; ok {
|
||||
newFieldsData = append(newFieldsData, field)
|
||||
reconstructedOutputFields = append(reconstructedOutputFields, field.GetFieldName())
|
||||
} else {
|
||||
structFieldID := subFieldToStructMap[fieldID]
|
||||
groupedStructFields[structFieldID] = append(groupedStructFields[structFieldID], field)
|
||||
}
|
||||
}
|
||||
|
||||
for structFieldID, fields := range groupedStructFields {
|
||||
fieldData := &schemapb.FieldData{
|
||||
FieldName: structFieldNames[structFieldID],
|
||||
FieldId: structFieldID,
|
||||
Type: schemapb.DataType_ArrayOfStruct,
|
||||
Field: &schemapb.FieldData_StructArrays{StructArrays: &schemapb.StructArrayField{Fields: fields}},
|
||||
}
|
||||
newFieldsData = append(newFieldsData, fieldData)
|
||||
reconstructedOutputFields = append(reconstructedOutputFields, structFieldNames[structFieldID])
|
||||
}
|
||||
|
||||
return newFieldsData, reconstructedOutputFields
|
||||
}
|
||||
|
||||
// Wrapper for QueryResults
|
||||
func reconstructStructFieldDataForQuery(results *milvuspb.QueryResults, schema *schemapb.CollectionSchema) {
|
||||
fieldsData, outputFields := reconstructStructFieldDataCommon(
|
||||
results.FieldsData,
|
||||
results.OutputFields,
|
||||
schema,
|
||||
)
|
||||
results.FieldsData = fieldsData
|
||||
results.OutputFields = outputFields
|
||||
}
|
||||
|
||||
// New wrapper for SearchResults
|
||||
func reconstructStructFieldDataForSearch(results *milvuspb.SearchResults, schema *schemapb.CollectionSchema) {
|
||||
if results.Results == nil {
|
||||
return
|
||||
}
|
||||
fieldsData, outputFields := reconstructStructFieldDataCommon(
|
||||
results.Results.FieldsData,
|
||||
results.Results.OutputFields,
|
||||
schema,
|
||||
)
|
||||
results.Results.FieldsData = fieldsData
|
||||
results.Results.OutputFields = outputFields
|
||||
}
|
||||
|
||||
@ -3825,5 +3825,644 @@ func TestCheckAndFlattenStructFieldData(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidateFieldsInStruct(t *testing.T) {
|
||||
// todo(SpadeA): add test cases
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test_collection",
|
||||
}
|
||||
|
||||
t.Run("valid array field", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "valid_array",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("valid array of vector field", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "valid_array_vector",
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "128"},
|
||||
},
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid field name", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
expected string
|
||||
}{
|
||||
{"", "field name should not be empty"},
|
||||
{"123abc", "The first character of a field name must be an underscore or letter"},
|
||||
{"abc-def", "Field name can only contain numbers, letters, and underscores"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: tc.name,
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.expected)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid data type", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "invalid_type",
|
||||
DataType: schemapb.DataType_Int32, // Not array or array of vector
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Fields in StructArrayField can only be array or array of struct")
|
||||
})
|
||||
|
||||
t.Run("nested array not supported", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
elementType schemapb.DataType
|
||||
}{
|
||||
{schemapb.DataType_ArrayOfStruct},
|
||||
{schemapb.DataType_ArrayOfVector},
|
||||
{schemapb.DataType_Array},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "nested_array",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: tc.elementType,
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Nested array is not supported")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("array field with vector element type", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "array_with_vector",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_FloatVector,
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "element type of array field array_with_vector is a vector type")
|
||||
})
|
||||
|
||||
t.Run("array of vector field with non-vector element type", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "array_vector_with_scalar",
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "element type of array field array_vector_with_scalar is not a vector type")
|
||||
})
|
||||
|
||||
t.Run("array of vector missing dimension", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "array_vector_no_dim",
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{}, // No dimension specified
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "dimension is not defined in field")
|
||||
})
|
||||
|
||||
t.Run("array of vector with invalid dimension", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "array_vector_invalid_dim",
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "not_a_number"},
|
||||
},
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("varchar array without max_length", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "varchar_array_no_max_length",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{}, // No max_length specified
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "type param(max_length) should be specified")
|
||||
})
|
||||
|
||||
t.Run("varchar array with valid max_length", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "varchar_array_valid",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.MaxLengthKey, Value: "100"},
|
||||
},
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("varchar array with invalid max_length", func(t *testing.T) {
|
||||
// Test with max_length exceeding limit
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "varchar_array_invalid_length",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_VarChar,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.MaxLengthKey, Value: "99999999"}, // Exceeds limit
|
||||
},
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "the maximum length specified for the field")
|
||||
})
|
||||
|
||||
t.Run("nullable field not supported", func(t *testing.T) {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "nullable_field",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
Nullable: true,
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "nullable is not supported for fields in struct array now")
|
||||
})
|
||||
|
||||
t.Run("sparse float vector in array of vector", func(t *testing.T) {
|
||||
// Note: ArrayOfVector with sparse vector element type still requires dimension
|
||||
// because validateDimension checks the field's DataType (ArrayOfVector), not ElementType
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "sparse_vector_array",
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: schemapb.DataType_SparseFloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{},
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "dimension is not defined")
|
||||
})
|
||||
|
||||
t.Run("array with various scalar element types", func(t *testing.T) {
|
||||
validScalarTypes := []schemapb.DataType{
|
||||
schemapb.DataType_Bool,
|
||||
schemapb.DataType_Int8,
|
||||
schemapb.DataType_Int16,
|
||||
schemapb.DataType_Int32,
|
||||
schemapb.DataType_Int64,
|
||||
schemapb.DataType_Float,
|
||||
schemapb.DataType_Double,
|
||||
schemapb.DataType_String,
|
||||
}
|
||||
|
||||
for _, dt := range validScalarTypes {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "array_" + dt.String(),
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: dt,
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("array of vector with various vector types", func(t *testing.T) {
|
||||
validVectorTypes := []schemapb.DataType{
|
||||
schemapb.DataType_FloatVector,
|
||||
schemapb.DataType_BinaryVector,
|
||||
schemapb.DataType_Float16Vector,
|
||||
schemapb.DataType_BFloat16Vector,
|
||||
// Note: SparseFloatVector is excluded because validateDimension checks
|
||||
// the field's DataType (ArrayOfVector), not ElementType, so it still requires dimension
|
||||
}
|
||||
|
||||
for _, vt := range validVectorTypes {
|
||||
field := &schemapb.FieldSchema{
|
||||
Name: "vector_array_" + vt.String(),
|
||||
DataType: schemapb.DataType_ArrayOfVector,
|
||||
ElementType: vt,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: common.DimKey, Value: "128"},
|
||||
},
|
||||
}
|
||||
err := ValidateFieldsInStruct(field, schema)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_reconstructStructFieldDataCommon(t *testing.T) {
|
||||
t.Run("count(*) query - should return early", func(t *testing.T) {
|
||||
fieldsData := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "count(*)",
|
||||
FieldId: 0,
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
}
|
||||
outputFields := []string{"count(*)"}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "test_struct",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1021,
|
||||
Name: "sub_field",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
originalFieldsData := make([]*schemapb.FieldData, len(fieldsData))
|
||||
copy(originalFieldsData, fieldsData)
|
||||
originalOutputFields := make([]string, len(outputFields))
|
||||
copy(originalOutputFields, outputFields)
|
||||
|
||||
resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema)
|
||||
|
||||
// Should not modify anything for count(*) query
|
||||
assert.Equal(t, originalFieldsData, resultFieldsData)
|
||||
assert.Equal(t, originalOutputFields, resultOutputFields)
|
||||
})
|
||||
|
||||
t.Run("no struct array fields - should return early", func(t *testing.T) {
|
||||
fieldsData := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "field1",
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldName: "field2",
|
||||
FieldId: 101,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
},
|
||||
}
|
||||
outputFields := []string{"field1", "field2"}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{},
|
||||
}
|
||||
|
||||
originalFieldsData := make([]*schemapb.FieldData, len(fieldsData))
|
||||
copy(originalFieldsData, fieldsData)
|
||||
originalOutputFields := make([]string, len(outputFields))
|
||||
copy(originalOutputFields, outputFields)
|
||||
|
||||
resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema)
|
||||
|
||||
// Should not modify anything when no struct array fields
|
||||
assert.Equal(t, originalFieldsData, resultFieldsData)
|
||||
assert.Equal(t, originalOutputFields, resultOutputFields)
|
||||
})
|
||||
|
||||
t.Run("reconstruct single struct field", func(t *testing.T) {
|
||||
// Create mock data
|
||||
subField1Data := &schemapb.FieldData{
|
||||
FieldName: "sub_int_array",
|
||||
FieldId: 1021,
|
||||
Type: schemapb.DataType_Array,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
Data: []*schemapb.ScalarField{
|
||||
{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{Data: []int32{1, 2, 3}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
subField2Data := &schemapb.FieldData{
|
||||
FieldName: "sub_text_array",
|
||||
FieldId: 1022,
|
||||
Type: schemapb.DataType_Array,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
ElementType: schemapb.DataType_VarChar,
|
||||
Data: []*schemapb.ScalarField{
|
||||
{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{Data: []string{"hello", "world"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldsData := []*schemapb.FieldData{subField1Data, subField2Data}
|
||||
outputFields := []string{"sub_int_array", "sub_text_array"}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "pk",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "test_struct",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1021,
|
||||
Name: "sub_int_array",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
{
|
||||
FieldID: 1022,
|
||||
Name: "sub_text_array",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_VarChar,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema)
|
||||
|
||||
// Check result
|
||||
assert.Len(t, resultFieldsData, 1, "Should only have one reconstructed struct field")
|
||||
assert.Len(t, resultOutputFields, 1, "Output fields should only have one")
|
||||
|
||||
structField := resultFieldsData[0]
|
||||
assert.Equal(t, "test_struct", structField.FieldName)
|
||||
assert.Equal(t, int64(102), structField.FieldId)
|
||||
assert.Equal(t, schemapb.DataType_ArrayOfStruct, structField.Type)
|
||||
assert.Equal(t, "test_struct", resultOutputFields[0])
|
||||
|
||||
// Check fields inside struct
|
||||
structArrays := structField.GetStructArrays()
|
||||
assert.NotNil(t, structArrays)
|
||||
assert.Len(t, structArrays.Fields, 2, "Struct should contain 2 sub fields")
|
||||
|
||||
// Check sub fields
|
||||
var foundIntField, foundTextField bool
|
||||
for _, field := range structArrays.Fields {
|
||||
switch field.FieldId {
|
||||
case 1021:
|
||||
assert.Equal(t, "sub_int_array", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_Array, field.Type)
|
||||
foundIntField = true
|
||||
case 1022:
|
||||
assert.Equal(t, "sub_text_array", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_Array, field.Type)
|
||||
foundTextField = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundIntField, "Should find int array field")
|
||||
assert.True(t, foundTextField, "Should find text array field")
|
||||
})
|
||||
|
||||
t.Run("mixed regular and struct fields", func(t *testing.T) {
|
||||
// Create regular field data
|
||||
regularField := &schemapb.FieldData{
|
||||
FieldName: "regular_field",
|
||||
FieldId: 100,
|
||||
Type: schemapb.DataType_Int64,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create struct sub field data
|
||||
subFieldData := &schemapb.FieldData{
|
||||
FieldName: "sub_field",
|
||||
FieldId: 1021,
|
||||
Type: schemapb.DataType_Array,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
Data: []*schemapb.ScalarField{
|
||||
{
|
||||
Data: &schemapb.ScalarField_IntData{
|
||||
IntData: &schemapb.IntArray{Data: []int32{10, 20}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldsData := []*schemapb.FieldData{regularField, subFieldData}
|
||||
outputFields := []string{"regular_field", "sub_field"}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "regular_field",
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "test_struct",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1021,
|
||||
Name: "sub_field",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema)
|
||||
|
||||
// Check result: should have 2 fields (1 regular + 1 reconstructed struct)
|
||||
assert.Len(t, resultFieldsData, 2)
|
||||
assert.Len(t, resultOutputFields, 2)
|
||||
|
||||
// Check regular and struct fields both exist
|
||||
var foundRegularField, foundStructField bool
|
||||
for i, field := range resultFieldsData {
|
||||
switch field.FieldId {
|
||||
case 100:
|
||||
assert.Equal(t, "regular_field", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_Int64, field.Type)
|
||||
assert.Equal(t, "regular_field", resultOutputFields[i])
|
||||
foundRegularField = true
|
||||
case 102:
|
||||
assert.Equal(t, "test_struct", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
|
||||
assert.Equal(t, "test_struct", resultOutputFields[i])
|
||||
foundStructField = true
|
||||
}
|
||||
}
|
||||
assert.True(t, foundRegularField, "Should find regular field")
|
||||
assert.True(t, foundStructField, "Should find reconstructed struct field")
|
||||
})
|
||||
|
||||
t.Run("multiple struct fields", func(t *testing.T) {
|
||||
// Create sub field for first struct
|
||||
struct1SubField := &schemapb.FieldData{
|
||||
FieldName: "struct1_sub",
|
||||
FieldId: 1021,
|
||||
Type: schemapb.DataType_Array,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
Data: []*schemapb.ScalarField{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create sub fields for second struct
|
||||
struct2SubField1 := &schemapb.FieldData{
|
||||
FieldName: "struct2_sub1",
|
||||
FieldId: 1031,
|
||||
Type: schemapb.DataType_Array,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_ArrayData{
|
||||
ArrayData: &schemapb.ArrayArray{
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
Data: []*schemapb.ScalarField{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
struct2SubField2 := &schemapb.FieldData{
|
||||
FieldName: "struct2_sub2",
|
||||
FieldId: 1032,
|
||||
Type: schemapb.DataType_VarChar,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_StringData{
|
||||
StringData: &schemapb.StringArray{Data: []string{"test"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fieldsData := []*schemapb.FieldData{struct1SubField, struct2SubField1, struct2SubField2}
|
||||
outputFields := []string{"struct1_sub", "struct2_sub1", "struct2_sub2"}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 100,
|
||||
Name: "pk",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
},
|
||||
StructArrayFields: []*schemapb.StructArrayFieldSchema{
|
||||
{
|
||||
FieldID: 102,
|
||||
Name: "struct1",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1021,
|
||||
Name: "struct1_sub",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
FieldID: 103,
|
||||
Name: "struct2",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1031,
|
||||
Name: "struct2_sub1",
|
||||
DataType: schemapb.DataType_Array,
|
||||
ElementType: schemapb.DataType_Int32,
|
||||
},
|
||||
{
|
||||
FieldID: 1032,
|
||||
Name: "struct2_sub2",
|
||||
DataType: schemapb.DataType_VarChar,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resultFieldsData, resultOutputFields := reconstructStructFieldDataCommon(fieldsData, outputFields, schema)
|
||||
|
||||
// Check result: should have 2 struct fields
|
||||
assert.Len(t, resultFieldsData, 2)
|
||||
assert.Len(t, resultOutputFields, 2)
|
||||
|
||||
// Check both struct fields
|
||||
var foundStruct1, foundStruct2 bool
|
||||
for _, field := range resultFieldsData {
|
||||
switch field.FieldId {
|
||||
case 102:
|
||||
assert.Equal(t, "struct1", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
|
||||
foundStruct1 = true
|
||||
structArrays := field.GetStructArrays()
|
||||
assert.NotNil(t, structArrays)
|
||||
assert.Len(t, structArrays.Fields, 1)
|
||||
case 103:
|
||||
assert.Equal(t, "struct2", field.FieldName)
|
||||
assert.Equal(t, schemapb.DataType_ArrayOfStruct, field.Type)
|
||||
foundStruct2 = true
|
||||
structArrays := field.GetStructArrays()
|
||||
assert.NotNil(t, structArrays)
|
||||
assert.Len(t, structArrays.Fields, 2)
|
||||
}
|
||||
}
|
||||
assert.True(t, foundStruct1, "Should find struct1")
|
||||
assert.True(t, foundStruct2, "Should find struct2")
|
||||
})
|
||||
}
|
||||
|
||||
@ -33,7 +33,6 @@ import (
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/cgopb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
@ -115,25 +114,6 @@ func (li *LoadIndexInfo) appendIndexFile(ctx context.Context, filePath string) e
|
||||
return HandleCStatus(ctx, &status, "AppendIndexIFile failed")
|
||||
}
|
||||
|
||||
// appendFieldInfo appends fieldID & fieldType to index
|
||||
func (li *LoadIndexInfo) appendFieldInfo(ctx context.Context, collectionID int64, partitionID int64, segmentID int64, fieldID int64, fieldType schemapb.DataType, enableMmap bool, mmapDirPath string) error {
|
||||
var status C.CStatus
|
||||
GetDynamicPool().Submit(func() (any, error) {
|
||||
cColID := C.int64_t(collectionID)
|
||||
cParID := C.int64_t(partitionID)
|
||||
cSegID := C.int64_t(segmentID)
|
||||
cFieldID := C.int64_t(fieldID)
|
||||
cintDType := uint32(fieldType)
|
||||
cEnableMmap := C.bool(enableMmap)
|
||||
cMmapDirPath := C.CString(mmapDirPath)
|
||||
defer C.free(unsafe.Pointer(cMmapDirPath))
|
||||
status = C.AppendFieldInfo(li.cLoadIndexInfo, cColID, cParID, cSegID, cFieldID, cintDType, cEnableMmap, cMmapDirPath)
|
||||
return nil, nil
|
||||
}).Await()
|
||||
|
||||
return HandleCStatus(ctx, &status, "AppendFieldInfo failed")
|
||||
}
|
||||
|
||||
func (li *LoadIndexInfo) appendStorageInfo(uri string, version int64) {
|
||||
GetDynamicPool().Submit(func() (any, error) {
|
||||
cURI := C.CString(uri)
|
||||
|
||||
@ -476,9 +476,16 @@ func checkFieldSchema(fieldSchemas []*schemapb.FieldSchema) error {
|
||||
|
||||
func checkStructArrayFieldSchema(schemas []*schemapb.StructArrayFieldSchema) error {
|
||||
for _, schema := range schemas {
|
||||
// todo(SpadeA): check struct array field schema
|
||||
if len(schema.GetFields()) == 0 {
|
||||
return merr.WrapErrParameterInvalidMsg("empty fields in StructArrayField is not allowed")
|
||||
}
|
||||
|
||||
for _, field := range schema.GetFields() {
|
||||
if field.GetDataType() != schemapb.DataType_Array && field.GetDataType() != schemapb.DataType_ArrayOfVector {
|
||||
msg := fmt.Sprintf("Fields in StructArrayField can only be array or array of vector, but field %s is %s", field.Name, field.DataType.String())
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
|
||||
if field.IsPartitionKey || field.IsPrimaryKey {
|
||||
msg := fmt.Sprintf("partition key or primary key can not be in struct array field. data type:%s, element type:%s, name:%s",
|
||||
field.DataType.String(), field.ElementType.String(), field.Name)
|
||||
|
||||
@ -19,6 +19,7 @@ package storage
|
||||
import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
// DataSorter sorts insert data
|
||||
@ -52,11 +53,7 @@ func (ds *DataSorter) Len() int {
|
||||
// Swap swaps each field's i-th and j-th element
|
||||
func (ds *DataSorter) Swap(i, j int) {
|
||||
if ds.AllFields == nil {
|
||||
allFields := ds.InsertCodec.Schema.Schema.Fields
|
||||
for _, field := range ds.InsertCodec.Schema.Schema.StructArrayFields {
|
||||
allFields = append(allFields, field.Fields...)
|
||||
}
|
||||
ds.AllFields = allFields
|
||||
ds.AllFields = typeutil.GetAllFieldSchemas(ds.InsertCodec.Schema.Schema)
|
||||
}
|
||||
for _, field := range ds.AllFields {
|
||||
singleData, has := ds.InsertData.Data[field.FieldID]
|
||||
|
||||
@ -282,9 +282,9 @@ func ValueDeserializerWithSchema(r Record, v []*Value, schema *schemapb.Collecti
|
||||
return valueDeserializer(r, v, allFields, shouldCopy)
|
||||
}
|
||||
|
||||
func valueDeserializer(r Record, v []*Value, fieldSchema []*schemapb.FieldSchema, shouldCopy bool) error {
|
||||
func valueDeserializer(r Record, v []*Value, fields []*schemapb.FieldSchema, shouldCopy bool) error {
|
||||
pkField := func() *schemapb.FieldSchema {
|
||||
for _, field := range fieldSchema {
|
||||
for _, field := range fields {
|
||||
if field.GetIsPrimaryKey() {
|
||||
return field
|
||||
}
|
||||
@ -299,12 +299,12 @@ func valueDeserializer(r Record, v []*Value, fieldSchema []*schemapb.FieldSchema
|
||||
value := v[i]
|
||||
if value == nil {
|
||||
value = &Value{}
|
||||
value.Value = make(map[FieldID]interface{}, len(fieldSchema))
|
||||
value.Value = make(map[FieldID]interface{}, len(fields))
|
||||
v[i] = value
|
||||
}
|
||||
|
||||
m := value.Value.(map[FieldID]interface{})
|
||||
for _, f := range fieldSchema {
|
||||
for _, f := range fields {
|
||||
j := f.FieldID
|
||||
dt := f.DataType
|
||||
if r.Column(j).IsNull(i) {
|
||||
|
||||
@ -1532,7 +1532,9 @@ func GetDefaultValue(fieldSchema *schemapb.FieldSchema) interface{} {
|
||||
func fillMissingFields(schema *schemapb.CollectionSchema, insertData *InsertData) error {
|
||||
batchRows := int64(insertData.GetRowNum())
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
allFields := typeutil.GetAllFieldSchemas(schema)
|
||||
|
||||
for _, field := range allFields {
|
||||
// Skip function output fields and system fields
|
||||
if field.GetIsFunctionOutput() || field.GetFieldID() < 100 {
|
||||
continue
|
||||
|
||||
@ -9,7 +9,7 @@ type AUTOINDEXChecker struct {
|
||||
baseChecker
|
||||
}
|
||||
|
||||
func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
func (c *AUTOINDEXChecker) CheckTrain(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ import (
|
||||
|
||||
type baseChecker struct{}
|
||||
|
||||
func (c baseChecker) CheckTrain(dataType schemapb.DataType, params map[string]string) error {
|
||||
func (c baseChecker) CheckTrain(_ schemapb.DataType, _ schemapb.DataType, _ map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -35,7 +35,7 @@ func (c baseChecker) CheckValidDataType(indexType IndexType, field *schemapb.Fie
|
||||
|
||||
func (c baseChecker) SetDefaultMetricTypeIfNotExist(dType schemapb.DataType, m map[string]string) {}
|
||||
|
||||
func (c baseChecker) StaticCheck(dataType schemapb.DataType, params map[string]string) error {
|
||||
func (c baseChecker) StaticCheck(dataType schemapb.DataType, elementType schemapb.DataType, params map[string]string) error {
|
||||
return errors.New("unsupported index type")
|
||||
}
|
||||
|
||||
|
||||
@ -47,9 +47,9 @@ func Test_baseChecker_CheckTrain(t *testing.T) {
|
||||
test.params[common.IndexTypeKey] = "HNSW"
|
||||
var err error
|
||||
if test.params[common.IsSparseKey] == "True" {
|
||||
err = c.CheckTrain(schemapb.DataType_SparseFloatVector, test.params)
|
||||
err = c.CheckTrain(schemapb.DataType_SparseFloatVector, schemapb.DataType_None, test.params)
|
||||
} else {
|
||||
err = c.CheckTrain(schemapb.DataType_FloatVector, test.params)
|
||||
err = c.CheckTrain(schemapb.DataType_FloatVector, schemapb.DataType_None, test.params)
|
||||
}
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
@ -132,5 +132,5 @@ func Test_baseChecker_CheckValidDataType(t *testing.T) {
|
||||
|
||||
func Test_baseChecker_StaticCheck(t *testing.T) {
|
||||
// TODO
|
||||
assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, nil))
|
||||
assert.Error(t, newBaseChecker().StaticCheck(schemapb.DataType_FloatVector, schemapb.DataType_None, nil))
|
||||
}
|
||||
|
||||
@ -68,7 +68,7 @@ func Test_binFlatChecker_CheckTrain(t *testing.T) {
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("BINFLAT")
|
||||
for _, test := range cases {
|
||||
test.params[common.IndexTypeKey] = "BINFLAT"
|
||||
err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params)
|
||||
err := c.CheckTrain(schemapb.DataType_BinaryVector, schemapb.DataType_None, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
||||
@ -119,7 +119,7 @@ func Test_binIVFFlatChecker_CheckTrain(t *testing.T) {
|
||||
c, _ := GetIndexCheckerMgrInstance().GetChecker("BIN_IVF_FLAT")
|
||||
for _, test := range cases {
|
||||
test.params[common.IndexTypeKey] = "BIN_IVF_FLAT"
|
||||
err := c.CheckTrain(schemapb.DataType_BinaryVector, test.params)
|
||||
err := c.CheckTrain(schemapb.DataType_BinaryVector, schemapb.DataType_None, test.params)
|
||||
if test.errIsNil {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user